<a href="https://colab.research.google.com/github/AbhiJeet70/PowerfulGNNs/blob/main/ESAN_New.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# Install necessary packages
!pip install torch torch-geometric

import torch
import torch.nn.functional as F
from torch_geometric.nn import GCNConv
from torch_geometric.datasets import Planetoid
from itertools import combinations
import networkx as nx
from torch_geometric.utils import to_networkx, from_networkx

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


def load_dataset(name):
    dataset = Planetoid(root=f"./data/{name}", name=name)
    data = dataset[0]

    # Ensure node features are initialized
    if data.x is None:
        num_nodes = data.num_nodes
        data.x = torch.eye(num_nodes)  # One-hot encoding for nodes

    return dataset

# Define the ESAN Model
class ESAN(torch.nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim):
        super(ESAN, self).__init__()
        self.conv1 = GCNConv(input_dim, hidden_dim)
        self.conv2 = GCNConv(hidden_dim, hidden_dim)
        self.shared_aggregator = torch.nn.Linear(hidden_dim, output_dim)

    def forward(self, subgraphs, num_nodes, batch_size=50):
        # Initialize tensors to store node predictions and counts
        device = next(self.parameters()).device
        node_predictions = torch.zeros((num_nodes, self.shared_aggregator.out_features), device=device)
        node_counts = torch.zeros(num_nodes, device=device)

        # Process subgraphs in batches
        for i in range(0, len(subgraphs), batch_size):
            batch = subgraphs[i:i + batch_size]
            for subgraph in batch:
                x, edge_index = subgraph.x.to(device), subgraph.edge_index.to(device)
                x = self.conv1(x, edge_index)
                x = F.relu(x)
                x = F.dropout(x, p=0.5, training=self.training)
                x = self.conv2(x, edge_index)

                # Map to output dimension (num_classes)
                x = self.shared_aggregator(x)

                # Aggregate features for nodes in the subgraph
                node_predictions[subgraph.n_id] += x
                node_counts[subgraph.n_id] += 1

        # Average predictions for nodes that appear in multiple subgraphs
        node_predictions = node_predictions / node_counts.unsqueeze(1).clamp(min=1)
        return F.log_softmax(node_predictions, dim=1)



# Train the ESAN model
def train_model(model, subgraphs, data, optimizer, epochs=50):
    model.train()
    for epoch in range(epochs):
        optimizer.zero_grad()
        out = model(subgraphs, data.num_nodes)  # Process subgraphs through the model
        loss = F.nll_loss(out[data.train_mask], data.y[data.train_mask])
        loss.backward()
        optimizer.step()

        # Print training progress
        if (epoch + 1) % 10 == 0 or epoch == epochs - 1:
            train_acc = (out[data.train_mask].argmax(dim=1) == data.y[data.train_mask]).sum().item() / data.train_mask.sum().item()
            print(f"Epoch {epoch + 1}/{epochs}, Loss: {loss.item():.4f}, Train Accuracy: {train_acc:.4f}")

    return model

# Test the ESAN model
def test_model(model, subgraphs, data):
    model.eval()
    logits = model(subgraphs, data.num_nodes)  # Process subgraphs through the model
    accs = []
    for mask_name, mask in zip(["Train", "Validation", "Test"], [data.train_mask, data.val_mask, data.test_mask]):
        pred = logits[mask].argmax(dim=1)
        acc = pred.eq(data.y[mask]).sum().item() / mask.sum().item()
        accs.append(acc)
        print(f"{mask_name} Accuracy: {acc:.4f}")
    return accs

def generate_subgraphs(data, policy="edge_deleted", max_subgraphs=300):
    graph = to_networkx(data, to_undirected=True)
    subgraphs = []

    if policy == "edge_deleted":
        for i, edge in enumerate(graph.edges):
            if len(subgraphs) >= max_subgraphs:
                break
            subgraph = graph.copy()
            subgraph.remove_edge(*edge)
            pyg_subgraph = from_networkx(subgraph)
            pyg_subgraph.n_id = torch.tensor(list(subgraph.nodes))  # Map subgraph nodes to original graph nodes
            if pyg_subgraph.x is None:
                pyg_subgraph.x = data.x[pyg_subgraph.n_id]  # Use features from the original graph
            subgraphs.append(pyg_subgraph)

    elif policy == "node_deleted":
        for i, node in enumerate(graph.nodes):
            if len(subgraphs) >= max_subgraphs:
                break
            subgraph = graph.copy()
            subgraph.remove_node(node)
            pyg_subgraph = from_networkx(subgraph)
            pyg_subgraph.n_id = torch.tensor(list(subgraph.nodes))  # Map subgraph nodes to original graph nodes
            if pyg_subgraph.x is None:
                pyg_subgraph.x = data.x[pyg_subgraph.n_id]  # Use features from the original graph
            subgraphs.append(pyg_subgraph)

    elif policy == "ego":
        radius = 2
        for i, node in enumerate(graph.nodes):
            if len(subgraphs) >= max_subgraphs:
                break
            # Generate ego graph for the node with the specified radius
            subgraph = nx.ego_graph(graph, node, radius=radius)

            # Convert the subgraph to PyTorch Geometric format
            pyg_subgraph = from_networkx(subgraph)

            # Add mapping of subgraph nodes to original graph nodes
            pyg_subgraph.n_id = torch.tensor(list(subgraph.nodes))  # Map subgraph nodes

            # Add central node feature
            central_node_feature = torch.zeros(len(subgraph.nodes), 1)
            central_node_idx = list(subgraph.nodes).index(node)  # Index of the central node
            central_node_feature[central_node_idx] = 1

            # Combine central node feature with original features
            if pyg_subgraph.x is None:
                pyg_subgraph.x = data.x[pyg_subgraph.n_id]  # Use original features from the graph
            pyg_subgraph.x = torch.cat([pyg_subgraph.x, central_node_feature], dim=1)  # Add central node feature

            # Normalize the features (optional)
            pyg_subgraph.x = F.normalize(pyg_subgraph.x, p=2, dim=1)

            subgraphs.append(pyg_subgraph)


    return subgraphs


# Main execution
def main():
    datasets = ["Cora", "CiteSeer", "PubMed"]
    policies = ["edge_deleted", "node_deleted", "ego"]

    for dataset_name in datasets:
        dataset = load_dataset(dataset_name)
        data = dataset[0]

        for policy in policies:
            print(f"Processing policy: {policy} on dataset: {dataset_name}")
            subgraphs = generate_subgraphs(data, policy=policy, max_subgraphs=300)
            print(f"Generated {len(subgraphs)} subgraphs using {policy} policy")


            # Model parameters
            input_dim = dataset.num_node_features + (1 if policy == "ego" else 0)
            hidden_dim = 16
            output_dim = dataset.num_classes

            # Initialize model, optimizer
            model = ESAN(input_dim, hidden_dim, output_dim)
            optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)


            # Train and test the model
            print(f"Training on {dataset_name} with policy {policy}...")
            model = train_model(model, subgraphs, data, optimizer, epochs=50)

            print(f"Testing on {dataset_name} with policy {policy}...")
            test_model(model, subgraphs, data)

if __name__ == "__main__":
    main()


Processing policy: edge_deleted on dataset: Cora
Generated 300 subgraphs using edge_deleted policy
Training on Cora with policy edge_deleted...
Epoch 10/50, Loss: 0.8448, Train Accuracy: 0.8357
Epoch 20/50, Loss: 0.0806, Train Accuracy: 1.0000
Epoch 30/50, Loss: 0.0058, Train Accuracy: 1.0000
Epoch 40/50, Loss: 0.0020, Train Accuracy: 1.0000
Epoch 50/50, Loss: 0.0023, Train Accuracy: 1.0000
Testing on Cora with policy edge_deleted...
Train Accuracy: 1.0000
Validation Accuracy: 0.7540
Test Accuracy: 0.7810
Processing policy: node_deleted on dataset: Cora
Generated 300 subgraphs using node_deleted policy
Training on Cora with policy node_deleted...
Epoch 10/50, Loss: 0.8650, Train Accuracy: 0.9429
Epoch 20/50, Loss: 0.0778, Train Accuracy: 1.0000
Epoch 30/50, Loss: 0.0064, Train Accuracy: 1.0000
Epoch 40/50, Loss: 0.0024, Train Accuracy: 1.0000
Epoch 50/50, Loss: 0.0025, Train Accuracy: 1.0000
Testing on Cora with policy node_deleted...
Train Accuracy: 1.0000
Validation Accuracy: 0.7520
