In [4]:
import lightning as L
import torchmetrics
import torch
from torch import nn
from torch.nn import functional as F
from torch.optim import AdamW
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from torch.utils.data.dataset import random_split
from torchvision.datasets import MNIST
from collections import OrderedDict

torch.manual_seed(1)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# train_dataset, val_dataset = random_split(train_dataset, [55000, 5000])
train_dataset = MNIST(root="../../data", download=True, transform=transforms.ToTensor(), train=True)
test_dataset = MNIST(root="../../data", transform=transforms.ToTensor(), train=False)

train_dataset, val_dataset = random_split(train_dataset, [55000, 5000])

train_loader = DataLoader(
  dataset=train_dataset,
  batch_size=128,
  shuffle=True,
  num_workers=11,
)

val_loader = DataLoader(
  dataset=val_dataset,
  batch_size=128,
  num_workers=11,
)

test_loader = DataLoader(
  dataset=test_dataset,
  batch_size=128,
  num_workers=0,
)

class MnistClassifierModel(L.LightningModule):
    def __init__(self):
        super().__init__()

        self.seq = nn.Sequential(OrderedDict([
            ('conv1', nn.Conv2d(1, 32, 3, 1)),
            ('relu1', nn.ReLU()),
            ('maxpool1', nn.MaxPool2d(2)),
            ('conv2', nn.Conv2d(32, 64, 3, 1)),
            ('relu2', nn.ReLU()),
            ('maxpool2', nn.MaxPool2d(2)),
            ('flatten', nn.Flatten()),
            ('lin1', nn.Linear(1600, 512)),
            ('relu3', nn.ReLU()),
            ('dropout2', nn.Dropout(0.5)),
            ('lin2', nn.Linear(512, 10)),
        ]))

        self.train_acc = torchmetrics.Accuracy(task='multiclass', num_classes=10).to(device)
        self.val_acc = torchmetrics.Accuracy(task='multiclass', num_classes=10).to(device)
        self.test_acc = torchmetrics.Accuracy(task='multiclass', num_classes=10).to(device)

    def forward(self, x):
        return self.seq(x)

    
    def _shared_step(self, batch):
        features, true_labels = batch
        logits = self(features)
        loss = F.cross_entropy(logits, true_labels)
        predictions = torch.argmax(logits, dim=1)
        return loss, true_labels, predictions
    
    def training_step(self, batch, batch_idx: int):
        loss, true_labels, predictions = self._shared_step(batch)

        self.log('train_loss', loss)

        # Track accuracy
        self.train_acc(predictions, true_labels)
        self.log('train_acc', self.train_acc, prog_bar=True, on_epoch=True, on_step=False)

        return loss # this is passed to the optimizer for training
    
    def validation_step(self, batch, batch_idx: int):
        loss, true_labels, predictions = self._shared_step(batch)

        self.log('val_loss', loss, prog_bar=True)

        # Track accuracy
        self.val_acc(predictions, true_labels)
        self.log('val_acc', self.val_acc, prog_bar=True)

    def test_step(self, batch, batch_idx: int):
        loss, true_labels, predictions = self._shared_step(batch)

        # Track accuracy
        self.test_acc(predictions, true_labels)
        self.log('accuracy', self.test_acc, prog_bar=True)
    
    def configure_optimizers(self):
        optimizer = torch.optim.AdamW(self.parameters())
        return optimizer


In [5]:
model = MnistClassifierModel().to(device)

print(model)

trainer = L.Trainer(
  max_epochs=3,
  accelerator='gpu',
  devices='auto' # Use all available GPUs if applicable
)

trainer.fit(
  model,
  train_dataloaders=train_loader,
  val_dataloaders=val_loader
)


GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name      | Type               | Params | Mode 
---------------------------------------------------------
0 | seq       | Sequential         | 843 K  | train
1 | train_acc | MulticlassAccuracy | 0      | train
2 | val_acc   | MulticlassAccuracy | 0      | train
3 | test_acc  | MulticlassAccuracy | 0      | train
---------------------------------------------------------
843 K     Trainable params
0         Non-trainable params
843 K     Total params
3.375     Total estimated model params size (MB)


MnistClassifierModel(
  (seq): Sequential(
    (conv1): Conv2d(1, 32, kernel_size=(3, 3), stride=(1, 1))
    (relu1): ReLU()
    (maxpool1): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (conv2): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1))
    (relu2): ReLU()
    (maxpool2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (flatten): Flatten(start_dim=1, end_dim=-1)
    (lin1): Linear(in_features=1600, out_features=512, bias=True)
    (relu3): ReLU()
    (dropout2): Dropout(p=0.5, inplace=False)
    (lin2): Linear(in_features=512, out_features=10, bias=True)
  )
  (train_acc): MulticlassAccuracy()
  (val_acc): MulticlassAccuracy()
  (test_acc): MulticlassAccuracy()
)


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

`Trainer.fit` stopped: `max_epochs=3` reached.


In [6]:
model.eval()
@model.register_forward_hook
def hook(module, input, output):
  print(output.shape)

trainer.test(model, test_loader)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Testing: |          | 0/? [00:00<?, ?it/s]

torch.Size([128, 10])
torch.Size([128, 10])
torch.Size([128, 10])
torch.Size([128, 10])
torch.Size([128, 10])
torch.Size([128, 10])
torch.Size([128, 10])
torch.Size([128, 10])
torch.Size([128, 10])
torch.Size([128, 10])
torch.Size([128, 10])
torch.Size([128, 10])
torch.Size([128, 10])
torch.Size([128, 10])
torch.Size([128, 10])
torch.Size([128, 10])
torch.Size([128, 10])
torch.Size([128, 10])
torch.Size([128, 10])
torch.Size([128, 10])
torch.Size([128, 10])
torch.Size([128, 10])
torch.Size([128, 10])
torch.Size([128, 10])
torch.Size([128, 10])
torch.Size([128, 10])
torch.Size([128, 10])
torch.Size([128, 10])
torch.Size([128, 10])
torch.Size([128, 10])
torch.Size([128, 10])
torch.Size([128, 10])
torch.Size([128, 10])
torch.Size([128, 10])
torch.Size([128, 10])
torch.Size([128, 10])
torch.Size([128, 10])
torch.Size([128, 10])
torch.Size([128, 10])
torch.Size([128, 10])
torch.Size([128, 10])
torch.Size([128, 10])
torch.Size([128, 10])
torch.Size([128, 10])
torch.Size([128, 10])
torch.Size

[{'accuracy': 0.9886000156402588}]