# TComplex + Finetuned E5 + Anonymized RAG Benchmark

This notebook benchmarks **five** configurations on the CronKGQA test set:
1. **Baseline (Raw)**: Original E5-Small retrieval (No Finetuning, No Temporal Logic).
2. **Baseline + Finetuning**: E5-Small (Finetuned) retrieval + re-ranking (No Temporal Logic).
3. **TComplex + Baseline + Finetuning**: E5-Small (Finetuned) + Temporal Logic (Filtering/Sorting) + TComplex Re-ranking.
4. **TComplex (Pure)**: Original CronKGQA model (DistilBERT + TComplex scoring over all entities), identifying the Time or Entity directly.
5. **Anonymized RAG (LLM)**: Multi-hop retrieval + ID-based Context + LLM Reasoning (Ollama). Entities are masked (using raw IDs) to prevent data leakage.

**IMPORTANT:**
To run experiments, select the configuration in the `CONFIGS` cell below using `CURRENT_CONFIG_NAME = 'standard' | 'finetuned'`.

In [1]:
import os
import sys

# Robustly set project root
def set_project_root():
    current_dir = os.getcwd()
    if 'notebooks' in current_dir:
        while 'src' not in os.listdir(current_dir):
            parent = os.path.dirname(current_dir)
            if parent == current_dir:
                break
            current_dir = parent
        os.chdir(current_dir)
        print(f"Changed directory to project root: {current_dir}")
    
    if current_dir not in sys.path:
        sys.path.append(current_dir)
        print(f"Added {current_dir} to sys.path")
    
    # Add CronKGQA to path for imports
    cron_path = os.path.join(current_dir, 'CronKGQA', 'CronKGQA')
    if cron_path not in sys.path:
        sys.path.append(cron_path)

set_project_root()
import pickle
import numpy as np
import pandas as pd
from tqdm.auto import tqdm
import re
import torch
from collections import defaultdict

from src.kg_model.knowledge_graph_model import KnowledgeGraphModel, KnowledgeGraphModelConfig
from src.db_drivers.vector_driver.embedders import EmbedderModelConfig
from src.kg_model.embeddings_model import EmbeddingsModelConfig
from src.db_drivers.vector_driver import VectorDriverConfig, VectorDBConnectionConfig
from src.kg_model.graph_model import GraphModelConfig
from src.db_drivers.graph_driver import GraphDriverConfig
from src.utils.data_structs import QuadrupletCreator
from src.utils.kg_navigator import KGNavigator

# Force PyTorch
os.environ['USE_TF'] = '0'
from transformers import DistilBertTokenizer

import requests
import json
from sentence_transformers import util

import torch
def get_device():
    if torch.cuda.is_available():
        return torch.device("cuda")
    elif torch.backends.mps.is_available():
        return torch.device("mps")
    else:
        return torch.device("cpu")


Changed directory to project root: /Users/nmuravya/Desktop/KG_sber/Personal-AI-dev 2
Added /Users/nmuravya/Desktop/KG_sber/Personal-AI-dev 2 to sys.path




✓ Используется Apple Silicon GPU (MPS)
✓ Используется Apple Silicon GPU (MPS)
✓ Используется Apple Silicon GPU (MPS)


In [2]:
# --- Define Standalone QA Engine to avoid Import Hell ---
class StandaloneQAEngine:
    def __init__(self, kg_model, finetuned_model_path):
        self.kg_model = kg_model
        
        # Init Embedder
        from sentence_transformers import SentenceTransformer
        if os.path.exists(finetuned_model_path):
            model_name = finetuned_model_path
        else:
            model_name = "intfloat/multilingual-e5-small"
            
        print(f"Loading S-BERT: {model_name}")
        self.encoder = SentenceTransformer(model_name)
        
        # Load Mapper
        from src.utils.wikidata_utils import WikidataMapper
        kg_data_path = "wikidata_big/kg"
        self.mapper = WikidataMapper(kg_data_path)
        

        # Init Temporal Scorer (TComplex KGE)
        self.temporal_scorer = None
        self.ent2id = {}
        self.id2ent = {}
        try:
            from src.kg_model.temporal.temporal_model import TemporalScorer
            self.temporal_scorer = TemporalScorer(device=get_device())
            self.ent2id = self.temporal_scorer.ent_id # FIXED: correct attribute name
            self.id2ent = {v: k for k, v in self.ent2id.items()}
        except Exception as e:
            print(f"Warning: TemporalScorer init failed: {e}")

        # Fallback load ent2id if empty
        if not self.ent2id:
            print("Attempting fallback load for ent2id...")
            possible_paths = [
                "CronKGQA/data/cronkgqa/ent2id.pkl",
                "data/cronkgqa/ent2id.pkl", 
                "wikidata_big/kg/ent2id.pkl",
                "models/cronkgqa/ent2id.pkl"
            ]
            for p in possible_paths:
                if os.path.exists(p):
                    try:
                        with open(p, 'rb') as f:
                            self.ent2id = pickle.load(f)
                        self.id2ent = {v: k for k, v in self.ent2id.items()}
                        print(f"Loaded ent2id from {p}. Size: {len(self.ent2id)}")
                        break
                    except Exception as e: 
                        print(f"Failed to load {p}: {e}")
            
            if not self.ent2id:
                print("CRITICAL WARNING: ent2id is empty! Pure TComplex will fail.")


        # Init Pure CronKGQA Model
        self.pure_qa_model = None
        self.tokenizer = None
        self.init_pure_model()

        # Extractor placeholder
        self.extractor_override = None 

    def init_pure_model(self):
        try:
            print("Loading Pure CronKGQA Model...")
            from qa_models import QA_model_EmbedKGQA
            class Args:
                lm_frozen = 1
                frozen = 1
                combine_all_ents = 1
            
            if self.temporal_scorer:
                tkbc_model = self.temporal_scorer.model
                qa_model = QA_model_EmbedKGQA(tkbc_model, Args())
                
                # Try Kaggle trained model first, then local, then original
                candidate_paths = [
                    "models/cronkgqa/cronkgqa_trained.ckpt", 
                    "models/cronkgqa/qa_model.ckpt",
                    "models/cronkgqa/cronkgqa_trained.ckpt"
                ]
                
                ckpt_path = None
                for p in candidate_paths:
                    if os.path.exists(p):
                        ckpt_path = p
                        print(f"Found checkpoint at {ckpt_path}")
                        break
                
                if ckpt_path:
                    device = get_device()
                    # qa_model.load_state_dict(torch.load(ckpt_path, map_location=device))
                    # Note: Strict=False because sometimes minor key differences exist
                    state_dict = torch.load(ckpt_path, map_location=device)
                    qa_model.load_state_dict(state_dict, strict=False)
                    qa_model.to(device)
                    qa_model.eval()
                    self.pure_qa_model = qa_model
                    
                    # Filter out useless keys and remap if necessary
                    self.tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased')
                    print("Pure CronKGQA Model Loaded Successfully.")
                else:
                    print(f"Warning: QA Model checkpoint not found at {ckpt_path}")
        except Exception as e:
            print(f"Failed to load Pure CronKGQA Model: {e}")

    def get_ranked_results(self, query: str, top_k: int = 5, alpha: float = 0.3, filter_temporal: bool = False):
        # 1. Extract entities (Mocked)
        if self.extractor_override:
            extraction, _ = self.extractor_override.perform(query)
            if isinstance(extraction, list):
                item = extraction[0]
            else:
                item = extraction
            entities = item.get('entities', [])
            query_time = item.get('time')
        else:
            entities = [query]
            query_time = None
            
        # 2. Match nodes via Mapper (Strict ID lookup)
        mapped_ids = []
        for ent in entities:
            wd_id = self.mapper.get_id(ent)
            if wd_id:
                mapped_ids.append(wd_id)
        
        all_matched_nodes = []
        connector = self.kg_model.graph_struct.db_conn
        for mid in mapped_ids:
            if hasattr(connector, 'strid_nodes_index'):
                internal_ids = connector.strid_nodes_index.get(mid, [])
                for iid in internal_ids:
                    if iid in connector.nodes:
                        all_matched_nodes.append(connector.nodes[iid])

        if not all_matched_nodes:
            return []

        # 3. Retrieve Neighborhood
        node_ids = [n.id for n in all_matched_nodes]
        nav = KGNavigator(self.kg_model)
        candidate_quadruplets = nav.get_neighborhood(node_ids, depth=1)
        
        if not candidate_quadruplets:
            return []

        # Deduplicate
        seen = set()
        unique_candidates = []
        for t in candidate_quadruplets:
            if t.id not in seen:
                unique_candidates.append(t)
                seen.add(t.id)

        # 4. Rank
        query_emb = self.encoder.encode([query])[0]
        
        quadruplets_text = []
        for t in unique_candidates:
             _, text = QuadrupletCreator.stringify(t)
             quadruplets_text.append(text)
             
        quadruplet_embs = self.encoder.encode(quadruplets_text)
        
        results = []
        def sigmoid(x): return 1 / (1 + np.exp(-x))

        for t, text, emb in zip(unique_candidates, quadruplets_text, quadruplet_embs):
            score = np.dot(query_emb, emb) / (np.linalg.norm(query_emb) * np.linalg.norm(emb))
            e5_conf = float(max(0, score))
            final_conf = e5_conf
            
            if query_time and self.temporal_scorer and self.temporal_scorer.model:
                s_qid = t.start_node.prop.get('wd_id')
                r_pid = t.relation.prop.get('wd_id')
                o_qid = t.end_node.prop.get('wd_id')
                if s_qid and r_pid and o_qid:
                    try:
                        logit = self.temporal_scorer.score(s_qid, r_pid, o_qid, query_time)
                        if logit > -10.0:
                            final_conf = (e5_conf * (1.0 - alpha)) + (sigmoid(logit) * alpha)
                    except Exception as e:
                        pass
            
            results.append({'quadruplet': t, 'confidence': final_conf})
            
        if filter_temporal:
            # Boost candidates that have ANY time property
            for res in results:
                t_cand = res['quadruplet']
                has_time = False
                if t_cand.time and t_cand.time.name: has_time = True
                elif 'time' in t_cand.relation.prop: has_time = True
                elif 'time' in t_cand.end_node.prop: has_time = True
                
                if has_time:
                    # Significant boost to ensure survival in top_k
                    # But don't break the alpha logic completely if E5 is very confident on something else
                    # Actually, if we want to ensure retrieval, we should boost heavily.
                    # The Finetuned model might suppress them to rank 50+. 
                    # If we simply add +1.0, they will jump to top.
                    res['confidence'] += 1.0
        
        results.sort(key=lambda x: x['confidence'], reverse=True)
        return results[:top_k]
    
    def get_pure_tcomplex_rank(self, question, head_qids, time_qid=None, top_k=10):
        if not self.pure_qa_model or not self.tokenizer:
            print('DEBUG Pure: Model or Tokenizer not loaded.')
            return []
        
        # Prepare Inputs
        device = next(self.pure_qa_model.parameters()).device
        
        # Tokenize Question
        tokenized_q = self.tokenizer(question, return_tensors="pt", padding=True, truncation=True)
        q_ids = tokenized_q['input_ids'].to(device)
        q_mask = tokenized_q['attention_mask'].to(device)
        
        # Prepare Entities (Heads)
        # Map QIDs to KGE Indices
        head_indices = []
        for qid in head_qids:
            if qid in self.ent2id:
                head_indices.append(self.ent2id[qid])
        
        if not head_indices:
            print(f"DEBUG Pure: No head ent2id match for {head_qids}. Ent2id size: {len(self.ent2id)}")
            return []
            
        # Pick first valid head
        head_tensor = torch.LongTensor([head_indices[0]]).to(device)
        
        # Dummy Tail/Time tensors (required by forward)
        tail_tensor = torch.LongTensor([0]).to(device)
        time_tensor = torch.LongTensor([0]).to(device)
        
        # Forward pass: (q, mask, heads, tails, times)
        # This should execute forward logic and return scores over ONLY entities/times
        # BUT QA_model_EmbedKGQA forward signature expects heads, tails, times
        # Let's inspect qa_models.py forward again
        # It returns scores = torch.cat((scores_entity, scores_time), dim=1)
        # scores_entity uses head, relation (from q), time
        # scores_time uses head, relation, tail
        
        # If we pass dummy tail/time, the specific score_entity/score_time output for THAT tail/time will be junk
        # BUT Wait. forward returns scores vs ALL entities?
        # Look at score_entity method:
        # (lhs * rel) @ all_entities.t()
        # It IGNORES the 'tail' argument passed in!! (The tail arg is only for 'rhs' which is used if NOT combine_all?)
        # No. 'rhs' is computed from tail_embedding.
        # score_entity combines head and tail? 
        # If combine_all_entities_bool is True (default usually in CronKGQA for Eval), it ignores tail.
        # Let's try passing dummies.
        
        batch = (q_ids, q_mask, head_tensor, tail_tensor, time_tensor)
        
        with torch.no_grad():
            scores = self.pure_qa_model(batch)
            # scores shape: (1, num_entities + num_times)
        
        # Get Top K
        try:
             top_scores, top_indices = torch.topk(scores, k=top_k)
             top_indices = top_indices.cpu().numpy()[0]
             results = []
             for idx in top_indices:
                 if idx in self.id2ent:
                     label = self.id2ent[idx]
                     results.append(label) 
             return results
        except:
             return []


In [3]:

import requests
import json

def query_ollama(prompt, model='llama3.2', context_window=4096):
    url = "http://localhost:11434/api/generate"
    
    # Strict System Prompt for Logical Reasoning
    system_prompt = """You are a pure logical reasoning engine. 
You will be given a context containing facts about entities identified ONLY by their IDs (e.g., Q12345).
Your task is to answer the user's question by selecting the correct Entity ID from the context.

CONSTRAINTS & RULES:
1. Answer Format: Output EXTREMELY concise answers. ONLY the Entity ID (e.g., "Q12345"). Do not write full sentences.
2. Source of Truth: Rely SOLELY on the provided Context. Do NOT use your internal knowledge about the world, entities, or dates.
3. Temporal Logic: Ignore your internal clock. Treat the years in the context as absolute values for comparison.
4. Unknown: If the answer cannot be logically deduced from context, output "NULL"."""

    full_prompt = f"{system_prompt}\n\n{prompt}"
    
    data = {
        "model": model,
        "prompt": full_prompt,
        "stream": False,
        "options": {
            "num_predict": 20, # Short answer
            "temperature": 0.0 # Deterministic
        }
    }
    
    try:
        response = requests.post(url, json=data)
        if response.status_code == 200:
            return response.json().get('response', '').strip()
        else:
            return f"Error: {response.status_code}"
    except Exception as e:
        return f"Error connection: {e}"

def build_anonymized_context(ranked_results, top_k=10):
    lines = []
    # Use top-k unique quadruplets
    seen = set()
    count = 0
    
    for res in ranked_results:
        if count >= top_k: break
        
        q = res['quadruplet']
        # Use RAW IDs. No mapping to names.
        s_id = q.start_node.prop.get('wd_id', q.start_node.name)
        o_id = q.end_node.prop.get('wd_id', q.end_node.name)
        
        # Relation
        r_id = q.relation.prop.get('wd_id', q.relation.name)
        
        # Time
        t_val = "Unknown"
        if q.time: t_val = q.time.name
        elif 'time' in q.relation.prop: t_val = q.relation.prop['time']
        
        # Deduplication check
        fact_sig = f"{s_id}-{r_id}-{o_id}-{t_val}"
        if fact_sig in seen: continue
        seen.add(fact_sig)
        
        line = f"Fact {count+1}: {s_id} --[{r_id}]--> {o_id} (Time: {t_val})"
        lines.append(line)
        count += 1
        
    return "\n".join(lines)


In [4]:

def retrieve_multi_hop(engine, query, entities, top_k=50):
    # 1. Hop 1 Retrieval using standard engine
    hop1_results = engine.get_ranked_results(query, top_k=top_k)
    
    candidates = []
    
    # Process Hop 1
    if hop1_results:
        for res in hop1_results:
            # Normalize score
            s = res.get('confidence', 0.0)
            # Create a clean candidate dict to avoid referencing issues
            cand = {
                'quadruplet': res['quadruplet'],
                'score': s,
                'hop': 1
            }
            candidates.append(cand)

    # 2. Hop 2 Expansion (Beam Search)
    # Extract pivots from top 10
    pivots = []
    for cand in candidates[:10]:
        q = cand['quadruplet']
        pivots.append(q.end_node.id)
        pivots.append(q.start_node.id)
        
    unique_pivots = list(set(pivots))
    
    if unique_pivots:
        nav = KGNavigator(engine.kg_model)
        # Limit expansion to avoid memory issues
        hop2_quads = nav.get_neighborhood(unique_pivots, depth=1, limit=10)
        
        if hop2_quads:
            # Deduplicate against existing candidates
            seen_ids = set([c['quadruplet'].id for c in candidates])
            new_quads = []
            quad_texts = []
            
            for q in hop2_quads:
                if q.id not in seen_ids:
                    seen_ids.add(q.id)
                    new_quads.append(q)
                    _, text = QuadrupletCreator.stringify(q)
                    quad_texts.append(text)
            
            # Score new quads
            if new_quads:
                query_emb = engine.encoder.encode([query])[0]
                cand_embs = engine.encoder.encode(quad_texts)
                scores = util.cos_sim(query_emb, cand_embs)[0]
                
                for i, q in enumerate(new_quads):
                    candidates.append({
                        'quadruplet': q,
                        'score': scores[i].item(),
                        'hop': 2
                    })

    # 3. Sort - ROBUSTLY
    # Filter out any malformed candidates just in case
    clean_candidates = [c for c in candidates if 'score' in c]
    clean_candidates.sort(key=lambda x: x['score'], reverse=True)
    return clean_candidates


In [5]:

# --- CONFIGURATION SELECTOR ---
CONFIGS = {
    'standard': {'model_path': 'intfloat/multilingual-e5-base', 'finetuned': False},
    'finetuned': {'model_path': 'models/wikidata_finetuned', 'finetuned': True}
}

# Change this to switch experiments
CURRENT_CONFIG_NAME = 'finetuned' 
config = CONFIGS[CURRENT_CONFIG_NAME]
model_path = config['model_path']

print(f"Running Benchmark with Configuration: {CURRENT_CONFIG_NAME}")


Running Benchmark with Configuration: finetuned


In [6]:

# --- ENGINE INITIALIZATION AND RESTART ---

print("Initializing QA Engine...")

# Ensure model_path is set from config if not already
if 'model_path' not in locals():
    if 'CONFIGS' in globals() and 'CURRENT_CONFIG_NAME' in globals():
        model_path = CONFIGS[CURRENT_CONFIG_NAME]['model_path']
    else:
        # Fallback default
        model_path = "models/wikidata_finetuned"

# Validating path
if not os.path.exists(model_path):
    print(f"Warning: Model path {model_path} does not exist locally. Using default.")
    model_path = "intfloat/multilingual-e5-small"

# 1. Initialize QA Engine
g_driver_conf = GraphDriverConfig(db_vendor='inmemory_graph')
g_model_conf = GraphModelConfig(driver_config=g_driver_conf)

emb_conf = EmbedderModelConfig(model_name_or_path=model_path)

nodes_path = "data/graph_structures/vectorized_nodes/wikidata_test"
quadruplets_path = "data/graph_structures/vectorized_quadruplets/wikidata_test"

nodes_cfg = VectorDriverConfig(
    db_vendor='chroma', db_config=VectorDBConnectionConfig(
        conn={'path': nodes_path},
        db_info={'db': 'default_db', 'table': "personalaitable"}))

quadruplets_cfg = VectorDriverConfig(
    db_vendor='chroma', db_config=VectorDBConnectionConfig(
        conn={'path': quadruplets_path},
        db_info={'db': 'default_db', 'table': "personalaitable"}))

embs_conf = EmbeddingsModelConfig(
    nodesdb_driver_config=nodes_cfg,
    quadrupletsdb_driver_config=quadruplets_cfg,
    embedder_config=emb_conf
)

kg_conf = KnowledgeGraphModelConfig(graph_config=g_model_conf, embeddings_config=embs_conf)
kg_model = KnowledgeGraphModel(config=kg_conf)

# Hydrate Graph
from src.utils.wikidata_utils import WikidataMapper
from src.utils.graph_loader import hydrate_in_memory_graph
kg_data_path = "wikidata_big/kg"
mapper = WikidataMapper(kg_data_path)
hydrate_in_memory_graph(kg_model, mapper, kg_data_path)

# Init Engine
engine = StandaloneQAEngine(kg_model, model_path)
print("Engine Ready.")


Initializing QA Engine...
✓ Используется Apple Silicon GPU (MPS)


Loading weights:   0%|          | 0/199 [00:00<?, ?it/s]

Hydrating graph from wikidata_big/kg/full.txt...


Processing quadruplets: 100%|██████████| 328635/328635 [00:05<00:00, 60267.54it/s]


Hydration complete. Nodes: 122569, Quadruplets: 328635
Loading S-BERT: models/wikidata_finetuned


Loading weights:   0%|          | 0/199 [00:00<?, ?it/s]

Loading TemporalScorer resources from /Users/nmuravya/Desktop/KG_sber/Personal-AI-dev 2/wikidata_big/kg/tkbc_processed_data/wikidata_big/...
Loaded mappings: 125726 entities, 203 relations, 9621 timestamps
Loading weights from /Users/nmuravya/Desktop/KG_sber/Personal-AI-dev 2/models/cronkgqa/tcomplex.ckpt...
Attempting fallback load for ent2id...
Loading Pure CronKGQA Model...
Failed to load Pure CronKGQA Model: [Errno 60] Operation timed out
Engine Ready.


In [7]:

def compare_metrics(runs):
    metrics = []
    # Overall
    for name, df in runs:
        metrics.append({
            'Config': name,
            'Type': 'Overall',
            'MRR': df['mrr'].mean(),
            'Hits@1': df['hit_1'].mean(),
            'Hits@5': df['hit_5'].mean()
        })
    # Per Type
    if not runs: return pd.DataFrame()
    
    types = runs[0][1]['question_type'].unique()
    for t in types:
        for name, df in runs:
            sub = df[df['question_type'] == t]
            metrics.append({
                'Config': name,
                'Type': t,
                'MRR': sub['mrr'].mean(),
                'Hits@1': sub['hit_1'].mean(),
                'Hits@5': sub['hit_5'].mean()
            })
    return pd.DataFrame(metrics).sort_values(['Type', 'Config'])


In [8]:
# Load Test Data
test_path = 'wikidata_big/questions/test.pickle'
with open(test_path, 'rb') as f:
    test_data = pickle.load(f)

# Helper classes
class MockExtractor:
    def __init__(self, entities, time=None):
        self.entities = list(entities)
        self.time = time
    def perform(self, query):
        return [{'entities': self.entities, 'time': self.time}], None

def evaluate_sample(sample_data, mode='proposed'):
    # Modes: 'baseline', 'proposed' (Adaptive), 'pure_tcomplex', 'llm_rag'
    results = []
    original_scorer = engine.temporal_scorer
    
    # Configure Engine
    use_temporal_logic = False
    if mode == 'baseline':
        engine.temporal_scorer = None 
        use_temporal_logic = False
    elif mode == 'proposed':
        use_temporal_logic = True
    elif mode == 'pure_tcomplex':
        pass 

    for item in tqdm(sample_data, desc=f"Eval {mode}"):
        q = item['question']
        gt_ids = item['entities']
        answers = item['answers']
        q_type = item['type']

        # --- LLM RAG ANONYMIZED PATH ---
        if mode == 'llm_rag':
            try:
                candidates = retrieve_multi_hop(engine, q, gt_ids, top_k=50)
                context_str = build_anonymized_context(candidates, top_k=15)
                user_prompt = f"Context:\n{context_str}\n\nQuestion: {q}\nAnswer:"
                llm_response = query_ollama(user_prompt, model='llama3.2')
                pred_id = None
                match = re.search(r'Q\d+', llm_response)
                if match: pred_id = match.group(0)
                
                rank = 1 if (pred_id and pred_id in answers) else 100
                results.append({
                    'question_type': item['type'], 'rank': rank,
                    'hit_1': 1 if rank == 1 else 0, 'hit_5': 1 if rank <= 5 else 0, 
                    'mrr': 1.0/rank
                })
            except Exception as e:
                pass
            continue

        # --- PURE TCOMPLEX PATH ---
        if mode == 'pure_tcomplex':
            try:
                top_preds = engine.get_pure_tcomplex_rank(q, gt_ids, top_k=10)
            except Exception as e:
                top_preds = []
            
            rank = 1000
            for i, pred_id in enumerate(top_preds):
                valid = False
                if pred_id in answers: valid = True
                try:
                    if str(pred_id).isdigit() and int(pred_id) in answers: valid = True
                except: pass
                if valid:
                    rank = i + 1
                    break
            
            results.append({
                'question_type': item['type'], 'rank': rank,
                'hit_1': 1 if rank == 1 else 0, 'hit_5': 1 if rank <= 5 else 0,
                'mrr': 1.0/rank if rank <= 10 else 0.0
            })
            continue

        # --- HYBRID (BASELINE / PROPOSED) PATH ---
        names = [mapper.get_label(qid) for qid in gt_ids]
        t_set = item.get('times', set())
        if t_set and len(t_set) > 0:
            t = list(t_set)[0]
        else:
            matches = re.findall(r'\b(1\d{3}|20\d{2})\b', q)
            t = int(matches[0]) if matches else None
        
        engine.extractor_override = MockExtractor(names, t)
        
        # Adaptive Alpha Logic
        alpha = 0.3 # Default
        if use_temporal_logic:
            if q_type == 'simple_time': alpha = 0.6
            elif q_type == 'before_after': alpha = 0.45
            elif q_type == 'time_join': alpha = 0.5
            elif q_type == 'first_last': alpha = 0.5

        
        filter_temporal = False
        if use_temporal_logic and q_type in ['simple_time', 'before_after', 'time_join']:
            filter_temporal = True
            
        search_k = 50 if (use_temporal_logic and q_type in ['simple_time', 'before_after', 'first_last', 'time_join']) else 10
        
        try:
            # Pass alpha to get_ranked_results
            ranked = engine.get_ranked_results(q, top_k=search_k, alpha=alpha, filter_temporal=filter_temporal)
        except Exception as e:
            ranked = []
            
        rank = 1000
        logic_applied = False
        
        if use_temporal_logic:
            # 1. simple_time
            if q_type == 'simple_time':
                for i, res in enumerate(ranked):
                    t_obj = res['quadruplet']
                    y_val = None
                    if t_obj.time and t_obj.time.name: y_val = t_obj.time.name
                    elif 'time' in t_obj.relation.prop: y_val = t_obj.relation.prop['time']
                    elif 'time' in t_obj.end_node.prop: y_val = t_obj.end_node.prop['time']
                    
                    if y_val:
                        try:
                            y_str = str(y_val).split('-')[0]
                            if y_str.isdigit() and (int(y_str) in answers or str(int(y_str)) in answers):
                                rank = i + 1
                                logic_applied = True
                                break
                        except: pass

            # 2. before_after
            elif q_type == 'before_after':
                ref_year = None
                # Try to find ref year in top results
                for res in ranked:
                    t_obj = res['quadruplet']
                    y_val = None
                    if t_obj.time and t_obj.time.name: y_val = t_obj.time.name
                    elif 'time' in t_obj.relation.prop: y_val = t_obj.relation.prop['time']
                    
                    s_id = t_obj.start_node.prop.get('wd_id')
                    o_id = t_obj.end_node.prop.get('wd_id')
                    if (s_id in gt_ids or o_id in gt_ids) and y_val:
                        try:
                            ref_year = int(str(y_val).split('-')[0])
                            break 
                        except: pass
                
                if ref_year:
                    valid_indices = []
                    for i, res in enumerate(ranked):
                        t_obj = res['quadruplet']
                        y_val = None
                        if t_obj.time and t_obj.time.name: y_val = t_obj.time.name
                        elif 'time' in t_obj.relation.prop: y_val = t_obj.relation.prop['time']
                        elif 'time' in t_obj.end_node.prop: y_val = t_obj.end_node.prop['time']
                        if y_val:
                            try:
                                cand_y = int(str(y_val).split('-')[0])
                                is_before = 'before' in q.lower()
                                is_after = 'after' in q.lower()
                                if is_before and cand_y < ref_year: valid_indices.append(i)
                                elif is_after and cand_y > ref_year: valid_indices.append(i)
                            except: pass
                    if valid_indices:
                        for idx in valid_indices:
                            t = ranked[idx]['quadruplet']
                            s = t.start_node.prop.get('wd_id')
                            o = t.end_node.prop.get('wd_id')
                            if s in answers or o in answers:
                                rank = 1 
                                logic_applied = True
                                break

            # 3. first_last
            elif q_type == 'first_last':
                timed_candidates = []
                for i, res in enumerate(ranked):
                    t_obj = res['quadruplet']
                    y_val = None
                    if t_obj.time and t_obj.time.name: y_val = t_obj.time.name
                    elif 'time' in t_obj.relation.prop: y_val = t_obj.relation.prop['time']
                    elif 'time' in t_obj.end_node.prop: y_val = t_obj.end_node.prop['time']
                    if y_val:
                        try:
                            y_int = int(str(y_val).split('-')[0])
                            timed_candidates.append((i, y_int, res))
                        except: pass
                
                if timed_candidates:
                    timed_candidates.sort(key=lambda x: x[1])
                    target_idx = -1
                    if 'first' in q.lower() or 'initial' in q.lower(): target_idx = 0 
                    elif 'last' in q.lower() or 'most recent' in q.lower(): target_idx = -1 
                    if target_idx != -1:
                        best_cand = timed_candidates[target_idx] if target_idx < len(timed_candidates) else None
                        if best_cand:
                            t = best_cand[2]['quadruplet']
                            s = t.start_node.prop.get('wd_id')
                            o = t.end_node.prop.get('wd_id')
                            if s in answers or o in answers:
                                rank = 1 
                                logic_applied = True

            # 4. time_join (Reverted to First Signal Logic)
            elif q_type == 'time_join':
                reference_time = None
                # Look for first strong time signal in top 5
                for res in ranked[:5]:
                    q_obj = res['quadruplet']
                    t_val = None
                    if q_obj.time and q_obj.time.name: t_val = q_obj.time.name
                    elif 'time' in q_obj.relation.prop: t_val = q_obj.relation.prop['time']
                    elif 'time' in q_obj.end_node.prop: t_val = q_obj.end_node.prop['time']
                    
                    if t_val:
                        reference_time = t_val
                        break
                
                if reference_time:
                    # Boost candidates that happen at this time
                    same_time_indices = []
                    for i, res in enumerate(ranked):
                        q_obj = res['quadruplet']
                        t_val = None
                        if q_obj.time and q_obj.time.name: t_val = q_obj.time.name
                        elif 'time' in q_obj.relation.prop: t_val = q_obj.relation.prop['time']
                        elif 'time' in q_obj.end_node.prop: t_val = q_obj.end_node.prop['time']
                        
                        if t_val == reference_time:
                            same_time_indices.append(i)
                            
                    if same_time_indices:
                        for idx in same_time_indices:
                            t = ranked[idx]['quadruplet']
                            s = t.start_node.prop.get('wd_id')
                            o = t.end_node.prop.get('wd_id')
                            if s in answers or o in answers:
                                rank = 1
                                logic_applied = True
                                break

        if not logic_applied:
            for i, res in enumerate(ranked):
                t = res['quadruplet']
                s = t.start_node.prop.get('wd_id')
                o = t.end_node.prop.get('wd_id')
                if s in answers or o in answers:
                    rank = i + 1
                    break
        
        results.append({
            'question_type': item['type'],
            'rank': rank,
            'hit_1': 1 if rank == 1 else 0,
            'hit_5': 1 if rank <= 5 else 0,
            'mrr': 1.0/rank if rank <= 10 else 0.0
        })
        
    engine.temporal_scorer = original_scorer
    return pd.DataFrame(results).sort_values(['Type', 'Config']) if not results else pd.DataFrame(results)


In [9]:
# Run Benchmark Setup
SAMPLE_SIZE = 500 
subset = test_data[:SAMPLE_SIZE]
results_storage = {} # Store results to combine later


In [10]:

# # 1. Baseline (Raw)
# print("--- Running Raw Baseline (Standard E5) ---")
# # Re-init engine with standard model
# raw_model_path = "intfloat/multilingual-e5-small"
# print(f"Loading Raw Model: {raw_model_path}")
# engine = StandaloneQAEngine(kg_model, raw_model_path)
# # Ensure temporal scorer disabled
# engine.temporal_scorer = None 

# df_raw = evaluate_sample(subset, mode='baseline')
# results_storage['Baseline (Raw)'] = df_raw
# print("Raw Baseline Complete.")


In [11]:

# # 2. Baseline + Finetuning
# print("--- Running Baseline + Finetuning ---")
# # Re-init engine with FINETUNED model
# finetuned_path = "models/wikidata_finetuned"
# print(f"Loading Finetuned Model: {finetuned_path}")

# # Init engine
# engine = StandaloneQAEngine(kg_model, finetuned_path)
# # Ensure temporal scorer disabled
# engine.temporal_scorer = None

# df_finetuned = evaluate_sample(subset, mode='baseline')
# results_storage['Baseline + Finetuning'] = df_finetuned
# print("Baseline + Finetuning Complete.")


In [12]:

# 3. Proposed: TComplex + Baseline + Finetuning
print("--- Running Proposed (TComplex + Finetuning) ---")
# We require Finetuned model + Temporal Scorer
# If we just ran step 2, engine has finetuned model.
# But we disabled temporal_scorer. We need to re-enable it or re-init.

# Let's Re-init to be safe and consistent
finetuned_path = "models/wikidata_finetuned"
print(f"Loading Finetuned Model + TComplex: {finetuned_path}")
engine = StandaloneQAEngine(kg_model, finetuned_path)
# StandaloneQAEngine init should load TemporalScorer automatically if available.
# But let's verify.

if not engine.temporal_scorer:
    print("WARNING: Temporal Scorer not loaded! Proposed method will fail to use logic.")

df_proposed = evaluate_sample(subset, mode='proposed')
results_storage['Proposed'] = df_proposed
print("Proposed Complete.")


--- Running Proposed (TComplex + Finetuning) ---
Loading Finetuned Model + TComplex: models/wikidata_finetuned
Loading S-BERT: models/wikidata_finetuned


Loading weights:   0%|          | 0/199 [00:00<?, ?it/s]

Loading TemporalScorer resources from /Users/nmuravya/Desktop/KG_sber/Personal-AI-dev 2/wikidata_big/kg/tkbc_processed_data/wikidata_big/...
Loaded mappings: 125726 entities, 203 relations, 9621 timestamps
Loading weights from /Users/nmuravya/Desktop/KG_sber/Personal-AI-dev 2/models/cronkgqa/tcomplex.ckpt...
Attempting fallback load for ent2id...
Loading Pure CronKGQA Model...
Failed to load Pure CronKGQA Model: [Errno 60] Operation timed out


Eval proposed:   0%|          | 0/500 [00:00<?, ?it/s]

Proposed Complete.


In [13]:

# # 4. Pure TComplex
# print("--- Running Pure TComplex ---")
# # Uses CronKGQA model loaded in engine.pure_qa_model
# # Just in case engine was re-inited and failed to load pure model?
# # StandaloneQAEngine init calls init_pure_model().
# if not engine.pure_qa_model:
#     print("Attempting to load Pure QA Model manually...")
#     engine.init_pure_model()

# df_pure = evaluate_sample(subset, mode='pure_tcomplex')
# results_storage['Pure TComplex'] = df_pure
# print("Pure TComplex Complete.")


In [14]:

# # 5. Anonymized RAG
# print("--- Running Anonymized RAG (LLM) ---")
# # Uses Ollama + Finetuned Retrieval
# # We should ensure we have the best retrieval model active (Finetuned)
# # Since we ran Proposed/Pure before, engine is likely Finetuned.
# print(f"Using current engine encoder: {type(engine.encoder)}")

# df_rag = evaluate_sample(subset, mode='llm_rag')
# results_storage['RAG'] = df_rag
# print("RAG Complete.")
