In [261]:
import os
import json
from time import sleep

import pandas as pd
from graphdatascience import GraphDataScience
from neo4j import GraphDatabase
import torch
import torch.optim as optim
import torch.nn.functional as F
from torch_geometric.data import Data
from torch_geometric.loader import DataLoader
from torch_geometric.nn import GCNConv
from torch_geometric.transforms import RandomNodeSplit
import random
import numpy as np

from src.shared.database_wrapper import DatabaseWrapper
from src.datasets.who_is_who import WhoIsWhoDataset
from src.model.GAT.gat_encoder import GATv2Encoder
from src.model.GAT.gat_decoder import GATv2Decoder
from src.shared.graph_schema import NodeType, EdgeType, node_one_hot, edge_one_hot
from src.shared import config

import networkx as nx
import plotly.graph_objects as go

random.seed(42)
np.random.seed(42)
torch.manual_seed(42)
torch.cuda.manual_seed_all(42)

In [262]:
driver = GraphDatabase.driver(config.DB_URI, auth=(config.DB_USER, config.DB_PASSWORD))

In [263]:
def fetch_n_hop_neighbourhood(start_node_type: NodeType, start_node_id: str, node_attr: str, node_types: list = None, edge_types: list = None, max_level: int = 6):
    with driver.session() as session:
        node_filter = '|'.join(
            [nt.value for nt in NodeType] if node_types is None else 
            [nt.value for nt in node_types]
        )
        edge_filter = '|'.join(
            [f"<{et.value}" for et in EdgeType] if edge_types is None else 
            [f"<{et.value}" for et in edge_types]
        )
        
        query = f"""
                MATCH (start:{start_node_type.value} {{id: '{start_node_id}'}})
                CALL apoc.path.subgraphAll(start, {{
                  maxLevel: {max_level},
                  relationshipFilter: '{edge_filter}',
                  labelFilter: '+{node_filter}'
                }}) YIELD nodes, relationships
                RETURN nodes, relationships
            """
        result = session.run(query)
        data = result.single()
        if not data:
            return None, None
        
        nodes = data["nodes"]
        relationships = data["relationships"]
        
        # Process nodes
        node_data = []
        for node in nodes:
            node_id = node.get("id")
            attr = node.get(node_attr, None)
            node_data.append({"nodeId": node_id, node_attr: attr, "nodeLabels": list(node.labels)})
        
        node_df = pd.DataFrame(node_data)
        
        # Process relationships
        edge_dict = {}
        for rel in relationships:
            if rel.type not in edge_dict:
                edge_dict[rel.type] = [[], []]
            source_id = rel.start_node.get("id")
            target_id = rel.end_node.get("id")
            
            edge_dict[rel.type][0].append(source_id)
            edge_dict[rel.type][1].append(target_id)
    
    return node_df, edge_dict

In [264]:
def normalize_topology(new_idx_to_old, topology):
    # Reverse index mapping based on new idx -> old idx
    old_idx_to_new = dict((v, k) for k, v in new_idx_to_old.items())
    return {rel_type: [[old_idx_to_new[node_id] for node_id in nodes] for nodes in topology] for rel_type, topology in topology.items()}

def create_edge_index(topology):
    edge_index = []
    edge_features = []
    for rel_type, nodes in topology.items():
        src_nodes, dst_nodes = nodes
        edges = torch.tensor([src_nodes, dst_nodes], dtype=torch.long)
        edge_index.append(edges)
        edge_feature_vec = edge_one_hot[rel_type]
        edge_features.extend([edge_feature_vec for _ in range(len(src_nodes))])
    return torch.cat(edge_index, dim=1), torch.vstack(edge_features)

def project_node_embeddings(node_df):
    def stack_one_hot(row):
        one_hot_enc = node_one_hot[row["nodeLabels"][0]]
        return torch.hstack((one_hot_enc, torch.tensor(row["vec"])))
    return node_df.apply(stack_one_hot, axis=1)

In [265]:
included_nodes = [
    NodeType.PUBLICATION, 
    NodeType.VENUE, 
    NodeType.ORGANIZATION,
    NodeType.AUTHOR,
    NodeType.CO_AUTHOR
]
included_edges = [
    EdgeType.PUB_VENUE,
    EdgeType.VENUE_PUB,
    EdgeType.PUB_ORG,
    EdgeType.ORG_PUB, 
    #EdgeType.SIM_VENUE,
    #EdgeType.SIM_ORG,
    EdgeType.PUB_AUTHOR,
    EdgeType.AUTHOR_PUB,
    EdgeType.AUTHOR_ORG,
    EdgeType.ORG_AUTHOR,
    EdgeType.PUB_ORG,
    EdgeType.ORG_PUB,
]

def sample_triplet(anchor, pos, neg):
    triplet = {}
    for label, node_id in zip(["anchor", "pos", "neg"], [anchor, pos, neg]):
        node_df, topology = fetch_n_hop_neighbourhood(
            start_node_type=NodeType.PUBLICATION, 
            start_node_id=node_id, 
            node_attr="vec",
            node_types=included_nodes,
            edge_types=included_edges,
            max_level=2
        )
        if node_df is None or len(node_df) == 0:
            return None
        if topology is None or len(topology) == 0:
            return None
        node_df["vec_projected"] = project_node_embeddings(node_df)
        if node_df["vec_projected"].isnull().values.any():
            return None
        normalized_node_ids = {new_idx: old_idx for new_idx, old_idx in enumerate(node_df["nodeId"])}
        normalized_topology = normalize_topology(normalized_node_ids, topology)
        if not normalized_topology or len(normalized_topology) == 0:
            return None
            
        edge_index, edge_features = create_edge_index(normalized_topology)
        node_features = torch.vstack(node_df["vec_projected"].tolist())
        
        if len(node_features) == 0 or len(edge_index) == 0 or len(edge_features) == 0:
            return None
        
        
        graph = {
            "x": node_features.tolist(),
            "edge_index": edge_index.tolist(),
            "edge_attr": edge_features.tolist(),
            "node_id": node_id
        }
        triplet[label] = graph
    
    json.dumps(triplet, indent=2)
    return triplet

In [266]:
class TripleSampler:
    def __init__(self):
        self.data = WhoIsWhoDataset.parse_train()
        
    def iter_triplet_ids_rand(self, num_triplets):
        # yield random triplets from the WhoIsWho dataset
        current_num_triplets = 0
        while current_num_triplets < num_triplets:
            # Get random author
            author_id = random.choice(list(self.data.keys()))
            
            # Check if author has enough data
            normal_data = self.data[author_id]["normal_data"]
            outliers = self.data[author_id]["outliers"]
            if len(normal_data) < 2 or len(outliers) == 0:
                self.data.pop(author_id)
                continue
            # Get random anchor, positive and negative samples, remove anchor
            anchor_id = random.choice(normal_data)
            self.data[author_id]["normal_data"].remove(anchor_id)
            
            pos_sample = random.choice(normal_data)
            neg_sample = random.choice(outliers)
            yield anchor_id, pos_sample, neg_sample
            current_num_triplets += 1
            
    def sample_triplet_ids(self, attempt = 0):
        if attempt > 30:
            print("Unable to sample more triples")
            return None
        # Get random author
        author_id = random.choice(list(self.data.keys()))
        
        # Check if author has enough data
        normal_data = self.data[author_id]["normal_data"]
        outliers = self.data[author_id]["outliers"]
        if len(normal_data) < 2 or len(outliers) == 0:
            self.data.pop(author_id)
            return self.sample_triplet_ids(attempt + 1)
        # Get random anchor, positive and negative samples, remove anchor
        anchor_id = random.choice(normal_data)
        self.data[author_id]["normal_data"].remove(anchor_id)
        
        pos_sample = random.choice(normal_data)
        neg_sample = random.choice(outliers)
        return anchor_id, pos_sample, neg_sample

    def iter_triplets_rand(self, num_triplets):
        # Fetch triplets by id from the neo4j database and yield them
        count = 0
        count_skipped = 0
        while count < num_triplets:                
            anchor_id, pos_id, neg_id = self.sample_triplet_ids()
            triplet = sample_triplet(anchor_id, pos_id, neg_id)
            if triplet is None or type(triplet) is not dict:
                count_skipped += 1
                continue
        
            yield triplet
            count += 1
        #print(f"Requested {num_triplets}, Fetched {count}, Skipped {count_skipped}")

In [267]:
triplet_sampler = TripleSampler()
path = "./data/triplet_dataset"

if not os.path.exists(path):
    os.makedirs(path)

for batch_id in range(10000):
    file = os.path.join(path, f"triplet_batch_{batch_id}.json")
    with open(file, "w") as f:
        data = []
        for triplet in triplet_sampler.iter_triplets_rand(10):
            #print(json.dumps(triplet, indent=2))
            data.append(triplet)
        
        # Save the triplets to disk
        print(f"Saving batch {batch_id} to disk")
        f.write(json.dumps(data))

Saving batch 0 to disk
Saving batch 1 to disk
Saving batch 2 to disk
Saving batch 3 to disk
Saving batch 4 to disk
Saving batch 5 to disk
Saving batch 6 to disk
Saving batch 7 to disk
Saving batch 8 to disk
Saving batch 9 to disk
Saving batch 10 to disk
Saving batch 11 to disk
Saving batch 12 to disk
Saving batch 13 to disk
Saving batch 14 to disk
Saving batch 15 to disk
Saving batch 16 to disk
Saving batch 17 to disk
Saving batch 18 to disk
Saving batch 19 to disk
Saving batch 20 to disk
Saving batch 21 to disk
Saving batch 22 to disk
Saving batch 23 to disk
Saving batch 24 to disk
Saving batch 25 to disk
Saving batch 26 to disk
Saving batch 27 to disk
Saving batch 28 to disk
Saving batch 29 to disk
Saving batch 30 to disk
Saving batch 31 to disk
Saving batch 32 to disk
Saving batch 33 to disk
Saving batch 34 to disk
Saving batch 35 to disk
Saving batch 36 to disk
Saving batch 37 to disk
Saving batch 38 to disk
Saving batch 39 to disk
Saving batch 40 to disk
Saving batch 41 to disk
Sa

KeyboardInterrupt: 