# 2 Working faster and cleaner with PyTorch Lightning

## 2.1 Introduction

In the previous tutorial, we learned how to build a simple neural network using PyTorch. In this tutorial, we will learn how to use PyTorch Lightning to make our code cleaner and more efficient.

PyTorch Lightning is a lightweight wrapper around PyTorch that helps you organize your code and decouple the science code from the engineering code. It provides a high-level interface for training and testing models, making it easier to write clean and maintainable code. It also helps you to scale your code to multiple GPUs and TPUs, and to run it on different platforms (e.g., cloud, local, etc.) without changing your code. 

We will also briefly look at Weights & Biases (wandb), a tool for tracking experiments and visualizing results. It is not required to use PyTorch Lightning, but it is a great tool to have in your toolbox. We will use it to log our training and validation metrics, and to visualize our results.

## 2.2 PyTorch Lightning

### 2.2.1 Repeating the basics

First, lets get some code from the previous tutorial. We will use the same dataset and model, but we will refactor the code to use PyTorch Lightning.

In [None]:
# Install necessary packages (uncomment if needed)
# !pip install lightning wandb -q

In [1]:
import torch
import torch.nn as nn
import pandas as pd
from sklearn.datasets import load_breast_cancer
from torch.utils.data import Dataset, DataLoader

_ = torch.manual_seed(42)

class CancerDataset(Dataset):
    def __init__(self, features: pd.DataFrame, targets: pd.Series):
        self.features = torch.tensor(features.values, dtype=torch.float32)
        self.targets = torch.tensor(targets.values, dtype=torch.float32)

    def __len__(self):
        return len(self.features)

    def __getitem__(self, idx):
        x = self.features[idx]
        y = self.targets[idx]
        return x, y


class BreastCancerClassifier(nn.Module):
    def __init__(self):
        super().__init__()
        self.layer1 = nn.Linear(30, 20)
        self.layer2 = nn.Linear(20, 10)
        self.layer3 = nn.Linear(10, 1)
        self.activation = nn.ReLU()

    def forward(self, x):
        x = self.layer1(x)
        x = self.activation(x)
        x = self.layer2(x)
        x = self.activation(x)
        x = self.layer3(x)
        return x


cancer_data = load_breast_cancer(as_frame=True)
cancer_dataset = CancerDataset(cancer_data['data'], cancer_data['target'])

data_train, data_validation = torch.utils.data.random_split(cancer_dataset, [0.8, 0.2])
dataloader_train = DataLoader(data_train, batch_size=32, shuffle=True)
dataloader_validation = DataLoader(data_validation, batch_size=32, shuffle=False)

### 2.2.2 Setting up the Lightning module

Take a look again at the training loop code from the previous tutorial. It contained quite some boilerplate code that would be highly similar across different ML projects. We will now refactor this code to use PyTorch Lightning, which will help us to remove a lot of this boilerplate code.

To do this, we will create a new class that inherits from [`LightningModule`](https://pytorch-lightning.readthedocs.io/en/stable/common/lightning_module.html#lightningmodule). This class will contain all the code for training and validating our model. Similar to the `nn.Module` class, we will need to implement the `__init__` method to initialize our model and the `forward` method to define the forward pass. However, we will also need to implement some additional methods that are specific to PyTorch Lightning:

- `training_step`: This method will be called for each batch of training data. It will contain the code for the forward pass and the loss calculation.
- `validation_step`: This method will be called for each batch of validation data. It will contain the code for the forward pass and the loss calculation.
- `configure_optimizers`: This method will be called to configure the optimizer and the learning rate scheduler.


In [2]:
import lightning as L

class BreastCancerModule(L.LightningModule):
    def __init__(self, learning_rate=0.001):
        super().__init__()
        self.save_hyperparameters()
        self.model = BreastCancerClassifier()
        self.loss_fn = nn.BCEWithLogitsLoss()
        self.learning_rate = learning_rate
        self.optimizer = torch.optim.Adam

    def forward(self, x):
        return self.model(x).squeeze()

    def training_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = self.loss_fn(y_hat, y)
        self.log('train_loss', loss, prog_bar=True)
        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = self.loss_fn(y_hat, y)
        self.log('val_loss', loss, prog_bar=True)

        return loss

    def configure_optimizers(self):
        return self.optimizer(self.parameters(), lr=self.learning_rate)

Note that the resulting code is much cleaner and more organized than the training loop code we had before.

### 2.2.3 Setting up the Lightning trainer

All configuration of the training process is done in the [`Trainer`](https://pytorch-lightning.readthedocs.io/en/stable/common/trainer.html#trainer) class. This class will take care of the training and validation loops, as well as logging and checkpointing. We will create an instance of this class and pass it our model and the training and validation data loaders.

*Assignment: Browse through the documentation for the trainer class and try to understand the different arguments of the class. Pay special attention to the `accelerator` and `devices` arguments, which are some of the most useful features of PyTorch Lightning.* 

In [3]:
trainer = L.Trainer(
    max_epochs=10,
    accelerator="auto",  # Automatically selects GPU or CPU
)

/shared/home/ralfgabriels/.local/lib/python3.12/site-packages/lightning/fabric/plugins/environments/slurm.py:204: The `srun` command is available on your system but is not used. HINT: If your intention is to run Lightning on SLURM, prepend your python command with `srun` like so: srun python /shared/software/miniconda/envs/python-3.12/lib/pyth ...
Using default `ModelCheckpoint`. Consider installing `litmodels` package to enable `LitModelCheckpoint` for automatic upload to the Lightning model registry.
GPU available: False, used: False
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
/shared/home/ralfgabriels/.local/lib/python3.12/site-packages/lightning/pytorch/trainer/connectors/logger_connector/logger_connector.py:76: Starting from v1.9.0, `tensorboardX` has been removed as a dependency of the `lightning.pytorch` package, due to potential conflicts with other packages in the ML ecosystem. For this reason, `logger=True` will use `CSVLogger` as the default 

Note that the trainer immediately tells us whether we are using a GPU or not. This is a great feature of PyTorch Lightning, as it allows us to write code that is agnostic to the hardware we are using. We can run the same code on a CPU, a single GPU, or multiple GPUs without changing anything in our code.

### 2.2.4 Fitting the model

Now we can simply call the `fit` method of the trainer to start training our model. It takes the model, the training data loader, and the validation data loader as arguments.

In [None]:
trainer.fit(
    BreastCancerModule(learning_rate=0.001),
    train_dataloaders=dataloader_train,
    val_dataloaders=dataloader_validation,
)

While we had to take care of logging the progress ourselves in the previous tutorial, PyTorch Lightning takes care of this for us.

## 2.3 Logging with Weights & Biases

Weights & Biases (wandb) allows us to log and visualize training and validation metrics across different training runs. It also has a feature for hyperparameter tuning, called Sweeps, which allows us to automatically search for the best hyperparameters for our model.

To use wandb, first go to the [wandb website](https://wandb.ai/) and create an account. You can easily sign in with an existing GitHub, Google, or Microsoft account. After signing in, you will be taken to the dashboard, where you can create a new project. You can also create a new API key, which you will need to use wandb in your code.



In [None]:
import wandb
wandb.init(project="breast-cancer-classification")

Now we can add the wandb logger to our PyTorch Lightning trainer with the `logger`argument:

In [None]:
trainer = L.Trainer(
    max_epochs=10,
    accelerator="auto",
    logger=L.pytorch.loggers.WandbLogger(
        project="breast-cancer-classification",
        log_model=True,
    ))

trainer.fit(
    BreastCancerModule(learning_rate=0.001),
    train_dataloaders=dataloader_train,
    val_dataloaders=dataloader_validation,
)

wandb.finish()  # Finish the wandb run

Go to the run URL as logged by wandb and check out the results. You should see a new run with the same name as the one you used in the code. You can click on it to see the details of the run, including the training and validation metrics, the model checkpoints, and the hyperparameters used for the run.

## 2.4 Performing a hyperparameter sweep

Now that we have set up wandb, we can use it to perform a hyperparameter sweep. This will allow us to automatically search for the best hyperparameters for our model, such as number of layers, number of neurons per layer, learning rate, etc.

To do this, we must first update the module to accept hyperparameters as arguments. To make the number of layers and the number of neurons per layer configurable, we will implement a loop in the `__init__` method that creates the layers based on the hyperparameters.

In [None]:
class BreastCancerClassifier(nn.Module):
    def __init__(self, hidden_layers=1, hidden_neurons=10):
        super().__init__()
        self.layers = nn.ModuleList()
        self.layers.append(nn.Linear(30, hidden_neurons))
        for _ in range(hidden_layers - 1):
            self.layers.append(nn.Linear(hidden_neurons, hidden_neurons))
        self.layers.append(nn.Linear(hidden_neurons, 1))
        self.activation = nn.ReLU()

    def forward(self, x):
        # Iterate over all layers except the last one and add activations
        for layer in self.layers[:-1]:
            x = layer(x)
            x = self.activation(x)
        # Last layer without activation
        x = self.layers[-1](x)
        return x

We must also slightly modify the Lightning module to accept the hyperparameters as arguments and pass them on to the model and the optimizer.

In [None]:
class BreastCancerModule(L.LightningModule):
    def __init__(self, learning_rate=0.001, hidden_layers=1, hidden_neurons=10):
        super().__init__()
        self.save_hyperparameters()
        self.model = BreastCancerClassifier(hidden_layers, hidden_neurons)
        self.loss_fn = nn.BCEWithLogitsLoss()
        self.learning_rate = learning_rate
        self.optimizer = torch.optim.Adam

    def forward(self, x):
        return self.model(x).squeeze()

    def training_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = self.loss_fn(y_hat, y)
        self.log('train_loss', loss, prog_bar=True)
        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = self.loss_fn(y_hat, y)
        self.log('val_loss', loss, prog_bar=True)

        return loss

    def configure_optimizers(self):
        return self.optimizer(self.parameters(), lr=self.learning_rate)

The following function will take the hyperparameters as arguments and setup the training run:

In [None]:
def sweep(*args, **kwargs):
    # Create the trainer
    trainer = L.Trainer(
        max_epochs=10,
        accelerator="auto",
        logger=L.pytorch.loggers.WandbLogger(
            project="breast-cancer-classification",
            log_model=True,
        ))

    trainer.fit(
        BreastCancerModule(*args, **kwargs),
        train_dataloaders=dataloader_train,
        val_dataloaders=dataloader_validation,
    )
    wandb.finish()

Now, we can create a sweep configuration. This configuration will define the hyperparameters we want to search over and the values we want to try. We will use the `wandb` library to create a sweep configuration. To configure sweeps, check out the [wandb documentation](https://docs.wandb.ai/guides/sweeps).

In [None]:
sweep_configuration = {
    "method": "random",
    "metric": {"goal": "minimize", "name": "val_loss"},
    "parameters": {
        "hidden_layers": {
            "values": [0, 1, 2]
        },
        "hidden_neurons": {
            "values": [2, 4, 8, 16, 32]
        },
        "learning_rate": {
            "min": 0.0001,
            "max": 0.01
        },
    },
}

Here, we initialize the sweep with its configuration:

In [None]:
sweep_id = wandb.sweep(sweep=sweep_configuration, project="breast-cancer-classification")

With the `wandb.agent` function, we can start the sweep. Note that the `count` argument defines how many runs we want to perform. If it would not be set, the sweep would run indefinitely and keep testing different hyperparameters until we stop it manually.

In [None]:
wandb.agent(sweep_id, function=sweep, count=20)

*Assignment: Go to the wandb dashboard and check out the results of the sweep. You should see a new sweep with the same name as the one you used in the code. You can click on it to see the details of the sweep, including the training and validation metrics, the model checkpoints, and the hyperparameters used for each run.*