In [48]:
import os
import json
from time import sleep

import pandas as pd
from graphdatascience import GraphDataScience
from neo4j import GraphDatabase
import torch
import torch.optim as optim
import torch.nn.functional as F
from torch_geometric.data import Data
from torch_geometric.loader import DataLoader
from torch_geometric.nn import GCNConv
from torch_geometric.transforms import RandomNodeSplit
import random
import numpy as np

from src.shared.database_wrapper import DatabaseWrapper
from src.datasets.who_is_who import WhoIsWhoDataset
from src.model.GAT.gat_encoder import GATv2Encoder
from src.model.GAT.gat_decoder import GATv2Decoder
from src.shared.graph_schema import NodeType, EdgeType, node_one_hot, edge_one_hot
from src.model.loss.triplet_loss import TripletLoss
from src.shared import config

import networkx as nx
import plotly.graph_objects as go

random.seed(42)
np.random.seed(42)
torch.manual_seed(42)
torch.cuda.manual_seed_all(42)

In [49]:
node_feature_dim = NodeType.PUBLICATION.one_hot().shape[0] + 32
edge_feature_dim = EdgeType.SIM_TITLE.one_hot().shape[0]
gat_embedding_dim = 32

device = torch.device(
    'cuda' if torch.cuda.is_available()
    else 'cpu'
)
print(device)

encoder = GATv2Encoder(
    in_channels=node_feature_dim,
    hidden_channels=32,
    out_channels=gat_embedding_dim,
    edge_dim=edge_feature_dim,
    add_self_loops=False
)
encoder.to(device)

cpu


GATv2Encoder(
  (conv1): GATv2Conv(37, 32, heads=5)
  (conv2): GATv2Conv(160, 32, heads=5)
  (linear_output): Sequential(
    (0): Linear(in_features=160, out_features=32, bias=True)
    (1): Dropout(p=0.0, inplace=False)
    (2): Linear(in_features=32, out_features=32, bias=True)
  )
)

In [50]:
class TripletDataset:
    def __init__(self, dataset_path):
        self.dataset_path = dataset_path
        self.batch_files = os.listdir(dataset_path)
        
    def iter_triplets(self):
        for batch_file in self.batch_files:
            file_path = os.path.join(self.dataset_path, batch_file)
            batch = torch.load(file_path)
            for triplet in batch:
                yield triplet
                
    def __len__(self, batch_size):
        return len(self.batch_files) * batch_size
        

In [51]:
def train_gat(encoder, ds: TripletDataset, epochs=1000, lr=0.01):
    # Define the optimizer for the gat model
    optimizer = optim.SGD(list(encoder.parameters()), lr=lr)
    
    # Define a loss function
    criterion = torch.nn.MSELoss()
    triplet_loss = TripletLoss()
    
    # Training loop
    for epoch in range(epochs):
        encoder.train()
        
        total_loss = 0
        
        for triplet in ds.iter_triplets():
            anchor, pos, neg = triplet['anchor'], triplet['pos'], triplet['neg']
            
            anchor = Data(
                x=torch.tensor(anchor['x'], dtype=torch.float32), 
                edge_index=torch.tensor(anchor['edge_index'], dtype=torch.long),
                edge_attr=torch.tensor(anchor['edge_attr'], dtype=torch.float32)
            )
            pos = Data(
                x=torch.tensor(pos['x'], dtype=torch.float32),
                edge_index=torch.tensor(pos['edge_index'], dtype=torch.long),
                edge_attr=torch.tensor(pos['edge_attr'], dtype=torch.float32)
            )
            neg = Data(
                x=torch.tensor(neg['x'], dtype=torch.float32),
                edge_index=torch.tensor(neg['edge_index'], dtype=torch.long),
                edge_attr=torch.tensor(neg['edge_attr'], dtype=torch.float32)
            )
            
            anchor.to(device)
            pos.to(device)
            neg.to(device)
            
            # Zero gradients
            optimizer.zero_grad()

            # Forward pass through the encoder
            anchor_emb = encoder.forward(anchor.x, anchor.edge_index, anchor.edge_attr)
            pos_emb = encoder.forward(pos.x, pos.edge_index, pos.edge_attr)
            neg_emb = encoder.forward(neg.x, neg.edge_index, neg.edge_attr)
            print(anchor_emb.shape, pos_emb.shape, neg_emb.shape)
            
            # Compute loss 
            #loss = criterion()
            loss = triplet_loss.forward(anchor_emb, pos_emb, neg_emb)

            # Backward pass
            loss.backward()

            # Optimize the parameters
            optimizer.step()

            total_loss += loss.item()

        # Print loss every 10 epochs
        if epoch % 10 == 0:
            print(f'Epoch {epoch}, Loss: {total_loss / len(ds.__len__(10))}')

In [52]:
ds = TripletDataset("./data/triplet_dataset")
train_gat(encoder, ds, epochs=1000, lr=0.01)

torch.Size([42, 32]) torch.Size([72, 32]) torch.Size([6, 32])


RuntimeError: The size of tensor a (42) must match the size of tensor b (72) at non-singleton dimension 0