Skip to content

Commit

Permalink
Merge branch 'develop'
Browse files Browse the repository at this point in the history
  • Loading branch information
amaiya committed Feb 4, 2020
2 parents ed0668a + 2802d14 commit f25d22c
Show file tree
Hide file tree
Showing 5 changed files with 51 additions and 21 deletions.
14 changes: 14 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,20 @@ Most recent releases are shown at the top. Each release shows:
- **Fixed**: Bug fixes that don't change documented behaviour


## 0.9.2 (2020-02-04)

### New:
- N/A

### Changed:
- Removed Exception when `distilbert` is selected in `text_classifier` for non-English language after
[Hugging Face fixed the reported bug](https://github.com/huggingface/transformers/issues/2462).

### Fixed:
- XLNet models like `xlnet-base-cased` now works after casting input arrays to `int32`
- modified `TextPredictor.explain` to propogate correct error message from `eli5` for multilabel text classification.


## 0.9.1 (2020-02-01)

### New:
Expand Down
22 changes: 11 additions & 11 deletions examples/text/ChineseHotelReviews-BERT.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -124,13 +124,13 @@
}
],
"source": [
"(x_train, y_train), (x_test, y_test), preproc = text.texts_from_folder('/home/amaiya/data/ChnSentiCorp_htl_ba_6000', \n",
" maxlen=75, \n",
" max_features=30000,\n",
" preprocess_mode='bert',\n",
" train_test_names=['train'],\n",
" val_pct=0.1,\n",
" classes=['pos', 'neg'])"
"trn, val, preproc = text.texts_from_folder('/home/amaiya/data/ChnSentiCorp_htl_ba_6000', \n",
" maxlen=75, \n",
" max_features=30000,\n",
" preprocess_mode='bert',\n",
" train_test_names=['train'],\n",
" val_pct=0.1,\n",
" classes=['pos', 'neg'])"
]
},
{
Expand All @@ -156,10 +156,10 @@
}
],
"source": [
"model = text.text_classifier('bert', (x_train, y_train) , preproc=preproc)\n",
"model = text.text_classifier('bert', trn, preproc=preproc)\n",
"learner = ktrain.get_learner(model, \n",
" train_data=(x_train, y_train), \n",
" val_data=(x_test, y_test), \n",
" train_data=trn, \n",
" val_data=val, \n",
" batch_size=32)"
]
},
Expand Down Expand Up @@ -423,7 +423,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.8"
"version": "3.6.9"
}
},
"nbformat": 4,
Expand Down
6 changes: 3 additions & 3 deletions ktrain/text/predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def predict(self, texts, return_proba=False):
elif not isinstance(texts, np.ndarray) and not isinstance(texts, list):
raise ValueError('data must be numpy.ndarray or list (of texts)')
classification, multilabel = U.is_classifier(self.model)
if multilabel: return_proba = True
#if multilabel: return_proba = True
#treat_multilabel = False
#loss = self.model.loss
#if loss != 'categorical_crossentropy' and not return_proba:
Expand All @@ -55,8 +55,8 @@ def predict(self, texts, return_proba=False):
else:
preds = np.squeeze(preds)
if len(preds.shape) == 0: preds = np.expand_dims(preds, -1)
result = preds if return_proba or not self.c else [self.c[np.argmax(pred)] for pred in preds]
if multilabel:
result = preds if return_proba or multilabel or not self.c else [self.c[np.argmax(pred)] for pred in preds]
if multilabel and not return_proba:
result = [list(zip(self.c, r)) for r in result]
if is_str: return result[0]
else: return result
Expand Down
28 changes: 22 additions & 6 deletions ktrain/text/preprocessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -793,8 +793,6 @@ def __init__(self, maxlen, max_features, classes=[],
model_name = 'distilbert-base-uncased'
else:
model_name = 'distilbert-base-multilingual-cased'
raise Exception('currently_unsupported: non-English languages are not currently supported for '+\
'distilbert due to issues with TF2 version of transformers library. ')

super().__init__(model_name,
maxlen, max_features, classes=classes,
Expand Down Expand Up @@ -948,10 +946,28 @@ def to_tfdataset(self, shuffle=True, repeat=True):
"""
convert transformer features to tf.Dataset
"""
tfdataset = tf.data.Dataset.from_tensor_slices((self.x, self.y))
tfdataset = tfdataset.map(lambda x,y: ({'input_ids': x[0],
'attention_mask': x[1],
'token_type_ids': x[2]}, y))
if len(self.y.shape) == 1:
yshape = []
else:
yshape = [None]

def gen():
for idx, data in enumerate(self.x):
yield ({'input_ids': data[0],
'attention_mask': data[1],
'token_type_ids': data[2]},
self.y[idx])

tfdataset= tf.data.Dataset.from_generator(gen,
({'input_ids': tf.int32,
'attention_mask': tf.int32,
'token_type_ids': tf.int32},
tf.int64),
({'input_ids': tf.TensorShape([None]),
'attention_mask': tf.TensorShape([None]),
'token_type_ids': tf.TensorShape([None])},
tf.TensorShape(yshape)))

if shuffle:
tfdataset = tfdataset.shuffle(self.x.shape[0])
tfdataset = tfdataset.batch(self.batch_size)
Expand Down
2 changes: 1 addition & 1 deletion ktrain/version.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
__all__ = ['__version__']
__version__ = '0.9.1'
__version__ = '0.9.2'

0 comments on commit f25d22c

Please sign in to comment.