# Imports

In [1]:
from functools import partial

import lightning as L
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchmetrics import Accuracy
from torchvision import datasets
from torchvision.transforms import v2

In [2]:
cuda_available = torch.cuda.is_available()
print(f"CUDA Available: {cuda_available}")

CUDA Available: True


In [3]:
torch.set_float32_matmul_precision('high')

# Hyperparameters

In [4]:
batch_size = 256
learning_rate = 0.01

# Data loading

In [5]:
transform = v2.Compose([
    v2.ToImage(),
    v2.ToDtype(torch.float32, scale=True),
])

# Load CIFAR10 dataset
trainset = datasets.CIFAR10(root="./data", train=True, download=True, transform=transform)
testset = datasets.CIFAR10(root="./data", train=False, download=True, transform=transform)
trainset, testset

(Dataset CIFAR10
     Number of datapoints: 50000
     Root location: ./data
     Split: Train
     StandardTransform
 Transform: Compose(
                  ToImage()
                  ToDtype(scale=True)
            ),
 Dataset CIFAR10
     Number of datapoints: 10000
     Root location: ./data
     Split: Test
     StandardTransform
 Transform: Compose(
                  ToImage()
                  ToDtype(scale=True)
            ))

In [6]:
trainset[0][0].shape

torch.Size([3, 32, 32])

In [7]:
# Split trainset into train and validation
trainset, valset = torch.utils.data.random_split(trainset, [0.8, 0.2])
trainloader = torch.utils.data.DataLoader(
    trainset,
    batch_size=batch_size,
    shuffle=True,
    num_workers=4, 
    pin_memory=True,
)
valloader = torch.utils.data.DataLoader(
    valset,
    batch_size=batch_size,
    shuffle=False,
    num_workers=4, 
    pin_memory=True,
)
testloader = torch.utils.data.DataLoader(
    testset,
    batch_size=batch_size,
    shuffle=False,
    num_workers=4, 
    pin_memory=True,
)

# Model definition

In [8]:
DefaultConv2d = partial(nn.Conv2d, kernel_size=3, padding=1, bias=False)
DownSample = partial(nn.Conv2d, kernel_size=1, stride=2, bias=False)

class ResidualBlock(nn.Module):
    def __init__(self, in_channels, out_channels, stride=1, downsample=None):
        super().__init__()
        self.net = nn.Sequential(
            DefaultConv2d(in_channels, out_channels, stride=stride),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(),
            DefaultConv2d(out_channels, out_channels),
            nn.BatchNorm2d(out_channels),
        )
        self.downsample = downsample
    def forward(self, x):
        residual = x
        x = self.net(x)
        if self.downsample:
            residual = self.downsample(residual)
        x += residual
        return F.relu(x)

class ResNet(nn.Module):
    def __init__(self, num_classes, input_channels=3, lpb=3, filters=32, shrinks=2, dropout=0.5):
        super().__init__()
        self.net = nn.Sequential()
        self.net.add_module('conv', DefaultConv2d(input_channels, filters))
        for i in range(shrinks):
            for j in range(lpb):
                self.net.add_module(f'resblock{j}_{filters}', ResidualBlock(filters, filters))
            self.net.add_module(f'resblock_down{i}', ResidualBlock(filters, filters*2, stride=2, downsample=DownSample(filters, filters*2)))
            filters *= 2
        for i in range(lpb):
            self.net.add_module(f'resblock{i}_{filters}', ResidualBlock(filters, filters))
        self.net.add_module('avgpool', nn.AdaptiveAvgPool2d(1))
        self.net.add_module('flatten', nn.Flatten())
        self.net.add_module('linear', nn.Linear(filters, filters*2))
        self.net.add_module('relu', nn.ReLU())
        self.net.add_module('dropout', nn.Dropout(dropout))
        self.net.add_module('output', nn.Linear(filters*2, num_classes))
    def forward(self, x):
        return self.net(x)


In [9]:
def init_weights(m):
    if type(m) == nn.Linear or type(m) == nn.Conv2d:
        nn.init.kaiming_normal_(m.weight, nonlinearity='relu')

In [10]:
model = ResNet(num_classes=10, lpb=5, filters=64, shrinks=1, dropout=0.2)
model.apply(init_weights)
print(model)

ResNet(
  (net): Sequential(
    (conv): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (resblock0_64): ResidualBlock(
      (net): Sequential(
        (0): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): ReLU()
        (3): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (4): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (resblock1_64): ResidualBlock(
      (net): Sequential(
        (0): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): ReLU()
        (3): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (4): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True

# Lightning

In [11]:
class CIFARLightning(L.LightningModule):
    def __init__(self, model, lr=1e-3):
        super().__init__()
        self.model = model
        self.lr = lr
        self.accuracy = Accuracy(task='multiclass', num_classes=10)

    def forward(self, x):
        return self.model(x)
    
    def training_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self.model(x)
        loss = F.cross_entropy(y_hat, y)
        self.log("train_loss", loss)
        return loss
    
    def validation_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self.model(x)
        loss = F.cross_entropy(y_hat, y)
        self.log("val_loss", loss)
        self.log("val_acc", self.accuracy(y_hat, y), on_epoch=True, prog_bar=True)
        self.log('learning_rate', self.lr_schedulers().get_last_lr()[0], on_epoch=True, prog_bar=True)
    
    def test_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self.model(x)
        loss = F.cross_entropy(y_hat, y)
        self.log("test_loss", loss)
        self.log("test_acc", self.accuracy(y_hat, y))
    
    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.model.parameters(), lr=self.lr)
        sch = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', factor=0.5, patience=5)
        return {
            "optimizer": optimizer,
            "lr_scheduler": {
                "scheduler": sch,
                "interval": "epoch",
                "monitor": "val_acc"
            }
        }

In [12]:
early_stop = L.pytorch.callbacks.EarlyStopping(monitor='val_acc', patience=10, mode='max')

model = CIFARLightning(model, learning_rate)
trainer = L.Trainer(callbacks=[early_stop])
trainer.fit(model, trainloader, valloader)
trainer.test(model, testloader)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs


/home/chandon/miniconda3/envs/pyto/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/logger_connector/logger_connector.py:76: Starting from v1.9.0, `tensorboardX` has been removed as a dependency of the `lightning.pytorch` package, due to potential conflicts with other packages in the ML ecosystem. For this reason, `logger=True` will use `CSVLogger` as the default logger, unless the `tensorboard` or `tensorboardX` packages are found. Please `pip install lightning[extra]` or one of them to enable TensorBoard support by default
/home/chandon/miniconda3/envs/pyto/lib/python3.11/site-packages/lightning/pytorch/loops/utilities.py:73: `max_epochs` was not set. Setting it to 1000 epochs. To train without an epoch limit, set `max_epochs=-1`.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name     | Type               | Params | Mode 
--------------------------------------------------------
0 | model    | ResNet             | 2.1 M  | train
1 | accuracy | MulticlassAccuracy | 0 

Epoch 59: 100%|██████████| 157/157 [00:20<00:00,  7.65it/s, v_num=25, val_acc=0.859, learning_rate=0.000625]

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]



Testing DataLoader 0: 100%|██████████| 40/40 [00:01<00:00, 30.67it/s]
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
       Test metric             DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
        test_acc            0.8544999957084656
        test_loss           1.2127819061279297
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────


[{'test_loss': 1.2127819061279297, 'test_acc': 0.8544999957084656}]