In [None]:
import sys
import os
import random
import shutil
from tqdm.auto import tqdm
import torch

# Add src to path
sys.path.append('../../')

from src.kg_model.embeddings_model import EmbeddingsModel, EmbeddingsModelConfig, EmbedderModelConfig
from src.db_drivers.vector_driver import VectorDriverConfig, VectorDBConnectionConfig
from src.utils.data_structs import TripletCreator, NodeCreator, NodeType, RelationCreator, RelationType


# Metrics Evaluation (Temporal Graph)
This notebook evaluates **Link Prediction** metrics on a graph where triplets are enriched with a **Temporal Dimension (Time)**.
We use a separate/temporary VectorDB to avoid polluting the main graph.

In [None]:
# Configuration for Temporal Test
NODES_DB_PATH = '../../data/graph_structures/vectorized_nodes/temporal_test'
TRIPLETS_DB_PATH = '../../data/graph_structures/vectorized_triplets/temporal_test'
EMBEDDER_MODEL_PATH = '../../models/intfloat/multilingual-e5-small'
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'

print(f"Using device: {DEVICE}")

# Ensure fresh start
if os.path.exists(NODES_DB_PATH):
    shutil.rmtree(NODES_DB_PATH)
if os.path.exists(TRIPLETS_DB_PATH):
    shutil.rmtree(TRIPLETS_DB_PATH)

embedder_config = EmbedderModelConfig(model_name_or_path=EMBEDDER_MODEL_PATH, device=DEVICE)
nodes_driver = VectorDriverConfig(db_config=VectorDBConnectionConfig(path=NODES_DB_PATH, need_to_clear=True))
triplets_driver = VectorDriverConfig(db_config=VectorDBConnectionConfig(path=TRIPLETS_DB_PATH, need_to_clear=True))

config = EmbeddingsModelConfig(
    nodesdb_driver_config=nodes_driver,
    tripletsdb_driver_config=triplets_driver,
    embedder_config=embedder_config,
    verbose=True
)

model = EmbeddingsModel(config)
model.embedder.model = model.embedder.model.to(DEVICE) # Explicit move just in case
model.embedder.init_model()

In [None]:
# Load Triplets and Add Fake Time
import pickle
PKL_GRAPH_PATH = '../../data/pickled_graphs/all_triplets.pkl'

try:
    with open(PKL_GRAPH_PATH, 'rb') as f:
        original_triplets = pickle.load(f)
    print(f"Loaded {len(original_triplets)} triplets from pickle.")
except FileNotFoundError:
    print("Pickle not found. Using dummy data.")
    # Create some dummy data if pickle missing
    original_triplets = []
    for i in range(10):
        s = NodeCreator.create(NodeType.object, f"Subject_{i}")
        p = RelationCreator.create(RelationType.simple, "related_to")
        o = NodeCreator.create(NodeType.object, f"Object_{i}")
        original_triplets.append(TripletCreator.create(s, p, o))

# Add Temporal Dimension
temporal_triplets = []
years = ['2020', '2021', '2022', '2023', '2024']

for t in tqdm(original_triplets):
    # Pick a random year
    random_year = random.choice(years)
    time_node = NodeCreator.create(NodeType.time, random_year)
    
    # Create new triplet with time
    # Note: We must recreate s, p, o to ensure no ID conflict if needed, 
    # but reuse is fine as long as we use the new create method.
    new_triplet = TripletCreator.create(
        start_node=t.start_node,
        relation=t.relation,
        end_node=t.end_node,
        time=time_node
    )
    temporal_triplets.append(new_triplet)

print(f"Created {len(temporal_triplets)} temporal triplets.")
print(f"Sample stringified: {temporal_triplets[0].stringified}")

In [None]:
# Add to VectorDB
# This utilizes the updated create_triplets method which handles time nodes.
info = model.create_triplets(temporal_triplets, batch_size=64)

In [None]:
# Evaluation Function (Temporal)
def calculate_temporal_metrics(model, test_triplets, k_values=[1, 5, 10]):
    hits = {k: 0 for k in k_values}
    mrr = 0
    count = 0
    
    print(f"Evaluating on {len(test_triplets)} triplets...")
    from src.db_drivers.vector_driver import VectorDBInstance

    for triplet in tqdm(test_triplets):
        s = triplet.start_node
        p = triplet.relation
        o = triplet.end_node
        t = triplet.time
        
        # Construct Temporal Query: "Time: t | Subject Relation"
        # Must match the stringify logic: "{t.name}: {s} {p} {o}"
        # If we predict O, we query: "{t.name}: {s} {p}"
        
        s_str = NodeCreator.add_str_props(s, s.name)
        p_str = NodeCreator.add_str_props(p, p.name)
        
        if t:
            query_text = f"{t.name}: {s_str} {p_str}"
        else:
            query_text = f"{s_str} {p_str}"

        # Embed query
        query_embedding = model.embedder.encode_passages([query_text])[0]
        query_inst = VectorDBInstance(id='query', document=query_text, embedding=query_embedding)
        
        # Retrieve candidates from NODES
        # We expect 'o' to be in the retrieved nodes.
        # Note: 'o' itself doesn't have time info in its node string, usually.
        # The association is in the Triplet vector.
        # However, we are testing if the context of 'Time' helps/allows retrieval.
        # In this naive setup, we check if the embedding of "2021: John related_to" is close to "Mary".
        
        top_k = 100
        results = model.vectordbs['nodes'].retrieve([query_inst], n_results=top_k, includes=['ids'])
        candidates = results[0]
        
        rank = None
        for i, (dist, inst) in enumerate(candidates):
            if inst.id == o.id:
                rank = i + 1
                break
        
        if rank is not None:
            mrr += 1.0 / rank
            for k in k_values:
                if rank <= k:
                    hits[k] += 1
        
        count += 1

    mrr /= count if count > 0 else 1
    for k in k_values:
        hits[k] /= count if count > 0 else 1
        
    return hits, mrr

In [None]:
# Run Evaluation (Sample)
SAMPLE_SIZE = 100
if len(temporal_triplets) > SAMPLE_SIZE:
    test_sample = random.sample(temporal_triplets, SAMPLE_SIZE)
else:
    test_sample = temporal_triplets

if test_sample:
    hits, mrr = calculate_temporal_metrics(model, test_sample)
    print("Temporal Results:")
    print(f"MRR: {mrr:.4f}")
    for k, v in hits.items():
        print(f"Hits@{k}: {v:.4f}")
else:
    print("No triplets to test.")