In [2]:
#setup

%load_ext autoreload
%autoreload 2

import random
import json
from pathlib import Path

import pandas as pd
import matplotlib.pyplot as plt


import sys


# add /Users/ywxiu/jasp-multimodal-rag/src to sys.path
project_root = Path.cwd().parent  # goes from notebooks → jasp-multimodal-rag
src_path = project_root / "src"
sys.path.append(str(src_path))

import retrieval.retrieval as retr  # adjust if needed


[32m2025-11-21 16:50:35.177[0m | [1mINFO    [0m | [36mretrieval.retrieval[0m:[36m<module>[0m:[36m56[0m - [1m✅ Loaded FlagEmbeddingReranker from llama_index.postprocessor.flag_embedding_reranker[0m


In [3]:
import pandas as pd

def evaluate_config(cfg: dict, test_set: list):
    """Run retrieval for a given config and compute H1a, H1b, H2."""
    # You must have apply_config(cfg) and retr.retrieve_top_k defined somewhere
    apply_config(cfg)

    retrieval_results = []

    for item in test_set:
        q = item["query"]
        qid = item["id"]
        relevant_ids = set(item["relevant_chunk_ids"])

        # Run your pipeline
        results = retr.retrieve_top_k(q, top_k=cfg["TOP_FINAL"])

        retrieved_ids = [
            # try doc_id, fall back to section_id
            (getattr(r, "metadata", {}) or {}).get("doc_id")
            or (getattr(r, "metadata", {}) or {}).get("section_id")
            for r in results
        ]

        success_at_k = any(rid in relevant_ids for rid in retrieved_ids[:cfg["TOP_FINAL"]])
        top1_relevant = retrieved_ids[0] in relevant_ids if retrieved_ids else False

        retrieval_results.append({
            "id": qid,
            "query": q,
            "answerable": item["answerable"],
            "retrieved_ids": retrieved_ids,
            "success_at_k": success_at_k,
            "top1_relevant": top1_relevant,
        })

    df = pd.DataFrame(retrieval_results)
    df_answerable   = df[df["answerable"] == True]
    df_unanswerable = df[df["answerable"] == False]

    H1a = df_answerable["success_at_k"].mean() if not df_answerable.empty else float("nan")
    H1b = df_answerable["top1_relevant"].mean() if not df_answerable.empty else float("nan")

    if not df_unanswerable.empty:
        is_empty = (df_unanswerable["retrieved_ids"].str.len() == 0)
        H2 = is_empty.mean()
    else:
        H2 = float("nan")

    return H1a, H1b, H2


In [4]:
#Step 1 – Load test_set
TEST_JSON = Path("/Users/ywxiu/jasp-multimodal-rag/data/test_QA/QA_filled_1.json")

with open(TEST_JSON, "r") as f:
    test_set = json.load(f)


print(list(test_set[0]))
len(test_set)

['id', 'query', 'answerable', 'ground_truth_answer', 'relevant_chunk_ids']


12

In [5]:
#Step 2 – Helper to apply a config to the retrieval module
def apply_config(cfg: dict):
    retr.K_BM25        = cfg["K_BM25"]
    retr.BOOST_WEIGHT  = cfg["boost_weight"]
    retr.K_SEMANTIC    = cfg["K_SEMANTIC"]
    retr.RRF_K         = cfg["RRF_K"]
    retr.TOP_AFTER_RRF = cfg["TOP_AFTER_RRF"]
    retr.SCORE_THRESHOLD = cfg["score_threshold"]
    retr.TOP_FINAL     = cfg["TOP_FINAL"]

In [6]:
import json
import os
import csv
import hashlib
import traceback

import pandas as pd
from itertools import product
from tqdm import tqdm


def cfg_hash(cfg: dict) -> str:
    """
    Stable unique hash for a configuration.
    """
    s = json.dumps(cfg, sort_keys=True)
    return hashlib.md5(s.encode()).hexdigest()


def load_completed_hashes(save_path: str) -> set:
    """
    Load already completed config hashes from an existing results CSV.
    """
    if not os.path.exists(save_path):
        return set()

    try:
        df = pd.read_csv(save_path)
        if "config_hash" not in df.columns:
            print("⚠️ Existing file has no 'config_hash' column, treating as empty.")
            return set()
        return set(df["config_hash"].astype(str).tolist())
    except Exception as e:
        print(f"⚠️ Warning: Could not read existing results file ({e}). Restarting fresh.")
        return set()


def append_result_row(save_path: str, row: dict, header_written: bool):
    """
    Append a single result row to CSV immediately.
    Ensures nothing is lost even if the process stops.
    """
    write_header = not header_written or not os.path.exists(save_path)

    with open(save_path, "a", newline='', encoding="utf-8") as f:
        writer = csv.DictWriter(f, fieldnames=row.keys())
        if write_header:
            writer.writeheader()
        writer.writerow(row)

    return True


def build_all_configs(param_space: dict):
    """
    Build list of (cfg, config_hash) pairs from param_space.
    cfg does NOT contain config_hash, so it's safe to pass to evaluate_config.
    """
    param_items = list(param_space.items())
    configs = []

    for values in product(*[vals for _, vals in param_items]):
        cfg = {param_items[i][0]: values[i] for i in range(len(values))}
        h = cfg_hash(cfg)
        configs.append((cfg, h))

    return configs


def run_full_grid_search_resumable(param_space, test_set, save_path="grid_search_results.csv"):
    """
    Resumable grid search:
    - Writes each iteration immediately to disk
    - Detects already-finished configs (via config_hash column)
    - Continues where you left off
    """

    # 1) Load previously completed hashes
    completed = load_completed_hashes(save_path)
    print(f"Found {len(completed)} previously completed configurations.")

    # 2) Build full configuration list
    configs = build_all_configs(param_space)
    print(f"Total combinations: {len(configs)}")

    # 3) Loop through configs
    header_written = os.path.exists(save_path)

    for cfg, h in tqdm(configs, desc="Grid Search"):

        # Skip if already completed
        if h in completed:
            continue

        try:
            # Evaluate this config (cfg does NOT contain config_hash)
            H1a, H1b, H2 = evaluate_config(cfg, test_set)

            # Prepare row for CSV:
            #   - all hyperparameters (cfg)
            #   - config_hash
            #   - metrics
            row = {
                **cfg,
                "config_hash": h,
                "H1a": H1a,
                "H1b": H1b,
                "H2": H2,
            }

            # Save immediately
            append_result_row(save_path, row, header_written)
            header_written = True
            completed.add(h)

        except Exception as e:
            print(f"❌ Error evaluating config {cfg} (hash={h}): {e}")
            traceback.print_exc()
            # Continue with next config even if this one fails

    print(f"✅ Finished. Results saved to {save_path}")
    return pd.read_csv(save_path)


In [7]:
#Step 5 – Parameter space for tuning
param_space = {
    "K_BM25":          [10],
    "K_SEMANTIC":      [10],
    "RRF_K":           [120],
    "TOP_AFTER_RRF":   [10],
    "TOP_FINAL":       [5],
    "score_threshold": [-3, -2,-1, -0.5, 0, 0.5, 1, 1.5, 2.0],
    "boost_weight":    [0.0, 1.0, 3.0, 4.0,5.0],

}


from itertools import product

print("Current param_space:")
for k, v in param_space.items():
    print(f"  {k}: {v}")

param_items = list(param_space.items())
configs = []
for values in product(*[vals for _, vals in param_items]):
    cfg = {param_items[i][0]: values[i] for i in range(len(values))}
    configs.append(cfg)

print("Total combinations:", len(configs))
print("First (and only?) config:", configs[0])

Current param_space:
  K_BM25: [10]
  K_SEMANTIC: [10]
  RRF_K: [120]
  TOP_AFTER_RRF: [10]
  TOP_FINAL: [5]
  score_threshold: [-3, -2, -1, -0.5, 0, 0.5, 1, 1.5, 2.0]
  boost_weight: [0.0, 1.0, 3.0, 4.0, 5.0]
Total combinations: 45
First (and only?) config: {'K_BM25': 10, 'K_SEMANTIC': 10, 'RRF_K': 120, 'TOP_AFTER_RRF': 10, 'TOP_FINAL': 5, 'score_threshold': -3, 'boost_weight': 0.0}


In [None]:
df_results = run_full_grid_search_resumable(
    param_space=param_space,
    test_set=test_set,
    save_path="grid_search_results.csv"
)
