In [1]:
from bedoner.models import *

In [2]:
from bedoner.ner_labels.utils import make_biluo_labels
from bedoner.ner_labels.labels_irex import ALL_LABELS
from spacy.util import minibatch
from tqdm import tqdm
from pathlib import Path

In [3]:
import json
import random

In [13]:
data=[]
with (Path.home() / "datasets/ner/gsk-ene-1.1-bccwj/irex/irex-positive.jsonl").open() as f:
    for i,line in enumerate(f):
        data.append(json.loads(line))

In [14]:
ntrain, neval=10000,100
random.shuffle(data)
train_data=data[:ntrain]
val_data=data[-neval:]

In [39]:
nlp=bert_ner(labels=make_biluo_labels(ALL_LABELS))

# eval

In [40]:
from spacy.gold import spans_from_biluo_tags, GoldParse
from itertools import zip_longest

def is_same(ents1, ents2):
    for e, e2 in zip_longest(ents1,ents2):
        if e != e2:
            return False
    return True

texts, golds = zip(*val_data)
def val(nlp):
    docs=list(nlp.pipe(texts))
    gs=[GoldParse(doc, **gold) for doc,gold in zip(docs,golds)]
    entsl=[spans_from_biluo_tags(doc,g.ner) for g,doc in zip(gs,docs)]
    return sum(is_same(doc.ents, ents) for doc, ents in zip(docs,entsl))

In [41]:
m=nlp.pipeline[1][1]
m.optim_parameters = lambda: []

In [42]:
from torch.optim import SGD

In [43]:
niter=20
nbatch=16
ndata=ntrain
optim=nlp.resume_training(t_total=niter, enable_scheduler=False)

In [44]:
for i in range(niter):
    random.shuffle(train_data)
    epoch_loss=0
    for i,batch in enumerate(minibatch(train_data, size=nbatch)):
        texts, golds=zip(*batch)
        docs=[nlp.make_doc(text) for text in texts]
        nlp.update(docs, golds,optim)
        loss = sum(doc._.loss.detach().item() for doc in docs)
        epoch_loss += loss
        print(f"{i*nbatch}/{ndata} loss: {loss}")
        if i % 10 == 9:
            acc = val(nlp)
            print(f"epoch {i} val: ", acc /neval)
    print(f"epoch {i} loss: ", epoch_loss)
    nlp.to_disk(f"irex{i}")

0/100 loss: 37.774184226989746
10/100 loss: 31.306445121765137
20/100 loss: 28.123074293136597
30/100 loss: 26.081987023353577
40/100 loss: 24.230735659599304
50/100 loss: 19.343295335769653
60/100 loss: 18.15466547012329
70/100 loss: 13.455450296401978
80/100 loss: 18.318995475769043
90/100 loss: 16.408040523529053
epoch 9 val:  0.0
epoch 9 loss:  tensor(233.1969, grad_fn=<AddBackward0>)
tensor(233.1969, grad_fn=<AddBackward0>)
0/100 loss: 12.276712656021118
10/100 loss: 18.361782371997833
20/100 loss: 15.03300678730011
30/100 loss: 12.41399684548378
40/100 loss: 12.147306650876999
50/100 loss: 18.530424565076828
60/100 loss: 10.687222480773926
70/100 loss: 16.670929610729218
80/100 loss: 8.674575388431549
90/100 loss: 9.679763674736023
epoch 9 val:  0.0
epoch 9 loss:  tensor(134.4757, grad_fn=<AddBackward0>)
tensor(134.4757, grad_fn=<AddBackward0>)
0/100 loss: 14.496490597724915
10/100 loss: 11.225960075855255
20/100 loss: 10.024987608194351
30/100 loss: 10.986346185207367
40/100 los

KeyboardInterrupt: 