
# Rule Search Evaluation (Names CSV)

This notebook evaluates BM25-only, Semantic-only, Fuzzy-only, and Hybrid search for your Kotlin rule retriever.

- **Primary metric:** MRR@5  
- **Secondary:** Hit@1/3/5, Coverage  
- **Tuning:** Leave-One-Prompt-Out (LOOCV) on hybrid weights (simplex step=0.1)  
- **CSV contains rule names** → we map names → IDs robustly (case/space-insensitive).


In [None]:

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

# Adjust these imports to your own package/module paths.
from notebooks.rule_retriever_improved import RuleRetriever
from your_package.base import SearchConfig, SearchMode   # <-- change 'your_package'
from your_package.db.manager import DatabaseManager      # <-- change 'your_package'
from your_package.embeddings.embeddings_manager import EmbeddingManager  # <-- change 'your_package'

plt.rcParams['figure.figsize'] = (8, 5)



## 1) Load data & initialize retriever


In [None]:

# ---- User inputs ----
CSV_PATH = "../data/prompts.csv"  # columns: prompt, expected_rule (RULE NAME)
DB_PATH = "../db/rules.db"
TABLE_NAME = "rules"
EMBEDDING_MODEL = "../models/embeddings/UAE-Large-V1"

TOPK_POOL = 20  # per signal for candidate pooling

# ---- Load CSV ----
df = pd.read_csv(CSV_PATH)
assert {'prompt', 'expected_rule'} <= set(df.columns), "CSV must contain prompt and expected_rule columns"

# ---- Initialize retriever (disable filtering/reranking for eval) ----
config = SearchConfig(
    semantic_weight=0.6, bm25_weight=0.35, fuzzy_weight=0.05,
    min_similarity=0.0, enable_reranking=False
)
dbm = DatabaseManager(db_path=DB_PATH, table_name=TABLE_NAME); dbm.init_db()
emb_mgr = EmbeddingManager(model_name=EMBEDDING_MODEL)
retriever = RuleRetriever(embedding_manager=emb_mgr, config=config, db_manager=dbm)

print(f"Loaded {len(df)} prompts and {len(retriever.rules)} rules.")


## 2) Map expected **names** → IDs (robust)


In [None]:

# Build name -> id map (case/space-insensitive); support both 'rule_name' and 'name' fields.
name_to_id = {}
for r in retriever.rules:
    nm = (r.get('rule_name') or r.get('name') or '').strip().lower()
    if nm:
        name_to_id[nm] = str(r['rule_id']).strip()

def expected_to_id_from_name(x):
    return name_to_id.get(str(x).strip().lower(), None)

df['expected_id'] = df['expected_rule'].apply(expected_to_id_from_name)
n_missing = int(df['expected_id'].isna().sum())
print("Missing expected_id mapping:", n_missing)
if n_missing:
    display(df[df['expected_id'].isna()].head())



## 3) Candidate pools


In [None]:

def candidate_pool_ids(prompt, k_each=TOPK_POOL):
    pool_rules = retriever.candidate_pool(prompt, k_each=k_each)
    return [r['rule_id'] for r in pool_rules]

pools = {i: candidate_pool_ids(row['prompt']) for i, row in df.iterrows()}

def expected_in_pool_stats():
    present = 0
    issues = []
    for i, row in df.iterrows():
        exp = row['expected_id']
        if not isinstance(exp, str):
            issues.append((i, 'MISSING_MAPPING'))
            continue
        in_pool = exp in pools[i]
        if in_pool:
            present += 1
        else:
            issues.append((i, 'NOT_IN_POOL'))
    print(f"Expected present in pool for {present}/{len(df)} prompts.")
    if issues[:5]:
        print("Examples:", issues[:5])

expected_in_pool_stats()



## 4) Metrics


In [None]:

def rank_of_expected(expected_id, ranked_ids):
    try:
        return ranked_ids.index(expected_id) + 1
    except ValueError:
        return None

def mrr_at_k(ranks, k=5): return float(np.mean([1.0/r if (r is not None and r<=k) else 0.0 for r in ranks]))
def hit_at_k(ranks, k=1): return float(np.mean([1.0 if (r is not None and r<=k) else 0.0 for r in ranks]))
def coverage(ranks):      return float(np.mean([1.0 if (r is not None) else 0.0 for r in ranks]))



## 5) Ranking helpers


In [None]:

def normalize_per_prompt(score_dict):
    if not score_dict: return score_dict
    mx = max(score_dict.values())
    if mx <= 0: return {k: 0.0 for k in score_dict}
    return {k: v/mx for k, v in score_dict.items()}

def ranked_ids_by_method(prompt, pool_ids, method, weights=None):
    id_to_rule = {r['rule_id']: r for r in retriever.rules}
    rules = [id_to_rule[rid] for rid in pool_ids if rid in id_to_rule]

    if method == 'bm25':
        scores = retriever._bm25_scores(prompt, rules)
    elif method == 'semantic':
        scores = retriever._semantic_scores(prompt, rules)
    elif method == 'fuzzy':
        scores = retriever._fuzzy_scores(prompt, rules)
    elif method == 'hybrid':
        scores = retriever._hybrid_scores(prompt, rules)
    elif method == 'hybrid_custom':
        sem = normalize_per_prompt(retriever._semantic_scores(prompt, rules))
        bm  = normalize_per_prompt(retriever._bm25_scores(prompt, rules))
        fz  = normalize_per_prompt(retriever._fuzzy_scores(prompt, rules))
        w_sem, w_bm, w_fz = weights
        s = w_sem + w_bm + w_fz
        if s <= 0: s = 1.0
        w_sem, w_bm, w_fz = w_sem/s, w_bm/s, w_fz/s
        scores = {rid: w_sem*sem.get(rid,0.0) + w_bm*bm.get(rid,0.0) + w_fz*fz.get(rid,0.0) for rid in pool_ids}
    else:
        raise ValueError("Unknown method")

    return sorted(pool_ids, key=lambda rid: scores.get(rid, 0.0), reverse=True)



## 6) Baselines (no tuning)


In [None]:

def evaluate_method(method):
    ranks = []
    for i, row in df.iterrows():
        exp = row['expected_id']
        if not isinstance(exp, str):  # skip missing mapping
            ranks.append(None); continue
        ranked = ranked_ids_by_method(row['prompt'], pools[i], method)
        ranks.append(rank_of_expected(exp, ranked))
    return {
        "MRR@5": mrr_at_k(ranks, k=5),
        "Hit@1": hit_at_k(ranks, k=1),
        "Hit@3": hit_at_k(ranks, k=3),
        "Hit@5": hit_at_k(ranks, k=5),
        "Coverage": coverage(ranks),
    }

baseline_results = {
    "BM25": evaluate_method("bm25"),
    "Semantic": evaluate_method("semantic"),
    "Fuzzy": evaluate_method("fuzzy"),
    "Hybrid (prod)": evaluate_method("hybrid"),
}
import pandas as pd
pd.DataFrame(baseline_results).T



## 7) LOOCV tuning for hybrid weights (simplex step=0.1)


In [None]:

# Build simplex grid
grid = []
for a in np.round(np.arange(0.0, 1.01, 0.1), 2):
    for b in np.round(np.arange(0.0, 1.01 - a, 0.1), 2):
        c = round(1.0 - a - b, 2)
        if c < -1e-9: continue
        grid.append((float(a), float(b), float(c)))

indices_all = [i for i in range(len(df)) if isinstance(df.iloc[i]['expected_id'], str)]

def eval_weights_on(indices, weights):
    ranks = []
    for i in indices:
        row = df.iloc[i]
        ranked = ranked_ids_by_method(row['prompt'], pools[i], 'hybrid_custom', weights=weights)
        ranks.append(rank_of_expected(row['expected_id'], ranked))
    return mrr_at_k(ranks, k=5)

best_weights_per_prompt, ranks_test = [], []
for held_out in indices_all:
    train_idx = [i for i in indices_all if i != held_out]
    best_w, best_score = None, -1.0
    for w in grid:
        s = eval_weights_on(train_idx, w)
        if s > best_score:
            best_score, best_w = s, w
    best_weights_per_prompt.append(best_w)
    row = df.iloc[held_out]
    ranked = ranked_ids_by_method(row['prompt'], pools[held_out], 'hybrid_custom', weights=best_w)
    ranks_test.append(rank_of_expected(row['expected_id'], ranked))

tuned = {
    "MRR@5": mrr_at_k(ranks_test, k=5),
    "Hit@1": hit_at_k(ranks_test, k=1),
    "Hit@3": hit_at_k(ranks_test, k=3),
    "Hit@5": hit_at_k(ranks_test, k=5),
    "Coverage": coverage(ranks_test),
}
import pandas as pd
pd.DataFrame([tuned], index=["Hybrid (tuned)"])



## 8) Compare & plot


In [None]:

all_results = pd.concat([pd.DataFrame(baseline_results).T, pd.DataFrame([tuned], index=["Hybrid (tuned)"])])
display(all_results)

plt.figure()
plt.bar(all_results.index, all_results["MRR@5"])
plt.title("MRR@5 by Method")
plt.ylabel("MRR@5")
plt.xticks(rotation=30, ha='right')
plt.tight_layout()
plt.show()



## 9) Ablation & sensitivity


In [None]:

w_arr = np.array(best_weights_per_prompt)
w_med = tuple(np.median(w_arr, axis=0)) if len(w_arr) else (0.6, 0.35, 0.05)
s = sum(w_med) or 1.0
w_sem, w_bm, w_fz = (w_med[0]/s, w_med[1]/s, w_med[2]/s)
print("Median tuned weights:", (w_sem, w_bm, w_fz))

def evaluate_custom_weights(weights):
    ranks = []
    for i in indices_all:
        row = df.iloc[i]
        ranked = ranked_ids_by_method(row['prompt'], pools[i], 'hybrid_custom', weights=weights)
        ranks.append(rank_of_expected(row['expected_id'], ranked))
    return {
        "MRR@5": mrr_at_k(ranks, k=5),
        "Hit@1": hit_at_k(ranks, k=1),
        "Hit@3": hit_at_k(ranks, k=3),
        "Hit@5": hit_at_k(ranks, k=5),
        "Coverage": coverage(ranks),
    }

# Ablations
abl = {}
ws = (0.0, w_bm, w_fz); s = sum(ws) or 1.0; ws = tuple(w/s for w in ws); abl["No Semantic"] = evaluate_custom_weights(ws)
ws = (w_sem, 0.0, w_fz); s = sum(ws) or 1.0; ws = tuple(w/s for w in ws); abl["No BM25"]   = evaluate_custom_weights(ws)
ws = (w_sem, w_bm, 0.0); s = sum(ws) or 1.0; ws = tuple(w/s for w in ws); abl["No Fuzzy"]  = evaluate_custom_weights(ws)
display(pd.DataFrame(abl).T)

# Sensitivity
def clip01(x): return max(0.0, min(1.0, float(x)))
def sens(center_w, deltas=[-0.2,-0.1,0.0,0.1,0.2]):
    w_sem, w_bm, w_fz = center_w
    rows = []
    for d in deltas:
        ws = (clip01(w_sem+d), w_bm, w_fz); s = sum(ws) or 1.0; ws = tuple(w/s for w in ws); rows.append(("sem", d, evaluate_custom_weights(ws)["MRR@5"]))
        ws = (w_sem, clip01(w_bm+d), w_fz); s = sum(ws) or 1.0; ws = tuple(w/s for w in ws); rows.append(("bm25", d, evaluate_custom_weights(ws)["MRR@5"]))
        ws = (w_sem, w_bm, clip01(w_fz+d)); s = sum(ws) or 1.0; ws = tuple(w/s for w in ws); rows.append(("fuzzy", d, evaluate_custom_weights(ws)["MRR@5"]))
    return pd.DataFrame(rows, columns=["weight","delta","MRR@5"])

sens_df = sens((w_sem, w_bm, w_fz)); display(sens_df)

for wname in ["sem","bm25","fuzzy"]:
    sub = sens_df[sens_df["weight"]==wname]
    plt.figure()
    plt.plot(sub["delta"], sub["MRR@5"], marker='o')
    plt.title(f"Sensitivity: {wname} vs MRR@5")
    plt.xlabel("Delta")
    plt.ylabel("MRR@5")
    plt.grid(True)
    plt.tight_layout()
    plt.show()
