From 7e942bf99f373d84850b6eefd510bb287d616538 Mon Sep 17 00:00:00 2001 From: BrikerMan Date: Tue, 21 May 2019 23:25:41 +0800 Subject: [PATCH] :sparkles: Adding predict and evaluate function. --- kashgari/embeddings/base_embedding.py | 4 + kashgari/pre_processors/base_processor.py | 3 + .../classification_processor.py | 3 + kashgari/pre_processors/labeling_processor.py | 15 ++++ kashgari/tasks/labeling/base_model.py | 73 +++++++++++++++++++ setup.py | 2 +- tests/test_labeling.py | 12 +-- 7 files changed, 106 insertions(+), 6 deletions(-) diff --git a/kashgari/embeddings/base_embedding.py b/kashgari/embeddings/base_embedding.py index c66a8c52..761d5126 100644 --- a/kashgari/embeddings/base_embedding.py +++ b/kashgari/embeddings/base_embedding.py @@ -159,6 +159,10 @@ def process_y_dataset(self, """ return self.processor.process_y_dataset(data, self.sequence_length, subset) + def reverse_numerize_label_sequences(self, + sequences, + lengths=None): + return self.processor.reverse_numerize_label_sequences(sequences, lengths) if __name__ == "__main__": print("Hello world") diff --git a/kashgari/pre_processors/base_processor.py b/kashgari/pre_processors/base_processor.py index 5fb3ce74..ffc69bae 100644 --- a/kashgari/pre_processors/base_processor.py +++ b/kashgari/pre_processors/base_processor.py @@ -142,6 +142,9 @@ def numerize_label_sequences(self, sequences: List[List[str]]) -> List[List[int]]: raise NotImplementedError + def reverse_numerize_label_sequences(self, sequence, **kwargs): + raise NotImplemented + if __name__ == "__main__": print("Hello world") diff --git a/kashgari/pre_processors/classification_processor.py b/kashgari/pre_processors/classification_processor.py index ad7762f4..6d22c1df 100644 --- a/kashgari/pre_processors/classification_processor.py +++ b/kashgari/pre_processors/classification_processor.py @@ -94,6 +94,9 @@ def numerize_label_sequences(self, result.append([self.label2idx[label] for label in sequence]) return result + def reverse_numerize_label_sequences(self, sequence, **kwargs): + print(sequence) + if __name__ == "__main__": from kashgari.corpus import SMP2018ECDTCorpus diff --git a/kashgari/pre_processors/labeling_processor.py b/kashgari/pre_processors/labeling_processor.py index d20b6ccf..acff7919 100644 --- a/kashgari/pre_processors/labeling_processor.py +++ b/kashgari/pre_processors/labeling_processor.py @@ -73,6 +73,20 @@ def numerize_label_sequences(self, result.append([self.label2idx[label] for label in seq]) return result + def reverse_numerize_label_sequences(self, + sequences, + lengths=None): + result = [] + + for index, seq in enumerate(sequences): + labels = [] + for idx in seq: + labels.append(self.idx2label[idx]) + if lengths: + labels = labels[:lengths[index]] + result.append(labels) + return result + def prepare_dicts_if_need(self, corpus: List[List[str]], labels: List[List[str]]): @@ -110,6 +124,7 @@ def _process_sequence(self, maxlens: Optional[Tuple[int, ...]] = None, subset: Optional[List[int]] = None) -> Union[Tuple[np.ndarray, ...], List[np.ndarray]]: result = [] + data = utils.wrap_as_tuple(data) for index, dataset in enumerate(data): if subset is not None: target = utils.get_list_subset(dataset, subset) diff --git a/kashgari/tasks/labeling/base_model.py b/kashgari/tasks/labeling/base_model.py index 27c90655..108a541b 100644 --- a/kashgari/tasks/labeling/base_model.py +++ b/kashgari/tasks/labeling/base_model.py @@ -11,7 +11,11 @@ from typing import Dict, Any, List, Optional, Union, Tuple import numpy as np +import random +import logging from tensorflow import keras +from seqeval.metrics import classification_report +from seqeval.metrics.sequence_labeling import get_entities import kashgari from kashgari import utils @@ -190,9 +194,78 @@ def compile_model(self, **kwargs): self.tf_model.compile(**kwargs) self.tf_model.summary() + def predict(self, + x_data, + batch_size=None, + debug_info=False): + """ + Generates output predictions for the input samples. + + Computation is done in batches. + + Args: + x_data: The input data, as a Numpy array (or list of Numpy arrays if the model has multiple inputs). + batch_size: Integer. If unspecified, it will default to 32. + debug_info: Bool, Should print out the logging info + + Returns: + array(s) of predictions. + """ + lengths = [len(sen) for sen in x_data] + tensor = self.embedding.process_x_dataset(x_data) + pred = self.tf_model.predict(tensor, batch_size=batch_size) + res = self.embedding.reverse_numerize_label_sequences(pred.argmax(-1), + lengths) + if debug_info: + logging.info('input: {}'.format(tensor)) + logging.info('output: {}'.format(pred)) + logging.info('output argmax: {}'.format(pred.argmax(-1))) + return res + + def evaluate(self, + x_data, + y_data, + batch_size=None, + digits=4, + debug_info=False) -> Tuple[float, float, Dict]: + """ + Build a text report showing the main classification metrics. + + Args: + x_data: + y_data: + batch_size: + digits: + debug_info: + + Returns: + + """ + y_pred = self.predict(x_data, batch_size=batch_size) + y_true = [seq[:self.embedding.sequence_length[0]] for seq in y_data] + + if debug_info: + for index in random.sample(list(range(len(x_data))), 5): + logging.debug('------ sample {} ------'.format(index)) + logging.debug('x : {}'.format(x_data[index])) + logging.debug('y_true : {}'.format(y_true[index])) + logging.debug('y_pred : {}'.format(y_pred[index])) + report = classification_report(y_true, y_pred, digits=digits) + print(classification_report(y_true, y_pred, digits=digits)) + return report + def build_model_arc(self): raise NotImplementedError if __name__ == "__main__": + from kashgari.tasks.labeling import CNNLSTMModel + from kashgari.corpus import ChineseDailyNerCorpus + + train_x, train_y = ChineseDailyNerCorpus.load_data('valid') + + model = CNNLSTMModel() + model.fit(train_x[:100], train_y[:100]) + model.predict(train_x[:5]) + model.evaluate(train_x[:20], train_y[:20]) print("Hello world") diff --git a/setup.py b/setup.py index de57cad6..3c25f264 100644 --- a/setup.py +++ b/setup.py @@ -59,7 +59,7 @@ def find_version(*file_paths): # 'scikit-learn>=0.19.1', # 'numpy>=1.14.3', # 'download>=0.3.3', - # 'seqeval >=0.0.3', + 'seqeval==0.0.10', # 'colorlog>=4.0.0', 'gensim>=3.5.0', # # 'bz2file>=0.98', diff --git a/tests/test_labeling.py b/tests/test_labeling.py index 502b5c67..81624b24 100644 --- a/tests/test_labeling.py +++ b/tests/test_labeling.py @@ -8,7 +8,7 @@ # time: 2019-05-20 19:03 import unittest - +import os from tensorflow.python.keras import utils import kashgari @@ -16,11 +16,9 @@ from kashgari.embeddings import WordEmbedding from kashgari.tasks.labeling import CNNLSTMModel, BLSTMModel -SAMPLE_WORD2VEC_URL = 'http://storage.eliyar.biz/embedding/word2vec/sample_w2v.txt' - valid_x, valid_y = ChineseDailyNerCorpus.load_data('valid') -sample_w2v_path = utils.get_file('sample_w2v.txt', SAMPLE_WORD2VEC_URL) +sample_w2v_path = os.path.join(kashgari.utils.get_project_path(), 'tests/test-data/sample_w2v.txt') w2v_embedding = WordEmbedding(sample_w2v_path, task=kashgari.LABELING) w2v_embedding_variable_len = WordEmbedding(sample_w2v_path, task=kashgari.LABELING, sequence_length='variable') @@ -34,7 +32,11 @@ def setUpClass(cls): def test_basic_use_build(self): model = self.model_class() model.fit(valid_x, valid_y, valid_x, valid_y, epochs=1) - assert True + res = model.predict(valid_x[:5]) + for i in range(5): + assert len(res[i]) == min(model.embedding.sequence_length[0], len(valid_x[i])) + + model.evaluate(valid_x[:100], valid_y[:100]) def test_w2v_model(self): model = self.model_class(embedding=w2v_embedding)