Skip to content

Commit

Permalink
✨ Adding predict and evaluate function.
Browse files Browse the repository at this point in the history
  • Loading branch information
BrikerMan committed May 21, 2019
1 parent 44ad63a commit 7e942bf
Show file tree
Hide file tree
Showing 7 changed files with 106 additions and 6 deletions.
4 changes: 4 additions & 0 deletions kashgari/embeddings/base_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,10 @@ def process_y_dataset(self,
"""
return self.processor.process_y_dataset(data, self.sequence_length, subset)

def reverse_numerize_label_sequences(self,
sequences,
lengths=None):
return self.processor.reverse_numerize_label_sequences(sequences, lengths)

if __name__ == "__main__":
print("Hello world")
3 changes: 3 additions & 0 deletions kashgari/pre_processors/base_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,9 @@ def numerize_label_sequences(self,
sequences: List[List[str]]) -> List[List[int]]:
raise NotImplementedError

def reverse_numerize_label_sequences(self, sequence, **kwargs):
raise NotImplemented


if __name__ == "__main__":
print("Hello world")
3 changes: 3 additions & 0 deletions kashgari/pre_processors/classification_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,9 @@ def numerize_label_sequences(self,
result.append([self.label2idx[label] for label in sequence])
return result

def reverse_numerize_label_sequences(self, sequence, **kwargs):
print(sequence)


if __name__ == "__main__":
from kashgari.corpus import SMP2018ECDTCorpus
Expand Down
15 changes: 15 additions & 0 deletions kashgari/pre_processors/labeling_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,20 @@ def numerize_label_sequences(self,
result.append([self.label2idx[label] for label in seq])
return result

def reverse_numerize_label_sequences(self,
sequences,
lengths=None):
result = []

for index, seq in enumerate(sequences):
labels = []
for idx in seq:
labels.append(self.idx2label[idx])
if lengths:
labels = labels[:lengths[index]]
result.append(labels)
return result

def prepare_dicts_if_need(self,
corpus: List[List[str]],
labels: List[List[str]]):
Expand Down Expand Up @@ -110,6 +124,7 @@ def _process_sequence(self,
maxlens: Optional[Tuple[int, ...]] = None,
subset: Optional[List[int]] = None) -> Union[Tuple[np.ndarray, ...], List[np.ndarray]]:
result = []
data = utils.wrap_as_tuple(data)
for index, dataset in enumerate(data):
if subset is not None:
target = utils.get_list_subset(dataset, subset)
Expand Down
73 changes: 73 additions & 0 deletions kashgari/tasks/labeling/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,11 @@
from typing import Dict, Any, List, Optional, Union, Tuple

import numpy as np
import random
import logging
from tensorflow import keras
from seqeval.metrics import classification_report
from seqeval.metrics.sequence_labeling import get_entities

import kashgari
from kashgari import utils
Expand Down Expand Up @@ -190,9 +194,78 @@ def compile_model(self, **kwargs):
self.tf_model.compile(**kwargs)
self.tf_model.summary()

def predict(self,
x_data,
batch_size=None,
debug_info=False):
"""
Generates output predictions for the input samples.
Computation is done in batches.
Args:
x_data: The input data, as a Numpy array (or list of Numpy arrays if the model has multiple inputs).
batch_size: Integer. If unspecified, it will default to 32.
debug_info: Bool, Should print out the logging info
Returns:
array(s) of predictions.
"""
lengths = [len(sen) for sen in x_data]
tensor = self.embedding.process_x_dataset(x_data)
pred = self.tf_model.predict(tensor, batch_size=batch_size)
res = self.embedding.reverse_numerize_label_sequences(pred.argmax(-1),
lengths)
if debug_info:
logging.info('input: {}'.format(tensor))
logging.info('output: {}'.format(pred))
logging.info('output argmax: {}'.format(pred.argmax(-1)))
return res

def evaluate(self,
x_data,
y_data,
batch_size=None,
digits=4,
debug_info=False) -> Tuple[float, float, Dict]:
"""
Build a text report showing the main classification metrics.
Args:
x_data:
y_data:
batch_size:
digits:
debug_info:
Returns:
"""
y_pred = self.predict(x_data, batch_size=batch_size)
y_true = [seq[:self.embedding.sequence_length[0]] for seq in y_data]

if debug_info:
for index in random.sample(list(range(len(x_data))), 5):
logging.debug('------ sample {} ------'.format(index))
logging.debug('x : {}'.format(x_data[index]))
logging.debug('y_true : {}'.format(y_true[index]))
logging.debug('y_pred : {}'.format(y_pred[index]))
report = classification_report(y_true, y_pred, digits=digits)
print(classification_report(y_true, y_pred, digits=digits))
return report

def build_model_arc(self):
raise NotImplementedError


if __name__ == "__main__":
from kashgari.tasks.labeling import CNNLSTMModel
from kashgari.corpus import ChineseDailyNerCorpus

train_x, train_y = ChineseDailyNerCorpus.load_data('valid')

model = CNNLSTMModel()
model.fit(train_x[:100], train_y[:100])
model.predict(train_x[:5])
model.evaluate(train_x[:20], train_y[:20])
print("Hello world")
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def find_version(*file_paths):
# 'scikit-learn>=0.19.1',
# 'numpy>=1.14.3',
# 'download>=0.3.3',
# 'seqeval >=0.0.3',
'seqeval==0.0.10',
# 'colorlog>=4.0.0',
'gensim>=3.5.0',
# # 'bz2file>=0.98',
Expand Down
12 changes: 7 additions & 5 deletions tests/test_labeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,19 +8,17 @@
# time: 2019-05-20 19:03

import unittest

import os
from tensorflow.python.keras import utils

import kashgari
from kashgari.corpus import ChineseDailyNerCorpus
from kashgari.embeddings import WordEmbedding
from kashgari.tasks.labeling import CNNLSTMModel, BLSTMModel

SAMPLE_WORD2VEC_URL = 'http://storage.eliyar.biz/embedding/word2vec/sample_w2v.txt'

valid_x, valid_y = ChineseDailyNerCorpus.load_data('valid')

sample_w2v_path = utils.get_file('sample_w2v.txt', SAMPLE_WORD2VEC_URL)
sample_w2v_path = os.path.join(kashgari.utils.get_project_path(), 'tests/test-data/sample_w2v.txt')

w2v_embedding = WordEmbedding(sample_w2v_path, task=kashgari.LABELING)
w2v_embedding_variable_len = WordEmbedding(sample_w2v_path, task=kashgari.LABELING, sequence_length='variable')
Expand All @@ -34,7 +32,11 @@ def setUpClass(cls):
def test_basic_use_build(self):
model = self.model_class()
model.fit(valid_x, valid_y, valid_x, valid_y, epochs=1)
assert True
res = model.predict(valid_x[:5])
for i in range(5):
assert len(res[i]) == min(model.embedding.sequence_length[0], len(valid_x[i]))

model.evaluate(valid_x[:100], valid_y[:100])

def test_w2v_model(self):
model = self.model_class(embedding=w2v_embedding)
Expand Down

0 comments on commit 7e942bf

Please sign in to comment.