#Objective of This Practical

To build a complete end-to-end CNN training pipeline using PyTorch Lightning with a clean, modular, and industry-style workflow.

To demonstrate training, validation, accuracy tracking, checkpoint saving, and model loading in a structured ML workflow.

# What learner Will Learn

1.How to prepare image datasets using torchvision.transforms and FakeData.

2.How to create train/validation splits and DataLoaders.

3.How to build a simple Convolutional Neural Network (CNN) in PyTorch.

4.How to wrap a PyTorch model inside a LightningModule for cleaner training.

5.How PyTorch Lightning handles:

6.Training loops

7.Validation loops

8.Logging metrics (loss & accuracy)

9.Optimizer setup

10.How to use ModelCheckpoint to save the best model.

11.How to reload the saved checkpoint and perform inference.

12.How to ensure reproducibility with proper seeding.

13.How to run everything efficiently on CPU/GPU automatically.

In [21]:
# Install required packages (run once in notebook)
!pip install torch torchvision




In [22]:
# Import PyTorch core library for tensors and computations
import torch

# Import neural network module (layers like Conv2d, Linear, etc.)
import torch.nn as nn

# Import optimization algorithms such as Adam, SGD, etc.
import torch.optim as optim

# Import DataLoader for batching & shuffling datasets
from torch.utils.data import DataLoader

# Import image transforms and standard datasets like MNIST/CIFAR
from torchvision import transforms, datasets

# Import Python's random module for reproducibility
import random

# Import NumPy for array operations and seeding
import numpy as np

# Detect GPU if available, else fallback to CPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Print which device is being used for computation
print("Device:", device)


Device: cpu


In [23]:
# Set a fixed seed value for reproducibility
seed = 42

# Apply seed to PyTorch CPU operations
torch.manual_seed(seed)

# Apply seed to NumPy operations
np.random.seed(seed)

# Apply seed to Python's built-in random module
random.seed(seed)

# If running on GPU, set CUDA-specific seeds as well
if device.type == "cuda":
    torch.cuda.manual_seed_all(seed)


In [24]:
# Create a fake image dataset (CIFAR-like: 3×32×32) for quick demo/testing
# FakeData generates random images + labels (useful to test model pipeline)

# Define transforms to apply on each image
transform = transforms.Compose([
    transforms.ToTensor(),                                 # Convert image to tensor (C,H,W) with values in [0,1]
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))  # Normalize each channel to roughly [-1, 1]
])

# Create a random dataset with 1000 samples, 3 channels, 32×32 resolution and 10 classes
dataset = datasets.FakeData(
    size=1000,
    image_size=(3, 32, 32),
    num_classes=10,
    transform=transform,     # Apply transforms on each image
    random_offset=0          # Makes dataset deterministic
)

# Set batch size for DataLoader
batch_size = 32

# Create DataLoader for batching and shuffling the dataset
loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)


In [25]:
# Define a simple CNN architecture
class SimpleCNN(nn.Module):
    def __init__(self, num_classes=10):
        super(SimpleCNN, self).__init__()

        # --- Convolution Block 1 ---
        # First conv layer: input channels=3, output channels=16, kernel=3x3, padding=1 (keeps size same)
        self.conv1 = nn.Conv2d(in_channels=3, out_channels=16, kernel_size=3, padding=1)
        # Activation function
        self.relu1 = nn.ReLU()
        # Max pooling: reduces spatial size from 32x32 → 16x16
        self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)

        # --- Convolution Block 2 ---
        # Second conv: 16 → 32 channels
        self.conv2 = nn.Conv2d(16, 32, kernel_size=3, padding=1)
        # Activation function
        self.relu2 = nn.ReLU()
        # Pooling: reduces 16x16 → 8x8
        self.pool2 = nn.MaxPool2d(2, 2)

        # --- Fully Connected Layers ---
        # FC layer: flatten 32×8×8 → 128 features
        self.fc1 = nn.Linear(32 * 8 * 8, 128)
        # Activation function
        self.relu3 = nn.ReLU()
        # Output layer for number of classes
        self.fc2 = nn.Linear(128, num_classes)

    def forward(self, x):
        # Input shape = (batch_size, 3, 32, 32)
        x = self.pool1(self.relu1(self.conv1(x)))   # Apply conv1 → relu1 → pool1
        x = self.pool2(self.relu2(self.conv2(x)))   # Apply conv2 → relu2 → pool2

        # Flatten tensor for FC layers
        x = x.view(x.size(0), -1)                   # Shape becomes (batch, 32*8*8)

        # Apply first fully connected layer + activation
        x = self.relu3(self.fc1(x))

        # Output layer → raw class logits
        x = self.fc2(x)
        return x

# Create model instance and move to CPU/GPU
model = SimpleCNN(num_classes=10).to(device)

# Print network architecture summary
print(model)


SimpleCNN(
  (conv1): Conv2d(3, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (relu1): ReLU()
  (pool1): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (conv2): Conv2d(16, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (relu2): ReLU()
  (pool2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (fc1): Linear(in_features=2048, out_features=128, bias=True)
  (relu3): ReLU()
  (fc2): Linear(in_features=128, out_features=10, bias=True)
)


In [26]:
# Define loss function (CrossEntropyLoss is used for multi-class image classification)
criterion = nn.CrossEntropyLoss()

# Adam optimizer to update CNN weights during training (learning rate = 0.001)
optimizer = optim.Adam(model.parameters(), lr=0.001)


In [18]:
# Set number of training epochs
num_epochs = 5

# Put model in training mode (enables dropout, batchnorm updates if present)
model.train()

# Loop over epochs
for epoch in range(num_epochs):

    # Track total loss for the epoch
    running_loss = 0.0

    # Loop through each batch from the DataLoader
    for i, (inputs, labels) in enumerate(loader):

        # Move image batch to CPU/GPU
        inputs = inputs.to(device)

        # Move label batch to CPU/GPU
        labels = labels.to(device)

        # --- Forward pass ---
        outputs = model(inputs)                # Get model predictions (logits)
        loss = criterion(outputs, labels)      # Calculate loss for this batch

        # --- Backward + Optimize ---
        optimizer.zero_grad()                  # Clear previous gradients
        loss.backward()                        # Backpropagate to compute gradients
        optimizer.step()                       # Update model weights

        # Add loss to running total
        running_loss += loss.item()

    # Compute average loss over all batches in this epoch
    avg_loss = running_loss / len(loader)

    # Print epoch result
    print(f"Epoch [{epoch+1}/{num_epochs}] - Avg Loss: {avg_loss:.4f}")


Epoch [1/5] - Avg Loss: 2.3108
Epoch [2/5] - Avg Loss: 2.2977
Epoch [3/5] - Avg Loss: 2.2919
Epoch [4/5] - Avg Loss: 2.2827
Epoch [5/5] - Avg Loss: 2.2666


In [19]:
# Quick evaluation on a few samples (not a real test split, just demo)
model.eval()   # Set model to evaluation mode (disables dropout/batchnorm updates)

# Disable gradient calculation (faster + saves memory)
with torch.no_grad():

    # Get one batch of data from the loader
    sample_inputs, sample_labels = next(iter(loader))

    # Move inputs to CPU/GPU
    sample_inputs = sample_inputs.to(device)

    # Move labels to CPU/GPU
    sample_labels = sample_labels.to(device)

    # Forward pass to get predictions
    logits = model(sample_inputs)

    # Convert logits to predicted class indices
    preds = logits.argmax(dim=1)

    # Print first 10 true vs predicted labels
    for idx in range(10):
        print(f"True: {int(sample_labels[idx].item())}  Pred: {int(preds[idx].item())}")


True: 5  Pred: 3
True: 6  Pred: 3
True: 6  Pred: 3
True: 2  Pred: 8
True: 6  Pred: 3
True: 6  Pred: 3
True: 1  Pred: 3
True: 8  Pred: 3
True: 7  Pred: 3
True: 8  Pred: 8


#single, ready-to-run PyTorch Lightning notebook cell that includes:

> train / val split (validation split),

> batch & epoch accuracy logging,

> saving best checkpoint (by val_loss),

> loading that checkpoint for inference,

> saving state_dict separately.

In [20]:
# Install required packages (uncomment if needed; quiet to reduce output)
!pip install torch torchvision pytorch_lightning --quiet

# OS utilities (file paths, dirs)
import os
# Python random utilities for reproducibility
import random
# NumPy for numeric ops and seeding
import numpy as np
# Core PyTorch library
import torch
# Neural network modules (layers, losses, etc.)
import torch.nn as nn
# PyTorch optimizers (Adam, SGD, ...)
import torch.optim as optim
# DataLoader and random_split for batching and splitting datasets
from torch.utils.data import DataLoader, random_split
# Image transforms and toy datasets from torchvision
from torchvision import transforms, datasets
# PyTorch Lightning high-level training framework
import pytorch_lightning as pl
# Lightning callback to save model checkpoints
from pytorch_lightning.callbacks import ModelCheckpoint

# ----------------------------
# Reproducibility & device
# ----------------------------
# Fixed random seed value for reproducibility
seed = 42
# Seed Python's random module
random.seed(seed)
# Seed NumPy RNG
np.random.seed(seed)
# Seed PyTorch CPU RNG
torch.manual_seed(seed)
# If CUDA available, seed all CUDA devices for reproducibility
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(seed)

# Choose device: GPU if available, else CPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Print which device will be used for compute
print("Device:", device)

# ----------------------------
# Dataset & transforms
# ----------------------------
# Compose transforms: convert PIL to tensor and normalize channels
transform = transforms.Compose([
    transforms.ToTensor(),                              # PIL->Tensor with shape (C,H,W) and values in [0,1]
    transforms.Normalize((0.5,0.5,0.5), (0.5,0.5,0.5))  # Normalize each channel to roughly [-1,1]
])

# Create a FakeData dataset (CIFAR-like, 3x32x32, deterministic with random_offset=0)
full_dataset = datasets.FakeData(
    size=1000,                 # total samples
    image_size=(3,32,32),      # image shape (C,H,W)
    num_classes=10,            # number of target classes
    transform=transform,       # transforms to apply to each sample
    random_offset=0            # deterministic offset for repeatability
)

# Compute train/validation split sizes (80% train, 20% val)
train_len = int(0.8 * len(full_dataset))
val_len = len(full_dataset) - train_len

# Split dataset into train and validation sets using fixed generator seed
train_dataset, val_dataset = random_split(
    full_dataset,
    [train_len, val_len],
    generator=torch.Generator().manual_seed(seed)  # deterministic split
)

# Batch size for training/validation
batch_size = 32

# DataLoader for training: shuffles each epoch
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=0)

# DataLoader for validation: no shuffle to keep deterministic ordering
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=0)

# ----------------------------
# Simple CNN definition
# ----------------------------
class SimpleCNN(nn.Module):
    # Initialize with number of classes (default 10)
    def __init__(self, num_classes=10):
        super().__init__()

        # First convolutional layer: 3 input channels -> 16 output channels
        self.conv1 = nn.Conv2d(3, 16, kernel_size=3, padding=1)  # keeps spatial size the same
        # ReLU activation after conv1
        self.relu1 = nn.ReLU()
        # MaxPool reduces spatial dims by factor 2 (32x32 -> 16x16)
        self.pool1 = nn.MaxPool2d(2,2)

        # Second convolutional layer: 16 -> 32 channels
        self.conv2 = nn.Conv2d(16, 32, kernel_size=3, padding=1) # keeps spatial size before pooling
        # ReLU activation after conv2
        self.relu2 = nn.ReLU()
        # Second pooling (16x16 -> 8x8)
        self.pool2 = nn.MaxPool2d(2,2)

        # Fully connected layer: flatten 32*8*8 features -> 128
        self.fc1 = nn.Linear(32 * 8 * 8, 128)
        # ReLU activation for FC
        self.relu3 = nn.ReLU()
        # Final FC to get logits for each class
        self.fc2 = nn.Linear(128, num_classes)

    # Forward pass: defines how input tensor moves through the network
    def forward(self, x):
        # conv1 -> relu -> pool1 sequence
        x = self.pool1(self.relu1(self.conv1(x)))
        # conv2 -> relu -> pool2 sequence
        x = self.pool2(self.relu2(self.conv2(x)))
        # Flatten tensor: (batch, channels, H, W) -> (batch, channels*H*W)
        x = x.view(x.size(0), -1)
        # Fully connected layer + activation
        x = self.relu3(self.fc1(x))
        # Output logits (no softmax; CrossEntropy expects raw logits)
        x = self.fc2(x)
        return x

# ----------------------------
# Lightning Module with accuracy calculation
# ----------------------------
class LitCNN(pl.LightningModule):
    # Accept a PyTorch model and learning rate
    def __init__(self, model: nn.Module, lr: float = 1e-3):
        super().__init__()
        # Store the underlying model (SimpleCNN)
        self.model = model
        # Learning rate for optimizer
        self.lr = lr
        # Loss function for multi-class classification
        self.criterion = nn.CrossEntropyLoss()

        # Save hyperparameters (lr is saved; model is ignored to avoid pickling)
        self.save_hyperparameters(ignore=['model'])

    # Forward pass delegates to the underlying model
    def forward(self, x):
        return self.model(x)

    # Training step executed for each batch during training
    def training_step(self, batch, batch_idx):
        # Unpack inputs and targets from the batch
        inputs, targets = batch
        # Forward pass: get logits
        logits = self(inputs)
        # Compute cross-entropy loss
        loss = self.criterion(logits, targets)

        # Compute batch predictions (class indices)
        preds = torch.argmax(logits, dim=1)
        # Count correct predictions in the batch
        batch_correct = (preds == targets).sum().float()
        # Batch accuracy as fraction
        batch_acc = batch_correct / targets.size(0)

        # Log train loss and accuracy averaged across epoch (on_epoch=True)
        self.log('train_loss', loss, on_step=False, on_epoch=True, prog_bar=True)
        self.log('train_acc', batch_acc, on_step=False, on_epoch=True, prog_bar=True)
        # Return loss for optimizer step
        return loss

    # Validation step executed for each batch during validation
    def validation_step(self, batch, batch_idx):
        # Unpack batch
        inputs, targets = batch
        # Forward pass
        logits = self(inputs)
        # Compute validation loss
        loss = self.criterion(logits, targets)

        # Compute predictions and batch accuracy
        preds = torch.argmax(logits, dim=1)
        batch_correct = (preds == targets).sum().float()
        batch_acc = batch_correct / targets.size(0)

        # Log validation loss and accuracy (averaged across epoch)
        self.log('val_loss', loss, on_step=False, on_epoch=True, prog_bar=True)
        self.log('val_acc', batch_acc, on_step=False, on_epoch=True, prog_bar=True)
        # Return metrics for potential further aggregation
        return {"val_loss": loss, "val_acc": batch_acc}

    # Configure optimizer used for training
    def configure_optimizers(self):
        # Use Adam optimizer over LightningModule parameters with chosen lr
        optimizer = optim.Adam(self.parameters(), lr=self.lr)
        return optimizer

# ----------------------------
# Instantiate model, LightningModule, and checkpoint callback
# ----------------------------
# Create SimpleCNN instance with 10 classes
cnn = SimpleCNN(num_classes=10)
# Wrap the PyTorch model into the LightningModule with learning rate 1e-3
lit_model = LitCNN(cnn, lr=1e-3)

# Directory to store checkpoints
checkpoint_dir = "checkpoints"
# Create directory if it does not exist
os.makedirs(checkpoint_dir, exist_ok=True)

# ModelCheckpoint callback: save the best model based on minimum val_loss
ckpt_callback = ModelCheckpoint(
    monitor='val_loss',                                 # metric to monitor
    dirpath=checkpoint_dir,                             # where to save
    filename='best-checkpoint-{epoch:02d}-{val_loss:.4f}', # filename template
    save_top_k=1,                                       # keep only best checkpoint
    mode='min'                                          # minimize monitored metric
)

# ----------------------------
# Trainer: run training
# ----------------------------
trainer = pl.Trainer(
    max_epochs=5,                      # number of epochs to train
    callbacks=[ckpt_callback],         # checkpoint callback to save best model
    log_every_n_steps=20,              # log frequency in steps
    accelerator='auto',                # automatically choose CPU/GPU/TPU
    devices=1 if torch.cuda.is_available() else 1, # run on 1 device (CPU or GPU)
    deterministic=True                 # try to make runs deterministic
)

# Start training using train and validation dataloaders
trainer.fit(lit_model, train_dataloaders=train_loader, val_dataloaders=val_loader)

# ----------------------------
# After training: path to best checkpoint & demonstration of loading
# ----------------------------
# Get path to the best saved checkpoint (empty string if none saved)
best_ckpt = ckpt_callback.best_model_path
# Print path or message if no checkpoint saved
print("Best checkpoint path:", best_ckpt if best_ckpt else "No checkpoint saved")

# If a best checkpoint exists, load it and run example inference
if best_ckpt:
    # Load LightningModule from checkpoint (provide a fresh model object if required)
    loaded = LitCNN.load_from_checkpoint(best_ckpt, model=SimpleCNN(num_classes=10))
    # Move loaded module to chosen device (GPU/CPU)
    loaded.to(device)
    # Set module to evaluation mode
    loaded.eval()

    # Grab one validation batch for demo inference
    batch_inputs, batch_labels = next(iter(val_loader))
    # Move batch inputs to device
    batch_inputs = batch_inputs.to(device)
    # Move batch labels to device
    batch_labels = batch_labels.to(device)

    # Disable gradient computation for inference (faster & memory efficient)
    with torch.no_grad():
        # Get logits from the loaded model
        logits = loaded(batch_inputs)
        # Convert logits to predicted class indices
        preds = torch.argmax(logits, dim=1)

    # Print a sample of predictions (first up to 10)
    print("\nSample predictions (first 10):")
    for i in range(min(10, len(preds))):
        print(f"True: {int(batch_labels[i].item())}  Pred: {int(preds[i].item())}")

    # Optionally save the state_dict of the underlying PyTorch model
    final_state_path = os.path.join(checkpoint_dir, "final_model_state_dict.pth")
    torch.save(loaded.model.state_dict(), final_state_path)
    # Confirm saved path
    print("Saved model state_dict to:", final_state_path)
else:
    # Inform user that no checkpoint was saved to load
    print("No checkpoint to load.")

# ----------------------------
# Notes:
# - Checkpoints saved under the 'checkpoints' directory.
# - Lightning logs (train_loss/train_acc/val_loss/val_acc) appear in the progress bar and can be used by loggers.
# - To use a real dataset replace FakeData with CIFAR10 or MNIST from torchvision.datasets (will download data).
# ----------------------------


INFO:pytorch_lightning.utilities.rank_zero:GPU available: False, used: False
INFO:pytorch_lightning.utilities.rank_zero:TPU available: False, using: 0 TPU cores


Device: cpu


/usr/local/lib/python3.12/dist-packages/pytorch_lightning/callbacks/model_checkpoint.py:881: Checkpoint directory /content/checkpoints exists and is not empty.


Output()

INFO:pytorch_lightning.utilities.rank_zero:`Trainer.fit` stopped: `max_epochs=5` reached.


Best checkpoint path: /content/checkpoints/best-checkpoint-epoch=01-val_loss=2.3032-v1.ckpt

Sample predictions (first 10):
True: 2  Pred: 8
True: 5  Pred: 8
True: 6  Pred: 8
True: 2  Pred: 8
True: 0  Pred: 8
True: 8  Pred: 8
True: 3  Pred: 8
True: 7  Pred: 8
True: 1  Pred: 8
True: 5  Pred: 8
Saved model state_dict to: checkpoints/final_model_state_dict.pth
