In [None]:
import pandas as pd
import numpy as np
import torch
import ot
import copy
from tqdm import tqdm
import pickle
import torch_geometric.transforms as T
import matplotlib.pyplot as plt

from torch_geometric.datasets import TUDataset

import networkx as nx

mutag = list(TUDataset('data', name='MUTAG'))
enzymes = list(TUDataset(root="data", name="ENZYMES"))
imdb = list(TUDataset("data", name="IMDB-BINARY"))

for graph in imdb:
    n = graph.num_nodes
    graph.x = torch.ones((n,1))

In [None]:
def get_computation_tree(graph, node, depth):
    """
    Returns the computation tree of a given node in the graph
    up to a certain depth.

    Parameters
    ----------
    graph : torch_geometric.data.Data
        The graph to extract the computation tree from.

    node : int
        The node to extract the computation tree from.

    depth : int
        The depth of the computation tree.

    Returns
    -------
    computation_tree : nx.DiGraph
        The computation tree of the given node.
    """
    # initialize the computation tree and add the root node with its features
    computation_tree = nx.DiGraph()
    computation_tree.add_node(node, features=graph.x[node])

    for i in range(depth):
        # get the nodes at the current depth
        current_nodes = list(computation_tree.nodes())
        for current_node in current_nodes:
            # check if the current node is a leaf node
            if computation_tree.in_degree(current_node) > 0:
                continue
            # get the neighbors of the current node
            neighbors = graph.edge_index[1][graph.edge_index[0] == current_node]
            for neighbor in neighbors:
                # add the neighbor to the computation tree
                if neighbor not in computation_tree.nodes():
                    computation_tree.add_node(neighbor, features=graph.x[neighbor])
                    computation_tree.add_edge(neighbor, current_node)

    return computation_tree


In [None]:
# visualize the first graph in the mutag dataset using networkx
graph = mutag[0]
G = nx.Graph()
G.add_nodes_from(range(graph.num_nodes))
edges = graph.edge_index.t().tolist()
G.add_edges_from(edges)
nx.draw(G, with_labels=True)
plt.show()

In [None]:
# compute the computation tree of the first node in the graph
computation_tree = get_computation_tree(graph, 0, 2)
nx.draw(computation_tree, with_labels=True)
plt.show()