### Assignment 2 Understanding transfer learning and fine tuning 

# Imports and Setup

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, random_split
from torchvision import datasets
from torchvision.models import resnet18, ResNet18_Weights
from sklearn.metrics import accuracy_score, confusion_matrix, precision_score, recall_score
import matplotlib.pyplot as plt
import numpy as np
from torch.cuda.amp import autocast, GradScaler

print(f"[INFO] Torch version: {torch.__version__}")

# Reproducibility
torch.manual_seed(42)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"[INFO] Using device: {device}")

# Model Class Definition

In [None]:
class ResNet18Transfer(nn.Module):
    def __init__(self,
                 number_of_output_classes: int,
                 freeze_backbone: bool = True,
                 dense_units: list[int] = [256],
                 dropout_probabilities: float = 0.3,
                 weights=ResNet18_Weights.DEFAULT):
        super().__init__()
        self.weights = weights

        # Create a feature extractor
        backbone = resnet18(weights=weights)
        self.features = nn.Sequential(*list(backbone.children())[:-1])
        self.backbone_out = backbone.fc.in_features  # 512 for ResNet18

        # Freeze the layers
        if freeze_backbone:
            for p in self.features.parameters():
                p.requires_grad = False

        # Classifier
        mlp = []
        cur = self.backbone_out
        for h in dense_units:
            mlp += [
                nn.Linear(cur, h),
                nn.ReLU(inplace=True),
                nn.Dropout(p=dropout_probabilities),
            ]
            cur = h  # Update current size

        # Add final classifier
        self.classifier = nn.Sequential(*mlp)
        self.final_classifier = nn.Linear(cur, number_of_output_classes)

        # Ensure head trainable
        for p in self.classifier.parameters():
            p.requires_grad = True
        for p in self.final_classifier.parameters():
            p.requires_grad = True

    def forward(self, x):
        x = self.features(x)
        x = torch.flatten(x, 1)
        x = self.classifier(x)
        return self.final_classifier(x)
    
    def unfreeze_layer4(self):
        """Unfreeze the last ResNet block for fine-tuning"""
        # features[7] is layer4 in ResNet18
        for p in self.features[7].parameters():
            p.requires_grad = True

    def get_transform(self):
        """This returns the recommended preprocessing for the pretrained weights"""
        return self.weights.transforms()

# Training and Evaluation Functions

In [None]:
def train_one_epoch(model, loader, criterion, optimizer, device, scaler):
    """Train for one epoch with mixed precision."""
    model.train()
    total_loss = 0.0
    all_preds, all_labels = [], []

    for images, labels in loader:
        images, labels = images.to(device), labels.to(device)

        optimizer.zero_grad()
        
        with autocast():
            outputs = model(images)
            loss = criterion(outputs, labels)
        
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        total_loss += loss.item()
        preds = outputs.argmax(dim=1)
        all_preds.extend(preds.cpu().numpy())
        all_labels.extend(labels.cpu().numpy())

    avg_loss = total_loss / len(loader)
    acc = accuracy_score(all_labels, all_preds)
    return avg_loss, acc


def evaluate(model, loader, criterion, device):
    """Evaluate the model on a dataloader."""
    model.eval()
    total_loss = 0.0
    all_preds, all_labels = [], []

    with torch.no_grad():
        for images, labels in loader:
            images, labels = images.to(device), labels.to(device)
            
            with autocast():
                outputs = model(images)
                loss = criterion(outputs, labels)

            total_loss += loss.item()
            preds = outputs.argmax(dim=1)
            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())

    avg_loss = total_loss / len(loader)
    acc = accuracy_score(all_labels, all_preds)
    return avg_loss, acc


def train_with_early_stopping(model, train_loader, val_loader, criterion, optimizer, 
                              device, scaler, epochs, patience=5):
    """Train with early stopping."""
    train_losses, val_losses = [], []
    train_accs, val_accs = [], []
    
    best_val_loss = float('inf')
    patience_counter = 0
    best_model_state = None

    for epoch in range(1, epochs + 1):
        train_loss, train_acc = train_one_epoch(model, train_loader, criterion, optimizer, device, scaler)
        val_loss, val_acc = evaluate(model, val_loader, criterion, device)

        train_losses.append(train_loss)
        val_losses.append(val_loss)
        train_accs.append(train_acc)
        val_accs.append(val_acc)

        # Early stopping check
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            patience_counter = 0
            best_model_state = model.state_dict().copy()
        else:
            patience_counter += 1
            
        if patience_counter >= patience:
            print(f"\n[EARLY STOP] No improvement for {patience} epochs. Stopping at epoch {epoch}.")
            model.load_state_dict(best_model_state)
            break

        if epoch % 5 == 0 or epoch == 1:
            print(f"Epoch {epoch:3d}/{epochs} | "
                  f"Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f} | "
                  f"Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.4f}")

    return train_losses, val_losses, train_accs, val_accs


def predict(model, loader, device):
    """Get predictions for entire dataset."""
    model.eval()
    all_preds, all_labels = [], []

    with torch.no_grad():
        for images, labels in loader:
            images = images.to(device)
            
            with autocast():
                outputs = model(images)
            
            preds = outputs.argmax(dim=1)
            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(labels.numpy())

    return np.array(all_preds), np.array(all_labels)

# Dataset Configuration and Loading

In [None]:
# CNN Configuration
cnn_configuration = {
    "number_of_output_classes": 37,  # OxfordIIITPet has 37 classes
    "freeze_backbone": True,
    "dense_units": [],  # Empty = direct 512 -> 37 mapping
    "dropout_probabilities": 0.3
}

# Create model and get transform
model = ResNet18Transfer(**cnn_configuration).to(device)
transform = model.get_transform()

print(model)

# Split the data

In [None]:
# Load dataset
full_train = datasets.OxfordIIITPet(root="./data", split="trainval", download=True, transform=transform)
test_ds = datasets.OxfordIIITPet(root="./data", split="test", download=True, transform=transform)

# Train/val split
val_ratio = 0.2
val_size = int(len(full_train) * val_ratio)
train_size = len(full_train) - val_size
train_ds, val_ds = random_split(full_train, [train_size, val_size])

print(f"Train size:      {len(train_ds)}")
print(f"Validation size: {len(val_ds)}")
print(f"Test size:       {len(test_ds)}")

# create dataloader handler from the dataset

In [None]:
BATCH_SIZE = 64

train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True,
                         num_workers=4, pin_memory=True, persistent_workers=True)
val_loader = DataLoader(val_ds, batch_size=BATCH_SIZE, shuffle=False,
                       num_workers=4, pin_memory=True, persistent_workers=True)
test_loader = DataLoader(test_ds, batch_size=BATCH_SIZE, shuffle=False,
                        num_workers=4, pin_memory=True, persistent_workers=True)

print(f"Train batches: {len(train_loader)}")
print(f"Val batches:   {len(val_loader)}")
print(f"Test batches:  {len(test_loader)}")

# Question 1?
- Why we freeze the backbone? Why not the classifier?

## Answer to Question 1

We **freeze the backbone** because it was already trained on ImageNet and has learned rich, general-purpose features (edges, textures, shapes). Freezing it means those weights are not updated during training, which:
1. Saves a lot of computation — only the small classifier is updated.
2. Prevents destroying the learned features with a noisy gradient from our small dataset.

We do **not freeze the classifier** because it is randomly initialised and knows nothing about our dataset. It needs to be trained from scratch to map the backbone's features to our 37 pet-breed classes.

### Phase 1: Transfer learning

# freeze all layers except the classifier.
- train and evaluate the model for 50 epochs 
- remember to save val loss and train loss

In [None]:
# Setup optimizer and loss
optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), 
                      lr=1e-3, weight_decay=1e-4)
criterion = nn.CrossEntropyLoss()
scaler = GradScaler()

# Train with early stopping
print("\n=== Phase 1: Transfer Learning ===")
phase1_train_losses, phase1_val_losses, phase1_train_accs, phase1_val_accs = train_with_early_stopping(
    model, train_loader, val_loader, criterion, optimizer, device, scaler, epochs=30, patience=5
)

# Plot accuracy and loss

In [None]:
epochs_range = range(1, len(phase1_train_losses) + 1)

fig, axes = plt.subplots(1, 2, figsize=(12, 4))

# Loss curve
axes[0].plot(epochs_range, phase1_train_losses, label="Train Loss")
axes[0].plot(epochs_range, phase1_val_losses, label="Val Loss")
axes[0].set_title("Phase 1 – Loss")
axes[0].set_xlabel("Epoch")
axes[0].set_ylabel("Loss")
axes[0].legend()

# Accuracy curve
axes[1].plot(epochs_range, phase1_train_accs, label="Train Acc")
axes[1].plot(epochs_range, phase1_val_accs, label="Val Acc")
axes[1].set_title("Phase 1 – Accuracy")
axes[1].set_xlabel("Epoch")
axes[1].set_ylabel("Accuracy")
axes[1].legend()

plt.tight_layout()
plt.show()

# plot predictions 

In [None]:
# Show a grid of 8 test images with predicted vs true class names
class_names = test_ds.classes

# Grab one batch from the test loader
images_batch, labels_batch = next(iter(test_loader))
outputs = model(images_batch.to(device))
preds_batch = outputs.argmax(dim=1).cpu()

# ImageNet normalisation - reverse it for display
mean = torch.tensor([0.485, 0.456, 0.406])
std = torch.tensor([0.229, 0.224, 0.225])

fig, axes = plt.subplots(2, 4, figsize=(14, 7))
for i, ax in enumerate(axes.flat):
    img = images_batch[i].permute(1, 2, 0) * std + mean
    img = img.clamp(0, 1).numpy()
    ax.imshow(img)
    true_name = class_names[labels_batch[i]]
    pred_name = class_names[preds_batch[i]]
    color = "green" if labels_batch[i] == preds_batch[i] else "red"
    ax.set_title(f"True: {true_name}\nPred: {pred_name}", color=color, fontsize=8)
    ax.axis("off")

plt.suptitle("Phase 1 Predictions (green = correct, red = wrong)", fontsize=12)
plt.tight_layout()
plt.show()

# calculate TEST accuracy score 

In [None]:
phase1_preds, phase1_labels = predict(model, test_loader, device)
phase1_test_acc = accuracy_score(phase1_labels, phase1_preds)
print(f"Phase 1 Test Accuracy: {phase1_test_acc:.4f}")

# Calculate confusion matrices precision and recall 

In [None]:
# Confusion matrix
cm = confusion_matrix(phase1_labels, phase1_preds)

fig, ax = plt.subplots(figsize=(14, 12))
im = ax.imshow(cm, cmap="Blues")
plt.colorbar(im, ax=ax)
ax.set_title("Phase 1 – Confusion Matrix")
ax.set_xlabel("Predicted")
ax.set_ylabel("True")
plt.tight_layout()
plt.show()

# Per-class precision and recall
precision_macro = precision_score(phase1_labels, phase1_preds, average="macro", zero_division=0)
recall_macro = recall_score(phase1_labels, phase1_preds, average="macro", zero_division=0)

print(f"Phase 1 Macro Precision: {precision_macro:.4f}")
print(f"Phase 1 Macro Recall:    {recall_macro:.4f}")

### Phase 2: Freeze layer 4 - Fine tuning

# from the freezed cnn unfreeze the  ``` layer4 ```. 

In [None]:
# Unfreeze layer4
model.unfreeze_layer4()

# CRITICAL: Use MUCH lower learning rates for fine-tuning
optimizer = optim.Adam([
    {"params": model.features[7].parameters(), "lr": 1e-5},      # layer4 - very low
    {"params": model.classifier.parameters(), "lr": 5e-4},       # MLP head
    {"params": model.final_classifier.parameters(), "lr": 5e-4}, # final layer
], weight_decay=1e-4)

scaler = GradScaler()  # Reset scaler

print("layer4 and classifier are now trainable.")

# Train for 50 epochs 

In [None]:
# Train with early stopping
print("\n=== Phase 2: Fine-Tuning ===")
phase2_train_losses, phase2_val_losses, phase2_train_accs, phase2_val_accs = train_with_early_stopping(
    model, train_loader, val_loader, criterion, optimizer, device, scaler, epochs=30, patience=7
)

# Plot curves

In [None]:
epochs_range2 = range(1, len(phase2_train_losses) + 1)

fig, axes = plt.subplots(1, 2, figsize=(12, 4))

axes[0].plot(epochs_range2, phase2_train_losses, label="Train Loss")
axes[0].plot(epochs_range2, phase2_val_losses, label="Val Loss")
axes[0].set_title("Phase 2 – Loss")
axes[0].set_xlabel("Epoch")
axes[0].set_ylabel("Loss")
axes[0].legend()

axes[1].plot(epochs_range2, phase2_train_accs, label="Train Acc")
axes[1].plot(epochs_range2, phase2_val_accs, label="Val Acc")
axes[1].set_title("Phase 2 – Accuracy")
axes[1].set_xlabel("Epoch")
axes[1].set_ylabel("Accuracy")
axes[1].legend()

plt.tight_layout()
plt.show()

# visualize prediction

In [None]:
# Reuse the same batch for comparison
outputs = model(images_batch.to(device))
preds_batch = outputs.argmax(dim=1).cpu()

fig, axes = plt.subplots(2, 4, figsize=(14, 7))
for i, ax in enumerate(axes.flat):
    img = images_batch[i].permute(1, 2, 0) * std + mean
    img = img.clamp(0, 1).numpy()
    ax.imshow(img)
    true_name = class_names[labels_batch[i]]
    pred_name = class_names[preds_batch[i]]
    color = "green" if labels_batch[i] == preds_batch[i] else "red"
    ax.set_title(f"True: {true_name}\nPred: {pred_name}", color=color, fontsize=8)
    ax.axis("off")

plt.suptitle("Phase 2 Predictions (green = correct, red = wrong)", fontsize=12)
plt.tight_layout()
plt.show()

# Calculate test accuracy score 

In [None]:
phase2_preds, phase2_labels = predict(model, test_loader, device)
phase2_test_acc = accuracy_score(phase2_labels, phase2_preds)
print(f"Phase 1 Test Accuracy: {phase1_test_acc:.4f}")
print(f"Phase 2 Test Accuracy: {phase2_test_acc:.4f}")
print(f"Improvement:           {phase2_test_acc - phase1_test_acc:+.4f}")

# calculate confusion matrix precision and recall

In [None]:
cm2 = confusion_matrix(phase2_labels, phase2_preds)

fig, ax = plt.subplots(figsize=(14, 12))
im = ax.imshow(cm2, cmap="Blues")
plt.colorbar(im, ax=ax)
ax.set_title("Phase 2 – Confusion Matrix")
ax.set_xlabel("Predicted")
ax.set_ylabel("True")
plt.tight_layout()
plt.show()

precision_macro2 = precision_score(phase2_labels, phase2_preds, average="macro", zero_division=0)
recall_macro2 = recall_score(phase2_labels, phase2_preds, average="macro", zero_division=0)

print(f"Phase 2 Macro Precision: {precision_macro2:.4f}")
print(f"Phase 2 Macro Recall:    {recall_macro2:.4f}")

# Question 2 What did you learn? What is the difference between transfer learning and fine tuning? 

## Answer to Question 2

**Transfer learning** (Phase 1) means taking a model that was pre-trained on a large dataset (ImageNet) and re-using its feature-extraction layers as-is, frozen. Only a new classification head on top is trained. This is fast and works well even with limited data, because the frozen layers already know how to detect general visual features.

**Fine tuning** (Phase 2) goes one step further: after the classifier has been trained, we unfreeze some of the later backbone layers (here `layer4`) and continue training with a *much lower* learning rate. This allows the model to adapt the high-level features of the backbone to the specifics of our dataset, usually pushing accuracy higher. We use a very small learning rate (1e-5 for layer4 vs 1e-3 in Phase 1) to avoid catastrophic forgetting — if the learning rate is too high, the model will overfit to the training set and destroy the valuable pretrained features.

**Key takeaway:** Transfer learning gets you a good baseline quickly; fine tuning squeezes out extra performance by letting the network specialise its features to your task. The critical requirement for fine-tuning is using conservative learning rates (typically 10-100x lower than Phase 1) to prevent overfitting and preserve pretrained knowledge.