In [1]:
import torch
import os
import torch.nn.functional as F
from torch_geometric.data import Dataset, DataLoader
from torch_geometric.nn import GCNConv, global_mean_pool
from sklearn.model_selection import train_test_split
from tqdm import tqdm
import matplotlib.pyplot as plt
import numpy as np
from torchvision.datasets import MNIST # Used for visualization later

In [2]:
class GNN_MNIST(torch.nn.Module):
    def __init__(self, num_node_features, num_classes):
        super(GNN_MNIST, self).__init__()
        self.conv1 = GCNConv(num_node_features, 32)
        self.conv2 = GCNConv(32, 64)
        self.conv3 = GCNConv(64, 128)
        
        self.fc1 = torch.nn.Linear(128, 64)
        self.fc2 = torch.nn.Linear(64, num_classes)

    def forward(self, data):
        x, edge_index, batch = data.x, data.edge_index, data.batch

        # GNN layers
        x = self.conv1(x, edge_index)
        x = F.relu(x)
        x = self.conv2(x, edge_index)
        x = F.relu(x)
        x = self.conv3(x, edge_index)
        x = F.relu(x)

        # Pooling Layer
        x = global_mean_pool(x, batch)

        # Classifier
        x = F.dropout(x, p=0.5, training=self.training)
        x = self.fc1(x)
        x = F.relu(x)
        x = self.fc2(x)
        
        return F.log_softmax(x, dim=-1)

In [3]:

class MNISTGraphDataset(Dataset):
    def __init__(self, root, transform=None, pre_transform=None):
        # The parent class constructor handles everything.
        super(MNISTGraphDataset, self).__init__(root, transform, pre_transform)

    @property
    def raw_file_names(self):
        return []

    @property
    def processed_file_names(self):
        # This dynamically finds all data files in the processed folder.
        # It assumes you have 60,000 files from the training set.
        return [f'data_{i}.pt' for i in range(60000)]

    def len(self):
        return len(self.processed_file_names)

    def get(self, idx):
        # The base class automatically knows where to find this file.
        data = torch.load(os.path.join(self.processed_dir, f'data_{idx}.pt'), weights_only=False)
        return data

In [4]:
def train(model, loader, optimizer, device):
    model.train()
    total_loss = 0
    for data in tqdm(loader, desc="Training"):
        data = data.to(device)
        optimizer.zero_grad()
        out = model(data)
        loss = F.nll_loss(out, data.y)
        loss.backward()
        optimizer.step()
        total_loss += loss.item() * data.num_graphs
    return total_loss / len(loader.dataset)

In [5]:
def test(model, loader, device):
    model.eval()
    correct = 0
    with torch.no_grad():
        for data in tqdm(loader, desc="Testing"):
            data = data.to(device)
            out = model(data)
            pred = out.argmax(dim=1)
            correct += int((pred == data.y).sum())
    return correct / len(loader.dataset)

In [6]:
if __name__ == '__main__':
    # Setup device
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Using device: {device}")

    # Load dataset

    dataset = MNISTGraphDataset(root='./mnist_graphs') # <-- Use the new root path
    
    # Split dataset into training and testing
    # Note: A fixed random_state ensures the split is the same every time.
    train_indices, test_indices = train_test_split(list(range(len(dataset))), test_size=0.2, random_state=42)
    train_dataset = dataset[train_indices]
    test_dataset = dataset[test_indices]
    
    print(f"Number of training graphs: {len(train_dataset)}")
    print(f"Number of test graphs: {len(test_dataset)}")

    # Create DataLoaders
    train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True,num_workers=0,pin_memory=True)
    test_loader = DataLoader(test_dataset, batch_size=128, shuffle=False,num_workers=0,pin_memory=True)

    # Initialize model, optimizer
    # Node features are [intensity, pos_y, pos_x], so num_node_features=3
    model = GNN_MNIST(num_node_features=3, num_classes=10).to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

    # Training loop
    epochs = 20
    history = {'loss': [], 'accuracy': []}
    
    for epoch in range(1, epochs + 1):
        loss = train(model, train_loader, optimizer, device)
        test_acc = test(model, test_loader, device)
        history['loss'].append(loss)
        history['accuracy'].append(test_acc)
        print(f'Epoch: {epoch:02d}, Loss: {loss:.4f}, Test Accuracy: {test_acc:.4f}')


print("Saving trained model state...")
torch.save(model.state_dict(), 'gnn_mnist_model.pth')
print("Model saved to gnn_mnist_model.pth")



Using device: cuda
Number of training graphs: 48000
Number of test graphs: 12000


Training: 100%|██████████| 375/375 [00:51<00:00,  7.33it/s]
Testing: 100%|██████████| 94/94 [00:13<00:00,  6.75it/s]


Epoch: 01, Loss: 2.1115, Test Accuracy: 0.2480


Training: 100%|██████████| 375/375 [00:32<00:00, 11.62it/s]
Testing: 100%|██████████| 94/94 [00:06<00:00, 14.60it/s]


Epoch: 02, Loss: 1.9892, Test Accuracy: 0.2860


Training: 100%|██████████| 375/375 [00:31<00:00, 11.93it/s]
Testing: 100%|██████████| 94/94 [00:13<00:00,  7.10it/s]


Epoch: 03, Loss: 1.9133, Test Accuracy: 0.2938


Training: 100%|██████████| 375/375 [00:53<00:00,  6.98it/s]
Testing: 100%|██████████| 94/94 [00:05<00:00, 16.63it/s]


Epoch: 04, Loss: 1.8619, Test Accuracy: 0.3209


Training: 100%|██████████| 375/375 [00:28<00:00, 12.95it/s]
Testing: 100%|██████████| 94/94 [00:05<00:00, 16.02it/s]


Epoch: 05, Loss: 1.8291, Test Accuracy: 0.3331


Training: 100%|██████████| 375/375 [00:31<00:00, 11.74it/s]
Testing: 100%|██████████| 94/94 [00:05<00:00, 15.96it/s]


Epoch: 06, Loss: 1.8009, Test Accuracy: 0.3420


Training: 100%|██████████| 375/375 [00:26<00:00, 13.92it/s]
Testing: 100%|██████████| 94/94 [00:10<00:00,  8.90it/s]


Epoch: 07, Loss: 1.7669, Test Accuracy: 0.3662


Training: 100%|██████████| 375/375 [00:26<00:00, 13.98it/s]
Testing: 100%|██████████| 94/94 [00:06<00:00, 13.70it/s]


Epoch: 08, Loss: 1.7146, Test Accuracy: 0.4359


Training: 100%|██████████| 375/375 [00:29<00:00, 12.92it/s]
Testing: 100%|██████████| 94/94 [00:05<00:00, 17.39it/s]


Epoch: 09, Loss: 1.6166, Test Accuracy: 0.4943


Training: 100%|██████████| 375/375 [00:24<00:00, 15.05it/s]
Testing: 100%|██████████| 94/94 [00:05<00:00, 17.80it/s]


Epoch: 10, Loss: 1.5309, Test Accuracy: 0.5397


Training: 100%|██████████| 375/375 [00:30<00:00, 12.19it/s]
Testing: 100%|██████████| 94/94 [00:07<00:00, 11.88it/s]


Epoch: 11, Loss: 1.4776, Test Accuracy: 0.5599


Training: 100%|██████████| 375/375 [00:45<00:00,  8.30it/s]
Testing: 100%|██████████| 94/94 [00:05<00:00, 16.56it/s]


Epoch: 12, Loss: 1.4218, Test Accuracy: 0.5789


Training: 100%|██████████| 375/375 [00:32<00:00, 11.37it/s]
Testing: 100%|██████████| 94/94 [00:11<00:00,  8.14it/s]


Epoch: 13, Loss: 1.3768, Test Accuracy: 0.5967


Training: 100%|██████████| 375/375 [01:01<00:00,  6.09it/s]
Testing: 100%|██████████| 94/94 [00:13<00:00,  7.21it/s]


Epoch: 14, Loss: 1.3376, Test Accuracy: 0.5978


Training: 100%|██████████| 375/375 [00:38<00:00,  9.76it/s]
Testing: 100%|██████████| 94/94 [00:14<00:00,  6.52it/s]


Epoch: 15, Loss: 1.3016, Test Accuracy: 0.6171


Training: 100%|██████████| 375/375 [00:44<00:00,  8.41it/s]
Testing: 100%|██████████| 94/94 [00:11<00:00,  8.48it/s]


Epoch: 16, Loss: 1.2689, Test Accuracy: 0.6292


Training: 100%|██████████| 375/375 [00:49<00:00,  7.60it/s]
Testing: 100%|██████████| 94/94 [00:08<00:00, 10.83it/s]


Epoch: 17, Loss: 1.2321, Test Accuracy: 0.6370


Training: 100%|██████████| 375/375 [00:39<00:00,  9.58it/s]
Testing: 100%|██████████| 94/94 [00:12<00:00,  7.25it/s]


Epoch: 18, Loss: 1.1987, Test Accuracy: 0.6505


Training: 100%|██████████| 375/375 [00:46<00:00,  8.05it/s]
Testing: 100%|██████████| 94/94 [00:13<00:00,  6.77it/s]


Epoch: 19, Loss: 1.1731, Test Accuracy: 0.6545


Training: 100%|██████████| 375/375 [00:46<00:00,  8.03it/s]
Testing: 100%|██████████| 94/94 [00:13<00:00,  7.01it/s]

Epoch: 20, Loss: 1.1459, Test Accuracy: 0.6679
Saving trained model state...
Model saved to gnn_mnist_model.pth



