# Training a new model


In [1]:
%load_ext autoreload
%autoreload 2

import sys
sys.path.append("..")

from hc_nlp.pipeline import ThesaurusMatcher, EntityFilter, MapEntityTypes, DateMatcher
from hc_nlp.model_testing import test_ner
from hc_nlp.io import load_text_and_annotations_from_labelstudio
from hc_nlp.spacy_helpers import correct_entity_boundaries
from hc_nlp import constants

import pprint
pp = pprint.PrettyPrinter(indent=2)

import pandas as pd
import time

import random
import warnings

import spacy
from spacy.util import minibatch, compounding


In [2]:
nlp = spacy.load('en_core_web_lg')

## 1. Create training set
This is all the data that isn't in the test set at `../data/TEST_SET_2020-12-10-12-43-04.zip`

In [3]:
test_data = load_text_and_annotations_from_labelstudio('../data/TEST_SET_2020-12-10-12-43-04.zip', spacy_model=nlp)
test_text = [i[0] for i in test_data]

In [4]:
test_data[0], test_text[0]

(('Soup plate, white ceramic, Newhaven to Dieppe service logo at top. Made by C Mc D Mann & Co Ltd, Hanley. Overall: 44 mm x 253 mm, 0.73kg.',
  [(97, 103, 'LOC'), (27, 35, 'LOC'), (39, 45, 'LOC'), (75, 95, 'ORG')]),
 'Soup plate, white ceramic, Newhaven to Dieppe service logo at top. Made by C Mc D Mann & Co Ltd, Hanley. Overall: 44 mm x 253 mm, 0.73kg.')

In [5]:
text = pd.read_json("../data/text_all.json")

In [6]:
# text_test may contain duplicates as the descriptions come from labelled data, but text_train will not
text_train = text[~text['text'].isin(test_text)].drop_duplicates(subset='text')
text_test = text[text['text'].isin(test_text)]

len(text), len(text_train), len(text_test)

(291620, 264098, 2297)

In [7]:
train_size = None

text_train = text_train.sample(train_size, random_state=42) if train_size is not None else text_train

## 2. Train new model

### 2.1 Initialise model with rule- and thesaurus-based matching

- `nlp` is our out-of-the-box model
- `nlp_thes` is our model with the additional components

In [8]:
# model with rules for dates before & thesaurus after, with overwrite
datematcher = DateMatcher(nlp)
thes_ow = ThesaurusMatcher(nlp, thesaurus_path="../data/labels_all_unambiguous_types_people_orgs.jsonl", 
                                  case_sensitive=False, overwrite_ents=True)
entityfilter = EntityFilter(ent_labels_ignore=['DATE'])
mapentitytypes = MapEntityTypes(nlp, validate_mapping=False)

nlp_thes = spacy.load("en_core_web_lg")
nlp_thes.add_pipe(datematcher, before='ner')
nlp_thes.add_pipe(thes_ow, after='ner')
nlp_thes.add_pipe(entityfilter, last=True)
# nlp_thes.add_pipe(mapentitytypes)

# nlp.add_pipe(mapentitytypes)

nlp.pipe_names, nlp_thes.pipe_names

2021-01-05 16:32:17,575 - hc_nlp.pipeline - INFO - Loading thesaurus from ../data/labels_all_unambiguous_types_people_orgs.jsonl
2021-01-05 16:32:22,993 - hc_nlp.pipeline - INFO - 17016 term thesaurus imported in 5s


(['tagger', 'parser', 'ner'],
 ['tagger',
  'parser',
  'DateMatcher',
  'ner',
  'ThesaurusMatcher',
  'EntityFilter'])

### 2.2 Create training data

In [96]:
def get_entity_list(doc, correct=True):
    entity_list = []
    
    for ent in doc.ents:
        start = doc[ent.start].idx

        # TODO: correct end in hc-nlp
        end = doc[ent.end-1].idx + len(doc[ent.end-1].text)
        entity_list.append((start, end, ent.label_))
        
    if correct:
        entity_list_old = entity_list
        entity_list = correct_entity_boundaries(nlp, doc.text, entity_list)

        if entity_list_old != entity_list:
            print(entity_list_old, entity_list, doc.text)

    return entity_list

text = "Who is Shaka Khan?"
doc = nlp_thes(text)

get_entity_list(doc)

[(7, 17, 'PERSON')]

In [89]:
TRAIN_DATA = []
model = nlp_thes

start = time.time()
for doc in model.pipe(text_train['text'].head(5000).tolist()):
    TRAIN_DATA.append(
        (doc.text, {"entities": get_entity_list(doc)})
    )
end = time.time()

print(f"{len(TRAIN_DATA)} records created in {int(end-start)} seconds")

5000 records created in 47 seconds


### 2.3 Train new model

In [92]:
nlp_new = spacy.load('en_core_web_lg')
ner = nlp_new.get_pipe("ner")

n_iter = 100

In [93]:
required_pipes = ['ner']
disable_pipes = [pipe for pipe in nlp_new.pipe_names if pipe not in required_pipes]

with nlp_new.disable_pipes(*disable_pipes), warnings.catch_warnings():
    # show warnings for misaligned entity spans once
    warnings.filterwarnings("once", category=UserWarning, module='spacy')
    
    for itn in range(n_iter):
        random.shuffle(TRAIN_DATA)
        losses = {}
        
        batches = minibatch(TRAIN_DATA, size=compounding(4.0, 32.0, 1.001))
        for batch in batches:
            texts, annotations = zip(*batch)
            try:
                nlp_new.update(
                    texts,
                    annotations,
                    drop=0.5, # dropout for regularization
                    losses=losses,
                )
            # TODO: drop this except and let spaCy handle the errors
            except:
                print(batch)
                print()
                break

        print("Losses", losses)
        

Losses {'ner': 8752.958746125605}
Losses {'ner': 6644.080891124536}
Losses {'ner': 6149.880774482395}
Losses {'ner': 5765.229031543167}
Losses {'ner': 5513.222365782627}
Losses {'ner': 5328.265107249336}
Losses {'ner': 5129.718474322531}
Losses {'ner': 4897.516061056674}
Losses {'ner': 4790.09224472084}
Losses {'ner': 4546.9993544531935}
Losses {'ner': 4522.889306091676}
Losses {'ner': 4303.650312348633}
Losses {'ner': 4120.754619620491}
Losses {'ner': 4185.125311100743}
Losses {'ner': 4085.752658769397}
Losses {'ner': 3987.9173164355343}
Losses {'ner': 3840.4816759482833}
Losses {'ner': 3851.970716302068}
Losses {'ner': 3708.6101216559177}
Losses {'ner': 3638.098496863698}
Losses {'ner': 3540.1594012227843}
Losses {'ner': 3516.0517860667323}
Losses {'ner': 3416.2412796115245}
Losses {'ner': 3299.033486725564}
Losses {'ner': 3297.038785775817}
Losses {'ner': 3308.0953211830397}
Losses {'ner': 3206.700624895452}
Losses {'ner': 3206.3090628694317}
Losses {'ner': 3019.968656119732}
Losses