In [None]:
import argparse
import math
import random
import os
from typing import Tuple

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.backends.cudnn as cudnn
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

# -----------------------
# Reproducibility helpers
# -----------------------
def set_seed(seed: int = 42):
    random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    cudnn.deterministic = False
    cudnn.benchmark = True  # faster on GPUs for conv nets

# -----------------------
# Data
# -----------------------
def get_cifar10_loaders(batch_size: int = 128, num_workers: int = 4) -> Tuple[DataLoader, DataLoader]:
    mean = (0.4914, 0.4822, 0.4465)
    std  = (0.2470, 0.2435, 0.2616)

    train_tf = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize(mean, std),
    ])
    test_tf = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(mean, std),
    ])

    train_ds = datasets.CIFAR10(root="./data", train=True, download=True, transform=train_tf)
    test_ds  = datasets.CIFAR10(root="./data", train=False, download=True, transform=test_tf)

    train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True,  num_workers=num_workers, pin_memory=True)
    test_loader  = DataLoader(test_ds,  batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=True)
    return train_loader, test_loader

In [None]:
# -----------------------
# Train / Eval
# -----------------------
@torch.no_grad()
def evaluate(model, loader, device):
    model.eval()
    total, correct, running_loss = 0, 0, 0.0
    criterion = nn.CrossEntropyLoss()
    for images, labels in loader:
        images, labels = images.to(device, non_blocking=True), labels.to(device, non_blocking=True)
        outputs = model(images)
        loss = criterion(outputs, labels)
        running_loss += loss.item() * images.size(0)
        preds = outputs.argmax(dim=1)
        correct += (preds == labels).sum().item()
        total += labels.size(0)
    return running_loss / total, correct / total

def train_one_epoch(model, loader, optimizer, device):
    model.train()
    criterion = nn.CrossEntropyLoss()
    running_loss, correct, total = 0.0, 0, 0
    for images, labels in loader:
        images, labels = images.to(device, non_blocking=True), labels.to(device, non_blocking=True)

        optimizer.zero_grad(set_to_none=True)
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item() * images.size(0)
        preds = outputs.argmax(dim=1)
        correct += (preds == labels).sum().item()
        total += labels.size(0)

    return running_loss / total, correct / total

In [None]:
# -----------------------
# Model: Inception for CIFAR-10
# -----------------------
class ConvBNReLU(nn.Module):
    def __init__(self, in_c, out_c, k=3, s=1, p=1):
        super().__init__()
        self.conv = nn.Conv2d(in_c, out_c, k, s, p, bias=False)
        self.bn   = nn.BatchNorm2d(out_c)
        self.act  = nn.ReLU(inplace=True)
    def forward(self, x):
        return self.act(self.bn(self.conv(x)))

class InceptionBlock(nn.Module):
    """
    v1-style block:
      - b1: 1x1
      - b2: 1x1 -> 3x3
      - b3: 1x1 -> 5x5 (implemented as 3x3 -> 3x3)
      - b4: 3x3 pool -> 1x1
    """
    def __init__(self, in_c, out_1x1, red_3x3, out_3x3, red_5x5, out_5x5, pool_proj,
                 pool_type="avg"):
        super().__init__()
        # 1x1
        self.b1 = ConvBNReLU(in_c, out_1x1, k=1, p=0)

        # 1x1 -> 3x3
        self.b2_reduce = ConvBNReLU(in_c, red_3x3, k=1, p=0)
        self.b2_conv   = ConvBNReLU(red_3x3, out_3x3, k=3, p=1)

        # 1x1 -> (3x3 -> 3x3) ≈ 5x5
        self.b3_reduce = ConvBNReLU(in_c, red_5x5, k=1, p=0)
        self.b3_conv1  = ConvBNReLU(red_5x5, out_5x5, k=3, p=1)
        self.b3_conv2  = ConvBNReLU(out_5x5, out_5x5, k=3, p=1)

        # pool -> 1x1
        if pool_type == "avg":
            self.pool = nn.AvgPool2d(kernel_size=3, stride=1, padding=1)
        else:
            self.pool = nn.MaxPool2d(kernel_size=3, stride=1, padding=1)
        self.b4_proj = ConvBNReLU(in_c, pool_proj, k=1, p=0)

    def forward(self, x):
        b1 = self.b1(x)

        r2 = self.b2_reduce(x)
        b2 = self.b2_conv(r2)

        r3 = self.b3_reduce(x)
        b3 = self.b3_conv2(self.b3_conv1(r3))

        b4 = self.b4_proj(self.pool(x))

        return torch.cat([b1, b2, b3, b4], dim=1)

class InceptionNetCIFAR(nn.Module):
    """
    A slim GoogLeNet-v1 style network sized for 32x32 CIFAR-10.
    """
    def __init__(self, num_classes=10, dropout=0.2):
        super().__init__()
        # Stem (keep strides=1 for 32x32)
        self.stem = nn.Sequential(
            ConvBNReLU(3, 64, k=3, s=1, p=1),
            ConvBNReLU(64, 64, k=3, s=1, p=1),
            nn.MaxPool2d(3, stride=2, padding=1),  # 32->16
        )

        # Inception stack (channels tuned for small model)
        # After stem: 64ch
        self.inc1 = InceptionBlock(
            in_c=64,
            out_1x1=32,
            red_3x3=32, out_3x3=48,
            red_5x5=8,  out_5x5=16,
            pool_proj=16,
            pool_type="avg"
        )  # -> 32+48+16+16 = 112 ch

        self.down1 = nn.MaxPool2d(3, stride=2, padding=1)  # 16->8

        self.inc2 = InceptionBlock(
            in_c=112,
            out_1x1=64,
            red_3x3=48, out_3x3=64,
            red_5x5=16, out_5x5=32,
            pool_proj=32,
            pool_type="avg"
        )  # -> 64+64+32+32 = 192 ch

        self.inc3 = InceptionBlock(
            in_c=192,
            out_1x1=96,
            red_3x3=64, out_3x3=96,
            red_5x5=24, out_5x5=64,
            pool_proj=64,
            pool_type="avg"
        )  # -> 96+96+64+64 = 320 ch

        self.head = nn.Sequential(
            nn.Dropout(p=dropout),
            nn.AdaptiveAvgPool2d(1),  # 8x8 -> 1x1
        )
        self.fc = nn.Linear(320, num_classes)

        self._init_weights()

    def _init_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1.)
                nn.init.constant_(m.bias, 0.)
            elif isinstance(m, nn.Linear):
                nn.init.normal_(m.weight, 0, 0.01)
                nn.init.constant_(m.bias, 0.)

    def forward(self, x):
        x = self.stem(x)
        x = self.inc1(x)
        x = self.down1(x)
        x = self.inc2(x)
        x = self.inc3(x)
        x = self.head(x)           # (B, 320, 1, 1)
        x = torch.flatten(x, 1)    # (B, 320)
        return self.fc(x)


class AlexNetCIFAR(nn.Module):
    """
    AlexNet adapted for 32x32 inputs:
    - Use 3x3 convs (stride 1) instead of 11x11/5x5
    - Slightly reduced channels to fit CIFAR-10 scale
    - Removing the FC layers to use a pointwise convolution and Global Average Pooling
    """
    def __init__(self, num_classes: int = 10, dropout: float = 0.5):
        super().__init__()
        self.features = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1),  # 32x32 -> 32x32
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),                  # 32 -> 16

            nn.Conv2d(64, 192, kernel_size=3, padding=1),           # 16 -> 16
            nn.BatchNorm2d(192),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),                  # 16 -> 8

            nn.Conv2d(192, 384, kernel_size=3, padding=1),          # 8 -> 8
            nn.BatchNorm2d(384),
            nn.ReLU(inplace=True),

            nn.Conv2d(384, 256, kernel_size=3, padding=1),          # 8 -> 8
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),

            nn.Conv2d(256, 256, kernel_size=3, padding=1),          # 8 -> 8
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),                  # 8 -> 4
        )


        self.head = nn.Sequential(
            nn.Conv2d(256, num_classes, kernel_size=1),
            nn.AdaptiveAvgPool2d(1),
            nn.Flatten(),
        )

    def forward(self, x):
        x = self.features(x)
        return self.head(x)

In [None]:
model = AlexNetCIFAR()
total_params = sum(p.numel() for p in model.parameters())
print(f"Total number of parameters in AlexNet: {total_params}")

model = InceptionNetCIFAR()
total_params = sum(p.numel() for p in model.parameters())
print(f"Total number of parameters in InceptionNet: {total_params}")

Total number of parameters in AlexNet: 2256458
Total number of parameters in InceptionNet: 279818


In [None]:
epochs = 10
batch_size = 128
lr = 5e-4
weight_decay = 5e-4
num_workers = 4
seed = 42

In [None]:
set_seed(seed)
device = "cuda" if torch.cuda.is_available() else "cpu"
train_loader, test_loader = get_cifar10_loaders(batch_size=batch_size, num_workers=num_workers)

100%|██████████| 170M/170M [00:03<00:00, 48.6MB/s]


In [None]:
def train_and_evaluate(model_name='InceptionNet', num_classes=10):
  model = InceptionNetCIFAR(num_classes=num_classes)
  model.to(device)
  optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)

  best_acc = 0.0
  for epoch in range(1, epochs + 1):
      train_loss, train_acc = train_one_epoch(model, train_loader, optimizer, device)
      val_loss,  val_acc  = evaluate(model, test_loader, device)

      if val_acc > best_acc:
          best_acc = val_acc
          os.makedirs("checkpoints", exist_ok=True)
          torch.save({"model": model.state_dict(),
                      "epoch": epoch,
                      "acc": best_acc
                      },
                      f"checkpoints/{model_name}_best.pt")

      print(f"Epoch {epoch:02d}/{epochs} | "
            f"Train Loss {train_loss:.4f} Acc {train_acc*100:.2f}% | "
            f"Val Loss {val_loss:.4f} Acc {val_acc*100:.2f}% | "
            f"Best Val Acc {best_acc*100:.2f}%")

  # Final evaluation
  test_loss, test_acc = evaluate(model, test_loader, device)
  print(f"\nFinal {model_name.upper()} Test Accuracy: {test_acc*100:.2f}% (loss {test_loss:.4f})")

In [None]:
train_and_evaluate()

Epoch 01/10 | Train Loss 1.4470 Acc 46.93% | Val Loss 1.1514 Acc 57.68% | Best Val Acc 57.68%
Epoch 02/10 | Train Loss 1.0420 Acc 62.73% | Val Loss 0.9709 Acc 65.36% | Best Val Acc 65.36%
Epoch 03/10 | Train Loss 0.8906 Acc 68.29% | Val Loss 0.8735 Acc 69.10% | Best Val Acc 69.10%
Epoch 04/10 | Train Loss 0.7869 Acc 72.24% | Val Loss 0.7886 Acc 72.50% | Best Val Acc 72.50%
Epoch 05/10 | Train Loss 0.7167 Acc 75.01% | Val Loss 0.7514 Acc 74.11% | Best Val Acc 74.11%
Epoch 06/10 | Train Loss 0.6651 Acc 76.85% | Val Loss 0.7387 Acc 74.87% | Best Val Acc 74.87%
Epoch 07/10 | Train Loss 0.6177 Acc 78.60% | Val Loss 0.6703 Acc 76.89% | Best Val Acc 76.89%
Epoch 08/10 | Train Loss 0.5781 Acc 79.99% | Val Loss 0.6868 Acc 76.63% | Best Val Acc 76.89%
Epoch 09/10 | Train Loss 0.5561 Acc 80.75% | Val Loss 0.5807 Acc 80.27% | Best Val Acc 80.27%
Epoch 10/10 | Train Loss 0.5316 Acc 81.56% | Val Loss 0.6179 Acc 78.99% | Best Val Acc 80.27%

Final INCEPTIONNET Test Accuracy: 78.99% (loss 0.6179)


# Conclusion

Inception network inspired architecture we used here has far fewer parameters than other networks like AlexNet, yet can achieve comparable if not better performance on the same dataset with similar parameters.

We can see this stark contrast to other networks trained [here](https://github.com/ajhalthor/computer-vision-101/blob/main/pointwise_convolution/pointwise_convolutions.ipynb)