In [55]:
import os
import json
from time import sleep

import pandas as pd
from graphdatascience import GraphDataScience
from neo4j import GraphDatabase
import torch
import torch.optim as optim
import torch.nn.functional as F
from torch_geometric.data import Data
from torch_geometric.loader import DataLoader
from torch_geometric.nn import GCNConv
from torch_geometric.transforms import RandomNodeSplit
from torch_geometric.data import HeteroData
import random
import numpy as np

from src.shared.database_wrapper import DatabaseWrapper
from src.datasets.who_is_who import WhoIsWhoDataset
from src.model.GAT.gat_encoder import GATv2Encoder
from src.model.GAT.gat_decoder import GATv2Decoder
from src.shared.graph_schema import NodeType, EdgeType, node_one_hot, edge_one_hot, edge_pyg_key_vals
from src.model.loss.triplet_loss import TripletLoss
from src.shared import config

import networkx as nx
import plotly.graph_objects as go

random.seed(42)
np.random.seed(42)
torch.manual_seed(42)
torch.cuda.manual_seed_all(42)

In [56]:
import torch
import torch.nn.functional as F
from torch_geometric.nn import GATv2Conv, HeteroConv
from torch_geometric.nn import Linear


class TestGATv2Encoder(torch.nn.Module):
    def __init__(
            self,
            in_channels,
            hidden_channels,
            out_channels,
            edge_feature_dim,
            edge_types,
            node_types,
            heads=5,
            concat=True,
            negative_slope=0.2,
            dropout=0.0,
            add_self_loops=True
    ):
        super(TestGATv2Encoder, self).__init__()

        self.conv_1 = HeteroConv({
            edge_type: GATv2Conv(
                in_channels=in_channels,
                out_channels=hidden_channels,
                heads=heads,
                concat=concat,
                negative_slope=negative_slope,
                dropout=dropout,
                add_self_loops=add_self_loops,
                edge_dim=edge_feature_dim
            )
            for edge_type in edge_types
        }, aggr='mean')

        self.conv_2 = HeteroConv({
            edge_type: GATv2Conv(
                in_channels=heads * hidden_channels if concat else hidden_channels,
                out_channels=hidden_channels,
                heads=heads,
                concat=concat,
                negative_slope=negative_slope,
                dropout=dropout,
                add_self_loops=add_self_loops,
                edge_dim=edge_feature_dim
            )
            for edge_type in edge_types
        }, aggr='mean')

        self.lin_out = torch.nn.ModuleDict()
        for node_type in node_types:
            self.lin_out[node_type] = torch.nn.Sequential(
                Linear(heads * hidden_channels, hidden_channels),
                torch.nn.Dropout(dropout),
                Linear(hidden_channels, out_channels)
            )

    def forward(self, x_dict, edge_index_dict, edge_feature_dict):
        """
        :param x_dict: dict of torch.Tensor
            Node feature vectors for each node type.
        :param edge_index_dict: dict of torch.Tensor
            Edge indices for each edge type.
        :param edge_feature_dict: dict of torch.Tensor
            Edge attribute vectors for each edge type.
        """
        
        x_dict = self.conv_1(x_dict, edge_index_dict, edge_feature_dict)
        for node_type in x_dict.keys():
            x_dict[node_type] = F.dropout(F.relu(x_dict[node_type]), p=0.5, training=self.training)

        x_dict = self.conv_2(x_dict, edge_index_dict, edge_feature_dict)
        for node_type in x_dict.keys():
            x_dict[node_type] = F.dropout(F.relu(x_dict[node_type]), p=0.5, training=self.training)

        out_dict = {}
        for node_type in x_dict.keys():
            out_dict[node_type] = self.lin_out[node_type](x_dict[node_type])

        return out_dict

In [57]:
node_feature_dim = 32
edge_feature_dim = EdgeType.SIM_TITLE.one_hot().shape[0]
gat_embedding_dim = 32

included_nodes = [
    NodeType.PUBLICATION, 
    NodeType.VENUE, 
    NodeType.ORGANIZATION,
    NodeType.AUTHOR,
    NodeType.CO_AUTHOR
]
included_edges = [
    EdgeType.PUB_VENUE,
    EdgeType.VENUE_PUB,
    EdgeType.PUB_ORG,
    EdgeType.ORG_PUB, 
    EdgeType.PUB_AUTHOR,
    EdgeType.AUTHOR_PUB,
    EdgeType.AUTHOR_ORG,
    EdgeType.ORG_AUTHOR,
    EdgeType.PUB_ORG,
    EdgeType.ORG_PUB,
]

device = torch.device(
    'cuda' if torch.cuda.is_available()
    else 'cpu'
)
print(device)

encoder = TestGATv2Encoder(
    in_channels=node_feature_dim,
    hidden_channels=32,
    out_channels=gat_embedding_dim,
    edge_feature_dim=edge_feature_dim,
    edge_types=[edge_pyg_key_vals[edge_type] for edge_type in included_edges],
    node_types=[node_type.value for node_type in included_nodes],
    add_self_loops=False
)
encoder.to(device)

cuda


TestGATv2Encoder(
  (conv_1): HeteroConv(num_relations=8)
  (conv_2): HeteroConv(num_relations=8)
  (lin_out): ModuleDict(
    (Publication): Sequential(
      (0): Linear(160, 32, bias=True)
      (1): Dropout(p=0.0, inplace=False)
      (2): Linear(32, 32, bias=True)
    )
    (Venue): Sequential(
      (0): Linear(160, 32, bias=True)
      (1): Dropout(p=0.0, inplace=False)
      (2): Linear(32, 32, bias=True)
    )
    (Organization): Sequential(
      (0): Linear(160, 32, bias=True)
      (1): Dropout(p=0.0, inplace=False)
      (2): Linear(32, 32, bias=True)
    )
    (Author): Sequential(
      (0): Linear(160, 32, bias=True)
      (1): Dropout(p=0.0, inplace=False)
      (2): Linear(32, 32, bias=True)
    )
    (CoAuthor): Sequential(
      (0): Linear(160, 32, bias=True)
      (1): Dropout(p=0.0, inplace=False)
      (2): Linear(32, 32, bias=True)
    )
  )
)

In [58]:
class TripletDataset:
    def __init__(self, dataset_path):
        self.dataset_path = dataset_path
        self.batch_files = os.listdir(dataset_path)
        
    def iter_triplets(self):
        for batch_file in self.batch_files:
            file_path = os.path.join(self.dataset_path, batch_file)
            batch = torch.load(file_path)
            for triplet in batch:
                yield triplet
                
    def __len__(self, batch_size):
        return len(self.batch_files) * batch_size
        

In [59]:
from torch.nn.modules.loss import TripletMarginLoss

def train_gat(encoder, ds: TripletDataset, epochs=1000, lr=0.01):
    # Define the optimizer for the gat model
    optimizer = optim.SGD(list(encoder.parameters()), lr=lr)
    
    # Loss function
    criterion = TripletMarginLoss(margin=1.0, p=2, eps=1e-7)
    
    # Training loop
    for epoch in range(epochs):
        encoder.train()
        
        total_loss = 0
        
        for triplet in ds.iter_triplets():
            anchor, pos, neg = triplet['anchor'], triplet['pos'], triplet['neg']
            
            anchor_data = anchor['data']
            pos_data = pos['data']
            neg_data = neg['data']
            
            anchor_id_map = anchor['node_id_map']
            pos_id_map = pos['node_id_map']
            neg_id_map = neg['node_id_map']
            
            anchor_data.to(device)
            pos_data.to(device)
            neg_data.to(device)
            
            # Zero gradients
            optimizer.zero_grad()

            # Forward pass through the encoder
            #print(f"Anchor node: {anchor['pub_node_id']}")
            #print(f"Node features: {anchor_data.x_dict}")
            #print(f"Edge index: {anchor_data.edge_index_dict}")
            
            anchor_emb = encoder.forward(anchor_data.x_dict, anchor_data.edge_index_dict, anchor_data.edge_attr_dict)
            pos_emb = encoder.forward(pos_data.x_dict, pos_data.edge_index_dict, pos_data.edge_attr_dict)
            neg_emb = encoder.forward(neg_data.x_dict, neg_data.edge_index_dict, neg_data.edge_attr_dict)
            
            # Retrieve embedding of respective start nodes
            anchor_emb = anchor_emb["Publication"][anchor_id_map[anchor["pub_node_id"]]]
            pos_emb = pos_emb["Publication"][pos_id_map[pos["pub_node_id"]]]
            neg_emb = neg_emb["Publication"][neg_id_map[neg["pub_node_id"]]]
            #print(anchor_emb.shape, pos_emb.shape, neg_emb.shape)
            
            # loss = triplet_loss.forward(anchor_emb, pos_emb, neg_emb)
            loss = criterion(anchor_emb, pos_emb, neg_emb)
            
            # Backward pass
            loss.backward()

            # Optimize the parameters
            optimizer.step()

            total_loss += loss.item()

        if epoch % 10 == 0:
            print(f'Epoch {epoch}, Loss: {total_loss / ds.__len__(10)}')

In [None]:
ds = TripletDataset("./data/triplet_dataset")
train_gat(encoder, ds, epochs=1000, lr=0.01)

Epoch 0, Loss: 1.1363968921013368
Epoch 10, Loss: 0.8635011682143579
Epoch 20, Loss: 0.7420051929278252
