## Week 5: Fine-tuning VGG

### 0. Import the necessary libraries and dataset

In [None]:
import torch
import torch.nn as nn
import torchvision.transforms as T
from torchvision.models import vgg16, VGG16_Weights
from torch.utils.data import DataLoader, random_split
from torchvision.datasets import ImageFolder
from torch.optim import lr_scheduler
from tqdm import tqdm
from torchvision import models
from PIL import Image
from sklearn.utils.class_weight import compute_class_weight
import numpy as np
import cv2
import kagglehub

# Download latest version
path = kagglehub.dataset_download("lukechugh/best-alzheimer-mri-dataset-99-accuracy")

print("Path to dataset files:", path)

Downloading from https://www.kaggle.com/api/v1/datasets/download/lukechugh/best-alzheimer-mri-dataset-99-accuracy?dataset_version_number=1...


100%|██████████| 71.5M/71.5M [00:02<00:00, 31.4MB/s]

Extracting files...





Path to dataset files: /root/.cache/kagglehub/datasets/lukechugh/best-alzheimer-mri-dataset-99-accuracy/versions/1


### 1. Data loading and preprocessing

#### 1.1 Transformation for training dataset

Here, we create a series of transformations for all images in the following sequence:
- Converting the images to grayscale with 3 identical channels (to match the expected input shape for ResNet)
- Resizing the images to 256 x 256 pixels
- Cropping the resized images to 224 x 224 pixels from the center point
- Convert the PIL (Pillow) image to a PyTorch tensor
- Normalize the tensor with each channel to the respective mean and standard deviation. The specific mean and standard deviation values are calculated from the ImageNet dataset, which ResNet was originally trained on.

In [None]:
# TODO: Implement the above series of transformations for training images
train_transforms = T.Compose([
    T.Grayscale(num_output_channels=3),
    T.Resize(256),
    T.CenterCrop(224),
    T.ToTensor(),
    T.Normalize(mean=[0.485,0.456,0.406],
                         std=[0.229,0.224,0.225]),
])

#### 1.2 Load and split training dataset

We first load the training dataset from the directory. Note: The data is already split into training and testing sets, but we only need to load the training set. Ensure that the root directory of the dataset is `mri_dataset`.

While loading the dataset, we also add the transformations declared earlier.

Next, we split the loaded data into training and validation sets with an 80-20 split. The validation set is used to evaluate the model during training.

In [None]:
full_dataset = ImageFolder(path + "/Combined Dataset/train", transform=train_transforms)

train_len = int(0.8 * len(full_dataset))
val_len = len(full_dataset) - train_len
train_ds, val_ds = random_split(full_dataset, [train_len, val_len])
print(f"Train: {len(train_ds)} | Validation: {len(val_ds)}")

Train: 8192 | Validation: 2048


In [None]:
train_dl = DataLoader(train_ds, batch_size=32, shuffle=True, num_workers=4)
val_dl = DataLoader(val_ds, batch_size=32, shuffle=False, num_workers=4)



### 2. VGG fine-tuning

#### 2.1 Set default weights for the pre-trained VGG 16 model.

As mentioned previously, this model was trained on the ImageNet dataset. We will use the weights from this pre-trained model and adjust them to work with our dataset. We are only modifying the final classification block (witth 3 layers) model to have 4 output classes, as opposed to the original 1000.

In [None]:
# TODO: Load VGG model and pre-trained weights
weights = VGG16_Weights.DEFAULT
model = vgg16(weights=weights)

num_classes = 4
model.classifier[-1] = nn.Linear(in_features=4096, out_features=num_classes)

Downloading: "https://download.pytorch.org/models/vgg16-397923af.pth" to /root/.cache/torch/hub/checkpoints/vgg16-397923af.pth


100%|██████████| 528M/528M [00:03<00:00, 165MB/s]


In [None]:
# TODO: Unfreeze final classifier layers
for name, param in model.named_parameters():
    if "features.24" in name or "features.25" in name or \
       "features.26" in name or "features.27" in name or \
       "features.28" in name or "features.29" in name:
        param.requires_grad = True  # unfreeze last conv block
    elif "classifier" in name:
        param.requires_grad = True  # also train classifier
    else:
        param.requires_grad = False  # keep rest frozen

Use the device's GPU if available.

In [None]:
# Note: If you have a CUDA GPU, change "mps" to "cuda"
device = torch.device("cuda" if torch.cuda.is_available() else("mps" if torch.backends.mps.is_available() else "cpu"))
print(device)
model = model.to(device)

cuda


#### 2.2 Declare loss function, optimizer, and scheduler

Here, we are using the Cross Entropy Loss function, Adam optimizer, and Learning Rate Scheduler to train the model.

- The Cross Entropy Loss function is the best option for our purpuses in fine-tuning the ResNet model, as it expects a softmax output, which the ResNet model outputs.
- The Adam optimizer is a good option for fine-tuning purposes and for generalization, which is an important factor for our use-case with medical data. Here, we adjust the learning rates for different layers of the model, increasing as we move forward in the model. This is because the earlier layers pick up higher-level features, while the later layers pick up finer features. Thus, having a lower learning rate for the begining layers helps the model generalize, while a higher learning rate for the later layers helps the model better learn more fine-grained features.
- The Learning Rate Scheduler is used to dynamically adjust the learning rate of the optimizer during training, which can help prevent overfitting and improve the model's generalization ability.

Before we declare the loss, optimizer, and scheduler, however, we need to balance the class weights due to the significant class imbalance in our dataset. For instance, since the number of images in the 'No impairment' class is significantly higher than the number of images in the 'Moderate impairment' class, we need to assign a higher weight to the latter class to penalize errors in the 'Moderate impairment' class more. This leads the model away from just predicting the 'No impariment' class more often just to score a higher accuracy, and to a more balanced model overall.

In [None]:
# TODO: Extract labels from training data
y_train = []
for _, labels in train_dl:
    y_train.extend(labels.numpy())
y_train = np.array(y_train)

weights = compute_class_weight(class_weight='balanced', classes=np.unique(y_train), y=y_train)
criterion = nn.CrossEntropyLoss(weight=torch.tensor(weights, dtype=torch.float).to(device))

trainable_params = [p for p in model.parameters() if p.requires_grad]

optimizer = torch.optim.Adam(trainable_params, lr=1e-4)
scheduler = lr_scheduler.StepLR(optimizer, step_size=7, gamma=0.1)

#### 2.3 Fine-tuning VGG

We'll be using 15 epochs for fine-tuning. In each epoch, if the validation accuracy improves from the previous epoch, we'll save the new model weights.

As always, we begin with the training part of the loop; moving all the inputs to the GPU (if present), zeroing the optimizer's gradients, calculating the loss, and updating the weights through backpropagation. This is not so different from training traditional CNNs.

Moving to the validation phase of each epoch, we set the model to evaluation mode, and we calculate the validation loss and accuracy. This is again similar to training traditional CNNs.

In [None]:
best_val_acc = 0.0
num_epochs = 15

for epoch in range(num_epochs):
    print(f"Epoch {epoch+1}/{num_epochs}")
    print("-" * 10)

    # --- Train Phase ---
    model.train()
    running_loss = running_corrects = 0

    for inputs, labels in tqdm(train_dl, desc="Training"):
        inputs, labels = inputs.to(device), labels.to(device)

        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        preds = outputs.argmax(dim=1)
        running_loss += loss.item() * inputs.size(0)
        running_corrects += (preds == labels).sum().item()

    epoch_loss = running_loss / len(train_dl.dataset)
    epoch_acc = running_corrects / len(train_dl.dataset)
    print(f"Train Loss: {epoch_loss:.4f}, Acc: {epoch_acc:.4f}")

    # --- Validation Phase ---
    model.eval()
    val_loss = val_corrects = 0

    with torch.no_grad():
        for inputs, labels in tqdm(val_dl, desc="Validating"):
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            loss = criterion(outputs, labels)

            preds = outputs.argmax(dim=1)
            val_loss += loss.item() * inputs.size(0)
            val_corrects += (preds == labels).sum().item()

    val_loss = val_loss / len(val_dl.dataset)
    val_acc = val_corrects / len(val_dl.dataset)
    print(f"Val Loss: {val_loss:.4f}, Acc: {val_acc:.4f}")

    # Step LR
    scheduler.step()

    # Save best model
    if val_acc > best_val_acc:
        best_val_acc = val_acc
        torch.save(model.state_dict(), "best_vgg_mri.pth")
        print("→ Saved best model")

print("Training complete. Best Validation Accuracy:", best_val_acc)

Epoch 1/15
----------


Training: 100%|██████████| 256/256 [00:58<00:00,  4.37it/s]


Train Loss: 0.5339, Acc: 0.7604


Validating: 100%|██████████| 64/64 [00:11<00:00,  5.74it/s]


Val Loss: 0.3128, Acc: 0.8701
→ Saved best model
Epoch 2/15
----------


Training: 100%|██████████| 256/256 [00:56<00:00,  4.51it/s]


Train Loss: 0.2352, Acc: 0.9056


Validating: 100%|██████████| 64/64 [00:11<00:00,  5.81it/s]


Val Loss: 0.2520, Acc: 0.9106
→ Saved best model
Epoch 3/15
----------


Training: 100%|██████████| 256/256 [00:57<00:00,  4.48it/s]


Train Loss: 0.1019, Acc: 0.9640


Validating: 100%|██████████| 64/64 [00:10<00:00,  5.83it/s]


Val Loss: 0.1856, Acc: 0.9336
→ Saved best model
Epoch 4/15
----------


Training: 100%|██████████| 256/256 [00:57<00:00,  4.47it/s]


Train Loss: 0.0673, Acc: 0.9769


Validating: 100%|██████████| 64/64 [00:10<00:00,  5.85it/s]


Val Loss: 0.1678, Acc: 0.9482
→ Saved best model
Epoch 5/15
----------


Training: 100%|██████████| 256/256 [00:57<00:00,  4.44it/s]


Train Loss: 0.0313, Acc: 0.9883


Validating: 100%|██████████| 64/64 [00:11<00:00,  5.81it/s]


Val Loss: 0.1776, Acc: 0.9458
Epoch 6/15
----------


Training:  20%|██        | 52/256 [00:12<00:47,  4.29it/s]


KeyboardInterrupt: 

In [None]:
# TODO: Load VGG model with null weights, modify final classifier to have 4 outputs
model = models.vgg16(weights=None)

model.classifier[6] = nn.Linear(4096, 4)
model.load_state_dict(torch.load('best_vgg_mri.pth'))

def replace_relu_with_out_of_place(module):
    for name, child in module.named_children():
        if isinstance(child, torch.nn.ReLU) and child.inplace:
            setattr(module, name, torch.nn.ReLU(inplace=False))
        else:
            replace_relu_with_out_of_place(child)

replace_relu_with_out_of_place(model)

model.eval().to(device)

### 3. Testing

#### 3.1 Data loading and pre-processing

We'll begin by declaring another sequence of pre-processing modifications for the testing data. This is identical to the set of pre-processing modifications we used for the training data.

Similar to the training data, we'll load the dataset and loader.

In [None]:
# TODO: Declare pre-processing for testing data
preprocess = T.Compose([
    T.Grayscale(num_output_channels=3),
    T.Resize(256),
    T.CenterCrop(224),
    T.ToTensor(),
    T.Normalize(mean=[0.485,0.456,0.406],
                         std=[0.229,0.224,0.225]),
])
dataset = ImageFolder(path + "/Combined Dataset/test", transform=preprocess)
loader = DataLoader(dataset, batch_size=16, shuffle=False)

#### 3.2 Generate predictions

As always, we'll walk through the inputs and labels in the test dataset and append the predictions to the output array called `all_preds`. To be able to evaluate the model's performance, we'll also append the actual labels to another array caled `all_labels`.

In [None]:
model.eval()
all_preds = []
all_labels = []

with torch.no_grad():
    for inputs, labels in loader:
        inputs = inputs.to(device)
        outputs = model(inputs)
        preds = outputs.argmax(dim=1)
        all_preds.extend(preds.cpu().numpy())
        all_labels.extend(labels.cpu().numpy())

#### 3.3 Visualizing Model Performance

As with any classification model, we can use the confusion matrix to visualize the performance.

It's interesting to note how the model performs on the test set: it's incorrect predictions often do not seem to be too far from the actual values. For instance, for the Mild Impairment label, most of the incorrect predictions are classified as Very Mild Impairment. This makes sense because the labels are on a scale of the severity of impairment.

In [None]:
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay, classification_report
import matplotlib.pyplot as plt

cm = confusion_matrix(all_labels, all_preds)
disp = ConfusionMatrixDisplay(cm, display_labels=dataset.classes)
disp.plot(cmap="Blues", xticks_rotation="vertical")
plt.title("Confusion Matrix")
plt.show()

print(f"Classification Report:")
print(classification_report(all_labels, all_preds, target_names=dataset.classes))

### 4. Feature visualization (Bonus)

Here, we'll take a look at the features learned by the VGG-16 model on a sample image. As a reminder, here's the architecture of the VGG-16 model:

<img src="https://encrypted-tbn0.gstatic.com/images?q=tbn:ANd9GcTonMOBS41plEHMuaFXJdfaQ9rwHjofx-Sb8g&s" width="200%"/>

Before we get started, recall that every convolutional layer contains a set of learnable filters (kernels), typically of size n×n (like 3×3).
These filters slide over the input image (or feature map), performing dot products to produce activations.
The number of filters in a layer determines the number of output channels, and each filter learns to detect specific features.

To generate a feature map, we'll use a technique called GradCAM (Gradient-weighted Class Activation Mapping). The core idea of this method is to (1) access the features from a convolutional layer and (2) weight each feature map according to it's importance in predicting the class of the image (hence the 'Class' in GradCAM).

The method of calculating the weights is what distinguishes GradCAM from other techniques to generate feature maps.

Here's a quick overview of the method that GradCAM uses to calculate the weights:

1. Compute the gradients of the class prediction with respect to the activations of the convolutional layer. This provides a measure of how much changing a given element in an activation channel contributes to the prediction of that particular class. The gradients thus also have the same dimensions as the activations.
2. Compute the average of the gradients for each channel in the feature map. This provides a single gradient as the overall "change" for each channel in the feature map, which are essentially weights for each channel.
3. Weight each feature map by the average of the gradients. This provides a single weighted feature map (i.e. activation channel).
4. Sum the weighted feature maps together to get the final heatmap. This is effectively the weighted sum of the activations for each channel in the feature map, and is called the GradCAM heatmap.
5. Filter out negative values with the ReLU function. This is because negative values don't have a physical interpretation in the context of the model.
6. Normalize the heatmap values to be between 0 and 1
7. Upscale the heatmap to the same size as the original image.
8. Apply the heatmap to the original image and display it!

For a more visual explanation, check out this video: https://www.youtube.com/watch?v=_QiebC9WxOc

#### 4.1 Create forward and backwward hooks

A forward hook stores the most recent activaitons in the target layer after a forward pass. A backward hook stores the gradients passed into the target layer after a backward pass (back propagation).

The activations provide a set of features, which are areas of importance in the image that are picked up by the layer. The gradients provide a set of weights to highlight the importance of those features. Together, they can be used to create a heatmap of the most important areas in the image.

In [None]:
activations = None
gradients = None

def forward_hook(module, input, output):
    global activations
    activations = output.detach()

def backward_hook(module, grad_input, grad_output):
    global gradients
    gradients = grad_output[0].detach()

We will use the last convolutional layer of the fine-tuned VGG16 model to extract the features from the images. Note that this is before the classifier block in the model.

The reason for picking the last convolution layer is that it would have picked up the most detailed features from the image.

In [None]:
target_layer = model.features[29]  # last conv layer in VGG16
target_layer.register_forward_hook(forward_hook)
target_layer.register_full_backward_hook(backward_hook)

#### 4.2 Load and pre-process a sample image

We first load a sample image from the test dataset and convert it to RGB. This conversion is important because the model expects the image to have 3 color channels. We run the same preprocessing steps as with the training and testing data for consistency.

We then set the requires_grad parameter of the processed image to True to allow the backward hook to properly track the gradients with respect to the image.

In [None]:
img = Image.open(path + "/Combined Dataset/test/Mild Impairment/1 (2).jpg")

input_tensor = preprocess(img).unsqueeze(0).to(device)

input_tensor.requires_grad = True

print("Input image shape:", input_tensor.shape)

#### 4.3 Forward and backward passes
Next, we perform a forward pass on the processed image and get the output class. This updates the layer's activations in the forward hook.

In [None]:
output = model(input_tensor)
class_idx = output.argmax().item()  # or specify class manually

- We then reset the gradients in the model and perform a back propagation only for the neurons connected to the output class.
- Resetting the gradients is necessary to avoid accumulating the gradients from the previous iteration.
- Additionally, we only compute the gradients for the neurons connected to the output class in the final layer. This helps us identify the features that are important in predicting that particular class.

In [None]:
output = model(input_tensor)
loss = output[0, class_idx]
model.zero_grad()
loss.backward()

#### 4.4 Compute GradCAM

The gradients are of the format [B, C, H, W], where
- B is the batch size (number of images- in our case, 1)
- C is the number of channels/filters picked up by the layer. Each channel is a feature map.
- H is the height of the feature map
- W is the width of the feature map

In [None]:
pooled_grads = torch.mean(gradients, dim=(0, 2, 3))
for i in range(activations.shape[1]):
    activations[0, i, :, :] *= pooled_grads[i]

gradcam = torch.mean(activations[0], dim=0).cpu().numpy()
gradcam = np.maximum(gradcam, 0)
gradcam /= gradcam.max()

Upscale the GradCAM to the image resolution and scale its values to 0-255.

Create a heatmap of the GradCAM values on the original image.

In [None]:
# Resize heatmap to original image size
heatmap = cv2.resize(gradcam, (img.width, img.height))
heatmap = np.uint8(255 * heatmap) # scale to 0-255 for colormap
heatmap = cv2.applyColorMap(heatmap, cv2.COLORMAP_JET)  # apply colormap

Convert the image to BGR and overlay the heatmap on top of it.

In [None]:
img_np = np.array(img)
img_np_bgr = cv2.cvtColor(img_np, cv2.COLOR_RGB2BGR)

# Overlay heatmap on top of the image
overlay = cv2.addWeighted(img_np_bgr, 0.6, heatmap, 0.4, 0)

Finally, we can display the heatmap after converting the image back to RGB!

In [None]:
# Convert back to RGB for matplotlib
overlay_rgb = overlay[..., ::-1]

plt.imshow(overlay_rgb)
plt.axis('off')
plt.title('Grad-CAM Heatmap Overlay')
plt.show()