# KAN-GPSConv: Integrating Kolmogorov-Arnold Networks with Graph Positional Signatures

This notebook implements a Graph Neural Network that integrates Kolmogorov-Arnold Network (KAN) layers into Graph Positional Signatures (GPS) networks for both node and graph classification tasks.

## Setup and Imports

First, we'll install the necessary packages and import the required libraries.

In [None]:
# Install required packages
!pip install torch torch_geometric wandb
!pip install ogb  # For OGB datasets if needed

# Import necessary libraries
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GCNConv, global_mean_pool
from torch_geometric.datasets import Planetoid, TUDataset
from torch_geometric.transforms import NormalizeFeatures
from torch_geometric.loader import DataLoader
import wandb
import argparse
import numpy as np
from sklearn.metrics import accuracy_score, precision_recall_fscore_support

## Weights & Biases Integration

Before running experiments, make sure to log in to Weights & Biases:

In [None]:
wandb.login()

## Model Implementation

### KAN Layer

The Kolmogorov-Arnold Network (KAN) layer is implemented as a custom PyTorch module.

In [None]:
class KANLayer(nn.Module):
    def __init__(self, in_features, out_features, num_neurons=10):
        super(KANLayer, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.num_neurons = num_neurons
        
        # Initialize weights and biases
        self.weights = nn.Parameter(torch.Tensor(num_neurons, in_features))
        self.biases = nn.Parameter(torch.Tensor(num_neurons))
        self.output_weights = nn.Parameter(torch.Tensor(out_features, num_neurons))
        
        self.reset_parameters()
    
    def reset_parameters(self):
        nn.init.xavier_uniform_(self.weights)
        nn.init.xavier_uniform_(self.output_weights)
        nn.init.zeros_(self.biases)
    
    def forward(self, x):
        # Compute the inner products
        inner_products = F.linear(x, self.weights, self.biases)
        
        # Apply activation function (e.g., sine for KAN)
        activations = torch.sin(inner_products)
        
        # Compute output
        output = F.linear(activations, self.output_weights)
        
        return output

### GPS Network

The Graph Positional Signature (GPS) Network integrates the KAN layer with graph convolutions.

In [None]:
class GPSNetwork(nn.Module):
    def __init__(self, num_node_features, num_classes, hidden_dim=64, task='node'):
        super(GPSNetwork, self).__init__()
        self.task = task
        
        # Define network layers
        self.conv1 = GCNConv(num_node_features, hidden_dim)
        self.kan = KANLayer(hidden_dim, hidden_dim)
        self.conv2 = GCNConv(hidden_dim, hidden_dim)
        
        # Final classification layer
        if task == 'graph':
            self.fc = nn.Linear(hidden_dim, num_classes)
        else:  # node classification
            self.fc = nn.Linear(hidden_dim, num_classes)

    def forward(self, data):
        x, edge_index = data.x, data.edge_index
        
        # First graph convolution
        x = self.conv1(x, edge_index)
        x = F.relu(x)
        
        # Apply KAN layer
        x = self.kan(x)
        
        # Second graph convolution
        x = self.conv2(x, edge_index)
        x = F.relu(x)
        
        if self.task == 'graph':
            # Global pooling for graph classification
            x = global_mean_pool(x, data.batch)
        
        # Final classification layer
        x = self.fc(x)
        
        return F.log_softmax(x, dim=1)

## Data Loading and Preprocessing

These functions handle dataset loading and preprocessing for both node and graph classification tasks.

In [None]:
def load_dataset(name, root='/content/data', task='node'):
    transform = NormalizeFeatures()
    
    if task == 'node':
        if name in ['Cora', 'Citeseer', 'PubMed']:
            dataset = Planetoid(root=root, name=name, transform=transform)
        else:
            raise ValueError(f"Node classification dataset {name} not recognized")
        
        data = dataset[0]
        num_classes = dataset.num_classes
        return data, num_classes
    
    elif task == 'graph':
        dataset = TUDataset(root=root, name=name, transform=transform)
        num_classes = dataset.num_classes
        return dataset, num_classes
    
    else:
        raise ValueError("Task must be either 'node' or 'graph'")

def split_data(data, val_ratio=0.1, test_ratio=0.1):
    # Split node data into train, validation, and test sets
    num_nodes = data.num_nodes
    indices = torch.randperm(num_nodes)
    test_size = int(num_nodes * test_ratio)
    val_size = int(num_nodes * val_ratio)
    
    data.test_mask = torch.zeros(num_nodes, dtype=torch.bool)
    data.test_mask[indices[:test_size]] = True
    
    data.val_mask = torch.zeros(num_nodes, dtype=torch.bool)
    data.val_mask[indices[test_size:test_size+val_size]] = True
    
    data.train_mask = torch.zeros(num_nodes, dtype=torch.bool)
    data.train_mask[indices[test_size+val_size:]] = True
    
    return data

def prepare_graph_data(dataset, batch_size=32, split_ratios=[0.8, 0.1, 0.1]):
    # Prepare data loaders for graph classification tasks
    num_graphs = len(dataset)
    num_train = int(num_graphs * split_ratios[0])
    num_val = int(num_graphs * split_ratios[1])
    num_test = num_graphs - num_train - num_val
    
    train_dataset = dataset[:num_train]
    val_dataset = dataset[num_train:num_train+num_val]
    test_dataset = dataset[num_train+num_val:]
    
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size)
    test_loader = DataLoader(test_dataset, batch_size=batch_size)
    
    return train_loader, val_loader, test_loader

## Training and Evaluation Functions

These functions handle the training and evaluation processes for both node and graph classification tasks.

In [None]:
def train_node(model, data, optimizer):
    model.train()
    optimizer.zero_grad()
    out = model(data)
    loss = F.nll_loss(out[data.train_mask], data.y[data.train_mask])
    loss.backward()
    optimizer.step()
    return loss.item()

def train_graph(model, loader, optimizer, device):
    model.train()
    total_loss = 0
    for data in loader:
        data = data.to(device)
        optimizer.zero_grad()
        out = model(data)
        loss = F.nll_loss(out, data.y)
        loss.backward()
        optimizer.step()
        total_loss += loss.item() * data.num_graphs
    return total_loss / len(loader.dataset)

def test_node(model, data, mask):
    model.eval()
    with torch.no_grad():
        out = model(data)
        pred = out.argmax(dim=1)
        correct = pred[mask] == data.y[mask]
        acc = int(correct.sum()) / int(mask.sum())
    return acc

def test_graph(model, loader, device):
    model.eval()
    correct = 0
    for data in loader:
        data = data.to(device)
        with torch.no_grad():
            out = model(data)
            pred = out.argmax(dim=1)
            correct += int((pred == data.y).sum())
    return correct / len(loader.dataset)

## Main Training Loop

This function orchestrates the entire training process, including data loading, model initialization, training, and evaluation.

In [None]:
def main(args):
    wandb.init(project="KAN-GPSConv", config=args)

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    
    if args.task == 'node':
        # Node classification task
        data, num_classes = load_dataset(args.dataset, task='node')
        data = split_data(data)
        data = data.to(device)
        
        wandb.config.update({
            "model_type": "KAN-GPS",
            "dataset": args.dataset,
            "task": args.task,
            "num_features": data.num_features,
            "num_classes": num_classes,
            "learning_rate": 0.01,
            "weight_decay": 5e-4,
            "num_epochs": 200
        })
        
        model = GPSNetwork(data.num_features, num_classes, task='node').to(device)
        optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)
        
        for epoch in range(200):
            loss = train_node(model, data, optimizer)
            train_acc = test_node(model, data, data.train_mask)
            val_acc = test_node(model, data, data.val_mask)
            test_acc = test_node(model, data, data.test_mask)
            
            wandb.log({
                "epoch": epoch,
                "loss": loss,
                "train_acc": train_acc,
                "val_acc": val_acc,
                "test_acc": test_acc,
                "gpu_memory_allocated": torch.cuda.memory_allocated(),
                "gpu_memory_cached": torch.cuda.memory_cached()
            })
            
            if epoch % 10 == 0:
                print(f'Epoch: {epoch:03d}, Loss: {loss:.4f}, Train: {train_acc:.4f}, Val: {val_acc:.4f}, Test: {test_acc:.4f}')
    
    elif args.task == 'graph':
        # Graph classification task
        dataset, num_classes = load_dataset(args.dataset, task='graph')
        train_loader, val_loader, test_loader = prepare_graph_data(dataset)
        
        wandb.config.update({
            "model_type": "KAN-GPS",
            "dataset": args.dataset,
            "task": args.task,
            "num_features": dataset.num_node_features,
            "num_classes": num_classes,
            "learning_rate": 0.01,
            "weight_decay": 5e-4,
            "num_epochs": 200,
            "batch_size": train_loader.batch_size
        })
        
        model = GPSNetwork(dataset.num_node_features, num_classes, task='graph').to(device)
        optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
        
        for epoch in range(200):
            loss = train_graph(model, train_loader, optimizer, device)
            train_acc = test_graph(model, train_loader, device)
            val_acc = test_graph(model, val_loader, device)
            test_acc = test_graph(model, test_loader, device)
            
            wandb.log({
                "epoch": epoch,
                "loss": loss,
                "train_acc": train_acc,
                "val_acc": val_acc,
                "test_acc": test_acc,
                "gpu_memory_allocated": torch.cuda.memory_allocated(),
                "gpu_memory_cached": torch.cuda.memory_cached()
            })
            
            if epoch % 10 == 0:
                print(f'Epoch: {epoch:03d}, Loss: {loss:.4f}, Train: {train_acc:.4f}, Val: {val_acc:.4f}, Test: {test_acc:.4f}')
    
    else:
        raise ValueError("Task must be either 'node' or 'graph'")
    
    wandb.finish()

## Run the Experiment

To run an experiment, use the following code. Uncomment and modify as needed.

In [None]:
# Example usage:
# main(argparse.Namespace(dataset='Cora', task='node'))
# main(argparse.Namespace(dataset='MUTAG', task='graph'))

# For interactive argument parsing, you can use this:
if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='Train GPS Network on node or graph classification datasets')
    parser.add_argument('--dataset', type=str, default='Cora', help='Dataset name')
    parser.add_argument('--task', type=str, default='node', choices=['node', 'graph'], help='Task type: node or graph classification')
    args = parser.parse_args()
    
    main(args)