# Training ResNet10 on CIFAR10: Unoptimized vs Optimized

### In this notebook, we first train a ResNet10 model on the CIFAR10 dataset without any performance optimizations then we implement multiple strategies to speed up training by roughly 10x. 

## Section 1: Unoptimized Training

#### In this section, we define the ResNet10 model along with a basic training loop using a small batch size and no additional optimizations.


In [8]:
import torch
import torch.nn as nn
import torch.nn.init as init
import torch.optim as optim
import time
from torchvision import datasets, transforms
from torchvision.datasets import CIFAR10
from torch.utils.data import DataLoader
from tqdm import tqdm

# Define a BasicBlock for ResNet-10
class BasicBlock(nn.Module):
    def __init__(self, in_channels, out_channels, stride=1):
        super().__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(out_channels)
        self.shortcut = nn.Sequential()
        if stride != 1 or in_channels != out_channels:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(out_channels)
            )

    def forward(self, x):
        identity = x
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)
        out = self.conv2(out)
        out = self.bn2(out)
        out += self.shortcut(identity)
        out = self.relu(out)
        return out

# Define the ResNet-10 model
class ResNet10(nn.Module):
    def __init__(self, num_classes=10):
        super().__init__()
        self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.relu = nn.ReLU(inplace=True)
        self.layer1 = self._make_layer(64, 64, stride=1)
        self.layer2 = self._make_layer(64, 128, stride=2)
        self.layer3 = self._make_layer(128, 256, stride=2)
        self.layer4 = self._make_layer(256, 512, stride=2)
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(512, num_classes)
        self.init_weights()

    def _make_layer(self, in_channels, out_channels, stride):
        return nn.Sequential(BasicBlock(in_channels, out_channels, stride))

    def init_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
                init.xavier_normal_(m.weight)
                if m.bias is not None:
                    init.zeros_(m.bias)
            elif isinstance(m, nn.BatchNorm2d):
                init.constant_(m.weight, 1)
                init.constant_(m.bias, 0)

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)
        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        x = self.fc(x)
        return x

# Helper function for testing
def test_model(model, test_loader, device):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for inputs, labels in test_loader:
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            _, predicted = torch.max(outputs, 1)
            correct += (predicted == labels).sum().item()
            total += labels.size(0)
    accuracy = 100 * correct / total
    return accuracy

# Data loading without optimizations
transform = transforms.ToTensor()

train_dataset = CIFAR10(root='./data', train=True, download=True, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=4, shuffle=False, num_workers=0)

test_dataset = CIFAR10(root='./data', train=False, download=True, transform=transform)
test_loader = DataLoader(test_dataset, batch_size=4, shuffle=False, num_workers=0)

# Initialize the model, loss function, and optimizer (unoptimized)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model_unoptim = ResNet10().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model_unoptim.parameters(), lr=0.001)  # Simple SGD

# Training loop 
num_epochs = 5
print("Starting training...")
start_time = time.time()

for epoch in range(num_epochs):
    model_unoptim.train()
    epoch_start = time.time()
    train_loop = tqdm(train_loader, desc=f"Epoch {epoch + 1}/{num_epochs}", leave=False)
    for i, (inputs, labels) in enumerate(train_loop):
        inputs, labels = inputs.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs = model_unoptim(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        train_loop.set_postfix({"Loss": loss.item()})
    
    epoch_time = time.time() - epoch_start
    print(f"Epoch {epoch+1}, Loss: {loss.item():.4f}, Time: {epoch_time:.2f}s")

total_time_unoptimized = time.time() - start_time
print(f"Total unoptimized training time: {total_time_unoptimized:.2f} seconds")

initial_accuracy = test_model(model_unoptim, test_loader, device)

Files already downloaded and verified
Files already downloaded and verified
Starting training...


                                                                             

Epoch 1, Loss: 1.5516, Time: 98.64s


                                                                             

Epoch 2, Loss: 1.3854, Time: 97.92s


                                                                              

Epoch 3, Loss: 1.1749, Time: 97.66s


                                                                              

Epoch 4, Loss: 0.9992, Time: 97.62s


                                                                              

Epoch 5, Loss: 0.7867, Time: 97.67s
Total unoptimized training time: 489.52 seconds


## Section 2: Optimized Training  
### Now we implement various optimization strategies:  
### The strategies include:  

### **DataLoader Optimizations**  
- **Large Batch Size:** Set to 256 for efficient utilization of GPU resources.  
- **Multiple Workers:** Utilizes 4 workers for parallel data loading.  
- **Pinned Memory:** Enables faster data transfer from CPU to GPU by setting `pin_memory=True`.  

### **Mixed Precision Training**  
- **Automatic Mixed Precision:** Uses `torch.cuda.amp.autocast()` for faster computation on GPUs with Tensor Cores.  
- **Gradient Scaling:** Applies `GradScaler()` for smoother accumulation and to prevent underflow issues during backpropagation.  

### **Gradient Accumulation**  
- **Simulated Larger Batch Size:** Divides the loss by `accumulation_steps=4` and accumulates gradients before updating weights.  
- **Loss Normalization:** Divides the loss by accumulation steps for more stable training.  

### **Modern Optimizer**  
- **AdamW Optimizer:** Provides better generalization and faster convergence compared to traditional optimizers.  

### **Learning Rate Scheduler**  
- **OneCycleLR:** Dynamically adjusts the learning rate for stable and faster convergence during training.  

### **Weight Sharing**  
- **Shared Weights in Layer 1:** Reduces redundant computations and lowers parameter count by sharing weights across blocks in `layer1`.  

In [9]:
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
from torchvision.datasets import CIFAR10
from tqdm import tqdm
import warnings
import torch.optim as optim
from torch.amp import GradScaler, autocast
import time

class BasicBlock(nn.Module):
    def __init__(self, in_channels, out_channels, stride=1, share_weights=False):
        super().__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, 3, stride, 1, bias=False)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv2d(out_channels, out_channels, 3, 1, 1, bias=False)
        self.bn2 = nn.BatchNorm2d(out_channels)
        self.use_shortcut = stride == 1 and in_channels == out_channels
        self.share_weights = share_weights

    def forward(self, x):
        identity = x if self.use_shortcut else None
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)
        out = self.conv2(out)
        out = self.bn2(out)
        if identity is not None and not self.share_weights:
            out += identity
        out = self.relu(out)
        return out

class ResNet10(nn.Module):
    def __init__(self, num_classes=10):
        super().__init__()
        self.in_channels = 64

        # Initial layer
        self.conv1 = nn.Conv2d(3, 64, 3, 1, 1, bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.relu = nn.ReLU(inplace=True)

        # Layer 1 shares weights across blocks
        self.layer1 = nn.Sequential(
            BasicBlock(self.in_channels, 64, 1, share_weights=True)
        )
        self.in_channels = 64  # Ensure in_channels is updated

        # Layers with unique parameters
        self.layer2 = self._make_layer(128, 2)
        self.layer3 = self._make_layer(256, 2)
        self.layer4 = self._make_layer(512, 2)

        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(512, num_classes)

    def _make_layer(self, out_channels, stride):
        layer = BasicBlock(self.in_channels, out_channels, stride)
        self.in_channels = out_channels  # Update in_channels after creating the layer
        return nn.Sequential(layer)

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)
        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        x = self.fc(x)
        return x

def initialize_weights(model):
    """
    Apply Kaiming Initialization to Convolution and Linear layers
    """
    for m in model.modules():
        if isinstance(m, nn.Conv2d):
            nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
        elif isinstance(m, nn.BatchNorm2d):
            nn.init.constant_(m.weight, 1)
            nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.Linear):
            nn.init.kaiming_normal_(m.weight, mode='fan_out')
            nn.init.constant_(m.bias, 0)

# Data loading
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

train_dataset = CIFAR10(root='./data', train=True, download=True, transform=transform)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=256, shuffle=True, num_workers=4, pin_memory=True)

test_dataset = CIFAR10(root='./data', train=False, download=True, transform=transform)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=256, shuffle=False, num_workers=4, pin_memory=True)

# Suppress specific warnings
warnings.filterwarnings("ignore", category=FutureWarning, module="torch.cuda.amp")

# Initialize the model, loss function, and optimizer
optim_model = ResNet10().cuda()
initialize_weights(optim_model)  # Apply Kaiming initialization

criterion = nn.CrossEntropyLoss().cuda()
optimizer = optim.AdamW(optim_model.parameters(), lr=0.001, weight_decay=1e-4)
scaler = GradScaler()
scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr=0.01, steps_per_epoch=len(train_loader), epochs=5)

# Gradient Accumulation
accumulation_steps = 4

start_time = time.time()
# Training loop
num_epochs = 5
start_time = time.time()
for epoch in range(num_epochs):
    optim_model.train()
    epoch_start = time.time()
    train_loop = tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}", leave=False)
    
    for i, (inputs, labels) in enumerate(train_loop):
        inputs, labels = inputs.cuda(), labels.cuda()
        with autocast(device_type='cuda'):
            outputs = optim_model(inputs)
            loss = criterion(outputs, labels) / accumulation_steps

        scaler.scale(loss).backward()
        
        if (i + 1) % accumulation_steps == 0:
            scaler.step(optimizer)
            scaler.update()
            scheduler.step()
            optimizer.zero_grad()  # Move after accumulation

    epoch_end = time.time()
    print(f"Epoch {epoch+1}, Loss: {loss.item():.4f}, Time: {epoch_end - epoch_start:.2f}s")

total_time_optimized = time.time() - start_time
print(f"Total training time: {total_time_optimized:.2f}s")


Files already downloaded and verified
Files already downloaded and verified


                                                            

Epoch 1, Loss: 0.4053, Time: 8.89s


                                                            

Epoch 2, Loss: 0.2961, Time: 9.18s


                                                            

Epoch 3, Loss: 0.2468, Time: 9.61s


                                                            

Epoch 4, Loss: 0.1558, Time: 9.74s


                                                            

Epoch 5, Loss: 0.2252, Time: 9.54s
Total training time: 46.98s




In [12]:
optimized_accuracy = test_model(optim_model, test_loader, device)

In [13]:
optimized_accuracy = test_model(optim_model, test_loader, device)

# ----------------------
# Results Comparison
# ----------------------

print("\nPerformance Comparison:")
print(f"{'Implementation':<20} | {'Accuracy (%)':<15} | {'Training Time (s)':<20}")
print(f"{'Initial':<20} | {initial_accuracy:.2f} | {total_time_unoptimized:.2f}")
print(f"{'Optimized':<20} | {optimized_accuracy:.2f} | {total_time_optimized:.2f}")


Performance Comparison:
Implementation       | Accuracy (%)    | Training Time (s)   
Initial              | 65.65 | 489.52
Optimized            | 66.67 | 46.98
