In [1]:
import argparse
import math
import random
import os
from typing import Tuple

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.backends.cudnn as cudnn
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

# -----------------------
# Reproducibility helpers
# -----------------------
def set_seed(seed: int = 42):
    random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    cudnn.deterministic = False
    cudnn.benchmark = True  # faster on GPUs for conv nets

# -----------------------
# Data
# -----------------------
def get_cifar10_loaders(batch_size: int = 128, num_workers: int = 4) -> Tuple[DataLoader, DataLoader]:
    mean = (0.4914, 0.4822, 0.4465)
    std  = (0.2470, 0.2435, 0.2616)

    train_tf = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize(mean, std),
    ])
    test_tf = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(mean, std),
    ])

    train_ds = datasets.CIFAR10(root="./data", train=True, download=True, transform=train_tf)
    test_ds  = datasets.CIFAR10(root="./data", train=False, download=True, transform=test_tf)

    train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True,  num_workers=num_workers, pin_memory=True)
    test_loader  = DataLoader(test_ds,  batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=True)
    return train_loader, test_loader

In [2]:
# -----------------------
# Train / Eval
# -----------------------
@torch.no_grad()
def evaluate(model, loader, device):
    model.eval()
    total, correct, running_loss = 0, 0, 0.0
    criterion = nn.CrossEntropyLoss()
    for images, labels in loader:
        images, labels = images.to(device, non_blocking=True), labels.to(device, non_blocking=True)
        outputs = model(images)
        loss = criterion(outputs, labels)
        running_loss += loss.item() * images.size(0)
        preds = outputs.argmax(dim=1)
        correct += (preds == labels).sum().item()
        total += labels.size(0)
    return running_loss / total, correct / total

def train_one_epoch(model, loader, optimizer, device):
    model.train()
    criterion = nn.CrossEntropyLoss()
    running_loss, correct, total = 0.0, 0, 0
    for images, labels in loader:
        images, labels = images.to(device, non_blocking=True), labels.to(device, non_blocking=True)

        optimizer.zero_grad(set_to_none=True)
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item() * images.size(0)
        preds = outputs.argmax(dim=1)
        correct += (preds == labels).sum().item()
        total += labels.size(0)

    return running_loss / total, correct / total

In [22]:
class AlexNetCIFAR(nn.Module):
    """
    AlexNet adapted for 32x32 inputs:
    - Use 3x3 convs (stride 1) instead of 11x11/5x5
    - Slightly reduced channels to fit CIFAR-10 scale
    """
    def __init__(self, num_classes: int = 10, dropout: float = 0.5):
        super().__init__()
        self.features = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=7, stride=1, padding=1),  # 32x32 -> 32x32
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),                  # 32 -> 16

            nn.Conv2d(64, 192, kernel_size=5, padding=1),           # 16 -> 16
            nn.BatchNorm2d(192),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),                  # 16 -> 8

            nn.Conv2d(192, 384, kernel_size=5, padding=1),          # 8 -> 8
            nn.BatchNorm2d(384),
            nn.ReLU(inplace=True),

            nn.Conv2d(384, 256, kernel_size=3, padding=1),          # 8 -> 8
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),

            nn.Conv2d(256, 256, kernel_size=3, padding=1),          # 8 -> 8
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),                  # 8 -> 4
        )

        self.head = nn.Sequential(
            nn.Conv2d(256, num_classes, kernel_size=1),
            nn.AdaptiveAvgPool2d(1),
            nn.Flatten(),
        )

    def forward(self, x):
        x = self.features(x)
        return self.head(x)

In [32]:
class VGGNetCIFAR(nn.Module):
    """
    Replace
    """
    def __init__(self, num_classes: int = 10, dropout: float = 0.5):
        super().__init__()
        self.features = nn.Sequential(
            nn.Conv2d(3, 8, kernel_size=3, stride=1, padding=1),  # 32x32 -> 32x32
            nn.BatchNorm2d(8),
            nn.ReLU(inplace=True),
            nn.Conv2d(8, 16, kernel_size=3, stride=1, padding=1),  # 32x32 -> 32x32
            nn.BatchNorm2d(16),
            nn.ReLU(inplace=True),
            nn.Conv2d(16, 64, kernel_size=3, stride=1, padding=1),  # 32x32 -> 32x32
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),                  # 32 -> 16

            nn.Conv2d(64, 32, kernel_size=3, padding=1),           # 16 -> 16
            nn.BatchNorm2d(32),
            nn.ReLU(inplace=True),
            nn.Conv2d(32, 192, kernel_size=3, padding=1),           # 16 -> 16
            nn.BatchNorm2d(192),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),                  # 16 -> 8

            nn.Conv2d(192, 64, kernel_size=3, padding=1),          # 8 -> 8
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 384, kernel_size=3, padding=1),          # 8 -> 8
            nn.BatchNorm2d(384),
            nn.ReLU(inplace=True),

            nn.Conv2d(384, 256, kernel_size=3, padding=1),          # 8 -> 8
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),

            nn.Conv2d(256, 256, kernel_size=3, padding=1),          # 8 -> 8
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),                  # 8 -> 4
        )

        self.head = nn.Sequential(
            nn.Conv2d(256, num_classes, kernel_size=1),
            nn.AdaptiveAvgPool2d(1),
            nn.Flatten(),
        )

    def forward(self, x):
        x = self.features(x)
        return self.head(x)

In [33]:
model = AlexNetCIFAR()
total_params = sum(p.numel() for p in model.parameters())
print(f"Total number of parameters in ALEXNET: {total_params}")

model = VGGNetCIFAR()
total_params = sum(p.numel() for p in model.parameters())
print(f"Total number of parameters in VGGNet: {total_params}")

Total number of parameters in ALEXNET: 3639882
Total number of parameters in VGGNet: 1896522


In [27]:
epochs = 10
batch_size = 128
lr = 0.001
weight_decay = 5e-4
num_workers = 2
seed = 42

In [6]:
set_seed(seed)
device = "cuda" if torch.cuda.is_available() else "cpu"
train_loader, test_loader = get_cifar10_loaders(batch_size=batch_size, num_workers=num_workers)

100%|██████████| 170M/170M [00:04<00:00, 41.6MB/s]


In [34]:
def train_and_evaluate(model_name, num_classes=10):
  if model_name == 'vggnet':
    model = VGGNetCIFAR()
  else:
    model = AlexNetCIFAR()

  model.to(device)
  optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)

  best_acc = 0.0
  for epoch in range(1, epochs + 1):
      train_loss, train_acc = train_one_epoch(model, train_loader, optimizer, device)
      val_loss,  val_acc  = evaluate(model, test_loader, device)

      if val_acc > best_acc:
          best_acc = val_acc
          os.makedirs("checkpoints", exist_ok=True)
          torch.save({"model": model.state_dict(),
                      "epoch": epoch,
                      "acc": best_acc
                      },
                      f"checkpoints/{model_name}_best.pt")

      print(f"Epoch {epoch:02d}/{epochs} | "
            f"Train Loss {train_loss:.4f} Acc {train_acc*100:.2f}% | "
            f"Val Loss {val_loss:.4f} Acc {val_acc*100:.2f}% | "
            f"Best Val Acc {best_acc*100:.2f}%")

  # Final evaluation
  test_loss, test_acc = evaluate(model, test_loader, device)
  print(f"\nFinal {model_name.upper()} Test Accuracy: {test_acc*100:.2f}% (loss {test_loss:.4f})")

In [36]:
train_and_evaluate(model_name='alexnet')

Epoch 01/10 | Train Loss 1.4239 Acc 47.82% | Val Loss 1.4908 Acc 49.50% | Best Val Acc 49.50%
Epoch 02/10 | Train Loss 1.0814 Acc 61.55% | Val Loss 0.9764 Acc 66.05% | Best Val Acc 66.05%
Epoch 03/10 | Train Loss 0.9357 Acc 67.13% | Val Loss 0.9969 Acc 66.08% | Best Val Acc 66.08%
Epoch 04/10 | Train Loss 0.8378 Acc 70.73% | Val Loss 0.8527 Acc 70.26% | Best Val Acc 70.26%
Epoch 05/10 | Train Loss 0.7738 Acc 72.99% | Val Loss 0.7195 Acc 75.41% | Best Val Acc 75.41%
Epoch 06/10 | Train Loss 0.7237 Acc 75.10% | Val Loss 0.7873 Acc 73.21% | Best Val Acc 75.41%
Epoch 07/10 | Train Loss 0.6834 Acc 76.06% | Val Loss 0.8066 Acc 73.52% | Best Val Acc 75.41%
Epoch 08/10 | Train Loss 0.6482 Acc 77.61% | Val Loss 0.7547 Acc 75.25% | Best Val Acc 75.41%
Epoch 09/10 | Train Loss 0.6223 Acc 78.45% | Val Loss 0.6679 Acc 77.84% | Best Val Acc 77.84%
Epoch 10/10 | Train Loss 0.5892 Acc 79.55% | Val Loss 0.6753 Acc 77.54% | Best Val Acc 77.84%

Final ALEXNET Test Accuracy: 77.54% (loss 0.6753)


In [35]:
train_and_evaluate(model_name='vggnet')

Epoch 01/10 | Train Loss 1.4492 Acc 46.38% | Val Loss 1.2956 Acc 53.46% | Best Val Acc 53.46%
Epoch 02/10 | Train Loss 1.0078 Acc 63.93% | Val Loss 1.0559 Acc 64.58% | Best Val Acc 64.58%
Epoch 03/10 | Train Loss 0.8272 Acc 70.70% | Val Loss 1.1721 Acc 62.95% | Best Val Acc 64.58%
Epoch 04/10 | Train Loss 0.7122 Acc 75.27% | Val Loss 0.7277 Acc 75.17% | Best Val Acc 75.17%
Epoch 05/10 | Train Loss 0.6487 Acc 77.58% | Val Loss 0.8033 Acc 72.76% | Best Val Acc 75.17%
Epoch 06/10 | Train Loss 0.5997 Acc 79.48% | Val Loss 0.7025 Acc 77.10% | Best Val Acc 77.10%
Epoch 07/10 | Train Loss 0.5573 Acc 80.76% | Val Loss 0.6816 Acc 77.59% | Best Val Acc 77.59%
Epoch 08/10 | Train Loss 0.5205 Acc 82.14% | Val Loss 0.5812 Acc 80.58% | Best Val Acc 80.58%
Epoch 09/10 | Train Loss 0.4951 Acc 83.03% | Val Loss 0.6399 Acc 79.63% | Best Val Acc 80.58%
Epoch 10/10 | Train Loss 0.4674 Acc 84.05% | Val Loss 0.6458 Acc 79.79% | Best Val Acc 80.58%

Final VGGNET Test Accuracy: 79.79% (loss 0.6458)


# Conclusion

Replacing 7x7 and 5x5 convolutions with 3x3 convolutions, the network has 2x less parameters while maintaining better performance than AlexNet.