From d7e4f72b7702ccb838acba409e96c47c3e9bbc25 Mon Sep 17 00:00:00 2001 From: BrikerMan Date: Sat, 9 May 2020 11:34:16 +0800 Subject: [PATCH] :construction: Work in progress. --- kashgari/callbacks/eval_callBack.py | 26 +++++++++++++--------- kashgari/tasks/abs_task_model.py | 4 +++- kashgari/tasks/classification/abc_model.py | 5 +++-- kashgari/tasks/labeling/abc_model.py | 3 ++- kashgari/tasks/labeling/bi_gru_model.py | 3 --- kashgari/tasks/labeling/cnn_lstm_model.py | 1 - 6 files changed, 23 insertions(+), 19 deletions(-) diff --git a/kashgari/callbacks/eval_callBack.py b/kashgari/callbacks/eval_callBack.py index 2934d7e5..ec5d2873 100644 --- a/kashgari/callbacks/eval_callBack.py +++ b/kashgari/callbacks/eval_callBack.py @@ -16,33 +16,37 @@ class EvalCallBack(keras.callbacks.Callback): def __init__(self, + kash_model: ABCTaskModel, + x_data: List[Any], + y_data: List[Any], *, - task_model: ABCTaskModel, - valid_x: List[Any], - valid_y: List[Any], step: int = 5, + truncating: bool = False, batch_size: int = 256) -> None: """ Evaluate callback, calculate precision, recall and f1 Args: - task_model: the kashgari task model to evaluate - valid_x: feature data - valid_y: label data + kash_model: the kashgari task model to evaluate + x_data: feature data for evaluation + y_data: label data for evaluation step: step, default 5 + truncating: truncating: remove values from sequences larger than `model.embedding.sequence_length` batch_size: batch size, default 256 """ super(EvalCallBack, self).__init__() - self.task_model: ABCTaskModel = task_model - self.valid_x = valid_x - self.valid_y = valid_y + self.kash_model: ABCTaskModel = kash_model + self.x_data = x_data + self.y_data = y_data self.step = step + self.truncating = truncating self.batch_size = batch_size self.logs: List[Dict] = [] def on_epoch_end(self, epoch: int, logs: Any = None) -> None: if (epoch + 1) % self.step == 0: - report = self.task_model.evaluate(self.valid_x, - self.valid_y, + report = self.task_model.evaluate(self.x_data, + self.y_data, + truncating=self.truncating, batch_size=self.batch_size) self.logs.append({ diff --git a/kashgari/tasks/abs_task_model.py b/kashgari/tasks/abs_task_model.py index a7e15c89..120b2ada 100644 --- a/kashgari/tasks/abs_task_model.py +++ b/kashgari/tasks/abs_task_model.py @@ -180,7 +180,9 @@ def evaluate(self, *, batch_size: int = 32, digits: int = 4, - debug_info: bool = False, ) -> Dict: + truncating: bool = False, + debug_info: bool = False, + **kwargs: Dict) -> Dict: raise NotImplementedError diff --git a/kashgari/tasks/classification/abc_model.py b/kashgari/tasks/classification/abc_model.py index 1c240094..3942fc74 100644 --- a/kashgari/tasks/classification/abc_model.py +++ b/kashgari/tasks/classification/abc_model.py @@ -241,7 +241,7 @@ def predict(self, # type: ignore[override] print('output argmax: {}'.format(pred.argmax(-1))) return res - def evaluate(self, + def evaluate(self, # type: ignore[override] x_data: TextSamplesVar, y_data: Union[ClassificationLabelVar, MultiLabelClassificationLabelVar], *, @@ -249,7 +249,8 @@ def evaluate(self, digits: int = 4, multi_label_threshold: float = 0.5, truncating: bool = False, - debug_info: bool = False) -> Dict: + debug_info: bool = False, + **kwargs: Dict) -> Dict: y_pred = self.predict(x_data, batch_size=batch_size, truncating=truncating, diff --git a/kashgari/tasks/labeling/abc_model.py b/kashgari/tasks/labeling/abc_model.py index 3b74ca81..fb57734c 100644 --- a/kashgari/tasks/labeling/abc_model.py +++ b/kashgari/tasks/labeling/abc_model.py @@ -271,7 +271,8 @@ def evaluate(self, batch_size: int = 32, digits: int = 4, truncating: bool = False, - debug_info: bool = False) -> Dict: + debug_info: bool = False, + **kwargs: Dict) -> Dict: """ Build a text report showing the main labeling metrics. diff --git a/kashgari/tasks/labeling/bi_gru_model.py b/kashgari/tasks/labeling/bi_gru_model.py index 07f9bcc3..e769a5d7 100644 --- a/kashgari/tasks/labeling/bi_gru_model.py +++ b/kashgari/tasks/labeling/bi_gru_model.py @@ -17,8 +17,6 @@ class BiGRU_Model(ABCLabelingModel): - __task__ = 'labeling' - @classmethod def default_hyper_parameters(cls) -> Dict[str, Dict[str, Any]]: return { @@ -44,7 +42,6 @@ def build_model_arc(self) -> None: layer_stack = [ L.Bidirectional(L.GRU(**config['layer_bgru']), name='layer_bgru'), L.Dropout(**config['layer_dropout'], name='layer_dropout'), - # L.Dense(output_dim, **config['layer_time_distributed']), L.TimeDistributed(L.Dense(output_dim, **config['layer_time_distributed']), name='layer_time_distributed'), L.Activation(**config['layer_activation']) ] diff --git a/kashgari/tasks/labeling/cnn_lstm_model.py b/kashgari/tasks/labeling/cnn_lstm_model.py index 83265d3e..d64b6767 100644 --- a/kashgari/tasks/labeling/cnn_lstm_model.py +++ b/kashgari/tasks/labeling/cnn_lstm_model.py @@ -42,7 +42,6 @@ def build_model_arc(self) -> None: layer_stack = [ L.Bidirectional(L.GRU(**config['layer_bgru']), name='layer_bgru'), L.Dropout(**config['layer_dropout'], name='layer_dropout'), - # L.Dense(output_dim, **config['layer_time_distributed']), L.TimeDistributed(L.Dense(output_dim, **config['layer_time_distributed']), name='layer_time_distributed'), L.Activation(**config['layer_activation']) ]