In [1]:
import spacy
import random
from spacy.util import minibatch, compounding

In [2]:
def iobToSpacySimple(iobDataset):
    
    TRAIN_DATA = []
    
    with open(iobDataset, 'r') as fp:
        for line in fp:    
            wordTagPairsList = []
            
            wordTagStringList = line.split(' ')[:-1]
            
            for wordTagString in wordTagStringList:
                wordTagPairList = wordTagString.split('|')
                wordTagPairsList.append(wordTagPairList)
            
            entityStartIndex = 0
            entityDict = {}
            entityList = []
            sentText = ''
            currentIndex = 0
            
            for index, pair in enumerate(wordTagPairsList):
                
                if pair[1].startswith('B'):
                    entityEndIndex = currentIndex + len(pair[0])
                    entityStartIndex = currentIndex
                    
                    for pairInner in wordTagPairsList[index + 1:]:
                        if pairInner[1].startswith('O'):
                            break
                        
                        entityEndIndex += (len(pairInner[0]) + 1)
                    
                    entityTuple = (entityStartIndex, entityEndIndex, pair[1][2:])
                    entityList.append(entityTuple) 
                
                sentText += pair[0] + ' '             
                currentIndex += len(pair[0]) + 1
                
            entityDict['entities'] = entityList
            trainTuple = (sentText, entityDict)
            TRAIN_DATA.append(trainTuple)
    
    #Remove list elemens in TRAIN_DATA with overlapping entities
    overlappingIndexes = []
    for index, tup in enumerate(TRAIN_DATA):
        endIndex = -1
        _, annotation = tup
        for ent in annotation.get('entities'):
            if (ent[0] < endIndex):
                overlappingIndexes.append(index)
            endIndex = ent[1]
    indexAdjustment = 0        
    for ind in overlappingIndexes:
        TRAIN_DATA.pop(ind - indexAdjustment)
        indexAdjustment += 1
    return TRAIN_DATA

In [8]:
def trainNer(TRAIN_DATA, pathToModel, iterations):
    nlp = spacy.load(pathToModel)
    
    if 'ner' not in nlp.pipe_names:
        ner = nlp.create_pipe('ner')
        nlp.add_pipe(ner, last=True)
    
    for _, annotations in TRAIN_DATA:
         for ent in annotations.get('entities'):
            ner.add_label(ent[2])
    
    other_pipes = [pipe for pipe in nlp.pipe_names if pipe != 'ner']
    with nlp.disable_pipes(*other_pipes):
        nlp.begin_training()
        
        for itn in range(iterations):
            print("Statring iteration " + str(itn))
            random.shuffle(TRAIN_DATA)
            losses = {}
            batches = minibatch(TRAIN_DATA, size=compounding(4.0, 32.0, 1.001))
            for batch in batches:
                texts, annotations = zip(*batch)
                nlp.update(
                    texts,  # batch of texts
                    annotations,  # batch of annotations
                    drop=0.5,  # dropout - make it harder to memorise data
                    losses=losses,
                )
            
            print(losses)
            
    return nlp
        

In [9]:
TRAIN_DATA = iobToSpacySimple('../Data/DatasetsFormatIOB/SUC30IOB2SelTags')
nlp = trainNer(TRAIN_DATA, '../Models/SpacySwedishFastText', 20)

Statring iteration 0
{'ner': 40468.645476032034}
Statring iteration 1
{'ner': 29233.071479270515}
Statring iteration 2
{'ner': 25581.238843535273}
Statring iteration 3
{'ner': 23360.7814324469}
Statring iteration 4
{'ner': 21573.62435180597}
Statring iteration 5
{'ner': 20378.82527422333}
Statring iteration 6
{'ner': 19362.70593440347}
Statring iteration 7
{'ner': 18704.13575881814}
Statring iteration 8
{'ner': 17825.65488828454}
Statring iteration 9
{'ner': 17392.407668748587}
Statring iteration 10
{'ner': 16803.774388647424}
Statring iteration 11
{'ner': 16289.306068179085}
Statring iteration 12
{'ner': 16083.101346311776}
Statring iteration 13
{'ner': 15532.147737091058}
Statring iteration 14
{'ner': 15351.195557645808}
Statring iteration 15
{'ner': 14903.363955556664}
Statring iteration 16
{'ner': 14695.928080697893}
Statring iteration 17
{'ner': 14425.787008845451}
Statring iteration 18
{'ner': 14234.783160544283}
Statring iteration 19
{'ner': 13939.353874403978}


In [13]:
doc = nlp('Våla härad omfattade socknarna Östervåla, Nora, Harbo och Huddunge.Häradsrätten upphörde vid utgången av år 1887 och uppgick i Västmanlands östra domsagas häradsrätt.')
for ent in doc.ents:
    print(ent)

Östervåla
Nora
Harbo
Huddunge
år 1887
Västmanlands
