In [136]:
from __future__ import unicode_literals, print_function
import pickle
import plac
import random
from pathlib import Path
import spacy
from spacy.util import minibatch, compounding
import srsly
import pickle
import os
import csv
import numpy as np
#!python -m spacy download en_core_web_md 

In [144]:
# creating a folder to hold the data files, model results, output etc.
model_dir_name = "custom_entity_extractor"
model_dir_parent = "C://Users//buchh//OneDrive//Desktop"
output_dir = os.path.join(model_dir_parent, model_dir_name)
os.makedirs(output_dir, exist_ok=True)
print("created", output_dir)

In [138]:
# Create Training data
# 70% training data and 30% validation/testing data

# spacy requires data to be in a specific format which is what the code below does. Here is the acceptable formart:
# [('Heat stress and persistent dehydration can cause kidney damage. IMPLIED_BASE IMPLIED_BASE', {'entities': [(5, 11, 'base'), (27, 38, 'base'), (49, 55, 'base')]})]

file_name_answers = "checkin_answers"
file_path_answers = "C://Users//buchh//OneDrive/Desktop//"+file_name_answers+".jsonl"
data = srsly.read_jsonl(file_path_answers)
final_sent = []

for entry in data:
    if "text" in entry:
        text = entry["text"]
    else:
        text = ""
        throw("NO 'text' field encountered! This field is necessary for the rest of the script to work! Please fix this and then run this script.")

    label_arr = []
    label_tup = ()
    if entry['answer'] == "accept":
        for relation in entry['spans']:
            if ("label" in relation) and ("start" in relation) and ("end" in relation):
                child_span_start = relation["start"]
                child_span_end = relation["end"]
                word = text[child_span_start:child_span_end]
                if relation["label"] == "base":
                    tmp_tuple = (child_span_start, child_span_end, relation["label"])
                    label_arr.append(tmp_tuple)

    label_tup = (text, {"entities": label_arr})
    final_sent.append(label_tup)

training_data_index = int(len(final_sent)*0.3)

with open(os.path.join(output_dir, 'training_data.txt'), 'wb') as file:
    pickle.dump(final_sent[:training_data_index], file)
    
with open(os.path.join(output_dir, 'testing_data.txt'), 'wb') as file:
    pickle.dump(final_sent[training_data_index:], file)

In [139]:
LABEL = ['base']

with open (os.path.join(output_dir, 'training_data.txt'), 'rb') as file:
    TRAIN_DATA = pickle.load(file)

with open (os.path.join(output_dir, 'testing_data.txt'), 'rb') as file:
    TEST_DATA = pickle.load(file)

In [140]:
model = "en_core_web_md"
new_model_name = "custom_entity_extraction"
n_iter = 1000

In [141]:
if model is not None:
    nlp = spacy.load(model)
    print("Loaded model '%s'" % model)
else:
    nlp = spacy.blank('en')
    print("Created blank 'en' model")
    
if 'ner' not in nlp.pipe_names:
    ner = nlp.create_pipe('ner')
    nlp.add_pipe(ner)
else:
    ner = nlp.get_pipe('ner')

for i in LABEL:
    # Adding new entity
    ner.add_label(i)

if model is None:
    optimizer = nlp.begin_training()
else:
    optimizer = nlp.entity.create_optimizer()

# Get names of other pipes to disable them during training to train only NER
other_pipes = [pipe for pipe in nlp.pipe_names if pipe != 'ner']
with nlp.disable_pipes(*other_pipes):  # only train NER
    for itn in range(n_iter):
        random.shuffle(TRAIN_DATA)
        losses = {}
        batches = minibatch(TRAIN_DATA, size=compounding(4., 32., 1.001))
        for batch in batches:
            texts, annotations = zip(*batch)
            nlp.update(texts, annotations, sgd=optimizer, drop=0.2, losses=losses)
        print('Losses', losses)

# Save model 
if output_dir is not None:
    output_dir = Path(output_dir)
    if not output_dir.exists():
        output_dir.mkdir()
    nlp.meta['name'] = new_model_name  # rename model
    nlp.to_disk(output_dir)
    print("Saved model to", output_dir)

Loaded model 'en_core_web_md'
Losses {'ner': 1050.1765359640121}
Losses {'ner': 924.4315682500601}
Losses {'ner': 775.7670805826783}
Losses {'ner': 747.6329988241196}
Losses {'ner': 744.945388393011}
Losses {'ner': 668.8180352461059}
Losses {'ner': 652.2744732196516}
Losses {'ner': 599.4143710053177}
Losses {'ner': 561.4986105194994}
Losses {'ner': 581.6835314521741}
Losses {'ner': 567.6546531309586}
Losses {'ner': 552.918059484994}
Losses {'ner': 537.5492045386892}
Losses {'ner': 491.6366693851564}
Losses {'ner': 503.89471460808454}
Losses {'ner': 503.9402866243581}
Losses {'ner': 489.0869802023981}
Losses {'ner': 509.7391119321983}
Losses {'ner': 480.8796471943824}
Losses {'ner': 487.97410559051787}
Losses {'ner': 482.59252879705673}
Losses {'ner': 476.18344833268566}
Losses {'ner': 429.16890970396344}
Losses {'ner': 489.09408143326436}
Losses {'ner': 485.4263989555184}
Losses {'ner': 507.39537772902986}
Losses {'ner': 488.86752171972876}
Losses {'ner': 500.4962157192349}
Losses {'ne

Losses {'ner': 389.619315656455}
Losses {'ner': 412.0240764591408}
Losses {'ner': 374.73955387485694}
Losses {'ner': 418.0651379770716}
Losses {'ner': 417.1892020589148}
Losses {'ner': 416.17797461582813}
Losses {'ner': 388.08272186068643}
Losses {'ner': 394.5200201686421}
Losses {'ner': 408.14621790498904}
Losses {'ner': 406.7582841566764}
Losses {'ner': 373.89383524204095}
Losses {'ner': 410.6861808561953}
Losses {'ner': 410.0100015248354}
Losses {'ner': 398.3820003117726}
Losses {'ner': 400.9443883430795}
Losses {'ner': 355.71964893742245}
Losses {'ner': 417.6988459236791}
Losses {'ner': 412.93654895933287}
Losses {'ner': 406.4276845443528}
Losses {'ner': 393.16487875796383}
Losses {'ner': 394.0740433170722}
Losses {'ner': 420.72640501950264}
Losses {'ner': 410.4299867677764}
Losses {'ner': 421.67421904497314}
Losses {'ner': 402.9190059893806}
Losses {'ner': 401.8307788405182}
Losses {'ner': 432.0752811433049}
Losses {'ner': 382.57045997342937}
Losses {'ner': 387.69421905584204}
Los

Losses {'ner': 422.4062588561792}
Losses {'ner': 418.50966013756624}
Losses {'ner': 405.53156171205046}
Losses {'ner': 402.1248072385788}
Losses {'ner': 424.36865452991697}
Losses {'ner': 352.5720933144803}
Losses {'ner': 405.7786371095191}
Losses {'ner': 412.7263518415639}
Losses {'ner': 443.2962375921252}
Losses {'ner': 431.53844522818235}
Losses {'ner': 444.68084190020727}
Losses {'ner': 489.989705677277}
Losses {'ner': 381.83615658884446}
Losses {'ner': 395.6130304449507}
Losses {'ner': 364.3313596138323}
Losses {'ner': 396.86332926961325}
Losses {'ner': 391.6988048254573}
Losses {'ner': 395.4435127128381}
Losses {'ner': 416.67769038758706}
Losses {'ner': 359.4597391951065}
Losses {'ner': 369.76665423229883}
Losses {'ner': 378.48824351471626}
Losses {'ner': 406.3602748606042}
Losses {'ner': 371.6429631860441}
Losses {'ner': 356.19541569136254}
Losses {'ner': 410.1106550359365}
Losses {'ner': 378.6905965666909}
Losses {'ner': 395.4025647428789}
Losses {'ner': 400.1344664440938}
Loss

Losses {'ner': 413.634201331594}
Losses {'ner': 408.1818405566611}
Losses {'ner': 371.83912291751767}
Losses {'ner': 388.15421620605775}
Losses {'ner': 357.9153829163172}
Losses {'ner': 403.6063999343719}
Losses {'ner': 427.5461749615831}
Losses {'ner': 347.2248698166675}
Losses {'ner': 356.5225587328973}
Losses {'ner': 405.2072965711732}
Losses {'ner': 370.2417946958376}
Losses {'ner': 393.17412297985476}
Losses {'ner': 412.94140098727075}
Losses {'ner': 407.53716361631086}
Losses {'ner': 404.2360964769414}
Losses {'ner': 371.6553256330735}
Losses {'ner': 406.17716284384073}
Losses {'ner': 395.1843309670785}
Losses {'ner': 388.3159817339747}
Losses {'ner': 397.0468021833764}
Losses {'ner': 362.5263695483495}
Losses {'ner': 445.98219318795407}
Losses {'ner': 397.0086313921388}
Losses {'ner': 389.05590821927774}
Losses {'ner': 407.1787067853402}
Losses {'ner': 418.7652225606871}
Losses {'ner': 440.9926184066426}
Losses {'ner': 389.664186428098}
Losses {'ner': 386.03615799306226}
Losses 

Losses {'ner': 381.18019790052864}
Losses {'ner': 406.1977795657174}
Losses {'ner': 373.66159165824416}
Losses {'ner': 398.59811431882}
Losses {'ner': 404.40742921492665}
Losses {'ner': 386.21166957845344}
Losses {'ner': 387.7250378017279}
Losses {'ner': 408.03095490162343}
Losses {'ner': 373.52642063082567}
Losses {'ner': 400.65298708066985}
Losses {'ner': 400.25484329961864}
Losses {'ner': 383.01899984800184}
Losses {'ner': 381.33591232961476}
Losses {'ner': 415.1343220148947}
Losses {'ner': 400.82294668890177}
Losses {'ner': 383.95903044051374}
Losses {'ner': 390.17112372828524}
Losses {'ner': 382.02810134220636}
Losses {'ner': 405.9872016831588}
Losses {'ner': 385.17577282625643}
Losses {'ner': 378.4746066629624}
Losses {'ner': 405.9413395966985}
Losses {'ner': 350.5269228421052}
Losses {'ner': 370.5352796209263}
Losses {'ner': 387.3044901125445}
Losses {'ner': 390.9300961187146}
Losses {'ner': 393.18308497177804}
Losses {'ner': 380.9298005050657}
Losses {'ner': 391.021723827429}
L

In [145]:
# Test the saved model
print("Loading from", output_dir)
nlp_test = spacy.load(output_dir)

Loading from C://Users//buchh//OneDrive//Desktop\custom_entity_extractor


In [146]:
result = os.path.join(output_dir, "result.csv")
all_bases = 0
correct_bases = 0
false_pos = 0
false_neg = 0

headers = ["text", "predicted_base", "actual_base"]
final_res = []
final_res.append(headers)
for sent in TEST_DATA:
    doc2 = nlp_test(sent[0])
    res = [i.text for i in list(doc2.ents)]
    actual_base = []
    for tokens in sent[1]['entities']:
        actual_base.append(sent[0][tokens[0]:tokens[1]])
    all_bases += len(actual_base)
    correct_bases += len(np.intersect1d(actual_base, res))
    false_pos += len(list(sorted(set(actual_base) - set(res))))
    false_neg += len(list(sorted(set(res) - set(actual_base))))
    final_res.append([sent, res, actual_base])

acc = ((correct_bases*100)/all_bases)
print("False pos: ", false_pos)
print("False neg: ", false_neg)
print("Marked {} out of {} bases correctly. Accuracy: {}".format(correct_bases, all_bases, acc))
    
with open(result, 'w', newline='') as csvfile:
    writer = csv.writer(csvfile)
    writer.writerows(final_res)

print("created result file", result)

False pos:  29
False neg:  36
Marked 44 out of 76 bases correctly. Accuracy: 57.89473684210526
created result file C://Users//buchh//OneDrive//Desktop\custom_entity_extractor\result.csv
