-
-
Notifications
You must be signed in to change notification settings - Fork 439
/
eval_callBack.py
58 lines (49 loc) · 2.01 KB
/
eval_callBack.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
# encoding: utf-8
# author: BrikerMan
# contact: eliyar917@gmail.com
# blog: https://eliyar.biz
# file: eval_callBack.py
# time: 6:53 下午
from tensorflow import keras
from kashgari.tasks.abs_task_model import ABCTaskModel
from typing import List, Any, Dict
class EvalCallBack(keras.callbacks.Callback):
def __init__(self,
kash_model: ABCTaskModel,
x_data: List[Any],
y_data: List[Any],
*,
step: int = 5,
truncating: bool = False,
batch_size: int = 256) -> None:
"""
Evaluate callback, calculate precision, recall and f1
Args:
kash_model: the kashgari task model to evaluate
x_data: feature data for evaluation
y_data: label data for evaluation
step: step, default 5
truncating: truncating: remove values from sequences larger than `model.embedding.sequence_length`
batch_size: batch size, default 256
"""
super(EvalCallBack, self).__init__()
self.kash_model: ABCTaskModel = kash_model
self.x_data = x_data
self.y_data = y_data
self.step = step
self.truncating = truncating
self.batch_size = batch_size
self.logs: List[Dict] = []
def on_epoch_end(self, epoch: int, logs: Any = None) -> None:
if (epoch + 1) % self.step == 0:
report = self.task_model.evaluate(self.x_data,
self.y_data,
truncating=self.truncating,
batch_size=self.batch_size)
self.logs.append({
'precision': report['precision'],
'recall': report['recall'],
'f1-score': report['f1-score']
})
print(f"\nepoch: {epoch} precision: {report['precision']:.6f},"
f" recall: {report['recall']:.6f}, f1-score: {report['f1-score']:.6f}")