In [1]:
# This notebook computes and plots the results for Table 2 in the paper


import sys
sys.path.append("../../")

import matplotlib.pyplot as plt
import numpy as np
from expected_cost import ec, utils
from expected_cost.data import get_llks_for_multi_classif_task
from sklearn.metrics import roc_curve, roc_auc_score, f1_score, precision_recall_curve, accuracy_score, recall_score, precision_score
from expected_cost.calibration import calibration_with_crossval, calibration_train_on_test
from expected_cost.psrcal_wrappers import Brier, LogLoss, LogLossSE, ECE, ECEbin, L2ECEbin, CalLoss
import re

outdir = "outputs/"
utils.mkdir_p(outdir)

In [2]:
p_agnews = 0.01
p_iemocap = 0.05
data_dir = '../data/'

datasets = {'SST2 GPT2-4sh':         'sst2_gpt2_4shot',
            'SST2 GPT2-0sh':         'sst2_gpt2',
            'SITW XvPLDA':           'sitw_plda',
            'FVCAUS XvPLDA':         'fvcaus_plda',
            'CIFAR-1vsO Resnet-20': ('cifar100_resnet-20/', None, 1),
            'CIFAR-2vsO Resnet-20': ('cifar100_resnet-20/', None, 2),
            'IEMOCAP W2V2':          'iemocap_wav2vec_pt',
            'AGNEWS GPT2-0sh':       'agnews_gpt2',
            'CIFAR10 Resnet-20':     'cifar10_resnet-20/',
            'CIFAR10 Vgg19':         'cifar10_vgg19_bn/',
            'CIFAR10 RepVgg-a2':     'cifar10_repvgg_a2/',
            'CIFAR100 Resnet-20':    'cifar100_resnet-20/',
            'CIFAR100 Vgg19':        'cifar100_vgg19_bn/',
            'CIFAR100 RepVgg-a2':    'cifar100_repvgg_a2/'}

# Field separators for printing to screen or in latex format for the paper
print_style = ''  # 'latex'

sep = '  ' 
sep2 = '|' 
newline = ''

first = True
targets_dict = {}
seed = 0

gammas = np.arange(0.0, 1.1, 0.05)

colors = {'raw': 'k', 'TScal': 'r', 'DPcal': 'b', 'OPcal': 'g'}

for dname, dinfo in datasets.items():
    
    data_name = re.sub(' .*','',dname)
    system_name = re.sub('.* ','',dname)

    if type(dinfo) == tuple:
        (dpath, priors, one_vs_other) = dinfo
        targets, logpost_raw, _ = get_llks_for_multi_classif_task(data_dir+dpath, logpost=True, priors=priors, one_vs_other=one_vs_other)
    else:
        dpath = dinfo
        targets, logpost_raw, _ = get_llks_for_multi_classif_task(data_dir+dpath, logpost=True)
    
    targets_dict[dname] = targets

    K = logpost_raw.shape[1]
    counts = np.bincount(targets)
    priors = counts/len(targets)

    # Obtain three calibrated versions
    logpost_dict = {'raw': logpost_raw}
    logpost_dict['TScal'] = calibration_with_crossval(logpost_raw, targets, seed=seed, calparams={'bias':False})
    logpost_dict['DPcal'] = calibration_with_crossval(logpost_raw, targets, seed=seed, calparams={'bias':True})
    if K == 2:
        logpost_dict['OPcal'] = calibration_train_on_test(logpost_raw, targets, calmethod='PAV')

    # Define various 0-1 costs matrices with abstention option 
    cost_no_rej = ec.CostMatrix.zero_one_costs(K)
    costs = {}
    for gamma in gammas:
        costs[f'Cab{gamma:.2f}'] = ec.CostMatrix.zero_one_costs(K, abstention_cost=gamma) #ec.CostMatrix(np.c_[cost_no_rej.get_matrix(), gamma*np.ones(K)])

    if first is True:
        print(f'\nSystem                         {sep} {sep2} ', end='')
        for costn, cost in costs.items():
            print(f" {costn:5s}{sep} ", end='')
        print(f'{sep2}     CE  {sep} {sep2}  RCL   {sep} ECEbin  {sep} ECEmc {newline}')        
        first = False

    ce_raw = LogLoss(logpost_raw, targets)

    fig, axes = plt.subplots(1,4, figsize=(14,5))
    for logpostname, logpost in logpost_dict.items():

        print(f'{dname:25s}{logpostname:5s} {sep} {sep2} ', end='')
        
        ecval_vector = []
        ecvaln_vector = []
        ecval_sel_vector = []
        coverage_vector = []
        # FPR and FNR for the missclassification detection problem. Each vector will have one value per gamma.
        md_fpr = []
        md_fnr = []

        # Decisions without a reject option
        decisions_no_rej, _ = ec.bayes_decisions(logpost, cost_no_rej, score_type='log_posteriors')

        # Targets for the missclassification detection problem
        targets_rej = decisions_no_rej != targets

        # This is the "s" function. It should be generalized for any possible cost function.
        # The only below only works for the 0-1 cost matrix
        confidences = np.max(logpost, axis=1)

        for gamma in gammas:
            costn = f'Cab{gamma:.2f}'
            cost = costs[costn]
            decisions, _ = ec.bayes_decisions(logpost, cost, score_type='log_posteriors')
            ecvaln = ec.average_cost(targets, decisions, cost, adjusted=True)
            ecval  = ec.average_cost(targets, decisions, cost, adjusted=False)
            ecvaln_vector.append(ecvaln)
            ecval_vector.append(ecval)
            print(f"   {ecvaln:6.3f}{sep}", end='')

            # EC only for the samples that were not rejected
            sel = decisions != K
            targets_sel = targets[sel]
            decisions_sel = decisions[sel]
            coverage_vector.append(np.sum(sel)/len(decisions))
            if np.sum(sel) > 0:
                ecval_sel_vector.append(ec.average_cost(targets_sel, decisions_sel, cost_no_rej, adjusted=False))
            else:
                ecval_sel_vector.append(0.0)

            # When decisions are equal to the last class, the system decided to reject
            # These are the decisions for the pseudo-ROC
            dec_pos = decisions == K
            dec_neg = decisions != K
            tot_pos = np.sum(targets_rej==True)
            tot_neg = np.sum(targets_rej==False)
            md_fpr.append(np.sum(targets_rej[dec_pos]==False)/tot_neg)
            md_fnr.append(np.sum(targets_rej[dec_neg]==True)/tot_pos)

        ce = LogLoss(logpost, targets)

        # Finally, compute two calibration metrics
        logpost_cal_axv = calibration_with_crossval(logpost, targets, seed=seed, calparams={'bias':True})
        rcl = CalLoss(LogLoss, logpost, logpost_cal_axv, targets, relative=True)
        ece = ECE(logpost, targets)
        if K==2:
            ecebin = ECEbin(logpost, targets)
        
        print(f"{sep2}   {ce:5.3f}  {sep}", end='')

        if K==2:
            print(f"{sep2} {rcl:5.1f} {sep}  {ece:5.1f} {sep} {ecebin:5.1f}  ", end='')
        else:
            print(f"{sep2} {rcl:5.1f} {sep}  {ece:5.1f} {sep}  -  ", end='')

        print(f'{newline}')
    
        color = colors[logpostname]

        # gamma vs EC
        ax = axes[0]
        ax.plot(gammas, ecval_vector, color+'-', label=f'{logpostname} (NCE = {ce:.3f})')
        ax.plot(gammas, ecval_sel_vector, color+'--')
        ax.legend()
        ax.set_xlabel("Gamma")
        ax.set_ylabel("EC (solid) / EC sel (dashed)")
        
        # gamma vs NEC
        ax = axes[1]
        ax.plot(gammas, ecvaln_vector, color+'-')
        ax.plot(gammas, coverage_vector, color+'--')
        ax.set_xlabel("Gamma")
        ax.set_ylabel("NEC (solid) / Coverage (dashed)")
        
        # coverage vs risk
        ax = axes[2]
        ax.plot(coverage_vector, ecval_sel_vector, color+'-')
        auc_rc = np.mean(ecval_sel_vector)
        ax.set_xlabel("Coverage")
        ax.set_ylabel("EC sel")

        # ROC
        ax = axes[3]
        ax.plot(md_fpr, md_fnr, color+':')
        fpr, tpr, _ = roc_curve(targets_rej, 1-confidences)
        auc_roc = roc_auc_score(targets_rej, 1-confidences)
        ax.plot(fpr, 1-tpr, color+'-', label=auc_roc)
        ax.legend()
        ax.set_xlabel("MD FPR")
        ax.set_ylabel("MD FNR")

    fig.suptitle(dname, fontsize=10)
    plt.tight_layout()
    plt.savefig(f'{outdir}/{dname}.pdf')
        

    print("")





System                            |  Cab0.00    Cab0.05    Cab0.10    Cab0.15    Cab0.20    Cab0.25    Cab0.30    Cab0.35    Cab0.40    Cab0.45    Cab0.50    Cab0.55    Cab0.60    Cab0.65    Cab0.70    Cab0.75    Cab0.80    Cab0.85    Cab0.90    Cab0.95    Cab1.00    Cab1.05   |     CE     |  RCL      ECEbin     ECEmc 
SST2 GPT2-4sh            raw      |       nan      0.974      0.812      0.718      0.888      1.202      1.333      1.296      1.198      1.097      0.996      0.996      0.996      0.996      0.996      0.996      0.996      0.996      0.996      0.996      0.996      0.996  

  return ave_cost / norm_value
  return ave_cost / norm_value


|   1.073    |  62.4      34.4     34.8  
SST2 GPT2-4sh            TScal    |       nan      1.000      1.000      1.000      1.000      0.970      0.822      0.648      0.871      1.047      0.996      0.996      0.996      0.996      0.996      0.996      0.996      0.996      0.996      0.996      0.996      0.996  |   0.931    |  56.4      31.5     31.9  
SST2 GPT2-4sh            DPcal    |       nan      0.718      0.585      0.496      0.420      0.387      0.335      0.306      0.282      0.252      0.226      0.226      0.226      0.226      0.226      0.226      0.226      0.226      0.226      0.226      0.226      0.226  |   0.404    |  -0.6       1.6      2.2  
SST2 GPT2-4sh            OPcal    |       nan      0.673      0.555      0.470      0.405      0.363      0.325      0.297      0.272      0.247      0.223      0.223  

  return ave_cost / norm_value
  return ave_cost / norm_value


    0.223      0.223      0.223      0.223      0.223      0.223      0.223      0.223      0.223      0.223  |   0.385    |  -0.4       0.0      0.0  

SST2 GPT2-0sh            raw      |       nan      0.933      0.818      0.842      0.893      0.936      0.972      0.977      0.952      0.895      0.828      0.828      0.828      0.828      0.828      0.828      0.828      0.828      0.828      0.828      0.828      0.828  |   0.917    |  46.0      20.0     27.3  
SST2 GPT2-0sh            TScal    |       nan      1.000      0.995      0.943      0.843      0.745      0.732      0.779      0.852      0.877      0.828      0.828      0.828      0.828      0.828      0.828      0.828      0.828      0.828  

  return ave_cost / norm_value
  return ave_cost / norm_value


    0.828      0.828      0.828  |   0.858    |  42.1      18.9     26.9  
SST2 GPT2-0sh            DPcal    |       nan      0.811      0.655      0.602      0.528      0.490      0.435      0.399      0.360      0.333      0.310      0.310      0.310      0.310      0.310  

  return ave_cost / norm_value
  return ave_cost / norm_value


    0.310      0.310      0.310      0.310      0.310      0.310      0.310  |   0.495    |  -0.6       1.4      1.5  
SST2 GPT2-0sh            OPcal    |       nan      0.757      0.638      0.580      0.522      0.470      0.429      0.393      0.356      0.325      0.298      0.298      0.298      0.298      0.298      0.298      0.298      0.298      0.298      0.298      0.298      0.298  |   0.478    |  -0.4       0.0      0.0  

SITW XvPLDA              raw      | 

  return ave_cost / norm_value


      nan      0.180      0.225      0.256      0.273      0.288      0.297      0.305      0.311      0.323      0.324      0.324      0.324      0.324      0.324      0.324      0.324      0.324      0.324      0.324      0.324      0.324  |   0.189    |  16.7       0.2      0.2  
SITW XvPLDA              TScal    | 

  return ave_cost / norm_value


      nan      0.152      0.202      0.235      0.259      0.277      0.291      0.303      0.315      0.322      0.324      0.324      0.324      0.324      0.324      0.324      0.324      0.324      0.324      0.324      0.324      0.324  |   0.163    |   3.2       0.0      0.1  
SITW XvPLDA              DPcal    | 

  return ave_cost / norm_value


      nan      0.145      0.190      0.225      0.251      0.268      0.289      0.297      0.303      0.304      0.306      0.306      0.306      0.306      0.306      0.306      0.306      0.306      0.306      0.306      0.306      0.306  