Pytorch Lightning
* It forces a standard structure on your code - easier for code review
* It gets rid of boilerplate 
* It abstracts away alot of extra stuff besides the core code. Logging on parallelization on GPU or even TPU. Parallelization can be done just by adding some flags to your trainer rather than refactoring your code.

In [None]:
"""
Dataset
Build a model
Define loss_func and optimizer
Define trainer
Define test
Run trainer and test
"""

In [None]:
import torch
from torch import nn
from torch.nn import functional as F
from torch import optim
from torch.utils.data import DataLoader
from torchvision.datasets import MNIST
from torchvision.transforms import ToTensor
import pytorch_lightning as pl

In [None]:
# Get dataset
train_ds = MNIST(root='data3', train=True, download=True, transform=ToTensor())
valid_ds = MNIST(root='data3', train=False, download=True, transform=ToTensor())

bs = 64
train_dl = DataLoader(train_ds, batch_size=bs, shuffle=True)
valid_dl = DataLoader(valid_ds, batch_size=bs, shuffle=True)

In [None]:
# import torchmetrics

# Build the model
class MNISTModel(pl.LightningModule): # pl.lightningmodule is an nn.module with extra features
    def __init__(self, lr=0.5):
        super().__init__()
        self.lin = nn.Linear(784, 10)
        self.lr = lr

        # metrics
        # self.train_accuracy = torchmetrics.Accuracy() 
        # self.valid_accuracy = torchmetrics.Accuracy()

    def forward(self, xb):
        xb = xb.flatten(1, -1)
        return self.lin(xb)
    
    def training_step(self, batch, batch_idx):
        return self.shared_step(batch, train=True)

    def validation_step(self, batch, batch_idx):
        self.shared_step(batch, train=False)

    def shared_step(self, batch, train):
        xb, yb = batch
        pred = self.forward(xb)
        loss = F.cross_entropy(pred, yb)

        # Logging
        # if train:
        #     self.train_accuracy(pred.softmax(dim=-1), yb)
        #     self.log('train_accuracy', self.train_accuracy, on_step=True, on_epoch=False, prog_bar=True)
        # else:
        #     self.valid_accuracy(pred.softmax(dim=-1), yb)
        #     self.log('valid_accuracy', self.valid_accuracy, on_step=True, on_epoch=True, prog_bar=True)

        return loss

    # def test_step(self, ....) -> for production use

    def configure_optimizers(self):
        return optim.SGD(self.parameters(), lr=self.lr)

In [None]:
# TODO: Could not get TensorBoardLogger to work. All comments related are commented out
# from pytorch_lightning.loggers import TensorBoardLogger

# logger
# tb_logger = TensorBoardLogger('tb_logs')

In [None]:
# init model
mnist_model = MNISTModel()

# init trainer
trainer = pl.Trainer(
    max_epochs=2, 
    # logger=tb_logger
)

# train the model
trainer.fit(mnist_model, train_dl)

# optionally: run test
# trainer.test()