In [None]:
import os
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import matplotlib.pyplot as plt

from torch.utils.data import DataLoader
from torchvision import models
from torchvision import transforms
from pytorch_lightning import LightningModule, LightningDataModule, Trainer, seed_everything
from pytorch_lightning.loggers import TensorBoardLogger
from pytorch_lightning.callbacks import ModelCheckpoint, TQDMProgressBar
from torchmetrics.functional import auroc
from PIL import Image
from medmnist.info import INFO
from medmnist.dataset import MedMNIST

We will be using the [MedMNIST Pneumonia](https://medmnist.com/) dataset, which is a medical imaging inspired dataset but with the characteristics of MNIST. This allows efficient experimentation due to the small image size. The dataset contains real chest X-ray images but downsampled to 28 x 28 pixels, with binary labels indicating the presence of [Pneumonia](https://www.nhs.uk/conditions/pneumonia/) (which is an inflammation of the lungs).

In [None]:
class SimCLRPneumoniaMNISTDataset(MedMNIST):
    def __init__(self, split = 'train', positive_pair = True):
        ''' Dataset class for PneumoniaMNIST.
        The provided init function will automatically download the necessary
        files at the first class initialistion.

        :param split: 'train', 'val' or 'test', select subset

        '''
        self.flag = "pneumoniamnist"
        self.size = 28
        self.size_flag = ""
        self.root = './data/coursework/'
        self.info = INFO[self.flag]
        self.download()
        self.positive_pair = positive_pair

        npz_file = np.load(os.path.join(self.root, "pneumoniamnist.npz"))

        self.split = split

        # Load all the images
        assert self.split in ['train','val','test']

        self.imgs = npz_file[f'{self.split}_images']
        self.labels = npz_file[f'{self.split}_labels']

        # Add a short description in plain language.
        # We know that for self supervised learning we need to create two views of the same image, but these should be hard to learn from with a variety of relevant augmentations applied

        # The original SimCLR paper use RandomCrop with Resize, Random Horizontal Flip, Random Color Distortion and Random Gaussian Blur.
        # As MedMNIST is grayscale, we don't need RandomColor Distortion.
        
        # From inspecting a sample of the training images (see below), we have the following observations which we will use to design our augmentation pipeline:
        # 1. The images all have the heart on the same side (this can lead to overfitting) -> we can apply random horizontal flips to account for left/right orientation variation
        # 2. The orientation in terms of rotation or angle of the sternum is not always the same (in reality patient positioning in front of the x-ray can vary) -> we can apply random rotations to account for this
        # 3. We see some images are more zoomed in than others but only a few , hence we do some random resized cropping to create more images with different zoom levels 
        # 4. From online research, we learn that pneumonia can be in one or both lungs and it can be in patches or spread out (lobar or Bronchopneumonia)-> hence we also do random crop/ resize to account for this variation  
        # 5. We also see a variation in the bluriness of the images, hence we apply gaussian blur to account for this variation

        # We started with the default random resized crop scale of (0.08,1) but we found this was generating views that were too small (basically whole view was single colour), hence we increased this to 0.2
        self.augmentation_pipeline = transforms.Compose([
            transforms.ToTensor(),
            transforms.RandomResizedCrop(size=28,scale=(0.2,1.0)),
            transforms.RandomHorizontalFlip(p=0.5),
            transforms.GaussianBlur(kernel_size=3, sigma=(0.1, 2.0)),
            transforms.RandomAffine(degrees=10, translate=(0.1, 0.1))
        ])
        
        
    def __len__(self):    
        return self.imgs.shape[0]

    def __getitem__(self, index):
        # of shape [1, 28, 28], img_view1 and img_view2, representing two augmented view of the images.
        if not self.positive_pair:
            # return indexed image

            return self.imgs[index], self.labels[index]

        img = self.imgs[index]

        img1 = self.augmentation_pipeline(img)
        img2 = self.augmentation_pipeline(img)      

        return img1, img2



In [None]:
class SimCLRPneumoniaMNISTDataModule(LightningDataModule):
    def __init__(self, batch_size: int = 8):
        super().__init__()
        self.batch_size = batch_size
        self.train_set = SimCLRPneumoniaMNISTDataset(split='train')
        self.val_set = SimCLRPneumoniaMNISTDataset(split='val')
        self.test_set = SimCLRPneumoniaMNISTDataset(split='test')

    def train_dataloader(self):
        return DataLoader(dataset=self.train_set, batch_size=self.batch_size, shuffle=True)

    def val_dataloader(self):
        return DataLoader(dataset=self.val_set, batch_size=self.batch_size, shuffle=False)

    def test_dataloader(self):
        return DataLoader(dataset=self.test_set, batch_size=self.batch_size, shuffle=False)

#### **Check** dataset implementation.

In [None]:

# Initialise data module
datamodule = SimCLRPneumoniaMNISTDataModule()
# Get train dataloader
train_dataloader = datamodule.train_dataloader()
# Get first batch
batch = next(iter(train_dataloader))
# Visualise the images
view1, view2 = batch
f, ax = plt.subplots(2, 8, figsize=(12,4))
for i in range(8):
  ax[0,i].imshow(view1[i, 0], cmap='gray')
  ax[1,i].imshow(view2[i, 0], cmap='gray')
  ax[0,i].set_title('view 1')
  ax[1,i].set_title('view 2')
  ax[0, i].axis("off")
  ax[1, i].axis("off")