In [None]:
# import required packages
import torch
import torch.nn as nn
import torch.backends.cudnn as cudnn
from torch.utils.data import Dataset
from torchvision import transforms
from models import *
import matplotlib.pyplot as plt
import numpy as np
import random
from matplotlib.patches import Patch
import pandas as pd
from PIL import Image
import io

%matplotlib inline

device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(device)

In [None]:
class CustomImageDataset(Dataset):
    def __init__(self, csv_file=None, transform=None):
        if csv_file:
            self.data_frame = pd.read_csv(csv_file, delimiter=',') 
        else:
            self.data_frame = pd.DataFrame()
        self.transform = transform

    def __len__(self):
        # return the total number of samples
        return len(self.data_frame)

    def __getitem__(self, idx):

        # get the binary image data and label
        image_bytes = self.data_frame.iloc[idx, 3]  # the image data is in the fourth column
        label = int(self.data_frame.iloc[idx, 2])  # the label is in the second column
        user_id = self.data_frame.iloc[idx, 0]

        # convert the binary data to an image
        png_binary = eval(image_bytes) 
        image = Image.open(io.BytesIO(png_binary)) 

        # apply transformations if any
        if self.transform:
            image = self.transform(image)

        return image, label, user_id
    
    def filter_indices_by_user(self, user_id):
        return self.data_frame[self.data_frame.iloc[:, 0] == user_id].index.tolist()

# define the transformations
transform = transforms.Compose([
    transforms.ToTensor()
])

In [None]:
test_dataset = CustomImageDataset(
    csv_file='test.csv',
    transform=transform
)

In [None]:
net = LeNet62()

net = net.to(device)
if device == 'cuda':
    net = torch.nn.DataParallel(net)
    cudnn.benchmark = True

criterion = nn.CrossEntropyLoss()

checkpoint_path = "weights/LeNet_0.1_100_512_SGD" 
checkpoint = torch.load(checkpoint_path, map_location=device)

# load the weights into the model
net.load_state_dict(checkpoint)
net.eval()

In [None]:
n_writers = 50
alpha = 0.15
softmax = nn.Softmax(dim=1) # used for the computation of the loss

In [None]:
Niter = 100
# initialize lists of size Niter
total_lengths = []
total_ratios = []
total_cumul_ratios = []
list_covered = [] # append 1 if all conformal sets (t=1,...,T) contain the predictions

value_counts = test_dataset.data_frame.iloc[:, 0].value_counts()
random_writers = random.sample(sorted(value_counts.keys()), n_writers)

calibration_size = int(0.5 * len(test_dataset))
final_test_size = len(test_dataset) - calibration_size

for iter in range(Niter):
    print(iter)

    # randomize calibration/final test sets
    indices = np.random.permutation(len(test_dataset.data_frame))
    calibration_indices = indices[:calibration_size]
    final_test_indices = indices[calibration_size:]
    calibration_df = test_dataset.data_frame.iloc[calibration_indices]
    final_test_df = test_dataset.data_frame.iloc[final_test_indices]
    calibration_set = CustomImageDataset(csv_file=None, transform=transform)
    calibration_set.data_frame = calibration_df.reset_index(drop=True)
    final_test_set = CustomImageDataset(csv_file=None, transform=transform)
    final_test_set.data_frame = final_test_df.reset_index(drop=True)

    # initialize variables to store results
    conformal_sets = []      # store conformal sets
    cumul_ratios = []        # store martingale (cumulative ratio)
    ratios = []
    model_predictions = []
    true_labels = []
    cumul_ratio = 1

    # for computing coverage
    covered = 1 # change to 0 if one of the conformal sets does not contain the true label

    # iterate through writers (= batches)
    for label in range(n_writers):
        # step 1: compute scores for calibration samples with y_cal = label
        label_indices=calibration_set.filter_indices_by_user(random_writers[label])
        n_label = len(label_indices)
        scores_calibration = []
        with torch.no_grad():
            for idx in label_indices:
                x_sample, y_true, _ = calibration_set[idx]
                x_sample = x_sample.unsqueeze(0).to(device)
                y_true = torch.tensor([y_true], dtype=torch.long).to(device)

                # model output and cross-entropy score
                logits = net(x_sample)
                score = 1/torch.log(1+softmax(logits).squeeze(0)[y_true])**(1/4)
                scores_calibration.append(score)

        sum_label = sum(scores_calibration)

        # step 2: sample one random element from the final test set with y_finaltest = label
        test_indices_with_label = final_test_set.filter_indices_by_user(random_writers[label])
        random_idx = np.random.choice(test_indices_with_label)  # randomly select one index
        x_random, y_random, _ = final_test_set[random_idx]
        y_random = int(y_random)
        true_labels.append(y_random)

        # step 3: compute conformal set for the random sample
        conformal_set = []
        x_random_tensor = x_random.unsqueeze(0).to(device)
        logits_random = net(x_random_tensor)
        model_prediction = torch.argmax(logits_random).item()
        model_predictions.append(model_prediction)

        # compute the conformal set
        for k in range(62):  # iterate over all possible classes
            with torch.no_grad():
                true_label_tensor_random = torch.tensor([k], dtype=torch.long).to(device)
                # compute the score S for the class k
                S = 1/torch.log(1+softmax(logits_random).squeeze(0)[true_label_tensor_random])**(1/4)
            # compute the ratio
            ratio = (n_label + 1) * S / (sum_label + S)
            # check if the ratio satisfies the conformal condition
            if cumul_ratio * ratio < 1 / alpha:
                conformal_set.append(k)
            # update the true ratio if k matches the random sample's true label
            if k == y_random:
                cumul_ratio *= ratio  # accumulate the product
                ratios.append(ratio.item())
                cumul_ratios.append(cumul_ratio.item())

        # store results
        conformal_sets.append(conformal_set)

        # coverage
        if y_random not in conformal_set: 
            covered = 0
    
    list_covered.append(covered)

    print("Lengths:", [len(cs) for cs in conformal_sets])
    print("Conformal sets for each label:", conformal_sets)
    print("Model predictions:", model_predictions)
    print("True labels:", true_labels)
    print("Cumulative ratios:", cumul_ratios)
    print("Individual ratios:", ratios)

    total_lengths.append([len(cs) for cs in conformal_sets])
    total_ratios.append(ratios)
    total_cumul_ratios.append(cumul_ratios)

    ###########################################################
    # PLOT CONFORMAL SETS (binary matrix)
    ###########################################################

    # # convert conformal sequence to a binary matrix
    # num_labels = 62
    # num_steps = len(conformal_sets)
    # binary_matrix = np.zeros((num_labels, num_steps))  # time on X-axis
    # for t, C_t in enumerate(conformal_sets):
    #     for label in C_t:
    #         binary_matrix[label, t] = 1

    # # plot binary matrix
    # plt.figure(figsize=(15, 10))
    # plt.imshow(binary_matrix, cmap="Greys", aspect="auto", interpolation="nearest")

    # # overlay true labels (red stars) and predictions (green triangles)
    # for t, (true_label, prediction) in enumerate(zip(true_labels, model_predictions)):
    #     plt.scatter(t, true_label, color="red", marker="*", s=150, label="True Label" if t == 0 else "")
    #     plt.scatter(t, prediction, color="green", marker="^", s=100, label="Prediction" if t == 0 else "")

    # # labels and formatting
    # plt.xlabel(r"Time Steps $t$", fontsize=12)
    # plt.ylabel("Labels", fontsize=12)
    # plt.xticks(range(num_steps), labels=np.arange(1,num_steps+1)) # start at t=1 and not t=0
    # characters = [
    #     '0', '1', '2', '3', '4', '5', '6', '7', '8', '9',
    #     'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J',
    #     'K', 'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T',
    #     'U', 'V', 'W', 'X', 'Y', 'Z', 
    #     'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j',
    #     'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't',
    #     'u', 'v', 'w', 'x', 'y', 'z'
    # ]
    # plt.yticks(range(num_labels), labels=characters)
    # conformal_patch = Patch(color="black", label="Conformal Set")
    # true_label_marker = plt.Line2D([0], [0], color="red", marker="*", linestyle="None", markersize=10, label="True Label")
    # prediction_marker = plt.Line2D([0], [0], color="green", marker="^", linestyle="None", markersize=8, label="Prediction")
    # plt.legend(
    #     handles=[conformal_patch, true_label_marker, prediction_marker],
    #     loc="upper right",
    #     fontsize=10
    # )
    # plt.grid(axis="y", linestyle="--", alpha=0.5)
    # plt.tight_layout()
    # plt.show()

    ###########################################################
    # PLOT MARTINGALE
    ###########################################################

    # plt.figure(figsize=(8, 5))
    # plt.plot(range(1, 50 + 1), cumul_ratios, marker="o", linestyle="-", color="blue")

    # plt.xlabel(r"Time steps $t$", fontsize=14)
    # plt.ylabel(r"Martingale $M_t$", fontsize=14)
    # plt.xticks(range(1, 50 + 1, 2), fontsize=12) 
    # plt.xticks(range(1, 50 + 1), fontsize=12, minor=True) 

    # plt.yticks(fontsize=12)
    # plt.grid(True, linestyle="--", alpha=0.6)

    # plt.xlim(1,50)

    # plt.tight_layout()
    
    # plt.show()

In [None]:
# plot histograms of conformal set sizes for each batch
fig, axes = plt.subplots(10, 5, figsize=(15, 20), constrained_layout=True)

# flatten the axes array for easy iteration
axes = axes.ravel()

for i in range(50):

    lengths = [input_list[i] for input_list in total_lengths]
    
    axes[i].hist(lengths,bins=np.append(np.arange(0,63,3),62), 
                 color='green', alpha=0.6, align='mid')
    
    axes[i].set_xticks(np.arange(0,62,5),minor=False)
    axes[i].set_xticks(np.arange(0,62,1),minor=True)

    axes[i].set_yticks(np.arange(0,61,10))
    axes[i].set_ylim(0,61)
    axes[i].set_xlim(0,62)

    axes[i].tick_params(axis='both', which='major', labelsize=10)
    
    axes[i].set_title(f"Batch {i+1}",fontsize=12)

plt.show()


In [None]:
# plot histogram of all sizes concatenated
all_lengths = np.concatenate([input_list[:50] for input_list in total_lengths])

plt.figure(figsize=(7, 4))
plt.hist(all_lengths, bins=np.arange(0, 63), color='green', alpha=0.6, align='mid')

plt.xticks(np.arange(0, 62, 5), fontsize=10)
plt.yticks(np.arange(0, 510, 50), fontsize=10)
plt.xlim(0, 62)

plt.grid(True, linestyle="--", alpha=0.7)

plt.xlabel("Length", fontsize=12)
plt.ylabel("Frequency", fontsize=12)

plt.show()

In [None]:
sum(list_covered)/len(list_covered)