In [None]:
import sys
import os
import torch
notebook_dir = os.getcwd()
project_root = os.path.abspath(os.path.join(notebook_dir, ".."))
sys.path.append(project_root)

import torch
import networkx as nx
import matplotlib.pyplot as plt
from torch_geometric.utils import to_networkx
from DeepLearningPH.simulations.n_body_simulation import n_body_simulation
from gnn_model.node_data_list import node_data_list

def plot_graph(data, feature_dim_to_show=-2, with_labels=True, node_size=600, cmap='viridis'):
    """
    Plots a PyTorch Geometric Data object as a graph using networkx.

    Args:
        data (torch_geometric.data.Data): The graph data.
        feature_dim_to_show (int): Which feature dimension to use as color/label.
        with_labels (bool): Whether to annotate nodes with feature values.
        node_size (int): Node size in the plot.
        cmap (str): Matplotlib colormap for node colors.
    """
    # Convert to NetworkX graph
    G = to_networkx(data, to_undirected=False)
    # Layout for nicer appearance
    pos = nx.spring_layout(G, seed=42)

    nx.draw_networkx_edges(
    G, pos,
    arrowsize=20,
    connectionstyle='arc3,rad=0.05',  # curve the edges
    arrow_pos="middle"
    )

    nx.draw_networkx_edges(
    G, pos,
    arrowsize=20,
    connectionstyle='arc3,rad=-0.05',
    arrow_pos="middle"
    )

    # Get node features (for coloring or labeling)
    if hasattr(data, 'x') and data.x is not None:
        node_attrs = data.x[:, feature_dim_to_show].detach().cpu().numpy()
    else:
        node_attrs = None

    # Draw nodes
    nodes = nx.draw_networkx_nodes(G, pos, node_size=node_size,
                                   node_color=node_attrs, cmap=cmap)

    # Optionally add labels
    if with_labels and node_attrs is not None:
        labels = {i: f"{val:.2f}" for i, val in enumerate(node_attrs)}
        nx.draw_networkx_labels(G, pos, labels, font_size=10, font_color="black")

    plt.title(f"{len(data.x[:, feature_dim_to_show])}-body gravity simulation")
    plt.axis('off')
    plt.show()

In [46]:
from torch_geometric.data import Data

# Create example graph with 5 nodes and random features
abcd = n_body_simulation()
abc = node_data_list(abcd, self_loop=False, complete_graph=True)

data = abc[0]

# Plot, using the first feature dimension as color/label
plot_graph(data, feature_dim_to_show=0)


TypeError: draw_networkx_edges() got an unexpected keyword argument 'arrow_pos'