# Organizing code with PyTorch Lightning

In [None]:
# Install PyTorch Lightning
#!pip install lightning
#!pip install torchmetrics

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
from torch.utils.data.dataset import random_split
from torchvision import datasets, transforms
import lightning as L
import torchmetrics

train_ds = datasets.MNIST(root = "./mnist", train = True, transform = transforms.ToTensor(), download=True)
# # transforms.ToTensor() converts the image into numbers and scales the values between 0 and 1
test_ds = datasets.MNIST(root = "./mnist", train = False, transform = transforms.ToTensor(), download=True)

# Split the training dataset into training and validation dataset
torch.manual_seed(2056)
train_ds, val_ds = random_split(train_ds, [55000, 5000])

# Create dataloaders
# Create training, validation, and test dataloaders to load data in batches for model training
train_dl = DataLoader(
    dataset = train_ds,
    batch_size = 64,
    shuffle = True,
)

val_dl = DataLoader(
    dataset = val_ds,
    batch_size = 64,
    shuffle = False
)   

test_dl = DataLoader(
    dataset = test_ds,
    batch_size = 64,
    shuffle = False
)  

# Create a model
class PyTorchMLP(nn.Module):
    def __init__(self, num_features, num_classes):
        super().__init__()

        self.all_layers = nn.Sequential(
            # First hidden layer
            nn.Linear(num_features, 50),
            nn.ReLU(),
            # Second hidden layer
            nn.Linear(50, 25),
            nn.ReLU(),
            # Output layer
            nn.Linear(25, num_classes)
        )

    # Foward the input through the model
    def forward(self, x):
        x = torch.flatten(x, 1) # Flatten all dimensions except batch
        logits = self.all_layers(x)
        return logits
    
# Utility function to compute accuracy
def compute_accuracy(model, dataloader):

    correct, total_examples = 0.0, 0.0
    
    # Put the model in eval mode
    model = model.eval()

    for idx, (features, targets) in enumerate(dataloader):

        with torch.inference_mode():
            logits = model(features)
        
            # Get predictions from the model
            predicted_labels = torch.argmax(logits, dim = 1)

            # Count number of correct predictions
            correct += torch.sum((predicted_labels == targets).float())

            # Count the total number of examples
            total_examples += len(targets)

    # Compute the accuracy
    accuracy = correct/total_examples

    return accuracy

A LightningModule organizes your PyTorch code into 6 sections:

`Initialization (__init__ and setup())`.

`Train Loop (training_step())`

`Validation Loop (validation_step())`

`Test Loop (test_step())`

`Prediction Loop (predict_step())`

`Optimizers and LR Schedulers (configure_optimizers())`

In [12]:
import torchmetrics
# Define a LightningModule that receives Pytorch model as input
class LightningModel(L.LightningModule):

    def __init__(self, model, learning_rate):
        super().__init__()
        self.model = model
        self.learning_rate = learning_rate

        # Set up metrics
        self.train_acc = torchmetrics.Accuracy(task = "multiclass", num_classes = 10)
        self.val_acc = torchmetrics.Accuracy(task = "multiclass", num_classes = 10)

    def forward(self, x):
        # Use forward pass of Pytorch model
        return self.model(x)
    
    # Train loop
    def training_step(self, batch, batch_idx):
        features, true_labels = batch
        logits = self(features)
        loss = F.cross_entropy(logits, true_labels)
        # Log metrics average loss across the epoch
        self.log("train_loss", loss)


        # Compute training accuracy after every epoch
        predicted_labels = torch.argmax(logits, dim = 1)
        self.train_acc(predicted_labels, true_labels)
        self.log("train_acc", self.train_acc, prog_bar = True, on_epoch = True, on_step = False)


        return loss # # Return the loss which will be passed to the optimizer to zero out gradients, compute gradients, and update weights
    
    # Validation loop
    def validation_step(self, batch, batch_idx):
        features, true_labels = batch
        logits = self(features)
        loss = F.cross_entropy(logits, true_labels)
        # Log metrics average loss across the epoch
        self.log("val_loss", loss, prog_bar=True)

        # Compute validation accuracy after every epoch
        predicted_labels = torch.argmax(logits, dim = 1)
        self.val_acc(predicted_labels, true_labels)
        self.log("val_acc", self.val_acc, prog_bar = True) # Doesn't need to specify on_epoch = True because it is the default behavior of validation_step
        # Doesn't return anything because we don't need to compute gradients during validation

    # Define optimizers and learning rate schedulers
    def configure_optimizers(self):
        optimizer = torch.optim.SGD(self.parameters(), lr = self.learning_rate)
        return optimizer

In [6]:
# Add a test loop
class LightningModel(L.LightningModule):

    def __init__(self, model, learning_rate):
        super().__init__()

        self.model = model
        self.learning_rate = learning_rate

        # Set up metrics
        self.train_acc = torchmetrics.Accuracy(task = "multiclass", num_classes = 10)
        self.val_acc = torchmetrics.Accuracy(task = "multiclass", num_classes = 10)
        self.test_acc = torchmetrics.Accuracy(task = "multiclass", num_classes = 10)

    def forward(self, x):
        # Use forward pass of Pytorch model
        return self.model(x)
    
    # Define shared step for training, validation, and test steps
    def _shared_step(self, batch):
        features, true_labels = batch
        logits = self(features)

        loss = F.cross_entropy(logits, true_labels)
        predicted_labels = torch.argmax(logits, dim = 1)
        return loss, predicted_labels, true_labels
    
    # Train loop
    def training_step(self, batch, batch_idx):
        loss, predicted_labels, true_labels = self._shared_step(batch)

        self.log("train_loss", loss)
        self.train_acc(predicted_labels, true_labels)
        self.log(
            "train_acc",
            self.train_acc,
            prog_bar = True,
            on_epoch = True, on_step = False
        )
        return loss # # Return the loss which will be passed to the optimizer to zero out gradients, compute gradients, and update weights
    
    # Validation loop
    def validation_step(self, batch, batch_idx):
        loss, predicted_labels, true_labels = self._shared_step(batch)

        self.log("val_loss", loss, prog_bar = True)
        self.val_acc(predicted_labels, true_labels)
        self.log("val_acc", self.val_acc, prog_bar = True)

    # Test loop
    def test_step(self, batch, batch_idx):
        loss, predicted_labels, true_labels = self._shared_step(batch)
        self.test_acc(predicted_labels, true_labels)
        self.log("test_acc", self.test_acc)

    # Define optimizers and learning rate schedulers
    def configure_optimizers(self):
        optimizer = torch.optim.SGD(self.parameters(), lr = self.learning_rate)
        return optimizer

TRAINER
Once you’ve organized your PyTorch code into a LightningModule, the Trainer automates everything else.

The Trainer achieves the following:

>You maintain control over all aspects via PyTorch code in your LightningModule.

>The trainer uses best practices embedded by contributors and users from top AI labs such as Facebook AI Research, NYU, MIT, Stanford, etc…



In [None]:
# Define a PyTorch model
pytorch_model = PyTorchMLP(num_features = 28*28, num_classes = 10)

# Define a Lightning model
lightning_model = LightningModel(model = pytorch_model, learning_rate = 0.05)

# Customize aspects of the training process using a Trainer
trainer = L.Trainer(
    max_epochs = 10,
    accelerator = "auto", # Set to auto to use GPU if available
    devices = "auto", # Set to auto to use all GPUs if available
    #deterministic=True # Set to True to ensure reproducibility
)

# Train the model
trainer.fit(
    model = lightning_model,
    train_dataloaders = train_dl,
    val_dataloaders = val_dl
)

# # Compute accuracy on test dataset
# train_acc = compute_accuracy(pytorch_model, train_dl)
# val_acc = compute_accuracy(pytorch_model, val_dl)
# test_acc = compute_accuracy(pytorch_model, test_dl)
# print(f"\nTrain accuracy: {train_acc*100:.2f}, \nValidation accuracy: {val_acc*100:.2f}, \nTest accuracy: {test_acc*100:.2f}")


In [10]:
# Evaluate model based on test_step which is computed after training
test_acc = trainer.test(dataloaders=test_dl)[0]["test_acc"]

# Can be done on the train and validation dataset as well
train_acc = trainer.test(dataloaders=train_dl)[0]["test_acc"]
val_acc = trainer.test(dataloaders=val_dl)[0]["test_acc"]

print(f"\nTrain accuracy: {train_acc*100:.2f}, \nValidation accuracy: {val_acc*100:.2f}, \nTest accuracy: {test_acc*100:.2f}")

  rank_zero_warn(
Restoring states from the checkpoint path at c:\Users\homeuser\Documents\deep_learning_fundamentals\lightning_logs\version_3\checkpoints\epoch=9-step=8600.ckpt
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
Loaded model weights from the checkpoint at c:\Users\homeuser\Documents\deep_learning_fundamentals\lightning_logs\version_3\checkpoints\epoch=9-step=8600.ckpt
  rank_zero_warn(


Testing: 0it [00:00, ?it/s]

Restoring states from the checkpoint path at c:\Users\homeuser\Documents\deep_learning_fundamentals\lightning_logs\version_3\checkpoints\epoch=9-step=8600.ckpt
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
Loaded model weights from the checkpoint at c:\Users\homeuser\Documents\deep_learning_fundamentals\lightning_logs\version_3\checkpoints\epoch=9-step=8600.ckpt
  rank_zero_warn(


Testing: 0it [00:00, ?it/s]

Restoring states from the checkpoint path at c:\Users\homeuser\Documents\deep_learning_fundamentals\lightning_logs\version_3\checkpoints\epoch=9-step=8600.ckpt
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
Loaded model weights from the checkpoint at c:\Users\homeuser\Documents\deep_learning_fundamentals\lightning_logs\version_3\checkpoints\epoch=9-step=8600.ckpt


Testing: 0it [00:00, ?it/s]


Train accuracy: 97.49, 
Validation accuracy: 96.68, 
Test accuracy: 96.82


In [None]:
# # Saving and loading models
# PATH = "lightning.pt"
# torch.save(pytorch_model.state_dict(), PATH)

# # Load the model
# model = PyTorchMLP(num_features = 28*28, num_classes = 10)
# model.load_state_dict(torch.load(PATH))
# model.eval()

In [4]:


# # Define a LightningModule that receives a PyTorch model as input
# class LightningModel(L.LightningModule): # Define a class that inherits from L.LightningModule

#     def __init__(self, model, learning_rate):
#         super().__init__()

#         self.model = model
#         self.learning_rate = learning_rate

#     def forward(self, x): 
#         return self.model(x)
    
#     def training_step(self, batch, batch_idx):
#         features, true_labels = batch
#         logits = self(features)
#         loss = F.cross_entropy(logits, true_labels)
#         self.log("train_loss", loss)
#         return loss # Return the loss which will be passed to the optimizer to zero out gradients, compute gradients, and update weights
    
#     def validation_step(self, batch, batch_idx):
#         features, true_labels = batch
#         logits = self(features)
#         loss = F.cross_entropy(logits, true_labels)
#         self.log("val_loss", loss, prog_bar=True)
#         # Doesn't return anything because we don't need to compute gradients during validation

#     def configure_optimizers(self):
#         optimizer = torch.optim.SGD(self.parameters(), lr = self.learning_rate)
#         return optimizer



## Adding data modules

LIGHTNINGDATAMODULE
A datamodule is a shareable, reusable class that encapsulates all the steps needed to process data:

A datamodule encapsulates the five steps involved in data processing in PyTorch:

Download / tokenize / process.

Clean and (maybe) save to disk.

Load inside Dataset.

Apply transforms (rotate, tokenize, etc…).

Wrap inside a DataLoader.

To define a datamodule, the following methods are used to create train/val/test/predict dataloaders:

`prepare_data()`: used to download, tokenize, etc… data (only called on 1 GPU/TPU in distributed).

`setup`: used to do data operations that you might want to perform on every GPU E.g count number of classes, build vocabulary, perform train/val/test splits, apply transforms

`train_dataloader()`: returns the train dataloader. This is the dataloader that the trainer `fit` method uses

`val_dataloader()`: returns the val dataloader(s). This is the dataloader that the trainer `validate` and `fit` method uses

`test_dataloader()`: returns the test dataloader(s). This is the dataloader that the trainer `test` method uses

`predict_dataloader()`: returns the predict dataloader(s). This is the dataloader that the trainer `predict` method uses

In [3]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
from torch.utils.data.dataset import random_split
from torchvision import datasets, transforms
import lightning as L
import torchmetrics
from lightning.pytorch.loggers import CSVLogger

# Data modules as an optional organization layer
class MNISTDataModule(L.LightningDataModule):

    def __init__(self, data_dir = "./mnist", batch_size = 64):
        super().__init__()
        self.data_dir = data_dir
        self.batch_size = batch_size

    def prepare_data(self): # Download the dataset - only called on 1 GPU
        datasets.MNIST(self.data_dir, train = True, download = True)
        datasets.MNIST(self.data_dir, train = False, download = True)

    def setup(self, stage: str):# Split the dataset into training, validation, and test sets - called on every GPU
        self.mnist_test = datasets.MNIST(
            self.data_dir, transform = transforms.ToTensor(), train = False
            )
        mnist_full = datasets.MNIST(
            self.data_dir, transform = transforms.ToTensor(), train = True
        )
        self.mnist_train, self.mnist_val = random_split(mnist_full, [55000, 5000])
        self.mnist_predict = datasets.MNIST(
            self.data_dir, transform = transforms.ToTensor(), train = False
        )
    
    def train_dataloader(self):
        return DataLoader(dataset = self.mnist_train, batch_size = self.batch_size, shuffle = True)
    
    def val_dataloader(self):
        return DataLoader(dataset = self.mnist_val, batch_size = self.batch_size, shuffle = False)
    
    def test_dataloader(self):
        return DataLoader(dataset = self.mnist_test, batch_size = self.batch_size, shuffle = False)
    
    def predict_dataloader(self):
        return DataLoader(dataset = self.mnist_predict, batch_size = self.batch_size, shuffle = False)
    

# Create a Pytorch model
class PyTorchMLP(nn.Module):
    def __init__(self, num_features, num_classes):
        super().__init__()

        self.all_layers = nn.Sequential(
            # First hidden layer
            nn.Linear(num_features, 50),
            nn.ReLU(),
            # Second hidden layer
            nn.Linear(50, 25),
            nn.ReLU(),
            # Output layer
            nn.Linear(25, num_classes)
        )

    # Foward the input through the model
    def forward(self, x):
        x = torch.flatten(x, 1) # Flatten all dimensions except batch
        logits = self.all_layers(x)
        return logits

# Create a lightning module
class LightningModel(L.LightningModule):

    def __init__(self, model, learning_rate):
        super().__init__()

        self.model = model
        self.learning_rate = learning_rate

        # Save settings and hyperparameters to the log directory
        # but don't save the model
        self.save_hyperparameters(ignore = ["model"])

        # Set up metrics
        self.train_acc = torchmetrics.Accuracy(task = "multiclass", num_classes = 10)
        self.val_acc = torchmetrics.Accuracy(task = "multiclass", num_classes = 10)
        self.test_acc = torchmetrics.Accuracy(task = "multiclass", num_classes = 10)

    def forward(self, x):
        # Use forward pass of Pytorch model
        return self.model(x)
    
    # Define shared step for training, validation, and test steps
    def _shared_step(self, batch):
        features, true_labels = batch
        logits = self(features)

        loss = F.cross_entropy(logits, true_labels)
        predicted_labels = torch.argmax(logits, dim = 1)
        return loss, predicted_labels, true_labels
    
    # Train loop
    def training_step(self, batch, batch_idx):
        loss, predicted_labels, true_labels = self._shared_step(batch)

        self.log("train_loss", loss)
        self.train_acc(predicted_labels, true_labels)
        self.log(
            "train_acc",
            self.train_acc,
            prog_bar = True,
            on_epoch = True, on_step = False
        )
        return loss # # Return the loss which will be passed to the optimizer to zero out gradients, compute gradients, and update weights
    
    # Validation loop
    def validation_step(self, batch, batch_idx):
        loss, predicted_labels, true_labels = self._shared_step(batch)

        self.log("val_loss", loss, prog_bar = True)
        self.val_acc(predicted_labels, true_labels)
        self.log("val_acc", self.val_acc, prog_bar = True)

    # Test loop
    def test_step(self, batch, batch_idx):
        loss, predicted_labels, true_labels = self._shared_step(batch)
        self.test_acc(predicted_labels, true_labels)
        self.log("test_acc", self.test_acc)

    # Define optimizers and learning rate schedulers
    def configure_optimizers(self):
        optimizer = torch.optim.SGD(self.parameters(), lr = self.learning_rate)
        return optimizer


In [None]:
 
torch.manual_seed(2056)

# Instantiate the data module
dm = MNISTDataModule(data_dir = "./mnist", batch_size = 64)

# Instantiate the PyTorch model
pytorch_model = PyTorchMLP(num_features = 28*28, num_classes = 10)

# Instantiate the Lightning model
lightning_model = LightningModel(model = pytorch_model, learning_rate = 0.05)

# Customize aspects of the training process using a Trainer
trainer = L.Trainer(
    max_epochs = 10,
    accelerator = "auto", # Set to auto to use GPU if available
    devices = "auto", # Set to auto to use all GPUs if available
    logger = CSVLogger("./logs", name = "mnist_logs"),
    #default_root_dir = "./logs", # Set the root directory for logs and weights while using tensorboard etc
    #deterministic=True # Set to True to ensure reproducibility
)

# Train the model
trainer.fit(
    model = lightning_model,
    datamodule = dm
)

# Print how the model performs on the test dataset
test_acc = trainer.test(dataloaders = dm.test_dataloader())[0]["test_acc"]

# The same can be done on the train and validation dataset as well
train_acc = trainer.test(dataloaders=dm.train_dataloader())[0]["test_acc"]
val_acc = trainer.test(dataloaders=dm.val_dataloader())[0]["test_acc"]

print(f"\nTrain accuracy: {train_acc*100:.2f}, \nValidation accuracy: {val_acc*100:.2f}, \nTest accuracy: {test_acc*100:.2f}")

In [None]:
# Open TensorBoard
# %load_ext tensorboard
%tensorboard 

In [7]:
# Delete all current logs
import shutil
shutil.rmtree("./logs")

## Using the model on new data


In [12]:
# Lightning automatically logs checkpoints but you can manually log checkpoints as well
#trainer.save_checkpoint("model.ckpt")

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
from torch.utils.data.dataset import random_split
from torchvision import datasets, transforms
import lightning as L
import torchmetrics
from lightning.pytorch.loggers import CSVLogger

# Data modules as an optional organization layer
class MNISTDataModule(L.LightningDataModule):

    def __init__(self, data_dir = "./mnist", batch_size = 64):
        super().__init__()
        self.data_dir = data_dir
        self.batch_size = batch_size

    def prepare_data(self): # Download the dataset - only called on 1 GPU
        datasets.MNIST(self.data_dir, train = True, download = True)
        datasets.MNIST(self.data_dir, train = False, download = True)

    def setup(self, stage: str):# Split the dataset into training, validation, and test sets - called on every GPU
        self.mnist_test = datasets.MNIST(
            self.data_dir, transform = transforms.ToTensor(), train = False
            )
        mnist_full = datasets.MNIST(
            self.data_dir, transform = transforms.ToTensor(), train = True
        )
        self.mnist_train, self.mnist_val = random_split(mnist_full, [55000, 5000])
        self.mnist_predict = datasets.MNIST(
            self.data_dir, transform = transforms.ToTensor(), train = False
        )
    
    def train_dataloader(self):
        return DataLoader(dataset = self.mnist_train, batch_size = self.batch_size, shuffle = True)
    
    def val_dataloader(self):
        return DataLoader(dataset = self.mnist_val, batch_size = self.batch_size, shuffle = False)
    
    def test_dataloader(self):
        return DataLoader(dataset = self.mnist_test, batch_size = self.batch_size, shuffle = False)
    
    def predict_dataloader(self):
        return DataLoader(dataset = self.mnist_predict, batch_size = self.batch_size, shuffle = False)
    

# Create a Pytorch model
class PyTorchMLP(nn.Module):
    def __init__(self, num_features, num_classes):
        super().__init__()

        self.all_layers = nn.Sequential(
            # First hidden layer
            nn.Linear(num_features, 50),
            nn.ReLU(),
            # Second hidden layer
            nn.Linear(50, 25),
            nn.ReLU(),
            # Output layer
            nn.Linear(25, num_classes)
        )

    # Foward the input through the model
    def forward(self, x):
        x = torch.flatten(x, 1) # Flatten all dimensions except batch
        logits = self.all_layers(x)
        return logits

# Create a lightning module
class LightningModel(L.LightningModule):

    def __init__(self, model, learning_rate):
        super().__init__()

        self.model = model
        self.learning_rate = learning_rate

        # Save settings and hyperparameters to the log directory
        # but don't save the model
        self.save_hyperparameters(ignore = ["model"])

        # Set up metrics
        self.train_acc = torchmetrics.Accuracy(task = "multiclass", num_classes = 10)
        self.val_acc = torchmetrics.Accuracy(task = "multiclass", num_classes = 10)
        self.test_acc = torchmetrics.Accuracy(task = "multiclass", num_classes = 10)

    def forward(self, x):
        # Use forward pass of Pytorch model
        return self.model(x)
    
    # Define shared step for training, validation, and test steps
    def _shared_step(self, batch):
        features, true_labels = batch
        logits = self(features)

        loss = F.cross_entropy(logits, true_labels)
        predicted_labels = torch.argmax(logits, dim = 1)
        return loss, predicted_labels, true_labels
    
    # Train loop
    def training_step(self, batch, batch_idx):
        loss, predicted_labels, true_labels = self._shared_step(batch)

        self.log("train_loss", loss)
        self.train_acc(predicted_labels, true_labels)
        self.log(
            "train_acc",
            self.train_acc,
            prog_bar = True,
            on_epoch = True, on_step = False
        )
        return loss # # Return the loss which will be passed to the optimizer to zero out gradients, compute gradients, and update weights
    
    # Validation loop
    def validation_step(self, batch, batch_idx):
        loss, predicted_labels, true_labels = self._shared_step(batch)

        self.log("val_loss", loss, prog_bar = True)
        self.val_acc(predicted_labels, true_labels)
        self.log("val_acc", self.val_acc, prog_bar = True)

    # Test loop
    def test_step(self, batch, batch_idx):
        loss, predicted_labels, true_labels = self._shared_step(batch)
        self.test_acc(predicted_labels, true_labels)
        self.log("test_acc", self.test_acc)

    # Define optimizers and learning rate schedulers
    def configure_optimizers(self):
        optimizer = torch.optim.SGD(self.parameters(), lr = self.learning_rate)
        return optimizer


In [8]:
# Instantiate pytorch model
pytorch_model = PyTorchMLP(num_features = 28*28, num_classes = 10)

# Instantiate the Lightning model from checkpoint
lightning_model = LightningModel.load_from_checkpoint(
    checkpoint_path="model.ckpt",
    model = pytorch_model,
)

# Put the model in eval mode
lightning_model = lightning_model.eval() # Can be done using pytorch_model.eval() as well

# Instantiate the data module
dm = MNISTDataModule(data_dir = "./mnist", batch_size = 64)
dm.setup(stage = "test")

# Define test dl that will pass the test dataset to the model in batches
test_dl = dm.test_dataloader()
acc = torchmetrics.Accuracy(task = "multiclass", num_classes = 10)

# Test loop
with torch.inference_mode():
    for batch in test_dl:
        features, true_labels = batch
        logits = lightning_model(features)
        predicted_labels = torch.argmax(logits, dim = 1)
        acc(predicted_labels, true_labels)

# Compute accuracy
test_acc = acc.compute()
print(f"Test accuracy: {test_acc*100:.2f}")
    



Test accuracy: 96.47


Lightning data module: datasets, dataloaders
Lightning module: model, training, validation, testing, prediction