# 05: ResNet from Scratch for CIFAR-10

Deep learning paper implementation from scratch using PyTorch.
1. BasicBlock Implementation


In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np
import time
import copy

torch.manual_seed(42)
np.random.seed(42)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


## 1. Data Loading

In [None]:
# CIFAR-10 normalization
CIFAR10_MEAN = (0.4914, 0.4822, 0.4465)
CIFAR10_STD = (0.2470, 0.2435, 0.2616)

# Transforms
train_transform = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(CIFAR10_MEAN, CIFAR10_STD),
])

test_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(CIFAR10_MEAN, CIFAR10_STD),
])

# Datasets
train_dataset = torchvision.datasets.CIFAR10(
    root='./data', train=True, download=True, transform=train_transform
)
test_dataset = torchvision.datasets.CIFAR10(
    root='./data', train=False, download=True, transform=test_transform
)

# Dataloaders
BATCH_SIZE = 128
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=2)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=2)

CLASSES = ('airplane', 'automobile', 'bird', 'cat', 'deer',
           'dog', 'frog', 'horse', 'ship', 'truck')

print(f"Training samples: {len(train_dataset)}")
print(f"Test samples: {len(test_dataset)}")

## 2. BasicBlock Implementation

The BasicBlock is the fundamental building unit of ResNet-18/34:

```
Input (x)
    │
    ├──────────────────────┐
    │                      │ (identity or projection)
    ▼                      │
Conv 3x3 → BN → ReLU       │
    │                      │
    ▼                      │
Conv 3x3 → BN              │
    │                      │
    ▼                      ▼
    └────── + ◄────────────┘
            │
            ▼
          ReLU
            │
            ▼
         Output
```

When stride > 1 or channels change, we need a **projection shortcut** (1x1 conv) to match dimensions.

In [None]:
class BasicBlock(nn.Module):
    expansion = 1  # Output channels = in_channels * expansion
    
    def __init__(self, in_channels: int, out_channels: int, stride: int = 1):
        super().__init__()
        
        # Main path
        self.conv1 = nn.Conv2d(
            in_channels, out_channels, kernel_size=3, 
            stride=stride, padding=1, bias=False
        )
        self.bn1 = nn.BatchNorm2d(out_channels)
        
        self.conv2 = nn.Conv2d(
            out_channels, out_channels, kernel_size=3,
            stride=1, padding=1, bias=False
        )
        self.bn2 = nn.BatchNorm2d(out_channels)
        
        # Shortcut (identity or projection)
        self.shortcut = nn.Sequential()
        if stride != 1 or in_channels != out_channels:
            # Projection shortcut: 1x1 conv to match dimensions
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_channels, out_channels, kernel_size=1, 
                          stride=stride, bias=False),
                nn.BatchNorm2d(out_channels)
            )
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # Main path
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        
        # Skip connection
        out = out + self.shortcut(x)
        out = F.relu(out)
        
        return out


# Test BasicBlock
print("Testing BasicBlock...")

# Identity shortcut (same channels, stride=1)
block1 = BasicBlock(64, 64, stride=1)
x = torch.randn(2, 64, 16, 16)
y = block1(x)
print(f"Identity: {x.shape} -> {y.shape}")

# Projection shortcut (different channels)
block2 = BasicBlock(64, 128, stride=1)
y = block2(x)
print(f"Channel change: {x.shape} -> {y.shape}")

# Downsampling (stride=2)
block3 = BasicBlock(64, 128, stride=2)
y = block3(x)
print(f"Downsample: {x.shape} -> {y.shape}")

## 3. ResNet Architecture

ResNet-18 structure for CIFAR-10 (adapted for 32x32 input):
- Initial conv: 3x3, 64 channels (no 7x7 or maxpool since images are small)
- Layer 1: 2 BasicBlocks, 64 channels
- Layer 2: 2 BasicBlocks, 128 channels, stride 2
- Layer 3: 2 BasicBlocks, 256 channels, stride 2
- Layer 4: 2 BasicBlocks, 512 channels, stride 2
- Global average pool + FC

In [None]:
class ResNet(nn.Module):
    def __init__(self, block, num_blocks: list, num_classes: int = 10, 
                 base_channels: int = 64):
        super().__init__()
        self.in_channels = base_channels
        
        # Initial convolution (3x3 for CIFAR)
        self.conv1 = nn.Conv2d(3, base_channels, kernel_size=3, 
                               stride=1, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(base_channels)
        
        # Residual layers
        self.layer1 = self._make_layer(block, base_channels, num_blocks[0], stride=1)
        self.layer2 = self._make_layer(block, base_channels * 2, num_blocks[1], stride=2)
        self.layer3 = self._make_layer(block, base_channels * 4, num_blocks[2], stride=2)
        self.layer4 = self._make_layer(block, base_channels * 8, num_blocks[3], stride=2)
        
        # Classifier
        self.avgpool = nn.AdaptiveAvgPool2d(1)
        self.fc = nn.Linear(base_channels * 8 * block.expansion, num_classes)
        
        # Initialize weights
        self._init_weights()
    
    def _make_layer(self, block, out_channels: int, num_blocks: int, stride: int):
        strides = [stride] + [1] * (num_blocks - 1)
        layers = []
        
        for s in strides:
            layers.append(block(self.in_channels, out_channels, s))
            self.in_channels = out_channels * block.expansion
        
        return nn.Sequential(*layers)
    
    def _init_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.ones_(m.weight)
                nn.init.zeros_(m.bias)
            elif isinstance(m, nn.Linear):
                nn.init.normal_(m.weight, 0, 0.01)
                nn.init.zeros_(m.bias)
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # Initial conv
        out = F.relu(self.bn1(self.conv1(x)))
        
        # Residual layers
        out = self.layer1(out)  # 32x32
        out = self.layer2(out)  # 16x16
        out = self.layer3(out)  # 8x8
        out = self.layer4(out)  # 4x4
        
        # Classifier
        out = self.avgpool(out)
        out = out.view(out.size(0), -1)
        out = self.fc(out)
        
        return out


def ResNet18(num_classes=10):
    return ResNet(BasicBlock, [2, 2, 2, 2], num_classes=num_classes)


# Create and inspect model
model = ResNet18().to(device)

total_params = sum(p.numel() for p in model.parameters())
print(f"ResNet-18 parameters: {total_params:,}")

# Test forward pass
x = torch.randn(2, 3, 32, 32).to(device)
y = model(x)
print(f"Input: {x.shape} -> Output: {y.shape}")

In [None]:
# Verify skip connections work (gradient flow)
print("\nVerifying gradient flow through skip connections...")

model = ResNet18().to(device)
x = torch.randn(2, 3, 32, 32, requires_grad=True).to(device)
y = model(x)
loss = y.sum()
loss.backward()

# Check gradients in different layers
print(f"Input gradient norm: {x.grad.norm().item():.4f}")
print(f"Conv1 weight gradient norm: {model.conv1.weight.grad.norm().item():.4f}")
print(f"Layer4 block0 conv1 gradient norm: {model.layer4[0].conv1.weight.grad.norm().item():.4f}")
print("Gradients are flowing through all layers!")

In [None]:
def train_epoch(model, train_loader, criterion, optimizer, device):
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0
    
    for inputs, targets in train_loader:
        inputs, targets = inputs.to(device), targets.to(device)
        
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, targets)
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item() * inputs.size(0)
        _, predicted = outputs.max(1)
        total += targets.size(0)
        correct += predicted.eq(targets).sum().item()
    
    return running_loss / total, 100. * correct / total


def evaluate(model, test_loader, criterion, device):
    model.eval()
    running_loss = 0.0
    correct = 0
    total = 0
    
    with torch.no_grad():
        for inputs, targets in test_loader:
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = model(inputs)
            loss = criterion(outputs, targets)
            
            running_loss += loss.item() * inputs.size(0)
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()
    
    return running_loss / total, 100. * correct / total


def train_model(model, train_loader, test_loader, epochs, lr, weight_decay, 
                device, scheduler_type='cosine', verbose=True):
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(model.parameters(), lr=lr, momentum=0.9, weight_decay=weight_decay)
    
    if scheduler_type == 'cosine':
        scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs)
    else:  # step
        scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=[epochs//2, 3*epochs//4], gamma=0.1)
    
    history = {'train_loss': [], 'train_acc': [], 'test_loss': [], 'test_acc': [], 'lr': []}
    best_acc = 0.0
    best_state = None
    
    start_time = time.time()
    
    for epoch in range(epochs):
        train_loss, train_acc = train_epoch(model, train_loader, criterion, optimizer, device)
        test_loss, test_acc = evaluate(model, test_loader, criterion, device)
        current_lr = optimizer.param_groups[0]['lr']
        scheduler.step()
        
        history['train_loss'].append(train_loss)
        history['train_acc'].append(train_acc)
        history['test_loss'].append(test_loss)
        history['test_acc'].append(test_acc)
        history['lr'].append(current_lr)
        
        if test_acc > best_acc:
            best_acc = test_acc
            best_state = copy.deepcopy(model.state_dict())
        
        if verbose and (epoch + 1) % 10 == 0:
            print(f"Epoch {epoch+1:3d}/{epochs} | "
                  f"Train: {train_acc:.2f}% | Test: {test_acc:.2f}% | LR: {current_lr:.6f}")
    
    total_time = time.time() - start_time
    if verbose:
        print(f"\nTraining complete in {total_time:.1f}s. Best accuracy: {best_acc:.2f}%")
    
    return history, best_state, best_acc, total_time


## 5. Train ResNet-18

In [None]:
NUM_EPOCHS = 100
LEARNING_RATE = 0.1
WEIGHT_DECAY = 5e-4

print(f"Training ResNet-18 for {NUM_EPOCHS} epochs...")
print("="*70)

resnet_model = ResNet18().to(device)
resnet_history, resnet_best_state, resnet_best_acc, resnet_time = train_model(
    resnet_model, train_loader, test_loader,
    epochs=NUM_EPOCHS, lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY,
    device=device, scheduler_type='cosine'
)

## 6. Comparison: Simple CNN vs ResNet


In [None]:
# Simple CNN baseline (from notebook 04)
class SimpleCNN(nn.Module):
    def __init__(self, num_classes=10):
        super().__init__()
        self.features = nn.Sequential(
            nn.Conv2d(3, 32, 3, padding=1), nn.BatchNorm2d(32), nn.ReLU(), nn.MaxPool2d(2),
            nn.Conv2d(32, 64, 3, padding=1), nn.BatchNorm2d(64), nn.ReLU(), nn.MaxPool2d(2),
            nn.Conv2d(64, 128, 3, padding=1), nn.BatchNorm2d(128), nn.ReLU(), nn.MaxPool2d(2),
            nn.Conv2d(128, 256, 3, padding=1), nn.BatchNorm2d(256), nn.ReLU(),
        )
        self.classifier = nn.Sequential(
            nn.AdaptiveAvgPool2d(1),
            nn.Flatten(),
            nn.Linear(256, num_classes)
        )
    
    def forward(self, x):
        return self.classifier(self.features(x))

# Train CNN for quick comparison (fewer epochs)
print("Training Simple CNN baseline (50 epochs for comparison)...")
print("="*70)

cnn_model = SimpleCNN().to(device)
cnn_history, cnn_best_state, cnn_best_acc, cnn_time = train_model(
    cnn_model, train_loader, test_loader,
    epochs=50, lr=0.1, weight_decay=5e-4,
    device=device, scheduler_type='cosine'
)

cnn_params = sum(p.numel() for p in cnn_model.parameters())
resnet_params = sum(p.numel() for p in resnet_model.parameters())

In [None]:
# Comparison table
print("\n" + "="*70)
print("MODEL COMPARISON: Simple CNN vs ResNet-18")
print("="*70)
print(f"{'Metric':<25} {'Simple CNN':<20} {'ResNet-18':<20}")
print("-"*70)
print(f"{'Parameters':<25} {cnn_params:>15,} {resnet_params:>15,}")
print(f"{'Training Epochs':<25} {50:>15} {NUM_EPOCHS:>15}")
print(f"{'Best Test Accuracy':<25} {cnn_best_acc:>14.2f}% {resnet_best_acc:>14.2f}%")
print(f"{'Training Time (s)':<25} {cnn_time:>15.1f} {resnet_time:>15.1f}")
print("="*70)

improvement = resnet_best_acc - cnn_best_acc
print(f"\nResNet-18 improvement over CNN: {improvement:+.2f}%")

In [None]:
# Plot comparison
fig, axes = plt.subplots(1, 2, figsize=(12, 4))

# Test accuracy
axes[0].plot(cnn_history['test_acc'], label='Simple CNN (50 ep)', alpha=0.8)
axes[0].plot(resnet_history['test_acc'], label=f'ResNet-18 ({NUM_EPOCHS} ep)', alpha=0.8)
axes[0].set_xlabel('Epoch')
axes[0].set_ylabel('Test Accuracy (%)')
axes[0].set_title('Test Accuracy Comparison')
axes[0].legend()
axes[0].grid(True, alpha=0.3)

# Bar chart comparison
models = ['Simple CNN', 'ResNet-18']
accs = [cnn_best_acc, resnet_best_acc]
colors = ['#ff6b6b', '#4ecdc4']

bars = axes[1].bar(models, accs, color=colors)
axes[1].set_ylabel('Best Test Accuracy (%)')
axes[1].set_title('Best Accuracy Comparison')
axes[1].set_ylim([min(accs) - 5, max(accs) + 5])

for bar, acc in zip(bars, accs):
    axes[1].text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.5,
                 f'{acc:.2f}%', ha='center', fontsize=11)

plt.tight_layout()
plt.show()

## 7. Ablation: Weight Decay 5e-4 vs 1e-4

In [None]:
def run_weight_decay_ablation(weight_decay, epochs=100):
    model = ResNet18().to(device)
    history, best_state, best_acc, train_time = train_model(
        model, train_loader, test_loader,
        epochs=epochs, lr=0.1, weight_decay=weight_decay,
        device=device, scheduler_type='cosine', verbose=False
    )
    return {
        'weight_decay': weight_decay,
        'best_acc': best_acc,
        'train_time': train_time,
        'history': history
    }

print("Running ablation study: Weight Decay comparison")
print("="*60)

# Weight decay = 5e-4 (standard)
print("\nTraining with weight_decay=5e-4...")
results_wd5e4 = run_weight_decay_ablation(5e-4, epochs=100)
print(f"  Best accuracy: {results_wd5e4['best_acc']:.2f}%")

# Weight decay = 1e-4
print("\nTraining with weight_decay=1e-4...")
results_wd1e4 = run_weight_decay_ablation(1e-4, epochs=100)
print(f"  Best accuracy: {results_wd1e4['best_acc']:.2f}%")

In [None]:
# Ablation results
print("\n" + "="*60)
print("ABLATION STUDY: Weight Decay Comparison")
print("="*60)
print(f"{'Configuration':<25} {'Best Test Acc':<20} {'Train Time (s)':<20}")
print("-"*60)
print(f"{'Weight Decay = 5e-4':<25} {results_wd5e4['best_acc']:>15.2f}% {results_wd5e4['train_time']:>15.1f}")
print(f"{'Weight Decay = 1e-4':<25} {results_wd1e4['best_acc']:>15.2f}% {results_wd1e4['train_time']:>15.1f}")
print("="*60)

diff = results_wd5e4['best_acc'] - results_wd1e4['best_acc']
better = "5e-4" if diff > 0 else "1e-4"
print(f"\nDifference: {abs(diff):.2f}%")
print(f"Better weight decay: {better}")

In [None]:
# Plot ablation
fig, axes = plt.subplots(1, 2, figsize=(12, 4))

# Test accuracy
axes[0].plot(results_wd5e4['history']['test_acc'], label='WD=5e-4', alpha=0.8)
axes[0].plot(results_wd1e4['history']['test_acc'], label='WD=1e-4', alpha=0.8)
axes[0].set_xlabel('Epoch')
axes[0].set_ylabel('Test Accuracy (%)')
axes[0].set_title('Test Accuracy: Weight Decay Comparison')
axes[0].legend()
axes[0].grid(True, alpha=0.3)

# Train-test gap
gap_5e4 = [t - v for t, v in zip(results_wd5e4['history']['train_acc'], results_wd5e4['history']['test_acc'])]
gap_1e4 = [t - v for t, v in zip(results_wd1e4['history']['train_acc'], results_wd1e4['history']['test_acc'])]

axes[1].plot(gap_5e4, label='WD=5e-4', alpha=0.8)
axes[1].plot(gap_1e4, label='WD=1e-4', alpha=0.8)
axes[1].set_xlabel('Epoch')
axes[1].set_ylabel('Train - Test Accuracy (%)')
axes[1].set_title('Generalization Gap')
axes[1].legend()
axes[1].grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

print(f"Final generalization gap:")
print(f"  WD=5e-4: {gap_5e4[-1]:.2f}%")
print(f"  WD=1e-4: {gap_1e4[-1]:.2f}%")