Skip to content

Commit

Permalink
🚧 Work in progress.
Browse files Browse the repository at this point in the history
  • Loading branch information
BrikerMan committed May 8, 2020
1 parent a4a7e3e commit 2b26333
Show file tree
Hide file tree
Showing 6 changed files with 99 additions and 94 deletions.
13 changes: 7 additions & 6 deletions kashgari/callbacks/eval_callBack.py
Expand Up @@ -10,17 +10,18 @@
from tensorflow import keras

from kashgari.tasks.abs_task_model import ABCTaskModel
from typing import List, Any, Dict


class EvalCallBack(keras.callbacks.Callback):

def __init__(self,
*,
task_model: ABCTaskModel,
valid_x,
valid_y,
step=5,
batch_size=256):
valid_x: List[Any],
valid_y: List[Any],
step: int = 5,
batch_size: int = 256) -> None:
"""
Evaluate callback, calculate precision, recall and f1
Args:
Expand All @@ -36,9 +37,9 @@ def __init__(self,
self.valid_y = valid_y
self.step = step
self.batch_size = batch_size
self.logs = []
self.logs: List[Dict] = []

def on_epoch_end(self, epoch, logs=None):
def on_epoch_end(self, epoch: int, logs: Any = None) -> None:
if (epoch + 1) % self.step == 0:
report = self.task_model.evaluate(self.valid_x,
self.valid_y,
Expand Down
3 changes: 0 additions & 3 deletions kashgari/corpus.py
Expand Up @@ -283,6 +283,3 @@ def load_data(self,

for i in y[:20]:
print(i)

import numpy as np
np.sum()
Expand Up @@ -7,6 +7,5 @@
# file: __init__.py
# time: 10:44 下午


if __name__ == "__main__":
pass
Expand Up @@ -20,24 +20,27 @@
from collections import defaultdict

import numpy as np
from typing import List, Dict, Tuple, Any


def get_entities(seq, suffix=False):
def bulk_get_entities(seq_list: List[List[str]], *, suffix: bool = False) -> List[Tuple[str, int, int]]:
seq = [item for sublist in seq_list for item in sublist + ['O']]
return get_entities(seq, suffix=suffix)


def get_entities(seq: List[str], *, suffix: bool = False) -> List[Tuple[str, int, int]]:
"""Gets entities from sequence.
Args:
seq (list): sequence of labels.
seq: sequence of labels.
suffix:
Returns:
list: list of (chunk_type, chunk_start, chunk_end).
Example:
>>> from seqeval.metrics.sequence_labeling import get_entities
>>> from kashgari.metrics.sequence_labeling import get_entities
>>> seq = ['B-PER', 'I-PER', 'O', 'B-LOC']
>>> get_entities(seq)
[('PER', 0, 1), ('LOC', 3, 3)]
"""
# for nested list
if any(isinstance(s, list) for s in seq):
seq = [item for sublist in seq for item in sublist + ['O']]

prev_tag = 'O'
prev_type = ''
begin_offset = 0
Expand All @@ -60,7 +63,7 @@ def get_entities(seq, suffix=False):
return chunks


def end_of_chunk(prev_tag, tag, prev_type, type_):
def end_of_chunk(prev_tag: str, tag: str, prev_type: str, type_: str) -> bool:
"""Checks if a chunk ended between the previous and current word.
Args:
prev_tag: previous chunk tag.
Expand All @@ -72,23 +75,31 @@ def end_of_chunk(prev_tag, tag, prev_type, type_):
"""
chunk_end = False

if prev_tag == 'E': chunk_end = True
if prev_tag == 'S': chunk_end = True
if prev_tag == 'E':
chunk_end = True
if prev_tag == 'S':
chunk_end = True

if prev_tag == 'B' and tag == 'B': chunk_end = True
if prev_tag == 'B' and tag == 'S': chunk_end = True
if prev_tag == 'B' and tag == 'O': chunk_end = True
if prev_tag == 'I' and tag == 'B': chunk_end = True
if prev_tag == 'I' and tag == 'S': chunk_end = True
if prev_tag == 'I' and tag == 'O': chunk_end = True
if prev_tag == 'B' and tag == 'B':
chunk_end = True
if prev_tag == 'B' and tag == 'S':
chunk_end = True
if prev_tag == 'B' and tag == 'O':
chunk_end = True
if prev_tag == 'I' and tag == 'B':
chunk_end = True
if prev_tag == 'I' and tag == 'S':
chunk_end = True
if prev_tag == 'I' and tag == 'O':
chunk_end = True

if prev_tag != 'O' and prev_tag != '.' and prev_type != type_:
chunk_end = True

return chunk_end


def start_of_chunk(prev_tag, tag, prev_type, type_):
def start_of_chunk(prev_tag: str, tag: str, prev_type: str, type_: str) -> bool:
"""Checks if a chunk started between the previous and current word.
Args:
prev_tag: previous chunk tag.
Expand All @@ -100,8 +111,10 @@ def start_of_chunk(prev_tag, tag, prev_type, type_):
"""
chunk_start = False

if tag == 'B': chunk_start = True
if tag == 'S': chunk_start = True
if tag == 'B':
chunk_start = True
if tag == 'S':
chunk_start = True

if prev_tag == 'E' and tag == 'E': chunk_start = True
if prev_tag == 'E' and tag == 'I': chunk_start = True
Expand All @@ -116,27 +129,30 @@ def start_of_chunk(prev_tag, tag, prev_type, type_):
return chunk_start


def f1_score(y_true, y_pred, average='micro', suffix=False):
def f1_score(y_true: List[List[str]],
y_pred: List[List[str]],
suffix: bool = False) -> float:
"""Compute the F1 score.
The F1 score can be interpreted as a weighted average of the precision and
recall, where an F1 score reaches its best value at 1 and worst score at 0.
The relative contribution of precision and recall to the F1 score are
equal. The formula for the F1 score is::
F1 = 2 * (precision * recall) / (precision + recall)
Args:
y_true : 2d array. Ground truth (correct) target values.
y_pred : 2d array. Estimated targets as returned by a tagger.
y_true: 2d array. Ground truth (correct) target values.
y_pred: 2d array. Estimated targets as returned by a tagger.
suffix:
Returns:
score : float.
Example:
>>> from seqeval.metrics import f1_score
>>> from kashgari.metrics.sequence_labeling import f1_score
>>> y_true = [['O', 'O', 'O', 'B-MISC', 'I-MISC', 'I-MISC', 'O'], ['B-PER', 'I-PER', 'O']]
>>> y_pred = [['O', 'O', 'B-MISC', 'I-MISC', 'I-MISC', 'I-MISC', 'O'], ['B-PER', 'I-PER', 'O']]
>>> f1_score(y_true, y_pred)
0.50
"""
true_entities = set(get_entities(y_true, suffix))
pred_entities = set(get_entities(y_pred, suffix))
true_entities = set(bulk_get_entities(y_true, suffix=suffix))
pred_entities = set(bulk_get_entities(y_pred, suffix=suffix))

nb_correct = len(true_entities & pred_entities)
nb_pred = len(pred_entities)
Expand All @@ -149,7 +165,7 @@ def f1_score(y_true, y_pred, average='micro', suffix=False):
return score


def accuracy_score(y_true, y_pred):
def accuracy_score(y_true: List[List[str]], y_pred: List[List[str]]) -> float:
"""Accuracy classification score.
In multilabel classification, this function computes subset accuracy:
the set of labels predicted for a sample must *exactly* match the
Expand All @@ -160,25 +176,26 @@ def accuracy_score(y_true, y_pred):
Returns:
score : float.
Example:
>>> from seqeval.metrics import accuracy_score
>>> from kashgari.metrics.sequence_labeling import accuracy_score
>>> y_true = [['O', 'O', 'O', 'B-MISC', 'I-MISC', 'I-MISC', 'O'], ['B-PER', 'I-PER', 'O']]
>>> y_pred = [['O', 'O', 'B-MISC', 'I-MISC', 'I-MISC', 'I-MISC', 'O'], ['B-PER', 'I-PER', 'O']]
>>> accuracy_score(y_true, y_pred)
0.80
"""
if any(isinstance(s, list) for s in y_true):
y_true = [item for sublist in y_true for item in sublist]
y_pred = [item for sublist in y_pred for item in sublist]
y_true_all = [item for sublist in y_true for item in sublist]
y_pred_all = [item for sublist in y_pred for item in sublist]

nb_correct = sum(y_t == y_p for y_t, y_p in zip(y_true, y_pred))
nb_correct = sum(y_t == y_p for y_t, y_p in zip(y_true_all, y_pred_all))
nb_true = len(y_true)

score = nb_correct / nb_true

return score


def precision_score(y_true, y_pred, average='micro', suffix=False):
def precision_score(y_true: List[List[str]],
y_pred: List[List[str]],
suffix: bool = False) -> float:
"""Compute the precision.
The precision is the ratio ``tp / (tp + fp)`` where ``tp`` is the number of
true positives and ``fp`` the number of false positives. The precision is
Expand All @@ -190,14 +207,14 @@ def precision_score(y_true, y_pred, average='micro', suffix=False):
Returns:
score : float.
Example:
>>> from seqeval.metrics import precision_score
>>> from kashgari.metrics.sequence_labeling import precision_score
>>> y_true = [['O', 'O', 'O', 'B-MISC', 'I-MISC', 'I-MISC', 'O'], ['B-PER', 'I-PER', 'O']]
>>> y_pred = [['O', 'O', 'B-MISC', 'I-MISC', 'I-MISC', 'I-MISC', 'O'], ['B-PER', 'I-PER', 'O']]
>>> precision_score(y_true, y_pred)
0.50
"""
true_entities = set(get_entities(y_true, suffix))
pred_entities = set(get_entities(y_pred, suffix))
true_entities = set(bulk_get_entities(y_true, suffix=suffix))
pred_entities = set(bulk_get_entities(y_pred, suffix=suffix))

nb_correct = len(true_entities & pred_entities)
nb_pred = len(pred_entities)
Expand All @@ -207,7 +224,9 @@ def precision_score(y_true, y_pred, average='micro', suffix=False):
return score


def recall_score(y_true, y_pred, average='micro', suffix=False):
def recall_score(y_true: List[List[str]],
y_pred: List[List[str]],
suffix: bool = False) -> float:
"""Compute the recall.
The recall is the ratio ``tp / (tp + fn)`` where ``tp`` is the number of
true positives and ``fn`` the number of false negatives. The recall is
Expand All @@ -219,14 +238,14 @@ def recall_score(y_true, y_pred, average='micro', suffix=False):
Returns:
score : float.
Example:
>>> from seqeval.metrics import recall_score
>>> from kashgari.metrics.sequence_labeling import recall_score
>>> y_true = [['O', 'O', 'O', 'B-MISC', 'I-MISC', 'I-MISC', 'O'], ['B-PER', 'I-PER', 'O']]
>>> y_pred = [['O', 'O', 'B-MISC', 'I-MISC', 'I-MISC', 'I-MISC', 'O'], ['B-PER', 'I-PER', 'O']]
>>> recall_score(y_true, y_pred)
0.50
"""
true_entities = set(get_entities(y_true, suffix))
pred_entities = set(get_entities(y_pred, suffix))
true_entities = set(bulk_get_entities(y_true, suffix=suffix))
pred_entities = set(bulk_get_entities(y_pred, suffix=suffix))

nb_correct = len(true_entities & pred_entities)
nb_true = len(true_entities)
Expand All @@ -236,7 +255,8 @@ def recall_score(y_true, y_pred, average='micro', suffix=False):
return score


def performance_measure(y_true, y_pred):
def performance_measure(y_true: List[List[str]],
y_pred: List[List[str]]) -> Dict[str, int]:
"""
Compute the performance metrics: TP, FP, FN, TN
Args:
Expand All @@ -245,32 +265,32 @@ def performance_measure(y_true, y_pred):
Returns:
performance_dict : dict
Example:
>>> from seqeval.metrics import performance_measure
>>> from kashgari.metrics.sequence_labeling import performance_measure
>>> y_true = [['O', 'O', 'O', 'B-MISC', 'I-MISC', 'O', 'B-ORG'], ['B-PER', 'I-PER', 'O']]
>>> y_pred = [['O', 'O', 'B-MISC', 'I-MISC', 'I-MISC', 'O', 'O'], ['B-PER', 'I-PER', 'O']]
>>> performance_measure(y_true, y_pred)
(3, 3, 1, 4)
"""
performace_dict = dict()
if any(isinstance(s, list) for s in y_true):
y_true = [item for sublist in y_true for item in sublist]
y_pred = [item for sublist in y_pred for item in sublist]
performace_dict['TP'] = sum(y_t == y_p for y_t, y_p in zip(y_true, y_pred)
if ((y_t != 'O') or (y_p != 'O')))
performace_dict['FP'] = sum(y_t != y_p for y_t, y_p in zip(y_true, y_pred))
performace_dict['FN'] = sum(((y_t != 'O') and (y_p == 'O'))
for y_t, y_p in zip(y_true, y_pred))
performace_dict['TN'] = sum((y_t == y_p == 'O')
for y_t, y_p in zip(y_true, y_pred))

return performace_dict


def sequence_labeling_report(y_true,
y_pred,
digits=2,
suffix=False,
verbose=1):
performance_dict = dict()
y_true_all = [item for sublist in y_true for item in sublist]
y_pred_all = [item for sublist in y_pred for item in sublist]

performance_dict['TP'] = sum(y_t == y_p for y_t, y_p in zip(y_true_all, y_pred_all)
if ((y_t != 'O') or (y_p != 'O')))
performance_dict['FP'] = sum(y_t != y_p for y_t, y_p in zip(y_true_all, y_pred_all))
performance_dict['FN'] = sum(((y_t != 'O') and (y_p == 'O'))
for y_t, y_p in zip(y_true_all, y_pred_all))
performance_dict['TN'] = sum((y_t == y_p == 'O')
for y_t, y_p in zip(y_true_all, y_pred_all))

return performance_dict


def sequence_labeling_report(y_true: List[List[str]],
y_pred: List[List[str]],
digits: int = 2,
suffix: bool = False,
verbose: int = 1) -> Dict[str, Any]:
"""Build a text report showing the main classification metrics.
Args:
Expand All @@ -283,7 +303,7 @@ def sequence_labeling_report(y_true,
report: string. Text summary of the precision, recall, F1 score for each class.
Examples:
>>> from kashgari.toolkits.metrics.sequence_labeling import sequence_labeling_report
>>> from kashgari.metrics.sequence_labeling import sequence_labeling_report
>>> y_true = [['O', 'O', 'O', 'B-MISC', 'I-MISC', 'I-MISC', 'O'], ['B-PER', 'I-PER', 'O']]
>>> y_pred = [['O', 'O', 'B-MISC', 'I-MISC', 'I-MISC', 'I-MISC', 'O'], ['B-PER', 'I-PER', 'O']]
>>> report = sequence_labeling_report(y_true, y_pred)
Expand All @@ -296,8 +316,8 @@ def sequence_labeling_report(y_true,
macro avg 0.50 0.50 0.50 2
<BLANKLINE>
"""
true_entities = set(get_entities(y_true, suffix))
pred_entities = set(get_entities(y_pred, suffix))
true_entities = set(bulk_get_entities(y_true, suffix=suffix))
pred_entities = set(bulk_get_entities(y_pred, suffix=suffix))

name_width = 0
d1 = defaultdict(set)
Expand All @@ -318,16 +338,16 @@ def sequence_labeling_report(y_true,

row_fmt = u'{:>{width}s} ' + u' {:>9.{digits}f}' * 3 + u' {:>9}\n'

report_dic = {
report_dic: Dict[str, Any] = {
'detail': {}
}

ps, rs, f1s, s = [], [], [], []
for type_name, true_entities in d1.items():
pred_entities = d2[type_name]
nb_correct = len(true_entities & pred_entities)
nb_pred = len(pred_entities)
nb_true = len(true_entities)
for type_name, t_true_entities in d1.items():
t_pred_entities = d2[type_name]
nb_correct = len(t_true_entities & t_pred_entities)
nb_pred = len(t_pred_entities)
nb_true = len(t_true_entities)

p = nb_correct / nb_pred if nb_pred > 0 else 0
r = nb_correct / nb_true if nb_true > 0 else 0
Expand Down
4 changes: 2 additions & 2 deletions kashgari/tasks/labeling/abc_model.py
Expand Up @@ -19,8 +19,8 @@
from kashgari.generators import CorpusGenerator
from kashgari.processors import SequenceProcessor
from kashgari.tasks.abs_task_model import ABCTaskModel
from kashgari.toolkits.metrics.sequence_labeling import get_entities
from kashgari.toolkits.metrics.sequence_labeling import sequence_labeling_report
from kashgari.metrics.sequence_labeling import get_entities
from kashgari.metrics.sequence_labeling import sequence_labeling_report
from kashgari.types import TextSamplesVar

if TYPE_CHECKING:
Expand Down

0 comments on commit 2b26333

Please sign in to comment.