In [24]:
import torch
from torch_geometric.nn import MessagePassing
from torch_geometric.utils import add_self_loops
import torch.nn.functional as F
from torch_scatter import scatter_mean
from torch.nn import Sequential, Linear, ReLU

class MPNNLayer(MessagePassing):
    def __init__(self, in_channels, out_channels):
        super(MPNNLayer, self).__init__(aggr='mean')  # "mean" aggregation.
        self.mlp = Sequential(Linear(in_channels, out_channels), ReLU())

    def forward(self, x, edge_index):
        # Add self-loops to the adjacency matrix.
        edge_index, _ = add_self_loops(edge_index, num_nodes=x.size(0))

        # Start propagating messages.
        return self.propagate(edge_index, x=x)

    def message(self, x_j):
        # x_j refers to the neighbors' node features.
        return self.mlp(x_j)

    def update(self, aggr_out):
        # aggr_out refers to the aggregated message.
        return aggr_out


class MultiLayerMPNN(torch.nn.Module):
    def __init__(self):
        super(MultiLayerMPNN, self).__init__()
        self.layer1 = MPNNLayer(3, 5)  # 3 input features to 5 output features
        self.layer2 = MPNNLayer(5, 7)  # 5 input features to 7 output features
        self.layer3 = MPNNLayer(7, 2)  # 7 input features to 2 output features

    def forward(self, x, edge_index):
        print("Input to layer 1:", x)
        x = self.layer1(x, edge_index)
        print("Output of layer 1 (Input to layer 2):", x)
        
        x = self.layer2(x, edge_index)
        print("Output of layer 2 (Input to layer 3):", x)
        
        x = self.layer3(x, edge_index)
        print("Output of layer 3 (Final output):", x)
        
        return x


In [33]:
from torchviz import make_dot

# Forward pass through the model and visualize.

model.eval()
# with torch.no_grad():
    # Forward pass through the model
output = model(x, edge_index)

# output = model(x, edge_index)

# Visualize the computation graph.
dot = make_dot(output, params=dict(model.named_parameters()))
dot.render("multilayer_mpnn_computation_graph", format="png")


Input to layer 1: tensor([[ 1.1041, -0.5924,  1.2338],
        [-0.9644,  0.9783,  1.6265],
        [ 0.7566, -2.4212,  0.4351],
        [ 0.6107, -1.7176, -0.5074]])
Output of layer 1 (Input to layer 2): tensor([[0.1006, 0.0000, 0.0499, 0.0000, 0.1198],
        [0.0234, 0.0000, 0.0499, 0.0000, 0.0155],
        [0.0351, 0.0000, 0.0000, 0.0000, 0.0232],
        [0.1509, 0.0000, 0.0749, 0.0000, 0.1797]], grad_fn=<DivBackward0>)
Output of layer 2 (Input to layer 3): tensor([[0.1296, 0.0000, 0.0000, 0.0000, 0.2232, 0.1935, 0.1819],
        [0.1486, 0.0000, 0.0000, 0.0000, 0.2026, 0.1949, 0.1879],
        [0.1614, 0.0000, 0.0000, 0.0000, 0.1882, 0.1957, 0.1895],
        [0.1106, 0.0000, 0.0000, 0.0000, 0.2438, 0.1923, 0.1820]],
       grad_fn=<DivBackward0>)
Output of layer 3 (Final output): tensor([[0.3105, 0.0000],
        [0.3161, 0.0000],
        [0.3185, 0.0000],
        [0.3076, 0.0000]], grad_fn=<DivBackward0>)


'multilayer_mpnn_computation_graph.png'

In [38]:
from torch.fx import symbolic_trace
from torch_geometric.nn import GCNConv
model = GCNConv(3,32)

# Trace the modelimport torch.fx

# If arange is causing issues, wrap it:
# torch.fx.wrap('torch.arange')

# Proceed with your model and tracing as before
traced_model = torch.fx.symbolic_trace(model(x,edge_index))


# Print the graph
print(traced_model.graph)


ValueError: `MessagePassing.propagate` only supports integer tensors of shape `[2, num_messages]`, `torch_sparse.SparseTensor` or `torch.sparse.Tensor` for argument `edge_index`.

In [39]:
import torch
from torch_geometric.nn import MessagePassing

def add_hooks_to_message_passing(mpnn_module):
    """
    Adds hooks to the message, update, and forward functions of a MessagePassing module.
    
    Parameters:
    mpnn_module (torch_geometric.nn.MessagePassing): The MPNN module to which hooks will be added.
    """
    def forward_hook(module, input, output):
        print(f"Forward Hook - Module: {module.__class__.__name__}")
        print(f"Input: {input}")
        print(f"Output: {output}")

    def message_hook(module, input):
        print(f"Message Hook - Module: {module.__class__.__name__}")
        print(f"Input to message: {input}")

    def update_hook(module, input, output):
        print(f"Update Hook - Module: {module.__class__.__name__}")
        print(f"Input to update: {input}")
        print(f"Output of update: {output}")

    # Register the hooks
    mpnn_module.register_forward_hook(forward_hook)
    mpnn_module.message = torch.nn.Module.register_forward_hook(mpnn_module, message_hook)
    mpnn_module.update = torch.nn.Module.register_forward_hook(mpnn_module, update_hook)
class SimpleMPNN(MessagePassing):
    def __init__(self, in_channels, out_channels):
        super(SimpleMPNN, self).__init__(aggr='mean')
        self.mlp = torch.nn.Sequential(
            torch.nn.Linear(in_channels, out_channels),
            torch.nn.ReLU()
        )

    def forward(self, x, edge_index):
        edge_index, _ = torch_geometric.utils.add_self_loops(edge_index, num_nodes=x.size(0))
        return self.propagate(edge_index, x=x)

    def message(self, x_j):
        return self.mlp(x_j)

    def update(self, aggr_out):
        return aggr_out

# Instantiate the MPNN
mpnn_instance = SimpleMPNN(in_channels=3, out_channels=2)

# Add hooks to the MPNN instance
add_hooks_to_message_passing(mpnn_instance)

# Dummy input data
x = torch.randn((4, 3))  # 4 nodes with 3 features each
edge_index = torch.tensor([[0, 1, 1, 2, 3, 0], 
                           [1, 0, 2, 1, 0, 3]], dtype=torch.long)

# Run the MPNN with hooks enabled
with torch.no_grad():
    output = mpnn_instance(x, edge_index)

print("Final Output:", output)


TypeError: 'RemovableHandle' object is not callable