## Evaluation Criteria
* Recall is more important than precision
* Current model has R=0.99 and P=0.96 
* Change criteria of False Positives and False Negatives as wrong detections with higher or lower than actual priority
* Work on POS and Dependency parsing 
* Add doc site occurence analysis

In [1]:
import spacy
import seaborn as sns
import numpy as np
from spacy.tokens import DocBin
import matplotlib.pyplot as plt

In [2]:
MODEL_PATH = "../output_lg/model-best/"
DEV_DATA_PATH = "../data/dev.spacy"

In [3]:
def are_entities_equal(ent_1, ent_2):
    texts_are_equal = ent_1.text == ent_2.text
    start_chars_are_equal = ent_1.start_char == ent_2.start_char
    end_chars_are_equal = ent_1.end_char == ent_2.end_char
    labels_are_equal = ent_1.label_ == ent_2.label_
    
    if texts_are_equal and start_chars_are_equal and end_chars_are_equal and labels_are_equal:
        return True
    else:
        return False

In [4]:
class NerModel:
    def __init__(self, model_path, entity_2_index):
        self.model_path = model_path
        self.base_model = spacy.load(model_path)
        self.vocab = self.base_model.vocab
        self.entity_2_index = entity_2_index

        return

    def predict(self, text):
        prediction = self.base_model(text)
        predicted_entities = prediction.ents

        if len(predicted_entities) != 0:
            predicted_entity_indices = [self.entity_2_index[str(entity)] for entity in predicted_entities]
            highest_prediction_index = np.argmax(predicted_entity_indices)
            predicted_entity = predicted_entities[highest_prediction_index]
        else:
            predicted_entity = None

        return predicted_entity

In [6]:
original_model = spacy.load(MODEL_PATH)

db = DocBin()
doc_bin = db.from_disk(DEV_DATA_PATH)
docs = list(doc_bin.get_docs(original_model.vocab))

## Create entity ordering

In [7]:
all_entities = [] 

for original_doc in docs:
    all_entities.append(str(original_doc.ents[0]))
    doc_text = original_doc.text
    predictions = original_model(doc_text)
    predicted_ents = predictions.ents

    for ent in predicted_ents:
        all_entities.append(str(ent))

ordered_entities = np.unique(sorted(all_entities))
index_2_entity = dict(enumerate(ordered_entities))
entity_2_index = {value:key for key, value in index_2_entity.items()}

## Generate results

In [8]:
ner_model = NerModel(MODEL_PATH, entity_2_index)

In [9]:
correct_predictions = []
false_prediction_pairs = []

for original_doc in docs:
    doc_text = original_doc.text
    actual_ent = original_doc.ents[0]
    predicted_ent = ner_model.predict(doc_text)

    if predicted_ent is not None:
        entites_are_equal = are_entities_equal(actual_ent, predicted_ent)

        if entites_are_equal:
            correct_predictions.append(actual_ent)
        else:
            false_prediction_pairs.append((actual_ent, predicted_ent))
    else:
        false_prediction_pairs.append((actual_ent, predicted_ent))

In [10]:
missed_entities = []
false_positives = []
false_negatives = []

for prediction_pair in false_prediction_pairs:
    actual_ent = prediction_pair[0]
    predicted_ent = prediction_pair[1]

    if predicted_ent is None:
        missed_entities.append(actual_ent) 
    elif entity_2_index[str(predicted_ent)] > entity_2_index[str(actual_ent)]:
        false_positives.append((actual_ent, predicted_ent))
    elif entity_2_index[str(predicted_ent)] < entity_2_index[str(actual_ent)]:
        false_negatives.append((actual_ent, predicted_ent))
    else:
        continue

In [11]:
print(f"Samples tested: {len(docs)}")
print(f"Samples predicted correctly: {len(correct_predictions)}")
print(f"Samples predicted incorrectly: {len(false_prediction_pairs)}")
print(f"Missed entities: {len(missed_entities)}")
print(f"False positives: {len(false_positives)}")
print(f"False negatives: {len(false_negatives)}")
print(f"Accuracy: {100 * (len(correct_predictions)/len(docs)):.2f}%")

Samples tested: 156
Samples predicted correctly: 140
Samples predicted incorrectly: 16
Missed entities: 9
False positives: 3
False negatives: 4
Accuracy: 89.74%
