# PyTorch Lightning

Is a library built on top of the PyTorch library, to simplify the APIs, remove much of the necessary boilerplate code
to implement things, and more importantly it allows us to use many advanced features such as **multi-GPU** support and
**fast low-precision training**.

In [1]:
# Implementing the model

import pytorch_lightning as pl
import torch
import torch.nn as nn

from torchmetrics import Accuracy

class MultiLayerPerceptron(pl.LightningModule):
    def __init__(self, image_shape = (1,28,28), hidden_units=(32,16)):
        super().__init__()

        self.train_acc = Accuracy(task="multiclass",num_classes=10)
        self.valid_acc = Accuracy(task="multiclass",num_classes=10)
        self.test_acc = Accuracy(task="multiclass",num_classes=10)
        # Utilities to automatically compute accuracies

        input_size = image_shape[0] * image_shape[1] * image_shape[2]
        all_layers = [nn.Flatten()]
        for hidden_unit in hidden_units:
            all_layers.append(nn.Linear(input_size, hidden_unit))
            all_layers.append(nn.ReLU())
            input_size = hidden_unit

        all_layers.append(nn.Linear(hidden_units[-1],10))
        self.model = nn.Sequential(*all_layers)

    def forward(self, x):
        return self.model(x)

    # Method recognized by lightning
    def training_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = nn.functional.cross_entropy(self(x),y)
        preds = torch.argmax(logits,1)
        self.train_acc.update(preds,y)
        self.log("train_loss",loss, prog_bar=True)
        return loss

    # Method recognized by lightning
    def on_train_epoch_end(self):
        self.log("train_acc",self.train_acc.compute())

    # Method recognized by lightning
    def validation_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = nn.functional.cross_entropy(self(x),y)
        preds = torch.argmax(logits,1)
        self.valid_acc.update(preds,y)
        self.log("valid_loss",loss, prog_bar=True)
        self.log("valid_acc",self.valid_acc.compute(), prog_bar=True)
        return loss

    # Method recognized by lightning
    def test_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = nn.functional.cross_entropy(self(x),y)
        preds = torch.argmax(logits,1)
        self.test_acc.update(preds,y)
        self.log("test_loss",loss, prog_bar=True)
        self.log("test_acc",self.test_acc.compute(), prog_bar=True)
        return loss

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(),lr=0.001)
        return optimizer

# Setting up the data loaders for Lightning

There are three ways to prepare the dataset for Lightning:
1. Make the dataset part of the model;
2. Set up the data loaders as usual and feed them to the `fit` method of a Lightning Trainer;
3. Create a `LightningDataModule`.

In [2]:
# Using the LightningDataModule approach
from torch.utils.data import DataLoader, random_split
from torchvision.datasets import MNIST
import torchvision.transforms.v2 as transforms

class MnistDataModule(pl.LightningDataModule):
    def __init__(self,data_path="../NNs with PyTorch/"):
        super().__init__()
        self.data_path = data_path
        self.transform = transforms.Compose([transforms.ToImage(),transforms.ToDtype(torch.float32,scale=True)])

    def prepare_data(self):
        # Here we should use this method to download the data and prepare it
        # MNIST(root=self.data_path,download=True)
        pass

    def setup(self, stage):
        # Prepares the data to making use of it. Can create logic based on the provided stage.
        # stage is either 'fit', 'validate', 'test' or 'predict'
        mnist_all = MNIST(root=self.data_path,train=True,transform=self.transform,download=False)
        self.train, self.val = random_split(mnist_all,[55000,5000],generator=torch.Generator().manual_seed(1))
        self.test = MNIST(root=self.data_path,train=False,transform=self.transform,download=False)
    
    def train_dataloader(self):
        return DataLoader(self.train,batch_size=64, num_workers=4,persistent_workers=True)

    def val_dataloader(self):
        return DataLoader(self.val,batch_size=64, num_workers=4,persistent_workers=True)
        
    def test_dataloader(self):
        return DataLoader(self.test,batch_size=64, num_workers=4)

# Initializing the data module for training, validation and testing
torch.manual_seed(1)
mnist_dm = MnistDataModule()

# Training the model using the PyTorch Lightning Trainer class

Lightning implements a `Trainer` class that makes the training model super convenient by taking care of all the
intermediate steps, such as calling `.zero_grad()`, `.backward()` and `.step()`. Also, as a bonus, it lets us easily
specify one or more GPUs to use (if available):

In [3]:
mnistclassifier = MultiLayerPerceptron()

# Using MPS since I'm on apple silicon
trainer = pl.Trainer(accelerator="mps",max_epochs=10)

trainer.fit(model=mnistclassifier,datamodule=mnist_dm)


GPU available: True (mps), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs

  | Name      | Type               | Params | Mode 
---------------------------------------------------------
0 | train_acc | MulticlassAccuracy | 0      | train
1 | valid_acc | MulticlassAccuracy | 0      | train
2 | test_acc  | MulticlassAccuracy | 0      | train
3 | model     | Sequential         | 25.8 K | train
---------------------------------------------------------
25.8 K    Trainable params
0         Non-trainable params
25.8 K    Total params
0.103     Total estimated model params size (MB)
10        Modules in train mode
0         Modules in eval mode


Epoch 9: 100%|██████████| 860/860 [00:08<00:00, 97.10it/s, v_num=0, train_loss=0.250, valid_loss=0.167, valid_acc=0.936]   

`Trainer.fit` stopped: `max_epochs=10` reached.


Epoch 9: 100%|██████████| 860/860 [00:08<00:00, 97.00it/s, v_num=0, train_loss=0.250, valid_loss=0.167, valid_acc=0.936]


# Evaluating the model using TensorBoard

By default, Lightning tracks the training in a subfolder named `lightning_logs`.

Tensorboard can be used to analize the logs using the following commands, targetting the logs directory (during, and
after the training has finished).
```sh
tensorboard --logdir lightning_logs/
```