In [132]:
tree_path = "../NeuralNLP-NeuralClassifier-master/data/cpc_label_tree.pkl"
predicted_path = "../../../../predict.txt"
truth_path = "../../../../test.json"

In [2]:
%matplotlib inline
import pickle
import copy
import json
import pandas as pd
import matplotlib.pyplot as plt

In [134]:
def add_confusion_to_dict(cur_dict, in_predicted, in_true):
    
    # this is necessary because the metric can be zero and 
    # the key would not exist if the metric is zero
    for metric in ['tp', 'tn', 'fp', 'fn']:
        if metric not in cur_dict:
            cur_dict[metric] = 0
    
    if in_predicted and in_true:
        cur_dict['tp'] += 1
    elif in_predicted and not in_true:
        cur_dict['fp'] += 1
    elif not in_predicted and in_true:
        cur_dict['fn'] += 1
    else:
        cur_dict['tn'] += 1
        
def update_metrics_tracker(metrics_tracker, index, cur_dict):
    
    for metric in ['tp', 'tn', 'fp', 'fn']:
        metrics_tracker[index][metric] += cur_dict[metric]
    
    tp = cur_dict['tp']
    tn = cur_dict['tn']
    fp = cur_dict['fp']
    fn = cur_dict['fn']
    
    # https://github.com/dice-group/gerbil/wiki/Precision,-Recall-and-F1-measure
    # For these special cases, we have defined that if the true positives, 
    # false positives and false negatives are all 0, the precision, recall and F1-measure are 1. 
    # This might occur in cases in which the gold standard contains a document without any annotations 
    # and the annotator (correctly) returns no annotations. If true positives are 0 and 
    # one of the two other counters is larger than 0, the precision, recall and F1-measure are 0.
    
    if tp == 0 and fp == 0 and fn == 0:
        precision = 1
        recall = 1
        f1 = 1
    elif tp == 0 and (fp > 0 or fn > 0):
        precision = 0
        recall = 0
        f1 = 0
    # non-special cases
    else:
        precision = tp / (tp + fp)
        recall = tp / (tp + fn)
        f1 = (2 * precision * recall) / (precision + recall)

    metrics_tracker[index]['precisions'].append(precision)
    metrics_tracker[index]['recalls'].append(recall)
    metrics_tracker[index]['f1s'].append(f1)

    
def pretty_print_describe(stats):
    
    for metric in ['mean', 'std', 'min', '50%', 'max']:
        print("      {}:\t{}".format(metric, stats[metric]))

In [135]:
def print_evaluation_results(tree_path, predicted_path, truth_path):
    with open(tree_path, 'rb') as f:
        tree_dict = pickle.load(f)


    with open(predicted_path, 'r') as predicted_file, open(truth_path, 'r') as truth_file:  
        
        for predicted in predicted_file:
            truth = next(truth_file)

            predicted_labels = predicted.strip().split(";")
            true_labels = json.loads(truth)["doc_label"]

            predicted_level_tracker = [set() for _ in range(5)]
            true_level_tracker = [set() for _ in range(5)]

            # get tracker
            for predicted_label in predicted_labels:
                index = 0
                for level_label in predicted_label.split("--"):
                    predicted_level_tracker[index].add(level_label)
                    index += 1

            for true_label in true_labels:
                index = 0
                for level_label in true_label.split("--"):
                    true_level_tracker[index].add(level_label)
                    index += 1

            # traverse tree
            root_dict = tree_dict['Root']

            for cpc_section, section_dict in root_dict.items():

                if cpc_section not in ['tp', 'tn', 'fp', 'fn']:
                    in_predicted = cpc_section in predicted_level_tracker[0]
                    in_true = cpc_section in true_level_tracker[0]
                    add_confusion_to_dict(section_dict, in_predicted, in_true)

                    for cpc_class, class_dict in section_dict.items():

                        if cpc_class not in ['tp', 'tn', 'fp', 'fn']:
                            in_predicted = cpc_class in predicted_level_tracker[1]
                            in_true = cpc_class in true_level_tracker[1]
                            add_confusion_to_dict(class_dict, in_predicted, in_true)

                            for cpc_subclass, subclass_dict in class_dict.items():

                                if cpc_subclass not in ['tp', 'tn', 'fp', 'fn']:
                                    in_predicted = cpc_subclass in predicted_level_tracker[2]
                                    in_true = cpc_subclass in true_level_tracker[2]
                                    add_confusion_to_dict(subclass_dict, in_predicted, in_true)
                                    
    # calculate tp fp tn fn
    metrics_tracker = [{'tp': 0, 'fp': 0, 
                        'tn': 0, 'fn': 0, 
                        'precisions': [], 'recalls': [], 'f1s': []} for _ in range(5)]
    
    root_dict = tree_dict['Root']

    for cpc_section, section_dict in root_dict.items():

        if cpc_section not in ['tp', 'tn', 'fp', 'fn']:
            update_metrics_tracker(metrics_tracker, 0, section_dict)

            for cpc_class, class_dict in section_dict.items():

                if cpc_class not in ['tp', 'tn', 'fp', 'fn']:
                    update_metrics_tracker(metrics_tracker, 1, class_dict)

                    for cpc_subclass, subclass_dict in class_dict.items():

                        if cpc_subclass not in ['tp', 'tn', 'fp', 'fn']:
                            update_metrics_tracker(metrics_tracker, 2, subclass_dict)
    
    for i in range(3):
        print("level {}:".format(i+1))
        level_metrics_tracker = metrics_tracker[i]
        print("  macro:")
        print("    precision:")
        pretty_print_describe(pd.Series(level_metrics_tracker['precisions']).describe())
        print("    recall:")
        pretty_print_describe(pd.Series(level_metrics_tracker['recalls']).describe())
        print("    f1:")
        pretty_print_describe(pd.Series(level_metrics_tracker['f1s']).describe())
        print("  micro:")

        tp = level_metrics_tracker['tp']
        tn = level_metrics_tracker['tn']
        fp = level_metrics_tracker['fp']
        fn = level_metrics_tracker['fn']
        precision = tp / (tp + fp)
        recall = tp / (tp + fn)
        f1 = (2 * precision * recall) / (precision + recall)

        print("    precision:\t{}".format(precision))
        print("    recall:\t{}".format(recall))
        print("    f1:\t\t{}".format(f1))
        print("")

In [136]:
print_evaluation_results(tree_path, predicted_path, truth_path)

level 1:
  macro:
    precision:
      mean:	0.18277834924597014
      std:	0.0950387671537092
      min:	0.053029507508138195
      50%:	0.17889946490189867
      max:	0.32995907458448176
    recall:
      mean:	0.9623123793287331
      std:	0.08405541456498014
      min:	0.7448377581120944
      50%:	0.996999322427645
      max:	1.0
    f1:
      mean:	0.29911400546192374
      std:	0.13534696721502898
      min:	0.09900990099009903
      50%:	0.30349391287134314
      max:	0.49619432791579793
  micro:
    precision:	0.20764828182306633
    recall:	0.9946957319666078
    f1:		0.343573648324804

level 2:
  macro:
    precision:
      mean:	0.04439217424789303
      std:	0.0901732811172928
      min:	0.0
      50%:	0.027574223766714513
      max:	1.0
    recall:
      mean:	0.7086050862040237
      std:	0.257600501579483
      min:	0.0
      50%:	0.7824427480916031
      max:	1.0
    f1:
      mean:	0.07595782480970638
      std:	0.09813426037780547
      min:	0.0
      50%:	0.05286940

In [137]:
print_evaluation_results(tree_path, "../../../../model2_predict.txt", truth_path)

level 1:
  macro:
    precision:
      mean:	0.8586884618775475
      std:	0.10046112566502963
      min:	0.5981203007518797
      50%:	0.891543178334185
      max:	0.9264069264069265
    recall:
      mean:	0.5688828683547849
      std:	0.17636600592400353
      min:	0.2964412148313769
      50%:	0.5916878172588832
      max:	0.7911487355336476
    f1:
      mean:	0.6731226963512075
      std:	0.15188459541653226
      min:	0.39641210913168057
      50%:	0.7125119388729703
      max:	0.8265323257766582
  micro:
    precision:	0.8600873600873601
    recall:	0.6345985989570511
    f1:		0.7303343292808572

level 2:
  macro:
    precision:
      mean:	0.7587013609338736
      std:	0.2165222688309608
      min:	0.0
      50%:	0.803448275862069
      max:	1.0
    recall:
      mean:	0.3543226008506042
      std:	0.21700703386334672
      min:	0.0
      50%:	0.3333333333333333
      max:	1.0
    f1:
      mean:	0.4535701773810979
      std:	0.22725942775417868
      min:	0.0
      50%:	0.462

In [138]:
print_evaluation_results(tree_path, "../../../../model2_predict_all.txt", truth_path)

level 1:
  macro:
    precision:
      mean:	0.6927069460817629
      std:	0.10431756023040202
      min:	0.4500542888165038
      50%:	0.7118055555555556
      max:	0.8060626549470363
    recall:
      mean:	0.7862272868566605
      std:	0.13090079065086824
      min:	0.5406185951183157
      50%:	0.8166787644164852
      max:	0.9231675953707672
    f1:
      mean:	0.734307782406566
      std:	0.10984436409922531
      min:	0.4911968850516336
      50%:	0.7443631039531479
      max:	0.8307781649245064
  micro:
    precision:	0.6993252514928693
    recall:	0.8269398625814104
    f1:		0.7577974783495955

level 2:
  macro:
    precision:
      mean:	0.6513888507830401
      std:	0.20501617252990478
      min:	0.0
      50%:	0.6756756756756757
      max:	1.0
    recall:
      mean:	0.4343691284571295
      std:	0.2357339277919186
      min:	0.0
      50%:	0.44397759103641454
      max:	1.0
    f1:
      mean:	0.4935344768787844
      std:	0.22264662083466916
      min:	0.0
      50%:	0.50

In [139]:
print_evaluation_results(tree_path, "../../../../predict_n100.txt", truth_path)

level 1:
  macro:
    precision:
      mean:	0.16909459840873053
      std:	0.10466109926800289
      min:	0.02552974214960429
      50%:	0.17223500383448367
      max:	0.32921666666666666
    recall:
      mean:	0.985401793804436
      std:	0.037980061941852196
      min:	0.8849557522123894
      50%:	1.0
      max:	1.0
    f1:
      mean:	0.2771141980498963
      std:	0.15302654835091023
      min:	0.04962779156327544
      50%:	0.29385746590246464
      max:	0.49535440673078124
  micro:
    precision:	0.18355077367350423
    recall:	0.9986347664555404
    f1:		0.31010391817204097

level 2:
  macro:
    precision:
      mean:	0.029591138316574467
      std:	0.08979679541176429
      min:	0.0
      50%:	0.015211879753712423
      max:	1.0
    recall:
      mean:	0.8241816900658034
      std:	0.20664932680477244
      min:	0.0
      50%:	0.8972602739726028
      max:	1.0
    f1:
      mean:	0.049283685323533166
      std:	0.09579173246168839
      min:	0.0
      50%:	0.0297471129951761