In [10]:
import os
import pandas as pd
from graphdatascience import GraphDataScience
from neo4j import GraphDatabase
import torch
import torch.optim as optim
import torch.nn.functional as F
from torch_geometric.data import Data
from torch_geometric.loader import DataLoader
from torch_geometric.nn import GCNConv
from torch_geometric.transforms import RandomNodeSplit
import random
import numpy as np

from src.shared.database_wrapper import DatabaseWrapper
from src.model.GAT.gat_encoder import GATv2Encoder
from src.model.GAT.gat_decoder import GATv2Decoder
from src.shared.graph_schema import NodeType, EdgeType, node_one_hot, edge_one_hot
from src.shared import config

In [11]:
random.seed(42)
np.random.seed(42)
torch.manual_seed(42)
torch.cuda.manual_seed_all(42)

In [12]:
auth = (config.DB_USER, config.DB_PASSWORD)
gds = GraphDataScience(config.DB_URI, auth=auth)

included_nodes = [
    NodeType.PUBLICATION, 
    NodeType.VENUE, 
    NodeType.ORGANIZATION
]
included_edges = [
    EdgeType.PUB_VENUE, 
    EdgeType.PUB_ORG, 
    EdgeType.SIM_VENUE,
    EdgeType.SIM_ORG,
    EdgeType.ORG_PUB, 
    EdgeType.VENUE_PUB
]

node_spec = [node_type.value for node_type in included_nodes]
relationship_spec = [edge_type.value for edge_type in included_edges]

In [13]:
def normalize_topology(new_idx_to_old, topology):
    # Reverse index mapping based on new idx -> old idx
    old_idx_to_new = dict((v, k) for k, v in new_idx_to_old.items())
    return {rel_type: [[old_idx_to_new[node_id] for node_id in nodes] for nodes in topology] for rel_type, topology in topology.items()}

def create_edge_index(topology):
    edge_index = []
    edge_features = []
    for rel_type, nodes in topology.items():
        src_nodes, dst_nodes = nodes
        edges = torch.tensor([src_nodes, dst_nodes], dtype=torch.long)
        edge_index.append(edges)
        edge_feature_vec = edge_one_hot[rel_type]
        edge_features.extend([edge_feature_vec for _ in range(len(src_nodes))])
    return torch.cat(edge_index, dim=1), torch.vstack(edge_features)

def project_node_embeddings(node_df):
    def stack_one_hot(row):
        one_hot_enc = node_one_hot[row["nodeLabels"][0]]
        return torch.hstack((one_hot_enc, torch.tensor(row["vec"])))
    return node_df.apply(stack_one_hot, axis=1)


In [14]:
driver = GraphDatabase.driver(config.DB_URI, auth=(config.DB_USER, config.DB_PASSWORD))

def fetch_n_hop_neighbourhood(start_node_type: NodeType, start_node_id: str, node_attr: str, node_types: list = None, edge_types: list = None, max_level: int = 6):
    with driver.session() as session:
        node_filter = '|'.join(
            [nt.value for nt in NodeType] if node_types is None else 
            [nt.value for nt in node_types]
        )
        edge_filter = '|'.join(
            [et.value for et in EdgeType] if edge_types is None else 
            [et.value for et in edge_types]
        )
        
        query = f"""
                MATCH (start:{start_node_type.value} {{id: '{start_node_id}'}})
                CALL apoc.path.subgraphAll(start, {{
                  maxLevel: {max_level},
                  relationshipFilter: '<{edge_filter}>',
                  labelFilter: '+{node_filter}'
                }}) YIELD nodes, relationships
                RETURN nodes, relationships
            """
        result = session.run(query)
        data = result.single()
        nodes = data["nodes"]
        relationships = data["relationships"]

        # Process nodes
        node_data = []
        for node in nodes:
            node_id = node.get("id")
            attr = node.get(node_attr, None)
            node_data.append({"nodeId": node_id, node_attr: attr, "nodeLabels": list(node.labels)})
        
        node_df = pd.DataFrame(node_data)
        
        # Process relationships
        edge_dict = {}
        for rel in relationships:
            if rel.type not in edge_dict:
                edge_dict[rel.type] = [[], []]
            source_id = rel.start_node.get("id")
            target_id = rel.end_node.get("id")
            
            edge_dict[rel.type][0].append(source_id)
            edge_dict[rel.type][1].append(target_id)
    
    return node_df, edge_dict
        


In [15]:
def sample_subgraph(node_list):
    dataset = []
    for node_id in node_list:
        node_df, topology = fetch_n_hop_neighbourhood(
            start_node_type=NodeType.PUBLICATION, 
            start_node_id=node_id, 
            node_attr="vec",
            node_types=included_nodes,
            edge_types=included_edges,
            max_level=5
        )
        node_df["vec_projected"] = project_node_embeddings(node_df)
        normalized_node_ids = {new_idx: old_idx for new_idx, old_idx in enumerate(node_df["nodeId"])}
        normalized_topology = normalize_topology(normalized_node_ids, topology)
        if len(normalized_topology) == 0:
            continue
            
        edge_index, edge_features = create_edge_index(normalized_topology)
        node_features = torch.vstack(node_df["vec_projected"].tolist())
        
        dataset.append(Data(
            x=node_features,
            edge_index=edge_index,
            edge_attr=edge_features
        ))
    return DataLoader(dataset)

In [16]:
db_wrapper = DatabaseWrapper()
start_nodes = []
for nodes in db_wrapper.iter_nodes(NodeType.PUBLICATION, ["id"]):
    for node in nodes:
        start_nodes.append(node["id"])
        
    break
dataset = sample_subgraph(start_nodes)

2024-08-04 10:40:41,627 - DatabaseWrapper - INFO - Connecting to the database ...
2024-08-04 10:40:41,627 - DatabaseWrapper - INFO - Database ready.


In [17]:
node_feature_dim = 38
edge_feature_dim = EdgeType.PUB_YEAR.one_hot().shape[0]
gat_embedding_dim = 32

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Device:', device)

encoder = GATv2Encoder(
    in_channels=node_feature_dim,
    out_channels=gat_embedding_dim,
    edge_dim=edge_feature_dim,
    add_self_loops=False
)
encoder.to(device)

decoder = GATv2Decoder(
    in_channels=gat_embedding_dim,
    out_channels=node_feature_dim
)
decoder.to(device)

Device: cuda


GATv2Decoder(
  (linear1): Linear(in_features=32, out_features=16, bias=True)
  (linear2): Linear(in_features=16, out_features=38, bias=True)
)

In [18]:
def train_gat(encoder, decoder, dataloader, epochs=1000, lr=0.01):
    # Define the optimizer for the encoder and decoder
    optimizer = optim.SGD(list(encoder.parameters()) + list(decoder.parameters()), lr=lr)
    
    # Define a loss function
    criterion = torch.nn.MSELoss()
    
    # Training loop
    for epoch in range(epochs):
        encoder.train()
        decoder.train()
        
        total_loss = 0
        
        for batch in dataloader:
            batch.to(device)
            
            # Zero gradients
            optimizer.zero_grad()

            # Forward pass through the encoder
            encoded_nodes = encoder(batch.x, batch.edge_index, batch.edge_attr)

            # Forward pass through the decoder
            decoded_graph = decoder(encoded_nodes, batch.edge_index, batch.edge_attr)

            # Compute loss (assuming your decoder returns node features to be compared with the original)
            loss = criterion(decoded_graph, batch.x)

            # Backward pass
            loss.backward()

            # Optimize the parameters
            optimizer.step()

            total_loss += loss.item()

        # Print loss every 10 epochs
        if epoch % 10 == 0:
            print(f'Epoch {epoch}, Loss: {total_loss / len(dataloader)}')


train_gat(encoder, decoder, dataset)


Epoch 0, Loss: 1.6991263862904902
Epoch 10, Loss: 0.7182479475329562
Epoch 20, Loss: 0.6474938346539291
Epoch 30, Loss: 0.6019467816655841
Epoch 40, Loss: 0.5914328539402096
Epoch 50, Loss: 0.5862988346903101
Epoch 60, Loss: 0.5829303072758296
Epoch 70, Loss: 0.5797445497957365
Epoch 80, Loss: 0.5682526034839805
Epoch 90, Loss: 0.5630976225372609
Epoch 100, Loss: 0.5603424912660511
Epoch 110, Loss: 0.5584748582448212
Epoch 120, Loss: 0.5572052574853995
Epoch 130, Loss: 0.5562624087573934
Epoch 140, Loss: 0.5555271729539999
Epoch 150, Loss: 0.5549195307822221
Epoch 160, Loss: 0.5543473370000241
Epoch 170, Loss: 0.5538498581182666
Epoch 180, Loss: 0.5533817627960972
Epoch 190, Loss: 0.5529147990631353
Epoch 200, Loss: 0.5524237294213916
Epoch 210, Loss: 0.5519596584456876
Epoch 220, Loss: 0.5514514290743217
Epoch 230, Loss: 0.5509896007711989
Epoch 240, Loss: 0.5505499010750203
Epoch 250, Loss: 0.550146761310881
Epoch 260, Loss: 0.549777979090982
Epoch 270, Loss: 0.5494318231462055
Epoch