<a href="https://colab.research.google.com/github/AbhiJeet70/PowerfulGNNs/blob/main/ESAN_New.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torch.nn.functional as F
from torch_geometric.nn import GCNConv
from torch_geometric.datasets import Planetoid
from itertools import combinations
import networkx as nx
from torch_geometric.utils import to_networkx, from_networkx

# Necessary installs instructions
# To run this script, ensure the following packages are installed:
# pip install torch torchvision torchaudio
# pip install torch-geometric
# pip install torch-scatter torch-sparse torch-cluster torch-spline-conv
# pip install networkx

# Load Planetoid datasets (Cora, PubMed, Citeseer)
def load_dataset(name):
    dataset = Planetoid(root=f"./data/{name}", name=name)
    data = dataset[0]

    # Ensure node features are initialized
    if data.x is None:
        num_nodes = data.num_nodes
        data.x = torch.eye(num_nodes)  # One-hot encoding for nodes

    return dataset

# Define the ESAN Model
class ESAN(torch.nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim):
        super(ESAN, self).__init__()
        self.conv1 = GCNConv(input_dim, hidden_dim)
        self.conv2 = GCNConv(hidden_dim, hidden_dim)
        self.shared_aggregator = torch.nn.Linear(hidden_dim, output_dim)

    def forward(self, data):
        # Node-level prediction
        x, edge_index = data.x, data.edge_index
        x = self.conv1(x, edge_index)
        x = F.relu(x)
        x = F.dropout(x, p=0.5, training=self.training)
        x = self.conv2(x, edge_index)
        return F.log_softmax(x, dim=1)

# Subgraph Selection Policies
def generate_subgraphs(data, policy="edge_deleted", max_subgraphs=500):
    graph = to_networkx(data, to_undirected=True)
    subgraphs = []

    if policy == "edge_deleted":
        for i, edge in enumerate(graph.edges):
            if len(subgraphs) >= max_subgraphs:
                break
            subgraph = graph.copy()
            subgraph.remove_edge(*edge)
            pyg_subgraph = from_networkx(subgraph)
            if pyg_subgraph.x is None:
                pyg_subgraph.x = data.x[:pyg_subgraph.num_nodes]  # Copy features from original graph
            subgraphs.append(pyg_subgraph)

    elif policy == "node_deleted":
        for i, node in enumerate(graph.nodes):
            if len(subgraphs) >= max_subgraphs:
                break
            subgraph = graph.copy()
            subgraph.remove_node(node)
            pyg_subgraph = from_networkx(subgraph)
            if pyg_subgraph.x is None:
                pyg_subgraph.x = data.x[:pyg_subgraph.num_nodes]  # Copy features from original graph
            subgraphs.append(pyg_subgraph)

    elif policy == "ego":
        for i, node in enumerate(graph.nodes):
            if len(subgraphs) >= max_subgraphs:
                break
            subgraph = nx.ego_graph(graph, node, radius=1)
            pyg_subgraph = from_networkx(subgraph)
            if pyg_subgraph.x is None:
                pyg_subgraph.x = data.x[:pyg_subgraph.num_nodes]  # Copy features from original graph
            subgraphs.append(pyg_subgraph)

    return subgraphs

# Train the ESAN model
def train_model(model, data, optimizer, epochs=50):
    model.train()
    for epoch in range(epochs):
        optimizer.zero_grad()
        out = model(data)
        loss = F.nll_loss(out[data.train_mask], data.y[data.train_mask])
        loss.backward()
        optimizer.step()

        # Print training progress
        if (epoch + 1) % 10 == 0 or epoch == epochs - 1:
            train_acc = (out[data.train_mask].argmax(dim=1) == data.y[data.train_mask]).sum().item() / data.train_mask.sum().item()
            print(f"Epoch {epoch + 1}/{epochs}, Loss: {loss.item():.4f}, Train Accuracy: {train_acc:.4f}")

    return model

# Test the ESAN model
def test_model(model, data):
    model.eval()
    logits = model(data)
    accs = []
    for mask_name, mask in zip(["Train", "Validation", "Test"], [data.train_mask, data.val_mask, data.test_mask]):
        pred = logits[mask].argmax(dim=1)
        acc = pred.eq(data.y[mask]).sum().item() / mask.sum().item()
        accs.append(acc)
        print(f"{mask_name} Accuracy: {acc:.4f}")
    return accs

# Main execution
def main():
    datasets = ["Cora", "CiteSeer", "PubMed"]
    policies = ["edge_deleted", "node_deleted", "ego"]

    for dataset_name in datasets:
        dataset = load_dataset(dataset_name)
        data = dataset[0]

        for policy in policies:
            print(f"Processing policy: {policy} on dataset: {dataset_name}")
            subgraphs = generate_subgraphs(data, policy=policy, max_subgraphs=500)
            print(f"Generated {len(subgraphs)} subgraphs using {policy} policy")

            # Model parameters
            input_dim = dataset.num_node_features
            hidden_dim = 16
            output_dim = dataset.num_classes

            # Initialize model, optimizer
            model = ESAN(input_dim, hidden_dim, output_dim)
            optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)

            # Train and test the model
            print(f"Training on {dataset_name} with policy {policy}...")
            model = train_model(model, data, optimizer, epochs=50)

            print(f"Testing on {dataset_name} with policy {policy}...")
            train_acc, val_acc, test_acc = test_model(model, data)

            # Print results immediately
            print(f"Results for {dataset_name} with policy {policy}:")
            print(f"Train Accuracy: {train_acc:.4f}")
            print(f"Validation Accuracy: {val_acc:.4f}")
            print(f"Test Accuracy: {test_acc:.4f}\n")

if __name__ == "__main__":
    main()


Processing policy: edge_deleted on dataset: Cora
Generated 500 subgraphs using edge_deleted policy
Training on Cora with policy edge_deleted...
Epoch 10/50, Loss: 1.2796, Train Accuracy: 0.8000
Epoch 20/50, Loss: 0.4616, Train Accuracy: 0.9286
Epoch 30/50, Loss: 0.1704, Train Accuracy: 0.9857
Epoch 40/50, Loss: 0.1003, Train Accuracy: 0.9857
Epoch 50/50, Loss: 0.0893, Train Accuracy: 1.0000
Testing on Cora with policy edge_deleted...
Train Accuracy: 1.0000
Validation Accuracy: 0.7660
Test Accuracy: 0.7850
Results for Cora with policy edge_deleted:
Train Accuracy: 1.0000
Validation Accuracy: 0.7660
Test Accuracy: 0.7850

Processing policy: node_deleted on dataset: Cora
Generated 500 subgraphs using node_deleted policy
Training on Cora with policy node_deleted...
Epoch 10/50, Loss: 1.1993, Train Accuracy: 0.7214
Epoch 20/50, Loss: 0.4902, Train Accuracy: 0.9214
Epoch 30/50, Loss: 0.1949, Train Accuracy: 0.9571
Epoch 40/50, Loss: 0.0877, Train Accuracy: 0.9929
Epoch 50/50, Loss: 0.0702, T

Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.citeseer.x
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.citeseer.tx
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.citeseer.allx
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.citeseer.y
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.citeseer.ty
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.citeseer.ally
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.citeseer.graph
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.citeseer.test.index
Processing...
Done!


Processing policy: edge_deleted on dataset: CiteSeer
Generated 500 subgraphs using edge_deleted policy
Training on CiteSeer with policy edge_deleted...
Epoch 10/50, Loss: 0.6200, Train Accuracy: 0.8917
Epoch 20/50, Loss: 0.1684, Train Accuracy: 0.9500
Epoch 30/50, Loss: 0.0948, Train Accuracy: 0.9667
Epoch 40/50, Loss: 0.0586, Train Accuracy: 0.9833
Epoch 50/50, Loss: 0.0459, Train Accuracy: 1.0000
Testing on CiteSeer with policy edge_deleted...
Train Accuracy: 1.0000
Validation Accuracy: 0.6800
Test Accuracy: 0.6720
Results for CiteSeer with policy edge_deleted:
Train Accuracy: 1.0000
Validation Accuracy: 0.6800
Test Accuracy: 0.6720

Processing policy: node_deleted on dataset: CiteSeer
Generated 500 subgraphs using node_deleted policy
Training on CiteSeer with policy node_deleted...
Epoch 10/50, Loss: 0.8433, Train Accuracy: 0.7917
Epoch 20/50, Loss: 0.3272, Train Accuracy: 0.9417
Epoch 30/50, Loss: 0.1383, Train Accuracy: 0.9750
Epoch 40/50, Loss: 0.0762, Train Accuracy: 0.9833
Epoc