In [2]:
import torch
from graph_transformer_pytorch import GraphTransformer
from torch_geometric.datasets import TUDataset
from torch.utils.data import DataLoader
from sklearn.model_selection import train_test_split
import numpy as np

In [8]:
dataset = TUDataset(root='data/TUDataset', name='MUTAG')

print(f"Number of graphs: {len(dataset)}")
print(f"Number of classes: {dataset.num_classes}")
print(f"Number of features: {dataset.num_features}")

# Example graph
graph = dataset[0]
print(graph)
print(f"Node features shape: {graph.x.shape}")
print(f"Edge index shape: {graph.edge_index.shape}")
print(f"Label: {graph.y}")

Downloading https://www.chrsmrrs.com/graphkerneldatasets/MUTAG.zip


Number of graphs: 188
Number of classes: 2
Number of features: 7
Data(edge_index=[2, 38], x=[17, 7], edge_attr=[38, 4], y=[1])
Node features shape: torch.Size([17, 7])
Edge index shape: torch.Size([2, 38])
Label: tensor([1])


Processing...
Done!


In [15]:
node_feature_dim = dataset.num_node_features
num_classes = dataset.num_classes

def collate_fn(batch):
    max_nodes = max([data.num_nodes for data in batch])
    batch_size = len(batch)
    nodes = torch.zeros((batch_size, max_nodes, node_feature_dim))
    adj_mat = torch.zeros((batch_size, max_nodes, max_nodes))
    mask = torch.zeros((batch_size, max_nodes), dtype=torch.bool)
    labels = torch.zeros(batch_size, dtype=torch.long)

    for i, data in enumerate(batch):
        n = data.num_nodes
        nodes[i, :n, :] = data.x
        mask[i, :n] = 1

        ei = data.edge_index
        adj_mat[i, :n, :n][ei[0], ei[1]] = 1

        labels[i] = data.y.item()

    return nodes, adj_mat, mask, labels

# Create a custom DataLoader with the collate function
train_idx, test_idx = train_test_split(np.arange(len(dataset)), test_size=0.2, random_state=42)
train_dataset = dataset[train_idx]
test_dataset = dataset[test_idx]
train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True, collate_fn=collate_fn)
test_loader = DataLoader(test_dataset, batch_size=16, shuffle=False, collate_fn=collate_fn)

In [16]:
model = GraphTransformer(
    dim=128,
    depth=6,
    heads=8,
    edge_dim=512,
    with_feedforwards=True,
    gated_residual=True,
    accept_adjacency_matrix=True,
)

# Move model to the appropriate device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)

# Projectors, Loass and Optimizer
classifier = torch.nn.Linear(128, dataset.num_classes).to(device)
input_projector = torch.nn.Linear(dataset.num_features, 128).to(device)
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

In [17]:
epochs = 50
for epoch in range(epochs):
    model.train()
    classifier.train()
    total_loss = 0.0

    for nodes, adj_mat, mask, labels in train_loader:
        nodes = nodes.to(device)
        adj_mat = adj_mat.to(device)
        mask = mask.to(device)
        labels = labels.to(device)

        nodes = input_projector(nodes)

        optimizer.zero_grad()
        output, _ = model(nodes, adj_mat=adj_mat, mask=mask)

        # Graph-level mean pooling
        graph_emb = (output * mask.unsqueeze(-1)).sum(dim=1) / mask.sum(dim=1, keepdim=True)
        logits = classifier(graph_emb)

        loss = criterion(logits, labels)
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    print(f"Epoch {epoch+1:02d} | Train Loss: {total_loss / len(train_loader):.4f}")

Epoch 01 | Train Loss: 0.6949
Epoch 02 | Train Loss: 0.6618
Epoch 03 | Train Loss: 0.6534
Epoch 04 | Train Loss: 0.6799
Epoch 05 | Train Loss: 0.6676
Epoch 06 | Train Loss: 0.6564
Epoch 07 | Train Loss: 0.6907
Epoch 08 | Train Loss: 0.6708
Epoch 09 | Train Loss: 0.6455
Epoch 10 | Train Loss: 0.6534
Epoch 11 | Train Loss: 0.6485
Epoch 12 | Train Loss: 0.6499
Epoch 13 | Train Loss: 0.6404
Epoch 14 | Train Loss: 0.6489
Epoch 15 | Train Loss: 0.6417
Epoch 16 | Train Loss: 0.6536
Epoch 17 | Train Loss: 0.6472
Epoch 18 | Train Loss: 0.6565
Epoch 19 | Train Loss: 0.6503
Epoch 20 | Train Loss: 0.6501
Epoch 21 | Train Loss: 0.6496
Epoch 22 | Train Loss: 0.6485
Epoch 23 | Train Loss: 0.6419
Epoch 24 | Train Loss: 0.6428
Epoch 25 | Train Loss: 0.6428
Epoch 26 | Train Loss: 0.6514
Epoch 27 | Train Loss: 0.6439
Epoch 28 | Train Loss: 0.6477
Epoch 29 | Train Loss: 0.6745
Epoch 30 | Train Loss: 0.6698
Epoch 31 | Train Loss: 0.6634
Epoch 32 | Train Loss: 0.6605
Epoch 33 | Train Loss: 0.6383
Epoch 34 |

In [19]:
model.eval()
classifier.eval()
y_true, y_pred = [], []

with torch.no_grad():
    for nodes, adj_mat, mask, labels in test_loader:
        nodes = nodes.to(device)
        adj_mat = adj_mat.to(device)
        mask = mask.to(device)
        labels = labels.to(device)

        nodes = input_projector(nodes)
        output, _ = model(nodes, adj_mat=adj_mat, mask=mask)
        graph_emb = (output * mask.unsqueeze(-1)).sum(dim=1) / mask.sum(dim=1, keepdim=True)
        logits = classifier(graph_emb)
        preds = torch.argmax(logits, dim=1)

        y_true.append(labels.cpu())
        y_pred.append(preds.cpu())

y_true = torch.cat(y_true).numpy()
y_pred = torch.cat(y_pred).numpy()
acc = (y_true == y_pred).mean()
print(f"Test Accuracy: {acc:.4f}")


Test Accuracy: 0.6842
