Replace with correct file and image path.

In [None]:
image_path = "/content/extracted_data/last/images-224-subset"
metadata_csv = r"/content/extracted_data/last/Order 8003922 Fina subset_metadata_30000.csv"


In [None]:
import pandas as pd
df = pd.read_csv(metadata_csv)
df.tail()

Unnamed: 0,Image Index,Finding Labels,Follow-up #,Patient ID,Patient Age,Patient Sex,View Position,OriginalImage[Width,Height],OriginalImagePixelSpacing[x,y],Primary_Label
29995,00022772_000.png,Mass,0,22772,42,M,AP,3056,2544,0.139,0.139,Mass
29996,00009918_005.png,Infiltration,5,9918,61,M,PA,2992,2991,0.143,0.143,Infiltration
29997,00009429_003.png,Effusion|Infiltration,3,9429,45,F,PA,2048,2500,0.168,0.168,Effusion
29998,00014308_000.png,Mass,0,14308,31,F,PA,2590,2991,0.143,0.143,Mass
29999,00019150_009.png,Infiltration,22,19150,69,M,AP,3056,2544,0.139,0.139,Infiltration


In [None]:
df = df[['Image Index', 'Finding Labels', 'Patient Age', 'Patient Sex']].dropna()
df = df[df['Finding Labels'] != 'No Finding']
df.head()

Unnamed: 0,Image Index,Finding Labels,Patient Age,Patient Sex
0,00009305_001.png,Mass,54,M
1,00020703_027.png,Effusion|Mass|Nodule,61,M
2,00012834_008.png,Atelectasis|Consolidation|Effusion|Infiltration,33,M
3,00016064_004.png,Emphysema,55,M
4,00011702_007.png,Infiltration,24,F


In [None]:
# I take the first label as the primary class if there are two or more labels
df['Primary_Label'] = df['Finding Labels'].apply(lambda x: x.split('|')[0])
df.head()

Unnamed: 0,Image Index,Finding Labels,Patient Age,Patient Sex,Primary_Label
0,00009305_001.png,Mass,54,M,Mass
1,00020703_027.png,Effusion|Mass|Nodule,61,M,Effusion
2,00012834_008.png,Atelectasis|Consolidation|Effusion|Infiltration,33,M,Atelectasis
3,00016064_004.png,Emphysema,55,M,Emphysema
4,00011702_007.png,Infiltration,24,F,Infiltration


In [None]:
from torchcp.classification.score import THR

In [None]:
from PIL import Image
from torchvision import transforms
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader, random_split, Subset
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision.models import resnet152, ResNet152_Weights
from torchcp.classification.score import THR,APS, RAPS
from torchcp.classification.predictor import SplitPredictor
import os
import pandas as pd
# !pip install torchcp==1.0.0



In [None]:
image_path = "/content/extracted_data/last/images-224-subset"
metadata_csv = r"/content/extracted_data/last/Order 8003922 Fina subset_metadata_30000.csv"

In [None]:
df = pd.read_csv(metadata_csv)
df.tail()

Unnamed: 0,Image Index,Finding Labels,Follow-up #,Patient ID,Patient Age,Patient Sex,View Position,OriginalImage[Width,Height],OriginalImagePixelSpacing[x,y],Primary_Label
29995,00022772_000.png,Mass,0,22772,42,M,AP,3056,2544,0.139,0.139,Mass
29996,00009918_005.png,Infiltration,5,9918,61,M,PA,2992,2991,0.143,0.143,Infiltration
29997,00009429_003.png,Effusion|Infiltration,3,9429,45,F,PA,2048,2500,0.168,0.168,Effusion
29998,00014308_000.png,Mass,0,14308,31,F,PA,2590,2991,0.143,0.143,Mass
29999,00019150_009.png,Infiltration,22,19150,69,M,AP,3056,2544,0.139,0.139,Infiltration


In [None]:
class ChestXrayDataset(Dataset):
    def __init__(self, csv_file, img_dir, transform=None):
        #intializes dataset with CSV and image directory
        self.df = pd.read_csv(csv_file)
        self.img_dir = img_dir
        self.transform = transform
        self.label_map = {label: idx for idx, label in enumerate(self.df['Primary_Label'].unique())}
        self.df['LabelIndex'] = self.df['Primary_Label'].map(self.label_map)

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        #loads and converts the images to RGB 
        row = self.df.iloc[idx]
        img_path = os.path.join(self.img_dir, row['Image Index'])
        image = Image.open(img_path).convert('RGB')
        label = row['LabelIndex']
        age = row['Patient Age']
        sex = row['Patient Sex']
        if self.transform:
            image = self.transform(image)
        return image, label, age, sex

In [None]:
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
])

# I use 70% for training, 15% for calibration and the remaining 15% for testing
dataset = ChestXrayDataset(metadata_csv, image_path, transform=transform)
train_size = int(0.7 * len(dataset))
calib_size = int(0.15 * len(dataset))
test_size = len(dataset) - train_size - calib_size

train_set, calib_set, test_set = random_split(dataset, [train_size, calib_size, test_size])

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

Using device: cuda


In [None]:
#Initializes the model
model = resnet152(weights=ResNet152_Weights.DEFAULT)
model.fc = nn.Linear(model.fc.in_features, 14)
model = model.to(device)

criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

Downloading: "https://download.pytorch.org/models/resnet152-f82ba261.pth" to /root/.cache/torch/hub/checkpoints/resnet152-f82ba261.pth
100%|██████████| 230M/230M [00:01<00:00, 179MB/s]


In [None]:
#Data loader for training 
train_loader = DataLoader(train_set, batch_size=32, num_workers=2, shuffle=True)

#Iterates through 8 epochs 
for epoch in range(8):
    model.train()
    running_loss = 0.0

    for images, labels, *_ in train_loader:
        images, labels = images.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()

    avg_loss = running_loss / len(train_loader)
    print(f"Epoch {epoch + 1}, Loss: {avg_loss:.4f}")

In [None]:
#Load the trained model. This is already avaliable in submitted files.
model.load_state_dict(torch.load("/content/extracted_data/last/trained_model.pth", map_location="cuda"))  
model.eval()  

ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): Bottleneck(
      (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (downsample): Sequential(
        (0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): Bottleneck(
      (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (downsample): Sequential(
        (0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 

In [None]:
# Loaders, written down separately as well to prevent re-running the entire next block.
calib_loader = DataLoader(calib_set, batch_size=32, num_workers=2)
test_loader = DataLoader(test_set, batch_size=32, num_workers=2)


In [None]:
# Loaders
calib_loader = DataLoader(calib_set, batch_size=32, num_workers=2)
test_loader = DataLoader(test_set, batch_size=32, num_workers=2)

all_preds = []
all_labels = []
confident_correct = 0
confident_total = 0
confidence_threshold = 0.9
covered_count = 0
softmax = nn.Softmax(dim=1) #converts model outputs to probabilites

#Evaluates the model and calculates the confidence of the model
model.eval()
with torch.no_grad():
    for images, labels, *_ in test_loader:
        images = images.to(device)
        labels = labels.to(device)
        logits = model(images)
        probs = softmax(logits)

        max_probs, preds = probs.max(dim=1)
        all_preds.extend(preds.cpu().tolist())
        all_labels.extend(labels.cpu().tolist())

        confident_mask = max_probs >= confidence_threshold
        covered_count += confident_mask.sum().item()

        confident_correct += ((preds == labels) & confident_mask).sum().item()
        confident_total += confident_mask.sum().item()

#baseline comparsion for traditional approach, before applying CP
softmax_preds = torch.tensor(all_preds)
softmax_labels = torch.tensor(all_labels)
softmax_acc = (softmax_preds == softmax_labels).float().mean().item()
coverage = covered_count / len(test_loader.dataset)
covered_accuracy = confident_correct / confident_total if confident_total > 0 else 0.0

print(f"Softmax Accuracy: {softmax_acc:.3f}")
print(f"Softmax Coverage (confidence ≥ {confidence_threshold}): {coverage:.3f}")

# APS
aps_score_func = APS()
aps_predictor = SplitPredictor(score_function=aps_score_func, model=model)
aps_predictor.calibrate(calib_loader, alpha=0.1)
aps_results = aps_predictor.evaluate(test_loader)
print(f"APS Coverage: {aps_results['Coverage_rate']:.3f}, Avg Set Size: {aps_results['Average_size']:.2f}")

# RAPS
raps_score_func = RAPS()
raps_predictor = SplitPredictor(score_function=raps_score_func, model=model)
try:
    raps_predictor.calibrate(calib_loader, alpha=0.1, k=0, lam=0.05)
except TypeError:
    raps_predictor.calibrate(calib_loader, alpha=0.1) 
raps_results = raps_predictor.evaluate(test_loader)
print(f"RAPS Coverage: {raps_results['Coverage_rate']:.3f}, Avg Set Size: {raps_results['Average_size']:.2f}")

#THR function done below.

Softmax Accuracy: 0.754
Softmax Coverage (confidence ≥ 0.9): 0.667
APS Coverage: 0.896, Avg Set Size: 2.46
RAPS Coverage: 0.887, Avg Set Size: 2.28


In [None]:
# Function to extract logits (No need to re-run this as logits already provided)
def extract_logits(model, dataloader):
    all_logits = []
    all_labels = []

    with torch.no_grad():
        for images, labels, *_ in dataloader:
            images = images.to(device)
            logits = model(images)  # Get raw logits (before softmax)
            all_logits.append(logits.cpu())
            all_labels.append(labels)

    return torch.cat(all_logits), torch.cat(all_labels)

# Extracting logits from calibration and test sets
cal_logits_53, cal_labels_53 = extract_logits(model, calib_loader)

test_logits_53, test_labels_53 = extract_logits(model, test_loader)


In [None]:
# Saving the logits
torch.save({
    'cal_logits_53': cal_logits_53,
    'cal_labels_53': cal_labels_53,
    'test_logits_53': test_logits_53,
    'test_labels_53': test_labels_53
}, '/content/drive/MyDrive/chest_xray_logits_53.pt') #if re-running the code, save at desired location.

print(f"Saved logits - Cal: {cal_logits_53.shape}, Test: {test_logits_53.shape}")

In [None]:
# Load your saved logits. Already avaliable in the submitted files
data = torch.load('/content/drive/MyDrive/chest_xray_logits_53.pt')
cal_logits_53 = data['cal_logits_53']
cal_labels_53 = data['cal_labels_53']
test_logits_53 = data['test_logits_53']
test_labels_53 = data['test_labels_53']

In [None]:
# THR using the saved logits
thr_score = THR(score_type="softmax")
thr_predictor = SplitPredictor(score_function=thr_score)

thr_predictor.calculate_threshold(cal_logits_53, cal_labels_53, alpha=0.1)

thr_prediction_sets = thr_predictor.predict_with_logits(test_logits_53)

# Calculate metrics
coverage = thr_prediction_sets[range(len(test_labels_53)), test_labels_53].float().mean().item()
avg_size = thr_prediction_sets.sum(dim=1).float().mean().item()

print(f"Naive Softmax Coverage: {coverage:.3f}, Avg Set Size: {avg_size:.2f}")

Naive Softmax Coverage: 0.896, Avg Set Size: 2.32


In [None]:
# Testing different alpha levels 
alphas = [0.05, 0.1, 0.15, 0.2]

print("=== APS Results Across Different Alpha Levels ===")
for alpha in alphas:
  
    aps_score = APS(score_type="softmax", randomized=True)
    aps_predictor = SplitPredictor(score_function=aps_score)

   
    aps_predictor.calculate_threshold(cal_logits_53, cal_labels_53, alpha=alpha) #calculating threshold directly withl logits

    
    aps_prediction_sets = aps_predictor.predict_with_logits(test_logits_53)

 
    coverage = aps_prediction_sets[range(len(test_labels_53)), test_labels_53].float().mean().item()
    avg_size = aps_prediction_sets.sum(dim=1).float().mean().item()

    print(f"Alpha: {alpha:.2f} | Coverage: {coverage:.3f} | Avg Set Size: {avg_size:.2f}")

=== APS Results Across Different Alpha Levels ===
Alpha: 0.05 | Coverage: 0.944 | Avg Set Size: 4.19
Alpha: 0.10 | Coverage: 0.896 | Avg Set Size: 2.48
Alpha: 0.15 | Coverage: 0.850 | Avg Set Size: 1.78
Alpha: 0.20 | Coverage: 0.803 | Avg Set Size: 1.45


In [None]:
def calculate_aps_pvalues_with_saved_logits(cal_logits, cal_labels, test_logits):
    """
    Calculate p-values for APS using only saved logits
    """
    # Create APS score function
    aps_score = APS(score_type="softmax", randomized=False)

    # Convert logits to probabilities
    cal_probs = F.softmax(cal_logits, dim=1)
    test_probs = F.softmax(test_logits, dim=1)

    # Calculate calibration scores
    cal_scores = aps_score._calculate_single_label(cal_probs, cal_labels)

    # Calculate p-values for each test sample
    all_pvalues = []

    for i in range(len(test_logits)):
        sample_pvalues = []

        # For each possible class
        for class_idx in range(test_probs.shape[1]):
            # Calculate APS score if this class were the true class
            candidate_labels = torch.tensor([class_idx])
            candidate_score = aps_score._calculate_single_label(test_probs[i:i+1], candidate_labels)

            # P-value = proportion of calibration scores > candidate score
            pvalue = (cal_scores > candidate_score).float().mean().item()
            sample_pvalues.append(pvalue)

        all_pvalues.append(sample_pvalues)

    return torch.tensor(all_pvalues)

# Use with your saved logits
pvalues = calculate_aps_pvalues_with_saved_logits(cal_logits_53, cal_labels_53, test_logits_53)

In [None]:
def display_ranked_predictions(test_logits, test_labels, pvalues, alpha=0.1, sample_idx=0):

    #Displaying predictions ranked by p-values for a single sample

    class_names = [
        'Atelectasis', 'Cardiomegaly', 'Effusion', 'Infiltration', 'Mass',
        'Nodule', 'Pneumonia', 'Pneumothorax', 'Consolidation', 'Edema',
        'Emphysema', 'Fibrosis', 'Pleural_Thickening', 'Hernia'
    ]

    sample_pvalues = pvalues[sample_idx]
    true_label = test_labels[sample_idx].item()

    ranked_results = []
    for class_idx in range(len(sample_pvalues)):
        pval = sample_pvalues[class_idx].item()
        included = "✓" if pval >= alpha else "✗"
        ranked_results.append((class_idx, class_names[class_idx], pval, included))

    # Sort by p-value (descending)
    ranked_results.sort(key=lambda x: x[2], reverse=True)

    print(f"\n=== Sample {sample_idx} (True label: {class_names[true_label]}) ===")
    print("Rank | Class               | P-value | Included")
    print("-" * 55)
    for rank, (class_idx, class_name, pval, included) in enumerate(ranked_results, 1):
        print(f"{rank:4d} | {class_name:19s} | {pval:7.3f} | {included:8s}")

for i in range(20):
    display_ranked_predictions(test_logits_53, test_labels_53, pvalues, alpha=0.1, sample_idx=i)


=== Sample 0 (True label: Mass) ===
Rank | Class               | P-value | Included
-------------------------------------------------------
   1 | Mass                |   0.300 | ✓       
   2 | Pneumothorax        |   0.075 | ✗       
   3 | Edema               |   0.045 | ✗       
   4 | Fibrosis            |   0.033 | ✗       
   5 | Cardiomegaly        |   0.028 | ✗       
   6 | Pleural_Thickening  |   0.022 | ✗       
   7 | Effusion            |   0.017 | ✗       
   8 | Consolidation       |   0.014 | ✗       
   9 | Atelectasis         |   0.009 | ✗       
  10 | Pneumonia           |   0.006 | ✗       
  11 | Infiltration        |   0.003 | ✗       
  12 | Nodule              |   0.000 | ✗       
  13 | Emphysema           |   0.000 | ✗       
  14 | Hernia              |   0.000 | ✗       

=== Sample 1 (True label: Cardiomegaly) ===
Rank | Class               | P-value | Included
-------------------------------------------------------
   1 | Effusion            |   0.955 |

In [None]:
print("\n=== RAPS Results Across Different Alpha Levels ===")
for alpha in alphas:
    # Create RAPS predictor
    raps_score = RAPS(score_type="softmax", randomized=True, penalty=2, kreg=14)
    raps_predictor = SplitPredictor(score_function=raps_score)

    # Use calculate_threshold with logits directly
    raps_predictor.calculate_threshold(cal_logits_53, cal_labels_53, alpha=alpha)

    # Generate prediction sets using logits
    raps_prediction_sets = raps_predictor.predict_with_logits(test_logits_53)

    # Calculate metrics
    coverage = raps_prediction_sets[range(len(test_labels_53)), test_labels_53].float().mean().item()
    avg_size = raps_prediction_sets.sum(dim=1).float().mean().item()

    print(f"Alpha: {alpha:.2f} | Coverage: {coverage:.3f} | Avg Set Size: {avg_size:.2f}")


=== RAPS Results Across Different Alpha Levels ===
Alpha: 0.05 | Coverage: 0.934 | Avg Set Size: 3.69
Alpha: 0.10 | Coverage: 0.882 | Avg Set Size: 2.20
Alpha: 0.15 | Coverage: 0.845 | Avg Set Size: 1.69
Alpha: 0.20 | Coverage: 0.793 | Avg Set Size: 1.41


In [None]:
# Attempt to see how RAPS works under different hyperparameters k-reg and lamda

kreg_values = [0, 1, 2, 3, 4, 5, 6, 7]
lambda_values = [0.1,0.2,1.0, 1.5, 2.0,5]
alpha = 0.1

for lambda_val in lambda_values:
    print(f"\nLambda = {lambda_val}")
    print("-" * 30)

    for kreg in kreg_values:
        raps_score = RAPS(score_type="softmax", randomized=True, penalty=lambda_val, kreg=kreg)
        raps_predictor = SplitPredictor(score_function=raps_score)

        raps_predictor.calculate_threshold(cal_logits_53, cal_labels_53, alpha=alpha)
        raps_prediction_sets = raps_predictor.predict_with_logits(test_logits_53)

        coverage = raps_prediction_sets[range(len(test_labels_53)), test_labels_53].float().mean().item()
        avg_size = raps_prediction_sets.sum(dim=1).float().mean().item()

        print(f"k_reg: {kreg:2d} | Coverage: {coverage:.3f} | Avg Set Size: {avg_size:.2f}")


Lambda = 0.1
------------------------------
k_reg:  0 | Coverage: 0.895 | Avg Set Size: 3.04
k_reg:  1 | Coverage: 0.895 | Avg Set Size: 3.04
k_reg:  2 | Coverage: 0.895 | Avg Set Size: 3.04
k_reg:  3 | Coverage: 0.896 | Avg Set Size: 3.04
k_reg:  4 | Coverage: 0.890 | Avg Set Size: 2.34
k_reg:  5 | Coverage: 0.890 | Avg Set Size: 2.30
k_reg:  6 | Coverage: 0.887 | Avg Set Size: 2.25
k_reg:  7 | Coverage: 0.886 | Avg Set Size: 2.26

Lambda = 0.2
------------------------------
k_reg:  0 | Coverage: 0.895 | Avg Set Size: 3.03
k_reg:  1 | Coverage: 0.895 | Avg Set Size: 3.03
k_reg:  2 | Coverage: 0.896 | Avg Set Size: 3.04
k_reg:  3 | Coverage: 0.895 | Avg Set Size: 3.04
k_reg:  4 | Coverage: 0.892 | Avg Set Size: 2.38
k_reg:  5 | Coverage: 0.888 | Avg Set Size: 2.32
k_reg:  6 | Coverage: 0.884 | Avg Set Size: 2.25
k_reg:  7 | Coverage: 0.886 | Avg Set Size: 2.29

Lambda = 1.0
------------------------------
k_reg:  0 | Coverage: 0.895 | Avg Set Size: 3.04
k_reg:  1 | Coverage: 0.895 | Av

In [None]:
kreg_values = [0, 1,2,3,4,5,6,7]
lambda_fixed = 1

for kreg in kreg_values:
    # Create RAPS predictor with different k_reg
    raps_score = RAPS(score_type="softmax", randomized=True, penalty=lambda_fixed, kreg=kreg)
    raps_predictor = SplitPredictor(score_function=raps_score)

    # Calculate threshold and predict
    raps_predictor.calculate_threshold(cal_logits_53, cal_labels_53, alpha=alpha)
    raps_prediction_sets = raps_predictor.predict_with_logits(test_logits_53)

    # Calculate metrics
    coverage = raps_prediction_sets[range(len(test_labels_53)), test_labels_53].float().mean().item()
    avg_size = raps_prediction_sets.sum(dim=1).float().mean().item()

    print(f"k_reg: {kreg:4d} | Coverage: {coverage:.3f} | Avg Set Size: {avg_size:.2f}")

k_reg:    0 | Coverage: 0.895 | Avg Set Size: 3.04
k_reg:    1 | Coverage: 0.895 | Avg Set Size: 3.04
k_reg:    2 | Coverage: 0.894 | Avg Set Size: 3.03
k_reg:    3 | Coverage: 0.896 | Avg Set Size: 3.04
k_reg:    4 | Coverage: 0.890 | Avg Set Size: 2.36
k_reg:    5 | Coverage: 0.889 | Avg Set Size: 2.30
k_reg:    6 | Coverage: 0.886 | Avg Set Size: 2.27
k_reg:    7 | Coverage: 0.890 | Avg Set Size: 2.33


In [None]:
from google.colab import files
# save the fine-tuned model, calibration and test labels and logits
def save_logits_and_labels(loader, filename_prefix):
    all_logits = []
    all_labels = []

    model.eval()
    with torch.no_grad():
        for images, labels, *_ in loader:
            images = images.to(device)
            logits = model(images).cpu()
            all_logits.append(logits)
            all_labels.append(labels)

    logits_tensor = torch.cat(all_logits)
    labels_tensor = torch.cat(all_labels)

    torch.save({'logits': logits_tensor, 'labels': labels_tensor}, f"{filename_prefix}_data.pt")

# save both sets
save_logits_and_labels(calib_loader, "calibration")
save_logits_and_labels(test_loader, "test")

# save the fine-tuned model
torch.save(model.state_dict(), "trained_model.pth")

# downloading the files
files.download("calibration_data.pt")
files.download("test_data.pt")
files.download("trained_model.pth")

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

#download the fine tuned model

In [None]:
model.load_state_dict(torch.load("trained_model.pth", map_location="cpu"))  # or 'cuda' if using GPU
model.eval()  # Set to evaluation mode


ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): Bottleneck(
      (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (downsample): Sequential(
        (0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)


ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): Bottleneck(
      (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (downsample): Sequential(
        (0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 

In [None]:

# Define the groups
groups = [
    {"attr": "sex", "value": "Male"},
    {"attr": "sex", "value": "Female"},
    {"attr": "age_group", "value": "Young"},
    {"attr": "age_group", "value": "Middle"},
    {"attr": "age_group", "value": "Old"},
]

# extract group label
def get_group_label(age, sex):
    age_group = 'Young' if age < 40 else 'Middle' if age <= 60 else 'Old'
    sex_group = 'Male' if sex == 'M' else 'Female'
    return {
        "sex": sex_group,
        "age_group": age_group,
    }

# calibrate each group separately
def run_group_evaluation(model, calib_set, test_set,thr_cls, device):
    for group in groups:
        attr, val = group["attr"], group["value"]

        # we filter calibration indices
        calib_indices = [
            i for i, (_, _, age, sex) in enumerate(calib_set)
            if get_group_label(age, sex)[attr] == val
        ]
        test_indices = [
            i for i, (_, _, age, sex) in enumerate(test_set)
            if get_group_label(age, sex)[attr] == val
        ]
        # if the calibration and test indices are not defined
        if not calib_indices or not test_indices:
            print(f"Skipping group {attr} = {val} (no data)")
            continue

        # group-specific loaders
        group_calib_loader = DataLoader(Subset(calib_set, calib_indices), batch_size=32)
        group_test_loader = DataLoader(Subset(test_set, test_indices), batch_size=32)

        # THR Evaluation (Naive Softmax Threshold)
        thr = thr_cls()
        thr_predictor = SplitPredictor(score_function=thr, model=model)
        thr_predictor.calibrate(group_calib_loader, alpha=0.1)
        thr_results = thr_predictor.evaluate(group_test_loader)
        thr_cov = thr_results.get("Coverage_rate", 0)
        thr_sz = thr_results.get("Average_size", 0)
        print(f"THR: {attr} = {val:<12} | Coverage: {thr_cov:.3f}, Set Size: {thr_sz:.2f}")

run_group_evaluation(model, calib_set, test_set,THR, device)

THR: sex = Male         | Coverage: 0.891, Set Size: 2.38
THR: sex = Female       | Coverage: 0.901, Set Size: 2.47
THR: age_group = Young        | Coverage: 0.884, Set Size: 2.24
THR: age_group = Middle       | Coverage: 0.899, Set Size: 2.38
THR: age_group = Old          | Coverage: 0.896, Set Size: 2.66


In [None]:

# Define the groups
groups = [
    {"attr": "sex", "value": "Male"},
    {"attr": "sex", "value": "Female"},
    {"attr": "age_group", "value": "Young"},
    {"attr": "age_group", "value": "Middle"},
    {"attr": "age_group", "value": "Old"},
]

# extract group label
def get_group_label(age, sex):
    age_group = 'Young' if age < 40 else 'Middle' if age <= 60 else 'Old'
    sex_group = 'Male' if sex == 'M' else 'Female'
    return {
        "sex": sex_group,
        "age_group": age_group,
    }

# calibrate each group separately
def run_group_evaluation(model, calib_set, test_set, aps_cls, raps_cls, device):
    for group in groups:
        attr, val = group["attr"], group["value"]

        # we filter calibration indices
        calib_indices = [
            i for i, (_, _, age, sex) in enumerate(calib_set)
            if get_group_label(age, sex)[attr] == val
        ]
        test_indices = [
            i for i, (_, _, age, sex) in enumerate(test_set)
            if get_group_label(age, sex)[attr] == val
        ]
        # if the calibration and test indices are not defined
        if not calib_indices or not test_indices:
            print(f"Skipping group {attr} = {val} (no data)")
            continue

        # group-specific loaders
        group_calib_loader = DataLoader(Subset(calib_set, calib_indices), batch_size=32)
        group_test_loader = DataLoader(Subset(test_set, test_indices), batch_size=32)

        # APS Evaluation
        aps = aps_cls()
        aps_predictor = SplitPredictor(score_function=aps, model=model)
        aps_predictor.calibrate(group_calib_loader, alpha=0.1)
        aps_results = aps_predictor.evaluate(group_test_loader)
        aps_cov = aps_results.get("Coverage_rate", 0)
        aps_sz = aps_results.get("Average_size", 0)
        print(f"APS: {attr} = {val:<6} | Coverage: {aps_cov:.3f}, Set Size: {aps_sz:.2f}")

        # RAPS Evaluation
        raps = raps_cls()
        raps_predictor = SplitPredictor(score_function=raps, model=model)
        try:
            raps_predictor.calibrate(group_calib_loader, alpha=0.1, k=0, lam=0.05)
        except TypeError:
            raps_predictor.calibrate(group_calib_loader, alpha=0.1)
        raps_results = raps_predictor.evaluate(group_test_loader)
        raps_cov = raps_results.get("Coverage_rate", 0)
        raps_sz = raps_results.get("Average_size", 0)
        print(f"RAPS: {attr} = {val:<6} | Coverage: {raps_cov:.3f}, Set Size: {raps_sz:.2f}")
        print("-" * 60) # we print a dashed line

run_group_evaluation(model, calib_set, test_set, APS, RAPS, device)


APS: sex = Male   | Coverage: 0.897, Set Size: 2.32
RAPS: sex = Male   | Coverage: 0.889, Set Size: 2.17
------------------------------------------------------------
APS: sex = Female | Coverage: 0.904, Set Size: 2.41
RAPS: sex = Female | Coverage: 0.892, Set Size: 2.24
------------------------------------------------------------
APS: age_group = Young  | Coverage: 0.900, Set Size: 2.24
RAPS: age_group = Young  | Coverage: 0.881, Set Size: 2.00
------------------------------------------------------------
APS: age_group = Middle | Coverage: 0.899, Set Size: 2.36
RAPS: age_group = Middle | Coverage: 0.894, Set Size: 2.22
------------------------------------------------------------
APS: age_group = Old    | Coverage: 0.905, Set Size: 2.78
RAPS: age_group = Old    | Coverage: 0.899, Set Size: 2.58
------------------------------------------------------------
