In [6]:
import numpy as np
from sklearn.cluster import KMeans
from scipy.optimize import linear_sum_assignment
from torch_geometric.datasets import TUDataset
from torch.utils.data import DataLoader
from sklearn.model_selection import train_test_split
from swin_graph_transformer import SwinGraphTransformer
import torch
import warnings
warnings.filterwarnings("ignore")

# Train

In [7]:
dataset = TUDataset(root='data/TUDataset', name='MUTAG')

print(f"Number of graphs: {len(dataset)}")
print(f"Number of classes: {dataset.num_classes}")
print(f"Number of features: {dataset.num_features}")

# Example graph
graph = dataset[0]
print(graph)
print(f"Node features shape: {graph.x.shape}")
print(f"Edge index shape: {graph.edge_index.shape}")
print(f"Label: {graph.y}")

Number of graphs: 188
Number of classes: 2
Number of features: 7
Data(edge_index=[2, 38], x=[17, 7], edge_attr=[38, 4], y=[1])
Node features shape: torch.Size([17, 7])
Edge index shape: torch.Size([2, 38])
Label: tensor([1])


In [8]:
node_feature_dim = dataset.num_node_features
num_classes = dataset.num_classes

def collate_fn(batch):
    max_nodes = max([data.num_nodes for data in batch])
    batch_size = len(batch)
    nodes = torch.zeros((batch_size, max_nodes, node_feature_dim))
    adj_mat = torch.zeros((batch_size, max_nodes, max_nodes))
    mask = torch.zeros((batch_size, max_nodes), dtype=torch.bool)
    labels = torch.zeros(batch_size, dtype=torch.long)

    for i, data in enumerate(batch):
        n = data.num_nodes
        nodes[i, :n, :] = data.x
        mask[i, :n] = 1

        ei = data.edge_index
        adj_mat[i, :n, :n][ei[0], ei[1]] = 1

        labels[i] = data.y.item()

    return nodes, adj_mat, mask, labels

# Create a custom DataLoader with the collate function
train_idx, test_idx = train_test_split(np.arange(len(dataset)), test_size=0.2, random_state=42)
train_dataset = dataset[train_idx]
test_dataset = dataset[test_idx]
train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True, collate_fn=collate_fn)
test_loader = DataLoader(test_dataset, batch_size=16, shuffle=False, collate_fn=collate_fn)

In [9]:
def adj_to_edge_index(adj_mat):
    """Convert adjacency matrix to edge_index format"""
    edge_index = torch.nonzero(adj_mat, as_tuple=False).t().contiguous()
    return edge_index.long()  # Ensure it's long tensor

In [10]:
# More efficient training loop
model = SwinGraphTransformer(
    in_dim=node_feature_dim,
    dims=(64, 128, 256),
    heads=(4, 8, 8),
    Ks=(4, 2, 2),  # Reduced cluster numbers for faster computation
    pool_ratio=0.7  # Keep more nodes to avoid over-pooling
)

# Move model to the appropriate device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)

# Projectors, Loss and Optimizer
classifier = torch.nn.Linear(256, dataset.num_classes).to(device)
input_projector = torch.nn.Linear(dataset.num_features, node_feature_dim).to(device)
criterion = torch.nn.CrossEntropyLoss()

# Use separate optimizers with different learning rates
model_optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
classifier_optimizer = torch.optim.Adam(classifier.parameters(), lr=0.01)

epochs = 50
for epoch in range(epochs):
    model.train()
    classifier.train()
    total_loss = 0.0
    num_batches = 0

    for batch_idx, (nodes, adj_mat, mask, labels) in enumerate(train_loader):
        nodes = nodes.to(device)
        adj_mat = adj_mat.to(device)
        mask = mask.to(device)
        labels = labels.to(device)

        batch_size, num_nodes, node_dim = nodes.shape
        
        # Process each graph in the batch
        batch_outputs = []
        
        for b in range(batch_size):
            # Get single graph data
            single_nodes = nodes[b]  # [num_nodes, node_dim]
            single_adj = adj_mat[b]  # [num_nodes, num_nodes]
            single_mask = mask[b]    # [num_nodes]
            
            # Skip graphs with too few nodes
            if single_mask.sum() < 4:
                # Use simple mean pooling for very small graphs
                valid_nodes = single_nodes[single_mask.bool()]
                if len(valid_nodes) > 0:
                    graph_emb = input_projector(valid_nodes).mean(dim=0)
                    # Pad to final dimension
                    if graph_emb.size(0) != 256:
                        graph_emb = torch.nn.functional.linear(graph_emb, torch.eye(256, graph_emb.size(0), device=device))
                else:
                    graph_emb = torch.zeros(256, device=device)
                batch_outputs.append(graph_emb)
                continue
            
            # Convert adjacency matrix to edge_index
            edge_index = adj_to_edge_index(single_adj)
            
            # Apply input projection
            single_nodes = input_projector(single_nodes)
            
            try:
                # Forward pass through model
                output, perm = model(single_nodes, edge_index)
                
                # Apply mask only to the nodes that remain after pooling
                remaining_mask = single_mask[perm]  # Get mask for remaining nodes
                
                # Pool with proper masking
                if remaining_mask.sum() > 0:
                    masked_output = output * remaining_mask.unsqueeze(-1)
                    graph_emb = masked_output.sum(dim=0) / remaining_mask.sum()
                else:
                    # Fallback if no valid nodes remain
                    graph_emb = output.mean(dim=0)
                
            except Exception as e:
                # Fallback for problematic graphs
                print(f"Warning: Graph {b} in batch {batch_idx} failed: {e}")
                valid_nodes = single_nodes[single_mask.bool()]
                graph_emb = valid_nodes.mean(dim=0) if len(valid_nodes) > 0 else torch.zeros(256, device=device)
            
            batch_outputs.append(graph_emb)
        
        # Stack batch outputs
        batch_embeddings = torch.stack(batch_outputs)  # [batch_size, final_dim]
        
        # Classification
        model_optimizer.zero_grad()
        classifier_optimizer.zero_grad()
        
        logits = classifier(batch_embeddings)
        loss = criterion(logits, labels)
        loss.backward()
        
        model_optimizer.step()
        classifier_optimizer.step()

        total_loss += loss.item()
        num_batches += 1
        
        # Print progress every 10 batches
        if batch_idx % 10 == 0:
            print(f"Epoch {epoch+1:02d}, Batch {batch_idx:03d} | Loss: {loss.item():.4f}")

    avg_loss = total_loss / num_batches if num_batches > 0 else 0.0
    print(f"Epoch {epoch+1:02d} | Train Loss: {avg_loss:.4f}")

Epoch 01, Batch 000 | Loss: 0.6738
Epoch 01 | Train Loss: 0.8470
Epoch 02, Batch 000 | Loss: 0.7885
Epoch 02 | Train Loss: 0.6504
Epoch 03, Batch 000 | Loss: 0.7976
Epoch 03 | Train Loss: 0.6516
Epoch 04, Batch 000 | Loss: 0.6005
Epoch 04 | Train Loss: 0.6381
Epoch 05, Batch 000 | Loss: 0.5859
Epoch 05 | Train Loss: 0.6418
Epoch 06, Batch 000 | Loss: 0.3953
Epoch 06 | Train Loss: 0.5975
Epoch 07, Batch 000 | Loss: 0.6053
Epoch 07 | Train Loss: 0.5616
Epoch 08, Batch 000 | Loss: 0.9151
Epoch 08 | Train Loss: 0.6713
Epoch 09, Batch 000 | Loss: 0.6891
Epoch 09 | Train Loss: 0.6494
Epoch 10, Batch 000 | Loss: 0.5698
Epoch 10 | Train Loss: 0.6834
Epoch 11, Batch 000 | Loss: 0.5801
Epoch 11 | Train Loss: 0.6416
Epoch 12, Batch 000 | Loss: 0.5892
Epoch 12 | Train Loss: 0.6499
Epoch 13, Batch 000 | Loss: 0.6821
Epoch 13 | Train Loss: 0.6204
Epoch 14, Batch 000 | Loss: 0.6265
Epoch 14 | Train Loss: 0.6992
Epoch 15, Batch 000 | Loss: 0.5698
Epoch 15 | Train Loss: 0.6612
Epoch 16, Batch 000 | Los

In [11]:
model.eval()
classifier.eval()
y_true, y_pred = [], []

with torch.no_grad():
    for batch_idx, (nodes, adj_mat, mask, labels) in enumerate(test_loader):
        nodes = nodes.to(device)
        adj_mat = adj_mat.to(device)
        mask = mask.to(device)
        labels = labels.to(device)

        batch_size, num_nodes, node_dim = nodes.shape
        batch_outputs = []

        for b in range(batch_size):
            single_nodes = nodes[b]  # [num_nodes, node_dim]
            single_adj = adj_mat[b]  # [num_nodes, num_nodes]
            single_mask = mask[b]    # [num_nodes]

            # Handle very small graphs
            if single_mask.sum() < 4:
                valid_nodes = single_nodes[single_mask.bool()]
                if len(valid_nodes) > 0:
                    graph_emb = input_projector(valid_nodes).mean(dim=0)
                    if graph_emb.size(0) != 256:
                        graph_emb = torch.nn.functional.linear(
                            graph_emb, torch.eye(256, graph_emb.size(0), device=device)
                        )
                else:
                    graph_emb = torch.zeros(256, device=device)
                batch_outputs.append(graph_emb)
                continue

            # Convert adjacency matrix to edge_index
            edge_index = adj_to_edge_index(single_adj)

            # Apply input projection
            single_nodes_proj = input_projector(single_nodes)

            try:
                # Forward pass through model
                output, perm = model(single_nodes_proj, edge_index)

                # Mask remaining nodes after pooling
                remaining_mask = single_mask[perm]
                if remaining_mask.sum() > 0:
                    masked_output = output * remaining_mask.unsqueeze(-1)
                    graph_emb = masked_output.sum(dim=0) / remaining_mask.sum()
                else:
                    graph_emb = output.mean(dim=0)

            except Exception as e:
                print(f"Warning: Graph {b} in batch {batch_idx} failed: {e}")
                valid_nodes = single_nodes_proj[single_mask.bool()]
                graph_emb = valid_nodes.mean(dim=0) if len(valid_nodes) > 0 else torch.zeros(256, device=device)

            batch_outputs.append(graph_emb)

        batch_embeddings = torch.stack(batch_outputs)  # [batch_size, final_dim]

        # Classification
        logits = classifier(batch_embeddings)
        preds = torch.argmax(logits, dim=1)

        y_true.append(labels.cpu())
        y_pred.append(preds.cpu())

y_true = torch.cat(y_true).numpy()
y_pred = torch.cat(y_pred).numpy()
acc = (y_true == y_pred).mean()
print(f"Test Accuracy: {acc:.4f}")

Test Accuracy: 0.6842
