# Agent CFO ‚Äî Performance Optimization & Design

---
This is the starter notebook for your project. Follow the required structure below.


You will design and optimize an Agent CFO assistant for a listed company. The assistant should answer finance/operations questions using RAG (Retrieval-Augmented Generation) + agentic reasoning, with response time (latency) as the primary metric.

Your system must:
*   Ingest the company‚Äôs public filings.
*   Retrieve relevant passages efficiently.
*   Compute ratios/trends via tool calls (calculator, table parsing).
*   Produce answers with valid citations to the correct page/table.


In [2]:
import os

# Best practice: do NOT hardcode API keys in notebook cells.
# If GEMINI_API_KEY is already set in the environment (e.g., via secrets), keep it.
# Otherwise, prompt the user to enter it securely (won't be echoed).
if os.environ.get("GEMINI_API_KEY"):
	print("GEMINI_API_KEY found in environment.")
else:
	try:
		from getpass import getpass
		key = getpass("Enter GEMINI_API_KEY (input hidden): ")
	except Exception:
		# Fallback to input() if getpass is unavailable in this environment
		key = input("Enter GEMINI_API_KEY: ")
	if key:
		os.environ["GEMINI_API_KEY"] = key
		print("GEMINI_API_KEY set for this session (not saved).")
	else:
		raise RuntimeError("GEMINI_API_KEY not provided. Set it via environment variables or re-run this cell.")

GEMINI_API_KEY found in environment.


## 1. Config & Secrets

Fill in your API keys in secrets. **Do not hardcode keys** in cells.

In [3]:
import os

# Example:
# os.environ['GEMINI_API_KEY'] = 'your-key-here'
# os.environ['OPENAI_API_KEY'] = 'your-key-here'

COMPANY_NAME = "DBS Bank"


## 2. Data Download (Dropbox)

*   Annual Reports: last 3‚Äì5 years.
*   Quarterly Results Packs & MD&A (Management Discussion & Analysis).
*   Investor Presentations and Press Releases.
*   These files must be submitted later as a deliverable in the Dropbox data pack.
*   Upload them under `/content/data/`.

Scope limit: each team will ingest minimally 15 PDF files total.


## 3. System Requirements

**Retrieval & RAG**
*   Use a vector index (e.g., FAISS, LlamaIndex) + a keyword filter (BM25/ElasticSearch).
*   Citations must include: report name, year, page number, section/table.

**Agentic Reasoning**
*   Support at least 3 tool types: calculator, table extraction, multi-document compare.
*   Reasoning must follow a plan-then-act pattern (not a single unstructured call).

**Instrumentation**
*   Log timings for: T_ingest, T_retrieve, T_rerank, T_reason, T_generate, T_total.
*   Log: tokens used, cache hits, tools invoked.
*   Record p50/p95 latencies.

 ### Gemini Version 1

In [2]:
# 1. Install the marker library
# This command should be run in your terminal or a Colab cell:
# !pip install marker-pdf -q

# 2. Import necessary components
import subprocess
import shutil
from pathlib import Path
import sys
import hashlib
import re
import cv2
import numpy as np
import pandas as pd


def md5sum(file_path: Path, chunk_size: int = 8192) -> str:
    """Return the hex md5 of a file."""
    h = hashlib.md5()
    with open(file_path, "rb") as f:
        for chunk in iter(lambda: f.read(chunk_size), b""):
            h.update(chunk)
    return h.hexdigest()

# === OCR & extraction helpers ===
NUM_PAT = re.compile(r"^[+-]?\d{1,4}(?:[.,]\d+)?%?$")
NIM_KEYWORDS = ["net interest margin", "nim"]

QUARTER_PAT = re.compile(r"\b([1-4Iil|])\s*[QO0]\s*([0-9O]{2,4})\b", re.IGNORECASE)
# Simpler decade-only pattern for quarters, e.g., 2Q24, 1Q25
QUARTER_SIMPLE_PAT = re.compile(r"\b([1-4])Q(2\d)\b", re.IGNORECASE)  # e.g., 2Q24, 1Q25

# --- OCR character normalization for quarter tokens (common OCR mistakes) ---
_CHAR_FIX = str.maketrans({
    "O":"0","o":"0",
    "S":"5","s":"5",
    "I":"1","l":"1","|":"1","!":"1",
    "D":"0",
    "B":"3","8":"3",
    "Z":"2","z":"2"
})
def normalize_token(t: str) -> str:
    t = (t or "").strip()
    return t.translate(_CHAR_FIX).replace(" ", "")

# --- Helper: detect quarter tokens from nearby Markdown file ---
def detect_qlabels_from_md(dest_dir: Path, image_name: str) -> list[str]:
    """
    Scan the figure's markdown file for quarter tokens (e.g., 2Q24, 1Q2025).
    Returns tokens in document order (deduped).
    """
    try:
        md_file = dest_dir / f"{dest_dir.name}.md"
        if not md_file.exists():
            cand = list(dest_dir.glob("*.md"))
            if not cand:
                return []
            md_file = cand[0]
        text = md_file.read_text(encoding="utf-8", errors="ignore")
    except Exception:
        return []
    # Collect all quarter tokens across the document
    tokens = []
    for m in QUARTER_PAT.finditer(text):
        q = f"{m.group(1)}Q{m.group(2)[-2:]}"
        tokens.append(q)
    # Deduplicate preserving order
    seen = set()
    ordered = []
    for q in tokens:
        if q not in seen:
            seen.add(q)
            ordered.append(q)
    return ordered

def load_image(path):
    p = Path(path)
    im = cv2.imread(str(p))
    if im is None:
        raise RuntimeError(f"cv2.imread() failed: {p}")
    return im

def preprocess(img_bgr):
    scale = 2.0
    img = cv2.resize(img_bgr, None, fx=scale, fy=scale, interpolation=cv2.INTER_CUBIC)
    gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
    gray = cv2.bilateralFilter(gray, d=7, sigmaColor=50, sigmaSpace=50)
    clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8,8))
    gray = clahe.apply(gray)
    thr = cv2.adaptiveThreshold(gray, 255, cv2.ADAPTIVE_THRESH_GAUSSIAN_C,
                                cv2.THRESH_BINARY, 31, 8)
    return img, gray, thr, scale

def norm_num(s):
    s = s.replace(",", "").strip()
    pct = s.endswith("%")
    if pct:
        s = s[:-1]
    try:
        return float(s), pct
    except:
        return None, pct

def extract_numbers(ocr_results):
    rows = []
    for r in ocr_results or []:
        txt = str(r.get("text","")).strip()
        if NUM_PAT.match(txt):
            val, is_pct = norm_num(txt)
            if val is None:
                continue
            x1,y1,x2,y2 = r["bbox"]
            rows.append({
                "raw": txt, "value": val, "is_pct": is_pct, "conf": r.get("conf", None),
                "x1": int(x1), "y1": int(y1), "x2": int(x2), "y2": int(y2),
                "cx": int((x1+x2)/2), "cy": int((y1+y2)/2)
            })
    df = pd.DataFrame(rows).sort_values(["cy","cx"]).reset_index(drop=True)
    if "is_pct" not in df.columns and not df.empty:
        df["is_pct"] = df["raw"].astype(str).str.endswith("%")
    return df

def kmeans_1d(values, k=2, iters=20):
    values = np.asarray(values, dtype=float).reshape(-1,1)
    centers = np.array([values.min(), values.max()]).reshape(k,1)
    for _ in range(iters):
        d = ((values - centers.T)**2)
        labels = d.argmin(axis=1)
        new_centers = np.array([values[labels==i].mean() if np.any(labels==i) else centers[i] for i in range(k)]).reshape(k,1)
        if np.allclose(new_centers, centers, atol=1e-3):
            break
        centers = new_centers
    return labels, centers.flatten()

def run_easyocr(img_rgb):
    import easyocr
    global _EASY_OCR_READER
    try:
        _EASY_OCR_READER
    except NameError:
        _EASY_OCR_READER = None
    if _EASY_OCR_READER is None:
        _EASY_OCR_READER = easyocr.Reader(['en'], gpu=False, verbose=False)
    results = _EASY_OCR_READER.readtext(img_rgb, detail=1, paragraph=False)
    out = []
    for quad, text, conf in results:
        (x1,y1),(x2,y2),(x3,y3),(x4,y4) = quad
        out.append({"bbox": (int(x1),int(y1),int(x3),int(y3)), "text": str(text), "conf": float(conf)})
    return out

# --- Focused bottom-axis quarter detection using EasyOCR (robust to OCR confusions) ---
def detect_quarters_easyocr(img_bgr):
    """
    Use EasyOCR to read quarter labels along the bottom axis.
    Returns a list of (x_global, 'nQyy') sorted left‚Üíright, with half-year tokens removed.
    """
    H, W = img_bgr.shape[:2]
    y0 = int(H * 0.66)  # bottom ~34%
    crop = img_bgr[y0:H, 0:W]
    gray = cv2.cvtColor(crop, cv2.COLOR_BGR2GRAY)
    gray = cv2.bilateralFilter(gray, d=7, sigmaColor=50, sigmaSpace=50)
    clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8,8))
    gray = clahe.apply(gray)
    thr = cv2.adaptiveThreshold(gray, 255, cv2.ADAPTIVE_THRESH_GAUSSIAN_C,
                                cv2.THRESH_BINARY, 31, 8)
    # kernel = np.ones((3,3), np.uint8)
    # thr = cv2.morphologyEx(thr, cv2.MORPH_CLOSE, kernel, iterations=1)
    up = cv2.resize(thr, None, fx=3.0, fy=3.0, interpolation=cv2.INTER_CUBIC)
    img_rgb = cv2.cvtColor(up, cv2.COLOR_GRAY2RGB)
    ocr = run_easyocr(img_rgb)
    # PASS 1 ‚Äî direct regex on normalized tokens
    tokens = []
    for r in ocr or []:
        raw = str(r.get("text","")).strip()
        x1,y1,x2,y2 = r["bbox"]
        cx_local = (x1 + x2) // 2
        cx_global = int(cx_local / 3.0)  # undo scaling
        tokens.append({"x": cx_global, "raw": raw, "norm": normalize_token(raw)})
    def _is_half_token(t: str) -> bool:
        t = (t or "").lower().replace(" ", "")
        return ("9m" in t) or ("1h" in t) or ("h1" in t) or ("h2" in t) or ("2h" in t)
    quarters = []
    for t in tokens:
        if _is_half_token(t["norm"]):
            continue
        m = QUARTER_PAT.search(t["norm"])
        if m:
            q = f"{m.group(1)}Q{m.group(2)[-2:]}"
            q = normalize_token(q)
            quarters.append((t["x"], q))
    # PASS 2 ‚Äî stitch split tokens if too few quarters were found
    if len(quarters) < 4 and tokens:
        pieces = sorted(tokens, key=lambda d: d["x"])
        digits_1to4 = [p for p in pieces if p["norm"] in ("1","2","3","4")]
        q_only      = [p for p in pieces if p["norm"].upper() == "Q"]
        q_with_year = [p for p in pieces if re.fullmatch(r"Q[0-9O]{2,4}", p["norm"], flags=re.I)]
        years_2d    = [p for p in pieces if re.fullmatch(r"[0-9O]{2,4}", p["norm"])]
        def near(a, b, tol=70):
            return abs(a["x"] - b["x"]) <= tol
        for d in digits_1to4:
            # digit + Qyy
            candidates = [q for q in q_with_year if near(d, q)]
            if candidates:
                qtok = min(candidates, key=lambda q: abs(q["x"]-d["x"]))
                qyy = normalize_token(qtok["norm"])[1:]
                quarters.append(((d["x"]+qtok["x"])//2, f"{d['norm']}Q{qyy[-2:]}"))
                continue
            # digit + Q + yy
            qs = [q for q in q_only if near(d, q)]
            ys = [y for y in years_2d if near(d, y, tol=120)]
            if qs and ys:
                qtok = min(qs, key=lambda q: abs(q["x"]-d["x"]))
                ytok = min(ys, key=lambda y: abs(y["x"]-qtok["x"]))
                yy = normalize_token(ytok["norm"])
                quarters.append(((d["x"]+ytok["x"])//2, f"{d['norm']}Q{yy[-2:]}"))
                continue
    if not quarters:
        return []
    quarters.sort(key=lambda t: t[0])
    deduped, last_x = [], -10**9
    for x,q in quarters:
        if abs(x - last_x) <= 22:
            continue
        deduped.append((x,q))
        last_x = x
    return deduped

# NIM value band (pct) and geometry heuristics for verification
NIM_MIN, NIM_MAX = 1.3, 3.2
TOP_FRACTION = 0.65     # widen band: NIM labels often sit higher than 45%
RIGHT_HALF_ONLY = True  # NIM values appear on right panel in these deck

def is_strict_nim_image(img_path: Path) -> tuple[bool, str]:
    """
    Heuristic re-check:
      1) Title/text contains NIM keywords (coarse gate)
      2) Percent tokens mostly within NIM_MIN..NIM_MAX
      3) Tokens located in the top region (and right half, if enabled)
    Returns (ok, reason)
    """
    try:
        img_bgr = load_image(img_path)
        H, W = img_bgr.shape[:2]
        # 1) quick-text gate (soft): don't return yet; allow numeric signature to validate
        kw_ok = is_relevant_image(img_path, NIM_KEYWORDS)
        # 2) numeric gate on enhanced image
        img_up, gray, thr, scale = preprocess(img_bgr)
        img_rgb = cv2.cvtColor(thr, cv2.COLOR_GRAY2RGB)
        ocr = run_easyocr(img_rgb)
        # --- Semantic gate: accept classic NIM slides based on stable labels ---
        text_lower = " ".join(str(r.get("text", "")).lower() for r in ocr or [])
        has_nim = "net interest margin" in text_lower
        has_cb  = "commercial book" in text_lower
        has_grp = "group" in text_lower
        if has_nim and (has_cb or has_grp):
            which = [w for w, ok in (("nim", has_nim), ("cb", has_cb), ("grp", has_grp)) if ok]
            return (True, f"ok_semantic({'+' .join(which)})")
        df = extract_numbers(ocr)
        if df.empty:
            return (False, "no_numbers")
        # geometry filters (apply before value checks)
        top_cut = int(img_up.shape[0] * 0.62)
        cond_geom = (df["cy"] < top_cut)
        if RIGHT_HALF_ONLY:
            cond_geom &= (df["cx"] > (img_up.shape[1] // 2))

        # 2a) Preferred path: explicit percentage tokens
        df_pct = df[(df["is_pct"] == True) & cond_geom].copy()
        if not df_pct.empty:
            in_band = df_pct["value"].between(NIM_MIN, NIM_MAX)
            ratio = float(in_band.sum()) / float(len(df_pct))
            if ratio >= 0.6:
                return (True, "ok")
            else:
                return (False, f"non_nim_values_out_of_band({ratio:.2f})")

        # 2b) Fallback: some decks omit the % sign near the series values.
        # Accept plain numbers in the NIM range if units are explicit or implied, or if numeric signature is strong.
        title_text = text_lower  # already computed above
        has_units_pct = "(%)" in title_text or "margin (%)" in title_text or has_nim
        df_nums = df[(df["is_pct"] == False) & cond_geom].copy()
        if not df_nums.empty:
            in_band = df_nums["value"].between(NIM_MIN, NIM_MAX)
            ratio = float(in_band.sum()) / float(len(df_nums))
            # Case A: explicit or implied units in title ‚Üí accept when enough in-band hits
            if has_units_pct and ratio >= 0.6 and in_band.sum() >= 3:
                return (True, "ok_no_percent_signs")
            # Case B: title OCR may have missed units; if the quick keyword gate succeeded, accept with a stricter ratio
            if kw_ok and ratio >= 0.7 and in_band.sum() >= 3:
                return (True, "ok_numeric_signature")
            # Case C: strong structural evidence (quarters on bottom) + numeric signature in band
            q_xy_fallback = detect_quarters_easyocr(img_bgr)
            if len(q_xy_fallback) >= 4 and ratio >= 0.6 and in_band.sum() >= 3:
                return (True, "ok_structural_numeric_signature")

        # Final decision: if numeric signature still failed, report clearer reason
        if not kw_ok:
            return (False, "irrelevant_non_nim")
        else:
            return (False, "no_percentages_or_units")
    except Exception as e:
        return (False, f"exception:{e}")


# --- Helper: detect and order quarter labels from OCR ---
def detect_qlabels(ocr_results, img_width: int) -> list[str]:
    """
    Extract quarter tokens like 1Q25, 2Q2025 from OCR and return them left‚Üíright.
    We keep only tokens on the right half (where the series values live in your layout).
    """
    qtokens = []
    mid_x = img_width // 2
    for r in ocr_results or []:
        txt = str(r.get("text","")).strip()
        m = QUARTER_PAT.search(txt)
        if not m:
            continue
        x1,y1,x2,y2 = r["bbox"]
        cx = (x1 + x2) // 2
        if cx <= mid_x:
            continue  # ignore left panel quarters/titles
        q = f"{m.group(1)}Q{m.group(2)[-2:]}"  # normalize to 1Q25 style
        qtokens.append((cx, q))
    # sort by visual x-position and deduplicate by both text and proximity (ignore near-duplicates)
    qtokens.sort(key=lambda x: x[0])
    # Deduplicate by both text and proximity (ignore near-duplicates)
    ordered = []
    last_x = -9999
    last_q = None
    for x, q in qtokens:
        if last_q == q and abs(x - last_x) < 30:
            continue
        ordered.append(q)
        last_x, last_q = x, q
    return ordered

# === Focused bottom-of-chart scan for small quarter labels ===
def detect_qlabels_bottom(img_bgr) -> list[str]:
    """
    Focused pass: crop the bottom ~30% (where quarter labels usually sit),
    enhance contrast, OCR, and extract quarter tokens left‚Üíright.
    """
    try:
        H, W = img_bgr.shape[:2]
        y0 = int(H * 0.60)  # bottom 40%
        crop = img_bgr[y0:H, 0:W]
        # Enhance: grayscale -> bilateral -> CLAHE -> adaptive threshold
        gray = cv2.cvtColor(crop, cv2.COLOR_BGR2GRAY)
        gray = cv2.bilateralFilter(gray, d=7, sigmaColor=50, sigmaSpace=50)
        clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8,8))
        gray = clahe.apply(gray)
        thr = cv2.adaptiveThreshold(gray, 255, cv2.ADAPTIVE_THRESH_GAUSSIAN_C,
                                    cv2.THRESH_BINARY, 31, 8)
        # Morphological close to strengthen thin glyphs
        kernel = np.ones((3,3), np.uint8)
        thr = cv2.morphologyEx(thr, cv2.MORPH_CLOSE, kernel, iterations=1)
        # Upscale for small text
        up = cv2.resize(thr, None, fx=2.5, fy=2.5, interpolation=cv2.INTER_CUBIC)
        img_rgb = cv2.cvtColor(up, cv2.COLOR_GRAY2RGB)
        ocr = run_easyocr(img_rgb)
        # Map bboxes back to global coords: decide single-panel vs split-panel
        mid_x = W // 2
        left_quarters, right_quarters = [], []
        left_tokens_text, right_tokens_text = [], []
        for r in ocr or []:
            raw = str(r.get("text", "")).strip()
            x1,y1,x2,y2 = r["bbox"]
            cx_local = (x1 + x2) // 2
            cx_global = int(cx_local / 2.5)  # undo scale

            if cx_global <= mid_x:
                left_tokens_text.append(raw.lower())
            else:
                right_tokens_text.append(raw.lower())

            m = QUARTER_PAT.search(raw)
            if not m:
                continue
            q = f"{m.group(1)}Q{m.group(2)[-2:]}"
            if cx_global <= mid_x:
                left_quarters.append((cx_global, q))
            else:
                right_quarters.append((cx_global, q))

        def has_halfyear_or_9m(tokens: list[str]) -> bool:
            s = " ".join(tokens)
            return ("9m" in s) or ("1h" in s) or ("h1" in s) or ("h2" in s) or ("2h" in s)

        left_has_h = has_halfyear_or_9m(left_tokens_text)
        # Panel selection logic: prefer both halves unless left clearly half-year and right has ‚â•3 quarters
        if (not left_has_h) and (len(left_quarters) + len(right_quarters) >= 2):
            # Likely single panel or weak OCR on one side ‚Üí use both halves
            qtokens = left_quarters + right_quarters
        elif len(right_quarters) >= 3:
            # Strong right panel signal ‚Üí use right only
            qtokens = right_quarters
        else:
            # Fallback: use everything we found
            qtokens = left_quarters + right_quarters

        # Sort and dedupe close neighbors (‚â§18 px)
        qtokens.sort(key=lambda t: t[0])
        deduped = []
        last_x = -10**9
        for x, q in qtokens:
            if abs(x - last_x) <= 18:
                continue
            deduped.append((x, q))
            last_x = x

        return [q for _, q in deduped]
    except Exception:
        return []

# --- Same as detect_qlabels_bottom, but returns (x, label) for alignment ---
def detect_qlabels_bottom_with_xy(img_bgr) -> list[tuple[int, str]]:
    try:
        H, W = img_bgr.shape[:2]
        y0 = int(H * 0.60)
        crop = img_bgr[y0:H, 0:W]
        gray = cv2.cvtColor(crop, cv2.COLOR_BGR2GRAY)
        gray = cv2.bilateralFilter(gray, d=7, sigmaColor=50, sigmaSpace=50)
        clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8,8))
        gray = clahe.apply(gray)
        thr = cv2.adaptiveThreshold(gray, 255, cv2.ADAPTIVE_THRESH_GAUSSIAN_C,
                                    cv2.THRESH_BINARY, 31, 8)
        kernel = np.ones((3,3), np.uint8)
        thr = cv2.morphologyEx(thr, cv2.MORPH_CLOSE, kernel, iterations=1)
        up = cv2.resize(thr, None, fx=2.5, fy=2.5, interpolation=cv2.INTER_CUBIC)
        img_rgb = cv2.cvtColor(up, cv2.COLOR_GRAY2RGB)
        ocr = run_easyocr(img_rgb)

        mid_x = W // 2
        left_quarters, right_quarters = [], []
        left_tokens_text = []
        for r in ocr or []:
            raw = str(r.get("text", "")).strip()
            x1,y1,x2,y2 = r["bbox"]
            cx_local = (x1 + x2) // 2
            cx_global = int(cx_local / 2.5)
            if cx_global <= mid_x:
                left_tokens_text.append(raw.lower())
            m = QUARTER_PAT.search(raw)
            if not m:
                continue
            q = f"{m.group(1)}Q{m.group(2)[-2:]}"
            if cx_global <= mid_x:
                left_quarters.append((cx_global, q))
            else:
                right_quarters.append((cx_global, q))

        def has_halfyear_or_9m(tokens: list[str]) -> bool:
            s = " ".join(tokens)
            return ("9m" in s) or ("1h" in s) or ("h1" in s) or ("h2" in s) or ("2h" in s)

        left_has_h = has_halfyear_or_9m(left_tokens_text)
        if (not left_has_h) and (len(left_quarters) + len(right_quarters) >= 2):
            # Likely single panel or weak OCR on one side ‚Üí use both halves
            qtokens = left_quarters + right_quarters
        elif len(right_quarters) >= 3:
            # Strong right panel signal ‚Üí use right only
            qtokens = right_quarters
        else:
            # Fallback: use everything we found
            qtokens = left_quarters + right_quarters

        qtokens.sort(key=lambda t: t[0])
        deduped = []
        last_x = -10**9
        for x, q in qtokens:
            if abs(x - last_x) <= 18:
                continue
            deduped.append((x, q))
            last_x = x
        return deduped
    except Exception:
        return []

# --- Merge two ordered quarter lists ---
def _merge_ordered(primary: list[str], secondary: list[str]) -> list[str]:
    """
    Merge two left‚Üíright sequences, keeping 'primary' order and filling with
    any unseen items from 'secondary' in their order.
    """
    out = list(primary)
    seen = set(primary)
    for q in secondary:
        if q not in seen:
            out.append(q)
            seen.add(q)
    return out

# --- Expand a quarter label like '2Q24' forward n quarters ---
def _expand_quarters(start_q: str, n: int) -> list[str]:
    """
    Given a label like '2Q24', produce a forward sequence of n quarters:
    2Q24, 3Q24, 4Q24, 1Q25, 2Q25, ...
    """
    m = QUARTER_PAT.match(start_q) or QUARTER_SIMPLE_PAT.match(start_q)
    if not m:
        return []
    q = int(m.group(1))
    yy = int(m.group(2)[-2:])
    seq = []
    for _ in range(n):
        seq.append(f"{q}Q{yy:02d}")
        q += 1
        if q == 5:
            q = 1
            yy = (yy + 1) % 100
    return seq

# --- Find a plausible anchor quarter like 2Q24 from OCR or markdown tokens ---
def _anchor_quarter_from_texts(ocr_results, md_tokens: list[str]) -> str | None:
    """
    Find any token like 1Q2x..4Q2x from OCR texts or markdown tokens.
    Returns the first plausible anchor (normalized to e.g. 2Q24) or None.
    """
    # prefer bottom/ocr-derived tokens first (already parsed in detect_qlabels_bottom)
    # fallback: scan all OCR texts with simple pattern
    for r in ocr_results or []:
        txt = str(r.get("text","")).strip()
        m = QUARTER_SIMPLE_PAT.search(txt)
        if m:
            return f"{m.group(1)}Q{m.group(2)}"
    # fallback to any markdown token that matches the decade pattern
    for t in md_tokens or []:
        m = QUARTER_SIMPLE_PAT.match(t)
        if m:
            return f"{m.group(1)}Q{m.group(2)}"
    return None

def extract_series_from_df(df, img_up, ocr_results=None, qlabels_hint=None):
    H, W = img_up.shape[:2]
    mid_x = W//2
    top_band_min = int(H * 0.38)
    top_band_max = int(H * 0.58)

    # Detect bottom quarter labels (with x) early to infer layout
    detected_q_bot_xy = detect_quarters_easyocr(img_up)
    left_count  = sum(1 for x, _ in detected_q_bot_xy if x <= mid_x)
    right_count = sum(1 for x, _ in detected_q_bot_xy if x >  mid_x)
    # Heuristic: if we see ‚â•4 quarter tokens spanning both halves, it's a single-panel timeline
    single_panel = (len(detected_q_bot_xy) >= 4 and left_count >= 1 and right_count >= 1)

    # Filter tokens: keep right-half only for split panels; keep all for single panels
    if single_panel:
        pct = df[(df.is_pct==True)].copy()
        nums = df[(df.is_pct==False)].copy()
    else:
        pct = df[(df.is_pct==True) & (df.cx > mid_x)].copy()
        nums = df[(df.is_pct==False) & (df.cx > mid_x)].copy()

    if pct.empty:
        # Fallback for charts that omit the '%' sign on the value dots.
        # Use a wider top band and avoid forcing right-half on single-panel timelines.
        approx_top = int(H * 0.60)
        if single_panel:
            cx_mask = (df.cx > 0)  # keep all x for single panel
        else:
            cx_mask = (df.cx > mid_x)
        cand_pct = df[cx_mask & df.value.between(NIM_MIN, NIM_MAX) & (df.cy < approx_top)].copy()
        if not cand_pct.empty:
            cand_pct["is_pct"] = True
            pct = cand_pct

    nim_df = pd.DataFrame()
    if not pct.empty:
        # Try to split into two horizontal series by Y even when we have only 3 quarters (‚Üí 6 points)
        # Deduplicate by proximity on Y to stabilize clustering
        y_sorted = pct.sort_values("cy")["cy"].to_numpy()
        uniq_y = []
        last_y = -10**9
        for yy in y_sorted:
            if abs(yy - last_y) >= 6:  # 6px tolerance for duplicates
                uniq_y.append(yy)
                last_y = yy
        # Attempt k-means when we have at least 4 points total (‚âà 2 series √ó 2 quarters)
        if pct.shape[0] >= 4 and len(uniq_y) >= 2:
            labels, centers = kmeans_1d(pct["cy"].values, k=2)
            pct["series"] = labels
            order = np.argsort(centers)  # top (commercial) should have smaller y
            remap = {order[0]: "Commercial NIM (%)", order[1]: "Group NIM (%)"}
            pct["series_name"] = pct["series"].map(remap)
            # Sanity: ensure both series have data; else collapse to one
            counts = pct["series_name"].value_counts()
            if any(counts.get(name, 0) == 0 for name in ["Commercial NIM (%)", "Group NIM (%)"]):
                pct["series_name"] = "NIM (%)"
        else:
            pct["series_name"] = "NIM (%)"

        # Reuse bottom-quarter labels captured above
        detected_q_bot = [q for _, q in detected_q_bot_xy]
        detected_q_ocr = detect_qlabels(ocr_results or [], W) if ocr_results is not None else []
        if len(detected_q_bot) > len(detected_q_ocr):
            detected_q = _merge_ordered(detected_q_bot, detected_q_ocr)
        else:
            detected_q = _merge_ordered(detected_q_ocr, detected_q_bot)
        rows = []
        for name, sub in pct.groupby("series_name"):
            # Sort left‚Üíright and collapse near-duplicates (same x within 12px)
            sub_sorted = sub.sort_values("cx")
            uniq_rows = []
            last_x = -10**9
            for r in sub_sorted.itertuples(index=False):
                if abs(r.cx - last_x) < 12:
                    continue
                uniq_rows.append(r)
                last_x = r.cx
            # Keep only the right-panel portion (already ensured by cx>mid_x earlier)
            pick = list(uniq_rows)[-5:]  # cap to 5 most recent positions, but may be <5
            n = len(pick)
            if n == 0:
                continue
            labels = []
            # Robust mapping: map each value x to its nearest bottom quarter label x (right panel).
            # Filter any accidental half-year tokens (1H/2H/H1/H2/9M) just in case OCR returns them.
            def _is_half_token(t: str) -> bool:
                t = (t or "").lower().replace(" ", "")
                return ("9m" in t) or ("1h" in t) or ("h1" in t) or ("h2" in t) or ("2h" in t) or ("h24" in t) or ("h23" in t)

            # detected_q_bot_xy already respects split vs single panel. Keep right-panel positions only here.
            q_xy = []
            for x, q in detected_q_bot_xy:
                if x <= mid_x:
                    continue
                if _is_half_token(q):
                    continue
                q_xy.append((x, q))

            if len(q_xy) < n:
                # Borrow from left panel if they look like quarters (and not half-year)
                for x, q in detected_q_bot_xy:
                    if x > mid_x:
                        continue
                    if _is_half_token(q):
                        continue
                    q_xy.append((x, q))

            if q_xy:
                q_xy.sort(key=lambda t: t[0])  # left‚Üíright
                # Map each picked value to nearest quarter label by x-position
                vx = [rr.cx for rr in pick]
                qx = [x for x, _ in q_xy]
                ql = [q for _, q in q_xy]
                mapped = []
                for x in vx:
                    j = int(np.argmin([abs(x - xx) for xx in qx])) if qx else -1
                    mapped.append(ql[j] if j >= 0 else None)
                labels = mapped
            else:
                detected_q_ocr = detect_qlabels(ocr_results or [], W) if ocr_results is not None else []
                if detected_q_ocr:
                    labels = detected_q_ocr[-n:] if len(detected_q_ocr) >= n else detected_q_ocr

            # If still short, use markdown tokens; else expand from an anchor like 2Q24
            if (not labels) or (len(labels) != n):
                if qlabels_hint:
                    labels = qlabels_hint[-n:] if len(qlabels_hint) >= n else qlabels_hint
            if (not labels) or (len(labels) != n):
                anchor = _anchor_quarter_from_texts(ocr_results, qlabels_hint)
                if anchor:
                    labels = _expand_quarters(anchor, n)
            if (not labels) or (len(labels) != n):
                labels = [f"{i+1}Q??" for i in range(n)]
            # Ensure left‚Üíright order for consistent mapping to labels
            pick = sorted(pick, key=lambda r: r.cx)
            labels = list(labels)[:n]
            for i, r in enumerate(pick):
                if i >= len(labels):
                    break
                rows.append({"Quarter": labels[i], "series": name, "value": r.value})
        if rows:
            nim_table = pd.DataFrame(rows)
            # Guard: drop rows with missing labels
            nim_table = nim_table.dropna(subset=["Quarter", "series"])  
            # If multiple detections map to the same (Quarter, series), average them
            if not nim_table.empty:
                dupe_mask = nim_table.duplicated(subset=["Quarter", "series"], keep=False)
                if dupe_mask.any():
                    # Aggregate duplicates by mean (stable for minor OCR jitter)
                    nim_table = nim_table.groupby(["Quarter", "series"], as_index=False)["value"].mean()
            nim_df = nim_table.pivot(index="Quarter", columns="series", values="value").reset_index()

    # NIM-only mode: skip NII extraction entirely
    nii_df = pd.DataFrame()

    def _sort_q(df_in):
        if df_in is None or df_in.empty or "Quarter" not in df_in.columns:
            return df_in
        # Try to sort by numeric (Q#, year) if labels are like 2Q24; else keep input order
        def _key(q):
            m = QUARTER_PAT.match(str(q))
            if not m:
                return (999, 999)
            qn = int(m.group(1))
            yr = int(m.group(2)[-2:])  # last two digits
            return (yr, qn)
        try:
            return df_in.assign(_k=df_in["Quarter"].map(_key)).sort_values("_k").drop(columns=["_k"]).reset_index(drop=True)
        except Exception:
            return df_in.reset_index(drop=True)

    return _sort_q(nim_df), _sort_q(nii_df)

def _extract_md_context(dest_dir: Path, image_name: str) -> dict:
    """
    Best-effort: read the <pdf_stem>.md in dest_dir, find the <image_name> reference,
    capture nearby headings and a neighbor paragraph to build context.
    """
    try:
        # Prefer "<pdf_stem>.md", else any .md
        md_file = dest_dir / f"{dest_dir.name}.md"
        if not md_file.exists():
            cands = list(dest_dir.glob("*.md"))
            if not cands:
                return {}
            md_file = cands[0]
        lines = md_file.read_text(encoding="utf-8", errors="ignore").splitlines()
    except Exception:
        return {}

    # Find the image line
    idx = None
    for i, line in enumerate(lines):
        if image_name in line:
            idx = i
            break
    if idx is None:
        return {}

    # Walk upward to find up to two headings and a neighbor paragraph
    figure_title = None
    section_title = None
    neighbor_text = None

    # Find the closest preceding heading(s)
    for j in range(idx - 1, -1, -1):
        s = lines[j].strip()
        if not s:
            continue
        # markdown heading levels
        if s.startswith("#"):
            # Remove leading #'s and whitespace
            heading = s.lstrip("#").strip()
            if figure_title is None:
                figure_title = heading
            elif section_title is None:
                section_title = heading
                break

    # Find a non-empty paragraph between the image and last heading
    for j in range(idx - 1, -1, -1):
        s = lines[j].strip()
        if s and not s.startswith("#") and not s.startswith("![]("):
            neighbor_text = s
            break

    out = {}
    if figure_title: out["figure_title"] = figure_title
    if section_title: out["section_title"] = section_title
    if neighbor_text: out["neighbor_text"] = neighbor_text
    return out

def _parse_page_and_figure_from_name(image_name: str) -> dict:
    """
    Extract page/figure indices from names like '_page_0_Figure_2.jpeg'.
    """
    info = {}
    try:
        # Very loose parse
        if "_page_" in image_name:
            after = image_name.split("_page_", 1)[1]
            num = after.split("_", 1)[0]
            info["page"] = int(num) + 1  # 1-based for human readability
        if "Figure_" in image_name:
            after = image_name.split("Figure_", 1)[1]
            num = ""
            for ch in after:
                if ch.isdigit():
                    num += ch
                else:
                    break
            if num:
                info["figure_index"] = int(num)
    except Exception:
        pass
    return info

def is_relevant_image(img_path, keywords):
    """Robust relevance check for NIM slides.
    - Reuse the singleton EasyOCR reader (run_easyocr)
    - Accept split tokens like "Net" / "interest" / "margin" (not only the exact phrase)
    - Fallback: if we see ‚â•4 quarter labels on the bottom AND ‚â•3 top-band percent-like values in NIM range, treat as relevant.
    """
    try:
        img = cv2.imread(str(img_path))
        if img is None:
            return False

        # Pass A: OCR on lightly upscaled original
        view_a = cv2.resize(img, None, fx=1.3, fy=1.3, interpolation=cv2.INTER_CUBIC)
        ocr_a = run_easyocr(cv2.cvtColor(view_a, cv2.COLOR_BGR2RGB))
        tokens_a = [str(r.get("text","")).lower() for r in (ocr_a or [])]
        text_a = " ".join(tokens_a)

        # Quick phrase match (exact keywords like "net interest margin")
        if any(k in text_a for k in keywords):
            return True

        # Pass B: OCR on preprocessed thresholded view (more stable for thin fonts)
        _, _, thr, _ = preprocess(img)
        ocr_b = run_easyocr(cv2.cvtColor(thr, cv2.COLOR_GRAY2RGB))
        tokens_b = [str(r.get("text","")).lower() for r in (ocr_b or [])]
        text_b = " ".join(tokens_b)
        if any(k in text_b for k in keywords):
            return True

        # Token-level split-word check
        tokens = tokens_a + tokens_b
        has_net      = any("net" in t for t in tokens)
        has_interest = any("interest" in t for t in tokens)
        has_margin   = any("margin" in t for t in tokens or [])
        has_nim_abbr = any(re.search(r"\bnim\b", t) for t in tokens)
        has_cb       = any("commercial book" in t for t in tokens)
        has_grp      = any(re.search(r"\bgroup\b", t) for t in tokens)
        if (has_net and has_interest and has_margin) or has_nim_abbr:
            # Strengthen with context words if available
            if has_cb or has_grp:
                return True

        # Structural fallback: quarters + percent values in the NIM band
        q_xy = detect_quarters_easyocr(img)
        if len(q_xy) >= 4:
            # Look for ‚â•3 percent-ish values in the top band within NIM_MIN..NIM_MAX
            df = extract_numbers(ocr_b)
            if not df.empty:
                H, W = view_a.shape[:2]
                top_cut = int(H * 0.55)
                in_top = df["cy"] < top_cut
                in_band = df["value"].between(NIM_MIN, NIM_MAX)
                pctish = in_band  # allow numbers without % (the series sometimes omit it)
                if int((in_top & pctish).sum()) >= 3:
                    return True

        return False
    except Exception:
        return False


# =============== Pluggable OCR Extractor Framework ===============
class BaseChartExtractor:
    """
    Minimal interface for pluggable chart extractors.
    Implement `is_relevant` and `extract_table`, then call `handle_image(...)`.
    """
    name = "base"
    topic = "Generic Chart"
    units = None
    entity = None
    keywords = []

    def is_relevant(self, img_path: Path) -> bool:
        return is_relevant_image(img_path, self.keywords)

    def extract_table(self, img_path: Path, dest_dir: Path, pdf_name: str):
        """
        Return (df, context_dict) or (None, reason) on failure.
        context_dict will be merged into the _context object.
        """
        raise NotImplementedError

    def _build_context(self, pdf_name: str, img_path: Path, dest_dir: Path, extra: dict | None = None) -> dict:
        ctx = {
            "source_pdf": pdf_name,
            "image": img_path.name,
            "topic": self.topic,
        }
        if self.units:  ctx["units"]  = self.units
        if self.entity: ctx["entity"] = self.entity
        ctx.update(_parse_page_and_figure_from_name(img_path.name))
        md_ctx = _extract_md_context(dest_dir, img_path.name)
        if md_ctx: ctx.update(md_ctx)
        if extra:  ctx.update(extra)
        return ctx

    def _write_jsonl(self, out_path: Path, ctx: dict, df: pd.DataFrame):
        import json
        with open(out_path, "w", encoding="utf-8") as f:
            f.write(json.dumps({"_context": ctx}, ensure_ascii=False) + "\n")
            for rec in df.to_dict(orient="records"):
                rec_out = dict(rec)
                rec_out["_meta"] = {"source_pdf": ctx.get("source_pdf"), "image": ctx.get("image")}
                f.write(json.dumps(rec_out, ensure_ascii=False) + "\n")

    def handle_image(self, img_path: Path, dest_dir: Path, pdf_name: str, *, bypass_relevance: bool = False):
        if not bypass_relevance and not self.is_relevant(img_path):
            return False, "Not relevant"
        df, ctx_extra = self.extract_table(img_path, dest_dir, pdf_name)
        if df is None or df.empty:
            return False, ctx_extra if isinstance(ctx_extra, str) else "No data"
        # Build context and summary if possible
        ctx = self._build_context(pdf_name, img_path, dest_dir, extra=ctx_extra if isinstance(ctx_extra, dict) else {})
        try:
            cols = [c for c in df.columns if c != "Quarter"]
            if len(df) >= 2 and cols:
                def _pick_q(s):
                    return s if QUARTER_PAT.match(str(s) or "") else None
                _fq = str(df.iloc[0]["Quarter"])
                _lq = str(df.iloc[-1]["Quarter"])
                first_q = _pick_q(_fq) or (_fq if "??" not in _fq else "start")
                last_q  = _pick_q(_lq) or (_lq if "??" not in _lq else "end")
                pieces = []
                for col in cols[:2]:
                    a = df.iloc[0][col]
                    b = df.iloc[-1][col]
                    if pd.notna(a) and pd.notna(b):
                        suffix = "%" if "NIM" in col or ctx.get("units") == "percent" else ""
                        pieces.append(f"{col}: {a:.2f}{suffix} ‚Üí {b:.2f}{suffix}")
                if pieces:
                    ctx["summary"] = f"Figure shows {', '.join(pieces)} from {first_q} to {last_q}."
        except Exception:
            pass
        out_path = img_path.with_suffix(f".{self.name}.jsonl")
        self._write_jsonl(out_path, ctx, df)
        return True, str(out_path)

class NIMExtractor(BaseChartExtractor):
    name = "nim"
    topic = "Net Interest Margin"
    units = "percent"
    entity = "DBS"
    keywords = NIM_KEYWORDS

    def extract_table(self, img_path: Path, dest_dir: Path, pdf_name: str):
        # Reuse the existing pipeline
        img_bgr = load_image(img_path)
        img_up, gray, thr, scale = preprocess(img_bgr)
        img_rgb = cv2.cvtColor(thr, cv2.COLOR_GRAY2RGB)
        ocr = run_easyocr(img_rgb)
        df_tokens = extract_numbers(ocr)
        if df_tokens.empty:
            return None, "No numeric tokens detected"
        md_q = detect_qlabels_from_md(dest_dir, img_path.name)
        nim_df, _nii_df = extract_series_from_df(df_tokens, img_up, ocr_results=ocr, qlabels_hint=md_q)
        if nim_df is None or nim_df.empty:
            return None, "No NIM table detected"
        return nim_df, {"topic": self.topic, "units": self.units, "entity": self.entity}

# Registry of extractors (add more later)
EXTRACTORS: list[BaseChartExtractor] = [
    NIMExtractor(),
]
# ============= End pluggable extractor framework =============

# === Single-image rebuild/verify mode (optional) ===
# Set single_image_mode=True and point single_image_path to a specific extracted image
# to run the two-stage gate + extraction just for that file, then exit.
single_image_mode = False
single_image_paths: list[Path] = [
   
]
# Optional singular fallback path (legacy): set to a string/Path if you want a single-image override
single_image_path = None

# Legacy fallback (ignored i
 # Toggle: if True ‚Üí normal md5 skip; if False ‚Üí always reprocess
md5_check = True

# 3. Define the path to the directory containing your PDF files
pdf_directory = Path("/Users/marcusfoo/Documents/GitHub/PTO_ICT3113_Grp1/All/")

# === Fast path: single image only ===
# === Fast path: single/multi-image only ===
if single_image_mode:
    paths: list[Path] = []
    if single_image_paths:
        paths = [Path(p) for p in single_image_paths if p is not None]
    elif single_image_path:
        paths = [Path(single_image_path)]

    if not paths:
        print("‚ùå single_image_mode=True but no paths were provided.")
        sys.exit(1)

    print("--- Multi-image mode ---")
    successes = 0
    for img_path in paths:
        if not img_path.exists():
            print(f"‚ùå Missing: {img_path}")
            continue

        dest_dir = img_path.parent
        pdf_name = f"{dest_dir.name}.pdf"
        print(f"\nüñºÔ∏è  Image: {img_path.name}  |  PDF: {pdf_name}")

        # Quick quarter readout (EasyOCR-only, bottom axis)
        try:
            img_bgr_quarters = load_image(img_path)
            q_xy = detect_quarters_easyocr(img_bgr_quarters)
            if q_xy:
                print("   üìé Quarters (EasyOCR):", ", ".join([q for _,q in q_xy]))
            else:
                print("   üìé Quarters (EasyOCR): <none>")
        except Exception as _qe:
            print(f"   üìé Quarters (EasyOCR): error ‚Üí {_qe}")

        any_hit = False

        for ex in EXTRACTORS:
            print(f"   ¬∑ [{ex.name}] quick gate‚Ä¶", end=" ")
            if not ex.is_relevant(img_path):
                print("‚è≠Ô∏è  Not relevant")
                continue
            print("‚úÖ ok; strict gate‚Ä¶", end=" ")
            ok_strict, reason = is_strict_nim_image(img_path)
            if not ok_strict:
                print(f"‚è≠Ô∏è  Failed strict ({reason})")
                continue
            print("‚úÖ Strict OK ‚Äî extracting‚Ä¶")

            # Extract directly so we can print the table; still write JSONL
            df, ctx_extra = ex.extract_table(img_path, dest_dir, pdf_name)
            if df is None or df.empty:
                print("   ‚ö†Ô∏è No data extracted.")
                continue

            any_hit = True
            successes += 1

            # Build context + summary and write JSONL
            ctx = ex._build_context(pdf_name, img_path, dest_dir, extra=ctx_extra if isinstance(ctx_extra, dict) else {})
            try:
                cols = [c for c in df.columns if c != "Quarter"]
                if len(df) >= 2 and cols:
                    def _pick_q(s):
                        return s if QUARTER_PAT.match(str(s) or "") else None
                    _fq = str(df.iloc[0]["Quarter"]); _lq = str(df.iloc[-1]["Quarter"])
                    first_q = _pick_q(_fq) or (_fq if "??" not in _fq else "start")
                    last_q  = _pick_q(_lq) or (_lq if "??" not in _lq else "end")
                    pieces = []
                    for col in cols[:2]:
                        a = df.iloc[0][col]; b = df.iloc[-1][col]
                        if pd.notna(a) and pd.notna(b):
                            suffix = "%" if "NIM" in col or ctx.get("units") == "percent" else ""
                            pieces.append(f"{col}: {a:.2f}{suffix} ‚Üí {b:.2f}{suffix}")
                    if pieces:
                        ctx["summary"] = f"Figure shows {', '.join(pieces)} from {first_q} to {last_q}."
            except Exception:
                pass

            out_path = img_path.with_suffix(f".{ex.name}.jsonl")
            ex._write_jsonl(out_path, ctx, df)
            print(f"   üíæ Saved JSONL ‚Üí {out_path}")

            # Pretty-print the extracted table directly
            try:
                print("\n   üìä Extracted table:")
                print(df.to_string(index=False))
            except Exception:
                print(df)

        if not any_hit:
            print("   ‚è≠Ô∏è  No matching extractors for this image.")

    print(f"\n‚úÖ Done. Extracted from {successes} image(s).")
    # Prevent the pipeline (marker/md5) from running if notebook catches SystemExit
    globals()["_STOP_AFTER_SINGLE"] = True
    sys.exit(0)
    
# Check if the directory exists before proceeding
if not pdf_directory.is_dir():
    print(f"‚ùå ERROR: The directory was not found at '{pdf_directory}'.")
    sys.exit(1) # Exit the script if the directory doesn't exist

# 4. Check if the 'marker_single' command is available
if not shutil.which("marker_single"):
    print("‚ùå ERROR: The 'marker_single' command was not found.")
    print("Please ensure 'marker-pdf' is installed correctly in your environment's PATH.")
    sys.exit(1)

# Loop through every PDF file in the specified directory
for pdf_path in pdf_directory.glob("*.pdf"):
    print(f"--- Processing file: {pdf_path.name} ---")

    # 5. Let Marker create the <pdf_stem>/ subfolder automatically.
    # Point --output_dir to the *parent* folder so we don't end up with Demo PDF/Demo PDF/.
    output_parent = pdf_path.parent  # e.g., .../Demo/

    # Determine the destination folder Marker will create and a checksum sidecar file
    dest_dir = output_parent / pdf_path.stem
    checksum_file = dest_dir / ".marker_md5"

    # Compute the current md5 of the source PDF
    current_md5 = md5sum(pdf_path)

    # Define the expected main outputs (Marker uses the same stem)
    expected_md = dest_dir / f"{pdf_path.stem}.md"
    expected_json = dest_dir / f"{pdf_path.stem}.json"
    outputs_exist = expected_md.exists() and expected_json.exists()

    # md5 two-mode logic
    if md5_check:
        # Normal: skip if checksum matches and key outputs exist
        if dest_dir.is_dir() and checksum_file.exists() and outputs_exist:
            try:
                saved_md5 = checksum_file.read_text().strip()
            except Exception:
                saved_md5 = ""
            if saved_md5 == current_md5:
                print(f"‚è≠Ô∏è  Skipping {pdf_path.name}: up-to-date (md5 match). ‚Üí {dest_dir}")
                continue
            else:
                print(f"‚ôªÔ∏è  md5 mismatch ‚Üí reprocessing {pdf_path.name}")
                print(f"    saved={saved_md5}")
                print(f"    current={current_md5}")
                print(f"    Cleaning old outputs in: {dest_dir}")
                try:
                    shutil.rmtree(dest_dir)
                except Exception as _e:
                    print(f"    ‚ö†Ô∏è  Could not fully clean '{dest_dir}': {_e}")
        else:
            print("‚ÑπÔ∏è  No prior checksum or outputs ‚Üí processing normally.")
    else:
        # Force reprocess regardless of checksum
        print("‚öôÔ∏è  md5_check=False ‚Üí forcing reprocess (marker + OCR).")
        if dest_dir.exists():
            print(f"    Cleaning existing folder: {dest_dir}")
            try:
                shutil.rmtree(dest_dir)
            except Exception as _e:
                print(f"    ‚ö†Ô∏è  Could not fully clean '{dest_dir}': {_e}")

    try:
        # ======================================================================
        # 1. Run the CLI command to generate JSON output (with real-time output)
        # ======================================================================
        print(f"Running CLI command for JSON output on {pdf_path.name}...")
        json_command = [
            "marker_single",
            str(pdf_path),
            "--output_format", "json",
            "--output_dir", str(output_parent)
        ]
        # By removing 'capture_output', the subprocess will stream its output directly to the console in real-time.
        result_json = subprocess.run(json_command, check=True)
        print("‚úÖ JSON file generated successfully by CLI.")


        # ======================================================================
        # 2. Run the CLI command to generate Markdown and Image output (with real-time output)
        # ======================================================================
        print(f"\nRunning CLI command for Markdown and Image output on {pdf_path.name}...")
        md_command = [
            "marker_single",
            str(pdf_path),
            # Default format is markdown, so we don't need to specify it
            "--output_dir", str(output_parent)
        ]
        result_md = subprocess.run(md_command, check=True)
        print("‚úÖ Markdown file and images generated successfully by CLI.")

        print(f"\n‚ú® Files saved under '{output_parent / pdf_path.stem}'.")
        print("Note: Marker creates a subfolder named after the PDF automatically.")

        # === Post-processing: scan Marker images ‚Üí filter relevant ‚Üí save JSONL ===
        print("üîé Scanning extracted images for relevant charts/plots‚Ä¶")
        img_exts = (".png", ".jpg", ".jpeg")
        img_files = [p for p in dest_dir.rglob("*") if p.suffix.lower() in img_exts]
        if not img_files:
            print("   üñºÔ∏è  No images found in extracted folder.")
        for img_path in sorted(img_files):
            print(f"   ‚Ä¢ {img_path.name}")
            any_hit = False
            for ex in EXTRACTORS:
                # Stage 1: quick keyword/title skim
                print(f"      ¬∑ [{ex.name}] quick gate‚Ä¶", end=" ")
                if not ex.is_relevant(img_path):
                    print("‚è≠Ô∏è  Not relevant")
                    continue
                print("‚úÖ ok; strict gate‚Ä¶", end=" ")

                # Stage 2: strict verifier (geometry + numeric band + semantic anchors)
                ok_strict, reason = is_strict_nim_image(img_path)
                if not ok_strict:
                    print(f"‚è≠Ô∏è  Failed strict ({reason})")
                    continue

                any_hit = True
                print("‚úÖ Strict OK ‚Äî extracting‚Ä¶", end=" ")
                ok, msg = ex.handle_image(img_path, dest_dir, pdf_path.name, bypass_relevance=True)
                if ok:
                    print(f"üíæ Saved ‚Üí {msg}")
                else:
                    print(f"‚ö†Ô∏è Skipped ({msg})")
            if not any_hit:
                print("      ‚è≠Ô∏è  No matching extractors for this image.")

        # After OCR completes, write/update checksum sidecar
        try:
            dest_dir.mkdir(parents=True, exist_ok=True)
            checksum_file.write_text(current_md5)
            print(f"üßæ Recorded checksum in: {checksum_file}")
        except Exception as _e:
            print(f"‚ö†Ô∏è  Failed to write checksum file at '{checksum_file}': {_e}")

    except subprocess.CalledProcessError as e:
        print(f"\n‚ùå An error occurred while processing {pdf_path.name}.")
        print(f"Command: '{' '.join(e.cmd)}'")
        print(f"Return Code: {e.returncode}")
        print("Note: Outputs (if any) may be incomplete; checksum not updated.")
    except Exception as e:
        print(f"\nAn unexpected error occurred while processing {pdf_path.name}: {e}")
    
    print(f"--- Finished processing: {pdf_path.name} ---\n")

print("üéâ All PDF files in the directory have been processed.")


# === Stage-1 continuation: Build KB + FAISS (inline; no external scripts) ===
try:
    import sys, subprocess
    # 1) Ensure minimal deps (idempotent)
    for _pkg in ["sentence-transformers", "faiss-cpu", "pandas", "pyarrow", "numpy", "lxml", "tqdm"]:
        try:
            __import__(_pkg.split("-")[0])
        except Exception:
            print(f"üì¶ Installing {_pkg} ‚Ä¶")
            subprocess.check_call([sys.executable, "-m", "pip", "install", _pkg, "-q"])  # noqa: S603,S607

    import re, json, hashlib, time
    import numpy as _np, pandas as _pd, faiss  # type: ignore
    from io import StringIO as _StringIO
    from pathlib import Path as _Path
    from tqdm import tqdm as _tqdm
    from sentence_transformers import SentenceTransformer as _ST

    KB_IN_DIR  = str(pdf_directory)  # reuse the same directory processed above
    KB_OUT_DIR = str((_Path("./data_marker")).resolve())

    # ---- helpers (namespaced with kb_ to avoid collisions) ----
    def kb_file_hash_key(p: _Path) -> str:
        try:
            s = p.stat()
            return hashlib.md5(f"{p.resolve()}|{s.st_size}|{int(s.st_mtime)}".encode()).hexdigest()
        except FileNotFoundError:
            return ""

    def kb_safe_read(path: _Path) -> str:
        for enc in ("utf-8", "utf-8-sig", "latin-1"):
            try:
                return path.read_text(encoding=enc, errors="ignore")
            except Exception:
                continue
        return ""

    def kb_strip_md_basic(md: str) -> str:
        md = re.sub(r"```.*?```", " ", md, flags=re.DOTALL)
        md = re.sub(r"!\[[^\]]*\]\([^\)]*\)", " ", md)
        md = re.sub(r"\[([^\]]+)\]\([^\)]*\)", r"\1", md)
        md = re.sub(r"<[^>]+>", " ", md)
        md = re.sub(r"\s+", " ", md)
        return md.strip()

    def kb_coerce_numbers_df(df: _pd.DataFrame) -> _pd.DataFrame:
        df = df.copy()
        for c in df.columns:
            if df[c].dtype == object:
                s = df[c].astype(str).str.replace(",", "", regex=False)
                num = _pd.to_numeric(s, errors="coerce")
                df[c] = _np.where(num.notna(), num, s)
        return df

    def kb_extract_tables_from_marker_json_blocks(jtxt: str):
        try:
            data = json.loads(jtxt)
        except Exception:
            return []
        out = []
        def _page_from_id(node: dict, fallback):
            node_id = node.get("id") if isinstance(node.get("id"), str) else ""
            m = re.search(r"/page/(\d+)/", node_id or "")
            if m:
                try:
                    return int(m.group(1))
                except Exception:
                    pass
            return fallback
        def walk(node, current_page=None):
            if isinstance(node, dict):
                current_page = _page_from_id(node, current_page)
                if node.get("block_type") == "Table" and isinstance(node.get("html"), str):
                    html = node["html"]
                    try:
                        dfs = _pd.read_html(_StringIO(html))
                        for df in dfs:
                            out.append({"df": kb_coerce_numbers_df(df), "page": current_page})
                    except Exception:
                        pass
                for v in node.values():
                    walk(v, current_page)
            elif isinstance(node, list):
                for v in node:
                    walk(v, current_page)
        walk(data)
        return out

    def kb_extract_text_spans_with_pages(jtxt: str):
        try:
            data = json.loads(jtxt)
        except Exception:
            return []
        spans = []
        def _page_from_id(node: dict, fallback):
            node_id = node.get("id") if isinstance(node.get("id"), str) else ""
            m = re.search(r"/page/(\d+)/", node_id or "")
            if m:
                try:
                    return int(m.group(1))
                except Exception:
                    pass
            return fallback
        def _strip_html(s: str) -> str:
            s = re.sub(r"<[^>]+>", " ", s)
            s = re.sub(r"\s+", " ", s).strip()
            return s
        TEXT_BLOCKS = {"Text", "SectionHeader", "Paragraph", "Heading", "ListItem", "Caption", "Footer", "Header"}
        def walk(node, current_page=None):
            if isinstance(node, dict):
                current_page = _page_from_id(node, current_page)
                bt = node.get("block_type")
                if isinstance(bt, str) and bt in TEXT_BLOCKS:
                    html = node.get("html")
                    if isinstance(html, str) and html.strip():
                        txt = _strip_html(html)
                        if txt:
                            spans.append({"page": current_page, "text": txt})
                for v in node.values():
                    walk(v, current_page)
            elif isinstance(node, list):
                for v in node:
                    walk(v, current_page)
        walk(data)
        return spans

    def kb_markdown_tables_find(md_text: str):
        lines = md_text.splitlines()
        i, n = 0, len(lines)
        while i < n:
            if '|' in lines[i]:
                j = i + 1
                if j < n and re.search(r'^\s*\|?\s*:?-{3,}', lines[j]):
                    k = j + 1
                    while k < n and '|' in lines[k] and lines[k].strip():
                        k += 1
                    yield "\n".join(lines[i:k])
                    i = k; continue
            i += 1

    def kb_markdown_table_to_df(table_md: str):
        rows = [r.strip() for r in table_md.strip().splitlines() if r.strip()]
        if len(rows) < 2: return None
        def split_row(r: str):
            r = r.strip()
            if r.startswith('|'): r = r[1:]
            if r.endswith('|'): r = r[:-1]
            return [c.strip() for c in r.split('|')]
        cols = split_row(rows[0])
        if len(split_row(rows[1])) != len(cols): return None
        data = []
        for r in rows[2:]:
            cells = split_row(r)
            if len(cells) < len(cols): cells += [""] * (len(cols) - len(cells))
            if len(cells) > len(cols): cells = cells[:len(cols)]
            data.append(cells)
        try:
            df = _pd.DataFrame(data, columns=cols)
            return kb_coerce_numbers_df(df)
        except Exception:
            return None

    def kb_table_rows_to_sentences(df: _pd.DataFrame, doc_name: str, table_id: int):
        sents = []
        if df.shape[1] == 0: return sents
        label = df.columns[0]
        for ridx, row in df.reset_index(drop=True).iterrows():
            parts = [str(row[label])]
            for c in df.columns[1:]:
                parts.append(f"{c}: {row[c]}")
            sents.append(f"[{doc_name}] table#{table_id} row#{ridx} :: " + " | ".join(parts))
        return sents

    def kb_table_signature(df: _pd.DataFrame) -> str:
        try:
            cols = [str(c).strip() for c in df.columns]
            first_col = cols[0] if cols else ""
            years = sorted({c for c in cols if re.fullmatch(r"\d{4}", str(c))})
            nums = []
            for c in df.columns:
                s = _pd.to_numeric(_pd.Series(df[c]).astype(str).str.replace(",", "", regex=False), errors="coerce")
                vals = [float(x) for x in s.dropna().tolist()]
                nums.extend(vals)
            nums = [round(x, 3) for x in nums[:8]]
            return "|".join([
                f"first:{first_col.lower()}",
                "years:" + ",".join(years),
                "nums:" + ",".join(map(str, nums))
            ])
        except Exception:
            return ""

    def kb_encode(texts, model_name):
        model = _ST(model_name)
        embs = model.encode(texts, batch_size=64, show_progress_bar=True, normalize_embeddings=True)
        return _np.asarray(embs, dtype="float32")

    def kb_build_faiss(embs):
        d = int(embs.shape[1])
        idx = faiss.IndexFlatIP(d)  # cosine via normalized inner product
        idx.add(embs)
        return idx

    def kb_discover_docs(in_dir: _Path):
        docs = {}
        for f in sorted(in_dir.iterdir()):
            if not f.is_dir():
                continue
            nested = f / f.name
            md = list(f.glob("*.md")) + (list(nested.glob("*.md")) if nested.is_dir() else [])
            js = list(f.glob("*.json")) + (list(nested.glob("*.json")) if nested.is_dir() else [])
            jl = list(f.glob("*.jsonl")) + (list(nested.glob("*.jsonl")) if nested.is_dir() else [])
            if md or js or jl:
                docs[f.name] = {"md": sorted(md), "json": sorted(js), "jsonl": sorted(jl), "root": f}
        return docs

    def kb_load_jsonl(path: _Path) -> list:
        rows = []
        try:
            with open(path, "r", encoding="utf-8") as f:
                for line in f:
                    s = line.strip()
                    if not s:
                        continue
                    try:
                        rows.append(json.loads(s))
                    except Exception:
                        continue
        except Exception:
            return []
        return rows

    def kb_chunk_text(text: str, max_chars: int = 1600, overlap: int = 200):
        if not text: return []
        paras = [p.strip() for p in re.split(r"\n\s*\n", text) if p.strip()]
        chunks, buf, cur = [], [], 0
        def flush():
            nonlocal buf, cur
            if not buf: return
            s = "\n\n".join(buf).strip()
            step = max_chars - overlap
            for i in range(0, len(s), step):
                piece = s[i:i+step].strip()
                if piece: chunks.append(piece)
            buf.clear(); cur = 0
        for p in paras:
            if cur + len(p) + 2 <= max_chars:
                buf.append(p); cur += len(p) + 2
            else:
                flush(); buf.append(p); cur = len(p)
        flush(); return chunks

    def build_kb_with_tables(
        in_dir=KB_IN_DIR,
        out_dir=KB_OUT_DIR,
        model_name="sentence-transformers/all-MiniLM-L6-v2",
        max_chars=1600,
        overlap=200,
    ):
        in_path, out_path = _Path(in_dir), _Path(out_dir)
        out_path.mkdir(parents=True, exist_ok=True)

        kb_parquet     = out_path / "kb_chunks.parquet"
        kb_texts_npy   = out_path / "kb_texts.npy"
        kb_meta_json   = out_path / "kb_meta.json"
        kb_index_path  = out_path / "kb_index.faiss"
        kb_index_meta  = out_path / "kb_index_meta.json"
        kb_tables_parq = out_path / "kb_tables.parquet"
        kb_outline_parq = out_path / "kb_outline.parquet"

        cache = {}
        if kb_meta_json.exists():
            try:
                cache = json.loads(kb_meta_json.read_text(encoding="utf-8"))
            except Exception:
                cache = {}

        docs = kb_discover_docs(in_path)
        if not docs:
            print(f"‚ÑπÔ∏è No Marker artefacts found under: {in_path}")
            return {"docs_processed": 0, "chunks_total": 0, "tables_long_rows": 0, "paths": {}}
        print(f"üîé Found {len(docs)} docs under {in_path}")

        # outlines (optional)
        outline_rows = []
        for doc_name, art in docs.items():
            root = art.get("root", in_path / doc_name)
            candidates = list(root.glob("*_meta.json"))
            nested_same = root / doc_name
            if nested_same.is_dir():
                candidates += list(nested_same.glob("*_meta.json"))
            for meta_path in candidates:
                try:
                    data = json.loads(kb_safe_read(meta_path))
                    toc = data.get("table_of_contents") or data.get("toc") or []
                    for i, item in enumerate(toc):
                        outline_rows.append({
                            "doc_name": doc_name,
                            "source_path": str(meta_path),
                            "order": int(i),
                            "title": item.get("title"),
                            "page_id": item.get("page_id"),
                            "polygon": item.get("polygon"),
                        })
                except Exception:
                    pass
        if outline_rows:
            _pd.DataFrame(outline_rows).to_parquet(kb_outline_parq, engine="pyarrow", index=False)
            print(f"üìë Saved outline ‚Üí {kb_outline_parq} (rows={len(outline_rows)})")
        else:
            print("‚ÑπÔ∏è No *_meta.json outlines found.")

        rows_meta, chunk_texts = [], []
        tables_long = []
        json_sig_to_page = {}
        changed_any = False

        for name, art in _tqdm(docs.items(), desc="Processing docs"):
            md_files, json_files = art["md"], art["json"]
            jsonl_files = art.get("jsonl", [])
            keys = [kb_file_hash_key(p) for p in (md_files + json_files + jsonl_files)]
            doc_key = hashlib.md5("|".join(keys).encode()).hexdigest()

            if cache.get(name, {}).get("cache_key") == doc_key:
                continue
            changed_any = True

            # 1) JSON ‚Üí tables + page-text
            table_id = 0
            for jp in json_files:
                jtxt = kb_safe_read(jp)
                # tables with page capture
                for tb in kb_extract_tables_from_marker_json_blocks(jtxt):
                    df = tb["df"]; page_no = tb.get("page")
                    try:
                        sig = kb_table_signature(df)
                        if page_no is not None and sig:
                            json_sig_to_page[sig] = int(page_no)
                    except Exception:
                        pass
                    for sent in kb_table_rows_to_sentences(df, name, table_id):
                        if page_no is not None:
                            sent = f"[page {page_no}] " + sent
                        rows_meta.append({
                            "doc": name, "path": str(jp), "modality": "table_row",
                            "chunk": len(chunk_texts), "cache_key": doc_key,
                            "page": int(page_no) if page_no is not None else None,
                        })
                        chunk_texts.append(sent)
                    for ridx, row in df.reset_index(drop=True).iterrows():
                        for col in df.columns:
                            _val = row[col]
                            _val_str = "" if _pd.isna(_val) else str(_val)
                            try:
                                _val_num = _pd.to_numeric(_val_str.replace(",", ""), errors="coerce")
                            except Exception:
                                _val_num = _np.nan
                            tables_long.append({
                                "doc_name": name, "source_path": str(jp), "table_id": table_id,
                                "row_id": int(ridx), "column": str(col),
                                "value_str": _val_str,
                                "value_num": float(_val_num) if _pd.notna(_val_num) else None,
                                "page": int(page_no) if page_no is not None else None,
                            })
                    table_id += 1
                # page narrative
                spans = kb_extract_text_spans_with_pages(jtxt)
                by_page = {}
                for sp in spans:
                    by_page.setdefault(sp.get("page"), []).append(sp["text"])
                for page_no, texts in by_page.items():
                    page_text = kb_strip_md_basic("\n\n".join(texts))
                    for ch in kb_chunk_text(page_text, max_chars, overlap):
                        rows_meta.append({
                            "doc": name, "path": str(jp), "modality": "json",
                            "chunk": len(chunk_texts), "cache_key": doc_key,
                            "page": int(page_no) if page_no is not None else None,
                        })
                        chunk_texts.append(ch)

            # 1b) JSONL (extractor outputs)
            for jlp in jsonl_files:
                records = kb_load_jsonl(jlp)
                if not records:
                    continue
                ctx, data_recs = None, []
                for r in records:
                    if isinstance(r, dict) and "_context" in r:
                        ctx = r.get("_context")
                    elif isinstance(r, dict):
                        data_recs.append(r)
                page_no = None
                if isinstance(ctx, dict):
                    p = ctx.get("page")
                    if isinstance(p, int):
                        page_no = p
                df_jl = None
                if data_recs:
                    try:
                        df_jl = _pd.DataFrame(data_recs)
                        if "_meta" in df_jl.columns:
                            try:
                                df_jl = df_jl.drop(columns=["_meta"])
                            except Exception:
                                pass
                        df_jl = kb_coerce_numbers_df(df_jl)
                    except Exception:
                        df_jl = None
                if df_jl is not None and not df_jl.empty:
                    for sent in kb_table_rows_to_sentences(df_jl, name, table_id):
                        if page_no is not None:
                            sent = f"[page {page_no}] " + sent
                        rows_meta.append({
                            "doc": name, "path": str(jlp), "modality": "jsonl_row",
                            "chunk": len(chunk_texts), "cache_key": doc_key, "page": page_no,
                        })
                        chunk_texts.append(sent)
                    for ridx, row in df_jl.reset_index(drop=True).iterrows():
                        for col in df_jl.columns:
                            _val = row[col]
                            _val_str = "" if _pd.isna(_val) else str(_val)
                            try:
                                _val_num = _pd.to_numeric(_val_str.replace(",", ""), errors="coerce")
                            except Exception:
                                _val_num = _np.nan
                            tables_long.append({
                                "doc_name": name, "source_path": str(jlp), "table_id": table_id,
                                "row_id": int(ridx), "column": str(col),
                                "value_str": _val_str,
                                "value_num": float(_val_num) if _pd.notna(_val_num) else None,
                                "page": page_no,
                            })
                    table_id += 1
                if isinstance(ctx, dict) and isinstance(ctx.get("summary"), str) and ctx["summary"].strip():
                    rows_meta.append({
                        "doc": name, "path": str(jlp), "modality": "jsonl_summary",
                        "chunk": len(chunk_texts), "cache_key": doc_key, "page": page_no,
                    })
                    chunk_texts.append(f"[{name}] {ctx['summary'].strip()}")

            # 2) Markdown ‚Üí tables + non-table text
            for mp in md_files:
                md = kb_safe_read(mp)
                for tblock in kb_markdown_tables_find(md):
                    df = kb_markdown_table_to_df(tblock)
                    if df is None: 
                        continue
                    md_page = None
                    try:
                        md_sig = kb_table_signature(df)
                        if md_sig and md_sig in json_sig_to_page:
                            md_page = int(json_sig_to_page[md_sig])
                    except Exception:
                        md_page = None
                    for sent in kb_table_rows_to_sentences(df, name, table_id):
                        rows_meta.append({
                            "doc": name, "path": str(mp), "modality": "table_row",
                            "chunk": len(chunk_texts), "cache_key": doc_key, "page": md_page
                        })
                        chunk_texts.append(sent)
                    for ridx, row in df.reset_index(drop=True).iterrows():
                        for col in df.columns:
                            _val = row[col]
                            _val_str = "" if _pd.isna(_val) else str(_val)
                            try:
                                _val_num = _pd.to_numeric(_val_str.replace(",", ""), errors="coerce")
                            except Exception:
                                _val_num = _np.nan
                            tables_long.append({
                                "doc_name": name, "source_path": str(mp), "table_id": table_id,
                                "row_id": int(ridx), "column": str(col),
                                "value_str": _val_str,
                                "value_num": float(_val_num) if _pd.notna(_val_num) else None,
                                "page": md_page,
                            })
                    table_id += 1

                md_no_tables = md
                for tblock in kb_markdown_tables_find(md):
                    md_no_tables = md_no_tables.replace(tblock, "")
                for ch in kb_chunk_text(kb_strip_md_basic(md_no_tables), max_chars, overlap):
                    rows_meta.append({"doc": name, "path": str(mp), "modality": "md",
                                      "chunk": len(chunk_texts), "cache_key": doc_key, "page": None})
                    chunk_texts.append(ch)

            added_for_doc = sum(1 for r in rows_meta if r["cache_key"] == doc_key)
            cache[name] = {"cache_key": doc_key, "chunk_count": added_for_doc, "updated_at": int(time.time())}

        # If nothing changed and KB exists ‚Üí keep existing artifacts
        if (not changed_any) and ((out_path/"kb_chunks.parquet").exists()):
            print("‚úÖ No changes detected. Keeping existing KB and FAISS index.")
            texts_existing = _np.load(out_path/"kb_texts.npy", allow_pickle=True)
            return {
                "docs_processed": len(docs),
                "chunks_total": int(len(texts_existing)),
                "tables_long_rows": (_pd.read_parquet(out_path/"kb_tables.parquet").shape[0] if (out_path/"kb_tables.parquet").exists() else 0),
                "paths": {
                    "kb_chunks_parquet": str(out_path/"kb_chunks.parquet"),
                    "kb_texts_npy": str(out_path/"kb_texts.npy"),
                    "kb_meta_json": str(out_path/"kb_meta.json"),
                    "kb_tables_parquet": str(out_path/"kb_tables.parquet") if (out_path/"kb_tables.parquet").exists() else None,
                    "kb_index_faiss": str(out_path/"kb_index.faiss") if (out_path/"kb_index.faiss").exists() else None,
                    "kb_index_meta_json": str(out_path/"kb_index_meta.json") if (out_path/"kb_index_meta.json").exists() else None,
                }
            }

        # Persist KB + tables
        total = len(chunk_texts)
        print(f"üßæ Total new/updated text chunks (incl. table rows): {total}")
        _pd.DataFrame(rows_meta).to_parquet(out_path/"kb_chunks.parquet", engine="pyarrow", index=False)
        _np.save(out_path/"kb_texts.npy", _np.array(chunk_texts, dtype=object))
        if tables_long:
            _pd.DataFrame(tables_long).to_parquet(out_path/"kb_tables.parquet", engine="pyarrow", index=False)
            print(f"üìë Saved structured tables ‚Üí {out_path / 'kb_tables.parquet'} (rows={len(tables_long)})")
        else:
            print("üìë No structured tables detected this run.")
        (out_path/"kb_meta.json").write_text(json.dumps(cache, indent=2), encoding="utf-8")

        if total == 0:
            print("‚ö†Ô∏è No new chunks produced. Skipping embedding/index rebuild.")
            return {"docs_processed": len(docs), "chunks_total": 0, "tables_long_rows": len(tables_long), "paths": {}}

        # Embeddings + FAISS
        print("üß† Encoding embeddings ‚Ä¶")
        embs = kb_encode(chunk_texts, model_name)
        print(f"‚úÖ Embeddings shape: {embs.shape}")
        print("üì¶ Building FAISS index ‚Ä¶")
        idx = kb_build_faiss(embs)
        faiss.write_index(idx, str(out_path/"kb_index.faiss"))
        (out_path/"kb_index_meta.json").write_text(json.dumps({
            "model": model_name, "dim": int(embs.shape[1]), "total_vectors": int(embs.shape[0]),
            "metric": "cosine (via inner product on normalized vectors)",
        }, indent=2), encoding="utf-8")
        print(f"üéâ KB + index saved to: {out_path}")
        return {"docs_processed": len(docs), "chunks_total": int(total), "tables_long_rows": len(tables_long)}

    # ---- execute inline build ----
    print("\nüöÄ Building KB/index from extracted artifacts (JSON/MD/JSONL)‚Ä¶")
    _summary = build_kb_with_tables()
    print(_summary)
    print("‚úÖ KB build completed.")
except Exception as _e:
    print(f"‚ùå Inline KB build failed: {_e}")

--- Processing file: 2Q24_performance_summary.pdf ---
‚è≠Ô∏è  Skipping 2Q24_performance_summary.pdf: up-to-date (md5 match). ‚Üí /Users/marcusfoo/Documents/GitHub/PTO_ICT3113_Grp1/All/2Q24_performance_summary
--- Processing file: 3Q24_CEO_presentation.pdf ---
‚è≠Ô∏è  Skipping 3Q24_CEO_presentation.pdf: up-to-date (md5 match). ‚Üí /Users/marcusfoo/Documents/GitHub/PTO_ICT3113_Grp1/All/3Q24_CEO_presentation
--- Processing file: 4Q24_CFO_presentation.pdf ---
‚è≠Ô∏è  Skipping 4Q24_CFO_presentation.pdf: up-to-date (md5 match). ‚Üí /Users/marcusfoo/Documents/GitHub/PTO_ICT3113_Grp1/All/4Q24_CFO_presentation
--- Processing file: 4Q24_performance_summary.pdf ---
‚è≠Ô∏è  Skipping 4Q24_performance_summary.pdf: up-to-date (md5 match). ‚Üí /Users/marcusfoo/Documents/GitHub/PTO_ICT3113_Grp1/All/4Q24_performance_summary
--- Processing file: 4Q24_CEO_presentation.pdf ---
‚è≠Ô∏è  Skipping 4Q24_CEO_presentation.pdf: up-to-date (md5 match). ‚Üí /Users/marcusfoo/Documents/GitHub/PTO_ICT3113_Grp1/All/4Q24


[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m25.2[0m[39;49m -> [0m[32;49m25.3[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpip install --upgrade pip[0m



üöÄ Building KB/index from extracted artifacts (JSON/MD/JSONL)‚Ä¶
üîé Found 24 docs under /Users/marcusfoo/Documents/GitHub/PTO_ICT3113_Grp1/All
üìë Saved outline ‚Üí /Users/marcusfoo/Documents/GitHub/PTO_ICT3113_Grp1/data_marker/kb_outline.parquet (rows=3325)


Processing docs: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 24/24 [00:06<00:00,  3.55it/s]


üßæ Total new/updated text chunks (incl. table rows): 13587
üìë Saved structured tables ‚Üí /Users/marcusfoo/Documents/GitHub/PTO_ICT3113_Grp1/data_marker/kb_tables.parquet (rows=52949)
üß† Encoding embeddings ‚Ä¶


Batches: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 213/213 [00:30<00:00,  6.99it/s]


‚úÖ Embeddings shape: (13587, 384)
üì¶ Building FAISS index ‚Ä¶
üéâ KB + index saved to: /Users/marcusfoo/Documents/GitHub/PTO_ICT3113_Grp1/data_marker
{'docs_processed': 24, 'chunks_total': 13587, 'tables_long_rows': 52949}
‚úÖ KB build completed.


### Check Data 

In [13]:
import pandas as pd
import re
import numpy as np

# ==========================================
# 1. LOAD & PREPARE DATA
# ==========================================
try:
    df = pd.read_parquet("./data_marker/kb_tables.parquet")
    print(f"üì¶ Loaded {len(df)} rows from kb_tables.parquet")
except FileNotFoundError:
    print("‚ùå kb_tables.parquet not found.")
    df = pd.DataFrame()

if not df.empty:
    print("üîÑ Pivoting data to reconstruct table rows...")
    wide_df = df.pivot_table(
        index=['doc_name', 'table_id', 'row_id'], 
        columns='column', 
        values='value_str', 
        aggfunc='first'
    ).reset_index()
    print(f"‚úÖ Reconstructed {len(wide_df)} unique table rows.")
else:
    wide_df = pd.DataFrame()

# ==========================================
# 2. EXTRACTION LOGIC (V3)
# ==========================================
def _norm(s):
    return str(s).lower().strip()

def get_row_text(row):
    return " ".join([str(x) for x in row.values if x is not None]).lower()

def extract_metric_series(target_label_keywords, year_cols=True):
    matches = []
    if wide_df.empty: return pd.DataFrame()

    all_cols = wide_df.columns.astype(str)
    year_columns = [c for c in all_cols if re.match(r'^\d{4}$', c)]
    
    for idx, row in wide_df.iterrows():
        doc_name = row['doc_name']
        row_text = get_row_text(row)
        
        # --- STRATEGY A: Financials (Years) ---
        if year_cols:
            if any(k in row_text for k in target_label_keywords):
                data = {}
                for y in year_columns:
                    if y in row and pd.notna(row[y]):
                        val_str = str(row[y]).replace(',', '')
                        try:
                            data[int(y)] = float(val_str)
                        except ValueError:
                            pass
                if data:
                    label_guess = "Unknown"
                    for val in row.values:
                        if pd.notna(val) and any(k in _norm(val) for k in target_label_keywords):
                            label_guess = val
                            break
                    matches.append({"Doc": doc_name, "Label": label_guess, "Data": data})

        # --- STRATEGY B: NIM/Metrics (Quarters) ---
        else:
            valid_items = row.dropna().to_dict()
            q_candidates = [k for k in valid_items.keys() if 'quarter' in str(k).lower()]
            q_candidates.sort(key=len)
            
            if q_candidates:
                q_col = q_candidates[0]
                quarter_val = valid_items[q_col]
                for col_name, cell_val in valid_items.items():
                    if col_name == q_col: continue
                    # Check column header for keyword
                    if any(k in _norm(col_name) for k in target_label_keywords):
                        matches.append({
                            "Doc": doc_name, 
                            "Label": col_name, 
                            "Data": {quarter_val: cell_val}
                        })

    return pd.DataFrame(matches)

# ==========================================
# 3. RUN EXTRACTORS (RAW DATA DUMP)
# ==========================================
print("\n" + "="*60)
print("üîé RAW DATA VERIFICATION (NO CALCULATIONS)")
print("="*60)

# --- 1. OPERATING EXPENSES ---
print("\nüîπ RAW PULL: Operating Expenses")
df_opex = extract_metric_series(["total expenses", "operating expenses"], year_cols=True)

if not df_opex.empty:
    # Filter to show only meaningful rows (>= 2 data points)
    df_opex_clean = df_opex[df_opex['Data'].apply(len) >= 2]
    for _, r in df_opex_clean.iterrows():
        # Sort data keys for cleaner viewing
        sorted_data = dict(sorted(r['Data'].items()))
        print(f"üìÑ {r['Doc']} | Label: {r['Label']}")
        print(f"   üìä {sorted_data}")
        print("-" * 20)
else:
    print("‚ùå No Opex data found.")

# --- 2. OPERATING INCOME ---
print("\nüîπ RAW PULL: Operating Income")
df_inc = extract_metric_series(["total income", "operating income"], year_cols=True)

if not df_inc.empty:
    df_inc_clean = df_inc[df_inc['Data'].apply(len) >= 2]
    for _, r in df_inc_clean.iterrows():
        sorted_data = dict(sorted(r['Data'].items()))
        print(f"üìÑ {r['Doc']} | Label: {r['Label']}")
        print(f"   üìä {sorted_data}")
        print("-" * 20)
else:
    print("‚ùå No Income data found.")

# --- 3. NET INTEREST MARGIN ---
print("\nüîπ RAW PULL: Net Interest Margin (Quarterly)")
df_nim = extract_metric_series(["nim", "net interest margin"], year_cols=False)

if not df_nim.empty:
    # Group by Doc to show the timeline extracted per document
    for doc, group in df_nim.groupby("Doc"):
        timeline = {}
        for _, r in group.iterrows():
            # Only keep keys that look like quarters (contain 'q')
            valid_q = {k:v for k,v in r['Data'].items() if 'q' in str(k).lower()}
            timeline.update(valid_q)
        
        if timeline:
            sorted_timeline = dict(sorted(timeline.items()))
            print(f"üìÑ {doc}")
            print(f"   üìà Raw Series: {sorted_timeline}")
            print("-" * 20)
else:
    print("‚ùå No NIM data found.")

üì¶ Loaded 52949 rows from kb_tables.parquet
üîÑ Pivoting data to reconstruct table rows...
‚úÖ Reconstructed 10332 unique table rows.

üîé RAW DATA VERIFICATION (NO CALCULATIONS)

üîπ RAW PULL: Operating Expenses
üìÑ dbs-annual-report-2022 | Label: Total expenses
   üìä {2021: 6569.0, 2022: 7090.0}
--------------------
üìÑ dbs-annual-report-2022 | Label: Total expenses
   üìä {2021: 6569.0, 2022: 7090.0}
--------------------
üìÑ dbs-annual-report-2023 | Label: Total expenses
   üìä {2022: 7090.0, 2023: 8291.0}
--------------------
üìÑ dbs-annual-report-2023 | Label: Total expenses
   üìä {2022: 7090.0, 2023: 8291.0}
--------------------
üìÑ dbs-annual-report-2024 | Label: Total expenses
   üìä {2023: 8291.0, 2024: 9018.0}
--------------------
üìÑ dbs-annual-report-2024 | Label: Total expenses
   üìä {2023: 8291.0, 2024: 9018.0}
--------------------

üîπ RAW PULL: Operating Income
üìÑ dbs-annual-report-2022 | Label: Total income
   üìä {2021: 14188.0, 2022: 16502.0}


### Check Data 2

In [12]:
import pandas as pd
import re
import numpy as np

# ==========================================
# 1. LOAD & PREPARE DATA
# ==========================================
try:
    df = pd.read_parquet("./data_marker/kb_tables.parquet")
    print(f"üì¶ Loaded {len(df)} rows from kb_tables.parquet")
except FileNotFoundError:
    print("‚ùå kb_tables.parquet not found.")
    df = pd.DataFrame()

if not df.empty:
    print("üîÑ Pivoting data to reconstruct table rows...")
    wide_df = df.pivot_table(
        index=['doc_name', 'table_id', 'row_id'], 
        columns='column', 
        values='value_str', 
        aggfunc='first'
    ).reset_index()
    print(f"‚úÖ Reconstructed {len(wide_df)} unique table rows.")
else:
    wide_df = pd.DataFrame()

# ==========================================
# 2. EXTRACTION LOGIC (V3)
# ==========================================
def _norm(s):
    return str(s).lower().strip()

def get_row_text(row):
    return " ".join([str(x) for x in row.values if x is not None]).lower()

def extract_metric_series(target_label_keywords, year_cols=True):
    matches = []
    if wide_df.empty: return pd.DataFrame()

    all_cols = wide_df.columns.astype(str)
    year_columns = [c for c in all_cols if re.match(r'^\d{4}$', c)]
    
    for idx, row in wide_df.iterrows():
        doc_name = row['doc_name']
        row_text = get_row_text(row)
        
        # --- STRATEGY A: Financials (Years) ---
        if year_cols:
            if any(k in row_text for k in target_label_keywords):
                data = {}
                for y in year_columns:
                    if y in row and pd.notna(row[y]):
                        val_str = str(row[y]).replace(',', '')
                        try:
                            data[int(y)] = float(val_str)
                        except ValueError:
                            pass
                if data:
                    label_guess = "Unknown"
                    for val in row.values:
                        if pd.notna(val) and any(k in _norm(val) for k in target_label_keywords):
                            label_guess = val
                            break
                    matches.append({"Doc": doc_name, "Label": label_guess, "Data": data})

        # --- STRATEGY B: NIM/Metrics (Quarters) ---
        else:
            valid_items = row.dropna().to_dict()
            q_candidates = [k for k in valid_items.keys() if 'quarter' in str(k).lower()]
            q_candidates.sort(key=len)
            
            if q_candidates:
                q_col = q_candidates[0]
                quarter_val = valid_items[q_col]
                for col_name, cell_val in valid_items.items():
                    if col_name == q_col: continue
                    if any(k in _norm(col_name) for k in target_label_keywords):
                        matches.append({
                            "Doc": doc_name, 
                            "Label": col_name, 
                            "Data": {quarter_val: cell_val}
                        })

    return pd.DataFrame(matches)

# ==========================================
# 3. RUN EXTRACTORS
# ==========================================
print("\nüîé Extracting raw data...")
df_opex = extract_metric_series(["total expenses", "operating expenses"], year_cols=True)
df_inc = extract_metric_series(["total income", "operating income"], year_cols=True)
df_nim = extract_metric_series(["nim", "net interest margin"], year_cols=False)

print(f"   found {len(df_opex)} Opex rows")
print(f"   found {len(df_inc)} Income rows")
print(f"   found {len(df_nim)} NIM rows")

# ==========================================
# 4. GENERATE BENCHMARK ANSWERS (FIXED)
# ==========================================
print("\n" + "="*60)
print("üìù FINAL BENCHMARK OUTPUTS")
print("="*60)

# Helper to merge multiple rows into one timeline (taking MAX value to find Group Total)
def merge_financial_series(df_source):
    master_timeline = {}
    if df_source.empty: return master_timeline
    
    for _, row in df_source.iterrows():
        for year, val in row['Data'].items():
            # If we have multiple values for 2023 (e.g. Segment vs Group), take the larger one
            current_max = master_timeline.get(year, 0)
            if val > current_max:
                master_timeline[year] = val
    return master_timeline

# --- ANSWER 1: NIM TREND ---
print("\nüîπ 1. Gross Margin / NIM Trend (Last 5 Quarters)")
if not df_nim.empty:
    timeline = {}
    for _, r in df_nim.iterrows():
        # Clean keys
        clean_q = {k:v for k,v in r['Data'].items() if 'q' in str(k).lower()}
        timeline.update(clean_q)
    
    def q_sorter(q):
        m = re.match(r'([1-4])q(\d{2})', str(q).lower())
        if m: return int(f"20{m.group(2)}{m.group(1)}")
        return 0

    if timeline:
        sorted_qs = sorted(timeline.keys(), key=q_sorter)
        display_qs = sorted_qs[-5:] if len(sorted_qs) > 5 else sorted_qs
        
        print(f"| Quarter | Group NIM (%) |")
        print(f"|:-------:|:-------------:|")
        for q in display_qs:
            print(f"| {q:<7} | {timeline[q]:<13} |")
    else:
        print("‚ùå No valid quarterly keys found.")
else:
    print("‚ùå No NIM data found.")


# --- ANSWER 2: OPEX YoY ---
print("\nüîπ 2. Operating Expenses (3 Years YoY)")
opex_data = merge_financial_series(df_opex)

if opex_data:
    print(f"| Year | Opex ($m) | YoY Change (%) |")
    print(f"|:----:|:---------:|:--------------:|")
    
    years = sorted(opex_data.keys())[-3:] # Last 3 years available
    for i, y in enumerate(years):
        val = opex_data[y]
        yoy_str = "-"
        
        if (y - 1) in opex_data:
            prev_val = opex_data[y-1]
            if prev_val != 0:
                pct = ((val - prev_val) / prev_val) * 100
                yoy_str = f"{pct:+.1f}%"
        
        print(f"| {y:<4} | {val:<9,.1f} | {yoy_str:<14} |")
else:
    print("‚ùå No Opex data found.")


# --- ANSWER 3: EFFICIENCY RATIO ---
print("\nüîπ 3. Operating Efficiency Ratio (Opex √∑ Income)")
inc_data = merge_financial_series(df_inc)

if opex_data and inc_data:
    common_years = sorted(set(opex_data.keys()) & set(inc_data.keys()))[-3:]
    
    print(f"| Year | Opex ($m) | Income ($m) | Efficiency Ratio (%) |")
    print(f"|:----:|:---------:|:-----------:|:--------------------:|")
    
    for y in common_years:
        opex = opex_data[y]
        inc = inc_data[y]
        ratio = (opex / inc) * 100 if inc != 0 else 0
        print(f"| {y:<4} | {opex:<9,.1f} | {inc:<11,.1f} | {ratio:<20.1f} |")
    
    print(f"\n*Calculation: (Operating Expenses / Total Income) * 100")
else:
    print("‚ùå Insufficient data (need both Opex and Income).")

üì¶ Loaded 52949 rows from kb_tables.parquet
üîÑ Pivoting data to reconstruct table rows...
‚úÖ Reconstructed 10332 unique table rows.

üîé Extracting raw data...
   found 6 Opex rows
   found 19 Income rows
   found 64 NIM rows

üìù FINAL BENCHMARK OUTPUTS

üîπ 1. Gross Margin / NIM Trend (Last 5 Quarters)
| Quarter | Group NIM (%) |
|:-------:|:-------------:|
| 2Q24    | 2.14          |
| 3Q24    | 2.11          |
| 4Q24    | 2.15          |
| 1Q25    | 2.12          |
| 2Q25    | 2.05          |

üîπ 2. Operating Expenses (3 Years YoY)
| Year | Opex ($m) | YoY Change (%) |
|:----:|:---------:|:--------------:|
| 2022 | 7,090.0   | +7.9%          |
| 2023 | 8,291.0   | +16.9%         |
| 2024 | 9,018.0   | +8.8%          |

üîπ 3. Operating Efficiency Ratio (Opex √∑ Income)
| Year | Opex ($m) | Income ($m) | Efficiency Ratio (%) |
|:----:|:---------:|:-----------:|:--------------------:|
| 2022 | 7,090.0   | 16,502.0    | 43.0                 |
| 2023 | 8,291.0   | 20,180.0  

## 4. Baseline Pipeline

**Baseline (starting point)**
*   Naive chunking.
*   Single-pass vector search.
*   One LLM call, no caching.

## Async I/O 


In [8]:
pip install aiohttp


Note: you may need to restart the kernel to use updated packages.


You should consider upgrading via the 'c:\Users\PC\AppData\Local\Programs\Python\Python310\python.exe -m pip install --upgrade pip' command.


In [None]:
import asyncio
try:
    import nest_asyncio
except ImportError:
    nest_asyncio = None

def ensure_loop():
    try:
        loop = asyncio.get_event_loop()
        if loop.is_running() and nest_asyncio:
            nest_asyncio.apply()
        return loop
    except RuntimeError:
        loop = asyncio.new_event_loop()
        asyncio.set_event_loop(loop)
        return loop

# 2) Async retrieval shim (FAISS/BM25 via thread pool)
from concurrent.futures import ThreadPoolExecutor
_search_pool = ThreadPoolExecutor(max_workers=8)

class AsyncRetrieval:
    @staticmethod
    async def execute_parallel_async(kb, sub_queries, k_ctx: int, max_concurrency: int = 8):
        print(f"Running async parallel retrieval for {len(sub_queries)} sub-queries with concurrency={max_concurrency}")
        if not sub_queries:
            return []
        loop = asyncio.get_event_loop()
        sem = asyncio.Semaphore(max_concurrency)

        async def one(q: str):
            async with sem:
                return await loop.run_in_executor(_search_pool, kb.search, q, k_ctx)

        results = await asyncio.gather(*(one(q) for q in sub_queries), return_exceptions=True)
        import pandas as pd
        safe = []
        for r in results:
            safe.append(pd.DataFrame() if isinstance(r, Exception) else r)
        return safe

    @staticmethod
    def execute_parallel(kb, sub_queries, k_ctx: int, max_concurrency: int = 8):
        loop = ensure_loop()
        return loop.run_until_complete(
            AsyncRetrieval.execute_parallel_async(kb, sub_queries, k_ctx, max_concurrency)
        )

# 3) Wire into your decomposer if present; else, provide a tiny adapter
def execute_parallel_subqueries(kb, sub_queries, k_ctx, max_concurrency=8):
    return AsyncRetrieval.execute_parallel(kb, sub_queries, k_ctx, max_concurrency)

# 4) Async HTTP clients (LLM + Embeddings) with sync adapters
import aiohttp
from typing import List, Dict, Any

class AsyncLLMClient:
    def __init__(self, base_url: str, api_key: str, max_concurrent: int = 8, timeout_s: int = 60):
        self.base_url = base_url.rstrip("/")
        self.api_key = api_key
        self.sem = asyncio.Semaphore(max_concurrent)
        self.timeout_s = timeout_s

    async def chat(self, payload: Dict[str, Any]) -> Dict[str, Any]:
        print("Async LLM chat API call started")
        headers = {"Authorization": f"Bearer {self.api_key}"} if self.api_key else {}
        async with self.sem:
            async with aiohttp.ClientSession(timeout=aiohttp.ClientTimeout(total=self.timeout_s)) as sess:
                async with sess.post(f"{self.base_url}/chat/completions", json=payload, headers=headers) as r:
                    r.raise_for_status()
                    print("Async LLM chat API call completed")
                    return await r.json()

class AsyncEmbeddingsClient:
    def __init__(self, base_url: str, api_key: str, batch_size: int = 64, max_concurrent: int = 4, timeout_s: int = 60):
        self.base_url = base_url.rstrip("/")
        self.api_key = api_key
        self.batch_size = batch_size
        self.sem = asyncio.Semaphore(max_concurrent)
        self.timeout_s = timeout_s

    async def embed(self, texts: List[str]) -> List[List[float]]:
        headers = {"Authorization": f"Bearer {self.api_key}"} if self.api_key else {}
        out: List[List[float]] = []
        async with aiohttp.ClientSession(timeout=aiohttp.ClientTimeout(total=self.timeout_s)) as sess:
            tasks = []
            for i in range(0, len(texts), self.batch_size):
                chunk = texts[i:i+self.batch_size]
                async def one(ch=chunk):
                    async with self.sem:
                        async with sess.post(f"{self.base_url}/embeddings", json={"input": ch}, headers=headers) as r:
                            r.raise_for_status()
                            data = await r.json()
                            return [v["embedding"] for v in data["data"]]
                tasks.append(one())
            for res in await asyncio.gather(*tasks, return_exceptions=True):
                if isinstance(res, Exception):
                    continue
                out.extend(res)
        return out

# 5) Provide sync adapters so the rest of the notebook doesn‚Äôt break
import os
LLM_ASYNC = AsyncLLMClient(base_url=os.environ.get("OPENAI_BASE_URL","https://api.openai.com/v1"),
                           api_key=os.environ.get("OPENAI_API_KEY",""),
                           max_concurrent=8)

EMB_ASYNC = AsyncEmbeddingsClient(base_url=os.environ.get("OPENAI_BASE_URL","https://api.openai.com/v1"),
                                  api_key=os.environ.get("OPENAI_API_KEY",""),
                                  batch_size=64, max_concurrent=4)

def llm_chat_sync(payload: dict) -> dict:
    loop = ensure_loop()
    return loop.run_until_complete(LLM_ASYNC.chat(payload))

def embed_sync(texts: List[str]) -> List[List[float]]:
    loop = ensure_loop()
    return loop.run_until_complete(EMB_ASYNC.embed(texts))




### Gemini Version 2

In [20]:
#!/usr/bin/env python3
# -*- coding: utf-8 -*-

"""
g2x.py ‚Äî Agentic RAG with tools on top of data_marker/ (FAISS + Marker outputs)
       - BM25, Reciprocal Rank Fusion, and Cross-Encoder Reranking

Artifacts required in ./data_marker:
  - kb_index.faiss
  - kb_index_meta.json
  - kb_texts.npy
  - kb_chunks.parquet
  - kb_tables.parquet        (recommended for table tools)
  - kb_outline.parquet       (optional, for section hints)

Tools exposed:
  1) CalculatorTool           -> safe arithmetic, deltas, YoY
  2) TableExtractionTool      -> pull metric rows; extract {year -> value}
  3) MultiDocCompareTool      -> compare a metric across multiple docs
Also:
  - Vector search (FAISS) for grounding

Agent runtime: Plan -> Act -> Observe -> (optional) Refine -> Final
"""

from pathlib import Path
from dataclasses import dataclass
from rank_bm25 import BM25Okapi
from sentence_transformers import CrossEncoder, SentenceTransformer
from typing import List, Dict, Any, Optional, Tuple
from openai import OpenAI
from concurrent.futures import ThreadPoolExecutor

import re, json, math, ast
import numpy as np
import pandas as pd
import asyncio
import faiss
import os

# ----------------------------- LLM (single-call baseline) -----------------------------

def _make_llm_client():
    """Minimal provider selection for LLM"""
    groq_key = os.environ.get("GROQ_API_KEY")
    if groq_key:
        client = OpenAI(api_key=groq_key, base_url="https://api.groq.com/openai/v1")
        model = os.getenv("GROQ_MODEL", "openai/gpt-oss-20b")
        return ("groq", client, model)
    
    gem_key = os.environ.get("GEMINI_API_KEY")
    if gem_key:
        return ("gemini", None, os.getenv("GEMINI_MODEL_NAME", "models/gemini-2.5-flash"))
    
    raise RuntimeError("No LLM credentials found. Set GROQ_API_KEY or GEMINI_API_KEY.")

def _llm_provider_info() -> str:
    try:
        prov, _, model = _make_llm_client()
        return f"{prov}:{model}"
    except Exception as e:
        return f"unconfigured ({e})"

def _llm_single_call(prompt: str, system: str = "You are a precise finance analyst.") -> str:
    prov, client, model = _make_llm_client()
    print(f"[LLM] provider={prov} model={model}")
    if prov == "groq":
        try:
            chat = client.chat.completions.create(
                model=model,
                messages=[
                    {"role": "system", "content": system},
                    {"role": "user", "content": prompt},
                ],
                temperature=0.1,
            )
            return chat.choices[0].message.content.strip()
        except Exception as e:
            return f"LLM error: {e}"
    
    try:
        from google import generativeai as genai
        genai.configure(api_key=os.environ["GEMINI_API_KEY"])
        model_obj = genai.GenerativeModel(model)
        out = model_obj.generate_content(prompt)
        return getattr(out, "text", "") or "LLM returned empty response."
    except Exception as e:
        return f"LLM error (Gemini): {e}"


def _page_or_none(x):
    try:
        import math
        import pandas as pd
        if x is None:
            return None
        if (hasattr(pd, 'isna') and pd.isna(x)) or (isinstance(x, float) and math.isnan(x)):
            return None
        return int(x)
    except Exception:
        return None


# ----------------------------- KB loader with BM25 + Reranker -----------------------------

class KBEnv:
    def __init__(self, base="./data_marker", enable_bm25=True, enable_reranker=True):
        self.base = Path(base)
        self.faiss_path = self.base / "kb_index.faiss"
        self.meta_path = self.base / "kb_index_meta.json"
        self.texts_path = self.base / "kb_texts.npy"
        self.chunks_path = self.base / "kb_chunks.parquet"
        self.tables_path = self.base / "kb_tables.parquet"
        self.outline_path = self.base / "kb_outline.parquet"

        if not self.faiss_path.exists():
            raise FileNotFoundError(self.faiss_path)
        if not self.meta_path.exists():
            raise FileNotFoundError(self.meta_path)
        if not self.texts_path.exists():
            raise FileNotFoundError(self.texts_path)
        if not self.chunks_path.exists():
            raise FileNotFoundError(self.chunks_path)

        self.texts: List[str] = np.load(self.texts_path, allow_pickle=True).tolist()
        self.meta_df: pd.DataFrame = pd.read_parquet(self.chunks_path)
        
        if 'page' in self.meta_df.columns:
            self.meta_df['page'] = pd.to_numeric(self.meta_df['page'], errors='coerce').astype('Int64')
            
        if len(self.texts) != len(self.meta_df):
            raise ValueError(f"texts ({len(self.texts)}) and meta ({len(self.meta_df)}) mismatch")

        self.tables_df: Optional[pd.DataFrame] = (
            pd.read_parquet(self.tables_path) if self.tables_path.exists() else None
        )
        self.outline_df: Optional[pd.DataFrame] = (
            pd.read_parquet(self.outline_path) if self.outline_path.exists() else None
        )

        # FAISS index
        self.index = faiss.read_index(str(self.faiss_path))
        idx_meta = json.loads(self.meta_path.read_text(encoding="utf-8"))
        self.model_name = idx_meta.get("model", "sentence-transformers/all-MiniLM-L6-v2")
        self.embed_dim = int(idx_meta.get("dim", 384))
        self.model = SentenceTransformer(self.model_name)

        # ========== NEW: BM25 Index ==========
        self.bm25 = None
        if enable_bm25:
            # print("[BM25] Building BM25 index...")
            tokenized_corpus = [text.lower().split() for text in self.texts]
            self.bm25 = BM25Okapi(tokenized_corpus)
            print(f"[BM25] ‚úì Indexed {len(self.texts)} documents")
        elif enable_bm25:
            print("[BM25] ‚úó rank_bm25 not installed, skipping BM25")

        # ========== NEW: Reranker ==========
        self.reranker = None
        if enable_reranker:
            # print("[Reranker] Loading cross-encoder...")
            self.reranker = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2')
            print("[Reranker] ‚úì Loaded cross-encoder/ms-marco-MiniLM-L-6-v2")
        elif enable_reranker:
            print("[Reranker] ‚úó CrossEncoder unavailable")

    def _embed(self, texts: List[str]) -> np.ndarray:
        v = self.model.encode(texts, normalize_embeddings=True)
        return np.asarray(v, dtype="float32")

    # ========== NEW: Hybrid Search with BM25 + Vector + RRF ==========
    def search(
        self, 
        query: str, 
        k: int = 12,
        alpha: float = 0.6,  # Weight for vector vs BM25 (0.0=pure BM25, 1.0=pure vector)
        rerank_top_k: int = None  # Rerank top candidates (default: 2*k)
    ) -> pd.DataFrame:
        """
        Hybrid search with BM25 + Vector + optional RRF + optional Reranking
        
        Pipeline:
        1. BM25 search ‚Üí get scores
        2. Vector search ‚Üí get scores
        3. Fusion: RRF (reciprocal rank) or weighted score fusion
        4. Rerank: Cross-encoder on top candidates
        5. Return top-k
        """
        if rerank_top_k is None:
            rerank_top_k = k * 2  # Get 2x candidates for reranking

        # ========== Step 1: Vector Search ==========
        qv = self._embed([query])
        vec_scores, vec_idxs = self.index.search(qv, min(rerank_top_k * 2, len(self.texts)))
        vec_idxs, vec_scores = vec_idxs[0], vec_scores[0]
        
        # Filter valid indices
        vec_results = {int(i): float(s) for i, s in zip(vec_idxs, vec_scores) if i >= 0 and i < len(self.texts)}

        # ========== Step 2: BM25 Search ==========
        bm25_results = {}
        if self.bm25 is not None:
            query_tokens = query.lower().split()
            bm25_scores = self.bm25.get_scores(query_tokens)
            
            # Normalize BM25 scores to [0, 1]
            max_bm25 = max(bm25_scores) if len(bm25_scores) > 0 else 1.0
            if max_bm25 > 0:
                bm25_scores = bm25_scores / max_bm25
            
            # Get top candidates
            top_bm25_idx = np.argsort(bm25_scores)[-rerank_top_k * 2:][::-1]
            bm25_results = {int(i): float(bm25_scores[i]) for i in top_bm25_idx if bm25_scores[i] > 0}

        # ========== Step 3: Fusion (RRF or Weighted Score) ==========
        all_indices = set(vec_results.keys()) | set(bm25_results.keys())
        
        if self.bm25 is not None:
            # Reciprocal Rank Fusion
            vec_ranks = {idx: rank for rank, idx in enumerate(sorted(vec_results, key=vec_results.get, reverse=True), 1)}
            bm25_ranks = {idx: rank for rank, idx in enumerate(sorted(bm25_results, key=bm25_results.get, reverse=True), 1)}
            
            k_rrf = 60  # RRF constant
            fused_scores = {}
            for idx in all_indices:
                vec_rank = vec_ranks.get(idx, len(self.texts))
                bm25_rank = bm25_ranks.get(idx, len(self.texts))
                fused_scores[idx] = (1 / (k_rrf + vec_rank)) + (1 / (k_rrf + bm25_rank))
            
            print(f"[Search] RRF fusion: {len(all_indices)} candidates")
        else:
            # Weighted score fusion (fallback if BM25 disabled or RRF=False)
            fused_scores = {}
            for idx in all_indices:
                vec_score = vec_results.get(idx, 0.0)
                bm25_score = bm25_results.get(idx, 0.0)
                fused_scores[idx] = alpha * vec_score + (1 - alpha) * bm25_score
            
            print(f"[Search] Weighted fusion (Œ±={alpha}): {len(all_indices)} candidates")

        # Sort by fused score
        sorted_indices = sorted(fused_scores.keys(), key=fused_scores.get, reverse=True)[:rerank_top_k]

        # ========== Step 4: Reranking (Optional) ==========
        if self.reranker is not None and len(sorted_indices) > k:
            print(f"[Rerank] Reranking top-{len(sorted_indices)} candidates...")
            
            # Prepare query-document pairs
            pairs = [[query, self.texts[idx]] for idx in sorted_indices]
            
            # Get rerank scores
            rerank_scores = self.reranker.predict(pairs)
            
            # Update fused scores with rerank scores
            for idx, score in zip(sorted_indices, rerank_scores):
                fused_scores[idx] = float(score)
            
            # Re-sort by rerank scores
            sorted_indices = sorted(sorted_indices, key=fused_scores.get, reverse=True)
            
            print(f"[Rerank] ‚úì Reranked to top-{k}")

        # ========== Step 5: Build Results DataFrame ==========
        final_indices = sorted_indices[:k]
        rows = []
        for rank, idx in enumerate(final_indices, start=1):
            md = self.meta_df.iloc[idx]
            item = {
                "rank": rank,
                "score": fused_scores[idx],
                "text": self.texts[idx],
                "doc": md.get("doc"),
                "path": md.get("path"),
                "modality": md.get("modality"),
                "chunk": int(md.get("chunk", 0)),
                "page": _page_or_none(md.get("page")),
            }
            
            # Section hint
            if self.outline_df is not None:
                toc = self.outline_df[self.outline_df["doc_name"] == item["doc"]]
                if not toc.empty:
                    item["section_hint"] = toc.iloc[0]["title"]
            
            rows.append(item)
        
        return pd.DataFrame(rows)
    
def baseline_answer_one_call(
    kb: KBEnv,
    query: str,
    k_ctx: int = 8,
    table_rows: Optional[List[Dict[str, Any]]] = None
) -> dict:
    """
    Baseline (Stage 4) requirements:
      - Naive chunking (we use existing kb_texts)
      - Single-pass vector search (FAISS only)
      - One LLM call, no caching
    """
    # 1) Retrieve top-k chunks
    ctx_df = kb.search(query, k=k_ctx)
    if ctx_df is None or ctx_df.empty:
        answer = "I couldn't find any relevant context in the KB for this query."
        print(answer)
        return {"answer": answer, "contexts": []}

    # 2) Build context and simple citations
    ctx_lines = []
    for _, row in ctx_df.iterrows():
        text = str(row["text"]).replace("\\n", " ").strip()
        if len(text) > 800:
            text = text[:800] + "..."
        ctx_lines.append(f"- {text}")

    # We will build citations later; prefer table-row provenance if provided
    cits = []

    # Build citations: prefer structured table rows with pages
    if table_rows:
        for r in table_rows[:5]:
            doc = str(r.get("doc") or "")
            page = r.get("page")
            if page is not None:
                cits.append(f"{doc}, page {int(page)}")
            else:
                cits.append(f"{doc}, table {r.get('table_id')} row {r.get('row_id')} (no page)")
    else:
        for _, row in ctx_df.iterrows():
            doc = str(row.get("doc") or "")
            mod = str(row.get("modality") or "")
            page = row.get("page")
            if page is not None:
                cits.append(f"{doc}, page {page}")
            else:
                ch = int(row.get("chunk") or 0)
                if mod in ("md", "table_row"):
                    cits.append(f"{doc}, chunk {ch} (no page; {mod})")
                else:
                    cits.append(f"{doc}, chunk {ch} (no page)")

    # Optional: include structured table rows so the LLM doesn't deny available data
    table_lines = []
    if table_rows:
        table_lines.append("STRUCTURED TABLE ROWS (authoritative):")
        for r in table_rows[:6]:
            ser_q = r.get("series_q") or {}
            ser_y = r.get("series") or {}
            if ser_q:
                def _qkey(k: str):
                    m = re.match(r"([1-4])Q(20\\d{2})$", k)
                    return (int(m.group(2)), int(m.group(1))) if m else (0, 0)
                qkeys = sorted(ser_q.keys(), key=_qkey)[-5:]
                table_lines.append(f"- {r.get('doc')} | {r.get('label')} | " + ", ".join(f"{k}: {ser_q[k]}" for k in qkeys))
            elif ser_y:
                ys = sorted(ser_y.keys())[-3:]
                table_lines.append(f"- {r.get('doc')} | {r.get('label')} | " + ", ".join(f"{y}: {ser_y[y]}" for y in ys))

    # 3) Compose strict prompt
    if table_lines:
        # When we have structured rows, exclude noisy text snippets to avoid conflicting numbers.
        prompt = f"""USER QUESTION: 
            {query}
            {chr(10).join(table_lines)} 
            INSTRUCTIONS:
            1. **Data Source Priority**: Use ONLY the numbers from STRUCTURED TABLE ROWS above. These are authoritative financial data extracted from official reports.

            2. **Metric Substitution**: If the exact metric requested isn't available but a closely related metric exists (e.g., "Total Income" instead of "Operating Income"), use the available metric and clearly state the substitution in your answer.

            3. **Calculations**: 
            - Show your work for any calculations (e.g., ratios, year-over-year growth)
            - Use the format: Operating Efficiency Ratio = Opex √∑ Operating Income = X √∑ Y = Z%
            - Calculate year-over-year changes as: ((New - Old) / Old) √ó 100%

            4. **Missing Data**: If requested periods or metrics are not present in the structured rows:
            - Explicitly state which periods/metrics are missing
            - Provide what IS available
            - Do NOT refuse to answer if partial data exists

            5. **Output Format**:
            - Start with a direct 1-2 sentence answer
            - Present numerical results in a clear Markdown table with columns: Period/Year | Metric | Value
            - Add brief notes if clarifications are needed

            6. **Accuracy**: Do NOT invent, extrapolate, or estimate numbers. Only use values explicitly shown in the structured rows.

            Example table format:
            | Year | Operating Expenses | Total Income | Efficiency Ratio |
            |------|-------------------|--------------|------------------|
            | 2022 |        $X         |      $Y      |        Z%        |
            | 2023 |        $X         |      $Y      |        Z%        |"""
    else:
        prompt = f"""USER QUESTION:
        {query}

        CONTEXT (verbatim excerpts from financial reports):
        {chr(10).join(ctx_lines)}

        INSTRUCTIONS:
        1. **Data Source**: Extract information ONLY from the CONTEXT above. These are direct quotes from official reports.

        2. **Explicit Data Gaps**: If the exact values for requested periods are not present in the context:
        - State which specific periods/metrics are missing
        - Provide what IS available from the context
        - Do NOT make up or estimate missing values

        3. **Calculations**: If calculations are requested:
        - Show your working step-by-step
        - Only calculate if all required values are present in the context
        - Use the format: Ratio = A √∑ B = X √∑ Y = Z%

        4. **Output Format**:
        - Start with a direct answer summarizing what you found
        - Present data in a clear Markdown table when applicable
        - Add a "Missing data" section if any requested information is unavailable

        5. **Citations**: Reference specific excerpts when stating values (e.g., "according to excerpt 2...")

        6. **Accuracy**: Precision is critical. Only use numbers explicitly stated in the context.

        Example output structure:
        **Answer**
        [Direct 1-2 sentence response]

        | Period | Metric | Value |
        |--------|--------|-------|
        |Q4 2024 | NIM    | 2.05% |

        **Missing data**
        - Q1-Q3 2024: No quarterly data available in context"""

    # 4) One LLM call
    print(f"[LLM] single-call baseline using {_llm_provider_info()}")
    answer = _llm_single_call(prompt)

    # 5) Print nicely in notebooks
    # print("""\nBASELINE (Single LLM Call)\n--------------------------------""")
    # print(answer)
    # print("\nCitations:")
    # for c in cits[:5]:
    #     print(f"- {c}")

    return {"answer": answer, "contexts": ctx_df.head(5)}
    

# ----------------------------- Tool: Calculator -----------------------------

class CalculatorTool:
    """
    Safe arithmetic eval (supports +,-,*,/,**, parentheses) and helpers for deltas/YoY.
    """

    ALLOWED = {
        ast.Expression, ast.BinOp, ast.UnaryOp, ast.Num, ast.Load,
        ast.Add, ast.Sub, ast.Mult, ast.Div, ast.Pow, ast.USub, ast.UAdd,
        ast.Mod, ast.FloorDiv, ast.Constant, ast.Call, ast.Name
    }
    SAFE_FUNCS = {"round": round, "abs": abs}

    @classmethod
    def safe_eval(cls, expr: str) -> float:
        node = ast.parse(expr, mode="eval")
        for n in ast.walk(node):
            if type(n) not in cls.ALLOWED:
                raise ValueError(f"Disallowed expression: {type(n).__name__}")
            if isinstance(n, ast.Call) and not (isinstance(n.func, ast.Name) and n.func.id in cls.SAFE_FUNCS):
                raise ValueError("Only round(...) and abs(...) calls are allowed")
        code = compile(node, "<expr>", "eval")
        return float(eval(code, {"__builtins__": {}}, cls.SAFE_FUNCS))

    @staticmethod
    def delta(a: float, b: float) -> float:
        return float(a) - float(b)

    @staticmethod
    def yoy(a: float, b: float) -> Optional[float]:
        b = float(b)
        if b == 0: return None
        return (float(a) - b) / b * 100.0


# ----------------------------- Tool: Table Extraction -----------------------------
class TableExtractionTool:
    """
    Look up a metric row in kb_tables.parquet and extract {year/quarter -> value}.
    Uses pivoting to handle both Wide (Financials) and Long (NIM) table formats.
    """

    def __init__(self, tables_df: Optional[pd.DataFrame]):
        self.wide_df = pd.DataFrame()
        if tables_df is not None and not tables_df.empty:
            # PIVOT: reconstruct rows from long-format storage
            # Group by doc+table+row to flatten 'Entity-Attribute-Value' back to a row
            try:
                self.wide_df = tables_df.pivot_table(
                    index=['doc_name', 'table_id', 'row_id'], 
                    columns='column', 
                    values='value_str', 
                    aggfunc='first'
                ).reset_index()
                
                # Pre-compute a text representation of each row for fuzzy matching
                self.wide_df['_row_text'] = self.wide_df.apply(
                    lambda r: " ".join([str(x) for x in r.values if pd.notna(x)]).lower(), axis=1
                )
            except Exception as e:
                print(f"[TableTool] Pivot error: {e}")
                self.wide_df = pd.DataFrame()

    @staticmethod
    def _norm(s: str) -> str:
        return str(s).lower().strip()

    def get_metric_rows(self, metric: str, limit: int = 5) -> List[Dict[str, Any]]:
        if self.wide_df.empty:
            return []

        metric_norm = self._norm(metric)
        # Basic synonyms handling
        keywords = [metric_norm]
        if "nim" in metric_norm or "margin" in metric_norm:
            keywords += ["nim", "net interest margin", "group nim"]
        elif "expense" in metric_norm:
            keywords += ["total expenses", "operating expenses", "expenses"]
        elif "income" in metric_norm:
            keywords += ["total income", "operating income", "total operating income"]
            
        results = []
        
        # Identify Year Columns (e.g., "2023", "2024")
        all_cols = self.wide_df.columns.astype(str)
        year_cols = [c for c in all_cols if re.match(r'^\d{4}$', c)]
        
        # Filter: Only rows containing one of our keywords
        # (We use the pre-computed _row_text for speed)
        mask = self.wide_df['_row_text'].apply(lambda x: any(k in x for k in keywords))
        candidates = self.wide_df[mask]

        for _, row in candidates.iterrows():
            data = {}
            row_dict = row.dropna().to_dict()
            
            # --- STRATEGY A: Year Columns (for Opex/Income) ---
            # If the row has values in columns like "2023", grab them.
            for y in year_cols:
                if y in row_dict:
                    # Clean number: "9,018.0" -> 9018.0
                    val = str(row_dict[y]).replace(',', '')
                    try:
                        data[int(y)] = float(val)
                    except ValueError:
                        pass
            
            # --- STRATEGY B: Quarter Columns (for NIM) ---
            # If Strategy A didn't find much, check for "Quarter" columns
            if len(data) < 2: 
                # Find which column acts as the "Quarter" label
                q_col = next((k for k in row_dict.keys() if 'quarter' in str(k).lower()), None)
                
                if q_col:
                    quarter_key = row_dict[q_col] # e.g. "1Q24"
                    # Look for the value in a column matching our keyword (e.g. "Group NIM")
                    for col_name, val in row_dict.items():
                        if col_name == q_col or col_name.startswith('_'): continue
                        
                        if any(k in self._norm(col_name) for k in keywords):
                            try:
                                # Clean number: "2.14%" -> 2.14
                                val_clean = str(val).replace('%', '').strip()
                                data[str(quarter_key)] = float(val_clean)
                            except ValueError:
                                pass

            # If we found valid data points, add to results
            if data:
                # Find the best label for this row (cell text or column header)
                label = "Unknown"
                # Check column headers first (Strategy B)
                for k in data.keys():
                    if isinstance(k, str) and not k.isdigit(): # likely quarterly data
                         # Find the column that produced this data
                         for col_name in row_dict:
                             if any(kw in self._norm(col_name) for kw in keywords):
                                 label = col_name
                                 break
                
                # If still unknown, check row values (Strategy A)
                if label == "Unknown":
                    for val in row_dict.values():
                        if any(kw in self._norm(val) for kw in keywords):
                            label = val
                            break

                results.append({
                    "doc": row['doc_name'],
                    "label": label,
                    "series": data if any(isinstance(k, int) for k in data) else {},
                    "series_q": data if any(isinstance(k, str) for k in data) else {}
                })
        
        return results[:limit]

#
# ----------------------------- Tool: Text Extraction (fallback for quarters) -----------------------------
class TextExtractionTool:
    """
    Regex-based fallback when Marker tables don't carry the quarter series.
    Currently focuses on percentage metrics like Net Interest Margin (NIM).
    It scans the KB text chunks and tries to pair quarter tokens with the nearest % value.
    """
    QPAT = re.compile(r"(?i)(?:\b([1-4])\s*q\s*((?:20)?\d{2})\b|\bq\s*([1-4])\s*((?:20)?\d{2})\b|\b([1-4])q((?:20)?\d{2})\b)")
    PCT = re.compile(r"(?i)(\d{1,2}(?:\.\d{1,2})?)\s*%")

    def __init__(self, kb: 'KBEnv'):
        self.kb = kb

    @staticmethod
    def _norm(s: str) -> str:
        return TableExtractionTool._norm(s)

    @staticmethod
    def _mk_qdisp(q: int, y: int) -> str:
        if y < 100: y += 2000
        return f"{q}Q{y}"

    def extract_quarter_pct(self, metric: str, top_k_text: int = 200) -> Dict[str, float]:
        metric_n = self._norm(metric)
        hits = self.kb.search(metric, k=top_k_text)
        if hits is None or hits.empty:
            return {}
        series_q: Dict[str, float] = {}
        for _, row in hits.iterrows():
            txt = str(row["text"])
            # Quick filter: only consider chunks that mention the metric name
            if metric_n not in self._norm(txt):
                continue
            # Find all quarter tokens in this chunk
            quarts = []
            for m in self.QPAT.finditer(txt):
                # groups: (q1,y1) or (q2,y2) or (q3,y3)
                if m.group(1):   q, y = int(m.group(1)), int(m.group(2))
                elif m.group(3): q, y = int(m.group(3)), int(m.group(4))
                else:            q, y = int(m.group(5)), int(m.group(6))
                if y < 100: y += 2000
                quarts.append((q, y, m.start(), m.end()))
            if not quarts:
                continue
            # Find % values; take the nearest % to each quarter mention
            pcts = [(pm.group(1), pm.start(), pm.end()) for pm in self.PCT.finditer(txt)]
            if not pcts:
                continue
            MAX_CHARS = 48  # require proximity
            for (q, y, qs, qe) in quarts:
                best = None; best_d = 1e9
                for (val, ps, pe) in pcts:
                    d = min(abs(ps - qe), abs(pe - qs))
                    if d < best_d and d <= MAX_CHARS:
                        try:
                            num = float(val)
                        except Exception:
                            continue
                        # sanity for NIM-like percentages
                        if 0.0 <= num <= 6.0:
                            best_d = d; best = num
                if best is not None:
                    disp = self._mk_qdisp(q, y)
                    series_q[disp] = float(best)
        return series_q

# ----------------------------- Tool: Multi-Doc Compare -----------------------------

class MultiDocCompareTool:
    """
    Compare the same metric across multiple docs by pulling each doc's row
    and extracting aligned year/value pairs.
    """

    def __init__(self, table_tool: TableExtractionTool):
        self.table_tool = table_tool

    def compare(self, metric: str, years: Optional[List[int]] = None, top_docs: int = 6):
        # get top rows across all docs
        rows = self.table_tool.get_metric_rows(metric, limit=50)
        if not rows:
            return []
        # take first occurrence per doc
        seen = set()
        picked = []
        for r in rows:
            if r["doc"] in seen: 
                continue
            seen.add(r["doc"])
            picked.append(r)
            if len(picked) >= top_docs:
                break
        # align years
        if years is None:
            all_years = set()
            for r in picked:
                all_years.update(r["series"].keys())
            years = sorted(all_years)[-3:]  # default: last 3 years available
        out = []
        for r in picked:
            values = {y: r["series"].get(y) for y in years}
            out.append({"doc": r["doc"], "label": r["label"], "years": years, "values": values})
        return out

# ----------------------------- Agent Mode: plan ‚Üí act ‚Üí observe -----------------------------

@dataclass
class AgentResult:
    plan: List[str]
    actions: List[str]
    observations: List[str]
    final: Dict[str, Any]

# ----------------------------- QUERY ANALYSIS UTILITIES-----------------------------
class QueryAnalyzer:
    """Utility methods for parsing financial queries"""
    
    @staticmethod
    def extract_metric(query: str) -> Optional[str]:
        """
        Extract metric name from query
        Priority: quoted phrase > regex patterns > capitalized words
        """
        # 1. Quoted phrase (highest priority)
        quoted = re.findall(r'"([^"]+)"', query)
        if quoted:
            return quoted[0]
        
        # 2. Common finance metrics (regex patterns)
        candidates = [
            r"net interest margin", r"nim", r"gross margin",
            r"operating expenses?(?: &| and)?(?: income)?",
            r"operating income", r"operating profit",
            r"total income", r"cost-to-income", r"allowances", 
            r"profit before tax", r"efficiency ratio",
            r"return on equity", r"roe", r"return on assets", r"roa"
        ]
        ql = query.lower()
        for pat in candidates:
            m = re.search(pat, ql)
            if m:
                return m.group(0)
        
        # 3. Fallback: capitalized phrase
        m2 = re.findall(r'\b([A-Z][A-Za-z&% ]{3,})\b', query)
        return m2[0] if m2 else None
    
    @staticmethod
    def want_compare(query: str) -> bool:
        """Check if query requests comparison across documents"""
        return bool(re.search(
            r"\b(compare|vs\.?|versus|across docs?|between|multi-?doc)\b", 
            query, re.I
        ))
    
    @staticmethod
    def want_yoy(query: str) -> bool:
        """Check if query requests year-over-year analysis"""
        return bool(re.search(
            r"\b(yoy|year[- ]over[- ]year|growth|change|%|delta|annual growth)\b", 
            query, re.I
        ))
    
    @staticmethod
    def want_quarters(query: str) -> bool:
        """Check if query requests quarterly data"""
        return bool(re.search(
            r"\b(quarter|quarters|\bq[1-4]\b|quarterly|half[- ]?year)\b", 
            query, re.I
        ))
    
    @staticmethod
    def extract_years(query: str) -> List[int]:
        """Extract year numbers from query"""
        years = [int(y) for y in re.findall(r"\b(20\d{2})\b", query)]
        # Deduplicate and sort
        return sorted(set(years))
    
    @staticmethod
    def extract_num_periods(query: str) -> Optional[int]:
        """Extract number of periods (e.g., 'last 5 quarters', 'last 3 years')"""
        # Pattern: "last N quarters/years"
        m = re.search(r"\blast\s+(\d+)\s+(quarters?|years?|periods?)", query, re.I)
        if m:
            return int(m.group(1))
        
        # Pattern: "N quarters/years"
        m2 = re.search(r"\b(\d+)\s+(quarters?|years?|periods?)", query, re.I)
        if m2:
            return int(m2.group(1))
        
        return None
    
    @staticmethod
    def needs_calculation(query: str) -> bool:
        """Check if query requires calculation"""
        return bool(re.search(
            r"\b(calculate|compute|derive|ratio|√∑|divided by|/|percentage of)\b", 
            query, re.I
        ))


# ----------------------------- PARALLEL QUERY DECOMPOSER -----------------------------

class ParallelQueryDecomposer:
    """Decomposes complex queries using QueryAnalyzer"""
    
    @staticmethod
    def decompose(query: str) -> List[str]:
        """
        Intelligent query decomposition using query analysis
        """
        analyzer = QueryAnalyzer
        
        # Extract intent
        needs_calc = analyzer.needs_calculation(query)
        metric = analyzer.extract_metric(query)
        years = analyzer.extract_years(query)
        num_periods = analyzer.extract_num_periods(query)
        
        # Q3: Efficiency Ratio (Opex √∑ Income)
        if needs_calc and metric and "efficiency" in metric.lower():
            return [
                f"Extract Operating Expenses for the last {num_periods or 3} fiscal years",
                f"Extract Total Income for the last {num_periods or 3} fiscal years"
            ]
        
        # Q3: Any ratio calculation (A √∑ B)
        if needs_calc and ("ratio" in query.lower() or "√∑" in query or "/" in query):
            # Try to extract both metrics
            parts = re.split(r'[√∑/]|\bdivided by\b', query, flags=re.I)
            if len(parts) == 2:
                metric_a = analyzer.extract_metric(parts[0])
                metric_b = analyzer.extract_metric(parts[1])
                if metric_a and metric_b:
                    return [
                        f"Extract {metric_a} for the last {num_periods or 3} fiscal years",
                        f"Extract {metric_b} for the last {num_periods or 3} fiscal years"
                    ]
        
        # Multi-metric comparison
        if analyzer.want_compare(query) and metric:
            # Decompose by year if multiple years specified
            if len(years) > 2:
                return [f"Extract {metric} for FY{y}" for y in years]
        
        # Single metric query (no decomposition)
        return [query]

    @staticmethod
    async def execute_parallel_async(kb: KBEnv, sub_queries: List[str], k_ctx: int) -> List[pd.DataFrame]:
        """Execute sub-queries in parallel"""
        loop = asyncio.get_event_loop()
        
        def search_sync(query):
            return kb.search(query, k=k_ctx)
        
        with ThreadPoolExecutor(max_workers=min(len(sub_queries), 4)) as executor:
            tasks = [loop.run_in_executor(executor, search_sync, sq) for sq in sub_queries]
            results = await asyncio.gather(*tasks)
        
        return results
    
    @staticmethod
    def execute_parallel(kb: KBEnv, sub_queries: List[str], k_ctx: int) -> List[pd.DataFrame]:
        """Blocking wrapper for async parallel execution"""
        try:
            loop = asyncio.get_event_loop()
            if loop.is_running():
                # Already in event loop (e.g., Jupyter), use nest_asyncio
                try:
                    import nest_asyncio
                    nest_asyncio.apply()
                except ImportError:
                    pass
        except RuntimeError:
            loop = asyncio.new_event_loop()
            asyncio.set_event_loop(loop)
        
        return loop.run_until_complete(
            ParallelQueryDecomposer.execute_parallel_async(kb, sub_queries, k_ctx)
        )
    
    @staticmethod
    def merge_results(results: List[pd.DataFrame], k_ctx: int) -> pd.DataFrame:
        """Merge and deduplicate parallel results"""
        if not results:
            return pd.DataFrame()
        
        # Concatenate
        merged = pd.concat([r for r in results if not r.empty], ignore_index=True)
        if merged.empty:
            return merged
        
        # Deduplicate by text (keep highest score)
        merged = merged.sort_values('score', ascending=False)
        merged = merged.drop_duplicates(subset=['text'], keep='first')
        
        # Take top-k
        merged = merged.head(k_ctx)
        
        # Re-rank
        merged = merged.sort_values('score', ascending=False).reset_index(drop=True)
        merged['rank'] = range(1, len(merged) + 1)
        
        return merged

# ----------------------------- Agent: plan ‚Üí act ‚Üí observe -----------------------------

class Agent:
    """
    Unified Agent with all tools:
    - CalculatorTool
    - TableExtractionTool
    - TextExtractionTool
    - MultiDocCompareTool (NEW)
    """
    
    def __init__(
        self, 
        kb: KBEnv, 
        use_parallel_subqueries: bool = False,
        verbose: bool = True
    ):
        self.kb = kb
        self.use_parallel_subqueries = use_parallel_subqueries
        self.verbose = verbose
        
        # Initialize all tools
        self.calc_tool = CalculatorTool()
        
        # Table tool (required for multi-doc compare)
        self.table_tool = TableExtractionTool(kb.tables_df) if kb.tables_df is not None else None
        
        # Text extraction fallback
        self.text_tool = TextExtractionTool(kb)
        
        # ‚úÖ Multi-doc compare tool (NEW)
        self.multidoc_tool = MultiDocCompareTool(self.table_tool) if self.table_tool else None
        
        # Analyzers
        self.analyzer = QueryAnalyzer()
        self.decomposer = ParallelQueryDecomposer() if use_parallel_subqueries else None
        
        if self.verbose:
            tools_status = []
            tools_status.append(f"Calculator: ‚úì")
            tools_status.append(f"Table: {'‚úì' if self.table_tool else '‚úó'}")
            tools_status.append(f"Text: ‚úì")
            tools_status.append(f"MultiDoc: {'‚úì' if self.multidoc_tool else '‚úó'}")
            print(f"[Agent] Tools: {' | '.join(tools_status)}")
    
    def run(self, query: str, k_ctx: int = 12) -> 'AgentResult':
        """Execute query with all available tools"""
        
        if self.verbose:
            print(f"\n[Agent] Query: {query[:60]}...")
        
        # Step 1: Analyze query
        metric = self.analyzer.extract_metric(query)
        wants_yoy = self.analyzer.want_yoy(query)
        wants_quarters = self.analyzer.want_quarters(query)
        wants_compare = self.analyzer.want_compare(query)  
        needs_calc = self.analyzer.needs_calculation(query)
        years = self.analyzer.extract_years(query)
        num_periods = self.analyzer.extract_num_periods(query)
        
        if self.verbose:
            print(f"[Agent] Analysis:")
            print(f"  Metric: {metric}")
            print(f"  YoY: {wants_yoy}, Quarterly: {wants_quarters}")
            print(f"  Compare: {wants_compare}, Calc: {needs_calc}")  
            print(f"  Years: {years}, Periods: {num_periods}")
        
        # Step 2: Retrieve contexts (parallel or standard)
        if self.use_parallel_subqueries and self.decomposer:
            sub_queries = self.decomposer.decompose(query)
            
            if len(sub_queries) > 1:
                if self.verbose:
                    print(f"[Agent] Decomposed into {len(sub_queries)} sub-queries")
                
                sub_results = self.decomposer.execute_parallel(self.kb, sub_queries, k_ctx)
                contexts = self.decomposer.merge_results(sub_results, k_ctx)
                
                if self.verbose:
                    print(f"[Agent] Merged ‚Üí {len(contexts)} contexts")
            else:
                contexts = self.kb.search(query, k=k_ctx)
        else:
            contexts = self.kb.search(query, k=k_ctx)
        
        # Step 3: Tool selection and execution
        actions = []
        table_rows = []
        comparison_results = []
        
        # ========== NEW: Multi-doc compare tool ==========
        if wants_compare and metric and self.multidoc_tool:
            actions.append("multi_doc_compare")
            comparison_results = self.multidoc_tool.compare(
                metric=metric,
                years=years if years else None,
                top_docs=6
            )
            if self.verbose:
                print(f"[Agent] MultiDoc compare: {len(comparison_results)} documents")
        
        # Table extraction (standard)
        elif metric and self.table_tool:
            actions.append("table_extraction")
            table_rows = self.table_tool.get_metric_rows(metric, limit=10)
            if self.verbose:
                print(f"[Agent] Extracted {len(table_rows)} table rows")
        
        # Text extraction fallback
        if not table_rows and not comparison_results and self.text_tool:
            actions.append("text_extraction")
            if wants_quarters and metric:
                quarter_data = self.text_tool.extract_quarter_pct(metric, top_k_text=50)
                if self.verbose:
                    print(f"[Agent] Text extraction: {len(quarter_data)} quarters")
        
        # Calculator for YoY or ratios
        if needs_calc and self.calc_tool:
            actions.append("calculation")
        
        # Step 4: LLM Synthesis
        answer = self._synthesize_with_llm(
            query, 
            contexts, 
            table_rows, 
            comparison_results, 
            metric, 
            wants_yoy,
            wants_compare  
        )
        
        # Step 5: Build result
        observations = [
            f"Retrieved {len(contexts)} contexts",
            f"Metric: {metric}",
            f"YoY: {wants_yoy}, Quarterly: {wants_quarters}, Compare: {wants_compare}",
            f"Tools used: {', '.join(actions)}"
        ]
        
        result = AgentResult(
            plan=f"Analyze ‚Üí {'Compare' if wants_compare else 'Extract'} {metric} ‚Üí Synthesize",
            actions=actions,
            observations=observations,
            final={
                "contexts": contexts,
                "table_rows": table_rows,
                "comparison_results": comparison_results, 
                "answer": answer,
                "metric": metric,
                "wants_yoy": wants_yoy,
                "wants_quarters": wants_quarters,
                "wants_compare": wants_compare
            }
        )
        
        return result
    
    def _synthesize_with_llm(
        self, 
        query: str, 
        contexts: pd.DataFrame, 
        table_rows: List[Dict],
        comparison_results: List[Dict],  
        metric: str,
        wants_yoy: bool,
        wants_compare: bool 
    ) -> str:
        """Synthesize final answer using LLM"""
        
        prompt_parts = [f"USER QUESTION:\n{query}\n"]
        
        # ========== NEW: Add multi-doc comparison results ==========
        if comparison_results:
            prompt_parts.append("\nMULTI-DOCUMENT COMPARISON:")
            for comp in comparison_results:
                doc = comp.get("doc", "Unknown")
                label = comp.get("label", metric)
                years = comp.get("years", [])
                values = comp.get("values", [])
                
                if years and values:
                    year_val_pairs = ", ".join(f"{y}: {v}" for y, v in zip(years, values))
                    prompt_parts.append(f"- {doc} | {label}: {year_val_pairs}")
        
        # Add table rows if available
        elif table_rows:
            prompt_parts.append("\nSTRUCTURED DATA:")
            for r in table_rows[:5]:
                if r.get("series_q"):
                    qkeys = sorted(r["series_q"].keys())[-5:]
                    ser = ", ".join(f"{k}: {r['series_q'][k]}" for k in qkeys)
                    prompt_parts.append(f"- {r['doc']} | {r['label']}: {ser}")
                else:
                    ys = sorted(r["series"].keys())[-3:]
                    ser = ", ".join(f"{y}: {r['series'][y]}" for y in ys)
                    prompt_parts.append(f"- {r['doc']} | {r['label']}: {ser}")
        
        # Add retrieved contexts
        if not contexts.empty:
            prompt_parts.append("\nCONTEXT:")
            for _, row in contexts.head(5).iterrows():
                text = str(row["text"])[:500]
                doc = row.get("doc", "Unknown")
                prompt_parts.append(f"- [{doc}] {text}")
        
        # Add instructions
        prompt_parts.append("\nINSTRUCTIONS:")
        prompt_parts.append("- Use ONLY the data provided above")
        
        if wants_compare:
            prompt_parts.append("- Compare the metric across different documents")
            prompt_parts.append("- Highlight similarities and differences")
        
        if wants_yoy:
            prompt_parts.append("- Calculate year-over-year growth percentages")
        
        prompt_parts.append("- Provide a concise answer with specific numbers and document names")
        prompt_parts.append("- If data is incomplete, state what's missing explicitly")
        
        prompt = "\n".join(prompt_parts)
        
        # Call LLM
        if self.verbose:
            print(f"[Agent] Synthesizing with LLM...")
        
        answer = _llm_single_call(prompt)
        
        return answer


# ----------------------------- Pretty print helpers -----------------------------

def _fmt_series(series: Dict[int, float], n: int = 3) -> str:
    if not series: return "‚Äî"
    ys = sorted(series.keys())[-n:]
    return ", ".join(f"{y}: {series[y]}" for y in ys)

def show_agent_result(res: AgentResult, show_ctx: int = 3):
    print("PLAN:")
    for step in res.plan:
        print("  -", step)
    print("\nACTIONS:")
    for a in res.actions:
        print("  -", a)
    print("\nOBSERVATIONS:")
    for o in res.observations:
        print("  -", o)

    fin = res.final

    # TABLE ROWS block
    if not fin.get("table_rows"):
        msg = fin.get("notice") or "No matching table rows were found for your request."
        print(f"\n‚ö†Ô∏è {msg}")
    elif "table_rows" in fin and fin["table_rows"]:
        print("\nTABLE ROWS (first few):")
        shown = 0
        for r in fin["table_rows"]:
            if shown >= 3:
                break
            sq = (r.get("series_q") or {})
            if sq:
                # sort quarters chronologically by (year, quarter)
                def _qkey(k):
                    m = re.match(r"([1-4])Q(20\\d{2})$", k)
                    if m:
                        return (int(m.group(2)), int(m.group(1)))
                    return (0, 0)
                qkeys = sorted(sq.keys(), key=_qkey)
                last5 = qkeys[-5:]
                ser = ", ".join(f"{k}: {sq[k]}" for k in last5)
                print(f"  doc={r['doc']} | label={r['label']} | quarters(last5)={ser}")
                shown += 1
            else:
                ys = sorted(r["series"].keys())
                ser = ", ".join(f"{y}: {r['series'][y]}" for y in ys[-3:]) if ys else "‚Äî"
                print(f"  doc={r['doc']} | label={r['label']} | years(last3)={ser}")
                shown += 1
    if "compare" in fin and fin["compare"]:
        print("\nCOMPARE (first few):")
        for r in fin["compare"][:3]:
            row = ", ".join(f"{y}: {r['values'].get(y)}" for y in r["years"])
            print(f"  doc={r['doc']} | label={r['label']} | {row}")
    if "calc" in fin and fin["calc"]:
        print("\nCALC (YoY):")
        for c in fin["calc"]:
            print(f"  {c['from']}‚Üí{c['to']}: {c['value_from']} ‚Üí {c['value_to']} | YoY={c['yoy_pct']}%")

    # Contexts
    ctx = fin.get("contexts")
    if ctx is not None and not ctx.empty:
        print("\nCONTEXTS:")
        for _, row in ctx.head(show_ctx).iterrows():
            t = str(row["text"]).replace("\n", " ")
            if len(t) > 240: t = t[:237] + "..."
            hint = f" ‚Äî {row.get('section_hint')}" if "section_hint" in row else ""
            print(f"  [{row['rank']}] {row['doc']} | {row['modality']}{hint}")
            print("     ", t)


# ----------------------------- CLI / Notebook ------------------------------------

# ----------------------------- Notebook Runtime ------------------------------------

# This section is safe for direct use inside a Jupyter/Colab/VSCode notebook cell.
# It avoids argparse/sys parsing and simply runs a default demo or accepts a variable `query`.

# Example usage in a notebook:
# from g2x import KBEnv, Agent, show_agent_result
# kb = KBEnv(base="./data_marker")
# agent = Agent(kb)
# res = agent.run("Compare Net Interest Margin across docs for 2022‚Äì2024")
# show_agent_result(res)

if __name__ == "__main__" or "__file__" not in globals():
    kb = KBEnv(base="./data_marker")
    agent = Agent(kb)  # Initialize the agent with tools

    try:
        query = globals().get("query", None)
    except Exception:
        query = None

    if not query:
        query = "What is the Net Interest Margin over the last 5 quarters?"
        print("‚ÑπÔ∏è Running notebook demo query:")
        print(f"   ‚Üí {query}\n")

    # --- OPTION 1: Run Full Agent (Recommended) ---
    print("üöÄ Running Agent...")
    result = agent.run(query)
    show_agent_result(result)
    
    # --- OPTION 2: Run Baseline (No Tools) ---
    # out = baseline_answer_one_call(kb, query, k_ctx=8)
    # print("\nBaseline Answer:", out["answer"])

[BM25] ‚úì Indexed 13587 documents
[Reranker] ‚úì Loaded cross-encoder/ms-marco-MiniLM-L-6-v2
[Agent] Tools: Calculator: ‚úì | Table: ‚úì | Text: ‚úì | MultiDoc: ‚úì
üöÄ Running Agent...

[Agent] Query: What is the Net Interest Margin over the last 5 quarters?...
[Agent] Analysis:
  Metric: net interest margin
  YoY: False, Quarterly: True
  Compare: False, Calc: False
  Years: [], Periods: 5
[Search] RRF fusion: 78 candidates
[Rerank] Reranking top-24 candidates...
[Rerank] ‚úì Reranked to top-12
[Agent] Extracted 7 table rows
[Agent] Synthesizing with LLM...
[LLM] provider=groq model=openai/gpt-oss-20b
PLAN:
  - A
  - n
  - a
  - l
  - y
  - z
  - e
  -  
  - ‚Üí
  -  
  - E
  - x
  - t
  - r
  - a
  - c
  - t
  -  
  - n
  - e
  - t
  -  
  - i
  - n
  - t
  - e
  - r
  - e
  - s
  - t
  -  
  - m
  - a
  - r
  - g
  - i
  - n
  -  
  - ‚Üí
  -  
  - S
  - y
  - n
  - t
  - h
  - e
  - s
  - i
  - z
  - e

ACTIONS:
  - table_extraction

OBSERVATIONS:
  - Retrieved 12 contexts
  - M

---
## Metadata Filtering/Boosting Optimization

Run below cell to enable the Metadata Filtering/boosting optimization, then run the Benchmark Runner to see the change in results.

In [11]:
# ============================================================================
# METADATA ENHANCEMENT SETUP (Run this cell once)
# ============================================================================

# Step 0: Reload modules to pick up latest changes
import sys
import importlib

# Remove old modules from cache
for mod in list(sys.modules.keys()):
    if 'metadata_enhancer' in mod or 'kb_metadata_extension' in mod:
        del sys.modules[mod]

print("=" * 80)
print("METADATA ENHANCEMENT SETUP - FINE-TUNED VERSION")
print("=" * 80 + "\n")

# Step 1: Enhance KB with metadata (ONE TIME - only run if not already done)
try:
    from metadata_enhancer import enhance_kb_with_metadata
    
    print("üìä Enhancing KB with metadata (year, quarter, doc_type, section)...")
    enhance_kb_with_metadata("./data_marker")
    print("‚úì KB enhancement complete!\n")
except FileNotFoundError as e:
    print(f"‚ö†Ô∏è  KB files not found: {e}")
    print("   Make sure ./data_marker/kb_chunks.parquet exists\n")
except Exception as e:
    print(f"‚ÑπÔ∏è  Enhancement skipped (may already be done): {e}\n")

# Step 2: Add metadata search capability to KBEnv
try:
    from kb_metadata_extension import add_metadata_search_to_kbenv
    add_metadata_search_to_kbenv()
    print("‚úì Metadata search added to KBEnv\n")
except Exception as e:
    print(f"‚ö†Ô∏è  Could not add metadata search: {e}\n")

# Step 3: Replace KBEnv.search() with metadata-enhanced version
print("=" * 80)
print("APPLYING BALANCED METADATA OPTIMIZATION")
print("=" * 80 + "\n")

from g2x import KBEnv

# Save original search method if not already saved
if not hasattr(KBEnv, '_original_search'):
    KBEnv._original_search = KBEnv.search
    print("‚úì Original search method saved\n")

# Define metadata-enhanced search wrapper with balanced boosting
def _metadata_optimized_search(self, query, k=50, alpha=0.6, rerank_top_k=100):
    """
    Enhanced search with balanced metadata boosting, recency decay, and adaptive weights
    
    FINE-TUNED SETTINGS:
    - Reduced boost weights (quarter: 4.0x‚Üí6.0x vs old 5.0x‚Üí8.0x)
    - Less aggressive recency decay (5% per quarter vs 7%)
    - Smaller initial pool (k*12 vs k*15) for better speed
    - Improved "last N quarters" detection
    """
    if hasattr(self, 'search_with_metadata'):
        # Temporarily restore original to avoid recursion
        original_method = KBEnv.search
        KBEnv.search = KBEnv._original_search
        try:
            result = self.search_with_metadata(
                query, 
                k=k, 
                alpha=alpha, 
                rerank_top_k=rerank_top_k,
                enable_metadata_boost=True,
                enable_metadata_filter=True,  # Soft filter enabled
                boost_weights=None,  # Use adaptive weights (None = auto-detect)
                apply_recency_decay=True  # Apply time-based decay (5% per quarter)
            )
        finally:
            # Restore the enhanced search
            KBEnv.search = original_method
        return result
    else:
        # Fallback to original if metadata not available
        return KBEnv._original_search(self, query, k=k, alpha=alpha, rerank_top_k=rerank_top_k)

# Apply the optimization
KBEnv.search = _metadata_optimized_search

print("‚úì BALANCED OPTIMIZATION ENABLED")
print("   All kb.search() calls will now use:")
print("   ‚Ä¢ Adaptive metadata boosting (balanced weights)")
print("   ‚Ä¢ Recency decay (5% per quarter age)")
print("   ‚Ä¢ Soft filtering (¬±2 year window when year detected)")
print("   ‚Ä¢ Improved 'last N quarters' detection")

print("\n" + "=" * 80)
print("SETUP COMPLETE - Balanced Metadata Optimization Active!")
print("=" * 80)
print("""
üéØ FINE-TUNED IMPROVEMENTS:

1. BALANCED BOOST WEIGHTS (Less Aggressive):
   ‚Ä¢ Quarterly queries: quarter=6.0x (was 8.0x), doc_type=1.4x, year=1.6x
   ‚Ä¢ YoY comparisons: year=3.0x (was 3.5x), quarter=1.8x, doc_type=2.2x
   ‚Ä¢ Annual queries: doc_type=3.5x (was 4.0x), year=2.2x, quarter=1.0x
   ‚Ä¢ Latest/recent: quarter=5.5x (was 7.0x), year=2.2x
   ‚Ä¢ Defaults: quarter=4.0x, doc_type=1.8x, year=1.6x, section=1.2x

2. GENTLER RECENCY DECAY:
   ‚Ä¢ Documents decay 5% in relevance per quarter of age (was 7%)
   ‚Ä¢ 1Q ago: 0.95x (was 0.93x)
   ‚Ä¢ 4Q ago: 0.81x (was 0.75x)
   ‚Ä¢ 8Q ago: 0.66x (was 0.56x)

3. FASTER INITIAL POOL:
   ‚Ä¢ k*12 candidates (600 for k=50) instead of k*15 (750)
   ‚Ä¢ ~15-20% faster while maintaining quality

4. IMPROVED PATTERN DETECTION:
   ‚Ä¢ Better "over the last N quarters" detection
   ‚Ä¢ "for the last N quarters" now works
   ‚Ä¢ "in the past N quarters" now works

‚úì Balanced precision vs speed
‚úì Less aggressive boosting = more diverse results
‚úì Falls back to regular search for generic queries

TO DISABLE THIS OPTIMIZATION:
  - Restart the kernel, OR
  - Run: KBEnv.search = KBEnv._original_search
""")

METADATA ENHANCEMENT SETUP - FINE-TUNED VERSION

üìä Enhancing KB with metadata (year, quarter, doc_type, section)...
Loading chunks from data_marker\kb_chunks.parquet...
Loading texts from data_marker\kb_texts.npy...
 Enhancing metadata...

Metadata Enhancement Summary:
   Total chunks: 13548
   Years found: {2022: np.int64(3071), 2023: np.int64(3021), 2024: np.int64(6013), 2025: np.int64(1443)}
   Quarters found: 6 unique quarters
   Doc types: {'annual_report': np.int64(9200), 'quarterly_results': np.int64(3055), 'cfo_presentation': np.int64(1002), 'trading_update': np.int64(188), 'press_statement': np.int64(60), 'ceo_presentation': np.int64(43)}

Saving enhanced chunks to data_marker\kb_chunks.parquet...
‚úì KB enhancement complete!

Added 'search_with_metadata' method to KBEnv class
‚úì Metadata search added to KBEnv

APPLYING BALANCED METADATA OPTIMIZATION

‚úì Original search method saved

‚úì BALANCED OPTIMIZATION ENABLED
   All kb.search() calls will now use:
   ‚Ä¢ Adaptive 

---

### Just to check available models

In [1]:
import google.generativeai as genai
import os

# Best practice: store your key as an environment variable
# Or replace "YOUR_API_KEY" with your actual key string for a quick test
genai.configure(api_key=os.environ.get("GEMINI_API_KEY", "YOUR_API_KEY"))

print("Available Models:\n")

# List all models and check which ones support the 'generateContent' method
for model in genai.list_models():
  if 'generateContent' in model.supported_generation_methods:
    print(f"- {model.name}")

Available Models:

- models/gemini-2.5-pro-preview-03-25
- models/gemini-2.5-flash
- models/gemini-2.5-pro-preview-05-06
- models/gemini-2.5-pro-preview-06-05
- models/gemini-2.5-pro
- models/gemini-2.0-flash-exp
- models/gemini-2.0-flash
- models/gemini-2.0-flash-001
- models/gemini-2.0-flash-exp-image-generation
- models/gemini-2.0-flash-lite-001
- models/gemini-2.0-flash-lite
- models/gemini-2.0-flash-lite-preview-02-05
- models/gemini-2.0-flash-lite-preview
- models/gemini-2.0-pro-exp
- models/gemini-2.0-pro-exp-02-05
- models/gemini-exp-1206
- models/gemini-2.0-flash-thinking-exp-01-21
- models/gemini-2.0-flash-thinking-exp
- models/gemini-2.0-flash-thinking-exp-1219
- models/gemini-2.5-flash-preview-tts
- models/gemini-2.5-pro-preview-tts
- models/learnlm-2.0-flash-experimental
- models/gemma-3-1b-it
- models/gemma-3-4b-it
- models/gemma-3-12b-it
- models/gemma-3-27b-it
- models/gemma-3n-e4b-it
- models/gemma-3n-e2b-it
- models/gemini-flash-latest
- models/gemini-flash-lite-lates

  from .autonotebook import tqdm as notebook_tqdm
E0000 00:00:1763887317.584774 9479662 alts_credentials.cc:93] ALTS creds ignored. Not running on GCP and untrusted ALTS is not enabled.


## 5. Benchmark Runner

Run these 3 standardized queries. Produce JSON then prose answers with citations. These are the standardized queries.

*   Gross Margin Trend (or NIM if Bank)
    *   Query: "Report the Gross Margin (or Net Interest Margin, if a bank) over the last 5 quarters, with values."
    *   Expected Output: A quarterly table of Gross Margin % (or NIM % if bank).

*   Operating Expenses (Opex) YoY for 3 Years
    *   Query: "Show Operating Expenses for the last 3 fiscal years, year-on-year comparison."
    *   Expected Output: A 3-year Opex table (absolute numbers and % change).

*   Operating Efficiency Ratio
    *   Query: "Calculate the Operating Efficiency Ratio (Opex √∑ Operating Income) for the last 3 fiscal years, showing the working."
    *   Expected Output: Table with Opex, Operating Income, and calculated ratio for 3 years.

### Gemini Version 3

In [None]:
# """
# Stage2.py ‚Äî DEFINITIVE FINAL VERSION
# Gemini vision plus pdfplumber
# """

# from __future__ import annotations
# import os, re, json, math, traceback
# from typing import List, Dict, Any, Optional

# import numpy as np
# import pandas as pd
# import time, contextlib

# # --- Logging Setup ---
# @contextlib.contextmanager
# def timeblock(row: dict, key: str):
#     t0 = time.perf_counter()
#     try:
#         yield
#     finally:
#         row[key] = round((time.perf_counter() - t0) * 1000.0, 2)

# class _Instr:
#     def __init__(self):
#         self.rows = []
#     def log(self, row):
#         self.rows.append(row)
#     def df(self):
#         cols = ['Query','T_retrieve','T_rerank','T_reason','T_generate','T_total','Tokens','Tools']
#         df = pd.DataFrame(self.rows)
#         for c in cols:
#             if c not in df:
#                 df[c] = None
#         return df[cols]

# instr = _Instr()

# os.environ["TOKENIZERS_PARALLELISM"] = "false"

# # --- Configuration ---
# VERBOSE = bool(int(os.environ.get("AGENT_CFO_VERBOSE", "1")))
# LLM_BACKEND = "gemini"
# GEMINI_MODEL_NAME = "models/gemini-2.5-flash"

# # --- Global Variables ---
# kb: Optional[pd.DataFrame] = None
# texts: Optional[np.ndarray] = None
# index, bm25, EMB = None, None, None
# _HAVE_FAISS, _HAVE_BM25, _INITIALIZED = False, False, False


# # === Groq / OpenAI LLM config ===
# import os
# from openai import OpenAI

# LLM_PROVIDER = os.getenv("LLM_PROVIDER", "groq").lower()  # "groq" | "openai"
# # Good fast defaults on Groq:
# #   - "openai/gpt-oss-20b" (supports Responses API + built-in tools)
# #   - "llama-3.3-70b-versatile" (chat.completions)
# GROQ_MODEL   = os.getenv("GROQ_MODEL", "openai/gpt-oss-20b")
# OPENAI_MODEL = os.getenv("OPENAI_MODEL", "gpt-4o-mini")  # if you switch back to OpenAI

# def _make_llm_client():
#     if LLM_PROVIDER == "groq":
#         api_key = os.environ.get("GROQ_API_KEY")
#         if not api_key:
#             raise RuntimeError("Missing GROQ_API_KEY")
#         return OpenAI(api_key=api_key, base_url="https://api.groq.com/openai/v1"), GROQ_MODEL
#     else:
#         api_key = os.environ.get("OPENAI_API_KEY")
#         if not api_key:
#             raise RuntimeError("Missing OPENAI_API_KEY")
#         return OpenAI(api_key=api_key), OPENAI_MODEL

# def _llm_respond(prompt: str, system: str = "You are a helpful finance analyst.") -> str:
#     """
#     Unified LLM call:
#       - If LLM_PROVIDER is 'groq' or 'openai', use the OpenAI SDK (Groq-compatible base_url when set).
#       - Else, caller should fall back to Gemini via _call_llm.
#     """
#     try:
#         client, model = _make_llm_client()
#     except Exception as e:
#         raise RuntimeError(f"LLM client init failed: {e}")

#     # Prefer chat.completions for generality (works on Groq + OpenAI)
#     try:
#         chat = client.chat.completions.create(
#             model=model,
#             messages=[
#                 {"role": "system", "content": system},
#                 {"role": "user", "content": prompt},
#             ],
#             temperature=0.2,
#         )
#         return chat.choices[0].message.content.strip()
#     except Exception:
#         # Fallback: Responses API (useful for Groq GPT-OSS models)
#         resp = client.responses.create(
#             model=model,
#             input=f"System: {system}\n\nUser: {prompt}"
#         )
#         text = getattr(resp, "output_text", "") or ""
#         return str(text).strip()
        
        
# # --- Core Logic Functions ---
# def _classify_query(q: str) -> Optional[str]:
#     ql = q.lower()
#     if re.search(r"\boperating\s+efficiency\s+ratio\b|\boer\b", ql) or ("√∑" in ql and "operating" in ql and "income" in ql):
#         return "oer"
#     if "nim" in ql or "net interest margin" in ql: 
#         return "nim"
#     if "opex" in ql or "operating expense" in ql or re.search(r"\bexpenses\b", ql): 
#         return "opex"
#     if re.search(r"\b(total\s+income|operating\s+income)\b", ql):
#         return "income"
#     if re.search(r"\bcti\b|cost[\s\-_\/]*to?\s*[\s\-_\/]*income", ql): 
#         return "cti"
#     return None

# class _EmbedLoader:
#     def __init__(self):
#         self.impl, self.dim, self.name, self.fn = None, None, None, None
#     def embed(self, texts: List[str]) -> np.ndarray:
#         if self.impl is None:
#             try:
#                 from sentence_transformers import SentenceTransformer
#                 model_name = "sentence-transformers/all-MiniLM-L6-v2"
#                 st = SentenceTransformer(model_name)
#                 self.impl, self.dim = ("st", model_name), st.get_sentence_embedding_dimension()
#                 self.fn = lambda b: st.encode(b, normalize_embeddings=True).astype(np.float32)
#             except ImportError: raise RuntimeError("sentence-transformers not installed.")
#         return self.fn(texts)

# def init_stage2(out_dir: str = "data"):
#     global kb, texts, index, bm25, _HAVE_FAISS, _HAVE_BM25, _INITIALIZED, EMB
#     os.environ["AGENT_CFO_OUT_DIR"] = out_dir
#     paths = [os.path.join(out_dir, f) for f in ["kb_chunks.parquet", "kb_texts.npy", "kb_index.faiss"]]
#     if not all(os.path.exists(p) for p in paths): raise RuntimeError(f"KB artifacts not found in '{out_dir}'.")
#     kb, texts = pd.read_parquet(paths[0]), np.load(paths[1], allow_pickle=True)
#     try:
#         import faiss
#         _HAVE_FAISS, index = True, faiss.read_index(paths[2])
#     except ImportError: _HAVE_FAISS, index = False, None
#     try:
#         from rank_bm25 import BM25Okapi
#         _HAVE_BM25, bm25 = True, BM25Okapi([str(t).lower().split() for t in texts])
#     except ImportError: _HAVE_BM25, bm25 = False, None
#     EMB = _EmbedLoader()
#     _INITIALIZED = True
#     if VERBOSE: print(f"[Stage2] Initialized successfully from '{out_dir}'.")

# def _ensure_init():
#     if not _INITIALIZED: raise RuntimeError("Stage2 not initialized. Call init_stage2() first.")

# def _detect_last_n_years(q: str) -> Optional[int]:
#     m = re.search(r"last\s+(\d+|three|five)\s+(fiscal\s+)?years?", q, re.I)
#     if m:
#         try:
#             val = m.group(1).lower();
#             if val == 'three': return 3
#             if val == 'five': return 5
#             return int(val)
#         except: return None
#     return None

# def _detect_last_n_quarters(q: str) -> Optional[int]:
#     m = re.search(r"last\s+(\d+|five)\s+quarters", q, re.I)
#     if m:
#         try:
#             val = m.group(1).lower();
#             if val == 'five': return 5
#             return int(val)
#         except: return None
#     return None

# def hybrid_search(query: str, top_k=12, alpha=0.6) -> List[Dict[str, Any]]:
#     _ensure_init()
#     vec_scores, bm25_scores = {}, {}
#     if _HAVE_FAISS and index and EMB:
#         qv = EMB.embed([query]); qv /= np.linalg.norm(qv, axis=1, keepdims=True)
#         sims, ids = index.search(qv.astype(np.float32), top_k * 4)
#         vec_scores = {int(i): float(s) for i, s in zip(ids[0], sims[0]) if i != -1}
#     if _HAVE_BM25 and bm25:
#         scores = bm25.get_scores(query.lower().split())
#         top_idx = np.argsort(scores)[-top_k*4:]
#         bm25_scores = {int(i): float(scores[i]) for i in top_idx}
    
#     fused = {k: (alpha * vec_scores.get(k, 0)) + ((1 - alpha) * (bm25_scores.get(k, 0) / (max(bm25_scores.values()) or 1.0))) for k in set(vec_scores) | set(bm25_scores)}
    
#     is_annual_query = bool(re.search(r"\bfy\b|fiscal\s+year|last\s+\d+\s+years", query, re.I))
#     year_match = re.search(r'\b(20\d{2})\b', query)
#     desired_year = int(year_match.group(1)) if year_match else None

#     qtype = _classify_query(query)
#     for i in fused:
#         meta = kb.iloc[i]
#         boost = 0.0
#         text_l = str(texts[i]).lower()
#         # --- Extended domain-aware features ---
#         file_l = str(meta.file).lower()
#         section_l = (str(meta.section_hint).lower() if isinstance(meta.section_hint, str) else "")
#         mentions_nim = ("net interest margin" in text_l) or re.search(r"\bnim\b", text_l)
#         mentions_percent_nim = bool(re.search(r"net\s+interest\s+margin[^%]{0,200}%|([0-9]+(?:\.[0-9]+)?)\s*%\s*(?:p|pts|percentage\s*points)?", text_l, flags=re.I))
#         mentions_expenses = ("operating expenses" in text_l) or re.search(r"\bexpenses\b", text_l)
#         has_money_units = bool(re.search(r"\(\$?\s*m\)|s\$\s*m|\(\$m\)|\bmillion\b|\bmn\b|\bbn\b|\bbillion\b", text_l, flags=re.I))
#         is_tableish = section_l.startswith("table_p")
#         is_vision = "vision_summary" in section_l
#         is_quarterly_doc = pd.notna(meta.quarter)
#         is_press_or_trading = bool(re.search(r"press[_\s-]?statement|trading[_\s-]?update", file_l))
#         is_corp_gov = "corporate governance" in text_l or "board of directors" in text_l
#         is_cfo_or_perf = bool(re.search(r"cfo[_\s-]?presentation|performance[_\s-]?summary", file_l))

#         # Year/annual vs quarterly alignment
#         if desired_year and pd.notna(meta.year):
#             if int(meta.year) == desired_year:
#                 boost += 5.0
#             else:
#                 boost -= 5.0

#         is_annual_doc = pd.isna(meta.quarter)
#         if is_annual_query:
#             boost += 5.0 if is_annual_doc else -5.0
#         else:
#             boost += 2.0 if not is_annual_doc else 0.0

#         # --- Domain-aware boosts ---
#         if qtype == "nim":
#             # Prefer quarterly docs and chunks explicitly mentioning NIM with a %
#             if is_quarterly_doc:
#                 boost += 4.0
#             if mentions_nim:
#                 boost += 4.0
#             if mentions_nim and mentions_percent_nim:
#                 boost += 6.0
#             # Strongly favour structured sources
#             if is_tableish and mentions_nim:
#                 boost += 5.0
#             if is_vision and (mentions_nim or "net interest margin" in text_l):
#                 boost += 5.0
#             # Penalise generic prose that often lacks explicit % values
#             if is_press_or_trading and not mentions_percent_nim:
#                 boost -= 10.0

#         if qtype == "opex" or qtype == "oer" or qtype == "cti":
#             # Prefer chunks that talk about (operating) expenses with monetary units
#             if mentions_expenses and has_money_units:
#                 boost += 6.0
#             # Extra rewards for structured/table/vision sources
#             if is_tableish and mentions_expenses:
#                 boost += 3.0
#             if is_vision and mentions_expenses:
#                 boost += 4.0
#             # Vision summary pages tend to have "For FYXXXX, Opex were NNNN million."
#             if is_vision and (mentions_expenses):
#                 boost += 5.0
#             # For Opex/CTI/OER annual asks, prefer annual docs
#             if is_annual_query and is_annual_doc:
#                 boost += 3.0
                
#         if qtype == "income":
#             if "total income" in text_l:
#                 boost += 6.0
#             if is_tableish:
#                 boost += 3.0
#             if is_vision:
#                 boost += 4.0
#             if is_annual_query and is_annual_doc:
#                 boost += 3.0

#         # Global penalties for off-topic governance prose
#         if is_corp_gov:
#             boost -= 8.0
#         # Light reward for CFO/performance decks (usually contain crisp metrics)
#         if is_cfo_or_perf:
#             boost += 2.0

#         fused[i] += boost
        
#     hits = [{"doc_id": kb.iloc[i].doc_id, "file": kb.iloc[i].file, "page": int(kb.iloc[i].page), "year": int(kb.iloc[i].year) if pd.notna(kb.iloc[i].year) else None, "quarter": int(kb.iloc[i].quarter) if pd.notna(kb.iloc[i].quarter) else None, "section_hint": kb.iloc[i].section_hint, "score": float(score)} for i, score in sorted(fused.items(), key=lambda x: x[1], reverse=True)[:top_k]]
#     return hits

# def format_citation(hit: dict) -> str:
#     parts = [hit.get("file", "?")]
#     y = hit.get("year"); q = hit.get("quarter")
#     if y is not None and q is not None: parts.append(f"{int(q)}Q{str(int(y))[-2:]}")
#     elif y is not None: parts.append(str(int(y)))
#     if hit.get("page") is not None: parts.append(f"p.{int(hit['page'])}")
#     sec = str(hit.get("section_hint") or "").strip()
#     if sec: parts.append(sec)
#     tab = hit.get("table_id")
#     if tab: parts.append(f"table {tab}")
#     return ", ".join(parts)

# def _latest_fys(kb: pd.DataFrame, n=3):
#     df = kb.copy()
#     df["y"] = pd.to_numeric(df["year"], errors="coerce")
#     ydf = df[df["quarter"].isna()].dropna(subset=["y"]).sort_values("y", ascending=False)
#     if ydf.empty:
#         ydf = df.dropna(subset=["y"]).sort_values("y", ascending=False)
#     years = [int(y) for y in ydf["y"].drop_duplicates().head(n)]
#     return years

# def _latest_quarters(kb: pd.DataFrame, n=5):
#     df = kb.copy()
#     df["y"] = pd.to_numeric(df["year"], errors="coerce")
#     df["q"] = pd.to_numeric(df["quarter"], errors="coerce")
#     qdf = df.dropna(subset=["y","q"]).sort_values(["y","q"], ascending=[False, False])
#     pairs = qdf[["y","q"]].drop_duplicates().head(20).values.tolist()
#     # return unique up to n, ordered newest‚Üíoldest
#     out, seen = [], set()
#     for y,q in pairs:
#         k = (int(y), int(q))
#         if k not in seen:
#             seen.add(k); out.append(k)
#         if len(out) == n: break
#     return out

# def _parse_tool_kv(s: str):
#     # Parses "Value: 8895, Source: file.pdf, 2024, p.15"
#     m = re.search(r"Value:\s*([^\n,]+)\s*,\s*Source:\s*(.*)", s, flags=re.S)
#     if not m: return None, None
#     val = m.group(1).strip()
#     src = m.group(2).strip()
#     return val, src

# def _fmt_num(x):
#     try: return f"{float(x):,.2f}"
#     except: return x

# def _unique_list(xs, cap=5):
#     out, seen = [], set()
#     for s in xs:
#         if not s: continue
#         if s not in seen:
#             seen.add(s); out.append(s)
#         if len(out) >= cap: break
#     return out

# def baseline_nim_5q() -> dict:
#     """
#     NIM for the last 5 quarters (Group):
#       - Use the dedicated NIM series parser (tool_nim_series) which aggregates across docs.
#       - Parse its result into a table.
#       - Add lightweight citations by retrieving a top hit per quarter.
#     """
#     _ensure_init()

#     # 1) Get the consolidated series (Group) from structured/vision + table text
#     series_str = tool_nim_series(last_n=5, variant="group")

#     # Expect format: "NIM (Group) last 5 quarters ‚Üí 2Q25: 2.05%, 1Q25: 2.12%, ..."
#     items = re.findall(r"([1-4]Q\d{2})\s*:\s*([0-9]+(?:\.[0-9]+)?)%", series_str)
#     if not items:
#         # Fall back to the original per-quarter extraction if parsing failed
#         pairs = _latest_quarters(kb, n=5)
#         rows, cites = [], []
#         for (y, q) in pairs:
#             r = tool_table_extraction(f"Net interest margin (%) for {int(q)}Q{int(y)}")
#             val, src = _parse_tool_kv(r)
#             rows.append((f"{q}Q{str(y)[-2:]}", val or "‚Äî"))
#             cites.append(src or r)
#         lines = ["NIM (%) ‚Äî last 5 quarters:", "Quarter | NIM (%)", "--------|--------"]
#         for qlab, v in rows:
#             lines.append(f"{qlab} | {v}")
#         lines.append("\nCitations:")
#         for c in _unique_list(cites, cap=5):
#             lines.append(f"- {c}")
#         return {"answer": "\n".join(lines), "hits": [], "execution_log": {"fallback": True}}

#     # 2) Build table from parsed items (already newest‚Üíoldest in tool_nim_series)
#     rows = [(q.upper(), v) for (q, v) in items]

#     # 3) Lightweight citations: take the top hit per quarter
#     def _cite_for_quarter(q_label: str) -> Optional[str]:
#         hits = hybrid_search(f"Net interest margin (%) {q_label}", top_k=1)
#         if not hits:
#             return None
#         return f"Source: {format_citation(hits[0])}"

#     cites = []
#     for qlab, _ in rows:
#         c = _cite_for_quarter(qlab)
#         if c:
#             cites.append(c)
#     cites = _unique_list(cites, cap=5)

#     # 4) Render output
#     out = ["NIM (%) ‚Äî last 5 quarters (Group):", "Quarter | NIM (%)", "--------|--------"]
#     for qlab, v in rows:
#         out.append(f"{qlab} | {v}")

#     if cites:
#         out.append("\nCitations:")
#         for c in cites:
#             out.append(f"- {c}")

#     return {"answer": "\n".join(out), "hits": [], "execution_log": {"built_from": "tool_nim_series"}}

# # def baseline_run_integrated_precachinopex_3y() -> dict:
# #     """
# #     Operating Expenses for last 3 fiscal years; deterministic extractor + YoY%.
# #     """
# #     _ensure_init()
# #     years = _latest_fys(kb, n=3)
# #     rows, cites = [], []
# #     for y in years:
# #         r = tool_table_extraction(f"Operating expenses for fiscal year {y}")
# #         val, src = _parse_tool_kv(r)
# #         rows.append((y, val or "‚Äî"))
# #         cites.append(src or r)

# #     # sort newest‚Üíoldest
# #     rows.sort(key=lambda t: t[0], reverse=True)
# #     out = ["Opex (S$ m) ‚Äî last 3 fiscal years:", "Year | Opex (S$ m) | YoY %", "-----|-------------|------"]
# #     for i,(yy,vv) in enumerate(rows):
# #         yoy = ""
# #         if i>0 and vv not in ("‚Äî","",None) and rows[i-1][1] not in ("‚Äî","",None):
# #             try:
# #                 cur = float(vv); prev = float(rows[i-1][1])
# #                 yoy = f"{((cur-prev)/prev)*100:,.1f}%"
# #             except: pass
# #         out.append(f"{yy} | { _fmt_num(vv) if vv!='‚Äî' else vv } | {yoy}")

# #     out.append("\nCitations:")
# #     for c in _unique_list(cites, cap=5):
# #         out.append(f"- {c}")

# #     return {"answer":"\n".join(out), "hits":[], "execution_log":{"years": years}}

# # def baseline_efficiency_ratio_3y() -> dict:
# #     """
# #     Operating Efficiency Ratio = Opex / Operating Income, last 3 fiscal years.
# #     """
# #     _ensure_init()
# #     years = _latest_fys(kb, n=3)
# #     rows, cits = [], []
# #     for y in years:
# #         r1 = tool_table_extraction(f"Operating expenses for fiscal year {y}")
# #         v_opex, c1 = _parse_tool_kv(r1)
# #         r2 = tool_table_extraction(f"Operating income for fiscal year {y}")
# #         v_oinc, c2 = _parse_tool_kv(r2)
# #         rows.append((y, v_opex or "‚Äî", v_oinc or "‚Äî"))
# #         cits.extend([c1 or r1, c2 or r2])

# #     rows.sort(key=lambda t: t[0], reverse=True)
# #     out = ["Operating Efficiency Ratio (Opex √∑ Operating Income):",
# #            "Year | Opex (S$ m) | Operating Income (S$ m) | Ratio",
# #            "-----|-------------|-------------------------|------"]
# #     for (yy, o, inc) in rows:
# #         ratio = "‚Äî"
# #         try:
# #             if o not in ("‚Äî","",None) and inc not in ("‚Äî","",None) and float(inc)!=0.0:
# #                 ratio = f"{(float(o)/float(inc))*100:,.1f}%"
# #         except: pass
# #         out.append(f"{yy} | {_fmt_num(o) if o!='‚Äî' else o} | {_fmt_num(inc) if inc!='‚Äî' else inc} | {ratio}")

# #     out.append("\nCitations:")
# #     for c in _unique_list(cits, cap=5):
# #         out.append(f"- {c}")

# #     return {"answer":"\n".join(out), "hits":[], "execution_log":{"years": years}}


# def answer_with_llm(query: str, topk: int = 5) -> Dict[str, Any]:
#     """
#     Baseline pipeline: single-pass retrieval + single LLM call (no planning, no tools).
#       - Uses hybrid_search() for retrieval (vector + BM25).
#       - Builds a compact CONTEXT from top-k chunks.
#       - Calls the LLM once to synthesize an answer.
#       - Ensures citations include report, year/quarter, and page.
#     """
#     _ensure_init()
    
#     ql = query.lower()

#     # Intent router for the 3 standardized prompts
#     if "net interest margin" in ql or "gross margin" in ql:
#         return baseline_nim_5q()

#     # if "operating expenses" in ql and ("last 3 fiscal years" in ql or "year-on-year" in ql or "yoy" in ql):
#     #     return baseline_opex_3y()

#     # if ("operating efficiency ratio" in ql) or ("opex √∑ operating income" in ql) or ("opex / operating income" in ql):
#     #     return baseline_efficiency_ratio_3y()

#     def _pos_of_docid(did: str) -> Optional[int]:
#         mask = (kb["doc_id"] == did).to_numpy()
#         idxs = np.flatnonzero(mask)
#         return int(idxs[0]) if idxs.size else None

#     # Opex-aware retrieval expansion (more table/vision leaning)
#     ql = query.lower()
#     is_opex = ("opex" in ql) or ("operating expense" in ql) or re.search(r"\bexpenses\b", ql)

#     if is_opex:
#         expanded = query + " | Operating expenses Opex ($m) fiscal year table vision_summary"
#         hits = hybrid_search(expanded, top_k=max(1, int(topk) * 2))  # e.g., 10 if topk=5
#     else:
#         hits = hybrid_search(query, top_k=max(1, int(topk)))

#     if not hits:
#         return "No relevant material found."

#     # Build context and citations
#     ctx_lines, cits = [], []
#     for h in hits[:topk]:
#         pos = _pos_of_docid(h.get("doc_id", ""))
#         snippet = (str(texts[pos]) if pos is not None else "")
#         snippet = re.sub(r"\s+", " ", snippet).strip()
#         if snippet:
#             ctx_lines.append(f"- {snippet[:800]}")
#         cits.append(format_citation(h))

#     # Strict prompt: stick to retrieved text; include citations at the end
#     prompt = (
#         "You are a finance analyst.\n"
#         "Using ONLY the CONTEXT below, answer the USER QUERY. Quote numbers exactly as reported.\n"
#         "If the numbers are not present in CONTEXT, say you cannot find them.\n"
#         "End with a bulleted list of citations (report name, year/quarter, page, section if present).\n\n"
#         f"USER QUERY:\n{query}\n\nCONTEXT:\n" + "\n".join(ctx_lines) +
#         "\n\nFORMAT:\nAnswer text.\n\nCitations:\n- <report (year/quarter), p.X, section>\n"
#     )

#     answer = _call_llm(prompt, dry_run=False)

#     # Ensure at least some citations if the model forgets
#     if "Citations:" not in answer:
#         answer += "\n\nCitations:\n" + "\n".join(f"- {c}" for c in cits[:3])

#     return {"answer": answer, "hits": hits[:min(5, len(hits))].to_dict("records") if hasattr(hits, "to_dict") else [], "execution_log": None}


# def _call_llm(prompt: str, dry_run: bool = False) -> str:
#     if dry_run:
#         return '{"plan": []}'

#     # Prefer Groq/OpenAI if configured
#     if os.getenv("LLM_PROVIDER", "").lower() in ("groq", "openai"):
#         try:
#             return _llm_respond(
#                 prompt,
#                 system="You are a precise finance analyst. Be concise and cite sources provided by the tools."
#             )
#         except Exception as e:
#             return f"LLM Generation Failed (Groq/OpenAI path): {e}"

#     # Fallback to Gemini
#     try:
#         from google import generativeai as genai
#         genai.configure(api_key=os.environ['GEMINI_API_KEY'])
#         model = genai.GenerativeModel(GEMINI_MODEL_NAME)
#         safety_settings = [
#             {"category": "HARM_CATEGORY_HARASSMENT", "threshold": "BLOCK_NONE"},
#             {"category": "HARM_CATEGORY_HATE_SPEECH", "threshold": "BLOCK_NONE"},
#             {"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", "threshold": "BLOCK_NONE"},
#             {"category": "HARM_CATEGORY_DANGEROUS_CONTENT", "threshold": "BLOCK_NONE"},
#         ]
#         out = model.generate_content(prompt, safety_settings=safety_settings)
#         return getattr(out, "text", "") or "LLM returned empty response."
#     except Exception as e:
#         return f"LLM Generation Failed (Gemini path): {e}"

# def tool_calculator(expression: str) -> str:
#     try:
#         s = str(expression)

#         # Guard: unresolved placeholders like ${var}
#         placeholders = re.findall(r"\$\{([^}]+)\}", s)
#         if placeholders:
#             return f"Error: unresolved placeholders: {', '.join(placeholders)}"

#         # Normalizations
#         s = re.sub(r'(?<=\d),(?=\d{3}\b)', '', s)               # 12,345 -> 12345
#         s = re.sub(r'(\d+(?:\.\d+)?)\s*%', r'(\1/100)', s)       # 12% -> (12/100)
#         s = re.sub(r'(?i)[s]?\$\s*', '', s)                      # S$ / $ -> strip
#         s = re.sub(r'(?i)\b(bn|billion|b)\b', 'e9', s)           # bn -> e9
#         s = re.sub(r'(?i)\b(mn|million|m)\b', 'e6', s)           # mn -> e6

#         # Safety: allow only digits, + - * / ( ) . e E and spaces
#         safe = re.sub(r'[^0-9eE\+\-*/(). ]', '', s)

#         result = eval(safe)
#         return f"Result: {result}"
#     except Exception as e:
#         return f"Error: {e}"

# def _desired_periods_from_query(query: str) -> list[tuple[int|None, int|None]]:
#     out = []
#     # Quarters like 1Q25
#     for m in re.finditer(r"\b([1-4])Q(\d{2})\b", query, re.I):
#         out.append((2000 + int(m.group(2)), int(m.group(1))))

#     # FY2024 / FY 2024
#     for m in re.finditer(r"\bFY\s?(20\d{2})\b", query, re.I):
#         out.append((int(m.group(1)), None))

#     # "fiscal year 2024"
#     for m in re.finditer(r"\bfiscal\s+year\s+(20\d{2})\b", query, re.I):
#         out.append((int(m.group(1)), None))

#     # bare year (only if nothing else found)
#     if not out:
#         m = re.search(r"\b(20\d{2})\b", query)
#         if m:
#             out.append((int(m.group(1)), None))

#     return out

# def tool_table_extraction(query: str) -> str:
#     """
#     Finds a single reported data point from the knowledge base using hybrid search,
#     then extracts and cleans the most likely numerical value from the retrieved text.

#     Improvements vs. previous version:
#       ‚Ä¢ Robust row-to-text mapping using positional index (not label).
#       ‚Ä¢ Query-aware extraction (Opex ‚Üí 'million' values; NIM ‚Üí percentages).
#       ‚Ä¢ Period-aware filtering (prefer sentences containing requested FY/quarter).
#       ‚Ä¢ Avoids 4-digit years being misread as values.
#       ‚Ä¢ Falls back through multiple heuristics and multiple hits if needed.
#     """
#     if VERBOSE:
#         print(f"  [Tool Call: table_extraction] with query: '{query}'")

#     hits = hybrid_search(query, top_k=12)
#     # --- Vision-first rescue: ensure year-matched vision_summary candidates are in the pool ---
#     try:
#         desired_periods = _desired_periods_from_query(query)
#         desired_years = [y for (y, q) in desired_periods if y]
#         sh_series = kb["section_hint"].astype(str).str.contains("vision_summary", case=False, na=False)
#         mask = sh_series
#         if desired_years:
#             mask = mask & kb["year"].isin(desired_years)
#         vis_idxs = np.flatnonzero(mask.to_numpy())
#         base_score = (min([float(h.get("score") or 0.0) for h in hits]) - 1.0) if hits else 0.0
#         extra_hits = []
#         for idx in vis_idxs[:6]:
#             row = kb.iloc[idx]
#             extra_hits.append({
#                 "doc_id": row.doc_id,
#                 "file": row.file,
#                 "page": int(row.page) if pd.notna(row.page) else None,
#                 "year": int(row.year) if pd.notna(row.year) else None,
#                 "quarter": int(row.quarter) if pd.notna(row.quarter) else None,
#                 "section_hint": row.section_hint,
#                 "score": base_score
#             })
#         if extra_hits:
#             hits = hits + extra_hits

#         # Deduplicate by doc_id
#         seen = set()
#         deduped = []
#         for h in hits:
#             did = h.get("doc_id")
#             if did in seen:
#                 continue
#             seen.add(did)
#             deduped.append(h)
#         hits = deduped

#         # --- Priority ordering of hits: vision first, then tables, then others
#         vision_hits = [h for h in hits if "vision_summary" in str(h.get("section_hint") or "").lower()]
#         table_hits  = [h for h in hits if str(h.get("section_hint") or "").lower().startswith("table_p")]
#         other_hits  = [h for h in hits if h not in vision_hits and h not in table_hits]
#     except Exception:
#         # Fail open; rely on original hits if rescue logic errors out
#         vision_hits, table_hits, other_hits = [], [], []
#         pass

#     if not hits:
#         return "Error: No relevant documents found."

#     # Helper: map a doc_id to the correct position in `texts` using a boolean mask.
#     def _pos_of_docid(did: str) -> Optional[int]:
#         mask = (kb["doc_id"] == did).to_numpy()
#         idxs = np.flatnonzero(mask)
#         return int(idxs[0]) if idxs.size else None

#     # Helper: safer float parsing (strip commas etc.)
#     def _clean_number(s: str) -> Optional[str]:
#         t = s.strip()
#         t = re.sub(r"[,\s]", "", t)
#         # Reject years (e.g., 2024) and obviously huge integers without unit context
#         if re.fullmatch(r"\d{4}", t):
#             return None
#         try:
#             float(t)
#             return t
#         except Exception:
#             return None

#     # Helper: plausibility check for NIM
#     def _plausible_nim_value(x: float) -> bool:
#         # DBS group NIM is realistically ~0.5%‚Äì3.5%
#         try:
#             return 0.5 <= float(x) <= 3.5
#         except Exception:
#             return False
        
#     # Helper: choose the best number from text given the query intent
#     def _extract_value(text: str, query: str) -> Optional[str]:
#         ql = query.lower()
#         is_nim = ("nim" in ql) or ("net interest margin" in ql)
#         is_opex = ("opex" in ql) or ("operating expense" in ql) or re.search(r"\bexpenses\b", ql)
#         is_income = re.search(r"\b(total\s+income|operating\s+income)\b", ql) is not None
#         # Detect if this is an annual ask (not a specific quarter)
#         annual_ask = not re.search(r"\b[1-4]Q\d{2}\b", query, re.I)

#         # If the query mentions a specific period, try to narrow the search window.
#         desired_periods = _desired_periods_from_query(query)
#         windows = []
#         if desired_periods:
#             for (yy, qq) in desired_periods:
#                 if yy and qq:
#                     tag = fr"{qq}q{str(yy)[-2:]}"
#                 elif yy:
#                     tag = fr"fy{yy}"
#                 else:
#                     tag = None
#                 if tag:
#                     m = re.search(tag, text, flags=re.I)
#                     if m:
#                         # take a sentence-sized window around the tag
#                         start = max(0, text.rfind(".", 0, m.start()))
#                         end = text.find(".", m.end())
#                         if end == -1:
#                             end = len(text)
#                         windows.append(text[start:end])
#         if not windows:
#             # fallback: whole text
#             windows = [text]

#         # Query-aware patterns
#         # 1) NIM ‚Üí percentages, prioritizing text near "net interest margin"
#         if is_nim:
#             # 1) Strongly anchored: look for "‚Ä¶margin was/to/at/of N.NN%"
#             for win in windows:
#                 m = re.search(
#                     r"net\s+interest\s+margin[^%]{0,120}?(?:was|to|at|of)\s*([0-9]+(?:\.[0-9]+)?)\s*%",
#                     win, flags=re.I | re.S
#                 )
#                 if m:
#                     v = m.group(1)
#                     if _plausible_nim_value(v):
#                         return _clean_number(v)

#             # 2) Vision-summary phrasing: "Group/Commercial Book Net Interest Margin was 2.13%."
#             for win in windows:
#                 m = re.search(
#                     r"(?:group|commercial(?:\s*book)?)\s*net\s+interest\s+margin.*?(?:was|to|at|of)\s*([0-9]+(?:\.[0-9]+)?)\s*%",
#                     win, flags=re.I | re.S
#                 )
#                 if m:
#                     v = m.group(1)
#                     if _plausible_nim_value(v):
#                         return _clean_number(v)

#             # 3) Anchored fallback: only if NIM is explicitly mentioned; pick the nearest plausible %
#             for win in windows:
#                 m_phrase = re.search(r"net\s+interest\s+margin|\bnim\b", win, flags=re.I)
#                 if not m_phrase:
#                     continue
#                 best = None
#                 best_dist = 1e9
#                 for p in re.finditer(r"([0-9]+(?:\.[0-9]+)?)\s*%", win):
#                     try:
#                         val = float(p.group(1))
#                     except Exception:
#                         continue
#                     if not _plausible_nim_value(val):
#                         continue
#                     dist = abs(p.start() - m_phrase.start())
#                     if dist < best_dist:
#                         best_dist = dist
#                         best = p.group(1)
#                 if best:
#                     return _clean_number(best)

#             # Do NOT fall back to non-% numbers for NIM; better to return None than a wrong value
#             return None

#         # 2) Opex / Operating Expenses ‚Üí numbers followed by a 'million/bn' unit
#         if is_opex:
#             # --- FAST PATH (Vision summary exact sentence for annual Opex) ---
#             # Prefer the Vision-summary wording:
#             # "For FY2024, total Operating Expenses (Opex) were 8895 million."
#             try:
#                 desired_periods_fp = _desired_periods_from_query(query)
#             except Exception:
#                 desired_periods_fp = []
#             target_years_fp = [yy for (yy, qq) in desired_periods_fp if yy and (qq is None)]
#             if target_years_fp:
#                 for yy in target_years_fp:
#                     m_fp = re.search(
#                         rf"For\s*FY{yy}\s*,?\s*total\s+Operating\s+Expenses\s*\(Opex\)\s*were\s*([0-9][\d,]*(?:\.[0-9]+)?)\s*(million|mn|m|bn|billion)\b",
#                         text,
#                         flags=re.I
#                     )
#                     if m_fp:
#                         val_fp = _clean_number(m_fp.group(1)) or None
#                         unit_fp = (m_fp.group(2) or "").lower()
#                         if val_fp:
#                             try:
#                                 v_fp = float(val_fp)
#                                 if unit_fp in ("bn", "billion", "b"):
#                                     v_fp *= 1000.0
#                                 # Annual Opex sanity range in $m for DBS scale
#                                 if 2000.0 <= v_fp <= 15000.0:
#                                     return ("%g" % v_fp)
#                             except Exception:
#                                 pass
#             # Vision-summary phrasing: "For FY2024, total Operating Expenses (Opex) were 8895 million."
#             for win in windows:
#                 m = re.search(
#                     r"operating\s+expenses.*?(?:were|:)?\s*([0-9][\d,]*(?:\.[0-9]+)?)\s*(million|mn|m|bn|billion)\b",
#                     win,
#                     flags=re.I | re.S,
#                 )
#                 if m:
#                     val = _clean_number(m.group(1))
#                     unit = (m.group(2) or "").lower()
#                     if val:
#                         try:
#                             v = float(val)
#                             # Normalise units to millions
#                             if unit in ("bn", "billion", "b"):
#                                 v *= 1000.0
#                             # Annual asks must be a sensible magnitude in $m (reject too-small or absurdly large)
#                             if annual_ask and unit in ("million", "mn", "m", "bn", "billion", "b") and not (2000 <= v <= 15000):
#                                 val = None
#                             else:
#                                 val = ("%g" % v)
#                         except Exception:
#                             pass
#                     if val:
#                         return val

#             # Generic '... expenses ... 8,895 million' even without "operating"
#             for win in windows:
#                 m = re.search(
#                     r"\bexpenses\b.*?(?:were|:)?\s*([0-9][\d,]*(?:\.[0-9]+)?)\s*(million|mn|m|bn|billion)\b",
#                     win,
#                     flags=re.I | re.S,
#                 )
#                 if m:
#                     val = _clean_number(m.group(1))
#                     unit = (m.group(2) or "").lower()
#                     if val:
#                         try:
#                             v = float(val)
#                             if unit in ("bn", "billion", "b"):
#                                 v *= 1000.0
#                             # Annual asks must be a sensible magnitude in $m (reject too-small or absurdly large)
#                             if annual_ask and unit in ("million", "mn", "m", "bn", "billion", "b") and not (2000 <= v <= 15000):
#                                 val = None
#                             else:
#                                 val = ("%g" % v)
#                         except Exception:
#                             pass
#                     if val:
#                         return val

#             # Table/markdown style: headers carry units like "($m)" or "S$ m", and the value is a 4+ digit number
#             for win in windows:
#                 # e.g., "| Operating expenses | 8,895 |" or "Operating expenses 8,895"
#                 m = re.search(
#                     r"(?:operating\s+expenses|^\s*\|\s*operating\s+expenses.*?)\D([0-9][\d,]{3,})\b",
#                     win, flags=re.I | re.S | re.M
#                 )
#                 if m:
#                     val = _clean_number(m.group(1))
#                     if val:
#                         return val
#             # If the surrounding text mentions monetary units like '($m)' or 'S$ m', prefer 4+ digit numbers anywhere in the window
#             for win in windows:
#                 if re.search(r"\(\$?\s*m\)|s\$\s*m|\(\$m\)|\(\$ million\)", win, flags=re.I):
#                     m = re.search(r"\b([0-9][\d,]{3,})\b", win)
#                     if m:
#                         val = _clean_number(m.group(1))
#                         if val:
#                             return val

#             # As a last resort, only if the window itself mentions expenses/opex AND a money unit cue is present.
#             # This avoids accidentally picking unrelated large numbers from generic prose (e.g., CFO narrative pages).
#             for win in windows:
#                 if re.search(r"\b(operating\s+)?expenses?\b|\bopex\b", win, flags=re.I):
#                     # Require a nearby money unit cue to reduce false positives.
#                     if not re.search(r"\(\$?\s*m\)|s\$\s*m|\(\$m\)|\bmillion\b|\bmn\b|\bbn\b|\bbillion\b", win, flags=re.I):
#                         continue
#                     m = re.search(r"\b([0-9][\d,]{3,})\b", win)
#                     if m:
#                         val = _clean_number(m.group(1))
#                         if val:
#                             return val
                        
#         # 3) Total/Operating Income ‚Üí require the phrase and a plausible 4+ digit value
#         if is_income:
#             # Prefer explicit "Total income ... NNNN"
#             for win in windows:
#                 if re.search(r"\btotal\s+income\b", win, flags=re.I):
#                     m = re.search(r"\btotal\s+income\b[^0-9]{0,60}([0-9][\d,]{3,})", win, flags=re.I)
#                     if m:
#                         val = _clean_number(m.group(1))
#                         if val:
#                             try:
#                                 v = float(val)
#                                 if 1000.0 <= v <= 50000.0:  # DBS scale in $m
#                                     return val
#                             except Exception:
#                                 pass
#             # Vision-summary phrasing: "... Total income was 22297."
#             for win in windows:
#                 m = re.search(r"\btotal\s+income\b\s*(?:was|:)?\s*([0-9][\d,]{3,})", win, flags=re.I)
#                 if m:
#                     val = _clean_number(m.group(1))
#                     if val:
#                         try:
#                             v = float(val)
#                             if 1000.0 <= v <= 50000.0:
#                                 return val
#                         except Exception:
#                             pass
#             # Markdown/table row style
#             for win in windows:
#                 m = re.search(r"(?:^\s*\|\s*)?total\s+income(?:\s*\|)?\s*([0-9][\d,]{3,})\b", win, flags=re.I | re.M)
#                 if m:
#                     val = _clean_number(m.group(1))
#                     if val:
#                         return val
#             # If the window says "$m" / "In $ millions", allow a nearby 4+ digit number
#             for win in windows:
#                 if re.search(r"\(\$?\s*m\)|in\s*\$?\s*millions", win, flags=re.I):
#                     m = re.search(r"\b([0-9][\d,]{3,})\b", win)
#                     if m:
#                         val = _clean_number(m.group(1))
#                         if val:
#                             try:
#                                 v = float(val)
#                                 if 1000.0 <= v <= 50000.0:
#                                     return val
#                             except Exception:
#                                 pass
#             # Avoid grabbing random numbers (like '31' from dates)
#             return None

#         # 4) Generic fallback: only for non-domain queries. For NIM/Opex, avoid bogus picks.
#         if not (is_nim or is_opex or is_income):
#             for win in windows:
#                 m = re.search(r"(-?\$?S?\s*[0-9][\d,]*(?:\.[0-9]+)?)", win)
#                 if m:
#                     val = re.sub(r"[S$\s]", "", m.group(1))
#                     val = _clean_number(val)
#                     if val:
#                         return val

#         return None

#     # --- Hard preference for Vision hits when Opex asks for a specific FY ---
#     try:
#         ql_pref = query.lower()
#         is_opex_pref = ("opex" in ql_pref) or ("operating expense" in ql_pref) or re.search(r"\bexpenses\b", ql_pref)
#         desired_periods_pref = _desired_periods_from_query(query)
#         explicit_fy_years = [yy for (yy, qq) in desired_periods_pref if yy and (qq is None)]
#         if is_opex_pref and explicit_fy_years:
#             yy = explicit_fy_years[0]
#             vision_for_year = [h for h in hits if "vision_summary" in str(h.get("section_hint") or "").lower() and h.get("year") == yy]
#             if vision_for_year:
#                 # Put those Vision hits first to be tried before any prose/table chunks
#                 rest = [h for h in hits if h not in vision_for_year]
#                 hits = vision_for_year + rest
#     except Exception:
#         pass

#     # Local rerank of hits to prefer structured/vision chunks for domain queries
#     ql = query.lower()
#     is_nim = ("nim" in ql) or ("net interest margin" in ql)
#     is_opex = ("opex" in ql) or ("operating expense" in ql) or re.search(r"\bexpenses\b", ql)
#     is_income = re.search(r"\b(total\s+income|operating\s+income)\b", ql) is not None

#     def _local_hit_score(h: dict) -> float:
#         sh = str(h.get("section_hint") or "").lower()
#         file_l = str(h.get("file") or "").lower()
#         s = 0.0

#         # Pull the actual text for content checks
#         pos = _pos_of_docid(h.get("doc_id", ""))
#         text_l = str(texts[pos]).lower() if pos is not None else ""

#         mentions_nim = ("net interest margin" in text_l) or re.search(r"\bnim\b", text_l) is not None
#         mentions_expenses = ("operating expenses" in text_l) or re.search(r"\bexpenses\b", text_l) is not None
#         mentions_total_income = re.search(r"\btotal\s+income\b", text_l) is not None
#         has_money_units = re.search(r"\(\$?\s*m\)|s\$\s*m|\(\$m\)|\bmillion\b|\bmn\b|\bbn\b|\bbillion\b", text_l, flags=re.I) is not None
#         mentions_percent = "%" in text_l

#         if "vision_summary" in sh:
#             s += 500.0
#         if sh.startswith("table_p"):
#             s += 30.0

#         # For NIM, demand the NIM phrase be present; otherwise heavily penalize
#         if is_nim:
#             if h.get("quarter") is not None:
#                 s += 20.0
#             if mentions_nim:
#                 s += 20.0
#                 if mentions_percent:
#                     s += 10.0
#             else:
#                 s -= 80.0  # do not allow non-NIM tables to outrank true NIM chunks

#         # For Opex-like asks, require expenses to be mentioned; favor money units
#         if is_opex:
#             if mentions_expenses:
#                 s += 20.0
#                 if has_money_units:
#                     s += 8.0
#             else:
#                 s -= 60.0  # push away tables/pages without expenses language
#             # Prefer structured sources over plain prose when scores tie
#             if sh == "prose":
#                 s -= 5.0
                
#         if is_income:
#             if "vision_summary" in sh:
#                 s += 60.0
#             if sh.startswith("table_p"):
#                 s += 25.0
#             if mentions_total_income:
#                 s += 20.0
#             else:
#                 s -= 40.0
#             if re.search(r"\(\$?\s*m\)|in\s*\$?\s*millions", text_l, flags=re.I):
#                 s += 6.0

#         # Deprioritize press/trading noise for numeric extractions
#         if re.search(r"press[_\s-]?statement|trading[_\s-]?update", file_l):
#             s -= 30.0

#         # fall back to hybrid score to break ties
#         s += float(h.get("score") or 0.0) * 0.01
#         return s

#     if is_nim or is_opex:
#         # Order: vision ‚Üí tables ‚Üí other, each block locally reranked
#         hits = (
#             sorted(vision_hits, key=_local_hit_score, reverse=True) +
#             sorted(table_hits,  key=_local_hit_score, reverse=True) +
#             sorted(other_hits,  key=_local_hit_score, reverse=True)
#         )

#     # Snapshot of the current hit ordering (useful for debugging/reuse in nested helpers)
#     _hits_snapshot = hits[:]

#     # Try the top-k hits in order until we successfully extract a plausible value
#     last_citation = None
#     for hit in hits:
#         pos = _pos_of_docid(hit["doc_id"])
#         if pos is None:
#             continue

#         text_content = str(texts[pos])
#         citation = f"Source: {format_citation(hit)}"
#         last_citation = citation

#         value = _extract_value(text_content, query)
#         if value is not None:
#             return f"Value: {value}, {citation}"

#     # If we got here, extraction failed for all hits
#     return f"Error: No numerical value found in the relevant document chunk. {last_citation or ''}"
  

# # --- Helper: Deterministic Opex 3-year baseline extractor ---

# # def answer_opex_3y_baseline() -> str:
# #     """
# #     Deterministic simple baseline for:
# #     'Show Operating Expenses for the last 3 fiscal years.'
# #     Uses the KB to pick the latest 3 FYs present, then calls table_extraction per FY.
# #     """
# #     # 1) find latest 3 FYs available in KB (prefer annual docs)
# #     df = kb.copy()
# #     df["y"] = pd.to_numeric(df["year"], errors="coerce")
# #     ydf = df[df["quarter"].isna()].dropna(subset=["y"]).sort_values("y", ascending=False)
# #     if ydf.empty:
# #         ydf = df.dropna(subset=["y"]).sort_values("y", ascending=False)
# #     years = [int(y) for y in ydf["y"].drop_duplicates().head(3)]
# #     if not years:
# #         return "No fiscal years found in KB."

# #     # 2) extract Opex per FY using the robust extractor
# #     rows, cites = [], []
# #     for y in years:
# #         r = tool_table_extraction(f"Operating expenses for fiscal year {y}")
# #         # Expected: "Value: 8895, Source: <citation>" or "Error: ..."
# #         m = re.search(r"Value:\s*([0-9][\d\.]*)\s*,\s*Source:\s*(.*)", r)
# #         if m:
# #             val = m.group(1)
# #             src = m.group(2)
# #             rows.append((y, val))
# #             cites.append(src)
# #         else:
# #             rows.append((y, "‚Äî"))
# #             cites.append(r)

# #     # 3) render a tiny table with YoY% and citations
# #     # rows is a list of tuples: [(year, value_str_or_dash), ...]
# #     rows.sort(key=lambda t: t[0], reverse=True)  # ensure FY2024, FY2023, FY2022 order

# #     def _fmt_m(x: str) -> str:
# #         try:
# #             return f"{float(x):,.0f}"
# #         except Exception:
# #             return x  # return as-is if not a number (e.g., "‚Äî")

# #     out = [
# #         "Opex (S$ m) ‚Äî last 3 fiscal years:",
# #         "Year   | Opex (S$ m) | YoY %",
# #         "-------|-------------|------",
# #     ]

# #     for i, (yy, vv) in enumerate(rows):
# #         yoy = ""
# #         if i > 0 and rows[i-1][1] not in ("‚Äî", "", None) and vv not in ("‚Äî", "", None):
# #             try:
# #                 cur = float(vv)
# #                 prev = float(rows[i-1][1])
# #                 yoy = f"{((cur - prev) / prev) * 100:,.1f}%"
# #             except Exception:
# #                 yoy = ""
# #         out.append(f"{yy} | {_fmt_m(vv) if vv != '‚Äî' else vv} | {yoy}")

# #     out.append("\nCitations:")
# #     seen = set()
# #     for c in cites:
# #         if c not in seen:
# #             seen.add(c)
# #             out.append(f"- {c}")
# #         if len(seen) >= 3:
# #             break
# #     return "\n".join(out)
# def tool_nim_series(last_n: int = 5, variant: str = "group") -> str:
#     """
#     Extract the last N quarters of Net Interest Margin (Group or Commercial Book).
#     Retrieval: FAISS (semantic) + BM25 (keyword) hybrid via hybrid_search().
#     Parsing priority: Vision summaries (nim_analysis-style lines), then structured tables,
#     then generic 'quarter ‚Üí %' mentions anchored to NIM.
#     """
#     # --- 1) Gather a broader candidate pool (multiple queries) ---
#     queries = [
#         "Net interest margin (%)",
#         "NIM (%)",
#         "Group Net Interest Margin quarterly",
#         "Commercial book Net Interest Margin (%)",
#         "Net interest margin group commercial"
#     ]
#     hits: List[Dict[str, Any]] = []
#     seen_doc_ids = set()
#     for q in queries:
#         for h in hybrid_search(q, top_k=40):
#             did = h.get("doc_id")
#             if did not in seen_doc_ids:
#                 seen_doc_ids.add(did)
#                 hits.append(h)

#     # Always include any vision_summary chunks (often hold clean 'For 2Q24, Group NIM was 2.13%' lines)
#     try:
#         sh_series = kb["section_hint"].astype(str).str.contains("vision_summary", case=False, na=False)
#         vis_idxs = np.flatnonzero(sh_series.to_numpy())
#         base_score = (min([float(h.get("score") or 0.0) for h in hits]) - 1.0) if hits else 0.0
#         for idx in vis_idxs[:20]:
#             row = kb.iloc[idx]
#             did = row.doc_id
#             if did in seen_doc_ids:
#                 continue
#             seen_doc_ids.add(did)
#             hits.append({
#                 "doc_id": row.doc_id,
#                 "file": row.file,
#                 "page": int(row.page) if pd.notna(row.page) else None,
#                 "year": int(row.year) if pd.notna(row.year) else None,
#                 "quarter": int(row.quarter) if pd.notna(row.quarter) else None,
#                 "section_hint": row.section_hint,
#                 "score": base_score
#             })
#     except Exception:
#         pass

#     # --- Helper: fetch raw text for a hit ---
#     def _pos_of_docid(did: str) -> Optional[int]:
#         mask = (kb["doc_id"] == did).to_numpy()
#         idxs = np.flatnonzero(mask)
#         return int(idxs[0]) if idxs.size else None

#     # --- Helper: plausibility filter for NIM values (in %) ---
#     def _nim_ok(x: float) -> bool:
#         try:
#             xf = float(x)
#         except Exception:
#             return False
#         return 0.5 <= xf <= 3.5

#     # --- 2) Parse points: map ("2Q25","group|commercial") ‚Üí value ---
#     from typing import Tuple
#     points: Dict[Tuple[str, str], float] = {}

#     # Order candidates: vision ‚Üí tables ‚Üí other
#     vision_hits = [h for h in hits if "vision_summary" in str(h.get("section_hint") or "").lower()]
#     table_hits  = [h for h in hits if str(h.get("section_hint") or "").lower().startswith("table_p")]
#     other_hits  = [h for h in hits if h not in vision_hits and h not in table_hits]
#     ordered = vision_hits + table_hits + other_hits

#     # --- 3) Parsing routines ---
#     re_qtr  = re.compile(r"\b([1-4]Q\d{2})\b", flags=re.I)
#     re_pct  = re.compile(r"([0-9]+(?:\.[0-9]+)?)\s*%")
#     re_num  = re.compile(r"([0-9]+(?:\.[0-9]+)?)")  # for tables where % sign is omitted
#     re_nim_phrase = re.compile(r"net\s*interest\s*margin|\bnim\b", flags=re.I)

#     def _maybe_add(qlabel: str, who: str, val: float):
#         who_norm = "commercial" if "commercial" in who.lower() else "group"
#         key = (qlabel.upper(), who_norm)
#         if _nim_ok(val) and key not in points:
#             points[key] = float(val)

#     for h in ordered:
#         pos = _pos_of_docid(h.get("doc_id", ""))
#         if pos is None:
#             continue
#         text = str(texts[pos])

#         # Skip chunks that don't obviously mention NIM to avoid 5% from unrelated places
#         if not re_nim_phrase.search(text):
#             continue

#         # (A) Vision-style lines from g1.format_vision_json_to_text
#         for m in re.finditer(
#             r"For\s+([1-4]Q\d{2}),\s+the\s+(Group|Commercial(?:\s*book)?)\s+Net\s+Interest\s+Margin.*?([0-9]+(?:\.[0-9]+)?)\s*%",
#             text, flags=re.I
#         ):
#             qlabel, who, val = m.group(1), m.group(2), float(m.group(3))
#             _maybe_add(qlabel, who, val)

#         # (B) Markdown table row like: "| Net interest margin (%) | 2Q25 | 1Q25 | ...\n| ... | 2.61 | 2.70 | ..."
#         lines = text.splitlines()
#         header_quarters: Optional[List[str]] = None
#         for li, line in enumerate(lines):
#             # Update current header_quarters if this line looks like a quarter header row
#             q_in_line = re_qtr.findall(line.upper())
#             if len(q_in_line) >= 2:
#                 header_quarters = q_in_line

#             if re.search(r"net\s*interest\s*margin|\bnim\b", line, flags=re.I):
#                 # 1) Same-line values (e.g., '| Net interest margin (%) | 2.61 | 2.70 | ...')
#                 vals_inline = [float(x) for x in re_num.findall(line) if _nim_ok(x)]
#                 if header_quarters and len(vals_inline) >= len(header_quarters):
#                     for ql, v in zip(header_quarters, vals_inline[:len(header_quarters)]):
#                         _maybe_add(ql, "group", float(v))

#                 # 2) Next-line values (common in markdown tables: headers then a metrics row on the next line)
#                 if li + 1 < len(lines):
#                     nxt = lines[li + 1]
#                     vals_next = [float(x) for x in re_num.findall(nxt) if _nim_ok(x)]
#                     if header_quarters and len(vals_next) >= len(header_quarters):
#                         for ql, v in zip(header_quarters, vals_next[:len(header_quarters)]):
#                             _maybe_add(ql, "group", float(v))

#         # (C) Generic anchored fallback:
#         # For each quarter mention, search a short window to the right for a plausible % or number.
#         # Expand the window to 160 chars to capture "‚Ä¶ 2Q25 ‚Ä¶ NIM ‚Ä¶ 2.61%".
#         for m in re.finditer(r"([1-4]Q\d{2})", text, flags=re.I):
#             span_end = min(len(text), m.end() + 160)
#             window = text[m.start():span_end]
#             if not re_nim_phrase.search(window):
#                 continue
#             m_pct = re_pct.search(window)
#             if m_pct:
#                 val = float(m_pct.group(1))
#                 if _nim_ok(val):
#                     _maybe_add(m.group(1), "group", val)
#                     continue
#             # If % sign omitted in tables, allow a plain number in plausible range
#             m_num = re_num.search(window)
#             if m_num:
#                 try:
#                     val = float(m_num.group(1))
#                 except Exception:
#                     val = None
#                 if val is not None and _nim_ok(val):
#                     _maybe_add(m.group(1), "group", val)

#     # --- 4) Keep only the requested variant & take most recent N points ---
#     series = []
#     for (qlabel, who), val in points.items():
#         if (variant == "group" and who == "group") or (variant != "group" and who != "group"):
#             qnum = int(qlabel[0])
#             yy = int(qlabel[2:])
#             year = 2000 + yy
#             series.append((year, qnum, qlabel.upper(), float(val)))

#     if not series:
#         return "Error: No NIM values found."

#     series.sort(key=lambda t: (t[0], t[1]), reverse=True)
#     take = max(1, int(last_n or 5))
#     series = series[:take]

#     formatted = ", ".join(f"{ql}: {v:.2f}%" for (_, _, ql, v) in series)
#     who_title = "Group" if variant == "group" else "Commercial Book"
#     return f"NIM ({who_title}) last {len(series)} quarters ‚Üí {formatted}"
  
# def tool_multi_document_compare(topic: str, files: list[str]) -> str:
#     results = []
#     for file_name in files:
#         hits = hybrid_search(f"{topic} in file {file_name}", top_k=2)
#         file_hits = [h for h in hits if h.get('file') == file_name]
#         if file_hits:
#             top_hit = file_hits[0]
#             citation = format_citation(top_hit)
#             text_content = texts[kb.index[kb['doc_id'] == top_hit['doc_id']][0]]
#             results.append(f"Source: [{citation}]\nContent: {text_content[:800]}")
#         else:
#             results.append(f"Source: {file_name}\nContent: No relevant information found.")
#     return "\n---\n".join(results)

# def _compile_or_repair_plan(query: str, plan: list[dict]) -> list[dict]:
#     def _has_params(step: dict) -> bool:
#         params = step.get("parameters")
#         return isinstance(params, dict) and any(v not in (None, "", []) for v in params.values())

#     if plan and all(_has_params(s) for s in plan):
#         return plan

#     qtype = _classify_query(query)
#     want_years  = _detect_last_n_years(query)
#     want_quarts = _detect_last_n_quarters(query)
    
#     df = kb.copy()
#     df["y"] = pd.to_numeric(df["year"], errors="coerce")
#     df["q"] = pd.to_numeric(df["quarter"], errors="coerce")
#     steps: list[dict] = []

#     if qtype == "nim":
#         n = want_quarts or 5
#         steps.append({
#             "step": f"Extract last {n} quarters of NIM (group)",
#             "tool": "nim_series",
#             "parameters": {"last_n": n, "variant": "group"},
#             "store_as": f"nim_series_last_{n}"
#         })
#         return steps

#     if qtype == "opex":
#         # If the user asked for a specific fiscal year (e.g., "FY2024" or "fiscal year 2024"),
#         # do a single extraction for that year and STOP. Do not add YoY steps.
#         periods = _desired_periods_from_query(query)
#         explicit_fy = [y for (y, q) in periods if y and (q is None)]
#         if explicit_fy:
#             y = int(explicit_fy[0])
#             steps.append({
#                 "step": f"Extract Opex for FY{y}",
#                 "tool": "table_extraction",
#                 "parameters": {"query": f"Operating expenses for fiscal year {y}"},
#                 "store_as": f"opex_fy{y}"
#             })
#             return steps

#         # Otherwise, assume a multi‚Äëyear ask. Default to the last 3 fiscal years and include a YoY calc.
#         n = want_years or 3
#         df_local = kb.copy()
#         df_local["y"] = pd.to_numeric(df_local["year"], errors="coerce")
#         df_local["q"] = pd.to_numeric(df_local["quarter"], errors="coerce")

#         ydf = df_local[df_local["q"].isna()].dropna(subset=["y"]).sort_values("y", ascending=False)
#         if ydf.empty:
#             ydf = df_local.dropna(subset=["y"]).sort_values("y", ascending=False)

#         years = [int(y) for y in ydf["y"].drop_duplicates().head(n)]
#         for y in years:
#             steps.append({
#                 "step": f"Extract Opex for FY{y}",
#                 "tool": "table_extraction",
#                 "parameters": {"query": f"Operating expenses for fiscal year {y}"},
#                 "store_as": f"opex_fy{y}"
#             })
#         if len(years) >= 2:
#             y0, y1 = years[0], years[1]
#             steps.append({
#                 "step": f"Compute YoY % change FY{y0} vs FY{y1}",
#                 "tool": "calculator",
#                 "parameters": {"expression": f"((${{opex_fy{y0}}} - ${{opex_fy{y1}}}) / ${{opex_fy{y1}}}) * 100"},
#                 "store_as": f"opex_yoy_{y0}_{y1}"
#             })
#         return steps
    
#     if qtype == "oer":
#         n = want_years or 3
#         ydf = df[df["q"].isna()].dropna(subset=["y"]).sort_values("y", ascending=False)
#         if ydf.empty: ydf = df.dropna(subset=["y"]).sort_values("y", ascending=False)
#         years = [int(y) for y in ydf["y"].drop_duplicates().head(n)]
#         for y in years:
#             steps.append({ "step": f"Extract Opex for FY{y}", "tool": "table_extraction", "parameters": {"query": f"Operating expenses for fiscal year {y}"}, "store_as": f"opex_fy{y}"})
#             steps.append({ "step": f"Extract Operating Income for FY{y}", "tool": "table_extraction", "parameters": {"query": f"Total income for fiscal year {y}"}, "store_as": f"income_fy{y}"})
#             steps.append({ "step": f"Compute OER for FY{y}", "tool": "calculator", "parameters": {"expression": f"(${{opex_fy{y}}} / ${{income_fy{y}}}) * 100"}, "store_as": f"oer_fy{y}"})
#         return steps
    
#     return [{"step": "Extract relevant figure", "tool": "table_extraction", "parameters": {"query": query}, "store_as": "value_1"}]

# def answer_with_agent(query: str, dry_run: bool = False) -> Dict[str, Any]:
#     _ensure_init()
#     execution_log = []
    
#     planning_prompt = f"""You are a financial analyst agent. Create a JSON plan to answer the user's query.
# Tools Available:
# - `table_extraction(query: str)`: Finds a single reported data point.
# - `calculator(expression: str)`: Calculates a math expression.
# User Query: "{query}"
# Return ONLY a valid JSON object with a "plan" key."""
#     if VERBOSE: print("[Agent] Step 1: Generating execution plan...")
    
#     plan_response = _call_llm(planning_prompt, dry_run)
#     plan = []
    
#     if dry_run:
#         plan = _compile_or_repair_plan(query, [])
#         answer = f"DRY RUN MODE: The agent generated the following plan and stopped before execution.\n\n{json.dumps(plan, indent=2)}"
#         return {"answer": answer, "hits": [], "execution_log": [{"step": "Planning", "plan": plan}]}

#     try:
#         json_match = re.search(r'```json\s*(\{.*?\})\s*```', plan_response, re.DOTALL)
#         plan_str = json_match.group(1) if json_match else plan_response
#         plan = json.loads(plan_str)["plan"]
#         execution_log.append({"step": "Planning", "plan": plan})
#         if VERBOSE: print("[Agent] Plan generated successfully.")
#     except Exception:
#         if VERBOSE: print("[Agent] LLM failed to generate valid plan. Using deterministic repair.")
#         plan = []

#     plan = _compile_or_repair_plan(query, plan)
#     if not execution_log or "repaired_plan" not in execution_log[0]:
#         execution_log.insert(0, {"step": "PlanRepair", "repaired_plan": plan})
    
#     if VERBOSE: print("[Agent] Step 2: Executing plan...")
#     tool_mapping = {
#         "calculator": tool_calculator,
#         "table_extraction": tool_table_extraction,
#         "multi_document_compare": tool_multi_document_compare,
#         "nim_series": tool_nim_series
#     }
#     execution_state = {}
    
#     for i, step in enumerate(plan):
#         tool = step.get("tool")
#         params = step.get("parameters", {}).copy() # Use copy to avoid modifying plan dict
#         store_as = step.get("store_as")

#         for p_name, p_value in params.items():
#             if isinstance(p_value, str):
#                 for var_name, var_value in execution_state.items():
#                     p_value = p_value.replace(f"${{{var_name}}}", str(var_value))
#             params[p_name] = p_value
        
#         try:
#             if tool not in tool_mapping:
#                 raise ValueError(f"Tool '{tool}' not found.")
            
#             result = tool_mapping[tool](**params)
#             execution_log.append({"step": f"Execution {i+1}", "tool_call": f"{tool}({params})", "result": result})
            
#             if store_as:
#                 val_for_state = result # Default to full result
#                 m_calc = re.search(r'Result:\s*([-\d\.]+e?[-\d]*)', result, re.I)
#                 if m_calc: val_for_state = m_calc.group(1)
                
#                 m_val = re.search(r'Value:\s*([^,]+)', result, re.I)
#                 if m_val: val_for_state = m_val.group(1).strip()

#                 execution_state[store_as] = val_for_state

#         except Exception as e:
#             execution_log.append({"step": f"Execution {i+1}", "tool_call": f"{tool}({params})", "error": traceback.format_exc()})

#     if VERBOSE: print("[Agent] Step 3: Synthesizing final answer...")
#     synthesis_prompt = f"""You are Agent CFO. Provide a final answer to the user's query based ONLY on the provided Tool Execution Log.
# User Query: "{query}"
# Tool Execution Log:
# {json.dumps(execution_log, indent=2)}
# Final Answer:"""
#     final_answer = _call_llm(synthesis_prompt)
    
#     return {"answer": final_answer, "hits": [], "execution_log": execution_log}

# def get_logs():
#     return instr.df()

# # if __name__ == "__main__":
# #     import sys, subprocess, importlib, os
# #     os.environ["TOKENIZERS_PARALLELISM"] = "false"

# #     # Auto-install missing deps
# #     def _pip(pkg):
# #         try:
# #             importlib.import_module(pkg)
# #         except Exception:
# #             subprocess.check_call([sys.executable, "-m", "pip", "install", pkg])

# #     for p in ["openai", "rank_bm25", "faiss-cpu"]:
# #         _pip(p)

# #     # Groq config (read from env; do NOT hardcode secrets)
# #     os.environ.setdefault("LLM_PROVIDER", "groq")
# #     os.environ.setdefault("GROQ_MODEL", "openai/gpt-oss-20b")
# #     if not os.getenv("GROQ_API_KEY"):
# #         print("‚ö†Ô∏è  GROQ_API_KEY not set. Please set it in your environment before running.")
    
# #     # Initialize Stage-2 and run the deterministic Opex baseline
# #     init_stage2("data")
# #     query = "Show Operating Expenses for the last 3 fiscal years"
# #     print(f"‚Üí Query: {query}\n")
# #     print(answer_opex_3y_baseline())

#     # from __future__ import annotations

# """
# Stage3.py ‚Äî Benchmark Runner (Stage 3)

# Runs the 3 standardized queries for both the baseline and agentic pipelines,
# times them, saves JSON/Markdown reports, and prints prose answers with citations.

# Artifacts written to OUT_DIR (default: data/):
#   - bench_results_baseline.json / bench_results_agent.json
#   - bench_report_baseline.md / bench_report_agent.md
# """
# import os, json, time, inspect
# from typing import List, Dict, Any

# import pandas as pd

# # Explicitly import Stage-2 entrypoints so we don't rely on globals
# # from g2 import init_stage2, answer_with_llm_baseline as answer_with_llm, answer_with_agent

# OUT_DIR = os.environ.get("AGENT_CFO_OUT_DIR", "data")

# # --- Standardized queries (exact spec) ---
# QUERIES: List[str] = [
#     # 1) NIM trend over last 5 quarters
#     "Report the Gross Margin (or Net Interest Margin, if a bank) over the last 5 quarters, with values."
#     # # 2) Opex YoY table only (absolute & % change)
#     # "Show Operating Expenses for the last 3 fiscal years, year-on-year comparison.",
#     # # 3) Operating Efficiency Ratio (Opex √∑ Operating Income) with working
#     # "Calculate the Operating Efficiency Ratio (Opex √∑ Operating Income) for the last 3 fiscal years, showing the working."
# ]


# # --- Helper functions for answer call and output normalization ---
# def _call_answer(func, query: str, dry_run: bool):
#     """Call answer function with optional dry_run if supported."""
#     try:
#         params = inspect.signature(func).parameters
#     except Exception:
#         params = {}
#     kwargs = {}
#     if 'dry_run' in params:
#         kwargs['dry_run'] = dry_run
#     return func(query, **kwargs)

# def _normalize_out(res) -> Dict[str, Any]:
#     """Coerce answer result to a dict with keys: answer, hits, execution_log."""
#     if isinstance(res, str):
#         return {"answer": res, "hits": [], "execution_log": None}
#     if isinstance(res, dict):
#         ans = res.get("answer") or res.get("Answer") or str(res)
#         hits = res.get("hits") or res.get("Hits") or []
#         log  = res.get("execution_log") or res.get("ExecutionLog")
#         return {"answer": ans, "hits": hits, "execution_log": log}
#     return {"answer": str(res), "hits": [], "execution_log": None}


# def _format_hits(hits: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
#     """Helper to format citation hits for JSON output."""
#     out = []
#     if not hits: return out
#     for h in hits:
#         out.append({
#             "file": h.get("file"),
#             "year": h.get("year"),
#             "quarter": h.get("quarter"),
#             "page": h.get("page"),
#             "section_hint": h.get("section_hint"),
#         })
#     return out



# def run_benchmark(
#     print_prose: bool = True,
#     use_agent: bool = False,
#     out_dir: str = OUT_DIR,
#     dry_run: bool = False  # <-- NEW TOGGLE
# ) -> Dict[str, Any]:
#     """
#     Runs the benchmark for either the baseline RAG or the agentic pipeline.
    
#     Args:
#         print_prose: Whether to print results to the console.
#         use_agent: If True, uses answer_with_agent. If False, uses answer_with_llm.
#         out_dir: The directory to save report files.
#         dry_run: If True, prints prompts instead of calling the LLM API.
#     """
#     # Guard: this module is intentionally NOT importing Stage 2.
#     # The caller/notebook must `import g2` first so that the following names
#     # are available in the global namespace.
#     if use_agent and 'answer_with_agent' not in globals():
#         raise RuntimeError("answer_with_agent is not defined. Import Stage 2 (g2) in the caller before running Stage 3.")
#     if not use_agent and 'answer_with_llm' not in globals():
#         raise RuntimeError("answer_with_llm is not defined. Import Stage 2 (g2) in the caller before running Stage 3.")

#     os.makedirs(out_dir, exist_ok=True)
    
#     if use_agent:
#         mode_name = "agent"
#         answer_func = answer_with_agent
#         print("\n" + "="*25 + f" RUNNING AGENT BENCHMARK " + "="*25)
#     else:
#         mode_name = "baseline"
#         answer_func = answer_with_llm
#         print("\n" + "="*24 + f" RUNNING BASELINE BENCHMARK " + "="*24)
    
#     if dry_run:
#         print("--- üî¨ DRY RUN MODE IS ON ---")

#     json_path = os.path.join(out_dir, f"bench_results_{mode_name}.json")
#     md_path = os.path.join(out_dir, f"bench_report_{mode_name}.md")

#     results: List[Dict[str, Any]] = []
#     latency_rows = []

#     for q in QUERIES:
#         t0 = time.perf_counter()
#         raw = _call_answer(answer_func, q, dry_run=dry_run)
#         out = _normalize_out(raw)
#         lat_ms = round((time.perf_counter() - t0) * 1000.0, 2)

#         if print_prose:
#             print(f"\n=== Question ===\n{q}")
#             print("\n--- Answer ---\n")
#             print(str(out["answer"]).strip())
#             if out.get("hits"):
#                 print("\n--- Citations (top ctx) ---")
#                 for h in _format_hits(out.get("hits", [])):
#                     y = f" {int(h['year'])}" if h.get('year') is not None else ""
#                     qtr_val = h.get('quarter')
#                     qtr = f" {int(qtr_val)}Q{str(y).strip()[2:]}" if qtr_val else ""
#                     sec = f" ‚Äî {h['section_hint']}" if h.get('section_hint') else ""
#                     print(f"- {h['file']}{y}{qtr} ‚Äî p.{h['page']}{sec}")
#             print(f"\n(latency: {lat_ms} ms)")

#         results.append({
#             "query": q,
#             "answer": out.get("answer"),
#             "hits": _format_hits(out.get("hits", [])),
#             "execution_log": out.get("execution_log"),
#             "latency_ms": lat_ms,
#         })
#         latency_rows.append({"Query": q, "Latency_ms": lat_ms})

#     # Saving logic remains the same...
#     with open(json_path, "w") as f:
#         json.dump({"results": results}, f, indent=2)

#     md_lines = [f"# Agent CFO ‚Äî {mode_name.title()} Benchmark Report\n"]
#     for i, r in enumerate(results, start=1):
#         md_lines.append(f"\n---\n\n## Q{i}. {r['query']}")
#         md_lines.append("\n**Answer**\n\n" + r["answer"].strip())
#         if r.get("hits"):
#             md_lines.append("\n**Citations (top ctx)**")
#             for h in r["hits"]:
#                 y = f" {int(h['year'])}" if h.get('year') is not None else ""
#                 qtr_val = h.get('quarter')
#                 qtr = f" {int(qtr_val)}Q{str(y).strip()[2:]}" if qtr_val else ""
#                 sec = f" ‚Äî {h['section_hint']}" if h.get('section_hint') else ""
#                 md_lines.append(f"- {h['file']}{y}{qtr} ‚Äî p.{h['page']}{sec}")
#         if r.get("execution_log"):
#             md_lines.append("\n**Execution Log**\n")
#             md_lines.append("```json")
#             md_lines.append(json.dumps(r["execution_log"], indent=2))
#             md_lines.append("```")

#     with open(md_path, "w") as f:
#         f.write("\n".join(md_lines) + "\n")

#     df = pd.DataFrame(latency_rows)
#     if print_prose and not df.empty:
#         p50 = float(df['Latency_ms'].quantile(0.5))
#         p95 = float(df['Latency_ms'].quantile(0.95))
#         print(f"\n=== {mode_name.upper()} Benchmark Summary ===")
#         print(f"Saved JSON: {json_path}")
#         print(f"Saved report: {md_path}")
#         print(f"Latency p50: {p50:.1f} ms, p95: {p95:.1f} ms")

#     return {"json_path": json_path, "md_path": md_path, "summary": df}


# #########################################################################333

# #!/usr/bin/env python3
# # -*- coding: utf-8 -*-

# """
# g3x.py ‚Äî Task runner over your FAISS/Marker KB (agentic tools) + optional ONLINE LLM answers

# This runs 3 specific analyses using the tools/agent from g2x.py:

#   1) NIM trend over last 5 quarters
#      -> "Report the Gross Margin (or Net Interest Margin, if a bank) over the last 5 quarters, with values."
#   2) Operating Expenses YoY table (absolute & % change) for last 3 fiscal years
#      -> "Show Operating Expenses for the last 3 fiscal years, year-on-year comparison."
#   3) Operating Efficiency Ratio (Opex √∑ Operating Income) with working
#      -> "Calculate the Operating Efficiency Ratio (Opex √∑ Operating Income) for the last 3 fiscal years, showing the working."

# All offline. Import and run from a notebook cell:
#     from g3x import run_all
#     run_all(base="./data_marker")
# """

# from typing import Dict, List, Optional, Tuple
# import math
# import re
# import os
# from g2x import KBEnv, Agent, show_agent_result, _llm_single_call, baseline_answer_one_call, _llm_provider_info
# # Feature flag for LLM summaries (set USE_LLM_SUMMARY=0/false in env to disable)
# USE_LLM_SUMMARY = os.getenv("USE_LLM_SUMMARY", "1") not in ("0", "false", "False")
# # ONLINE flag for baseline LLM calls (set ONLINE=0/false in env to disable)
# ONLINE = os.getenv("ONLINE", "1") not in ("0", "false", "False")

# # ---------- helpers ----------

# def _llm_summary(
#     question: str,
#     agent: Agent,
#     kb: KBEnv,
#     res=None,
#     k_ctx: int = 8,
#     rows_override: Optional[List[dict]] = None
# ) -> str:
#     """One LLM call to summarize/answer using extracted tables if present, else vector contexts."""
#     lines = []
#     # Prefer table rows from override if provided, else from the result
#     rows = rows_override if rows_override is not None else []
#     if not rows and res and getattr(res, 'final', None):
#         rows = res.final.get("table_rows") or []
#     if rows:
#         lines.append("TABLE EXTRACTS:")
#         for r in rows[:2]:
#             # prefer quarters if any
#             sq = r.get("series_q") or {}
#             if sq:
#                 # sort quarters
#                 def _qkey(k):
#                     m = re.match(r"([1-4])Q(20\d{2})$", k)
#                     return (int(m.group(2)), int(m.group(1))) if m else (0,0)
#                 qkeys = sorted(sq.keys(), key=_qkey)[-5:]
#                 ser = ", ".join(f"{k}: {sq[k]}" for k in qkeys)
#                 lines.append(f"- {r['doc']} | {r['label']} | quarters(last5)={ser}")
#             else:
#                 ys = sorted((r.get("series") or {}).keys())[-3:]
#                 ser = ", ".join(f"{y}: {r['series'][y]}" for y in ys)
#                 lines.append(f"- {r['doc']} | {r['label']} | years(last3)={ser}")
#     # If nothing extracted, fall back to vector contexts
#     if not lines:
#         ctx = kb.search(question, k=k_ctx)
#         if ctx is not None and not ctx.empty:
#             lines.append("CONTEXT SNIPPETS:")
#             for _, row in ctx.head(5).iterrows():
#                 text = str(row["text"]).replace("\n", " ").strip()
#                 if len(text) > 600:
#                     text = text[:600] + "..."
#                 lines.append("- " + text)
#     # Provide page-level hints for better citations
#     if rows:
#         hint_lines = []
#         for r in rows[:4]:
#             p = r.get('page')
#             if p is not None:
#                 hint_lines.append(f"- {r.get('doc')}, page {int(p)}")
#             else:
#                 hint_lines.append(f"- {r.get('doc')}, table {r.get('table_id')} row {r.get('row_id')} (no page)")
#         if hint_lines:
#             lines.append("CITATION HINTS:")
#             lines.extend(hint_lines)
#     # Build prompt
#     context_block = "\n".join(lines) if lines else "(no structured context found)"
#     prompt = (
#         "USER QUESTION:\n" + question + "\n\n" +
#         context_block +
#         "\n\nINSTRUCTIONS:\n"
#         "- You are given STRUCTURED TABLE ROWS and/or CONTEXT SNIPPETS above.\n"
#         "- If STRUCTURED TABLE ROWS are present, you MUST use ONLY those numbers for your answer and calculations.\n"
#         "- Do NOT claim data is missing if the numbers are present in the structured rows.\n"
#         "- If the task asks for 'Operating Income' but the rows contain 'Total income' only, TREAT 'Total income' as the denominator for Operating Efficiency Ratio.\n"
#         "- If a requested period truly does not appear in the structured rows, say so explicitly and do not infer.\n"
#         "- Return a concise answer, followed by a tiny table if applicable."
#     )
#     print(f"[LLM] summary using {_llm_provider_info()}")
#     return _llm_single_call(prompt)

# # ---------- helpers ----------

# def _last_n_quarters(series_q: Dict[str, float], n: int = 5) -> List[Tuple[str, float]]:
#     if not series_q:
#         return []
#     def _qkey(k: str):
#         m = re.match(r"([1-4])Q(20\d{2})$", k)
#         if m:
#             return (int(m.group(2)), int(m.group(1)))
#         return (0, 0)
#     keys = sorted(series_q.keys(), key=_qkey)
#     last = keys[-n:]
#     return [(k, series_q[k]) for k in last]

# def _last_n_years(series: Dict[int, float], n: int = 3) -> List[Tuple[int, float]]:
#     if not series:
#         return []
#     ys = sorted(series.keys())
#     sel = ys[-n:]
#     return [(y, series[y]) for y in sel]

# def _pct(a: float, b: float) -> Optional[float]:
#     b = float(b)
#     if b == 0:
#         return None
#     return (float(a) - b) / b * 100.0

# def _union_series(rows):
#     """
#     Merge {year->value} across many table rows from different docs and
#     return (values, provenance) where provenance maps each year to a list
#     of sources that contributed that year's value:
#         provenance[year] = [{"doc":..., "table_id":..., "row_id":..., "page": ...}, ...]
#     The first non-null value encountered for a year is kept as the value.
#     """
#     values = {}
#     prov = {}
#     for r in rows or []:
#         doc = r.get("doc")
#         tid = r.get("table_id")
#         rid = r.get("row_id")
#         page = r.get("page")
#         series = r.get("series") or {}
#         for y, v in series.items():
#             if v is None:
#                 continue
#             # record provenance regardless
#             prov.setdefault(y, []).append({
#                 "doc": doc, "table_id": tid, "row_id": rid, "page": page
#             })
#             # keep the first seen value for this year
#             if y not in values:
#                 values[y] = v
#     return values, prov

# def _last_n_years_map(series_map, n: int = 3):
#     ys = sorted(series_map.keys())
#     sel = ys[-n:]
#     return [(y, series_map[y]) for y in sel]

# # Helper to pick a representative source for a year
# def _pick_source_for_year(prov_map, y):
#     """
#     Choose one representative source dict for a given year
#     from the provenance map, preferring entries with a page number.
#     """
#     items = prov_map.get(y) or []
#     if not items:
#         return None
#     with_page = [s for s in items if s.get("page") is not None]
#     return (with_page[0] if with_page else items[0])

# # ---------- Q1: NIM last 5 quarters ----------

# def run_q1_nim_last5q(agent: Agent, kb: KBEnv):
#     q = "Net Interest Margin over the last 5 quarters"
#     res = agent.run(q, k_ctx=6)
#     print("\n=== Q1) Net Interest Margin ‚Äî last 5 quarters ===")
#     # Try table rows with quarters
#     rows = res.final.get("table_rows") or []
#     picked = None
#     for r in rows:
#         if r.get("series_q"):
#             picked = r
#             break
#     if not picked:
#         print("‚ö†Ô∏è No quarterly NIM found in indexed tables.")
#         # fall back to annual if available
#         for r in rows:
#             if r.get("series"):
#                 years = _last_n_years(r["series"], n=3)
#                 print("Fallback (years):", ", ".join(f"{y}: {v}" for y, v in years))
#                 break
#         # LLM summary even if not found
#         if USE_LLM_SUMMARY:
#             print("\nLLM Summary (baseline, single call):")
#             print(_llm_summary(q, agent, kb, res=res, k_ctx=8, rows_override=([picked] if picked else rows)))
#         if ONLINE:
#             print("\nLLM Answer (online, single call):")
#             print(f"[LLM] baseline using {_llm_provider_info()}")
#             tr = ([picked] if picked else rows)
#             baseline_answer_one_call(kb, q, k_ctx=8, table_rows=tr)
#         return res
#     last5 = _last_n_quarters(picked["series_q"], n=5)
#     if not last5:
#         print("‚ö†Ô∏è No quarterly NIM found in indexed tables.")
#         if USE_LLM_SUMMARY:
#             print("\nLLM Summary (baseline, single call):")
#             print(_llm_summary(q, agent, kb, res=res, k_ctx=8))
#         if ONLINE:
#             print("\nLLM Answer (online, single call):")
#             print(f"[LLM] baseline using {_llm_provider_info()}")
#             tr = ([picked] if picked else rows)
#             baseline_answer_one_call(kb, q, k_ctx=8, table_rows=tr)
#         return res
#     print(f"Source: {picked['doc']} | label: {picked['label']}")
#     print("Values (last 5): " + ", ".join(f"{k}: {v}" for k, v in last5))
#     if USE_LLM_SUMMARY:
#         print("\nLLM Summary (baseline, single call):")
#         print(_llm_summary(q, agent, kb, res=res, k_ctx=8, rows_override=([picked] if picked else rows)))
#     if ONLINE:
#         print("\nLLM Answer (online, single call):")
#         print(f"[LLM] baseline using {_llm_provider_info()}")
#         tr = ([picked] if picked else rows)
#         baseline_answer_one_call(kb, q, k_ctx=8, table_rows=tr)
#     return res

# # ---------- Q2: Opex last 3 fiscal years with YoY ----------

# def run_q2_opex_yoy(agent: Agent, kb: KBEnv):
#     q = "Operating Expenses last 3 fiscal years YoY"
#     res = agent.run(q, k_ctx=6)
#     print("\n=== Q2) Operating Expenses ‚Äî last 3 fiscal years (YoY) ===")

#     # Pull MANY rows then union across docs/tables to recover a continuous series
#     rows = agent.table.get_metric_rows("operating expenses", limit=50)
#     if not rows:
#         rows = agent.table.get_metric_rows("total expenses", limit=50)

#     combo, prov = _union_series(rows)
#     # Build per-year rows with real provenance so citations show actual docs/pages
#     years_for_report = sorted(combo.keys())[-3:] if combo else []
#     rows_yearwise = []
#     for y in years_for_report:
#         src = _pick_source_for_year(prov, y)
#         rows_yearwise.append({
#             "doc": (src.get("doc") if src else "(unknown)"),
#             "table_id": (src.get("table_id") if src else -1),
#             "row_id": (src.get("row_id") if src else -1),
#             "label": "Operating expenses",
#             "series": {y: combo.get(y)},
#             "series_q": {},
#             "page": (src.get("page") if src and src.get("page") is not None else None),
#         })
#     # Fallback: if something went wrong, still provide a single combined row
#     if not rows_yearwise:
#         rows_yearwise = [{
#             "doc": "(union)",
#             "table_id": -1,
#             "row_id": -1,
#             "label": "Operating expenses",
#             "series": combo,
#             "series_q": {},
#             "page": None
#         }]
#     if not combo:
#         print("‚ö†Ô∏è No expenses series found across docs.")
#         if USE_LLM_SUMMARY:
#             print("\nLLM Summary (baseline, single call):")
#             print(_llm_summary(q, agent, kb, res=res, k_ctx=8, rows_override=rows_yearwise))
#         if ONLINE:
#             print("\nLLM Answer (online, single call):")
#             print(f"[LLM] baseline using {_llm_provider_info()}")
#             baseline_answer_one_call(kb, q, k_ctx=8, table_rows=rows_yearwise)
#         return res

#     last3 = [(y, combo[y]) for y in years_for_report]
#     if len(last3) < 2:
#         print("‚ö†Ô∏è Not enough annual values to compute YoY.")
#         if USE_LLM_SUMMARY:
#             print("\nLLM Summary (baseline, single call):")
#             print(_llm_summary(q, agent, kb, res=res, k_ctx=8, rows_override=rows_yearwise))
#         if ONLINE:
#             print("\nLLM Answer (online, single call):")
#             print(f"[LLM] baseline using {_llm_provider_info()}")
#             baseline_answer_one_call(kb, q, k_ctx=8, table_rows=rows_yearwise)
#         return res

#     print("Year | Opex | YoY %")
#     print("-----|------|------")
#     prev_val = None
#     for y, v in last3:
#         yoy = ((v - prev_val) / prev_val * 100.0) if prev_val not in (None, 0) else None
#         yoy_s = f"{yoy:.2f}%" if yoy is not None else "‚Äî"
#         print(f"{y} | {v} | {yoy_s}")
#         prev_val = v

#     # Show sources (doc & page) used for each year printed
#     print("\nSources:")
#     for y, _ in last3:
#         src = _pick_source_for_year(prov, y)
#         if src:
#             p = src.get("page")
#             ptxt = f"page {int(p)}" if p is not None else "no page"
#             print(f"  {y}: {src.get('doc')} ({ptxt})")

#     if USE_LLM_SUMMARY:
#         print("\nLLM Summary (baseline, single call):")
#         print(_llm_summary(q, agent, kb, res=res, k_ctx=8, rows_override=rows_yearwise))
#     if ONLINE:
#         print("\nLLM Answer (online, single call):")
#         print(f"[LLM] baseline using {_llm_provider_info()}")
#         baseline_answer_one_call(kb, q, k_ctx=8, table_rows=rows_yearwise)

#     return res

# # ---------- Q3: Operating Efficiency Ratio (Opex √∑ Operating Income) ----------

# def run_q3_efficiency_ratio(agent: Agent, kb: KBEnv):
#     print("\n=== Q3) Operating Efficiency Ratio ‚Äî last 3 fiscal years ===")

#     # Union Opex across docs/tables
#     opex_rows = agent.table.get_metric_rows("operating expenses", limit=50) \
#         or agent.table.get_metric_rows("total expenses", limit=50)
#     opex, opex_prov = _union_series(opex_rows)

#     # Union Income across docs/tables (prefer 'total income', else 'operating income')
#     income_rows = agent.table.get_metric_rows("total income", limit=50) \
#         or agent.table.get_metric_rows("operating income", limit=50)
#     income, income_prov = _union_series(income_rows)

#     # Build per-year rows for both Opex and Income so citations show real docs/pages
#     rows_for_llm = []
#     years_overlap = sorted(set(opex.keys()).intersection(income.keys()))[-3:]
#     for y in years_overlap:
#         s_ox = _pick_source_for_year(opex_prov, y)
#         s_in = _pick_source_for_year(income_prov, y)
#         rows_for_llm.append({
#             "doc": (s_ox.get("doc") if s_ox else "(unknown)"),
#             "table_id": (s_ox.get("table_id") if s_ox else -1),
#             "row_id": (s_ox.get("row_id") if s_ox else -1),
#             "label": "Operating expenses",
#             "series": {y: opex.get(y)},
#             "series_q": {},
#             "page": (s_ox.get("page") if s_ox and s_ox.get("page") is not None else None)
#         })
#         rows_for_llm.append({
#             "doc": (s_in.get("doc") if s_in else "(unknown)"),
#             "table_id": (s_in.get("table_id") if s_in else -1),
#             "row_id": (s_in.get("row_id") if s_in else -1),
#             "label": "Total income",
#             "series": {y: income.get(y)},
#             "series_q": {},
#             "page": (s_in.get("page") if s_in and s_in.get("page") is not None else None)
#         })
#     # Fallback to union-style rows if needed
#     if not rows_for_llm:
#         rep_year = max(opex.keys() & income.keys()) if (opex and income) else None
#         rep_opex = _pick_source_for_year(opex_prov, rep_year) if rep_year else None
#         rep_income = _pick_source_for_year(income_prov, rep_year) if rep_year else None
#         rows_for_llm = [
#             {
#                 "doc": (rep_opex.get("doc") if rep_opex else "(union)"),
#                 "table_id": (rep_opex.get("table_id") if rep_opex else -1),
#                 "row_id": (rep_opex.get("row_id") if rep_opex else -1),
#                 "label": "Operating expenses",
#                 "series": opex or {},
#                 "series_q": {},
#                 "page": (rep_opex.get("page") if rep_opex else None)
#             },
#             {
#                 "doc": (rep_income.get("doc") if rep_income else "(union)"),
#                 "table_id": (rep_income.get("table_id") if rep_income else -1),
#                 "row_id": (rep_income.get("row_id") if rep_income else -1),
#                 "label": "Total income",
#                 "series": income or {},
#                 "series_q": {},
#                 "page": (rep_income.get("page") if rep_income else None)
#             },
#         ]

#     if not opex or not income:
#         print("‚ö†Ô∏è Missing Opex or Income series across docs.")
#         if USE_LLM_SUMMARY:
#             print("\nLLM Summary (baseline, single call):")
#             q = "Operating Efficiency Ratio (Opex / Operating Income) for the last 3 fiscal years"
#             print(_llm_summary(q, agent, kb, res=None, k_ctx=8, rows_override=rows_for_llm))
#         if ONLINE:
#             print("\nLLM Answer (online, single call):")
#             print(f"[LLM] baseline using {_llm_provider_info()}")
#             q_llm = "Operating Efficiency Ratio (Opex / Operating Income) for the last 3 fiscal years"
#             baseline_answer_one_call(kb, q_llm, k_ctx=8, table_rows=rows_for_llm)
#         return None

#     years = years_overlap
#     if not years:
#         print("‚ö†Ô∏è No overlapping years between Opex and Income.")
#         if USE_LLM_SUMMARY:
#             print("\nLLM Summary (baseline, single call):")
#             q = "Operating Efficiency Ratio (Opex / Operating Income) for the last 3 fiscal years"
#             print(_llm_summary(q, agent, kb, res=None, k_ctx=8, rows_override=rows_for_llm))
#         if ONLINE:
#             print("\nLLM Answer (online, single call):")
#             print(f"[LLM] baseline using {_llm_provider_info()}")
#             q_llm = "Operating Efficiency Ratio (Opex / Operating Income) for the last 3 fiscal years"
#             baseline_answer_one_call(kb, q_llm, k_ctx=8, table_rows=rows_for_llm)
#         return None

#     print("Year | Opex | Income | Opex/Income %")
#     print("-----|------|--------|---------------")
#     for y in years:
#         ov = opex.get(y)
#         iv = income.get(y)
#         ratio = (ov / iv * 100.0) if (iv not in (None, 0)) else None
#         ratio_s = f"{ratio:.2f}%" if ratio is not None else "‚Äî"
#         print(f"{y} | {ov} | {iv} | {ratio_s}")

#     print("\nSources:")
#     for y in years:
#         s1 = _pick_source_for_year(opex_prov, y)
#         s2 = _pick_source_for_year(income_prov, y)
#         if s1:
#             p1 = s1.get("page"); p1t = f"page {int(p1)}" if p1 is not None else "no page"
#             print(f"  Opex {y}: {s1.get('doc')} ({p1t})")
#         if s2:
#             p2 = s2.get("page"); p2t = f"page {int(p2)}" if p2 is not None else "no page"
#             print(f"  Income {y}: {s2.get('doc')} ({p2t})")

#     if USE_LLM_SUMMARY:
#         print("\nLLM Summary (baseline, single call):")
#         q = "Operating Efficiency Ratio (Opex / Operating Income) for the last 3 fiscal years"
#         print(_llm_summary(q, agent, kb, res=None, k_ctx=8, rows_override=rows_for_llm))
#     if ONLINE:
#         print("\nLLM Answer (online, single call):")
#         q_llm = "Operating Efficiency Ratio (Opex / Operating Income) for the last 3 fiscal years"
#         baseline_answer_one_call(kb, q_llm, k_ctx=8, table_rows=rows_for_llm)

#     return {"years": years, "opex": opex, "income": income}

# # ---------- Runner ----------

# def run_all(base: str = "./data_marker"):
#     kb = KBEnv(base=base)
#     agent = Agent(kb)

#     # Q1
#     # res1 = run_q1_nim_last5q(agent, kb)

#     # Q2
#     res2 = run_q2_opex_yoy(agent, kb)

#     # Q3
#     _ = run_q3_efficiency_ratio(agent, kb)

# # # Auto-run when executed directly (safe in notebooks too)
# # if __name__ == "__main__" or "__file__" not in globals():
# #     run_all(base="./data_marker")

# if __name__ == "__main__":
#     # Ensure Stage 2 is initialized, then run baseline with prose printing
#     try:
#         init_stage2(out_dir=OUT_DIR)
#         print("[Stage3] init_stage2() called successfully.")
#     except Exception as e:
#         print(f"[Stage3] init_stage2() failed: {e}")
    
#     bench = run_benchmark(print_prose=True, use_agent=False, out_dir=OUT_DIR, dry_run=False)
#     # Also echo the summary table at the end
#     if isinstance(bench.get("summary"), pd.DataFrame) and not bench["summary"].empty:
#         df = bench["summary"]
#         p50 = float(df['Latency_ms'].quantile(0.5))
#         p95 = float(df['Latency_ms'].quantile(0.95))
#         print(f"\n=== BASELINE Benchmark Summary ===")
#         print(f"Latency p50: {p50:.1f} ms, p95: {p95:.1f} ms")

#     run_all(base="./data_marker")

Two-Mode RAG System with Marker + PDFPlumber Fallback

In [1]:
"""
Two-Mode RAG System with Parallel Sub-Query Support
"""

from __future__ import annotations
import os, json, time
from typing import List, Dict, Any, Optional
from pathlib import Path

import numpy as np
import pandas as pd

# Import g2x components
from g2x import KBEnv, Agent, baseline_answer_one_call


# ============================================================================
# CONFIGURATION
# ============================================================================

class Config:
    """Centralized configuration"""
    VERBOSE = bool(int(os.environ.get("AGENT_CFO_VERBOSE", "1")))
    
    # Paths
    MARKER_INDEX = "./data_marker"
    PDFPLUMBER_INDEX = "./data"
    
    # Search params
    TOP_K = 12
    HYBRID_ALPHA = 0.6
    RERANK_TOP_K = 24
    
    # Agentic mode
    USE_PARALLEL_SUBQUERIES = True  # Enable parallel sub-query decomposition


# ============================================================================
# UTILITIES
# ============================================================================

def page_or_none(x) -> Optional[int]:
    """Safely convert page numbers"""
    try:
        if x is None or pd.isna(x):
            return None
        return int(x)
    except:
        return None


# ============================================================================
# MULTI-INDEX SEARCH (Marker + PDFPlumber Fallback)
# ============================================================================

class MultiIndexSearch:
    """
    Dual-index search with automatic fallback
    """
    
    def __init__(self, marker_path: str, pdfplumber_path: str):
        self.marker_kb = self._load_kb(marker_path, "Marker")
        self.pdfplumber_kb = self._load_kb(pdfplumber_path, "PDFPlumber")
        
        if not self.marker_kb and not self.pdfplumber_kb:
            raise RuntimeError("No valid indexes loaded")
    
    def _load_kb(self, path: str, name: str) -> Optional[KBEnv]:
        """Load KB with BM25 + Reranker enabled"""
        if not Path(path).exists():
            return None
        
        try:
            kb = KBEnv(base=path, enable_bm25=True, enable_reranker=True)
            if Config.VERBOSE:
                print(f"[{name}] ‚úì Loaded {len(kb.texts)} chunks")
            return kb
        except Exception as e:
            if Config.VERBOSE:
                print(f"[{name}] ‚úó Failed: {e}")
            return None
    
    def search(self, query: str, top_k: int = None) -> List[Dict[str, Any]]:
        """Hybrid search with fallback"""
        top_k = top_k or Config.TOP_K
        
        # Primary: Marker
        results = []
        if self.marker_kb:
            df = self.marker_kb.search(query, k=top_k, alpha=Config.HYBRID_ALPHA, rerank_top_k=Config.RERANK_TOP_K)
            results = self._df_to_dict(df, "marker")
        
        # Fallback: PDFPlumber
        if len(results) < top_k // 2 and self.pdfplumber_kb:
            df = self.pdfplumber_kb.search(query, k=top_k - len(results))
            results.extend(self._df_to_dict(df, "pdfplumber"))
        
        results.sort(key=lambda x: x.get("score", 0), reverse=True)
        return results[:top_k]
    
    def _df_to_dict(self, df: pd.DataFrame, source: str) -> List[Dict]:
        """Convert DataFrame to dict list"""
        if df is None or df.empty:
            return []
        
        return [
            {
                "file": str(row.get("doc")),
                "page": page_or_none(row.get("page")),
                "text": str(row.get("text")),
                "score": float(row.get("score", 0)),
                "year": int(row["year"]) if pd.notna(row.get("year")) else None,
                "quarter": str(row["quarter"]) if pd.notna(row.get("quarter")) else None,  # FIX: Quarter is a string, not int
                "section_hint": row.get("section_hint"),
                "index_source": source
            }
            for _, row in df.iterrows()
        ]


# ============================================================================
# ANSWERING ENGINE
# ============================================================================

class AnsweringEngine:
    """Unified baseline + agentic answering with parallel sub-queries"""
    
    def __init__(self):
        print("AnsweringEngine initialized")
        self.search = MultiIndexSearch(Config.MARKER_INDEX, Config.PDFPLUMBER_INDEX)
        
        # Agent with parallel sub-queries
        primary_kb = self.search.marker_kb or self.search.pdfplumber_kb
        self.agent = Agent(
            kb=primary_kb, 
            use_parallel_subqueries=True,
            verbose=Config.VERBOSE
        )
        print(f"[AnsweringEngine] Parallel sub-queries enabled: {self.agent.use_parallel_subqueries}")
    
    def answer(self, query: str, mode: str = "baseline") -> Dict[str, Any]:
        """Execute query in baseline or agentic mode"""
        
        if mode == "agentic":
            # Agentic mode with parallel sub-queries
            agent_result = self.agent.run(query, k_ctx=Config.TOP_K)
            
            return {
                "answer": self._format_agent_answer(agent_result),
                "hits": self._extract_agent_hits(agent_result),
                "execution_log": {
                    "plan": agent_result.plan,
                    "actions": agent_result.actions,
                    "observations": agent_result.observations
                }
            }
        
        else:  # baseline
            # Standard RAG: Retrieve ‚Üí Single LLM call
            results = self.search.search(query, top_k=Config.TOP_K)
            
            answer_result = baseline_answer_one_call(
                self.search.marker_kb or self.search.pdfplumber_kb,
                query,
                k_ctx=Config.TOP_K
            )
            
            return {
                "answer": answer_result.get("answer", ""),
                "hits": results[:5],  # Top-5 citations
                "execution_log": None
            }
    
    def _format_agent_answer(self, agent_result) -> str:
        """
        Format agentic answer with proper fallback
        """
        lines = []
        
        # Add execution summary (optional)
        if agent_result.observations and Config.VERBOSE:
            lines.append("**Execution Summary**")
            for obs in agent_result.observations:
                lines.append(f"- {obs}")
            lines.append("")
        
        fin = agent_result.final
        
        # Priority 1: Return LLM-synthesized answer if available
        if "answer" in fin and fin["answer"]:
            return fin["answer"]
        
        # Priority 2: Format tool outputs (FALLBACK)
        if fin.get("comparison_results"):
            lines.append("**Multi-Document Comparison**")
            for comp in fin["comparison_results"][:5]:
                doc = comp.get("doc", "Unknown")
                years = comp.get("years", [])
                values = comp.get("values", [])
                if years and values:
                    year_val = ", ".join(f"{y}: {v}" for y, v in zip(years, values))
                    lines.append(f"- {doc}: {year_val}")
        
        elif fin.get("table_rows"):
            lines.append("**Extracted Data**")
            for r in fin["table_rows"][:5]:
                doc = r.get("doc", "Unknown")
                label = r.get("label", "")
                
                if r.get("series_q"):
                    qkeys = sorted(r["series_q"].keys())[-5:]
                    ser = ", ".join(f"{k}: {r['series_q'][k]}" for k in qkeys)
                    lines.append(f"- {doc} | {label}: {ser}")
                elif r.get("series"):
                    ys = sorted(r["series"].keys())[-3:]
                    ser = ", ".join(f"{y}: {r['series'][y]}" for y in ys)
                    lines.append(f"- {doc} | {label}: {ser}")
                else:
                    lines.append(f"- {doc} | {label}: (no data)")
        
        # Priority 3: Final fallback message
        if not lines:
            lines.append("Analysis complete. No structured data extracted.")
        
        return "\n".join(lines)
    
    def _extract_agent_hits(self, agent_result) -> List[Dict]:
        """Extract citations from agent result"""
        contexts = agent_result.final.get("contexts")
        if contexts is None or contexts.empty:
            return []
        
        return [
            {
                "file": row.get("doc"),
                "page": row.get("page"),
                "section_hint": row.get("section_hint"),
                "index_source": "marker",
                "score": row.get("score")
            }
            for _, row in contexts.head(5).iterrows()
        ]


# ============================================================================
# BENCHMARK RUNNER
# ============================================================================

class BenchmarkRunner:
    """Standardized benchmark execution"""
    
    QUERIES = [
        "Report the Gross Margin (or Net Interest Margin, if a bank) over the last 5 quarters, with values.",
        "Show Operating Expenses for the last 3 fiscal years, year-on-year comparison.",
        "Calculate the Operating Efficiency Ratio (Opex √∑ Operating Income) for the last 3 fiscal years, showing the working."
    ]
    
    def __init__(self, engine: AnsweringEngine):
        self.engine = engine
    
    def run(self, mode: str = "baseline") -> Dict[str, Any]:
        """Run benchmark and save results"""
        out_dir = "data_marker" if mode == "agentic" else "data"
        os.makedirs(out_dir, exist_ok=True)
        
        print(f"\n{'='*60}")
        print(f"  {mode.upper()} BENCHMARK")
        print(f"{'='*60}\n")
        
        results = []
        for i, query in enumerate(self.QUERIES, 1):
            print(f"\nQ{i}. {query}\n")
            
            t0 = time.perf_counter()
            result = self.engine.answer(query, mode=mode)
            latency_ms = round((time.perf_counter() - t0) * 1000, 2)
            
            print(result["answer"])
            if result.get("hits"):
                print("\n--- Citations ---")
                for hit in result["hits"][:5]:
                    pg = f"p.{hit.get('page')}" if hit.get('page') else ""
                    print(f"- {hit['file']} {pg}")
            
            print(f"\n(Latency: {latency_ms} ms)")
            
            results.append({
                "query": query,
                "answer": result["answer"],
                "citations": result.get("hits", []),
                "execution_log": result.get("execution_log"),
                "latency_ms": latency_ms
            })
        
        # Save JSON with UTF-8 encoding
        json_path = f"{out_dir}/bench_{mode}.json"
        with open(json_path, "w", encoding="utf-8") as f:  # FIX: Add encoding
            json.dump({"results": results}, f, indent=2, ensure_ascii=False)  # FIX: Add ensure_ascii=False
        
        # Save Markdown
        md_path = f"{out_dir}/bench_{mode}.md"
        self._write_markdown(md_path, results, mode)
        
        # Summary
        latencies = [r["latency_ms"] for r in results]
        print(f"\n{'='*60}")
        print(f"  SUMMARY")
        print(f"{'='*60}")
        print(f"P50: {np.percentile(latencies, 50):.1f} ms")
        print(f"P95: {np.percentile(latencies, 95):.1f} ms\n")
        
        return {"json_path": json_path, "md_path": md_path, "results": results}
    
    def _write_markdown(self, path: str, results: List[Dict], mode: str):
        """Generate markdown report"""
        lines = [f"# {mode.title()} Benchmark Report\n"]
        
        if mode == "baseline":
            lines.append("**Pipeline**: Hybrid Search (BM25 + Vector + RRF + Rerank) -> Single LLM\n")  # Changed ‚Üí to ->
        else:
            lines.append("**Pipeline**: Parallel Sub-Queries -> Tool Execution -> Multi-step Reasoning\n")  # Changed ‚Üí to ->
        
        for i, r in enumerate(results, 1):
            lines.append(f"\n---\n\n## Q{i}. {r['query']}\n")
            lines.append(f"**Answer**\n\n{r['answer']}\n")
            
            if r.get("citations"):
                lines.append("\n**Citations**\n")
                for hit in r["citations"]:
                    pg = f"p.{hit.get('page')}" if hit.get('page') else ""
                    lines.append(f"- {hit['file']} {pg}")
            
            if r.get("execution_log"):
                lines.append("\n**Execution Log**\n```")
                lines.append(json.dumps(r["execution_log"], indent=2))
                lines.append("```")
            
            lines.append(f"\n**Latency**: {r['latency_ms']} ms")
        
        # Summary
        latencies = [r["latency_ms"] for r in results]
        lines.append("\n---\n\n## Summary\n")
        lines.append(f"- P50: {np.percentile(latencies, 50):.1f} ms")
        lines.append(f"- P95: {np.percentile(latencies, 95):.1f} ms")
        
        # FIX: Add encoding="utf-8"
        with open(path, "w", encoding="utf-8") as f:
            f.write("\n".join(lines))


# ============================================================================
# MAIN
# ============================================================================

def main():
    os.environ["TOKENIZERS_PARALLELISM"] = "false"
    
    engine = AnsweringEngine()
    benchmark = BenchmarkRunner(engine)
    
    print("\n" + "="*60)
    print("  TWO-MODE RAG SYSTEM")
    print("="*60)
    
    # Baseline
    baseline_results = benchmark.run(mode="baseline")
    
    # Agentic with parallel sub-queries
    agentic_results = benchmark.run(mode="agentic")
    
    print("\n" + "="*60)
    print("  COMPLETE")
    print("="*60)
    print(f"Baseline: {baseline_results['json_path']}")
    print(f"Agentic:  {agentic_results['json_path']}")


if __name__ == "__main__":
    main()

  from .autonotebook import tqdm as notebook_tqdm


AnsweringEngine initialized
[BM25] ‚úì Indexed 13548 documents
[Reranker] ‚úì Loaded cross-encoder/ms-marco-MiniLM-L-6-v2
[Marker] ‚úì Loaded 13548 chunks
[BM25] ‚úì Indexed 1623 documents
[Reranker] ‚úì Loaded cross-encoder/ms-marco-MiniLM-L-6-v2
[PDFPlumber] ‚úì Loaded 1623 chunks
[Agent] Tools: Calculator: ‚úì | Table: ‚úì | Text: ‚úì | MultiDoc: ‚úì
[AnsweringEngine] Parallel sub-queries enabled: True

  TWO-MODE RAG SYSTEM

  BASELINE BENCHMARK


Q1. Report the Gross Margin (or Net Interest Margin, if a bank) over the last 5 quarters, with values.

[Search] RRF fusion: 96 candidates
[Rerank] Reranking top-24 candidates...
[Rerank] ‚úì Reranked to top-12
[Search] RRF fusion: 96 candidates
[Rerank] Reranking top-24 candidates...
[Rerank] ‚úì Reranked to top-12
[LLM] single-call baseline using groq:openai/gpt-oss-20b
[LLM] provider=groq model=openai/gpt-oss-20b
**Answer**  
The context supplies explicit net‚Äëinterest‚Äëmargin (NIM) figures only for two quarters: Q4‚ÄØ2024 (2.05‚ÄØ%)

## 6. Instrumentation

Log timings: T_ingest, T_retrieve, T_rerank, T_reason, T_generate, T_total. Log tokens, cache hits, tools.

In [23]:
# Example instrumentation schema
import pandas as pd
logs = pd.DataFrame(columns=['Query','T_ingest','T_retrieve','T_rerank','T_reason','T_generate','T_total','Tokens','CacheHits','Tools'])
logs

Unnamed: 0,Query,T_ingest,T_retrieve,T_rerank,T_reason,T_generate,T_total,Tokens,CacheHits,Tools


## 7. Optimizations

**Required Optimizations**

Each team must implement at least:
*   2 retrieval optimizations (e.g., hybrid BM25+vector, smaller embeddings, dynamic k).
*   1 caching optimization (query cache or ratio cache).
*   1 agentic optimization (plan pruning, parallel sub-queries).
*   1 system optimization (async I/O, batch embedding, memory-mapped vectors).

## 8. Results & Plots

Show baseline vs optimized. Include latency plots (p50/p95) and accuracy tables.

In [None]:
# TODO: Generate plots with matplotlib
