# Pytorch Lightning Introduction

Welcome to the introduction to [`PyTorchLightning`](https://www.pytorchlightning.ai/). PyTorch Lightning is a wrapper for PyTorch that is focused towards building neural networks model quickly by removing the boilerplate code. It also extends the functionality of PyTorch, for example, with model Callbacks and automatic porting to GPU to accelerate computations.

In its essence, pytorch lightning provides a controlled setup where you are forced to in its specific partly automated library setup, but are able to customize/de-automatate each process if you so desire. You can both rely on a variety of already implemented functionalities and the fact that your code is by definition structured clearly, but at the cost of having to delve into their code structure and figuring out which functions to use.

Our suggestion is to try it out. If you don't like it, you can just stick to pytorch and build up your own library of functions you will use throughout your Deep Learning journey.

Let's get started by installing PyTorch Lightning.

## (Optional) Mount folder in Colab

Uncomment thefollowing cell to mount your gdrive if you are using the notebook in google colab:

In [1]:
"""
from google.colab import drive
import os

gdrive_path='/content/gdrive/MyDrive/DL_homeworks/homework_06'

# This will mount your google drive under 'MyDrive'
drive.mount('/content/gdrive', force_remount=True)
# In order to access the files in this notebook we have to navigate to the correct folder
os.chdir(gdrive_path)
# Check manually if all files are present
print(sorted(os.listdir()))
"""

"\nfrom google.colab import drive\nimport os\n\ngdrive_path='/content/gdrive/MyDrive/DL_homeworks/homework_06'\n\n# This will mount your google drive under 'MyDrive'\ndrive.mount('/content/gdrive', force_remount=True)\n# In order to access the files in this notebook we have to navigate to the correct folder\nos.chdir(gdrive_path)\n# Check manually if all files are present\nprint(sorted(os.listdir()))\n"

In [2]:
# Optional: install correct libraries in google colab
# !python -m pip install torch==1.11.0+cu113 torchvision==0.12.0+cu113 -f https://download.pytorch.org/whl/torch_stable.html
# !python -m pip install tensorboard==2.9.1

### Installing Pytorch Lightning

In [1]:
# For automatic file reloading as usual
%load_ext autoreload
%autoreload 2
%matplotlib inline

In [2]:
import sys

# For google colab
# !python -m pip install pytorch-lightning==1.6.0 > /dev/null

# For anaconda/regular python
!{sys.executable} -m pip install --ignore-installed pytorch-lightning==1.6.0

Collecting pytorch-lightning==1.6.0
  Using cached pytorch_lightning-1.6.0-py3-none-any.whl (582 kB)
Collecting numpy>=1.17.2 (from pytorch-lightning==1.6.0)
  Using cached numpy-1.26.1-cp310-cp310-macosx_11_0_arm64.whl.metadata (61 kB)
Collecting torch>=1.8.* (from pytorch-lightning==1.6.0)
  Using cached torch-2.1.0-cp310-none-macosx_11_0_arm64.whl.metadata (24 kB)
Collecting tqdm>=4.41.0 (from pytorch-lightning==1.6.0)
  Using cached tqdm-4.66.1-py3-none-any.whl.metadata (57 kB)
Collecting PyYAML>=5.4 (from pytorch-lightning==1.6.0)
  Using cached PyYAML-6.0.1-cp310-cp310-macosx_11_0_arm64.whl.metadata (2.1 kB)
Collecting fsspec!=2021.06.0,>=2021.05.0 (from fsspec[http]!=2021.06.0,>=2021.05.0->pytorch-lightning==1.6.0)
  Using cached fsspec-2023.10.0-py3-none-any.whl.metadata (6.8 kB)
Collecting tensorboard>=2.2.0 (from pytorch-lightning==1.6.0)
  Using cached tensorboard-2.15.1-py3-none-any.whl.metadata (1.7 kB)
Collecting torchmetrics>=0.4.1 (from pytorch-lightning==1.6.0)
  Using

Using cached werkzeug-3.0.1-py3-none-any.whl (226 kB)
Using cached filelock-3.13.1-py3-none-any.whl (11 kB)
Using cached networkx-3.2.1-py3-none-any.whl (1.6 MB)
Using cached async_timeout-4.0.3-py3-none-any.whl (5.7 kB)
Using cached cachetools-5.3.2-py3-none-any.whl (9.3 kB)
Using cached certifi-2023.7.22-py3-none-any.whl (158 kB)
Using cached charset_normalizer-3.3.2-cp310-cp310-macosx_11_0_arm64.whl (120 kB)
Using cached frozenlist-1.4.0-cp310-cp310-macosx_11_0_arm64.whl (46 kB)
Using cached MarkupSafe-2.1.3-cp310-cp310-macosx_10_9_universal2.whl (17 kB)
Using cached urllib3-2.0.7-py3-none-any.whl (124 kB)
[33mDEPRECATION: pytorch-lightning 1.6.0 has a non-standard dependency specifier torch>=1.8.*. pip 24.0 will enforce this behaviour change. A possible replacement is to upgrade to a newer version of pytorch-lightning or contact the author to suggest that they release a version with a conforming dependency specifiers. Discussion can be found at https://github.com/pypa/pip/issues/1

In [1]:
import pytorch_lightning as PL
print(f"Lightning version: {PL.__version__}")
if not PL.__version__.startswith("2.0"):
    print("You are using another version of pytorch lightning. We expect > pytorch lightning 2.0. You can continue with your version but it"
          " might cause dependency and compatibility issues.")

import tensorboard
import torch
import torchvision
print(f"Tensorboard version: {tensorboard.__version__}")
if not tensorboard.__version__.startswith("2.1"):
    print("You are using an another version of Tensorboard. We expect > Tensorboard 2.1 You may continue using your version but it"
          " might cause dependency and compatibility issues.")
print(f"PyTorch version Installed: {torch.__version__}\nTorchvision version Installed: {torchvision.__version__}\n")
if not torch.__version__.startswith("1.13"):
    print("you are using an another version of PyTorch. We expect > PyTorch 1.13. You may continue using your version but it"
          " might cause dependency and compatibility issues.")
if not torchvision.__version__.startswith("0.12"):
    print("you are using an another version of torchvision. We expect torchvision 0.12. You can continue with your version but it"
          " might cause dependency and compatibility issues.")

Lightning version: 2.1.1
You are using another version of pytorch lightning. We expect > pytorch lightning 2.0. You can continue with your version but it might cause dependency and compatibility issues.
Tensorboard version: 2.15.1
PyTorch version Installed: 2.1.0
Torchvision version Installed: 0.16.0

you are using an another version of PyTorch. We expect > PyTorch 1.13. You may continue using your version but it might cause dependency and compatibility issues.
you are using an another version of torchvision. We expect torchvision 0.12. You can continue with your version but it might cause dependency and compatibility issues.


# 1. Idea behind PyTorch Lightning

Codes in a Deep learning project consists of three main categories:

1. **Research code**   
    This is the exciting part of the experiment where you configure the model architecture and try out different optimizers and target task. This is managed by the `LightningModule` of PyTorch Lightning.
    
2. **Engineering code**  
    This is the same set of code that remain the same for all deep learning projects.Recall the training block of previous notebooks where we loop through the epochs and mini-batches. The `Trainer` class of PyTorch Lightning takes care of this part of code.
    
3. **Non-essential code**
    It is very important that we log our training metrics and organize different training runs to have purposeful experimentation of models. The `Callbacks` class PyTorch Lightning helps us with this section. 

Let's look at each of these modules in detail.

1. **LightningModules** contain all model related code. This is the part where we are working on when creating a new project. The idea is to have all important code in one module, e.g., the model's architecture and the evaluation of training and validation metrics. This provides a better overview as repeated elements, such as the training procedure, are not stored in the code that we work on. The lightning module also handles the calls `.to(device)` or `.train()` and `.eval()`. Hence, there is no need anymore to switch between the cpu and gpu and to take care of the model's mode as this is automated by the LightningModule. The framework also enables easy parallel computation on multiple gpus. 

2. **Trainer** contains all code needed for training our neural networks that doesn't change for each project ("one size fits all"). Usually, we don't touch the code automated by this class. The arguments that are specific for one training such as learning rate and batch size are provided as initialization arguments for the LightningModule.

3. **Callbacks** automate all parts needed for logging hyperparameters or training results such as the tensorboard logger. Logging becomes very important for research later since the results of experiments need to be reproducible.

All in all, PyTorch is a framework that handles all (annoying) "engineering" stuff for you such that you have more time for exciting research and scientific coding. This also results in the advantage that automated parts are guaranteed to be bug-free. Hence, you can't include a bug in a part of your code that is often used but not often checked. 

# 2. Overview of the PyTorch Lightning code

Research relevant code goes into the `LightningModule`. The advantage is that we have all the model building, training & validation steps within a single class. These are the components that usually change based on the projects and tasks.

![](https://drive.google.com/uc?export=view&id=1FqwrbgfMQD8Axowyq6zu3PeeNjjue4lK)

The remaining code is automated by the `Trainer` class which takes care of the tasks of our mechanical training loops components such as iterating through the minibatches and gradient updating steps.

![alt text](https://miro.medium.com/max/700/1*b81_j__xv8M0Bb6nFTXbAA.png)

We could already see how much more readable and concise our code is, after being transformed by PyTorch Lightning.

Let us now train a neural network model with PyTorch Lightning.

# 4. Training with PyTorch Lightning

We will build a two-layer neural network to train on the the [`Fashion-MNIST`](https://research.zalando.com/welcome/mission/research-projects/fashion-mnist/) dataset for this notebook. 

## 4.1 Define A LightningModule

We define our network as an instance of `L.LightningModule` which replaces our `PyTorch` network based on the class `nn.Module`. Additionally, it contains all the relevant parts that are used for training and evaluating different models on various tasks.  

Let's have a look at the implementation of `TwoLayerNet` in `exercise_code.lightning_models`.

The `__init__()` and `forward()` function defining the forward  pass remain the same. Hence, we can just copy the code from the `nn.Module`.

```python
class TwoLayerNet(L.LightningModule):
    def __init__(self, hparams, input_size=1 * 28 * 28, hidden_size=512, num_classes=10):
        super().__init__()
        self.save_hyperparameters(hparams)

        self.model = nn.Sequential(
            nn.Linear(input_size, hidden_size),
            nn.Sigmoid(),
            nn.Linear(hidden_size, num_classes),
        )

    def forward(self, x):
        # flatten the image  before sending as input to the model
        N, _, _, _ = x.shape
        x = x.view(N, -1)

        x = self.model(x)

        return x
```

We will now define the training  and validation steps since they also vary with different tasks and projects. Consequently, it is useful to integrate these parts into our instance of `LightningModule`. Validation loss is returned for each validation mini-batch and averaged at the end of the epoch.

```python
    def training_step(self, batch, batch_idx):
        images, targets = batch

        # Perform a forward pass on the network with inputs
        out = self.forward(images)

        # calculate the loss with the network predictions and ground truth targets
        loss = F.cross_entropy(out, targets)

        # Find the predicted class from probabilities of the image belonging to each of the classes
        # from the network output
        _, preds = torch.max(out, 1)

        # Calculate the accuracy of predictions
        acc = preds.eq(targets).sum().float() / targets.size(0)

        # Log the accuracy and loss values to the tensorboard
        self.log('loss', loss)
        self.log('acc', acc)

        return {'loss': loss}

    def validation_step(self, batch, batch_idx):
        images, targets = batch

        # Perform a forward pass on the network with inputs
        out = self.forward(images)

        # calculate the loss with the network predictions and ground truth targets
        loss = F.cross_entropy(out, targets)

        # Find the predicted class from probabilities of the image belonging to each of the classes
        # from the network output
        _, preds = torch.max(out, 1)

        # Calculate the accuracy of predictions
        acc = preds.eq(targets).sum().float() / targets.size(0)

        # Visualise the predictions  of the model
        if batch_idx == 0:
            self.visualize_predictions(images, out.detach(), targets)

        return {'val_loss': loss, 'val_acc': acc}

    def validation_epoch_end(self, outputs):

        # Average the loss over the entire validation data from it's mini-batches
        avg_loss = torch.stack([x['val_loss'] for x in outputs]).mean()
        avg_acc = torch.stack([x['val_acc'] for x in outputs]).mean()

        # Log the validation accuracy and loss values to the tensorboard
        self.log('val_loss', avg_loss)
        self.log('val_acc', avg_acc)
```

The last step missing in our `LightningModule` is the optimizer. This method needs to be defined in every `LightningModule`.

```python
    def configure_optimizers(self):
        optim = torch.optim.SGD(self.model.parameters(), self.hparams["learning_rate"], momentum=0.9)

        return optim
```

Now that we have set up the model and the training steps, we will now establish the data pipeline. PyTorch Lightning provides the `LightningDataModule` for setting up the dataloaders.

Let's have a look at the implementation of `FashionMNISTDataModule` in `exercise_code.data_class`.

The `prepare_data()` function intends to set up the dataset and the related transforms for it. As previously, we download the `FashionMNIST` dataset using `torchvision` and split the total training data into a training and validation set for tuning hyperparameters.

```python
class FashionMNISTDataModule(L.LightningDataModule):

    def __init__(self, batch_size=4):
        super().__init__()
        self.batch_size = batch_size

    def prepare_data(self):

        # Define the transform
        transform = transforms.Compose([transforms.ToTensor(),
                                        transforms.Normalize((0.5,), (0.5,))])

        # Download the Fashion-MNIST dataset
        fashion_mnist_train_val = torchvision.datasets.FashionMNIST(root='../datasets', train=True,
                                                                   download=True, transform=transform)

        self.fashion_mnist_test = torchvision.datasets.FashionMNIST(root='../datasets', train=False,
                                                                 download=True, transform=transform)

        # Apply the Transforms
        transform = transforms.Compose([transforms.ToTensor(),
                                        transforms.Normalize((0.5,), (0.5,))])

        # Perform the training and validation split
        self.train_dataset, self.val_dataset = random_split(
            fashion_mnist_train_val, [50000, 10000])
```

We shall now define `Dataloaders` for each of the data-splits. These data loaders can be directly called during model training.

```python

    def train_dataloader(self):
        return DataLoader(self.train_dataset, batch_size=self.batch_size)

    def val_dataloader(self):
        return DataLoader(self.val_dataset, batch_size=self.batch_size)

    def test_dataloader(self):
        return DataLoader(self.fashion_mnist_test, batch_size=self.batch_size)

```

You can notice now that most of the code of these steps can be directly copied from a Vanilla PyTorch code. Lightning just rearranges them. This marks the end of the research part of the code.

Let's see now how the `Trainer` class works:

##  4.2 Fitting the model with a Trainer

We will initialize the model and the data  with a set of hyperparameters given in the dictionary `hparams`.


In [2]:
from IPython.display import clear_output 

from exercise_code.lightning_models import TwoLayerNet
from exercise_code.data_class import FashionMNISTDataModule

hparams = {
    "batch_size": 16,
    "learning_rate": 1e-3,
    "input_size": 1 * 28 * 28,
    "hidden_size": 512,
    "num_classes": 10,
    "num_workers": 2,    # used by the dataloader, more workers means faster data preparation, but for us this is not a bottleneck here
}


model = TwoLayerNet(hparams)
data=FashionMNISTDataModule(hparams)
data.prepare_data()

 PyTorch Lightning provides ample flexibility for training using [`Trainer`](https://pytorch-lightning.readthedocs.io/en/latest/trainer.html) class.
Have a look at the documentation to know more about them.

Let's initialize it now.

In [8]:
trainer = PL.Trainer(
    max_epochs=2,
    # Uncomment to use GPU if available, if multiple you can select all indices, otherwise we use the first, -1 is CPU
    # devices="0"
)

GPU available: True (mps), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


The argument `max_epochs` sets the maximum number of epochs for training. 
The argument `weights_summary` prints a summary of the number of weights per layer at the beginning of the training. Set it to None if the summary is not required.
 

Here comes the actual training cell. The [`fit`](https://pytorch-lightning.readthedocs.io/en/latest/_modules/pytorch_lightning/trainer/trainer.html#Trainer.fit) function takes in the model and data to train the model with a lot more optional arguments for customization.

In [9]:
trainer = trainer.fit(model, data)


  | Name  | Type       | Params
-------------------------------------
0 | model | Sequential | 407 K 
-------------------------------------
407 K     Trainable params
0         Non-trainable params
407 K     Total params
1.628     Total estimated model params size (MB)


Sanity Checking: |                                        | 0/? [00:00<?, ?it/s]

/Users/tigrangaplanyan/anaconda3/envs/Homeworks/lib/python3.10/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:436: Consider setting `persistent_workers=True` in 'val_dataloader' to speed up the dataloader worker initialization.


                                                                                

/Users/tigrangaplanyan/anaconda3/envs/Homeworks/lib/python3.10/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:436: Consider setting `persistent_workers=True` in 'train_dataloader' to speed up the dataloader worker initialization.


Epoch 0: 100%|█████████| 3125/3125 [00:17<00:00, 175.89it/s, v_num=6, acc=0.688]
Validation: |                                             | 0/? [00:00<?, ?it/s][A
Validation:   0%|                                       | 0/625 [00:00<?, ?it/s][A
Validation DataLoader 0:   0%|                          | 0/625 [00:00<?, ?it/s][A
Validation DataLoader 0:   0%|                  | 1/625 [00:00<02:23,  4.34it/s][A
Validation DataLoader 0:   0%|                  | 2/625 [00:00<01:19,  7.86it/s][A
Validation DataLoader 0:   0%|                  | 3/625 [00:00<00:57, 10.85it/s][A
Validation DataLoader 0:   1%|                  | 4/625 [00:00<00:45, 13.62it/s][A
Validation DataLoader 0:   1%|▏                 | 5/625 [00:00<00:37, 16.60it/s][A
Validation DataLoader 0:   1%|▏                 | 6/625 [00:00<00:31, 19.46it/s][A
Validation DataLoader 0:   1%|▏                 | 7/625 [00:00<00:27, 22.16it/s][A
Validation DataLoader 0:   1%|▏                 | 8/625 [00:00<00:24, 24.79it/s

Validation DataLoader 0:  15%|██▍             | 94/625 [00:00<00:04, 132.63it/s][A
Validation DataLoader 0:  15%|██▍             | 95/625 [00:00<00:03, 133.31it/s][A
Validation DataLoader 0:  15%|██▍             | 96/625 [00:00<00:03, 133.86it/s][A
Validation DataLoader 0:  16%|██▍             | 97/625 [00:00<00:03, 134.37it/s][A
Validation DataLoader 0:  16%|██▌             | 98/625 [00:00<00:03, 135.02it/s][A
Validation DataLoader 0:  16%|██▌             | 99/625 [00:00<00:03, 135.65it/s][A
Validation DataLoader 0:  16%|██▍            | 100/625 [00:00<00:03, 136.18it/s][A
Validation DataLoader 0:  16%|██▍            | 101/625 [00:00<00:03, 136.74it/s][A
Validation DataLoader 0:  16%|██▍            | 102/625 [00:00<00:03, 137.42it/s][A
Validation DataLoader 0:  16%|██▍            | 103/625 [00:00<00:03, 138.00it/s][A
Validation DataLoader 0:  17%|██▍            | 104/625 [00:00<00:03, 138.56it/s][A
Validation DataLoader 0:  17%|██▌            | 105/625 [00:00<00:03, 139.07i

Validation DataLoader 0:  31%|████▌          | 191/625 [00:01<00:02, 170.80it/s][A
Validation DataLoader 0:  31%|████▌          | 192/625 [00:01<00:02, 170.88it/s][A
Validation DataLoader 0:  31%|████▋          | 193/625 [00:01<00:02, 171.02it/s][A
Validation DataLoader 0:  31%|████▋          | 194/625 [00:01<00:02, 171.13it/s][A
Validation DataLoader 0:  31%|████▋          | 195/625 [00:01<00:02, 171.30it/s][A
Validation DataLoader 0:  31%|████▋          | 196/625 [00:01<00:02, 171.49it/s][A
Validation DataLoader 0:  32%|████▋          | 197/625 [00:01<00:02, 171.61it/s][A
Validation DataLoader 0:  32%|████▊          | 198/625 [00:01<00:02, 171.77it/s][A
Validation DataLoader 0:  32%|████▊          | 199/625 [00:01<00:02, 171.96it/s][A
Validation DataLoader 0:  32%|████▊          | 200/625 [00:01<00:02, 172.20it/s][A
Validation DataLoader 0:  32%|████▊          | 201/625 [00:01<00:02, 172.30it/s][A
Validation DataLoader 0:  32%|████▊          | 202/625 [00:01<00:02, 172.46i

Validation DataLoader 0:  46%|██████▉        | 288/625 [00:01<00:01, 188.90it/s][A
Validation DataLoader 0:  46%|██████▉        | 289/625 [00:01<00:01, 189.10it/s][A
Validation DataLoader 0:  46%|██████▉        | 290/625 [00:01<00:01, 189.24it/s][A
Validation DataLoader 0:  47%|██████▉        | 291/625 [00:01<00:01, 189.38it/s][A
Validation DataLoader 0:  47%|███████        | 292/625 [00:01<00:01, 189.55it/s][A
Validation DataLoader 0:  47%|███████        | 293/625 [00:01<00:01, 189.76it/s][A
Validation DataLoader 0:  47%|███████        | 294/625 [00:01<00:01, 189.91it/s][A
Validation DataLoader 0:  47%|███████        | 295/625 [00:01<00:01, 190.07it/s][A
Validation DataLoader 0:  47%|███████        | 296/625 [00:01<00:01, 190.22it/s][A
Validation DataLoader 0:  48%|███████▏       | 297/625 [00:01<00:01, 190.44it/s][A
Validation DataLoader 0:  48%|███████▏       | 298/625 [00:01<00:01, 190.55it/s][A
Validation DataLoader 0:  48%|███████▏       | 299/625 [00:01<00:01, 190.72i

Validation DataLoader 0:  62%|█████████▏     | 385/625 [00:01<00:01, 200.52it/s][A
Validation DataLoader 0:  62%|█████████▎     | 386/625 [00:01<00:01, 200.61it/s][A
Validation DataLoader 0:  62%|█████████▎     | 387/625 [00:01<00:01, 200.74it/s][A
Validation DataLoader 0:  62%|█████████▎     | 388/625 [00:01<00:01, 200.82it/s][A
Validation DataLoader 0:  62%|█████████▎     | 389/625 [00:01<00:01, 200.92it/s][A
Validation DataLoader 0:  62%|█████████▎     | 390/625 [00:01<00:01, 201.04it/s][A
Validation DataLoader 0:  63%|█████████▍     | 391/625 [00:01<00:01, 201.21it/s][A
Validation DataLoader 0:  63%|█████████▍     | 392/625 [00:01<00:01, 201.27it/s][A
Validation DataLoader 0:  63%|█████████▍     | 393/625 [00:01<00:01, 201.37it/s][A
Validation DataLoader 0:  63%|█████████▍     | 394/625 [00:01<00:01, 201.49it/s][A
Validation DataLoader 0:  63%|█████████▍     | 395/625 [00:01<00:01, 201.63it/s][A
Validation DataLoader 0:  63%|█████████▌     | 396/625 [00:01<00:01, 201.69i

Validation DataLoader 0:  77%|███████████▌   | 482/625 [00:02<00:00, 207.80it/s][A
Validation DataLoader 0:  77%|███████████▌   | 483/625 [00:02<00:00, 207.83it/s][A
Validation DataLoader 0:  77%|███████████▌   | 484/625 [00:02<00:00, 207.90it/s][A
Validation DataLoader 0:  78%|███████████▋   | 485/625 [00:02<00:00, 207.96it/s][A
Validation DataLoader 0:  78%|███████████▋   | 486/625 [00:02<00:00, 208.06it/s][A
Validation DataLoader 0:  78%|███████████▋   | 487/625 [00:02<00:00, 208.16it/s][A
Validation DataLoader 0:  78%|███████████▋   | 488/625 [00:02<00:00, 208.24it/s][A
Validation DataLoader 0:  78%|███████████▋   | 489/625 [00:02<00:00, 208.27it/s][A
Validation DataLoader 0:  78%|███████████▊   | 490/625 [00:02<00:00, 208.30it/s][A
Validation DataLoader 0:  79%|███████████▊   | 491/625 [00:02<00:00, 208.38it/s][A
Validation DataLoader 0:  79%|███████████▊   | 492/625 [00:02<00:00, 208.45it/s][A
Validation DataLoader 0:  79%|███████████▊   | 493/625 [00:02<00:00, 208.50i

Validation DataLoader 0:  93%|█████████████▉ | 579/625 [00:02<00:00, 212.42it/s][A
Validation DataLoader 0:  93%|█████████████▉ | 580/625 [00:02<00:00, 212.45it/s][A
Validation DataLoader 0:  93%|█████████████▉ | 581/625 [00:02<00:00, 212.38it/s][A
Validation DataLoader 0:  93%|█████████████▉ | 582/625 [00:02<00:00, 212.39it/s][A
Validation DataLoader 0:  93%|█████████████▉ | 583/625 [00:02<00:00, 212.41it/s][A
Validation DataLoader 0:  93%|██████████████ | 584/625 [00:02<00:00, 212.37it/s][A
Validation DataLoader 0:  94%|██████████████ | 585/625 [00:02<00:00, 212.39it/s][A
Validation DataLoader 0:  94%|██████████████ | 586/625 [00:02<00:00, 212.45it/s][A
Validation DataLoader 0:  94%|██████████████ | 587/625 [00:02<00:00, 212.46it/s][A
Validation DataLoader 0:  94%|██████████████ | 588/625 [00:02<00:00, 212.50it/s][A
Validation DataLoader 0:  94%|██████████████▏| 589/625 [00:02<00:00, 212.55it/s][A
Validation DataLoader 0:  94%|██████████████▏| 590/625 [00:02<00:00, 212.58i

Validation DataLoader 0:   8%|█▎               | 47/625 [00:00<00:07, 80.41it/s][A
Validation DataLoader 0:   8%|█▎               | 48/625 [00:00<00:07, 81.62it/s][A
Validation DataLoader 0:   8%|█▎               | 49/625 [00:00<00:06, 82.64it/s][A
Validation DataLoader 0:   8%|█▎               | 50/625 [00:00<00:06, 83.72it/s][A
Validation DataLoader 0:   8%|█▍               | 51/625 [00:00<00:06, 84.86it/s][A
Validation DataLoader 0:   8%|█▍               | 52/625 [00:00<00:06, 85.99it/s][A
Validation DataLoader 0:   8%|█▍               | 53/625 [00:00<00:06, 86.98it/s][A
Validation DataLoader 0:   9%|█▍               | 54/625 [00:00<00:06, 88.06it/s][A
Validation DataLoader 0:   9%|█▍               | 55/625 [00:00<00:06, 89.13it/s][A
Validation DataLoader 0:   9%|█▌               | 56/625 [00:00<00:06, 90.16it/s][A
Validation DataLoader 0:   9%|█▌               | 57/625 [00:00<00:06, 91.14it/s][A
Validation DataLoader 0:   9%|█▌               | 58/625 [00:00<00:06, 92.12i

Validation DataLoader 0:  23%|███▍           | 144/625 [00:00<00:03, 146.60it/s][A
Validation DataLoader 0:  23%|███▍           | 145/625 [00:00<00:03, 147.06it/s][A
Validation DataLoader 0:  23%|███▌           | 146/625 [00:00<00:03, 147.40it/s][A
Validation DataLoader 0:  24%|███▌           | 147/625 [00:00<00:03, 147.78it/s][A
Validation DataLoader 0:  24%|███▌           | 148/625 [00:00<00:03, 148.23it/s][A
Validation DataLoader 0:  24%|███▌           | 149/625 [00:01<00:03, 148.67it/s][A
Validation DataLoader 0:  24%|███▌           | 150/625 [00:01<00:03, 149.12it/s][A
Validation DataLoader 0:  24%|███▌           | 151/625 [00:01<00:03, 149.49it/s][A
Validation DataLoader 0:  24%|███▋           | 152/625 [00:01<00:03, 149.85it/s][A
Validation DataLoader 0:  24%|███▋           | 153/625 [00:01<00:03, 150.29it/s][A
Validation DataLoader 0:  25%|███▋           | 154/625 [00:01<00:03, 150.72it/s][A
Validation DataLoader 0:  25%|███▋           | 155/625 [00:01<00:03, 151.00i

Validation DataLoader 0:  39%|█████▊         | 241/625 [00:01<00:02, 174.79it/s][A
Validation DataLoader 0:  39%|█████▊         | 242/625 [00:01<00:02, 174.95it/s][A
Validation DataLoader 0:  39%|█████▊         | 243/625 [00:01<00:02, 175.18it/s][A
Validation DataLoader 0:  39%|█████▊         | 244/625 [00:01<00:02, 175.36it/s][A
Validation DataLoader 0:  39%|█████▉         | 245/625 [00:01<00:02, 175.57it/s][A
Validation DataLoader 0:  39%|█████▉         | 246/625 [00:01<00:02, 175.78it/s][A
Validation DataLoader 0:  40%|█████▉         | 247/625 [00:01<00:02, 176.01it/s][A
Validation DataLoader 0:  40%|█████▉         | 248/625 [00:01<00:02, 176.18it/s][A
Validation DataLoader 0:  40%|█████▉         | 249/625 [00:01<00:02, 176.34it/s][A
Validation DataLoader 0:  40%|██████         | 250/625 [00:01<00:02, 176.54it/s][A
Validation DataLoader 0:  40%|██████         | 251/625 [00:01<00:02, 176.76it/s][A
Validation DataLoader 0:  40%|██████         | 252/625 [00:01<00:02, 177.00i

Validation DataLoader 0:  54%|████████       | 338/625 [00:01<00:01, 190.97it/s][A
Validation DataLoader 0:  54%|████████▏      | 339/625 [00:01<00:01, 191.04it/s][A
Validation DataLoader 0:  54%|████████▏      | 340/625 [00:01<00:01, 191.10it/s][A
Validation DataLoader 0:  55%|████████▏      | 341/625 [00:01<00:01, 191.19it/s][A
Validation DataLoader 0:  55%|████████▏      | 342/625 [00:01<00:01, 191.33it/s][A
Validation DataLoader 0:  55%|████████▏      | 343/625 [00:01<00:01, 191.42it/s][A
Validation DataLoader 0:  55%|████████▎      | 344/625 [00:01<00:01, 191.63it/s][A
Validation DataLoader 0:  55%|████████▎      | 345/625 [00:01<00:01, 191.69it/s][A
Validation DataLoader 0:  55%|████████▎      | 346/625 [00:01<00:01, 191.82it/s][A
Validation DataLoader 0:  56%|████████▎      | 347/625 [00:01<00:01, 191.89it/s][A
Validation DataLoader 0:  56%|████████▎      | 348/625 [00:01<00:01, 192.05it/s][A
Validation DataLoader 0:  56%|████████▍      | 349/625 [00:01<00:01, 192.21i

Validation DataLoader 0:  70%|██████████▍    | 435/625 [00:02<00:00, 201.66it/s][A
Validation DataLoader 0:  70%|██████████▍    | 436/625 [00:02<00:00, 201.77it/s][A
Validation DataLoader 0:  70%|██████████▍    | 437/625 [00:02<00:00, 201.90it/s][A
Validation DataLoader 0:  70%|██████████▌    | 438/625 [00:02<00:00, 202.05it/s][A
Validation DataLoader 0:  70%|██████████▌    | 439/625 [00:02<00:00, 202.16it/s][A
Validation DataLoader 0:  70%|██████████▌    | 440/625 [00:02<00:00, 202.25it/s][A
Validation DataLoader 0:  71%|██████████▌    | 441/625 [00:02<00:00, 202.33it/s][A
Validation DataLoader 0:  71%|██████████▌    | 442/625 [00:02<00:00, 202.42it/s][A
Validation DataLoader 0:  71%|██████████▋    | 443/625 [00:02<00:00, 202.54it/s][A
Validation DataLoader 0:  71%|██████████▋    | 444/625 [00:02<00:00, 202.63it/s][A
Validation DataLoader 0:  71%|██████████▋    | 445/625 [00:02<00:00, 202.67it/s][A
Validation DataLoader 0:  71%|██████████▋    | 446/625 [00:02<00:00, 202.73i

Validation DataLoader 0:  85%|████████████▊  | 532/625 [00:02<00:00, 208.76it/s][A
Validation DataLoader 0:  85%|████████████▊  | 533/625 [00:02<00:00, 208.83it/s][A
Validation DataLoader 0:  85%|████████████▊  | 534/625 [00:02<00:00, 208.89it/s][A
Validation DataLoader 0:  86%|████████████▊  | 535/625 [00:02<00:00, 208.93it/s][A
Validation DataLoader 0:  86%|████████████▊  | 536/625 [00:02<00:00, 208.98it/s][A
Validation DataLoader 0:  86%|████████████▉  | 537/625 [00:02<00:00, 209.06it/s][A
Validation DataLoader 0:  86%|████████████▉  | 538/625 [00:02<00:00, 209.13it/s][A
Validation DataLoader 0:  86%|████████████▉  | 539/625 [00:02<00:00, 209.15it/s][A
Validation DataLoader 0:  86%|████████████▉  | 540/625 [00:02<00:00, 209.19it/s][A
Validation DataLoader 0:  87%|████████████▉  | 541/625 [00:02<00:00, 209.27it/s][A
Validation DataLoader 0:  87%|█████████████  | 542/625 [00:02<00:00, 209.35it/s][A
Validation DataLoader 0:  87%|█████████████  | 543/625 [00:02<00:00, 209.41i

`Trainer.fit` stopped: `max_epochs=2` reached.


Epoch 1: 100%|█| 3125/3125 [00:22<00:00, 137.79it/s, v_num=6, acc=0.812, val_los


Checkout the directory `lightning_logs`. For each run there is a new directory `version_xx` created. The rightmost argument in the progress bar, the `v_num` variable above shows the version of the current run. Each directory automatically contains a folder with checkpoints, logs and the hyperparameters for this run.

As seen in the last notebook, you can have a look at the  logs of the runs in the TensorBoard  
Use the command as in the previous notebook in your terminal 
```
tensorboard --logdir lightning_logs
```
Make sure to use the above command as the same directory as `homework_06`.

If you are using Google Colab, run the following cell to load the TensorBoard extension within the notebook. You may have to scroll to this block whenever you need to look at the TensorBoard interface.

In [11]:
# %load_ext tensorboard
# %tensorboard --logdir lightning_logs

## 4.3 Add images to tensorboard

The tensorboard logger is a submodule of the `LightningModule` and can be accessed via `self.logger`. We can  add images  to the logging module by calling 
```python
self.logger.experiment.add_image('tag', image)
```
to add an image. 


We will log the first batch of validation images in a grid together with the predicted class labels and the ground truth labels. 

```python
        if batch_idx == 0:
            self.visualize_predictions(images, out.detach(), targets)
```

Let's have a look at the implementation of `visualize_predictions()` function in `exercise_code.lightning_models`.



```python
    def visualize_predictions(self, images, preds, targets):
        class_names = ['airplane', 'automobile', 'bird', 'cat', 'deer', 
                       'dog', 'frog', 'horse', 'ship', 'truck']

        # determine size of the grid based on given batch size
        num_rows = torch.tensor(len(images)).float().sqrt().floor()
        
        fig = plt.figure(figsize=(10, 10))
        for i in range(len(images)):
            plt.subplot(num_rows ,len(images) // num_rows + 1, i+1)
            plt.imshow(images[i].permute(1, 2, 0))
            plt.title(class_names[torch.argmax(preds, axis=-1)[i]] + f'\n[{class_names[targets[i]]}]')
            plt.axis('off')

        self.logger.experiment.add_figure('predictions', fig, global_step=self.global_step)
```

You can view the logged images in your `IMAGES` tab of TensorBoard.

We have now looked at how to train a model using PyTorch Lightning. PyTorch Lightning is very active in developement and the features set are continously expanded and updated.


# 4. Other Features of PyTorch Lightning


### Checking  training timings

The argument `profiler` of the `Trainer` class measures the time taken in different steps such as dataloading, forward and backward pass. You can select a variety of loggers [here](https://pytorch-lightning.readthedocs.io/en/stable/advanced/profiler.html).

Run the cell below to see for yourself.

In [14]:
trainer = PL.Trainer(
    profiler='simple',
    max_epochs=1,
    # gpus=1 # Use GPU if available
   
)

trainer.fit(model, data)

GPU available: True (mps), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs

  | Name  | Type       | Params
-------------------------------------
0 | model | Sequential | 407 K 
-------------------------------------
407 K     Trainable params
0         Non-trainable params
407 K     Total params
1.628     Total estimated model params size (MB)


Epoch 0: 100%|█████████| 3125/3125 [00:17<00:00, 174.85it/s, v_num=7, acc=0.812]
Validation: |                                             | 0/? [00:00<?, ?it/s][A
Validation:   0%|                                       | 0/625 [00:00<?, ?it/s][A
Validation DataLoader 0:   0%|                          | 0/625 [00:00<?, ?it/s][A
Validation DataLoader 0:   0%|                  | 1/625 [00:00<02:21,  4.42it/s][A
Validation DataLoader 0:   0%|                  | 2/625 [00:00<01:23,  7.49it/s][A
Validation DataLoader 0:   0%|                  | 3/625 [00:00<01:01, 10.15it/s][A
Validation DataLoader 0:   1%|                  | 4/625 [00:00<00:47, 13.12it/s][A
Validation DataLoader 0:   1%|▏                 | 5/625 [00:00<00:38, 16.01it/s][A
Validation DataLoader 0:   1%|▏                 | 6/625 [00:00<00:33, 18.73it/s][A
Validation DataLoader 0:   1%|▏                 | 7/625 [00:00<00:28, 21.43it/s][A
Validation DataLoader 0:   1%|▏                 | 8/625 [00:00<00:25, 24.06it/s

Validation DataLoader 0:  15%|██▍             | 94/625 [00:00<00:04, 130.28it/s][A
Validation DataLoader 0:  15%|██▍             | 95/625 [00:00<00:04, 130.80it/s][A
Validation DataLoader 0:  15%|██▍             | 96/625 [00:00<00:04, 131.43it/s][A
Validation DataLoader 0:  16%|██▍             | 97/625 [00:00<00:03, 132.14it/s][A
Validation DataLoader 0:  16%|██▌             | 98/625 [00:00<00:03, 132.53it/s][A
Validation DataLoader 0:  16%|██▌             | 99/625 [00:00<00:03, 133.20it/s][A
Validation DataLoader 0:  16%|██▍            | 100/625 [00:00<00:03, 133.91it/s][A
Validation DataLoader 0:  16%|██▍            | 101/625 [00:00<00:03, 134.56it/s][A
Validation DataLoader 0:  16%|██▍            | 102/625 [00:00<00:03, 134.90it/s][A
Validation DataLoader 0:  16%|██▍            | 103/625 [00:00<00:03, 135.46it/s][A
Validation DataLoader 0:  17%|██▍            | 104/625 [00:00<00:03, 136.11it/s][A
Validation DataLoader 0:  17%|██▌            | 105/625 [00:00<00:03, 136.76i

Validation DataLoader 0:  31%|████▌          | 191/625 [00:01<00:02, 169.29it/s][A
Validation DataLoader 0:  31%|████▌          | 192/625 [00:01<00:02, 169.55it/s][A
Validation DataLoader 0:  31%|████▋          | 193/625 [00:01<00:02, 169.72it/s][A
Validation DataLoader 0:  31%|████▋          | 194/625 [00:01<00:02, 169.99it/s][A
Validation DataLoader 0:  31%|████▋          | 195/625 [00:01<00:02, 170.30it/s][A
Validation DataLoader 0:  31%|████▋          | 196/625 [00:01<00:02, 170.64it/s][A
Validation DataLoader 0:  32%|████▋          | 197/625 [00:01<00:02, 170.86it/s][A
Validation DataLoader 0:  32%|████▊          | 198/625 [00:01<00:02, 171.00it/s][A
Validation DataLoader 0:  32%|████▊          | 199/625 [00:01<00:02, 171.30it/s][A
Validation DataLoader 0:  32%|████▊          | 200/625 [00:01<00:02, 171.60it/s][A
Validation DataLoader 0:  32%|████▊          | 201/625 [00:01<00:02, 171.79it/s][A
Validation DataLoader 0:  32%|████▊          | 202/625 [00:01<00:02, 172.02i

Validation DataLoader 0:  46%|██████▉        | 288/625 [00:01<00:01, 183.94it/s][A
Validation DataLoader 0:  46%|██████▉        | 289/625 [00:01<00:01, 183.93it/s][A
Validation DataLoader 0:  46%|██████▉        | 290/625 [00:01<00:01, 184.00it/s][A
Validation DataLoader 0:  47%|██████▉        | 291/625 [00:01<00:01, 184.19it/s][A
Validation DataLoader 0:  47%|███████        | 292/625 [00:01<00:01, 184.40it/s][A
Validation DataLoader 0:  47%|███████        | 293/625 [00:01<00:01, 184.44it/s][A
Validation DataLoader 0:  47%|███████        | 294/625 [00:01<00:01, 184.59it/s][A
Validation DataLoader 0:  47%|███████        | 295/625 [00:01<00:01, 184.77it/s][A
Validation DataLoader 0:  47%|███████        | 296/625 [00:01<00:01, 184.97it/s][A
Validation DataLoader 0:  48%|███████▏       | 297/625 [00:01<00:01, 185.10it/s][A
Validation DataLoader 0:  48%|███████▏       | 298/625 [00:01<00:01, 185.26it/s][A
Validation DataLoader 0:  48%|███████▏       | 299/625 [00:01<00:01, 185.42i

Validation DataLoader 0:  62%|█████████▏     | 385/625 [00:01<00:01, 194.26it/s][A
Validation DataLoader 0:  62%|█████████▎     | 386/625 [00:01<00:01, 194.36it/s][A
Validation DataLoader 0:  62%|█████████▎     | 387/625 [00:01<00:01, 194.48it/s][A
Validation DataLoader 0:  62%|█████████▎     | 388/625 [00:01<00:01, 194.54it/s][A
Validation DataLoader 0:  62%|█████████▎     | 389/625 [00:01<00:01, 194.69it/s][A
Validation DataLoader 0:  62%|█████████▎     | 390/625 [00:02<00:01, 194.74it/s][A
Validation DataLoader 0:  63%|█████████▍     | 391/625 [00:02<00:01, 194.86it/s][A
Validation DataLoader 0:  63%|█████████▍     | 392/625 [00:02<00:01, 194.94it/s][A
Validation DataLoader 0:  63%|█████████▍     | 393/625 [00:02<00:01, 195.07it/s][A
Validation DataLoader 0:  63%|█████████▍     | 394/625 [00:02<00:01, 195.11it/s][A
Validation DataLoader 0:  63%|█████████▍     | 395/625 [00:02<00:01, 195.17it/s][A
Validation DataLoader 0:  63%|█████████▌     | 396/625 [00:02<00:01, 195.28i

Validation DataLoader 0:  77%|███████████▌   | 482/625 [00:02<00:00, 198.64it/s][A
Validation DataLoader 0:  77%|███████████▌   | 483/625 [00:02<00:00, 198.70it/s][A
Validation DataLoader 0:  77%|███████████▌   | 484/625 [00:02<00:00, 198.77it/s][A
Validation DataLoader 0:  78%|███████████▋   | 485/625 [00:02<00:00, 198.80it/s][A
Validation DataLoader 0:  78%|███████████▋   | 486/625 [00:02<00:00, 198.90it/s][A
Validation DataLoader 0:  78%|███████████▋   | 487/625 [00:02<00:00, 198.99it/s][A
Validation DataLoader 0:  78%|███████████▋   | 488/625 [00:02<00:00, 199.06it/s][A
Validation DataLoader 0:  78%|███████████▋   | 489/625 [00:02<00:00, 199.09it/s][A
Validation DataLoader 0:  78%|███████████▊   | 490/625 [00:02<00:00, 199.13it/s][A
Validation DataLoader 0:  79%|███████████▊   | 491/625 [00:02<00:00, 199.18it/s][A
Validation DataLoader 0:  79%|███████████▊   | 492/625 [00:02<00:00, 199.16it/s][A
Validation DataLoader 0:  79%|███████████▊   | 493/625 [00:02<00:00, 199.16i

Validation DataLoader 0:  93%|█████████████▉ | 579/625 [00:02<00:00, 202.61it/s][A
Validation DataLoader 0:  93%|█████████████▉ | 580/625 [00:02<00:00, 202.69it/s][A
Validation DataLoader 0:  93%|█████████████▉ | 581/625 [00:02<00:00, 202.69it/s][A
Validation DataLoader 0:  93%|█████████████▉ | 582/625 [00:02<00:00, 202.70it/s][A
Validation DataLoader 0:  93%|█████████████▉ | 583/625 [00:02<00:00, 202.74it/s][A
Validation DataLoader 0:  93%|██████████████ | 584/625 [00:02<00:00, 202.81it/s][A
Validation DataLoader 0:  94%|██████████████ | 585/625 [00:02<00:00, 202.87it/s][A
Validation DataLoader 0:  94%|██████████████ | 586/625 [00:02<00:00, 202.87it/s][A
Validation DataLoader 0:  94%|██████████████ | 587/625 [00:02<00:00, 202.93it/s][A
Validation DataLoader 0:  94%|██████████████ | 588/625 [00:02<00:00, 202.98it/s][A
Validation DataLoader 0:  94%|██████████████▏| 589/625 [00:02<00:00, 203.06it/s][A
Validation DataLoader 0:  94%|██████████████▏| 590/625 [00:02<00:00, 203.09i

`Trainer.fit` stopped: `max_epochs=1` reached.


Epoch 0: 100%|█| 3125/3125 [00:21<00:00, 142.31it/s, v_num=7, acc=0.812, val_los


FIT Profiler Report

--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
|  Action                                                                                                                                                         	|  Mean duration (s)	|  Num calls      	|  Total time (s) 	|  Percentage %   	|
--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
|  Total                                                                                                                                                          	|  -              	|  121373         	|

We can see an overview of the time taken for different steps.
This  enables us to detect bottlenecks in the model more easily. A bottleneck can be, for example, long times in dataloading. It becomes very important later, especially, when you start to implement custom layers or loss functions.

### Some more debugging Options

* [`fast_dev_run`](https://pytorch-lightning.readthedocs.io/en/latest/trainer.html#fast-dev-run): Runs of batch of each train, validation and test pass (if validation and test datalaoders are passed as arguments).This is a fast way to check if everything works (dataloading, validation metric, model saving/ loading) without having to wait for a full epoch.

* [`track_grad_norm`](https://pytorch-lightning.readthedocs.io/en/latest/trainer.html#track-grad-norm): Logs the  norm of the gradients (set to `1` for the $L1$ norm or `2` for the $L2$ norm) for each layer. You can check whether the network is actually doing something. If the gradients are too small or too high, you won't have a good training (due to vanishing/ exploding gradients).


### Other Features

Finally, we want to mention some other useful options in the Trainer class:

* [`resume_from_checkpoint`](https://pytorch-lightning.readthedocs.io/en/latest/trainer.html#resume-from-checkpoint): Start the training from a checkpoint saved earlier. Argument is the path to the saved model file.
* [`Callbacks`](https://pytorch-lightning.readthedocs.io/en/latest/callbacks.html#callback): Callbacks are extremely useful system during training that automate non essential code such as  storing model checkpoints , saving weights values among others.

Let's have the look at the [`EarlyStopping`](https://pytorch-lightning.readthedocs.io/en/latest/early_stopping.html#early-stopping-based-on-metric-using-the-earlystopping-callback)  callback.

It interrupts the training if the `monitor` metric variable does not  improve for `patience` number of epochs.

Below is a code example on how to apply it.

```python
from lightning.callbacks.early_stopping import EarlyStopping

early_stop_callback = EarlyStopping(
   monitor='val_accuracy',
   patience=3,
   verbose=False,
   mode='max'
)

trainer = Trainer(max_epochs=10,callbacks=[early_stop_callback])
```

Pytorch lightning will force you to check out a bunch of callbacks to return the desired functions without having to write the respective code yourself. 

## References

1. PyTorch Lightining [`Source Code`](https://github.com/PyTorchLightning/pytorch-lightning) with a nice introduction 
2. PyTorch Lightining [`Documentation`](https://pytorch-lightning.readthedocs.io/en/latest/#)  Explore it, the features are very well explained. 
3. The creator's [`Blogpost`](https://towardsdatascience.com/from-pytorch-to-pytorch-lightning-a-gentle-introduction-b371b7caaf09) It is a gentle introduction From PyTorch to PyTorch Lightning. 