In [None]:
# import required packages
import torch
import torch.nn as nn
import torch.backends.cudnn as cudnn
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from models import *
import matplotlib.pyplot as plt
import numpy as np
import random
from matplotlib.patches import Patch
import pandas as pd
from PIL import Image
import io

device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(device)

In [None]:
class CustomImageDataset(Dataset):
    def __init__(self, csv_file=None, transform=None):
        if csv_file:
            self.data_frame = pd.read_csv(csv_file, delimiter=',')  # Read CSV file
        else:
            self.data_frame = pd.DataFrame()
        self.transform = transform

    def __len__(self):
        # return the total number of samples
        return len(self.data_frame)

    def __getitem__(self, idx):

        # get the binary image data and label
        image_bytes = self.data_frame.iloc[idx, 3]  # the image data is in the fourth column
        label = int(self.data_frame.iloc[idx, 2])  # the label is in the second column
        user_id = self.data_frame.iloc[idx, 0]

        # convert the binary data to an image
        png_binary = eval(image_bytes)  
        image = Image.open(io.BytesIO(png_binary)) 

        # apply transformations if any
        if self.transform:
            image = self.transform(image)

        return image, label, user_id
    
    def filter_indices_by_user(self, user_id):
        return self.data_frame[self.data_frame.iloc[:, 0] == user_id].index.tolist()

# define the transformations
transform = transforms.Compose([
    transforms.ToTensor() 
])

In [None]:
test_dataset = CustomImageDataset(
    csv_file='test.csv',
    transform=transform
)

In [None]:
net = LeNet62()

net = net.to(device)
if device == 'cuda':
    net = torch.nn.DataParallel(net)
    cudnn.benchmark = True

criterion = nn.CrossEntropyLoss()

checkpoint_path = "weights/LeNet_0.1_100_512_SGD"
checkpoint = torch.load(checkpoint_path, map_location=device)

# load the weights into the model
net.load_state_dict(checkpoint)
net.eval()

In [None]:
softmax = nn.Softmax(dim=1) # used for the computation of the loss

# for iteration
proba = 0
alphas_mins = []
Niter = 100
ratios = []

C = 3 # max target size of the conformal set
alphas = np.arange(0.01,0.31,0.01)

for iter in range(Niter):
    print(iter)

    # randomize calibration/final test sets
    indices = np.random.permutation(len(test_dataset.data_frame))
    calibration_indices = indices[:2000]
    final_test_indices = indices[2000:4000]
    final_test_size = len(final_test_indices)
    calibration_df = test_dataset.data_frame.iloc[calibration_indices]
    final_test_df = test_dataset.data_frame.iloc[final_test_indices]
    calibration_set = CustomImageDataset(csv_file=None, transform=transform)
    calibration_set.data_frame = calibration_df.reset_index(drop=True)
    final_test_set = CustomImageDataset(csv_file=None, transform=transform)
    final_test_set.data_frame = final_test_df.reset_index(drop=True)

    n_label = len(calibration_set)
    conformal_sets = []

    # compute scores for calibration samples
    scores_calibration = []

    with torch.no_grad():
        for idx in range(len(calibration_set)):
            x_sample, y_true, _ = calibration_set[idx]
            x_sample = x_sample.unsqueeze(0).to(device)
            y_true = torch.tensor([y_true], dtype=torch.long).to(device)
            logits = net(x_sample)
            score = criterion(logits, y_true).item()

            scores_calibration.append(score)

        sum_label = sum(scores_calibration)

        # sample one random element from the final test set
        random_idx = np.random.choice(final_test_size) 
        x_random, y_random, _ = final_test_set[random_idx]
        y_random = int(y_random) 
        true_label = y_random

        # convert the random test sample to a tensor
        x_random_tensor = x_random.unsqueeze(0).to(device)
        logits_random = net(x_random_tensor)

        model_prediction = torch.argmax(logits_random).item()

        for alpha in alphas:

            conformal_set = []

            for k in range(62): 
                with torch.no_grad():
                    true_label_tensor_random = torch.tensor([k], dtype=torch.long).to(device)
                    S = criterion(logits_random, true_label_tensor_random).item()

                ratio = (n_label + 1) * S / (sum_label + S)

                if ratio < 1 / alpha:
                    conformal_set.append(k)
            
            conformal_sets.append(conformal_set)

        # find minimal alpha such that Card(conformal set) <= C
        min_index = next((i for i, x in enumerate(conformal_sets) if len(x) <= C), -1)
        alpha_min = alphas[min_index]
        alphas_mins.append(alpha_min)
        
        if true_label in conformal_sets[min_index]:
            proba += 1
            ratios.append(0)
        else:
            ratios.append(1/alpha_min)

        ###########################################################
        # PLOT CONFORMAL SETS (binary matrix)
        ###########################################################

        # characters = [
        #     '0', '1', '2', '3', '4', '5', '6', '7', '8', '9',
        #     'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J',
        #     'K', 'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T',
        #     'U', 'V', 'W', 'X', 'Y', 'Z', 
        #     'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j',
        #     'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't',
        #     'u', 'v', 'w', 'x', 'y', 'z'
        # ]

        # fig, ax1 = plt.subplots(figsize=(15, 10))

        # binary_matrix = np.zeros((len(characters), len(alphas))) 
        # for col_idx, conformal_set in enumerate(conformal_sets):
        #     for label in conformal_set:
        #         binary_matrix[label, col_idx] = 1

        # im = ax1.imshow(binary_matrix, cmap="Blues", aspect="auto", interpolation="nearest")

        # for col_idx in range(len(alphas)):
        #     ax1.axvline(col_idx - 0.5, color="gray", linestyle="--", linewidth=0.5, alpha=0.7)

        # ax1.set_xticks(range(len(alphas)))
        # ax1.set_xticklabels([f"{alpha:.2f}" for alpha in alphas], fontsize=10)
        # ax1.set_xlabel(r"$\alpha$", fontsize=20)
        # ax1.set_yticks(range(len(characters)))
        # ax1.set_yticklabels(characters, fontsize=14)
        # ax1.set_ylabel("Labels", fontsize=20)

        # ax2 = ax1.twinx()
        # conformal_set_sizes = [len(conformal_set) for conformal_set in conformal_sets]
        # ax2.plot(range(len(alphas)), conformal_set_sizes, color="#E41A1C", label="Conformal Set Size", linewidth=3)
        # ax2.set_ylabel("Conformal Set Size", fontsize=14, color="#E41A1C")
        # ax2.tick_params(axis="y", labelcolor="#E41A1C", labelsize=12)

        # dark_blue = plt.cm.Blues(1.0)

        # merged_handles = [
        #     Patch(color=dark_blue, label="Conformal Set"),
        #     plt.Line2D([0], [0], color="#E41A1C", linestyle="-", linewidth=2, markersize=6, label="Conformal Set Size")
        # ]
        # ax1.legend(handles=merged_handles, loc="upper right", fontsize=16)

        # ax2.yaxis.set_major_locator(MultipleLocator(5)) 
        # ax2.yaxis.set_minor_locator(MultipleLocator(1))

        # ax2.set_ylim(0, 62)

        # plt.tight_layout()
        # plt.show()

        
proba = proba/Niter
alpha_mean = sum(alphas_mins)/Niter 
print(proba)
print(1-alpha_mean)

In [None]:
# plot histogram of 1-alpha
fig_alpha, ax_alpha = plt.subplots(figsize=(7, 4)) 
ax_alpha.hist(1-np.array(alphas_mins), color='green',edgecolor='black',alpha=0.6,bins=1-np.array(alphas[::-1]),align='left')
ax_alpha.axvline(x=proba, color='red', linestyle='--', linewidth=3)
ax_alpha.axvline(x=np.mean(1-np.array(alphas_mins)), color='black', linestyle='--', linewidth=3)
ax_alpha.set_xlabel(r"$1-\tilde{\alpha}$",fontsize=16)
ax_alpha.set_ylabel("Frequency",fontsize=16)

ax_alpha.set_xlim(0.8,1)
ax_alpha.set_ylim(0,45)

ax_alpha.tick_params(axis='both', labelsize=14)

ax_alpha.set_xticks(np.linspace(0.8, 1.0, 11))
ax_alpha.set_xticks(np.linspace(0.8, 1.0, 21),minor=True)

ax_alpha.grid(True, linestyle="--", alpha=0.4)

plt.show()