<a href="https://colab.research.google.com/github/AbhiJeet70/PowerfulGNNs/blob/main/SubStructureGNN.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install torch torch-geometric networkx



In [None]:
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import MessagePassing
from torch_geometric.utils import k_hop_subgraph
from torch_geometric.data import Data
from torch_geometric.datasets import Planetoid


class SubstructureAwareGNN(nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels):
        super(SubstructureAwareGNN, self).__init__()
        self.ego_gnn = MessagePassingLayer(in_channels, hidden_channels)
        self.cut_gnn = MessagePassingLayer(in_channels, hidden_channels)
        self.global_encoder = nn.Linear(in_channels, hidden_channels)
        self.final_fc = nn.Linear(3 * hidden_channels, out_channels)

    def forward(self, x, edge_index):
        # Extract subgraphs
        ego_features = self.extract_ego_subgraph(x, edge_index)
        cut_features = self.extract_cut_subgraph(x, edge_index)

        # Apply GNN layers
        ego_encoded = self.ego_gnn(ego_features, edge_index)
        cut_encoded = self.cut_gnn(cut_features, edge_index)
        global_encoded = self.global_encoder(x)

        # Concatenate and pass through the final layer
        combined_features = torch.cat([ego_encoded, cut_encoded, global_encoded], dim=-1)
        output = self.final_fc(combined_features)
        return F.log_softmax(output, dim=1)

    def extract_ego_subgraph(self, x, edge_index):
        k = 2  # Number of hops
        num_nodes = x.size(0)
        ego_features = torch.zeros_like(x, device=x.device)  # Initialize features

        for node_idx in range(num_nodes):
            # Extract k-hop subgraph
            subset, _, _, _ = k_hop_subgraph(node_idx, k, edge_index, relabel_nodes=False)

            # Compute mean of neighbor features
            if subset.numel() > 0:
                ego_features[node_idx] = x[subset].mean(dim=0)
            else:
                ego_features[node_idx] = x[node_idx]  # Fallback to node's own features

        return ego_features

    def extract_cut_subgraph(self, x, edge_index):
        # Calculate edge betweenness centrality approximation
        edge_weights = torch.rand(edge_index.size(1), device=edge_index.device)  # Replace with actual edge weights
        num_edges_to_remove = edge_weights.size(0) // 2

        # Sort edges by weights and mask the top ones
        _, sorted_indices = edge_weights.sort(descending=True)
        mask = torch.ones(edge_index.size(1), dtype=torch.bool, device=edge_index.device)
        mask[sorted_indices[:num_edges_to_remove]] = False
        new_edge_index = edge_index[:, mask]

        # Aggregate features for the remaining subgraph
        num_nodes = x.size(0)
        cut_features = torch.zeros_like(x, device=x.device)

        for node_idx in range(num_nodes):
            # Find neighbors in the new edge_index
            neighbors = new_edge_index[1][new_edge_index[0] == node_idx]

            # Compute mean of neighbor features
            if neighbors.numel() > 0:
                cut_features[node_idx] = x[neighbors].mean(dim=0)
            else:
                cut_features[node_idx] = x[node_idx]  # Fallback to node's own features

        return cut_features


class MessagePassingLayer(MessagePassing):
    def __init__(self, in_channels, out_channels):
        super(MessagePassingLayer, self).__init__(aggr="add")
        self.linear = nn.Linear(in_channels, out_channels)

    def forward(self, x, edge_index):
        return self.propagate(edge_index, x=self.linear(x))

    def message(self, x_j):
        return x_j

    def update(self, aggr_out):
        return F.relu(aggr_out)


# Training and Evaluation
if __name__ == "__main__":
    # Check for GPU availability
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Using device: {device}")

    # Load Planetoid datasets
    datasets = ['Cora', 'CiteSeer', 'Pubmed']
    for dataset_name in datasets:
        dataset = Planetoid(root=f'./data/{dataset_name}', name=dataset_name)
        data = dataset[0].to(device)

        model = SubstructureAwareGNN(
            in_channels=dataset.num_node_features,
            hidden_channels=16,
            out_channels=dataset.num_classes
        ).to(device)

        optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)
        criterion = nn.NLLLoss()

        best_loss = float('inf')
        patience = 10
        patience_counter = 0

        model.train()
        for epoch in range(500):
            optimizer.zero_grad()
            out = model(data.x, data.edge_index)
            loss = criterion(out[data.train_mask], data.y[data.train_mask])
            loss.backward()
            optimizer.step()

            if loss.item() < best_loss:
                best_loss = loss.item()
                patience_counter = 0
            else:
                patience_counter += 1

            if epoch % 10 == 0 or epoch == 199:
                print(f"Dataset: {dataset_name}, Epoch: {epoch}, Loss: {loss.item():.4f}")

            if patience_counter >= patience:
                print(f"Early stopping on epoch {epoch} with best loss {best_loss:.4f}")
                break

        model.eval()
        _, pred = model(data.x, data.edge_index).max(dim=1)
        correct = int((pred[data.test_mask] == data.y[data.test_mask]).sum())
        acc = correct / int(data.test_mask.sum())
        print(f"Accuracy on {dataset_name}: {acc:.4f}")


Using device: cuda
Dataset: Cora, Epoch: 0, Loss: 1.9462
Dataset: Cora, Epoch: 10, Loss: 0.0734
Dataset: Cora, Epoch: 20, Loss: 0.0034
Dataset: Cora, Epoch: 30, Loss: 0.0008
Dataset: Cora, Epoch: 40, Loss: 0.0012
Early stopping on epoch 45 with best loss 0.0005
Accuracy on Cora: 0.7690
Dataset: CiteSeer, Epoch: 0, Loss: 1.7983
Dataset: CiteSeer, Epoch: 10, Loss: 0.0093
Dataset: CiteSeer, Epoch: 20, Loss: 0.0005
Dataset: CiteSeer, Epoch: 30, Loss: 0.0001
Dataset: CiteSeer, Epoch: 40, Loss: 0.0007
Early stopping on epoch 40 with best loss 0.0001
Accuracy on CiteSeer: 0.6470
Dataset: Pubmed, Epoch: 0, Loss: 1.0980
Dataset: Pubmed, Epoch: 10, Loss: 0.6330
Dataset: Pubmed, Epoch: 20, Loss: 0.2165
Dataset: Pubmed, Epoch: 30, Loss: 0.0673
Dataset: Pubmed, Epoch: 40, Loss: 0.0365
Dataset: Pubmed, Epoch: 50, Loss: 0.0341
Early stopping on epoch 59 with best loss 0.0328
Accuracy on Pubmed: 0.7850
