In [73]:
import torch
import pandas as pd
from graphdatascience import GraphDataScience
from torch_geometric.data import Data
import torch.nn.functional as F
import torch.optim as optim

In [74]:
from src.shared.database_wrapper import DatabaseWrapper
from src.model.GAT.gat_encoder import GATv2Encoder
from src.model.GAT.gat_decoder import GATv2Decoder
from src.shared.graph_schema import NodeType, EdgeType, AuthorEdge, PublicationEdge, SimilarityEdge

In [75]:
db = DatabaseWrapper()

2024-08-02 10:01:04,388 - DatabaseWrapper - INFO - Connecting to the database ...
2024-08-02 10:01:04,409 - DatabaseWrapper - INFO - Database ready.


In [76]:
node_feature_dim = 4
node_to_one_hot = {
    NodeType.PUBLICATION.value: F.one_hot(torch.tensor(0), node_feature_dim).type(torch.float32),
    NodeType.ORGANIZATION.value: F.one_hot(torch.tensor(1), node_feature_dim).type(torch.float32),
    NodeType.VENUE.value: F.one_hot(torch.tensor(2), node_feature_dim).type(torch.float32),
}
edge_feature_dim = 8
edge_to_one_hot = {
    PublicationEdge.AUTHOR.value: F.one_hot(torch.tensor(0), edge_feature_dim).type(torch.float32),
    PublicationEdge.VENUE.value: F.one_hot(torch.tensor(1), edge_feature_dim).type(torch.float32),
    AuthorEdge.ORGANIZATION.value: F.one_hot(torch.tensor(2), edge_feature_dim).type(torch.float32),
    AuthorEdge.PUBLICATION.value: F.one_hot(torch.tensor(3), edge_feature_dim).type(torch.float32),
    SimilarityEdge.SIM_ORG.value: F.one_hot(torch.tensor(4), edge_feature_dim).type(torch.float32),
    SimilarityEdge.SIM_VENUE.value: F.one_hot(torch.tensor(5), edge_feature_dim).type(torch.float32),
    SimilarityEdge.SIM_TITLE.value: F.one_hot(torch.tensor(6), edge_feature_dim).type(torch.float32)
}

def project_single(n):
    return torch.hstack((node_to_one_hot[list(n.labels)[0]], torch.tensor(n['vec'])))

def project_pub_title_and_abstract(nodes):
    pass

projection_map = {
    NodeType.PUBLICATION.value: project_single,
    NodeType.AUTHOR.value: lambda x: None,
    NodeType.CO_AUTHOR.value: lambda x: None,
    NodeType.ORGANIZATION.value: project_single,
    NodeType.VENUE.value: project_single,
    NodeType.TRUE_AUTHOR.value: lambda x: None
}

def convert_to_pyg(nodes, relationships):
    # Create a mapping from Neo4j node IDs to consecutive integers
    node_id_mapping = {}
    node_features = []
    node_index = 0
    for n in nodes:
        feature_vec = projection_map[list(n.labels)[0]](n)
        if feature_vec is not None:
            node_features.append(feature_vec)
            node_id_mapping[n['id']] = node_index
            node_index += 1
    
    if len(node_features) == 0:
        return None
    
    x = torch.stack(node_features)

    # Create edge index
    edge_index = []
    edge_features = []
    for rel in relationships:
        if rel.type not in edge_to_one_hot:
            continue
        if rel.start_node['id'] not in node_id_mapping or rel.end_node['id'] not in node_id_mapping:
            continue
            
        source = node_id_mapping[rel.start_node['id']]
        target = node_id_mapping[rel.end_node['id']]
        edge_index.append([source, target])
        edge_features.append(edge_to_one_hot[rel.type])
    
    if len(edge_index) == 0:
        return None
    
    edge_index = torch.tensor(edge_index, dtype=torch.long).t().contiguous()

    # Create PyTorch Geometric data object
    return Data(x=x, edge_index=edge_index, edge_attr=torch.stack(edge_features))


In [77]:
gat_embedding_dim = 32

encoder = GATv2Encoder(
    in_channels=node_feature_dim + 32,
    out_channels=gat_embedding_dim,
    edge_dim=edge_feature_dim,
    add_self_loops=False
)

decoder = GATv2Decoder(
    in_channels=gat_embedding_dim,
    out_channels=node_feature_dim + 32
)
    

In [78]:

def train_gat(encoder, decoder, data, epochs=100, lr=0.01):
    # Define the optimizer for the encoder and decoder
    optimizer = optim.SGD(list(encoder.parameters()) + list(decoder.parameters()), lr=lr)
    
    # Define a loss function
    criterion = torch.nn.MSELoss()

    # Training loop
    for epoch in range(epochs):
        encoder.train()
        decoder.train()

        # Zero gradients
        optimizer.zero_grad()

        # Forward pass through the encoder
        encoded_nodes = encoder(data.x, data.edge_index, data.edge_attr)

        # Forward pass through the decoder
        decoded_graph = decoder(encoded_nodes, data.edge_index, data.edge_attr)

        # Compute loss (assuming your decoder returns node features to be compared with the original)
        loss = criterion(decoded_graph, data.x)

        # Backward pass
        loss.backward()

        # Optimize the parameters
        optimizer.step()

        # Print loss every 10 epochs
        if epoch % 10 == 0:
            print(f'Epoch {epoch}, Loss: {loss.item()}')
    

In [79]:
for nodes in db.iter_nodes(NodeType.PUBLICATION, ["id"]):
    for node in nodes:
        nodes, rels = db.fetch_neighborhood(NodeType.PUBLICATION, node["id"], 5)
        data = convert_to_pyg(nodes, rels)
        if data:
            train_gat(encoder, decoder, data)

Epoch 0, Loss: 3.529656410217285
Epoch 10, Loss: 2.7048017978668213
Epoch 20, Loss: 1.817945122718811
Epoch 30, Loss: 1.1770896911621094
Epoch 40, Loss: 0.8790652751922607
Epoch 50, Loss: 0.7660320401191711
Epoch 60, Loss: 0.7253664135932922
Epoch 70, Loss: 0.7102127075195312
Epoch 80, Loss: 0.7038132548332214
Epoch 90, Loss: 0.700435221195221
Epoch 0, Loss: 1.5151829719543457
Epoch 10, Loss: 0.8650241494178772
Epoch 20, Loss: 0.5606831312179565
Epoch 30, Loss: 0.43871623277664185
Epoch 40, Loss: 0.39382728934288025
Epoch 50, Loss: 0.37760454416275024
Epoch 60, Loss: 0.3715980648994446
Epoch 70, Loss: 0.36917877197265625
Epoch 80, Loss: 0.3680095374584198
Epoch 90, Loss: 0.36726468801498413
Epoch 0, Loss: 1.8850826025009155
Epoch 10, Loss: 1.0667037963867188
Epoch 20, Loss: 0.9207552075386047
Epoch 30, Loss: 0.8582070469856262
Epoch 40, Loss: 0.8296349048614502
Epoch 50, Loss: 0.8154789209365845
Epoch 60, Loss: 0.8075219988822937
Epoch 70, Loss: 0.8023483753204346
Epoch 80, Loss: 0.798

RuntimeError: stack expects a non-empty TensorList