In [1]:
import argparse
import math
import random
import os
from typing import Tuple

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.backends.cudnn as cudnn
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

# -----------------------
# Reproducibility helpers
# -----------------------
def set_seed(seed: int = 42):
    random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    cudnn.deterministic = False
    cudnn.benchmark = True  # faster on GPUs for conv nets

# -----------------------
# Data
# -----------------------
def get_cifar10_loaders(batch_size: int = 128, num_workers: int = 4) -> Tuple[DataLoader, DataLoader]:
    mean = (0.4914, 0.4822, 0.4465)
    std  = (0.2470, 0.2435, 0.2616)

    train_tf = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize(mean, std),
    ])
    test_tf = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(mean, std),
    ])

    train_ds = datasets.CIFAR10(root="./data", train=True, download=True, transform=train_tf)
    test_ds  = datasets.CIFAR10(root="./data", train=False, download=True, transform=test_tf)

    train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True,  num_workers=num_workers, pin_memory=True)
    test_loader  = DataLoader(test_ds,  batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=True)
    return train_loader, test_loader

In [2]:
class AlexNetPlainCIFAR(nn.Module):
    """
    AlexNet adapted for 32x32 inputs:
    - Use 3x3 convs (stride 1) instead of 11x11/5x5
    - Slightly reduced channels to fit CIFAR-10 scale
    """
    def __init__(self, num_classes: int = 10, dropout: float = 0.5):
        super().__init__()
        self.features = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1),  # 32x32 -> 32x32
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),                  # 32 -> 16

            nn.Conv2d(64, 192, kernel_size=3, padding=1),           # 16 -> 16
            nn.BatchNorm2d(192),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),                  # 16 -> 8

            nn.Conv2d(192, 384, kernel_size=3, padding=1),          # 8 -> 8
            nn.BatchNorm2d(384),
            nn.ReLU(inplace=True),

            nn.Conv2d(384, 256, kernel_size=3, padding=1),          # 8 -> 8
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),

            nn.Conv2d(256, 256, kernel_size=3, padding=1),          # 8 -> 8
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),                  # 8 -> 4
        )

        self.classifier = nn.Sequential(
            nn.Dropout(p=dropout),
            nn.Linear(256 * 4 * 4, 1024),
            nn.ReLU(inplace=True),
            nn.Dropout(p=dropout),
            nn.Linear(1024, 512),
            nn.ReLU(inplace=True),
            nn.Linear(512, num_classes),
        )

    def forward(self, x):
        x = self.features(x)
        x = torch.flatten(x, 1)
        return self.classifier(x)


class AlexNetCIFAR(nn.Module):
    """
    AlexNet adapted for 32x32 inputs:
    - Use 3x3 convs (stride 1) instead of 11x11/5x5
    - Slightly reduced channels to fit CIFAR-10 scale
    - Removing the FC layers to use a pointwise convolution and Global Average Pooling
    """
    def __init__(self, num_classes: int = 10, dropout: float = 0.5):
        super().__init__()
        self.features = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1),  # 32x32 -> 32x32
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),                  # 32 -> 16

            nn.Conv2d(64, 192, kernel_size=3, padding=1),           # 16 -> 16
            nn.BatchNorm2d(192),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),                  # 16 -> 8

            nn.Conv2d(192, 384, kernel_size=3, padding=1),          # 8 -> 8
            nn.BatchNorm2d(384),
            nn.ReLU(inplace=True),

            nn.Conv2d(384, 256, kernel_size=3, padding=1),          # 8 -> 8
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),

            nn.Conv2d(256, 256, kernel_size=3, padding=1),          # 8 -> 8
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),                  # 8 -> 4
        )


        self.head = nn.Sequential(
            nn.Conv2d(256, num_classes, kernel_size=1),
            nn.AdaptiveAvgPool2d(1),
            nn.Flatten(),
        )

    def forward(self, x):
        x = self.features(x)
        return self.head(x)

class NiNCIFAR(nn.Module):
    """
    AlexNet modified to include pointwise convolutions throughout the network.
    """
    def __init__(self, num_classes: int = 10, dropout: float = 0.5):
        super().__init__()
        self.features = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),

            nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(128),
            nn.Conv2d(128, 128, kernel_size=1, stride=1, padding=1),
            nn.BatchNorm2d(128),
            nn.Conv2d(128, 192, kernel_size=1, stride=1, padding=1), # expand (ensuring the 128 x 192 uses pointwise convolution)
            nn.BatchNorm2d(192),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),

            nn.Conv2d(192, 256, kernel_size=3, padding=1),
            nn.BatchNorm2d(256),
            nn.Conv2d(256, 256, kernel_size=1, padding=1),
            nn.BatchNorm2d(256),
            nn.Conv2d(256, 384, kernel_size=1, padding=1), # expand
            nn.BatchNorm2d(384),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),

            nn.Conv2d(384, 256, kernel_size=1, padding=1),  # contract (ensuring the 384 x 256 uses pointwise convolution)
            nn.BatchNorm2d(256),
            nn.Conv2d(256, 256, kernel_size=3, padding=1),
            nn.BatchNorm2d(256),
            nn.Conv2d(256, 256, kernel_size=1, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2)
        )


        self.head = nn.Sequential(
            nn.Conv2d(256, num_classes, kernel_size=1), # contract
            nn.AdaptiveAvgPool2d(1),
            nn.Flatten(),
        )

    def forward(self, x):
        x = self.features(x)
        return self.head(x)

In [3]:
model = AlexNetPlainCIFAR()
total_params = sum(p.numel() for p in model.parameters())
print(f"Total number of parameters in ALEXNET (plain): {total_params}")

model = AlexNetCIFAR()
total_params = sum(p.numel() for p in model.parameters())
print(f"Total number of parameters in ALEXNET (no FC layers): {total_params}")

model = NiNCIFAR()
total_params = sum(p.numel() for p in model.parameters())
print(f"Total number of parameters in NiN: {total_params}")

Total number of parameters in ALEXNET (plain): 6979146
Total number of parameters in ALEXNET (no FC layers): 2256458
Total number of parameters in NiN: 1485386


In [4]:
# -----------------------
# Train / Eval
# -----------------------
@torch.no_grad()
def evaluate(model, loader, device):
    model.eval()
    total, correct, running_loss = 0, 0, 0.0
    criterion = nn.CrossEntropyLoss()
    for images, labels in loader:
        images, labels = images.to(device, non_blocking=True), labels.to(device, non_blocking=True)
        outputs = model(images)
        loss = criterion(outputs, labels)
        running_loss += loss.item() * images.size(0)
        preds = outputs.argmax(dim=1)
        correct += (preds == labels).sum().item()
        total += labels.size(0)
    return running_loss / total, correct / total

def train_one_epoch(model, loader, optimizer, device):
    model.train()
    criterion = nn.CrossEntropyLoss()
    running_loss, correct, total = 0.0, 0, 0
    for images, labels in loader:
        images, labels = images.to(device, non_blocking=True), labels.to(device, non_blocking=True)

        optimizer.zero_grad(set_to_none=True)
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item() * images.size(0)
        preds = outputs.argmax(dim=1)
        correct += (preds == labels).sum().item()
        total += labels.size(0)

    return running_loss / total, correct / total

In [5]:
epochs = 10
batch_size = 126
lr = 0.0005
weight_decay = 5e-4
seed = 42

set_seed()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
train_loader, test_loader = get_cifar10_loaders(batch_size=batch_size)

Using device: cuda


100%|██████████| 170M/170M [00:02<00:00, 77.5MB/s]


In [6]:
def train_and_evaluate(model_name='alexnet'):
  if model_name == "alexnet":
    model = AlexNetCIFAR()
  elif model_name == 'alexnet_plain':
    model = AlexNetPlainCIFAR()
  else:
    model = NiNCIFAR()

  model.to(device)
  optimizer = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay=weight_decay)

  best_acc = 0.0
  for epoch in range(1, epochs + 1):
      train_loss, train_acc = train_one_epoch(model, train_loader, optimizer, device)
      val_loss,  val_acc  = evaluate(model, test_loader, device)

      if val_acc > best_acc:
          best_acc = val_acc
          os.makedirs("checkpoints", exist_ok=True)
          torch.save({"model": model.state_dict(),
                      "epoch": epoch,
                      "acc": best_acc
                      },
                      f"checkpoints/{model_name}_best.pt")

      print(f"Epoch {epoch:02d}/{epochs} | "
            f"Train Loss {train_loss:.4f} Acc {train_acc*100:.2f}% | "
            f"Val Loss {val_loss:.4f} Acc {val_acc*100:.2f}% | "
            f"Best Val Acc {best_acc*100:.2f}%")

  # Final evaluation
  test_loss, test_acc = evaluate(model, test_loader, device)
  print(f"\nFinal {model_name.upper()} Test Accuracy: {test_acc*100:.2f}% (loss {test_loss:.4f})")

In [7]:
train_and_evaluate(model_name='alexnet_plain')



Epoch 01/10 | Train Loss 1.5478 Acc 42.47% | Val Loss 1.2026 Acc 56.20% | Best Val Acc 56.20%
Epoch 02/10 | Train Loss 1.1516 Acc 58.80% | Val Loss 1.0272 Acc 63.28% | Best Val Acc 63.28%
Epoch 03/10 | Train Loss 0.9975 Acc 64.88% | Val Loss 1.0372 Acc 63.49% | Best Val Acc 63.49%
Epoch 04/10 | Train Loss 0.9002 Acc 68.46% | Val Loss 0.8523 Acc 71.69% | Best Val Acc 71.69%
Epoch 05/10 | Train Loss 0.8167 Acc 71.88% | Val Loss 0.9406 Acc 70.55% | Best Val Acc 71.69%
Epoch 06/10 | Train Loss 0.7516 Acc 74.48% | Val Loss 0.7157 Acc 75.55% | Best Val Acc 75.55%
Epoch 07/10 | Train Loss 0.7014 Acc 76.37% | Val Loss 0.6688 Acc 77.41% | Best Val Acc 77.41%
Epoch 08/10 | Train Loss 0.6614 Acc 77.89% | Val Loss 0.6266 Acc 79.05% | Best Val Acc 79.05%
Epoch 09/10 | Train Loss 0.6185 Acc 79.45% | Val Loss 0.6859 Acc 77.80% | Best Val Acc 79.05%
Epoch 10/10 | Train Loss 0.5916 Acc 80.16% | Val Loss 0.6424 Acc 78.58% | Best Val Acc 79.05%

Final ALEXNET_PLAIN Test Accuracy: 78.58% (loss 0.6424)


In [20]:
train_and_evaluate(model_name='alexnet')

Epoch 01/10 | Train Loss 1.3292 Acc 51.58% | Val Loss 1.5289 Acc 48.95% | Best Val Acc 48.95%
Epoch 02/10 | Train Loss 0.9533 Acc 66.22% | Val Loss 0.9590 Acc 66.35% | Best Val Acc 66.35%
Epoch 03/10 | Train Loss 0.7942 Acc 72.14% | Val Loss 1.0031 Acc 65.33% | Best Val Acc 66.35%
Epoch 04/10 | Train Loss 0.7082 Acc 75.37% | Val Loss 0.9508 Acc 68.23% | Best Val Acc 68.23%
Epoch 05/10 | Train Loss 0.6496 Acc 77.54% | Val Loss 0.9267 Acc 70.45% | Best Val Acc 70.45%
Epoch 06/10 | Train Loss 0.5981 Acc 79.37% | Val Loss 0.6574 Acc 77.09% | Best Val Acc 77.09%
Epoch 07/10 | Train Loss 0.5586 Acc 80.80% | Val Loss 0.8094 Acc 73.58% | Best Val Acc 77.09%
Epoch 08/10 | Train Loss 0.5263 Acc 81.88% | Val Loss 0.6889 Acc 76.89% | Best Val Acc 77.09%
Epoch 09/10 | Train Loss 0.4966 Acc 83.04% | Val Loss 0.6744 Acc 78.41% | Best Val Acc 78.41%
Epoch 10/10 | Train Loss 0.4738 Acc 83.71% | Val Loss 0.5613 Acc 81.14% | Best Val Acc 81.14%

Final ALEXNET Test Accuracy: 81.14% (loss 0.5613)


In [19]:
train_and_evaluate(model_name='NiN')

Epoch 01/10 | Train Loss 1.4436 Acc 47.06% | Val Loss 1.4685 Acc 47.14% | Best Val Acc 47.14%
Epoch 02/10 | Train Loss 1.1525 Acc 58.51% | Val Loss 1.2084 Acc 57.39% | Best Val Acc 57.39%
Epoch 03/10 | Train Loss 1.0611 Acc 62.12% | Val Loss 1.0820 Acc 61.19% | Best Val Acc 61.19%
Epoch 04/10 | Train Loss 0.9811 Acc 64.91% | Val Loss 1.0550 Acc 62.91% | Best Val Acc 62.91%
Epoch 05/10 | Train Loss 0.9309 Acc 67.10% | Val Loss 0.9825 Acc 65.69% | Best Val Acc 65.69%
Epoch 06/10 | Train Loss 0.8748 Acc 69.26% | Val Loss 0.9450 Acc 68.04% | Best Val Acc 68.04%
Epoch 07/10 | Train Loss 0.8292 Acc 70.85% | Val Loss 0.8811 Acc 69.52% | Best Val Acc 69.52%
Epoch 08/10 | Train Loss 0.7979 Acc 72.09% | Val Loss 1.0280 Acc 65.16% | Best Val Acc 69.52%
Epoch 09/10 | Train Loss 0.7637 Acc 73.49% | Val Loss 0.9304 Acc 67.89% | Best Val Acc 69.52%
Epoch 10/10 | Train Loss 0.7444 Acc 74.03% | Val Loss 0.7516 Acc 74.27% | Best Val Acc 74.27%

Final NIN Test Accuracy: 74.27% (loss 0.7516)


# Conclusion
- During convolution operations, large number of filters appear in deeper layers. Thus each additional convolution layer becomes expensive as its adds potentially millions of parameters to the network.
- With pointwise convolutions, we are able to control the channel depth at any point in the network without adding too many learnable parameters to the network.
- In the example above, we control the number of parameters in the network to drop by 30% compared to AlexNet (5x compared to AlexNet plain). Yet, training the exact same amount of time on the same data yeilds comparable performance.