In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.models as models
import torchvision
import torchvision.transforms as transforms
from torchvision.transforms import autoaugment
import torch.optim as optim
import torch.nn.init as init

In [2]:
def get_cifar10_loaders(
    batch_size: int         = 256,   # use the VRAM you freed with AMP
    num_workers: int        = 8,     # 4‑8 is a sweet‑spot on most GPUs
    pin_memory: bool        = True,
    persistent_workers: bool = True, # keeps the workers alive between epochs
):
    transform_train = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(p=0.5),
        autoaugment.AutoAugment(autoaugment.AutoAugmentPolicy.CIFAR10),
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465),
                             (0.2023, 0.1994, 0.2010)),
    ])

    transform_test = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465),
                             (0.2023, 0.1994, 0.2010)),
    ])

    trainset = torchvision.datasets.CIFAR10(
        root='./data', train=True, download=True, transform=transform_train
    )
    trainloader = torch.utils.data.DataLoader(
        trainset,
        batch_size=batch_size,
        shuffle=True,
        num_workers=num_workers,
        pin_memory=pin_memory,
        persistent_workers=persistent_workers,
        drop_last=True,                # keeps every batch full for BN / AMP
    )

    testset = torchvision.datasets.CIFAR10(
        root='./data', train=False, download=True, transform=transform_test
    )
    testloader = torch.utils.data.DataLoader(
        testset,
        batch_size=batch_size,
        shuffle=False,
        num_workers=num_workers,
        pin_memory=pin_memory,
        persistent_workers=persistent_workers,
    )

    return trainloader, testloader

In [3]:
class LayerNorm2d(nn.Module):
    def __init__(self, channels):
        super().__init__()
        self.norm = nn.LayerNorm(channels)

    def forward(self, x):
        # x: [B, C, H, W]
        x = x.permute(0, 2, 3, 1) # [B, H, W, C] 
        x = self.norm(x)
        x = x.permute(0, 3, 1, 2) # [B, C, H, W]
        return x

In [4]:
class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, in_planes, planes, stride=1):
        super(BasicBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != planes:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(planes)
            )

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += self.shortcut(x)
        out = F.relu(out)
        return out

In [5]:
class ResNet110(nn.Module):
    def __init__(self, num_blocks=[18, 18, 18], num_classes=10):
        super(ResNet110, self).__init__()
        self.in_planes = 16

        self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(16)
        
        # 3 stages with 18 blocks each (18*2*3 + 2 = 110 layers)
        self.layer1 = self._make_layer(16, num_blocks[0], stride=1)
        self.layer2 = self._make_layer(32, num_blocks[1], stride=2)
        self.layer3 = self._make_layer(64, num_blocks[2], stride=2)
        
        self.linear = nn.Linear(64, num_classes)
        
        # Weight initialization
        self._initialize_weights()

    def _make_layer(self, planes, num_blocks, stride):
        strides = [stride] + [1] * (num_blocks - 1)
        layers = []
        for stride in strides:
            layers.append(BasicBlock(self.in_planes, planes, stride))
            self.in_planes = planes
        return nn.Sequential(*layers)
    
    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
            elif isinstance(m, nn.BatchNorm2d):
                init.constant_(m.weight, 1)
                init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                init.normal_(m.weight, 0, 0.01)
                init.constant_(m.bias, 0)

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        out = F.avg_pool2d(out, out.size()[3])
        out = out.view(out.size(0), -1)
        out = self.linear(out)
        return out

In [6]:
def Net(num_classes=10):
    """Constructs a ResNet-110 model for CIFAR datasets."""
    return ResNet110(num_blocks=[18, 18, 18], num_classes=num_classes)

In [7]:
import time
import torch
from torch import nn, optim


def train_model(
    net: torch.nn.Module,
    get_cifar10_loaders,  # existing dataloader factory
    evaluate_model,       # existing eval helper
    batch_size: int = 256,
    epochs: int = 300,
    lr: float = 1e-3,
    eval_interval: int = 10,
    checkpoint_path: str = "model_checkpoint_recursion16.pth",
):
    """Speed‑tuned training loop with per‑epoch timing & throughput report."""

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # compile once for graph fusion
    net = torch.compile(net, mode="reduce-overhead").to(device)

    # data
    trainloader, testloader = get_cifar10_loaders(
        batch_size=batch_size,
        num_workers=8,
        pin_memory=True,
        persistent_workers=True,
    )

    # optimisation
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(net.parameters(), lr=lr)
    scheduler = optim.lr_scheduler.CosineAnnealingLR(
        optimizer, T_max=epochs, eta_min=1e-6
    )

    # new AMP API
    scaler = torch.amp.GradScaler(enabled=device.type == "cuda")

    best_test_acc = 0.0
    num_samples = len(trainloader.dataset)

    for epoch in range(epochs):
        start_time = time.perf_counter()

        net.train()
        running_loss = 0.0

        for inputs, targets in trainloader:
            inputs = inputs.to(device, non_blocking=True)
            targets = targets.to(device, non_blocking=True)

            optimizer.zero_grad(set_to_none=True)

            # forward – mixed precision
            with torch.amp.autocast(device_type="cuda"):
                outputs = net(inputs)
                loss = criterion(outputs, targets)

            # backward + step
            scaler.scale(loss).backward()
            scaler.unscale_(optimizer)
            torch.nn.utils.clip_grad_norm_(net.parameters(), max_norm=0.5)
            scaler.step(optimizer)
            scaler.update()

            running_loss += loss.item()

        scheduler.step()

        # ensure accurate timing for CUDA
        torch.cuda.synchronize()
        epoch_time = time.perf_counter() - start_time
        avg_loss = running_loss / len(trainloader)
        throughput = num_samples / epoch_time

        print(
            f"Epoch {epoch + 1:03d}/{epochs} ─ {epoch_time:.2f}s | "
            f"loss {avg_loss:.3f} | {throughput:.1f} img/s"
        )

        # evaluation & checkpointing
        if (epoch + 1) % eval_interval == 0 or epoch == epochs - 1:
            print(f"\nEvaluating at epoch {epoch + 1} …")
            train_acc = evaluate_model(net, trainloader, criterion, "Train", device)
            test_acc = evaluate_model(net, testloader, criterion, "Test", device)
            print(
                f"Current learning rate: {optimizer.param_groups[0]['lr']:.2e}\n"
            )

            if test_acc > best_test_acc:
                best_test_acc = test_acc
                torch.save(
                    {
                        "epoch": epoch + 1,
                        "model_state_dict": net.state_dict(),
                        "optimizer_state_dict": optimizer.state_dict(),
                        "scheduler_state_dict": scheduler.state_dict(),
                        "test_acc": test_acc,
                    },
                    checkpoint_path,
                )
                print(
                    f"New best model saved with test accuracy: {test_acc:.2f}%"
                )

    print("Finished Training")
    print(f"Best test accuracy achieved: {best_test_acc:.2f}%")
    return best_test_acc

In [8]:
def evaluate_model(
    net: torch.nn.Module,
    dataloader,
    criterion,
    dataset_name: str = "",
    device=None,
):
    """Run inference on *dataloader* under mixed precision and report accuracy."""

    device = device or torch.device("cuda" if torch.cuda.is_available() else "cpu")

    was_training = net.training  # remember current mode
    net.eval()

    correct = total = 0
    total_loss = 0.0

    with torch.no_grad(), torch.amp.autocast(device_type="cuda"):
        for images, labels in dataloader:
            images = images.to(device, non_blocking=True)
            labels = labels.to(device, non_blocking=True)

            outputs = net(images)
            total_loss += criterion(outputs, labels).item()

            _, predicted = torch.max(outputs, 1)
            total   += labels.size(0)
            correct += (predicted == labels).sum().item()

    accuracy = 100.0 * correct / total
    avg_loss = total_loss / len(dataloader)
    print(f"{dataset_name} Accuracy: {accuracy:.2f}%")
    print(f"{dataset_name} Average Loss: {avg_loss:.4f}")

    if was_training:
        net.train()

    return accuracy

In [9]:
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

# Example usage:
model = Net()
print(f"Total Trainable Parameters: {count_parameters(model):,}")

Total Trainable Parameters: 1,730,714


In [10]:
torch.backends.cudnn.benchmark = True
net = Net()
train_model(net, get_cifar10_loaders, evaluate_model)

Epoch [1] loss: 1.967
Epoch [2] loss: 1.638
Epoch [3] loss: 1.438
Epoch [4] loss: 1.284
Epoch [5] loss: 1.156
Epoch [6] loss: 1.051
Epoch [7] loss: 0.976
Epoch [8] loss: 0.911
Epoch [9] loss: 0.864
Epoch [10] loss: 0.818

Evaluating at epoch 10 …
Train Accuracy: 67.34%
Train Average Loss: 0.9495
Test Accuracy: 73.90%
Test Average Loss: 0.7733
Current learning rate: 9.97e-04

New best model saved with test accuracy: 73.90%
Epoch [11] loss: 0.785
Epoch [12] loss: 0.748
Epoch [13] loss: 0.719
Epoch [14] loss: 0.700
Epoch [15] loss: 0.673
Epoch [16] loss: 0.649
Epoch [17] loss: 0.637
Epoch [18] loss: 0.613
Epoch [19] loss: 0.595
Epoch [20] loss: 0.585

Evaluating at epoch 20 …
Train Accuracy: 75.33%
Train Average Loss: 0.7251
Test Accuracy: 79.45%
Test Average Loss: 0.6370
Current learning rate: 9.89e-04

New best model saved with test accuracy: 79.45%
Epoch [21] loss: 0.567
Epoch [22] loss: 0.551
Epoch [23] loss: 0.537
Epoch [24] loss: 0.537
Epoch [25] loss: 0.522
Epoch [26] loss: 0.508
E