In [1]:
import torch
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, Subset

# Simple transforms: convert to tensor and normalize
transform = transforms.Compose([
    transforms.ToTensor(),
])

# Download and load CIFAR-10 train+test splits
train_full = datasets.CIFAR10(root="../data", train=True, download=True, transform=transform)
test_set = datasets.CIFAR10(root="../data", train=False, download=True, transform=transform)

print(f"Train set size (full): {len(train_full)}")
print(f"Test set size: {len(test_set)}")


100.0%


Train set size (full): 50000
Test set size: 10000


In [2]:
from torch.utils.data import random_split

# Set split sizes
train_size = int(0.8 * len(train_full))   # 40,000
cal_size = len(train_full) - train_size   # 10,000

# Use random_split (with manual seed for reproducibility)
train_set, cal_set = random_split(train_full, [train_size, cal_size], generator=torch.Generator().manual_seed(42))

print(f"Train set size: {len(train_set)}")
print(f"Calibration set size: {len(cal_set)}")
print(f"Test set size: {len(test_set)}")


Train set size: 40000
Calibration set size: 10000
Test set size: 10000


In [3]:
from torch.utils.data import DataLoader

BATCH_SIZE = 128  # Large enough for fast epochs, but fits in GPU/RAM

# Create DataLoaders for each split
train_loader = DataLoader(train_set, batch_size=BATCH_SIZE, shuffle=True)
cal_loader   = DataLoader(cal_set,   batch_size=BATCH_SIZE, shuffle=False)
test_loader  = DataLoader(test_set,  batch_size=BATCH_SIZE, shuffle=False)

# Print example batch shapes
images, labels = next(iter(train_loader))
print("Train batch image shape:", images.shape)
print("Train batch label shape:", labels.shape)


Train batch image shape: torch.Size([128, 3, 32, 32])
Train batch label shape: torch.Size([128])


In [4]:
import torch.nn as nn
import torchvision.models as models

# Load ResNet-18 without ImageNet pretraining, for fair calibration
model = models.resnet18(pretrained=False)

# Change final FC layer to match CIFAR-10 (10 classes)
model.fc = nn.Linear(model.fc.in_features, 10)

print(model)  # Shows the summary




ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
  

In [5]:
# Put model in evaluation mode for testing only
model.eval()

# Grab a batch from your validation loader
sample_images, sample_labels = next(iter(cal_loader))
sample_logits = model(sample_images)

print("Logits shape:", sample_logits.shape)


Logits shape: torch.Size([128, 10])
