In [2]:
import numpy as np
import matplotlib.pyplot as plt
from scipy.special import softmax
import pickle
import pandas as pd
from entmax import sparsemax
import torch

In [135]:
# Get non-conformity scores
alpha = 0.1
pred_cal_path = '../../predictions/CIFAR10_cal_NLLLoss_softmax_proba.pickle'
pred_test_path = '../../predictions/CIFAR10_test_NLLLoss_softmax_proba.pickle'
true_cal_path = '../../predictions/CIFAR10_cal_true.pickle'
true_test_path = '../../predictions/CIFAR10_test_true.pickle'

def get_data(pred_cal_path, pred_test_path,true_cal_path, true_test_path):
    with open(pred_cal_path, 'rb') as f:
        pred_cal = pickle.load(f)
    with open(pred_test_path, 'rb') as f:
        pred_test = pickle.load(f)
    with open(true_cal_path, 'rb') as f:
        true_cal = pickle.load(f)
    with open(true_test_path, 'rb') as f:
        true_test = pickle.load(f)
    return pred_cal, pred_test, true_cal, true_test

In [136]:
pred_cal, pred_test, true_cal, true_test = get_data(pred_cal_path, 
                                                    pred_test_path,
                                                    true_cal_path, 
                                                    true_test_path)
n_test = pred_test.shape[0]
n_cal, n_classes = pred_cal.shape 

In [139]:
# Get calibration quantile
true_mask = true_cal.astype(bool)
cal_scores = 1 - pred_cal[true_mask]

q_level = np.ceil((n_cal+1)*(1-alpha))/n_cal
qhat = np.quantile(cal_scores, q_level, method = 'higher') # check quantile method

# test scores
test_scores = 1 - pred_test
#alternative
#test_scores = ((1 - pred_test)/(n_classes-pred_test.astype(bool).sum(axis=1).reshape((n_test,1))))
# qhat = test_scores<= qhat

In [140]:
qhat

0.8544148

In [132]:
def custom_score(pred_test,true_test):
    ranks = np.flip(pred_test.argsort(axis = 1),axis = 1).argsort()
    match = np.select(true_test.astype(bool).T,ranks.T)
    cond = ranks>np.expand_dims(match, axis=-1)
    k_y = np.select(true_test.astype(bool).T,pred_test.T)
    output = (pred_test-np.expand_dims(k_y, axis=-1))
    output[cond] = 0
    return output.sum(axis=1)

In [120]:
def all_custom_scores(pred_test):
    output = []
    for i in range(pred_test.shape[1]):
        true_test = np.zeros(pred_test.shape)
        true_test[:,i] = 1
        output.append(custom_score(pred_test,true_test)[None,:])
    return np.concatenate(output,axis=0).T
all_custom_scores(pred_test)

array([[0.99833435, 0.9997803 , 0.9999998 , ..., 0.9996498 , 0.9997507 ,
        0.9996964 ],
       [0.9999592 , 0.9996194 , 1.0000001 , ..., 0.9999988 , 0.        ,
        0.9999671 ],
       [0.99998856, 0.9999903 , 0.9999967 , ..., 0.9999982 , 0.        ,
        0.99999785],
       ...,
       [0.99999934, 0.9999993 , 0.9999837 , ..., 0.9995752 , 0.9999977 ,
        0.99999326],
       [0.9999949 , 0.        , 0.9999971 , ..., 0.9999978 , 0.99999833,
        0.9999969 ],
       [0.99997723, 0.9999985 , 0.999998  , ..., 0.        , 0.99999416,
        0.9999998 ]], dtype=float32)

In [142]:

def run_cp(pred_cal, pred_test, true_cal, true_test, alpha, plots = False, disallow_empty=False):
    def get_pvalue(preds):
                return np.array([((cal_scores>= el).sum() + 1)/(len(cal_scores) + 1) for el in preds])

    n_cal, n_classes = pred_cal.shape 
    n_test = true_test.shape[0]
    q_level = np.ceil((n_cal+1)*(1-alpha))/n_cal
    qhat = np.quantile(custom_score(pred_cal,true_cal), q_level, method = 'higher') # check quantile method
    
    test_scores = all_custom_scores(pred_test)
    test_match = test_scores<= qhat
    
    if disallow_empty:
        helper = np.zeros(pred_test[(test_match.sum(axis = 1)==0)].shape)
        helper[np.arange(helper.shape[0]),pred_test[(test_match.sum(axis = 1)==0)].argmax(axis = 1)]=1
        test_match[(test_match.sum(axis = 1)==0)] = helper
    # get p-values 
    test_pvalues = np.apply_along_axis(get_pvalue,1,test_scores)
    p_values_cal = get_pvalue(cal_scores)
    
    # Set size and scores distribution
    set_size = test_match.sum(axis = 1)
    if plots:   
        fig, axs = plt.subplots(1,2,figsize=(12,6))
        axs[0].hist(set_size)
        axs[0].vlines(set_size.mean(),0,max(np.histogram(set_size, bins=10)[0])+10, color='black')
        axs[0].text(set_size.mean()*1.02,max(np.histogram(set_size, bins=10)[0]-10)*0.95,  f'S = {set_size.mean()}', color='black',fontweight='bold')
        axs[0].set_title('Set Size Distribution')
        
        axs[1].hist(cal_scores)
        axs[1].vlines(qhat,0,max(np.histogram(cal_scores, bins=10)[0])+10, color='black')
        axs[1].text(qhat*1.02,max(np.histogram(cal_scores, bins=10)[0]-10)*0.95, f'q={qhat:.3f}', color='black',fontweight='bold')
        axs[1].set_title('Non-Conf Scores Distribution')
        plt.show()
    
    coverage = test_match[true_test.astype(bool)].sum()/n_test
    #print(f'Coverage:{coverage}')
    class_coverage = (test_match & true_test).sum(axis = 0)/true_test.sum(axis=0)
    
    set_size = test_match.sum(axis = 1)
    #print(f'Avg set size:{set_size.mean()}')
    class_size = true_test.copy()
    class_size[class_size==1]=test_match.sum(axis = 1)
    class_size = class_size.sum(axis=0)/true_test.sum(axis=0)
    
    if plots:
        # Class-wise metrics
        fig, axs = plt.subplots(1,2,figsize=(12,6))
        # add labels?
        axs[0].bar(np.arange(n_classes),class_coverage)
        axs[0].hlines(coverage,0,n_classes-1, color='black')
        axs[0].hlines(1-alpha,0,n_classes-1, color='green')
        axs[0].text(0,coverage, f'Emp. cov. = {coverage:.2f}', color='black',fontweight='bold')
        axs[0].text(0,1-alpha, f'Theo. cov. = {1-alpha:.2f}', color='green',fontweight='bold')
        axs[0].set_title('Class Conditional Coverage')
        
        
        axs[1].bar(np.arange(n_classes),class_size)
        axs[1].hlines(set_size.mean(),0,100, color='black')
        axs[1].text(0,set_size.mean(), f'S={set_size.mean():.3f}', color='black',fontweight='bold')
        axs[1].set_title('Class Avg Set size')
        
        plt.show()
    # Observed fuzziness
    of = np.ma.array(test_pvalues, mask = true_test).mean(axis=1).data.mean()
    #print(f'OF={of:.4f}')
    return test_match, coverage, set_size.mean(), qhat

In [161]:
test_match, coverage, mean_set_size, qhat = run_cp(pred_cal, pred_test, true_cal, true_test, alpha, plots = False)
print(coverage, mean_set_size)

0.9015 1.1326


In [162]:
beta = 1/qhat
sparse_pred = sparsemax(torch.tensor(pred_test)*beta, dim = -1)
sparse_pred = sparse_pred.numpy()
pred_match = sparse_pred>0
coverage = pred_match[true_test.astype(bool)].sum()/n_test
mean_set_size = pred_match.sum(axis = 1).mean()
print(coverage, mean_set_size)

0.9015 1.1326
