# Train CNN on MNIST

In [None]:
!pip install graphlearning

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import TensorDataset, DataLoader

import graphlearning as gl

# ---- Load MNIST data ----
mnist_digits, mnist_labels = gl.datasets.load("mnist")

# ---- Convert to PyTorch tensors ----
X = torch.tensor(mnist_digits, dtype=torch.float32).reshape(-1, 1, 28, 28) / 255.0
y = torch.tensor(mnist_labels, dtype=torch.long)

## Randomly shuffle
indices = torch.randperm(X.shape[0])
X = X[indices]
y = y[indices]


# ---- Split into train/test ----
train_X, test_X = X[:60000], X[60000:]
train_y, test_y = y[:60000], y[60000:]

train_ds = TensorDataset(train_X, train_y)
test_ds = TensorDataset(test_X, test_y)

train_loader = DataLoader(train_ds, batch_size=64, shuffle=True)
test_loader = DataLoader(test_ds, batch_size=256)

# ---- Define CNN model ----
from models.mnist_cnn import ConvNet

# ---- Initialize model, loss, optimizer ----
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Load pre-trained model if available
from pathlib import Path
model_path = Path("models/mnist_cnn.pth")
if model_path.exists():
    model = ConvNet().to(device)
    model.load_state_dict(torch.load(model_path))
    print("Loaded pre-trained model.")
else:
    model = ConvNet().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-4)
# optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9)


# ---- Training loop ----
for epoch in range(20):
    model.train()
    total_loss = 0
    for xb, yb in train_loader:
        xb, yb = xb.to(device), yb.to(device)
        optimizer.zero_grad()
        preds = model(xb)
        loss = criterion(preds, yb)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()

    print(f"Epoch {epoch+1}: train_loss = {total_loss / len(train_loader):.4f}")

# ---- Evaluation ----
model.eval()
correct = 0
total = 0
with torch.no_grad():
    for xb, yb in test_loader:
        xb, yb = xb.to(device), yb.to(device)
        preds = model(xb)
        correct += (preds.argmax(1) == yb).sum().item()
        total += yb.size(0)

print(f"Test accuracy: {correct / total:.4f}")

# ---- Save model ----
torch.save(model.state_dict(), "models/mnist_cnn.pth")


Epoch 1: train_loss = 2.3018
Epoch 2: train_loss = 2.2998
Epoch 3: train_loss = 2.2876
Epoch 4: train_loss = 2.1218
Epoch 5: train_loss = 1.6122
Epoch 6: train_loss = 1.1916
Epoch 7: train_loss = 0.9527
Epoch 8: train_loss = 0.8096
Epoch 9: train_loss = 0.7089
Epoch 10: train_loss = 0.6380
Epoch 11: train_loss = 0.5828
Epoch 12: train_loss = 0.5435
Epoch 13: train_loss = 0.5090
Epoch 14: train_loss = 0.4863
Epoch 15: train_loss = 0.4616
Epoch 16: train_loss = 0.4404
Epoch 17: train_loss = 0.4237
Epoch 18: train_loss = 0.4090
Epoch 19: train_loss = 0.3937
Epoch 20: train_loss = 0.3797
Epoch 21: train_loss = 0.3671
Epoch 22: train_loss = 0.3561
Epoch 23: train_loss = 0.3477
Epoch 24: train_loss = 0.3374
Epoch 25: train_loss = 0.3272
Epoch 26: train_loss = 0.3192
Epoch 27: train_loss = 0.3114
Epoch 28: train_loss = 0.3040
Epoch 29: train_loss = 0.2976
Epoch 30: train_loss = 0.2894
Epoch 31: train_loss = 0.2825
Epoch 32: train_loss = 0.2762
Epoch 33: train_loss = 0.2695
Epoch 34: train_los