<a href="https://colab.research.google.com/github/arkeodev/pytorch-tutorial/blob/main/pytorch_lightning.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Lightning

## Introduction

Lightning is a high-level framework that builds on top of PyTorch, one of the most popular deep learning libraries. It's designed to decouple the science code from the engineering code, helping researchers and developers focus on the core aspects of their models by abstracting away the boilerplate code typically associated with model training, validation, and testing. This approach not only makes the code more readable and maintainable but also significantly speeds up the development process for complex deep learning projects.

## Main Advantages of Using Lightning Over Plain PyTorch

1. **Reduced Boilerplate Code**: Lightning automates much of the setup code needed in PyTorch, such as training loops, validation loops, and testing loops, allowing developers to focus on the model's architecture and data rather than the mechanics of the training process.

2. **Reproducibility**: It ensures experiments are more reproducible by standardizing the way models are trained. This is achieved through a structured framework that encourages best practices and reduces the chances of making errors.

3. **Scalability**: Lightning simplifies the process of scaling your models to run on more GPUs, TPUs, or across multiple nodes. This makes it easier to scale your experiments without having to deeply understand distributed computing.

4. **Flexibility**: Despite the high-level abstractions, Lightning offers flexibility, allowing advanced users to customize the training loop and other components when needed. This means you can start with the simple, high-level interface and dive deeper as your project's complexity grows.

5. **Built-in Advanced Features**: Lightning comes with many advanced features out of the box, such as support for mixed precision training, which can significantly speed up computations and reduce memory usage, and automatic checkpointing, which makes it easy to save and resume training sessions.

6. **Community and Ecosystem**: Lightning has a vibrant and growing community, with a wide range of plugins and integrations available. This ecosystem includes support for popular tools and platforms, making it easier to incorporate things like logging, monitoring, and model serving into your workflow.

In summary, Lightning is designed to make deep learning projects simpler, faster, and more efficient, without sacrificing the power and flexibility that PyTorch provides. By abstracting away the engineering details, it enables researchers and developers to allocate more time to the scientific aspects of their projects, resulting in faster experimentation and development cycles.

## Imports

In [None]:
! pip install lightning -q

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.1/2.1 MB[0m [31m8.4 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m841.5/841.5 kB[0m [31m16.0 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m801.6/801.6 kB[0m [31m24.7 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m23.7/23.7 MB[0m [31m31.2 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m823.6/823.6 kB[0m [31m19.2 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m14.1/14.1 MB[0m [31m33.9 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m731.7/731.7 MB[0m [31m912.3 kB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m410.6/410.6 MB[0m [31m1.3 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━

In [None]:
import torch
from torch import nn, optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, DistributedSampler, random_split
import torch.distributed as dist

from torchvision.datasets import MNIST
from torchvision import transforms

import lightning as L

import numpy as np

import random
import os

## Core Concepts of Lightning

### LightningModule

The `LightningModule` is a central concept in Lightning, acting as a comprehensive encapsulation of the PyTorch `nn.Module`. It serves as the backbone for organizing your model's computations, including the forward pass, and it also integrates the training, validation, and testing steps within a single class. This approach significantly simplifies the model development process by structuring the code in a way that separates the computational part of the model from the experimental setup.

A `LightningModule` defines:
- **Model Architecture**: How the inputs are processed to produce outputs, encapsulated in the `forward` method.
- **Training Step**: The logic for a single iteration in the training loop, including forward pass, loss calculation, and backpropagation.
- **Validation and Testing Steps**: Procedures for evaluating the model on validation and test datasets to monitor performance and prevent overfitting.
- **Optimizers and Schedulers**: Configuration of optimizers and learning rate schedulers, specifying how weights are updated and how the learning rate changes over time.

By integrating these aspects into a unified class, `LightningModule` streamlines model development, making the code more modular, easier to read, and maintain, while also promoting best practices in deep learning research and development.

### Trainer

The `Trainer` in Lightning is a powerful engine that abstracts the complexity of writing the training loop and integrates your PyTorch code with the rich ecosystem of Lightning features. It is responsible for managing the training process, including running the training, validation, and testing loops, handling device placement (CPU, GPU, TPU), and facilitating distributed training.

Key features of the `Trainer` include:
- **Automatic Training Loop**: It automates the training process, managing everything from the start of training to its conclusion, including calling the appropriate steps defined in the `LightningModule`.
- **Checkpointing**: Automatically saves and, if needed, resumes the model's state from a checkpoint, ensuring long experiments can be paused and restarted without loss of progress.
- **Logging and Monitoring**: Integrates with popular logging and visualization tools (e.g., TensorBoard, MLFlow), enabling easy tracking of experiments and model performance.
- **Distributed Training**: Simplifies scaling up your training to multiple GPUs, TPUs, or nodes without the need to deeply understand the underlying distributed computing frameworks.

### DataModule

The `DataModule` is a data handling class that abstracts the complexity of data loading, preparation, and preprocessing in Lightning. It allows for a clean separation of data-related logic from the modeling code, making datasets reusable and shareable across projects.

A datamodule encapsulates the five steps involved in data processing in PyTorch:
- Download / tokenize / process
- Clean and (maybe) save to disk.
- Load inside Dataset.
- Apply transforms (rotate, tokenize, etc…).
- Wrap inside a DataLoader.

By encapsulating data-related tasks, the `DataModule` promotes a more organized and modular approach to handling datasets in PyTorch projects, making it easier to adapt to new data sources or experiment with different preprocessing techniques.

In [None]:
class MNISTDataModule(L.LightningDataModule):
    def __init__(self, data_dir: str = "./"):
        super().__init__()
        self.data_dir = data_dir
        self.transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])

    def prepare_data(self):
        # download
        MNIST(self.data_dir, train=True, download=True)
        MNIST(self.data_dir, train=False, download=True)

    def setup(self, stage: str):
        # Assign train/val datasets for use in dataloaders
        if stage == "fit":
            mnist_full = MNIST(self.data_dir, train=True, transform=self.transform)
            self.mnist_train, self.mnist_val = random_split(
                mnist_full, [55000, 5000], generator=torch.Generator().manual_seed(42)
            )

        # Assign test dataset for use in dataloader(s)
        if stage == "test":
            self.mnist_test = MNIST(self.data_dir, train=False, transform=self.transform)

        if stage == "predict":
            self.mnist_predict = MNIST(self.data_dir, train=False, transform=self.transform)

    def train_dataloader(self):
        return DataLoader(self.mnist_train, batch_size=32)

    def val_dataloader(self):
        return DataLoader(self.mnist_val, batch_size=32)

    def test_dataloader(self):
        return DataLoader(self.mnist_test, batch_size=32)

    def predict_dataloader(self):
        return DataLoader(self.mnist_predict, batch_size=32)

In Lightning, the flow of execution for data-related methods within a `LightningDataModule` follows a specific order, particularly when you start training a model. Understanding this order can help you design your data module more effectively. Based on the implementation you provided and the typical lifecycle of a PyTorch Lightning training process, here's the order in which these methods are called:

1. **`prepare_data`:** This method is called first and only once globally. It's designed for operations that need to be done once per dataset, such as downloading or preparing data. In a distributed setting, this ensures that the data is not redundantly downloaded by each process. Lightning takes care to call this method only on one process when running in distributed modes.

2. **`setup`:** After `prepare_data`, the `setup` method is called. This method is executed once per process. If you're running your model on multiple GPUs or nodes, `setup` will be called separately in each process. This method is where you should split your dataset and apply any transformations or preprocessing steps that are necessary. The `setup` method can be called with an optional `stage` argument (such as 'fit', 'validate', 'test', or 'predict'), allowing you to customize setup behavior depending on the training phase. If no `stage` is specified, it's assumed that the setup is for training.

3. **DataLoader methods (`train_dataloader`, `val_dataloader`, `test_dataloader`, `predict_dataloader`):** These methods are called after `setup` and right before the corresponding phase begins (e.g., training, validation, testing). They are responsible for returning the PyTorch `DataLoader` instances that will be used to load data during the model training or inference. These methods can be called multiple times throughout the lifecycle of training and evaluation, especially if the training process includes multiple epochs.

## Key Features of Lightning


### 1. Automation of Common Training Procedures

#### Without Lightning


In standard PyTorch, setting up a training loop requires manually coding the entire process, including forward passes, calculating loss, backpropagation, and updating model parameters.

In [None]:
# Define a simple neural network
class SimpleNet(nn.Module):
    def __init__(self):
        super(SimpleNet, self).__init__()
        self.layer1 = nn.Linear(28 * 28, 128)
        self.layer2 = nn.Linear(128, 10)

    def forward(self, x):
        x = torch.flatten(x, 1)
        x = nn.functional.relu(self.layer1(x))
        x = self.layer2(x)
        return x

# Load data
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
train_data = datasets.MNIST('data', train=True, download=True, transform=transform)
train_loader = DataLoader(train_data, batch_size=32, shuffle=True)

# Initialize network and optimizer
model = SimpleNet()
optimizer = optim.Adam(model.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss()

# Training loop
for epoch in range(10): # train for 10 epochs
    for images, labels in train_loader:
        optimizer.zero_grad()
        output = model(images)
        loss = criterion(output, labels)
        loss.backward()
        optimizer.step()
    print(f"Epoch {epoch}, Loss: {loss.item()}")

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to data/MNIST/raw/train-images-idx3-ubyte.gz


100%|██████████| 9912422/9912422 [00:00<00:00, 148260703.45it/s]

Extracting data/MNIST/raw/train-images-idx3-ubyte.gz to data/MNIST/raw






Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to data/MNIST/raw/train-labels-idx1-ubyte.gz


100%|██████████| 28881/28881 [00:00<00:00, 22650653.30it/s]


Extracting data/MNIST/raw/train-labels-idx1-ubyte.gz to data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to data/MNIST/raw/t10k-images-idx3-ubyte.gz


100%|██████████| 1648877/1648877 [00:00<00:00, 132227432.40it/s]

Extracting data/MNIST/raw/t10k-images-idx3-ubyte.gz to data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz





Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to data/MNIST/raw/t10k-labels-idx1-ubyte.gz


100%|██████████| 4542/4542 [00:00<00:00, 14323705.84it/s]


Extracting data/MNIST/raw/t10k-labels-idx1-ubyte.gz to data/MNIST/raw

Epoch 0, Loss: 0.14844422042369843
Epoch 1, Loss: 0.1093560978770256
Epoch 2, Loss: 0.10502023249864578
Epoch 3, Loss: 0.19714826345443726
Epoch 4, Loss: 0.010195491835474968
Epoch 5, Loss: 0.13798223435878754
Epoch 6, Loss: 0.02335546910762787
Epoch 7, Loss: 0.23562780022621155
Epoch 8, Loss: 0.004139881581068039
Epoch 9, Loss: 0.005697409622371197


#### With Lightning

Lightning automates the training process, requiring you to define the training step, and it handles the rest.

In [None]:
class LightningNet(L.LightningModule):
    def __init__(self):
        super(LightningNet, self).__init__()
        self.layer1 = nn.Linear(28 * 28, 128)
        self.layer2 = nn.Linear(128, 10)

    def forward(self, x):
        x = torch.flatten(x, 1)
        x = nn.functional.relu(self.layer1(x))
        x = self.layer2(x)
        return x

    def training_step(self, batch, batch_idx):
        images, labels = batch
        output = self(images)
        loss = nn.CrossEntropyLoss()(output, labels)
        return loss

    def configure_optimizers(self):
        return optim.Adam(self.parameters(), lr=0.001)

# Data module for organizing data loading
class MNISTDataModule(L.LightningDataModule):
    def train_dataloader(self):
        transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
        train_data = datasets.MNIST('data', train=True, download=True, transform=transform)
        return DataLoader(train_data, batch_size=32, shuffle=True)

# Setup model and data module
model = LightningNet()
mnist_data = MNISTDataModule()

# Train model
trainer = L.Trainer(max_epochs=10)
trainer.fit(model, mnist_data)


INFO:pytorch_lightning.utilities.rank_zero:GPU available: False, used: False
INFO:pytorch_lightning.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO:pytorch_lightning.utilities.rank_zero:IPU available: False, using: 0 IPUs
INFO:pytorch_lightning.utilities.rank_zero:HPU available: False, using: 0 HPUs
INFO: 
  | Name   | Type   | Params
----------------------------------
0 | layer1 | Linear | 100 K 
1 | layer2 | Linear | 1.3 K 
----------------------------------
101 K     Trainable params
0         Non-trainable params
101 K     Total params
0.407     Total estimated model params size (MB)
INFO:lightning.pytorch.callbacks.model_summary:
  | Name   | Type   | Params
----------------------------------
0 | layer1 | Linear | 100 K 
1 | layer2 | Linear | 1.3 K 
----------------------------------
101 K     Trainable params
0         Non-trainable params
101 K     Total params
0.407     Total estimated model params size (MB)


Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to data/MNIST/raw/train-images-idx3-ubyte.gz


100%|██████████| 9912422/9912422 [00:00<00:00, 95724665.91it/s]


Extracting data/MNIST/raw/train-images-idx3-ubyte.gz to data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to data/MNIST/raw/train-labels-idx1-ubyte.gz


100%|██████████| 28881/28881 [00:00<00:00, 36084508.14it/s]


Extracting data/MNIST/raw/train-labels-idx1-ubyte.gz to data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to data/MNIST/raw/t10k-images-idx3-ubyte.gz


100%|██████████| 1648877/1648877 [00:00<00:00, 29697997.19it/s]


Extracting data/MNIST/raw/t10k-images-idx3-ubyte.gz to data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to data/MNIST/raw/t10k-labels-idx1-ubyte.gz


100%|██████████| 4542/4542 [00:00<00:00, 6137412.62it/s]


Extracting data/MNIST/raw/t10k-labels-idx1-ubyte.gz to data/MNIST/raw



Training: |          | 0/? [00:00<?, ?it/s]

INFO:pytorch_lightning.utilities.rank_zero:`Trainer.fit` stopped: `max_epochs=10` reached.


#### Comparison

- **Automation and Simplification**: With Lightning, the training process is simplified as you only need to define the `training_step` and the `configure_optimizers` method. The `Trainer` object then automates the training loop, including forward and backward passes, optimization, and more.

- **Reduction of Boilerplate Code**: Lightning significantly reduces the amount of boilerplate code required, especially in training loops. This makes the code more readable and easier to maintain.

- **Focus on Model and Data**: Lightning encourages a separation of concerns, allowing you to focus on the model (`LightningModule`) and data (`DataModule`) separately. This results in cleaner, more modular code that's easier to debug and extend.

- **Built-in Best Practices**: Lightning incorporates many best practices by default, such as gradient clipping and logging, reducing the chance of common mistakes and improving the efficiency of the development process.

### 2. Reproducibility


#### Without Lightning



When not using Lightning, you would manually set seeds for all the relevant libraries you are using to ensure reproducibility. This often includes PyTorch itself, NumPy (if used for data manipulation), and Python’s random module.

In [None]:
def seed_everything(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)  # For CUDA-enabled GPUs
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

seed_everything(42)

#### With Lightning


Lightning provides a straightforward way to fix seeds across all the necessary libraries with a single line of code, ensuring reproducibility across runs.

In [None]:
L.seed_everything(42)

INFO: Seed set to 42
INFO:lightning.fabric.utilities.seed:Seed set to 42


42

### 3. Easy Experiment Tracking

#### Without Lightning

Tracking experiments without Lightning often requires manually logging metrics, model configurations, and other details to a file, a database, or a tool like TensorBoard. This can quickly become cumbersome and error-prone as the complexity of the experiments grows.

```python
from torch.utils.tensorboard import SummaryWriter

writer = SummaryWriter()

for epoch in range(10):  # Example training loop
    # Training steps
    # ...
    writer.add_scalar('Loss/train', loss, epoch)
    writer.add_scalar('Accuracy/train', accuracy, epoch)

writer.close()
```

#### With Lightning

Lightning simplifies experiment tracking and integrates seamlessly with popular tools like TensorBoard, MLFlow, Comet ML, and others. By using loggers, you can easily record metrics, hyperparameters, model graphs, and more without cluttering your model code.

```python
from lightning.pytorch import Trainer
from lightning.pytorch.loggers import TensorBoardLogger

# Define a logger
logger = TensorBoardLogger("tb_logs", name="my_model")

# Train model with logger
trainer = Trainer(logger=logger, max_epochs=10)
trainer.fit(model)

# Access the experiment URL if using an online logger like MLFlow, Comet ML, etc.
print("Experiment URL:", logger.experiment.url)
```



### 4. Out-of-the-Box Support for Multi-GPU, TPU, and Distributed Training

#### Without Lightning

Implementing multi-GPU, TPU, or distributed training in PyTorch requires a deep understanding of the underlying mechanisms like `torch.nn.DataParallel`, `torch.nn.parallel.DistributedDataParallel`, or setting up TPU environments. This process can be complex and error-prone, especially for those new to distributed computing concepts.

```python
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, DistributedSampler
import torch.distributed as dist
import os

def setup(rank, world_size):
    os.environ['MASTER_ADDR'] = 'localhost'
    os.environ['MASTER_PORT'] = '12355'
    dist.init_process_group("gloo", rank=rank, world_size=world_size)

def cleanup():
    dist.destroy_process_group()

class MyModel(nn.Module):
    # Model definition

def train(rank, world_size):
    setup(rank, world_size)
    # Model, optimizer, data loader setup
    model = MyModel()
    model = nn.parallel.DistributedDataParallel(model, device_ids=[rank])
    # Training loop
    cleanup()

world_size = 2
torch.multiprocessing.spawn(train, args=(world_size,), nprocs=world_size, join=True)
```

#### With Lightning

Lightning simplifies the process, allowing you to easily scale your models across multiple GPUs, TPUs, or nodes with minimal changes to your code. The `Trainer` class handles the complexity of distributed training.

`DDP Communication hooks` is an interface to control how gradients are communicated across workers, overriding the standard allreduce in `DistributedDataParallel`. This allows you to enable performance improving communication hooks when using multiple nodes. Enable `FP16 Compress Hook` for multi-node throughput improvement:


```python
import lightning as L
from lightning.pytorch.strategies import DDPStrategy
from torch.distributed.algorithms.ddp_comm_hooks import default_hooks as default

model = MyModel()
trainer = L.Trainer(
    accelerator="gpu",
    devices=4,  # Number of GPUs
    strategy=DDPStrategy(ddp_comm_hook=default.fp16_compress_hook)
)

trainer.fit(model)
```

### 5. Integration with Popular Computing Resources and Environments

#### Without Lightning

Setting up your PyTorch models to run on cloud platforms (e.g., AWS, GCP, Azure) or specialized computing environments (e.g., HPC clusters) often requires a significant amount of boilerplate code and configuration. This includes managing environments, dependencies, data storage, and compute resources.

#### With Lightning

Lightning offers integrations with a variety of computing resources and environments, streamlining the process of deploying and running your models. Whether it's on a cloud provider or an HPC cluster, Lightning's Trainer and its ecosystem are designed to work seamlessly with minimal configuration.

### 6. Automatic Checkpointing and Resuming of Training


#### Without Lightning

Implementing automatic checkpointing and resuming in plain PyTorch requires manually saving the model and optimizer state at specific intervals and then writing additional code to load these checkpoints if training is interrupted or needs to be resumed later.

```python
import torch

# Save checkpoint
def save_checkpoint(model, optimizer, epoch, filepath):
    torch.save({
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
    }, filepath)

# Load checkpoint
def load_checkpoint(filepath, model, optimizer):
    checkpoint = torch.load(filepath)
    model.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    return checkpoint['epoch']

# Example usage
model = ...
optimizer = ...
epoch_start = load_checkpoint('path_to_checkpoint.pt', model, optimizer)

for epoch in range(epoch_start, num_epochs):
    # Training loop...
    save_checkpoint(model, optimizer, epoch, 'path_to_checkpoint.pt')
```

#### With Lightning

Lightning simplifies checkpointing and resuming. The `Trainer` class has built-in support for checkpointing, automatically saving the model, optimizer, and training state at specified intervals or based on performance metrics. Resuming is as simple as providing the checkpoint path when initializing the trainer.



```python
from lightning.pytorch import Trainer
from lightning.pytorch.callbacks import ModelCheckpoint

# Setup model checkpointing
checkpoint_callback = ModelCheckpoint(
    monitor='val_loss', # Monitor validation loss for checkpointing
    dirpath='my_model/',
    filename='model-{epoch:02d}-{val_loss:.2f}',
    save_top_k=3, # Save the top 3 models according to val_loss
    mode='min', # Minimize val_loss
)

# Initialize trainer with checkpoint callback
trainer = Trainer(
    callbacks=[checkpoint_callback],
    resume_from_checkpoint='my_model/model-epoch=02-val_loss=0.02.ckpt', # Optional: Path to resume from a specific checkpoint
)

# Train model
model = ...
trainer.fit(model)
```

### 7. Built-in Support for Mixed-Precision Training


#### Without Lightning


In vanilla PyTorch, enabling mixed-precision training involves using `torch.cuda.amp` for automatic mixed precision (AMP), which can reduce memory usage and speed up training times on compatible hardware. This requires manual management of the AMP context.


```python
from torch.cuda.amp import GradScaler, autocast

model = ...
optimizer = ...
scaler = GradScaler()

for inputs, labels in data_loader:
    optimizer.zero_grad()
    with autocast():
        outputs = model(inputs)
        loss = loss_fn(outputs, labels)
    scaler.scale(loss).backward()
    scaler.step(optimizer)
    scaler.update()
```


#### With Lightning


Lightning offers a straightforward way to enable mixed-precision training with a single flag in the `Trainer`. This automatically handles the AMP context and scaler under the hood.

```python
from lightning.pytorch import Trainer

# Enable mixed precision training
trainer = Trainer(precision=16)

# Train model
model = ...
trainer.fit(model)
```

## Setting Up a Basic Lightning Project

> Indented block



Setting up a basic project with Lightning involves a few straightforward steps that help streamline the development process for deep learning models. Here's a guide to get you started:

#### 1. Install Lightning

In [None]:
# ! pip install lightning -q

#### 2. Define Your Model

   Create a Python file (e.g., `model.py`) and define your model by subclassing `L.LightningModule`.
   
   Implement the required methods such as `__init__`, `forward`, `training_step`, and `configure_optimizers`.

In [None]:
class LitModel(L.LightningModule):
    def __init__(self):
        super().__init__()
        # Define model layers
        self.layer = nn.Linear(28 * 28, 10)

    def forward(self, x):
        # Forward pass
        x = x.view(x.size(0), -1)  # Reshape input to [batch_size, 784]
        return torch.relu(self.layer(x))

    def training_step(self, batch, batch_idx):
        # Training logic
        x, y = batch
        y_hat = self(x)
        loss = nn.functional.cross_entropy(y_hat, y)
        return loss

    def configure_optimizers(self):
        # Optimizers
        return torch.optim.Adam(self.parameters(), lr=0.02)

#### 3. Prepare Your Data

   Define your data using PyTorch's `DataLoader` or Lightning's `DataModule`.
   
   A `DataModule` is a shareable, reusable class that encapsulates all data loading logic.

In [None]:
class MNISTDataModule(L.LightningDataModule):
    def __init__(self):
        super().__init__()
        self.transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])

    def prepare_data(self):
        # Download only once
        MNIST('data', train=True, download=True)
        MNIST('data', train=False, download=True)

    def setup(self, stage=None):
        # Transform and split data
        mnist_full = MNIST('data', train=True, transform=self.transform)
        self.mnist_train, self.mnist_val = random_split(mnist_full, [55000, 5000])
        self.mnist_test = MNIST('data', train=False, transform=self.transform)

    def train_dataloader(self):
        return DataLoader(self.mnist_train, batch_size=32)

    def val_dataloader(self):
        return DataLoader(self.mnist_val, batch_size=32)

    def test_dataloader(self):
        return DataLoader(self.mnist_test, batch_size=32)

#### 4. Train Your Model

   Use the `Trainer` class to train your model. Specify any configurations like the number of epochs, GPUs, etc.

In [None]:
from lightning.pytorch import Trainer

model = LitModel()
mnist_data = MNISTDataModule()
trainer = Trainer(max_epochs=2)
trainer.fit(model, mnist_data)

INFO:pytorch_lightning.utilities.rank_zero:GPU available: False, used: False
INFO:pytorch_lightning.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO:pytorch_lightning.utilities.rank_zero:IPU available: False, using: 0 IPUs
INFO:pytorch_lightning.utilities.rank_zero:HPU available: False, using: 0 HPUs
/usr/local/lib/python3.10/dist-packages/lightning/pytorch/trainer/configuration_validator.py:72: You passed in a `val_dataloader` but have no `validation_step`. Skipping val loop.
INFO: 
  | Name  | Type   | Params
---------------------------------
0 | layer | Linear | 7.9 K 
---------------------------------
7.9 K     Trainable params
0         Non-trainable params
7.9 K     Total params
0.031     Total estimated model params size (MB)
INFO:lightning.pytorch.callbacks.model_summary:
  | Name  | Type   | Params
---------------------------------
0 | layer | Linear | 7.9 K 
---------------------------------
7.9 K     Trainable params
0         Non-trainable params
7.9 K     

Training: |          | 0/? [00:00<?, ?it/s]

INFO:pytorch_lightning.utilities.rank_zero:`Trainer.fit` stopped: `max_epochs=2` reached.


## Conclusion

Lightning streamlines deep learning development by abstracting boilerplate code, enforcing best practices, and simplifying complexity. It supports scalable training across multiple GPUs and TPUs effortlessly, enhances reproducibility with features like fixed seeds, and remains flexible for custom needs. The active community and rich ecosystem provide extensive resources and support, making Lightning a powerful tool for efficient and reliable deep learning projects.

## Resources and Further Reading

For comprehensive information and resources on Lightning, here are the key places to look:

- **Documentation and Tutorials**: The [official documentation](https://lightning.ai/docs/pytorch/stable/) is a comprehensive resource for getting started with Lightning, offering detailed guides, API references, and tutorials for users of all levels.


