In [2]:
import torch
import os
import torch.nn.functional as F
from torch_geometric.data import Dataset, DataLoader
from torch_geometric.nn import GCNConv, global_mean_pool
from sklearn.model_selection import train_test_split
from tqdm import tqdm
import matplotlib.pyplot as plt
import numpy as np
from torchvision.datasets import MNIST # Used for visualization later

In [3]:
class GNN_MNIST(torch.nn.Module):
    def __init__(self, num_node_features, num_classes):
        super(GNN_MNIST, self).__init__()
        self.conv1 = GCNConv(num_node_features, 32)
        self.conv2 = GCNConv(32, 64)
        self.conv3 = GCNConv(64, 128)
        
        self.fc1 = torch.nn.Linear(128, 64)
        self.fc2 = torch.nn.Linear(64, num_classes)

    def forward(self, data):
        x, edge_index, batch = data.x, data.edge_index, data.batch

        # GNN layers
        x = self.conv1(x, edge_index)
        x = F.relu(x)
        x = self.conv2(x, edge_index)
        x = F.relu(x)
        x = self.conv3(x, edge_index)
        x = F.relu(x)

        # Pooling Layer
        x = global_mean_pool(x, batch)

        # Classifier
        x = F.dropout(x, p=0.5, training=self.training)
        x = self.fc1(x)
        x = F.relu(x)
        x = self.fc2(x)
        
        return F.log_softmax(x, dim=-1)

In [4]:

class MNISTGraphDataset(Dataset):
    def __init__(self, root, transform=None, pre_transform=None):
        # The parent class constructor handles everything.
        super(MNISTGraphDataset, self).__init__(root, transform, pre_transform)

    @property
    def raw_file_names(self):
        return []

    @property
    def processed_file_names(self):
        # This dynamically finds all data files in the processed folder.
        # It assumes you have 60,000 files from the training set.
        return [f'data_{i}.pt' for i in range(60000)]

    def len(self):
        return len(self.processed_file_names)

    def get(self, idx):
        # The base class automatically knows where to find this file.
        data = torch.load(os.path.join(self.processed_dir, f'data_{idx}.pt'), weights_only=False)
        return data

In [5]:
def train(model, loader, optimizer, device):
    model.train()
    total_loss = 0
    for data in tqdm(loader, desc="Training"):
        data = data.to(device)
        optimizer.zero_grad()
        out = model(data)
        loss = F.nll_loss(out, data.y)
        loss.backward()
        optimizer.step()
        total_loss += loss.item() * data.num_graphs
    return total_loss / len(loader.dataset)

In [6]:
def test(model, loader, device):
    model.eval()
    correct = 0
    with torch.no_grad():
        for data in tqdm(loader, desc="Testing"):
            data = data.to(device)
            out = model(data)
            pred = out.argmax(dim=1)
            correct += int((pred == data.y).sum())
    return correct / len(loader.dataset)

In [7]:
if __name__ == '__main__':
    # Setup device
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Using device: {device}")

    # Load dataset

    dataset = MNISTGraphDataset(root='./mnist_graphs') # <-- Use the new root path
    
    # Split dataset into training and testing
    # Note: A fixed random_state ensures the split is the same every time.
    train_indices, test_indices = train_test_split(list(range(len(dataset))), test_size=0.2, random_state=42)
    train_dataset = dataset[train_indices]
    test_dataset = dataset[test_indices]
    
    print(f"Number of training graphs: {len(train_dataset)}")
    print(f"Number of test graphs: {len(test_dataset)}")

    # Create DataLoaders
    train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True,num_workers=0,pin_memory=True)
    test_loader = DataLoader(test_dataset, batch_size=128, shuffle=False,num_workers=0,pin_memory=True)

    # Initialize model, optimizer
    # Node features are [intensity, pos_y, pos_x], so num_node_features=3
    model = GNN_MNIST(num_node_features=3, num_classes=10).to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

    # Training loop
    epochs = 50
    history = {'loss': [], 'accuracy': []}
    
    for epoch in range(1, epochs + 1):
        loss = train(model, train_loader, optimizer, device)
        test_acc = test(model, test_loader, device)
        history['loss'].append(loss)
        history['accuracy'].append(test_acc)
        print(f'Epoch: {epoch:02d}, Loss: {loss:.4f}, Test Accuracy: {test_acc:.4f}')


print("Saving trained model state...")
torch.save(model.state_dict(), 'gnn_mnist_model.pth')
print("Model saved to gnn_mnist_model.pth")

Using device: cuda
Number of training graphs: 48000
Number of test graphs: 12000


Training: 100%|██████████| 375/375 [02:10<00:00,  2.87it/s]
Testing: 100%|██████████| 94/94 [00:44<00:00,  2.11it/s]


Epoch: 01, Loss: 2.1160, Test Accuracy: 0.2552


Training: 100%|██████████| 375/375 [01:45<00:00,  3.57it/s]
Testing: 100%|██████████| 94/94 [00:06<00:00, 15.66it/s]


Epoch: 02, Loss: 1.9856, Test Accuracy: 0.2838


Training: 100%|██████████| 375/375 [00:41<00:00,  9.00it/s]
Testing: 100%|██████████| 94/94 [00:14<00:00,  6.38it/s]


Epoch: 03, Loss: 1.9046, Test Accuracy: 0.2994


Training: 100%|██████████| 375/375 [00:47<00:00,  7.84it/s]
Testing: 100%|██████████| 94/94 [00:05<00:00, 16.08it/s]


Epoch: 04, Loss: 1.8570, Test Accuracy: 0.3316


Training: 100%|██████████| 375/375 [00:30<00:00, 12.44it/s]
Testing: 100%|██████████| 94/94 [00:05<00:00, 15.71it/s]


Epoch: 05, Loss: 1.8252, Test Accuracy: 0.3410


Training: 100%|██████████| 375/375 [00:29<00:00, 12.83it/s]
Testing: 100%|██████████| 94/94 [00:05<00:00, 15.97it/s]


Epoch: 06, Loss: 1.8000, Test Accuracy: 0.3538


Training: 100%|██████████| 375/375 [00:27<00:00, 13.56it/s]
Testing: 100%|██████████| 94/94 [00:09<00:00, 10.34it/s]


Epoch: 07, Loss: 1.7741, Test Accuracy: 0.3747


Training: 100%|██████████| 375/375 [00:45<00:00,  8.25it/s]
Testing: 100%|██████████| 94/94 [00:14<00:00,  6.49it/s]


Epoch: 08, Loss: 1.7448, Test Accuracy: 0.3852


Training: 100%|██████████| 375/375 [00:53<00:00,  6.95it/s]
Testing: 100%|██████████| 94/94 [00:11<00:00,  8.49it/s]


Epoch: 09, Loss: 1.6960, Test Accuracy: 0.4382


Training: 100%|██████████| 375/375 [00:55<00:00,  6.70it/s]
Testing: 100%|██████████| 94/94 [00:14<00:00,  6.41it/s]


Epoch: 10, Loss: 1.6277, Test Accuracy: 0.5036


Training: 100%|██████████| 375/375 [00:55<00:00,  6.81it/s]
Testing: 100%|██████████| 94/94 [00:09<00:00,  9.90it/s]


Epoch: 11, Loss: 1.5446, Test Accuracy: 0.5008


Training: 100%|██████████| 375/375 [00:58<00:00,  6.43it/s]
Testing: 100%|██████████| 94/94 [00:14<00:00,  6.61it/s]


Epoch: 12, Loss: 1.4679, Test Accuracy: 0.5765


Training: 100%|██████████| 375/375 [00:45<00:00,  8.27it/s]
Testing: 100%|██████████| 94/94 [00:10<00:00,  9.04it/s]


Epoch: 13, Loss: 1.3975, Test Accuracy: 0.5862


Training: 100%|██████████| 375/375 [00:45<00:00,  8.19it/s]
Testing: 100%|██████████| 94/94 [00:09<00:00, 10.31it/s]


Epoch: 14, Loss: 1.3363, Test Accuracy: 0.6077


Training: 100%|██████████| 375/375 [00:37<00:00, 10.08it/s]
Testing: 100%|██████████| 94/94 [00:07<00:00, 11.88it/s]


Epoch: 15, Loss: 1.2879, Test Accuracy: 0.6229


Training: 100%|██████████| 375/375 [01:00<00:00,  6.24it/s]
Testing: 100%|██████████| 94/94 [00:13<00:00,  6.75it/s]


Epoch: 16, Loss: 1.2439, Test Accuracy: 0.6411


Training: 100%|██████████| 375/375 [00:55<00:00,  6.77it/s]
Testing: 100%|██████████| 94/94 [00:13<00:00,  6.99it/s]


Epoch: 17, Loss: 1.2059, Test Accuracy: 0.6545


Training: 100%|██████████| 375/375 [00:50<00:00,  7.48it/s]
Testing: 100%|██████████| 94/94 [00:09<00:00, 10.17it/s]


Epoch: 18, Loss: 1.1630, Test Accuracy: 0.6680


Training: 100%|██████████| 375/375 [00:46<00:00,  8.13it/s]
Testing: 100%|██████████| 94/94 [00:13<00:00,  7.00it/s]


Epoch: 19, Loss: 1.1230, Test Accuracy: 0.6749


Training: 100%|██████████| 375/375 [01:04<00:00,  5.86it/s]
Testing: 100%|██████████| 94/94 [00:13<00:00,  7.01it/s]


Epoch: 20, Loss: 1.0892, Test Accuracy: 0.6992


Training: 100%|██████████| 375/375 [00:56<00:00,  6.65it/s]
Testing: 100%|██████████| 94/94 [00:12<00:00,  7.27it/s]


Epoch: 21, Loss: 1.0560, Test Accuracy: 0.6987


Training: 100%|██████████| 375/375 [00:49<00:00,  7.57it/s]
Testing: 100%|██████████| 94/94 [00:11<00:00,  8.45it/s]


Epoch: 22, Loss: 1.0346, Test Accuracy: 0.7033


Training: 100%|██████████| 375/375 [00:40<00:00,  9.25it/s]
Testing: 100%|██████████| 94/94 [00:15<00:00,  5.98it/s]


Epoch: 23, Loss: 1.0066, Test Accuracy: 0.7051


Training: 100%|██████████| 375/375 [00:49<00:00,  7.65it/s]
Testing: 100%|██████████| 94/94 [00:16<00:00,  5.62it/s]


Epoch: 24, Loss: 0.9766, Test Accuracy: 0.7179


Training: 100%|██████████| 375/375 [00:50<00:00,  7.39it/s]
Testing: 100%|██████████| 94/94 [00:20<00:00,  4.56it/s]


Epoch: 25, Loss: 0.9595, Test Accuracy: 0.7444


Training: 100%|██████████| 375/375 [00:56<00:00,  6.67it/s]
Testing: 100%|██████████| 94/94 [00:14<00:00,  6.61it/s]


Epoch: 26, Loss: 0.9384, Test Accuracy: 0.7238


Training: 100%|██████████| 375/375 [00:55<00:00,  6.74it/s]
Testing: 100%|██████████| 94/94 [00:13<00:00,  6.80it/s]


Epoch: 27, Loss: 0.9149, Test Accuracy: 0.7588


Training: 100%|██████████| 375/375 [00:55<00:00,  6.76it/s]
Testing: 100%|██████████| 94/94 [00:09<00:00,  9.71it/s]


Epoch: 28, Loss: 0.9020, Test Accuracy: 0.7482


Training: 100%|██████████| 375/375 [00:53<00:00,  7.04it/s]
Testing: 100%|██████████| 94/94 [00:14<00:00,  6.30it/s]


Epoch: 29, Loss: 0.8890, Test Accuracy: 0.7618


Training: 100%|██████████| 375/375 [00:52<00:00,  7.12it/s]
Testing: 100%|██████████| 94/94 [00:13<00:00,  6.81it/s]


Epoch: 30, Loss: 0.8787, Test Accuracy: 0.7675


Training: 100%|██████████| 375/375 [00:54<00:00,  6.93it/s]
Testing: 100%|██████████| 94/94 [00:30<00:00,  3.12it/s]


Epoch: 31, Loss: 0.8543, Test Accuracy: 0.7674


Training: 100%|██████████| 375/375 [01:00<00:00,  6.20it/s]
Testing: 100%|██████████| 94/94 [00:15<00:00,  6.23it/s]


Epoch: 32, Loss: 0.8488, Test Accuracy: 0.7801


Training: 100%|██████████| 375/375 [00:51<00:00,  7.33it/s]
Testing: 100%|██████████| 94/94 [00:15<00:00,  6.26it/s]


Epoch: 33, Loss: 0.8323, Test Accuracy: 0.7768


Training: 100%|██████████| 375/375 [00:50<00:00,  7.45it/s]
Testing: 100%|██████████| 94/94 [00:12<00:00,  7.63it/s]


Epoch: 34, Loss: 0.8149, Test Accuracy: 0.7507


Training: 100%|██████████| 375/375 [00:59<00:00,  6.35it/s]
Testing: 100%|██████████| 94/94 [00:13<00:00,  6.85it/s]


Epoch: 35, Loss: 0.8018, Test Accuracy: 0.7903


Training: 100%|██████████| 375/375 [00:49<00:00,  7.55it/s]
Testing: 100%|██████████| 94/94 [00:06<00:00, 15.00it/s]


Epoch: 36, Loss: 0.7966, Test Accuracy: 0.7926


Training: 100%|██████████| 375/375 [00:46<00:00,  8.02it/s]
Testing: 100%|██████████| 94/94 [00:12<00:00,  7.59it/s]


Epoch: 37, Loss: 0.7842, Test Accuracy: 0.7977


Training: 100%|██████████| 375/375 [00:51<00:00,  7.27it/s]
Testing: 100%|██████████| 94/94 [00:14<00:00,  6.33it/s]


Epoch: 38, Loss: 0.7679, Test Accuracy: 0.8083


Training: 100%|██████████| 375/375 [01:27<00:00,  4.31it/s]
Testing: 100%|██████████| 94/94 [00:17<00:00,  5.53it/s]


Epoch: 39, Loss: 0.7563, Test Accuracy: 0.8033


Training: 100%|██████████| 375/375 [01:00<00:00,  6.19it/s]
Testing: 100%|██████████| 94/94 [00:11<00:00,  8.41it/s]


Epoch: 40, Loss: 0.7517, Test Accuracy: 0.8117


Training: 100%|██████████| 375/375 [00:51<00:00,  7.26it/s]
Testing: 100%|██████████| 94/94 [00:20<00:00,  4.57it/s]


Epoch: 41, Loss: 0.7337, Test Accuracy: 0.8119


Training: 100%|██████████| 375/375 [00:59<00:00,  6.35it/s]
Testing: 100%|██████████| 94/94 [00:11<00:00,  7.84it/s]


Epoch: 42, Loss: 0.7315, Test Accuracy: 0.8141


Training: 100%|██████████| 375/375 [00:56<00:00,  6.66it/s]
Testing: 100%|██████████| 94/94 [00:11<00:00,  7.95it/s]


Epoch: 43, Loss: 0.7166, Test Accuracy: 0.8211


Training: 100%|██████████| 375/375 [00:54<00:00,  6.84it/s]
Testing: 100%|██████████| 94/94 [00:05<00:00, 16.11it/s]


Epoch: 44, Loss: 0.7119, Test Accuracy: 0.8222


Training: 100%|██████████| 375/375 [00:26<00:00, 14.17it/s]
Testing: 100%|██████████| 94/94 [00:05<00:00, 17.71it/s]


Epoch: 45, Loss: 0.6973, Test Accuracy: 0.8202


Training: 100%|██████████| 375/375 [00:33<00:00, 11.12it/s]
Testing: 100%|██████████| 94/94 [00:12<00:00,  7.79it/s]


Epoch: 46, Loss: 0.6908, Test Accuracy: 0.8234


Training: 100%|██████████| 375/375 [00:38<00:00,  9.83it/s]
Testing: 100%|██████████| 94/94 [00:10<00:00,  9.09it/s]


Epoch: 47, Loss: 0.6833, Test Accuracy: 0.8264


Training: 100%|██████████| 375/375 [01:31<00:00,  4.12it/s]
Testing: 100%|██████████| 94/94 [00:16<00:00,  5.60it/s]


Epoch: 48, Loss: 0.6692, Test Accuracy: 0.8256


Training: 100%|██████████| 375/375 [00:50<00:00,  7.43it/s]
Testing: 100%|██████████| 94/94 [00:16<00:00,  5.83it/s]


Epoch: 49, Loss: 0.6634, Test Accuracy: 0.8190


Training: 100%|██████████| 375/375 [00:55<00:00,  6.76it/s]
Testing: 100%|██████████| 94/94 [00:13<00:00,  6.72it/s]


Epoch: 50, Loss: 0.6640, Test Accuracy: 0.8369
Saving trained model state...
Model saved to gnn_mnist_model.pth
