# MNIST Classification using PyTorch Lightning

## Instructions
- **Video Tutorial**: Please watch [this video](https://www.youtube.com/watch?v=XbIN9LaQycQ&list=PLhhyoLH6IjfyL740PTuXef4TstxAK6nGP) to gain an overview of PyTorch Lightning.
- **Reading Materials**: Read the content provided in [Link 1](https://www.datacamp.com/tutorial/pytorch-lightning-tutorial) and [Link 2](https://lightning.ai/docs/pytorch/stable/starter/introduction.html) to deepen your understanding.
- Ensure you have PyTorch Lightning installed in your environment before starting.

---

## Section 1: Conceptual Background

### 1.1 Theoretical Foundation and Key Concepts

**PyTorch Lightning** is a lightweight PyTorch wrapper for high-performance AI research. It provides a structured way to organize PyTorch code, making it easier to scale models, implement best practices, and maintain clean codebases.

**Key Components:**
- **LightningModule**: Encapsulates the model architecture, training, validation, and testing steps.
- **Trainer**: Handles the training loop, including logging, checkpointing, and GPU management.
- **DataModule**: Manages data loading and preprocessing, ensuring reproducibility and modularity.
- **Callbacks**: Allow custom actions during training, such as early stopping or model checkpointing.
- **Logger**: Integrates with tools like TensorBoard for tracking metrics and visualizing training progress.
- **Profiler**: Assists in performance tuning by profiling different parts of the training process.
- **Multi GPU Training**: Facilitates training on multiple GPUs to speed up the process.

### 1.2 Real-World Applications and Relevance

PyTorch Lightning is widely used in both academic research and industry for:
- Rapid prototyping of deep learning models.
- Ensuring reproducibility and scalability in experiments.
- Streamlining the transition from research to production.
- Facilitating collaboration among teams by enforcing a standardized code structure.

**Example Application**: Image classification tasks like MNIST digit recognition, where models need to be trained efficiently and evaluated accurately.

### 1.3 Prerequisite Knowledge

To effectively complete this assignment, students should be familiar with:
- **Python Programming**: Basic syntax and data structures.
- **PyTorch**: Understanding of tensors, autograd, and basic neural network components.
- **Deep Learning Concepts**: Knowledge of Convolutional Neural Networks (CNNs), activation functions, loss functions, and optimization algorithms.

### 1.4 Mathematical Concepts and Formulas

Understanding the following mathematical concepts is essential:

- **Convolution Operation**: A fundamental operation in CNNs that helps in detecting features within input data. It involves sliding a filter over the input to produce feature maps.
  
- **Activation Functions**: Functions like ReLU (Rectified Linear Unit) introduce non-linearity into the model, allowing it to learn complex patterns. ReLU is defined as the maximum of zero and the input value:
  
  ReLU(x) = max(0, x)
  
- **Loss Functions**: For classification tasks, the Cross-Entropy Loss is commonly used. It measures the difference between the predicted probabilities and the actual labels:
  
  Loss = -Σ(y_i * log(predicted_y_i))
  
- **Optimization Algorithms**: Methods like Stochastic Gradient Descent (SGD) adjust the model's weights to minimize the loss function. The update rule for SGD is:
  
  Weight = Weight - Learning_Rate * Gradient_of_Loss

### 1.5 Specific Algorithms and Techniques

- **Convolutional Neural Networks (CNNs)**: Utilized for image data to capture spatial hierarchies.
- **PyTorch Lightning Components**: Leveraging LightningModule, Trainer, DataModule, Callbacks, etc., to structure the workflow.
- **Metrics Calculation**: Tracking accuracy and loss during training and validation.

### 1.6 Common Pitfalls and Misconceptions

- **Overfitting**: Not using validation data to monitor and prevent overfitting.
- **Improper Data Handling**: Neglecting data normalization or shuffling, leading to poor model performance.
- **Ignoring Hardware Utilization**: Failing to leverage GPUs for training, resulting in longer training times.
- **Misunderstanding Callbacks**: Incorrect implementation of callbacks can disrupt the training process.

In [3]:
import torch
from torch import nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, random_split
from torchvision import datasets, transforms

import pytorch_lightning as pl
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping
from pytorch_lightning.loggers import TensorBoardLogger

# Section 2.1: Data Module
class MNISTDataModule(pl.LightningDataModule):
    def __init__(self, data_dir='./', batch_size=64):
        super().__init__()
        self.data_dir = data_dir
        self.batch_size = batch_size
        self.transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.1307,), (0.3081,))
        ])

    def prepare_data(self):
        # Download MNIST data if not already present
        datasets.MNIST(self.data_dir, train=True, download=True)
        datasets.MNIST(self.data_dir, train=False, download=True)

    def setup(self, stage=None):
        # Split the data into training and validation sets
        if stage == 'fit' or stage is None:
            mnist_full = datasets.MNIST(self.data_dir, train=True, transform=self.transform)
            self.mnist_train, self.mnist_val = random_split(mnist_full, [55000, 5000])

        if stage == 'test' or stage is None:
            self.mnist_test = datasets.MNIST(self.data_dir, train=False, transform=self.transform)

    def train_dataloader(self):
        # Return the training DataLoader
        return DataLoader(self.mnist_train, batch_size=self.batch_size)

    def val_dataloader(self):
        # Return the validation DataLoader
        return DataLoader(self.mnist_val, batch_size=self.batch_size)

# Section 2.2: Lightning Module
class LitCNN(pl.LightningModule):
    def __init__(self):
        super().__init__()
        # Define the CNN architecture
        self.conv1 = nn.Conv2d(1, 32, 3, 1)
        self.conv2 = nn.Conv2d(32, 64, 3, 1)
        self.fc1 = nn.Linear(9216, 128)
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        # Define the forward pass
        x = self.conv1(x)
        x = F.relu(x)
        x = self.conv2(x)
        x = F.relu(x)
        x = F.max_pool2d(x, 2)
        x = torch.flatten(x, 1)
        x = self.fc1(x)
        x = F.relu(x)
        x = self.fc2(x)
        return F.log_softmax(x, dim=1)

    def training_step(self, batch, batch_idx):
        # Implement the training step
        x, y = batch
        y_hat = self(x)
        loss = F.cross_entropy(y_hat, y)
        self.log('val_loss', loss)
        return loss

    def configure_optimizers(self):
        # Configure the optimizer and learning rate scheduler if any
        return torch.optim.Adam(self.parameters(), lr=1e-3)

# Section 2.3: Training Setup
def main():
    # Initialize the DataModule
    dm = MNISTDataModule()

    # Initialize the LightningModule
    model = LitCNN()

    # Set up callbacks (e.g., ModelCheckpoint, EarlyStopping)
    checkpoint_callback = ModelCheckpoint(
        monitor='val_loss',
        dirpath='checkpoints/',
        filename='mnist-{epoch:02d}-{val_loss:.2f}',
        save_top_k=3,
        mode='min'
    )
    early_stop_callback = EarlyStopping(
        monitor='val_loss',
        patience=3,
        mode='min'
    )

    # Initialize the logger
    logger = TensorBoardLogger('tb_logs', name='mnist_model')

    # Initialize the Trainer with appropriate parameters
    trainer = Trainer(
        max_epochs=10,
        callbacks=[checkpoint_callback, early_stop_callback],
        logger=logger
    )

    # Start the training process
    trainer.fit(model, dm)

if __name__ == '__main__':
    main()

GPU available: False, used: False
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs

  | Name  | Type   | Params | Mode 
-----------------------------------------
0 | conv1 | Conv2d | 320    | train
1 | conv2 | Conv2d | 18.5 K | train
2 | fc1   | Linear | 1.2 M  | train
3 | fc2   | Linear | 1.3 K  | train
-----------------------------------------
1.2 M     Trainable params
0         Non-trainable params
1.2 M     Total params
4.800     Total estimated model params size (MB)
4         Modules in train mode
0         Modules in eval mode


Epoch 9: 100%|██████████| 860/860 [00:45<00:00, 18.92it/s, v_num=2]

`Trainer.fit` stopped: `max_epochs=10` reached.


Epoch 9: 100%|██████████| 860/860 [00:45<00:00, 18.92it/s, v_num=2]
