<a href="https://colab.research.google.com/github/Hamza-Ali0237/PyTorch-Projects/blob/main/Intermediate/PyTorch-FineGrained-TransferLearning.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Fine-Grained Transfer Learning

In [None]:
!pip install torchmetrics

In [None]:
import torch
from torch import nn
import torchvision as tv
import torchmetrics as tm
from torchvision import transforms
from torchvision.datasets import Flowers102
from torch.utils.data import Dataset, DataLoader

import random
import numpy as np
from PIL import Image
import seaborn as sns
import matplotlib.pyplot as plt
from collections import Counter

# Configuration

In [None]:
conf_dict = {
    # Reproducibility
    "seed": 42,

    # Computing Device
    "device": "cuda" if torch.cuda.is_available() else "cpu",

    # Data Settings
    "root_dir": "./data",
    "image_size": 224,
    "batch_size": 32,
    "num_workers": 2,
    "pin_memory": True,

    # Model Settings
    "architecture": "resnet50",   # Backbone Architecture
    "pretrained": True,
    "num_classes": 102,

    # Training Hyperparameters
    "epochs": 25,
    "lr_head": 1e-3,              # LR for classifier head
    "lr_backbone": 1e-4,          # Lower LR for fine-tuning backbone
    "weight_decay": 1e-4,
    "label_smoothing": 0.0,

    # Scheduler
    "scheduler": "onecycle",

    # Imbalance Handling
    "use_class_weights": True,
    "use_weighted_sampler": False,

    # Logging & Saving
    "log_every_n_steps": 20,
    "save_best_by": "val_loss",
    "checkpoint_dir": "No checkpoints for now"
}

print(conf_dict)

In [None]:
device = conf_dict["device"]

In [None]:
# Set seed for reproducibility
SEED = conf_dict["seed"]

random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)

if torch.cuda.is_available():
  torch.cuda.manual_seed(SEED)
  torch.cuda.manual_seed_all(SEED)

# Ensure deterministic behavior
## for speed over strict determinism, set deterministic=False and benchmark=True
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

# Data Transforms

In [None]:
# Normalization values (ImageNet)
mean = [0.485, 0.456, 0.406]
std = [0.229, 0.224, 0.225]

# Traning Transforms (Augmentaion + Normalization)
train_transforms = transforms.Compose([
    transforms.RandomResizedCrop(224),
    transforms.RandomHorizontalFlip(),
    transforms.ColorJitter(
        brightness = 0.2, contrast = 0.2, saturation = 0.2, hue = 0.1
    ),
    transforms.ToTensor(),
    transforms.Normalize(
        mean = mean, std = std
    )
])

# Validation & Test Transforms (No heavy augmentation, just resize and center)
val_transforms = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(
        mean = mean, std = std
    )
])

test_transforms = val_transforms

# Load Data

In [None]:
# Load Datasets
train_dataset = Flowers102(root=conf_dict["root_dir"], split="train", transform=train_transforms, download=True)
val_dataset = Flowers102(root=conf_dict["root_dir"], split="val", transform=val_transforms, download=True)
test_dataset = Flowers102(root=conf_dict["root_dir"], split="test", transform=test_transforms, download=True)

In [None]:
# Define Default DataLoaders
train_loader = DataLoader(
    train_dataset, batch_size = conf_dict["batch_size"], shuffle = True,
    num_workers = conf_dict["num_workers"], pin_memory = conf_dict["pin_memory"]
)

val_loader = DataLoader(
    val_dataset, batch_size = conf_dict["batch_size"], shuffle = False,
    num_workers = conf_dict["num_workers"], pin_memory = conf_dict["pin_memory"]
)

test_loader = DataLoader(
    test_dataset, batch_size = conf_dict["batch_size"], shuffle = False,
    num_workers = conf_dict["num_workers"], pin_memory = conf_dict["pin_memory"]
)

In [None]:
# Check Class Imbalance

# Count samples per class in training set
targets = train_dataset.targets
class_counts = Counter(targets)

# Calculate weights: total_samples / (num_classes * count[class])
classes = train_dataset.classes
num_classes = len(classes)
total_samples = len(train_dataset)
class_weights = []

for c in range(num_classes):
  weight = total_samples / (num_classes * class_counts[c])
  class_weights.append(weight)

# Convert to tensor
class_weights = torch.tensor(class_weights, dtype=torch.float).to(device)

# Define loss with weights
criterion = nn.CrossEntropyLoss(weight = class_weights)

In [None]:
# Plot distributions (counts vs weights)

# Convert counts into a list aligned with class indices
counts_list = [class_counts[c] for c in range(num_classes)]

# Print first 10 classes as a quick sanity check
for i in range(10):
    print(f"Class {i:3d} | Count: {counts_list[i]:4d} | Weight: {class_weights[i].item():.4f}")

fig, ax1 = plt.subplots(figsize=(14,6))

# Plot counts on left y-axis
ax1.bar(range(num_classes), counts_list, alpha=0.6, label="Class Counts")
ax1.set_xlabel("Class Index")
ax1.set_ylabel("Number of Samples")
ax1.legend(loc="upper left")

# Plot weights on right y-axis
ax2 = ax1.twinx()
ax2.plot(range(num_classes), class_weights.cpu().numpy(), 'r-', label="Class Weights")
ax2.set_ylabel("Class Weight")
ax2.legend(loc="upper right")

plt.title("Class Distribution vs. Computed Weights")
plt.show()

# Model Definition (Head)

In [None]:
# Load pretrained ResNet18
model = tv.models.resnet18(
    weights = models.ResNet18_Weights.DEFAULT
)

# Freeze all model parameters
for params in model.parameters():
  params.requires_grad = False

# Get the number of input features to the final layer
in_features = model.fc.in_features

# Replace the final fc layer to match 102 classes
model.fc = nn.Linear(in_features, 102)

# Unfreeze only classifier parameters
for params in model.fc.parameters():
  param.requires_grad = True

model.to(device)

# Double-check which params will be trained
for name, param in model.named_parameters():
    if param.requires_grad:
        print("Trainable:", name)

In [None]:
print(model)

## Model Training

In [None]:
# Define Optmizer
optimizer = torch.optim.Adam(
    model.fc.parameters(),
    lr = conf_dict["head_lr"],
    weight_decay = conf_dict["weight_decay"]
)

# Define Scheduler
steps_per_epoch = len(train_loader)

scheduler = torch.optim.lr_scheduler.OneCycleLR(
    optimizer,
    max_lr = conf_dict["lr_head"],
    epochs = conf_dict["epochs"],
    steps_per_epoch = steps_per_epoch
)

In [None]:
# Training Loop
train_loss, val_loss = [], []
train_acc, val_acc = [], []

for epoch in range(conf_dict["epochs"]):

  # Training Phase
  model.train()
  epoch_train_loss = 0.0
  correct_train, total_train = 0, 0

  for images, labels in train_loader:
    images, labels = images.to(device), labels.to(device)
    optimizer.zero_grad()
    outputs = model(images)
    loss = criterion(outputs, labels)
    loss.backward()
    optimizer.step()
    scheduler.step()

    epoch_train_loss += loss.item()

    preds = torch.argmax(outputs, dim=1)
    correct_train += (preds == labels).sum().item()
    total_train += labels.size(0)

  epoch_train_loss /= len(train_loader)
  epoch_train_acc = correct_train / total_train
  train_loss.append(epoch_train_loss)
  train_acc.append()

  # Validation Phase
  model.eval()
  epoch_val_loss = 0.0
  correct_val, total_val = 0, 0

  with torch.no_grad():
    for images, labels in val_loader():
      images, labels = images.to(device), labels.to(device)
      outputs = model(images)
      loss = criterion(outputs, labels)
      epoch_val_loss += loss.item()

      preds = torch.argmax(outputs, dim=1)
      correct_val += (preds == labels).sum().item()
      total_val += labels.size(0)

  epoch_val_loss /= len(val_loader)
  epoch_val_acc = correct_val / total_val
  val_loss.append(epoch_val_loss)
  val_acc.append(epoch_val_acc)

  print(f"Epoch [{epoch+1}/{conf_dict['epochs']}], "
          f"Train Loss: {epoch_train_loss:.4f}, Acc: {epoch_train_acc:.4f}, "
          f"Val Loss: {epoch_val_loss:.4f}, Acc: {epoch_val_acc:.4f}"
  )

In [None]:
# Loss Curve
plt.figure(figsize=(8, 6))
plt.plot(train_loss, label="Train Loss")
plt.plot(val_loss, label="Validation Loss")
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.title("Training vs Validation Loss")
plt.legend()
plt.grid(True)
plt.show()

# Accuracy Curve
plt.figure(figsize=(8, 6))
plt.plot(train_acc, label="Train Accuracy")
plt.plot(val_acc, label="Validation Accuracy")
plt.xlabel("Epoch")
plt.ylabel("Accuracy")
plt.title("Training vs Validation Accuracy")
plt.legend()
plt.grid(True)
plt.show()

## Evaluation

In [None]:
num_classes = conf_dict["num_classes"]

accuracy = tm.Accuracy(
    task = "multiclass",
    num_classes = num_classes,
).to(device)

precision = tm.Precision(
    task = "multiclass",
    num_classes = num_classes,
    average = "macro"
).to(device)

recall = tm.Recall(
    task = "multiclass",
    num_classes = num_classes,
    average = "macro"
).to(device)

f1 = tm.F1Score(
    task = "multiclass",
    num_classes = num_classes,
    average = "macro"
).to(device)

conf_mat = tm.ConfusionMatrix(
    task = "multiclass",
    num_classes = num_classes,
).to(device)