In [86]:
import torch
import pandas as pd
from graphdatascience import GraphDataScience
from torch_geometric.data import Data
import torch.nn.functional as F

In [87]:
from src.shared.database_wrapper import DatabaseWrapper
from src.model.GAT.gat_encoder import GATEncoder
from src.model.GAT.gat_decoder import GATDecoder
from src.shared.graph_schema import NodeType, EdgeType, AuthorEdge, PublicationEdge

In [88]:
db = DatabaseWrapper()

2024-08-01 14:08:12,977 - DatabaseWrapper - INFO - Connecting to the database ...
2024-08-01 14:08:12,982 - DatabaseWrapper - INFO - Database ready.


In [89]:
node_feature_dim = 8
node_to_one_hot = {
    NodeType.PUBLICATION.value: F.one_hot(torch.tensor(0), node_feature_dim),
    NodeType.AUTHOR.value: F.one_hot(torch.tensor(1), node_feature_dim),
    NodeType.CO_AUTHOR.value: F.one_hot(torch.tensor(2), node_feature_dim),
    NodeType.ORGANIZATION.value: F.one_hot(torch.tensor(3), node_feature_dim),
    NodeType.VENUE.value: F.one_hot(torch.tensor(4), node_feature_dim),
}
edge_feature_dim = 8
edge_to_one_hot = {
    PublicationEdge.AUTHOR: F.one_hot(torch.tensor(0), edge_feature_dim),
    PublicationEdge.VENUE: F.one_hot(torch.tensor(1), edge_feature_dim),
    AuthorEdge.ORGANIZATION: F.one_hot(torch.tensor(2), edge_feature_dim),
    AuthorEdge.PUBLICATION: F.one_hot(torch.tensor(3), edge_feature_dim),
}

def project_single(n):
    return torch.hstack((node_to_one_hot[list(n.labels)[0]], torch.tensor(n['vec'])))

def project_pub_title_and_abstract(nodes):
    pass

projection_map = {
    NodeType.PUBLICATION.value: project_single,
    NodeType.AUTHOR.value: lambda x: None,
    NodeType.CO_AUTHOR.value: lambda x: None,
    NodeType.ORGANIZATION.value: project_single,
    NodeType.VENUE.value: project_single
}

def convert_to_pyg(nodes, relationships):
    # Create a mapping from Neo4j node IDs to consecutive integers
    node_id_mapping = {}
    node_features = []
    for i, n in enumerate(nodes):
        print(n.labels)
        feature_vec = projection_map[list(n.labels)[0]](n)
        if feature_vec is not None:
            node_features.append(feature_vec)
            node_id_mapping[n['id']] = i
            
    x = torch.stack(node_features)

    # Create edge index
    edge_index = []
    edge_features = []
    for rel in relationships:
        if list(rel.labels)[0] not in edge_to_one_hot:
            continue
        if rel.start_node['id'] not in node_id_mapping or rel.end_node['id'] not in node_id_mapping:
            continue
        source = node_id_mapping[rel.start_node['id']]
        target = node_id_mapping[rel.end_node['id']]
        edge_index.append([source, target])
        edge_features.append(edge_to_one_hot[list(rel.labels)[0]])

    edge_index = torch.tensor(edge_index, dtype=torch.long).t().contiguous()

    # Create PyTorch Geometric data object
    return Data(x=x, edge_index=edge_index, edge_attr=torch.stack(edge_features))


In [90]:
encoder = GATEncoder(
    node_feature_dim=node_feature_dim + 32, 
    edge_feature_dim=edge_feature_dim, 
    embedding_dim=32
)

decoder = GATDecoder(
    num_features=node_feature_dim + 32, 
    num_edge_features=edge_feature_dim, 
    embedding_dim=32, 
    num_node_types=len(node_to_one_hot.keys()), 
    num_edge_types=len(edge_to_one_hot.keys())
)

def train_gat(model, data):
    

IndentationError: expected an indented block (4001040235.py, line 3)

In [None]:
for nodes in db.iter_nodes(NodeType.PUBLICATION, ["id"]):
    for node in nodes:
        nodes, rels = db.fetch_neighborhood(NodeType.PUBLICATION, node["id"], 5)
        data = convert_to_pyg(nodes, rels)
        break