# Tutorial: From PyTorch to PyTorch Lightning 🚀

This notebook provides a detailed walkthrough on how to convert a standard PyTorch project to PyTorch Lightning, using the **Fashion MNIST** dataset.

## Introduction to PyTorch Lightning

PyTorch is a fantastic and flexible library for building deep learning models. However, when projects grow, you often find yourself writing the same boilerplate code again and again—training loops, validation loops, device management (`.to(device)`), etc.

**PyTorch Lightning** is a lightweight wrapper on top of PyTorch that organizes your code and automates the engineering and training loops. This lets you focus on the research and modeling parts.

**Key Benefits:**
* **Organized Code:** Your code is neatly structured into a `LightningModule` (the model and logic) and a `LightningDataModule` (the data handling).
* **Boilerplate Free:** Lightning handles the training, validation, and testing loops for you. No more manual `for` loops!
* **Hardware Agnostic:** The same code runs on CPUs, GPUs, or TPUs with a simple flag change (e.g., `accelerator='gpu', devices=4`).
* **Reproducibility:** Lightning helps make your experiments more reproducible.

In this tutorial, we will:
1.  Build and train a simple CNN on the **Fashion MNIST** dataset using **standard PyTorch**.
2.  Refactor the same project using **PyTorch Lightning**, highlighting the changes and benefits.

--- 
## Part 1: The Standard PyTorch Way

First, let's build and train our model using only plain PyTorch. This will be our baseline.

### 1.1 Imports and Setup

In [None]:
import torch
from torch import nn
from torch.utils.data import DataLoader, random_split
from torchvision import datasets
from torchvision.transforms import ToTensor
import torch.nn.functional as F
import torch.optim as optim
import time

# Set device
device = "cuda" if torch.cuda.is_available() else "cpu" # mps for apple silicon
print(f"Using {device} device")

### 1.2 The Model

Here we define a simple Convolutional Neural Network (CNN).

In [None]:
# A simple CNN model
class SimpleCNN(nn.Module):
    def __init__(self):
        super().__init__()
        # Convolutional layers
        self.conv1 = nn.Conv2d(1, 32, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
        # Pooling layer
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
        # Fully connected layers
        # The input features are calculated as 64 (channels) * 7 * 7 (image size after 2 pooling layers)
        self.fc1 = nn.Linear(64 * 7 * 7, 128)
        self.fc2 = nn.Linear(128, 10) # 10 output classes for Fashion MNIST

    def forward(self, x):
        # Apply conv1 -> ReLU -> pool
        x = self.pool(F.relu(self.conv1(x)))
        # Apply conv2 -> ReLU -> pool
        x = self.pool(F.relu(self.conv2(x)))
        # Flatten the image tensor for the fully connected layers
        x = x.view(-1, 64 * 7 * 7)
        # Apply fc1 -> ReLU
        x = F.relu(self.fc1(x))
        # Apply output layer
        x = self.fc2(x)
        return x

### 1.3 Data Loading

In [None]:
# --- Data Loading ---

# Download training data from open datasets.
training_data_full = datasets.FashionMNIST(
    root="data",
    train=True,
    download=True,
    transform=ToTensor(),
)

# Download test data from open datasets.
test_data = datasets.FashionMNIST(
    root="data",
    train=False,
    download=True,
    transform=ToTensor(),
)

# Split the full training data into a smaller training set and a validation set
train_size = int(0.8 * len(training_data_full))
val_size = len(training_data_full) - train_size
training_data, val_data = random_split(training_data_full, [train_size, val_size])

# --- Create DataLoaders ---
batch_size = 64
train_dataloader = DataLoader(training_data, batch_size=batch_size, shuffle=True)
val_dataloader = DataLoader(val_data, batch_size=batch_size)
test_dataloader = DataLoader(test_data, batch_size=batch_size)

### 1.4 The Training Loop (The Boilerplate!)
This is the core part that PyTorch Lightning will automate for us. Notice how much manual work is involved: moving data to the device, zeroing gradients, backpropagation, and manually looping through epochs and batches.

In [None]:
# --- Instantiate Model, Loss, and Optimizer ---
model = SimpleCNN().to(device)
loss_fn = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-3)

# --- Training and Validation Loop ---
num_epochs = 5

for epoch in range(num_epochs):
    start_time = time.time()
    
    # --- Training ---
    model.train() # Set model to training mode
    train_loss = 0.0
    for batch, (X, y) in enumerate(train_dataloader):
        # Move data to the selected device
        X, y = X.to(device), y.to(device)

        # 1. Forward pass
        pred = model(X)
        loss = loss_fn(pred, y)
        train_loss += loss.item()

        # 2. Backward pass and optimization
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    
    avg_train_loss = train_loss / len(train_dataloader)

    # --- Validation ---
    model.eval() # Set model to evaluation mode
    val_loss = 0.0
    correct = 0
    total = 0
    with torch.no_grad(): # No need to calculate gradients during validation
        for X, y in val_dataloader:
            X, y = X.to(device), y.to(device)
            pred = model(X)
            val_loss += loss_fn(pred, y).item()
            
            # Calculate accuracy
            _, predicted = torch.max(pred.data, 1)
            total += y.size(0)
            correct += (predicted == y).sum().item()
            
    avg_val_loss = val_loss / len(val_dataloader)
    accuracy = 100 * correct / total
    
    end_time = time.time()
    epoch_mins, epoch_secs = divmod(end_time - start_time, 60)

    print(f'Epoch: {epoch+1:02} | Time: {int(epoch_mins)}m {epoch_secs:.2f}s')
    print(f'\tTrain Loss: {avg_train_loss:.3f}')
    print(f'\t Val. Loss: {avg_val_loss:.3f} |  Val. Accuracy: {accuracy:.2f}%')

This works perfectly fine, but it's a lot of code to manage. If you wanted to add GPU training, logging, or gradient clipping, this section would get even more complex.

--- 
## Part 2: The PyTorch Lightning Way ✨

Now, let's refactor the code above using PyTorch Lightning. We will separate the data logic from the model logic.

### 2.1 Imports and Setup

We add imports for `pytorch_lightning` and `torchmetrics`. `torchmetrics` is a library that integrates seamlessly with Lightning for calculating metrics efficiently.

First, let's install them if they aren't already.

In [None]:
!pip install pytorch-lightning torchmetrics --quiet

In [None]:
import torch
from torch import nn
from torch.utils.data import DataLoader, random_split
from torchvision import datasets
from torchvision.transforms import ToTensor
import torch.nn.functional as F

# NEW: Import PyTorch Lightning and TorchMetrics
import pytorch_lightning as pl
import torchmetrics

# For reproducibility
pl.seed_everything(42)

### 2.2 The `LightningDataModule` (Organizing Data)
A `LightningDataModule` is a shareable, reusable class that encapsulates all the steps needed to process data.

* `prepare_data()`: This is for downloading and saving data to disk. It's called on **only one GPU/process** to avoid downloading the data multiple times.
* `setup()`: This is for splitting, transforming, and creating the PyTorch `Dataset` objects. It's called on **every GPU/process**.
* `..._dataloader()`: These methods return the actual `DataLoader` instances.

In [None]:
class FashionMNISTDataModule(pl.LightningDataModule):
    def __init__(self, data_dir: str = "./data", batch_size: int = 64):
        super().__init__()
        self.data_dir = data_dir
        self.batch_size = batch_size
        self.transform = ToTensor()

    # This is called on one process only.
    # Use it for downloading the dataset.
    def prepare_data(self):
        datasets.FashionMNIST(self.data_dir, train=True, download=True)
        datasets.FashionMNIST(self.data_dir, train=False, download=True)

    # This is called on every process (GPU).
    # Use it to split data, create datasets, etc.
    def setup(self, stage: str):
        if stage == "fit" or stage is None:
            # Load the full training dataset
            fashion_mnist_full = datasets.FashionMNIST(
                self.data_dir, train=True, transform=self.transform
            )
            # Split it into training and validation sets
            self.fashion_mnist_train, self.fashion_mnist_val = random_split(
                fashion_mnist_full, [50000, 10000]
            )
        if stage == "test" or stage is None:
            self.fashion_mnist_test = datasets.FashionMNIST(
                self.data_dir, train=False, transform=self.transform
            )

    # Return the training dataloader
    def train_dataloader(self):
        return DataLoader(
            self.fashion_mnist_train,
            batch_size=self.batch_size,
            shuffle=True,
        )

    # Return the validation dataloader
    def val_dataloader(self):
        return DataLoader(self.fashion_mnist_val, batch_size=self.batch_size)

    # Return the test dataloader
    def test_dataloader(self):
        return DataLoader(self.fashion_mnist_test, batch_size=self.batch_size)

### 2.3 The `LightningModule` (Organizing the Model & Logic)

The `LightningModule` is where the magic happens. We take our `nn.Module` and add hooks for training and validation logic. It combines the model architecture, the optimizer, and the training/validation steps into one clean class.

* `__init__()`: Define your model layers, loss function, and metrics here. We use `torchmetrics.Accuracy`.
* `forward()`: The standard forward pass, just like in `nn.Module`.
* `training_step()`: This replaces the inner training loop. It receives a batch, performs a forward pass, and returns the loss. Lightning handles the rest (backpropagation, optimizer step, etc.).
* `validation_step()`: Replaces the inner validation loop. You log your metrics here using `self.log()`.
* `configure_optimizers()`: You define your optimizer here and return it.

In [None]:
class LitFashionMNIST(pl.LightningModule):
    def __init__(self, learning_rate=1e-3):
        super().__init__()
        # Save hyperparameters like learning_rate
        self.save_hyperparameters()
        
        # --- Model Architecture (same as before) ---
        self.conv1 = nn.Conv2d(1, 32, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
        self.fc1 = nn.Linear(64 * 7 * 7, 128)
        self.fc2 = nn.Linear(128, 10)

        # --- Loss Function ---
        self.loss_fn = nn.CrossEntropyLoss()

        # --- Metrics ---
        # We use TorchMetrics to calculate accuracy.
        # 'task' is set to 'multiclass' and we specify the number of classes.
        self.accuracy = torchmetrics.Accuracy(task="multiclass", num_classes=10)

    # The forward pass for inference
    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 64 * 7 * 7)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x

    # This is the training loop.
    # It is called for each batch in the training dataloader.
    def training_step(self, batch, batch_idx):
        X, y = batch
        pred = self(X)  # Same as self.forward(X)
        loss = self.loss_fn(pred, y)

        # Use self.log to record the training loss
        self.log("train_loss", loss, on_step=True, on_epoch=True, prog_bar=True)
        return loss # This is required

    # This is the validation loop.
    def validation_step(self, batch, batch_idx):
        X, y = batch
        pred = self(X)
        loss = self.loss_fn(pred, y)
        acc = self.accuracy(pred, y)

        # Use self.log to record validation loss and accuracy
        self.log("val_loss", loss, prog_bar=True)
        self.log("val_acc", acc, prog_bar=True)

    # Define the optimizer
    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.hparams.learning_rate)
        return optimizer

### 2.4 The `Trainer` (Running the Show)

Finally, the `pl.Trainer` object takes your `LightningModule` and `LightningDataModule` and handles the entire training process.

You just need to tell it how many epochs to run and what hardware to use.

In [None]:
# --- Instantiate DataModule and Model ---
datamodule = FashionMNISTDataModule()
model = LitFashionMNIST()

# --- Instantiate the Trainer ---
# The Trainer automates the training, validation, and testing loops.
# You can specify the number of epochs, accelerator (cpu, gpu, tpu), devices, etc.
trainer = pl.Trainer(
    max_epochs=5,
    accelerator="auto", # Automatically selects GPU if available
    devices="auto"
)

# --- Start Training ---
# The .fit() method takes the model and the datamodule.
# It will automatically call the setup, dataloaders, and training/validation steps.
trainer.fit(model, datamodule)

And that's it! With just these few lines, you get a full training loop with progress bars, logging, and hardware acceleration, all handled for you.

--- 
### Comparison and Conclusion

| Task                 | Standard PyTorch                          | PyTorch Lightning                                                 |
| :------------------- | :---------------------------------------- | :---------------------------------------------------------------- |
| **Model Definition** | `nn.Module` class                         | `pl.LightningModule` class                                        |
| **Data Handling** | Manual `Dataset` and `DataLoader` creation | `pl.LightningDataModule` class                                    |
| **Training Loop** | Manual `for` loop over epochs and batches   | Handled by `Trainer.fit()` and `training_step`                      |
| **Validation Loop** | Manual `for` loop with `torch.no_grad()`    | Handled by `Trainer.fit()` and `validation_step`                    |
| **Optimizer Logic** | Manual `.zero_grad()`, `.backward()`, `.step()` | Handled by `Trainer`; you just define it in `configure_optimizers`  |
| **Device Management**| Manual `.to(device)` calls                | Handled automatically by the `Trainer` (`accelerator` flag)         |
| **Metrics** | Manual calculation                        | Integrated via `torchmetrics` and `self.log()`                      |

By refactoring to PyTorch Lightning, you've made your code cleaner, more organized, and ready to scale to more complex models and hardware without changing the core logic. 