# CNN on MNIST + Clean Eval Function

In [1]:
import torch
import torch.nn as nn

## 1. Define a small CNN

In [2]:
class CNNClassifier(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 16, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(16, 32, kernel_size=3, padding=1)
        self.pool = nn.MaxPool2d(2, 2)
        self.fc1 = nn.Linear(32 * 7 * 7, 64)
        self.fc2 = nn.Linear(64, 10)

    def forward(self, x):
        x = torch.relu(self.conv1(x))
        x = self.pool(torch.relu(self.conv2(x)))  # 28x28 -> 14x14
        x = self.pool(x)  # 14x14 -> 7x7
        x = x.view(x.size(0), -1)
        x = torch.relu(self.fc1(x))
        return self.fc2(x)


## 2. Evaluation Function

In [4]:
def evaluate(model, dataloader, device):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for x, y in dataloader:
            x, y = x.to(device), y.to(device)
            logits = model(x)
            preds = logits.argmax(dim=1)
            correct += (preds == y).sum().item()
            total += y.size(0)
    return correct / total


## 3. Train and Log

In [12]:
from torchvision import transforms, datasets
from torch.utils.data import DataLoader

In [25]:
def train_one_epoch(model, optim, data_loader, loss_func):
    running_loss = 0.0
    correct = 0
    total = 0
    for x, y in data_loader:
        optim.zero_grad()
        y_pred = model(x)
        loss = loss_func(y_pred, y)
        
        loss.backward()
        
        optim.step()
    
        running_loss += loss.item() * y.size(0)
        preds = y_pred.argmax(dim=1)
        correct += (preds == y).sum().item()
        total += y.size(0)

    epoch_loss = running_loss / total
    epoch_acc = correct / total
    return epoch_acc, epoch_loss


In [26]:
transform = transforms.ToTensor()
train_ds = datasets.MNIST(root="./data", train=True, download=True, transform=transform)
test_ds = datasets.MNIST(root="./data", train=False, download=True, transform=transform)
train_loader = DataLoader(train_ds, batch_size=128, shuffle=True)
test_loader = DataLoader(test_ds, batch_size=256, shuffle=False)

In [27]:
cnn = CNNClassifier()
optim = torch.optim.SGD(params=cnn.parameters(), lr=0.1)

loss_func = nn.CrossEntropyLoss()

In [29]:
epochs = 6

for e in range(epochs):
    acc, loss = train_one_epoch(cnn, optim, train_loader, loss_func)
    print(f"Epoch: {e} | Loss: {loss} | Acc: {acc} | Validation Acc: {evaluate(cnn, test_loader, 'cpu')}")

Epoch: 0 | Loss: 0.10820321278572083 | Acc: 0.9668333333333333 | Validation Acc: 0.9761
Epoch: 1 | Loss: 0.07534426254431406 | Acc: 0.9766666666666667 | Validation Acc: 0.9786
Epoch: 2 | Loss: 0.06033348806599776 | Acc: 0.98155 | Validation Acc: 0.9815
Epoch: 3 | Loss: 0.04994450241004427 | Acc: 0.9846166666666667 | Validation Acc: 0.9857
Epoch: 4 | Loss: 0.04305717636048794 | Acc: 0.9868 | Validation Acc: 0.9872
Epoch: 5 | Loss: 0.036738085391620795 | Acc: 0.9885666666666667 | Validation Acc: 0.9875
