# Training a classification model on MNIST with PyTorch Lightning

In [1]:
%%capture
!pip install pytorch-lightning

In [7]:
import torch
from torch import nn
from torch import optim
from torchvision import datasets, transforms
from torch.utils.data import random_split, DataLoader
import pytorch_lightning as pl
import torchmetrics

In [8]:
train_accuracy = torchmetrics.Accuracy()
valid_accuracy = torchmetrics.Accuracy(compute_on_step=False)

class ResNet(pl.LightningModule):
    def __init__(self):
        super().__init__()
        self.l1 = nn.Linear(28 * 28, 64)
        self.l2 = nn.Linear(64, 64)
        self.l3 = nn.Linear(64, 10)
        self.do = nn.Dropout(0.1)

        self.loss = nn.CrossEntropyLoss()

    def forward(self, x):
        h1 = nn.functional.relu(self.l1(x))
        h2 = nn.functional.relu(self.l2(h1))
        do = self.do(h2 + h1)
        logits = self.l3(do)
        return logits

    def configure_optimizers(self):
        optimizer = optim.SGD(self.parameters(), lr=1e-2)
        return optimizer

    def training_step(self, batch, batch_idx):
        x, y = batch

        # x: b x 1 x 28 x 28 (b & w image)
        b = x.size(0)
        x = x.view(b, -1)

        # 1 forward
        l = self(x) # l: logits
        # import pdb; pdb.set_trace() # debugging
        J = self.loss(l, y)
        acc = train_accuracy(l, y)
        pbar = {'train_acc': acc}
        return {'loss': J, 'progress_bar': pbar}

    def validation_step(self, batch, batch_idx):
        results = self.training_step(batch, batch_idx)
        results['progress_bar']['val_acc'] = results['progress_bar']['train_acc']
        del results['progress_bar']['train_acc']
        return results

    def validation_epoch_end(self, val_step_outputs):
        avg_val_loss = torch.tensor([x['loss'] for x in val_step_outputs]).mean()
        avg_val_acc = torch.tensor([x['progress_bar']['val_acc'] for x in val_step_outputs]).mean()

        pbar = {'avg_val_acc': avg_val_acc}
        return {'val_loss': avg_val_loss, 'progress_bar': pbar}

In [9]:
dataset = datasets.MNIST('data', train=True, download=True, transform=transforms.ToTensor())
mnist_train, mnist_val = random_split(dataset, [55000, 5000])
train_loader = DataLoader(mnist_train, batch_size=32, num_workers=4)
val_loader = DataLoader(mnist_val, batch_size=32, num_workers=4)

# model
model = ResNet()

In [10]:
trainer = pl.Trainer(progress_bar_refresh_rate=20, max_epochs=10)
trainer.fit(model, train_loader, val_loader)

GPU available: False, used: False
TPU available: False, using: 0 TPU cores

  | Name | Type             | Params
------------------------------------------
0 | l1   | Linear           | 50.2 K
1 | l2   | Linear           | 4.2 K 
2 | l3   | Linear           | 650   
3 | do   | Dropout          | 0     
4 | loss | CrossEntropyLoss | 0     
------------------------------------------
55.1 K    Trainable params
0         Non-trainable params
55.1 K    Total params
0.220     Total estimated model params size (MB)


Validation sanity check: 0it [00:00, ?it/s]

Training: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]