# Artificial Vision & Feature Separability — 06 · CNN on EMNIST + Representation Probing

**Goal.** Train a small CNN on EMNIST (letters/digits), then **probe internal representations** with PCA and linear probes to study feature separability across layers.

**Outputs.** Training curves, confusion matrix, example predictions, **PCA plots of layer activations**, and linear-probe accuracies (per layer).

**Data.** `torchvision.datasets.EMNIST` (byclass split). Use SSL fix below if downloads fail locally.

In [None]:
# --- Reproducibility & Environment (with SSL fix) ---
import os, random, numpy as np, certifi, torch

SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)

os.makedirs("results", exist_ok=True)
os.makedirs("data", exist_ok=True)

# SSL fix so dataset downloads don't fail on some systems
os.environ["SSL_CERT_FILE"] = certifi.where()
print("SSL_CERT_FILE set to:", os.environ["SSL_CERT_FILE"])
print("Seed set to", SEED)

In [None]:
# --- Imports ---
import numpy as np
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Subset

import torchvision
from torchvision import transforms

from sklearn.decomposition import PCA
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import confusion_matrix, classification_report, accuracy_score, log_loss, roc_auc_score
from sklearn.preprocessing import StandardScaler

## 1. Config
Adjust batch size, epochs, and subset sizes for quick experimentation or full runs.

In [None]:
class Cfg:
    split = "byclass"    # EMNIST split: 'byclass' has 62 classes (10 digits + 52 letters)
    batch_size = 128
    epochs = 5           # increase (e.g., 10–20) for stronger accuracy
    lr = 1e-3
    train_subset = 20000   # set to None for full train
    test_subset  = 5000    # set to None for full test
    num_workers = 2

cfg = Cfg()
vars(cfg)

## 2. Data — EMNIST (byclass)
EMNIST images come rotated; we apply the recommended transform. We also optionally subset for speed.

In [None]:
# EMNIST images need to be transposed; torchvision provides 'transforms' guidance.
# We'll convert to tensors and normalize to mean/std of EMNIST if desired; here we scale to [0,1].
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Lambda(lambda x: x.transpose(1,2).flip(2))  # rotate to upright
])

trainset = torchvision.datasets.EMNIST(root="./data", split=cfg.split, train=True, download=True, transform=transform)
testset  = torchvision.datasets.EMNIST(root="./data", split=cfg.split, train=False, download=True, transform=transform)

if cfg.train_subset is not None:
    train_idx = list(range(min(cfg.train_subset, len(trainset))))
    trainset = Subset(trainset, train_idx)

if cfg.test_subset is not None:
    test_idx = list(range(min(cfg.test_subset, len(testset))))
    testset = Subset(testset, test_idx)

train_loader = DataLoader(trainset, batch_size=cfg.batch_size, shuffle=True, num_workers=cfg.num_workers)
test_loader  = DataLoader(testset,  batch_size=cfg.batch_size, shuffle=False, num_workers=cfg.num_workers)

# Number of classes
if isinstance(trainset, Subset):
    n_classes = len(set([int(trainset.dataset.classes[i]) for i in range(len(trainset.dataset.classes))])) if hasattr(trainset.dataset, "classes") else 62
else:
    n_classes = len(trainset.classes) if hasattr(trainset, "classes") else 62
print("Train batches:", len(train_loader), " Test batches:", len(test_loader), " Classes:", n_classes)

## 3. Model — Small CNN with Named Layers
We expose intermediate activations (`conv1`, `conv2`, `fc1`) for probing.

In [None]:
class SmallCNN(nn.Module):
    def __init__(self, num_classes=62):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 32, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
        self.pool  = nn.MaxPool2d(2,2)
        self.drop  = nn.Dropout(0.25)
        self.fc1   = nn.Linear(64*7*7, 128)
        self.fc2   = nn.Linear(128, num_classes)
    def forward(self, x, return_acts=False):
        x = F.relu(self.conv1(x))
        a1 = x
        x = self.pool(x)
        x = F.relu(self.conv2(x))
        a2 = x
        x = self.pool(x)
        x = self.drop(x)
        x = x.view(x.size(0), -1)
        x = F.relu(self.fc1(x))
        a3 = x
        x = self.drop(x)
        logits = self.fc2(x)
        if return_acts:
            return logits, {"conv1": a1, "conv2": a2, "fc1": a3}
        return logits

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = SmallCNN(num_classes=n_classes).to(device)
sum(p.numel() for p in model.parameters()), device

## 4. Train
Standard cross-entropy training with Adam. We log **loss** and **accuracy** per epoch.

In [None]:
def train_one_epoch(model, loader, opt, loss_fn):
    model.train()
    total, correct, total_loss = 0, 0, 0.0
    for imgs, labels in loader:
        imgs, labels = imgs.to(device), labels.to(device)
        opt.zero_grad()
        logits = model(imgs)
        loss = loss_fn(logits, labels)
        loss.backward()
        opt.step()
        total += labels.size(0)
        total_loss += float(loss.item()) * labels.size(0)
        correct += (logits.argmax(1) == labels).sum().item()
    return total_loss/total, correct/total

@torch.no_grad()
def evaluate(model, loader, loss_fn):
    model.eval()
    total, correct, total_loss = 0, 0, 0.0
    for imgs, labels in loader:
        imgs, labels = imgs.to(device), labels.to(device)
        logits = model(imgs)
        loss = loss_fn(logits, labels)
        total += labels.size(0)
        total_loss += float(loss.item()) * labels.size(0)
        correct += (logits.argmax(1) == labels).sum().item()
    return total_loss/total, correct/total

import time
loss_fn = nn.CrossEntropyLoss()
opt = torch.optim.Adam(model.parameters(), lr=cfg.lr)

hist = {"train_loss": [], "train_acc": [], "val_loss": [], "val_acc": []}
for ep in range(cfg.epochs):
    t0 = time.time()
    tl, ta = train_one_epoch(model, train_loader, opt, loss_fn)
    vl, va = evaluate(model, test_loader, loss_fn)
    hist["train_loss"].append(tl); hist["train_acc"].append(ta)
    hist["val_loss"].append(vl);   hist["val_acc"].append(va)
    print(f"Epoch {ep+1:02d}/{cfg.epochs}  train_loss={tl:.4f}  train_acc={ta:.4f}  val_loss={vl:.4f}  val_acc={va:.4f}  ({time.time()-t0:.1f}s)")

## 5. Curves

In [None]:
import matplotlib.pyplot as plt

plt.figure()
plt.plot(hist["train_loss"], label="train_loss")
plt.plot(hist["val_loss"], label="val_loss")
plt.legend(); plt.title("Loss"); plt.xlabel("epoch"); plt.ylabel("loss")
plt.tight_layout(); plt.savefig("results/06_loss_curves.png", dpi=150); plt.show()

plt.figure()
plt.plot(hist["train_acc"], label="train_acc")
plt.plot(hist["val_acc"], label="val_acc")
plt.legend(); plt.title("Accuracy"); plt.xlabel("epoch"); plt.ylabel("acc")
plt.tight_layout(); plt.savefig("results/06_acc_curves.png", dpi=150); plt.show()

## 6. Confusion Matrix & Report

In [None]:
@torch.no_grad()
def collect_preds(model, loader):
    model.eval()
    all_logits, all_labels = [], []
    for imgs, labels in loader:
        imgs = imgs.to(device)
        logits = model(imgs)
        all_logits.append(logits.cpu())
        all_labels.append(labels)
    return torch.cat(all_logits,0).numpy(), torch.cat(all_labels,0).numpy()

logits_test, y_test = collect_preds(model, test_loader)
y_pred = logits_test.argmax(1)
cm = confusion_matrix(y_test, y_pred)
acc = accuracy_score(y_test, y_pred)
print(f"Test Accuracy: {acc:.4f}")
print(classification_report(y_test, y_pred))

plt.figure()
plt.imshow(cm, aspect="auto")
plt.title("EMNIST — Confusion Matrix (CNN)")
plt.xlabel("Pred"); plt.ylabel("True")
plt.colorbar(); plt.tight_layout(); plt.savefig("results/06_confusion_cnn_emnist.png", dpi=150); plt.show()

## 7. Representation Probing — PCA on Layer Activations
We project **conv1**, **conv2**, and **fc1** activations to 2D and visualize separability.

In [None]:
@torch.no_grad()
def collect_activations(model, loader, max_batches=50):
    model.eval()
    A = {"conv1": [], "conv2": [], "fc1": []}
    Y = []
    n = 0
    for i, (imgs, labels) in enumerate(loader):
        imgs = imgs.to(device)
        logits, acts = model(imgs, return_acts=True)
        A["conv1"].append(acts["conv1"].cpu().numpy().reshape(imgs.size(0), -1))
        A["conv2"].append(acts["conv2"].cpu().numpy().reshape(imgs.size(0), -1))
        A["fc1"].append(acts["fc1"].cpu().numpy())
        Y.append(labels.numpy())
        n += 1
        if max_batches is not None and n >= max_batches:
            break
    for k in A:
        A[k] = np.concatenate(A[k], axis=0)
    Y = np.concatenate(Y, axis=0)
    return A, Y

acts, y_small = collect_activations(model, test_loader, max_batches=40)  # limit for speed/memory
for layer in ["conv1","conv2","fc1"]:
    X = acts[layer]
    Xz = StandardScaler().fit_transform(X)
    p2 = PCA(n_components=2, random_state=SEED)
    Z = p2.fit_transform(Xz)
    plt.figure()
    # Use a subset of classes to avoid overplot clutter (optional)
    idx = np.random.choice(len(Z), size=min(4000, len(Z)), replace=False)
    plt.scatter(Z[idx,0], Z[idx,1], c=y_small[idx], s=6)
    plt.title(f"PCA — {layer} activations")
    plt.xlabel("PC1"); plt.ylabel("PC2")
    plt.tight_layout(); plt.savefig(f"results/06_pca_{layer}.png", dpi=150); plt.show()

## 8. Linear Probes (per layer)
Fit a **logistic regression** on frozen layer activations to quantify separability.

In [None]:
def linear_probe_acc(X_train, y_train, X_test, y_test):
    scaler = StandardScaler()
    Xz_tr = scaler.fit_transform(X_train)
    Xz_te = scaler.transform(X_test)
    clf = LogisticRegression(max_iter=200, multi_class="multinomial")
    clf.fit(Xz_tr, y_train)
    return accuracy_score(y_test, clf.predict(Xz_te))

# Collect activations for train and test (smaller subsets to keep runtime reasonable)
acts_tr, y_tr_small = collect_activations(model, train_loader, max_batches=80)
acts_te, y_te_small = collect_activations(model, test_loader,  max_batches=40)

probe_scores = {}
for layer in ["conv1","conv2","fc1"]:
    acc_probe = linear_probe_acc(acts_tr[layer], y_tr_small, acts_te[layer], y_te_small)
    probe_scores[layer] = acc_probe
    print(f"Linear probe accuracy — {layer}: {acc_probe:.4f}")

# Save a tiny CSV summary
import csv
with open("results/06_linear_probe_summary.csv", "w", newline="") as f:
    w = csv.writer(f); w.writerow(["layer","probe_accuracy"])
    for k,v in probe_scores.items():
        w.writerow([k, v])

## 9. Sample Predictions

In [None]:
@torch.no_grad()
def sample_preds(model, loader, n=16):
    model.eval()
    imgs_all, y_all, yhat_all = [], [], []
    for imgs, labels in loader:
        logits = model(imgs.to(device)).cpu()
        preds = logits.argmax(1)
        imgs_all.append(imgs); y_all.append(labels); yhat_all.append(preds)
        if sum(x.size(0) for x in imgs_all) >= n:
            break
    imgs = torch.cat(imgs_all,0)[:n]
    y = torch.cat(y_all,0)[:n]
    yhat = torch.cat(yhat_all,0)[:n]
    return imgs, y.numpy(), yhat.numpy()

imgs, y_true, y_hat = sample_preds(model, test_loader, n=16)
import math
cols = 8
rows = math.ceil(len(imgs)/cols)
plt.figure(figsize=(cols*1.2, rows*1.2))
for i in range(len(imgs)):
    plt.subplot(rows, cols, i+1)
    plt.imshow(imgs[i,0], cmap="gray")
    plt.axis("off")
    plt.title(f"{y_true[i]}→{y_hat[i]}", fontsize=8)
plt.suptitle("Sample predictions (true→pred)")
plt.tight_layout(); plt.savefig("results/06_sample_preds.png", dpi=150); plt.show()

## 10. Takeaways
- As depth increases (**conv1 → conv2 → fc1**), **PCA projections** typically show **cleaner class separation**.
- **Linear probes** quantify separability: later layers should provide higher probe accuracy.
- This sets up comparisons to **perceptron/logistic baselines** and justifies moving to deeper CNNs.