# 1. load and split image set

In [1]:
from torchvision.datasets import ImageFolder              # for loading images from ImageNet
from torch.utils.data import DataLoader, random_split

def split_data_set(dataset, random_seed):
    if random_seed is not None:
        torch.manual_seed(random_seed)  # set input as random seed

    # split image set ---> half for calibration data set, half for test data set
    dataset_length = len(dataset)
    print(f"Samples amount: {dataset_length}")
    calib_length = dataset_length // 2               
    test_length = dataset_length - calib_length      

    calib_dataset, test_dataset = random_split(dataset, [calib_length, test_length])
    return calib_dataset, test_dataset

# 2. Calculate Conformal Score

In [2]:
import numpy as np

# conformal function s(x,y)
def conformal_scores(model, dataloader, alpha=0.1):
    scores = []  # conformal scores of image sets
    labels = []  # true label sets 
    with torch.no_grad():
        for images, true_labels in dataloader:
            images, true_labels = images.to(device), true_labels.to(device)
            # logistic value
            outputs = model(images)
            # logistic value -> softmax
            # dim=1 : convert logistic values for all the classes of the example to the softmax 
            softmaxs = torch.softmax(outputs, dim=1)
            
            for softmax, true_label in zip(softmaxs, true_labels):
                # descending sort softmax
                sorted_softmax, sorted_index = torch.sort(softmax, descending=True)
                
                # get the position of the true label in the sorted softmax
                true_label_position = (sorted_index == true_label).nonzero(as_tuple=True)[0].item()
                # independent random variable u ~ Uniform(0, 1)
                u = np.random.uniform(0, 1)
                # cumulate sorted softmax
                cumulative_softmax = torch.cumsum(sorted_softmax, dim=0)  # dim=0 -> cumulate by raw direction

                if true_label_position == 0:
                    conformal_score = u * sorted_softmax[true_label_position].item()  # first softmax is true label
                else:
                    conformal_score = cumulative_softmax[true_label_position - 1].item() + u * sorted_softmax[true_label_position].item()
                    
                scores.append(conformal_score)
                labels.append(true_label.item())
    return np.array(scores), np.array(labels)

## 3. Construct APS

In [12]:
def aps_classification(model, dataloader, q_hat):
    aps = []         # probability set
    aps_labels = []  # label set indicated to the probability set
    labels = []      # true label
    with torch.no_grad():
        for images, true_labels in dataloader:
            images, true_labels = images.to(device), true_labels.to(device)
            outputs = model(images)
            softmaxs = torch.softmax(outputs, dim=1)
            for softmax, true_label in zip(softmaxs, true_labels):
                sorted_softmax, sorted_index = torch.sort(softmax, descending=True)
                cumulative_softmax = torch.cumsum(sorted_softmax, dim=0)

                u = torch.rand_like(cumulative_softmax)
                # score of label y = cumulative y-1 + u * probability y
                scores = cumulative_softmax - sorted_softmax + u * sorted_softmax
                
                # cumulate until meet q_hat and then cut off
                cutoff_index = torch.searchsorted(scores, q_hat, right=True)

                # Select all the probabilities and corresponding labels until cut-off index
                prediction_set_prob = sorted_softmax[:cutoff_index].tolist()
                prediction_set_labels = sorted_index[:cutoff_index].tolist()

                aps.append(prediction_set_prob)
                aps_labels.append(prediction_set_labels)
                labels.append(true_label.item())
    return aps, aps_labels, labels

## 4. Evaluate Perdiction Set

In [13]:
def eval_aps(aps_labels,  true_labels):
    total_set_size = 0
    coveraged = 0
    for aps_label, true_label in zip(aps_labels, true_labels):
        # cumulate total set size
        total_set_size += len(aps_label)
        # cumulate the predictions sets if it contains true label
        if true_label in aps_label:
            coveraged += 1

    # calculate average values
    samples_amount = len(true_labels)
    average_set_size = total_set_size / samples_amount
    average_coverage = coveraged / samples_amount
    print(f"Total set size: {total_set_size}")
    print(f"Total coverage sets: {coveraged}")
    print(f"Total samples amount: {samples_amount}")
    return average_set_size, average_coverage

In [14]:
import torch
from torch import nn, optim
from torch.nn import functional as F

class ModelWithTemperature(nn.Module):
    def __init__(self, model, tempreture):
        super(ModelWithTemperature, self).__init__()
        self.model = model
        self.tempreture = tempreture
        self.temperature = nn.Parameter(torch.ones(1) * tempreture)

    def forward(self, input):
        logits = self.model(input)
        return self.temperature_scale(logits)

    def temperature_scale(self, logits):
        """
        Perform temperature scaling on logits
        """
        # Expand temperature to match the size of logits
        temperature = self.temperature.unsqueeze(1).expand(logits.size(0), logits.size(1))
        return logits / temperature

    # This function probably should live outside of this class, but whatever
    def set_temperature(self, valid_loader):
        self.cuda()
        nll_criterion = nn.CrossEntropyLoss().cuda()
        ece_criterion = _ECELoss().cuda()

        # First: collect all the logits and labels for the validation set
        logits_list = []
        labels_list = []
        with torch.no_grad():
            for input, label in valid_loader:
                input = input.cuda()
                logits = self.model(input)
                logits_list.append(logits)
                labels_list.append(label)
            logits = torch.cat(logits_list).cuda()
            labels = torch.cat(labels_list).cuda()

        # Calculate NLL and ECE before temperature scaling
        before_temperature_nll = nll_criterion(logits, labels).item()
        before_temperature_ece = ece_criterion(logits, labels).item()
        print('Before temperature - NLL: %.3f, ECE: %.3f' % (before_temperature_nll, before_temperature_ece))

        # Next: optimize the temperature w.r.t. NLL
        optimizer = optim.LBFGS([self.temperature], lr=0.01, max_iter=50)

        def eval():
            optimizer.zero_grad()
            loss = nll_criterion(self.temperature_scale(logits), labels)
            loss.backward()
            return loss
        optimizer.step(eval)

        # Calculate NLL and ECE after temperature scaling
        after_temperature_nll = nll_criterion(self.temperature_scale(logits), labels).item()
        after_temperature_ece = ece_criterion(self.temperature_scale(logits), labels).item()
        print('Optimal temperature: %.3f' % self.temperature.item())
        print('After temperature - NLL: %.3f, ECE: %.3f' % (after_temperature_nll, after_temperature_ece))

        return self


class _ECELoss(nn.Module):
    def __init__(self, n_bins=15):
        """
        n_bins (int): number of confidence interval bins
        """
        super(_ECELoss, self).__init__()
        bin_boundaries = torch.linspace(0, 1, n_bins + 1)
        self.bin_lowers = bin_boundaries[:-1]
        self.bin_uppers = bin_boundaries[1:]

    def forward(self, logits, labels):
        softmaxes = F.softmax(logits, dim=1)
        confidences, predictions = torch.max(softmaxes, 1)
        accuracies = predictions.eq(labels)

        ece = torch.zeros(1, device=logits.device)
        for bin_lower, bin_upper in zip(self.bin_lowers, self.bin_uppers):
            # Calculated |confidence - accuracy| in each bin
            in_bin = confidences.gt(bin_lower.item()) * confidences.le(bin_upper.item())
            prop_in_bin = in_bin.float().mean()
            if prop_in_bin.item() > 0:
                accuracy_in_bin = accuracies[in_bin].float().mean()
                avg_confidence_in_bin = confidences[in_bin].mean()
                ece += torch.abs(avg_confidence_in_bin - accuracy_in_bin) * prop_in_bin

        return ece

## 5. Construct APS repeatedly

In [17]:
import torch
import torch.nn as nn
import torchvision.models as models
import torchvision.transforms as transforms               # include image preprocess tools
from torchvision.datasets import CIFAR10        # for loading images from Pytorch CIFAR
from torch.utils.data import DataLoader
import detectors
import timm

# check GPU status
print("Is CUDA available:", torch.cuda.is_available())
print("Device count:", torch.cuda.device_count())
print("Device name:", torch.cuda.get_device_name(0) if torch.cuda.is_available() else "No GPU")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# load fine-tuned model
model = timm.create_model("resnet50_cifar10", pretrained=True)
model = model.to(device)

# reprocess the images from CIFAR
data_transform = transforms.Compose([
    transforms.ToTensor(),          # transfer to tensor
    transforms.Normalize(mean=[0.4914, 0.4822, 0.4465], std=[0.2023, 0.1994, 0.2010])  # normalize
])
# load images from CIFAR10
dataset = CIFAR10(root="./data", train=False, download=True, transform=data_transform)

# Tempreture Scaling
temp_scal_loader = DataLoader(dataset, batch_size=32, shuffle=True)
model = ModelWithTemperature(model, tempreture=5.0).to(device)
model.set_temperature(temp_scal_loader)

model.eval()

# The number of times the experiment is going to be repeated
num_runs = 10

# error rate
alpha = 0.1

# contruct and evaluate repeatedly
all_avg_set_sizes = []
all_avg_coverages = []
print("APS Classification, Start!\n")
for i in range(num_runs):
    print(f"Running experiment {i+1}/{num_runs}...")

    # splite dataset
    calib_dataset, test_dataset = split_data_set(dataset, random_seed=i)

    # load data set respectively
    calib_loader = DataLoader(calib_dataset, batch_size=32, shuffle=False)
    test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)

    # calculate q_hat
    calib_scores, _ = conformal_scores(model, calib_loader, alpha)
    q_hat = np.quantile(calib_scores, 1 - 0.1)  # calculate 1-alpha quantile
    print(f"q_hat = {q_hat}")

    # construct APS
    aps, aps_labels, true_labels = aps_classification(model, test_loader, q_hat)

    # evaluate APS
    avg_set_size, avg_coverage = eval_aps(aps_labels, true_labels)
    print(f"Average Prediction Set Size After APS in runs {i+1}: {avg_set_size}")
    print(f"Average Coverage Rate in runs {i+1}: {avg_coverage}\n")

    # record current result
    all_avg_set_sizes.append(avg_set_size)
    all_avg_coverages.append(avg_coverage)

# calculate the final average result
final_avg_set_size = np.mean(all_avg_set_sizes)
final_avg_coverage = np.mean(all_avg_coverages)

print(f"Final Average Prediction Set Size: {final_avg_set_size}")
print(f"Final Average Coverage: {final_avg_coverage}")

Is CUDA available: True
Device count: 1
Device name: NVIDIA GeForce RTX 3060 Ti
Files already downloaded and verified
Before temperature - NLL: 0.352, ECE: 0.046
Optimal temperature: 4.908
After temperature - NLL: 0.824, ECE: 0.422
APS Classification, Start!

Running experiment 1/10...
Samples amount: 10000
q_hat = 0.5341354754454465
Total set size: 6974
Total coverage sets: 4533
Total samples amount: 5000
Average Prediction Set Size After APS in runs 1: 1.3948
Average Coverage Rate in runs 1: 0.9066

Running experiment 2/10...
Samples amount: 10000
q_hat = 0.5286978107940244
Total set size: 6916
Total coverage sets: 4519
Total samples amount: 5000
Average Prediction Set Size After APS in runs 2: 1.3832
Average Coverage Rate in runs 2: 0.9038

Running experiment 3/10...
Samples amount: 10000
q_hat = 0.5377961832988696
Total set size: 7112
Total coverage sets: 4526
Total samples amount: 5000
Average Prediction Set Size After APS in runs 3: 1.4224
Average Coverage Rate in runs 3: 0.9052


# Result

- Final Average **Prediction Set Size： 1.38 / 10**
- Final Average **Coverage: 90.31% ($\alpha$=0.1)**