# Reformulate Method Evaluation

This notebook evaluates the Reformulate query expansion method (BM25 and TF-IDF) on TREC-COVID and Climate-Fever datasets, comparing against baseline.

## How to Run

1. Press "Run All" to execute all cells
2. All required files will be created automatically if missing
3. Metrics and plots will be saved to `data/eval/reformulate/`


In [None]:
# Setup: imports and path configuration
import sys
import json
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path

# Notebook is located at notebook/eval/reformulate.ipynb
project_root = Path.cwd().parents[1]

# Define core configuration
datasets = ["trec_covid", "climate_fever"]
retrieval_methods = ["bm25", "tfidf"]
method_name = "reformulate"
data_base = project_root / "data"

auto_run_max_queries = None  # set to an int (e.g. 50) to limit Groq calls

sys.path.insert(0, str(project_root / "src"))

from eval import (
    compute_metrics_from_files,
    save_metrics_to_csv,
    load_run_file,
    load_qrels_file,
    compute_per_query_metric,
    compare_runs,
)
from eval.utils import (
    ensure_directory,
    find_top_delta_queries,
    create_summary_table,
)
from notebook.run_api import ensure_baseline_runs, ensure_method_runs

from llm_qe.expander import GroqQueryExpander, ExpansionStrategy

print("Setup complete! Project root:", project_root)

## Ensure ingest outputs and method runs

This will:
- Ensure ingested artifacts exist under `data/ingest/{dataset}`
- Ensure baseline runs exist under `data/retrieval/baseline/`
- Expand the real dataset queries using Groq
- Run retrieval (BM25 + TF-IDF) for the expanded queries and save to `data/retrieval/reformulate/`

In [None]:
ensure_directory(data_base / "retrieval" / method_name)
ensure_directory(data_base / "eval" / method_name)

baseline_runs = ensure_baseline_runs(datasets=datasets, retrieval_methods=retrieval_methods, top_k=100)
print("Baseline runs ensured:\n", json.dumps(baseline_runs, indent=2, default=str))

expander = GroqQueryExpander(strategy=ExpansionStrategy.REFORMULATE)

method_runs = ensure_method_runs(
    method_name=method_name,
    strategy=ExpansionStrategy.REFORMULATE,
    expander=expander,
    datasets=datasets,
    retrieval_methods=retrieval_methods,
    top_k=100,
    max_queries=auto_run_max_queries,
)
print(f"{method_name} runs ensured:\n", json.dumps(method_runs, indent=2, default=str))

## Compute Metrics for All 4 Combos


In [None]:
# Compute metrics for reformulate method runs (real data)
all_metrics = {}

for dataset in datasets:
    qrels_path = data_base / "ingest" / dataset / "qrels.csv"

    for retrieval in retrieval_methods:
        run_path = data_base / "retrieval" / method_name / f"{dataset}_{retrieval}.csv"
        metric_path = data_base / "eval" / method_name / f"{dataset}_{retrieval}.csv"

        metrics = compute_metrics_from_files(str(run_path), str(qrels_path), k=10)
        save_metrics_to_csv(
            metrics,
            str(metric_path),
            dataset=dataset,
            method=method_name,
            retrieval=retrieval,
        )

        all_metrics[(dataset, method_name, retrieval)] = metrics
        print(f"{retrieval} × {dataset}: nDCG@10={metrics['ndcg@10']:.4f}, MAP={metrics['map']:.4f}")

print("\nMetrics computation complete!")

## Compare with Baseline


In [None]:
# Compare reformulate method with baseline (statistical test on nDCG@10)
comparison_results = []

for dataset in datasets:
    qrels_path = data_base / "ingest" / dataset / "qrels.csv"

    for retrieval in retrieval_methods:
        run_path = data_base / "retrieval" / method_name / f"{dataset}_{retrieval}.csv"
        baseline_run_path = data_base / "retrieval" / "baseline" / f"{dataset}_{retrieval}.csv"

        stats = compare_runs(
            str(baseline_run_path),
            str(run_path),
            str(qrels_path),
            metric="ndcg@10",
            k=10,
        )

        comparison_results.append(
            {
                "dataset": dataset,
                "retrieval": retrieval,
                "baseline_mean": stats["baseline_mean"],
                "reformulate_mean": stats["system_mean"],
                "difference": stats["mean_difference"],
                "p_value": stats["p_value"],
                "ci_lower": stats.get("ci_lower"),
                "ci_upper": stats.get("ci_upper"),
            }
        )

        print(f"{retrieval} × {dataset}: Δ={stats['mean_difference']:.4f}, p={stats['p_value']:.4f}")

comparison_df = pd.DataFrame(comparison_results)
print("\nComparison Summary:")
print(comparison_df.to_string(index=False))

pvals_path = data_base / "eval" / method_name / "pvals.json"
ensure_directory(pvals_path.parent)
with open(pvals_path, "w", encoding="utf-8") as f:
    json.dump({"method": method_name, "comparisons": comparison_results}, f, indent=2)
print(f"\nP-values and CIs saved to: {pvals_path}")

## Summary and Top Delta Queries


In [None]:
# Create summary table
summary_df = create_summary_table(all_metrics)
summary_path = data_base / "eval" / method_name / "summary.csv"
summary_df.to_csv(summary_path, index=False)

print("Summary Metrics Table:")
print(summary_df.to_string(index=False))
print(f"\nSaved to: {summary_path}")

# Show top delta queries for one example
dataset_example = datasets[0]
retrieval_example = retrieval_methods[0]

qrels = load_qrels_file(str(data_base / "ingest" / dataset_example / "qrels.csv"))
baseline_run = load_run_file(str(data_base / "retrieval" / "baseline" / f"{dataset_example}_{retrieval_example}.csv"))
method_run = load_run_file(str(data_base / "retrieval" / method_name / f"{dataset_example}_{retrieval_example}.csv"))

baseline_scores = compute_per_query_metric(baseline_run, qrels, metric="ndcg@10", k=10)
method_scores = compute_per_query_metric(method_run, qrels, metric="ndcg@10", k=10)

top_positive, top_negative = find_top_delta_queries(baseline_scores, method_scores, top_n=10)

print(f"\nTop 5 Positive Δ (nDCG@10) - {retrieval_example} × {dataset_example}:")
for qid, delta in top_positive[:5]:
    print(f"  {qid}: +{delta:.4f}")

print(f"\nTop 5 Negative Δ (nDCG@10) - {retrieval_example} × {dataset_example}:")
for qid, delta in top_negative[:5]:
    print(f"  {qid}: {delta:.4f}")

## Plot nDCG@10 Comparison


In [None]:
# Plot nDCG@10 comparison (baseline vs reformulate)
from eval.compute_metrics import compute_metrics_from_files as load_metrics

baseline_metrics = {}
for dataset in datasets:
    for retrieval in retrieval_methods:
        qrels_path = data_base / "ingest" / dataset / "qrels.csv"
        baseline_run_path = data_base / "retrieval" / "baseline" / f"{dataset}_{retrieval}.csv"
        baseline_metrics[(dataset, retrieval)] = load_metrics(str(baseline_run_path), str(qrels_path), k=10)

fig, axes = plt.subplots(1, len(retrieval_methods), figsize=(14, 6), squeeze=False)
fig.suptitle("Reformulate Method vs Baseline: nDCG@10", fontsize=16, fontweight="bold")

for idx, retrieval in enumerate(retrieval_methods):
    ax = axes[0][idx]
    x = np.arange(len(datasets))
    width = 0.35

    baseline_scores = [baseline_metrics[(d, retrieval)]["ndcg@10"] for d in datasets]
    method_scores = [all_metrics[(d, method_name, retrieval)]["ndcg@10"] for d in datasets]

    bars1 = ax.bar(x - width / 2, baseline_scores, width, label="Baseline", alpha=0.8)
    bars2 = ax.bar(x + width / 2, method_scores, width, label="Reformulate", alpha=0.8)

    ax.set_xlabel("Dataset", fontsize=12)
    ax.set_ylabel("nDCG@10", fontsize=12)
    ax.set_title(f"{retrieval.upper()}", fontsize=13, fontweight="bold")
    ax.set_xticks(x)
    ax.set_xticklabels(datasets)
    ax.legend()
    ax.grid(axis="y", alpha=0.3)

    for bars in [bars1, bars2]:
        for bar in bars:
            height = bar.get_height()
            ax.text(
                bar.get_x() + bar.get_width() / 2.0,
                height,
                f"{height:.3f}",
                ha="center",
                va="bottom",
                fontsize=9,
            )

plt.tight_layout()
plot_path = data_base / "eval" / method_name / "ndcg.png"
plt.savefig(plot_path, dpi=150, bbox_inches="tight")
print(f"Plot saved to: {plot_path}")
plt.show()