### Evaluate Models


##### Imports

In [None]:
import torch
import pickle
import json
import multiprocessing
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import torch.nn.functional as F
from sklearn import metrics
from datasets import load_from_disk, Dataset
from transformers import AutoTokenizer

##### Evaluation Parameters

In [None]:
threshold = 0.5 # currently we don't maximize val f1 to find the threshold... need to grab scores for all the val sets if we do this
num_std = 1.96
num_bootstrap = 1000
line_width = 2
alpha = 0.2
font_size = 16
legend_size = 10
x_size = 10
y_size = 10

##### Initialize Score, Model, and Color Arrays

In [None]:
# Define master lists of labels, scores, names, and colors
all_y_trues, all_y_scores, all_model_names, all_colors = [], [], [], []

##### Load Fine-Tuned Torch LM Results

In [None]:
file_info = [('a', 'b', 'c'), ('x', 'y', 'z') ]
    
  
for label_file, score_file, model_name in file_info:  
    with open(label_file, "rb") as f:  
        labels = pickle.load(f)  
    with open(score_file, "rb") as f:  
        scores = pickle.load(f)
    
    # In the case of the 2048 model, get the score for the 1 label
    if "RoBERTa (2048)" in model_name:
        scores = scores[:,1]
      
    all_model_names.append(model_name)  
    all_y_trues.append(labels)  
    all_y_scores.append(scores)  

##### Define Recall at Precision Metric

In [None]:
def recall_at_precision(scores, labels, target_precision):
    
    # Compute precision-recall curve  
    precision, recall, thresholds = metrics.precision_recall_curve(labels, scores)  

    # Find the highest recall where precision >= target_precision  
    max_recall = recall[np.where(precision >= target_precision)].max()  

    return max_recall  

##### Define a Function to Print the Mean and Confidence Interval for a Given Metric

In [None]:
def print_mean_ci_of_metric_list(metric_list, metric_name, num_std):
    mean_metric = np.mean(metric_list)
    std_metric = np.std(metric_list)
    metric_low = np.maximum(mean_metric - std_metric * num_std, 0)
    metric_high = np.minimum(mean_metric + std_metric * num_std, 1)

    print(
        f"{metric_name}: {round(mean_metric, 3)} ([{round(metric_low, 3)} - {round(metric_high, 3)}] 95% CI)"
    )

##### Define a Function to Select a Threshold

In [None]:
def get_threshold_of_best_val_f1(val_scores, val_labels):
    
    # Find the best threshold by maximizing F1 score
    print("  Computing best threshold for F1 on validation set...")
    best_val_f1 = 0
    best_threshold = 0
    for int_threshold in range(0, 100, 1):
        threshold = int_threshold / 100
        sample_preds = [1 if x >= threshold else 0 for x in val_probs]
        f1 = metrics.f1_score(y_true=val_labels, y_pred=sample_preds)
        if f1 > best_val_f1:
            print(f"    Found new best F1 {f1:.4f} at threshold {threshold}")
            best_val_f1 = f1
            best_threshold = threshold
            
    return best_threshold

##### Print Performance for all Metrics for all Models

In [None]:
mean_fpr_linspace = np.linspace(0, 1, 100)
mean_recall_linspace = np.linspace(0, 1, 100)

model2metric_df = {}
for y_trues, y_scores, name in zip(
    all_y_trues, all_y_scores, all_model_names
):
    accuracies, recalls, precisions, aps, interp_ps, roc_aucs, interp_tprs, f1s, rs_at_p90, static_fprs, static_tprs = [], [], [], [], [], [], [], [], [], [], []
    for i in range(num_bootstrap):
        
        # Sample N records with replacement where N is the total number of records
        sample_indices = np.random.choice(len(y_trues), len(y_trues))
        sample_labels = np.array(y_trues)[sample_indices]
        sample_scores = np.array(y_scores)[sample_indices]
        
        # Generate thresholded prediction
        # threshold = get_threshold_of_best_val_f1(val_scores=y_val_scores, val_labels=y_val_trues)
        sample_preds = [1 if x >= threshold else 0 for x in sample_scores]

        accuracy = metrics.accuracy_score(y_true=sample_labels, y_pred=sample_preds)
        accuracies.append(accuracy)
        
#         recall = metrics.recall_score(y_true=sample_labels, y_pred=sample_preds)
#         recalls.append(recall)

#         precision = metrics.precision_score(y_true=sample_labels, y_pred=sample_preds)
#         precisions.append(precision)
        
#         f1 = metrics.f1_score(y_true=sample_labels, y_pred=sample_preds)
#         f1s.append(f1)
        
        ap = metrics.average_precision_score(y_true=sample_labels, y_score=sample_scores)
        aps.append(ap)
        
        p, r, thresholds = metrics.precision_recall_curve(y_true=sample_labels, probas_pred=sample_scores)
        interp_p = np.interp(mean_recall_linspace, np.fliplr([r])[0], np.fliplr([p])[0])
        interp_ps.append(interp_p)
        
        roc_auc = metrics.roc_auc_score(y_true=sample_labels, y_score=sample_scores)
        roc_aucs.append(roc_auc)
        
        fpr, tpr, _ = metrics.roc_curve(y_true=sample_labels, y_score=sample_scores)
        
        if 'GPT-4' in name or 'Text Gen' in name:
            static_fprs.append(fpr[1])
            static_tprs.append(tpr[1])
        else:
            static_fprs.append(None)
            static_tprs.append(None)
        
        interp_tpr = np.interp(mean_fpr_linspace, fpr, tpr)
        interp_tpr[0] = 0.0
        interp_tprs.append(interp_tpr)
        
        r_at_p90 = recall_at_precision(scores=sample_scores, labels=sample_labels, target_precision=0.9)
        rs_at_p90.append(r_at_p90)

        # "recalls": recalls,
        # "precisions": precisions,
        # "f1s": f1s,
        
    metric_df = pd.DataFrame({
        "aps": aps,
        "roc_aucs": roc_aucs,
    })
    model2metric_df[name] = metric_df

    print(f"\nResults for {name}\n")
    # print_mean_ci_of_metric_list(recalls, metric_name="Recall", num_std=num_std)
    # print_mean_ci_of_metric_list(precisions, metric_name="Precision", num_std=num_std)
    # print_mean_ci_of_metric_list(f1s, metric_name="F1", num_std=num_std)
    print_mean_ci_of_metric_list(aps, metric_name="Average Precision", num_std=num_std)
    print_mean_ci_of_metric_list(roc_aucs, metric_name="ROC AUC", num_std=num_std)
    
with open(f"./model2metric_df.pkl", "wb") as f:
    pickle.dump(model2metric_df, f)

In [None]:
model2metric_df = {k: v for k, v in model2metric_df.items() if 'Max' not in k}

In [None]:
def plot_mean_with_95_ci(ax, data, metric, condition):      
        
    metric_dict = {'aps': 'PR AUC', 'roc_aucs': 'ROC AUC'}    
    filtered_data = {k: v for k, v in data.items() if condition in k}      
          
    means = []      
    errors = []      
    colors = ['blue', 'green', 'red', 'cyan', 'magenta', 'yellow', 'black']      
    for model, df in filtered_data.items():      
        mean = df[metric].mean()      
        std = df[metric].std()      
        ci = 1.96 * std      
      
        means.append(mean)      
        errors.append(ci)      
      
    y_pos = np.arange(len(filtered_data))      
          
    for i, model in enumerate(filtered_data.keys()):      
        ax.barh(y_pos[i], means[i], xerr=errors[i], color=colors[i], capsize=10, label=f'M{i}: {map_model_name(model)}')      
      
    ax.set_yticks(y_pos)      
    ax.set_yticklabels(['M' + str(i) for i in range(len(filtered_data))])      
    ax.set_xlabel(metric_dict[metric])      
    ax.set_title(f'{metric_dict[metric]} for {condition} Prediction')      

conditions = ['x', 'y', 'z']      
metrics = ['aps', 'roc_aucs']      
    
fig, axs = plt.subplots(3, 2, figsize=(10, 12))      
    
for i, condition in enumerate(conditions):      
    for j, metric in enumerate(metrics):      
        plot_mean_with_95_ci(axs[i][j], model2metric_df, metric, condition)      
            
# Add a single legend for the entire plot      
handles, labels = axs[0][0].get_legend_handles_labels()      
fig.legend(handles, labels, loc='upper center', bbox_to_anchor=(0.5, 1.05),    
          ncol=len(handles), fancybox=True, shadow=True)    

# Add a single title for the entire plot  
fig.suptitle("Test Set Performance (1,000 Bootstrap Iterations)", fontsize=14, y=1.07)  
    
plt.tight_layout()    
plt.subplots_adjust(top=0.99)    
plt.show()  