In [1]:
!pip install scikit-learn==1.5.0 nltk sklearn_crfsuite

Defaulting to user installation because normal site-packages is not writeable


In [15]:
import sys
sys.path.append("./NER-Evaluation")

import nltk
import sklearn_crfsuite

from copy import deepcopy
from collections import defaultdict

from sklearn_crfsuite.metrics import flat_classification_report

from ner_evaluation.ner_eval import compute_metrics
from ner_evaluation.ner_eval import compute_precision_recall_wrapper
from ner_evaluation.ner_eval import find_overlap

import pandas as pd
import json

In [3]:
from collections import namedtuple
Entity = namedtuple("Entity", "e_type start_offset end_offset")

### Implement our own version of collect_named_entities

**Use code from coref_reformat.ipynb from crosslingual_coref**

In [4]:
# Get FAA data in format {c5_id:{0: word0, 1: word1, ..., n: wordn}} using word tokenization from faa.conll

faa = {}

with open('../../data/FAA_data/faa.conll') as f:
    text = f.read()

docs = text.split('#begin document ')

for doc in docs:
    if doc[:5] == '(faa/':
        word_count = 0
        c5_id = doc.split('_')[1][:15]
        faa[c5_id] = {}
        lines = doc.split('\n')
        for line in lines[1:]:
            if 'faa' in line:
                faa[c5_id][word_count] = line.split()[3].upper()
                word_count = word_count + 1
# Fix known err
faa['19980620030289I'] = {0: 'MR.', 1: 'KADERA', 2: 'THEN', 3: 'ATTEMPTED', 4: 'TO', 5: 'LAND', 6: 'IN', 7: 'A', 8: 'FIELD', 9: 'BUT', 10: 'WAS', 11: 'FORCED', 12: 'TO', 13: 'LAND', 14: 'ON', 15: 'HIGHWAY', 16: '93', 17: '.', 18: 'THREE', 19: 'MILES', 20: 'EAST', 21: 'OF', 22: 'SUNMER', 23: ',', 24: 'IOWA'}

In [5]:
faa['19980620030289I'].values()

dict_values(['MR.', 'KADERA', 'THEN', 'ATTEMPTED', 'TO', 'LAND', 'IN', 'A', 'FIELD', 'BUT', 'WAS', 'FORCED', 'TO', 'LAND', 'ON', 'HIGHWAY', '93', '.', 'THREE', 'MILES', 'EAST', 'OF', 'SUNMER', ',', 'IOWA'])

In [179]:
def get_spans(mentions, words):
    ''' Input:
    - mentions:['MENTION1','MENTION2',...]
    - words: ['This','is','a','sentence','.','This','is','another','sentence','.'] (dict values)
        Output: [[startidx_mention1, end_idxmention1], [startidx_mention2, end_idxmention2], ...]
    '''

    mention_spans = []

    if "'S" in words.values():
        idx = list(words.values()).index("'S")
        words[idx-1] = words[idx-1] + "'S"
        del words[idx]
    
    for imention, mention in enumerate(mentions):

        mention = mention.replace('(', ' ( ').replace(')',' ) ').replace('  ',' ')
        mention = mention.replace(',',' , ').replace('  ',' ')

        mention_span = [-1, -1]

        # if this mention has shown up before, find its end_idx in mention_spans
        # ilastmntn (index of last mention) will be -1 if not found, or the index in mention_spans of the last one if found
        ilastmntn = -1
        keep_going = True
        while keep_going:
            try:
                ilastmntn = mentions[ilastmntn+1:imention].index(mention)
            except:
                keep_going=False

        # create lists of words and indices
        # check_idxs maintains original token indices of words (doesn't always correspond to place in word if list if list is a slice or words have been removed)
        check_words = list(words.values())
        check_idxs = list(words.keys())
        if ilastmntn > 0:
            prev_idx = mention_spans[ilastmntn][1]
            check_words = check_words[prev_idx:]
            check_idxs = check_idxs[prev_idx:]

        # loop through to find start and end indices of mention
        # finds index of first word and then index of last word
        # if we're looking for 'cargo door' and 'cargo' by itself appears previously in the sentence, it loop until it finds cargo door, not 'cargo xyzwkladsjfkl cargo door'
        # doesn't account for case in which end word of mention occurs in the middle of the mention as well
        start_idx = -1
        end_idx = -1
        marker = 0
        prev_marker = -1
        if mention in ' '.join(check_words):
            while mention != ' '.join(check_words[start_idx:end_idx+1]) and marker != prev_marker:
                try:
                    start_idx = check_words[marker:].index(mention.split()[0])
                    end_idx = check_words[start_idx:].index(mention.split()[-1]) + start_idx                
                    mention_span = [check_idxs[start_idx], check_idxs[end_idx]+1]
                except:
                    break
                prev_marker = marker
                marker = start_idx + 1
        
        mention_spans.append(mention_span)

    return mention_spans

In [228]:
def get_spans(mentions, words):
    ''' Input:
    - mentions:['MENTION1','MENTION2',...]
    - words: ['This','is','a','sentence','.','This','is','another','sentence','.'] (dict values)
        Output: [[startidx_mention1, end_idxmention1], [startidx_mention2, end_idxmention2], ...]
    '''

    mention_spans = []

    repeat_mentions = {}

    if "'S" in words.values():
        idx = list(words.values()).index("'S")
        words[idx-1] = words[idx-1] + "'S"
        del words[idx]
    
    for imention, mention in enumerate(mentions):

        mention = mention.replace('(', ' ( ').replace(')',' ) ').replace('  ',' ')
        mention = mention.replace(',',' , ').replace('  ',' ')

        mention_span = [-1, -1] # if conditions below aren't met, [-1,-1 is returned]

        tokens = list(words.values())
        idxs = list(words.keys())

        # Check if mention has been seen before. If has, start_idx already stored in repeat_mentions
        # Matched sequentially 
        if mention in repeat_mentions:
            if len(repeat_mentions[mention]) > 0:
                start_idx = repeat_mentions[mention].pop(0) # get start_idx and pop off list
                end_idx = start_idx + len(mention.split()) - 1
                mention_span = [idxs[start_idx],idxs[end_idx]+1]

        # Normal case, where it has not been seen before, and we search for the start of the phrase in check_words
        else:
            start_indices = [i for i in range(len(tokens)) if tokens[i:i+len(mention.split())] == mention.split()]

            # If start_indices contains multiple idxs, get start_idx from front of list (first occurance) and save rest to repeat_mentions
            # If start_indices contains just one idx, that is the start_idx
            if len(start_indices) > 0:
                
                if len(start_indices) > 1:
                    repeat_mentions[mention] = start_indices[1:]
                
                start_idx = start_indices[0]
                end_idx = start_idx + len(mention.split()) - 1
                mention_span = [idxs[start_idx],idxs[end_idx]+1]

        mention_spans.append(mention_span)

    return mention_spans

In [226]:
get_spans(gold_df[gold_df['id']=='19850315007389A']['entities'].to_list(), faa['19850315007389A'])

[[0, 1], [4, 6], [8, 9], [13, 14], [-1, -1], [18, 19]]

**Now, very simple to make our own version of collect_named_entities()**

In [222]:
def collect_named_entities(entities, labels, tokens):
    """
    Creates a list of Entity named-tuples, storing the entity type and the start and end
    offsets of the entity.

    Parameters:
    - entities: ["ENT1","ENT2"...] All entities for a doc
    - labels: ["LABEL1","LABEL2"...] All corresponding labels for a doc
    - tokens: dict_values(['TOW', 'PLANE', 'BECAME', ...]) Tokenized doc. Result of faa[doc_id].values()

    Returns: a list of Entity named-tuples
    """

    ent_spans = get_spans(entities, tokens)

    named_entities = []
    for ient, ent_span in enumerate(ent_spans):
        named_entities.append(Entity(labels[ient], ent_span[0], ent_span[1]))

    return named_entities

In [223]:
collect_named_entities(gold_df[gold_df['id']=='19850315007389A']['entities'].to_list(),gold_df[gold_df['id']=='19850315007389A']['labels'].to_list(),faa['19850315007389A'])

[Entity(e_type='ORG', start_offset=0, end_offset=1),
 Entity(e_type='ORG', start_offset=4, end_offset=6),
 Entity(e_type='ORG', start_offset=8, end_offset=9),
 Entity(e_type='ORG', start_offset=13, end_offset=14),
 Entity(e_type='ORG', start_offset=-1, end_offset=-1),
 Entity(e_type='ORG', start_offset=18, end_offset=19)]

### Get Predicted and Gold Data

In [9]:
abbrevs = {'FACILITY':'FAC','ORGANIZATION':'ORG','PERSON':'PER','LOCATION':'LOC','VEH':'VEHICLE'}
result_df = pd.read_csv('../../data/results/pl-marker/pl-marker_scierc_scibert_NER_jun17.csv')
result_df['labels'] = result_df['labels'].apply(lambda x: abbrevs[x] if x in abbrevs else x)
result_df.rename(columns={'c5_unique_id':'id', 'c5_id':'id'},inplace=True)
result_df['entities'] = result_df['entities'].apply(str.upper)
result_df.head()

Unnamed: 0,id,doc_idx,sent_idx,c119_input,sentence,entities,labels
0,19750315005389A,0,0,TAILWHEEL COCKED RIGHT PRIOR TO TKOF. ...,tailwheel cocked right prior to tkof .,TAILWHEEL,OtherScientificTerm
1,19751209037899A,3,1,PLT NOTED SOFT R BRAKE PEDAL DRG TAXI TO TKOF....,flt rtnd springfield due soft brake strong win...,SOFT BRAKE,OtherScientificTerm
2,19760215007629A,7,0,MTNS OBSCURED.FLT TO CK VOR REC REPTD INOP PRI...,mtns obscured .,MTNS,OtherScientificTerm
3,19760531024959A,9,1,MAINT NOT PERFORMED DUE PARTS NOT AVAILABLE. T...,turbocharger waste gate did not actuate .,TURBOCHARGER WASTE GATE,OtherScientificTerm
4,19760507010379A,10,1,LEFT ENG OIL SUPPLY EXHAUSTED.GEAR-UP LDG IN M...,gear-up ldg in mesquite brush .,GEAR-UP LDG,Method


In [10]:
gold_df = pd.read_csv('../../gold_standard/processed/ner.csv')
if 'labels' not in gold_df.columns:
    gold_df['labels'] = ['ORG']*len(gold_df) # Add dummy labels for aviation mentions-only gs
gold_df.head()

Unnamed: 0,id,sample,entities,labels
0,19990213001379A,ACFT WAS TAXIING FOR TAKE OFF WHEN IT LOST CON...,ACFT,ORG
1,19990213001379A,ACFT WAS TAXIING FOR TAKE OFF WHEN IT LOST CON...,DITCH,ORG
2,19990213001379A,ACFT WAS TAXIING FOR TAKE OFF WHEN IT LOST CON...,TREE,ORG
3,19800217031649I,"AFTER TAKEOFF, ENGINE QUIT. WING FUEL TANK SUM...",TAKEOFF,ORG
4,19800217031649I,"AFTER TAKEOFF, ENGINE QUIT. WING FUEL TANK SUM...",ENGINE,ORG


In [347]:
#gold_df['entities'].iat[106] = 'BARTLESVILLE' # typo in ACE-2005 Gold Df

In [372]:
#gold_df['entities'].iat[86] = 'BARTLESVILLE' # typo in OntoNotes Gold Df

In [16]:
#gold_df['entities'].iat[209] = 'ATTITUDE' # typo in our Gold Df
#gold_df['entities'].iat[222] = 'ACCELARATION'
#gold_df['entities'].iat[439] = 'MR. KADERA'

In [11]:
id_col = 'id'
len(result_df[id_col].unique())

1104

In [12]:
len(gold_df['id'].unique())

100

**Create all_true_ents and all_pred_ents lists**

In the example in example-full-named-entity-evaluation.ipynb, they use test_sents_labels and y_pred, which are lists of the tokens to input to collect_named_entities to get true and pred, respectively. However, our collect_named_entities() takes more input, so it's easier to do that process ahead of time, and have lists of true's and pred's ready to go

In [172]:
def sort_ents(true_ents):
    final_ents = {ent:0 for ent in true_ents}
    for ent_a in true_ents:
        rest = [ent for ent in true_ents if ent != ent_a]
        for ent_b in rest:
            overlap = find_overlap(range(ent_a[1],ent_a[2]),range(ent_b[1],ent_b[2]))
            #print(f"Comparing {ent_a} with {ent_b}, overlap = {overlap}")
            if len(overlap) > 0 and len(range(ent_a[1],ent_a[2])) > len(range(ent_b[1],ent_b[2])):
                final_ents[ent_a] += 1
    return [ent[0] for ent in sorted(final_ents.items(), key=lambda x: x[1])]

In [229]:
all_true_ents = []
all_pred_ents = []

for doc_id in gold_df['id'].unique():
    true_rows = gold_df.dropna()[gold_df.dropna()['id']==doc_id]
    pred_rows = result_df.dropna()[result_df.dropna()[id_col]==doc_id]

    true_ents = collect_named_entities(true_rows['entities'].to_list(),true_rows['labels'].to_list(),faa[doc_id])
    pred_ents = collect_named_entities(pred_rows['entities'].to_list(),pred_rows['labels'].to_list(),faa[doc_id])
    
    all_true_ents.append(sort_ents(true_ents))
    all_pred_ents.append(pred_ents)

In [230]:
for idoc, doc_id in enumerate(gold_df['id'].unique()):
    for ent in all_true_ents[idoc]:
        if ent[1] == -1:
            print(f"True: {doc_id}")
            print(all_true_ents[idoc])
            print()
    for ent in all_pred_ents[idoc]:
        if ent[1] == -1:
            print(f"Pred: {doc_id}")
            print(all_pred_ents[idoc])
            print()

True: 19850315007389A
[Entity(e_type='ORG', start_offset=0, end_offset=1), Entity(e_type='ORG', start_offset=4, end_offset=6), Entity(e_type='ORG', start_offset=8, end_offset=9), Entity(e_type='ORG', start_offset=13, end_offset=14), Entity(e_type='ORG', start_offset=-1, end_offset=-1), Entity(e_type='ORG', start_offset=18, end_offset=19)]

True: 19801230089799I
[Entity(e_type='ORG', start_offset=-1, end_offset=-1), Entity(e_type='ORG', start_offset=12, end_offset=13), Entity(e_type='ORG', start_offset=17, end_offset=18), Entity(e_type='ORG', start_offset=12, end_offset=14)]

True: 19920405008919A
[Entity(e_type='ORG', start_offset=1, end_offset=2), Entity(e_type='ORG', start_offset=3, end_offset=5), Entity(e_type='ORG', start_offset=8, end_offset=11), Entity(e_type='ORG', start_offset=-1, end_offset=-1), Entity(e_type='ORG', start_offset=6, end_offset=7), Entity(e_type='ORG', start_offset=3, end_offset=7)]

True: 19950314029269I
[Entity(e_type='ORG', start_offset=3, end_offset=7), Enti

### Apply Example

In [80]:
conll_tags = ['PER', 'ORG', 'MISC', 'LOC']
ace_tags = ['PER','ORG','LOC','FAC','GPE'] # RESTRICTED SET
on_tags = ['PER','ORG','LOC','FAC','GPE','PRODUCT','NORP','QUANTITY','EVENT','WORK_OF_ART','CARDINAL','DATE','PERCENT','TIME','ORDINAL','MONEY','LAW','LANGUAGE']

In [169]:
tags = ace_tags

metrics_results = {'correct': 0, 'incorrect': 0, 'partial': 0,
                   'missed': 0, 'spurious': 0, 'possible': 0, 'actual': 0, 'precision': 0, 'recall': 0}

# overall results
results = {'strict': deepcopy(metrics_results),
           'ent_type': deepcopy(metrics_results),
           'partial':deepcopy(metrics_results),
           'exact':deepcopy(metrics_results)
          }


# results aggregated by entity type
evaluation_agg_entities_type = {e: deepcopy(results) for e in tags}

for true_ents, pred_ents in zip(all_true_ents, all_pred_ents):
    
    # compute results for one message
    tmp_results, tmp_agg_results = compute_metrics(
        true_ents, pred_ents,  tags
    )
    
    #print(tmp_results)

    # aggregate overall results
    for eval_schema in results.keys():
        for metric in metrics_results.keys():
            results[eval_schema][metric] += tmp_results[eval_schema][metric]
            
    # Calculate global precision and recall
        
    results = compute_precision_recall_wrapper(results)


    # aggregate results by entity type
 
    for e_type in tags:

        for eval_schema in tmp_agg_results[e_type]:

            for metric in tmp_agg_results[e_type][eval_schema]:
                
                evaluation_agg_entities_type[e_type][eval_schema][metric] += tmp_agg_results[e_type][eval_schema][metric]
                
        # Calculate precision recall at the individual entity level
                
        evaluation_agg_entities_type[e_type] = compute_precision_recall_wrapper(evaluation_agg_entities_type[e_type])

In [82]:
def print_results_labeled(tool_name, results):

    scores = {'exact':0.0,'strict':0.0,'partial':0.0,'ent_type':0.0}
    for score in scores:
        prec = results[score]['precision']
        rec = results[score]['recall']
        scores[score] = f"{2*prec*rec/(prec+rec):.4}" 

    print('|                                         | Strict  | Exact  | Partial  | Type    |')
    print('|-----------------------------------------|---------|--------|----------|---------|')
    print(f"| {tool_name:40}| {scores['strict']:8}| {scores['exact']:7}| {scores['partial']:9}| {scores['ent_type']:8}|")

In [83]:
def print_results_unlabeled(tool_name, results):
    scores = {'exact':0.0,'partial':0.0}
    for score in scores:
        prec = results[score]['precision']
        rec = results[score]['recall']
        scores[score] = {'prec':f"{prec:.4}", 'rec':f"{rec:.4}", 'f1':f"{(2*prec*rec/(prec+rec)):.4}"}

    print('|                                         | Precision (Weak) | Recall (Weak) | F1 (Weak)     | Precision (Strong) | Recall (Strong) | F1 (Strong) |')
    print('|-----------------------------------------|------------------|---------------|---------------|--------------------|-----------------|-------------|')
    print(f"| {tool_name:40}| {scores['partial']['prec']:17}| {scores['partial']['rec']:14}| {scores['partial']['f1']:14}| {scores['exact']['prec']:19}| {scores['exact']['rec']:16}| {scores['exact']['f1']:12}|")

In [84]:
print_results_unlabeled('nltk (uppercased)',results)

|                                         | Precision (Weak) | Recall (Weak) | F1 (Weak)     | Precision (Strong) | Recall (Strong) | F1 (Strong) |
|-----------------------------------------|------------------|---------------|---------------|--------------------|-----------------|-------------|
| nltk (uppercased)                       | 0.1245           | 0.4286        | 0.1929        | 0.0786             | 0.2707          | 0.1218      |


In [85]:
print_results_labeled('nltk (uppercased)',results)

|                                         | Strict  | Exact  | Partial  | Type    |
|-----------------------------------------|---------|--------|----------|---------|
| nltk (uppercased)                       | 0.01015 | 0.1218 | 0.1929   | 0.03723 |


In [86]:
results

{'ent_type': {'correct': 11,
  'incorrect': 67,
  'partial': 0,
  'missed': 55,
  'spurious': 380,
  'possible': 133,
  'actual': 458,
  'precision': 0.024017467248908297,
  'recall': 0.08270676691729323},
 'partial': {'correct': 36,
  'incorrect': 0,
  'partial': 42,
  'missed': 55,
  'spurious': 380,
  'possible': 133,
  'actual': 458,
  'precision': 0.12445414847161572,
  'recall': 0.42857142857142855},
 'strict': {'correct': 3,
  'incorrect': 75,
  'partial': 0,
  'missed': 55,
  'spurious': 380,
  'possible': 133,
  'actual': 458,
  'precision': 0.006550218340611353,
  'recall': 0.022556390977443608},
 'exact': {'correct': 36,
  'incorrect': 42,
  'partial': 0,
  'missed': 55,
  'spurious': 380,
  'possible': 133,
  'actual': 458,
  'precision': 0.07860262008733625,
  'recall': 0.2706766917293233}}

In [87]:
def orig_calculate_precision_recall_f1(gs, df_tool, id_col, ent_col, matching="STRONG"):
    """
    Calculate precision and recall based on entities comparison between gs (ground truth) and df_tool (answers).
    
    Parameters:
    - gs: DataFrame with columns ['id', 'sample', 'entities'] representing the ground truth.
    - df_tool: DataFrame with columns ['id', 'sample', 'entities', 'labels'] representing the tool's answers.
    
    Returns:
    - A tuple containing precision and recall.
    """
    TP = 0  # True Positives
    FP = 0  # False Positives
    FN = 0  # False Negatives
    
    # Check for True Positives and False Negatives by iterating over gs
    for index, gs_row in gs.dropna().iterrows():
        gs_id, gs_entity = gs_row['id'], gs_row['entities']
        tool_entities = [entity.upper() for entity in df_tool.loc[df_tool[id_col] == gs_id, ent_col].tolist()] # get all the entities the tool generated for the gs_id entry
        
        # In strong matching, we only count a tool-generated entity as correct if it exactly matches the gold standard entity
        if matching=="STRONG":
            if gs_entity in tool_entities:
                TP += 1
            else:
                FN += 1
        # In weak matching, we count a tool-generated entity as correct if it is a subspan of the gold standard entity, or if the gold standard entity is a subspan of it
        elif matching=="WEAK":
            if any(gs_entity in tool_entity or tool_entity in gs_entity for tool_entity in tool_entities):
                TP += 1
            else:
                FN += 1
        else:
            print("Error: Matching must = STRONG or WEAK")
            return None
    
    # Check for False Positives by iterating over df_tool
    for index, tool_row in df_tool[df_tool[id_col].isin(gs['id'].unique())].iterrows():
        tool_id, tool_entity = tool_row[id_col], tool_row[ent_col]
        gs_entities = gs.dropna().loc[gs.dropna()['id'] == tool_id, 'entities'].tolist()

        #  strong matching
        if matching=="STRONG":
            if tool_entity not in gs_entities:
                FP += 1
        else:
            if not any(tool_entity in gs_entity or gs_entity in tool_entity for gs_entity in gs_entities):
                FP += 1
    
    # Calculate precision and recall
    precision = TP / (TP + FP) if (TP + FP) > 0 else 0
    recall = TP / (TP + FN) if (TP + FN) > 0 else 0
    #print(f"TP={TP}, FP={FP}, FN={FN}, prec={precision}, rec={recall}")
    
    # Calculating the F1 score
    f1_score = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0
    
    return TP, FP, FN, precision, recall, f1_score

In [174]:
orig_calculate_precision_recall_f1(pd.read_csv('../../gold_standard/processed/ner_ace_nltk.csv'), pd.read_csv('../../data/results/nltk/nltk_ner_uppercased.csv'), id_col, 'entities', matching="WEAK")

(72, 374, 48, 0.16143497757847533, 0.6, 0.254416961130742)

In [89]:
def find_overlap(true_range, pred_range):
    """Find the overlap between two ranges

    Find the overlap between two ranges. Return the overlapping values if
    present, else return an empty set().

    Examples:

    >>> find_overlap((1, 2), (2, 3))
    2
    >>> find_overlap((1, 2), (3, 4))
    set()
    """

    true_set = set(true_range)
    pred_set = set(pred_range)

    overlaps = true_set.intersection(pred_set)

    return overlaps

In [143]:
def calculate_precision_recall_f1(all_true_ents, all_pred_ents, matching="STRONG"):
    """
    Calculate precision and recall based on entities comparison between gs (ground truth) and df_tool (answers).
    
    Parameters:
    - gs: DataFrame with columns ['id', 'sample', 'entities'] representing the ground truth.
    - df_tool: DataFrame with columns ['id', 'sample', 'entities', 'labels'] representing the tool's answers.
    
    Returns:
    - A tuple containing precision and recall.
    """
    COR = 0  # correct
    PAR = 0
    FP = 0  # False Positives
    FN = 0  # False Negatives

    # Go doc by doc:
    for true_ents, pred_ents in zip(all_true_ents, all_pred_ents):

        unmatched_preds = [(ent[1],ent[2]) for ent in pred_ents] # copy

        # Check CORs, PARs, and FNs by iterating over true ents
        for true_ent in true_ents:
            true_bounds = (true_ent[1],true_ent[2])
            found_match = False
            if true_bounds in unmatched_preds:
                COR +=1
                found_match = True
                unmatched_preds.remove(true_bounds)
            else:
                for pred_ent in unmatched_preds:
                    pred_bounds = (pred_ent[0],pred_ent[1])
                    overlaps = find_overlap(true_bounds, pred_bounds)
                    if len(overlaps) > 0:
                        PAR += 1
                        found_match = True
                        unmatched_preds.remove(pred_ent)
                        break
            if not found_match:
                FN += 1

        FP += len(unmatched_preds)
    
    # Calculate precision and recall
    if matching=="STRONG":
        TP = COR
        FP += PAR
        FN += PAR
    else:
        TP = PAR + COR
    precision = TP / (TP + FP) if (TP + FP) > 0 else 0
    recall = TP / (TP + FN) if (TP + FN) > 0 else 0
    #print(f"TP={TP}, FP={FP}, FN={FN}, prec={precision}, rec={recall}")
    
    # Calculating the F1 score
    f1_score = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0
    
    return COR, PAR, FP, FN, precision, recall, f1_score

In [175]:
calculate_precision_recall_f1(all_true_ents, all_pred_ents,matching="WEAK")

(40, 39, 379, 41, 0.17248908296943233, 0.6583333333333333, 0.27335640138408307)

In [176]:
davids = {'exact':{'correct':[],'incorrect':[],'partial':[],'missed':[],'spurious':[],'precision':[],'recall':[],'actual':[],'possible':[]},'partial':{'correct':[],'incorrect':[],'partial':[],'missed':[],'spurious':[],'precision':[],'recall':[],'actual':[],'possible':[]}}
my_new = {'strong':{'COR':[],'PAR':[],'FP':[],'FN':[],'precision':[],'recall':[]},'weak':{'COR':[],'PAR':[],'FP':[],'FN':[],'precision':[],'recall':[]}}
my_old = {'strong':{'TP':[],'FP':[],'FN':[],'precision':[],'recall':[]},'weak':{'TP':[],'FP':[],'FN':[],'precision':[],'recall':[]}}

for idoc, doc_id in enumerate(gold_df['id'].unique()):

    # David's way:
    true_ents = all_true_ents[idoc]
    pred_ents = all_pred_ents[idoc]
    tmp_results, tmp_agg_results = compute_metrics(
        true_ents, pred_ents,  tags
    )
    for type in ['exact','partial']:
        for key in tmp_results[type].keys():
            davids[type][key].append(tmp_results[type][key])

    # My way:
    gs = gold_df[gold_df['id']==doc_id].reset_index(drop=True)
    df_tool = result_df[result_df[id_col]==doc_id].reset_index(drop=True)
    for type in ['strong','weak']:
        TP, FP, FN, prec, rec, f1 = orig_calculate_precision_recall_f1(gs, df_tool, id_col,'entities',matching=type.upper())
        my_old[type]['TP'].append(TP)
        my_old[type]['FP'].append(FP)
        my_old[type]['FN'].append(FN)
        my_old[type]['precision'].append(prec)
        my_old[type]['recall'].append(rec)
        
        COR, PAR, FP, FN, prec, rec, f1 = calculate_precision_recall_f1([true_ents], [pred_ents],matching=type.upper())
        my_new[type]['COR'].append(COR)
        my_new[type]['PAR'].append(PAR)
        my_new[type]['FP'].append(FP)
        my_new[type]['FN'].append(FN)
        my_new[type]['precision'].append(prec)
        my_new[type]['recall'].append(rec)

**Compare**

TP = correct + partial\
FP = spurious + incorrect?\
FN = missed + incorrect?

In [121]:
sum(davids['exact']['correct'])

36

In [128]:
sum(my_old['strong']['TP'])

41

In [130]:
sum(my_new['strong']['COR'])

35

In [109]:
my['strong']['FP'][53]

3

In [108]:
davids['exact']['spurious'][53]

2

In [170]:
import numpy as np
np.mean(my['strong']['recall']) # macro

0.20616666666666666

In [173]:
doc_ids = gold_df['id'].unique()
old_strong = my_old['strong']['TP']
new_strong = my_new['strong']['COR']
davids_strong = davids['exact']['correct']
disc = map(lambda x: "" if x==True else "DISC", [old_strong[i] == new_strong[i] == davids_strong[i] for i in range(len(doc_ids))])

pd.DataFrame({'id':doc_ids,'gold':[gold_df[gold_df['id']==doc_id]['entities'].to_list() for doc_id in doc_ids],'pred':[result_df[result_df[id_col]==doc_id]['entities'].to_list() for doc_id in doc_ids],'My Old':old_strong,'My New':new_strong,'Davids':davids_strong, 'DISC':disc}).to_csv('check_correct.csv')

In [178]:
doc_ids = gold_df['id'].unique()
old_weak = my_old['weak']['TP']
new_weak = [my_new['weak']['COR'][i] + my_new['weak']['PAR'][i] for i in range(len(my_new['weak']['COR']))]
davids_weak = [davids['partial']['correct'][i] + davids['partial']['partial'][i] for i in range(len(davids['partial']['correct']))]
disc = map(lambda x: "" if x==True else "DISC", [old_weak[i] == new_weak[i] == davids_weak[i] for i in range(len(doc_ids))])

pd.DataFrame({'id':doc_ids,'gold':[gold_df[gold_df['id']==doc_id]['entities'].to_list() for doc_id in doc_ids],'pred':[result_df[result_df[id_col]==doc_id]['entities'].to_list() for doc_id in doc_ids],'My Old':old_weak,'My New':new_weak,'Davids':davids_weak, 'DISC':disc}).to_csv('check_correct.csv')

In [110]:
i = 53
tmp_results, tmp_agg_results = compute_metrics(
        all_true_ents[i], all_pred_ents[i],  tags
    )
tmp_results['exact']

{'correct': 2,
 'incorrect': 0,
 'partial': 0,
 'missed': 0,
 'spurious': 2,
 'precision': 0,
 'recall': 0,
 'actual': 4,
 'possible': 2}

In [111]:
all_true_ents[i]

[Entity(e_type='LOC', start_offset=1, end_offset=2),
 Entity(e_type='PER', start_offset=6, end_offset=7)]

In [112]:
all_pred_ents[i]

[Entity(e_type='ORG', start_offset=0, end_offset=1),
 Entity(e_type='ORG', start_offset=1, end_offset=2),
 Entity(e_type='ORG', start_offset=6, end_offset=7),
 Entity(e_type='ORG', start_offset=10, end_offset=11)]

In [113]:
gold_df[gold_df['id']==gold_df['id'].unique()[i]]

Unnamed: 0,id,sample,entities,labels
78,19861114075329I,ENTERED TCA WITHOUT ATC COMMUNICATION. PILOT W...,TCA,LOC
79,19861114075329I,ENTERED TCA WITHOUT ATC COMMUNICATION. PILOT W...,PILOT,PER


In [114]:
result_df[result_df[id_col]==gold_df['id'].unique()[i]]

Unnamed: 0,index,c5_unique_id,c119_text,entities,POS tags,labels
236,1117,19861114075329I,ENTERED TCA WITHOUT ATC COMMUNICATION. PILOT W...,ENTERED,NNP,ORG
237,1117,19861114075329I,ENTERED TCA WITHOUT ATC COMMUNICATION. PILOT W...,TCA,NNP,ORG
238,1117,19861114075329I,ENTERED TCA WITHOUT ATC COMMUNICATION. PILOT W...,PILOT,NNP,ORG
239,1117,19861114075329I,ENTERED TCA WITHOUT ATC COMMUNICATION. PILOT W...,MALFUNCTIONING,NNP,ORG


In [115]:
faa[gold_df['id'].unique()[i]]

{0: 'ENTERED',
 1: 'TCA',
 2: 'WITHOUT',
 3: 'ATC',
 4: 'COMMUNICATION',
 5: '/.',
 6: 'PILOT',
 7: 'WAS',
 8: 'AWARE',
 9: 'OF',
 10: 'MALFUNCTIONING',
 11: 'ENCODING',
 12: 'ALTIMETER',
 13: '/.'}