In [None]:
import pandas as pd
import numpy as np
import torch
import ot
import copy
from tqdm import tqdm
import pickle
import torch_geometric.transforms as T
import matplotlib.pyplot as plt

from torch_geometric.datasets import TUDataset

import networkx as nx

mutag = list(TUDataset('data', name='MUTAG'))
enzymes = list(TUDataset(root="data", name="ENZYMES"))
imdb = list(TUDataset("data", name="IMDB-BINARY"))

for graph in imdb:
    n = graph.num_nodes
    graph.x = torch.ones((n,1))

In [None]:
def get_computation_tree(graph, node, depth):
    """
    Returns the computation tree of a given node in the graph
    up to a certain depth.

    Parameters
    ----------
    graph : torch_geometric.data.Data
        The graph to extract the computation tree from.

    node : int
        The node to extract the computation tree from.

    depth : int
        The depth of the computation tree.

    Returns
    -------
    computation_tree : nx.DiGraph
        The computation tree of the given node.
    """
    # initialize the computation tree and add the root node with its features
    computation_tree = nx.DiGraph()
    computation_tree.add_node(node, features=graph.x[node])

    # define an auxxiliary function that for a given node, returns a directed
    # graph with directed edges from all its neighbors to the node
    def get_subgraph(graph, node):
        subgraph = nx.DiGraph()
        subgraph.add_node(node, features=graph.x[node])

        # get the neighbors of the node
        neighbors = graph.edge_index[1][graph.edge_index[0] == node]

        # add the neighbors to the subgraph
        for neighbor in neighbors:
            subgraph.add_node(neighbor, features=graph.x[neighbor])
            subgraph.add_edge(neighbor, node)

        return subgraph
    
    # define an auxiliary function that builds the computation tree
    def build_computation_tree(graph, node, depth, computation_tree):
        # if the depth is 0, return
        if depth == 0:
            return

        # get the subgraph of the node
        subgraph = get_subgraph(graph, node)

        # add the subgraph to the computation tree
        computation_tree = nx.compose(computation_tree, subgraph)

        # for each neighbor of the node, build the computation tree
        for neighbor in subgraph.predecessors(node):
            computation_tree = build_computation_tree(graph, neighbor, depth-1, computation_tree)

        return computation_tree

    # build the computation tree
    build_computation_tree(graph, node, depth, computation_tree)

    return computation_tree

In [None]:
# visualize the first graph in the mutag dataset using networkx
graph = mutag[0]
G = nx.Graph()
G.add_nodes_from(range(graph.num_nodes))
edges = graph.edge_index.t().tolist()
G.add_edges_from(edges)
nx.draw(G, with_labels=True)
plt.show()

In [None]:
# compute the computation tree of the first node in the graph
computation_tree = get_computation_tree(graph, 0, 2)
nx.draw(computation_tree, with_labels=True)
plt.show()