# TComplex + Finetuned E5 Benchmark

This notebook benchmarks three configurations on the CronKGQA test set:
1. **Baseline + Finetuning**: E5-Small (Finetuned) retrieval + re-ranking (No Temporal Logic).
2. **TComplex + Baseline + Finetuning**: E5-Small (Finetuned) + Temporal Logic (Filtering/Sorting) + TComplex Re-ranking.
3. **TComplex (Pure)**: Original CronKGQA model (DistilBERT + TComplex scoring over all entities), identifying the Time or Entity directly.

In [1]:
import os
import sys

# Robustly set project root
def set_project_root():
    current_dir = os.getcwd()
    if 'notebooks' in current_dir:
        while 'src' not in os.listdir(current_dir):
            parent = os.path.dirname(current_dir)
            if parent == current_dir:
                break
            current_dir = parent
        os.chdir(current_dir)
        print(f"Changed directory to project root: {current_dir}")
    
    if current_dir not in sys.path:
        sys.path.append(current_dir)
        print(f"Added {current_dir} to sys.path")
    
    # Add CronKGQA to path for imports
    cron_path = os.path.join(current_dir, 'CronKGQA', 'CronKGQA')
    if cron_path not in sys.path:
        sys.path.append(cron_path)

set_project_root()
import pickle
import numpy as np
import pandas as pd
from tqdm.auto import tqdm
import re
import torch
from collections import defaultdict

from src.kg_model.knowledge_graph_model import KnowledgeGraphModel, KnowledgeGraphModelConfig
from src.db_drivers.vector_driver.embedders import EmbedderModelConfig
from src.kg_model.embeddings_model import EmbeddingsModelConfig
from src.db_drivers.vector_driver import VectorDriverConfig, VectorDBConnectionConfig
from src.kg_model.graph_model import GraphModelConfig
from src.db_drivers.graph_driver import GraphDriverConfig
from src.utils.data_structs import TripletCreator
from src.utils.kg_navigator import KGNavigator

# Force PyTorch
os.environ['USE_TF'] = '0'
from transformers import DistilBertTokenizer


Changed directory to project root: /Users/nmuravya/Desktop/KG_sber/Personal-AI-dev 2
Added /Users/nmuravya/Desktop/KG_sber/Personal-AI-dev 2 to sys.path




✓ Используется Apple Silicon GPU (MPS)
✓ Используется Apple Silicon GPU (MPS)
✓ Используется Apple Silicon GPU (MPS)


In [2]:
# --- Define Standalone QA Engine to avoid Import Hell ---
class StandaloneQAEngine:
    def __init__(self, kg_model, finetuned_model_path):
        self.kg_model = kg_model
        
        # Init Embedder
        from sentence_transformers import SentenceTransformer
        if os.path.exists(finetuned_model_path):
            model_name = finetuned_model_path
        else:
            model_name = "intfloat/multilingual-e5-small"
            
        print(f"Loading S-BERT: {model_name}")
        self.encoder = SentenceTransformer(model_name)
        
        # Load Mapper
        from src.utils.wikidata_utils import WikidataMapper
        kg_data_path = "wikidata_big/kg"
        self.mapper = WikidataMapper(kg_data_path)
        
        # Init Temporal Scorer (TComplex KGE)
        self.temporal_scorer = None
        self.ent2id = {}
        self.id2ent = {}
        try:
            from src.kg_model.temporal.temporal_model import TemporalScorer
            self.temporal_scorer = TemporalScorer(device="cuda" if torch.cuda.is_available() else "cpu")
            self.ent2id = self.temporal_scorer.ent2id
            self.id2ent = {v: k for k, v in self.ent2id.items()}
        except Exception as e:
            print(f"Warning: Could not initialize TemporalScorer: {e}")

        # Init Pure CronKGQA Model
        self.pure_qa_model = None
        self.tokenizer = None
        self.init_pure_model()

        # Extractor placeholder
        self.extractor_override = None 

    def init_pure_model(self):
        try:
            print("Loading Pure CronKGQA Model...")
            from qa_models import QA_model_EmbedKGQA
            class Args:
                lm_frozen = 1
                frozen = 1
            
            if self.temporal_scorer:
                tkbc_model = self.temporal_scorer.model
                qa_model = QA_model_EmbedKGQA(tkbc_model, Args())
                
                ckpt_path = "models/cronkgqa/qa_model.ckpt"
                if os.path.exists(ckpt_path):
                    device = torch.device("cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu")
                    # qa_model.load_state_dict(torch.load(ckpt_path, map_location=device))
                    # Note: Strict=False because sometimes minor key differences exist
                    state_dict = torch.load(ckpt_path, map_location=device)
                    qa_model.load_state_dict(state_dict, strict=False)
                    qa_model.to(device)
                    qa_model.eval()
                    self.pure_qa_model = qa_model
                    
                    # Filter out useless keys and remap if necessary
                    self.tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased')
                    print("Pure CronKGQA Model Loaded Successfully.")
                else:
                    print(f"Warning: QA Model checkpoint not found at {ckpt_path}")
        except Exception as e:
            print(f"Failed to load Pure CronKGQA Model: {e}")

    def get_ranked_results(self, query: str, top_k: int = 5):
        # 1. Extract entities (Mocked)
        if self.extractor_override:
            extraction, _ = self.extractor_override.perform(query)
            if isinstance(extraction, list):
                item = extraction[0]
            else:
                item = extraction
            entities = item.get('entities', [])
            query_time = item.get('time')
        else:
            entities = [query]
            query_time = None
            
        # 2. Match nodes via Mapper (Strict ID lookup)
        mapped_ids = []
        for ent in entities:
            wd_id = self.mapper.get_id(ent)
            if wd_id:
                mapped_ids.append(wd_id)
        
        all_matched_nodes = []
        connector = self.kg_model.graph_struct.db_conn
        for mid in mapped_ids:
            if hasattr(connector, 'strid_nodes_index'):
                internal_ids = connector.strid_nodes_index.get(mid, [])
                for iid in internal_ids:
                    if iid in connector.nodes:
                        all_matched_nodes.append(connector.nodes[iid])

        if not all_matched_nodes:
            return []

        # 3. Retrieve Neighborhood
        node_ids = [n.id for n in all_matched_nodes]
        nav = KGNavigator(self.kg_model)
        candidate_triplets = nav.get_neighborhood(node_ids, depth=1)
        
        if not candidate_triplets:
            return []

        # Deduplicate
        seen = set()
        unique_candidates = []
        for t in candidate_triplets:
            if t.id not in seen:
                unique_candidates.append(t)
                seen.add(t.id)

        # 4. Rank
        query_emb = self.encoder.encode([query])[0]
        
        triplets_text = []
        for t in unique_candidates:
             _, text = TripletCreator.stringify(t)
             triplets_text.append(text)
             
        triplet_embs = self.encoder.encode(triplets_text)
        
        results = []
        def sigmoid(x): return 1 / (1 + np.exp(-x))

        for t, text, emb in zip(unique_candidates, triplets_text, triplet_embs):
            score = np.dot(query_emb, emb) / (np.linalg.norm(query_emb) * np.linalg.norm(emb))
            e5_conf = float(max(0, score))
            final_conf = e5_conf
            
            if query_time and self.temporal_scorer and self.temporal_scorer.model:
                s_qid = t.start_node.prop.get('wd_id')
                r_pid = t.relation.prop.get('wd_id')
                o_qid = t.end_node.prop.get('wd_id')
                if s_qid and r_pid and o_qid:
                    try:
                        logit = self.temporal_scorer.score(s_qid, r_pid, o_qid, query_time)
                        if logit > -10.0:
                            final_conf = (e5_conf * 0.7) + (sigmoid(logit) * 0.3)
                    except Exception as e:
                        pass
            
            results.append({'triplet': t, 'confidence': final_conf})
            
        results.sort(key=lambda x: x['confidence'], reverse=True)
        return results[:top_k]
    
    def get_pure_tcomplex_rank(self, question, head_qids, time_qid=None, top_k=10):
        if not self.pure_qa_model or not self.tokenizer:
            return []
        
        # Prepare Inputs
        device = next(self.pure_qa_model.parameters()).device
        
        # Tokenize Question
        tokenized_q = self.tokenizer(question, return_tensors="pt", padding=True, truncation=True)
        q_ids = tokenized_q['input_ids'].to(device)
        q_mask = tokenized_q['attention_mask'].to(device)
        
        # Prepare Entities (Heads)
        # Map QIDs to KGE Indices
        head_indices = []
        for qid in head_qids:
            if qid in self.ent2id:
                head_indices.append(self.ent2id[qid])
        
        if not head_indices:
            return []
            
        # Pick first valid head
        head_tensor = torch.LongTensor([head_indices[0]]).to(device)
        
        # Dummy Tail/Time tensors (required by forward)
        tail_tensor = torch.LongTensor([0]).to(device)
        time_tensor = torch.LongTensor([0]).to(device)
        
        # Forward pass: (q, mask, heads, tails, times)
        # This should execute forward logic and return scores over ONLY entities/times
        # BUT QA_model_EmbedKGQA forward signature expects heads, tails, times
        # Let's inspect qa_models.py forward again
        # It returns scores = torch.cat((scores_entity, scores_time), dim=1)
        # scores_entity uses head, relation (from q), time
        # scores_time uses head, relation, tail
        
        # If we pass dummy tail/time, the specific score_entity/score_time output for THAT tail/time will be junk
        # BUT Wait. forward returns scores vs ALL entities?
        # Look at score_entity method:
        # (lhs * rel) @ all_entities.t()
        # It IGNORES the 'tail' argument passed in!! (The tail arg is only for 'rhs' which is used if NOT combine_all?)
        # No. 'rhs' is computed from tail_embedding.
        # score_entity combines head and tail? 
        # If combine_all_entities_bool is True (default usually in CronKGQA for Eval), it ignores tail.
        # Let's try passing dummies.
        
        batch = (q_ids, q_mask, head_tensor, tail_tensor, time_tensor)
        
        with torch.no_grad():
            scores = self.pure_qa_model(batch)
            # scores shape: (1, num_entities + num_times)
        
        # Get Top K
        try:
             top_scores, top_indices = torch.topk(scores, k=top_k)
             top_indices = top_indices.cpu().numpy()[0]
             results = []
             for idx in top_indices:
                 if idx in self.id2ent:
                     label = self.id2ent[idx]
                     results.append(label) 
             return results
        except:
             return []


In [3]:
print("Initializing QA Engine...")

# 1. Initialize QA Engine
g_driver_conf = GraphDriverConfig(db_vendor='inmemory_graph')
g_model_conf = GraphModelConfig(driver_config=g_driver_conf)

model_path = "models/finetuned_e5"
if not os.path.exists(model_path):
    model_path = "intfloat/multilingual-e5-small"
    
emb_conf = EmbedderModelConfig(model_name_or_path=model_path)

nodes_path = "data/graph_structures/vectorized_nodes/wikidata_test"
triplets_path = "data/graph_structures/vectorized_triplets/wikidata_test"

nodes_cfg = VectorDriverConfig(
    db_vendor='chroma', db_config=VectorDBConnectionConfig(
        conn={'path': nodes_path},
        db_info={'db': 'default_db', 'table': "personalaitable"}))
        
triplets_cfg = VectorDriverConfig(
    db_vendor='chroma', db_config=VectorDBConnectionConfig(
        conn={'path': triplets_path},
        db_info={'db': 'default_db', 'table': "personalaitable"}))

embs_conf = EmbeddingsModelConfig(
    nodesdb_driver_config=nodes_cfg,
    tripletsdb_driver_config=triplets_cfg,
    embedder_config=emb_conf
)

kg_conf = KnowledgeGraphModelConfig(graph_config=g_model_conf, embeddings_config=embs_conf)
kg_model = KnowledgeGraphModel(config=kg_conf)

# Hydrate Graph
from src.utils.wikidata_utils import WikidataMapper
from src.utils.graph_loader import hydrate_in_memory_graph
kg_data_path = "wikidata_big/kg"
mapper = WikidataMapper(kg_data_path)
hydrate_in_memory_graph(kg_model, mapper, kg_data_path)

# Init Engine
engine = StandaloneQAEngine(kg_model, model_path)
print("Engine Ready.")

Initializing QA Engine...
✓ Используется Apple Silicon GPU (MPS)


Loading weights:   0%|          | 0/199 [00:00<?, ?it/s]

BertModel LOAD REPORT from: intfloat/multilingual-e5-small
Key                     | Status     |  | 
------------------------+------------+--+-
embeddings.position_ids | UNEXPECTED |  | 

Notes:
- UNEXPECTED	:can be ignored when loading from different task/architecture; not ok if you expect identical arch.


Hydrating graph from wikidata_big/kg/full.txt...


Processing triplets: 100%|██████████| 328635/328635 [00:05<00:00, 59105.91it/s]


Hydration complete. Nodes: 122569, Triplets: 328635
Loading S-BERT: intfloat/multilingual-e5-small


Loading weights:   0%|          | 0/199 [00:00<?, ?it/s]

BertModel LOAD REPORT from: intfloat/multilingual-e5-small
Key                     | Status     |  | 
------------------------+------------+--+-
embeddings.position_ids | UNEXPECTED |  | 

Notes:
- UNEXPECTED	:can be ignored when loading from different task/architecture; not ok if you expect identical arch.


Loading TemporalScorer resources from wikidata_big/kg/tkbc_processed_data/wikidata_big/...
Loaded mappings: 125726 entities, 203 relations, 9621 timestamps
Loading weights from models/cronkgqa/tcomplex.ckpt...
TemporalScorer initialized successfully.
Loading Pure CronKGQA Model...


config.json:   0%|          | 0.00/483 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/268M [00:00<?, ?B/s]

Loading weights:   0%|          | 0/100 [00:00<?, ?it/s]

DistilBertModel LOAD REPORT from: distilbert-base-uncased
Key                     | Status     |  | 
------------------------+------------+--+-
vocab_transform.weight  | UNEXPECTED |  | 
vocab_layer_norm.weight | UNEXPECTED |  | 
vocab_layer_norm.bias   | UNEXPECTED |  | 
vocab_projector.bias    | UNEXPECTED |  | 
vocab_transform.bias    | UNEXPECTED |  | 

Notes:
- UNEXPECTED	:can be ignored when loading from different task/architecture; not ok if you expect identical arch.


Freezing LM params
Failed to load Pure CronKGQA Model: 'Args' object has no attribute 'combine_all_ents'
Engine Ready.


In [4]:
# Load Test Data
test_path = 'wikidata_big/questions/test.pickle'
with open(test_path, 'rb') as f:
    test_data = pickle.load(f)

# Helper classes
class MockExtractor:
    def __init__(self, entities, time=None):
        self.entities = list(entities)
        self.time = time
    def perform(self, query):
        return [{'entities': self.entities, 'time': self.time}], None

def evaluate_sample(sample_data, mode='proposed'):
    # Modes: 'baseline' (No Temporal), 'proposed' (Filtering+Sorted), 'pure_tcomplex'
    results = []
    original_scorer = engine.temporal_scorer
    
    # Configure Engine
    use_temporal_logic = False
    if mode == 'baseline':
        engine.temporal_scorer = None # Disable TComplex Reranking
        use_temporal_logic = False
    elif mode == 'proposed':
        use_temporal_logic = True
    elif mode == 'pure_tcomplex':
        pass 

    for item in tqdm(sample_data, desc=f"Eval {mode}"):
        q = item['question']
        gt_ids = item['entities']
        answers = item['answers']
        q_type = item['type']
        
        # --- PURE TCOMPLEX PATH ---
        if mode == 'pure_tcomplex':
            # Pure Mode: Use CronKGQA model directly
            try:
                top_preds = engine.get_pure_tcomplex_rank(q, gt_ids, top_k=10)
            except Exception as e:
                top_preds = []
            
            rank = 1000
            for i, pred_id in enumerate(top_preds):
                # Convert QA Model entity/time ID string to set presence check
                valid = False
                # Check exact string match
                if pred_id in answers:
                    valid = True
                # Check int match (years)
                try:
                    if str(pred_id).isdigit() and int(pred_id) in answers:
                         valid = True
                except: pass
                
                if valid:
                    rank = i + 1
                    break
            
            results.append({
                'question_type': item['type'],
                'rank': rank,
                'hit_1': 1 if rank == 1 else 0,
                'hit_5': 1 if rank <= 5 else 0,
                'mrr': 1.0/rank if rank <= 10 else 0.0
            })
            continue

        # --- HYBRID PATH ---
        names = [mapper.get_label(qid) for qid in gt_ids]
        t_set = item.get('times', set())
        if t_set and len(t_set) > 0:
            t = list(t_set)[0]
        else:
            matches = re.findall(r'\b(1\d{3}|20\d{2})\b', q)
            t = int(matches[0]) if matches else None
        
        engine.extractor_override = MockExtractor(names, t)
        
        search_k = 50 if (use_temporal_logic and q_type in ['simple_time', 'before_after', 'first_last']) else 10
        
        try:
            ranked = engine.get_ranked_results(q, top_k=search_k)
        except Exception as e:
            ranked = []
            
        rank = 1000
        logic_applied = False
        
        if use_temporal_logic:
            # 1. simple_time
            if q_type == 'simple_time':
                for i, res in enumerate(ranked):
                    t_obj = res['triplet']
                    extracted_year = None
                    if t_obj.time and t_obj.time.name: extracted_year = t_obj.time.name
                    elif 'time' in t_obj.relation.prop: extracted_year = t_obj.relation.prop['time']
                    elif 'time' in t_obj.end_node.prop: extracted_year = t_obj.end_node.prop['time']
                    
                    if extracted_year:
                        try:
                            y_str = str(extracted_year).split('-')[0]
                            if y_str.isdigit():
                                 y_int = int(y_str)
                                 if y_int in answers or str(y_int) in answers:
                                     rank = i + 1
                                     logic_applied = True
                                     break
                        except: pass

            # 2. before_after
            elif q_type == 'before_after':
                ref_year = None
                for res in ranked:
                    t_obj = res['triplet']
                    y_val = None
                    if t_obj.time and t_obj.time.name: y_val = t_obj.time.name
                    elif 'time' in t_obj.relation.prop: y_val = t_obj.relation.prop['time']
                    
                    s_id = t_obj.start_node.prop.get('wd_id')
                    o_id = t_obj.end_node.prop.get('wd_id')
                    if (s_id in gt_ids or o_id in gt_ids) and y_val:
                        try:
                            ref_year = int(str(y_val).split('-')[0])
                            break 
                        except: pass
                
                if ref_year:
                    valid_indices = []
                    for i, res in enumerate(ranked):
                        t_obj = res['triplet']
                        y_val = None
                        if t_obj.time and t_obj.time.name: y_val = t_obj.time.name
                        elif 'time' in t_obj.relation.prop: y_val = t_obj.relation.prop['time']
                        elif 'time' in t_obj.end_node.prop: y_val = t_obj.end_node.prop['time'] 
                        if y_val:
                            try:
                                cand_y = int(str(y_val).split('-')[0])
                                is_before = 'before' in q.lower()
                                is_after = 'after' in q.lower()
                                if is_before and cand_y < ref_year: valid_indices.append(i)
                                elif is_after and cand_y > ref_year: valid_indices.append(i)
                            except: pass
                    if valid_indices:
                        for idx in valid_indices:
                            t = ranked[idx]['triplet']
                            s = t.start_node.prop.get('wd_id')
                            o = t.end_node.prop.get('wd_id')
                            if s in answers or o in answers:
                                rank = 1 
                                logic_applied = True
                                break

            # 3. first_last
            elif q_type == 'first_last':
                timed_candidates = []
                for i, res in enumerate(ranked):
                    t_obj = res['triplet']
                    y_val = None
                    if t_obj.time and t_obj.time.name: y_val = t_obj.time.name
                    elif 'time' in t_obj.relation.prop: y_val = t_obj.relation.prop['time']
                    elif 'time' in t_obj.end_node.prop: y_val = t_obj.end_node.prop['time']
                    if y_val:
                        try:
                            y_int = int(str(y_val).split('-')[0])
                            timed_candidates.append((i, y_int, res))
                        except: pass
                
                if timed_candidates:
                    timed_candidates.sort(key=lambda x: x[1])
                    target_idx = -1
                    if 'first' in q.lower() or 'initial' in q.lower(): target_idx = 0 
                    elif 'last' in q.lower() or 'most recent' in q.lower(): target_idx = -1 
                    if target_idx != -1:
                        best_cand_tuple = timed_candidates[target_idx] if target_idx < len(timed_candidates) else None
                        if best_cand_tuple:
                            t = best_cand_tuple[2]['triplet']
                            s = t.start_node.prop.get('wd_id')
                            o = t.end_node.prop.get('wd_id')
                            if s in answers or o in answers:
                                rank = 1 
                                logic_applied = True

        if not logic_applied:
            for i, res in enumerate(ranked):
                t = res['triplet']
                s = t.start_node.prop.get('wd_id')
                o = t.end_node.prop.get('wd_id')
                if s in answers or o in answers:
                    rank = i + 1
                    break
        
        results.append({
            'question_type': item['type'],
            'rank': rank,
            'hit_1': 1 if rank == 1 else 0,
            'hit_5': 1 if rank <= 5 else 0,
            'mrr': 1.0/rank if rank <= 10 else 0.0
        })
        
    engine.temporal_scorer = original_scorer
    return pd.DataFrame(results)


In [5]:
# Run Benchmark
SAMPLE_SIZE = 500 
subset = test_data[:SAMPLE_SIZE]

print("Running Baseline + Finetuning...")
df_base = evaluate_sample(subset, mode='baseline')

print("Running TComplex + Baseline + Finetuning (Proposed)...")
df_prop = evaluate_sample(subset, mode='proposed')

print("Running TComplex (Pure)...")
df_pure = evaluate_sample(subset, mode='pure_tcomplex')

print("Done.")

Running Baseline + Finetuning...


Eval baseline:   0%|          | 0/500 [00:00<?, ?it/s]

Running TComplex + Baseline + Finetuning (Proposed)...


Eval proposed:   0%|          | 0/500 [00:00<?, ?it/s]

Running TComplex (Pure)...


Eval pure_tcomplex:   0%|          | 0/500 [00:00<?, ?it/s]

Done.


In [6]:
# Analysis
def compare_metrics(runs):
    metrics = []
    # Overall
    for name, df in runs:
        metrics.append({
            'Config': name,
            'Type': 'Overall',
            'MRR': df['mrr'].mean(),
            'Hits@1': df['hit_1'].mean(),
            'Hits@5': df['hit_5'].mean()
        })
    # Per Type
    types = runs[0][1]['question_type'].unique()
    for t in types:
        for name, df in runs:
            sub = df[df['question_type'] == t]
            metrics.append({
                'Config': name,
                'Type': t,
                'MRR': sub['mrr'].mean(),
                'Hits@1': sub['hit_1'].mean(),
                'Hits@5': sub['hit_5'].mean()
            })
    return pd.DataFrame(metrics).sort_values(['Type', 'Config'])

runs = [
    ('Baseline + Finetuning', df_base),
    ('TComplex + Baseline + Finetuning', df_prop),
    ('TComplex (Pure)', df_pure)
]

final = compare_metrics(runs)
print(final)
final.to_csv("benchmark_results.csv")

                              Config           Type       MRR    Hits@1  \
0              Baseline + Finetuning        Overall  0.427168  0.372000   
2                    TComplex (Pure)        Overall  0.000000  0.000000   
1   TComplex + Baseline + Finetuning        Overall  0.564479  0.500000   
9              Baseline + Finetuning   before_after  0.231810  0.161290   
11                   TComplex (Pure)   before_after  0.000000  0.000000   
10  TComplex + Baseline + Finetuning   before_after  0.623387  0.580645   
6              Baseline + Finetuning     first_last  0.373344  0.309524   
8                    TComplex (Pure)     first_last  0.000000  0.000000   
7   TComplex + Baseline + Finetuning     first_last  0.439756  0.398810   
15             Baseline + Finetuning  simple_entity  0.892109  0.843972   
17                   TComplex (Pure)  simple_entity  0.000000  0.000000   
16  TComplex + Baseline + Finetuning  simple_entity  0.903101  0.858156   
12             Baseline +