# CIFAR10 Training with ResNet18 and Mixed Precision

This notebook demonstrates how to train a ResNet18 model on the CIFAR10 dataset using PyTorch. It incorporates mixed precision training to potentially accelerate the training process and reduce memory consumption, especially when using a GPU.

The notebook covers:
- Loading and preparing the CIFAR10 dataset.
- Initializing a ResNet18 model, optimizer, and loss function.
- Implementing a training loop with mixed precision using `torch.cuda.amp`.
- Monitoring the training progress and loss.

In [1]:
# ============================================================
# Mixed Precision vs Full Precision — CIFAR10 + ResNet18
# ============================================================

import torch, time, gc
from torch import nn, optim
from torchvision.models import resnet18
from torchvision.datasets import CIFAR10
from torchvision.transforms import ToTensor
from torch.utils.data import DataLoader
from torch.amp import autocast, GradScaler

device = 'cuda' if torch.cuda.is_available() else 'cpu'
print("Using device:", device)

Using device: cuda


In [2]:
# ---------------------------
# Data
# ---------------------------
train_loader = DataLoader(
    CIFAR10(root='.', train=True, download=True, transform=ToTensor()),
    batch_size=128, shuffle=True
)

# ---------------------------
# Helper function
# ---------------------------
def train_epoch(model, loader, optimizer, loss_fn, use_amp=False):
    model.train()
    total_loss = 0.0
    scaler = GradScaler('cuda') if use_amp else None

    start = time.time()
    for x, y in loader:
        x, y = x.to(device), y.to(device)
        optimizer.zero_grad()

        if use_amp:
            with autocast('cuda'):
                pred = model(x)
                loss = loss_fn(pred, y)
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
        else:
            pred = model(x)
            loss = loss_fn(pred, y)
            loss.backward()
            optimizer.step()

        total_loss += loss.item()

    elapsed = time.time() - start
    return total_loss / len(loader), elapsed

100%|██████████| 170M/170M [00:17<00:00, 9.71MB/s]


In [3]:
# ---------------------------
# Run FP32 baseline
# ---------------------------
model_fp32 = resnet18(num_classes=10).to(device)
opt_fp32 = optim.Adam(model_fp32.parameters(), lr=1e-3)
loss_fn = nn.CrossEntropyLoss()

torch.cuda.reset_peak_memory_stats(device)
loss32, time32 = train_epoch(model_fp32, train_loader, opt_fp32, loss_fn, use_amp=False)
mem32 = torch.cuda.max_memory_allocated(device) / 1e6  # MB

# ---------------------------
# Run AMP
# ---------------------------
gc.collect()
torch.cuda.empty_cache()

model_amp = resnet18(num_classes=10).to(device)
opt_amp = optim.Adam(model_amp.parameters(), lr=1e-3)

torch.cuda.reset_peak_memory_stats(device)
loss16, time16 = train_epoch(model_amp, train_loader, opt_amp, loss_fn, use_amp=True)
mem16 = torch.cuda.max_memory_allocated(device) / 1e6  # MB

# ---------------------------
# Compare results
# ---------------------------
print("\n=== Comparison Summary ===")
print(f"FP32  | Time: {time32:.2f}s | Peak Memory: {mem32:.1f} MB | Loss: {loss32:.3f}")
print(f"AMP   | Time: {time16:.2f}s | Peak Memory: {mem16:.1f} MB | Loss: {loss16:.3f}")

speedup = (time32 - time16) / time32 * 100
memsave = (mem32 - mem16) / mem32 * 100
print(f"\nSpeedup: {speedup:.1f}% faster with AMP")
print(f"Memory: {memsave:.1f}% less memory used")



=== Comparison Summary ===
FP32  | Time: 11.81s | Peak Memory: 270.5 MB | Loss: 1.368
AMP   | Time: 11.62s | Peak Memory: 431.4 MB | Loss: 1.380

Speedup: 1.6% faster with AMP
Memory: -59.5% less memory used


In [4]:
# ---------------------------
# Compare results
# ---------------------------
print("\n=== Comparison Summary ===")
print(f"FP32  | Time: {time32:.2f}s | Peak Memory: {mem32:.1f} MB | Loss: {loss32:.3f}")
print(f"AMP   | Time: {time16:.2f}s | Peak Memory: {mem16:.1f} MB | Loss: {loss16:.3f}")

speedup = (time32 - time16) / time32 * 100
memsave = (mem32 - mem16) / mem32 * 100
print(f"\nSpeedup: {speedup:.1f}% faster with AMP")
print(f"Memory: {memsave:.1f}% less memory used")


=== Comparison Summary ===
FP32  | Time: 11.81s | Peak Memory: 270.5 MB | Loss: 1.368
AMP   | Time: 11.62s | Peak Memory: 431.4 MB | Loss: 1.380

Speedup: 1.6% faster with AMP
Memory: -59.5% less memory used
