Skip to content

Commit

Permalink
Merge branch 'develop'
Browse files Browse the repository at this point in the history
  • Loading branch information
amaiya committed Dec 5, 2020
2 parents 5c9c6b3 + 68fc86b commit a79103c
Show file tree
Hide file tree
Showing 4 changed files with 29 additions and 10 deletions.
13 changes: 13 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,19 @@ Most recent releases are shown at the top. Each release shows:
- **Changed**: Additional parameters, changes to inputs or outputs, etc
- **Fixed**: Bug fixes that don't change documented behaviour

## 0.25.2 (2020-12-05)

### New:
- N/A

### Changed
- N/A

### Fixed:
- Added `custom_objects` argument to `load_predictor` to load models with custom loss functions, etc.
- Fixed bug #286 related to length computation when `use_dynamic_shape=True`


## 0.25.1 (2020-12-02)

### New:
Expand Down
20 changes: 12 additions & 8 deletions ktrain/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1438,14 +1438,18 @@ def get_predictor(model, preproc, batch_size=U.DEFAULT_BS):
raise Exception('preproc of type %s not currently supported' % (type(preproc)))


def load_predictor(fpath, batch_size=U.DEFAULT_BS):
def load_predictor(fpath, batch_size=U.DEFAULT_BS, custom_objects=None):
"""
Loads a previously saved Predictor instance
Args
fpath(str): predictor path name (value supplied to predictor.save)
From v0.16.x, this is always the path to a folder.
Pre-v0.16.x, this is the base name used to save model and .preproc instance.
batch_size(int): batch size to use for predictions. default:32
custom_objects(dict): custom objects required to load model.
This is useful if you compiled the model with a custom loss function, for example.
For models included with ktrain as is, this is populated automatically
and can be disregarded.
"""

# load the preprocessor
Expand All @@ -1462,7 +1466,7 @@ def load_predictor(fpath, batch_size=U.DEFAULT_BS):
raise Exception('Could not find a .preproc file in either the post v0.16.x loction (%s) or pre v0.16.x location (%s)' % (os.path.join(fpath, U.PREPROC_NAME), fpath+'.preproc'))

# load the model
model = _load_model(fpath, preproc=preproc)
model = _load_model(fpath, preproc=preproc, custom_objects=custom_objects)


# preprocessing functions in ImageDataGenerators are not pickable
Expand Down Expand Up @@ -1537,7 +1541,10 @@ def _load_model(fpath, preproc=None, train_data=None, custom_objects=None):
train_data and U.bert_data_tuple(train_data):
# custom BERT model
from keras_bert import get_custom_objects
custom_objects = get_custom_objects()
if isinstance(custom_objects, dict):
custom_objects.update(get_custom_objects())
else:
custom_objects = get_custom_objects()
elif (preproc and (isinstance(preproc, NERPreprocessor) or \
type(preproc).__name__ == 'NERPreprocessor')) or \
train_data and U.is_ner(data=train_data):
Expand Down Expand Up @@ -1568,11 +1575,8 @@ def _load_model(fpath, preproc=None, train_data=None, custom_objects=None):
# for bilstm models without CRF layer on TF2 where CRF is not supported
model = load_model(fpath, custom_objects={'AdamWeightDecay':AdamWeightDecay})
except Exception as e:
print('Call to keras.models.load_model failed. '
'Try using the learner.model.save_weights and '
'learner.model.load_weights instead.')
print('Error was: %s' % (e))
return
print('Call to keras.models.load_model failed. Try manually invoking this function to investigate error and report issue if necessary.')
raise Exception('Error detected: %s' % (e))

# see issue https://github.com/amaiya/ktrain/issues/21
if hasattr(model, '_make_predict_function'):
Expand Down
4 changes: 3 additions & 1 deletion ktrain/text/preprocessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -340,8 +340,10 @@ def hf_convert_examples(texts, y=None, tokenizer=None,
else:
text_a = text
text_b = None
sentences.append(tokenizer.tokenize(text_a, text_b))
sentences.append( tokenizer.convert_ids_to_tokens(tokenizer.encode(text_a, text_b)) )
#sentences.append(tokenizer.tokenize(text_a, text_b)) # only works for Fast tokenizers
maxlen = len(max([tokens for tokens in sentences], key=len,)) + 2

if maxlen < max_length: max_length = maxlen


Expand Down
2 changes: 1 addition & 1 deletion ktrain/version.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
__all__ = ['__version__']
__version__ = '0.25.1'
__version__ = '0.25.2'

0 comments on commit a79103c

Please sign in to comment.