In [3]:
import numpy as np
import matplotlib.pyplot as plt
from scipy.special import softmax
import pickle
import pandas as pd
from entmax import sparsemax
import torch

In [2]:
# Get non-conformity scores
alpha = 0.1
pred_cal_path = '../../predictions/CIFAR10_cal_NLLLoss_softmax_proba.pickle'
pred_test_path = '../../predictions/CIFAR10_test_NLLLoss_softmax_proba.pickle'
true_cal_path = '../../predictions/CIFAR10_cal_true.pickle'
true_test_path = '../../predictions/CIFAR10_test_true.pickle'

def get_data(pred_cal_path, pred_test_path,true_cal_path, true_test_path):
    with open(pred_cal_path, 'rb') as f:
        pred_cal = pickle.load(f)
    with open(pred_test_path, 'rb') as f:
        pred_test = pickle.load(f)
    with open(true_cal_path, 'rb') as f:
        true_cal = pickle.load(f)
    with open(true_test_path, 'rb') as f:
        true_test = pickle.load(f)
    return pred_cal, pred_test, true_cal, true_test

In [6]:
pred_cal, pred_test, true_cal, true_test = get_data(pred_cal_path, 
                                                    pred_test_path,
                                                    true_cal_path, 
                                                    true_test_path)
n_test = pred_test.shape[0]
n_cal, n_classes = pred_cal.shape 

In [7]:
# Get calibration quantile
true_mask = true_cal.astype(bool)
cal_scores = 1 - pred_cal[true_mask]

q_level = np.ceil((n_cal+1)*(1-alpha))/n_cal
qhat = np.quantile(cal_scores, q_level, method = 'higher') # check quantile method

# test scores
test_scores = 1 - pred_test
#alternative
#test_scores = ((1 - pred_test)/(n_classes-pred_test.astype(bool).sum(axis=1).reshape((n_test,1))))
test_match = test_scores<= qhat

In [24]:
true_test.argmax(axis = 1)

array([3, 8, 8, ..., 5, 1, 7])

In [34]:
array = np.array([[4,2,7,1],
                   [10,2,5,4]])
ranks = np.flip(array.argsort(axis = 1),axis = 1).argsort()


In [35]:
ranks

array([[1, 2, 0, 3],
       [0, 3, 1, 2]])

In [39]:
true_test

array([[0, 0, 0, ..., 0, 0, 0],
       [0, 0, 0, ..., 0, 1, 0],
       [0, 0, 0, ..., 0, 1, 0],
       ...,
       [0, 0, 0, ..., 0, 0, 0],
       [0, 1, 0, ..., 0, 0, 0],
       [0, 0, 0, ..., 1, 0, 0]])

In [18]:
pred_test

array([[4.9429428e-04, 2.8393635e-05, 1.0591329e-08, ..., 5.5816861e-05,
        3.3313660e-05, 4.4173135e-05],
       [1.0369331e-05, 1.8023030e-04, 7.0902301e-10, ..., 1.7914086e-07,
        9.9979967e-01, 7.7196755e-06],
       [3.1236430e-06, 2.2516808e-06, 4.2893913e-07, ..., 1.9568512e-07,
        9.9999166e-01, 2.4531684e-07],
       ...,
       [7.6206739e-08, 8.6769283e-08, 2.9525172e-06, ..., 1.7726909e-04,
        3.0087048e-07, 1.0447869e-06],
       [1.0492793e-06, 9.9995935e-01, 3.9540984e-07, ..., 2.7276397e-07,
        1.9206888e-07, 4.2753959e-07],
       [8.4794256e-06, 1.9544952e-07, 2.8989442e-07, ..., 9.9998569e-01,
        1.2514633e-06, 1.3111909e-08]], dtype=float32)

In [36]:
ranks = np.flip(pred_test.argsort(axis = 1),axis = 1).argsort()
ranks

array([[2, 6, 9, ..., 3, 5, 4],
       [2, 1, 8, ..., 6, 0, 3],
       [1, 2, 5, ..., 7, 0, 6],
       ...,
       [8, 7, 4, ..., 1, 6, 5],
       [2, 0, 5, ..., 6, 7, 4],
       [1, 5, 4, ..., 0, 3, 9]])

In [8]:
betas = np.linspace(0.1,1.5,20)
coverages = []
avg_sizes = []
for beta in betas:
    sparse_pred = sparsemax(torch.tensor(pred_test)*beta, dim = -1)
    sparse_pred = sparse_pred.numpy()
    pred_match = sparse_pred>0
    coverages.append(pred_match[true_test.astype(bool)].sum()/n_test)
    avg_sizes.append(pred_match.sum(axis = 1).mean())