In [8]:
import torch
import torch.nn.functional as F
from torch_geometric.datasets import TUDataset
from torch_geometric.loader import DataLoader
from torch_geometric.utils import to_dense_adj

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [9]:
dataset = TUDataset(root="data/MUTAG", name="MUTAG")
dataset = dataset.shuffle()

train_dataset = dataset[:150]
test_dataset  = dataset[150:]

train_loader = DataLoader(train_dataset, batch_size=1, shuffle=True)
test_loader  = DataLoader(test_dataset, batch_size=1)


In [None]:
dataset.num_classes # 2
dataset.num_node_features # 7

7

In [26]:
class GCNLayer(torch.nn.Module):
    def __init__(self, in_dim, out_dim):
        super().__init__()
        self.W = torch.nn.Parameter(torch.randn(in_dim, out_dim) * 0.01)

    def forward(self, A_norm, X):
        return A_norm @ X @ self.W
    
class GraphGCN(torch.nn.Module):
    def __init__(self, in_dim, hidden_dim, num_classes):
        super().__init__()
        self.gcn1 = GCNLayer(in_dim, hidden_dim)
        self.gcn2 = GCNLayer(hidden_dim, hidden_dim)
        self.classifier = torch.nn.Linear(hidden_dim, num_classes)

    def forward(self, A_norm, X):
        H = torch.relu(self.gcn1(A_norm, X))
        H = torch.relu(self.gcn2(A_norm, H))

        # Graph-level pooling (mean)
        g = H.mean(dim=0, keepdim=True)

        return self.classifier(g)

def normalize_adj(edge_index, num_nodes):
    A = to_dense_adj(edge_index, max_num_nodes=num_nodes)[0]

    I = torch.eye(num_nodes, device = device)
    A_hat = A + I

    D = torch.diag(A_hat.sum(dim=1))
    D_inv_sqrt = torch.linalg.inv(torch.sqrt(D))

    return D_inv_sqrt @ A_hat @ D_inv_sqrt


In [27]:

model = GraphGCN(
    in_dim=dataset.num_node_features,
    hidden_dim=32,
    num_classes=dataset.num_classes
).to(device)

optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
loss_fn = torch.nn.CrossEntropyLoss()


In [28]:
def train():
    model.train()
    total_loss = 0

    for data in train_loader:
        data = data.to(device)

        A_norm = normalize_adj(data.edge_index, data.num_nodes).to(device)
        X = data.x

        optimizer.zero_grad()
        out = model(A_norm, X)
        loss = loss_fn(out, data.y)
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    return total_loss / len(train_loader)

def test(loader):
    model.eval()
    correct = 0

    with torch.no_grad():
        for data in loader:
            data = data.to(device)
            A_norm = normalize_adj(data.edge_index, data.num_nodes).to(device)
            X = data.x

            out = model(A_norm, X)
            pred = out.argmax(dim=1)

            correct += (pred == data.y).sum().item()

    return correct / len(loader.dataset)



In [29]:
epochs = 500
for epoch in range(epochs):
    loss = train()
    train_acc = test(train_loader)
    test_acc = test(test_loader)

    if epoch % 10 == 0:
        print(
            f"Epoch {epoch:03d} | "
            f"Loss {loss:.4f} | "
            f"Train Acc {train_acc:.3f} | "
            f"Test Acc {test_acc:.3f}"
        )


Epoch 000 | Loss 0.6728 | Train Acc 0.693 | Test Acc 0.553
Epoch 010 | Loss 0.5822 | Train Acc 0.693 | Test Acc 0.553
Epoch 020 | Loss 0.5616 | Train Acc 0.700 | Test Acc 0.553
Epoch 030 | Loss 0.5560 | Train Acc 0.713 | Test Acc 0.579
Epoch 040 | Loss 0.5459 | Train Acc 0.707 | Test Acc 0.711
Epoch 050 | Loss 0.5512 | Train Acc 0.727 | Test Acc 0.763
Epoch 060 | Loss 0.5421 | Train Acc 0.740 | Test Acc 0.763
Epoch 070 | Loss 0.5389 | Train Acc 0.747 | Test Acc 0.763
Epoch 080 | Loss 0.5396 | Train Acc 0.733 | Test Acc 0.711
Epoch 090 | Loss 0.5331 | Train Acc 0.740 | Test Acc 0.711
Epoch 100 | Loss 0.5340 | Train Acc 0.733 | Test Acc 0.711
Epoch 110 | Loss 0.5334 | Train Acc 0.740 | Test Acc 0.737
Epoch 120 | Loss 0.5318 | Train Acc 0.733 | Test Acc 0.711
Epoch 130 | Loss 0.5307 | Train Acc 0.733 | Test Acc 0.737
Epoch 140 | Loss 0.5245 | Train Acc 0.747 | Test Acc 0.763
Epoch 150 | Loss 0.5304 | Train Acc 0.727 | Test Acc 0.711
Epoch 160 | Loss 0.5293 | Train Acc 0.727 | Test Acc 0.7