<a href="https://colab.research.google.com/github/Hanjie-Xu/ML_for_Imaging/blob/main/tutorial_1_classification.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Tutorial 1 - Classification

In this tutorial, we will first revisit a logistic-regression like approach for the MNIST classification task. This time we will be using PyTorch and PyTorch Lightning for the implementation, training, and testing. The task is then two replace the single-layer neural network with a convolutional neural network such as LeNet. We will compare the performance of the approaches on an out-of-distribution test set where digits have been slightly shifted.

**Objective:** Use PyTorch and PyTorch Lightning for the development of deep neural networks for image classification.

In [None]:
# On Google Colab uncomment the following line to install PyTorch Lightning
# ! pip install lightning

In [None]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import matplotlib
import matplotlib.pyplot as plt

from torch.utils.data import random_split, DataLoader
from torchvision import transforms
from torchvision.datasets import MNIST
from pytorch_lightning import LightningModule, LightningDataModule, Trainer, seed_everything
from pytorch_lightning.loggers import TensorBoardLogger
from pytorch_lightning.callbacks import ModelCheckpoint, TQDMProgressBar
from torchmetrics.functional import accuracy

## Data

We use a [LightningDataModule](https://lightning.ai/docs/pytorch/stable/data/datamodule.html) for handling the MNIST dataset.

In [None]:
class MNISTDataModule(LightningDataModule):
    def __init__(self, data_dir: str = './data', batch_size: int = 32, num_workers: int = 4, transform = transforms.ToTensor()):
        super().__init__()
        self.data_dir = data_dir
        self.batch_size = batch_size
        self.num_workers = num_workers
        self.transform = transform

        self.test_set = MNIST(self.data_dir, train=False, transform=self.transform, download=True)
        dev_set = MNIST(self.data_dir, train=True, transform=self.transform, download=True)
        self.train_set, self.val_set = random_split(dev_set, [55000, 5000])

    def train_dataloader(self):
        return DataLoader(self.train_set, batch_size=self.batch_size, shuffle=True, num_workers=self.num_workers, persistent_workers=True)

    def val_dataloader(self):
        return DataLoader(self.val_set, batch_size=self.batch_size, shuffle=False, num_workers=self.num_workers, persistent_workers=True)

    def test_dataloader(self):
        return DataLoader(self.test_set, batch_size=self.batch_size, shuffle=False, num_workers=self.num_workers, persistent_workers=True)

## Model

We use a [LightningModule](https://lightning.ai/docs/pytorch/stable/common/lightning_module.html) for implementing the model and its training and testing steps.

**Task:** Replace the single-layer network with a LeNet-like CNN architecture. You need to add the model layers and change the forward pass accordingly.

In [None]:
class ImageClassifier(LightningModule):
    def __init__(self, input_dim: tuple[int, int] = (28,28), output_dim: int = 10, learning_rate: float = 0.001):
        super().__init__()
        self.input_dim = input_dim
        self.output_dim = output_dim
        self.learning_rate = learning_rate

        # single-layer network (logistic regression-like)
        self.fc_layer = nn.Linear(torch.prod(torch.tensor(self.input_dim)), self.output_dim)

    def forward(self, x):
        x = x.view(x.size(0), -1)
        return self.fc_layer(x)

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)
        return optimizer

    def process_batch(self, batch):
        x, y = batch
        logits = self(x)
        loss = F.cross_entropy(logits, y)

        probs = torch.softmax(logits, dim=1)
        preds = torch.argmax(probs, dim=1)
        acc = accuracy(preds, y, task='multiclass', num_classes=self.output_dim)

        return loss, acc

    def training_step(self, batch, batch_idx):
        loss, acc = self.process_batch(batch)
        self.log('train_loss', loss, prog_bar=True)
        self.log('train_acc', acc, prog_bar=True)
        if batch_idx == 0:
            grid = torchvision.utils.make_grid(batch[0][0:16, ...], nrow=4, normalize=True)
            self.logger.experiment.add_image('train_images', grid, self.global_step)
        return loss

    def validation_step(self, batch, batch_idx):
        loss, acc = self.process_batch(batch)
        self.log('val_loss', loss, prog_bar=True)
        self.log('val_acc', acc, prog_bar=True)

    def test_step(self, batch, batch_idx):
        loss, acc = self.process_batch(batch)
        self.log('test_loss', loss)
        self.log('test_acc', acc)

## Training

We use the PyTorch Lightning [Trainer](https://lightning.ai/docs/pytorch/stable/common/trainer.html) for training and testing.

In [None]:
seed_everything(42, workers=True)

data = MNISTDataModule(data_dir='./data', batch_size=32)

model = ImageClassifier(input_dim=(28,28), output_dim=10, learning_rate=0.001)

trainer = Trainer(
    max_epochs=10,
    accelerator='auto',
    devices=1,
    logger=TensorBoardLogger(save_dir='./lightning_logs/classification/', name='mnist-logreg'),
    callbacks=[ModelCheckpoint(monitor='val_loss', mode='min'), TQDMProgressBar(refresh_rate=10)],
)
trainer.fit(model=model, datamodule=data)

## Validation

Evaluate the trained model with the best checkpoint on the validation data and report the classification accuracy.

In [None]:
trainer.validate(model=model, datamodule=data, ckpt_path=trainer.checkpoint_callback.best_model_path)

## Testing

Evaluate the trained model with the best checkpoint on the test data and report the classification accuracy.

In [None]:
trainer.test(model=model, datamodule=data, ckpt_path=trainer.checkpoint_callback.best_model_path)

## Testing on out-of-distribution data

Here we prepare an out-of-distribution test set where all digits are shifted by two pixels. We then test the trained model and report the classification accuracy.

In [None]:
transform = transforms.Compose([
    transforms.CenterCrop((26,26)),
    transforms.Pad((0,0,2,2)),
    transforms.ToTensor()
])

data_ood = MNISTDataModule(data_dir='./data', batch_size=32, transform=transform)

trainer.test(model=model, datamodule=data_ood, ckpt_path=trainer.checkpoint_callback.best_model_path)

## Visualisation

### Test data examples

Let's compare the original, in-distribution and the out-of-distribution test data.

In [None]:
samples_iid = [data.test_set[i][0] for i in range(0,16)]
samples_ood = [data_ood.test_set[i][0] for i in range(0,16)]
grid_iid = torchvision.utils.make_grid(samples_iid, nrow=4, normalize=True).numpy()[0,...].squeeze()
grid_ood = torchvision.utils.make_grid(samples_ood, nrow=4, normalize=True).numpy()[0,...].squeeze()

f, ax = plt.subplots(1,3, figsize=(15, 15))

ax[0].imshow(grid_iid, cmap=matplotlib.cm.gray)
ax[0].axis('off')
ax[0].set_title('in-distribution')

ax[1].imshow(grid_ood, cmap=matplotlib.cm.gray)
ax[1].axis('off')
ax[1].set_title('out-of-distribution')

ax[2].imshow(grid_iid - grid_ood, cmap=matplotlib.cm.bwr)
ax[2].axis('off')
ax[2].set_title('difference')

### Weights

Let's inspect the patterns that the single-layer network has learned for each digit class.

In [None]:
weights = model.fc_layer.weight.detach().cpu().numpy()

f, ax = plt.subplots(1,10, figsize=(15, 15))

for i in range(0,10):
    vmax = np.max(np.abs(weights[i,:])) / 2
    ax[i].imshow(weights[i,:].reshape(28,28), cmap=matplotlib.cm.bwr, clim=(-vmax,vmax))
    ax[i].axis('off')

## Logging

In [None]:
%load_ext tensorboard
%tensorboard --logdir './lightning_logs/classification/'