In [6]:
import numpy as np
from sklearn.cluster import KMeans
from scipy.optimize import linear_sum_assignment
from torch_geometric.datasets import TUDataset
from torch.utils.data import DataLoader
from sklearn.model_selection import train_test_split
from swin_graph_transformer import SwinGraphTransformer
import torch
import warnings
warnings.filterwarnings("ignore")

# Train

In [7]:
dataset = TUDataset(root='data/TUDataset', name='MUTAG')

print(f"Number of graphs: {len(dataset)}")
print(f"Number of classes: {dataset.num_classes}")
print(f"Number of features: {dataset.num_features}")

# Example graph
graph = dataset[0]
print(graph)
print(f"Node features shape: {graph.x.shape}")
print(f"Edge index shape: {graph.edge_index.shape}")
print(f"Label: {graph.y}")

Number of graphs: 188
Number of classes: 2
Number of features: 7
Data(edge_index=[2, 38], x=[17, 7], edge_attr=[38, 4], y=[1])
Node features shape: torch.Size([17, 7])
Edge index shape: torch.Size([2, 38])
Label: tensor([1])


In [8]:
node_feature_dim = dataset.num_node_features
num_classes = dataset.num_classes

def collate_fn(batch):
    max_nodes = max([data.num_nodes for data in batch])
    batch_size = len(batch)
    nodes = torch.zeros((batch_size, max_nodes, node_feature_dim))
    adj_mat = torch.zeros((batch_size, max_nodes, max_nodes))
    mask = torch.zeros((batch_size, max_nodes), dtype=torch.bool)
    labels = torch.zeros(batch_size, dtype=torch.long)

    for i, data in enumerate(batch):
        n = data.num_nodes
        nodes[i, :n, :] = data.x
        mask[i, :n] = 1

        ei = data.edge_index
        adj_mat[i, :n, :n][ei[0], ei[1]] = 1

        labels[i] = data.y.item()

    return nodes, adj_mat, mask, labels

# Create a custom DataLoader with the collate function
train_idx, test_idx = train_test_split(np.arange(len(dataset)), test_size=0.2, random_state=42)
train_dataset = dataset[train_idx]
test_dataset = dataset[test_idx]
train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True, collate_fn=collate_fn)
test_loader = DataLoader(test_dataset, batch_size=16, shuffle=False, collate_fn=collate_fn)

In [9]:
def adj_to_edge_index(adj_mat):
    """Convert adjacency matrix to edge_index format"""
    edge_index = torch.nonzero(adj_mat, as_tuple=False).t().contiguous()
    return edge_index.long()  # Ensure it's long tensor

In [10]:
# More efficient training loop
model = SwinGraphTransformer(
    in_dim=node_feature_dim,
    dims=(64, 128, 256),
    heads=(4, 8, 8),
    cluster_size=8,
    pool_ratio=0.7  # Keep more nodes to avoid over-pooling
)

# Move model to the appropriate device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)

# Projectors, Loss and Optimizer
classifier = torch.nn.Linear(256, dataset.num_classes).to(device)
input_projector = torch.nn.Linear(dataset.num_features, node_feature_dim).to(device)
criterion = torch.nn.CrossEntropyLoss()

# Use separate optimizers with different learning rates
model_optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
classifier_optimizer = torch.optim.Adam(classifier.parameters(), lr=0.01)

epochs = 50
for epoch in range(epochs):
    model.train()
    classifier.train()
    total_loss = 0.0
    num_batches = 0

    for batch_idx, (nodes, adj_mat, mask, labels) in enumerate(train_loader):
        nodes = nodes.to(device)
        adj_mat = adj_mat.to(device)
        mask = mask.to(device)
        labels = labels.to(device)

        batch_size, num_nodes, node_dim = nodes.shape
        
        # Process each graph in the batch
        batch_outputs = []
        
        for b in range(batch_size):
            # Get single graph data
            single_nodes = nodes[b]  # [num_nodes, node_dim]
            single_adj = adj_mat[b]  # [num_nodes, num_nodes]
            single_mask = mask[b]    # [num_nodes]
            
            # Convert adjacency matrix to edge_index
            edge_index = adj_to_edge_index(single_adj)
            
            # Apply input projection
            single_nodes = input_projector(single_nodes)

            # Forward pass through model
            output, perm = model(single_nodes, edge_index)
            
            # Apply mask only to the nodes that remain after pooling
            remaining_mask = single_mask[perm]  # Get mask for remaining nodes
            
            # Pool with proper masking
            if remaining_mask.sum() > 0:
                masked_output = output * remaining_mask.unsqueeze(-1)
                graph_emb = masked_output.sum(dim=0) / remaining_mask.sum()
            else:
                graph_emb = output.mean(dim=0)
            
            batch_outputs.append(graph_emb)
        
        batch_embeddings = torch.stack(batch_outputs)  # [batch_size, final_dim]
        
        # Classification
        model_optimizer.zero_grad()
        classifier_optimizer.zero_grad()
        
        logits = classifier(batch_embeddings)
        loss = criterion(logits, labels)
        loss.backward()
        
        model_optimizer.step()
        classifier_optimizer.step()

        total_loss += loss.item()
        num_batches += 1

    avg_loss = total_loss / num_batches if num_batches > 0 else 0.0
    print(f"Epoch {epoch+1:02d} | Train Loss: {avg_loss:.4f}")

Epoch 01 | Train Loss: 1.1538
Epoch 02 | Train Loss: 0.6204
Epoch 03 | Train Loss: 0.5089
Epoch 04 | Train Loss: 0.7711
Epoch 05 | Train Loss: 0.6216
Epoch 06 | Train Loss: 0.6748
Epoch 07 | Train Loss: 0.4915
Epoch 08 | Train Loss: 0.5875
Epoch 09 | Train Loss: 0.4970
Epoch 10 | Train Loss: 0.6133
Epoch 11 | Train Loss: 0.5638
Epoch 12 | Train Loss: 0.4180
Epoch 13 | Train Loss: 0.4503
Epoch 14 | Train Loss: 0.6630
Epoch 15 | Train Loss: 0.5510
Epoch 16 | Train Loss: 0.4410
Epoch 17 | Train Loss: 0.8632
Epoch 18 | Train Loss: 0.5637
Epoch 19 | Train Loss: 0.6816
Epoch 20 | Train Loss: 0.6838
Epoch 21 | Train Loss: 0.6539
Epoch 22 | Train Loss: 0.6375
Epoch 23 | Train Loss: 0.7011
Epoch 24 | Train Loss: 0.6507
Epoch 25 | Train Loss: 0.6316
Epoch 26 | Train Loss: 0.6404
Epoch 27 | Train Loss: 0.6313
Epoch 28 | Train Loss: 0.5921
Epoch 29 | Train Loss: 0.5639
Epoch 30 | Train Loss: 0.6291
Epoch 31 | Train Loss: 0.6455
Epoch 32 | Train Loss: 0.6235
Epoch 33 | Train Loss: 0.6177
Epoch 34 |

In [11]:
# --- Evaluation ---
model.eval()
classifier.eval()
y_true, y_pred = [], []

with torch.no_grad():
    for nodes, adj_mat, mask, labels in test_loader:
        nodes, adj_mat, mask, labels = (
            nodes.to(device),
            adj_mat.to(device),
            mask.to(device),
            labels.to(device),
        )

        batch_size, num_nodes, node_dim = nodes.shape
        batch_outputs = []

        for b in range(batch_size):
            # Single graph data
            single_nodes = nodes[b]      # [num_nodes, node_dim]
            single_adj = adj_mat[b]      # [num_nodes, num_nodes]
            single_mask = mask[b]        # [num_nodes]

            # Convert adjacency matrix to edge_index
            edge_index = adj_to_edge_index(single_adj)

            # Input projection
            single_nodes = input_projector(single_nodes)

            # Forward through model
            output, perm = model(single_nodes, edge_index)

            # Apply mask to remaining nodes
            remaining_mask = single_mask[perm]
            if remaining_mask.sum() > 0:
                masked_output = output * remaining_mask.unsqueeze(-1)
                graph_emb = masked_output.sum(dim=0) / remaining_mask.sum()
            else:
                graph_emb = output.mean(dim=0)

            batch_outputs.append(graph_emb)

        batch_embeddings = torch.stack(batch_outputs)  # [batch_size, final_dim]

        # Classification
        logits = classifier(batch_embeddings)
        preds = torch.argmax(logits, dim=1)

        y_true.append(labels.cpu())
        y_pred.append(preds.cpu())

# Accuracy
acc = (torch.cat(y_true) == torch.cat(y_pred)).float().mean().item()
print(f"Test Accuracy: {acc:.4f}")


Test Accuracy: 0.6842
