In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import MessagePassing, global_mean_pool
from torch_geometric.utils import add_self_loops

class MPNNLayer(MessagePassing):
    """
    Message Passing Neural Network Layer
    This layer performs message passing operations on graph-structured data.
    """
    def __init__(self, hidden_dim, edge_dim, activation=F.relu, dropout=0.0):
        """
        Initialize the MPNN Layer
        :param hidden_dim: Dimension of node features
        :param edge_dim: Dimension of edge features
        :param activation: Activation function to use (default: ReLU)
        :param dropout: Dropout rate (default: 0.0, no dropout)
        """
        super(MPNNLayer, self).__init__(aggr='add')  # Use 'add' aggregation for messages
        self.node_mlp = nn.Linear(hidden_dim, hidden_dim)  # MLP for updating node features
        self.message_mlp = nn.Sequential(
            nn.Linear(2 * hidden_dim + edge_dim + 3, hidden_dim),  # MLP for computing messages (+3 for relative position)
            nn.BatchNorm1d(hidden_dim),  # Batch normalization for stability
            activation(),
            nn.Dropout(dropout) if dropout > 0 else nn.Identity(),  # Optional dropout
            nn.Linear(hidden_dim, hidden_dim)
        )
        self.activation = activation
        self.dropout = nn.Dropout(dropout) if dropout > 0 else nn.Identity()
        
    def forward(self, x, edge_index, edge_attr, pos):
        """
        Forward pass of the MPNN Layer
        :param x: Node features
        :param edge_index: Graph connectivity
        :param edge_attr: Edge features
        :param pos: Node positions in 3D space
        :return: Updated node features
        """
        # Add self-loops to include self-information in message passing
        edge_index, _ = add_self_loops(edge_index, num_nodes=x.size(0))
        # Create zero-attribute for self-loops
        self_loop_attr = torch.zeros((x.size(0), edge_attr.size(1)),
                                     device=edge_attr.device, dtype=edge_attr.dtype)
        edge_attr = torch.cat([edge_attr, self_loop_attr], dim=0)
        # Start propagating messages
        return self.propagate(edge_index, x=x, edge_attr=edge_attr, pos=pos)
    
    def message(self, x_i, x_j, edge_attr, pos_i, pos_j):
        """
        Compute messages between nodes
        :param x_i: Features of target nodes
        :param x_j: Features of source nodes
        :param edge_attr: Edge features
        :param pos_i: Positions of target nodes
        :param pos_j: Positions of source nodes
        :return: Computed messages
        """
        rel_pos = pos_i - pos_j  # Compute relative positions
        # Concatenate node features, edge features, and relative position
        return self.message_mlp(torch.cat([x_i, x_j, edge_attr, rel_pos], dim=1))
    
    def update(self, aggr_out, x):
        """
        Update node features
        :param aggr_out: Aggregated messages
        :param x: Current node features
        :return: Updated node features
        """
        return self.dropout(self.activation(self.node_mlp(x) + aggr_out))

class MPNN(nn.Module):
    """
    Message Passing Neural Network (MPNN) model
    This model consists of multiple MPNN layers followed by a global pooling operation.
    """
    def __init__(self, node_dim, edge_dim, hidden_dim, output_dim, num_layers=3, activation=F.relu, dropout=0.0):
        """
        Initialize the MPNN model
        :param node_dim: Initial dimension of node features
        :param edge_dim: Dimension of edge features
        :param hidden_dim: Hidden dimension used throughout the network
        :param output_dim: Output dimension of the model
        :param num_layers: Number of MPNN layers to use
        :param activation: Activation function to use
        :param dropout: Dropout rate
        """
        super(MPNN, self).__init__()
        self.input_proj = nn.Linear(node_dim, hidden_dim)  # Project initial node features to hidden dimension
        
        # Create multiple MPNN layers
        self.layers = nn.ModuleList()
        for _ in range(num_layers):
            self.layers.append(MPNNLayer(hidden_dim, edge_dim, activation, dropout))
        
        # Output MLP
        self.output = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim),
            nn.BatchNorm1d(hidden_dim),
            activation(),
            nn.Dropout(dropout) if dropout > 0 else nn.Identity(),
            nn.Linear(hidden_dim, output_dim)
        )
        
    def forward(self, data):
        """
        Forward pass of the MPNN model
        :param data: PyTorch Geometric data object containing graph information
        :return: Output predictions
        """
        x, edge_index, edge_attr, pos, batch = data.x, data.edge_index, data.edge_attr, data.pos, data.batch
        
        h = self.input_proj(x)  # Initial projection of node features
        
        # Apply MPNN layers with residual connections
        for layer in self.layers:
            h = h + layer(h, edge_index, edge_attr, pos)  # residual connection
        
        h = global_mean_pool(h, batch)  # Global mean pooling
        return self.output(h)  # Final output projection

# Usage example
# model = MPNN(node_dim=32, edge_dim=16, hidden_dim=64, output_dim=10, num_layers=3, activation=F.relu, dropout=0.1)
# output = model(data)  # data: PyTorch Geometric's DataBatch object with pos attribute