# ⚡ Intro to PyTorch Lightning (MNIST Classification)

Yo, bro! Ready to classify some handwritten digits with **PyTorch Lightning**? 😎
We’re training a neural net on the MNIST dataset to recognize numbers 0–9 — super clean and fast! ⚡
This is step 8 of your learning path, fresh off your YOLOv8 adventure. Let’s make this model a digit-spotting champ! 🎉

In [None]:
# ✅ Step 1: Install dependencies
!pip install -q pytorch-lightning torchvision torchmetrics

In [2]:
# ✅ Step 2: Imports
import torch
from torch import nn
from torchvision import transforms, datasets
from torch.utils.data import DataLoader
import pytorch_lightning as pl
import torchmetrics
print("🎉 Libraries loaded — ready to roll! ⚡")

🎉 Libraries loaded — ready to roll! ⚡


In [3]:
# ✅ Step 3: DataModule for MNIST
class MNISTDataModule(pl.LightningDataModule):
    def __init__(self, batch_size=64):
        super().__init__()
        self.batch_size = batch_size
        self.transform = transforms.Compose([transforms.ToTensor()])

    def setup(self, stage=None):
        self.mnist_train = datasets.MNIST(root=".", train=True, download=True, transform=self.transform)
        self.mnist_val = datasets.MNIST(root=".", train=False, download=True, transform=self.transform)

    def train_dataloader(self):
        return DataLoader(self.mnist_train, batch_size=self.batch_size, shuffle=True)

    def val_dataloader(self):
        return DataLoader(self.mnist_val, batch_size=self.batch_size)
print("📚 MNIST DataModule set up — digits ready to load!")

📚 MNIST DataModule set up — digits ready to load!


In [4]:
# ✅ Step 4: LightningModule (Model + Training logic)
class LitModel(pl.LightningModule):
    def __init__(self):
        super().__init__()
        self.model = nn.Sequential(
            nn.Flatten(),
            nn.Linear(28*28, 128),
            nn.ReLU(),
            nn.Linear(128, 10)
        )
        self.loss = nn.CrossEntropyLoss()
        self.accuracy = torchmetrics.Accuracy(task="multiclass", num_classes=10)

    def forward(self, x):
        return self.model(x)

    def training_step(self, batch, batch_idx):
        x, y = batch
        preds = self(x)
        loss = self.loss(preds, y)
        acc = self.accuracy(preds.softmax(dim=-1), y)
        self.log("train_loss", loss, on_step=False, on_epoch=True)
        self.log("train_acc", acc, on_step=False, on_epoch=True)
        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch
        preds = self(x)
        loss = self.loss(preds, y)
        acc = self.accuracy(preds.softmax(dim=-1), y)
        self.log("val_loss", loss, on_step=False, on_epoch=True)
        self.log("val_acc", acc, on_step=False, on_epoch=True)

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=1e-3)
print("🧠 LightningModule ready — model’s set to learn digits!")

🧠 LightningModule ready — model’s set to learn digits!


In [5]:
# ✅ Step 5: Train
mnist = MNISTDataModule()
model = LitModel()
trainer = pl.Trainer(max_epochs=3, accelerator="auto")
trainer.fit(model, datamodule=mnist)
print("🎉 Training complete — our model’s a digit-classifying beast! ⚡")

INFO:pytorch_lightning.utilities.rank_zero:💡 Tip: For seamless cloud uploads and versioning, try installing [litmodels](https://pypi.org/project/litmodels/) to enable LitModelCheckpoint, which syncs automatically with the Lightning model registry.
INFO:pytorch_lightning.utilities.rank_zero:GPU available: False, used: False
INFO:pytorch_lightning.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO:pytorch_lightning.utilities.rank_zero:HPU available: False, using: 0 HPUs
100%|██████████| 9.91M/9.91M [00:00<00:00, 37.2MB/s]
100%|██████████| 28.9k/28.9k [00:00<00:00, 1.09MB/s]
100%|██████████| 1.65M/1.65M [00:00<00:00, 8.68MB/s]
100%|██████████| 4.54k/4.54k [00:00<00:00, 8.74MB/s]
INFO:pytorch_lightning.callbacks.model_summary:
  | Name     | Type               | Params | Mode 
--------------------------------------------------------
0 | model    | Sequential         | 101 K  | train
1 | loss     | CrossEntropyLoss   | 0      | train
2 | accuracy | MulticlassAccuracy | 0     

Sanity Checking: |          | 0/? [00:00<?, ?it/s]

Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

INFO:pytorch_lightning.utilities.rank_zero:`Trainer.fit` stopped: `max_epochs=3` reached.


🎉 Training complete — our model’s a digit-classifying beast! ⚡


# 📚 Tips for Having Fun
- Try increasing `max_epochs` in Step 5 to 5 for better accuracy (takes longer).
- Add a test step to evaluate on the MNIST test set (use `trainer.test()`).
- Play with the model in Step 4: Add more layers (e.g., another `nn.Linear(128, 64)` + `nn.ReLU()`).
- Check out PyTorch Lightning’s docs (https://pytorch-lightning.readthedocs.io/) for more tricks!

# 🚀 What’s Next?
- Save this as your eighth notebook in your learning path.
- Combine with your YOLOv8 notebook (step 7) for a vision project (e.g., digit detection + object detection).
- Try fine-tuning a bigger model or dataset (e.g., CIFAR-10) for more fun!
