In [102]:
import os
import pandas as pd
from graphdatascience import GraphDataScience
import torch
import torch.optim as optim
import torch.nn.functional as F
from torch_geometric.data import Data
from torch_geometric.nn import GCNConv
from torch_geometric.transforms import RandomNodeSplit
import random
import numpy as np

from src.shared.database_wrapper import DatabaseWrapper
from src.model.GAT.gat_encoder import GATv2Encoder
from src.model.GAT.gat_decoder import GATv2Decoder
from src.shared.graph_schema import NodeType, EdgeType, node_one_hot, edge_one_hot
from src.shared import config

In [103]:
# Set seeds for consistent results
random.seed(42)
np.random.seed(42)
torch.manual_seed(42)
torch.cuda.manual_seed_all(42)

In [104]:
#def project_single(n):
#    return torch.hstack((node_to_one_hot[list(n.labels)[0]], torch.tensor(n['vec'])))



In [105]:
auth = (config.DB_USER, config.DB_PASSWORD)
gds = GraphDataScience(config.DB_URI, auth=auth)

In [106]:
included_nodes = [
    NodeType.PUBLICATION, 
    NodeType.VENUE, 
    NodeType.ORGANIZATION
]
included_edges = [
    EdgeType.PUB_VENUE, 
    EdgeType.PUB_ORG, 
    EdgeType.SIM_VENUE,
    EdgeType.SIM_ORG,
    EdgeType.ORG_PUB, 
    EdgeType.VENUE_PUB
]

node_spec = [node_type.value for node_type in included_nodes]
relationship_spec = [edge_type.value for edge_type in included_edges]
print(node_spec)
print(relationship_spec)

['Publication', 'Venue', 'Organization']
['PubVenue', 'PubOrg', 'SimilarVenue', 'SimilarOrg', 'OrgPub', 'VenuePub']


In [107]:
gds.graph.drop('graph_sample')
G, _ = gds.graph.project(
    graph_name='graph_sample',
    node_spec=node_spec,
    relationship_spec=relationship_spec,
    nodeProperties=['vec']
)
print(G)

Graph(name=graph_sample, node_count=2313, relationship_count=6262)


In [108]:
configuration = {
    "concurrency": 1,
    "randomSeed": 42,
    #"start_nodes": [node['id']],
}
gds.graph.drop("graph_rwr")
G_sample, _ = gds.graph.sample.rwr("graph_rwr", G, configuration=configuration)
sample_topology_df = gds.beta.graph.relationships.stream(G_sample)
display(sample_topology_df)

Unnamed: 0,sourceNodeId,targetNodeId,relationshipType
0,20,56,SimilarOrg
1,20,70,SimilarOrg
2,20,204,SimilarOrg
3,20,330,SimilarOrg
4,20,477,SimilarOrg
...,...,...,...
1147,7607,7608,PubVenue
1148,7608,7607,VenuePub
1149,7652,1025,SimilarVenue
1150,7652,4301,SimilarVenue


In [109]:
sample_node_properties = gds.graph.nodeProperties.stream(
    G_sample,
    node_properties=["vec"],
    node_labels=["*"],
    separate_property_columns=True,
    listNodeLabels=True
)
display(sample_node_properties)

Unnamed: 0,nodeId,vec,nodeLabels
0,20,"[-7.349286079406738, -6.137350082397461, 1.769...",[Organization]
1,56,"[-7.838522434234619, -6.077364444732666, 1.122...",[Organization]
2,68,"[-2.6143743991851807, -2.6950881481170654, -1....",[Venue]
3,70,"[-9.003606796264648, -5.801609039306641, 1.042...",[Organization]
4,100,"[-16.875104904174805, -6.252501010894775, -2.3...",[Publication]
...,...,...,...
342,7599,"[-12.758501052856445, -6.827425956726074, -2.9...",[Publication]
343,7600,"[-10.055512428283691, -5.252875328063965, -0.0...",[Venue]
344,7607,"[-12.508116722106934, -4.270413398742676, -3.1...",[Publication]
345,7608,"[0.7140827775001526, -5.582761287689209, 0.461...",[Venue]


In [110]:
# Map neo4j node ids to PyG node indices
sample_topology = sample_topology_df.by_rel_type()
#print(sample_topology)

In [111]:
def normalize_topology(new_idx_to_old, topology):
    # Reverse index mapping based on new idx -> old idx
    old_idx_to_new = dict((v, k) for k, v in new_idx_to_old.items())
    return {rel_type: [[old_idx_to_new[node_id] for node_id in nodes] for nodes in topology] for rel_type, topology in topology.items()}

def create_edge_index(topology):
    edge_index = []
    edge_features = []
    for rel_type, nodes in topology.items():
        src_nodes, dst_nodes = nodes
        edges = torch.tensor([src_nodes, dst_nodes], dtype=torch.long)
        edge_index.append(edges)
        edge_feature_vec = edge_one_hot[rel_type]
        edge_features.extend([edge_feature_vec for _ in range(len(src_nodes))])
    return torch.cat(edge_index, dim=1), torch.vstack(edge_features) # TODO: REMOVE .t() and .contiguous() if not needed


normalized_topology = normalize_topology(dict(sample_node_properties["nodeId"]), sample_topology)
edge_index, edge_features = create_edge_index(normalized_topology)

display(edge_index)
display(edge_features)

tensor([[  3,   6,   6,  ..., 340, 343, 345],
        [179,   4,  13,  ..., 339, 342, 344]])

tensor([[0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        ...,
        [0., 0., 0.,  ..., 0., 0., 1.],
        [0., 0., 0.,  ..., 0., 0., 1.],
        [0., 0., 0.,  ..., 0., 0., 1.]])

In [112]:
def project_node_embeddings(node_df):
    def stack_one_hot(row):
        one_hot_enc = node_one_hot[row["nodeLabels"][0]]
        return torch.hstack((one_hot_enc, torch.tensor(row["vec"])))
    return node_df.apply(stack_one_hot, axis=1)

sample_node_properties["vec_projected"] = project_node_embeddings(sample_node_properties)
node_features = torch.vstack(sample_node_properties["vec_projected"].tolist())

display(sample_node_properties)
display(node_features)

Unnamed: 0,nodeId,vec,nodeLabels,vec_projected
0,20,"[-7.349286079406738, -6.137350082397461, 1.769...",[Organization],"[tensor(0.), tensor(0.), tensor(0.), tensor(0...."
1,56,"[-7.838522434234619, -6.077364444732666, 1.122...",[Organization],"[tensor(0.), tensor(0.), tensor(0.), tensor(0...."
2,68,"[-2.6143743991851807, -2.6950881481170654, -1....",[Venue],"[tensor(0.), tensor(0.), tensor(0.), tensor(0...."
3,70,"[-9.003606796264648, -5.801609039306641, 1.042...",[Organization],"[tensor(0.), tensor(0.), tensor(0.), tensor(0...."
4,100,"[-16.875104904174805, -6.252501010894775, -2.3...",[Publication],"[tensor(1.), tensor(0.), tensor(0.), tensor(0...."
...,...,...,...,...
342,7599,"[-12.758501052856445, -6.827425956726074, -2.9...",[Publication],"[tensor(1.), tensor(0.), tensor(0.), tensor(0...."
343,7600,"[-10.055512428283691, -5.252875328063965, -0.0...",[Venue],"[tensor(0.), tensor(0.), tensor(0.), tensor(0...."
344,7607,"[-12.508116722106934, -4.270413398742676, -3.1...",[Publication],"[tensor(1.), tensor(0.), tensor(0.), tensor(0...."
345,7608,"[0.7140827775001526, -5.582761287689209, 0.461...",[Venue],"[tensor(0.), tensor(0.), tensor(0.), tensor(0...."


tensor([[ 0.0000,  0.0000,  0.0000,  ...,  0.0575,  0.0671, -0.0406],
        [ 0.0000,  0.0000,  0.0000,  ...,  0.0630, -0.4268, -0.3316],
        [ 0.0000,  0.0000,  0.0000,  ..., -0.5694, -0.4004,  0.6518],
        ...,
        [ 1.0000,  0.0000,  0.0000,  ...,  0.0658, -1.4814, -0.9582],
        [ 0.0000,  0.0000,  0.0000,  ...,  0.5818, -1.0339, -0.4346],
        [ 0.0000,  0.0000,  0.0000,  ..., -0.2475, -1.7157, -0.7256]])

In [113]:
node_feature_dim = sample_node_properties["vec_projected"][0].shape[0]
edge_feature_dim = edge_features[0].shape[0]
gat_embedding_dim = 32

encoder = GATv2Encoder(
    in_channels=node_feature_dim,
    out_channels=gat_embedding_dim,
    edge_dim=edge_feature_dim,
    add_self_loops=False
)

decoder = GATv2Decoder(
    in_channels=gat_embedding_dim,
    out_channels=node_feature_dim
)

In [114]:
def train_gat(encoder, decoder, data, epochs=1000, lr=0.01):
    # Define the optimizer for the encoder and decoder
    optimizer = optim.SGD(list(encoder.parameters()) + list(decoder.parameters()), lr=lr)
    
    # Define a loss function
    criterion = torch.nn.MSELoss()

    # Training loop
    for epoch in range(epochs):
        encoder.train()
        decoder.train()

        # Zero gradients
        optimizer.zero_grad()

        # Forward pass through the encoder
        encoded_nodes = encoder(data.x, data.edge_index, data.edge_attr)

        # Forward pass through the decoder
        decoded_graph = decoder(encoded_nodes, data.edge_index, data.edge_attr)

        # Compute loss (assuming your decoder returns node features to be compared with the original)
        loss = criterion(decoded_graph, data.x)

        # Backward pass
        loss.backward()

        # Optimize the parameters
        optimizer.step()

        # Print loss every 10 epochs
        if epoch % 10 == 0:
            print(f'Epoch {epoch}, Loss: {loss.item()}')

In [115]:
data = Data(
    x=node_features,
    edge_index=edge_index,
    edge_attr=edge_features
)

train_gat(encoder, decoder, data)

Epoch 0, Loss: 6.02380895614624
Epoch 10, Loss: 4.913343906402588
Epoch 20, Loss: 4.351069450378418
Epoch 30, Loss: 3.6309823989868164
Epoch 40, Loss: 2.8314192295074463
Epoch 50, Loss: 2.2657554149627686
Epoch 60, Loss: 1.9944164752960205
Epoch 70, Loss: 1.8855009078979492
Epoch 80, Loss: 1.8361690044403076
Epoch 90, Loss: 1.805999517440796
Epoch 100, Loss: 1.7815426588058472
Epoch 110, Loss: 1.75917387008667
Epoch 120, Loss: 1.7381595373153687
Epoch 130, Loss: 1.7185533046722412
Epoch 140, Loss: 1.7003612518310547
Epoch 150, Loss: 1.683487892150879
Epoch 160, Loss: 1.6676430702209473
Epoch 170, Loss: 1.6525752544403076
Epoch 180, Loss: 1.6381007432937622
Epoch 190, Loss: 1.6239831447601318
Epoch 200, Loss: 1.610060691833496
Epoch 210, Loss: 1.5962764024734497
Epoch 220, Loss: 1.5826225280761719
Epoch 230, Loss: 1.569061517715454
Epoch 240, Loss: 1.5556806325912476
Epoch 250, Loss: 1.5424778461456299
Epoch 260, Loss: 1.529500126838684
Epoch 270, Loss: 1.5168428421020508
Epoch 280, Los

In [116]:
# Save the model
torch.save(encoder.state_dict(), './data/models/gat_encoder.pth')

In [117]:
_ = G_sample.drop()
_ = G.drop()