In [20]:
import os
import pandas as pd
from graphdatascience import GraphDataScience
import torch
from torch_geometric.data import Data
import torch.nn.functional as F
from torch_geometric.nn import GCNConv
from torch_geometric.transforms import RandomNodeSplit
import random
import numpy as np

from src.shared.database_wrapper import DatabaseWrapper
from src.model.GAT.gat_encoder import GATv2Encoder
from src.model.GAT.gat_decoder import GATv2Decoder
from src.shared.graph_schema import NodeType, EdgeType, node_one_hot, edge_one_hot
from src.shared import config

In [21]:
# Set seeds for consistent results
random.seed(42)
np.random.seed(42)
torch.manual_seed(42)
torch.cuda.manual_seed_all(42)

In [22]:
#def project_single(n):
#    return torch.hstack((node_to_one_hot[list(n.labels)[0]], torch.tensor(n['vec'])))



In [23]:
auth = (config.DB_USER, config.DB_PASSWORD)
gds = GraphDataScience(config.DB_URI, auth=auth)

In [24]:
excluded_nodes = [NodeType.TRUE_AUTHOR, NodeType.AUTHOR, NodeType.CO_AUTHOR]
included_edges = [
    EdgeType.PUB_VENUE, 
    EdgeType.PUB_ORG, 
    EdgeType.SIM_VENUE,
    EdgeType.SIM_ORG,
    EdgeType.ORG_PUB, 
    EdgeType.VENUE_PUB
]

node_spec = [node_type.value for node_type in NodeType if node_type not in excluded_nodes]
relationship_spec = [edge_type.value for edge_type in EdgeType if edge_type in included_edges]
print(node_spec)
print(relationship_spec)

['Publication', 'Organization', 'Venue']
['SimilarOrg', 'SimilarVenue', 'PubVenue', 'PubOrg', 'OrgPub', 'VenuePub']


In [25]:
gds.graph.drop('graph_sample')
G, _ = gds.graph.project(
    graph_name='graph_sample',
    node_spec=node_spec,
    relationship_spec=relationship_spec,
    nodeProperties=['vec']
)
print(G)

Graph(name=graph_sample, node_count=2313, relationship_count=6354)


In [26]:
configuration = {
    "concurrency": 1,
    "randomSeed": 42,
    #"start_nodes": [node['id']],
}
gds.graph.drop("graph_rwr")
G_sample, _ = gds.graph.sample.rwr("graph_rwr", G, configuration=configuration)
sample_topology_df = gds.beta.graph.relationships.stream(G_sample)
display(sample_topology_df)

Unnamed: 0,sourceNodeId,targetNodeId,relationshipType
0,6998,7002,PubOrg
1,6998,6999,PubVenue
2,6999,6998,VenuePub
3,6999,7011,VenuePub
4,6999,10319,VenuePub
...,...,...,...
1232,13173,13170,OrgPub
1233,13181,8767,PubOrg
1234,13181,8765,PubVenue
1235,13194,8359,PubOrg


In [27]:
sample_node_properties = gds.graph.nodeProperties.stream(
    G_sample,
    node_properties=["vec"],
    node_labels=["*"],
    separate_property_columns=True,
    listNodeLabels=True
)
display(sample_node_properties)

Unnamed: 0,nodeId,vec,nodeLabels
0,6998,"[4.177231788635254, -3.409796953201294, 0.4951...",[Publication]
1,6999,"[-1.361520528793335, -3.1956822872161865, 2.14...",[Venue]
2,7002,"[-6.341586589813232, -6.61083459854126, 1.9404...",[Organization]
3,7011,"[-15.870619773864746, -6.40643835067749, -1.65...",[Publication]
4,7017,"[-17.103137969970703, -3.749770164489746, -3.2...",[Publication]
...,...,...,...
342,13170,"[-10.912079811096191, -6.778702259063721, -1.3...",[Publication]
343,13171,"[-5.502167701721191, -5.071459770202637, 1.299...",[Venue]
344,13173,"[-6.176707744598389, -5.64370584487915, 1.4108...",[Organization]
345,13181,"[-14.853845596313477, -6.310136795043945, 2.97...",[Publication]


In [28]:
# Map neo4j node ids to PyG node indices
sample_topology = sample_topology_df.by_rel_type()
#print(sample_topology)

In [29]:
def normalize_topology(new_idx_to_old, topology):
    # Reverse index mapping based on new idx -> old idx
    old_idx_to_new = dict((v, k) for k, v in new_idx_to_old.items())
    return {rel_type: [[old_idx_to_new[node_id] for node_id in nodes] for nodes in topology] for rel_type, topology in topology.items()}

def create_edge_index(topology):
    edge_index = []
    for rel_type, nodes in topology.items():
        src_nodes, dst_nodes = nodes
        edges = torch.tensor([src_nodes, dst_nodes], dtype=torch.long)
        edge_index.append(edges)
    return torch.cat(edge_index, dim=1).t().contiguous() # TODO: REMOVE .t() and .contiguous() if not needed


normalized_topology = normalize_topology(dict(sample_node_properties["nodeId"]), sample_topology)
edge_index = create_edge_index(normalized_topology)

display(edge_index)

tensor([[  2,   0],
        [  2,   3],
        [  2, 245],
        ...,
        [336, 335],
        [339, 338],
        [343, 342]])

In [30]:
def project_node_embeddings(node_df):
    def stack_one_hot(row):
        one_hot_enc = node_one_hot[row["nodeLabels"][0]]
        return torch.hstack((one_hot_enc, torch.tensor(row["vec"])))
    return node_df.apply(stack_one_hot, axis=1)

sample_node_properties["vec_projected"] = project_node_embeddings(sample_node_properties)
display(sample_node_properties)

Unnamed: 0,nodeId,vec,nodeLabels,projected_vec
0,6998,"[4.177231788635254, -3.409796953201294, 0.4951...",[Publication],"[tensor(1.), tensor(0.), tensor(0.), tensor(0...."
1,6999,"[-1.361520528793335, -3.1956822872161865, 2.14...",[Venue],"[tensor(0.), tensor(0.), tensor(0.), tensor(0...."
2,7002,"[-6.341586589813232, -6.61083459854126, 1.9404...",[Organization],"[tensor(0.), tensor(0.), tensor(0.), tensor(0...."
3,7011,"[-15.870619773864746, -6.40643835067749, -1.65...",[Publication],"[tensor(1.), tensor(0.), tensor(0.), tensor(0...."
4,7017,"[-17.103137969970703, -3.749770164489746, -3.2...",[Publication],"[tensor(1.), tensor(0.), tensor(0.), tensor(0...."
...,...,...,...,...
342,13170,"[-10.912079811096191, -6.778702259063721, -1.3...",[Publication],"[tensor(1.), tensor(0.), tensor(0.), tensor(0...."
343,13171,"[-5.502167701721191, -5.071459770202637, 1.299...",[Venue],"[tensor(0.), tensor(0.), tensor(0.), tensor(0...."
344,13173,"[-6.176707744598389, -5.64370584487915, 1.4108...",[Organization],"[tensor(0.), tensor(0.), tensor(0.), tensor(0...."
345,13181,"[-14.853845596313477, -6.310136795043945, 2.97...",[Publication],"[tensor(1.), tensor(0.), tensor(0.), tensor(0...."
