In [87]:
import os
import random

import torch
import numpy as np
from tqdm.notebook import tqdm
from neo4j import GraphDatabase
from torch_geometric.data import HeteroData

from src.datasets.who_is_who import WhoIsWhoDataset
from src.shared.graph_schema import NodeType, EdgeType, node_one_hot, edge_one_hot, edge_val_to_pyg_key_vals
from src.shared import config

random.seed(42)
np.random.seed(42)
torch.manual_seed(42)
torch.cuda.manual_seed_all(42)

In [88]:
driver = GraphDatabase.driver(config.DB_URI, auth=(config.DB_USER, config.DB_PASSWORD))

# Fetch data from Neo4j and create PyG HeteroData objects

In [89]:
included_nodes = [
    NodeType.PUBLICATION, 
    NodeType.VENUE, 
    NodeType.ORGANIZATION,
    NodeType.AUTHOR,
    NodeType.CO_AUTHOR
]
included_edges = [
    EdgeType.PUB_VENUE,
    EdgeType.VENUE_PUB,
    EdgeType.PUB_ORG,
    EdgeType.ORG_PUB, 
    EdgeType.PUB_AUTHOR,
    EdgeType.AUTHOR_PUB,
    EdgeType.AUTHOR_ORG,
    EdgeType.ORG_AUTHOR,
    EdgeType.PUB_ORG,
    EdgeType.ORG_PUB,
]

def verify_hetero_data(h_data: HeteroData) -> bool:
    # Check if there are at least 2 nodes in total
    total_nodes = 0
    for node_type in h_data.node_types:
        if 'x' in h_data[node_type]:
            total_nodes += h_data[node_type].x.size(0)

    if total_nodes < 2:
        print(f"Error: The HeteroData object should contain at least 2 nodes, but contains {total_nodes}.")
        return False
    
    return True
    
def fetch_n_hop_neighbourhood(
        start_node_type: NodeType, 
        start_node_id: str,
        node_attr: str, 
        node_types: list = None, 
        edge_types: list = None,
        max_level: int = 6
):
    with driver.session() as session:
        node_filter = '|'.join(
            [nt.value for nt in NodeType] if node_types is None else 
            [nt.value for nt in node_types]
        )
        edge_filter = '|'.join(
            [f"<{et.value}" for et in EdgeType] if edge_types is None else 
            [f"<{et.value}" for et in edge_types]
        )
        
        query = f"""
                MATCH (start:{start_node_type.value} {{id: '{start_node_id}'}})
                CALL apoc.path.subgraphAll(start, {{
                  maxLevel: {max_level},
                  relationshipFilter: '{edge_filter}',
                  labelFilter: '+{node_filter}'
                }}) YIELD nodes, relationships
                RETURN nodes, relationships
            """
        result = session.run(query)
        data = result.single()
        if not data:
            return None, None
        
        nodes = data["nodes"]
        relationships = data["relationships"]
        print(f"Start node id: {start_node_id}")
        print(f"Nodes: {len(nodes)}, Relationships: {len(relationships)}")
        if len(nodes) > 500:
            print(f"Too many nodes: {len(nodes)}")
            return None, None
    
        # Create data object
        h_data = HeteroData()
        
        node_features = {}
        node_ids = {}
        node_id_map = {}
        
        for node in nodes:
            node_id = node.get("id")
            node_feature = node.get(node_attr, None)
            if node_feature is None:
                print(f"Node {node_id} has no attribute {node_attr}")
                continue
            node_label = list(node.labels)[0]
            if node_label not in node_features:
                node_features[node_label] = []
                node_ids[node_label] = []
            
            # Convert node features to tensors
            node_features[node_label].append(torch.tensor(node_feature, dtype=torch.float32))
            node_ids[node_label].append(node_id)
            
            # Map node ID to its index in the list
            node_id_map[node_id] = len(node_ids[node_label]) - 1
        
        # Convert list of features to a single tensor per node type
        for node_label, node_features in node_features.items():
            h_data[node_label].x = torch.vstack(node_features)
            #print(f"Node {node_label} x: {h_data[node_label].x.shape}")
        
        # Process relationships
        edge_dict = {}
        
        for rel in relationships:
            key = edge_val_to_pyg_key_vals[rel.type]  # edge_val_to_pyg_key_vals maps edge types to tuples (src, dst)
            if key not in edge_dict:
                edge_dict[key] = [[], []]
                
            source_id = rel.start_node.get("id")
            target_id = rel.end_node.get("id")
            
            # Append the indices of the source and target nodes
            edge_dict[key][0].append(node_id_map[source_id])
            edge_dict[key][1].append(node_id_map[target_id])
        
        # Convert edge lists to tensors
        for key in edge_dict:
            h_data[key[0], key[1], key[2]].edge_index = torch.vstack([
                torch.tensor(edge_dict[key][0], dtype=torch.long),
                torch.tensor(edge_dict[key][1], dtype=torch.long)
            ])
            #print(f"Edge index: {h_data[key[0], key[1], key[2]].edge_index.shape}")
            h_data[key[0], key[1], key[2]].edge_attr = torch.vstack([edge_one_hot[key[1]] for _ in range(len(edge_dict[key][0]))])
            #print(f"Edge attr: {h_data[key[0], key[1], key[2]].edge_attr.shape}")
    
    if not verify_hetero_data(h_data):
        return None, None
    
    return h_data, node_id_map

def sample_triplet(anchor, pos, neg):
    triplet = {}
    for label, node_id in zip(["anchor", "pos", "neg"], [anchor, pos, neg]):
        h_data, node_id_map = fetch_n_hop_neighbourhood(
            start_node_type=NodeType.PUBLICATION, 
            start_node_id=node_id, 
            node_attr="vec",
            node_types=included_nodes,
            edge_types=included_edges,
            max_level=2
        )
        if not h_data or not node_id_map:
            return None
        
        graph = {
            "data": h_data,
            "node_id_map": node_id_map,
            "pub_node_id": node_id
        }
        
        triplet[label] = graph
    
    return triplet

In [90]:
class TripletSampler:
    def __init__(self):
        self.data = WhoIsWhoDataset.parse_train()
            
    def sample_triplet_ids(self, attempt = 0):
        if attempt > 30:
            print("Unable to sample more triples")
            return None
        # Get random author
        author_id = random.choice(list(self.data.keys()))
        
        # Check if author has enough data
        normal_data = self.data[author_id]["normal_data"]
        outliers = self.data[author_id]["outliers"]
        if len(normal_data) < 2 or len(outliers) == 0:
            self.data.pop(author_id)
            return self.sample_triplet_ids(attempt + 1)
        # Get random anchor, positive and negative samples, remove anchor
        anchor_id = random.choice(normal_data)
        self.data[author_id]["normal_data"].remove(anchor_id)
        
        pos_sample = random.choice(normal_data)
        neg_sample = random.choice(outliers)
        return anchor_id, pos_sample, neg_sample

    def iter_triplets_rand(self, num_triplets):
        # Fetch triplets by id from the neo4j database and yield them
        count = 0
        count_skipped = 0
        while count < num_triplets:                
            anchor_id, pos_id, neg_id = self.sample_triplet_ids()
            triplet = sample_triplet(anchor_id, pos_id, neg_id)
            if triplet is None or type(triplet) is not dict:
                count_skipped += 1
                continue
        
            yield triplet
            count += 1

In [None]:
triplet_sampler = TripletSampler()
path = "./data/triplet_dataset"

if not os.path.exists(path):
    os.makedirs(path)
    
# Delete folder contents
for file in os.listdir(path):
    os.remove(os.path.join(path, file))

with tqdm(total=10000) as pbar:
    for batch_id in range(10000):
        file = os.path.join(path, f"triplet_batch_{batch_id}.pt")
        #with open(file, "w") as f:
        data = []
        for triplet in triplet_sampler.iter_triplets_rand(10):
            #print(json.dumps(triplet, indent=2))
            data.append(triplet)
        
        # Save the triplets to disk
        print(f"Saving batch {batch_id} to disk")
        #f.write(json.dumps(data))
        torch.save(data, file)
        pbar.update(1)

  0%|          | 0/10000 [00:00<?, ?it/s]

Start node id: WC5A2FBn
Nodes: 121, Relationships: 376
HeteroData(
  Publication={ x=[48, 32] },
  Venue={ x=[1, 32] },
  Organization={ x=[11, 32] },
  Author={ x=[61, 32] },
  (Publication, PubVenue, Venue)={
    edge_index=[2, 48],
    edge_attr=[48, 24],
  },
  (Publication, PubAuthor, Author)={
    edge_index=[2, 61],
    edge_attr=[61, 24],
  },
  (Publication, PubOrg, Organization)={
    edge_index=[2, 18],
    edge_attr=[18, 24],
  },
  (Venue, VenuePub, Publication)={
    edge_index=[2, 48],
    edge_attr=[48, 24],
  },
  (Organization, OrgAuthor, Author)={
    edge_index=[2, 61],
    edge_attr=[61, 24],
  },
  (Organization, OrgPub, Publication)={
    edge_index=[2, 18],
    edge_attr=[18, 24],
  },
  (Author, AuthorOrg, Organization)={
    edge_index=[2, 61],
    edge_attr=[61, 24],
  },
  (Author, AuthorPub, Publication)={
    edge_index=[2, 61],
    edge_attr=[61, 24],
  }
)
Start node id: F9MIZ15O
Nodes: 720, Relationships: 1438
Too many nodes: 720
Start node id: p26CVS2m