# Learn PyTorch Lightning: A Lightweight PyTorch Wrapper

PyTorch Lightning is a lightweight PyTorch wrapper that helps organize your PyTorch code, making it more readable and maintainable. You can read about how to convert your PyTorch code: https://lightning.ai/docs/pytorch/stable/starter/converting.html. It abstracts away much of the boilerplate code, allowing you to focus on the core logic of your models. This tutorial will guide you through the basics of PyTorch Lightning.

## 1. Introduction to PyTorch Lightning

PyTorch Lightning separates the research code from the engineering code, helping you write scalable and more readable code. It automates most of the training loop and other common functionalities, making it easier to replicate results and scale your projects.

## 2. Installing PyTorch Lightning

Before you begin, you need to have PyTorch installed. Then, install PyTorch Lightning via pip:

```bash
pip install pytorch-lightning
```

## 3. Creating a Lightning Module

A Lightning Module is where you define your model, just like a standard PyTorch `nn.Module`, but you also define the training step, validation step, etc. Here's a simple example:

In [8]:
import pytorch_lightning as pl
import torch
from torch import nn
from torch.nn import functional as F


class LitModel(pl.LightningModule):
    def __init__(self):
        super().__init__()
        self.layer = nn.Linear(28 * 28, 10)

    def forward(self, x):
        return self.layer(x.view(x.size(0), -1))

    def training_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = F.cross_entropy(logits, y)
        return loss

    # Uncomment to add validation step
    # def validation_step(self, batch, batch_idx):
    #     x, y = batch
    #     logits = self(x)
    #     loss = F.cross_entropy(logits, y)
    #     # Add logging
    #     self.log('val_loss', loss)

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=0.001)


model = LitModel()

## 4. Data Preparation

PyTorch Lightning works with the standard PyTorch DataLoader. Let's load the MNIST dataset as an example:

In [None]:
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

# Define data loaders
train_loader = DataLoader(
    datasets.MNIST("", train=True, download=True, transform=transforms.ToTensor()), batch_size=32, shuffle=True
)

# Check data
x, y = next(iter(train_loader))
print(x.shape, y.shape)

## 5. Training the Model

Training a model with PyTorch Lightning is straightforward. You just need to initialize a `Trainer` and call the `fit` method:

In [None]:
# Initialize our model
model = LitModel()
print(model)

# Initialize a trainer
trainer = pl.Trainer(max_epochs=3)

# Train the model
trainer.fit(model, train_loader)

## 6. Validation and Testing

You can easily add validation and test steps in your `LitModel`. For validation, implement the `validation_step` method:

```python
class LitModel(pl.LightningModule):
    # ...
    def validation_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = F.cross_entropy(logits, y)
        # Add logging
        self.log('val_loss', loss)
```

Use a similar approach for the `test_step`.

## 7. Logging and Callbacks

PyTorch Lightning comes with built-in support for logging and callbacks. You can use TensorBoard, or other loggers like MLFlow, Comet, etc.

In [None]:
from pytorch_lightning.callbacks import ModelCheckpoint

checkpoint_callback = ModelCheckpoint(
    monitor="val_loss", dirpath="./my_model", filename="sample-mnist-{epoch:02d}-{val_loss:.2f}"
)

trainer = pl.Trainer(max_epochs=3, callbacks=[checkpoint_callback])

## 8. Advanced Features

PyTorch Lightning also supports distributed training, mixed precision training, and more. These features can be easily activated in the Trainer (make sure you have the required hardware and software installed). Here for example, we would start training on two GPUs and use mixed precision training:

```python
trainer = pl.Trainer(gpus=2, precision=16)
```

If you don't know the number of GPUs available, you can set `gpus=-1` and PyTorch Lightning will automatically use all available GPUs.

For more options on the `Trainer`, check out the [documentation](https://lightning.ai/docs/pytorch/stable/common/trainer.html).

## Conclusion

PyTorch Lightning is a powerful tool for organizing PyTorch code and making it more efficient and maintainable. It abstracts away the engineering details, allowing you to focus on the research part. This tutorial covered the basics, but there's a lot more to explore, including advanced features like distributed training, and integrations with other tools and libraries. Be sure to check out the [official documentation](https://www.pytorchlightning.ai/) for more information.