In [None]:
import sys
import os
import random
from tqdm.auto import tqdm
import numpy as np
import torch

# Add src to path
sys.path.append('../../')

from src.kg_model.embeddings_model import EmbeddingsModel, EmbeddingsModelConfig, EmbedderModelConfig
from src.db_drivers.vector_driver import VectorDriverConfig, VectorDBConnectionConfig
from src.utils.data_structs import TripletCreator, NodeCreator


# Metrics Evaluation (Existing Graph)
This notebook evaluates **Link Prediction** metrics (Hits@1, Hits@5, MRR) on the existing Knowledge Graph stored in VectorDB.

In [None]:
# Configuration
NODES_DB_PATH = '../../data/graph_structures/vectorized_nodes/testing'
TRIPLETS_DB_PATH = '../../data/graph_structures/vectorized_triplets/testing'
EMBEDDER_MODEL_PATH = '../../models/intfloat/multilingual-e5-small'
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'

print(f"Using device: {DEVICE}")

embedder_config = EmbedderModelConfig(model_name_or_path=EMBEDDER_MODEL_PATH, device=DEVICE)
nodes_driver = VectorDriverConfig(db_config=VectorDBConnectionConfig(path=NODES_DB_PATH))
triplets_driver = VectorDriverConfig(db_config=VectorDBConnectionConfig(path=TRIPLETS_DB_PATH))

config = EmbeddingsModelConfig(
    nodesdb_driver_config=nodes_driver,
    tripletsdb_driver_config=triplets_driver,
    embedder_config=embedder_config,
    verbose=True
)

model = EmbeddingsModel(config)
model.embedder.init_model() # Ensure embedder is loaded

In [None]:
# Load Triplets for Evaluation
# Since we don't have the original Triplet objects here easily unless we load the pickle,
# we will try to fetch some metadata or assume we can rely on what's in the DB.
# Ideally we should load 'all_triplets.pkl' like in built_graph.ipynb.

import pickle
PKL_GRAPH_PATH = '../../data/pickled_graphs/all_triplets.pkl' # Adjusted path assumption
# If that fails, we might need to hunt for the file. For now let's hope it's there or user provides it.
try:
    with open(PKL_GRAPH_PATH, 'rb') as f:
        triplets = pickle.load(f)
    print(f"Loaded {len(triplets)} triplets from pickle.")
except FileNotFoundError:
    print("Pickle not found. Metrics will only work if we can query the DB blind or generating dummy queries.")
    triplets = []

In [None]:
def calculate_link_prediction_metrics(model, test_triplets, k_values=[1, 5, 10]):
    hits = {k: 0 for k in k_values}
    mrr = 0
    count = 0
    
    print(f"Evaluating on {len(test_triplets)} triplets...")
    
    # We need a 'VectorDBInstance' wrapper for query
    from src.db_drivers.vector_driver import VectorDBInstance
    
    for triplet in tqdm(test_triplets):
        s = triplet.start_node
        p = triplet.relation
        o = triplet.end_node
        
        # Construct Query: "Subject Relation"
        # This is a heuristic. Ideally, stringify(s) + stringify(p)
        s_str = NodeCreator.add_str_props(s, s.name)
        p_str = NodeCreator.add_str_props(p, p.name)
        query_text = f"{s_str} {p_str}"
        
        # Embed query
        # We treat this query as a 'passage' or 'query' depending on E5 model usage.
        # EmbedderModel.encode_passages is used in create_instances.
        # We should use the same method model uses.
        # model.embedder is available.
        
        # Create a dummy instance for retrieval
        # embedder expects list of strings
        query_embedding = model.embedder.encode_passages([query_text])[0]
        query_inst = VectorDBInstance(id='query', document=query_text, embedding=query_embedding)
        
        # Retrieve candidates from NODES
        # Limit 50 or max(k_values) + buffer
        top_k = 100
        results = model.vectordbs['nodes'].retrieve([query_inst], n_results=top_k, includes=['ids', 'metadatas'])
        
        # results[0] is list of (distance, instance)
        candidates = results[0]
        
        # Find rank of o.id
        rank = None
        for i, (dist, inst) in enumerate(candidates):
            if inst.id == o.id:
                rank = i + 1
                break
        
        if rank is not None:
            mrr += 1.0 / rank
            for k in k_values:
                if rank <= k:
                    hits[k] += 1
        
        count += 1
    
    # Average
    mrr /= count if count > 0 else 1
    for k in k_values:
        hits[k] /= count if count > 0 else 1
        
    return hits, mrr

In [None]:
# Run Evaluation (Sample)
SAMPLE_SIZE = 100
if len(triplets) > SAMPLE_SIZE:
    test_sample = random.sample(triplets, SAMPLE_SIZE)
else:
    test_sample = triplets

if test_sample:
    hits, mrr = calculate_link_prediction_metrics(model, test_sample)
    print("Results:")
    print(f"MRR: {mrr:.4f}")
    for k, v in hits.items():
        print(f"Hits@{k}: {v:.4f}")
else:
    print("No triplets to test.")