In [None]:
import sys, os

project_root = os.path.abspath(os.path.join(os.getcwd(), "../.."))
if project_root not in sys.path:
    sys.path.append(project_root)

from pykeen.pipeline import pipeline
from pykeen.triples import TriplesFactory
import csv
import pandas as pd
import numpy as np

In [None]:
url = os.getenv('NEO4J_URI')
username = 'neo4j'
password = os.getenv('NEO4J_AUTH')

from neo4j import GraphDatabase, Result

driver = GraphDatabase.driver(url, auth=(username, password), keep_alive=True)

#### Export Node relationship into csv

In [None]:
def export_all():

    triples_query = """
    MATCH (h:Test_embedding)-[r]->(t:Test_embedding)
    RETURN h.id AS head, type(r) AS rel, t.id AS tail
    """

    with driver.session() as session:
        result = session.run(triples_query)

        with open('triple.csv', "w", newline="", encoding="utf-8") as f:
            writer = csv.writer(f)
            writer.writerow(["head", "relation", "tail"])

            for record in result:
                writer.writerow([
                    record["head"],
                    record["rel"],
                    record["tail"]
                ])
                
export_all()

#### Train ComplEx

In [None]:
df = pd.read_csv('triple.csv')

triples = df[['head', 'relation', 'tail']].values.tolist()

# Save temporary triples file in TSV format (PyKEEN requirement)
triples_tsv = "triples_temp.tsv"
df[['head', 'relation', 'tail']].to_csv(triples_tsv, sep="\t", index=False, header=False)

# Create a single TriplesFactory for all splits
tf = TriplesFactory.from_path(triples_tsv)

result = pipeline(
    model='ComplEx',
    training=tf,
    validation=tf,
    testing=tf,
    model_kwargs={'embedding_dim': 100}, 
    training_kwargs={'num_epochs': 100},
)

model = result.model

# Extract entity embeddings
entity_embeddings = model.entity_representations[0]().cpu().detach().numpy()
entities = list(tf.entity_to_id.keys())

final_embeddings = {
    entity: entity_embeddings[idx]
    for idx, entity in enumerate(entities)
}

print("Entities:", len(final_embeddings))
print("ComplEx embedding dim:", len(next(iter(final_embeddings.values()))))

INFO:pykeen.pipeline.api:Using device: None
INFO:pykeen.nn.representation:Inferred unique=False for Embedding(
  (regularizer): LpRegularizer()
)
INFO:pykeen.nn.representation:Inferred unique=False for Embedding(
  (regularizer): LpRegularizer()
)
Training epochs on cpu: 100%|██████████| 100/100 [00:51<00:00,  1.93epoch/s, loss=0.919, prev_loss=0.943]
Evaluating on cpu: 100%|██████████| 3.48k/3.48k [00:01<00:00, 1.92ktriple/s]
INFO:pykeen.evaluation.evaluator:Evaluation took 1.90s seconds


Entities: 2368
ComplEx embedding dim: 100


#### Convert Complex embedding to Real embedding by concatenating Re and Img values

In [None]:
entity_embeddings = model.entity_representations[0]().cpu().detach().numpy()
entity_ids = tf.entity_to_id 

def complex_to_real_vector(complex_array):
    # complex_array is shape (dim,), dtype=complex
    real = complex_array.real
    imag = complex_array.imag
    return np.concatenate([real, imag])

df_complex = pd.DataFrame({
    'node': list(entity_ids.keys()),
    'complex_embedding': list(entity_embeddings)
})

df_complex['real_embedding'] = df_complex['complex_embedding'].apply(complex_to_real_vector)

#### Get node information, especially original_embedding

In [None]:
neo_query = """
MATCH (n:Test_embedding)
RETURN n.id AS node,
       n.original_embedding AS original_embedding
"""

with driver.session() as session:
    neo_data = session.run(neo_query).data()

df_neo = pd.DataFrame(neo_data)

#### Concat Original + Complex embedding and write back to Neo4j

In [None]:
df = df_neo.merge(df_complex, on='node', how='inner')

df["combined_embedding"] = df.apply(
    lambda row: row["original_embedding"] + row["real_embedding"].tolist(),
    axis=1
)

records = [
    {"node_id": row["node"], "embedding": row["combined_embedding"]}
    for _, row in df.iterrows()
]

query = """
UNWIND $rows AS r
MATCH (n {id: r.node_id})
SET n.combined_embedding = r.embedding
"""

with driver.session() as session:
    session.run(query, rows=records)

driver.close()