# P3: CNN for CIFAR-10

**Objective:** Train a small CNN on CIFAR-10 and visualize training metrics.

In [None]:
import tensorflow as tf
from tensorflow.keras import layers, models
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.cifar10.load_data()
x_train = x_train.astype('float32')/255.0
x_test = x_test.astype('float32')/255.0
model = models.Sequential([
  layers.Input(shape=(32,32,3)),
  layers.Conv2D(32, 3, activation='relu'),
  layers.MaxPooling2D(),
  layers.Conv2D(64, 3, activation='relu'),
  layers.MaxPooling2D(),
  layers.Flatten(),
  layers.Dense(128, activation='relu'),
  layers.Dense(10, activation='softmax')
])
model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy'])
model.fit(x_train, y_train, validation_split=0.1, epochs=5, batch_size=64)

In [None]:
# Practical 3: CNN on CIFAR-10 (PyTorch)
import torch, torch.nn as nn, torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt

tfm = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))])
train_ds = datasets.CIFAR10(root='./data', train=True, download=True, transform=tfm)
test_ds = datasets.CIFAR10(root='./data', train=False, download=True, transform=tfm)
train_dl = DataLoader(train_ds, batch_size=128, shuffle=True)
test_dl = DataLoader(test_ds, batch_size=256)

class SmallCNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv2d(3,32,3,padding=1), nn.ReLU(), nn.MaxPool2d(2),
            nn.Conv2d(32,64,3,padding=1), nn.ReLU(), nn.MaxPool2d(2),
            nn.Conv2d(64,128,3,padding=1), nn.ReLU(), nn.AdaptiveAvgPool2d(1),
        )
        self.fc = nn.Linear(128, 10)
    def forward(self,x):
        h = self.net(x).view(x.size(0), -1)
        return self.fc(h)

model = SmallCNN(); opt = optim.Adam(model.parameters(), lr=1e-3); crit = nn.CrossEntropyLoss()
for epoch in range(2):
    model.train()
    for xb,yb in train_dl:
        logits = model(xb)
        loss = crit(logits,yb)
        opt.zero_grad(); loss.backward(); opt.step()
    print('epoch', epoch, 'loss', loss.item())

model.eval(); correct=0; total=0
with torch.no_grad():
    for xb,yb in test_dl:
        pred = model(xb).argmax(1)
        correct += (pred==yb).sum().item(); total += yb.numel()
print('Test acc:', correct/total)