In [1]:
import pytorch_lightning as pl
from typing import Any
import torchmetrics
import torch

class CNN(pl.LightningModule):

    def __init__(self, loss: callable, lr: float, classification_head:torch.nn.Module ,num_classes:int =10) -> None:
        super().__init__()
        self.conv_layers = torch.nn.Sequential(
            torch.nn.Conv2d(1, 32, 3), 
            torch.nn.BatchNorm2d(32),
            torch.nn.ReLU(),
            torch.nn.Conv2d(32, 48, 3),
            torch.nn.BatchNorm2d(48),
            torch.nn.ReLU(),
            torch.nn.Conv2d(48, 64, 3),
            torch.nn.BatchNorm2d(64),
            torch.nn.ReLU(),
            torch.nn.Conv2d(64, 80, 3),
            torch.nn.BatchNorm2d(80),
            torch.nn.ReLU(),
            torch.nn.Conv2d(80, 96, 3),
            torch.nn.BatchNorm2d(96),
            torch.nn.ReLU(),
            torch.nn.Conv2d(96, 112, 3),
            torch.nn.BatchNorm2d(112),
            torch.nn.ReLU(),
            torch.nn.Conv2d(112, 128, 3),
            torch.nn.BatchNorm2d(128),
            torch.nn.ReLU(),
            torch.nn.Conv2d(128, 144, 3),
            torch.nn.BatchNorm2d(144),
            torch.nn.ReLU(),
            torch.nn.Conv2d(144, 160, 3),
            torch.nn.BatchNorm2d(160),
            torch.nn.ReLU(),
            torch.nn.Conv2d(160, 176, 3),
            torch.nn.BatchNorm2d(176),
            torch.nn.ReLU(),
        )
        self.classification_head = classification_head 
        self.num_classes = num_classes
        self.loss = loss
        self.lr = lr
        
        self.test_auroc = torchmetrics.AUROC(num_classes=self.num_classes)
        self.test_acc = torchmetrics.Accuracy(num_classes=self.num_classes)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.conv_layers(x)
        x = self.classification_head(x)
        return x

    def _step(self, batch) -> torch.Tensor:
        x, y = batch
        pred = self.forward(x)
        loss = self.loss(pred, y)
        return pred, loss

    def training_step(self, batch) -> torch.Tensor:
        pred, loss = self._step(batch)
        self.log("train/loss", loss)
        pred = torch.nn.functional.softmax(pred, dim=1)
        acc = torchmetrics.functional.accuracy(pred, batch[-1], num_classes=self.num_classes)
        self.log("train/acc", acc)
        return loss
    
    def _eval_step(self, batch, auroc, acc):
        pred, loss = self._step(batch)
        pred = torch.nn.functional.softmax(pred, dim=1)
        auroc.update(pred, batch[-1])
        acc.update(pred, batch[-1])
        return loss
        
    def test_step(self, batch: torch.Tensor, batch_idx: int) -> torch.Tensor:
        loss = self._eval_step(batch, self.test_auroc, self.test_acc)
        self.log("test", loss)

    def test_epoch_end(self, outputs) -> None:
        print(f"Test AUROC: {self.test_auroc.compute().data}")
        print(f"Test Accuracy: {self.test_acc.compute().data}")

    def validation_step(self, batch: torch.Tensor, batch_idx: int) -> torch.Tensor:
        loss = self._eval_step(batch, self.test_auroc, self.test_acc)
        self.log("test", loss)

    def validation_epoch_end(self, outputs) -> None:
        print(f"Test AUROC: {self.test_auroc.compute().data}")
        print(f"Test Accuracy: {self.test_acc.compute().data}")

        
    def configure_optimizers(self) -> Any:
        optim = torch.optim.Adam(self.parameters(), lr=self.lr)
        return optim

In [2]:
# load MNIST data
import torchvision
from torchvision.datasets import MNIST

transform = torchvision.transforms.Compose([
        torchvision.transforms.ToTensor(),]
)

BATCH_SIZE = 2048
train_data = MNIST(
    root="data", 
    download=True, 
    train=True, 
    transform=transform)

test_data = MNIST(
    root="data", 
    download=True, 
    train=False, 
    transform=transform)

dl_train = torch.utils.data.DataLoader(train_data, batch_size=BATCH_SIZE, shuffle=True, num_workers=2)
dl_test = torch.utils.data.DataLoader(test_data, batch_size=BATCH_SIZE, shuffle=False, num_workers=2)

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to data/MNIST/raw/train-images-idx3-ubyte.gz


  0%|          | 0/9912422 [00:00<?, ?it/s]

Extracting data/MNIST/raw/train-images-idx3-ubyte.gz to data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to data/MNIST/raw/train-labels-idx1-ubyte.gz


  0%|          | 0/28881 [00:00<?, ?it/s]

Extracting data/MNIST/raw/train-labels-idx1-ubyte.gz to data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to data/MNIST/raw/t10k-images-idx3-ubyte.gz


  0%|          | 0/1648877 [00:00<?, ?it/s]

Extracting data/MNIST/raw/t10k-images-idx3-ubyte.gz to data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to data/MNIST/raw/t10k-labels-idx1-ubyte.gz


  0%|          | 0/4542 [00:00<?, ?it/s]

Extracting data/MNIST/raw/t10k-labels-idx1-ubyte.gz to data/MNIST/raw



In [3]:
loss = torch.nn.CrossEntropyLoss()
lr = 1e-3
num_classes = 10

pool_head = torch.nn.Sequential(
    torch.nn.AdaptiveAvgPool2d((1, 1)),
    torch.nn.Flatten(),
    torch.nn.Linear(176, num_classes),
)

flatten_head = torch.nn.Sequential(
    torch.nn.Flatten(),
    torch.nn.Linear(176 * 8 * 8, 100),
    torch.nn.Linear(100, num_classes),
)

pool_model = CNN(loss, lr, pool_head, num_classes)
flatten_model = CNN(loss, lr, flatten_head, num_classes)



In [4]:
from pytorch_lightning.callbacks import EarlyStopping

EPOCHS = 15
trainer_pool = pl.Trainer(accelerator="gpu", devices=1, max_epochs=EPOCHS, log_every_n_steps=10)
# trainer.fit(pool_model, dl_test)
trainer_pool.fit(pool_model, dl_train)
trainer_pool.test(pool_model, dl_test)

trainer_flatten = pl.Trainer(accelerator="gpu", devices=1, max_epochs=EPOCHS, log_every_n_steps=10)
# trainer.fit(pool_model, dl_test)
trainer_flatten.fit(flatten_model, dl_train)
trainer_pool.test(flatten_model, dl_test)



Training: 0it [00:00, ?it/s]

Testing: 0it [00:00, ?it/s]

Test AUROC: 0.9999639391899109
Test Accuracy: 0.9944999814033508




Training: 0it [00:00, ?it/s]

Testing: 0it [00:00, ?it/s]

Test AUROC: 0.999951183795929
Test Accuracy: 0.9905999898910522


[{'test': 0.03099948726594448}]

In [2]:
%load_ext tensorboard
%tensorboard --logdir lightning_logs

The tensorboard extension is already loaded. To reload it, use:
  %reload_ext tensorboard


Reusing TensorBoard on port 6006 (pid 17949), started 0:00:16 ago. (Use '!kill 17949' to kill it.)

# Part b findings

With my models there is not really a difference in performance between the two approaches. This is probably because the Image Size is rather small and it is easy for the flatten approach to learn image structure.

However the performance with the global average pooling is around 0.5% better.

If we look at the tensorboard logs we can also see that the global average pooling method seems to be converging a little bit faster.