Skip to content

Commit

Permalink
Merge branch 'develop'
Browse files Browse the repository at this point in the history
  • Loading branch information
amaiya committed Jun 5, 2020
2 parents 658adb6 + 35d788c commit 7ff72ec
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 5 deletions.
12 changes: 12 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,18 @@ Most recent releases are shown at the top. Each release shows:
- **Fixed**: Bug fixes that don't change documented behaviour


## 0.16.1 (2020-06-05)

### New:
- N/A

### Changed
- N/A/

### Fixed:
- prevent `transformer` tokenizers from being pickled during `predictor.save`, as it causes problems for
some community-uploaded models like `bert-base-japanese-whole-word-masking`.

## 0.16.0 (2020-06-03)

### New:
Expand Down
23 changes: 19 additions & 4 deletions ktrain/text/preprocessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -814,17 +814,30 @@ def __init__(self, model_name,
if "bert-base-japanese" in model_name:
self.tokenizer_type = transformers.BertJapaneseTokenizer

tokenizer = self.tokenizer_type.from_pretrained(model_name)
# NOTE: As of v0.16.1, do not unnecessarily instantiate tokenizer
# as it will be saved/pickled along with Preprocessor, which causes
# problems for some community-uploaded models like bert-base-japanse-whole-word-masking.
#tokenizer = self.tokenizer_type.from_pretrained(model_name)
#self.tok = tokenizer
self.tok = None # not pickled, see __getstate__

self.tok = tokenizer
self.tok_dct = None
self.max_features = max_features # ignored
self.ngram_range = 1 # ignored


def __getstate__(self):
return {k: v for k, v in self.__dict__.items() if k not in ['tok']}


def __setstate__(self, state):
self.__dict__.update(state)
if not hasattr(self, 'tok'): self.tok = None


def get_preprocessor(self):
if self.tok is None:
self.tok = self.tokenizer_type.from_pretrained(self.model_name)
return (self.tok, self.tok_dct)


Expand All @@ -839,6 +852,7 @@ def undo(self, doc):
undoes preprocessing and returns raw data by:
converting a list or array of Word IDs back to words
"""
tok, _ = self.get_preprocessor()
return self.tok.convert_ids_to_tokens(doc)
#raise Exception('currently_unsupported: Transformers.Preprocessor.undo is not yet supported')

Expand Down Expand Up @@ -870,9 +884,10 @@ def preprocess_train(self, texts, y=None, mode='train', verbose=1):
y = self._transform_y(y)

# convert examples
dataset = hf_convert_examples(texts, y=y, tokenizer=self.tok, max_length=self.maxlen,
tok, _ = self.get_preprocessor()
dataset = hf_convert_examples(texts, y=y, tokenizer=tok, max_length=self.maxlen,
pad_on_left=bool(self.name in ['xlnet']),
pad_token=self.tok.convert_tokens_to_ids([self.tok.pad_token][0]),
pad_token=tok.convert_tokens_to_ids([tok.pad_token][0]),
pad_token_segment_id=4 if self.name in ['xlnet'] else 0)
self.set_multilabel(dataset, mode, verbose=verbose)
if mode == 'train': self.preprocess_train_called = True
Expand Down
2 changes: 1 addition & 1 deletion ktrain/version.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
__all__ = ['__version__']
__version__ = '0.16.0'
__version__ = '0.16.1'

0 comments on commit 7ff72ec

Please sign in to comment.