In [None]:
%pip install "pytorch-lightning==1.6.3"
%pip install "lightning-bolts==0.5.0"
%pip install torchvision

In [None]:
# Import Torch.
import torch
from torch.nn import functional as F

# Import Lightning && MNIST data modfule from lightning bolts.
# https://lightning-bolts.readthedocs.io/en/latest/
import pytorch_lightning as pl
from pl_bolts.datamodules.mnist_datamodule import MNISTDataModule

In [None]:
# Extend lightning module to override the training methods.
class LitClassifier(pl.LightningModule):
    def __init__(self, hidden_dim: int = 128, learning_rate: float = 0.0001):
        super().__init__()
        self.save_hyperparameters()

        self.l1 = torch.nn.Linear(28 * 28, self.hparams.hidden_dim)
        self.l2 = torch.nn.Linear(self.hparams.hidden_dim, 10)

    def forward(self, x):
        x = x.view(x.size(0), -1)
        x = torch.relu(self.l1(x))
        x = torch.relu(self.l2(x))
        return x

    def training_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = F.cross_entropy(y_hat, y)
        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch
        probs = self(x)
        acc = self.accuracy(probs, y)
        return acc

    def test_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        acc = self.accuracy(logits, y)
        return acc

    def accuracy(self, logits, y):
        acc = torch.sum(torch.eq(torch.argmax(logits, -1), y).to(torch.float32)) / len(y)
        return acc

    def validation_epoch_end(self, outputs) -> None:
        self.log("val_acc", torch.stack(outputs).mean(), prog_bar=True)

    def test_epoch_end(self, outputs) -> None:
        self.log("test_acc", torch.stack(outputs).mean())

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=self.hparams.learning_rate)

In [None]:
# Create an instance of the data module provided by lightning bolts.
dm = MNISTDataModule(batch_size=32)
model = LitClassifier()

In [None]:
# Single node, CPU based training.
# --------------------------------
# Pass the customzed lightning module to the Trainer to execute the training process.
# https://pytorch-lightning.readthedocs.io/en/stable/common/trainer.html#basic-use
trainer = pl.Trainer(max_epochs=200)
trainer.fit(model, datamodule=dm)
trainer.test(model, datamodule=dm)

In [None]:
# Single node, GPU based training.
# --------------------------------
# https://pytorch-lightning.readthedocs.io/en/stable/accelerators/gpu.html#single-gpu-training
# Interactive launches can be done only when using a single GPU, for single node/multi GPU, multi node/multi GPU
# training, the job has to be launched through the SM training toolkit.
trainer = pl.Trainer(accelerator="gpu", devices=1)
trainer.fit(model, datamodule=dm)
trainer.test(model, datamodule=dm)