# Train CNN on MNIST

We need to train to do regression on the one-hot MNIST labels.

In [1]:
!pip install graphlearning

Collecting graphlearning
  Downloading graphlearning-1.7.4.tar.gz (93 kB)
[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/93.4 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [91m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m[90m╺[0m [32m92.2/93.4 kB[0m [31m6.9 MB/s[0m eta [36m0:00:01[0m[2K     [91m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m[90m╺[0m [32m92.2/93.4 kB[0m [31m6.9 MB/s[0m eta [36m0:00:01[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m93.4/93.4 kB[0m [31m780.1 kB/s[0m eta [36m0:00:00[0m
[?25h  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
Building wheels for collected packages: graphlearning
  Building wheel for graphlearning (pyproject.toml) ... [?25l[?25hdone
  Created wheel for graphlearning: filename=graphlearning-1.7.4-cp312-cp312-linux_x86_64.whl size=358029 sha256=9257a65b68a06d477

In [21]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import TensorDataset, DataLoader

import graphlearning as gl

# ---- Load MNIST data ----
mnist_digits, mnist_labels = gl.datasets.load("mnist")

# ---- Convert to PyTorch tensors ----
X = torch.tensor(mnist_digits, dtype=torch.float32).reshape(-1, 1, 28, 28) / 255.0
y = torch.tensor(mnist_labels, dtype=torch.long)

## Randomly shuffle
indices = torch.randperm(X.shape[0])
X = X[indices]
y = y[indices]


# ---- Split into train/test ----
train_X, test_X = X[:60000], X[60000:]
train_y, test_y = y[:60000], y[60000:]

train_ds = TensorDataset(
    train_X,
    torch.nn.functional.one_hot(train_y, num_classes=10),
  )

test_ds = TensorDataset(test_X, test_y)

train_loader = DataLoader(train_ds, batch_size=64, shuffle=True)
test_loader = DataLoader(test_ds, batch_size=256)

# ---- Define CNN model ----
from models.mnist_cnn import ConvNet

# ---- Initialize model, loss, optimizer ----
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Load pre-trained model if available
from pathlib import Path
model_path = Path("models/mnist_cnn.pth")
if model_path.exists():
    model = ConvNet().to(device)
    model.load_state_dict(torch.load(model_path))
    print("Loaded pre-trained model.")
else:
    model = ConvNet().to(device)


criterion = torch.nn.SmoothL1Loss()
# criterion = torch.nn.BCEWithLogitsLoss()


optimizer = optim.Adam(model.parameters(), lr=1e-6)
# optimizer = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9)


# ---- Training loop ----
for epoch in range(1000):
    model.train()
    total_loss = 0
    for xb, yb in train_loader:
        xb, yb = xb.to(device), yb.to(device)
        optimizer.zero_grad()
        preds = model(xb)
        loss = criterion(preds, yb)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()

    print(f"Epoch {epoch+1}: train_loss = {total_loss / len(train_loader):.4f}")


# ---- Evaluation ----
model.eval()
correct = 0
total = 0
with torch.no_grad():
    for xb, yb in test_loader:
        xb, yb = xb.to(device), yb.to(device)
        preds = model(xb)
        predicted = torch.argmax(preds, 1)
        total += yb.size(0)
        correct += (predicted == yb).sum().item()

print(f"Test accuracy: {correct / total:.4f}")

# ---- Save model ----
torch.save(model.state_dict(), "models/mnist_cnn.pth")

Loaded pre-trained model.
Epoch 1: train_loss = 0.0027
Epoch 2: train_loss = 0.0027
Epoch 3: train_loss = 0.0027
Epoch 4: train_loss = 0.0027
Epoch 5: train_loss = 0.0027
Epoch 6: train_loss = 0.0027
Epoch 7: train_loss = 0.0027
Epoch 8: train_loss = 0.0027
Epoch 9: train_loss = 0.0027
Epoch 10: train_loss = 0.0027
Epoch 11: train_loss = 0.0027
Epoch 12: train_loss = 0.0027
Epoch 13: train_loss = 0.0027
Epoch 14: train_loss = 0.0027
Epoch 15: train_loss = 0.0027
Epoch 16: train_loss = 0.0027
Epoch 17: train_loss = 0.0027
Epoch 18: train_loss = 0.0027
Epoch 19: train_loss = 0.0027
Epoch 20: train_loss = 0.0027
Epoch 21: train_loss = 0.0027
Epoch 22: train_loss = 0.0027
Epoch 23: train_loss = 0.0027
Epoch 24: train_loss = 0.0027
Epoch 25: train_loss = 0.0027
Epoch 26: train_loss = 0.0027
Epoch 27: train_loss = 0.0027
Epoch 28: train_loss = 0.0027
Epoch 29: train_loss = 0.0027
Epoch 30: train_loss = 0.0027
Epoch 31: train_loss = 0.0027
Epoch 32: train_loss = 0.0027
Epoch 33: train_loss = 

KeyboardInterrupt: 