In [6]:
import numpy as np
import torch
import torchvision.transforms as transforms               # include image preprocess tools
from torchvision.datasets import CIFAR10        # for loading images from Pytorch CIFAR
from torch.utils.data import DataLoader
import detectors
import timm
from src.saps import split_data_set, saps_scores, saps_classification, eval_aps
from src.temperature_scaling import ModelWithTemperature

# check GPU status
print("Is CUDA available:", torch.cuda.is_available())
print("Device count:", torch.cuda.device_count())
print("Device name:", torch.cuda.get_device_name(0) if torch.cuda.is_available() else "No GPU")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# load fine-tuned model
model = timm.create_model("resnet50_cifar10", pretrained=True)
model = model.to(device)

# reprocess the images from CIFAR
data_transform = transforms.Compose([
    transforms.ToTensor(),          # transfer to tensor
    transforms.Normalize(mean=[0.4914, 0.4822, 0.4465], std=[0.2023, 0.1994, 0.2010])  # normalize
])
# load images from CIFAR10
dataset = CIFAR10(root="./data", train=False, download=True, transform=data_transform)

# temperature scaling
temp_scal_loader = DataLoader(dataset, batch_size=32, shuffle=True)
model = ModelWithTemperature(model, temperature=5.0).to(device)
model.set_temperature(temp_scal_loader)

model.eval()

# The number of times the experiment is going to be repeated
num_runs = 10

# error rate
alpha = 0.1
lambda_ = 2.0

# construct and evaluate repeatedly
all_avg_set_sizes = []
all_avg_coverages = []
print("SAPS Classification, Start!\n")
for i in range(num_runs):
    print(f"Running experiment {i+1}/{num_runs}...")

    # splite dataset
    calib_dataset, test_dataset = split_data_set(dataset, random_seed=i)

    # load data set respectively
    calib_loader = DataLoader(calib_dataset, batch_size=32, shuffle=False)
    test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)

    # calculate q_hat
    calib_scores, _ = saps_scores(model, calib_loader, alpha, lambda_, device)
    t_cal = np.quantile(calib_scores, 1 - 0.1)  # calculate 1-alpha quantile
    print(f"t_cal = {t_cal}")

    # construct APS
    aps, aps_labels, true_labels = saps_classification(model, test_loader, t_cal, lambda_, device)

    # evaluate APS
    avg_set_size, avg_coverage = eval_aps(aps_labels, true_labels)
    print(f"Average Prediction Set Size After APS in runs {i+1}: {avg_set_size}")
    print(f"Average Coverage Rate in runs {i+1}: {avg_coverage}\n")

    # record current result
    all_avg_set_sizes.append(avg_set_size)
    all_avg_coverages.append(avg_coverage)

# calculate the final average result
final_avg_set_size = np.mean(all_avg_set_sizes)
final_avg_coverage = np.mean(all_avg_coverages)

print(f"Final Average Prediction Set Size: {final_avg_set_size}")
print(f"Final Average Coverage: {final_avg_coverage}")

  from .autonotebook import tqdm as notebook_tqdm


Is CUDA available: True
Device count: 1
Device name: NVIDIA GeForce RTX 3060 Ti
Files already downloaded and verified
Before temperature - NLL: 0.351, ECE: 0.046
Optimal temperature: 4.908
After temperature - NLL: 0.830, ECE: 0.421
SAPS Classification, Start!

Running experiment 1/10...
Samples amount: 10000
t_cal = 0.5872043251991272
Total set size: 5073
Total coverage sets: 4496
Total samples amount: 5000
Average Prediction Set Size After APS in runs 1: 1.0146
Average Coverage Rate in runs 1: 0.8992

Running experiment 2/10...
Samples amount: 10000
t_cal = 0.5821468651294708
Total set size: 5027
Total coverage sets: 4436
Total samples amount: 5000
Average Prediction Set Size After APS in runs 2: 1.0054
Average Coverage Rate in runs 2: 0.8872

Running experiment 3/10...
Samples amount: 10000
t_cal = 0.5883177399635316
Total set size: 5068
Total coverage sets: 4474
Total samples amount: 5000
Average Prediction Set Size After APS in runs 3: 1.0136
Average Coverage Rate in runs 3: 0.8948

# Result

- Final Average **Prediction Set Size：1.015 / 10**
- Final Average **Coverage: 89.85% ($\alpha$=0.1)**