In [None]:
import os
from collections import defaultdict

import pandas as pd
from graphdatascience import GraphDataScience
from neo4j import GraphDatabase
import torch
from tqdm.notebook import tqdm
import torch.optim as optim
import torch.nn.functional as F
from torch_geometric.data import Data
from torch_geometric.loader import DataLoader
import random
import numpy as np

from src.datasets.who_is_who import WhoIsWhoDataset
from src.shared.database_wrapper import DatabaseWrapper
from src.model.GAT.gat_encoder import GATv2Encoder
from src.model.loss.triplet_loss import TripletLoss
from src.model.GAT.gat_decoder import GATv2Decoder
from src.shared.graph_schema import NodeType, EdgeType, node_one_hot, edge_one_hot
from src.shared import config

In [None]:
random.seed(42)
np.random.seed(42)
torch.manual_seed(42)
torch.cuda.manual_seed_all(42)

In [None]:
def normalize_topology(new_idx_to_old, topology):
    # Reverse index mapping based on new idx -> old idx
    old_idx_to_new = dict((v, k) for k, v in new_idx_to_old.items())
    return {rel_type: [[old_idx_to_new[node_id] for node_id in nodes] for nodes in topology] for rel_type, topology in topology.items()}

def create_edge_index(topology):
    edge_index = []
    edge_features = []
    for rel_type, nodes in topology.items():
        src_nodes, dst_nodes = nodes
        edges = torch.tensor([src_nodes, dst_nodes], dtype=torch.long)
        edge_index.append(edges)
        edge_feature_vec = edge_one_hot[rel_type]
        edge_features.extend([edge_feature_vec for _ in range(len(src_nodes))])
    return torch.cat(edge_index, dim=1), torch.vstack(edge_features)

def project_node_embeddings(node_df):
    def stack_one_hot(row):
        one_hot_enc = node_one_hot[row["nodeLabels"][0]]
        return torch.hstack((one_hot_enc, torch.tensor(row["vec"])))
    return node_df.apply(stack_one_hot, axis=1)


In [None]:
driver = GraphDatabase.driver(config.DB_URI, auth=(config.DB_USER, config.DB_PASSWORD))

def fetch_n_hop_neighbourhood(start_node_type: NodeType, start_node_id: str, node_attr: str, node_types: list = None, edge_types: list = None, max_level: int = 1):
    with driver.session() as session:
        node_filter = '|'.join(
            [nt.value for nt in NodeType] if node_types is None else 
            [nt.value for nt in node_types]
        )
        edge_filter = '|'.join(
            [f"<{et.value}" for et in EdgeType] if edge_types is None else
            [f"<{et.value}" for et in edge_types]
        )

        query = f"""
                MATCH (start:{start_node_type.value} {{id: '{start_node_id}'}})
                CALL apoc.path.subgraphAll(start, {{
                  maxLevel: {max_level},
                  relationshipFilter: '{edge_filter}',
                  labelFilter: '+{node_filter}'
                }}) YIELD nodes, relationships
                RETURN nodes, relationships
            """
        result = session.run(query)
        data = result.single()
        try:
            nodes = data["nodes"]
            relationships = data["relationships"]
    
            # Process nodes
            node_data = []
            for node in nodes:
                node_id = node.get("id")
                attr = node.get(node_attr, None)
                node_data.append({"nodeId": node_id, node_attr: attr, "nodeLabels": list(node.labels)})
    
            node_df = pd.DataFrame(node_data)
    
            # Process relationships
            edge_dict = {}
            for rel in relationships:
                if rel.type not in edge_dict:
                    edge_dict[rel.type] = [[], []]
                source_id = rel.start_node.get("id")
                target_id = rel.end_node.get("id")
    
                edge_dict[rel.type][0].append(source_id)
                edge_dict[rel.type][1].append(target_id)
        except Exception as e:
            #print(f"Error: {e}")
            return None, None

    return node_df, edge_dict

In [None]:
included_nodes = [
    NodeType.PUBLICATION,
    NodeType.VENUE,
    NodeType.ORGANIZATION,
    NodeType.AUTHOR,
    NodeType.CO_AUTHOR
]
included_edges = [
    EdgeType.PUB_VENUE,
    EdgeType.PUB_ORG,
    EdgeType.SIM_VENUE,
    EdgeType.SIM_ORG,
    EdgeType.ORG_PUB,
    EdgeType.VENUE_PUB,
    EdgeType.PUB_AUTHOR,
    EdgeType.AUTHOR_PUB,
    EdgeType.AUTHOR_CO_AUTHOR,
    EdgeType.CO_AUTHOR_AUTHOR,
    EdgeType.PUB_CO_AUTHOR,
    EdgeType.CO_AUTHOR_PUB,
    EdgeType.AUTHOR_ORG,
    EdgeType.ORG_AUTHOR,
    EdgeType.CO_AUTHOR_ORG,
    EdgeType.ORG_CO_AUTHOR
]

def sample_subgraph(node_list):
    dataset = []
    for node_id in node_list:
        node_df, topology = fetch_n_hop_neighbourhood(
            start_node_type=NodeType.PUBLICATION, 
            start_node_id=node_id, 
            node_attr="vec",
            node_types=included_nodes,
            edge_types=included_edges,
            max_level=3
        )
        if node_df is None or len(node_df) == 0:
            continue
        node_df["vec_projected"] = project_node_embeddings(node_df)
        normalized_node_ids = {new_idx: old_idx for new_idx, old_idx in enumerate(node_df["nodeId"])}
        normalized_topology = normalize_topology(normalized_node_ids, topology)
        if not normalized_topology or len(normalized_topology) == 0:
            continue
            
        edge_index, edge_features = create_edge_index(normalized_topology)
        node_features = torch.vstack(node_df["vec_projected"].tolist())
        
        dataset.append(Data(
            x=node_features,
            edge_index=edge_index,
            edge_attr=edge_features
        ))
    return dataset

def sample_triplet(triplet_ids):
    triplet = []
    for node_id in triplet_ids:
        node_df, topology = fetch_n_hop_neighbourhood(
            start_node_type=NodeType.PUBLICATION, 
            start_node_id=node_id, 
            node_attr="vec",
            node_types=included_nodes,
            edge_types=included_edges,
            max_level=2
        )
        if node_df is None or len(node_df) == 0:
            continue
        node_df["vec_projected"] = project_node_embeddings(node_df)
        normalized_node_ids = {new_idx: old_idx for new_idx, old_idx in enumerate(node_df["nodeId"])}
        normalized_topology = normalize_topology(normalized_node_ids, topology)
        if not normalized_topology or len(normalized_topology) == 0:
            continue
            
        edge_index, edge_features = create_edge_index(normalized_topology)
        node_features = torch.vstack(node_df["vec_projected"].tolist())
        
        triplet.append(Data(
            x=node_features,
            edge_index=edge_index,
            edge_attr=edge_features
        ))
        print(f"Sampled triplet {node_id}")
        print(f"Num nodes: {len(node_df)}, node features: {node_features.shape}")
        print(f"Num edges: {len(edge_index)}, edge features: {edge_features.shape}")
    if len(triplet) != 3:
        return None
    return triplet

In [None]:
db_wrapper = DatabaseWrapper()
start_nodes = []

def iter_triplet_ids_rand(max_triplets = 1000):
    # yield random triplets from the WhoIsWho dataset
    data = WhoIsWhoDataset.parse_train()
    for i in range(max_triplets):
        author_id = random.choice(list(data.keys()))
        anchor_id = random.choice(data[author_id]["normal_data"])
        data[author_id]["normal_data"].remove(anchor_id)
        pos_data = data[author_id]["normal_data"]
        neg_data = data[author_id]["outliers"]
        if len(pos_data) < 2 or len(neg_data) == 0:
            continue
        
        pos_sample = random.sample(pos_data, 1)
        neg_sample = random.sample(neg_data, 1)
        yield anchor_id, pos_sample[0], neg_sample[0]

def iter_triplets_rand(max_triplets = 1000):
    # Fetch triplets by id from the neo4j database and yield them
    count = 0
    for anchor_id, pos_id, neg_id in iter_triplet_ids_rand(max_triplets * 10):
        print(f"Fetching triplet {anchor_id}, {pos_id}, {neg_id}")
        if count > max_triplets:
            break
        triplet = sample_triplet([anchor_id, pos_id, neg_id])
        if not triplet:
            print(f"Skipping triplet {anchor_id}, {pos_id}, {neg_id}")
            continue
        yield triplet
        count += 1
        
        

In [None]:
node_feature_dim = NodeType.PUBLICATION.one_hot().shape[0] + 32
edge_feature_dim = EdgeType.SIM_TITLE.one_hot().shape[0]
gat_embedding_dim = 32

device = torch.device(
    'cuda' if torch.cuda.is_available()
    else 'cpu'
)
print(device)

encoder = GATv2Encoder(
    in_channels=node_feature_dim,
    hidden_channels=32,
    out_channels=gat_embedding_dim,
    edge_dim=edge_feature_dim,
    add_self_loops=False
)
encoder.to(device)

In [None]:
def train_gat(encoder, dataloader, epochs=1000, lr=0.01):
    # Define the optimizer for the gat model
    optimizer = optim.SGD(list(encoder.parameters()), lr=lr)
    
    # Define a loss function
    criterion = torch.nn.MSELoss()
    triplet_loss = TripletLoss()
    
    # Training loop
    for epoch in range(epochs):
        encoder.train()
        
        total_loss = 0
        
        for batch in dataloader:
            anchor, pos, neg = batch
            print(anchor)
            print(pos)
            print(neg)
            print(anchor.x)
            print(anchor.edge_index)
            print(anchor.edge_attr)
            anchor.to(device)
            pos.to(device)
            neg.to(device)
            
            # Zero gradients
            optimizer.zero_grad()

            # Forward pass through the encoder
            anchor_emb = encoder(anchor.x, anchor.edge_index, anchor.edge_attr)
            pos_emb = encoder(pos.x, pos.edge_index, pos.edge_attr)
            neg_emb = encoder(neg.x, neg.edge_index, neg.edge_attr)
            
            # Compute loss 
            #loss = criterion()
            loss = triplet_loss.forward(anchor_emb, pos_emb, neg_emb)

            # Backward pass
            loss.backward()

            # Optimize the parameters
            optimizer.step()

            total_loss += loss.item()

        # Print loss every 10 epochs
        if epoch % 10 == 0:
            print(f'Epoch {epoch}, Loss: {total_loss / len(dataloader)}')

In [None]:
from torch.utils.data import IterableDataset
class TripletIterableDataset(IterableDataset):
    def __init__(self, max_triplets=1000):
        super(TripletIterableDataset, self).__init__()
        self.max_triplets = max_triplets

    def __iter__(self):
        return iter_triplets_rand(self.max_triplets)
    
dataset = TripletIterableDataset(max_triplets=1000)
dataloader = DataLoader(dataset, batch_size=3)
train_gat(encoder, dataloader, epochs=1000, lr=0.01)

In [None]:
# Idea: Predict links between papers purely based on graph structure depending on whether they were written by the same author