In [None]:
import os
import pandas as pd
from graphdatascience import GraphDataScience
import torch
import torch.optim as optim
import torch.nn.functional as F
from torch_geometric.data import Data
from torch_geometric.nn import GCNConv
from torch_geometric.transforms import RandomNodeSplit
import random
import numpy as np

from src.shared.database_wrapper import DatabaseWrapper
from src.model.GAT.gat_encoder import GATv2Encoder
from src.model.GAT.gat_decoder import GATv2Decoder
from src.shared.graph_schema import NodeType, EdgeType, node_one_hot, edge_one_hot
from src.shared import config

In [None]:
# Set seeds for consistent results
random.seed(42)
np.random.seed(42)
torch.manual_seed(42)
torch.cuda.manual_seed_all(42)

In [None]:
#def project_single(n):
#    return torch.hstack((node_to_one_hot[list(n.labels)[0]], torch.tensor(n['vec'])))



In [None]:
auth = (config.DB_USER, config.DB_PASSWORD)
gds = GraphDataScience(config.DB_URI, auth=auth)

In [None]:
included_nodes = [
    NodeType.PUBLICATION, 
    NodeType.VENUE, 
    NodeType.ORGANIZATION
]
included_edges = [
    EdgeType.PUB_VENUE, 
    EdgeType.PUB_ORG, 
    EdgeType.SIM_VENUE,
    EdgeType.SIM_ORG,
    EdgeType.ORG_PUB, 
    EdgeType.VENUE_PUB
]

node_spec = [node_type.value for node_type in included_nodes]
relationship_spec = [edge_type.value for edge_type in included_edges]
print(node_spec)
print(relationship_spec)

In [None]:
gds.graph.drop('graph_sample')
G, _ = gds.graph.project(
    graph_name='graph_sample',
    node_spec=node_spec,
    relationship_spec=relationship_spec,
    nodeProperties=['vec']
)
print(G)

In [None]:
configuration = {
    "concurrency": 1,
    "randomSeed": 42,
    #"start_nodes": [node['id']],
}
gds.graph.drop("graph_rwr")
G_sample, _ = gds.graph.sample.rwr("graph_rwr", G, configuration=configuration)
sample_topology_df = gds.beta.graph.relationships.stream(G_sample)
display(sample_topology_df)

In [None]:
sample_node_properties = gds.graph.nodeProperties.stream(
    G_sample,
    node_properties=["vec"],
    node_labels=["*"],
    separate_property_columns=True,
    listNodeLabels=True
)
display(sample_node_properties)

In [None]:
# Map neo4j node ids to PyG node indices
sample_topology = sample_topology_df.by_rel_type()
#print(sample_topology)

In [None]:
def normalize_topology(new_idx_to_old, topology):
    # Reverse index mapping based on new idx -> old idx
    old_idx_to_new = dict((v, k) for k, v in new_idx_to_old.items())
    return {rel_type: [[old_idx_to_new[node_id] for node_id in nodes] for nodes in topology] for rel_type, topology in topology.items()}

def create_edge_index(topology):
    edge_index = []
    edge_features = []
    for rel_type, nodes in topology.items():
        src_nodes, dst_nodes = nodes
        edges = torch.tensor([src_nodes, dst_nodes], dtype=torch.long)
        edge_index.append(edges)
        edge_feature_vec = edge_one_hot[rel_type]
        edge_features.extend([edge_feature_vec for _ in range(len(src_nodes))])
    return torch.cat(edge_index, dim=1), torch.vstack(edge_features) # TODO: REMOVE .t() and .contiguous() if not needed


normalized_topology = normalize_topology(dict(sample_node_properties["nodeId"]), sample_topology)
edge_index, edge_features = create_edge_index(normalized_topology)

display(edge_index)
display(edge_features)

In [None]:
def project_node_embeddings(node_df):
    def stack_one_hot(row):
        one_hot_enc = node_one_hot[row["nodeLabels"][0]]
        return torch.hstack((one_hot_enc, torch.tensor(row["vec"])))
    return node_df.apply(stack_one_hot, axis=1)

sample_node_properties["vec_projected"] = project_node_embeddings(sample_node_properties)
node_features = torch.vstack(sample_node_properties["vec_projected"].tolist())

display(sample_node_properties)
display(node_features)

In [None]:
node_feature_dim = sample_node_properties["vec_projected"][0].shape[0]
edge_feature_dim = edge_features[0].shape[0]
gat_embedding_dim = 32

encoder = GATv2Encoder(
    in_channels=node_feature_dim,
    out_channels=gat_embedding_dim,
    edge_dim=edge_feature_dim,
    add_self_loops=False
)

decoder = GATv2Decoder(
    in_channels=gat_embedding_dim,
    out_channels=node_feature_dim
)

In [None]:
def train_gat(encoder, decoder, data, epochs=1000, lr=0.01):
    # Define the optimizer for the encoder and decoder
    optimizer = optim.SGD(list(encoder.parameters()) + list(decoder.parameters()), lr=lr)
    
    # Define a loss function
    criterion = torch.nn.MSELoss()

    # Training loop
    for epoch in range(epochs):
        encoder.train()
        decoder.train()

        # Zero gradients
        optimizer.zero_grad()

        # Forward pass through the encoder
        encoded_nodes = encoder(data.x, data.edge_index, data.edge_attr)

        # Forward pass through the decoder
        decoded_graph = decoder(encoded_nodes, data.edge_index, data.edge_attr)

        # Compute loss (assuming your decoder returns node features to be compared with the original)
        loss = criterion(decoded_graph, data.x)

        # Backward pass
        loss.backward()

        # Optimize the parameters
        optimizer.step()

        # Print loss every 10 epochs
        if epoch % 10 == 0:
            print(f'Epoch {epoch}, Loss: {loss.item()}')

In [None]:
data = Data(
    x=node_features,
    edge_index=edge_index,
    edge_attr=edge_features
)

train_gat(encoder, decoder, data)

In [None]:
# Save the model
torch.save(encoder.state_dict(), './data/models/gat_encoder.pth')

In [None]:
_ = G_sample.drop()
_ = G.drop()