# Image Classification with PyTorch Lightning
This notebook demonstrates how to train a ResNet-50 model on the CIFAR-10 dataset using [PyTorch Lightning](https://www.pytorchlightning.ai/).

## Setup
Install and import required libraries.

In [1]:
import torch
import torchvision
from torchvision import transforms, datasets, models
import pytorch_lightning as pl
from torch.utils.data import random_split, DataLoader
import torch.nn as nn
import torchmetrics


ModuleNotFoundError: No module named 'pytorch_lightning'

## Data Loading
Download the CIFAR-10 dataset and split it into training, validation, and test sets.

In [None]:

class CIFAR10DataModule(pl.LightningDataModule):
    def __init__(self, data_dir='./data', batch_size=64):
        super().__init__()
        self.data_dir = data_dir
        self.batch_size = batch_size
        self.transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        ])

    def prepare_data(self):
        datasets.CIFAR10(self.data_dir, train=True, download=True)
        datasets.CIFAR10(self.data_dir, train=False, download=True)

    def setup(self, stage=None):
        if stage == 'fit' or stage is None:
            full_train = datasets.CIFAR10(self.data_dir, train=True, transform=self.transform)
            self.train_set, self.val_set = random_split(full_train, [45000, 5000])
        if stage == 'test' or stage is None:
            self.test_set = datasets.CIFAR10(self.data_dir, train=False, transform=self.transform)

    def train_dataloader(self):
        return DataLoader(self.train_set, batch_size=self.batch_size, shuffle=True)

    def val_dataloader(self):
        return DataLoader(self.val_set, batch_size=self.batch_size)

    def test_dataloader(self):
        return DataLoader(self.test_set, batch_size=self.batch_size)


## Model
Define a LightningModule wrapping a ResNet-50 model for classification.

In [None]:

class LitResNet(pl.LightningModule):
    def __init__(self, lr=1e-3):
        super().__init__()
        self.save_hyperparameters()
        self.model = models.resnet50(weights=None, num_classes=10)
        self.criterion = nn.CrossEntropyLoss()
        self.accuracy = torchmetrics.classification.MulticlassAccuracy(num_classes=10)

    def forward(self, x):
        return self.model(x)

    def training_step(self, batch, batch_idx):
        x, y = batch
        logits = self.forward(x)
        loss = self.criterion(logits, y)
        acc = self.accuracy(logits.softmax(dim=-1), y)
        self.log('train_loss', loss)
        self.log('train_acc', acc, prog_bar=True)
        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch
        logits = self.forward(x)
        loss = self.criterion(logits, y)
        acc = self.accuracy(logits.softmax(dim=-1), y)
        self.log('val_loss', loss, prog_bar=True)
        self.log('val_acc', acc, prog_bar=True)

    def test_step(self, batch, batch_idx):
        x, y = batch
        logits = self.forward(x)
        loss = self.criterion(logits, y)
        acc = self.accuracy(logits.softmax(dim=-1), y)
        self.log('test_loss', loss)
        self.log('test_acc', acc)

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=self.hparams.lr)


## Training
Create the datamodule and model, then fit using the PyTorch Lightning trainer.

In [None]:

data_module = CIFAR10DataModule()
model = LitResNet()
trainer = pl.Trainer(max_epochs=1)
trainer.fit(model, data_module)


## Evaluation
Evaluate the trained model on the test set and display metrics.

In [None]:

results = trainer.test(model, datamodule=data_module)
print(results)


You can visualize the results further using a confusion matrix.

In [None]:

import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay

preds, targets = [], []
for batch in data_module.test_dataloader():
    x, y = batch
    logits = model(x)
    preds.append(logits.argmax(dim=1))
    targets.append(y)

preds = torch.cat(preds)
targets = torch.cat(targets)
cm = confusion_matrix(targets.numpy(), preds.numpy())
ConfusionMatrixDisplay(cm).plot()
plt.show()
