# Equivariant Subgraph Aggregation Networks: $Simple Code$ (ESAN)

In [15]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch_geometric.datasets import Planetoid
from torch_geometric.utils import to_dense_adj

# ESANLayer for subgraph sampling and node embedding
class ESANLayer(nn.Module):
    def __init__(self, input_dim, output_dim, k_hop=1):
        super(ESANLayer, self).__init__()
        self.k_hop = k_hop
        self.node_encoder = nn.Linear(input_dim, output_dim)
    
    def extract_subgraph(self, x, adj, node_idx):
        # Initialize the set of neighbors with the central node
        neighbors = torch.tensor([node_idx])
        visited = set(neighbors.tolist())

        # Perform k-hop neighborhood expansion
        for _ in range(self.k_hop):
            next_neighbors = []
            for neighbor in neighbors:
                # Find all nodes connected to the current neighbor
                neighbor_nodes = (adj[neighbor] > 0).nonzero(as_tuple=True)[0]
                for node in neighbor_nodes:
                    if node.item() not in visited:
                        visited.add(node.item())
                        next_neighbors.append(node.item())
            neighbors = torch.tensor(next_neighbors)
        
        # Create subgraph with the visited nodes
        subgraph_nodes = torch.tensor(list(visited))
        ego_x = x[subgraph_nodes]
        ego_adj = adj[subgraph_nodes][:, subgraph_nodes]
        return ego_x, ego_adj, subgraph_nodes

    def forward(self, x, adj):
        node_embeddings = []
        for node_idx in range(x.size(0)):
            # Extract the subgraph centered at the current node
            ego_x, ego_adj, subgraph_nodes = self.extract_subgraph(x, adj, node_idx)
            # Encode the node features within the subgraph
            node_embedding = F.relu(self.node_encoder(ego_x))
            # Apply equivariant pooling within the subgraph to get a single embedding for the central node
            pooled_embedding = torch.mean(node_embedding, dim=0)
            node_embeddings.append(pooled_embedding)
        return torch.stack(node_embeddings, dim=0)  # Shape: [num_nodes, output_dim]

In [16]:
# ESANModel using ESANLayer with EquivariantPooling concept embedded in ESANLayer
class ESANModel(nn.Module):
    def __init__(self, input_dim, hidden_dim, num_classes, k_hop=1):
        super(ESANModel, self).__init__()
        self.esan_layer = ESANLayer(input_dim, hidden_dim, k_hop)
        self.output_layer = nn.Linear(hidden_dim, num_classes)

    def forward(self, x, adj):
        h = self.esan_layer(x, adj)  # Node embeddings after equivariant pooling for subgraphs
        return self.output_layer(h)  # Output node-level predictions

In [17]:
# Load dataset
dataset = Planetoid(root='/tmp/Cora', name='Cora')
data = dataset[0]
adj_matrix = to_dense_adj(data.edge_index)[0]  # Convert to dense adjacency for simplicity

# Initialize model, loss, and optimizer
model = ESANModel(input_dim=dataset.num_features, hidden_dim=64, num_classes=dataset.num_classes)
optimizer = optim.Adam(model.parameters(), lr=0.01)
criterion = nn.CrossEntropyLoss()

In [18]:
# Training loop
for epoch in range(10):
    model.train()
    optimizer.zero_grad()
    
    # Forward pass
    out = model(data.x, adj_matrix)
    
    # Compute loss and backpropagate
    loss = criterion(out[data.train_mask], data.y[data.train_mask])
    loss.backward()
    optimizer.step()
    
    # Evaluation
    model.eval()
    with torch.no_grad():
        out = model(data.x, adj_matrix)
        _, pred = out[data.test_mask].max(dim=1)
        correct = pred.eq(data.y[data.test_mask]).sum().item()
        acc = correct / data.test_mask.sum().item()
    
    print(f"Epoch {epoch+1:03d}, Loss: {loss.item():.4f}, Test Accuracy: {acc:.4f}")


Epoch 001, Loss: 1.9465, Test Accuracy: 0.2350
Epoch 002, Loss: 1.8215, Test Accuracy: 0.5370
Epoch 003, Loss: 1.6361, Test Accuracy: 0.5860
Epoch 004, Loss: 1.4048, Test Accuracy: 0.6010
Epoch 005, Loss: 1.1539, Test Accuracy: 0.6190
Epoch 006, Loss: 0.9023, Test Accuracy: 0.6550
Epoch 007, Loss: 0.6689, Test Accuracy: 0.7040
Epoch 008, Loss: 0.4706, Test Accuracy: 0.7440
Epoch 009, Loss: 0.3166, Test Accuracy: 0.7710
Epoch 010, Loss: 0.2059, Test Accuracy: 0.7830
