In [None]:
import os
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import matplotlib.pyplot as plt

from torch.utils.data import DataLoader
from torchvision import models
from torchvision import transforms
from pytorch_lightning import LightningModule, LightningDataModule, Trainer, seed_everything
from pytorch_lightning.loggers import TensorBoardLogger
from pytorch_lightning.callbacks import ModelCheckpoint, TQDMProgressBar
from torchmetrics.functional import auroc
from PIL import Image
from medmnist.info import INFO
from medmnist.dataset import MedMNIST

We will be using the [MedMNIST Pneumonia](https://medmnist.com/) dataset, which is a medical imaging inspired dataset but with the characteristics of MNIST. This allows efficient experimentation due to the small image size. The dataset contains real chest X-ray images but downsampled to 28 x 28 pixels, with binary labels indicating the presence of [Pneumonia](https://www.nhs.uk/conditions/pneumonia/) (which is an inflammation of the lungs).

In [None]:
class SimCLRPneumoniaMNISTDataset(MedMNIST):
    def __init__(self, split = 'train', positive_pair = True):
        ''' Dataset class for PneumoniaMNIST.
        The provided init function will automatically download the necessary
        files at the first class initialistion.

        :param split: 'train', 'val' or 'test', select subset

        '''
        self.flag = "pneumoniamnist"
        self.size = 28
        self.size_flag = ""
        self.root = './data/coursework/'
        self.info = INFO[self.flag]
        self.download()
        self.positive_pair = positive_pair

        npz_file = np.load(os.path.join(self.root, "pneumoniamnist.npz"))

        self.split = split

        # Load all the images
        assert self.split in ['train','val','test']

        self.imgs = npz_file[f'{self.split}_images']
        self.labels = npz_file[f'{self.split}_labels']

        # Add a short description in plain language.
        # We know that for self supervised learning we need to create two views of the same image, but these should be hard to learn from with a variety of relevant augmentations applied

        # The original SimCLR paper use RandomCrop with Resize, Random Horizontal Flip, Random Color Distortion and Random Gaussian Blur.
        # As MedMNIST is grayscale, we don't need RandomColor Distortion.
        
        # From inspecting a sample of the training images (see below), we have the following observations which we will use to design our augmentation pipeline:
        # 1. The images all have the heart on the same side (this can lead to overfitting) -> we can apply random horizontal flips to account for left/right orientation variation
        # 2. The orientation in terms of rotation or angle of the sternum is not always the same (in reality patient positioning in front of the x-ray can vary) -> we can apply random rotations to account for this
        # 3. We see some images are more zoomed in than others but only a few , hence we do some random resized cropping to create more images with different zoom levels 
        # 4. From online research, we learn that pneumonia can be in one or both lungs and it can be in patches or spread out (lobar or Bronchopneumonia)-> hence we also do random crop/ resize to account for this variation  
        # 5. We also see a variation in the bluriness of the images, hence we apply gaussian blur to account for this variation

        # We started with the default random resized crop scale of (0.08,1) but we found this was generating views that were too small (basically whole view was single colour), hence we increased this to 0.2
        self.augmentation_pipeline = transforms.Compose([
            transforms.ToTensor(),
            transforms.RandomResizedCrop(size=28,scale=(0.2,1.0)),
            transforms.RandomHorizontalFlip(p=0.5),
            transforms.GaussianBlur(kernel_size=3, sigma=(0.1, 2.0)),
            transforms.RandomAffine(degrees=10, translate=(0.1, 0.1))
        ])
        
        
    def __len__(self):    
        return self.imgs.shape[0]

    def __getitem__(self, index):
        # of shape [1, 28, 28], img_view1 and img_view2, representing two augmented view of the images.
        if not self.positive_pair:
            # return indexed image

            return self.imgs[index], self.labels[index]

        img = self.imgs[index]

        img1 = self.augmentation_pipeline(img)
        img2 = self.augmentation_pipeline(img)      

        return img1, img2



In [None]:
class SimCLRPneumoniaMNISTDataModule(LightningDataModule):
    def __init__(self, batch_size: int = 8):
        super().__init__()
        self.batch_size = batch_size
        self.train_set = SimCLRPneumoniaMNISTDataset(split='train')
        self.val_set = SimCLRPneumoniaMNISTDataset(split='val')
        self.test_set = SimCLRPneumoniaMNISTDataset(split='test')

    def train_dataloader(self):
        return DataLoader(dataset=self.train_set, batch_size=self.batch_size, shuffle=True)

    def val_dataloader(self):
        return DataLoader(dataset=self.val_set, batch_size=self.batch_size, shuffle=False)

    def test_dataloader(self):
        return DataLoader(dataset=self.test_set, batch_size=self.batch_size, shuffle=False)

#### **Check** dataset implementation.

In [None]:

# Initialise data module
datamodule = SimCLRPneumoniaMNISTDataModule()
# Get train dataloader
train_dataloader = datamodule.train_dataloader()
# Get first batch
batch = next(iter(train_dataloader))
# Visualise the images
view1, view2 = batch
f, ax = plt.subplots(2, 8, figsize=(12,4))
for i in range(8):
  ax[0,i].imshow(view1[i, 0], cmap='gray')
  ax[1,i].imshow(view2[i, 0], cmap='gray')
  ax[0,i].set_title('view 1')
  ax[1,i].set_title('view 2')
  ax[0, i].axis("off")
  ax[1, i].axis("off")

implement the simclr loss function based on the paper

In [None]:
def simclr_loss(embedding_view1, embedding_view2, tau=1.0):
    '''
    Corrected implementation of the SimCLR loss function.
    '''
    # Step 1: Normalise the embeddings
    embedding_view1_norm = F.normalize(embedding_view1, p=2, dim=1)
    embedding_view2_norm = F.normalize(embedding_view2, p=2, dim=1)
    
    # Step 2: gather all embeddings into one big vector of size [2*N , feature_dim]
    combined_embeddings = torch.cat([embedding_view1_norm, embedding_view2_norm], dim=0)
    
    # Step 3: compute all possible similarities, should be a matrix of size [2 * N, 2 * N]
    # all_similarities[i,j] will be the similarity between z_all_views[i] and z_all_views[j].
    # Use the hint.
    # all_similarities = torch.mm(embedding_view1_normalised, embedding_view2_normalised.t())  
    # we need to use combined similarities since we index i, j+N, i+N,j
    similarity_matrix = torch.mm(combined_embeddings, combined_embeddings.T)
    
    batch_size = embedding_view1.size(0)
    n = similarity_matrix.size(0) // 2

    # Step 4: self-mask. For computing the denominator term in the loss function,
    # we need to sum over all possible similarities except the self-similarity.
    # Create a mask of shape [2*N, 2*N] that is 1 for all valid pairs and 0 for all self-pairs (i = j).
    
    # we needs 1 in all values apart from diagonal

    self_mask = torch.ones((2 * n, 2 * n), dtype=torch.bool).to(embedding_view1.device)
    self_mask.fill_diagonal_(0)

    # Step 5: Here we want to return a mask of size[2 * N, 2* N] for which mask[i,j] = 1 if
    # z_all_views[i] and z_all_views[j] form a positive pair.
    # There should be exactely 2 * N non-zeros elements in this matrix.
    
    # positive pairs are where i,j+N and i+N,j, use these for the mask
    positive_mask = torch.zeros_like(self_mask)
    for i in range(batch_size):
        positive_mask[i, i+batch_size] = 1
        positive_mask[i+batch_size, i] = 1
 
    assert(positive_mask.sum() == 2 * n)

    # Step 6: Computing all numerators for the loss function.
    # Should be vector of size [2 * N],
    # where element is exp(sim(i, j) / t) for each positive pair (i, j).
    # Re-use the computed quantities above.
    
    # we need to use the mask to get the positive pairs
    sim_ij = torch.masked_select(similarity_matrix, positive_mask)

    exp_sim_ij = torch.exp(sim_ij / tau)
    assert(exp_sim_ij.shape[0] == 2 * n)

    
    # Step 7: Computing all denominators for the loss function.
    # Should be a vector of size [2 * N].
    # Where each element should be the sum of exp(sim(i,k)/tau) for all k != i.
    
    # we need to use the self mask to get the sum of all similarities apart from the self similarity
    sim_ik = torch.masked_select(similarity_matrix, self_mask)
    exp_sim_ik = torch.exp(sim_ik / tau).view(2*batch_size, -1).sum(dim=1)  # Sum over all negative samples
    assert(exp_sim_ik.shape[0] == 2 * n)


    loss_ij = -torch.log(exp_sim_ij / exp_sim_ik)
    loss = loss_ij.mean()  # Average over all positive pairs
    
    return loss



testing to ensure correct behaviour

In [None]:
# DO NOT MODIFY THIS CELL! IT IS FOR CHECKING THE IMPLEMENTATION ONLY.

seed_everything(33)

expected_results = [torch.tensor(1.7518), torch.tensor(1.6376), torch.tensor(4.194),  torch.tensor(4.1754)]
for i, (N, feature_dim) in enumerate(zip([3, 3, 33, 33], [5, 125, 5, 125])):
  print(f"{N=} and {feature_dim=}")
  embedding_view1 = torch.rand((N, feature_dim))
  embedding_view2 = torch.rand((N, feature_dim))
  loss = simclr_loss(embedding_view1.clone(), embedding_view2.clone(), tau=0.5)
  print(f"Expected loss: {expected_results[i]}, Computed loss: {loss}")
  assert torch.isclose(loss, expected_results[i], rtol=1e-3)
print("Passed all tests successfully !")

implement a CNN backbone using a contrastive loss objective

In [None]:
class ImageEncoder(torch.nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.net = models.resnet50(weights=None)
        del self.net.fc
        self.net.conv1 = torch.nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.net.conv1(x)
        x = self.net.bn1(x)
        x = self.net.relu(x)
        x0 = self.net.maxpool(x)
        x1 = self.net.layer1(x0)
        x2 = self.net.layer2(x1)
        x3 = self.net.layer3(x2)
        x4 = self.net.layer4(x3)
        x4 = self.net.avgpool(x4)
        x4 = torch.flatten(x4, 1)
        return x4
    
    
class SimCLRModel(LightningModule):
    def __init__(self, learning_rate: float = 0.001):
        super().__init__()
        self.learning_rate = learning_rate

        self.encoder = ImageEncoder()

        self.projector = torch.nn.Sequential(
            torch.nn.Linear(2048, 1024),
            torch.nn.ReLU(),
            torch.nn.Linear(1024, 128),
        )

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)
        return optimizer

    def process_batch(self, batch):
        # TASK: Implement the process_batch function
        self.encoder.train()
        self.projector.train()

        # print device of model and data
        # print(f"{self.device=}")

        # put the data on the same device as the encoder and projector

        view1, view2 = batch
        view1 = view1.to(device=self.device)
        view2 = view2.to(device=self.device)


        # Pass the views through the encoder
        embedding_view1 = self.encoder(view1)
        embedding_view2 = self.encoder(view2)

        # Pass the embeddings through the projector
        projection_view1 = self.projector(embedding_view1)
        projection_view2 = self.projector(embedding_view2)

        # Compute the loss
        loss = simclr_loss(projection_view1, projection_view2)

        return loss

    def training_step(self, batch, batch_idx):
        loss = self.process_batch(batch)
        self.log('train_loss', loss, prog_bar=True)
        if batch_idx == 0:
            grid = torchvision.utils.make_grid(torch.cat((batch[0][0:4, ...], batch[1][0:4, ...]), dim=0), nrow=4, normalize=True)
            self.logger.experiment.add_image('train_images', grid, self.global_step)
        return loss

    def validation_step(self, batch, batch_idx):
        loss = self.process_batch(batch)
        self.log('val_loss', loss, prog_bar=True)

In [None]:
# train the model using the provided data module
seed_everything(33, workers=True)

data = SimCLRPneumoniaMNISTDataModule(batch_size=32)

model = SimCLRModel()

trainer = Trainer(
    max_epochs=5,
    accelerator='auto',
    devices=1,
    logger=TensorBoardLogger(save_dir='./lightning_logs/coursework/', name='simclr'),
    callbacks=[ModelCheckpoint(monitor='val_loss', mode='min'), TQDMProgressBar(refresh_rate=10)],
)
trainer.fit(model=model, datamodule=data)

compare two encoders, one fine-tuned vs one probed on image classification task of the PneumoniaMNIST dataset

In [None]:
class PneumoniaMNISTDataset(MedMNIST):
    def __init__(self, split = 'train', augmentation: bool = False):
        ''' Dataset class for Pneumonia MNST.
        The provided init function will automatically download the necessary
        files at the first class initialistion.

        :param split: 'train', 'val' or 'test', select subset

        '''
        self.flag = "pneumoniamnist"
        self.size = 28
        self.size_flag = ""
        self.root = './data/coursework/'
        self.info = INFO[self.flag]
        self.download()

        npz_file = np.load(os.path.join(self.root, "pneumoniamnist.npz"))

        self.split = split

        # Load all the images
        assert self.split in ['train','val','test']

        self.imgs = npz_file[f'{self.split}_images']
        self.labels = npz_file[f'{self.split}_labels']

        self.do_augment = augmentation

        # TASK: Define here your data augmentation pipeline suitable for classification.
        # Check previous tutorials for inspiration.
        
        # SimCLR emphasizes aggressive augmentations to create pairs that maintain core semantic similarity despite heavy alterations. 
        # However for image classification you want the model to focus on discriminative features and be robust to minor real-world variations
        # but you typically don't want transformations so strong that they change the essential nature of the class


        if self.do_augment:
            self.augmentation_pipeline = transforms.Compose([
                transforms.ToTensor(),
                transforms.RandomResizedCrop(size=28,scale=(0.7,1.0)), # made crop scale bigger to 0.7
                transforms.RandomHorizontalFlip(p=0.5),
                transforms.GaussianBlur(kernel_size=3, sigma=(0.1, 1.0)), # reduced blurring
                transforms.RandomAffine(degrees=5, translate=(0.1, 0.1)) # reduced variation
            ])
            
            

    def __len__(self):
        return self.imgs.shape[0]

    def __getitem__(self, index):
        # TASK: Implement the __getitem__ function to return the image and its class label.
        label = self.labels[index]
        img = self.imgs[index]
        if self.do_augment:
            img = self.augmentation_pipeline(img)
        else:
            img = transforms.ToTensor()(img)
        return img, label[0]

In [None]:
class PneumoniaMNISTDataModule(LightningDataModule):
    def __init__(self, batch_size: int = 32):
        super().__init__()
        self.batch_size = batch_size
        self.train_set = PneumoniaMNISTDataset(split='train', augmentation=True)
        self.val_set = PneumoniaMNISTDataset(split='val', augmentation=False)
        self.test_set = PneumoniaMNISTDataset(split='test', augmentation=False)

    def train_dataloader(self):
        return DataLoader(dataset=self.train_set, batch_size=self.batch_size, shuffle=True)

    def val_dataloader(self):
        return DataLoader(dataset=self.val_set, batch_size=self.batch_size, shuffle=False)

    def test_dataloader(self):
        return DataLoader(dataset=self.test_set, batch_size=self.batch_size, shuffle=False)

In [None]:
# DO NOT MODIFY THIS CELL! IT IS FOR CHECKING THE IMPLEMENTATION ONLY.

# Initialise data module
datamodule = PneumoniaMNISTDataModule()
# Get train dataloader
train_dataloader = datamodule.train_dataloader()
# Get first batch
batch = next(iter(train_dataloader))
# Visualise the images
images, labels = batch
f, ax = plt.subplots(1, 8, figsize=(12,4))
for i in range(8):
  ax[i].imshow(images[i, 0], cmap='gray')
  ax[i].set_title('label: ' + str(labels[i].item()))
  ax[i].axis("off")

using a pre-trained encoder as a starting point

In [None]:
def load_encoder_from_checkpoint(checkpoint_path):
  ckpt = torch.load(checkpoint_path, map_location='cpu')
  simclr_module = SimCLRModel()
  print(simclr_module.load_state_dict(state_dict=ckpt))
  return simclr_module.encoder.eval()

imagenet_model = '../data/coursework/model_imagenet.ckpt'
chestxray_model = '../data/coursework/model_chestxray.ckpt'

In [None]:

from typing import Any


class ImageClassifier(LightningModule):
    def __init__(self, pretrained_encoder: torch.nn.Module, freeze_encoder: bool = True, output_dim: int = 2, learning_rate: float = 0.001):
        super().__init__()
        self.save_hyperparameters()

        self.encoder = pretrained_encoder
        self.learning_rate = learning_rate

        if freeze_encoder:
            for param in self.encoder.parameters():
                param.requires_grad = False

        # Get encoder output dimension 
        self.classifier = nn.Linear(2048, output_dim) # got image encoder output dimension from previous cell (where we printed the output shape of the encoder)


    def forward(self, x):
        features = self.encoder(x)
        output = self.classifier(features)
        return F.sigmoid(output)

    def training_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = F.cross_entropy(logits, y)
        self.log('train_loss', loss)
        return loss

    def validation_step(self, batch,batch_idx):
        x, y = batch
        y_hot = F.one_hot(y.long(),2)
        logits = self(x)
        loss = F.cross_entropy(logits, y)
        roc_auc = auroc(logits, y_hot,task="binary", num_classes=self.hparams.output_dim, average='macro')
        self.log('val_loss', loss, prog_bar=True)
        self.log('val_roc_auc', roc_auc, prog_bar=True)
        
        # return batch loss and roc_auc
        return loss, roc_auc
    
    def test_step(self, batch,batch_idx):
        x, y = batch
        y_hot = F.one_hot(y.long(),2)
        logits = self(x)
        loss = F.cross_entropy(logits, y)
        roc_auc = auroc(logits, y_hot,task="binary", num_classes=self.hparams.output_dim, average='macro')
        self.log('test_loss', loss, prog_bar=True)
        self.log('test_roc_auc', roc_auc, prog_bar=True)
        
        # return batch loss and roc_auc
        return loss, roc_auc


    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=self.hparams.learning_rate)

start by finetuning the encoder 

In [None]:
seed_everything(33, workers=True)

data = PneumoniaMNISTDataModule(batch_size=32)

# TASK: Implement the model finetuning training and testing routines.

# IMAGENET
imagenet_encoder = load_encoder_from_checkpoint(imagenet_model)
finetuned_imagenet_model = ImageClassifier(pretrained_encoder=imagenet_encoder, freeze_encoder=False, output_dim=2, learning_rate=0.001)
trainer = Trainer(
    max_epochs=25,
    accelerator='auto',
    devices=1,
    logger=TensorBoardLogger(save_dir='./lightning_logs/coursework/', name='simclr'),
    callbacks=[ModelCheckpoint(monitor='val_loss', mode='min'), TQDMProgressBar(refresh_rate=10)],
)
trainer.fit(model=finetuned_imagenet_model, datamodule=data)