# CIFAR-10 Image Classification (PyTorch)

This notebook provides clean, separable blocks you can copy to Colab:
- Installs
- Data setup (transforms + loaders)
- CNN model
- Training loop
- Evaluation loop

GPU recommended.


In [None]:
# Installs (Colab-friendly). Skip if already installed.
pip -q install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu121
pip -q install torchmetrics tqdm


In [None]:
# Data setup: transforms and loaders
import torch
from torch.utils.data import DataLoader
import torchvision
from torchvision import transforms

BATCH_SIZE = 128
# Notebook/Colab-safe: disable multiprocessing to avoid child process assertions
NUM_WORKERS = 0

# Standard CIFAR-10 normalization
transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)),
])

transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)),
])

trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train)
trainloader = DataLoader(
    trainset,
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=NUM_WORKERS,
    pin_memory=torch.cuda.is_available(),
    persistent_workers=False,
)

testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test)
testloader = DataLoader(
    testset,
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=NUM_WORKERS,
    pin_memory=torch.cuda.is_available(),
    persistent_workers=False,
)

CLASSES = trainset.classes
print('Classes:', CLASSES)


In [None]:
# Model: a simple CNN suitable for CIFAR-10
import torch
import torch.nn as nn
import torch.nn.functional as F

class CifarCNN(nn.Module):
    def __init__(self, num_classes: int = 10):
        super().__init__()
        self.features = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2),  # 16x16

            nn.Conv2d(64, 128, kernel_size=3, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),
            nn.Conv2d(128, 128, kernel_size=3, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2),  # 8x8

            nn.Conv2d(128, 256, kernel_size=3, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),
            nn.Conv2d(256, 256, kernel_size=3, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2),  # 4x4
        )
        self.classifier = nn.Sequential(
            nn.Dropout(0.5),
            nn.Linear(256 * 4 * 4, 512),
            nn.ReLU(inplace=True),
            nn.Dropout(0.5),
            nn.Linear(512, num_classes),
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.features(x)
        x = torch.flatten(x, 1)
        x = self.classifier(x)
        return x

# Init model and optimizer
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = CifarCNN(num_classes=len(CLASSES)).to(DEVICE)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3, weight_decay=5e-4)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=30)

sum(p.numel() for p in model.parameters())


In [None]:
# Training loop
from tqdm.auto import tqdm

def train_one_epoch(model, loader, optimizer, criterion, device):
    model.train()
    total_loss = 0.0
    total_correct = 0
    total = 0
    pbar = tqdm(loader, leave=False)
    for images, targets in pbar:
        images = images.to(device, non_blocking=True)
        targets = targets.to(device, non_blocking=True)
        optimizer.zero_grad(set_to_none=True)
        outputs = model(images)
        loss = criterion(outputs, targets)
        loss.backward()
        optimizer.step()

        total_loss += loss.item() * images.size(0)
        _, preds = outputs.max(1)
        total_correct += preds.eq(targets).sum().item()
        total += images.size(0)
        pbar.set_postfix({"loss": total_loss / total, "acc": total_correct / total})

    return total_loss / total, total_correct / total



In [None]:
# Evaluation loop
import torch
from torchmetrics.classification import MulticlassAccuracy

@torch.no_grad()
def evaluate(model, loader, criterion, device, num_classes=10):
    model.eval()
    total_loss = 0.0
    total = 0
    acc_metric = MulticlassAccuracy(num_classes=num_classes).to(device)

    for images, targets in loader:
        images = images.to(device, non_blocking=True)
        targets = targets.to(device, non_blocking=True)
        outputs = model(images)
        loss = criterion(outputs, targets)
        total_loss += loss.item() * images.size(0)
        total += images.size(0)
        acc_metric.update(outputs, targets)

    avg_loss = total_loss / total
    acc = acc_metric.compute().item()
    return avg_loss, acc



In [None]:
# Train for a few epochs and evaluate
EPOCHS = 15
best_acc = 0.0

for epoch in range(1, EPOCHS + 1):
    train_loss, train_acc = train_one_epoch(model, trainloader, optimizer, criterion, DEVICE)
    val_loss, val_acc = evaluate(model, testloader, criterion, DEVICE, num_classes=len(CLASSES))
    scheduler.step()

    print(f"Epoch {epoch:02d}/{EPOCHS} | train_loss={train_loss:.4f} acc={train_acc:.4f} | val_loss={val_loss:.4f} acc={val_acc:.4f}")
    if val_acc > best_acc:
        best_acc = val_acc
        torch.save({
            'model_state': model.state_dict(),
            'optimizer_state': optimizer.state_dict(),
            'epoch': epoch,
            'classes': CLASSES,
        }, 'cifar10_cnn.pt')
        print(f"Saved checkpoint with acc={best_acc:.4f}")

print(f"Best val acc: {best_acc:.4f}")


In [None]:
# Save only model parameters (state dict) and metadata
SAVE_PATH = 'cifar10_cnn_state.pt'
META_PATH = 'cifar10_meta.pt'

torch.save(model.state_dict(), SAVE_PATH)
# minimal metadata you may want for inference
meta = {
    'classes': CLASSES,
    'normalize_mean': (0.4914, 0.4822, 0.4465),
    'normalize_std': (0.2470, 0.2435, 0.2616),
    'arch': 'CifarCNN',
}
torch.save(meta, META_PATH)

print('Saved:', SAVE_PATH, 'and', META_PATH)


In [None]:
# Load model parameters later (e.g., in an API)
import torch

# Recreate model architecture
loaded_meta = torch.load('cifar10_meta.pt', map_location='cpu')
loaded_classes = loaded_meta.get('classes', [str(i) for i in range(10)])

inference_model = CifarCNN(num_classes=len(loaded_classes))
state = torch.load('cifar10_cnn_state.pt', map_location='cpu')
inference_model.load_state_dict(state, strict=True)
inference_model.eval()

print('Loaded model with', len(loaded_classes), 'classes')
