In [1]:
# ! pip install torch_geometric

In [2]:
import os
import torch
import torch.nn.functional as F
from torch_geometric.nn import DenseGCNConv, dense_diff_pool
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import pickle
import networkx as nx
import sys

from torch_geometric.datasets import Entities, TUDataset
from torch_geometric.loader import DataLoader
from torch_geometric.utils import to_dense_adj


from utils import *
from model import *

In [3]:
# Load Dataset
# dataset = Entities(root='data/Entities', name='MUTAG')
dataset = TUDataset(root='data/TUDataset', name='MUTAG')
# dataset = TUDataset(root='data/TUDataset', name='Mutagenicity')

print()
print(f'Dataset: {dataset}:')
print('====================')
print(f'Number of graphs: {len(dataset)}')
print(f'Number of features: {dataset.num_features}')
print(f'Number of classes: {dataset.num_classes}')

data = dataset[0]  # Get the first graph object.

print()
print(data)
print('=============================================================')

# Gather some statistics about the first graph.
print(f'Number of nodes: {data.num_nodes}')
print(f'Number of edges: {data.num_edges}')
print(f'Average node degree: {data.num_edges / data.num_nodes:.2f}')
print(f'Has isolated nodes: {data.has_isolated_nodes()}')
print(f'Has self-loops: {data.has_self_loops()}')
print(f'Is undirected: {data.is_undirected()}')


Dataset: MUTAG(188):
Number of graphs: 188
Number of features: 7
Number of classes: 2

Data(edge_index=[2, 38], x=[17, 7], edge_attr=[38, 4], y=[1])
Number of nodes: 17
Number of edges: 38
Average node degree: 2.24
Has isolated nodes: False
Has self-loops: False
Is undirected: True


In [4]:
torch.manual_seed(12345)

# Train test split
train_split = 0.8
batch_size = 16
train_loader, test_loader = prepare_data(dataset, train_split, batch_size)


Class split - Training 0: 53 1: 97, Test 0: 10 1: 28


In [5]:
def train(model, loader, optimizer, criterion):
    model.train()
    total_loss = 0

    for data in loader:
        adj = to_dense_adj(data.edge_index)[0]
        optimizer.zero_grad()

        # Create a dummy batch vector if your data doesn't have one
        dummy_batch = torch.zeros(data.num_nodes, dtype=torch.long, device=data.x.device)

        out, l1, e2 = model(data.x, adj, dummy_batch)
        loss = criterion(out, data.y) + l1 + e2
        loss.backward()
        optimizer.step()
        total_loss += loss.item()

    return total_loss / len(loader)



In [6]:
def test(model, loader, criterion):
    model.eval()
    total_loss = 0
    correct = 0

    with torch.no_grad():
        for data in loader:
            adj = to_dense_adj(data.edge_index)[0]
             # Create a dummy batch vector if your data doesn't have one
            dummy_batch = torch.zeros(data.num_nodes, dtype=torch.long, device=data.x.device)
            out, _, _ = model(data.x, adj, dummy_batch)
            loss = criterion(out, data.y)
            total_loss += loss.item()
            pred = out.max(dim=1)[1]
            correct += pred.eq(data.y).sum().item()

    return total_loss / len(loader), correct / len(loader.dataset)



In [7]:
pooling_ratio = 0.25

model = DiffPoolGNN(dataset.num_features, dataset.num_classes, pooling_ratio)  


lr = 0.01
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
criterion = torch.nn.CrossEntropyLoss()

In [8]:
epochs = 10
train_losses = []
test_losses = []

for epoch in range(epochs):
    train_loss = train(model, train_loader, optimizer, criterion)
    test_loss, test_acc = test(model, test_loader, criterion)

    train_losses.append(train_loss)
    test_losses.append(test_loss)

    print(f'Epoch {epoch+1}, Train Loss: {train_loss:.4f}, Test Loss: {test_loss:.4f}, Test Acc: {test_acc:.4f}')


plt.figure(figsize=(10, 6))
plt.plot(train_losses, label='Train Loss')
plt.plot(test_losses, label='Test Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Train vs Test Loss')
plt.legend()
plt.show()

RuntimeError: Expected index [294] to be smaller than self [1] apart from dimension 0 and to be smaller size than src [1]