<a href="https://colab.research.google.com/github/anshika0601/pytorchz-learn/blob/main/Day14/PyTorch_Lightning.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

**PyTorch Lightning**

In [3]:
pip install pytorch-lightning


Collecting pytorch-lightning
  Downloading pytorch_lightning-2.5.4-py3-none-any.whl.metadata (20 kB)
Collecting torchmetrics>0.7.0 (from pytorch-lightning)
  Downloading torchmetrics-1.8.1-py3-none-any.whl.metadata (22 kB)
Collecting lightning-utilities>=0.10.0 (from pytorch-lightning)
  Downloading lightning_utilities-0.15.2-py3-none-any.whl.metadata (5.7 kB)
Downloading pytorch_lightning-2.5.4-py3-none-any.whl (829 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m829.2/829.2 kB[0m [31m26.8 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading lightning_utilities-0.15.2-py3-none-any.whl (29 kB)
Downloading torchmetrics-1.8.1-py3-none-any.whl (982 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m983.0/983.0 kB[0m [31m57.1 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: lightning-utilities, torchmetrics, pytorch-lightning
Successfully installed lightning-utilities-0.15.2 pytorch-lightning-2.5.4 torchmetrics-1.8.1


In [4]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
import pytorch_lightning as pl
import torch.optim as optim

In [5]:
# ---------------------------
# Lightning Model
# ---------------------------
class LitCNN(pl.LightningModule):
    def __init__(self, num_classes=10, lr=1e-3):
        super().__init__()
        self.save_hyperparameters()  # logs hyperparams automatically

        # Define architecture
        self.conv1 = nn.Conv2d(3, 32, 3, padding=1)
        self.bn1   = nn.BatchNorm2d(32)
        self.conv2 = nn.Conv2d(32, 64, 3, padding=1)
        self.bn2   = nn.BatchNorm2d(64)
        self.fc1   = nn.Linear(64*8*8, 128)
        self.bn3   = nn.BatchNorm1d(128)
        self.fc2   = nn.Linear(128, num_classes)

        self.dropout = nn.Dropout(0.5)
        self.lr = lr

    def forward(self, x):
        x = F.relu(self.bn1(self.conv1(x)))
        x = F.max_pool2d(x, 2)
        x = self.dropout(x)

        x = F.relu(self.bn2(self.conv2(x)))
        x = F.max_pool2d(x, 2)
        x = self.dropout(x)

        x = x.view(x.size(0), -1)
        x = F.relu(self.bn3(self.fc1(x)))
        x = self.dropout(x)
        return self.fc2(x)

    def training_step(self, batch, batch_idx):
        x, y = batch
        output = self(x)
        loss = F.cross_entropy(output, y)
        acc = (output.argmax(dim=1) == y).float().mean()
        self.log("train_loss", loss, prog_bar=True)
        self.log("train_acc", acc, prog_bar=True)
        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = F.cross_entropy(y_hat, y)
        acc = (y_hat.argmax(dim=1) == y).float().mean()
        self.log("val_loss", loss, prog_bar=True)
        self.log("val_acc", acc, prog_bar=True)

    def configure_optimizers(self):
        return optim.Adam(self.parameters(), lr=self.lr)


In [6]:
from torch.utils.data import DataLoader, random_split
from torchvision import datasets, transforms

class CIFAR10DataModule(pl.LightningDataModule):
    def __init__(self, data_dir="./data", batch_size=64):
        super().__init__()
        self.data_dir = data_dir
        self.batch_size = batch_size
        self.transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.5,), (0.5,))
        ])

    def setup(self, stage=None):
        dataset = datasets.CIFAR10(self.data_dir, train=True, download=True, transform=self.transform)
        self.train_set, self.val_set = random_split(dataset, [45000, 5000])
        self.test_set = datasets.CIFAR10(self.data_dir, train=False, download=True, transform=self.transform)

    def train_dataloader(self):
        return DataLoader(self.train_set, batch_size=self.batch_size, shuffle=True)

    def val_dataloader(self):
        return DataLoader(self.val_set, batch_size=self.batch_size)

    def test_dataloader(self):
        return DataLoader(self.test_set, batch_size=self.batch_size)


In [7]:
from pytorch_lightning import Trainer

# Initialize data + model
datamodule = CIFAR10DataModule()
model = LitCNN(num_classes=10)

# Trainer handles everything (device, logging, checkpointing, etc.)
trainer = Trainer(
    max_epochs=5,
    accelerator="auto",   # automatically picks GPU if available
    devices="auto",
)

# Train
trainer.fit(model, datamodule=datamodule)


INFO:pytorch_lightning.utilities.rank_zero:💡 Tip: For seamless cloud uploads and versioning, try installing [litmodels](https://pypi.org/project/litmodels/) to enable LitModelCheckpoint, which syncs automatically with the Lightning model registry.
INFO:pytorch_lightning.utilities.rank_zero:GPU available: True (cuda), used: True
INFO:pytorch_lightning.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO:pytorch_lightning.utilities.rank_zero:HPU available: False, using: 0 HPUs
100%|██████████| 170M/170M [00:03<00:00, 49.4MB/s]
INFO:pytorch_lightning.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO:pytorch_lightning.callbacks.model_summary:
  | Name    | Type        | Params | Mode 
------------------------------------------------
0 | conv1   | Conv2d      | 896    | train
1 | bn1     | BatchNorm2d | 64     | train
2 | conv2   | Conv2d      | 18.5 K | train
3 | bn2     | BatchNorm2d | 128    | train
4 | fc1     | Linear      | 524 K  | train
5 | bn3     | Batc

Sanity Checking: |          | 0/? [00:00<?, ?it/s]

Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

INFO:pytorch_lightning.utilities.rank_zero:`Trainer.fit` stopped: `max_epochs=5` reached.


In [8]:
from pytorch_lightning.callbacks import ModelCheckpoint , EarlyStopping

checkpoint_callback = ModelCheckpoint(
    monitor="val_loss",       # metric to monitor
    mode="min",               # "min" for loss, "max" for accuracy
    save_top_k=1,             # save best model only
    filename="best-checkpoint-{epoch:02d}-{val_loss:.2f}"
)
early_stop_callback = EarlyStopping(
    monitor="val_loss",
    patience=3,     # number of epochs to wait before stopping
    mode="min"      # minimize validation loss
)
from pytorch_lightning import Trainer

trainer = Trainer(
    max_epochs=5,
    accelerator="auto",
    devices="auto",
    callbacks=[checkpoint_callback, early_stop_callback],
    log_every_n_steps=10
)
datamodule = CIFAR10DataModule()
model = LitCNN(num_classes=10)

trainer.fit(model, datamodule=datamodule)


INFO:pytorch_lightning.utilities.rank_zero:GPU available: True (cuda), used: True
INFO:pytorch_lightning.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO:pytorch_lightning.utilities.rank_zero:HPU available: False, using: 0 HPUs
INFO:pytorch_lightning.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO:pytorch_lightning.callbacks.model_summary:
  | Name    | Type        | Params | Mode 
------------------------------------------------
0 | conv1   | Conv2d      | 896    | train
1 | bn1     | BatchNorm2d | 64     | train
2 | conv2   | Conv2d      | 18.5 K | train
3 | bn2     | BatchNorm2d | 128    | train
4 | fc1     | Linear      | 524 K  | train
5 | bn3     | BatchNorm1d | 256    | train
6 | fc2     | Linear      | 1.3 K  | train
7 | dropout | Dropout     | 0      | train
------------------------------------------------
545 K     Trainable params
0         Non-trainable params
545 K     Total params
2.182     Total estimated model params size (MB)
8        

Sanity Checking: |          | 0/? [00:00<?, ?it/s]

Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

INFO:pytorch_lightning.utilities.rank_zero:`Trainer.fit` stopped: `max_epochs=5` reached.
