In [None]:
from lightning import LightningModule
from lightning.pytorch import Trainer
from lightning.pytorch.tuner import Tuner

import torch
from torch import nn
from torch import optim
from torch.nn import functional as F
from torch.utils.data import DataLoader

from torchvision import datasets
from torchvision.transforms import v2

# Data Processing

In [None]:
transform = v2.Compose([
    v2.ToImage(),
    v2.ToDtype(torch.float32, scale=True),
])

In [None]:
# Download Data from open datasets.
data = {
    split: datasets.CIFAR10(
        root='../dataset/',
        download=True,
        transform=transform,
    ) for split in ['train', 'val']
}

# Model

In [None]:

class AlexNetCIFAR10(LightningModule):
    def __init__(self, batch_size=1):
        super(AlexNetCIFAR10, self).__init__()
        self.batch_size = batch_size
        self.features = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=5, stride=1, padding=2),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2),

            nn.Conv2d(64, 128, kernel_size=5, stride=1, padding=2),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2),

            nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1),
            nn.ReLU(inplace=True),

            nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2),
        )
        self.classifier = nn.Sequential(
            nn.Dropout(),
            nn.Linear(256 * 3 * 3, 4096),
            nn.ReLU(inplace=True),

            nn.Dropout(),
            nn.Linear(4096, 4096),
            nn.ReLU(inplace=True),

            nn.Linear(4096, 10)
        )

    def forward(self, x):
        x = self.features(x)
        x = torch.flatten(x, 1)
        y_hat = self.classifier(x)
        return y_hat

    def configure_optimizers(self):
        return optim.SGD(self.parameters(), lr=1e-3)

    def train_dataloader(self):
        return DataLoader(dataset=data['train'], batch_size=self.batch_size, num_workers=4)

    def training_step(self, train_batch, batch_idx):
        x, y = train_batch
        y_hat = self.forward(x)
        loss = F.cross_entropy(y_hat, y)
        self.log('train_loss', loss, on_epoch=True)
        return loss

    def val_dataloader(self):
        return DataLoader(dataset=data['val'], batch_size=self.batch_size, num_workers=4)

    def validation_step(self, val_batch, batch_idx):
        x, y = val_batch
        y_hat = self.forward(x)
        loss = F.cross_entropy(y_hat, y)
        self.log('val_loss', loss, on_epoch=True)

In [None]:
model = AlexNetCIFAR10()

# Training 

In [None]:
trainer = Trainer(max_epochs=5, log_every_n_steps=16)
tuner = Tuner(trainer)

In [None]:
tuner.scale_batch_size(model, mode="power")

In [None]:
trainer.fit(model)