In [None]:
import numpy as np
import matplotlib.pyplot as plt
import torchvision
import torchvision.transforms as transforms

# transformations for the test dataset
transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

# load the CIFAR-10H annotations
cifar10h_file = "data/cifar10h-probs.npy"  # download from https://github.com/jcpeterson/cifar-10h
cifar10h_probs = np.load(cifar10h_file)

# note: CIFAR-10H annotations correspond to CIFAR-10 test dataset

# load CIFAR-10 test dataset
testset = torchvision.datasets.CIFAR10(
    root='./data', train=False, download=True, transform=transform_test
)

# CIFAR-10 class names
class_names = [
    "airplane", "automobile", "bird", "cat", "deer",
    "dog", "frog", "horse", "ship", "truck"
]

# define a function to unnormalize and convert the image back to a NumPy array
def unnormalize_image(image_tensor):
    mean = np.array([0.4914, 0.4822, 0.4465])
    std = np.array([0.2023, 0.1994, 0.2010])
    image = image_tensor.numpy().transpose((1, 2, 0))  # convert from (C, H, W) to (H, W, C)
    image = std * image + mean  # unnormalize
    image = np.clip(image, 0, 1)  # clip to valid range [0, 1]
    return image

# plot image with its label distribution
def plot_image_with_distribution(image_tensor, label_distribution):
    # unnormalize and convert image to NumPy format
    image = unnormalize_image(image_tensor)
    
    # plot image and label distribution
    plt.figure(figsize=(6, 3))

    # image
    plt.subplot(1, 2, 1)
    plt.imshow(image)
    plt.axis('off')
    plt.title("Image")

    # label distribution
    plt.subplot(1, 2, 2)
    plt.bar(class_names, label_distribution, color='blue')
    plt.title("Label Distribution")
    plt.xticks(rotation=90)
    plt.tight_layout()

    plt.show()

# show examples for specific indices
for i in [151, 165]:  # examples with ambiguous ground truth
    image_tensor, _ = testset[i]
    plot_image_with_distribution(image_tensor, cifar10h_probs[i])

In [None]:
import torch

device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(device)

In [None]:
from models import *
import torch.backends.cudnn as cudnn

# model
net = EfficientNetB0()

net = net.to(device)
if device == 'cuda':
    net = torch.nn.DataParallel(net)
    cudnn.benchmark = True

weights_path = "weights/EfficientNetB0_0.1_100_512_SGD_1" 

# load weights into the model
net.load_state_dict(torch.load(weights_path))
print("Model weights loaded successfully!")

In [None]:
criterion = nn.CrossEntropyLoss() 

net.eval()

In [None]:
alpha = 0.3

testset_size = len(testset)

In [None]:
# keep the samples with ambiguous ground truth only
valid_indices = [i for i, dist in enumerate(cifar10h_probs) if np.sum(np.array(dist) > 0.1) >= 2]

# filter the test set and label distributions
filtered_testset = torch.utils.data.Subset(testset, valid_indices)
filtered_cifar10h_probs = cifar10h_probs[valid_indices]

# check the size of the filtered dataset
print(f"Filtered test set size: {len(filtered_testset)}")

# function to get CIFAR-10H label distributions for a subset
def get_filtered_cifar10h_distributions(subset):
    # retrieve the label distributions for the valid indices in the subset
    return [filtered_cifar10h_probs[idx] for idx in subset.indices]

filtered_calibration_size = int(0.3 * len(filtered_testset))
filtered_final_test_size = len(filtered_testset) - filtered_calibration_size

In [None]:
from torch.utils.data import random_split

list_results = [] # list of 2 dictionaries, for m=1 and m=20
# each dict has coverage and length averaged over Niter values

for m in [1,20]:
    print(f'### m = {m}')

    list_coverage = []
    list_coverage_p = []
    list_avg_len = []
    list_avg_len_p = []
    list_avg_len_correct = []
    list_avg_len_correct_p = []
    list_avg_len_wrong = []
    list_avg_len_wrong_p = []

    Niter = 200

    for iter in range(Niter):
        print(iter)

        # split the filtered test set
        filtered_calibration_set, filtered_final_test_set = random_split(
            filtered_testset,
            [filtered_calibration_size, filtered_final_test_size],
            generator=torch.Generator().manual_seed(iter)
        )

        # get label distributions for the filtered splits
        filtered_calibration_label_distributions = get_filtered_cifar10h_distributions(filtered_calibration_set)
        filtered_final_test_label_distributions = get_filtered_cifar10h_distributions(filtered_final_test_set)

        # store the scores
        calibration_scores = []

        # loop through each sample in the calibration set
        with torch.no_grad(): 
            for i in range(filtered_calibration_size):
                image, _ = filtered_calibration_set[i]  # get individual sample
                image = image.unsqueeze(0).to(device)
                label_distribution = filtered_calibration_label_distributions[i]
                output = net(image)
                for j in range(m):
                    # draw an index based on the probabilities
                    label = np.random.choice(len(label_distribution), p=label_distribution)
                    label = torch.tensor(label).unsqueeze(0).to(device)
                    loss = criterion(output, label)  # compute loss
  
                    calibration_scores.append(loss.item()) 

        # initialize the sums list with zeros
        calibration_sums = [0] * m

        # compute the sums for all j
        for j in range(m):
            calibration_sums[j] = sum(calibration_scores[j + i * m] for i in range(len(calibration_scores) // m))

        quantile = np.quantile(calibration_scores, (np.ceil(m*(1-alpha)*(filtered_calibration_size+1))-1)/(m*filtered_calibration_size))

        # initialize an empty list to store the conformal sets
        conformal_sets = []
        pconformal_sets = []

        # loop through each sample in the final test set
        with torch.no_grad(): 
            for i in range(filtered_final_test_size):
                image, _ = filtered_final_test_set[i]  # get individual sample
                image = image.unsqueeze(0).to(device)                
                output = net(image) 

                conformal_set = [] # conformal set for final_test_set[i]
                pconformal_set = []

                for y in range(10): # possible labels y

                    possible_label = torch.tensor(y).unsqueeze(0).to(device)

                    loss = criterion(output, possible_label)

                    # compute e-variable
                    e_var = 0
                    for k in range(m):
                        e_var += (filtered_calibration_size+1)/m*loss / (calibration_sums[k] + loss)

                    # conformal e-prediction criterion
                    if e_var < 1/alpha :
                        conformal_set.append(y)

                    # conformal p-prediction criterion
                    if loss <= quantile:
                        pconformal_set.append(y)
                
                conformal_sets.append(conformal_set)
                pconformal_sets.append(pconformal_set)

        # compute coverage, conformal set size (size=length) 
        coverage = 0
        coverage_p = 0
        avg_len = 0
        avg_len_p = 0

        for i in range(filtered_final_test_size):
            _, true_label = filtered_final_test_set[i]  # get individual sample
            label_distribution = filtered_final_test_label_distributions[i]
            label = np.random.choice(len(label_distribution), p=label_distribution)

            if label in conformal_sets[i]:
                coverage += 1
            if label in pconformal_sets[i]:
                coverage_p += 1
            avg_len += len(conformal_sets[i])
            avg_len_p += len(pconformal_sets[i])

        coverage = coverage / filtered_final_test_size
        coverage_p = coverage_p / filtered_final_test_size
        avg_len = avg_len / filtered_final_test_size
        avg_len_p = avg_len_p / filtered_final_test_size

        list_coverage.append(coverage)
        list_coverage_p.append(coverage_p)
        list_avg_len.append(avg_len)
        list_avg_len_p.append(avg_len_p)

        avg_len_correct = 0
        avg_len_correct_p = 0
        nb_correct = 0

        avg_len_wrong = 0
        avg_len_wrong_p = 0
        nb_wrong = 0

        for i in range(filtered_final_test_size):
            image, true_label = filtered_final_test_set[i]  # get individual sample
            image = image.unsqueeze(0).to(device)
            true_label = torch.tensor(true_label).unsqueeze(0).to(device)   
            output = net(image) 
            prediction = torch.argmax(output)
            if prediction == true_label:
                nb_correct += 1
                avg_len_correct += len(conformal_sets[i])
                avg_len_correct_p += len(pconformal_sets[i])
            else:
                nb_wrong += 1
                avg_len_wrong += len(conformal_sets[i])
                avg_len_wrong_p += len(pconformal_sets[i])
        avg_len_correct = avg_len_correct / nb_correct
        avg_len_correct_p = avg_len_correct_p / nb_correct
        avg_len_wrong = avg_len_wrong / nb_wrong
        avg_len_wrong_p = avg_len_wrong_p / nb_wrong

        list_avg_len_correct.append(avg_len_correct)
        list_avg_len_correct_p.append(avg_len_correct_p)
        list_avg_len_wrong.append(avg_len_wrong)
        list_avg_len_wrong_p.append(avg_len_wrong_p)

    # save
    results = {
        "list_coverage": list_coverage,
        "list_coverage_p": list_coverage_p,
        "list_avg_len": list_avg_len,
        "list_avg_len_p": list_avg_len_p,
        "list_avg_len_correct": list_avg_len_correct,
        "list_avg_len_correct_p": list_avg_len_correct_p,
        "list_avg_len_wrong": list_avg_len_wrong,
        "list_avg_len_wrong_p": list_avg_len_wrong_p,
    }

    list_results.append(results)

In [None]:
import os
import pickle

# save results
save_dir = "output"
os.makedirs(save_dir, exist_ok=True) 

with open(os.path.join(save_dir, "hist.pkl"), "wb") as f:
    pickle.dump(list_results, f)

print(f"Results saved in {os.path.join(save_dir, 'hist.pkl')}")