In [1]:
from bedoner.models import *

In [2]:
from bedoner.ner_labels.utils import make_biluo_labels
from bedoner.ner_labels.labels_irex import ALL_LABELS
from spacy.util import minibatch
from tqdm import tqdm

In [3]:
import json
import random

In [4]:
data=[]
ndata=10000
with open("../data/gsk-ene-1.1-bccwj/irex/irex.jsonl") as f:
    for i,line in enumerate(f):
        if i == ndata:
            break
        data.append(json.loads(line))

In [6]:
nlp=pytt_bert_ner(labels=make_biluo_labels(ALL_LABELS))

In [None]:
niter=5
optim=nlp.resume_training(t_total=niter, enable_scheduler=False)
losses=[]
for i in range(niter):
    random.shuffle(data)
    epoch_loss=0
    for batch in tqdm(minibatch(data, size=32)):
        texts, golds=zip(*batch)
        docs=[nlp.make_doc(text) for text in texts]
        nlp.update(docs, golds,optim)
        epoch_loss+=sum(doc._.loss for doc in docs)
    nlp.to_disk(f"irex{i}")
    losses.append(epoch_loss)
    print(epoch_loss)

50it [50:47, 20.58s/it] 

# eval

In [44]:
data=[]
ntestdata=100
with open("../data/gsk-ene-1.1-bccwj/irex/irex.jsonl") as f:
    for i,line in enumerate(f):
        if i < ndata:
            continue
        if i == ndata+ntestdata:
            break
        data.append(json.loads(line))

In [45]:
from spacy.gold import spans_from_biluo_tags, GoldParse
from itertools import zip_longest

In [46]:
def is_same(ents1, ents2):
    for e, e2 in zip_longest(ents1,ents2):
        if e != e2:
            return False
    return True

In [57]:
count=0
acc=0
fails=[]
for text, gold in tqdm(data):
    if not len(gold["entities"]):
        continue
    doc=nlp(text)
    g=GoldParse(doc, **gold)
    ents=spans_from_biluo_tags(doc,g.ner)
    count += 1
    if is_same(doc.ents, ents):
        acc += 1
    else:
        fails.append((text,gold,doc.ents))


  0%|          | 0/100 [00:00<?, ?it/s][A
  8%|▊         | 8/100 [00:00<00:01, 71.90it/s][A
 11%|█         | 11/100 [00:00<00:02, 35.12it/s][A
 18%|█▊        | 18/100 [00:00<00:02, 39.97it/s][A
 22%|██▏       | 22/100 [00:00<00:02, 37.15it/s][A
 28%|██▊       | 28/100 [00:00<00:01, 41.09it/s][A
 36%|███▌      | 36/100 [00:00<00:01, 47.05it/s][A
 41%|████      | 41/100 [00:00<00:01, 35.75it/s][A
 46%|████▌     | 46/100 [00:01<00:01, 33.78it/s][A
 53%|█████▎    | 53/100 [00:01<00:01, 36.45it/s][A
 57%|█████▋    | 57/100 [00:01<00:01, 35.50it/s][A
 61%|██████    | 61/100 [00:01<00:01, 30.85it/s][A
 72%|███████▏  | 72/100 [00:01<00:00, 36.38it/s][A
 81%|████████  | 81/100 [00:01<00:00, 43.88it/s][A
 88%|████████▊ | 88/100 [00:02<00:00, 43.29it/s][A
 94%|█████████▍| 94/100 [00:02<00:00, 36.48it/s][A
100%|██████████| 100/100 [00:02<00:00, 41.49it/s][A


In [58]:
count,acc, acc/count

(35, 9, 0.2571428571428571)

In [59]:
for e in fails[8][2]:
    print(e.text, e.label_)

In [60]:
for i,f in enumerate(fails):
    print(i,f)

0 ('２交代制の-結構きつい仕事を-文句ひとつ言わずに頑張ってくれている主人に対して申し訳なさで一杯です。', {'entities': [[0, 4, 'ARTIFACT']]}, ())
1 ('ジャパンネットバンクの「郵貯Ｗｅｂ送金」のことですよね？', {'entities': [[0, 10, 'ORGANIZATION'], [12, 19, 'ARTIFACT']]}, (郵貯Ｗｅｂ送金,))
2 ('ＶＩＳＡとマスターカードの違いってなんですか？？？？？', {'entities': [[0, 4, 'ARTIFACT'], [5, 12, 'ARTIFACT']]}, (ＶＩＳＡ,))
3 ('ＶＩＳＡ社とマスターカード社。', {'entities': [[0, 5, 'ORGANIZATION'], [6, 14, 'ORGANIZATION']]}, (ＶＩＳＡ,))
4 ('ＧＷ中も情報収集を欠かさずやっているのですか？', {'entities': [[0, 3, 'DATE']]}, (ＧＷ,))
5 ('例えば１＄＝１００円だったとして、これが１＄＝１１０円になったとします。', {'entities': [[3, 5, 'MONEY'], [6, 10, 'MONEY'], [20, 22, 'MONEY'], [23, 27, 'MONEY']]}, (１００円, １１０円))
6 ('僕が生まれる前にあったバブル期と比べればすごく円安です。', {'entities': [[11, 15, 'DATE']]}, ())
7 ('「ＣＨＯＫＫＡ」の平成電電関連の上場銘柄はありますか？', {'entities': [[1, 7, 'ARTIFACT'], [9, 13, 'ORGANIZATION']]}, (ＣＨＯＫＫＡ,))
8 ('ドリテク・・・・・・・・・・・・・・・・・・・・・・・・・・・・・・。', {'entities': [[0, 4, 'ORGANIZATION']]}, ())
9 ('日経ビジネス読まれてる方。', {'entities': [[0, 6, 'ARTIFACT']]}, (日経,))
10 ('来年の春までは、まだまだイケイケですかね〜', {'entities': [[3, 4, 'DAT