# CNN Classifier Training

Purpose: Train classifier on real datasets only.

Includes: MNIST, EMNIST, FashionMNIST  
Excludes: generated data and VAE training


In [1]:
import sys
from pathlib import Path
import torch
import torch.nn as nn
from torch.optim import Adam
from torchvision import models

current = Path().resolve()
while not (current / "src").exists():
    current = current.parent

sys.path.append(str(current))
print("Project root:", current)


  warn(


Project root: /workspace


In [2]:
from src.datasets.grayscale_datasets import get_grayscale_loader

In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

batch_size = 64
epochs = 8        # CPU-safe, enough for validation
lr = 1e-3


Using device: cpu


In [4]:
model = models.resnet18(weights=None)

# change first conv to 1-channel
model.conv1 = nn.Conv2d(
    1, 64,
    kernel_size=7,
    stride=2,
    padding=3,
    bias=False
)

# number of classes depends on dataset
model = model.to(device)


In [5]:
def train_classifier(model, loader, num_classes, epochs):
    model.fc = nn.Linear(model.fc.in_features, num_classes).to(device)
    optimizer = Adam(model.parameters(), lr=lr)
    criterion = nn.CrossEntropyLoss()

    model.train()

    for epoch in range(epochs):
        total_loss = 0
        correct = 0
        total = 0

        for x, y in loader:
            x = x.to(device)
            y = y.to(device)

            optimizer.zero_grad()
            logits = model(x)
            loss = criterion(logits, y)
            loss.backward()
            optimizer.step()

            total_loss += loss.item()
            preds = logits.argmax(dim=1)
            correct += (preds == y).sum().item()
            total += y.size(0)

        acc = 100 * correct / total
        print(
            f"Epoch [{epoch+1}/{epochs}] | "
            f"Loss: {total_loss:.2f} | "
            f"Acc: {acc:.2f}%"
        )


In [6]:
datasets = {
    "mnist": 10,
    "fashion": 10,
    "emnist": 26   # letters
}

for ds, num_classes in datasets.items():
    print(f"\n=== TRAINING CNN ON {ds.upper()} ===")

    loader = get_grayscale_loader(
        dataset_name=ds,
        root=current / "data" / "raw",
        batch_size=batch_size
    )

    model = models.resnet18(weights=None)
    model.conv1 = nn.Conv2d(
        1, 64,
        kernel_size=7,
        stride=2,
        padding=3,
        bias=False
    )
    model = model.to(device)

    train_classifier(model, loader, num_classes, epochs)

    ckpt_path = current / "checkpoints" / "grayscale" / f"resnet18_{ds}.pt"
    torch.save(model.state_dict(), ckpt_path)
    print(f"Saved CNN → {ckpt_path}")



=== TRAINING CNN ON MNIST ===
Epoch [1/8] | Loss: 127.06 | Acc: 95.85%
Epoch [2/8] | Loss: 55.63 | Acc: 98.25%
Epoch [3/8] | Loss: 41.51 | Acc: 98.70%
Epoch [4/8] | Loss: 36.44 | Acc: 98.83%
Epoch [5/8] | Loss: 32.20 | Acc: 99.00%
Epoch [6/8] | Loss: 26.27 | Acc: 99.18%
Epoch [7/8] | Loss: 25.63 | Acc: 99.22%
Epoch [8/8] | Loss: 21.61 | Acc: 99.26%
Saved CNN → /workspace/checkpoints/grayscale/resnet18_mnist.pt

=== TRAINING CNN ON FASHION ===
Epoch [1/8] | Loss: 402.31 | Acc: 84.47%
Epoch [2/8] | Loss: 286.54 | Acc: 88.78%
Epoch [3/8] | Loss: 251.40 | Acc: 90.04%
Epoch [4/8] | Loss: 223.17 | Acc: 91.16%
Epoch [5/8] | Loss: 203.71 | Acc: 92.02%
Epoch [6/8] | Loss: 185.84 | Acc: 92.56%
Epoch [7/8] | Loss: 174.18 | Acc: 93.05%
Epoch [8/8] | Loss: 155.68 | Acc: 93.73%
Saved CNN → /workspace/checkpoints/grayscale/resnet18_fashion.pt

=== TRAINING CNN ON EMNIST ===


IndexError: Target 26 is out of bounds.