# Lovli - Retrieval Threshold Sweep (Colab T4)

This notebook runs a retrieval-quality sweep over threshold combinations for Lovli.

It is designed for Google Colab with a T4 GPU to avoid local MPS memory pressure.

## What it does
- Loads the project and dependencies
- Builds one `LegalRAGChain`
- Precomputes retrieval/reranker candidates once
- Sweeps combinations of:
  - `retrieval_k_initial`
  - `reranker_confidence_threshold`
  - `reranker_min_doc_score`
- Saves results to `eval/retrieval_sweep_results.json`

## Expected outputs
- Ranked sweep table in notebook output
- JSON file with all combinations and metrics

In [None]:
# If running in Colab, clone your repo first (skip if already cloned).
import os
if not os.path.exists("/content/lovli"):
    !git clone https://github.com/AndreasRamsli/lovli.git
%cd /content/lovli

# Satisfy google-colab's requests pin to avoid dependency conflicts
!pip install -q "requests==2.32.4"
!pip install -q -e .

In [None]:
import json
import itertools
import logging
import math
import os
import sys
from pathlib import Path

import pandas as pd
import torch

# Reduce noisy tracing during local experiments
os.environ["LANGCHAIN_TRACING_V2"] = "false"
os.environ["LANGSMITH_TRACING"] = "false"
os.environ["HF_HUB_DISABLE_TELEMETRY"] = "1"

# Find repo root: Colab clone is /content/lovli; fallback to cwd if running from repo
ROOT_DIR = Path("/content/lovli")
if not (ROOT_DIR / "src" / "lovli").exists():
    for cand in [Path.cwd(), Path.cwd().parent]:
        if (cand / "src" / "lovli").exists():
            ROOT_DIR = cand
            break
if not (ROOT_DIR / "src" / "lovli").exists():
    raise FileNotFoundError(
        "lovli package not found. Run the setup cell above first (clone + pip install), "
        "then Runtime > Restart runtime, then run from this cell."
    )
if str(ROOT_DIR / "src") not in sys.path:
    sys.path.insert(0, str(ROOT_DIR / "src"))

from lovli.chain import LegalRAGChain
from lovli.config import get_settings

logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
logger = logging.getLogger(__name__)

print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")

In [None]:
# Sweep only uses retrieval/reranking (no LLM). Use placeholder so Settings loads.
if not os.environ.get("OPENROUTER_API_KEY"):
    os.environ["OPENROUTER_API_KEY"] = "sweep-placeholder-not-used"

# Qdrant is required. Set via .env in repo root, or uncomment to prompt:
# import getpass
# os.environ["QDRANT_URL"] = input("QDRANT_URL (https://...): ").strip()
# os.environ["QDRANT_API_KEY"] = getpass.getpass("QDRANT_API_KEY: ").strip()

QUESTIONS_PATH = ROOT_DIR / "eval" / "questions.jsonl"
OUT_PATH = ROOT_DIR / "eval" / "retrieval_sweep_results.json"

# Optional: set to an int like 20 for a quick dry run
SAMPLE_SIZE = None

In [None]:
def matches_expected(cited_id: str, expected_set: set[str]) -> bool:
    return any(cited_id.startswith(exp) for exp in expected_set)


def load_questions(path: Path) -> list[dict]:
    rows = []
    with open(path, "r", encoding="utf-8") as f:
        for line in f:
            line = line.strip()
            if line:
                rows.append(json.loads(line))
    return rows


def precompute_candidates(chain: LegalRAGChain, questions: list[dict], max_k_initial: int) -> list[dict]:
    chain.retriever = chain.vectorstore.as_retriever(search_kwargs={"k": max_k_initial})
    cached = []

    for row in questions:
        query = row["question"]
        docs = chain._invoke_retriever(query, routed_law_ids=chain._route_law_ids(query))

        dedup_docs = []
        seen_keys = set()
        for doc in docs:
            metadata = doc.metadata if hasattr(doc, "metadata") else doc.get("metadata", {})
            key = (metadata.get("law_id"), metadata.get("article_id"))
            if metadata.get("article_id") and key not in seen_keys:
                seen_keys.add(key)
                dedup_docs.append(doc)
            elif not metadata.get("article_id"):
                dedup_docs.append(doc)
        docs = dedup_docs[:max_k_initial]

        candidates = []
        if docs:
            pairs = [[query, doc.page_content if hasattr(doc, "page_content") else str(doc)] for doc in docs]
            raw_scores = chain.reranker.predict(pairs) if chain.reranker else [1.0] * len(docs)
            if hasattr(raw_scores, "tolist"):
                raw_scores = raw_scores.tolist()
            scores = [float(s) for s in raw_scores]
            normalized = [1.0 / (1.0 + math.exp(-s)) for s in scores]

            for doc, score in zip(docs, normalized):
                metadata = doc.metadata if hasattr(doc, "metadata") else doc.get("metadata", {})
                candidates.append({
                    "law_id": metadata.get("law_id", ""),
                    "article_id": metadata.get("article_id", ""),
                    "score": score,
                })

        cached.append({
            "id": row.get("id"),
            "expected_articles": row.get("expected_articles", []),
            "candidates": candidates,
        })

    return cached


def evaluate_combo(chain: LegalRAGChain, cached_candidates: list[dict], retrieval_k_initial: int, confidence_threshold: float, min_doc_score: float) -> dict:
    retrieval_hits = 0
    retrieval_total = 0
    ambiguity_total = 0
    ambiguity_clean = 0
    top_scores = []

    final_k = chain.settings.retrieval_k
    min_sources = chain.settings.reranker_min_sources

    for row in cached_candidates:
        expected = set(row.get("expected_articles", []))
        candidates = row.get("candidates", [])

        subset = candidates[:retrieval_k_initial]
        ranked = sorted(subset, key=lambda x: x["score"], reverse=True)[:final_k]

        kept = [c for c in ranked if c["score"] >= min_doc_score]
        if len(kept) < min(min_sources, len(ranked)):
            kept = ranked[: min(min_sources, len(ranked))]

        scores = [c["score"] for c in kept]
        cited_ids = [c["article_id"] for c in kept if c.get("article_id")]
        top_score = scores[0] if scores else None

        if top_score is not None:
            top_scores.append(top_score)

        if expected:
            retrieval_total += 1
            if any(matches_expected(cid, expected) for cid in cited_ids):
                retrieval_hits += 1
        else:
            ambiguity_total += 1
            is_gated = False
            if top_score is not None and top_score < confidence_threshold:
                is_gated = True
            if (
                not is_gated
                and chain.settings.reranker_ambiguity_gating_enabled
                and len(scores) >= 2
                and scores[0] <= chain.settings.reranker_ambiguity_top_score_ceiling
                and (scores[0] - scores[1]) < chain.settings.reranker_ambiguity_min_gap
            ):
                is_gated = True
            if is_gated or not cited_ids:
                ambiguity_clean += 1

    recall_at_k = retrieval_hits / retrieval_total if retrieval_total else 0.0
    ambiguity_success = ambiguity_clean / ambiguity_total if ambiguity_total else 1.0
    avg_top_score = sum(top_scores) / len(top_scores) if top_scores else 0.0

    return {
        "recall_at_k": recall_at_k,
        "ambiguity_success": ambiguity_success,
        "avg_top_score": avg_top_score,
    }

In [None]:
questions = load_questions(QUESTIONS_PATH)
if SAMPLE_SIZE:
    questions = questions[:SAMPLE_SIZE]

settings = get_settings()
chain = LegalRAGChain(settings=settings)

# Fixed baseline from previous sweep
BASE_RETRIEVAL_K_INITIAL = 15
BASE_CONFIDENCE = 0.4
BASE_MIN_DOC_SCORE = 0.3

# Ambiguity-focused grid
MIN_GAP_VALUES = [0.05, 0.08, 0.10, 0.12, 0.15]
TOP_SCORE_CEILING_VALUES = [0.55, 0.60, 0.65, 0.70]

max_k_initial = BASE_RETRIEVAL_K_INITIAL
logger.info("Precomputing candidates once with k=%s on %s questions...", max_k_initial, len(questions))
cached_candidates = precompute_candidates(chain, questions, max_k_initial=max_k_initial)

rows = []
for min_gap, top_score_ceiling in itertools.product(
    MIN_GAP_VALUES,
    TOP_SCORE_CEILING_VALUES,
):
    chain.settings.retrieval_k_initial = BASE_RETRIEVAL_K_INITIAL
    chain.settings.reranker_confidence_threshold = BASE_CONFIDENCE
    chain.settings.reranker_min_doc_score = BASE_MIN_DOC_SCORE
    chain.settings.reranker_ambiguity_min_gap = min_gap
    chain.settings.reranker_ambiguity_top_score_ceiling = top_score_ceiling

    metrics = evaluate_combo(
        chain,
        cached_candidates,
        retrieval_k_initial=BASE_RETRIEVAL_K_INITIAL,
        confidence_threshold=BASE_CONFIDENCE,
        min_doc_score=BASE_MIN_DOC_SCORE,
    )
    row = {
        "retrieval_k_initial": BASE_RETRIEVAL_K_INITIAL,
        "reranker_confidence_threshold": BASE_CONFIDENCE,
        "reranker_min_doc_score": BASE_MIN_DOC_SCORE,
        "reranker_ambiguity_min_gap": min_gap,
        "reranker_ambiguity_top_score_ceiling": top_score_ceiling,
        **metrics,
    }
    rows.append(row)

rows.sort(key=lambda r: (r["ambiguity_success"], r["recall_at_k"]), reverse=True)

with open(OUT_PATH, "w", encoding="utf-8") as f:
    json.dump(rows, f, ensure_ascii=False, indent=2)

print(f"Saved: {OUT_PATH}")
print("Top 5 configurations:")
for row in rows[:5]:
    print(row)

df = pd.DataFrame(rows)
df.head(10)

## Notes

- For a quick pass, set `SAMPLE_SIZE = 20` first.
- For final tuning, set `SAMPLE_SIZE = None`.
- If reranker/model loading still feels slow, keep this notebook on T4 and avoid running in local MPS.
- You can copy the best row values into your `.env` or `Settings` defaults after validation.