In [8]:
import os
import json
from time import sleep

import pandas as pd
from graphdatascience import GraphDataScience
from neo4j import GraphDatabase
import torch
import torch.optim as optim
import torch.nn.functional as F
from torch_geometric.data import Data
from torch_geometric.loader import DataLoader
from torch_geometric.nn import GCNConv
from torch_geometric.transforms import RandomNodeSplit
import random
import numpy as np

from src.shared.database_wrapper import DatabaseWrapper
from src.datasets.who_is_who import WhoIsWhoDataset
from src.model.GAT.gat_encoder import GATv2Encoder
from src.model.GAT.gat_decoder import GATv2Decoder
from src.shared.graph_schema import NodeType, EdgeType, node_one_hot, edge_one_hot
from src.model.loss.triplet_loss import TripletLoss
from src.shared import config

import networkx as nx
import plotly.graph_objects as go

random.seed(42)
np.random.seed(42)
torch.manual_seed(42)
torch.cuda.manual_seed_all(42)

In [9]:
node_feature_dim = NodeType.PUBLICATION.one_hot().shape[0] + 32
edge_feature_dim = EdgeType.SIM_TITLE.one_hot().shape[0]
gat_embedding_dim = 32

device = torch.device(
    'cuda' if torch.cuda.is_available()
    else 'cpu'
)
print(device)

encoder = GATv2Encoder(
    in_channels=node_feature_dim,
    hidden_channels=32,
    out_channels=gat_embedding_dim,
    edge_dim=edge_feature_dim,
    add_self_loops=False
)
encoder.to(device)

cpu


GATv2Encoder(
  (conv1): GATv2Conv(37, 32, heads=5)
  (conv2): GATv2Conv(160, 32, heads=5)
  (linear_output): Sequential(
    (0): Linear(in_features=160, out_features=32, bias=True)
    (1): Dropout(p=0.0, inplace=False)
    (2): Linear(in_features=32, out_features=32, bias=True)
  )
)

In [10]:
class TripletDataset:
    def __init__(self, dataset_path):
        self.dataset_path = dataset_path
        self.batch_files = os.listdir(dataset_path)
        
    def iter_triplets(self):
        for batch_file in self.batch_files:
            with open(os.path.join(self.dataset_path, batch_file), 'r') as f:
                batch = json.load(f)
                for triplet in batch:
                    yield triplet
        

In [11]:
def train_gat(encoder, ds: TripletDataset, epochs=1000, lr=0.01):
    # Define the optimizer for the gat model
    optimizer = optim.SGD(list(encoder.parameters()), lr=lr)
    
    # Define a loss function
    criterion = torch.nn.MSELoss()
    triplet_loss = TripletLoss()
    
    # Training loop
    for epoch in range(epochs):
        encoder.train()
        
        total_loss = 0
        
        for triplet in ds.iter_triplets():
            anchor, pos, neg = triplet['anchor'], triplet['pos'], triplet['neg']
            print(anchor)
            print(pos)
            print(neg)
            print(anchor.get('x'))
            print(anchor.get('edge_index'))
            print(anchor.get('edge_attr'))
            
            anchor.to(device)
            pos.to(device)
            neg.to(device)
            
            # Zero gradients
            optimizer.zero_grad()

            # Forward pass through the encoder
            anchor_emb = encoder(anchor.x, anchor.edge_index, anchor.edge_attr)
            pos_emb = encoder(pos.x, pos.edge_index, pos.edge_attr)
            neg_emb = encoder(neg.x, neg.edge_index, neg.edge_attr)
            
            # Compute loss 
            #loss = criterion()
            loss = triplet_loss.forward(anchor_emb, pos_emb, neg_emb)

            # Backward pass
            loss.backward()

            # Optimize the parameters
            optimizer.step()

            total_loss += loss.item()

        # Print loss every 10 epochs
        if epoch % 10 == 0:
            print(f'Epoch {epoch}, Loss: {total_loss / len(dataloader)}')

In [12]:
ds = TripletDataset("./data/triplet_dataset")
train_gat(encoder, ds, epochs=1000, lr=0.01)

{'x': [[1.0, 0.0, 0.0, 0.0, 0.0, -15.819159507751465, -4.691801071166992, -0.5335561633110046, -2.8367950916290283, -0.8650607466697693, -0.21454881131649017, 4.108437538146973, -0.40054208040237427, -1.0187499523162842, 1.0962814092636108, 4.073404788970947, -1.3514854907989502, -0.0996280387043953, 2.801846742630005, 1.2038699388504028, -0.07721380889415741, 1.0996593236923218, 0.4039984941482544, -0.4622865319252014, -0.3228813409805298, -3.8246653079986572, 1.0145349502563477, 0.29665258526802063, 0.6286441087722778, 1.0685454607009888, 0.060539521276950836, -0.8760803937911987, 0.8104182481765747, -1.7079638242721558, 0.52536940574646, -0.32546406984329224, 0.661298930644989], [0.0, 0.0, 0.0, 0.0, 1.0, -1.0593839883804321, -5.586912631988525, 1.1636171340942383, -3.7086212635040283, -1.1894723176956177, 0.4946366846561432, 0.831084668636322, -0.801983118057251, 0.44755661487579346, 1.4017194509506226, 2.0668883323669434, -0.6301614046096802, 0.1719898134469986, 1.7921817302703857,

AttributeError: 'dict' object has no attribute 'to'