# Assignment 1, Task 1: CNN vs. ViT Inductive Biases
**By:** Muhammad Abdullah Sohail, Muhammad Haseeb, and Salaar Masood.

**Models:** ResNet-50 vs. ViT-S/16  
**Dataset:** CIFAR-10

This notebook contains the complete code for investigating and comparing the inductive biases of a Convolutional Neural Network (ResNet-50) and a Vision Transformer (ViT-S/16). The experiments are structured to test semantic, architectural, and generalization biases as outlined in the assignment brief.

### Setup and Imports

In [None]:
# Install required libraries, especially 'timm' for the ViT-S/16 model
!pip install -q timm scikit-learn seaborn umap-learn

import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, models, transforms
from torch.utils.data import DataLoader, Dataset
import timm # PyTorch Image Models library

import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
import time
import copy
import os
from PIL import Image
from tqdm.notebook import tqdm

from sklearn.manifold import TSNE
import umap

# --- Configuration ---
# Set seed for reproducibility
torch.manual_seed(42)
np.random.seed(42)

# Set up device (GPU is highly recommended)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Matplotlib style
plt.style.use('seaborn-v0_8-whitegrid')
%matplotlib inline

### Helper Functions (Training & Evaluation)

In [None]:
# This cell contains reusable helper functions for training and evaluating models.

def train_model(model, train_loader, val_loader, num_epochs=10, lr=0.001, model_save_path='model.pth'):
    """Trains a model and saves the best performing one."""
    criterion = nn.CrossEntropyLoss()
    # Observe that only parameters of the final layer are being optimized.
    optimizer = optim.SGD(filter(lambda p: p.requires_grad, model.parameters()), lr=lr, momentum=0.9)
    
    best_model_wts = copy.deepcopy(model.state_dict())
    best_acc = 0.0
    
    history = {'train_loss': [], 'train_acc': [], 'val_loss': [], 'val_acc': []}

    for epoch in range(num_epochs):
        print(f'Epoch {epoch+1}/{num_epochs}')
        print('-' * 10)

        for phase in ['train', 'val']:
            if phase == 'train':
                model.train()
                dataloader = train_loader
            else:
                model.eval()
                dataloader = val_loader

            running_loss = 0.0
            running_corrects = 0

            for inputs, labels in tqdm(dataloader, desc=f"{phase.capitalize()} Epoch {epoch+1}"):
                inputs, labels = inputs.to(device), labels.to(device)
                optimizer.zero_grad()

                with torch.set_grad_enabled(phase == 'train'):
                    outputs = model(inputs)
                    _, preds = torch.max(outputs, 1)
                    loss = criterion(outputs, labels)

                    if phase == 'train':
                        loss.backward()
                        optimizer.step()

                running_loss += loss.item() * inputs.size(0)
                running_corrects += torch.sum(preds == labels.data)

            epoch_loss = running_loss / len(dataloader.dataset)
            epoch_acc = running_corrects.double() / len(dataloader.dataset)
            
            history[f'{phase}_loss'].append(epoch_loss)
            history[f'{phase}_acc'].append(epoch_acc.item())

            print(f'{phase.capitalize()} Loss: {epoch_loss:.4f} Acc: {epoch_acc:.4f}')

            if phase == 'val' and epoch_acc > best_acc:
                best_acc = epoch_acc
                best_model_wts = copy.deepcopy(model.state_dict())
                torch.save(model.state_dict(), model_save_path)
                print(f"New best model saved to {model_save_path} with accuracy: {best_acc:.4f}")

    print(f'Best val Acc: {best_acc:4f}')
    model.load_state_dict(best_model_wts)
    return model, history

def evaluate_model(model, dataloader, name="Test"):
    """Evaluates a model's accuracy on a given dataloader."""
    model.eval()
    running_corrects = 0
    
    with torch.no_grad():
        for inputs, labels in tqdm(dataloader, desc=f"Evaluating on {name}"):
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            _, preds = torch.max(outputs, 1)
            running_corrects += torch.sum(preds == labels.data)
            
    accuracy = running_corrects.double() / len(dataloader.dataset)
    print(f'{name} Accuracy: {accuracy:.4f}')
    return accuracy.item()

## Step 1: Model Fine-tuning on CIFAR-10

First, we fine-tune a pre-trained ResNet-50 and a ViT-S/16 on the CIFAR-10 training data. This establishes our baseline in-distribution performance. We replace the final classifier of each model to match the 10 classes of CIFAR-10 and train them until they reach a reasonable accuracy.

In [None]:
# --- Data Preparation ---
# Standard ImageNet transforms, as models are pre-trained on it.
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

train_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
test_dataset = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)

train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True, num_workers=2)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False, num_workers=2)

# --- Model Definition ---
def get_model(model_name, num_classes=10, pretrained=True):
    if model_name == 'resnet50':
        model = models.resnet50(weights=models.ResNet50_Weights.DEFAULT if pretrained else None)
        # Unfreeze all parameters for fine-tuning
        for param in model.parameters():
            param.requires_grad = True
        num_ftrs = model.fc.in_features
        model.fc = nn.Linear(num_ftrs, num_classes)
        
    elif model_name == 'vit_s_16':
        # Using timm to get the ViT-S/16 model
        model = timm.create_model('vit_small_patch16_224', pretrained=pretrained, num_classes=num_classes)
    else:
        raise ValueError("Model not supported")
        
    return model.to(device)

# Instantiate the models
resnet50 = get_model('resnet50')
vit_s16 = get_model('vit_s_16')

print("ResNet-50 and ViT-S/16 models created.")

In [None]:
# NOTE: Training will take a significant amount of time. 
# Run this once and the best models will be saved.
# If you have already trained models, you can skip this cell and load them in the next one.

print("--- Training ResNet-50 ---")
resnet50_trained, resnet_history = train_model(resnet50, train_loader, test_loader, num_epochs=10, model_save_path='resnet50_cifar10.pth')

print("\n--- Training ViT-S/16 ---")
vit_s16_trained, vit_history = train_model(vit_s16, train_loader, test_loader, num_epochs=10, model_save_path='vit_s16_cifar10.pth')

## Step 2 & 3: In-Distribution Performance and Color Bias Test

Now, we load our fine-tuned models and evaluate their baseline accuracy on the clean CIFAR-10 test set. Then, we test them on a grayscale version of the test set to measure their reliance on color cues. A significant drop in accuracy indicates a strong color bias.

In [None]:
# --- Load Pre-trained Models ---
# If you skipped the training cell, run this to load your saved models.
resnet50 = get_model('resnet50')
resnet50.load_state_dict(torch.load('resnet50_cifar10.pth'))

vit_s16 = get_model('vit_s_16')
vit_s16.load_state_dict(torch.load('vit_s16_cifar10.pth'))

# --- Create Grayscale Test Set ---
grayscale_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.Grayscale(num_output_channels=3), # Crucial for model input shape
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
grayscale_test_dataset = datasets.CIFAR10(root='./data', train=False, download=True, transform=grayscale_transform)
grayscale_loader = DataLoader(grayscale_test_dataset, batch_size=64, shuffle=False, num_workers=2)

# --- Evaluate ---
results = {}

print("--- Evaluating ResNet-50 ---")
results['resnet_clean'] = evaluate_model(resnet50, test_loader, "Clean CIFAR-10")
results['resnet_gray'] = evaluate_model(resnet50, grayscale_loader, "Grayscale CIFAR-10")

print("\n--- Evaluating ViT-S/16 ---")
results['vit_clean'] = evaluate_model(vit_s16, test_loader, "Clean CIFAR-10")
results['vit_gray'] = evaluate_model(vit_s16, grayscale_loader, "Grayscale CIFAR-10")

# --- Analyze Color Bias ---
resnet_drop = results['resnet_clean'] - results['resnet_gray']
vit_drop = results['vit_clean'] - results['vit_gray']

print(f"\nAccuracy Drop (Color Bias):")
print(f"ResNet-50: {resnet_drop*100:.2f}%")
print(f"ViT-S/16: {vit_drop*100:.2f}%")

## Step 5: Translation Invariance Test

Here, we test the models' robustness to small spatial shifts in the input image. We create a test set where all images are shifted horizontally by a fixed number of pixels. CNNs, due to their convolutional nature (weight sharing), are expected to be more invariant to such shifts than ViTs.

In [None]:
# --- Create Translated Test Set ---
shift_pixels = 16 # A noticeable shift
image_size = 224
translate_fraction = shift_pixels / image_size

translate_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    # Apply a fixed horizontal shift
    transforms.RandomAffine(degrees=0, translate=(translate_fraction, 0)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
translated_test_dataset = datasets.CIFAR10(root='./data', train=False, download=True, transform=translate_transform)
translated_loader = DataLoader(translated_test_dataset, batch_size=64, shuffle=False, num_workers=2)

# --- Evaluate ---
print("--- Evaluating ResNet-50 on Translated Data ---")
results['resnet_translated'] = evaluate_model(resnet50, translated_loader, f"Translated ({shift_pixels}px)")

print("\n--- Evaluating ViT-S/16 on Translated Data ---")
results['vit_translated'] = evaluate_model(vit_s16, translated_loader, f"Translated ({shift_pixels}px)")

# --- Analyze Translation Invariance ---
resnet_drop_translate = results['resnet_clean'] - results['resnet_translated']
vit_drop_translate = results['vit_clean'] - results['vit_translated']

print(f"\nAccuracy Drop (Translation):")
print(f"ResNet-50: {resnet_drop_translate*100:.2f}%")
print(f"ViT-S/16: {vit_drop_translate*100:.2f}%")

## Step 6: Permutation / Occlusion Test

This section tests how models react when the global structure of an image is disrupted.
1.  **Patch Shuffling:** We break the image into a grid of patches and shuffle their positions.
2.  **Occlusion:** We randomly mask out a square region of the image.

This helps reveal whether the models rely more on local features (which are preserved in shuffling) or global context (which is destroyed).

In [None]:
# --- Custom Transform for Patch Shuffling ---
class PatchShuffler:
    def __init__(self, patch_size=16):
        self.patch_size = patch_size

    def __call__(self, img_tensor):
        c, h, w = img_tensor.shape
        num_patches = (h // self.patch_size) * (w // self.patch_size)
        
        # Create a view of the image as patches
        patches = img_tensor.unfold(1, self.patch_size, self.patch_size).unfold(2, self.patch_size, self.patch_size)
        patches = patches.contiguous().view(c, -1, self.patch_size, self.patch_size)  # [C, Num_Patches, PS, PS]
        
        # Shuffle patches
        perm = torch.randperm(patches.size(1))
        shuffled_patches = patches[:, perm, :, :]
        
        # Reassemble the image
        shuffled_patches = shuffled_patches.view(c, h // self.patch_size, w // self.patch_size, self.patch_size, self.patch_size)
        shuffled_img = shuffled_patches.permute(0, 1, 3, 2, 4).contiguous().view(c, h, w)
        return shuffled_img

# --- Create Perturbed Datasets ---
patch_shuffle_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    PatchShuffler(patch_size=16), # Custom transform
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

occlusion_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    # Built-in transform for occlusion
    transforms.RandomErasing(p=1.0, scale=(0.1, 0.2), ratio=(0.5, 2.0)), 
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

shuffled_test_dataset = datasets.CIFAR10(root='./data', train=False, download=True, transform=patch_shuffle_transform)
occluded_test_dataset = datasets.CIFAR10(root='./data', train=False, download=True, transform=occlusion_transform)

shuffled_loader = DataLoader(shuffled_test_dataset, batch_size=64, shuffle=False)
occluded_loader = DataLoader(occluded_test_dataset, batch_size=64, shuffle=False)

# --- Evaluate ---
print("--- Evaluating on Patch-Shuffled Data ---")
results['resnet_shuffled'] = evaluate_model(resnet50, shuffled_loader, "Patch Shuffled")
results['vit_shuffled'] = evaluate_model(vit_s16, shuffled_loader, "Patch Shuffled")

print("\n--- Evaluating on Occluded Data ---")
results['resnet_occluded'] = evaluate_model(resnet50, occluded_loader, "Occluded")
results['vit_occluded'] = evaluate_model(vit_s16, occluded_loader, "Occluded")

## Step 7: Feature Representation Analysis

To understand *how* the models represent data internally, we extract the penultimate layer features for a subset of test images. We then use UMAP to visualize the 2D feature space, coloring points by their class. This can reveal whether one model creates more semantically meaningful and separable clusters.

In [None]:
def get_features(model, dataloader, num_samples=500):
    """Extracts penultimate features from a model."""
    model.eval()
    
    # Create a feature extractor model
    if isinstance(model, models.ResNet):
        feature_extractor = nn.Sequential(*list(model.children())[:-1]) # Remove final fc layer
    elif isinstance(model, timm.models.VisionTransformer):
        feature_extractor = model.forward_features # Use timm's built-in method
    else:
        raise TypeError("Unsupported model type for feature extraction")

    features_list = []
    labels_list = []
    
    with torch.no_grad():
        for inputs, labels in tqdm(dataloader, desc="Extracting Features", total=int(np.ceil(num_samples / dataloader.batch_size))):
            inputs = inputs.to(device)
            
            # Get features
            feats = feature_extractor(inputs)
            # Flatten features: for ResNet it's (B, C, 1, 1), for ViT it's (B, N, D)
            if feats.dim() == 4:
                feats = feats.squeeze(-1).squeeze(-1)
            elif 'vit' in model.default_cfg['architecture']:
                feats = feats[:, 0] # Take the CLS token embedding

            features_list.append(feats.cpu())
            labels_list.append(labels)
            
            if len(torch.cat(labels_list)) >= num_samples:
                break

    return torch.cat(features_list).numpy(), torch.cat(labels_list).numpy()

# --- Extract Features ---
resnet_features, resnet_labels = get_features(resnet50, test_loader)
vit_features, vit_labels = get_features(vit_s16, test_loader)

# --- Visualize with UMAP ---
print("Running UMAP on ResNet-50 features...")
reducer = umap.UMAP(n_components=2, random_state=42)
resnet_2d = reducer.fit_transform(resnet_features)

print("Running UMAP on ViT-S/16 features...")
vit_2d = reducer.fit_transform(vit_features)

# Plotting
fig, axes = plt.subplots(1, 2, figsize=(16, 7))
class_names = test_dataset.classes

# ResNet Plot
sns.scatterplot(x=resnet_2d[:, 0], y=resnet_2d[:, 1], hue=[class_names[l] for l in resnet_labels], 
                palette='tab10', s=10, alpha=0.7, ax=axes[0]).set_title('ResNet-50 Feature Space (UMAP)')
axes[0].legend(title='Class', markerscale=1, fontsize='small')

# ViT Plot
sns.scatterplot(x=vit_2d[:, 0], y=vit_2d[:, 1], hue=[class_names[l] for l in vit_labels], 
                palette='tab10', s=10, alpha=0.7, ax=axes[1]).set_title('ViT-S/16 Feature Space (UMAP)')
axes[1].legend(title='Class', markerscale=1, fontsize='small')

plt.tight_layout()
plt.savefig("feature_space_comparison.png")
plt.show()

## Task 1 Summary of Quantitative Results

Finally, we compile all our accuracy measurements into a single table to provide a clear, quantitative comparison of the two models across all tests.

In [None]:
# Create a pandas DataFrame for a clean summary
summary_df = pd.DataFrame({
    'Test': ['Clean', 'Grayscale', 'Translated', 'Shuffled', 'Occluded'],
    'ResNet-50 Acc': [
        results.get('resnet_clean', 0),
        results.get('resnet_gray', 0),
        results.get('resnet_translated', 0),
        results.get('resnet_shuffled', 0),
        results.get('resnet_occluded', 0)
    ],
    'ViT-S/16 Acc': [
        results.get('vit_clean', 0),
        results.get('vit_gray', 0),
        results.get('vit_translated', 0),
        results.get('vit_shuffled', 0),
        results.get('vit_occluded', 0)
    ]
})

# Calculate accuracy drop from clean baseline
summary_df['ResNet-50 Drop'] = summary_df['ResNet-50 Acc'][0] - summary_df['ResNet-50 Acc']
summary_df['ViT-S/16 Drop'] = summary_df['ViT-S/16 Acc'][0] - summary_df['ViT-S/16 Acc']

# Format as percentage
for col in summary_df.columns[1:]:
    summary_df[col] = summary_df[col].apply(lambda x: f"{x*100:.2f}%")

print("--- Summary of Model Performance on Perturbation Tests ---")
display(summary_df)