In [36]:
import os
import json
from time import sleep

import pandas as pd
from graphdatascience import GraphDataScience
from neo4j import GraphDatabase
import torch
import torch.optim as optim
import torch.nn.functional as F
from torch_geometric.data import Data
from torch_geometric.loader import DataLoader
from torch_geometric.nn import GCNConv
from torch_geometric.transforms import RandomNodeSplit
import random
import numpy as np

from src.shared.database_wrapper import DatabaseWrapper
from src.datasets.who_is_who import WhoIsWhoDataset
from src.model.GAT.gat_encoder import GATv2Encoder
from src.model.GAT.gat_decoder import GATv2Decoder
from src.shared.graph_schema import NodeType, EdgeType, node_one_hot, edge_one_hot, edge_pyg_key_vals
from src.model.loss.triplet_loss import TripletLoss
from src.shared import config

import networkx as nx
import plotly.graph_objects as go

random.seed(42)
np.random.seed(42)
torch.manual_seed(42)
torch.cuda.manual_seed_all(42)

In [37]:
import torch
import torch.nn.functional as F
from torch_geometric.nn import GATv2Conv, HeteroConv
from torch_geometric.nn import Linear


class TestGATv2Encoder(torch.nn.Module):
    def __init__(
            self,
            in_channels,
            hidden_channels,
            out_channels,
            edge_feature_dim,
            edge_types,
            node_types,
            heads=5,
            concat=True,
            negative_slope=0.2,
            dropout=0.0,
            add_self_loops=True
    ):
        super(TestGATv2Encoder, self).__init__()

        self.conv_1 = HeteroConv({
            edge_type: GATv2Conv(
                in_channels=in_channels,
                out_channels=hidden_channels,
                heads=heads,
                concat=concat,
                negative_slope=negative_slope,
                dropout=dropout,
                add_self_loops=add_self_loops,
                edge_dim=edge_feature_dim
            )
            for edge_type in edge_types
        }, aggr='mean')

        self.conv_2 = HeteroConv({
            edge_type: GATv2Conv(
                in_channels=heads * hidden_channels if concat else hidden_channels,
                out_channels=hidden_channels,
                heads=heads,
                concat=concat,
                negative_slope=negative_slope,
                dropout=dropout,
                add_self_loops=add_self_loops,
                edge_dim=edge_feature_dim
            )
            for edge_type in edge_types
        }, aggr='mean')

        self.lin_out = torch.nn.ModuleDict()
        for node_type in node_types:
            self.lin_out[node_type] = torch.nn.Sequential(
                Linear(heads * hidden_channels, hidden_channels),
                torch.nn.Dropout(dropout),
                Linear(hidden_channels, out_channels)
            )

    def forward(self, x_dict, edge_index_dict, edge_feature_dict):
        """
        :param x_dict: dict of torch.Tensor
            Node feature vectors for each node type.
        :param edge_index_dict: dict of torch.Tensor
            Edge indices for each edge type.
        :param edge_feature_dict: dict of torch.Tensor
            Edge attribute vectors for each edge type.
        """
        print("Forward")
        print(x_dict)
        x_dict = self.conv_1(x_dict, edge_index_dict, edge_feature_dict)
        print(x_dict)
        for node_type in x_dict.keys():
            x_dict[node_type] = F.dropout(F.relu(x_dict[node_type]), p=0.5, training=self.training)

        x_dict = self.conv_2(x_dict, edge_index_dict, edge_feature_dict)
        for node_type in x_dict.keys():
            x_dict[node_type] = F.dropout(F.relu(x_dict[node_type]), p=0.5, training=self.training)

        out_dict = {}
        for node_type in x_dict.keys():
            out_dict[node_type] = self.lin_out[node_type](x_dict[node_type])

        return out_dict

In [38]:
node_feature_dim = 32
edge_feature_dim = EdgeType.SIM_TITLE.one_hot().shape[0]
gat_embedding_dim = 32

included_nodes = [
    NodeType.PUBLICATION, 
    NodeType.VENUE, 
    NodeType.ORGANIZATION,
    NodeType.AUTHOR,
    NodeType.CO_AUTHOR
]
included_edges = [
    EdgeType.PUB_VENUE,
    EdgeType.VENUE_PUB,
    EdgeType.PUB_ORG,
    EdgeType.ORG_PUB, 
    EdgeType.PUB_AUTHOR,
    EdgeType.AUTHOR_PUB,
    EdgeType.AUTHOR_ORG,
    EdgeType.ORG_AUTHOR,
    EdgeType.PUB_ORG,
    EdgeType.ORG_PUB,
]

device = torch.device(
    'cuda' if torch.cuda.is_available()
    else 'cpu'
)
print(device)

encoder = TestGATv2Encoder(
    in_channels=node_feature_dim,
    hidden_channels=32,
    out_channels=gat_embedding_dim,
    edge_feature_dim=edge_feature_dim,
    edge_types=[edge_pyg_key_vals[edge_type] for edge_type in included_edges],
    node_types=[node_type.value for node_type in included_nodes],
    add_self_loops=False
)
encoder.to(device)

cpu


TestGATv2Encoder(
  (conv_1): HeteroConv(num_relations=8)
  (conv_2): HeteroConv(num_relations=8)
  (lin_out): ModuleDict(
    (Publication): Sequential(
      (0): Linear(160, 32, bias=True)
      (1): Dropout(p=0.0, inplace=False)
      (2): Linear(32, 32, bias=True)
    )
    (Venue): Sequential(
      (0): Linear(160, 32, bias=True)
      (1): Dropout(p=0.0, inplace=False)
      (2): Linear(32, 32, bias=True)
    )
    (Organization): Sequential(
      (0): Linear(160, 32, bias=True)
      (1): Dropout(p=0.0, inplace=False)
      (2): Linear(32, 32, bias=True)
    )
    (Author): Sequential(
      (0): Linear(160, 32, bias=True)
      (1): Dropout(p=0.0, inplace=False)
      (2): Linear(32, 32, bias=True)
    )
    (CoAuthor): Sequential(
      (0): Linear(160, 32, bias=True)
      (1): Dropout(p=0.0, inplace=False)
      (2): Linear(32, 32, bias=True)
    )
  )
)

In [39]:
class TripletDataset:
    def __init__(self, dataset_path):
        self.dataset_path = dataset_path
        self.batch_files = os.listdir(dataset_path)
        
    def iter_triplets(self):
        for batch_file in self.batch_files:
            file_path = os.path.join(self.dataset_path, batch_file)
            batch = torch.load(file_path)
            for triplet in batch:
                yield triplet
                
    def __len__(self, batch_size):
        return len(self.batch_files) * batch_size
        

In [40]:
def train_gat(encoder, ds: TripletDataset, epochs=1000, lr=0.01):
    # Define the optimizer for the gat model
    optimizer = optim.SGD(list(encoder.parameters()), lr=lr)
    
    # Define a loss function
    criterion = torch.nn.MSELoss()
    triplet_loss = TripletLoss()
    
    # Training loop
    for epoch in range(epochs):
        encoder.train()
        
        total_loss = 0
        
        for triplet in ds.iter_triplets():
            anchor, pos, neg = triplet['anchor'], triplet['pos'], triplet['neg']
            
            # Zero gradients
            optimizer.zero_grad()

            # Forward pass through the encoder
            print("Edge index: ", anchor["edge_index_dict"])
            anchor_emb = encoder.forward(anchor["x_dict"], anchor["edge_index_dict"], anchor["edge_feature_dict"])
            pos_emb = encoder.forward(pos["x_dict"], pos["edge_index_dict"], pos["edge_feature_dict"])
            neg_emb = encoder.forward(neg["x_dict"], neg["edge_index_dict"], neg["edge_feature_dict"])
            #print(anchor_emb.shape, pos_emb.shape, neg_emb.shape)
            
            # Compute loss 
            #loss = criterion()
            anchor_pub_node_id = anchor["x_pub_id"]
            pos_pub_node_id = pos["x_pub_id"]
            neg_pub_node_id = neg["x_pub_id"]
            print(anchor_pub_node_id)
            print(anchor_emb)
            
            loss = triplet_loss.forward(anchor_emb, pos_emb, neg_emb)

            # Backward pass
            loss.backward()

            # Optimize the parameters
            optimizer.step()

            total_loss += loss.item()

        # Print loss every 10 epochs
        if epoch % 10 == 0:
            print(f'Epoch {epoch}, Loss: {total_loss / len(ds.__len__(10))}')

In [41]:
ds = TripletDataset("./data/triplet_dataset")
train_gat(encoder, ds, epochs=1000, lr=0.01)

Edge index:  {('Publication', 'PubOrg', 'Organization'): tensor([[140, 136, 138, 149, 137, 124, 139, 147],
        [103, 103, 103, 103, 103, 103, 103, 103]]), ('Publication', 'PubVenue', 'Venue'): tensor([[140,  57,  52,  58,  51,  53,  54],
        [ 56,  56,  56,  56,  56,  56,  56]]), ('Publication', 'PubAuthor', 'Author'): tensor([[140, 140, 140, 136, 136, 136, 136, 138, 138, 149, 149, 149, 149, 149,
         137, 137, 137, 124, 124, 124, 139, 139, 139, 147, 147, 147],
        [151, 146, 132, 127, 128, 155, 122, 153, 125, 135, 141, 144, 142, 143,
         133, 123, 130, 152, 126, 134, 150, 145, 131, 148, 129, 154]]), ('Venue', 'VenuePub', 'Publication'): tensor([[ 56,  56,  56,  56,  56,  56,  56],
        [ 57, 140,  52,  58,  51,  54,  53]]), ('Organization', 'OrgAuthor', 'Author'): tensor([[103, 103, 103, 103, 103, 103, 103, 103, 103, 103, 103, 103, 103, 103,
         103, 103, 103, 103, 103, 103, 103, 103, 103, 103, 103, 103],
        [142, 151, 146, 132, 125, 153, 141, 128, 15

IndexError: Found indices in 'edge_index' that are larger than 13 (got 140). Please ensure that all indices in 'edge_index' point to valid indices in the interval [0, 14) in your node feature matrix and try again.