<a href="https://colab.research.google.com/github/AbhiJeet70/PowerfulGNNs/blob/main/SUN_NEW.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
#@inproceedings{frasca2022understanding,
#title={Understanding and Extending Subgraph GNNs by Rethinking Their Symmetries},
#author={Frasca, Fabrizio and Bevilacqua, Beatrice and Bronstein, Michael M and Maron, Haggai},
#booktitle={Advances in Neural Information Processing Systems},
#year={2022},
#}

# Install necessary packages
!pip install torch torch-geometric

import torch
import torch.nn as nn
from torch_geometric.data import DataLoader
from torch_geometric.datasets import Planetoid
from torch_geometric.utils import to_dense_adj, k_hop_subgraph
from torch_geometric.nn import MessagePassing
import torch.nn.functional as F

class SUNLayer(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(SUNLayer, self).__init__()
        # Separate transformations for root and non-root nodes
        self.root_mlp = nn.Linear(in_channels, out_channels)
        self.non_root_mlp = nn.Linear(in_channels, out_channels)
        self.global_mlp = nn.Linear(in_channels, out_channels)

    def forward(self, x, edge_index, subgraph_masks):
        # Convert edge_index to dense adjacency matrix
        adjacency_matrix = to_dense_adj(edge_index, max_num_nodes=x.size(0))[0]

        # Local message passing within subgraphs
        local_features = torch.matmul(adjacency_matrix, x)

        # Global aggregation across subgraphs
        global_features = self.global_mlp(torch.mean(x, dim=0, keepdim=True))
        global_features = global_features.expand(x.size(0), global_features.size(1))  # Broadcast to match x's shape

        # Initialize root and non-root features with correct shape
        root_features = torch.zeros((x.size(0), global_features.size(1)), device=x.device)
        non_root_features = torch.zeros((x.size(0), global_features.size(1)), device=x.device)

        # Apply transformations to root nodes
        root_features[subgraph_masks] = self.root_mlp(x[subgraph_masks])

        # Apply transformations to non-root nodes
        non_root_features = self.non_root_mlp(local_features)

        # Combine root, non-root, and global updates
        updated_features = root_features + non_root_features + global_features

        return updated_features


# Define the SUN model
class SUN(nn.Module):
    def __init__(self, num_features, num_classes, hidden_channels):
        super(SUN, self).__init__()
        self.layer1 = SUNLayer(num_features, hidden_channels)
        self.layer2 = SUNLayer(hidden_channels, num_classes)

    def forward(self, data):
        x, edge_index = data.x, data.edge_index

        # Create subgraph masks for root and non-root nodes using a node-based policy (e.g., 1-hop ego nets)
        num_nodes = x.size(0)
        subgraph_masks = torch.zeros(num_nodes, dtype=torch.bool)

        # Example subgraph extraction: mark every node as root for simplicity
        for i in range(num_nodes):
            _, _, _, node_mask = k_hop_subgraph(i, 1, edge_index, relabel_nodes=False, num_nodes=num_nodes)
            subgraph_masks[node_mask[:num_nodes]] = True  # Ensure alignment with graph size

        # Pass through SUN layers
        x = self.layer1(x, edge_index, subgraph_masks)
        x = F.relu(x)
        x = self.layer2(x, edge_index, subgraph_masks)
        return x

# Training and Evaluation
if __name__ == "__main__":
    # Load Planetoid datasets
    datasets = ['Cora', 'Pubmed', 'CiteSeer']
    results = {}

    for dataset_name in datasets:
        dataset = Planetoid(root=f'./data/{dataset_name}', name=dataset_name)
        data = dataset[0]

        # Initialize model, optimizer, and loss function
        model = SUN(dataset.num_features, dataset.num_classes, hidden_channels=32)
        optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
        criterion = nn.CrossEntropyLoss()

        # Train the model
        def train():
            for epoch in range(200):
                model.train()
                optimizer.zero_grad()
                out = model(data)
                loss = criterion(out[data.train_mask], data.y[data.train_mask])
                loss.backward()
                optimizer.step()
                print(f"Epoch {epoch+1}, Loss: {loss.item():.4f}")

        # Test the model
        def test():
            model.eval()
            with torch.no_grad():
                pred = model(data).argmax(dim=1)
                acc = (pred[data.test_mask] == data.y[data.test_mask]).sum() / data.test_mask.sum()
                return acc.item()

        train()
        accuracy = test()
        results[dataset_name] = accuracy

    # Print results
    for dataset_name, accuracy in results.items():
        print(f"Dataset: {dataset_name}, Test Accuracy: {accuracy:.4f}")


Epoch 1, Loss: 2.3401
Epoch 2, Loss: 1.5984
Epoch 3, Loss: 1.3868
Epoch 4, Loss: 0.8538
Epoch 5, Loss: 0.7661
Epoch 6, Loss: 0.5640
Epoch 7, Loss: 0.4513
Epoch 8, Loss: 0.3978
Epoch 9, Loss: 0.2903
Epoch 10, Loss: 0.1773
Epoch 11, Loss: 0.1312
Epoch 12, Loss: 0.1427
Epoch 13, Loss: 0.0696
Epoch 14, Loss: 0.0513
Epoch 15, Loss: 0.0389
Epoch 16, Loss: 0.0295
Epoch 17, Loss: 0.0224
Epoch 18, Loss: 0.0181
Epoch 19, Loss: 0.0145
Epoch 20, Loss: 0.0093
Epoch 21, Loss: 0.0062
Epoch 22, Loss: 0.0045
Epoch 23, Loss: 0.0034
Epoch 24, Loss: 0.0027
Epoch 25, Loss: 0.0021
Epoch 26, Loss: 0.0018
Epoch 27, Loss: 0.0015
Epoch 28, Loss: 0.0013
Epoch 29, Loss: 0.0012
Epoch 30, Loss: 0.0010
Epoch 31, Loss: 0.0009
Epoch 32, Loss: 0.0008
Epoch 33, Loss: 0.0007
Epoch 34, Loss: 0.0006
Epoch 35, Loss: 0.0005
Epoch 36, Loss: 0.0005
Epoch 37, Loss: 0.0004
Epoch 38, Loss: 0.0004
Epoch 39, Loss: 0.0003
Epoch 40, Loss: 0.0003
Epoch 41, Loss: 0.0003
Epoch 42, Loss: 0.0002
Epoch 43, Loss: 0.0002
Epoch 44, Loss: 0.00

Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.pubmed.x
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.pubmed.tx
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.pubmed.allx
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.pubmed.y
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.pubmed.ty
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.pubmed.ally
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.pubmed.graph
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.pubmed.test.index
Processing...
Done!


Epoch 1, Loss: 1.1357
Epoch 2, Loss: 0.9954
Epoch 3, Loss: 0.8519
Epoch 4, Loss: 0.7326
Epoch 5, Loss: 0.6005
Epoch 6, Loss: 0.5465
Epoch 7, Loss: 0.4643
Epoch 8, Loss: 0.3931
Epoch 9, Loss: 0.3455
Epoch 10, Loss: 0.2991
Epoch 11, Loss: 0.2557
Epoch 12, Loss: 0.2160
Epoch 13, Loss: 0.1810
Epoch 14, Loss: 0.1509
Epoch 15, Loss: 0.1247
Epoch 16, Loss: 0.1011
Epoch 17, Loss: 0.0801
Epoch 18, Loss: 0.0628
Epoch 19, Loss: 0.0505
Epoch 20, Loss: 0.0397
Epoch 21, Loss: 0.0322
Epoch 22, Loss: 0.0262
Epoch 23, Loss: 0.0211
Epoch 24, Loss: 0.0168
Epoch 25, Loss: 0.0134
Epoch 26, Loss: 0.0108
Epoch 27, Loss: 0.0087
Epoch 28, Loss: 0.0072
Epoch 29, Loss: 0.0060
Epoch 30, Loss: 0.0051
Epoch 31, Loss: 0.0043
Epoch 32, Loss: 0.0037
Epoch 33, Loss: 0.0032
Epoch 34, Loss: 0.0028
Epoch 35, Loss: 0.0024
Epoch 36, Loss: 0.0021
Epoch 37, Loss: 0.0019
Epoch 38, Loss: 0.0017
Epoch 39, Loss: 0.0015
Epoch 40, Loss: 0.0014
Epoch 41, Loss: 0.0012
Epoch 42, Loss: 0.0011
Epoch 43, Loss: 0.0010
Epoch 44, Loss: 0.00

Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.citeseer.x
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.citeseer.tx
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.citeseer.allx
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.citeseer.y
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.citeseer.ty
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.citeseer.ally
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.citeseer.graph
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.citeseer.test.index
Processing...
Done!


Epoch 1, Loss: 1.9279
Epoch 2, Loss: 1.8588
Epoch 3, Loss: 0.9494
Epoch 4, Loss: 0.5665
Epoch 5, Loss: 0.4185
Epoch 6, Loss: 0.2994
Epoch 7, Loss: 0.1989
Epoch 8, Loss: 0.1201
Epoch 9, Loss: 0.0656
Epoch 10, Loss: 0.0364
Epoch 11, Loss: 0.0222
Epoch 12, Loss: 0.0163
Epoch 13, Loss: 0.0118
Epoch 14, Loss: 0.0060
Epoch 15, Loss: 0.0033
Epoch 16, Loss: 0.0019
Epoch 17, Loss: 0.0012
Epoch 18, Loss: 0.0008
Epoch 19, Loss: 0.0005
Epoch 20, Loss: 0.0004
Epoch 21, Loss: 0.0003
Epoch 22, Loss: 0.0002
Epoch 23, Loss: 0.0001
Epoch 24, Loss: 0.0001
Epoch 25, Loss: 0.0001
Epoch 26, Loss: 0.0001
Epoch 27, Loss: 0.0001
Epoch 28, Loss: 0.0001
Epoch 29, Loss: 0.0000
Epoch 30, Loss: 0.0000
Epoch 31, Loss: 0.0000
Epoch 32, Loss: 0.0000
Epoch 33, Loss: 0.0000
Epoch 34, Loss: 0.0000
Epoch 35, Loss: 0.0000
Epoch 36, Loss: 0.0000
Epoch 37, Loss: 0.0000
Epoch 38, Loss: 0.0000
Epoch 39, Loss: 0.0000
Epoch 40, Loss: 0.0000
Epoch 41, Loss: 0.0000
Epoch 42, Loss: 0.0000
Epoch 43, Loss: 0.0000
Epoch 44, Loss: 0.00