# Streamlining Deep Learning with PyTorch Lightning ‚ö°

---

This notebook introduces **PyTorch Lightning (PL)**, a high-level framework that wraps PyTorch. PL is not covered in detail in Raschka's book but is an essential tool in the modern deep learning ecosystem, providing a **structured, reusable, and less verbose way** to organize the training and evaluation steps learned in Chapter 13.

It demonstrates how to abstract away the boilerplate code (like moving data to GPU, managing training loops, and handling validation) into a clean, single class.

### 1. The `pl.LightningModule` Paradigm üèóÔ∏è

The core of the notebook is defining a custom model class that inherits from `pl.LightningModule` (e.g., `MLPClassifier`). This class organizes all the necessary PyTorch components:

* **`__init__(self, ...)`:** Defines the network layers (`nn.Linear`) and metrics (like `torchmetrics.Accuracy`).
* **`forward(self, x)`:** Defines the prediction logic (the standard forward pass).
* **`training_step(self, batch, batch_idx)`:** Defines the logic for a single training batch, automatically handling the loss calculation, backward pass, and logging.
* **`validation_step(self, batch, batch_idx)`:** Defines the logic for evaluation, automatically run periodically by the Trainer.
* **`configure_optimizers(self)`:** Defines the optimizer (`optim.Adam`) and learning rate scheduler.

### 2. Data Organization with `pl.DataModule` (Implicitly Demonstrated)

Although not always explicitly subclassed, the notebook relies on the PyTorch Lightning structure for data management:

* It uses **`torchvision.datasets.MNIST`** for data loading, splitting it into training and validation sets via `random_split`.
* It sets up the **`DataLoader`**s for batching, which are then passed directly to the `Trainer`.

### 3. The Central `pl.Trainer`

PyTorch Lightning replaces the hundreds of lines of the manual training loop from Chapter 13 with a single, highly configurable `Trainer` object:

* **Initialization:** The `Trainer` is initialized with essential parameters like `max_epochs`, `devices` (for GPU/CPU selection), and `logger` (for tracking results).
* **Training Execution:** A single call to **`trainer.fit(model, train_dataloader, val_dataloader)`** executes the entire training regimen:
    * Manages epochs, batches, and device transfers (CPU/GPU).
    * Handles checkpointing and logging.
    * Automatically runs validation and saves the best model.

### 4. Visualization with TensorBoard üìä

* The notebook leverages the built-in integration of **TensorBoard** (`%tensorboard --logdir lightning_logs/`).
* This demonstrates how PyTorch Lightning automatically logs metrics (like train/validation loss and accuracy) to a standard directory, providing powerful visual tools for monitoring model performance and debugging without extra code.

This notebook showcases how PyTorch Lightning significantly reduces complexity, making PyTorch code more scalable and easier to share.

In [1]:
import pytorch_lightning as pl
from torch import nn, optim
import torch
from torchmetrics import Accuracy
from torch.utils.data import DataLoader
from torch.utils.data import random_split
from torchvision import transforms
from torchvision.datasets import MNIST

In [7]:
import pytorch_lightning as pl
import torch
import torch.nn as nn
from torchmetrics.classification import Accuracy
import torch.nn.functional as F

class MLPClassifier(pl.LightningModule):
    def __init__(self, image_shape=(1, 28, 28), hidden_units=(32, 16), lr=0.001):
        super().__init__()
        self.save_hyperparameters()

        # Metrics
        self.train_acc = Accuracy(task="multiclass", num_classes=10)
        self.valid_acc = Accuracy(task="multiclass", num_classes=10)
        self.test_acc = Accuracy(task="multiclass", num_classes=10)

        # Model
        input_size = image_size[0] * image_size[1] * image_size[2]
        layers = [nn.Flatten()]
        for hidden_layer in hidden_layers:
            layers += [nn.Linear(input_size, hidden_layer)]
            input_size = hidden_layer
        layers.append(nn.Linear(hidden_layers[-1], 10))
        self.model = nn.Sequential(*layers)

    def forward(self, x):
        return self.model(x)

    def step(self, batch, stage):
        x, y = batch
        logits = self(x)
        loss = F.cross_entropy(logits, y)
        preds = torch.argmax(logits, dim=1)
        acc = getattr(self, f"{stage}_acc")(preds, y)
        self.log(f"{stage}_loss", loss, prog_bar=True)
        self.log(f"{stage}_acc", acc, prog_bar=True)
        return loss

    def training_step(self, batch, batch_idx):
        return self.step(batch, "train")

    def validation_step(self, batch, batch_idx):
        return self.step(batch, "valid")

    def test_step(self, batch, batch_idx):
        return self.step(batch, "test")

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=self.hparams.lr)


In [8]:
class MNISTDataModule(pl.LightningDataModule):
    def __init__(self, batch_size= 64, data_path= './'):
        super().__init__()
        self.batch_size = batch_size
        self.data_path = data_path
        self.transform = transforms.Compose([transforms.ToTensor()])
        
    def prepare_data(self):
        MNIST(root= self.data_path, train= True, download= True)
        MNIST(root= self.data_path, train= False, download= True)
        
    def setup(self, stage= None):
        mnist_full = MNIST(root= self.data_path, 
                           train= True, 
                           transform= self.transform, 
                           download= False)
        self.train, self.val = random_split(mnist_full, [55000, 5000])
        self.test = MNIST(root= self.data_path, 
                          train= False, 
                          transform= self.transform,
                          download= False)
        
    def train_dataloader(self):
        return DataLoader(self.train, batch_size= self.batch_size, shuffle= True)
    
    def val_dataloader(self):
        return DataLoader(self.val, batch_size= self.batch_size)
    
    def test_dataloader(self):
        return DataLoader(self.test, batch_size= self.batch_size)

In [34]:
class MultiLayerPerceptron(pl.LightningModule):
    
    def __init__(self, image_size= (1, 28, 28), hidden_layers= (32, 16)):
        super().__init__()
        
        self.train_acc = Accuracy(task= 'multiclass', num_classes= 10)
        self.val_acc = Accuracy(task= 'multiclass', num_classes= 10)
        self.test_acc = Accuracy(task= 'multiclass', num_classes= 10)
        
        input_size = image_size[0] * image_size[1] * image_size[2]
        layers = [nn.Flatten()]
        for hidden_layer in hidden_layers:
            layers += [nn.Linear(input_size, hidden_layer)]
            input_size = hidden_layer
        layers.append(nn.Linear(hidden_layers[-1], 10))
        self.model = nn.Sequential(*layers)
        
    def forward(self, x):
        return self.model(x)
        
    
    def training_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = nn.functional.cross_entropy(self(x), y)
        preds = torch.argmax(logits, dim= 1)
        self.train_acc.update(preds, y)
        self.log('train_loss', loss, prog_bar= True)
        
        return loss
    
    def on_train_epoch_end(self):
        self.log('train_acc', self.train_acc.compute(), prog_bar= True)
        self.train_acc.reset()
    
    def validation_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = nn.functional.cross_entropy(self(x), y)
        preds = torch.argmax(logits, dim= 1)
        self.val_acc.update(preds, y)
        self.log('val_loss', loss, prog_bar= True)
        
        return loss
    
    def on_validation_epoch_end(self):
        self.log('val_acc', self.val_acc.compute(), prog_bar= True)
        self.val_acc.reset()
    
    def test_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = nn.functional.cross_entropy(self(x), y)
        preds = torch.argmax(logits, dim= 1)
        self.test_acc.update(preds, y)
        self.log('test_loss', loss, prog_bar= True)
        
        return loss
    
    def on_test_epoch_end(self):
        self.log('test_acc', self.test_acc.compute(), prog_bar= True)
        self.test_acc.reset()
    
    def configure_optimizers(self):
        optimizer = optim.Adam(self.parameters(), lr= 0.001)
        
        return optimizer

In [35]:
class MnistDataModule(pl.LightningDataModule):
    def __init__(self, data_path= './'):
        super().__init__()
        self.data_path = data_path
        self.transform = transforms.Compose([transforms.ToTensor()])
    
    def prepare_data(self):
        MNIST(root= self.data_path, download= True)
    
    def setup(self, stage= None):
        mnist_all = MNIST(root= self.data_path, 
                          train= True, 
                          transform= self.transform, 
                          download= False)
        self.train, self.val = random_split(mnist_all, [55000, 5000], generator= torch.Generator().manual_seed(1))
        
        self.test = MNIST(root= self.data_path, 
                          train= False, 
                          transform= self.transform, 
                          download= False)
    def train_dataloader(self):
        return DataLoader(self.train, batch_size= 64, num_workers= 4, persistent_workers= True)
    
    def val_dataloader(self):
        return DataLoader(self.val, batch_size= 64, num_workers= 4, persistent_workers= True)
    
    def test_dataloader(self):
        return DataLoader(self.test, batch_size= 64, num_workers= 4, persistent_workers= True)

In [36]:
if __name__ == '__main__':
    torch.manual_seed(1)
    data_module = MnistDataModule()
    model = MultiLayerPerceptron()
    trainer = pl.Trainer(
        max_epochs= 10,
        accelerator= 'auto',
        devices= 'auto',
        deterministic= True
    )

üí° Tip: For seamless cloud uploads and versioning, try installing [litmodels](https://pypi.org/project/litmodels/) to enable LitModelCheckpoint, which syncs automatically with the Lightning model registry.
GPU available: False, used: False
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs


In [37]:
trainer.fit(model, datamodule= data_module)


  | Name      | Type               | Params | Mode 
---------------------------------------------------------
0 | train_acc | MulticlassAccuracy | 0      | train
1 | val_acc   | MulticlassAccuracy | 0      | train
2 | test_acc  | MulticlassAccuracy | 0      | train
3 | model     | Sequential         | 25.8 K | train
---------------------------------------------------------
25.8 K    Trainable params
0         Non-trainable params
25.8 K    Total params
0.103     Total estimated model params size (MB)
8         Modules in train mode
0         Modules in eval mode


Sanity Checking: |                                                                               | 0/? [00:00<‚Ä¶

Training: |                                                                                      | 0/? [00:00<‚Ä¶

Validation: |                                                                                    | 0/? [00:00<‚Ä¶

Validation: |                                                                                    | 0/? [00:00<‚Ä¶

Validation: |                                                                                    | 0/? [00:00<‚Ä¶

Validation: |                                                                                    | 0/? [00:00<‚Ä¶

Validation: |                                                                                    | 0/? [00:00<‚Ä¶

Validation: |                                                                                    | 0/? [00:00<‚Ä¶

Validation: |                                                                                    | 0/? [00:00<‚Ä¶

Validation: |                                                                                    | 0/? [00:00<‚Ä¶

Validation: |                                                                                    | 0/? [00:00<‚Ä¶

Validation: |                                                                                    | 0/? [00:00<‚Ä¶

`Trainer.fit` stopped: `max_epochs=10` reached.


In [40]:
%load_ext tensorboard
%tensorboard --logdir lightning_logs/

In [42]:
trainer = pl.Trainer(
    max_epochs= 15,
    accelerator= 'auto',
    devices= 'auto',
    deterministic= True
)
trainer.fit(model, datamodule= data_module, ckpt_path= './lightning_logs/version_3/checkpoints/epoch=9-step=8600.ckpt')

üí° Tip: For seamless cloud uploads and versioning, try installing [litmodels](https://pypi.org/project/litmodels/) to enable LitModelCheckpoint, which syncs automatically with the Lightning model registry.
GPU available: False, used: False
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
Restoring states from the checkpoint path at ./lightning_logs/version_3/checkpoints/epoch=9-step=8600.ckpt
C:\Users\98922\anaconda3\lib\site-packages\pytorch_lightning\callbacks\model_checkpoint.py:366: The dirpath has changed from 'C:\\Users\\98922\\Documents\\python_scripts\\AI\\pytorch_raschka\\chap13\\lightning_logs\\version_3\\checkpoints' to 'C:\\Users\\98922\\Documents\\python_scripts\\AI\\pytorch_raschka\\chap13\\lightning_logs\\version_4\\checkpoints', therefore `best_model_score`, `kth_best_model_path`, `kth_value`, `last_model_path` and `best_k_models` won't be reloaded. Only `best_model_path` will be reloaded.

  | Name      | Type               | Params | Mode

Sanity Checking: |                                                                               | 0/? [00:00<‚Ä¶

Training: |                                                                                      | 0/? [00:00<‚Ä¶

Validation: |                                                                                    | 0/? [00:00<‚Ä¶

Validation: |                                                                                    | 0/? [00:00<‚Ä¶

Validation: |                                                                                    | 0/? [00:00<‚Ä¶

Validation: |                                                                                    | 0/? [00:00<‚Ä¶

Validation: |                                                                                    | 0/? [00:00<‚Ä¶

`Trainer.fit` stopped: `max_epochs=15` reached.


In [45]:
%reload_ext tensorboard
%tensorboard --logdir lightning_logs/

Reusing TensorBoard on port 6006 (pid 16496), started 0:57:34 ago. (Use '!kill 16496' to kill it.)