# TComplex + Finetuned E5 Benchmark

This notebook benchmarks the performance of the TComplex integration against a Finetuned E5-only baseline on the CronKGQA test set.

In [1]:
import os
import sys

# Robustly set project root
def set_project_root():
    current_dir = os.getcwd()
    if 'notebooks' in current_dir:
        while 'src' not in os.listdir(current_dir):
            parent = os.path.dirname(current_dir)
            if parent == current_dir:
                break
            current_dir = parent
        os.chdir(current_dir)
        print(f"Changed directory to project root: {current_dir}")
    
    if current_dir not in sys.path:
        sys.path.append(current_dir)
        print(f"Added {current_dir} to sys.path")

set_project_root()
import pickle
import numpy as np
import pandas as pd
from tqdm.auto import tqdm
import re
import torch

from src.kg_model.knowledge_graph_model import KnowledgeGraphModel, KnowledgeGraphModelConfig
from src.db_drivers.vector_driver.embedders import EmbedderModelConfig
from src.kg_model.embeddings_model import EmbeddingsModelConfig
from src.db_drivers.vector_driver import VectorDriverConfig, VectorDBConnectionConfig
from src.kg_model.graph_model import GraphModelConfig
from src.db_drivers.graph_driver import GraphDriverConfig
from src.utils.data_structs import TripletCreator
from src.utils.kg_navigator import KGNavigator

# Force PyTorch
os.environ['USE_TF'] = '0'


Changed directory to project root: /Users/nmuravya/Desktop/KG_sber/Personal-AI-dev 2
Added /Users/nmuravya/Desktop/KG_sber/Personal-AI-dev 2 to sys.path




✓ Используется Apple Silicon GPU (MPS)


✓ Используется Apple Silicon GPU (MPS)
✓ Используется Apple Silicon GPU (MPS)


In [2]:
# --- Define Standalone QA Engine to avoid Import Hell ---
class StandaloneQAEngine:
    def __init__(self, kg_model, finetuned_model_path):
        self.kg_model = kg_model
        
        # Init Embedder
        from sentence_transformers import SentenceTransformer
        if os.path.exists(finetuned_model_path):
            model_name = finetuned_model_path
        else:
            model_name = "intfloat/multilingual-e5-small"
            
        print(f"Loading S-BERT: {model_name}")
        self.encoder = SentenceTransformer(model_name)
        
        # Load Mapper
        from src.utils.wikidata_utils import WikidataMapper
        kg_data_path = "wikidata_big/kg"
        self.mapper = WikidataMapper(kg_data_path)
        
        # Init Temporal Scorer
        self.temporal_scorer = None
        try:
            from src.kg_model.temporal.temporal_model import TemporalScorer
            self.temporal_scorer = TemporalScorer(device="cuda" if torch.cuda.is_available() else "cpu")
        except Exception as e:
            print(f"Warning: Could not initialize TemporalScorer: {e}")

        # Extractor placeholder
        self.extractor_override = None 

    def get_ranked_results(self, query: str, top_k: int = 5):
        # 1. Extract entities (Mocked)
        if self.extractor_override:
            extraction, _ = self.extractor_override.perform(query)
            # Handle list or dict return
            if isinstance(extraction, list):
                item = extraction[0]
            else:
                item = extraction
            entities = item.get('entities', [])
            query_time = item.get('time')
        else:
            entities = [query]
            query_time = None
            
        # 2. Match nodes via Mapper (Strict ID lookup)
        mapped_ids = []
        for ent in entities:
            wd_id = self.mapper.get_id(ent)
            if wd_id:
                mapped_ids.append(wd_id)
        
        all_matched_nodes = []
        connector = self.kg_model.graph_struct.db_conn
        for mid in mapped_ids:
            if hasattr(connector, 'strid_nodes_index'):
                internal_ids = connector.strid_nodes_index.get(mid, [])
                for iid in internal_ids:
                    if iid in connector.nodes:
                        all_matched_nodes.append(connector.nodes[iid])

        if not all_matched_nodes:
            return []

        # 3. Retrieve Neighborhood
        node_ids = [n.id for n in all_matched_nodes]
        nav = KGNavigator(self.kg_model)
        candidate_triplets = nav.get_neighborhood(node_ids, depth=1)
        
        if not candidate_triplets:
            return []

        # Deduplicate
        seen = set()
        unique_candidates = []
        for t in candidate_triplets:
            if t.id not in seen:
                unique_candidates.append(t)
                seen.add(t.id)

        # 4. Rank
        query_emb = self.encoder.encode([query])[0]
        
        triplets_text = []
        for t in unique_candidates:
             _, text = TripletCreator.stringify(t)
             triplets_text.append(text)
             
        triplet_embs = self.encoder.encode(triplets_text)
        
        results = []
        def sigmoid(x): return 1 / (1 + np.exp(-x))

        for t, text, emb in zip(unique_candidates, triplets_text, triplet_embs):
            score = np.dot(query_emb, emb) / (np.linalg.norm(query_emb) * np.linalg.norm(emb))
            e5_conf = float(max(0, score))
            final_conf = e5_conf
            
            if query_time and self.temporal_scorer:
                s_qid = t.start_node.prop.get('wd_id')
                r_pid = t.relation.prop.get('wd_id')
                o_qid = t.end_node.prop.get('wd_id')
                if s_qid and r_pid and o_qid:
                    try:
                        logit = self.temporal_scorer.score(s_qid, r_pid, o_qid, query_time)
                        # Check if logit is valid (not super negative indicating unknown)
                        if logit > -10.0:
                            # Weighted combination: 70% Semantic, 30% Temporal
                            final_conf = (e5_conf * 0.7) + (sigmoid(logit) * 0.3)
                    except Exception as e:
                        pass
            
            results.append({'triplet': t, 'confidence': final_conf})
            
        results.sort(key=lambda x: x['confidence'], reverse=True)
        return results[:top_k]


In [3]:
print("Initializing QA Engine...")

# 1. Initialize QA Engine with headless config
g_driver_conf = GraphDriverConfig(db_vendor='inmemory_graph')
g_model_conf = GraphModelConfig(driver_config=g_driver_conf)

model_path = "models/finetuned_e5"
if not os.path.exists(model_path):
    model_path = "intfloat/multilingual-e5-small"
    
emb_conf = EmbedderModelConfig(model_name_or_path=model_path)

# Correct paths for vector DBs
nodes_path = "data/graph_structures/vectorized_nodes/wikidata_test"
triplets_path = "data/graph_structures/vectorized_triplets/wikidata_test"

nodes_cfg = VectorDriverConfig(
    db_vendor='chroma', db_config=VectorDBConnectionConfig(
        conn={'path': nodes_path},
        db_info={'db': 'default_db', 'table': "personalaitable"}))
        
triplets_cfg = VectorDriverConfig(
    db_vendor='chroma', db_config=VectorDBConnectionConfig(
        conn={'path': triplets_path},
        db_info={'db': 'default_db', 'table': "personalaitable"}))

embs_conf = EmbeddingsModelConfig(
    nodesdb_driver_config=nodes_cfg,
    tripletsdb_driver_config=triplets_cfg,
    embedder_config=emb_conf
)

kg_conf = KnowledgeGraphModelConfig(graph_config=g_model_conf, embeddings_config=embs_conf)
kg_model = KnowledgeGraphModel(config=kg_conf)

# Hydrate Graph (Crucial for ID mapping)
from src.utils.wikidata_utils import WikidataMapper
from src.utils.graph_loader import hydrate_in_memory_graph
kg_data_path = "wikidata_big/kg"
mapper = WikidataMapper(kg_data_path)
hydrate_in_memory_graph(kg_model, mapper, kg_data_path)

# Init Engine using Standalone class
engine = StandaloneQAEngine(kg_model, model_path)
print("Engine Ready.")

Initializing QA Engine...
✓ Используется Apple Silicon GPU (MPS)




Loading weights:   0%|          | 0/199 [00:00<?, ?it/s]

BertModel LOAD REPORT from: intfloat/multilingual-e5-small
Key                     | Status     |  | 
------------------------+------------+--+-
embeddings.position_ids | UNEXPECTED |  | 

Notes:
- UNEXPECTED	:can be ignored when loading from different task/architecture; not ok if you expect identical arch.


Hydrating graph from wikidata_big/kg/full.txt...


Processing triplets:   0%| 

Processing triplets:   2%| 

Processing triplets:   5%| 

Processing triplets:   7%| 

Processing triplets:   9%| 

Processing triplets:  12%| 

Processing triplets:  15%|▏

Processing triplets:  17%|▏

Processing triplets:  20%|▏

Processing triplets:  23%|▏

Processing triplets:  25%|▎

Processing triplets:  28%|▎

Processing triplets:  31%|▎

Processing triplets:  34%|▎

Processing triplets:  36%|▎

Processing triplets:  39%|▍

Processing triplets:  42%|▍

Processing triplets:  45%|▍

Processing triplets:  48%|▍

Processing triplets:  51%|▌

Processing triplets:  53%|▌

Processing triplets:  56%|▌

Processing triplets:  59%|▌

Processing triplets:  62%|▌

Processing triplets:  65%|▋

Processing triplets:  68%|▋

Processing triplets:  71%|▋

Processing triplets:  74%|▋

Processing triplets:  77%|▊

Processing triplets:  80%|▊

Processing triplets:  83%|▊

Processing triplets:  85%|▊

Processing triplets:  88%|▉

Processing triplets:  91%|▉

Processing triplets:  94%|▉

Processing triplets:  97%|▉

Processing triplets: 100%|█




Hydration complete. Nodes: 122569, Triplets: 328635
Loading S-BERT: intfloat/multilingual-e5-small


Loading weights:   0%|          | 0/199 [00:00<?, ?it/s]

BertModel LOAD REPORT from: intfloat/multilingual-e5-small
Key                     | Status     |  | 
------------------------+------------+--+-
embeddings.position_ids | UNEXPECTED |  | 

Notes:
- UNEXPECTED	:can be ignored when loading from different task/architecture; not ok if you expect identical arch.


Loading TemporalScorer resources from wikidata_big/kg/tkbc_processed_data/wikidata_big/...


Loaded mappings: 125726 entities, 203 relations, 9621 timestamps


Loading weights from models/cronkgqa/tcomplex.ckpt...


TemporalScorer initialized successfully.
Engine Ready.


In [4]:
# Load Test Data
test_path = 'wikidata_big/questions/test.pickle'
with open(test_path, 'rb') as f:
    test_data = pickle.load(f)

# Helper classes
class MockExtractor:
    def __init__(self, entities, time=None):
        self.entities = list(entities)
        self.time = time
    def perform(self, query):
        return [{'entities': self.entities, 'time': self.time}], None

def evaluate_sample(sample_data, use_temporal=True):
    results = []
    # Toggle scorer
    original_scorer = engine.temporal_scorer
    if not use_temporal:
        engine.temporal_scorer = None
        
    for item in tqdm(sample_data):
        q = item['question']
        gt_ids = item['entities']
        answers = item['answers']
        
        # Map IDs to Names for Search
        names = [mapper.get_label(qid) for qid in gt_ids]
        
        # IMPROVED: Extract time from dataset 'times' field if available
        # Fallback to regex if 'times' is empty
        t_set = item.get('times', set())
        if t_set and len(t_set) > 0:
            # Take the first year found
            t = list(t_set)[0]
        else:
            matches = re.findall(r'\b(1\d{3}|20\d{2})\b', q)
            t = int(matches[0]) if matches else None
        
        engine.extractor_override = MockExtractor(names, t)
        
        try:
            ranked = engine.get_ranked_results(q, top_k=10)
        except Exception as e:
            ranked = []
            
        rank = 1000
        
        # Special logic for simple_time (Answer is a Year, not Entity ID)
        if item['type'] == 'simple_time':
            for i, res in enumerate(ranked):
                t_obj = res['triplet']
                extracted_year = None
                
                # Try to find time in triplet properties
                # 1. Triplet.time node
                if t_obj.time and t_obj.time.name:
                    extracted_year = t_obj.time.name
                # 2. Relation property
                elif 'time' in t_obj.relation.prop:
                    extracted_year = t_obj.relation.prop['time']
                # 3. End node property (common for episodic)
                elif 'time' in t_obj.end_node.prop:
                    extracted_year = t_obj.end_node.prop['time']
                
                # Check if extracted year matches any answer
                if extracted_year:
                    try:
                        # Handle date strings like "2008-01-01" or "2008"
                        y_str = str(extracted_year).split('-')[0]
                        if y_str.isdigit():
                             y_int = int(y_str)
                             # Answers are a set of ints or strings.
                             if y_int in answers or str(y_int) in answers:
                                 rank = i + 1
                                 break
                    except:
                        pass
        else:
            # Standard Entity Logic
            for i, res in enumerate(ranked):
                t = res['triplet']
                s = t.start_node.prop.get('wd_id')
                o = t.end_node.prop.get('wd_id')
                if s in answers or o in answers:
                    rank = i + 1
                    break
        
        results.append({
            'question_type': item['type'],
            'rank': rank,
            'hit_1': 1 if rank == 1 else 0,
            'hit_5': 1 if rank <= 5 else 0,
            'mrr': 1.0/rank if rank <= 10 else 0.0
        })
        
    engine.temporal_scorer = original_scorer
    return pd.DataFrame(results)


In [5]:
# Run Benchmark
SAMPLE_SIZE = 500 # Adjust as needed
subset = test_data[:SAMPLE_SIZE]

print("Running Baseline...")
df_base = evaluate_sample(subset, use_temporal=False)

print("Running Proposed...")
df_prop = evaluate_sample(subset, use_temporal=True)

print("Done.")

Running Baseline...


  0%|          | 0/500 [00:00<?, ?it/s]

Running Proposed...


  0%|          | 0/500 [00:00<?, ?it/s]

Done.


In [6]:
# Analysis
def compare_metrics(df_base, df_prop):
    metrics = []
    # Overall
    for name, df in [('Baseline', df_base), ('Proposed', df_prop)]:
        metrics.append({
            'Config': name,
            'Type': 'Overall',
            'MRR': df['mrr'].mean(),
            'Hits@1': df['hit_1'].mean(),
            'Hits@5': df['hit_5'].mean()
        })
    # Per Type
    types = df_base['question_type'].unique()
    for t in types:
        for name, df in [('Baseline', df_base), ('Proposed', df_prop)]:
            sub = df[df['question_type'] == t]
            metrics.append({
                'Config': name,
                'Type': t,
                'MRR': sub['mrr'].mean(),
                'Hits@1': sub['hit_1'].mean(),
                'Hits@5': sub['hit_5'].mean()
            })
    return pd.DataFrame(metrics).sort_values(['Type', 'Config'])

final = compare_metrics(df_base, df_prop)
print(final)
final.to_csv("benchmark_results.csv")

      Config           Type       MRR    Hits@1    Hits@5
0   Baseline        Overall  0.514787  0.440000  0.620000
1   Proposed        Overall  0.517887  0.444000  0.620000
6   Baseline   before_after  0.231810  0.161290  0.322581
7   Proposed   before_after  0.231810  0.161290  0.322581
4   Baseline     first_last  0.373344  0.309524  0.446429
5   Proposed     first_last  0.373344  0.309524  0.446429
10  Baseline  simple_entity  0.892109  0.843972  0.957447
11  Proposed  simple_entity  0.903101  0.858156  0.957447
8   Baseline    simple_time  0.425335  0.330097  0.572816
9   Proposed    simple_time  0.425335  0.330097  0.572816
2   Baseline      time_join  0.313840  0.175439  0.543860
3   Proposed      time_join  0.313840  0.175439  0.543860
