Skip to content

Commit

Permalink
Merge pull request #440 from SharpKoi/v2-dev
Browse files Browse the repository at this point in the history
fix a bug and add a new feature
  • Loading branch information
BrikerMan committed Dec 9, 2020
2 parents 70d40a1 + 8d180e9 commit bb8ce93
Show file tree
Hide file tree
Showing 4 changed files with 112 additions and 8 deletions.
103 changes: 103 additions & 0 deletions kashgari/callbacks/save_callback.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
import os
import numpy as np
from typing import Union, Any, AnyStr

import tensorflow as tf
from kashgari.tasks.abs_task_model import ABCTaskModel
from kashgari.logger import logger


class KashgariModelCheckpoint(tf.keras.callbacks.ModelCheckpoint):
"""Save the model after every epoch.
Arguments:
filepath: string, path to save the model file.
monitor: quantity to monitor.
verbose: verbosity mode, 0 or 1.
save_best_only: if `save_best_only=True`, the latest best model according
to the quantity monitored will not be overwritten.
mode: one of {auto, min, max}. If `save_best_only=True`, the decision to
overwrite the current save file is made based on either the maximization
or the minimization of the monitored quantity. For `val_acc`, this
should be `max`, for `val_loss` this should be `min`, etc. In `auto`
mode, the direction is automatically inferred from the name of the
monitored quantity.
save_weights_only: if True, then only the model's weights will be saved
(`model.save_weights(filepath)`), else the full model is saved
(`model.save(filepath)`).
save_freq: `'epoch'` or integer. When using `'epoch'`, the callback saves
the model after each epoch. When using integer, the callback saves the
model at end of a batch at which this many samples have been seen since
last saving. Note that if the saving isn't aligned to epochs, the
monitored metric may potentially be less reliable (it could reflect as
little as 1 batch, since the metrics get reset every epoch). Defaults to
`'epoch'`
**kwargs: Additional arguments for backwards compatibility. Possible key
is `period`.
"""

def __init__(self,
filepath: AnyStr,
monitor: str = 'val_loss',
verbose: int = 1,
save_best_only: bool = False,
save_weights_only: bool = False,
mode: str = 'auto',
save_freq: Union[str, int] = 'epoch',
kash_model: ABCTaskModel = None,
**kwargs: Any) -> None:
super(KashgariModelCheckpoint, self).__init__(
filepath=filepath,
monitor=monitor,
verbose=verbose,
save_best_only=save_best_only,
save_weights_only=save_weights_only,
mode=mode,
save_freq=save_freq,
**kwargs)
self.kash_model = kash_model

def _save_model(self, epoch: int, logs: dict) -> None:
"""Saves the model.
Arguments:
epoch: the epoch this iteration is in.
logs: the `logs` dict passed in to `on_batch_end` or `on_epoch_end`.
"""
logs = logs or {}

if isinstance(self.save_freq,
int) or self.epochs_since_last_save >= self.period:
self.epochs_since_last_save: int = 0
filepath = self._get_file_path(epoch, logs)

if self.save_best_only:
current = logs.get(self.monitor)
if current is None:
logger.warning('Can save best model only with %s available, skipping.', self.monitor)
else:
if self.monitor_op(current, self.best):
if self.verbose > 0:
print('\nEpoch %d: %s improved from %0.5f to %0.5f,'
' saving model to %s' % (epoch + 1, self.monitor, self.best,
current, filepath))
self.best: float = current
if self.save_weights_only:
filepath = os.path.join(filepath, 'cp')
self.model.save_weights(filepath, overwrite=True)
logger.info(f'checkpoint saved to {filepath}')
else:
self.kash_model.save(filepath)
else:
if self.verbose > 0:
print('\nEpoch %d: %s did not improve from %0.5f' %
(epoch + 1, self.monitor, self.best))
else:
if self.verbose > 0:
print('\nEpoch %d: saving model to %s' % (epoch + 1, filepath))
if self.save_weights_only:
filepath = os.path.join(filepath, 'cp')
self.model.save_weights(filepath, overwrite=True)
logger.info(f'checkpoint saved to {filepath}')
else:
self.kash_model.save(filepath)

self._maybe_remove_file()
4 changes: 2 additions & 2 deletions kashgari/metrics/sequence_labeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,8 +320,8 @@ def sequence_labeling_report(y_true: List[List[str]],
pred_entities = set(bulk_get_entities(y_pred, suffix=suffix))

name_width = 0
d1 = defaultdict(set)
d2 = defaultdict(set)
d1: Dict = defaultdict(set)
d2: Dict = defaultdict(set)
for e in true_entities:
d1[e[0]].add((e[1], e[2]))
name_width = max(name_width, len(e[0]))
Expand Down
11 changes: 6 additions & 5 deletions kashgari/tasks/abs_task_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,11 +76,11 @@ def default_hyper_parameters(cls) -> Dict[str, Dict[str, Any]]:
"""
raise NotImplementedError

def save(self, model_path: str) -> str:
def save(self, model_path: str, encoding: str = 'utf-8') -> str:
pathlib.Path(model_path).mkdir(exist_ok=True, parents=True)
model_path = os.path.abspath(model_path)

with open(os.path.join(model_path, 'model_config.json'), 'w') as f:
with open(os.path.join(model_path, 'model_config.json'), 'w', encoding=encoding) as f:
f.write(json.dumps(self.to_dict(), indent=2, ensure_ascii=False))
f.close()

Expand All @@ -91,15 +91,16 @@ def save(self, model_path: str) -> str:

@classmethod
def load_model(cls, model_path: str,
custom_objects: Dict = None) -> Union["ABCLabelingModel", "ABCClassificationModel"]:
custom_objects: Dict = None,
encoding: str = 'utf-8') -> Union["ABCLabelingModel", "ABCClassificationModel"]:
if custom_objects is None:
custom_objects = {}

if cls.__name__ not in custom_objects:
custom_objects[cls.__name__] = cls

model_config_path = os.path.join(model_path, 'model_config.json')
model_config = json.loads(open(model_config_path, 'r').read())
model_config = json.loads(open(model_config_path, 'r', encoding=encoding).read())
model = load_data_object(model_config, custom_objects)

model.embedding = load_data_object(model_config['embedding'], custom_objects)
Expand All @@ -112,7 +113,7 @@ def load_model(cls, model_path: str,
custom_objects=kashgari.custom_objects)

if isinstance(model.tf_model.layers[-1], KConditionalRandomField):
model.layer_crf = model.tf_model.layers[-1]
model.crf_layer = model.tf_model.layers[-1]

model.tf_model.load_weights(os.path.join(model_path, 'model_weights.h5'))
model.embedding.embed_model.load_weights(os.path.join(model_path, 'embed_model_weights.h5'))
Expand Down
2 changes: 1 addition & 1 deletion kashgari/tokenizers/bert_tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ def _tokenize(self, text: str) -> List[str]:
spaced += ch

if len(self._token_dict) > 0:
tokens = []
tokens: List[str] = []
for word in spaced.strip().split():
tokens += self._word_piece_tokenize(word)
return tokens
Expand Down

0 comments on commit bb8ce93

Please sign in to comment.