In [51]:
import torch
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from transformers import CLIPProcessor, CLIPModel, AutoProcessor, AutoModel


In [52]:
torch.set_default_device("cuda")
model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to("cuda")
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")


In [53]:
transform = transforms.Compose([
    transforms.Resize((224, 224)),  # CLIP expects 224x224 images
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.4914, 0.4822, 0.4465], std=[0.2470, 0.2435, 0.2616])  # CLIP normalization
])


In [54]:
from torch.utils.data import Subset
cifar_data = datasets.CIFAR10(root="./data", train=False, download=True, transform=transform)

In [55]:
class_names = cifar_data.classes  # CIFAR-10 class names
text_inputs = processor(text=class_names, return_tensors="pt", padding=True)


In [56]:
import matplotlib.pyplot as plt
from sklearn.metrics import accuracy_score

def denormalize(img: torch.Tensor, mean: torch.Tensor, std: torch.Tensor):
    """Denormalizes the image given the mean and standard deviation."""
    return img * torch.tensor(std, device="cpu").view(3, 1, 1) + torch.tensor(mean, device="cpu").view(3, 1, 1)

In [57]:
import numpy as np
from itertools import product

seeds = list(range(10))
results = []

for seed in seeds:

    fractions = [0.9, 0.1]
    total_len = len(cifar_data)
    lengths = [int(f * total_len) for f in fractions]
    lengths[-1] = total_len - sum(lengths[:-1])
    generator = torch.Generator(device="cuda").manual_seed(seed)

    # Split the dataset
    test_data, calib_data = torch.utils.data.random_split(cifar_data, lengths, generator=generator)
    calib_loader = DataLoader(calib_data, batch_size=32, shuffle=False)
    test_loader = DataLoader(test_data, batch_size=32, shuffle=False)

    # Forward pass through CLIP

    all_labels = []
    all_predictions = []
    scores = []

    # Compute nonconformity scores

    for images, labels in calib_loader:

        pil_images = [transforms.ToPILImage()(denormalize(img, processor.image_processor.image_mean, processor.image_processor.image_std)) for img in images]
        
        # Process images using CLIP's processor (automatically normalizes them)
        inputs = processor(images=pil_images, return_tensors="pt").to("cuda")
        input_image_processed = inputs['pixel_values'].squeeze(0)

        outputs = model(**inputs, **text_inputs)
        logits_per_image = outputs.logits_per_image  # Image-to-text similarity scores
        probs = logits_per_image.softmax(dim=1)  # Convert to probabilities
        predictions = probs.argmax(dim=1)
        all_labels.extend(labels.tolist())
        all_predictions.extend(predictions.tolist())
        scores += logits_per_image.take_along_dim(torch.tensor(labels).unsqueeze(-1),dim=1).squeeze().tolist()
        


    print(scores)



    alphas = [0.02, 0.05, 0.1, 0.2]
    for alpha in alphas:
        # Compute the quantile for the nonconformity scores
        n = len(scores)
        threshold = np.quantile(scores, np.ceil((n+1)*(alpha))/n, method="inverted_cdf")
        prediction_sets = []
        all_labels = []
        all_predictions = []

        for images, labels in test_loader:

            pil_images = [transforms.ToPILImage()(denormalize(img, processor.image_processor.image_mean, processor.image_processor.image_std)) for img in images]
            
            # Process images using CLIP's processor (automatically normalizes them)
            inputs = processor(images=pil_images, return_tensors="pt").to("cuda")
            input_image_processed = inputs['pixel_values'].squeeze(0)

            outputs = model(**inputs, **text_inputs)
            logits_per_image = outputs.logits_per_image  # Image-to-text similarity scores
            probs = logits_per_image.softmax(dim=1)  # Convert to probabilities
            predictions = probs.argmax(dim=1)
            all_labels.extend(labels.tolist())
            all_predictions.extend(predictions.tolist())
            indices = (logits_per_image > threshold).nonzero(as_tuple=True)
            row_indices = [indices[1][indices[0] == i] for i in range(logits_per_image.size(0))]
            prediction_sets.extend(row_indices)

        pred_sets = [x.tolist() for x in prediction_sets]
        coverage = np.mean([all_labels[i] in pred_sets[i] for i in range(len(all_labels))])
        avg_set_size = np.mean([len(s) for s in pred_sets])
        median_set_size = np.median([len(s) for s in pred_sets])
        acc_score = accuracy_score(all_labels, all_predictions)
        results.append([seed, alpha, coverage, avg_set_size, median_set_size, acc_score])
        # print(f"accuracy =\t\t {acc_score}")
        # print(f"coverage =\t\t {coverage}")
        # print(f"mean set size =\t\t {avg_set_size}")
        # print(f"median set size =\t {median_set_size}")

        print(seed,len(calib_data), len(test_data))

  return func(*args, **kwargs)


[25.252965927124023, 30.183380126953125, 26.752796173095703, 31.09225082397461, 29.038806915283203, 26.53801155090332, 29.167356491088867, 27.731943130493164, 26.997941970825195, 25.050024032592773, 27.23029899597168, 27.58633804321289, 27.784029006958008, 29.511795043945312, 29.560592651367188, 27.244945526123047, 26.311054229736328, 28.24018096923828, 28.928226470947266, 29.320680618286133, 28.326383590698242, 28.0279483795166, 27.81427001953125, 28.319622039794922, 28.82791519165039, 28.410879135131836, 26.64170265197754, 25.915267944335938, 28.487751007080078, 29.502857208251953, 23.946657180786133, 29.691696166992188, 27.95136833190918, 27.778200149536133, 28.544898986816406, 24.80379867553711, 29.2397403717041, 28.111839294433594, 27.02745819091797, 25.928884506225586, 27.475208282470703, 29.77061653137207, 27.922025680541992, 28.84884262084961, 30.116012573242188, 29.483917236328125, 28.033241271972656, 26.00082778930664, 29.474523544311523, 29.419681549072266, 25.58819007873535

  return func(*args, **kwargs)


[27.71703338623047, 29.552125930786133, 26.67011833190918, 29.140979766845703, 23.742088317871094, 29.541954040527344, 28.838624954223633, 28.09546661376953, 26.370080947875977, 28.586206436157227, 22.781124114990234, 26.009660720825195, 28.896156311035156, 23.787492752075195, 23.82308006286621, 30.06288719177246, 25.226852416992188, 24.237815856933594, 30.233034133911133, 27.749902725219727, 28.756174087524414, 28.84610939025879, 27.079326629638672, 26.591678619384766, 24.894052505493164, 28.559648513793945, 28.743946075439453, 23.5189266204834, 25.296815872192383, 29.87729263305664, 25.099014282226562, 29.132158279418945, 28.32421875, 29.933780670166016, 27.456172943115234, 26.751501083374023, 29.554134368896484, 27.800025939941406, 28.29216766357422, 26.38034439086914, 27.758121490478516, 29.41254997253418, 24.55747413635254, 26.057260513305664, 24.107805252075195, 28.882963180541992, 26.901206970214844, 28.3593807220459, 29.888227462768555, 29.54971694946289, 28.639713287353516, 26

  return func(*args, **kwargs)


[29.623470306396484, 21.657054901123047, 27.120384216308594, 28.649860382080078, 26.33731460571289, 26.21457290649414, 29.08243179321289, 27.1667423248291, 28.284347534179688, 24.23440170288086, 26.068443298339844, 22.775157928466797, 29.87275505065918, 27.493968963623047, 27.820093154907227, 30.273303985595703, 26.395395278930664, 27.330219268798828, 27.177671432495117, 28.033525466918945, 24.33786392211914, 29.61410903930664, 25.44194793701172, 25.979825973510742, 28.605934143066406, 28.155349731445312, 26.59515953063965, 26.974811553955078, 30.408870697021484, 26.25025177001953, 23.912647247314453, 27.570024490356445, 29.507869720458984, 25.39529800415039, 27.183443069458008, 27.655277252197266, 27.696456909179688, 27.459726333618164, 29.570039749145508, 24.999252319335938, 29.90341567993164, 28.639511108398438, 27.49066734313965, 28.85509490966797, 27.052425384521484, 29.661619186401367, 29.183874130249023, 26.658815383911133, 27.18246078491211, 28.985082626342773, 28.9264583587646

  return func(*args, **kwargs)


[27.078702926635742, 25.615083694458008, 28.174325942993164, 25.08989906311035, 27.295854568481445, 26.71687889099121, 26.42276954650879, 30.044227600097656, 30.468666076660156, 26.659982681274414, 27.989370346069336, 25.355806350708008, 28.62241554260254, 27.84307861328125, 27.540735244750977, 28.512826919555664, 25.310802459716797, 28.99250030517578, 28.04493522644043, 25.89023780822754, 27.207645416259766, 29.619327545166016, 29.297157287597656, 28.337961196899414, 25.24496078491211, 29.48351287841797, 30.71615982055664, 27.83784294128418, 26.36544418334961, 24.57988929748535, 28.90879249572754, 24.768646240234375, 30.278148651123047, 26.889631271362305, 25.920900344848633, 28.6366024017334, 27.850914001464844, 26.52779769897461, 25.913578033447266, 27.15170669555664, 26.67626953125, 28.792726516723633, 28.90397834777832, 29.161849975585938, 28.165021896362305, 28.07375144958496, 28.13074493408203, 29.06459617614746, 29.851791381835938, 26.235204696655273, 26.45162010192871, 24.6646

  return func(*args, **kwargs)


[25.368871688842773, 29.66478157043457, 30.272478103637695, 26.55723762512207, 23.90666961669922, 29.970962524414062, 28.256410598754883, 26.87456512451172, 29.113054275512695, 29.54493522644043, 24.3740291595459, 26.56829071044922, 27.327302932739258, 30.776674270629883, 30.499813079833984, 25.15605354309082, 29.541215896606445, 28.665998458862305, 24.685583114624023, 26.351076126098633, 29.131242752075195, 26.426040649414062, 27.74399185180664, 25.422626495361328, 26.086612701416016, 29.1688232421875, 26.213546752929688, 25.63544464111328, 28.33953285217285, 28.046890258789062, 24.481168746948242, 25.53298568725586, 28.99321937561035, 27.881479263305664, 26.214113235473633, 26.980995178222656, 30.44023895263672, 25.38074493408203, 30.34706687927246, 30.047744750976562, 27.49289894104004, 26.864524841308594, 29.004919052124023, 26.86035919189453, 27.796390533447266, 28.866689682006836, 27.959379196166992, 29.0628662109375, 31.21271514892578, 27.481529235839844, 29.547849655151367, 27.

  return func(*args, **kwargs)


[26.21415901184082, 30.03314208984375, 29.85943031311035, 26.454761505126953, 27.509864807128906, 27.517446517944336, 29.0345458984375, 25.928571701049805, 25.201385498046875, 25.816654205322266, 26.486652374267578, 27.6613826751709, 23.85355567932129, 24.48591423034668, 27.015071868896484, 27.690404891967773, 30.558958053588867, 25.860139846801758, 29.898273468017578, 28.52338218688965, 24.32016944885254, 26.519933700561523, 29.49410629272461, 28.91126251220703, 29.21961784362793, 26.80060577392578, 28.9025936126709, 25.519763946533203, 28.3655948638916, 27.23847007751465, 25.59747314453125, 28.32854461669922, 26.841588973999023, 30.183320999145508, 24.237815856933594, 28.176647186279297, 29.140979766845703, 27.039384841918945, 27.303869247436523, 24.85445213317871, 25.258726119995117, 27.95047378540039, 25.499099731445312, 27.635360717773438, 26.86526107788086, 28.652969360351562, 29.098072052001953, 28.80093765258789, 27.363496780395508, 25.313701629638672, 28.407733917236328, 28.45

  return func(*args, **kwargs)


[27.339685440063477, 29.089406967163086, 24.752674102783203, 25.726905822753906, 26.03914451599121, 30.52350616455078, 27.044248580932617, 27.138551712036133, 22.37965965270996, 30.337125778198242, 28.2115421295166, 29.624975204467773, 26.452531814575195, 30.264965057373047, 25.541015625, 27.00617027282715, 26.052955627441406, 29.81424331665039, 26.663869857788086, 30.53702735900879, 23.695024490356445, 28.1962833404541, 25.32269287109375, 28.96825408935547, 29.47199058532715, 28.439367294311523, 28.774028778076172, 30.6400146484375, 29.9622745513916, 28.8768253326416, 26.607507705688477, 27.278549194335938, 26.958168029785156, 29.01482582092285, 28.6894588470459, 26.153480529785156, 29.768735885620117, 29.08829116821289, 28.630596160888672, 29.161319732666016, 30.41303253173828, 28.15997314453125, 25.52660369873047, 28.83027458190918, 29.53525161743164, 29.187223434448242, 28.342575073242188, 31.21038055419922, 26.07210350036621, 28.839223861694336, 28.419214248657227, 29.085290908813

  return func(*args, **kwargs)


[27.190723419189453, 29.183874130249023, 24.912893295288086, 29.69953727722168, 29.07370948791504, 26.647275924682617, 29.892576217651367, 24.72152328491211, 27.067211151123047, 29.262784957885742, 30.83416748046875, 26.80060577392578, 27.559412002563477, 24.28409767150879, 29.51938247680664, 26.201778411865234, 29.493314743041992, 23.912504196166992, 25.528156280517578, 27.846315383911133, 26.191091537475586, 27.850914001464844, 29.001602172851562, 22.787885665893555, 22.890884399414062, 27.04213523864746, 28.420642852783203, 25.655153274536133, 27.50726890563965, 23.474510192871094, 24.86570930480957, 30.536968231201172, 27.71233558654785, 25.27901840209961, 27.052425384521484, 27.230506896972656, 27.86720085144043, 29.369434356689453, 28.204185485839844, 28.06543731689453, 26.21415901184082, 29.42342185974121, 24.550662994384766, 29.838407516479492, 28.250810623168945, 24.08659553527832, 25.457611083984375, 28.97809410095215, 29.49164581298828, 24.2868709564209, 28.02889060974121, 2

  return func(*args, **kwargs)


[27.007814407348633, 29.29060173034668, 27.758121490478516, 29.255237579345703, 24.23723602294922, 25.605846405029297, 27.743022918701172, 26.009660720825195, 22.824153900146484, 29.43044662475586, 24.756011962890625, 26.811174392700195, 28.813631057739258, 29.836400985717773, 27.64727210998535, 28.67323112487793, 29.820329666137695, 26.916826248168945, 29.105836868286133, 30.989221572875977, 24.058731079101562, 26.900625228881836, 29.470699310302734, 28.17317771911621, 28.423368453979492, 30.780052185058594, 29.602140426635742, 28.012523651123047, 28.46643829345703, 30.649152755737305, 26.68900489807129, 27.698776245117188, 29.278139114379883, 29.03980827331543, 29.990779876708984, 27.973634719848633, 28.563356399536133, 30.811668395996094, 29.610624313354492, 28.783958435058594, 25.863039016723633, 29.120630264282227, 28.91157341003418, 28.9851016998291, 27.476146697998047, 24.063091278076172, 25.296588897705078, 29.64324378967285, 25.67626190185547, 23.586910247802734, 26.9385032653

  return func(*args, **kwargs)


[26.69907569885254, 27.726703643798828, 24.79353904724121, 26.568456649780273, 29.708194732666016, 30.811012268066406, 27.700843811035156, 24.060087203979492, 27.48179817199707, 26.129606246948242, 29.385173797607422, 29.836400985717773, 26.792816162109375, 27.418798446655273, 28.909603118896484, 26.479413986206055, 28.510454177856445, 25.338972091674805, 26.45391845703125, 25.493419647216797, 27.70734977722168, 28.425857543945312, 23.77528190612793, 25.99851417541504, 27.070650100708008, 27.951631546020508, 23.72539710998535, 26.1348876953125, 30.496810913085938, 26.206663131713867, 27.437084197998047, 30.48931312561035, 25.336122512817383, 25.443401336669922, 29.995342254638672, 28.79241180419922, 29.482418060302734, 25.33019256591797, 24.492280960083008, 28.595130920410156, 28.245586395263672, 28.028955459594727, 28.017324447631836, 29.12392234802246, 28.270206451416016, 30.67861557006836, 29.501041412353516, 26.198200225830078, 27.627485275268555, 24.35097312927246, 25.492959976196

In [58]:
import pandas as pd
result_df = pd.DataFrame(data=results, columns=["seed", "alpha", "coverage", "avg_set_size", "median_set_size", "acc_score"])

In [59]:
result_df.to_csv("cifar10_clip_results.csv")

In [None]:

agg = result_df.groupby('alpha').agg(['mean', 'std'])

# Step 2: Flatten MultiIndex columns
agg.columns = ['_'.join(col).strip() for col in agg.columns.values]

# Step 3: Format into "mean ± std" for each numeric column
formatted = pd.DataFrame(index=agg.index)

cols = [col for col in result_df.columns if col not in ["alpha", "seed", "acc_score", "median_set_size"]]

for col in cols:  # Skip 'Group'
    mean_col = f"{col}_mean"
    std_col = f"{col}_std"
    formatted[col] = agg[mean_col].round(4).astype(str) + ' ± ' + agg[std_col].round(4).astype(str)

print(formatted)


              coverage     avg_set_size median_set_size
alpha                                                  
0.02   0.9791 ± 0.0036   4.9845 ± 0.245    5.1 ± 0.3162
0.05   0.9492 ± 0.0054  3.5225 ± 0.1504       3.0 ± 0.0
0.10   0.8984 ± 0.0109  2.3808 ± 0.1465       2.0 ± 0.0
0.20   0.7957 ± 0.0117   1.4692 ± 0.061       1.0 ± 0.0


In [86]:
print(formatted.to_latex())

\begin{tabular}{llll}
\toprule
 & coverage & avg_set_size & median_set_size \\
alpha &  &  &  \\
\midrule
0.020000 & 0.9791 ± 0.0036 & 4.9845 ± 0.245 & 5.1 ± 0.3162 \\
0.050000 & 0.9492 ± 0.0054 & 3.5225 ± 0.1504 & 3.0 ± 0.0 \\
0.100000 & 0.8984 ± 0.0109 & 2.3808 ± 0.1465 & 2.0 ± 0.0 \\
0.200000 & 0.7957 ± 0.0117 & 1.4692 ± 0.061 & 1.0 ± 0.0 \\
\bottomrule
\end{tabular}

