In [None]:
# train_cifar.py
import os
import random
import numpy as np
from tqdm import tqdm
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

# -----------------------
# Reproducibility / device
# -----------------------
seed = 42
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(seed)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device:", device)

# -----------------------
# Hyperparams
# -----------------------
batch_size = 128
num_epochs = 30
learning_rate = 0.01
num_workers = 4  # set 0 on Windows or when debugging
save_dir = "./checkpoints"
os.makedirs(save_dir, exist_ok=True)

# -----------------------
# Data transforms + loaders
# -----------------------
cifar_mean = [0.4914, 0.4822, 0.4465]
cifar_std  = [0.2470, 0.2435, 0.2616]

train_transforms = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(cifar_mean, cifar_std),
])

val_transforms = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(cifar_mean, cifar_std),
])

train_dataset = datasets.CIFAR10(root="./data", train=True, download=True, transform=train_transforms)
val_dataset   = datasets.CIFAR10(root="./data", train=False, download=True, transform=val_transforms)

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers, pin_memory=True)
val_loader   = DataLoader(val_dataset,   batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=True)

# -----------------------
# Model: small CNN
# -----------------------
class SimpleCNN(nn.Module):
    def __init__(self, num_classes=10):
        super().__init__()
        self.features = nn.Sequential(
            nn.Conv2d(3, 32, kernel_size=3, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2),  # 16x16

            nn.Conv2d(32, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2),  # 8x8

            nn.Conv2d(64, 128, kernel_size=3, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),
            nn.AdaptiveAvgPool2d((1,1)),  # 1x1
        )
        self.classifier = nn.Linear(128, num_classes)

    def forward(self, x):
        x = self.features(x)
        x = x.view(x.size(0), -1)
        return self.classifier(x)

model = SimpleCNN(num_classes=10).to(device)

# -----------------------
# Loss, optimizer, scheduler
# -----------------------
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=learning_rate, momentum=0.9, weight_decay=5e-4)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=15, gamma=0.1)

# -----------------------
# Train / validate loops
# -----------------------
def train_one_epoch(epoch):
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0
    pbar = tqdm(train_loader, desc=f"Train {epoch}")
    for inputs, targets in pbar:
        inputs, targets = inputs.to(device), targets.to(device)
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, targets)
        loss.backward()
        optimizer.step()

        running_loss += loss.item() * inputs.size(0)
        _, predicted = outputs.max(1)
        total += targets.size(0)
        correct += predicted.eq(targets).sum().item()
        pbar.set_postfix(loss=running_loss/total, acc=100.*correct/total)

    return running_loss / total, 100.*correct/total

def validate(epoch):
    model.eval()
    running_loss = 0.0
    correct = 0
    total = 0
    with torch.no_grad():
        for inputs, targets in val_loader:
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = model(inputs)
            loss = criterion(outputs, targets)
            running_loss += loss.item() * inputs.size(0)
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()
    return running_loss/total, 100.*correct/total

best_val_acc = 0.0
for epoch in range(1, num_epochs+1):
    train_loss, train_acc = train_one_epoch(epoch)
    val_loss, val_acc = validate(epoch)
    scheduler.step()

    print(f"Epoch {epoch}: Train loss {train_loss:.4f}, acc {train_acc:.2f}% | Val loss {val_loss:.4f}, acc {val_acc:.2f}%")
    # Save best model
    if val_acc > best_val_acc:
        best_val_acc = val_acc
        torch.save({
            'epoch': epoch,
            'model_state': model.state_dict(),
            'optimizer_state': optimizer.state_dict(),
            'val_acc': val_acc,
        }, os.path.join(save_dir, "best_cifar_model.pth"))
        print(f"Saved new best model (val acc {val_acc:.2f}%)")

print("Training finished. Best val acc:", best_val_acc)

Device: cuda


100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 170M/170M [00:07<00:00, 22.5MB/s]
Train 1: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 391/391 [00:10<00:00, 37.20it/s, acc=39.4, loss=1.66]


Epoch 1: Train loss 1.6605, acc 39.37% | Val loss 1.4456, acc 46.71%
Saved new best model (val acc 46.71%)


Train 2: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 391/391 [00:03<00:00, 126.00it/s, acc=52.1, loss=1.34]


Epoch 2: Train loss 1.3352, acc 52.06% | Val loss 1.4071, acc 50.06%
Saved new best model (val acc 50.06%)


Train 3: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 391/391 [00:03<00:00, 119.93it/s, acc=56.9, loss=1.21]


Epoch 3: Train loss 1.2050, acc 56.88% | Val loss 1.2258, acc 55.32%
Saved new best model (val acc 55.32%)


Train 4: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 391/391 [00:03<00:00, 120.01it/s, acc=59.7, loss=1.14]


Epoch 4: Train loss 1.1378, acc 59.70% | Val loss 1.1411, acc 58.96%
Saved new best model (val acc 58.96%)


Train 5: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 391/391 [00:03<00:00, 120.84it/s, acc=61.3, loss=1.08]


Epoch 5: Train loss 1.0848, acc 61.29% | Val loss 1.2236, acc 56.08%


Train 6: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 391/391 [00:03<00:00, 117.13it/s, acc=63.4, loss=1.04]


Epoch 6: Train loss 1.0374, acc 63.41% | Val loss 1.1371, acc 59.04%
Saved new best model (val acc 59.04%)


Train 7: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 391/391 [00:03<00:00, 116.66it/s, acc=64.6, loss=1]


Epoch 7: Train loss 1.0046, acc 64.64% | Val loss 1.0778, acc 62.28%
Saved new best model (val acc 62.28%)


Train 8: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 391/391 [00:03<00:00, 118.40it/s, acc=65.9, loss=0.972]


Epoch 8: Train loss 0.9717, acc 65.90% | Val loss 0.9868, acc 64.42%
Saved new best model (val acc 64.42%)


Train 9: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 391/391 [00:03<00:00, 114.80it/s, acc=66.9, loss=0.942]


Epoch 9: Train loss 0.9423, acc 66.89% | Val loss 1.0437, acc 63.41%


Train 10: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 391/391 [00:03<00:00, 114.55it/s, acc=67.6, loss=0.922]


Epoch 10: Train loss 0.9223, acc 67.64% | Val loss 1.0304, acc 64.01%


Train 11: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 391/391 [00:03<00:00, 115.16it/s, acc=68.8, loss=0.896]


Epoch 11: Train loss 0.8965, acc 68.76% | Val loss 0.9729, acc 65.36%
Saved new best model (val acc 65.36%)


Train 12: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 391/391 [00:03<00:00, 116.38it/s, acc=69.7, loss=0.869]


Epoch 12: Train loss 0.8691, acc 69.68% | Val loss 0.9408, acc 67.09%
Saved new best model (val acc 67.09%)


Train 13: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 391/391 [00:03<00:00, 117.43it/s, acc=70.8, loss=0.848]


In [None]:
# Load best checkpoint
checkpoint = torch.load("./checkpoints/best_cifar_model.pth", map_location=device)

model = SimpleCNN(num_classes=10).to(device)
model.load_state_dict(checkpoint["model_state"])
model.eval()

print("Loaded model from epoch:", checkpoint["epoch"], "with val acc:", checkpoint["val_acc"])

In [None]:
test_dataset = datasets.CIFAR10(
    root="./data",
    train=False,
    download=True,
    transform=val_transforms  # same normalization you defined earlier
)

test_loader = DataLoader(test_dataset, batch_size=128, shuffle=False, num_workers=2)

correct, total = 0, 0
with torch.no_grad():
    for images, labels in test_loader:
        images, labels = images.to(device), labels.to(device)
        outputs = model(images)
        _, preds = outputs.max(1)
        total += labels.size(0)
        correct += preds.eq(labels).sum().item()

print(f"Test accuracy: {100. * correct / total:.2f}%")

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import torch
from torchvision import datasets, transforms

# CIFAR-10 classes
classes = ['airplane','automobile','bird','cat','deer',
           'dog','frog','horse','ship','truck']

# Load best model
checkpoint = torch.load("./checkpoints/best_cifar_model.pth", map_location=device)
model = SimpleCNN(num_classes=10).to(device)
model.load_state_dict(checkpoint['model_state'])
model.eval()

# Test dataset
test_dataset = datasets.CIFAR10(root="./data", train=False, download=True, transform=val_transforms)

# Function to display predictions with color feedback
def show_predictions(model, dataset, num_images=12):
    model.eval()
    fig, axes = plt.subplots(3, 4, figsize=(12, 9))
    axes = axes.flatten()
    
    for i in range(num_images):
        img, label = dataset[i]
        with torch.no_grad():
            output = model(img.unsqueeze(0).to(device))
            pred = output.argmax(1).item()
        
        # Unnormalize image
        img_np = img.permute(1,2,0).cpu().numpy() * np.array(cifar_std) + np.array(cifar_mean)
        img_np = np.clip(img_np, 0, 1)
        
        axes[i].imshow(img_np)
        axes[i].axis("off")
        
        # Set color: green if correct, red if wrong
        color = 'green' if pred == label else 'red'
        axes[i].set_title(f"P: {classes[pred]}\nT: {classes[label]}", color=color)
    
    plt.tight_layout()
    plt.show()

# Show predictions
show_predictions(model, test_dataset, num_images=12)