# Overall Pipeline Evaluation

This notebook evaluates the performance of the overall Chemimap pipeline by comparing the predicted positive labels and the gold labels.  We cover the 500 dev examples this time.

## 1. Imports

In [1]:
import csv

In [2]:
## This is a helper function used in the script.  Nothing special here
def chunks(l, n):
    """
    Yield successive n-sized chunks from l.
    """
    for i in range(0, len(l), n):
        assert len(l[i:i + n]) == n
        yield l[i:i + n]

## 2. Loading Results

In [30]:
## Extracting the gold positive labels
test_set_goldfile = "./data_scibert_version/test_filter.data"
train_set_goldfile = "./data_scibert_version/train_filter.data"
dev_set_goldfile = "./data_scibert_version/dev_filter.data"
pos_label_count = 0
pos_goldlabel = {}
                               
with open(dev_set_goldfile, 'r') as f:
    lines = [l.strip().split('\t') for l in f]

    for l in lines:
        pmid = l[0]

        for r in chunks(l[2:], 17):
            
            if r[0] == '1:CID:2':
                assert ((r[7] == 'Chemical') and (r[13] == 'Disease'))
                entry = {'Chemical': set(r[6].split('|')),
                         'Disease': set(r[12].split('|'))}
                if pmid in pos_goldlabel.keys():
                    pos_goldlabel[pmid].append(entry)
                else:
                    pos_goldlabel[pmid] = [entry]
                pos_label_count += 1

print(f"Gold labels loading completed.  Number of positive examples: {pos_label_count}")
print(f"\nShowing the extracted gold positive relationships for the 3rd example:\n{pos_goldlabel[list(pos_goldlabel.keys())[2]]}")

Gold labels loading completed.  Number of positive examples: 1012

Showing the extracted gold positive relationships for the 3rd example:
[{'Chemical': {'NH3', 'Ammonia', 'ammonia'}, 'Disease': {'drowsiness'}}, {'Chemical': {'valproic acid', 'Valproic acid', 'VPA'}, 'Disease': {'drowsiness'}}]


In [31]:
## Extracting the predicted positive relationships
test_set_predfile = "./overall_results/full_set_results.pubtator"
pos_pred_count = 0
pos_pred = {}
entities_mapping = {}

with open(test_set_predfile, 'r') as f:
    lines = [l.strip().split('\t') for l in f]
    
    for l in lines:
        pmid = l[0]
        if pmid not in pos_goldlabel.keys():
            continue
            
        if pmid not in entities_mapping.keys():
            entities_mapping[pmid] = {}
        
        # extracting the entities mapping:
        if len(l) == 6:
            MeSHID = l[5]
            entity_name = l[3]
            if MeSHID not in entities_mapping[pmid].keys():
                entities_mapping[pmid][MeSHID] = [entity_name]
            else:
                entities_mapping[pmid][MeSHID].append(entity_name)
                
        # extracting the predictions, note that predictions always follow the full entity lists:
        if len(l) == 5:
            assert l[1] == "CID", f"Error found in example with PMed ID {pmid}: Not CID"
            assert l[4] == 'predict', f"Error found in example with PMed ID {pmid}: Not prediction"
            
            entity1_names = entities_mapping[pmid][l[2]]
            entity2_names = entities_mapping[pmid][l[3]]
            
            entry = {'Entity1': set(entity1_names),
                     'Entity2': set(entity2_names)}
            
            if pmid in pos_pred.keys():
                pos_pred[pmid].append(entry)
            else:
                pos_pred[pmid] = [entry]
            pos_pred_count += 1
        
print(f"Prediction loading completed.  Number of positive predictions: {pos_pred_count}")
print(f"\nShowing the extracted predicted positive relationships for the 3rd example:\n{pos_pred[list(pos_pred.keys())[2]]}")      
        

Prediction loading completed.  Number of positive predictions: 862

Showing the extracted predicted positive relationships for the 3rd example:
[{'Entity1': {'NH3', 'Ammonia', 'ammonia'}, 'Entity2': {'epileptic'}}, {'Entity1': {'NH3', 'Ammonia', 'ammonia'}, 'Entity2': {'drowsiness'}}]


In [32]:
for PMedID in list(pos_goldlabel.keys()):
    if PMedID in list(pos_pred.keys()):
        continue
    else:
        print(f"PMed ID {PMedID} not found in predictions")

In [33]:
len(list(pos_goldlabel.keys())), len(list(pos_pred.keys()))

(500, 500)

## 3. Actual Evaluation
- Given that any diffences in the NER and MeSH steps will impact the relationships that the RE model is asked to predict on, it is simply not possible to directly compare all the positive and negative predictions using the conventonal notions of precision and recall, as there is not a single number available for use as the denominator.
- Instead, as our focus is on the positive labels only (N.B. negative labels are not shown in the app, nor are they the concern of users), we adopt the following altered definitions:
- ***Precision***: number of correct positive predictions / number of positive predictions made
- ***Recall***: number of correct positive predictions / number of positive gold labels
- ***F1***: calculated based on the above altered precision and recall metrics

In [34]:
## Calculating the number of correct positive predictions
correct_preds = 0

for pmid in pos_pred.keys():
    
    preds = pos_pred[pmid]
    if pmid in pos_goldlabel.keys():
        golds = pos_goldlabel[pmid]
    else:
        continue
    
    for pred in preds:
        found = False
        
        for gold in golds:
            if not found:
                entity1 = pred['Entity1']
                entity2 = pred['Entity2']
                gold_chem = gold['Chemical']
                gold_dis = gold['Disease']
                if len(entity1.intersection(gold_chem)) >0 and len(entity2.intersection(gold_dis)) > 0:
                    correct_preds += 1
                    found = True
                elif len(entity2.intersection(gold_chem)) >0 and len(entity1.intersection(gold_dis)) > 0:
                    correct_preds += 1
                    found = True
                    
print(f"Completed.  Number of correct positive predictions: {correct_preds}")
                

Completed.  Number of correct positive predictions: 343


In [35]:
## Score calculations:
precision = correct_preds / pos_pred_count
recall = correct_preds / pos_label_count
f1 = 2 * (precision * recall) / (precision + recall)

print(f"Results:\n- Precision: {round(precision*100, 2)}\n- Recall: {round(recall*100, 2)}\n- F1: {round(f1*100, 2)}")

Results:
- Precision: 39.79
- Recall: 33.89
- F1: 36.61
