# Notebook 1: Setup and Data Loading

## Purpose
This notebook handles all setup, data loading, and model initialization. It prepares the data and extracts logits that will be used in all subsequent notebooks.

## What This Notebook Does
1. Import all necessary libraries
2. Set random seeds for reproducibility
3. Load CIFAR-10 dataset
4. Load pre-trained ResNet56 model
5. Extract logits and labels for validation and test sets
6. Save preprocessed data for use in other notebooks

## Output
- `logits_val.npy` - Validation set logits
- `labels_val.npy` - Validation set labels
- `logits_test.npy` - Test set logits
- `labels_test.npy` - Test set labels


In [1]:
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import numpy as np
import os

np.random.seed(42)
torch.manual_seed(42)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(42)

if torch.backends.mps.is_available():
    device = torch.device('mps')
    print('Using Apple Metal (MPS)')
elif torch.cuda.is_available():
    device = torch.device('cuda')
    print('Using CUDA')
else:
    device = torch.device('cpu')
    print('Using CPU')

print(f'Device: {device}')
print(f'PyTorch version: {torch.__version__}')
print(f'NumPy version: {np.__version__}')


Using Apple Metal (MPS)
Device: mps
PyTorch version: 2.9.1
NumPy version: 2.2.6


In [2]:
print('='*60)
print('STEP 1: LOAD PRE-TRAINED RESNET56 MODEL')
print('='*60)

try:
    resnet_model = torch.hub.load("chenyaofo/pytorch-cifar-models", "cifar10_resnet56", pretrained=True)
    resnet_model = resnet_model.to(device)
    resnet_model.eval()
    print('✓ ResNet56 loaded successfully')
    print(f'  Parameters: {sum(p.numel() for p in resnet_model.parameters()):,}')
    print(f'  Model on device: {next(resnet_model.parameters()).device}')
except Exception as e:
    print(f'✗ Failed to load model: {e}')
    raise


STEP 1: LOAD PRE-TRAINED RESNET56 MODEL
✓ ResNet56 loaded successfully
  Parameters: 855,770
  Model on device: mps:0


Using cache found in /Users/Studies/.cache/torch/hub/chenyaofo_pytorch-cifar-models_master


In [None]:
print('='*60)
print('STEP 2: LOAD CIFAR-10 DATASET')
print('='*60)

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
])

testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
testloader = DataLoader(testset, batch_size=100, shuffle=False, num_workers=0)

print('✓ CIFAR-10 test set loaded')
print(f'  Total test samples: {len(testset)}')
print(f'  Number of classes: {len(testset.classes)}')

STEP 2: LOAD CIFAR-10 DATASET
✓ CIFAR-10 test set loaded
  Total test samples: 10000
  Number of classes: 10


In [None]:
print('='*60)
print('STEP 3: SPLIT DATA INTO VALIDATION AND TEST SETS')
print('='*60)

all_test_data = []
all_test_labels = []

with torch.no_grad():
    for images, labels in testloader:
        all_test_data.append(images)
        all_test_labels.append(labels)

all_test_data = torch.cat(all_test_data, dim=0)
all_test_labels = torch.cat(all_test_labels, dim=0)

indices = np.arange(len(all_test_data))
np.random.seed(42)
np.random.shuffle(indices)

split_idx = len(indices) // 2
val_indices = indices[:split_idx]
test_indices = indices[split_idx:]

val_data = all_test_data[val_indices]
val_labels = all_test_labels[val_indices]
test_data = all_test_data[test_indices]
test_labels = all_test_labels[test_indices]

print('✓ Data split completed')
print(f'  Validation set: {len(val_data)} samples')
print(f'  Test set: {len(test_data)} samples')
print(f'  Total: {len(val_data) + len(test_data)} samples')

STEP 3: SPLIT DATA INTO VALIDATION AND TEST SETS
✓ Data split completed
  Validation set: 5000 samples
  Test set: 5000 samples
  Total: 10000 samples


In [None]:
print('='*60)
print('STEP 4: EXTRACT LOGITS AND LABELS')
print('='*60)

val_loader = DataLoader(torch.utils.data.TensorDataset(val_data, val_labels), batch_size=100, shuffle=False)
test_loader = DataLoader(torch.utils.data.TensorDataset(test_data, test_labels), batch_size=100, shuffle=False)

logits_val_list = []
labels_val_list = []

print('Extracting logits from validation set...')
with torch.no_grad():
    for images, labels in val_loader:
        images = images.to(device)
        logits = resnet_model(images)
        logits_val_list.append(logits.cpu().numpy())
        labels_val_list.append(labels.numpy())

logits_val = np.concatenate(logits_val_list, axis=0)
labels_val = np.concatenate(labels_val_list, axis=0)

logits_test_list = []
labels_test_list = []

print('Extracting logits from test set...')
with torch.no_grad():
    for images, labels in test_loader:
        images = images.to(device)
        logits = resnet_model(images)
        logits_test_list.append(logits.cpu().numpy())
        labels_test_list.append(labels.numpy())

logits_test = np.concatenate(logits_test_list, axis=0)
labels_test = np.concatenate(labels_test_list, axis=0)

print('\n✓ Logits extracted successfully')
print(f'  Validation logits shape: {logits_val.shape}')
print(f'  Validation labels shape: {labels_val.shape}')
print(f'  Test logits shape: {logits_test.shape}')
print(f'  Test labels shape: {labels_test.shape}')

STEP 4: EXTRACT LOGITS AND LABELS
Extracting logits from validation set...
Extracting logits from test set...

✓ Logits extracted successfully
  Validation logits shape: (5000, 10)
  Validation labels shape: (5000,)
  Test logits shape: (5000, 10)
  Test labels shape: (5000,)


In [None]:
print('='*60)
print('STEP 5: SAVE PREPROCESSED DATA')
print('='*60)

os.makedirs('./data/processed', exist_ok=True)

np.save('./data/processed/logits_val.npy', logits_val)
np.save('./data/processed/labels_val.npy', labels_val)
np.save('./data/processed/logits_test.npy', logits_test)
np.save('./data/processed/labels_test.npy', labels_test)

print('✓ Data saved successfully')
print('  Files saved to ./data/processed/')
print('  - logits_val.npy')
print('  - labels_val.npy')
print('  - logits_test.npy')
print('  - labels_test.npy')
print('\n============================================================')
print('SETUP COMPLETE!')
print('============================================================')
print('\nNext steps:')
print('  1. Run Notebook 2: Baseline Calibration Methods')
print('  2. Or load saved data in other notebooks using:')
print('     logits_val = np.load("./data/processed/logits_val.npy")')
print('     labels_val = np.load("./data/processed/labels_val.npy")')

STEP 5: SAVE PREPROCESSED DATA
✓ Data saved successfully
  Files saved to ./data/processed/
  - logits_val.npy
  - labels_val.npy
  - logits_test.npy
  - labels_test.npy

SETUP COMPLETE!

Next steps:
  1. Run Notebook 2: Baseline Calibration Methods
  2. Or load saved data in other notebooks using:
     logits_val = np.load("./data/processed/logits_val.npy")
     labels_val = np.load("./data/processed/labels_val.npy")
