# Reqs

In [None]:
from sklearn import metrics
import tqdm
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from statistics import mean

# Global Vars

In [None]:
DATASET = "commonsense"
NUM_LABELS = 5
EPOCH_START = 5

# Function Defs

In [None]:
def get_risk_coverage_info(prob_list, em_list):
    num1 = int(len(prob_list)/2)
    num2 = len(prob_list) - num1
    sources = [0 for i in range(num1)]
    sources.extend([1 for i in range(num2)])
    assert len(sources) == len(prob_list)
    tuples = [(x,y,z) for x,y,z in zip(prob_list, em_list, sources)]
    sorted_tuples = sorted(tuples, key=lambda x: -x[0])
    sorted_probs = [x[0] for x in sorted_tuples]
    sorted_em = [x[1] for x in sorted_tuples]
    sorted_sources = [x[2] for x in sorted_tuples]
    total_questions = len(sorted_em)
    total_correct = 0
    covered = 0
    risks = []
    coverages = []

    for em, prob in zip(sorted_em, sorted_probs):
        covered += 1
        if em:
            total_correct += 1
        risks.append(1 - (total_correct/covered))
        coverages.append(covered/total_questions)        
    auc = round(metrics.auc(coverages, risks), 4)

    
    return risks, coverages, auc, sorted_sources, sorted_em, sorted_probs

def get_coverage_cutoff(risks, accuracy_cutoff):
    index = len(risks)
    while risks[index-1] >= (1.0-accuracy_cutoff) and index > 0:
        index -= 1
    return index

def auc_show(probs, correct, plot_graph=False):
    import matplotlib.pyplot as plt
    import numpy as np

    all_risks, all_coverages, all_aucs = [], [], []
    all_sorted_sources, all_sorted_em = [], []

    risks, coverages, auc, sorted_sources, sorted_em, sorted_probs = get_risk_coverage_info(probs, correct)

    all_risks.append(risks)
    all_coverages.append(coverages)
    all_aucs.append(auc)
    all_sorted_sources.append(sorted_sources)
    all_sorted_em.append(sorted_em)
    
    avg_risks = np.mean(all_risks, axis=0)
    avg_coverages = np.mean(all_coverages, axis=0)
    
    avg_auc = np.mean(all_aucs)

    

    values = list(np.arange(99.5, 1, -0.5))
    for i in values:
        index = get_coverage_cutoff(avg_risks, i/100) - 1
        cov = round((100 * avg_coverages[index]), 4)
        prob = sorted_probs[index]
        if(cov == 100):
            break

    if(plot_graph):
        plt.plot(coverages,risks)
        plt.show()
        
    return round(100*avg_auc, 2), coverages, risks


def get_auc(df, plot_graph=False):
    accuracy = round(100*df["correct"].mean(),2)
    achieved_auc, coverages, risks = auc_show(list(df["maxProb"]), list(df["correct"]), plot_graph=plot_graph)
    
    dev_len = df.shape[0] + 1
    ideal_probs_list = np.asarray(range(dev_len)[1:])/dev_len 
    em_list = df["correct"]
    em_list = np.sort(em_list)
    min_auc, min_coverages, min_risks = auc_show(list(ideal_probs_list), list(em_list))
    
    return (accuracy, achieved_auc, min_auc, coverages, risks)

# Calculate Accuracy and Selective Pred


In [None]:

results = []
num_labels = 5

for segment in ["baseline", "easy", "amb", "hard", "mixed"]:
    for epoch in range(EPOCH_START-1,5):
        df = pd.read_csv(f"/content/drive/MyDrive/NLP/data/{DATASET}/segments/{segment}/epoch_{epoch}_predictions.csv")
        cols = list(map(str, list(range(NUM_LABELS))))
        df["maxProb"] = df[cols].max(axis=1)
        df["correct"] = df["label"] == df["prediction"]
        (accuracy, achieved_auc, min_auc, coverages, risks) = get_auc(df)
        results.append([segment, epoch, accuracy, achieved_auc])

acc = pd.DataFrame(data=results, columns=["segment", "epoch", "accuracy", "selective_pred"])

In [None]:
acc

Unnamed: 0,segment,epoch,accuracy,selective_pred
0,baseline,4,38.9,51.78
1,easy,4,51.35,34.71
2,amb,4,34.07,61.3
3,hard,4,39.89,53.45
4,mixed,4,49.39,39.37


In [None]:
df = pd.read_csv(f"/content/drive/MyDrive/NLP/data/{DATASET}/segments/easy/epoch_4_predictions.csv")
df
# cols = list(map(str, list(range(NUM_LABELS))))
# cols

Unnamed: 0,label,prediction,0,1,2,3,4
0,0,3,0.000055,0.000018,0.003448,0.996437,0.000042
1,0,0,0.998663,0.001296,0.000012,0.000018,0.000011
2,1,1,0.000001,0.999988,0.000005,0.000003,0.000003
3,0,0,0.999899,0.000095,0.000002,0.000003,0.000001
4,0,1,0.000069,0.483494,0.265376,0.250912,0.000148
...,...,...,...,...,...,...,...
1216,0,3,0.108532,0.396002,0.010684,0.482796,0.001985
1217,1,1,0.026704,0.969329,0.000007,0.003953,0.000007
1218,2,2,0.000002,0.001485,0.998507,0.000003,0.000003
1219,0,0,0.909117,0.000006,0.000004,0.000007,0.090866
