-
-
Notifications
You must be signed in to change notification settings - Fork 439
/
abc_model.py
305 lines (271 loc) · 13.4 KB
/
abc_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
# encoding: utf-8
# author: BrikerMan
# contact: eliyar917@gmail.com
# blog: https://eliyar.biz
# file: abc_model.py
# time: 4:30 下午
import logging
import random
from abc import ABC
from typing import List, Dict, Any, Union, TYPE_CHECKING
import kashgari
from kashgari.embeddings.abc_embedding import ABCEmbedding
from kashgari.generators import BatchDataGenerator
from kashgari.generators import CorpusGenerator
from kashgari.processors import SequenceProcessor
from kashgari.tasks.abs_task_model import ABCTaskModel
from kashgari.metrics.sequence_labeling import get_entities
from kashgari.metrics.sequence_labeling import sequence_labeling_report
from kashgari.types import TextSamplesVar
if TYPE_CHECKING:
from tensorflow import keras
class ABCLabelingModel(ABCTaskModel, ABC):
def __init__(self,
embedding: ABCEmbedding = None,
sequence_length: int = None,
hyper_parameters: Dict[str, Dict[str, Any]] = None,
**kwargs: Any):
"""
Abstract Labeling Model
Args:
embedding: embedding object
sequence_length: target sequence length
hyper_parameters: hyper_parameters to overwrite
**kwargs:
"""
super(ABCLabelingModel, self).__init__(embedding=embedding,
sequence_length=sequence_length,
hyper_parameters=hyper_parameters,
**kwargs)
self.default_labeling_processor = SequenceProcessor(vocab_dict_type='labeling',
min_count=1)
def fit(self,
x_train: TextSamplesVar,
y_train: TextSamplesVar,
x_validate: TextSamplesVar = None,
y_validate: TextSamplesVar = None,
batch_size: int = 64,
epochs: int = 5,
callbacks: List['keras.callbacks.Callback'] = None,
fit_kwargs: Dict = None) -> 'keras.callbacks.History':
"""
Trains the model for a given number of epochs with given data set list.
Args:
x_train: Array of train feature data (if the model has a single input),
or tuple of train feature data array (if the model has multiple inputs)
y_train: Array of train label data
x_validate: Array of validation feature data (if the model has a single input),
or tuple of validation feature data array (if the model has multiple inputs)
y_validate: Array of validation label data
batch_size: Number of samples per gradient update, default to 64.
epochs: Number of epochs to train the model.
An epoch is an iteration over the entire `x` and `y` data provided.
Note that in conjunction with `initial_epoch`, `epochs` is to be understood as "final epoch".
The model is not trained for a number of iterations given by `epochs`, but merely until the epoch
of index `epochs` is reached.
callbacks: List of `keras.callbacks.Callback` instances.
List of callbacks to apply during training.
See `tf.keras.callbacks`.
fit_kwargs: fit_kwargs: additional arguments passed to ``fit()`` function from
``tensorflow.keras.Model`` - https://www.tensorflow.org/api_docs/python/tf/keras/Model#fit
Returns:
A `History` object. Its `History.history` attribute is
a record of training loss values and metrics values
at successive epochs, as well as validation loss values
and validation metrics values (if applicable).
"""
train_gen = CorpusGenerator(x_train, y_train)
if x_validate is not None:
valid_gen = CorpusGenerator(x_validate, y_validate)
else:
valid_gen = None
return self.fit_generator(train_sample_gen=train_gen,
valid_sample_gen=valid_gen,
batch_size=batch_size,
epochs=epochs,
callbacks=callbacks,
fit_kwargs=fit_kwargs)
def fit_generator(self,
train_sample_gen: CorpusGenerator,
valid_sample_gen: CorpusGenerator = None,
batch_size: int = 64,
epochs: int = 5,
callbacks: List['keras.callbacks.Callback'] = None,
fit_kwargs: Dict = None) -> 'keras.callbacks.History':
"""
Trains the model for a given number of epochs with given data generator.
Data generator must be the subclass of `CorpusGenerator`
Args:
train_sample_gen: train data generator.
valid_sample_gen: valid data generator.
batch_size: Number of samples per gradient update, default to 64.
epochs: Number of epochs to train the model.
An epoch is an iteration over the entire `x` and `y` data provided.
Note that in conjunction with `initial_epoch`, `epochs` is to be understood as "final epoch".
The model is not trained for a number of iterations given by `epochs`, but merely until the epoch
of index `epochs` is reached.
callbacks: List of `keras.callbacks.Callback` instances.
List of callbacks to apply during training.
See `tf.keras.callbacks`.
fit_kwargs: fit_kwargs: additional arguments passed to ``fit()`` function from
``tensorflow.keras.Model`` - https://www.tensorflow.org/api_docs/python/tf/keras/Model#fit
Returns:
A `History` object. Its `History.history` attribute is
a record of training loss values and metrics values
at successive epochs, as well as validation loss values
and validation metrics values (if applicable).
"""
self.build_model(train_sample_gen)
self.tf_model.summary()
train_gen = BatchDataGenerator(train_sample_gen,
text_processor=self.text_processor,
label_processor=self.label_processor,
segment=self.embedding.segment,
seq_length=self.embedding.sequence_length,
max_position=self.embedding.max_position,
batch_size=batch_size)
if fit_kwargs is None:
fit_kwargs = {}
if valid_sample_gen:
valid_gen = BatchDataGenerator(valid_sample_gen,
text_processor=self.text_processor,
label_processor=self.label_processor,
segment=self.embedding.segment,
seq_length=self.embedding.sequence_length,
max_position=self.embedding.max_position,
batch_size=batch_size)
fit_kwargs['validation_data'] = valid_gen.generator()
fit_kwargs['validation_steps'] = len(valid_gen)
return self.tf_model.fit(train_gen.generator(),
steps_per_epoch=len(train_gen),
epochs=epochs,
callbacks=callbacks,
**fit_kwargs)
def predict(self, # type: ignore[override]
x_data: TextSamplesVar,
*,
batch_size: int = 32,
truncating: bool = False,
debug_info: bool = False,
predict_kwargs: Dict = None,
**kwargs: Any) -> List[List[str]]:
"""
Generates output predictions for the input samples.
Computation is done in batches.
Args:
x_data: The input data, as a Numpy array (or list of Numpy arrays if the model has multiple inputs).
batch_size: Integer. If unspecified, it will default to 32.
truncating: remove values from sequences larger than `model.embedding.sequence_length`
debug_info: Bool, Should print out the logging info.
predict_kwargs: arguments passed to ``predict()`` function of ``tf.keras.Model``
Returns:
array(s) of predictions.
"""
if predict_kwargs is None:
predict_kwargs = {}
with kashgari.utils.custom_object_scope():
if truncating:
seq_length = self.embedding.sequence_length
else:
seq_length = None
tensor = self.text_processor.transform(x_data,
segment=self.embedding.segment,
seq_lengtg=seq_length,
max_position=self.embedding.max_position)
pred = self.tf_model.predict(tensor, batch_size=batch_size, **predict_kwargs)
pred = pred.argmax(-1)
lengths = [len(sen) for sen in x_data]
res: List[List[str]] = self.label_processor.inverse_transform(pred, # type: ignore
lengths=lengths)
if debug_info:
logging.info('input: {}'.format(tensor))
logging.info('output: {}'.format(pred))
logging.info('output argmax: {}'.format(pred.argmax(-1)))
return res
def predict_entities(self,
x_data: TextSamplesVar,
batch_size: int = 32,
join_chunk: str = ' ',
truncating: bool = False,
debug_info: bool = False,
predict_kwargs: Dict = None) -> List[Dict]:
"""Gets entities from sequence.
Args:
x_data: The input data, as a Numpy array (or list of Numpy arrays if the model has multiple inputs).
batch_size: Integer. If unspecified, it will default to 32.
truncating: remove values from sequences larger than `model.embedding.sequence_length`
join_chunk: str or False,
debug_info: Bool, Should print out the logging info.
predict_kwargs: arguments passed to ``predict()`` function of ``tf.keras.Model``
Returns:
list: list of entity.
"""
if isinstance(x_data, tuple):
text_seq = x_data[0]
else:
text_seq = x_data
res = self.predict(x_data,
batch_size=batch_size,
truncating=truncating,
debug_info=debug_info,
predict_kwargs=predict_kwargs)
new_res = [get_entities(seq) for seq in res]
final_res = []
for index, seq in enumerate(new_res):
seq_data = []
for entity in seq:
res_entities: List[str] = []
for i, e in enumerate(text_seq[index][entity[1]:entity[2] + 1]):
# Handle bert tokenizer
if e.startswith('##') and len(res_entities) > 0:
res_entities[-1] += e.replace('##', '')
else:
res_entities.append(e)
value: Union[str, List[str]]
if join_chunk is False:
value = res_entities
else:
value = join_chunk.join(res_entities)
seq_data.append({
"entity": entity[0],
"start": entity[1],
"end": entity[2],
"value": value,
})
final_res.append({
'tokenized': x_data[index],
'labels': seq_data
})
return final_res
def evaluate(self,
x_data: TextSamplesVar,
y_data: TextSamplesVar,
batch_size: int = 32,
digits: int = 4,
truncating: bool = False,
debug_info: bool = False,
**kwargs: Dict) -> Dict:
"""
Build a text report showing the main labeling metrics.
"""
y_pred = self.predict(x_data,
batch_size=batch_size,
truncating=truncating,
debug_info=debug_info)
y_true = [seq[:len(y_pred[index])] for index, seq in enumerate(y_data)]
new_y_pred = []
for x in y_pred:
new_y_pred.append([str(i) for i in x])
new_y_true = []
for x in y_true:
new_y_true.append([str(i) for i in x])
if debug_info:
for index in random.sample(list(range(len(x_data))), 5):
logging.debug('------ sample {} ------'.format(index))
logging.debug('x : {}'.format(x_data[index]))
logging.debug('y_true : {}'.format(y_true[index]))
logging.debug('y_pred : {}'.format(y_pred[index]))
report = sequence_labeling_report(y_true, y_pred, digits=digits)
return report
if __name__ == "__main__":
pass