In [67]:
import json
import os
from time import sleep

import pandas as pd
from graphdatascience import GraphDataScience
import torch
from torch_geometric.data import Data
from torch_geometric.loader import DataLoader
import random
import numpy as np

from src.shared.database_wrapper import DatabaseWrapper

import networkx as nx
import plotly.graph_objects as go
from neo4j import GraphDatabase, Result
from torch_geometric.data import HeteroData

from src.shared.graph_schema import NodeType, EdgeType, node_one_hot, edge_one_hot, edge_val_to_pyg_key_vals
from src.shared import config


random.seed(42)
np.random.seed(42)
torch.manual_seed(42)
torch.cuda.manual_seed_all(42)

In [68]:
driver = GraphDatabase.driver(config.DB_URI, auth=(config.DB_USER, config.DB_PASSWORD))
gds = GraphDataScience(config.DB_URI, auth=(config.DB_USER, config.DB_PASSWORD))

In [69]:


def fetch_n_hop_neighbourhood(start_node_type: NodeType, start_node_id: str, node_attr: str, node_types: list = None, edge_types: list = None, max_level: int = 6):
    with driver.session() as session:
        node_filter = '|'.join(
            [nt.value for nt in NodeType] if node_types is None else 
            [nt.value for nt in node_types]
        )
        edge_filter = '|'.join(
            [f"<{et.value}" for et in EdgeType] if edge_types is None else 
            [f"<{et.value}" for et in edge_types]
        )
        
        query = f"""
                MATCH (start:{start_node_type.value} {{id: '{start_node_id}'}})
                CALL apoc.path.subgraphAll(start, {{
                  maxLevel: {max_level},
                  relationshipFilter: '{edge_filter}',
                  labelFilter: '+{node_filter}'
                }}) YIELD nodes, relationships
                RETURN nodes, relationships
            """
        result = session.run(query)
        data = result.single()
        nodes = data["nodes"]
        relationships = data["relationships"]

        # Process nodes
        node_data = []
        for node in nodes:
            node_id = node.get("id")
            attr = node.get(node_attr, None)
            node_data.append({"nodeId": node_id, node_attr: attr, "nodeLabels": list(node.labels)})
        
        node_df = pd.DataFrame(node_data)
        
        # Process relationships
        edge_dict = {}
        for rel in relationships:
            if rel.type not in edge_dict:
                edge_dict[rel.type] = [[], []]
            source_id = rel.start_node.get("id")
            target_id = rel.end_node.get("id")
            
            edge_dict[rel.type][0].append(source_id)
            edge_dict[rel.type][1].append(target_id)
    
    return node_df, edge_dict

In [70]:
def normalize_topology(new_idx_to_old, topology):
    # Reverse index mapping based on new idx -> old idx
    old_idx_to_new = dict((v, k) for k, v in new_idx_to_old.items())
    return {rel_type: [[old_idx_to_new[node_id] for node_id in nodes] for nodes in topology] for rel_type, topology in topology.items()}

def create_edge_index(topology):
    edge_index = []
    edge_features = []
    for rel_type, nodes in topology.items():
        src_nodes, dst_nodes = nodes
        edges = torch.tensor([src_nodes, dst_nodes], dtype=torch.long)
        edge_index.append(edges)
        edge_feature_vec = edge_one_hot[rel_type]
        edge_features.extend([edge_feature_vec for _ in range(len(src_nodes))])
    return torch.cat(edge_index, dim=1), torch.vstack(edge_features)

def project_node_embeddings(node_df):
    def stack_one_hot(row):
        one_hot_enc = node_one_hot[row["nodeLabels"][0]]
        return torch.hstack((one_hot_enc, torch.tensor(row["vec"])))
    return node_df.apply(stack_one_hot, axis=1)

In [71]:
included_nodes = [
    NodeType.PUBLICATION, 
    NodeType.VENUE, 
    NodeType.ORGANIZATION,
    NodeType.AUTHOR,
    NodeType.CO_AUTHOR
]
included_edges = [
    EdgeType.PUB_VENUE,
    EdgeType.VENUE_PUB,
    EdgeType.PUB_ORG,
    EdgeType.ORG_PUB, 
    EdgeType.PUB_AUTHOR,
    EdgeType.AUTHOR_PUB,
    EdgeType.AUTHOR_ORG,
    EdgeType.ORG_AUTHOR,
    EdgeType.PUB_ORG,
    EdgeType.ORG_PUB,
]

def sample_subgraph(node_list):
    dataset = []
    for node_id in node_list:
        node_df, topology = fetch_n_hop_neighbourhood(
            start_node_type=NodeType.PUBLICATION, 
            start_node_id=node_id, 
            node_attr="vec",
            node_types=included_nodes,
            edge_types=included_edges,
            max_level=5
        )
        node_df["vec_projected"] = project_node_embeddings(node_df)
        normalized_node_ids = {new_idx: old_idx for new_idx, old_idx in enumerate(node_df["nodeId"])}
        normalized_topology = normalize_topology(normalized_node_ids, topology)
        if len(normalized_topology) == 0:
            continue
            
        edge_index, edge_features = create_edge_index(normalized_topology)
        node_features = torch.vstack(node_df["vec_projected"].tolist())
        
        dataset.append(Data(
            x=node_features,
            edge_index=edge_index,
            edge_attr=edge_features
        ))
    return DataLoader(dataset)

In [72]:
def visualize_n_hop_neighbourhood(start_node_type: NodeType, start_node_id: str, node_attr: str, node_types: list = None, edge_types: list = None, max_level: int = 6):
    with driver.session() as session:
        node_filter = '|'.join(
            [nt.value for nt in NodeType] if node_types is None else 
            [nt.value for nt in node_types]
        )
        edge_filter = '|'.join(
            [f"<{et.value}" for et in EdgeType] if edge_types is None else 
            [f"<{et.value}" for et in edge_types]
        )
        
        query = f"""
                MATCH (start:{start_node_type.value} {{id: '{start_node_id}'}})
                CALL apoc.path.subgraphAll(start, {{
                  maxLevel: {max_level},
                  relationshipFilter: '{edge_filter}',
                  labelFilter: '+{node_filter}'
                }}) YIELD nodes, relationships
                RETURN nodes, relationships
            """
        result = session.run(query)
        data = result.single()
        nodes = data["nodes"]
        relationships = data["relationships"]

        G = nx.Graph()

        for node in nodes:
            node_id = node.get("id")
            attr = node.get(node_attr, None)
            G.add_node(node_id, label=node.labels, vec=attr)

        for rel in relationships:
            source_id = rel.start_node.get("id")
            target_id = rel.end_node.get("id")
            G.add_edge(source_id, target_id, type=rel.type)

        return G

In [73]:
def plot_graph(G):
    pos = nx.spring_layout(G)

    edge_trace = []
    annotations = []

    for edge in G.edges(data=True):
        x0, y0 = pos[edge[0]]
        x1, y1 = pos[edge[1]]

        edge_trace.append(go.Scatter(
            x=[x0, x1, None],
            y=[y0, y1, None],
            line=dict(width=1, color='#888'),
            hoverinfo='none',
            mode='lines'))

        annotations.append(
            dict(
                ax=x0, ay=y0,
                x=x1, y=y1,
                xref='x', yref='y',
                axref='x', ayref='y',
                showarrow=True,
                arrowhead=3,
                arrowsize=1.5,
                arrowwidth=1,
                arrowcolor='#888'
            )
        )

    node_trace = go.Scatter(
        x=[],
        y=[],
        text=[],
        mode='markers+text',
        hoverinfo='text',
        marker=dict(
            showscale=True,
            colorscale='YlGnBu',
            reversescale=True,
            color=[],
            size=10,
            colorbar=dict(
                thickness=15,
                title='Node Connections',
                xanchor='left',
                titleside='right'
            ),
        ))

    for node in G.nodes(data=True):
        x, y = pos[node[0]]
        node_trace['x'] += tuple([x])
        node_trace['y'] += tuple([y])

        node_info = f"{node[0]}<br>{list(node[1]['label'])[0]}"
        node_trace['text'] += tuple([node_info])
        node_trace['marker']['color'] += tuple([len(G[node[0]])])

    fig = go.Figure(
        data=edge_trace + [node_trace],
        layout=go.Layout(
            title='<br>Subgraph',
            titlefont=dict(size=16),
            showlegend=False,
            hovermode='closest',
            margin=dict(b=0, l=0, r=0, t=0),
            annotations=annotations,
            xaxis=dict(showgrid=False, zeroline=False),
            yaxis=dict(showgrid=False, zeroline=False)
        )
    )

    fig.show()

db_wrapper = DatabaseWrapper()
"""
for nodes in db_wrapper.iter_nodes(NodeType.PUBLICATION, ["id"]):
    for node in nodes:
        G = visualize_n_hop_neighbourhood(
            start_node_type=NodeType.PUBLICATION, 
            start_node_id=node.get("id"), 
            node_attr="vec",
            node_types=included_nodes,
            edge_types=included_edges,
            max_level=2
        )
        plot_graph(G)
        sleep(5)
        
"""

2024-09-08 08:40:28,174 - DatabaseWrapper - INFO - Connecting to the database ...
2024-09-08 08:40:28,181 - DatabaseWrapper - INFO - Database ready.


'\nfor nodes in db_wrapper.iter_nodes(NodeType.PUBLICATION, ["id"]):\n    for node in nodes:\n        G = visualize_n_hop_neighbourhood(\n            start_node_type=NodeType.PUBLICATION, \n            start_node_id=node.get("id"), \n            node_attr="vec",\n            node_types=included_nodes,\n            edge_types=included_edges,\n            max_level=2\n        )\n        plot_graph(G)\n        sleep(5)\n        \n'

In [74]:
node_ids = []# ["wgKatLxf"]
for node_id in node_ids:
    G = visualize_n_hop_neighbourhood(
        start_node_type=NodeType.PUBLICATION, 
        start_node_id=node_id, 
        node_attr="vec",
        node_types=included_nodes,
        edge_types=included_edges,
        max_level=2
    )
    print(len(G.nodes) * 37)
    plot_graph(G)
    sleep(5)

In [75]:
import networkx as nx
import matplotlib.pyplot as plt
import matplotlib as mpl
from torch_geometric.utils import to_networkx

def visualize_heterodata(data, node_colors=None, node_size=300, font_size=12):
    # Convert HeteroData to NetworkX graph
    G = to_networkx(data, node_attrs=['x'], edge_attrs=['edge_attr'])

    # Create a color map for the nodes based on type
    if node_colors is None:
        cmap = mpl.colormaps.get_cmap('tab20')
        node_colors = {key: cmap(i % cmap.N) for i, key in enumerate(data.node_types)}

    color_map = []
    for node in G.nodes(data=True):
        node_type = node[1]['type']
        color_map.append(node_colors[node_type])
    
    pos = nx.spring_layout(G)
    
    plt.figure(figsize=(12, 8))
    nx.draw(G, pos, with_labels=True, node_color=color_map, node_size=node_size, font_size=font_size, cmap=plt.get_cmap('tab20'))

    for node_type, color in node_colors.items():
        plt.scatter([], [], c=[color], label=node_type, s=node_size)
    plt.legend(scatterpoints=1, frameon=False, labelspacing=1, loc='upper left')

    plt.show()

In [76]:
class GraphSampling:
    def __init__(self, node_spec: list, relationship_spec: list, node_properties: list):
        self.driver = GraphDatabase.driver(config.DB_URI, auth=(config.DB_USER, config.DB_PASSWORD))
        self.node_spec = node_spec
        self.relationship_spec = relationship_spec
        self.node_properties = node_properties
        
    def spanning_tree(
            self,
            start_node_type: NodeType,
            start_node_id: str,
            node_types: list = None,
            edge_types: list = None,
            max_level: int = 3,
            limit: int = 300
    ):

        with self.driver.session() as session:
            node_filter = '|'.join(
                [nt.value for nt in NodeType] if node_types is None else
                [nt.value for nt in node_types]
            )
            edge_filter = '|'.join(
                [f"<{et.value}" for et in EdgeType] if edge_types is None else
                [f"<{et.value}" for et in edge_types]
            )

            query = f"""
                MATCH (start:{start_node_type.value} {{id: '{start_node_id}'}})
                CALL apoc.path.spanningTree(start, {{
                    maxLevel: {max_level},
                    limit: {limit}, 
                    relationshipFilter: '{edge_filter}',
                    labelFilter: '+{node_filter}'
                }}) YIELD path
                WITH apoc.coll.flatten(collect(nodes(path))) AS allNodes, apoc.coll.flatten(collect(relationships(path))) AS allRels
                UNWIND allNodes AS node
                UNWIND allRels AS rel
                WITH collect(DISTINCT node) AS nodes, collect(DISTINCT rel) AS relationships
                RETURN 
                  [node IN nodes | {{id: node.id, labels: labels(node), vec: node.vec}}] AS nodes, 
                    [rel IN relationships | {{start_node: {{id: startNode(rel).id}}, type: type(rel), end_node: {{endNode(rel).id}}}}] AS relationships
            """
            result = session.run(query)
            data = result.data()
            return data
            
    
    @staticmethod
    def neo_to_pyg(
        data,
        node_attr: str
    ):
        if not data:
            return None, None

        nodes = data["nodes"]
        relationships = data["relationships"]

        print(f"Nodes: {len(nodes)}, Relationships: {len(relationships)}")
        if len(nodes) > 500:
            print(f"Too many nodes: {len(nodes)}")
            return None, None

        # Create data object
        h_data = HeteroData()

        node_features = {}
        node_ids = {}
        node_id_map = {}

        for node in nodes:
            node_id = node.get("id")
            node_feature = node.get(node_attr, None)
            if node_feature is None:
                print(f"Node {node_id} has no attribute {node_attr}")
                continue
            node_label = list(node.labels)[0]
            if node_label not in node_features:
                node_features[node_label] = []
                node_ids[node_label] = []

            # Convert node features to tensors
            node_features[node_label].append(torch.tensor(node_feature, dtype=torch.float32))
            node_ids[node_label].append(node_id)

            # Map node ID to its index in the list
            node_id_map[node_id] = len(node_ids[node_label]) - 1

        # Convert list of features to a single tensor per node type
        for node_label, node_features in node_features.items():
            h_data[node_label].x = torch.vstack(node_features)

        # Process relationships
        edge_dict = {}

        for rel in relationships:
            key = edge_val_to_pyg_key_vals[rel.type]  # edge_val_to_pyg_key_vals maps edge types to tuples (src, dst)
            if key not in edge_dict:
                edge_dict[key] = [[], []]

            source_id = rel.start_node.get("id")
            target_id = rel.end_node.get("id")

            # Append the indices of the source and target nodes
            edge_dict[key][0].append(node_id_map[source_id])
            edge_dict[key][1].append(node_id_map[target_id])

        # Convert edge lists to tensors
        for key in edge_dict:
            h_data[key[0], key[1], key[2]].edge_index = torch.vstack([
                torch.tensor(edge_dict[key][0], dtype=torch.long),
                torch.tensor(edge_dict[key][1], dtype=torch.long)
            ])

            h_data[key[0], key[1], key[2]].edge_attr = torch.vstack(
                [edge_one_hot[key[1]] for _ in range(len(edge_dict[key][0]))])

        return h_data, node_id_map
    
    def __del__(self):
        self.driver.close()
            
"""gs = GraphSampling(
    node_spec=[n.value for n in included_nodes], 
    relationship_spec=[e.value for e in included_edges], 
    node_properties=["vec"]
)
result = gs.spanning_tree(
    start_node_type=NodeType.PUBLICATION, 
    start_node_id="wgKatLxf", 
    node_types=included_nodes,
    edge_types=included_edges,
    max_level=2
)
for res in result:
    print(len(res["nodes"]), len(res["relationships"]))
    for node in res["nodes"]:
        print(node["labels"], node["id"])
    for rel in res["relationships"]:
        print(rel)
#h_data, node_id_map = GraphSampling.neo_to_pyg(result[0], "vec")
#visualize_heterodata(h_data)
"""

'gs = GraphSampling(\n    node_spec=[n.value for n in included_nodes], \n    relationship_spec=[e.value for e in included_edges], \n    node_properties=["vec"]\n)\nresult = gs.spanning_tree(\n    start_node_type=NodeType.PUBLICATION, \n    start_node_id="wgKatLxf", \n    node_types=included_nodes,\n    edge_types=included_edges,\n    max_level=2\n)\nfor res in result:\n    print(len(res["nodes"]), len(res["relationships"]))\n    for node in res["nodes"]:\n        print(node["labels"], node["id"])\n    for rel in res["relationships"]:\n        print(rel)\n#h_data, node_id_map = GraphSampling.neo_to_pyg(result[0], "vec")\n#visualize_heterodata(h_data)\n'

In [77]:
from src.shared.neo_to_pyg import GraphSampling
sampler = GraphSampling(
    node_spec=[n.value for n in included_nodes], 
    relationship_spec=[e.value for e in included_edges], 
    node_properties=["vec"]
)
result = sampler.n_hop_neighbourhood(
    start_node_type=NodeType.PUBLICATION, 
    start_node_id="wgKatLxf", 
    node_types=included_nodes,
    edge_types=included_edges,
    max_level=2
)
#print(json.dumps(result, indent=2))
h_data, node_id_map = GraphSampling.neo_to_pyg(result, "vec")
visualize_heterodata(h_data)

Nodes: 40247, Relationships: 143856
Too many nodes: 40247


AttributeError: 'NoneType' object has no attribute 'node_offsets'