Skip to content

Commit

Permalink
🚧 Work in progress.
Browse files Browse the repository at this point in the history
  • Loading branch information
BrikerMan committed May 9, 2020
1 parent 2b26333 commit d7e4f72
Show file tree
Hide file tree
Showing 6 changed files with 23 additions and 19 deletions.
26 changes: 15 additions & 11 deletions kashgari/callbacks/eval_callBack.py
Expand Up @@ -16,33 +16,37 @@
class EvalCallBack(keras.callbacks.Callback):

def __init__(self,
kash_model: ABCTaskModel,
x_data: List[Any],
y_data: List[Any],
*,
task_model: ABCTaskModel,
valid_x: List[Any],
valid_y: List[Any],
step: int = 5,
truncating: bool = False,
batch_size: int = 256) -> None:
"""
Evaluate callback, calculate precision, recall and f1
Args:
task_model: the kashgari task model to evaluate
valid_x: feature data
valid_y: label data
kash_model: the kashgari task model to evaluate
x_data: feature data for evaluation
y_data: label data for evaluation
step: step, default 5
truncating: truncating: remove values from sequences larger than `model.embedding.sequence_length`
batch_size: batch size, default 256
"""
super(EvalCallBack, self).__init__()
self.task_model: ABCTaskModel = task_model
self.valid_x = valid_x
self.valid_y = valid_y
self.kash_model: ABCTaskModel = kash_model
self.x_data = x_data
self.y_data = y_data
self.step = step
self.truncating = truncating
self.batch_size = batch_size
self.logs: List[Dict] = []

def on_epoch_end(self, epoch: int, logs: Any = None) -> None:
if (epoch + 1) % self.step == 0:
report = self.task_model.evaluate(self.valid_x,
self.valid_y,
report = self.task_model.evaluate(self.x_data,
self.y_data,
truncating=self.truncating,
batch_size=self.batch_size)

self.logs.append({
Expand Down
4 changes: 3 additions & 1 deletion kashgari/tasks/abs_task_model.py
Expand Up @@ -180,7 +180,9 @@ def evaluate(self,
*,
batch_size: int = 32,
digits: int = 4,
debug_info: bool = False, ) -> Dict:
truncating: bool = False,
debug_info: bool = False,
**kwargs: Dict) -> Dict:
raise NotImplementedError


Expand Down
5 changes: 3 additions & 2 deletions kashgari/tasks/classification/abc_model.py
Expand Up @@ -241,15 +241,16 @@ def predict(self, # type: ignore[override]
print('output argmax: {}'.format(pred.argmax(-1)))
return res

def evaluate(self,
def evaluate(self, # type: ignore[override]
x_data: TextSamplesVar,
y_data: Union[ClassificationLabelVar, MultiLabelClassificationLabelVar],
*,
batch_size: int = 32,
digits: int = 4,
multi_label_threshold: float = 0.5,
truncating: bool = False,
debug_info: bool = False) -> Dict:
debug_info: bool = False,
**kwargs: Dict) -> Dict:
y_pred = self.predict(x_data,
batch_size=batch_size,
truncating=truncating,
Expand Down
3 changes: 2 additions & 1 deletion kashgari/tasks/labeling/abc_model.py
Expand Up @@ -271,7 +271,8 @@ def evaluate(self,
batch_size: int = 32,
digits: int = 4,
truncating: bool = False,
debug_info: bool = False) -> Dict:
debug_info: bool = False,
**kwargs: Dict) -> Dict:
"""
Build a text report showing the main labeling metrics.
Expand Down
3 changes: 0 additions & 3 deletions kashgari/tasks/labeling/bi_gru_model.py
Expand Up @@ -17,8 +17,6 @@

class BiGRU_Model(ABCLabelingModel):

__task__ = 'labeling'

@classmethod
def default_hyper_parameters(cls) -> Dict[str, Dict[str, Any]]:
return {
Expand All @@ -44,7 +42,6 @@ def build_model_arc(self) -> None:
layer_stack = [
L.Bidirectional(L.GRU(**config['layer_bgru']), name='layer_bgru'),
L.Dropout(**config['layer_dropout'], name='layer_dropout'),
# L.Dense(output_dim, **config['layer_time_distributed']),
L.TimeDistributed(L.Dense(output_dim, **config['layer_time_distributed']), name='layer_time_distributed'),
L.Activation(**config['layer_activation'])
]
Expand Down
1 change: 0 additions & 1 deletion kashgari/tasks/labeling/cnn_lstm_model.py
Expand Up @@ -42,7 +42,6 @@ def build_model_arc(self) -> None:
layer_stack = [
L.Bidirectional(L.GRU(**config['layer_bgru']), name='layer_bgru'),
L.Dropout(**config['layer_dropout'], name='layer_dropout'),
# L.Dense(output_dim, **config['layer_time_distributed']),
L.TimeDistributed(L.Dense(output_dim, **config['layer_time_distributed']), name='layer_time_distributed'),
L.Activation(**config['layer_activation'])
]
Expand Down

0 comments on commit d7e4f72

Please sign in to comment.