In [14]:
import os
import json
from time import sleep
from typing import Any

from sklearn.metrics import accuracy_score
from torch.distributions.constraints import positive
from torch_geometric.nn.models.dimenet import triplets
from tqdm.notebook import tqdm
import pandas as pd
from graphdatascience import GraphDataScience
from neo4j import GraphDatabase
import torch
import torch.optim as optim
import torch.nn.functional as F
from torch_geometric.data import Data
from torch_geometric.loader import DataLoader
from torch_geometric.nn import GCNConv, GATv2Conv, global_mean_pool, HeteroConv
from torch_geometric.transforms import RandomNodeSplit
from torch_geometric.data import HeteroData
import random
import numpy as np
from torch.nn.modules.loss import TripletMarginLoss
import matplotlib.pyplot as plt
from torch_geometric.data import Batch
from torch.optim import Adam
from src.shared.database_wrapper import DatabaseWrapper
from src.datasets.who_is_who import WhoIsWhoDataset
from src.model.GAT.gat_encoder import GATv2Encoder
from src.model.GAT.gat_decoder import GATv2Decoder
from src.shared.graph_schema import NodeType, EdgeType, node_one_hot, edge_one_hot, edge_pyg_key_vals
from src.model.loss.triplet_loss import TripletLoss
from src.shared import config
from torch.utils.data import Dataset
import networkx as nx
import plotly.graph_objects as go
from itertools import combinations, product
from src.shared.graph_sampling import GraphSampling

random.seed(42)
np.random.seed(42)
torch.manual_seed(42)
torch.cuda.manual_seed_all(42)

## Configurations

In [15]:
# Graph sampling configurations
node_spec = [
    NodeType.PUBLICATION,
]

edge_spec = [
    EdgeType.SIM_TITLE,
    EdgeType.SIM_ABSTRACT,
    EdgeType.SIM_VENUE,
    EdgeType.SIM_AUTHOR,
]

node_properties = [
    'id',
    'title',
    'abstract',
    'venue',
    'title_emb',
    'abstract_emb',
    'venue_emb',
    'true_author',
]

gs = GraphSampling(
    node_spec=node_spec,
    edge_spec=edge_spec,
    node_properties=node_properties,
    database='homogeneous-graph',
)

# Model configurations
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model_node_feature = 'abstract_emb'  # Node feature to use for GAT encoder
model_edge_type = EdgeType.SIM_AUTHOR  # Edge type to use for GAT encoder

## Graph Pair Data Handling

In [16]:
class GraphPairDataset(Dataset):
    def __init__(self, pairs, gs):
        self.pairs = pairs  # List of tuples: (paper_id1, paper_id2, label)
        self.gs = gs

    def __len__(self):
        return len(self.pairs)

    def __getitem__(self, idx):
        paper_id1, paper_id2, label = self.pairs[idx]
        try:
            #print(f"Processing pair ({paper_id1}, {paper_id2})")
            # Get n-hop neighbourhood for each paper
            graph1 = self.gs.n_hop_neighbourhood(NodeType.PUBLICATION, paper_id1, max_level=3)
            graph2 = self.gs.n_hop_neighbourhood(NodeType.PUBLICATION, paper_id2, max_level=3)

            # Convert to PyG Data objects
            data1, node_map_1 = neo_to_pyg_homogeneous(graph1, model_node_feature)
            data1.central_node_id = torch.tensor([node_map_1[paper_id1]])
            
            data2, node_map_2 = neo_to_pyg_homogeneous(graph2, model_node_feature)
            data2.central_node_id = torch.tensor([node_map_2[paper_id2]])
            
            # Return data and label
            return data1, data2, torch.tensor(label, dtype=torch.float)
        except Exception as e:
            print(f"Error processing pair ({paper_id1}, {paper_id2}): {e}")
            return None
        
class PairData(Data):
    def __cat_dim__(self, key, value, *args, **kwargs):
        if key == 'central_node_id':
            return 0  # Concat along batch dim
        else:
            return super().__cat_dim__(key, value, *args, **kwargs)
        
    def __inc__(self, key: str, value: Any, *args, **kwargs) -> Any:
        if key == 'central_node_id':
            return self.num_nodes
        else:
            return super().__inc__(key, value, *args, **kwargs)

In [17]:
def graph_data_valid(data: Data):
    try:
        assert data is not None, "Data object is None."
        assert data.x is not None, "Node features 'x' are missing."
        assert data.edge_index is not None, "Edge index 'edge_index' is missing."
        assert data.num_nodes > 0, "Number of nodes must be greater than 0."
        assert data.num_edges > 0, "Number of edges must be greater than 0."
        assert data.x.size(0) == data.num_nodes, "Mismatch between 'x' size and 'num_nodes'."
        assert data.edge_index.size(0) == 2, "'edge_index' should have shape [2, num_edges]."
        assert data.edge_index.size(1) == data.num_edges, "Mismatch between 'edge_index' and 'num_edges'."
        assert data.x.dtype in (torch.float32, torch.float64), "Node features 'x' must be floating point."
        assert data.edge_index.max() < data.num_nodes, "'edge_index' contains invalid node indices."
        return True
    except AssertionError as e:
        print(f"Data check failed: {e}")
        return False

In [18]:
def neo_to_pyg_homogeneous(
        data,
        node_attr: str,
):
    if not data:
        return None, None

    nodes = data["nodes"]
    relationships = data["relationships"]

    #print(f"Nodes: {len(nodes)}, Relationships: {len(relationships)}")

    # Create a PyG Data object
    pyg_data = PairData()

    node_features = []
    node_ids = []
    node_id_map = {}

    for node in nodes:
        node_id = node.get("id")
        node_feature = node.get(node_attr, None)
        if node_feature is None:
            print(f"Node {node_id} has no attribute {node_attr}")
            continue

        # Map node id to its index in the list
        idx = len(node_ids)
        node_id_map[node_id] = idx
        node_ids.append(node_id)

        # Convert node features to tensors
        node_feature_tensor = torch.tensor(node_feature, dtype=torch.float32)
        node_features.append(node_feature_tensor)

    # Convert list of features to tensor
    if node_features:
        pyg_data.x = torch.vstack(node_features)
        pyg_data.num_nodes = pyg_data.x.size(0)
    else:
        print("No node features available.")
        return None, None

    # Process relationships
    edge_index_list = [[], []]

    for rel in relationships:
        source_id = rel.start_node.get("id")
        target_id = rel.end_node.get("id")

        if source_id not in node_id_map or target_id not in node_id_map:
            print(f"Edge from {source_id} to {target_id} cannot be mapped to node indices.")
            continue

        source_idx = node_id_map[source_id]
        target_idx = node_id_map[target_id]

        edge_index_list[0].append(source_idx)
        edge_index_list[1].append(target_idx)

    # Convert edge lists to tensor
    if edge_index_list[0] and edge_index_list[1]:
        pyg_data.edge_index = torch.tensor(edge_index_list, dtype=torch.long)
    else:
        print("No edges available.")
        return None, None

    return pyg_data, node_id_map

### Harvest Positive and Negative tuples from the graph database

In [19]:
class DataHarvester:
    def __init__(self, db: DatabaseWrapper, gs: GraphSampling):
        self.db = db
        self.gs = gs
        self.pairs = []
        self.prepare_pairs()

    def prepare_pairs(self):
        print("Preparing pairs...")
        try:
            print("Loading pairs...")
            self.load_pairs('./data/pairs.json')
            print(f"Loaded {len(self.pairs)} pairs.")
        except FileNotFoundError:
            print("Could not load pairs from file. Generating pairs...")
            self.generate_pairs()
            print(f"Generated {len(self.pairs)} pairs.")
            print("Saving pairs...")
            self.save_pairs('./data/pairs.json')
            print("Pairs saved.")
            
    def load_pairs(self, file_path):
        with open(file_path, 'r') as f:
            self.pairs = json.load(f)
    
    def save_pairs(self, file_path):
        with open(file_path, 'w') as f:
            json.dump(self.pairs, f)
                
    def generate_pairs(self):
        # Filter out the papers that are not present in the graph or have less than 2 edges
        paper_ids = []
        print("Checking data validity...")
        for nodes in self.db.iter_nodes_with_edge_count(NodeType.PUBLICATION, model_edge_type, ['id', 'true_author']):
            for node in nodes:
                data = gs.n_hop_neighbourhood(NodeType.PUBLICATION, node['id'], max_level=1)
                data = neo_to_pyg_homogeneous(data, model_node_feature)[0]
                if not graph_data_valid(data):
                    continue
                paper_ids.append(node['id'])
        
        print(f"Total papers: {len(paper_ids)}")
        print("Preparing pairs...")
        paper_set = set(paper_ids)
        
        author_data = WhoIsWhoDataset.parse_train()
        for author_id, data in author_data.items():
            for key in data:
                data[key] = [p_id for p_id in data[key] if p_id in paper_set]
        
        # Generate pairs with labels
        pairs = []
        
        for author_id, data in author_data.items():
            normal_data = data.get('normal_data', [])
            outliers = data.get('outliers', [])
                    
            # Positive pairs: combinations of normal_data
            pos_pairs = list(combinations(normal_data, 2))
            if len(pos_pairs) > 50:
                pos_pairs = random.sample(pos_pairs, 50)
            for pair in pos_pairs:
                pairs.append((pair[0], pair[1], 1))
            
            # Negative pairs: product of normal_data and outliers
            neg_pairs = list(product(normal_data, outliers))
            if len(neg_pairs) > 50:
                neg_pairs = random.sample(neg_pairs, 50)
            elif len(neg_pairs) < len(pos_pairs):
                # Sample random paper ids from other authors
                while len(neg_pairs) < len(pos_pairs):
                    p1 = random.choice(normal_data)
                    p2 = random.choice(paper_ids)
                    if p2 not in normal_data:
                        neg_pairs.append((p1, p2))
            for pair in neg_pairs:
                pairs.append((pair[0], pair[1], 0))
        
        print(f"Total pairs: {len(pairs)}. Done.")
        self.pairs = pairs
                

## GAT Encoder

In [20]:

class GATEncoder(torch.nn.Module):
    def __init__(self, hidden_channels, out_channels, num_heads=8):
        super().__init__()
        self.conv1 = GATv2Conv(-1, hidden_channels, heads=num_heads)
        self.conv2 = GATv2Conv(hidden_channels * num_heads, out_channels)

    def forward(self, data):
        x, edge_index = data.x, data.edge_index
        x = self.conv1(x, edge_index)
        x = F.elu(x)
        x = self.conv2(x, edge_index)
        return x

In [21]:
def contrastive_loss(embeddings1, embeddings2, labels, margin=1.0):
    # Compute Euclidean distances between embeddings
    distances = F.pairwise_distance(embeddings1, embeddings2)
    
    # Loss
    loss_pos = labels * distances.pow(2)  # For positive pairs
    loss_neg = (1 - labels) * F.relu(margin - distances).pow(2)  # For negative pairs
    loss = loss_pos + loss_neg
    return loss.mean()


In [22]:
# custom collate function adjusted for GraphPairDataset
"""def custom_collate(batch):
    batch = [data for data in batch if data is not None]
    if len(batch) == 0:
        return Batch()
    return Batch.from_data_list(batch)"""

def custom_collate(batch):
    batch = [item for item in batch if item is not None]
    if len(batch) == 0:
        return None  # Skip empty batches

    data1_list = [item[0] for item in batch]
    data2_list = [item[1] for item in batch]
    
    labels = torch.stack([item[2] for item in batch])
    
    batch1 = Batch.from_data_list(data1_list)
    batch2 = Batch.from_data_list(data2_list)

    return batch1, batch2, labels

In [23]:
# Plot loss
def plot_loss(losses, epoch_marker_pos=None, plot_title="Loss", plot_file=None):
    if plot_file is None:
        plot_file = f'./data/losses/loss.png'
    if epoch_marker_pos is None:
        epoch_marker_pos = []
        
    plt.figure(figsize=(10, 6))
    plt.plot(losses, label=f'Loss')
    plt.xlabel('Iteration')
    plt.ylabel('Loss')
    
    for ix, x_pos in enumerate(epoch_marker_pos):
        plt.axvline(x=x_pos, color='red', linestyle='dotted', linewidth=1)
        plt.text(
        x_pos,
        max(losses),
        f'Epoch {ix}',
        rotation=90,
        verticalalignment='top',
        horizontalalignment='right',
        fontsize=10,
        color='red'
    )

    plt.title(plot_title)
    plt.legend()
    plt.grid(True)
    plt.savefig(plot_file)
    plt.close()

In [24]:
def train(model, batch1, batch2, labels, optimizer, margin):
    model.train()
    
    optimizer.zero_grad()
    
    batch1 = batch1.to(device)
    batch2 = batch2.to(device)
    labels = labels.to(device)

    embeddings1 = model(batch1)
    embeddings2 = model(batch2)

    embeddings1_central = embeddings1[batch1.central_node_id]
    embeddings2_central = embeddings2[batch2.central_node_id]

    loss = contrastive_loss(embeddings1_central, embeddings2_central, labels, margin)
    loss.backward()
    optimizer.step()
        
    batch_loss = loss.item()/len(labels)
    #print(f"Batch loss: {batch_loss:.4f}")
    return batch_loss

def test(model, dataloader):
    model.eval()
    total_loss = 0
    for data1, data2, labels in dataloader:
        if data1 is None or data2 is None:
            continue  # Skip empty batches

        data1 = data1.to(device)
        data2 = data2.to(device)
        labels = labels.to(device)

        embeddings1 = model(data1)[data1.central_node_id]
        embeddings2 = model(data2)[data2.central_node_id]
        
        loss = contrastive_loss(embeddings1, embeddings2, labels, margin)
        
        total_loss += loss.item()
        
    avg_loss = total_loss / len(dataloader)
    print(f"Test Loss: {avg_loss:.4f}")
    return avg_loss
    

## Training 

In [25]:
db = DatabaseWrapper(database='homogeneous-graph')
data_harvester = DataHarvester(db=db, gs=gs)


# Split the pairs into train and test
random.shuffle(data_harvester.pairs)
train_size = int(0.95 * len(data_harvester.pairs))
train_pairs = data_harvester.pairs[:train_size]
test_pairs = data_harvester.pairs[train_size:]

# Remove the train pairs from the test pairs
train_dataset = GraphPairDataset(train_pairs, gs)
test_dataset = GraphPairDataset(test_pairs, gs)

# Create the DataLoader
batch_size = 32
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, collate_fn=custom_collate)
test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, collate_fn=custom_collate)

# Create model
model = GATEncoder(128, 32).to(device)
optimizer = Adam(model.parameters(), lr=0.005, weight_decay=5e-4)

2024-10-12 12:20:32,162 - DatabaseWrapper - INFO - Connecting to the database ...
2024-10-12 12:20:32,163 - DatabaseWrapper - INFO - Database ready.


Preparing pairs...
Loading pairs...
Loaded 6184 pairs.


In [None]:
num_epochs = 20
margin = 1.0  # Margin for contrastive loss
train_losses = []
test_losses = []
for epoch in range(1, num_epochs + 1):
    epoch_marker_pos = list(range(0, len(train_dataloader) * epoch, len(train_dataloader)))
    
    for data1, data2, labels in tqdm(train_dataloader, desc=f"Epoch {epoch}/{num_epochs}"):
        if data1 is None or data2 is None:
            continue
        
        if len(train_losses) % 10 == 0:
            test_loss = test(model, test_dataloader)
            test_losses.append(test_loss)
            test_epoch_marker_pos = [marker/10 for marker in epoch_marker_pos if marker != 0]
            plot_loss(test_losses, epoch_marker_pos=test_epoch_marker_pos, plot_title='Test Loss', plot_file=f'./data/losses/test_loss_homo_edges.png')
            
        loss = train(model, data1, data2, labels, optimizer, margin)
        train_losses.append(loss)
        
        plot_loss(train_losses, epoch_marker_pos=epoch_marker_pos, plot_title='Training Loss', plot_file=f'./data/losses/train_loss_homo_edges.png')
    
    # Save model if loss has decreased
    if len(test_losses) > 1 and test_losses[-1] < min(test_losses[:-1]):
        print(f"Saving model at epoch {epoch}...")
        torch.save(model.state_dict(), f'./data/models/gat_encoder_homo_edges.pt')

Epoch 1/20:   0%|          | 0/184 [00:00<?, ?it/s]

Test Loss: 0.3545


In [None]:
# Evaluation function
def evaluate(model, dataloader):
    model.eval()
    distances = []
    labels_list = []
    with torch.no_grad():
        for batch1, batch2, labels in dataloader:
            batch1 = batch1.to(device)
            batch2 = batch2.to(device)
            embeddings1 = model(batch1)
            embeddings2 = model(batch2)
            dist = F.pairwise_distance(embeddings1, embeddings2).cpu().numpy()
            distances.extend(dist)
            labels_list.extend(labels.numpy())
    return distances, labels_list

# After training
from sklearn.metrics import roc_auc_score

distances, labels_list = evaluate(model, dataloader)
roc_auc = roc_auc_score(labels_list, -np.array(distances))  # Negative distances for similarity
print(f"ROC AUC Score: {roc_auc:.4f}")
