# Train CNN on MNIST

We need to train to do regression on the one-hot MNIST labels.

In [None]:
import torch
import torch.optim as optim
from torchvision.datasets import MNIST
from torch.utils.data import TensorDataset, DataLoader

# ---- Load MNIST data ----
mnist_digits = MNIST(root="./data", train=True, download=True).data
mnist_labels = MNIST(root="./data", train=True, download=True).targets
test_digits = MNIST(root="./data", train=False, download=True).data
test_labels = MNIST(root="./data", train=False, download=True).targets

# ---- Convert to PyTorch tensors ----
X = mnist_digits.float().reshape(-1, 1, 28, 28) / 255.0
y = mnist_labels.long()

## Randomly shuffle
indices = torch.randperm(X.shape[0])
X = X[indices]
y = y[indices]


train_ds = TensorDataset(
    X,
    torch.nn.functional.one_hot(y, num_classes=10),
)
test_ds = TensorDataset(X, y)

train_loader = DataLoader(train_ds, batch_size=64, shuffle=True)
test_loader = DataLoader(test_ds, batch_size=256)

# ---- Initialize model, loss, optimizer ----
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
import base64
import io
import json
from pathlib import Path

class TrainingInfo:
    def __init__(self, json_path):
        self.json_path = json_path
        self.epochs = 0
        self.train_losses = []
        self.test_accuracies = {}  # epoch -> accuracy
        self.leverage_scores = {}  # epoch -> leverage scores array

    def to_json(self):
        leverage_scores_serialized = {}
        for k, arr in self.leverage_scores.items():
            torch.save(arr, buf := io.BytesIO())
            leverage_scores_serialized[str(k)] = base64.b85encode(buf.getvalue()).decode('ascii')
        return {
            "epochs": self.epochs,
            "train_losses": self.train_losses,
            "test_accuracies": self.test_accuracies,
            "leverage_scores": leverage_scores_serialized
        }

    def from_json(self, data: dict):
        self.epochs = data["epochs"]
        self.train_losses = data["train_losses"]
        self.test_accuracies = data["test_accuracies"]
        leverage_scores_deserialized = {}
        for k, b85str in data["leverage_scores"].items():
            byte_data = base64.b85decode(b85str.encode('ascii'))
            buf = io.BytesIO(byte_data)
            leverage_scores_deserialized[int(k)] = torch.load(buf, weights_only=False)
        self.leverage_scores = leverage_scores_deserialized

    def save(self):
        with open(self.json_path, "w") as f:
            json.dump(self.to_json(), f)

    def load(self):
        with open(self.json_path, "r") as f:
            data = json.load(f)
        self.from_json(data)

    def __enter__(self):
        if Path(self.json_path).exists():
            self.load()
        return self

    def __exit__(self, exc_type, exc_value, traceback):
        self.save()

In [None]:
import torch
import torch.optim as optim
from torchvision.datasets import MNIST
from torch.utils.data import TensorDataset, DataLoader

# ---- Define CNN model ----
from models.mnist_cnn import MnistConvNet

combined_digits = torch.cat([mnist_digits, test_digits], dim=0).to(device).float().reshape(-1, 1, 28, 28) / 255.0

# Create a DataLoader for combined_digits to process in batches
combined_dataset = TensorDataset(combined_digits)
combined_loader = DataLoader(combined_dataset, batch_size=256, shuffle=False) # Use a reasonable batch size

def leverage_scores(model: MnistConvNet, data_loader: DataLoader): # Changed data to data_loader
    """Compute leverage scores for the model's embedding of the data in batches."""
    all_embeddings = []
    for batch in data_loader:
        xb = batch[0].to(device)
        with torch.no_grad():
            embeddings_batch = model.embed(xb).cpu()
        all_embeddings.append(embeddings_batch)

    full_embeddings = torch.cat(all_embeddings, dim=0)
    Q, _ = torch.linalg.qr(full_embeddings)
    leverage = torch.sum(Q**2, dim=1)
    return leverage


for i in [10, 20, 50, 100, 150, 200, 250, 300, 400, 500, 1000]:
    model_path = Path(f"models/mnist_cnn_R{i}_classify.pth")
    info_path = Path(f"training_info/mnist_cnn_R{i}_classify_info.json")
    with TrainingInfo(info_path) as info:
        if model_path.exists():
            model = MnistConvNet(i).to(device)
            model.load_state_dict(torch.load(model_path))
            print("Loaded pre-trained model.")
        else:
            print("No pre-trained model found. Initializing new model.")
            model = MnistConvNet(i).to(device)
            print("Measuring initial leverage scores")
            # Measure initial leverage scores
            model.eval()
            # Pass the DataLoader instead of the raw tensor
            info.leverage_scores[info.epochs] = leverage_scores(model, combined_loader).cpu().numpy()

        print(f"Training MnistConvNet with R={i}...")
        criterion = torch.nn.CrossEntropyLoss()  # Changed from SmoothL1Loss
        optimizer = optim.Adam(model.parameters(), lr=1e-3)  # Increased learning rate
        # optimizer = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9)


        # ---- Training loop ----
        for epoch in range(2):
            model.train()
            total_loss = 0
            for xb, yb in train_loader:
                xb, yb = xb.to(device), yb.to(device)

                optimizer.zero_grad()
                preds = model(xb)
                loss = criterion(preds, yb.float()) # Ensure yb is float for SmoothL1Loss, though CrossEntropyLoss typically expects long for labels.
                loss.backward()
                optimizer.step()
                total_loss += loss.item()

            print(f"  Epoch {epoch+1}: train_loss = {total_loss / len(train_loader):.4f}")
            info.epochs += 1
            info.train_losses.append(total_loss / len(train_loader))

        # ---- Evaluation ----
        model.eval()

        # Measure leverage scores after training
        with torch.no_grad():
            # Pass the DataLoader instead of the raw tensor
            info.leverage_scores[info.epochs] = leverage_scores(model, combined_loader).cpu().numpy()

        correct = 0
        total = 0
        with torch.no_grad():
            for xb, yb in test_loader:
                xb, yb = xb.to(device), yb.to(device)
                preds = model(xb)
                predicted = torch.argmax(preds, 1)
                total += yb.size(0)
                correct += (predicted == yb).sum().item()

        accuracy = correct / total
        print(f"Test accuracy: {accuracy:.4f}")
        info.test_accuracies[info.epochs] = accuracy

        # ---- Save model ----
        torch.save(model.state_dict(), f"models/mnist_cnn_R{i}_classify.pth")

In [None]:
import numpy as np
from matplotlib import pyplot as plt

dimensions = [10,20,50,100,150,200,250,300,400,500,1000]

fig, axes = plt.subplots(
    3,
    4,
    figsize=(5 * 4, 5 * 3),
)

for i, dimension in enumerate(dimensions):
    info_path = Path(f"training_info/mnist_cnn_R{dimension}_classify_info.json")
    info = TrainingInfo(info_path)
    info.load()
    ax = axes[i // 4, i % 4]

    for epoch, leverage_scores in info.leverage_scores.items():
        y = np.sort(torch.tensor(leverage_scores).cpu().numpy())[::-1]
        ax.scatter(np.arange(y.shape[0]), y, label=f"Epoch {epoch}", marker="x")
    ax.set_title(f"Dimension {dimension}")
    ax.set_xlabel("Index")
    ax.set_xscale("log")
    ax.set_ylabel("Leverage Score")
    ax.legend()
    ax.grid(True)
plt.tight_layout()
plt.savefig("mnist_leverage_scores_by_dimension.png")