In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import TensorDataset, DataLoader

import graphlearning as gl

# ---- Load MNIST data ----

mnist_digits, mnist_labels = gl.datasets.load("mnist")

# ---- Convert to PyTorch tensors ----
X = torch.tensor(mnist_digits, dtype=torch.float32).reshape(-1, 1, 28, 28) / 255.0
y = torch.tensor(mnist_labels, dtype=torch.long)

# ---- Split into train/test ----
train_X, test_X = X[:60000], X[60000:]
train_y, test_y = y[:60000], y[60000:]

train_ds = TensorDataset(train_X, train_y)
test_ds = TensorDataset(test_X, test_y)

train_loader = DataLoader(train_ds, batch_size=64, shuffle=True)
test_loader = DataLoader(test_ds, batch_size=256)

# ---- Define CNN model ----
class ConvNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv2d(1, 32, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Conv2d(32, 64, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Flatten(),
            nn.Linear(64 * 7 * 7, 128),
            nn.ReLU(),
            nn.Linear(128, 10)
        )

    def forward(self, x):
        return self.net(x)

# ---- Initialize model, loss, optimizer ----
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Load pre-trained model if available
from pathlib import Path
model_path = Path("mnist_cnn.pth")
if model_path.exists():
    model = ConvNet().to(device)
    model.load_state_dict(torch.load(model_path))
    print("Loaded pre-trained model.")
else:
    model = ConvNet().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-3)

# ---- Training loop ----
for epoch in range(100):
    model.train()
    total_loss = 0
    for xb, yb in train_loader:
        xb, yb = xb.to(device), yb.to(device)
        optimizer.zero_grad()
        preds = model(xb)
        loss = criterion(preds, yb)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()

    print(f"Epoch {epoch+1}: train_loss = {total_loss / len(train_loader):.4f}")

# ---- Evaluation ----
model.eval()
correct = 0
total = 0
with torch.no_grad():
    for xb, yb in test_loader:
        xb, yb = xb.to(device), yb.to(device)
        preds = model(xb)
        correct += (preds.argmax(1) == yb).sum().item()
        total += yb.size(0)

print(f"Test accuracy: {correct / total:.4f}")

# ---- Save model ----
torch.save(model.state_dict(), "mnist_cnn.pth")
