In [1]:
import os
from time import sleep

import pandas as pd
from graphdatascience import GraphDataScience
from neo4j import GraphDatabase
import torch
import torch.optim as optim
import torch.nn.functional as F
from torch_geometric.data import Data
from torch_geometric.loader import DataLoader
from torch_geometric.nn import GCNConv
from torch_geometric.transforms import RandomNodeSplit
import random
import numpy as np

from src.shared.database_wrapper import DatabaseWrapper
from src.model.GAT.gat_encoder import GATv2Encoder
from src.model.GAT.gat_decoder import GATv2Decoder
from src.shared.graph_schema import NodeType, EdgeType, node_one_hot, edge_one_hot
from src.shared import config

import networkx as nx
import plotly.graph_objects as go

random.seed(42)
np.random.seed(42)
torch.manual_seed(42)
torch.cuda.manual_seed_all(42)

In [2]:
driver = GraphDatabase.driver(config.DB_URI, auth=(config.DB_USER, config.DB_PASSWORD))
gds = GraphDataScience(config.DB_URI, auth=(config.DB_USER, config.DB_PASSWORD))

Unable to connect to the Neo4j DBMS. Trying again...
Unable to connect to the Neo4j DBMS. Trying again...
Unable to connect to the Neo4j DBMS. Trying again...
Unable to connect to the Neo4j DBMS. Trying again...
Unable to connect to the Neo4j DBMS. Trying again...
Unable to connect to the Neo4j DBMS. Trying again...


KeyboardInterrupt: 

In [65]:


def fetch_n_hop_neighbourhood(start_node_type: NodeType, start_node_id: str, node_attr: str, node_types: list = None, edge_types: list = None, max_level: int = 6):
    with driver.session() as session:
        node_filter = '|'.join(
            [nt.value for nt in NodeType] if node_types is None else 
            [nt.value for nt in node_types]
        )
        edge_filter = '|'.join(
            [f"<{et.value}" for et in EdgeType] if edge_types is None else 
            [f"<{et.value}" for et in edge_types]
        )
        
        query = f"""
                MATCH (start:{start_node_type.value} {{id: '{start_node_id}'}})
                CALL apoc.path.subgraphAll(start, {{
                  maxLevel: {max_level},
                  relationshipFilter: '{edge_filter}',
                  labelFilter: '+{node_filter}'
                }}) YIELD nodes, relationships
                RETURN nodes, relationships
            """
        result = session.run(query)
        data = result.single()
        nodes = data["nodes"]
        relationships = data["relationships"]

        # Process nodes
        node_data = []
        for node in nodes:
            node_id = node.get("id")
            attr = node.get(node_attr, None)
            node_data.append({"nodeId": node_id, node_attr: attr, "nodeLabels": list(node.labels)})
        
        node_df = pd.DataFrame(node_data)
        
        # Process relationships
        edge_dict = {}
        for rel in relationships:
            if rel.type not in edge_dict:
                edge_dict[rel.type] = [[], []]
            source_id = rel.start_node.get("id")
            target_id = rel.end_node.get("id")
            
            edge_dict[rel.type][0].append(source_id)
            edge_dict[rel.type][1].append(target_id)
    
    return node_df, edge_dict

In [66]:
def normalize_topology(new_idx_to_old, topology):
    # Reverse index mapping based on new idx -> old idx
    old_idx_to_new = dict((v, k) for k, v in new_idx_to_old.items())
    return {rel_type: [[old_idx_to_new[node_id] for node_id in nodes] for nodes in topology] for rel_type, topology in topology.items()}

def create_edge_index(topology):
    edge_index = []
    edge_features = []
    for rel_type, nodes in topology.items():
        src_nodes, dst_nodes = nodes
        edges = torch.tensor([src_nodes, dst_nodes], dtype=torch.long)
        edge_index.append(edges)
        edge_feature_vec = edge_one_hot[rel_type]
        edge_features.extend([edge_feature_vec for _ in range(len(src_nodes))])
    return torch.cat(edge_index, dim=1), torch.vstack(edge_features)

def project_node_embeddings(node_df):
    def stack_one_hot(row):
        one_hot_enc = node_one_hot[row["nodeLabels"][0]]
        return torch.hstack((one_hot_enc, torch.tensor(row["vec"])))
    return node_df.apply(stack_one_hot, axis=1)

In [67]:
included_nodes = [
    NodeType.PUBLICATION, 
    NodeType.VENUE, 
    NodeType.ORGANIZATION,
    NodeType.AUTHOR,
    NodeType.CO_AUTHOR
]
included_edges = [
    EdgeType.PUB_VENUE, 
    EdgeType.PUB_ORG, 
    EdgeType.SIM_VENUE,
    EdgeType.SIM_ORG,
    EdgeType.ORG_PUB, 
    EdgeType.VENUE_PUB,
    EdgeType.PUB_AUTHOR,
    EdgeType.AUTHOR_PUB,
    EdgeType.AUTHOR_CO_AUTHOR,
    EdgeType.CO_AUTHOR_AUTHOR,
    EdgeType.PUB_CO_AUTHOR,
    EdgeType.CO_AUTHOR_PUB,
    EdgeType.AUTHOR_ORG,
    EdgeType.ORG_AUTHOR,
    EdgeType.CO_AUTHOR_ORG,
    EdgeType.ORG_CO_AUTHOR
]

def sample_subgraph(node_list):
    dataset = []
    for node_id in node_list:
        node_df, topology = fetch_n_hop_neighbourhood(
            start_node_type=NodeType.PUBLICATION, 
            start_node_id=node_id, 
            node_attr="vec",
            node_types=included_nodes,
            edge_types=included_edges,
            max_level=5
        )
        node_df["vec_projected"] = project_node_embeddings(node_df)
        normalized_node_ids = {new_idx: old_idx for new_idx, old_idx in enumerate(node_df["nodeId"])}
        normalized_topology = normalize_topology(normalized_node_ids, topology)
        if len(normalized_topology) == 0:
            continue
            
        edge_index, edge_features = create_edge_index(normalized_topology)
        node_features = torch.vstack(node_df["vec_projected"].tolist())
        
        dataset.append(Data(
            x=node_features,
            edge_index=edge_index,
            edge_attr=edge_features
        ))
    return DataLoader(dataset)

In [68]:
def visualize_n_hop_neighbourhood(start_node_type: NodeType, start_node_id: str, node_attr: str, node_types: list = None, edge_types: list = None, max_level: int = 6):
    with driver.session() as session:
        node_filter = '|'.join(
            [nt.value for nt in NodeType] if node_types is None else 
            [nt.value for nt in node_types]
        )
        edge_filter = '|'.join(
            [f"<{et.value}" for et in EdgeType] if edge_types is None else 
            [f"<{et.value}" for et in edge_types]
        )
        
        query = f"""
                MATCH (start:{start_node_type.value} {{id: '{start_node_id}'}})
                CALL apoc.path.subgraphAll(start, {{
                  maxLevel: {max_level},
                  relationshipFilter: '{edge_filter}',
                  labelFilter: '+{node_filter}'
                }}) YIELD nodes, relationships
                RETURN nodes, relationships
            """
        result = session.run(query)
        data = result.single()
        nodes = data["nodes"]
        relationships = data["relationships"]

        G = nx.Graph()

        for node in nodes:
            node_id = node.get("id")
            attr = node.get(node_attr, None)
            G.add_node(node_id, label=node.labels, vec=attr)

        for rel in relationships:
            source_id = rel.start_node.get("id")
            target_id = rel.end_node.get("id")
            G.add_edge(source_id, target_id, type=rel.type)

        return G

In [69]:
def plot_graph(G):
    pos = nx.spring_layout(G)

    edge_trace = []
    annotations = []

    for edge in G.edges(data=True):
        x0, y0 = pos[edge[0]]
        x1, y1 = pos[edge[1]]

        edge_trace.append(go.Scatter(
            x=[x0, x1, None],
            y=[y0, y1, None],
            line=dict(width=1, color='#888'),
            hoverinfo='none',
            mode='lines'))

        annotations.append(
            dict(
                ax=x0, ay=y0,
                x=x1, y=y1,
                xref='x', yref='y',
                axref='x', ayref='y',
                showarrow=True,
                arrowhead=3,
                arrowsize=1.5,
                arrowwidth=1,
                arrowcolor='#888'
            )
        )

    node_trace = go.Scatter(
        x=[],
        y=[],
        text=[],
        mode='markers+text',
        hoverinfo='text',
        marker=dict(
            showscale=True,
            colorscale='YlGnBu',
            reversescale=True,
            color=[],
            size=10,
            colorbar=dict(
                thickness=15,
                title='Node Connections',
                xanchor='left',
                titleside='right'
            ),
        ))

    for node in G.nodes(data=True):
        x, y = pos[node[0]]
        node_trace['x'] += tuple([x])
        node_trace['y'] += tuple([y])

        node_info = f"{node[0]}<br>{list(node[1]['label'])[0]}"
        node_trace['text'] += tuple([node_info])
        node_trace['marker']['color'] += tuple([len(G[node[0]])])

    fig = go.Figure(
        data=edge_trace + [node_trace],
        layout=go.Layout(
            title='<br>Subgraph',
            titlefont=dict(size=16),
            showlegend=False,
            hovermode='closest',
            margin=dict(b=0, l=0, r=0, t=0),
            annotations=annotations,
            xaxis=dict(showgrid=False, zeroline=False),
            yaxis=dict(showgrid=False, zeroline=False)
        )
    )

    fig.show()

db_wrapper = DatabaseWrapper()
"""
for nodes in db_wrapper.iter_nodes(NodeType.PUBLICATION, ["id"]):
    for node in nodes:
        G = visualize_n_hop_neighbourhood(
            start_node_type=NodeType.PUBLICATION, 
            start_node_id=node.get("id"), 
            node_attr="vec",
            node_types=included_nodes,
            edge_types=included_edges,
            max_level=2
        )
        plot_graph(G)
        sleep(5)
        
"""

2024-08-18 20:04:58,776 - DatabaseWrapper - INFO - Connecting to the database ...
2024-08-18 20:04:58,784 - DatabaseWrapper - INFO - Database ready.


'\nfor nodes in db_wrapper.iter_nodes(NodeType.PUBLICATION, ["id"]):\n    for node in nodes:\n        G = visualize_n_hop_neighbourhood(\n            start_node_type=NodeType.PUBLICATION, \n            start_node_id=node.get("id"), \n            node_attr="vec",\n            node_types=included_nodes,\n            edge_types=included_edges,\n            max_level=2\n        )\n        plot_graph(G)\n        sleep(5)\n        \n'

In [70]:
node_ids = []# ["wgKatLxf"]
for node_id in node_ids:
    G = visualize_n_hop_neighbourhood(
        start_node_type=NodeType.PUBLICATION, 
        start_node_id=node_id, 
        node_attr="vec",
        node_types=included_nodes,
        edge_types=included_edges,
        max_level=2
    )
    print(len(G.nodes) * 37)
    plot_graph(G)
    sleep(5)