Skip to content

Commit

Permalink
Merge branch 'develop'
Browse files Browse the repository at this point in the history
  • Loading branch information
amaiya committed May 15, 2020
2 parents 678a74c + 4d1a4f5 commit 061a1ef
Show file tree
Hide file tree
Showing 5 changed files with 21 additions and 4 deletions.
13 changes: 13 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,19 @@ Most recent releases are shown at the top. Each release shows:
- **Fixed**: Bug fixes that don't change documented behaviour


## 0.15.2 (2020-05-15)

### New:
- N/A

### Changed
- Added `n_samples` argument to `TextPredictor.explain` to address slowness of `explain` on Google Colab
- Lock to version 0.21.3 of `scikit-learn` to ensure old-style explanations are generated from `TextPredictor.explain`

### Fixed:
- added missing `import pickle` to ensure saved topic models can be loaded


## 0.15.1 (2020-05-14)

### New:
Expand Down
1 change: 1 addition & 0 deletions ktrain/text/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from . import shallownlp
from .qa import SimpleQA
from . import textutils
import pickle

__all__ = [
'text_classifier', 'text_regression_model',
Expand Down
7 changes: 5 additions & 2 deletions ktrain/text/predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,14 +82,17 @@ def predict_proba(self, texts):
return self.predict(texts, return_proba=True)


def explain(self, doc, truncate_len=512, all_targets=False):
def explain(self, doc, truncate_len=512, all_targets=False, n_samples=2500):
"""
Highlights text to explain prediction
Args:
doc (str): text of documnet
truncate_len(int): truncate document to this many words
all_targets(bool): If True, show visualization for
each target.
n_samples(int): number of samples to generate and train on.
Larger values give better results, but will take more time.
Lower this value if explain is taking too long.
"""
is_array, is_pair = detect_text_format(doc)
if is_pair:
Expand All @@ -114,7 +117,7 @@ def explain(self, doc, truncate_len=512, all_targets=False):
doc = self.preproc.process_chinese([doc])
doc = doc[0]
doc = ' '.join(doc.split()[:truncate_len])
te = TextExplainer(random_state=42)
te = TextExplainer(random_state=42, n_samples=n_samples)
_ = te.fit(doc, self.predict_proba)
return te.show_prediction(target_names=self.preproc.get_classes(), targets=prediction)

Expand Down
2 changes: 1 addition & 1 deletion ktrain/version.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
__all__ = ['__version__']
__version__ = '0.15.1'
__version__ = '0.15.2'
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
keywords = ['tensorflow', 'keras', 'deep learning', 'machine learning'],
install_requires=[
'tensorflow==2.1.0',
'scikit-learn >= 0.21.3',
'scikit-learn==0.21.3', # affects format of predictor.explain
'matplotlib >= 3.0.0',
'pandas >= 1.0.1',
'fastprogress >= 0.1.21',
Expand Down

0 comments on commit 061a1ef

Please sign in to comment.