# Step 1: Baseline Model (AlexNet adapted for MNIST)

In [7]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import time
from torch.profiler import profile, record_function, ProfilerActivity
import matplotlib.pyplot as plt

# Set seed and device
torch.manual_seed(42)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Data loading
transform = transforms.Compose([
    transforms.Resize((224, 224)),  # Resize MNIST to match AlexNet input size
    transforms.ToTensor()
])

train_data = datasets.MNIST(root="data", train=True, download=True, transform=transform)
test_data = datasets.MNIST(root="data", train=False, download=True, transform=transform)

# DataLoader (baseline)
train_loader = DataLoader(train_data, batch_size=64, shuffle=True)
test_loader = DataLoader(test_data, batch_size=1000)

# Define model: Modified AlexNet for grayscale MNIST images
class AlexNetMNIST(nn.Module):
    def __init__(self):
        super().__init__()
        self.features = nn.Sequential(
            nn.Conv2d(1, 64, kernel_size=11, stride=4, padding=2),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2),
            nn.Conv2d(64, 192, kernel_size=5, padding=2),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2),
            nn.Conv2d(192, 384, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(384, 256, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(256, 256, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2),
        )
        self.avgpool = nn.AdaptiveAvgPool2d((6, 6))
        self.classifier = nn.Sequential(
            nn.Dropout(),
            nn.Linear(256 * 6 * 6, 4096),
            nn.ReLU(inplace=True),
            nn.Dropout(),
            nn.Linear(4096, 4096),
            nn.ReLU(inplace=True),
            nn.Linear(4096, 10),
        )

    def forward(self, x):
        x = self.features(x)
        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        x = self.classifier(x)
        return x

model = AlexNetMNIST().to(device)
optimizer = optim.Adadelta(model.parameters(), lr=1.0)
criterion = nn.CrossEntropyLoss()

# Timing functions
def train(model, loader, epochs=5):
    model.train()
    start_time = time.time()
    for epoch in range(epochs):
        for batch_idx, (data, target) in enumerate(loader):
            data, target = data.to(device), target.to(device)
            optimizer.zero_grad()
            output = model(data)
            loss = criterion(output, target)
            loss.backward()
            optimizer.step()
    return time.time() - start_time

def inference(model, loader):
    model.eval()
    correct = 0
    start_time = time.time()
    with torch.no_grad():
        for data, target in loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            pred = output.argmax(dim=1)
            correct += pred.eq(target).sum().item()
    acc = correct / len(loader.dataset)
    return acc, time.time() - start_time

# Baseline Training/Inference
baseline_train_time = train(model, train_loader)
baseline_acc, baseline_inference_time = inference(model, test_loader)


Using device: cuda
Baseline Training Time: 208.57s
Baseline Inference Time: 4.3859s | Accuracy: 99.01%


# Baseline Training/Inference


In [8]:
print(f"Baseline Training Time: {baseline_train_time:.2f}s")
print(f"Baseline Inference Time: {baseline_inference_time:.4f}s | Accuracy: {baseline_acc*100:.2f}%")

Baseline Training Time: 208.57s
Baseline Inference Time: 4.3859s | Accuracy: 99.01%


# Step 2: Profiling the Baseline


In [3]:
def profile_model():
    model.train()
    with profile(
        activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
        record_shapes=True,
        profile_memory=True,
        with_stack=True,
        schedule=torch.profiler.schedule(wait=1, warmup=1, active=3)
    ) as prof:
        for step, (data, target) in enumerate(train_loader):
            data, target = data.to(device), target.to(device)
            optimizer.zero_grad()
            with record_function("train_step"):
                output = model(data)
                loss = criterion(output, target)
                loss.backward()
                optimizer.step()
            prof.step()

    print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=15))

print("\nProfiling Baseline Model...")
profile_model()


Profiling Baseline Model...
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg       CPU Mem  Self CPU Mem      CUDA Mem  Self CUDA Mem    # of Calls  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                          ProfilerStep*         2.21%     952.073us        63.48%      27.341ms      13.671ms       0.000us         0.00%      14.602ms       7.301ms        -512 b      -6.13 

## Profiling Results Summary

| Operator / Event | CUDA Time (ms) | % of Total |
|------------------|----------------|------------|
| `Optimizer.step#Adadelta.step` | ~11.790 | 30.55% |
| Multi-tensor ops (`multi_tensor_applier`) | ~5.138–3.460 | ~25% total |
| `aten::convolution_backward` | ~2.465 | 6.39% |
| `cutlass::Kernel...` (likely forward convs) | ~3.649 | 9.46% |

1. **Adadelta Optimizer dominates GPU time (30%)**
   - This is a known issue with Adadelta, it has heavy internal computation. We will do optimizer switch: `Adadelta` → `AdamW`
2. **Multi-tensor operations take ~25%**
   - These are likely from optimizer updates or gradient manipulations.
3. **Convolution backward pass (~6.4%)**
   - Expected, but not the main bottleneck here.
4. **Forward convolution kernel (~9.5%)**
   - Also expected; AlexNet is convolution-heavy.


We also will change:

Mixed Precision Training -  Reduces memory usage and speeds up computation.

Classifier Head Simplification - original large fully connected layers caused slowdowns. Replaced with smaller FC layers.


DataLoader Improvements - Speed up data loading and reduce host-device transfer bottlenecks.


# Step 3: Optimized Model


In [9]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import time
from torch.profiler import profile, record_function, ProfilerActivity
import inspect

torch.manual_seed(42)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")


# Resize MNIST to fit AlexNet input size
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor()
])

train_data = datasets.MNIST(root="data", train=True, download=True, transform=transform)
test_data = datasets.MNIST(root="data", train=False, download=True, transform=transform)

# Optimized DataLoader
train_loader = DataLoader(
    train_data,
    batch_size=64,
    shuffle=True,
    num_workers=4,
    pin_memory=True,
    persistent_workers=True
)
test_loader = DataLoader(
    test_data,
    batch_size=1000,
    num_workers=4,
    pin_memory=True,
    persistent_workers=True
)

# Define AlexNet with adaptive classifier head
class AlexNetMNIST(nn.Module):
    def __init__(self):
        super().__init__()
        self.features = nn.Sequential(
            nn.Conv2d(1, 64, kernel_size=11, stride=4, padding=2),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2),
            nn.Conv2d(64, 192, kernel_size=5, padding=2),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2),
            nn.Conv2d(192, 384, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(384, 256, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(256, 256, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2),
        )
        self.avgpool = nn.AdaptiveAvgPool2d((6, 6))
        self.classifier = nn.Sequential(
            nn.Dropout(0.5),
            nn.Linear(256 * 6 * 6, 1024),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(1024, 10),
        )

    def forward(self, x):
        x = self.features(x)
        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        x = self.classifier(x)
        return x

model = AlexNetMNIST().to(device)

# Use AdamW with fused kernels if available
use_fused = 'fused' in inspect.signature(optim.AdamW).parameters
optimizer = optim.AdamW(model.parameters(), lr=3e-4, fused=use_fused) if use_fused else optim.AdamW(model.parameters(), lr=3e-4)

criterion = nn.CrossEntropyLoss()

# GradScaler for mixed precision
scaler = torch.cuda.amp.GradScaler()

# Training function
def train_optimized(model, loader, epochs=5):
    model.train()
    start_time = time.time()
    for epoch in range(epochs):
        for batch_idx, (data, target) in enumerate(loader):
            data, target = data.to(device, non_blocking=True), target.to(device, non_blocking=True)
            optimizer.zero_grad()
            with torch.cuda.amp.autocast():
                output = model(data)
                loss = criterion(output, target)
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
    return time.time() - start_time

# Inference function
def inference_optimized(model, loader):
    model.eval()
    correct = 0
    start_time = time.time()
    with torch.no_grad():
        for data, target in loader:
            data, target = data.to(device, non_blocking=True), target.to(device, non_blocking=True)
            with torch.cuda.amp.autocast():
                output = model(data)
            pred = output.argmax(dim=1)
            correct += pred.eq(target).sum().item()
    acc = correct / len(loader.dataset)
    return acc, time.time() - start_time

print("\nTraining Optimized Model...")
optimized_train_time = train_optimized(model, train_loader)
optimized_acc, optimized_inference_time = inference_optimized(model, test_loader)

Using device: cuda

Training Optimized Model...


  scaler = torch.cuda.amp.GradScaler()
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():


# Run optimized training & inference

In [10]:
print(f"\nBaseline Training Time: {baseline_train_time:.2f}s | Accuracy: {baseline_acc*100:.2f}%")
print(f"Optimized Training Time: {optimized_train_time:.2f}s | Accuracy: {optimized_acc*100:.2f}%")
print(f"Speedup: {baseline_train_time / optimized_train_time:.2f}x")

print(f"\nBaseline Inference Time: {baseline_inference_time:.4f}s")
print(f"Optimized Inference Time: {optimized_inference_time:.4f}s")
print(f"Speedup: {baseline_inference_time / optimized_inference_time:.2f}x")


Baseline Training Time: 208.57s | Accuracy: 99.01%
Optimized Training Time: 42.91s | Accuracy: 99.44%
Speedup: 4.86x

Baseline Inference Time: 4.3859s
Optimized Inference Time: 1.4929s
Speedup: 2.94x
