Skip to content

Commit

Permalink
add voting picking label for bert_label2token
Browse files Browse the repository at this point in the history
  • Loading branch information
Ubuntu committed Dec 17, 2018
1 parent 7f6e3a5 commit c3b794d
Show file tree
Hide file tree
Showing 3 changed files with 138 additions and 36 deletions.
125 changes: 103 additions & 22 deletions examples/factrueval.ipynb
Expand Up @@ -91,7 +91,7 @@
},
{
"cell_type": "code",
"execution_count": 24,
"execution_count": 5,
"metadata": {},
"outputs": [
{
Expand Down Expand Up @@ -175,7 +175,7 @@
},
{
"cell_type": "code",
"execution_count": 15,
"execution_count": 7,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -184,7 +184,7 @@
},
{
"cell_type": "code",
"execution_count": 16,
"execution_count": 8,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -193,7 +193,7 @@
},
{
"cell_type": "code",
"execution_count": 17,
"execution_count": 9,
"metadata": {},
"outputs": [
{
Expand Down Expand Up @@ -221,7 +221,7 @@
")"
]
},
"execution_count": 17,
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
Expand Down Expand Up @@ -261,7 +261,7 @@
},
{
"cell_type": "code",
"execution_count": 19,
"execution_count": 10,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -270,7 +270,7 @@
},
{
"cell_type": "code",
"execution_count": 20,
"execution_count": 11,
"metadata": {},
"outputs": [],
"source": [
Expand Down Expand Up @@ -310,7 +310,7 @@
},
{
"cell_type": "code",
"execution_count": 22,
"execution_count": 12,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -319,7 +319,7 @@
},
{
"cell_type": "code",
"execution_count": 26,
"execution_count": 13,
"metadata": {},
"outputs": [
{
Expand Down Expand Up @@ -350,7 +350,7 @@
},
{
"cell_type": "code",
"execution_count": 27,
"execution_count": 14,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -359,7 +359,7 @@
},
{
"cell_type": "code",
"execution_count": 28,
"execution_count": 15,
"metadata": {},
"outputs": [
{
Expand Down Expand Up @@ -397,7 +397,7 @@
},
{
"cell_type": "code",
"execution_count": 29,
"execution_count": 17,
"metadata": {},
"outputs": [
{
Expand Down Expand Up @@ -440,6 +440,78 @@
"print(validate_step(learner.data.valid_dl, learner.model, learner.data.id2label, learner.sup_labels))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Tokens report"
]
},
{
"cell_type": "code",
"execution_count": 72,
"metadata": {},
"outputs": [],
"source": [
"from sklearn_crfsuite.metrics import flat_classification_report"
]
},
{
"cell_type": "code",
"execution_count": 73,
"metadata": {},
"outputs": [],
"source": [
"from modules.utils.utils import bert_labels2tokens"
]
},
{
"cell_type": "code",
"execution_count": 74,
"metadata": {},
"outputs": [],
"source": [
"pred_tokens, pred_labels = bert_labels2tokens(dl, preds)\n",
"true_tokens, true_labels = bert_labels2tokens(dl, [x.labels for x in dl.dataset])"
]
},
{
"cell_type": "code",
"execution_count": 75,
"metadata": {},
"outputs": [],
"source": [
"assert pred_tokens == true_tokens\n",
"tokens_report = flat_classification_report(true_labels, pred_labels)"
]
},
{
"cell_type": "code",
"execution_count": 76,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
" precision recall f1-score support\n",
"\n",
" I_LOC 0.93 0.90 0.92 230\n",
" I_O 0.99 0.99 0.99 7203\n",
" I_ORG 0.92 0.87 0.89 543\n",
" I_PER 0.98 0.98 0.98 321\n",
"\n",
" micro avg 0.98 0.98 0.98 8297\n",
" macro avg 0.96 0.94 0.95 8297\n",
"weighted avg 0.98 0.98 0.98 8297\n",
"\n"
]
}
],
"source": [
"print(tokens_report)"
]
},
{
"cell_type": "markdown",
"metadata": {},
Expand All @@ -449,30 +521,39 @@
},
{
"cell_type": "code",
"execution_count": 30,
"execution_count": 79,
"metadata": {},
"outputs": [],
"source": [
"from modules.utils.utils import voting_choicer"
]
},
{
"cell_type": "code",
"execution_count": 80,
"metadata": {
"scrolled": true
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
" precision recall f1-score support\n",
"\n",
" PER 0.911 0.926 0.918 188\n",
" LOC 0.853 0.849 0.851 192\n",
" ORG 0.799 0.830 0.814 259\n",
" ORG 0.809 0.834 0.821 259\n",
" LOC 0.851 0.859 0.855 192\n",
" PER 0.936 0.936 0.936 188\n",
"\n",
" micro avg 0.848 0.864 0.856 639\n",
" macro avg 0.855 0.868 0.861 639\n",
"weighted avg 0.848 0.864 0.856 639\n",
" micro avg 0.858 0.872 0.865 639\n",
" macro avg 0.865 0.877 0.871 639\n",
"weighted avg 0.859 0.872 0.865 639\n",
"\n"
]
}
],
"source": [
"from modules.utils.plot_metrics import get_bert_span_report\n",
"clf_report = get_bert_span_report(dl, preds)\n",
"print(clf_report)"
"print(get_bert_span_report(dl, preds, fn=voting_choicer))"
]
},
{
Expand Down
8 changes: 4 additions & 4 deletions modules/utils/plot_metrics.py
@@ -1,7 +1,7 @@
import numpy as np
from collections import defaultdict
from matplotlib import pyplot as plt
from .utils import tokens2spans, bert_labels2tokens
from .utils import tokens2spans, bert_labels2tokens, voting_choicer, first_choicer
from sklearn_crfsuite.metrics import flat_classification_report


Expand Down Expand Up @@ -63,10 +63,10 @@ def get_mean_max_metric(history, metric_="f1", return_idx=False):
return res


def get_bert_span_report(dl, preds, ignore_labels=["O"]):
tokens, labels = bert_labels2tokens(dl, preds)
def get_bert_span_report(dl, preds, ignore_labels=["O"], fn=first_choicer):
tokens, labels = bert_labels2tokens(dl, preds, fn)
spans_pred = tokens2spans(tokens, labels)
tokens, labels = bert_labels2tokens(dl, [x.labels for x in dl.dataset])
tokens, labels = bert_labels2tokens(dl, [x.labels for x in dl.dataset], fn)
spans_true = tokens2spans(tokens, labels)
set_labels = set()
for idx in range(len(spans_pred)):
Expand Down
41 changes: 31 additions & 10 deletions modules/utils/utils.py
@@ -1,24 +1,45 @@
import sys
import __main__ as main
from collections import Counter


def ipython_info():
return hasattr(main, '__file__')


def bert_labels2tokens(dl, labels):
def voting_choicer(tok_map, labels):
label = []
prev_idx = 0
for origin_idx in tok_map:

vote_labels = Counter(["I_" + l.split("_")[1] if l not in ["[SEP]", "[CLS]"] else "B_O" for l in labels[prev_idx:origin_idx]])
# vote_labels = Counter(c)
label.append(sorted(list(vote_labels), key=lambda x: vote_labels[x])[-1])
prev_idx = origin_idx
if origin_idx < 0:
break
return label


def first_choicer(tok_map, labels):
label = []
prev_idx = 0
for origin_idx in tok_map:
if labels[prev_idx] in ["I_O", "[SEP]", "[CLS]"]:
labels[prev_idx] = "B_O"
label.append(labels[prev_idx])
prev_idx = origin_idx
if origin_idx < 0:
break
return label


def bert_labels2tokens(dl, labels, fn=voting_choicer):
res_tokens = []
res_labels = []
for f, l in zip(dl.dataset, labels):
label = []
prev_idx = 0
for origin_idx in f.tok_map:
if l[prev_idx] in ["I_O", "[SEP]", "[CLS]"]:
l[prev_idx] = "B_O"
label.append(l[prev_idx])
prev_idx = origin_idx
if origin_idx < 0:
break
label = fn(f.tok_map, l)

res_tokens.append(f.tokens[1:-1])
res_labels.append(label[1:])
return res_tokens, res_labels
Expand Down

0 comments on commit c3b794d

Please sign in to comment.