## Translation Invariant GNN


In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import MessagePassing

In [None]:
class TranslationInvariantGNN(MessagePassing):
    def __init__(self, node_features, hidden_channels):
        super().__init__(aggr='add')  # Aggregation as shown in PDF
        
        # Edge operation (φe) that takes hi, hj, and ||xi - xj||^2
        self.edge_mlp = nn.Sequential(
            nn.Linear(2 * node_features + 1, hidden_channels), # +1 for the distance
            nn.ReLU(),
            nn.Linear(hidden_channels, hidden_channels)
        )
        
        # Node operation (φh)
        self.node_mlp = nn.Sequential(
            nn.Linear(node_features + hidden_channels, hidden_channels),
            nn.ReLU(),
            nn.Linear(hidden_channels, hidden_channels)
        )

    def forward(self, x, edge_index, pos):
        """
        x: Node features [N, node_features]
        edge_index: Graph connectivity [2, E]
        pos: Node positions [N, D] - could be learned or initialized
        """
        return self.propagate(edge_index, x=x, pos=pos)

    def message(self, x_i, x_j, pos_i, pos_j):
        # Compute ||xi - xj||^2
        distance = torch.sum((pos_i - pos_j) ** 2, dim=1).unsqueeze(1)
        
        # Concatenate features and distance
        msg_features = torch.cat([x_i, x_j, distance], dim=-1)
        
        return self.edge_mlp(msg_features)

    def update(self, aggr_out, x):
        # Combine node's current features with aggregated messages
        update_features = torch.cat([x, aggr_out], dim=-1)
        return self.node_mlp(update_features)