/
base_model.py
129 lines (102 loc) · 4.09 KB
/
base_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
# encoding: utf-8
# author: BrikerMan
# contact: eliyar917@gmail.com
# blog: https://eliyar.biz
# file: base_model.py
# time: 2019-05-20 13:07
from typing import Dict, Any, Tuple
import random
import logging
from seqeval.metrics import classification_report
from seqeval.metrics.sequence_labeling import get_entities
from kashgari.tasks.base_model import BaseModel
class BaseLabelingModel(BaseModel):
"""Base Sequence Labeling Model"""
__task__ = 'labeling'
@classmethod
def get_default_hyper_parameters(cls) -> Dict[str, Dict[str, Any]]:
raise NotImplementedError
def predict_entities(self,
x_data,
batch_size=None,
join_chunk=' ',
debug_info=False,
predict_kwargs: Dict = None):
"""Gets entities from sequence.
Args:
x_data: The input data, as a Numpy array (or list of Numpy arrays if the model has multiple inputs).
batch_size: Integer. If unspecified, it will default to 32.
join_chunk: str or False,
debug_info: Bool, Should print out the logging info.
predict_kwargs: arguments passed to ``predict()`` function of ``tf.keras.Model``
Returns:
list: list of entity.
"""
if isinstance(x_data, tuple):
text_seq = x_data[0]
else:
text_seq = x_data
res = self.predict(x_data, batch_size, debug_info, predict_kwargs)
new_res = [get_entities(seq) for seq in res]
final_res = []
for index, seq in enumerate(new_res):
seq_data = []
for entity in seq:
if join_chunk is False:
value = text_seq[index][entity[1]:entity[2] + 1],
else:
value = join_chunk.join(text_seq[index][entity[1]:entity[2] + 1])
seq_data.append({
"entity": entity[0],
"start": entity[1],
"end": entity[2],
"value": value,
})
final_res.append({
'text': join_chunk.join(text_seq[index]),
'text_raw': text_seq[index],
'labels': seq_data
})
return final_res
def evaluate(self,
x_data,
y_data,
batch_size=None,
digits=4,
debug_info=False) -> Tuple[float, float, Dict]:
"""
Build a text report showing the main classification metrics.
Args:
x_data:
y_data:
batch_size:
digits:
debug_info:
Returns:
"""
y_pred = self.predict(x_data, batch_size=batch_size)
y_true = [seq[:len(y_pred[index])] for index, seq in enumerate(y_data)]
if debug_info:
for index in random.sample(list(range(len(x_data))), 5):
logging.debug('------ sample {} ------'.format(index))
logging.debug('x : {}'.format(x_data[index]))
logging.debug('y_true : {}'.format(y_true[index]))
logging.debug('y_pred : {}'.format(y_pred[index]))
report = classification_report(y_true, y_pred, digits=digits)
print(classification_report(y_true, y_pred, digits=digits))
return report
def build_model_arc(self):
raise NotImplementedError
if __name__ == "__main__":
logging.basicConfig(level=logging.DEBUG)
from kashgari.tasks.labeling import BiLSTM_Model
from kashgari.corpus import ChineseDailyNerCorpus
from kashgari.utils import load_model
train_x, train_y = ChineseDailyNerCorpus.load_data('train', shuffle=False)
valid_x, valid_y = ChineseDailyNerCorpus.load_data('valid')
train_x, train_y = train_x[:5120], train_y[:5120]
model = load_model('/Users/brikerman/Desktop/blstm_model')
# model.build_model(train_x[:100], train_y[:100])
# model.fit(train_x[:1000], train_y[:1000], epochs=10)
# model.evaluate(train_x[:20], train_y[:20])
print("Hello world")