# Assignment 2 â€“ Transfer Learning with MedMNIST

This notebook implements a complete transfer learning pipeline on the **BloodMNIST** dataset from the MedMNIST collection using **ResNet-18** and PyTorch.

In [None]:
import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.transforms import v2, ToTensor
from torchvision import models

import numpy as np
import matplotlib.pyplot as plt

from medmnist import BloodMNIST
import medmnist

from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay, classification_report
from torch.utils.tensorboard import SummaryWriter

import os

print('PyTorch version:', torch.__version__)
print('MedMNIST version:', medmnist.__version__)

In [None]:
# Dataset & basic configuration
DataClass = BloodMNIST
info = medmnist.INFO[DataClass.flag]
n_classes = len(info['label'])
print('Using dataset:', DataClass.flag)
print('Number of classes:', n_classes)

BATCH_SIZE = 64
download = True

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Device:', device)

## 1. Transforms

In [None]:
train_transforms = v2.Compose([
    ToTensor(),
    v2.ToDtype(torch.float32, scale=True),
    v2.RandomHorizontalFlip(),
    v2.RandomRotation(10),
    v2.Normalize([0.485, 0.456, 0.406],
                 [0.229, 0.224, 0.225]),
])

test_transforms = v2.Compose([
    ToTensor(),
    v2.ToDtype(torch.float32, scale=True),
    v2.Normalize([0.485, 0.456, 0.406],
                 [0.229, 0.224, 0.225]),
])

## 2. Datasets and Dataloaders

In [None]:
train_dataset = DataClass(split='train', transform=train_transforms,
                          download=download, size=224, mmap_mode='r')
val_dataset   = DataClass(split='val',   transform=test_transforms,
                          download=download, size=224, mmap_mode='r')
test_dataset  = DataClass(split='test',  transform=test_transforms,
                          download=download, size=224, mmap_mode='r')

train_dataloader = DataLoader(dataset=train_dataset, batch_size=BATCH_SIZE, shuffle=True)
val_dataloader   = DataLoader(dataset=val_dataset,   batch_size=BATCH_SIZE, shuffle=False)
test_dataloader  = DataLoader(dataset=test_dataset,  batch_size=BATCH_SIZE, shuffle=False)

len(train_dataset), len(val_dataset), len(test_dataset)

### Sample batch visualisation

In [None]:
images, labels = next(iter(train_dataloader))
fig, axes = plt.subplots(2, 4, figsize=(10, 5))
axes = axes.flatten()
for img, lbl, ax in zip(images[:8], labels[:8], axes):
    img_show = img.permute(1, 2, 0).numpy()
    img_show = (img_show - img_show.min()) / (img_show.max() - img_show.min() + 1e-8)
    ax.imshow(img_show)
    ax.set_title(f'label: {int(lbl)}')
    ax.axis('off')
plt.tight_layout()
plt.show()

## 3. Define Transfer Learning Model

In [None]:
log_dir = 'runs/bloodmnist_resnet18'
os.makedirs(log_dir, exist_ok=True)
writer = SummaryWriter(log_dir=log_dir)
writer

In [None]:
# Load a pretrained ResNet-18 model
try:
    resnet = models.resnet18(weights=models.ResNet18_Weights.IMAGENET1K_V1)
except AttributeError:
    resnet = models.resnet18(pretrained=True)

# Freeze backbone parameters
for param in resnet.parameters():
    param.requires_grad = False

# Replace the final FC layer
num_ftrs = resnet.fc.in_features
resnet.fc = nn.Linear(num_ftrs, n_classes)

model = resnet.to(device)
print(model)

criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.fc.parameters(), lr=1e-3)

## 4. Optimisation & Training Loop

In [None]:
num_epochs = 10
best_val_acc = 0.0
best_model_path = 'best_bloodmnist_resnet18.pt'

for epoch in range(num_epochs):
    # Training phase
    model.train()
    running_loss = 0.0
    running_corrects = 0

    for inputs, labels in train_dataloader:
        inputs = inputs.to(device)
        labels = labels.squeeze().long().to(device)

        optimizer.zero_grad()

        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        _, preds = torch.max(outputs, 1)
        running_loss += loss.item() * inputs.size(0)
        running_corrects += torch.sum(preds == labels)

    epoch_train_loss = running_loss / len(train_dataset)
    epoch_train_acc = running_corrects.double() / len(train_dataset)

    # Validation phase
    model.eval()
    val_running_loss = 0.0
    val_running_corrects = 0

    with torch.no_grad():
        for inputs, labels in val_dataloader:
            inputs = inputs.to(device)
            labels = labels.squeeze().long().to(device)

            outputs = model(inputs)
            loss = criterion(outputs, labels)

            _, preds = torch.max(outputs, 1)
            val_running_loss += loss.item() * inputs.size(0)
            val_running_corrects += torch.sum(preds == labels)

    epoch_val_loss = val_running_loss / len(val_dataset)
    epoch_val_acc = val_running_corrects.double() / len(val_dataset)

    # Log to TensorBoard
    writer.add_scalar('Loss/train', epoch_train_loss, epoch)
    writer.add_scalar('Loss/val', epoch_val_loss, epoch)
    writer.add_scalar('Acc/train', epoch_train_acc, epoch)
    writer.add_scalar('Acc/val', epoch_val_acc, epoch)

    print(
        f'Epoch {epoch+1}/{num_epochs} '
        f'Train loss: {epoch_train_loss:.4f} acc: {epoch_train_acc:.4f} | '
        f'Val loss: {epoch_val_loss:.4f} acc: {epoch_val_acc:.4f}'
    )

    # Save best model
    if epoch_val_acc > best_val_acc:
        best_val_acc = epoch_val_acc
        torch.save(model.state_dict(), best_model_path)

print(f'Best validation accuracy: {best_val_acc:.4f}')

In [None]:
# To launch TensorBoard inside e.g. Jupyter/Colab, you can use:
# %load_ext tensorboard
# %tensorboard --logdir runs
print('To view logs, run TensorBoard pointing to:', log_dir)

## 5. Evaluate Model on Test Set

In [None]:
# Load best model weights
model.load_state_dict(torch.load(best_model_path, map_location=device))
model.to(device)
model.eval()

y_true = []
y_pred = []

with torch.no_grad():
    for inputs, labels in test_dataloader:
        inputs = inputs.to(device)
        labels = labels.squeeze().long().to(device)

        outputs = model(inputs)
        _, preds = torch.max(outputs, 1)

        y_true.extend(labels.cpu().numpy())
        y_pred.extend(preds.cpu().numpy())

y_true = np.array(y_true)
y_pred = np.array(y_pred)

classes = list(info['label'].values())
print('Classes:', classes)
print('Test samples:', len(y_true))

In [None]:
cm = confusion_matrix(y_true, y_pred)
disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=classes)
fig, ax = plt.subplots(figsize=(8, 8))
disp.plot(ax=ax, xticks_rotation=45)
plt.title('Confusion Matrix - BloodMNIST (ResNet-18)')
plt.tight_layout()
plt.show()

In [None]:
print('\nClassification Report:')
print(classification_report(y_true, y_pred, target_names=classes))

## 6. Comments about your Approach

- **Dataset choice**: I selected **BloodMNIST**, which consists of blood cell images with multiple classes. This allows me to demonstrate transfer learning on a realistic medical imaging dataset.
- **Model & transfer learning strategy**: I used **ResNet-18** pretrained on ImageNet. I froze the convolutional backbone and only trained the final fully connected layer to keep training efficient and to avoid overfitting.
- **Transforms & augmentation**: I applied random horizontal flips and small rotations to increase robustness, as well as normalisation using ImageNet mean and standard deviation so that the images match the statistics the pretrained backbone expects.
- **Training setup**: I used Adam with learning rate 1e-3, batch size 64, and trained for 10 epochs. I monitored training and validation loss/accuracy using TensorBoard and saved the model with the best validation accuracy.
- **Results**: The confusion matrix and classification report show how well the model performs per class. Any classes with lower precision/recall are candidates for further investigation.
- **Potential improvements**: I could unfreeze some of the deeper ResNet layers and fine-tune them on BloodMNIST, use stronger data augmentation, or try alternative architectures (e.g., ResNet-34, EfficientNet) to potentially improve performance further.