# Agent CFO ‚Äî Performance Optimization & Design

---
This is the starter notebook for your project. Follow the required structure below.


You will design and optimize an Agent CFO assistant for a listed company. The assistant should answer finance/operations questions using RAG (Retrieval-Augmented Generation) + agentic reasoning, with response time (latency) as the primary metric.

Your system must:
*   Ingest the company‚Äôs public filings.
*   Retrieve relevant passages efficiently.
*   Compute ratios/trends via tool calls (calculator, table parsing).
*   Produce answers with valid citations to the correct page/table.


In [1]:
import os

# Best practice: do NOT hardcode API keys in notebook cells.
# If GEMINI_API_KEY is already set in the environment (e.g., via secrets), keep it.
# Otherwise, prompt the user to enter it securely (won't be echoed).
if os.environ.get("GEMINI_API_KEY"):
	print("GEMINI_API_KEY found in environment.")
else:
	try:
		from getpass import getpass
		key = getpass("Enter GEMINI_API_KEY (input hidden): ")
	except Exception:
		# Fallback to input() if getpass is unavailable in this environment
		key = input("Enter GEMINI_API_KEY: ")
	if key:
		os.environ["GEMINI_API_KEY"] = key
		print("GEMINI_API_KEY set for this session (not saved).")
	else:
		raise RuntimeError("GEMINI_API_KEY not provided. Set it via environment variables or re-run this cell.")

GEMINI_API_KEY found in environment.


## 1. Config & Secrets

Fill in your API keys in secrets. **Do not hardcode keys** in cells.

In [2]:
import os

# Example:
# os.environ['GEMINI_API_KEY'] = 'your-key-here'
# os.environ['OPENAI_API_KEY'] = 'your-key-here'

COMPANY_NAME = "DBS Bank"


## 2. Data Download (Dropbox)

*   Annual Reports: last 3‚Äì5 years.
*   Quarterly Results Packs & MD&A (Management Discussion & Analysis).
*   Investor Presentations and Press Releases.
*   These files must be submitted later as a deliverable in the Dropbox data pack.
*   Upload them under `/content/data/`.

Scope limit: each team will ingest minimally 15 PDF files total.


## 3. System Requirements

**Retrieval & RAG**
*   Use a vector index (e.g., FAISS, LlamaIndex) + a keyword filter (BM25/ElasticSearch).
*   Citations must include: report name, year, page number, section/table.

**Agentic Reasoning**
*   Support at least 3 tool types: calculator, table extraction, multi-document compare.
*   Reasoning must follow a plan-then-act pattern (not a single unstructured call).

**Instrumentation**
*   Log timings for: T_ingest, T_retrieve, T_rerank, T_reason, T_generate, T_total.
*   Log: tokens used, cache hits, tools invoked.
*   Record p50/p95 latencies.

 ### Gemini Version 1

In [None]:
!pip install rank_bm25

In [None]:
pip install rank_bm25

In [None]:
pip install openai


In [None]:
pip install opencv-python

In [3]:
# 1. Install the marker library
# This command should be run in your terminal or a Colab cell:
# !pip install marker-pdf -q

# 2. Import necessary components
import subprocess
import shutil
from pathlib import Path
import sys
import hashlib
import re
import cv2
import numpy as np
import pandas as pd


def md5sum(file_path: Path, chunk_size: int = 8192) -> str:
    """Return the hex md5 of a file."""
    h = hashlib.md5()
    with open(file_path, "rb") as f:
        for chunk in iter(lambda: f.read(chunk_size), b""):
            h.update(chunk)
    return h.hexdigest()

# === OCR & extraction helpers ===
NUM_PAT = re.compile(r"^[+-]?\d{1,4}(?:[.,]\d+)?%?$")
NIM_KEYWORDS = ["net interest margin", "nim"]

QUARTER_PAT = re.compile(r"\b([1-4Iil|])\s*[QO0]\s*([0-9O]{2,4})\b", re.IGNORECASE)
# Simpler decade-only pattern for quarters, e.g., 2Q24, 1Q25
QUARTER_SIMPLE_PAT = re.compile(r"\b([1-4])Q(2\d)\b", re.IGNORECASE)  # e.g., 2Q24, 1Q25

# --- OCR character normalization for quarter tokens (common OCR mistakes) ---
_CHAR_FIX = str.maketrans({
    "O":"0","o":"0",
    "S":"5","s":"5",
    "I":"1","l":"1","|":"1","!":"1",
    "D":"0",
    "B":"3","8":"3",
    "Z":"2","z":"2"
})
def normalize_token(t: str) -> str:
    t = (t or "").strip()
    return t.translate(_CHAR_FIX).replace(" ", "")

# --- Helper: detect quarter tokens from nearby Markdown file ---
def detect_qlabels_from_md(dest_dir: Path, image_name: str) -> list[str]:
    """
    Scan the figure's markdown file for quarter tokens (e.g., 2Q24, 1Q2025).
    Returns tokens in document order (deduped).
    """
    try:
        md_file = dest_dir / f"{dest_dir.name}.md"
        if not md_file.exists():
            cand = list(dest_dir.glob("*.md"))
            if not cand:
                return []
            md_file = cand[0]
        text = md_file.read_text(encoding="utf-8", errors="ignore")
    except Exception:
        return []
    # Collect all quarter tokens across the document
    tokens = []
    for m in QUARTER_PAT.finditer(text):
        q = f"{m.group(1)}Q{m.group(2)[-2:]}"
        tokens.append(q)
    # Deduplicate preserving order
    seen = set()
    ordered = []
    for q in tokens:
        if q not in seen:
            seen.add(q)
            ordered.append(q)
    return ordered

def load_image(path):
    p = Path(path)
    im = cv2.imread(str(p))
    if im is None:
        raise RuntimeError(f"cv2.imread() failed: {p}")
    return im

def preprocess(img_bgr):
    scale = 2.0
    img = cv2.resize(img_bgr, None, fx=scale, fy=scale, interpolation=cv2.INTER_CUBIC)
    gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
    gray = cv2.bilateralFilter(gray, d=7, sigmaColor=50, sigmaSpace=50)
    clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8,8))
    gray = clahe.apply(gray)
    thr = cv2.adaptiveThreshold(gray, 255, cv2.ADAPTIVE_THRESH_GAUSSIAN_C,
                                cv2.THRESH_BINARY, 31, 8)
    return img, gray, thr, scale

def norm_num(s):
    s = s.replace(",", "").strip()
    pct = s.endswith("%")
    if pct:
        s = s[:-1]
    try:
        return float(s), pct
    except:
        return None, pct

def extract_numbers(ocr_results):
    rows = []
    for r in ocr_results or []:
        txt = str(r.get("text","")).strip()
        if NUM_PAT.match(txt):
            val, is_pct = norm_num(txt)
            if val is None:
                continue
            x1,y1,x2,y2 = r["bbox"]
            rows.append({
                "raw": txt, "value": val, "is_pct": is_pct, "conf": r.get("conf", None),
                "x1": int(x1), "y1": int(y1), "x2": int(x2), "y2": int(y2),
                "cx": int((x1+x2)/2), "cy": int((y1+y2)/2)
            })
    df = pd.DataFrame(rows).sort_values(["cy","cx"]).reset_index(drop=True)
    if "is_pct" not in df.columns and not df.empty:
        df["is_pct"] = df["raw"].astype(str).str.endswith("%")
    return df

def kmeans_1d(values, k=2, iters=20):
    values = np.asarray(values, dtype=float).reshape(-1,1)
    centers = np.array([values.min(), values.max()]).reshape(k,1)
    for _ in range(iters):
        d = ((values - centers.T)**2)
        labels = d.argmin(axis=1)
        new_centers = np.array([values[labels==i].mean() if np.any(labels==i) else centers[i] for i in range(k)]).reshape(k,1)
        if np.allclose(new_centers, centers, atol=1e-3):
            break
        centers = new_centers
    return labels, centers.flatten()

def run_easyocr(img_rgb):
    import easyocr
    global _EASY_OCR_READER
    try:
        _EASY_OCR_READER
    except NameError:
        _EASY_OCR_READER = None
    if _EASY_OCR_READER is None:
        _EASY_OCR_READER = easyocr.Reader(['en'], gpu=False, verbose=False)
    results = _EASY_OCR_READER.readtext(img_rgb, detail=1, paragraph=False)
    out = []
    for quad, text, conf in results:
        (x1,y1),(x2,y2),(x3,y3),(x4,y4) = quad
        out.append({"bbox": (int(x1),int(y1),int(x3),int(y3)), "text": str(text), "conf": float(conf)})
    return out

# --- Focused bottom-axis quarter detection using EasyOCR (robust to OCR confusions) ---
def detect_quarters_easyocr(img_bgr):
    """
    Use EasyOCR to read quarter labels along the bottom axis.
    Returns a list of (x_global, 'nQyy') sorted left‚Üíright, with half-year tokens removed.
    """
    H, W = img_bgr.shape[:2]
    y0 = int(H * 0.66)  # bottom ~34%
    crop = img_bgr[y0:H, 0:W]
    gray = cv2.cvtColor(crop, cv2.COLOR_BGR2GRAY)
    gray = cv2.bilateralFilter(gray, d=7, sigmaColor=50, sigmaSpace=50)
    clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8,8))
    gray = clahe.apply(gray)
    thr = cv2.adaptiveThreshold(gray, 255, cv2.ADAPTIVE_THRESH_GAUSSIAN_C,
                                cv2.THRESH_BINARY, 31, 8)
    # kernel = np.ones((3,3), np.uint8)
    # thr = cv2.morphologyEx(thr, cv2.MORPH_CLOSE, kernel, iterations=1)
    up = cv2.resize(thr, None, fx=3.0, fy=3.0, interpolation=cv2.INTER_CUBIC)
    img_rgb = cv2.cvtColor(up, cv2.COLOR_GRAY2RGB)
    ocr = run_easyocr(img_rgb)
    # PASS 1 ‚Äî direct regex on normalized tokens
    tokens = []
    for r in ocr or []:
        raw = str(r.get("text","")).strip()
        x1,y1,x2,y2 = r["bbox"]
        cx_local = (x1 + x2) // 2
        cx_global = int(cx_local / 3.0)  # undo scaling
        tokens.append({"x": cx_global, "raw": raw, "norm": normalize_token(raw)})
    def _is_half_token(t: str) -> bool:
        t = (t or "").lower().replace(" ", "")
        return ("9m" in t) or ("1h" in t) or ("h1" in t) or ("h2" in t) or ("2h" in t)
    quarters = []
    for t in tokens:
        if _is_half_token(t["norm"]):
            continue
        m = QUARTER_PAT.search(t["norm"])
        if m:
            q = f"{m.group(1)}Q{m.group(2)[-2:]}"
            q = normalize_token(q)
            quarters.append((t["x"], q))
    # PASS 2 ‚Äî stitch split tokens if too few quarters were found
    if len(quarters) < 4 and tokens:
        pieces = sorted(tokens, key=lambda d: d["x"])
        digits_1to4 = [p for p in pieces if p["norm"] in ("1","2","3","4")]
        q_only      = [p for p in pieces if p["norm"].upper() == "Q"]
        q_with_year = [p for p in pieces if re.fullmatch(r"Q[0-9O]{2,4}", p["norm"], flags=re.I)]
        years_2d    = [p for p in pieces if re.fullmatch(r"[0-9O]{2,4}", p["norm"])]
        def near(a, b, tol=70):
            return abs(a["x"] - b["x"]) <= tol
        for d in digits_1to4:
            # digit + Qyy
            candidates = [q for q in q_with_year if near(d, q)]
            if candidates:
                qtok = min(candidates, key=lambda q: abs(q["x"]-d["x"]))
                qyy = normalize_token(qtok["norm"])[1:]
                quarters.append(((d["x"]+qtok["x"])//2, f"{d['norm']}Q{qyy[-2:]}"))
                continue
            # digit + Q + yy
            qs = [q for q in q_only if near(d, q)]
            ys = [y for y in years_2d if near(d, y, tol=120)]
            if qs and ys:
                qtok = min(qs, key=lambda q: abs(q["x"]-d["x"]))
                ytok = min(ys, key=lambda y: abs(y["x"]-qtok["x"]))
                yy = normalize_token(ytok["norm"])
                quarters.append(((d["x"]+ytok["x"])//2, f"{d['norm']}Q{yy[-2:]}"))
                continue
    if not quarters:
        return []
    quarters.sort(key=lambda t: t[0])
    deduped, last_x = [], -10**9
    for x,q in quarters:
        if abs(x - last_x) <= 22:
            continue
        deduped.append((x,q))
        last_x = x
    return deduped

# NIM value band (pct) and geometry heuristics for verification
NIM_MIN, NIM_MAX = 1.3, 3.2
TOP_FRACTION = 0.65     # widen band: NIM labels often sit higher than 45%
RIGHT_HALF_ONLY = True  # NIM values appear on right panel in these deck

def is_strict_nim_image(img_path: Path) -> tuple[bool, str]:
    """
    Heuristic re-check:
      1) Title/text contains NIM keywords (coarse gate)
      2) Percent tokens mostly within NIM_MIN..NIM_MAX
      3) Tokens located in the top region (and right half, if enabled)
    Returns (ok, reason)
    """
    try:
        img_bgr = load_image(img_path)
        H, W = img_bgr.shape[:2]
        # 1) quick-text gate (soft): don't return yet; allow numeric signature to validate
        kw_ok = is_relevant_image(img_path, NIM_KEYWORDS)
        # 2) numeric gate on enhanced image
        img_up, gray, thr, scale = preprocess(img_bgr)
        img_rgb = cv2.cvtColor(thr, cv2.COLOR_GRAY2RGB)
        ocr = run_easyocr(img_rgb)
        # --- Semantic gate: accept classic NIM slides based on stable labels ---
        text_lower = " ".join(str(r.get("text", "")).lower() for r in ocr or [])
        has_nim = "net interest margin" in text_lower
        has_cb  = "commercial book" in text_lower
        has_grp = "group" in text_lower
        if has_nim and (has_cb or has_grp):
            which = [w for w, ok in (("nim", has_nim), ("cb", has_cb), ("grp", has_grp)) if ok]
            return (True, f"ok_semantic({'+' .join(which)})")
        df = extract_numbers(ocr)
        if df.empty:
            return (False, "no_numbers")
        # geometry filters (apply before value checks)
        top_cut = int(img_up.shape[0] * 0.62)
        cond_geom = (df["cy"] < top_cut)
        if RIGHT_HALF_ONLY:
            cond_geom &= (df["cx"] > (img_up.shape[1] // 2))

        # 2a) Preferred path: explicit percentage tokens
        df_pct = df[(df["is_pct"] == True) & cond_geom].copy()
        if not df_pct.empty:
            in_band = df_pct["value"].between(NIM_MIN, NIM_MAX)
            ratio = float(in_band.sum()) / float(len(df_pct))
            if ratio >= 0.6:
                return (True, "ok")
            else:
                return (False, f"non_nim_values_out_of_band({ratio:.2f})")

        # 2b) Fallback: some decks omit the % sign near the series values.
        # Accept plain numbers in the NIM range if units are explicit or implied, or if numeric signature is strong.
        title_text = text_lower  # already computed above
        has_units_pct = "(%)" in title_text or "margin (%)" in title_text or has_nim
        df_nums = df[(df["is_pct"] == False) & cond_geom].copy()
        if not df_nums.empty:
            in_band = df_nums["value"].between(NIM_MIN, NIM_MAX)
            ratio = float(in_band.sum()) / float(len(df_nums))
            # Case A: explicit or implied units in title ‚Üí accept when enough in-band hits
            if has_units_pct and ratio >= 0.6 and in_band.sum() >= 3:
                return (True, "ok_no_percent_signs")
            # Case B: title OCR may have missed units; if the quick keyword gate succeeded, accept with a stricter ratio
            if kw_ok and ratio >= 0.7 and in_band.sum() >= 3:
                return (True, "ok_numeric_signature")
            # Case C: strong structural evidence (quarters on bottom) + numeric signature in band
            q_xy_fallback = detect_quarters_easyocr(img_bgr)
            if len(q_xy_fallback) >= 4 and ratio >= 0.6 and in_band.sum() >= 3:
                return (True, "ok_structural_numeric_signature")

        # Final decision: if numeric signature still failed, report clearer reason
        if not kw_ok:
            return (False, "irrelevant_non_nim")
        else:
            return (False, "no_percentages_or_units")
    except Exception as e:
        return (False, f"exception:{e}")


# --- Helper: detect and order quarter labels from OCR ---
def detect_qlabels(ocr_results, img_width: int) -> list[str]:
    """
    Extract quarter tokens like 1Q25, 2Q2025 from OCR and return them left‚Üíright.
    We keep only tokens on the right half (where the series values live in your layout).
    """
    qtokens = []
    mid_x = img_width // 2
    for r in ocr_results or []:
        txt = str(r.get("text","")).strip()
        m = QUARTER_PAT.search(txt)
        if not m:
            continue
        x1,y1,x2,y2 = r["bbox"]
        cx = (x1 + x2) // 2
        if cx <= mid_x:
            continue  # ignore left panel quarters/titles
        q = f"{m.group(1)}Q{m.group(2)[-2:]}"  # normalize to 1Q25 style
        qtokens.append((cx, q))
    # sort by visual x-position and deduplicate by both text and proximity (ignore near-duplicates)
    qtokens.sort(key=lambda x: x[0])
    # Deduplicate by both text and proximity (ignore near-duplicates)
    ordered = []
    last_x = -9999
    last_q = None
    for x, q in qtokens:
        if last_q == q and abs(x - last_x) < 30:
            continue
        ordered.append(q)
        last_x, last_q = x, q
    return ordered

# === Focused bottom-of-chart scan for small quarter labels ===
def detect_qlabels_bottom(img_bgr) -> list[str]:
    """
    Focused pass: crop the bottom ~30% (where quarter labels usually sit),
    enhance contrast, OCR, and extract quarter tokens left‚Üíright.
    """
    try:
        H, W = img_bgr.shape[:2]
        y0 = int(H * 0.60)  # bottom 40%
        crop = img_bgr[y0:H, 0:W]
        # Enhance: grayscale -> bilateral -> CLAHE -> adaptive threshold
        gray = cv2.cvtColor(crop, cv2.COLOR_BGR2GRAY)
        gray = cv2.bilateralFilter(gray, d=7, sigmaColor=50, sigmaSpace=50)
        clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8,8))
        gray = clahe.apply(gray)
        thr = cv2.adaptiveThreshold(gray, 255, cv2.ADAPTIVE_THRESH_GAUSSIAN_C,
                                    cv2.THRESH_BINARY, 31, 8)
        # Morphological close to strengthen thin glyphs
        kernel = np.ones((3,3), np.uint8)
        thr = cv2.morphologyEx(thr, cv2.MORPH_CLOSE, kernel, iterations=1)
        # Upscale for small text
        up = cv2.resize(thr, None, fx=2.5, fy=2.5, interpolation=cv2.INTER_CUBIC)
        img_rgb = cv2.cvtColor(up, cv2.COLOR_GRAY2RGB)
        ocr = run_easyocr(img_rgb)
        # Map bboxes back to global coords: decide single-panel vs split-panel
        mid_x = W // 2
        left_quarters, right_quarters = [], []
        left_tokens_text, right_tokens_text = [], []
        for r in ocr or []:
            raw = str(r.get("text", "")).strip()
            x1,y1,x2,y2 = r["bbox"]
            cx_local = (x1 + x2) // 2
            cx_global = int(cx_local / 2.5)  # undo scale

            if cx_global <= mid_x:
                left_tokens_text.append(raw.lower())
            else:
                right_tokens_text.append(raw.lower())

            m = QUARTER_PAT.search(raw)
            if not m:
                continue
            q = f"{m.group(1)}Q{m.group(2)[-2:]}"
            if cx_global <= mid_x:
                left_quarters.append((cx_global, q))
            else:
                right_quarters.append((cx_global, q))

        def has_halfyear_or_9m(tokens: list[str]) -> bool:
            s = " ".join(tokens)
            return ("9m" in s) or ("1h" in s) or ("h1" in s) or ("h2" in s) or ("2h" in s)

        left_has_h = has_halfyear_or_9m(left_tokens_text)
        # Panel selection logic: prefer both halves unless left clearly half-year and right has ‚â•3 quarters
        if (not left_has_h) and (len(left_quarters) + len(right_quarters) >= 2):
            # Likely single panel or weak OCR on one side ‚Üí use both halves
            qtokens = left_quarters + right_quarters
        elif len(right_quarters) >= 3:
            # Strong right panel signal ‚Üí use right only
            qtokens = right_quarters
        else:
            # Fallback: use everything we found
            qtokens = left_quarters + right_quarters

        # Sort and dedupe close neighbors (‚â§18 px)
        qtokens.sort(key=lambda t: t[0])
        deduped = []
        last_x = -10**9
        for x, q in qtokens:
            if abs(x - last_x) <= 18:
                continue
            deduped.append((x, q))
            last_x = x

        return [q for _, q in deduped]
    except Exception:
        return []

# --- Same as detect_qlabels_bottom, but returns (x, label) for alignment ---
def detect_qlabels_bottom_with_xy(img_bgr) -> list[tuple[int, str]]:
    try:
        H, W = img_bgr.shape[:2]
        y0 = int(H * 0.60)
        crop = img_bgr[y0:H, 0:W]
        gray = cv2.cvtColor(crop, cv2.COLOR_BGR2GRAY)
        gray = cv2.bilateralFilter(gray, d=7, sigmaColor=50, sigmaSpace=50)
        clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8,8))
        gray = clahe.apply(gray)
        thr = cv2.adaptiveThreshold(gray, 255, cv2.ADAPTIVE_THRESH_GAUSSIAN_C,
                                    cv2.THRESH_BINARY, 31, 8)
        kernel = np.ones((3,3), np.uint8)
        thr = cv2.morphologyEx(thr, cv2.MORPH_CLOSE, kernel, iterations=1)
        up = cv2.resize(thr, None, fx=2.5, fy=2.5, interpolation=cv2.INTER_CUBIC)
        img_rgb = cv2.cvtColor(up, cv2.COLOR_GRAY2RGB)
        ocr = run_easyocr(img_rgb)

        mid_x = W // 2
        left_quarters, right_quarters = [], []
        left_tokens_text = []
        for r in ocr or []:
            raw = str(r.get("text", "")).strip()
            x1,y1,x2,y2 = r["bbox"]
            cx_local = (x1 + x2) // 2
            cx_global = int(cx_local / 2.5)
            if cx_global <= mid_x:
                left_tokens_text.append(raw.lower())
            m = QUARTER_PAT.search(raw)
            if not m:
                continue
            q = f"{m.group(1)}Q{m.group(2)[-2:]}"
            if cx_global <= mid_x:
                left_quarters.append((cx_global, q))
            else:
                right_quarters.append((cx_global, q))

        def has_halfyear_or_9m(tokens: list[str]) -> bool:
            s = " ".join(tokens)
            return ("9m" in s) or ("1h" in s) or ("h1" in s) or ("h2" in s) or ("2h" in s)

        left_has_h = has_halfyear_or_9m(left_tokens_text)
        if (not left_has_h) and (len(left_quarters) + len(right_quarters) >= 2):
            # Likely single panel or weak OCR on one side ‚Üí use both halves
            qtokens = left_quarters + right_quarters
        elif len(right_quarters) >= 3:
            # Strong right panel signal ‚Üí use right only
            qtokens = right_quarters
        else:
            # Fallback: use everything we found
            qtokens = left_quarters + right_quarters

        qtokens.sort(key=lambda t: t[0])
        deduped = []
        last_x = -10**9
        for x, q in qtokens:
            if abs(x - last_x) <= 18:
                continue
            deduped.append((x, q))
            last_x = x
        return deduped
    except Exception:
        return []

# --- Merge two ordered quarter lists ---
def _merge_ordered(primary: list[str], secondary: list[str]) -> list[str]:
    """
    Merge two left‚Üíright sequences, keeping 'primary' order and filling with
    any unseen items from 'secondary' in their order.
    """
    out = list(primary)
    seen = set(primary)
    for q in secondary:
        if q not in seen:
            out.append(q)
            seen.add(q)
    return out

# --- Expand a quarter label like '2Q24' forward n quarters ---
def _expand_quarters(start_q: str, n: int) -> list[str]:
    """
    Given a label like '2Q24', produce a forward sequence of n quarters:
    2Q24, 3Q24, 4Q24, 1Q25, 2Q25, ...
    """
    m = QUARTER_PAT.match(start_q) or QUARTER_SIMPLE_PAT.match(start_q)
    if not m:
        return []
    q = int(m.group(1))
    yy = int(m.group(2)[-2:])
    seq = []
    for _ in range(n):
        seq.append(f"{q}Q{yy:02d}")
        q += 1
        if q == 5:
            q = 1
            yy = (yy + 1) % 100
    return seq

# --- Find a plausible anchor quarter like 2Q24 from OCR or markdown tokens ---
def _anchor_quarter_from_texts(ocr_results, md_tokens: list[str]) -> str | None:
    """
    Find any token like 1Q2x..4Q2x from OCR texts or markdown tokens.
    Returns the first plausible anchor (normalized to e.g. 2Q24) or None.
    """
    # prefer bottom/ocr-derived tokens first (already parsed in detect_qlabels_bottom)
    # fallback: scan all OCR texts with simple pattern
    for r in ocr_results or []:
        txt = str(r.get("text","")).strip()
        m = QUARTER_SIMPLE_PAT.search(txt)
        if m:
            return f"{m.group(1)}Q{m.group(2)}"
    # fallback to any markdown token that matches the decade pattern
    for t in md_tokens or []:
        m = QUARTER_SIMPLE_PAT.match(t)
        if m:
            return f"{m.group(1)}Q{m.group(2)}"
    return None

def extract_series_from_df(df, img_up, ocr_results=None, qlabels_hint=None):
    H, W = img_up.shape[:2]
    mid_x = W//2
    top_band_min = int(H * 0.38)
    top_band_max = int(H * 0.58)

    # Detect bottom quarter labels (with x) early to infer layout
    detected_q_bot_xy = detect_quarters_easyocr(img_up)
    left_count  = sum(1 for x, _ in detected_q_bot_xy if x <= mid_x)
    right_count = sum(1 for x, _ in detected_q_bot_xy if x >  mid_x)
    # Heuristic: if we see ‚â•4 quarter tokens spanning both halves, it's a single-panel timeline
    single_panel = (len(detected_q_bot_xy) >= 4 and left_count >= 1 and right_count >= 1)

    # Filter tokens: keep right-half only for split panels; keep all for single panels
    if single_panel:
        pct = df[(df.is_pct==True)].copy()
        nums = df[(df.is_pct==False)].copy()
    else:
        pct = df[(df.is_pct==True) & (df.cx > mid_x)].copy()
        nums = df[(df.is_pct==False) & (df.cx > mid_x)].copy()

    if pct.empty:
        # Fallback for charts that omit the '%' sign on the value dots.
        # Use a wider top band and avoid forcing right-half on single-panel timelines.
        approx_top = int(H * 0.60)
        if single_panel:
            cx_mask = (df.cx > 0)  # keep all x for single panel
        else:
            cx_mask = (df.cx > mid_x)
        cand_pct = df[cx_mask & df.value.between(NIM_MIN, NIM_MAX) & (df.cy < approx_top)].copy()
        if not cand_pct.empty:
            cand_pct["is_pct"] = True
            pct = cand_pct

    nim_df = pd.DataFrame()
    if not pct.empty:
        # Try to split into two horizontal series by Y even when we have only 3 quarters (‚Üí 6 points)
        # Deduplicate by proximity on Y to stabilize clustering
        y_sorted = pct.sort_values("cy")["cy"].to_numpy()
        uniq_y = []
        last_y = -10**9
        for yy in y_sorted:
            if abs(yy - last_y) >= 6:  # 6px tolerance for duplicates
                uniq_y.append(yy)
                last_y = yy
        # Attempt k-means when we have at least 4 points total (‚âà 2 series √ó 2 quarters)
        if pct.shape[0] >= 4 and len(uniq_y) >= 2:
            labels, centers = kmeans_1d(pct["cy"].values, k=2)
            pct["series"] = labels
            order = np.argsort(centers)  # top (commercial) should have smaller y
            remap = {order[0]: "Commercial NIM (%)", order[1]: "Group NIM (%)"}
            pct["series_name"] = pct["series"].map(remap)
            # Sanity: ensure both series have data; else collapse to one
            counts = pct["series_name"].value_counts()
            if any(counts.get(name, 0) == 0 for name in ["Commercial NIM (%)", "Group NIM (%)"]):
                pct["series_name"] = "NIM (%)"
        else:
            pct["series_name"] = "NIM (%)"

        # Reuse bottom-quarter labels captured above
        detected_q_bot = [q for _, q in detected_q_bot_xy]
        detected_q_ocr = detect_qlabels(ocr_results or [], W) if ocr_results is not None else []
        if len(detected_q_bot) > len(detected_q_ocr):
            detected_q = _merge_ordered(detected_q_bot, detected_q_ocr)
        else:
            detected_q = _merge_ordered(detected_q_ocr, detected_q_bot)
        rows = []
        for name, sub in pct.groupby("series_name"):
            # Sort left‚Üíright and collapse near-duplicates (same x within 12px)
            sub_sorted = sub.sort_values("cx")
            uniq_rows = []
            last_x = -10**9
            for r in sub_sorted.itertuples(index=False):
                if abs(r.cx - last_x) < 12:
                    continue
                uniq_rows.append(r)
                last_x = r.cx
            # Keep only the right-panel portion (already ensured by cx>mid_x earlier)
            pick = list(uniq_rows)[-5:]  # cap to 5 most recent positions, but may be <5
            n = len(pick)
            if n == 0:
                continue
            labels = []
            # Robust mapping: map each value x to its nearest bottom quarter label x (right panel).
            # Filter any accidental half-year tokens (1H/2H/H1/H2/9M) just in case OCR returns them.
            def _is_half_token(t: str) -> bool:
                t = (t or "").lower().replace(" ", "")
                return ("9m" in t) or ("1h" in t) or ("h1" in t) or ("h2" in t) or ("2h" in t) or ("h24" in t) or ("h23" in t)

            # detected_q_bot_xy already respects split vs single panel. Keep right-panel positions only here.
            q_xy = []
            for x, q in detected_q_bot_xy:
                if x <= mid_x:
                    continue
                if _is_half_token(q):
                    continue
                q_xy.append((x, q))

            if len(q_xy) < n:
                # Borrow from left panel if they look like quarters (and not half-year)
                for x, q in detected_q_bot_xy:
                    if x > mid_x:
                        continue
                    if _is_half_token(q):
                        continue
                    q_xy.append((x, q))

            if q_xy:
                q_xy.sort(key=lambda t: t[0])  # left‚Üíright
                # Map each picked value to nearest quarter label by x-position
                vx = [rr.cx for rr in pick]
                qx = [x for x, _ in q_xy]
                ql = [q for _, q in q_xy]
                mapped = []
                for x in vx:
                    j = int(np.argmin([abs(x - xx) for xx in qx])) if qx else -1
                    mapped.append(ql[j] if j >= 0 else None)
                labels = mapped
            else:
                detected_q_ocr = detect_qlabels(ocr_results or [], W) if ocr_results is not None else []
                if detected_q_ocr:
                    labels = detected_q_ocr[-n:] if len(detected_q_ocr) >= n else detected_q_ocr

            # If still short, use markdown tokens; else expand from an anchor like 2Q24
            if (not labels) or (len(labels) != n):
                if qlabels_hint:
                    labels = qlabels_hint[-n:] if len(qlabels_hint) >= n else qlabels_hint
            if (not labels) or (len(labels) != n):
                anchor = _anchor_quarter_from_texts(ocr_results, qlabels_hint)
                if anchor:
                    labels = _expand_quarters(anchor, n)
            if (not labels) or (len(labels) != n):
                labels = [f"{i+1}Q??" for i in range(n)]
            # Ensure left‚Üíright order for consistent mapping to labels
            pick = sorted(pick, key=lambda r: r.cx)
            labels = list(labels)[:n]
            for i, r in enumerate(pick):
                if i >= len(labels):
                    break
                rows.append({"Quarter": labels[i], "series": name, "value": r.value})
        if rows:
            nim_table = pd.DataFrame(rows)
            # Guard: drop rows with missing labels
            nim_table = nim_table.dropna(subset=["Quarter", "series"])  
            # If multiple detections map to the same (Quarter, series), average them
            if not nim_table.empty:
                dupe_mask = nim_table.duplicated(subset=["Quarter", "series"], keep=False)
                if dupe_mask.any():
                    # Aggregate duplicates by mean (stable for minor OCR jitter)
                    nim_table = nim_table.groupby(["Quarter", "series"], as_index=False)["value"].mean()
            nim_df = nim_table.pivot(index="Quarter", columns="series", values="value").reset_index()

    # NIM-only mode: skip NII extraction entirely
    nii_df = pd.DataFrame()

    def _sort_q(df_in):
        if df_in is None or df_in.empty or "Quarter" not in df_in.columns:
            return df_in
        # Try to sort by numeric (Q#, year) if labels are like 2Q24; else keep input order
        def _key(q):
            m = QUARTER_PAT.match(str(q))
            if not m:
                return (999, 999)
            qn = int(m.group(1))
            yr = int(m.group(2)[-2:])  # last two digits
            return (yr, qn)
        try:
            return df_in.assign(_k=df_in["Quarter"].map(_key)).sort_values("_k").drop(columns=["_k"]).reset_index(drop=True)
        except Exception:
            return df_in.reset_index(drop=True)

    return _sort_q(nim_df), _sort_q(nii_df)

def _extract_md_context(dest_dir: Path, image_name: str) -> dict:
    """
    Best-effort: read the <pdf_stem>.md in dest_dir, find the <image_name> reference,
    capture nearby headings and a neighbor paragraph to build context.
    """
    try:
        # Prefer "<pdf_stem>.md", else any .md
        md_file = dest_dir / f"{dest_dir.name}.md"
        if not md_file.exists():
            cands = list(dest_dir.glob("*.md"))
            if not cands:
                return {}
            md_file = cands[0]
        lines = md_file.read_text(encoding="utf-8", errors="ignore").splitlines()
    except Exception:
        return {}

    # Find the image line
    idx = None
    for i, line in enumerate(lines):
        if image_name in line:
            idx = i
            break
    if idx is None:
        return {}

    # Walk upward to find up to two headings and a neighbor paragraph
    figure_title = None
    section_title = None
    neighbor_text = None

    # Find the closest preceding heading(s)
    for j in range(idx - 1, -1, -1):
        s = lines[j].strip()
        if not s:
            continue
        # markdown heading levels
        if s.startswith("#"):
            # Remove leading #'s and whitespace
            heading = s.lstrip("#").strip()
            if figure_title is None:
                figure_title = heading
            elif section_title is None:
                section_title = heading
                break

    # Find a non-empty paragraph between the image and last heading
    for j in range(idx - 1, -1, -1):
        s = lines[j].strip()
        if s and not s.startswith("#") and not s.startswith("![]("):
            neighbor_text = s
            break

    out = {}
    if figure_title: out["figure_title"] = figure_title
    if section_title: out["section_title"] = section_title
    if neighbor_text: out["neighbor_text"] = neighbor_text
    return out

def _parse_page_and_figure_from_name(image_name: str) -> dict:
    """
    Extract page/figure indices from names like '_page_0_Figure_2.jpeg'.
    """
    info = {}
    try:
        # Very loose parse
        if "_page_" in image_name:
            after = image_name.split("_page_", 1)[1]
            num = after.split("_", 1)[0]
            info["page"] = int(num) + 1  # 1-based for human readability
        if "Figure_" in image_name:
            after = image_name.split("Figure_", 1)[1]
            num = ""
            for ch in after:
                if ch.isdigit():
                    num += ch
                else:
                    break
            if num:
                info["figure_index"] = int(num)
    except Exception:
        pass
    return info

def is_relevant_image(img_path, keywords):
    """Robust relevance check for NIM slides.
    - Reuse the singleton EasyOCR reader (run_easyocr)
    - Accept split tokens like "Net" / "interest" / "margin" (not only the exact phrase)
    - Fallback: if we see ‚â•4 quarter labels on the bottom AND ‚â•3 top-band percent-like values in NIM range, treat as relevant.
    """
    try:
        img = cv2.imread(str(img_path))
        if img is None:
            return False

        # Pass A: OCR on lightly upscaled original
        view_a = cv2.resize(img, None, fx=1.3, fy=1.3, interpolation=cv2.INTER_CUBIC)
        ocr_a = run_easyocr(cv2.cvtColor(view_a, cv2.COLOR_BGR2RGB))
        tokens_a = [str(r.get("text","")).lower() for r in (ocr_a or [])]
        text_a = " ".join(tokens_a)

        # Quick phrase match (exact keywords like "net interest margin")
        if any(k in text_a for k in keywords):
            return True

        # Pass B: OCR on preprocessed thresholded view (more stable for thin fonts)
        _, _, thr, _ = preprocess(img)
        ocr_b = run_easyocr(cv2.cvtColor(thr, cv2.COLOR_GRAY2RGB))
        tokens_b = [str(r.get("text","")).lower() for r in (ocr_b or [])]
        text_b = " ".join(tokens_b)
        if any(k in text_b for k in keywords):
            return True

        # Token-level split-word check
        tokens = tokens_a + tokens_b
        has_net      = any("net" in t for t in tokens)
        has_interest = any("interest" in t for t in tokens)
        has_margin   = any("margin" in t for t in tokens or [])
        has_nim_abbr = any(re.search(r"\bnim\b", t) for t in tokens)
        has_cb       = any("commercial book" in t for t in tokens)
        has_grp      = any(re.search(r"\bgroup\b", t) for t in tokens)
        if (has_net and has_interest and has_margin) or has_nim_abbr:
            # Strengthen with context words if available
            if has_cb or has_grp:
                return True

        # Structural fallback: quarters + percent values in the NIM band
        q_xy = detect_quarters_easyocr(img)
        if len(q_xy) >= 4:
            # Look for ‚â•3 percent-ish values in the top band within NIM_MIN..NIM_MAX
            df = extract_numbers(ocr_b)
            if not df.empty:
                H, W = view_a.shape[:2]
                top_cut = int(H * 0.55)
                in_top = df["cy"] < top_cut
                in_band = df["value"].between(NIM_MIN, NIM_MAX)
                pctish = in_band  # allow numbers without % (the series sometimes omit it)
                if int((in_top & pctish).sum()) >= 3:
                    return True

        return False
    except Exception:
        return False


# =============== Pluggable OCR Extractor Framework ===============
class BaseChartExtractor:
    """
    Minimal interface for pluggable chart extractors.
    Implement `is_relevant` and `extract_table`, then call `handle_image(...)`.
    """
    name = "base"
    topic = "Generic Chart"
    units = None
    entity = None
    keywords = []

    def is_relevant(self, img_path: Path) -> bool:
        return is_relevant_image(img_path, self.keywords)

    def extract_table(self, img_path: Path, dest_dir: Path, pdf_name: str):
        """
        Return (df, context_dict) or (None, reason) on failure.
        context_dict will be merged into the _context object.
        """
        raise NotImplementedError

    def _build_context(self, pdf_name: str, img_path: Path, dest_dir: Path, extra: dict | None = None) -> dict:
        ctx = {
            "source_pdf": pdf_name,
            "image": img_path.name,
            "topic": self.topic,
        }
        if self.units:  ctx["units"]  = self.units
        if self.entity: ctx["entity"] = self.entity
        ctx.update(_parse_page_and_figure_from_name(img_path.name))
        md_ctx = _extract_md_context(dest_dir, img_path.name)
        if md_ctx: ctx.update(md_ctx)
        if extra:  ctx.update(extra)
        return ctx

    def _write_jsonl(self, out_path: Path, ctx: dict, df: pd.DataFrame):
        import json
        with open(out_path, "w", encoding="utf-8") as f:
            f.write(json.dumps({"_context": ctx}, ensure_ascii=False) + "\n")
            for rec in df.to_dict(orient="records"):
                rec_out = dict(rec)
                rec_out["_meta"] = {"source_pdf": ctx.get("source_pdf"), "image": ctx.get("image")}
                f.write(json.dumps(rec_out, ensure_ascii=False) + "\n")

    def handle_image(self, img_path: Path, dest_dir: Path, pdf_name: str, *, bypass_relevance: bool = False):
        if not bypass_relevance and not self.is_relevant(img_path):
            return False, "Not relevant"
        df, ctx_extra = self.extract_table(img_path, dest_dir, pdf_name)
        if df is None or df.empty:
            return False, ctx_extra if isinstance(ctx_extra, str) else "No data"
        # Build context and summary if possible
        ctx = self._build_context(pdf_name, img_path, dest_dir, extra=ctx_extra if isinstance(ctx_extra, dict) else {})
        try:
            cols = [c for c in df.columns if c != "Quarter"]
            if len(df) >= 2 and cols:
                def _pick_q(s):
                    return s if QUARTER_PAT.match(str(s) or "") else None
                _fq = str(df.iloc[0]["Quarter"])
                _lq = str(df.iloc[-1]["Quarter"])
                first_q = _pick_q(_fq) or (_fq if "??" not in _fq else "start")
                last_q  = _pick_q(_lq) or (_lq if "??" not in _lq else "end")
                pieces = []
                for col in cols[:2]:
                    a = df.iloc[0][col]
                    b = df.iloc[-1][col]
                    if pd.notna(a) and pd.notna(b):
                        suffix = "%" if "NIM" in col or ctx.get("units") == "percent" else ""
                        pieces.append(f"{col}: {a:.2f}{suffix} ‚Üí {b:.2f}{suffix}")
                if pieces:
                    ctx["summary"] = f"Figure shows {', '.join(pieces)} from {first_q} to {last_q}."
        except Exception:
            pass
        out_path = img_path.with_suffix(f".{self.name}.jsonl")
        self._write_jsonl(out_path, ctx, df)
        return True, str(out_path)

class NIMExtractor(BaseChartExtractor):
    name = "nim"
    topic = "Net Interest Margin"
    units = "percent"
    entity = "DBS"
    keywords = NIM_KEYWORDS

    def extract_table(self, img_path: Path, dest_dir: Path, pdf_name: str):
        # Reuse the existing pipeline
        img_bgr = load_image(img_path)
        img_up, gray, thr, scale = preprocess(img_bgr)
        img_rgb = cv2.cvtColor(thr, cv2.COLOR_GRAY2RGB)
        ocr = run_easyocr(img_rgb)
        df_tokens = extract_numbers(ocr)
        if df_tokens.empty:
            return None, "No numeric tokens detected"
        md_q = detect_qlabels_from_md(dest_dir, img_path.name)
        nim_df, _nii_df = extract_series_from_df(df_tokens, img_up, ocr_results=ocr, qlabels_hint=md_q)
        if nim_df is None or nim_df.empty:
            return None, "No NIM table detected"
        return nim_df, {"topic": self.topic, "units": self.units, "entity": self.entity}

# Registry of extractors (add more later)
EXTRACTORS: list[BaseChartExtractor] = [
    NIMExtractor(),
]
# ============= End pluggable extractor framework =============

# === Single-image rebuild/verify mode (optional) ===
# Set single_image_mode=True and point single_image_path to a specific extracted image
# to run the two-stage gate + extraction just for that file, then exit.
single_image_mode = False
single_image_paths: list[Path] = [
   
]
# Optional singular fallback path (legacy): set to a string/Path if you want a single-image override
single_image_path = None

# Legacy fallback (ignored i
 # Toggle: if True ‚Üí normal md5 skip; if False ‚Üí always reprocess
md5_check = True

# 3. Define the path to the directory containing your PDF files
# pdf_directory = Path("/Users/marcusfoo/Documents/GitHub/PTO_ICT3113_Grp1/All/")
pdf_directory = Path("./All/")
# === Fast path: single image only ===
# === Fast path: single/multi-image only ===
if single_image_mode:
    paths: list[Path] = []
    if single_image_paths:
        paths = [Path(p) for p in single_image_paths if p is not None]
    elif single_image_path:
        paths = [Path(single_image_path)]

    if not paths:
        print("‚ùå single_image_mode=True but no paths were provided.")
        sys.exit(1)

    print("--- Multi-image mode ---")
    successes = 0
    for img_path in paths:
        if not img_path.exists():
            print(f"‚ùå Missing: {img_path}")
            continue

        dest_dir = img_path.parent
        pdf_name = f"{dest_dir.name}.pdf"
        print(f"\nüñºÔ∏è  Image: {img_path.name}  |  PDF: {pdf_name}")

        # Quick quarter readout (EasyOCR-only, bottom axis)
        try:
            img_bgr_quarters = load_image(img_path)
            q_xy = detect_quarters_easyocr(img_bgr_quarters)
            if q_xy:
                print("   üìé Quarters (EasyOCR):", ", ".join([q for _,q in q_xy]))
            else:
                print("   üìé Quarters (EasyOCR): <none>")
        except Exception as _qe:
            print(f"   üìé Quarters (EasyOCR): error ‚Üí {_qe}")

        any_hit = False

        for ex in EXTRACTORS:
            print(f"   ¬∑ [{ex.name}] quick gate‚Ä¶", end=" ")
            if not ex.is_relevant(img_path):
                print("‚è≠Ô∏è  Not relevant")
                continue
            print("‚úÖ ok; strict gate‚Ä¶", end=" ")
            ok_strict, reason = is_strict_nim_image(img_path)
            if not ok_strict:
                print(f"‚è≠Ô∏è  Failed strict ({reason})")
                continue
            print("‚úÖ Strict OK ‚Äî extracting‚Ä¶")

            # Extract directly so we can print the table; still write JSONL
            df, ctx_extra = ex.extract_table(img_path, dest_dir, pdf_name)
            if df is None or df.empty:
                print("   ‚ö†Ô∏è No data extracted.")
                continue

            any_hit = True
            successes += 1

            # Build context + summary and write JSONL
            ctx = ex._build_context(pdf_name, img_path, dest_dir, extra=ctx_extra if isinstance(ctx_extra, dict) else {})
            try:
                cols = [c for c in df.columns if c != "Quarter"]
                if len(df) >= 2 and cols:
                    def _pick_q(s):
                        return s if QUARTER_PAT.match(str(s) or "") else None
                    _fq = str(df.iloc[0]["Quarter"]); _lq = str(df.iloc[-1]["Quarter"])
                    first_q = _pick_q(_fq) or (_fq if "??" not in _fq else "start")
                    last_q  = _pick_q(_lq) or (_lq if "??" not in _lq else "end")
                    pieces = []
                    for col in cols[:2]:
                        a = df.iloc[0][col]; b = df.iloc[-1][col]
                        if pd.notna(a) and pd.notna(b):
                            suffix = "%" if "NIM" in col or ctx.get("units") == "percent" else ""
                            pieces.append(f"{col}: {a:.2f}{suffix} ‚Üí {b:.2f}{suffix}")
                    if pieces:
                        ctx["summary"] = f"Figure shows {', '.join(pieces)} from {first_q} to {last_q}."
            except Exception:
                pass

            out_path = img_path.with_suffix(f".{ex.name}.jsonl")
            ex._write_jsonl(out_path, ctx, df)
            print(f"   üíæ Saved JSONL ‚Üí {out_path}")

            # Pretty-print the extracted table directly
            try:
                print("\n   üìä Extracted table:")
                print(df.to_string(index=False))
            except Exception:
                print(df)

        if not any_hit:
            print("   ‚è≠Ô∏è  No matching extractors for this image.")

    print(f"\n‚úÖ Done. Extracted from {successes} image(s).")
    # Prevent the pipeline (marker/md5) from running if notebook catches SystemExit
    globals()["_STOP_AFTER_SINGLE"] = True
    sys.exit(0)
    
# Check if the directory exists before proceeding
if not pdf_directory.is_dir():
    print(f"‚ùå ERROR: The directory was not found at '{pdf_directory}'.")
    sys.exit(1) # Exit the script if the directory doesn't exist

# 4. Check if the 'marker_single' command is available
if not shutil.which("marker_single"):
    print("‚ùå ERROR: The 'marker_single' command was not found.")
    print("Please ensure 'marker-pdf' is installed correctly in your environment's PATH.")
    sys.exit(1)

# Loop through every PDF file in the specified directory
for pdf_path in pdf_directory.glob("*.pdf"):
    print(f"--- Processing file: {pdf_path.name} ---")

    # 5. Let Marker create the <pdf_stem>/ subfolder automatically.
    # Point --output_dir to the *parent* folder so we don't end up with Demo PDF/Demo PDF/.
    output_parent = pdf_path.parent  # e.g., .../Demo/

    # Determine the destination folder Marker will create and a checksum sidecar file
    dest_dir = output_parent / pdf_path.stem
    checksum_file = dest_dir / ".marker_md5"

    # Compute the current md5 of the source PDF
    current_md5 = md5sum(pdf_path)

    # Define the expected main outputs (Marker uses the same stem)
    expected_md = dest_dir / f"{pdf_path.stem}.md"
    expected_json = dest_dir / f"{pdf_path.stem}.json"
    outputs_exist = expected_md.exists() and expected_json.exists()

    # md5 two-mode logic
    if md5_check:
        # Normal: skip if checksum matches and key outputs exist
        if dest_dir.is_dir() and checksum_file.exists() and outputs_exist:
            try:
                saved_md5 = checksum_file.read_text().strip()
            except Exception:
                saved_md5 = ""
            if saved_md5 == current_md5:
                print(f"‚è≠Ô∏è  Skipping {pdf_path.name}: up-to-date (md5 match). ‚Üí {dest_dir}")
                continue
            else:
                print(f"‚ôªÔ∏è  md5 mismatch ‚Üí reprocessing {pdf_path.name}")
                print(f"    saved={saved_md5}")
                print(f"    current={current_md5}")
                print(f"    Cleaning old outputs in: {dest_dir}")
                try:
                    shutil.rmtree(dest_dir)
                except Exception as _e:
                    print(f"    ‚ö†Ô∏è  Could not fully clean '{dest_dir}': {_e}")
        else:
            print("‚ÑπÔ∏è  No prior checksum or outputs ‚Üí processing normally.")
    else:
        # Force reprocess regardless of checksum
        print("‚öôÔ∏è  md5_check=False ‚Üí forcing reprocess (marker + OCR).")
        if dest_dir.exists():
            print(f"    Cleaning existing folder: {dest_dir}")
            try:
                shutil.rmtree(dest_dir)
            except Exception as _e:
                print(f"    ‚ö†Ô∏è  Could not fully clean '{dest_dir}': {_e}")

    try:
        # ======================================================================
        # 1. Run the CLI command to generate JSON output (with real-time output)
        # ======================================================================
        print(f"Running CLI command for JSON output on {pdf_path.name}...")
        json_command = [
            "marker_single",
            str(pdf_path),
            "--output_format", "json",
            "--output_dir", str(output_parent)
        ]
        # By removing 'capture_output', the subprocess will stream its output directly to the console in real-time.
        result_json = subprocess.run(json_command, check=True)
        print("‚úÖ JSON file generated successfully by CLI.")


        # ======================================================================
        # 2. Run the CLI command to generate Markdown and Image output (with real-time output)
        # ======================================================================
        print(f"\nRunning CLI command for Markdown and Image output on {pdf_path.name}...")
        md_command = [
            "marker_single",
            str(pdf_path),
            # Default format is markdown, so we don't need to specify it
            "--output_dir", str(output_parent)
        ]
        result_md = subprocess.run(md_command, check=True)
        print("‚úÖ Markdown file and images generated successfully by CLI.")

        print(f"\n‚ú® Files saved under '{output_parent / pdf_path.stem}'.")
        print("Note: Marker creates a subfolder named after the PDF automatically.")

        # === Post-processing: scan Marker images ‚Üí filter relevant ‚Üí save JSONL ===
        print("üîé Scanning extracted images for relevant charts/plots‚Ä¶")
        img_exts = (".png", ".jpg", ".jpeg")
        img_files = [p for p in dest_dir.rglob("*") if p.suffix.lower() in img_exts]
        if not img_files:
            print("   üñºÔ∏è  No images found in extracted folder.")
        for img_path in sorted(img_files):
            print(f"   ‚Ä¢ {img_path.name}")
            any_hit = False
            for ex in EXTRACTORS:
                # Stage 1: quick keyword/title skim
                print(f"      ¬∑ [{ex.name}] quick gate‚Ä¶", end=" ")
                if not ex.is_relevant(img_path):
                    print("‚è≠Ô∏è  Not relevant")
                    continue
                print("‚úÖ ok; strict gate‚Ä¶", end=" ")

                # Stage 2: strict verifier (geometry + numeric band + semantic anchors)
                ok_strict, reason = is_strict_nim_image(img_path)
                if not ok_strict:
                    print(f"‚è≠Ô∏è  Failed strict ({reason})")
                    continue

                any_hit = True
                print("‚úÖ Strict OK ‚Äî extracting‚Ä¶", end=" ")
                ok, msg = ex.handle_image(img_path, dest_dir, pdf_path.name, bypass_relevance=True)
                if ok:
                    print(f"üíæ Saved ‚Üí {msg}")
                else:
                    print(f"‚ö†Ô∏è Skipped ({msg})")
            if not any_hit:
                print("      ‚è≠Ô∏è  No matching extractors for this image.")

        # After OCR completes, write/update checksum sidecar
        try:
            dest_dir.mkdir(parents=True, exist_ok=True)
            checksum_file.write_text(current_md5)
            print(f"üßæ Recorded checksum in: {checksum_file}")
        except Exception as _e:
            print(f"‚ö†Ô∏è  Failed to write checksum file at '{checksum_file}': {_e}")

    except subprocess.CalledProcessError as e:
        print(f"\n‚ùå An error occurred while processing {pdf_path.name}.")
        print(f"Command: '{' '.join(e.cmd)}'")
        print(f"Return Code: {e.returncode}")
        print("Note: Outputs (if any) may be incomplete; checksum not updated.")
    except Exception as e:
        print(f"\nAn unexpected error occurred while processing {pdf_path.name}: {e}")
    
    print(f"--- Finished processing: {pdf_path.name} ---\n")

print("üéâ All PDF files in the directory have been processed.")


# === Stage-1 continuation: Build KB + FAISS (inline; no external scripts) ===
try:
    import sys, subprocess
    # 1) Ensure minimal deps (idempotent)
    for _pkg in ["sentence-transformers", "faiss-cpu", "pandas", "pyarrow", "numpy", "lxml", "tqdm"]:
        try:
            __import__(_pkg.split("-")[0])
        except Exception:
            print(f"üì¶ Installing {_pkg} ‚Ä¶")
            subprocess.check_call([sys.executable, "-m", "pip", "install", _pkg, "-q"])  # noqa: S603,S607

    import re, json, hashlib, time
    import numpy as _np, pandas as _pd, faiss  # type: ignore
    from io import StringIO as _StringIO
    from pathlib import Path as _Path
    from tqdm import tqdm as _tqdm
    from sentence_transformers import SentenceTransformer as _ST

    KB_IN_DIR  = str(pdf_directory)  # reuse the same directory processed above
    KB_OUT_DIR = str((_Path("./data_marker")).resolve())

    # ---- helpers (namespaced with kb_ to avoid collisions) ----
    def kb_file_hash_key(p: _Path) -> str:
        try:
            s = p.stat()
            return hashlib.md5(f"{p.resolve()}|{s.st_size}|{int(s.st_mtime)}".encode()).hexdigest()
        except FileNotFoundError:
            return ""

    def kb_safe_read(path: _Path) -> str:
        for enc in ("utf-8", "utf-8-sig", "latin-1"):
            try:
                return path.read_text(encoding=enc, errors="ignore")
            except Exception:
                continue
        return ""

    def kb_strip_md_basic(md: str) -> str:
        md = re.sub(r"```.*?```", " ", md, flags=re.DOTALL)
        md = re.sub(r"!\[[^\]]*\]\([^\)]*\)", " ", md)
        md = re.sub(r"\[([^\]]+)\]\([^\)]*\)", r"\1", md)
        md = re.sub(r"<[^>]+>", " ", md)
        md = re.sub(r"\s+", " ", md)
        return md.strip()

    def kb_coerce_numbers_df(df: _pd.DataFrame) -> _pd.DataFrame:
        df = df.copy()
        for c in df.columns:
            if df[c].dtype == object:
                s = df[c].astype(str).str.replace(",", "", regex=False)
                num = _pd.to_numeric(s, errors="coerce")
                df[c] = _np.where(num.notna(), num, s)
        return df

    def kb_extract_tables_from_marker_json_blocks(jtxt: str):
        try:
            data = json.loads(jtxt)
        except Exception:
            return []
        out = []
        def _page_from_id(node: dict, fallback):
            node_id = node.get("id") if isinstance(node.get("id"), str) else ""
            m = re.search(r"/page/(\d+)/", node_id or "")
            if m:
                try:
                    return int(m.group(1))
                except Exception:
                    pass
            return fallback
        def walk(node, current_page=None):
            if isinstance(node, dict):
                current_page = _page_from_id(node, current_page)
                if node.get("block_type") == "Table" and isinstance(node.get("html"), str):
                    html = node["html"]
                    try:
                        dfs = _pd.read_html(_StringIO(html))
                        for df in dfs:
                            out.append({"df": kb_coerce_numbers_df(df), "page": current_page})
                    except Exception:
                        pass
                for v in node.values():
                    walk(v, current_page)
            elif isinstance(node, list):
                for v in node:
                    walk(v, current_page)
        walk(data)
        return out

    def kb_extract_text_spans_with_pages(jtxt: str):
        try:
            data = json.loads(jtxt)
        except Exception:
            return []
        spans = []
        def _page_from_id(node: dict, fallback):
            node_id = node.get("id") if isinstance(node.get("id"), str) else ""
            m = re.search(r"/page/(\d+)/", node_id or "")
            if m:
                try:
                    return int(m.group(1))
                except Exception:
                    pass
            return fallback
        def _strip_html(s: str) -> str:
            s = re.sub(r"<[^>]+>", " ", s)
            s = re.sub(r"\s+", " ", s).strip()
            return s
        TEXT_BLOCKS = {"Text", "SectionHeader", "Paragraph", "Heading", "ListItem", "Caption", "Footer", "Header"}
        def walk(node, current_page=None):
            if isinstance(node, dict):
                current_page = _page_from_id(node, current_page)
                bt = node.get("block_type")
                if isinstance(bt, str) and bt in TEXT_BLOCKS:
                    html = node.get("html")
                    if isinstance(html, str) and html.strip():
                        txt = _strip_html(html)
                        if txt:
                            spans.append({"page": current_page, "text": txt})
                for v in node.values():
                    walk(v, current_page)
            elif isinstance(node, list):
                for v in node:
                    walk(v, current_page)
        walk(data)
        return spans

    def kb_markdown_tables_find(md_text: str):
        lines = md_text.splitlines()
        i, n = 0, len(lines)
        while i < n:
            if '|' in lines[i]:
                j = i + 1
                if j < n and re.search(r'^\s*\|?\s*:?-{3,}', lines[j]):
                    k = j + 1
                    while k < n and '|' in lines[k] and lines[k].strip():
                        k += 1
                    yield "\n".join(lines[i:k])
                    i = k; continue
            i += 1

    def kb_markdown_table_to_df(table_md: str):
        rows = [r.strip() for r in table_md.strip().splitlines() if r.strip()]
        if len(rows) < 2: return None
        def split_row(r: str):
            r = r.strip()
            if r.startswith('|'): r = r[1:]
            if r.endswith('|'): r = r[:-1]
            return [c.strip() for c in r.split('|')]
        cols = split_row(rows[0])
        if len(split_row(rows[1])) != len(cols): return None
        data = []
        for r in rows[2:]:
            cells = split_row(r)
            if len(cells) < len(cols): cells += [""] * (len(cols) - len(cells))
            if len(cells) > len(cols): cells = cells[:len(cols)]
            data.append(cells)
        try:
            df = _pd.DataFrame(data, columns=cols)
            return kb_coerce_numbers_df(df)
        except Exception:
            return None

    def kb_table_rows_to_sentences(df: _pd.DataFrame, doc_name: str, table_id: int):
        sents = []
        if df.shape[1] == 0: return sents
        label = df.columns[0]
        for ridx, row in df.reset_index(drop=True).iterrows():
            parts = [str(row[label])]
            for c in df.columns[1:]:
                parts.append(f"{c}: {row[c]}")
            sents.append(f"[{doc_name}] table#{table_id} row#{ridx} :: " + " | ".join(parts))
        return sents

    def kb_table_signature(df: _pd.DataFrame) -> str:
        try:
            cols = [str(c).strip() for c in df.columns]
            first_col = cols[0] if cols else ""
            years = sorted({c for c in cols if re.fullmatch(r"\d{4}", str(c))})
            nums = []
            for c in df.columns:
                s = _pd.to_numeric(_pd.Series(df[c]).astype(str).str.replace(",", "", regex=False), errors="coerce")
                vals = [float(x) for x in s.dropna().tolist()]
                nums.extend(vals)
            nums = [round(x, 3) for x in nums[:8]]
            return "|".join([
                f"first:{first_col.lower()}",
                "years:" + ",".join(years),
                "nums:" + ",".join(map(str, nums))
            ])
        except Exception:
            return ""

    def kb_encode(texts, model_name):
        model = _ST(model_name)
        embs = model.encode(texts, batch_size=64, show_progress_bar=True, normalize_embeddings=True)
        return _np.asarray(embs, dtype="float32")

    def kb_build_faiss(embs):
        d = int(embs.shape[1])
        idx = faiss.IndexFlatIP(d)  # cosine via normalized inner product
        idx.add(embs)
        return idx

    def kb_discover_docs(in_dir: _Path):
        docs = {}
        for f in sorted(in_dir.iterdir()):
            if not f.is_dir():
                continue
            nested = f / f.name
            md = list(f.glob("*.md")) + (list(nested.glob("*.md")) if nested.is_dir() else [])
            js = list(f.glob("*.json")) + (list(nested.glob("*.json")) if nested.is_dir() else [])
            jl = list(f.glob("*.jsonl")) + (list(nested.glob("*.jsonl")) if nested.is_dir() else [])
            if md or js or jl:
                docs[f.name] = {"md": sorted(md), "json": sorted(js), "jsonl": sorted(jl), "root": f}
        return docs

    def kb_load_jsonl(path: _Path) -> list:
        rows = []
        try:
            with open(path, "r", encoding="utf-8") as f:
                for line in f:
                    s = line.strip()
                    if not s:
                        continue
                    try:
                        rows.append(json.loads(s))
                    except Exception:
                        continue
        except Exception:
            return []
        return rows

    def kb_chunk_text(text: str, max_chars: int = 1600, overlap: int = 200):
        if not text: return []
        paras = [p.strip() for p in re.split(r"\n\s*\n", text) if p.strip()]
        chunks, buf, cur = [], [], 0
        def flush():
            nonlocal buf, cur
            if not buf: return
            s = "\n\n".join(buf).strip()
            step = max_chars - overlap
            for i in range(0, len(s), step):
                piece = s[i:i+step].strip()
                if piece: chunks.append(piece)
            buf.clear(); cur = 0
        for p in paras:
            if cur + len(p) + 2 <= max_chars:
                buf.append(p); cur += len(p) + 2
            else:
                flush(); buf.append(p); cur = len(p)
        flush(); return chunks

    def build_kb_with_tables(
        in_dir=KB_IN_DIR,
        out_dir=KB_OUT_DIR,
        model_name="sentence-transformers/all-MiniLM-L6-v2",
        max_chars=1600,
        overlap=200,
    ):
        in_path, out_path = _Path(in_dir), _Path(out_dir)
        out_path.mkdir(parents=True, exist_ok=True)

        kb_parquet     = out_path / "kb_chunks.parquet"
        kb_texts_npy   = out_path / "kb_texts.npy"
        kb_meta_json   = out_path / "kb_meta.json"
        kb_index_path  = out_path / "kb_index.faiss"
        kb_index_meta  = out_path / "kb_index_meta.json"
        kb_tables_parq = out_path / "kb_tables.parquet"
        kb_outline_parq = out_path / "kb_outline.parquet"

        cache = {}
        if kb_meta_json.exists():
            try:
                cache = json.loads(kb_meta_json.read_text(encoding="utf-8"))
            except Exception:
                cache = {}

        docs = kb_discover_docs(in_path)
        if not docs:
            print(f"‚ÑπÔ∏è No Marker artefacts found under: {in_path}")
            return {"docs_processed": 0, "chunks_total": 0, "tables_long_rows": 0, "paths": {}}
        print(f"üîé Found {len(docs)} docs under {in_path}")

        # outlines (optional)
        outline_rows = []
        for doc_name, art in docs.items():
            root = art.get("root", in_path / doc_name)
            candidates = list(root.glob("*_meta.json"))
            nested_same = root / doc_name
            if nested_same.is_dir():
                candidates += list(nested_same.glob("*_meta.json"))
            for meta_path in candidates:
                try:
                    data = json.loads(kb_safe_read(meta_path))
                    toc = data.get("table_of_contents") or data.get("toc") or []
                    for i, item in enumerate(toc):
                        outline_rows.append({
                            "doc_name": doc_name,
                            "source_path": str(meta_path),
                            "order": int(i),
                            "title": item.get("title"),
                            "page_id": item.get("page_id"),
                            "polygon": item.get("polygon"),
                        })
                except Exception:
                    pass
        if outline_rows:
            _pd.DataFrame(outline_rows).to_parquet(kb_outline_parq, engine="pyarrow", index=False)
            print(f"üìë Saved outline ‚Üí {kb_outline_parq} (rows={len(outline_rows)})")
        else:
            print("‚ÑπÔ∏è No *_meta.json outlines found.")

        rows_meta, chunk_texts = [], []
        tables_long = []
        json_sig_to_page = {}
        changed_any = False

        for name, art in _tqdm(docs.items(), desc="Processing docs"):
            md_files, json_files = art["md"], art["json"]
            jsonl_files = art.get("jsonl", [])
            keys = [kb_file_hash_key(p) for p in (md_files + json_files + jsonl_files)]
            doc_key = hashlib.md5("|".join(keys).encode()).hexdigest()

            if cache.get(name, {}).get("cache_key") == doc_key:
                continue
            changed_any = True

            # 1) JSON ‚Üí tables + page-text
            table_id = 0
            for jp in json_files:
                jtxt = kb_safe_read(jp)
                # tables with page capture
                for tb in kb_extract_tables_from_marker_json_blocks(jtxt):
                    df = tb["df"]; page_no = tb.get("page")
                    try:
                        sig = kb_table_signature(df)
                        if page_no is not None and sig:
                            json_sig_to_page[sig] = int(page_no)
                    except Exception:
                        pass
                    for sent in kb_table_rows_to_sentences(df, name, table_id):
                        if page_no is not None:
                            sent = f"[page {page_no}] " + sent
                        rows_meta.append({
                            "doc": name, "path": str(jp), "modality": "table_row",
                            "chunk": len(chunk_texts), "cache_key": doc_key,
                            "page": int(page_no) if page_no is not None else None,
                        })
                        chunk_texts.append(sent)
                    for ridx, row in df.reset_index(drop=True).iterrows():
                        for col in df.columns:
                            _val = row[col]
                            _val_str = "" if _pd.isna(_val) else str(_val)
                            try:
                                _val_num = _pd.to_numeric(_val_str.replace(",", ""), errors="coerce")
                            except Exception:
                                _val_num = _np.nan
                            tables_long.append({
                                "doc_name": name, "source_path": str(jp), "table_id": table_id,
                                "row_id": int(ridx), "column": str(col),
                                "value_str": _val_str,
                                "value_num": float(_val_num) if _pd.notna(_val_num) else None,
                                "page": int(page_no) if page_no is not None else None,
                            })
                    table_id += 1
                # page narrative
                spans = kb_extract_text_spans_with_pages(jtxt)
                by_page = {}
                for sp in spans:
                    by_page.setdefault(sp.get("page"), []).append(sp["text"])
                for page_no, texts in by_page.items():
                    page_text = kb_strip_md_basic("\n\n".join(texts))
                    for ch in kb_chunk_text(page_text, max_chars, overlap):
                        rows_meta.append({
                            "doc": name, "path": str(jp), "modality": "json",
                            "chunk": len(chunk_texts), "cache_key": doc_key,
                            "page": int(page_no) if page_no is not None else None,
                        })
                        chunk_texts.append(ch)

            # 1b) JSONL (extractor outputs)
            for jlp in jsonl_files:
                records = kb_load_jsonl(jlp)
                if not records:
                    continue
                ctx, data_recs = None, []
                for r in records:
                    if isinstance(r, dict) and "_context" in r:
                        ctx = r.get("_context")
                    elif isinstance(r, dict):
                        data_recs.append(r)
                page_no = None
                if isinstance(ctx, dict):
                    p = ctx.get("page")
                    if isinstance(p, int):
                        page_no = p
                df_jl = None
                if data_recs:
                    try:
                        df_jl = _pd.DataFrame(data_recs)
                        if "_meta" in df_jl.columns:
                            try:
                                df_jl = df_jl.drop(columns=["_meta"])
                            except Exception:
                                pass
                        df_jl = kb_coerce_numbers_df(df_jl)
                    except Exception:
                        df_jl = None
                if df_jl is not None and not df_jl.empty:
                    for sent in kb_table_rows_to_sentences(df_jl, name, table_id):
                        if page_no is not None:
                            sent = f"[page {page_no}] " + sent
                        rows_meta.append({
                            "doc": name, "path": str(jlp), "modality": "jsonl_row",
                            "chunk": len(chunk_texts), "cache_key": doc_key, "page": page_no,
                        })
                        chunk_texts.append(sent)
                    for ridx, row in df_jl.reset_index(drop=True).iterrows():
                        for col in df_jl.columns:
                            _val = row[col]
                            _val_str = "" if _pd.isna(_val) else str(_val)
                            try:
                                _val_num = _pd.to_numeric(_val_str.replace(",", ""), errors="coerce")
                            except Exception:
                                _val_num = _np.nan
                            tables_long.append({
                                "doc_name": name, "source_path": str(jlp), "table_id": table_id,
                                "row_id": int(ridx), "column": str(col),
                                "value_str": _val_str,
                                "value_num": float(_val_num) if _pd.notna(_val_num) else None,
                                "page": page_no,
                            })
                    table_id += 1
                if isinstance(ctx, dict) and isinstance(ctx.get("summary"), str) and ctx["summary"].strip():
                    rows_meta.append({
                        "doc": name, "path": str(jlp), "modality": "jsonl_summary",
                        "chunk": len(chunk_texts), "cache_key": doc_key, "page": page_no,
                    })
                    chunk_texts.append(f"[{name}] {ctx['summary'].strip()}")

            # 2) Markdown ‚Üí tables + non-table text
            for mp in md_files:
                md = kb_safe_read(mp)
                for tblock in kb_markdown_tables_find(md):
                    df = kb_markdown_table_to_df(tblock)
                    if df is None: 
                        continue
                    md_page = None
                    try:
                        md_sig = kb_table_signature(df)
                        if md_sig and md_sig in json_sig_to_page:
                            md_page = int(json_sig_to_page[md_sig])
                    except Exception:
                        md_page = None
                    for sent in kb_table_rows_to_sentences(df, name, table_id):
                        rows_meta.append({
                            "doc": name, "path": str(mp), "modality": "table_row",
                            "chunk": len(chunk_texts), "cache_key": doc_key, "page": md_page
                        })
                        chunk_texts.append(sent)
                    for ridx, row in df.reset_index(drop=True).iterrows():
                        for col in df.columns:
                            _val = row[col]
                            _val_str = "" if _pd.isna(_val) else str(_val)
                            try:
                                _val_num = _pd.to_numeric(_val_str.replace(",", ""), errors="coerce")
                            except Exception:
                                _val_num = _np.nan
                            tables_long.append({
                                "doc_name": name, "source_path": str(mp), "table_id": table_id,
                                "row_id": int(ridx), "column": str(col),
                                "value_str": _val_str,
                                "value_num": float(_val_num) if _pd.notna(_val_num) else None,
                                "page": md_page,
                            })
                    table_id += 1

                md_no_tables = md
                for tblock in kb_markdown_tables_find(md):
                    md_no_tables = md_no_tables.replace(tblock, "")
                for ch in kb_chunk_text(kb_strip_md_basic(md_no_tables), max_chars, overlap):
                    rows_meta.append({"doc": name, "path": str(mp), "modality": "md",
                                      "chunk": len(chunk_texts), "cache_key": doc_key, "page": None})
                    chunk_texts.append(ch)

            added_for_doc = sum(1 for r in rows_meta if r["cache_key"] == doc_key)
            cache[name] = {"cache_key": doc_key, "chunk_count": added_for_doc, "updated_at": int(time.time())}

        # If nothing changed and KB exists ‚Üí keep existing artifacts
        if (not changed_any) and ((out_path/"kb_chunks.parquet").exists()):
            print("‚úÖ No changes detected. Keeping existing KB and FAISS index.")
            texts_existing = _np.load(out_path/"kb_texts.npy", allow_pickle=True)
            return {
                "docs_processed": len(docs),
                "chunks_total": int(len(texts_existing)),
                "tables_long_rows": (_pd.read_parquet(out_path/"kb_tables.parquet").shape[0] if (out_path/"kb_tables.parquet").exists() else 0),
                "paths": {
                    "kb_chunks_parquet": str(out_path/"kb_chunks.parquet"),
                    "kb_texts_npy": str(out_path/"kb_texts.npy"),
                    "kb_meta_json": str(out_path/"kb_meta.json"),
                    "kb_tables_parquet": str(out_path/"kb_tables.parquet") if (out_path/"kb_tables.parquet").exists() else None,
                    "kb_index_faiss": str(out_path/"kb_index.faiss") if (out_path/"kb_index.faiss").exists() else None,
                    "kb_index_meta_json": str(out_path/"kb_index_meta.json") if (out_path/"kb_index_meta.json").exists() else None,
                }
            }

        # Persist KB + tables
        total = len(chunk_texts)
        print(f"üßæ Total new/updated text chunks (incl. table rows): {total}")
        _pd.DataFrame(rows_meta).to_parquet(out_path/"kb_chunks.parquet", engine="pyarrow", index=False)
        _np.save(out_path/"kb_texts.npy", _np.array(chunk_texts, dtype=object))
        if tables_long:
            _pd.DataFrame(tables_long).to_parquet(out_path/"kb_tables.parquet", engine="pyarrow", index=False)
            print(f"üìë Saved structured tables ‚Üí {out_path / 'kb_tables.parquet'} (rows={len(tables_long)})")
        else:
            print("üìë No structured tables detected this run.")
        (out_path/"kb_meta.json").write_text(json.dumps(cache, indent=2), encoding="utf-8")

        if total == 0:
            print("‚ö†Ô∏è No new chunks produced. Skipping embedding/index rebuild.")
            return {"docs_processed": len(docs), "chunks_total": 0, "tables_long_rows": len(tables_long), "paths": {}}

        # Embeddings + FAISS
        print("üß† Encoding embeddings ‚Ä¶")
        embs = kb_encode(chunk_texts, model_name)
        print(f"‚úÖ Embeddings shape: {embs.shape}")
        print("üì¶ Building FAISS index ‚Ä¶")
        idx = kb_build_faiss(embs)
        faiss.write_index(idx, str(out_path/"kb_index.faiss"))
        (out_path/"kb_index_meta.json").write_text(json.dumps({
            "model": model_name, "dim": int(embs.shape[1]), "total_vectors": int(embs.shape[0]),
            "metric": "cosine (via inner product on normalized vectors)",
        }, indent=2), encoding="utf-8")
        print(f"üéâ KB + index saved to: {out_path}")
        return {"docs_processed": len(docs), "chunks_total": int(total), "tables_long_rows": len(tables_long)}

    # ---- execute inline build ----
    print("\nüöÄ Building KB/index from extracted artifacts (JSON/MD/JSONL)‚Ä¶")
    _summary = build_kb_with_tables()
    print(_summary)
    print("‚úÖ KB build completed.")
except Exception as _e:
    print(f"‚ùå Inline KB build failed: {_e}")

--- Processing file: 1Q24_CEO_presentation.pdf ---
‚è≠Ô∏è  Skipping 1Q24_CEO_presentation.pdf: up-to-date (md5 match). ‚Üí All\1Q24_CEO_presentation
--- Processing file: 1Q24_CFO_presentation.pdf ---
‚è≠Ô∏è  Skipping 1Q24_CFO_presentation.pdf: up-to-date (md5 match). ‚Üí All\1Q24_CFO_presentation
--- Processing file: 1Q24_trading_update.pdf ---
‚è≠Ô∏è  Skipping 1Q24_trading_update.pdf: up-to-date (md5 match). ‚Üí All\1Q24_trading_update
--- Processing file: 1Q25_CEO_presentation.pdf ---
‚è≠Ô∏è  Skipping 1Q25_CEO_presentation.pdf: up-to-date (md5 match). ‚Üí All\1Q25_CEO_presentation
--- Processing file: 1Q25_CFO_presentation.pdf ---
‚è≠Ô∏è  Skipping 1Q25_CFO_presentation.pdf: up-to-date (md5 match). ‚Üí All\1Q25_CFO_presentation
--- Processing file: 1Q25_trading_update.pdf ---
‚è≠Ô∏è  Skipping 1Q25_trading_update.pdf: up-to-date (md5 match). ‚Üí All\1Q25_trading_update
--- Processing file: 2Q24_CEO_presentation.pdf ---
‚è≠Ô∏è  Skipping 2Q24_CEO_presentation.pdf: up-to-date (md5 match).

  from .autonotebook import tqdm as notebook_tqdm



üöÄ Building KB/index from extracted artifacts (JSON/MD/JSONL)‚Ä¶
üîé Found 24 docs under All
üìë Saved outline ‚Üí C:\Users\Aaron\Documents\GitHub\PTO_ICT3113_Grp11\data_marker\kb_outline.parquet (rows=3325)


Processing docs: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 24/24 [00:00<00:00, 1655.86it/s]

‚úÖ No changes detected. Keeping existing KB and FAISS index.
{'docs_processed': 24, 'chunks_total': 13587, 'tables_long_rows': 52949, 'paths': {'kb_chunks_parquet': 'C:\\Users\\Aaron\\Documents\\GitHub\\PTO_ICT3113_Grp11\\data_marker\\kb_chunks.parquet', 'kb_texts_npy': 'C:\\Users\\Aaron\\Documents\\GitHub\\PTO_ICT3113_Grp11\\data_marker\\kb_texts.npy', 'kb_meta_json': 'C:\\Users\\Aaron\\Documents\\GitHub\\PTO_ICT3113_Grp11\\data_marker\\kb_meta.json', 'kb_tables_parquet': 'C:\\Users\\Aaron\\Documents\\GitHub\\PTO_ICT3113_Grp11\\data_marker\\kb_tables.parquet', 'kb_index_faiss': 'C:\\Users\\Aaron\\Documents\\GitHub\\PTO_ICT3113_Grp11\\data_marker\\kb_index.faiss', 'kb_index_meta_json': 'C:\\Users\\Aaron\\Documents\\GitHub\\PTO_ICT3113_Grp11\\data_marker\\kb_index_meta.json'}}
‚úÖ KB build completed.





In [4]:
# --- Sanity check FAISS retrieval vs. table storage ---
from g2x import KBEnv
import pandas as pd, numpy as np, re, math

kb = KBEnv(base="./data_marker")

def show_search(q, k=12):
    print(f"\nüîé FAISS search ‚Üí {q}")
    df = kb.search(q, k=k)
    if df is None or df.empty:
        print("  (no hits)")
        return df
    cols = ["rank","score","doc","modality","chunk","path"]
    print(df[cols].to_string(index=False))
    for _, row in df.head(2).iterrows():
        print("\n--- snippet ---")
        print(str(row["text"])[:800])
    return df

# 1) Similarity probes
queries = [
    "Operating expenses 2024 2023 YoY",
    "Expenses 2024 2023 table",
    "Operating expenses and income YoY 2024 2023",
    "Total expenses 2024 2023 DBS annual report",
    "Net interest margin quarter Q1 Q2 Q3 Q4",
]
_ = [show_search(q, k=12) for q in queries]

# 2) Direct read from kb_tables.parquet (bypass FAISS)
tbl = kb.tables_df.copy()
print(f"\nüì¶ kb_tables rows: {len(tbl)} | cols: {list(tbl.columns)}")

# ---------- helpers ----------
def _norm(s: str) -> str:
    s = "" if s is None else str(s)
    s = s.lower().replace("&"," and ")
    s = re.sub(r"[^a-z0-9 ]+", " ", s)
    s = re.sub(r"\s+", " ", s).strip()
    return s

def is_year(s) -> bool:
    return bool(re.fullmatch(r"\d{4}", str(s or "").strip()))

_qpat = re.compile(r"(?i)(?:\b([1-4])\s*q\s*((?:20)?\d{2})\b|\bq\s*([1-4])\s*(?:fy)?\s*((?:20)?\d{2})\b|\b([1-4])q((?:20)?\d{2})\b)")
def parse_quarter_token(s: str):
    if s is None: return None
    s = str(s)
    m = _qpat.search(s)
    if not m: return None
    if m.group(1):   q, y = int(m.group(1)), int(m.group(2))
    elif m.group(3): q, y = int(m.group(3)), int(m.group(4))
    else:            q, y = int(m.group(5)), int(m.group(6))
    if y < 100: y += 2000
    return f"{q}Q{y}"

def to_num(x):
    if x is None: return np.nan
    s = str(x).strip()
    if not s or s in {"‚Äî","‚Äì","-"}: return np.nan
    neg = s.startswith("(") and s.endswith(")")
    s = s.strip("()").replace(",", "")
    s = re.sub(r"[^0-9eE\.\-%]", "", s)
    if s.endswith("%"):
        s = s[:-1]
        try:
            v = float(s)/100.0
            return -v if neg else v
        except:
            return np.nan
    try:
        v = float(s)
        return -v if neg else v
    except:
        return np.nan

# normalize + fix numbers when value_num is NaN
tbl["val_norm"] = tbl["value_str"].astype(str).map(_norm)
tbl["col_norm"] = tbl["column"].astype(str).map(_norm)
tbl["column_str"] = tbl["column"].astype(str)
tbl["value_num_fix"] = tbl["value_num"]
mask_nan = tbl["value_num_fix"].isna() & tbl["value_str"].notna()
tbl.loc[mask_nan, "value_num_fix"] = tbl.loc[mask_nan, "value_str"].map(to_num)

# ---------- A) NIM by quarter ----------
nim_terms = ["net interest margin", "nim", "net interest margin group", "nim group"]
nim_mask = pd.Series(False, index=tbl.index)
for t in nim_terms:
    tnorm = _norm(t)
    nim_mask |= tbl["val_norm"].str.contains(rf"\b{re.escape(tnorm)}\b", regex=True) \
             |  tbl["col_norm"].str.contains(rf"\b{re.escape(tnorm)}\b", regex=True)

nim_rows = []
if nim_mask.any():
    for doc, tid in (
        tbl[nim_mask][["doc_name","table_id"]]
        .drop_duplicates()
        .itertuples(index=False, name=None)
    ):
        sub = tbl[(tbl["doc_name"]==doc) & (tbl["table_id"]==tid)]
        for rid in sorted(sub["row_id"].unique()):
            r = sub[sub["row_id"]==rid]
            if not (r["val_norm"].str.contains(r"\bnim\b|\bnet interest margin\b", regex=True).any() or
                    r["col_norm"].str.contains(r"\bnim\b|\bnet interest margin\b", regex=True).any()):
                continue
            series_q = {}
            for _, cell in r.iterrows():
                qlab = parse_quarter_token(cell["column_str"]) or parse_quarter_token(cell["value_str"])
                if not qlab: 
                    continue
                v = cell["value_num_fix"]
                if pd.isna(v): 
                    continue
                val = float(v)
                if val < 0.5:  # fractions ‚Üí %
                    val = round(val*100.0, 2)
                series_q[qlab] = val
            if series_q:
                label_guess = r["value_str"].dropna().astype(str).head(1)
                nim_rows.append({
                    "doc":doc, "table_id":tid, "row_id":rid,
                    "label": (label_guess.iloc[0] if not label_guess.empty else "Net interest margin"),
                    "series_q": series_q
                })

def _qkey(k):
    m = re.match(r"([1-4])Q(20\d{2})$", k)
    return (int(m.group(2)), int(m.group(1))) if m else (0,0)

nim_rows.sort(key=lambda r: (-(len(r["series_q"])),
                             -_qkey(sorted(r["series_q"].keys())[-1])[0],
                             -_qkey(sorted(r["series_q"].keys())[-1])[1]))

print("\n=== NIM (quarters) ‚Äî top 2 candidates ===")
if nim_rows:
    for r in nim_rows[:2]:
        last5 = sorted(r["series_q"].keys(), key=_qkey)[-5:]
        print(f"doc={r['doc']} table={r['table_id']} row={r['row_id']} | label={r['label']}")
        print("  last5:", ", ".join(f"{k}: {r['series_q'][k]}" for k in last5))
else:
    print("‚ö†Ô∏è No quarter NIM extracted. (Likely chart-only or prose-only.)")
# ---------- B) Operating Expenses by year ----------
exp_terms = ["operating expenses", "total expenses", "expenses", "opex"]
exp_mask = pd.Series(False, index=tbl.index)
for t in exp_terms:
    tnorm = _norm(t)
    exp_mask |= tbl["val_norm"].str.contains(rf"\b{re.escape(tnorm)}\b", regex=True) \
             |  tbl["col_norm"].str.contains(rf"\b{re.escape(tnorm)}\b", regex=True)

exp_rows = []
if exp_mask.any():
    for (doc, tid, rid), sub in tbl.groupby(["doc_name","table_id","row_id"]):
        if not exp_mask.loc[sub.index].any():
            continue
        series = {}
        for _, cell in sub.iterrows():
            col = str(cell["column"])
            if is_year(col) and pd.notna(cell["value_num_fix"]):
                series[int(col)] = float(cell["value_num_fix"])
        if len(series) >= 2:
            label_guess = sub[~sub["column"].astype(str).map(is_year)]["value_str"].dropna().astype(str).head(1)
            label = label_guess.iloc[0] if not label_guess.empty else "Expenses"
            exp_rows.append({"doc":doc,"table_id":tid,"row_id":rid,"label":label,"series":dict(sorted(series.items()))})

exp_rows.sort(key=lambda r: (-(len(r["series"])), -max(r["series"].keys()) if r["series"] else 0))

print("\n=== Operating Expenses (years) ‚Äî top 2 candidates ===")
if exp_rows:
    for r in exp_rows[:2]:
        ys = sorted(r["series"].keys())[-3:]
        print(f"doc={r['doc']} table={r['table_id']} row={r['row_id']} | label={r['label']}")
        print("  last years:", ", ".join(f"{y}: {r['series'][y]}" for y in ys))
else:
    print("‚ö†Ô∏è No expense rows with year columns extracted.")

# ---------- C) Operating/Total Income by year ----------
inc_terms = ["operating income", "total operating income", "total income", "income"]
inc_mask = pd.Series(False, index=tbl.index)
for t in inc_terms:
    tnorm = _norm(t)
    inc_mask |= tbl["val_norm"].str.contains(rf"\b{re.escape(tnorm)}\b", regex=True) \
             |  tbl["col_norm"].str.contains(rf"\b{re.escape(tnorm)}\b", regex=True)

inc_rows = []
if inc_mask.any():
    for (doc, tid, rid), sub in tbl.groupby(["doc_name","table_id","row_id"]):
        if not inc_mask.loc[sub.index].any():
            continue
        series = {}
        for _, cell in sub.iterrows():
            col = str(cell["column"])
            if is_year(col) and pd.notna(cell["value_num_fix"]):
                series[int(col)] = float(cell["value_num_fix"])
        if len(series) >= 2:
            label_guess = sub[~sub["column"].astype(str).map(is_year)]["value_str"].dropna().astype(str).head(1)
            label = label_guess.iloc[0] if not label_guess.empty else "Income"
            inc_rows.append({"doc":doc,"table_id":tid,"row_id":rid,"label":label,"series":dict(sorted(series.items()))})

inc_rows.sort(key=lambda r: (-(len(r["series"])), -max(r["series"].keys()) if r["series"] else 0))

print("\n=== Operating/Total Income (years) ‚Äî top 2 candidates ===")
if inc_rows:
    for r in inc_rows[:2]:
        ys = sorted(r["series"].keys())[-3:]
        print(f"doc={r['doc']} table={r['table_id']} row={r['row_id']} | label={r['label']}")
        print("  last years:", ", ".join(f"{y}: {r['series'][y]}" for y in ys))
else:
    print("‚ö†Ô∏è No income rows with year columns extracted.")

# ---------- D) Efficiency Ratio preview (if both present) ----------
if exp_rows and inc_rows:
    ex, inc = exp_rows[0], inc_rows[0]
    years = sorted(set(ex["series"]).intersection(inc["series"]))[-3:]
    print("\n=== Efficiency Ratio preview (Opex √∑ Income, %) ‚Äî aligned last 3 years ===")
    if years:
        print("Year | Opex | Income | Ratio%")
        print("-----|------|--------|-------")
        for y in years:
            ov, iv = ex["series"][y], inc["series"][y]
            ratio = (ov/iv*100.0) if iv else math.nan
            rs = "‚Äî" if not iv else f"{ratio:.2f}%"
            print(f"{y} | {ov} | {iv} | {rs}")
    else:
        print("‚ö†Ô∏è No overlapping fiscal years between the chosen Opex and Income rows.")

[BM25] ‚úì Indexed 13587 documents
[Reranker] ‚úì Loaded cross-encoder/ms-marco-MiniLM-L-6-v2

üîé FAISS search ‚Üí Operating expenses 2024 2023 YoY
[Search] RRF fusion: 82 candidates
[Rerank] Reranking top-24 candidates...
[Reranker] ‚úì Loaded cross-encoder/ms-marco-MiniLM-L-6-v2

üîé FAISS search ‚Üí Operating expenses 2024 2023 YoY
[Search] RRF fusion: 82 candidates
[Rerank] Reranking top-24 candidates...
[Rerank] ‚úì Reranked to top-12
 rank     score                    doc  modality  chunk                                                   path
    1  4.992685 dbs-annual-report-2024 table_row  10563 All\dbs-annual-report-2024\dbs-annual-report-2024.json
    2  4.874355 dbs-annual-report-2023 table_row   7504 All\dbs-annual-report-2023\dbs-annual-report-2023.json
    3  4.852805 dbs-annual-report-2024 table_row  12719   All\dbs-annual-report-2024\dbs-annual-report-2024.md
    4  4.765504 dbs-annual-report-2023 table_row   9552   All\dbs-annual-report-2023\dbs-annual-report-2023.m

## 4. Baseline Pipeline

**Baseline (starting point)**
*   Naive chunking.
*   Single-pass vector search.
*   One LLM call, no caching.

### Gemini Version 2

In [6]:
#!/usr/bin/env python3
# -*- coding: utf-8 -*-

"""
g2x.py ‚Äî Agentic RAG with tools on top of data_marker/ (FAISS + Marker outputs)
       - BM25, Reciprocal Rank Fusion, and Cross-Encoder Reranking

Artifacts required in ./data_marker:
  - kb_index.faiss
  - kb_index_meta.json
  - kb_texts.npy
  - kb_chunks.parquet
  - kb_tables.parquet        (recommended for table tools)
  - kb_outline.parquet       (optional, for section hints)

Tools exposed:
  1) CalculatorTool           -> safe arithmetic, deltas, YoY
  2) TableExtractionTool      -> pull metric rows; extract {year -> value}
  3) MultiDocCompareTool      -> compare a metric across multiple docs
Also:
  - Vector search (FAISS) for grounding

Agent runtime: Plan -> Act -> Observe -> (optional) Refine -> Final
"""

from pathlib import Path
from dataclasses import dataclass
from rank_bm25 import BM25Okapi
from sentence_transformers import CrossEncoder, SentenceTransformer
from typing import List, Dict, Any, Optional, Tuple
from openai import OpenAI
from concurrent.futures import ThreadPoolExecutor

import re, json, math, ast
import numpy as np
import pandas as pd
import asyncio
import faiss
import os

# ----------------------------- LLM (single-call baseline) -----------------------------

def _make_llm_client():
    """Minimal provider selection for LLM"""
    groq_key = os.environ.get("GROQ_API_KEY")
    if groq_key:
        client = OpenAI(api_key=groq_key, base_url="https://api.groq.com/openai/v1")
        model = os.getenv("GROQ_MODEL", "openai/gpt-oss-20b")
        return ("groq", client, model)
    
    gem_key = os.environ.get("GEMINI_API_KEY")
    if gem_key:
        return ("gemini", None, os.getenv("GEMINI_MODEL_NAME", "models/gemini-2.5-flash"))
    
    raise RuntimeError("No LLM credentials found. Set GROQ_API_KEY or GEMINI_API_KEY.")

def _llm_provider_info() -> str:
    try:
        prov, _, model = _make_llm_client()
        return f"{prov}:{model}"
    except Exception as e:
        return f"unconfigured ({e})"

def _llm_single_call(prompt: str, system: str = "You are a precise finance analyst.") -> str:
    prov, client, model = _make_llm_client()
    print(f"[LLM] provider={prov} model={model}")
    if prov == "groq":
        try:
            chat = client.chat.completions.create(
                model=model,
                messages=[
                    {"role": "system", "content": system},
                    {"role": "user", "content": prompt},
                ],
                temperature=0.1,
            )
            return chat.choices[0].message.content.strip()
        except Exception as e:
            return f"LLM error: {e}"
    
    try:
        from google import generativeai as genai
        genai.configure(api_key=os.environ["GEMINI_API_KEY"])
        model_obj = genai.GenerativeModel(model)
        out = model_obj.generate_content(prompt)
        return getattr(out, "text", "") or "LLM returned empty response."
    except Exception as e:
        return f"LLM error (Gemini): {e}"


def _page_or_none(x):
    try:
        import math
        import pandas as pd
        if x is None:
            return None
        if (hasattr(pd, 'isna') and pd.isna(x)) or (isinstance(x, float) and math.isnan(x)):
            return None
        return int(x)
    except Exception:
        return None


# ----------------------------- KB loader with BM25 + Reranker -----------------------------

class KBEnv:
    def __init__(self, base="./data_marker", enable_bm25=True, enable_reranker=True):
        self.base = Path(base)
        self.faiss_path = self.base / "kb_index.faiss"
        self.meta_path = self.base / "kb_index_meta.json"
        self.texts_path = self.base / "kb_texts.npy"
        self.chunks_path = self.base / "kb_chunks.parquet"
        self.tables_path = self.base / "kb_tables.parquet"
        self.outline_path = self.base / "kb_outline.parquet"

        if not self.faiss_path.exists():
            raise FileNotFoundError(self.faiss_path)
        if not self.meta_path.exists():
            raise FileNotFoundError(self.meta_path)
        if not self.texts_path.exists():
            raise FileNotFoundError(self.texts_path)
        if not self.chunks_path.exists():
            raise FileNotFoundError(self.chunks_path)

        self.texts: List[str] = np.load(self.texts_path, allow_pickle=True).tolist()
        self.meta_df: pd.DataFrame = pd.read_parquet(self.chunks_path)
        
        if 'page' in self.meta_df.columns:
            self.meta_df['page'] = pd.to_numeric(self.meta_df['page'], errors='coerce').astype('Int64')
            
        if len(self.texts) != len(self.meta_df):
            raise ValueError(f"texts ({len(self.texts)}) and meta ({len(self.meta_df)}) mismatch")

        self.tables_df: Optional[pd.DataFrame] = (
            pd.read_parquet(self.tables_path) if self.tables_path.exists() else None
        )
        self.outline_df: Optional[pd.DataFrame] = (
            pd.read_parquet(self.outline_path) if self.outline_path.exists() else None
        )

        # FAISS index
        self.index = faiss.read_index(str(self.faiss_path))
        idx_meta = json.loads(self.meta_path.read_text(encoding="utf-8"))
        self.model_name = idx_meta.get("model", "sentence-transformers/all-MiniLM-L6-v2")
        self.embed_dim = int(idx_meta.get("dim", 384))
        self.model = SentenceTransformer(self.model_name)

        # ========== NEW: BM25 Index ==========
        self.bm25 = None
        if enable_bm25:
            # print("[BM25] Building BM25 index...")
            tokenized_corpus = [text.lower().split() for text in self.texts]
            self.bm25 = BM25Okapi(tokenized_corpus)
            print(f"[BM25] ‚úì Indexed {len(self.texts)} documents")
        elif enable_bm25:
            print("[BM25] ‚úó rank_bm25 not installed, skipping BM25")

        # ========== NEW: Reranker ==========
        self.reranker = None
        if enable_reranker:
            # print("[Reranker] Loading cross-encoder...")
            self.reranker = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2')
            print("[Reranker] ‚úì Loaded cross-encoder/ms-marco-MiniLM-L-6-v2")
        elif enable_reranker:
            print("[Reranker] ‚úó CrossEncoder unavailable")

    def _embed(self, texts: List[str]) -> np.ndarray:
        v = self.model.encode(texts, normalize_embeddings=True)
        return np.asarray(v, dtype="float32")

    # ========== NEW: Hybrid Search with BM25 + Vector + RRF ==========
    def search(
        self, 
        query: str, 
        k: int = 12,
        alpha: float = 0.6,  # Weight for vector vs BM25 (0.0=pure BM25, 1.0=pure vector)
        rerank_top_k: int = None  # Rerank top candidates (default: 2*k)
    ) -> pd.DataFrame:
        """
        Hybrid search with BM25 + Vector + optional RRF + optional Reranking
        
        Pipeline:
        1. BM25 search ‚Üí get scores
        2. Vector search ‚Üí get scores
        3. Fusion: RRF (reciprocal rank) or weighted score fusion
        4. Rerank: Cross-encoder on top candidates
        5. Return top-k
        """
        if rerank_top_k is None:
            rerank_top_k = k * 2  # Get 2x candidates for reranking

        # ========== Step 1: Vector Search ==========
        qv = self._embed([query])
        vec_scores, vec_idxs = self.index.search(qv, min(rerank_top_k * 2, len(self.texts)))
        vec_idxs, vec_scores = vec_idxs[0], vec_scores[0]
        
        # Filter valid indices
        vec_results = {int(i): float(s) for i, s in zip(vec_idxs, vec_scores) if i >= 0 and i < len(self.texts)}

        # ========== Step 2: BM25 Search ==========
        bm25_results = {}
        if self.bm25 is not None:
            query_tokens = query.lower().split()
            bm25_scores = self.bm25.get_scores(query_tokens)
            
            # Normalize BM25 scores to [0, 1]
            max_bm25 = max(bm25_scores) if len(bm25_scores) > 0 else 1.0
            if max_bm25 > 0:
                bm25_scores = bm25_scores / max_bm25
            
            # Get top candidates
            top_bm25_idx = np.argsort(bm25_scores)[-rerank_top_k * 2:][::-1]
            bm25_results = {int(i): float(bm25_scores[i]) for i in top_bm25_idx if bm25_scores[i] > 0}

        # ========== Step 3: Fusion (RRF or Weighted Score) ==========
        all_indices = set(vec_results.keys()) | set(bm25_results.keys())
        
        if self.bm25 is not None:
            # Reciprocal Rank Fusion
            vec_ranks = {idx: rank for rank, idx in enumerate(sorted(vec_results, key=vec_results.get, reverse=True), 1)}
            bm25_ranks = {idx: rank for rank, idx in enumerate(sorted(bm25_results, key=bm25_results.get, reverse=True), 1)}
            
            k_rrf = 60  # RRF constant
            fused_scores = {}
            for idx in all_indices:
                vec_rank = vec_ranks.get(idx, len(self.texts))
                bm25_rank = bm25_ranks.get(idx, len(self.texts))
                fused_scores[idx] = (1 / (k_rrf + vec_rank)) + (1 / (k_rrf + bm25_rank))
            
            print(f"[Search] RRF fusion: {len(all_indices)} candidates")
        else:
            # Weighted score fusion (fallback if BM25 disabled or RRF=False)
            fused_scores = {}
            for idx in all_indices:
                vec_score = vec_results.get(idx, 0.0)
                bm25_score = bm25_results.get(idx, 0.0)
                fused_scores[idx] = alpha * vec_score + (1 - alpha) * bm25_score
            
            print(f"[Search] Weighted fusion (Œ±={alpha}): {len(all_indices)} candidates")

        # Sort by fused score
        sorted_indices = sorted(fused_scores.keys(), key=fused_scores.get, reverse=True)[:rerank_top_k]

        # ========== Step 4: Reranking (Optional) ==========
        if self.reranker is not None and len(sorted_indices) > k:
            print(f"[Rerank] Reranking top-{len(sorted_indices)} candidates...")
            
            # Prepare query-document pairs
            pairs = [[query, self.texts[idx]] for idx in sorted_indices]
            
            # Get rerank scores
            rerank_scores = self.reranker.predict(pairs)
            
            # Update fused scores with rerank scores
            for idx, score in zip(sorted_indices, rerank_scores):
                fused_scores[idx] = float(score)
            
            # Re-sort by rerank scores
            sorted_indices = sorted(sorted_indices, key=fused_scores.get, reverse=True)
            
            print(f"[Rerank] ‚úì Reranked to top-{k}")

        # ========== Step 5: Build Results DataFrame ==========
        final_indices = sorted_indices[:k]
        rows = []
        for rank, idx in enumerate(final_indices, start=1):
            md = self.meta_df.iloc[idx]
            item = {
                "rank": rank,
                "score": fused_scores[idx],
                "text": self.texts[idx],
                "doc": md.get("doc"),
                "path": md.get("path"),
                "modality": md.get("modality"),
                "chunk": int(md.get("chunk", 0)),
                "page": _page_or_none(md.get("page")),
            }
            
            # Section hint
            if self.outline_df is not None:
                toc = self.outline_df[self.outline_df["doc_name"] == item["doc"]]
                if not toc.empty:
                    item["section_hint"] = toc.iloc[0]["title"]
            
            rows.append(item)
        
        return pd.DataFrame(rows)
    
def baseline_answer_one_call(
    kb: KBEnv,
    query: str,
    k_ctx: int = 8,
    table_rows: Optional[List[Dict[str, Any]]] = None
) -> dict:
    """
    Baseline (Stage 4) requirements:
      - Naive chunking (we use existing kb_texts)
      - Single-pass vector search (FAISS only)
      - One LLM call, no caching
    """
    # 1) Retrieve top-k chunks
    ctx_df = kb.search(query, k=k_ctx)
    if ctx_df is None or ctx_df.empty:
        answer = "I couldn't find any relevant context in the KB for this query."
        print(answer)
        return {"answer": answer, "contexts": []}

    # 2) Build context and simple citations
    ctx_lines = []
    for _, row in ctx_df.iterrows():
        text = str(row["text"]).replace("\\n", " ").strip()
        if len(text) > 800:
            text = text[:800] + "..."
        ctx_lines.append(f"- {text}")

    # We will build citations later; prefer table-row provenance if provided
    cits = []

    # Build citations: prefer structured table rows with pages
    if table_rows:
        for r in table_rows[:5]:
            doc = str(r.get("doc") or "")
            page = r.get("page")
            if page is not None:
                cits.append(f"{doc}, page {int(page)}")
            else:
                cits.append(f"{doc}, table {r.get('table_id')} row {r.get('row_id')} (no page)")
    else:
        for _, row in ctx_df.iterrows():
            doc = str(row.get("doc") or "")
            mod = str(row.get("modality") or "")
            page = row.get("page")
            if page is not None:
                cits.append(f"{doc}, page {page}")
            else:
                ch = int(row.get("chunk") or 0)
                if mod in ("md", "table_row"):
                    cits.append(f"{doc}, chunk {ch} (no page; {mod})")
                else:
                    cits.append(f"{doc}, chunk {ch} (no page)")

    # Optional: include structured table rows so the LLM doesn't deny available data
    table_lines = []
    if table_rows:
        table_lines.append("STRUCTURED TABLE ROWS (authoritative):")
        for r in table_rows[:6]:
            ser_q = r.get("series_q") or {}
            ser_y = r.get("series") or {}
            if ser_q:
                def _qkey(k: str):
                    m = re.match(r"([1-4])Q(20\\d{2})$", k)
                    return (int(m.group(2)), int(m.group(1))) if m else (0, 0)
                qkeys = sorted(ser_q.keys(), key=_qkey)[-5:]
                table_lines.append(f"- {r.get('doc')} | {r.get('label')} | " + ", ".join(f"{k}: {ser_q[k]}" for k in qkeys))
            elif ser_y:
                ys = sorted(ser_y.keys())[-3:]
                table_lines.append(f"- {r.get('doc')} | {r.get('label')} | " + ", ".join(f"{y}: {ser_y[y]}" for y in ys))

    # 3) Compose strict prompt
    if table_lines:
        # When we have structured rows, exclude noisy text snippets to avoid conflicting numbers.
        prompt = f"""USER QUESTION: 
            {query}
            {chr(10).join(table_lines)} 
            INSTRUCTIONS:
            1. **Data Source Priority**: Use ONLY the numbers from STRUCTURED TABLE ROWS above. These are authoritative financial data extracted from official reports.

            2. **Metric Substitution**: If the exact metric requested isn't available but a closely related metric exists (e.g., "Total Income" instead of "Operating Income"), use the available metric and clearly state the substitution in your answer.

            3. **Calculations**: 
            - Show your work for any calculations (e.g., ratios, year-over-year growth)
            - Use the format: Operating Efficiency Ratio = Opex √∑ Operating Income = X √∑ Y = Z%
            - Calculate year-over-year changes as: ((New - Old) / Old) √ó 100%

            4. **Missing Data**: If requested periods or metrics are not present in the structured rows:
            - Explicitly state which periods/metrics are missing
            - Provide what IS available
            - Do NOT refuse to answer if partial data exists

            5. **Output Format**:
            - Start with a direct 1-2 sentence answer
            - Present numerical results in a clear Markdown table with columns: Period/Year | Metric | Value
            - Add brief notes if clarifications are needed

            6. **Accuracy**: Do NOT invent, extrapolate, or estimate numbers. Only use values explicitly shown in the structured rows.

            Example table format:
            | Year | Operating Expenses | Total Income | Efficiency Ratio |
            |------|-------------------|--------------|------------------|
            | 2022 |        $X         |      $Y      |        Z%        |
            | 2023 |        $X         |      $Y      |        Z%        |"""
    else:
        prompt = f"""USER QUESTION:
        {query}

        CONTEXT (verbatim excerpts from financial reports):
        {chr(10).join(ctx_lines)}

        INSTRUCTIONS:
        1. **Data Source**: Extract information ONLY from the CONTEXT above. These are direct quotes from official reports.

        2. **Explicit Data Gaps**: If the exact values for requested periods are not present in the context:
        - State which specific periods/metrics are missing
        - Provide what IS available from the context
        - Do NOT make up or estimate missing values

        3. **Calculations**: If calculations are requested:
        - Show your working step-by-step
        - Only calculate if all required values are present in the context
        - Use the format: Ratio = A √∑ B = X √∑ Y = Z%

        4. **Output Format**:
        - Start with a direct answer summarizing what you found
        - Present data in a clear Markdown table when applicable
        - Add a "Missing data" section if any requested information is unavailable

        5. **Citations**: Reference specific excerpts when stating values (e.g., "according to excerpt 2...")

        6. **Accuracy**: Precision is critical. Only use numbers explicitly stated in the context.

        Example output structure:
        **Answer**
        [Direct 1-2 sentence response]

        | Period | Metric | Value |
        |--------|--------|-------|
        |Q4 2024 | NIM    | 2.05% |

        **Missing data**
        - Q1-Q3 2024: No quarterly data available in context"""

    # 4) One LLM call
    print(f"[LLM] single-call baseline using {_llm_provider_info()}")
    answer = _llm_single_call(prompt)

    # 5) Print nicely in notebooks
    # print("""\nBASELINE (Single LLM Call)\n--------------------------------""")
    # print(answer)
    # print("\nCitations:")
    # for c in cits[:5]:
    #     print(f"- {c}")

    return {"answer": answer, "contexts": ctx_df.head(5)}
    

# ----------------------------- Tool: Calculator -----------------------------

class CalculatorTool:
    """
    Safe arithmetic eval (supports +,-,*,/,**, parentheses) and helpers for deltas/YoY.
    """

    ALLOWED = {
        ast.Expression, ast.BinOp, ast.UnaryOp, ast.Num, ast.Load,
        ast.Add, ast.Sub, ast.Mult, ast.Div, ast.Pow, ast.USub, ast.UAdd,
        ast.Mod, ast.FloorDiv, ast.Constant, ast.Call, ast.Name
    }
    SAFE_FUNCS = {"round": round, "abs": abs}

    @classmethod
    def safe_eval(cls, expr: str) -> float:
        node = ast.parse(expr, mode="eval")
        for n in ast.walk(node):
            if type(n) not in cls.ALLOWED:
                raise ValueError(f"Disallowed expression: {type(n).__name__}")
            if isinstance(n, ast.Call) and not (isinstance(n.func, ast.Name) and n.func.id in cls.SAFE_FUNCS):
                raise ValueError("Only round(...) and abs(...) calls are allowed")
        code = compile(node, "<expr>", "eval")
        return float(eval(code, {"__builtins__": {}}, cls.SAFE_FUNCS))

    @staticmethod
    def delta(a: float, b: float) -> float:
        return float(a) - float(b)

    @staticmethod
    def yoy(a: float, b: float) -> Optional[float]:
        b = float(b)
        if b == 0: return None
        return (float(a) - b) / b * 100.0


# ----------------------------- Tool: Table Extraction -----------------------------

class TableExtractionTool:
    """
    Look up a metric row in kb_tables.parquet and extract {year -> value_num}.
    Heuristic: find any row where any cell (value_str) contains the metric term,
    then collect all cells in that row whose column is a 4-digit year.
    """

    # --- normalization helpers & synonyms (for robust matching) ---
    @staticmethod
    def _norm(s: str) -> str:
        """Lowercase, replace '&' with 'and', strip punctuation, collapse spaces."""
        if s is None:
            return ""
        s = str(s).lower()
        s = s.replace("&", " and ")
        s = re.sub(r"[^a-z0-9 ]+", " ", s)
        s = re.sub(r"\s+", " ", s).strip()
        return s

    # Expanded metric synonyms
    SYNONYMS = {
        # NIM
        "nim": ["net interest margin", "nim", "net interest margin group", "nim group"],
        "net interest margin": ["net interest margin", "nim", "net interest margin group", "nim group"],
        # Gross margin (treat as NIM for banks)
        "gross margin": ["net interest margin", "nim", "net interest margin group", "nim group", "gross margin"],
        # Opex
        "operating expenses and income": [
            "operating expenses and income",
            "operating expenses",
            "total expenses",
            "expenses",
        ],
        "operating expenses": [
            "operating expenses",
            "total expenses",
            "expenses",
            "opex",
        ],
        "total expenses": [
            "total expenses",
            "expenses",
            "operating expenses",
            "opex",
        ],
        # Income
        "operating income": [
            "operating income",
            "total operating income",
            "total income",
            "income",
        ],
        "total income": [
            "total income",
            "operating income",
            "total operating income",
            "income",
        ],
    }

    def __init__(self, tables_df: Optional[pd.DataFrame]):
        self.df = tables_df

    @staticmethod
    def _is_year(col: str) -> bool:
        return bool(re.fullmatch(r"\d{4}", str(col).strip()))

    @staticmethod
    def _parse_quarter_token(col: str):
        """
        Parse common quarter column labels like '1Q24', '1Q 2024', 'Q1 2024', '4QFY24'.
        Returns a tuple (year:int, quarter:int, display:str) or None if not a quarter.
        """
        s = str(col).strip()
        # 1) Compact form like '1Q24' or '4Q2024'
        m = re.search(r'(?i)\b([1-4])\s*q\s*((?:20)?\d{2})\b', s)
        if not m:
            # 2) 'Q1 2024' or 'Q3 FY24'
            m = re.search(r'(?i)\bq\s*([1-4])\s*(?:fy)?\s*((?:20)?\d{2})\b', s)
        if not m:
            # 3) '([1-4])Q((?:20)?\d{2})' without space
            m = re.search(r'(?i)\b([1-4])q((?:20)?\d{2})\b', s)
        if not m:
            return None
        q = int(m.group(1))
        ytxt = m.group(2)
        y = int(ytxt)
        if y < 100:  # normalize '24' -> 2024
            y += 2000
        display = f"{q}Q{y}"
        return (y, q, display)

    @staticmethod
    def _is_quarter(col: str) -> bool:
        return TableExtractionTool._parse_quarter_token(col) is not None

    def get_metric_rows(self, metric: str, doc: Optional[str] = None, limit: int = 5):
        if self.df is None or self.df.empty:
            return []
        base_df = self.df

        # Build normalized copies for robust matching
        df = base_df.assign(
            _val_norm=base_df["value_str"].astype(str).map(self._norm),
            _col_norm=base_df["column"].astype(str).map(self._norm),
        )

        metric_norm = self._norm(metric)
        cand_terms = self.SYNONYMS.get(metric_norm, [metric_norm])

        mask = pd.Series(False, index=df.index)
        for term in cand_terms:
            term_norm = self._norm(term)
            mask = mask | df["_val_norm"].str.contains(term_norm, na=False) | df["_col_norm"].str.contains(term_norm, na=False)

        if doc:
            mask = mask & (df["doc_name"] == doc)

        if not mask.any():
            return []

        # --- ORIENTATION A: metric appears as a COLUMN header; quarters are in ROW label cells ---
        results: List[Dict[str, Any]] = []
        table_keys = (
            df.loc[mask, ["doc_name", "table_id"]]
              .drop_duplicates()
              .itertuples(index=False, name=None)
        )
        for (d, t) in table_keys:
            tbl = base_df[(base_df["doc_name"] == d) & (base_df["table_id"] == t)].copy()
            if tbl.empty:
                continue
            # normalized copies to detect metric column(s)
            tbln = tbl.assign(
                _val_norm=tbl["value_str"].astype(str).map(self._norm),
                _col_norm=tbl["column"].astype(str).map(self._norm),
            )
            # columns whose header contains the metric term
            metric_cols = sorted(tbln.loc[tbln["_col_norm"].str.contains(metric_norm, na=False), "column"].unique().tolist())
            if metric_cols:
                mcol = str(metric_cols[0])
                # build series_q by iterating all rows in the table and picking the metric cell + a quarter label cell
                series_q: Dict[str, float] = {}
                series_y: Dict[int, float] = {}
                series_pct: Dict[int, float] = {}
                pages_seen: list[int] = []
                for rid in sorted(tbl["row_id"].unique()):
                    row_cells = tbl[tbl["row_id"] == rid]
                    # collect page numbers for this row (if available)
                    try:
                        pser = row_cells.get("page")
                        if pser is not None:
                            pages_seen += [int(p) for p in pser.dropna().astype(int).tolist()]
                    except Exception:
                        pass
                    # find the cell for the metric column in this row
                    mcell = row_cells[row_cells["column"].astype(str) == mcol]
                    if mcell.empty:
                        continue
                    val = mcell.iloc[0].get("value_num")
                    # also try to pick YoY % values when the metric column header is a YoY column
                    # e.g., column header contains 'yoy' or '%'
                    for _, rc in row_cells.iterrows():
                        ctext = str(rc.get("column") or "")
                        if re.search(r"(?i)yoy|%", ctext):
                            try:
                                ylab = (rc.get("value_str") or "").strip()
                                if self._is_year(ylab):
                                    vnum = rc.get("value_num")
                                    if pd.notna(vnum):
                                        series_pct[int(ylab)] = float(vnum)
                            except Exception:
                                pass
                    # find a row label that looks like a quarter or a year in any non-year/quarter column
                    label_text = None
                    for _, rc in row_cells.iterrows():
                        vstr = (rc.get("value_str") or "").strip()
                        if not vstr:
                            continue
                        # prefer quarter tokens
                        qtok = self._parse_quarter_token(vstr)
                        if qtok:
                            disp = qtok[2]
                            label_text = disp
                            break
                        # else maybe pure year row label like "2024"
                        if self._is_year(vstr):
                            label_text = vstr
                            break
                    if pd.notna(val) and label_text:
                        # decide if it's quarter or year
                        qtok2 = self._parse_quarter_token(label_text)
                        if qtok2:
                            series_q[qtok2[2]] = float(val)
                        elif self._is_year(label_text):
                            try:
                                series_y[int(label_text)] = float(val)
                            except Exception:
                                pass
                page_val = None
                if pages_seen:
                    try:
                        page_val = max(set(pages_seen), key=pages_seen.count)
                    except Exception:
                        page_val = pages_seen[-1]
                if series_q or series_y:
                    # label: use the metric column header text
                    label = str(mcol)
                    results.append({
                        "doc": d,
                        "table_id": int(t),
                        "row_id": -1,  # synthetic aggregation over rows
                        "label": label,
                        "series": series_y,
                        "series_q": series_q,
                        "series_pct": series_pct,
                        "page": page_val,
                    })

        # stop early if we already found enough good quarter rows
        if results and len(results) >= limit:
            # rank quarter-first
            def _rank_q(r):
                sq = r.get("series_q", {}) or {}
                def _qkey(k: str):
                    m = re.match(r"([1-4])Q(20\\d{2})$", k)
                    if m:
                        return (int(m.group(2)), int(m.group(1)))
                    return (0, 0)
                if sq:
                    qkeys = sorted(sq.keys(), key=_qkey)
                    latest_qy, latest_q = _qkey(qkeys[-1]) if qkeys else (0, 0)
                    return ( -len(sq), -latest_qy, -latest_q, 0, 0 )
                years = sorted((results[0].get("series") or {}).keys())
                latest_y = years[-1] if years else 0
                return ( 0, 0, 0, -len(years), -latest_y )
            results.sort(key=_rank_q)
            return results[:limit]

        # --- ORIENTATION B (fallback): metric appears as a ROW label; years/quarters are COLUMNS ---
        key_cols = ["doc_name", "table_id", "row_id"]
        row_keys = (
            df.loc[mask, key_cols]
              .drop_duplicates()
              .itertuples(index=False, name=None)
        )

        for (d, t, r) in row_keys:
            # Load the FULL row from the base dataframe (not the masked slice)
            row_cells = base_df[(base_df["doc_name"] == d) & (base_df["table_id"] == t) & (base_df["row_id"] == r)]
            if row_cells.empty:
                continue

            # choose a representative page for this row
            page_val = None
            try:
                pser = row_cells.get("page")
                if pser is not None:
                    vals = [int(p) for p in pser.dropna().astype(int).tolist()]
                    if vals:
                        page_val = max(set(vals), key=vals.count)
            except Exception:
                pass

            # Determine label
            label = None
            rc_norm = row_cells.assign(
                _val_norm=row_cells["value_str"].astype(str).map(self._norm),
                _col_norm=row_cells["column"].astype(str).map(self._norm),
            )
            metric_hits = rc_norm[~rc_norm["column"].astype(str).map(self._is_year) & rc_norm["_val_norm"].str.contains(metric_norm, na=False)]
            if not metric_hits.empty:
                label = (metric_hits.iloc[0]["value_str"] or "").strip()
            if not label:
                non_year = row_cells[~row_cells["column"].astype(str).map(self._is_year)]
                if not non_year.empty:
                    label = (non_year.iloc[0]["value_str"] or "").strip() or str(non_year.iloc[0]["column"])
            if not label:
                label = f"row {int(r)}"

            # Build year and quarter series from ALL cells in this row
            series: Dict[int, float] = {}
            series_q: Dict[str, float] = {}
            for _, cell in row_cells.iterrows():
                col = str(cell["column"]).strip()
                val = cell.get("value_num")
                if pd.isna(val):
                    continue
                if self._is_year(col):
                    try:
                        y = int(col); series[y] = float(val); continue
                    except Exception:
                        pass
                qtok = self._parse_quarter_token(col)
                if qtok:
                    series_q[qtok[2]] = float(val)

            results.append({
                "doc": d,
                "table_id": int(t),
                "row_id": int(r),
                "label": label,
                "series": series,
                "series_q": series_q,
                "page": page_val
            })

        # Rank results: quarters first by count/recency, then years
        def _row_rank(r):
            sq = r.get("series_q", {}) or {}
            def _qkey(k: str):
                m = re.match(r"([1-4])Q(20\\d{2})$", k)
                if m:
                    return (int(m.group(2)), int(m.group(1)))
                return (0, 0)
            if sq:
                qkeys = sorted(sq.keys(), key=_qkey)
                latest_qy, latest_q = _qkey(qkeys[-1]) if qkeys else (0, 0)
                return ( -len(sq), -latest_qy, -latest_q, 0, 0 )
            years = sorted(r["series"].keys())
            latest_y = years[-1] if years else 0
            return ( 0, 0, 0, -len(years), -latest_y )

        results.sort(key=_row_rank)
        return results[:limit]

    @staticmethod
    def last_n_years(series: Dict[int, float], n: int = 3) -> List[Tuple[int, float]]:
        ys = sorted(series.keys())
        return [(y, series[y]) for y in ys[-n:]]


#
# ----------------------------- Tool: Text Extraction (fallback for quarters) -----------------------------
class TextExtractionTool:
    """
    Regex-based fallback when Marker tables don't carry the quarter series.
    Currently focuses on percentage metrics like Net Interest Margin (NIM).
    It scans the KB text chunks and tries to pair quarter tokens with the nearest % value.
    """
    QPAT = re.compile(r"(?i)(?:\b([1-4])\s*q\s*((?:20)?\d{2})\b|\bq\s*([1-4])\s*((?:20)?\d{2})\b|\b([1-4])q((?:20)?\d{2})\b)")
    PCT = re.compile(r"(?i)(\d{1,2}(?:\.\d{1,2})?)\s*%")

    def __init__(self, kb: 'KBEnv'):
        self.kb = kb

    @staticmethod
    def _norm(s: str) -> str:
        return TableExtractionTool._norm(s)

    @staticmethod
    def _mk_qdisp(q: int, y: int) -> str:
        if y < 100: y += 2000
        return f"{q}Q{y}"

    def extract_quarter_pct(self, metric: str, top_k_text: int = 200) -> Dict[str, float]:
        metric_n = self._norm(metric)
        hits = self.kb.search(metric, k=top_k_text)
        if hits is None or hits.empty:
            return {}
        series_q: Dict[str, float] = {}
        for _, row in hits.iterrows():
            txt = str(row["text"])
            # Quick filter: only consider chunks that mention the metric name
            if metric_n not in self._norm(txt):
                continue
            # Find all quarter tokens in this chunk
            quarts = []
            for m in self.QPAT.finditer(txt):
                # groups: (q1,y1) or (q2,y2) or (q3,y3)
                if m.group(1):   q, y = int(m.group(1)), int(m.group(2))
                elif m.group(3): q, y = int(m.group(3)), int(m.group(4))
                else:            q, y = int(m.group(5)), int(m.group(6))
                if y < 100: y += 2000
                quarts.append((q, y, m.start(), m.end()))
            if not quarts:
                continue
            # Find % values; take the nearest % to each quarter mention
            pcts = [(pm.group(1), pm.start(), pm.end()) for pm in self.PCT.finditer(txt)]
            if not pcts:
                continue
            MAX_CHARS = 48  # require proximity
            for (q, y, qs, qe) in quarts:
                best = None; best_d = 1e9
                for (val, ps, pe) in pcts:
                    d = min(abs(ps - qe), abs(pe - qs))
                    if d < best_d and d <= MAX_CHARS:
                        try:
                            num = float(val)
                        except Exception:
                            continue
                        # sanity for NIM-like percentages
                        if 0.0 <= num <= 6.0:
                            best_d = d; best = num
                if best is not None:
                    disp = self._mk_qdisp(q, y)
                    series_q[disp] = float(best)
        return series_q

# ----------------------------- Tool: Multi-Doc Compare -----------------------------

class MultiDocCompareTool:
    """
    Compare the same metric across multiple docs by pulling each doc's row
    and extracting aligned year/value pairs.
    """

    def __init__(self, table_tool: TableExtractionTool):
        self.table_tool = table_tool

    def compare(self, metric: str, years: Optional[List[int]] = None, top_docs: int = 6):
        # get top rows across all docs
        rows = self.table_tool.get_metric_rows(metric, limit=50)
        if not rows:
            return []
        # take first occurrence per doc
        seen = set()
        picked = []
        for r in rows:
            if r["doc"] in seen: 
                continue
            seen.add(r["doc"])
            picked.append(r)
            if len(picked) >= top_docs:
                break
        # align years
        if years is None:
            all_years = set()
            for r in picked:
                all_years.update(r["series"].keys())
            years = sorted(all_years)[-3:]  # default: last 3 years available
        out = []
        for r in picked:
            values = {y: r["series"].get(y) for y in years}
            out.append({"doc": r["doc"], "label": r["label"], "years": years, "values": values})
        return out

# ----------------------------- Agent Mode: plan ‚Üí act ‚Üí observe -----------------------------

@dataclass
class AgentResult:
    plan: List[str]
    actions: List[str]
    observations: List[str]
    final: Dict[str, Any]

# ----------------------------- QUERY ANALYSIS UTILITIES-----------------------------
class QueryAnalyzer:
    """Utility methods for parsing financial queries"""
    
    @staticmethod
    def extract_metric(query: str) -> Optional[str]:
        """
        Extract metric name from query
        Priority: quoted phrase > regex patterns > capitalized words
        """
        # 1. Quoted phrase (highest priority)
        quoted = re.findall(r'"([^"]+)"', query)
        if quoted:
            return quoted[0]
        
        # 2. Common finance metrics (regex patterns)
        candidates = [
            r"net interest margin", r"nim", r"gross margin",
            r"operating expenses?(?: &| and)?(?: income)?",
            r"operating income", r"operating profit",
            r"total income", r"cost-to-income", r"allowances", 
            r"profit before tax", r"efficiency ratio",
            r"return on equity", r"roe", r"return on assets", r"roa"
        ]
        ql = query.lower()
        for pat in candidates:
            m = re.search(pat, ql)
            if m:
                return m.group(0)
        
        # 3. Fallback: capitalized phrase
        m2 = re.findall(r'\b([A-Z][A-Za-z&% ]{3,})\b', query)
        return m2[0] if m2 else None
    
    @staticmethod
    def want_compare(query: str) -> bool:
        """Check if query requests comparison across documents"""
        return bool(re.search(
            r"\b(compare|vs\.?|versus|across docs?|between|multi-?doc)\b", 
            query, re.I
        ))
    
    @staticmethod
    def want_yoy(query: str) -> bool:
        """Check if query requests year-over-year analysis"""
        return bool(re.search(
            r"\b(yoy|year[- ]over[- ]year|growth|change|%|delta|annual growth)\b", 
            query, re.I
        ))
    
    @staticmethod
    def want_quarters(query: str) -> bool:
        """Check if query requests quarterly data"""
        return bool(re.search(
            r"\b(quarter|quarters|\bq[1-4]\b|quarterly|half[- ]?year)\b", 
            query, re.I
        ))
    
    @staticmethod
    def extract_years(query: str) -> List[int]:
        """Extract year numbers from query"""
        years = [int(y) for y in re.findall(r"\b(20\d{2})\b", query)]
        # Deduplicate and sort
        return sorted(set(years))
    
    @staticmethod
    def extract_num_periods(query: str) -> Optional[int]:
        """Extract number of periods (e.g., 'last 5 quarters', 'last 3 years')"""
        # Pattern: "last N quarters/years"
        m = re.search(r"\blast\s+(\d+)\s+(quarters?|years?|periods?)", query, re.I)
        if m:
            return int(m.group(1))
        
        # Pattern: "N quarters/years"
        m2 = re.search(r"\b(\d+)\s+(quarters?|years?|periods?)", query, re.I)
        if m2:
            return int(m2.group(1))
        
        return None
    
    @staticmethod
    def needs_calculation(query: str) -> bool:
        """Check if query requires calculation"""
        return bool(re.search(
            r"\b(calculate|compute|derive|ratio|√∑|divided by|/|percentage of)\b", 
            query, re.I
        ))


# ----------------------------- PARALLEL QUERY DECOMPOSER -----------------------------

class ParallelQueryDecomposer:
    """Decomposes complex queries using QueryAnalyzer"""
    
    @staticmethod
    def decompose(query: str) -> List[str]:
        """
        Intelligent query decomposition using query analysis
        """
        analyzer = QueryAnalyzer
        
        # Extract intent
        needs_calc = analyzer.needs_calculation(query)
        metric = analyzer.extract_metric(query)
        years = analyzer.extract_years(query)
        num_periods = analyzer.extract_num_periods(query)
        
        # Q3: Efficiency Ratio (Opex √∑ Income)
        if needs_calc and metric and "efficiency" in metric.lower():
            return [
                f"Extract Operating Expenses for the last {num_periods or 3} fiscal years",
                f"Extract Total Income for the last {num_periods or 3} fiscal years"
            ]
        
        # Q3: Any ratio calculation (A √∑ B)
        if needs_calc and ("ratio" in query.lower() or "√∑" in query or "/" in query):
            # Try to extract both metrics
            parts = re.split(r'[√∑/]|\bdivided by\b', query, flags=re.I)
            if len(parts) == 2:
                metric_a = analyzer.extract_metric(parts[0])
                metric_b = analyzer.extract_metric(parts[1])
                if metric_a and metric_b:
                    return [
                        f"Extract {metric_a} for the last {num_periods or 3} fiscal years",
                        f"Extract {metric_b} for the last {num_periods or 3} fiscal years"
                    ]
        
        # Multi-metric comparison
        if analyzer.want_compare(query) and metric:
            # Decompose by year if multiple years specified
            if len(years) > 2:
                return [f"Extract {metric} for FY{y}" for y in years]
        
        # Single metric query (no decomposition)
        return [query]

    @staticmethod
    async def execute_parallel_async(kb: KBEnv, sub_queries: List[str], k_ctx: int) -> List[pd.DataFrame]:
        """Execute sub-queries in parallel"""
        loop = asyncio.get_event_loop()
        
        def search_sync(query):
            return kb.search(query, k=k_ctx)
        
        with ThreadPoolExecutor(max_workers=min(len(sub_queries), 4)) as executor:
            tasks = [loop.run_in_executor(executor, search_sync, sq) for sq in sub_queries]
            results = await asyncio.gather(*tasks)
        
        return results
    
    @staticmethod
    def execute_parallel(kb: KBEnv, sub_queries: List[str], k_ctx: int) -> List[pd.DataFrame]:
        """Blocking wrapper for async parallel execution"""
        try:
            loop = asyncio.get_event_loop()
            if loop.is_running():
                # Already in event loop (e.g., Jupyter), use nest_asyncio
                try:
                    import nest_asyncio
                    nest_asyncio.apply()
                except ImportError:
                    pass
        except RuntimeError:
            loop = asyncio.new_event_loop()
            asyncio.set_event_loop(loop)
        
        return loop.run_until_complete(
            ParallelQueryDecomposer.execute_parallel_async(kb, sub_queries, k_ctx)
        )
    
    @staticmethod
    def merge_results(results: List[pd.DataFrame], k_ctx: int) -> pd.DataFrame:
        """Merge and deduplicate parallel results"""
        if not results:
            return pd.DataFrame()
        
        # Concatenate
        merged = pd.concat([r for r in results if not r.empty], ignore_index=True)
        if merged.empty:
            return merged
        
        # Deduplicate by text (keep highest score)
        merged = merged.sort_values('score', ascending=False)
        merged = merged.drop_duplicates(subset=['text'], keep='first')
        
        # Take top-k
        merged = merged.head(k_ctx)
        
        # Re-rank
        merged = merged.sort_values('score', ascending=False).reset_index(drop=True)
        merged['rank'] = range(1, len(merged) + 1)
        
        return merged

# ----------------------------- Agent: plan ‚Üí act ‚Üí observe -----------------------------

class Agent:
    """
    Unified Agent with all tools:
    - CalculatorTool
    - TableExtractionTool
    - TextExtractionTool
    - MultiDocCompareTool (NEW)
    """
    
    def __init__(
        self, 
        kb: KBEnv, 
        use_parallel_subqueries: bool = False,
        verbose: bool = True
    ):
        self.kb = kb
        self.use_parallel_subqueries = use_parallel_subqueries
        self.verbose = verbose
        
        # Initialize all tools
        self.calc_tool = CalculatorTool()
        
        # Table tool (required for multi-doc compare)
        self.table_tool = TableExtractionTool(kb.tables_df) if kb.tables_df is not None else None
        
        # Text extraction fallback
        self.text_tool = TextExtractionTool(kb)
        
        # ‚úÖ Multi-doc compare tool (NEW)
        self.multidoc_tool = MultiDocCompareTool(self.table_tool) if self.table_tool else None
        
        # Analyzers
        self.analyzer = QueryAnalyzer()
        self.decomposer = ParallelQueryDecomposer() if use_parallel_subqueries else None
        
        if self.verbose:
            tools_status = []
            tools_status.append(f"Calculator: ‚úì")
            tools_status.append(f"Table: {'‚úì' if self.table_tool else '‚úó'}")
            tools_status.append(f"Text: ‚úì")
            tools_status.append(f"MultiDoc: {'‚úì' if self.multidoc_tool else '‚úó'}")
            print(f"[Agent] Tools: {' | '.join(tools_status)}")
    
    def run(self, query: str, k_ctx: int = 12) -> 'AgentResult':
        """Execute query with all available tools"""
        
        if self.verbose:
            print(f"[Agent] Query: {query[:60]}...")
        
        # Step 1: Analyze query
        metric = self.analyzer.extract_metric(query)
        wants_yoy = self.analyzer.want_yoy(query)
        wants_quarters = self.analyzer.want_quarters(query)
        wants_compare = self.analyzer.want_compare(query)
        needs_calc = self.analyzer.needs_calculation(query)
        years = self.analyzer.extract_years(query)
        num_periods = self.analyzer.extract_num_periods(query)
        
        if self.verbose:
            print(f"[Agent] Analysis:")
            print(f"  Metric: {metric}")
            print(f"  YoY: {wants_yoy}, Quarterly: {wants_quarters}")
            print(f"  Compare: {wants_compare}, Calc: {needs_calc}")
            print(f"  Years: {years}, Periods: {num_periods}")
        
        # Step 2 & 3 COMBINED: Retrieve and extract in PARALLEL
        actions = []
        table_rows = []
        comparison_results = []
        
        # Create thread pool for parallel execution
        executor = ThreadPoolExecutor(max_workers=3)
        futures = {}
        
        # Task 1: Start retrieval (runs in background thread)
        if self.use_parallel_subqueries and self.decomposer:
            sub_queries = self.decomposer.decompose(query)
            
            if len(sub_queries) >= 4:
                if self.verbose:
                    print(f"[Agent] Decomposed into {len(sub_queries)} sub-queries (parallel)")
                
                futures['retrieval'] = executor.submit(
                    self.decomposer.execute_parallel,
                    self.kb, sub_queries, k_ctx
                )
            elif len(sub_queries) > 1:
                # Sequential is faster for 2-3 queries (less overhead)
                if self.verbose:
                    print(f"[Agent] Decomposed into {len(sub_queries)} sub-queries (sequential)")
                
                def sequential_search():
                    results = [self.kb.search(sq, k=k_ctx) for sq in sub_queries]
                    return self.decomposer.merge_results(results, k_ctx)
                
                futures['retrieval'] = executor.submit(sequential_search)
            else:
                futures['retrieval'] = executor.submit(self.kb.search, query, k_ctx)
        else:
            futures['retrieval'] = executor.submit(self.kb.search, query, k_ctx)
        
        # Task 2: Start table extraction (runs SAME TIME as retrieval!)
        if metric and self.table_tool and not wants_compare:
            futures['table'] = executor.submit(
                self.table_tool.get_metric_rows,
                metric,
                10
            )
        
        # Task 3: Start multi-doc comparison (if needed)
        if wants_compare and metric and self.multidoc_tool:
            futures['comparison'] = executor.submit(
                self.multidoc_tool.compare,
                metric=metric,
                years=years if years else None,
                top_docs=6
            )
        
        # Wait for ALL tasks to finish and get results
        results = {}
        for name, future in futures.items():
            try:
                results[name] = future.result()  # This waits for each task
            except Exception as e:
                if self.verbose:
                    print(f"[Agent] Task {name} failed: {e}")
                results[name] = None
        
        # Process retrieval results
        if 'retrieval' in results and results['retrieval'] is not None:
            raw_retrieval = results['retrieval']
            
            if self.use_parallel_subqueries and isinstance(raw_retrieval, list):
                contexts = raw_retrieval  # Already merged by sequential_search()
                if self.verbose:
                    print(f"[Agent] Retrieved {len(contexts)} contexts")
            else:
                contexts = raw_retrieval
        else:
            # Fallback if retrieval failed
            contexts = self.kb.search(query, k=k_ctx)
        
        # Process table extraction results
        if 'table' in results and results['table']:
            table_rows = results['table']
            actions.append("table_extraction")
            if self.verbose:
                print(f"[Agent] Extracted {len(table_rows)} table rows")
        
        # Process comparison results
        if 'comparison' in results and results['comparison']:
            comparison_results = results['comparison']
            actions.append("multi_doc_compare")
            if self.verbose:
                print(f"[Agent] MultiDoc compare: {len(comparison_results)} documents")
        
        # Text extraction fallback (still sequential - only runs if table failed)
        if not table_rows and not comparison_results and self.text_tool:
            actions.append("text_extraction")
            if wants_quarters and metric:
                quarter_data = self.text_tool.extract_quarter_pct(metric, top_k_text=50)
                if self.verbose:
                    print(f"[Agent] Text extraction: {len(quarter_data)} quarters")
        
        # Calculator for YoY or ratios
        if needs_calc and self.calc_tool:
            actions.append("calculation")
        
        # Step 4: LLM Synthesis
        answer = self._synthesize_with_llm(
            query, 
            contexts, 
            table_rows, 
            comparison_results,  
            metric, 
            wants_yoy,
            wants_compare  
        )
        
        # Step 5: Build result
        observations = [
            f"Retrieved {len(contexts)} contexts",
            f"Metric: {metric}",
            f"YoY: {wants_yoy}, Quarterly: {wants_quarters}, Compare: {wants_compare}",
            f"Tools used: {', '.join(actions)}"
        ]
        
        result = AgentResult(
            plan=f"Analyze ‚Üí {'Compare' if wants_compare else 'Extract'} {metric} ‚Üí Synthesize",
            actions=actions,
            observations=observations,
            final={
                "contexts": contexts,
                "table_rows": table_rows,
                "comparison_results": comparison_results,
                "answer": answer,
                "metric": metric,
                "wants_yoy": wants_yoy,
                "wants_quarters": wants_quarters,
                "wants_compare": wants_compare
            }
        )
        
        return result
    
    def _synthesize_with_llm(
        self, 
        query: str, 
        contexts: pd.DataFrame, 
        table_rows: List[Dict],
        comparison_results: List[Dict],
        metric: str,
        wants_yoy: bool,
        wants_compare: bool
    ) -> str:
        """Synthesize final answer using LLM"""
        
        prompt_parts = [f"USER QUESTION:\n{query}\n"]
        
        # Add multi-doc comparison results
        if comparison_results:
            prompt_parts.append("\nMULTI-DOCUMENT COMPARISON:")
            for comp in comparison_results:
                doc = comp.get("doc", "Unknown")
                label = comp.get("label", metric)
                years = comp.get("years", [])
                values = comp.get("values", [])
                
                if years and values:
                    year_val_pairs = ", ".join(f"{y}: {v}" for y, v in zip(years, values))
                    prompt_parts.append(f"- {doc} | {label}: {year_val_pairs}")
        
        # Add table rows if available
        elif table_rows:
            prompt_parts.append("\nSTRUCTURED DATA:")
            for r in table_rows[:5]:
                if r.get("series_q"):
                    qkeys = sorted(r["series_q"].keys())[-5:]
                    ser = ", ".join(f"{k}: {r['series_q'][k]}" for k in qkeys)
                    prompt_parts.append(f"- {r['doc']} | {r['label']}: {ser}")
                else:
                    ys = sorted(r["series"].keys())[-3:]
                    ser = ", ".join(f"{y}: {r['series'][y]}" for y in ys)
                    prompt_parts.append(f"- {r['doc']} | {r['label']}: {ser}")
        
        # Add retrieved contexts
        if not contexts.empty:
            prompt_parts.append("\nCONTEXT:")
            for _, row in contexts.head(5).iterrows():
                text = str(row["text"])[:500]
                doc = row.get("doc", "Unknown")
                prompt_parts.append(f"- [{doc}] {text}")
        
        # Add instructions
        prompt_parts.append("\nINSTRUCTIONS:")
        prompt_parts.append("- Use ONLY the data provided above")
        
        if wants_compare:
            prompt_parts.append("- Compare the metric across different documents")
            prompt_parts.append("- Highlight similarities and differences")
        
        if wants_yoy:
            prompt_parts.append("- Calculate year-over-year growth percentages")
        
        prompt_parts.append("- Provide a concise answer with specific numbers and document names")
        prompt_parts.append("- If data is incomplete, state what's missing explicitly")
        
        prompt = "\n".join(prompt_parts)
        
        # Call LLM
        if self.verbose:
            print(f"[Agent] Synthesizing with LLM...")
        
        answer = _llm_single_call(prompt)
        
        return answer


# ----------------------------- Pretty print helpers -----------------------------

def _fmt_series(series: Dict[int, float], n: int = 3) -> str:
    if not series: return "‚Äî"
    ys = sorted(series.keys())[-n:]
    return ", ".join(f"{y}: {series[y]}" for y in ys)

def show_agent_result(res: AgentResult, show_ctx: int = 3):
    print("PLAN:")
    for step in res.plan:
        print("  -", step)
    print("\nACTIONS:")
    for a in res.actions:
        print("  -", a)
    print("\nOBSERVATIONS:")
    for o in res.observations:
        print("  -", o)

    fin = res.final

    # TABLE ROWS block
    if not fin.get("table_rows"):
        msg = fin.get("notice") or "No matching table rows were found for your request."
        print(f"\n‚ö†Ô∏è {msg}")
    elif "table_rows" in fin and fin["table_rows"]:
        print("\nTABLE ROWS (first few):")
        shown = 0
        for r in fin["table_rows"]:
            if shown >= 3:
                break
            sq = (r.get("series_q") or {})
            if sq:
                # sort quarters chronologically by (year, quarter)
                def _qkey(k):
                    m = re.match(r"([1-4])Q(20\\d{2})$", k)
                    if m:
                        return (int(m.group(2)), int(m.group(1)))
                    return (0, 0)
                qkeys = sorted(sq.keys(), key=_qkey)
                last5 = qkeys[-5:]
                ser = ", ".join(f"{k}: {sq[k]}" for k in last5)
                print(f"  doc={r['doc']} | label={r['label']} | quarters(last5)={ser}")
                shown += 1
            else:
                ys = sorted(r["series"].keys())
                ser = ", ".join(f"{y}: {r['series'][y]}" for y in ys[-3:]) if ys else "‚Äî"
                print(f"  doc={r['doc']} | label={r['label']} | years(last3)={ser}")
                shown += 1
    if "compare" in fin and fin["compare"]:
        print("\nCOMPARE (first few):")
        for r in fin["compare"][:3]:
            row = ", ".join(f"{y}: {r['values'].get(y)}" for y in r["years"])
            print(f"  doc={r['doc']} | label={r['label']} | {row}")
    if "calc" in fin and fin["calc"]:
        print("\nCALC (YoY):")
        for c in fin["calc"]:
            print(f"  {c['from']}‚Üí{c['to']}: {c['value_from']} ‚Üí {c['value_to']} | YoY={c['yoy_pct']}%")

    # Contexts
    ctx = fin.get("contexts")
    if ctx is not None and not ctx.empty:
        print("\nCONTEXTS:")
        for _, row in ctx.head(show_ctx).iterrows():
            t = str(row["text"]).replace("\n", " ")
            if len(t) > 240: t = t[:237] + "..."
            hint = f" ‚Äî {row.get('section_hint')}" if "section_hint" in row else ""
            print(f"  [{row['rank']}] {row['doc']} | {row['modality']}{hint}")
            print("     ", t)


# ----------------------------- CLI / Notebook ------------------------------------

# ----------------------------- Notebook Runtime ------------------------------------

# This section is safe for direct use inside a Jupyter/Colab/VSCode notebook cell.
# It avoids argparse/sys parsing and simply runs a default demo or accepts a variable `query`.

# Example usage in a notebook:
# from g2x import KBEnv, Agent, show_agent_result
# kb = KBEnv(base="./data_marker")
# agent = Agent(kb)
# res = agent.run("Compare Net Interest Margin across docs for 2022‚Äì2024")
# show_agent_result(res)

if __name__ == "__main__" or "__file__" not in globals():
    kb = KBEnv(base="./data_marker")
    agent = Agent(kb)

    try:
        query = globals().get("query", None)
    except Exception:
        query = None

    if not query:
        query = "What is the Net Interest Margin over the last 5 quarters?"
        print("‚ÑπÔ∏è Running notebook demo query:")
        print(f"   ‚Üí {query}\n")

    # BASELINE execution (single LLM, no caching)
    out = baseline_answer_one_call(kb, query, k_ctx=8)

[BM25] ‚úì Indexed 13587 documents
[Reranker] ‚úì Loaded cross-encoder/ms-marco-MiniLM-L-6-v2
[Agent] Tools: Calculator: ‚úì | Table: ‚úì | Text: ‚úì | MultiDoc: ‚úì
[Search] RRF fusion: 53 candidates
[Rerank] Reranking top-16 candidates...
[Reranker] ‚úì Loaded cross-encoder/ms-marco-MiniLM-L-6-v2
[Agent] Tools: Calculator: ‚úì | Table: ‚úì | Text: ‚úì | MultiDoc: ‚úì
[Search] RRF fusion: 53 candidates
[Rerank] Reranking top-16 candidates...
[Rerank] ‚úì Reranked to top-8
[Rerank] ‚úì Reranked to top-8
[LLM] single-call baseline using groq:openai/gpt-oss-20b
[LLM] single-call baseline using groq:openai/gpt-oss-20b
[LLM] provider=groq model=openai/gpt-oss-20b
[LLM] provider=groq model=openai/gpt-oss-20b


---

### Just to check available models

In [None]:
import google.generativeai as genai
import os

# Best practice: store your key as an environment variable
# Or replace "YOUR_API_KEY" with your actual key string for a quick test
genai.configure(api_key=os.environ.get("GEMINI_API_KEY", "YOUR_API_KEY"))

print("Available Models:\n")

# List all models and check which ones support the 'generateContent' method
for model in genai.list_models():
  if 'generateContent' in model.supported_generation_methods:
    print(f"- {model.name}")

## 5. Benchmark Runner

Run these 3 standardized queries. Produce JSON then prose answers with citations. These are the standardized queries.

*   Gross Margin Trend (or NIM if Bank)
    *   Query: "Report the Gross Margin (or Net Interest Margin, if a bank) over the last 5 quarters, with values."
    *   Expected Output: A quarterly table of Gross Margin % (or NIM % if bank).

*   Operating Expenses (Opex) YoY for 3 Years
    *   Query: "Show Operating Expenses for the last 3 fiscal years, year-on-year comparison."
    *   Expected Output: A 3-year Opex table (absolute numbers and % change).

*   Operating Efficiency Ratio
    *   Query: "Calculate the Operating Efficiency Ratio (Opex √∑ Operating Income) for the last 3 fiscal years, showing the working."
    *   Expected Output: Table with Opex, Operating Income, and calculated ratio for 3 years.

### Gemini Version 3

Two-Mode RAG System with Marker + PDFPlumber Fallback

In [7]:
"""
Two-Mode RAG System with Parallel Sub-Query Support
"""

from __future__ import annotations
import os, json, time
from typing import List, Dict, Any, Optional
from pathlib import Path

import numpy as np
import pandas as pd

# Import g2x components
from g2x import KBEnv, Agent, baseline_answer_one_call


# ============================================================================
# CONFIGURATION
# ============================================================================

class Config:
    """Centralized configuration"""
    VERBOSE = bool(int(os.environ.get("AGENT_CFO_VERBOSE", "1")))
    
    # Paths
    MARKER_INDEX = "./data_marker"
    PDFPLUMBER_INDEX = "./data"
    
    # Search params
    TOP_K = 12
    HYBRID_ALPHA = 0.6
    RERANK_TOP_K = 24
    
    # Agentic mode
    USE_PARALLEL_SUBQUERIES = True  # Enable parallel sub-query decomposition


# ============================================================================
# UTILITIES
# ============================================================================

def page_or_none(x) -> Optional[int]:
    """Safely convert page numbers"""
    try:
        if x is None or pd.isna(x):
            return None
        return int(x)
    except:
        return None


# ============================================================================
# MULTI-INDEX SEARCH (Marker + PDFPlumber Fallback)
# ============================================================================

class MultiIndexSearch:
    """
    Dual-index search with automatic fallback
    """
    
    def __init__(self, marker_path: str, pdfplumber_path: str):
        self.marker_kb = self._load_kb(marker_path, "Marker")
        self.pdfplumber_kb = self._load_kb(pdfplumber_path, "PDFPlumber")
        
        if not self.marker_kb and not self.pdfplumber_kb:
            raise RuntimeError("No valid indexes loaded")
    
    def _load_kb(self, path: str, name: str) -> Optional[KBEnv]:
        """Load KB with BM25 + Reranker enabled"""
        if not Path(path).exists():
            return None
        
        try:
            kb = KBEnv(base=path, enable_bm25=True, enable_reranker=True)
            if Config.VERBOSE:
                print(f"[{name}] ‚úì Loaded {len(kb.texts)} chunks")
            return kb
        except Exception as e:
            if Config.VERBOSE:
                print(f"[{name}] ‚úó Failed: {e}")
            return None
    
    def search(self, query: str, top_k: int = None) -> List[Dict[str, Any]]:
        """Hybrid search with fallback"""
        top_k = top_k or Config.TOP_K
        
        # Primary: Marker
        results = []
        if self.marker_kb:
            df = self.marker_kb.search(query, k=top_k, alpha=Config.HYBRID_ALPHA, rerank_top_k=Config.RERANK_TOP_K)
            results = self._df_to_dict(df, "marker")
        
        # Fallback: PDFPlumber
        if len(results) < top_k // 2 and self.pdfplumber_kb:
            df = self.pdfplumber_kb.search(query, k=top_k - len(results))
            results.extend(self._df_to_dict(df, "pdfplumber"))
        
        results.sort(key=lambda x: x.get("score", 0), reverse=True)
        return results[:top_k]
    
    def _df_to_dict(self, df: pd.DataFrame, source: str) -> List[Dict]:
        """Convert DataFrame to dict list"""
        if df is None or df.empty:
            return []
        
        return [
            {
                "file": str(row.get("doc")),
                "page": page_or_none(row.get("page")),
                "text": str(row.get("text")),
                "score": float(row.get("score", 0)),
                "year": int(row["year"]) if pd.notna(row.get("year")) else None,
                "quarter": str(row["quarter"]) if pd.notna(row.get("quarter")) else None,  # FIX: Quarter is a string, not int
                "section_hint": row.get("section_hint"),
                "index_source": source
            }
            for _, row in df.iterrows()
        ]


# ============================================================================
# ANSWERING ENGINE
# ============================================================================

class AnsweringEngine:
    """Unified baseline + agentic answering with parallel sub-queries"""
    
    def __init__(self):
        print("AnsweringEngine initialized")
        self.search = MultiIndexSearch(Config.MARKER_INDEX, Config.PDFPLUMBER_INDEX)
        
        # Agent with parallel sub-queries
        primary_kb = self.search.marker_kb or self.search.pdfplumber_kb
        self.agent = Agent(
            kb=primary_kb, 
            use_parallel_subqueries=True,
            verbose=Config.VERBOSE
        )
        print(f"[AnsweringEngine] Parallel sub-queries enabled: {self.agent.use_parallel_subqueries}")
    
    def answer(self, query: str, mode: str = "baseline") -> Dict[str, Any]:
        """Execute query in baseline or agentic mode"""
        
        if mode == "agentic":
            # Agentic mode with parallel sub-queries
            agent_result = self.agent.run(query, k_ctx=Config.TOP_K)
            
            return {
                "answer": self._format_agent_answer(agent_result),
                "hits": self._extract_agent_hits(agent_result),
                "execution_log": {
                    "plan": agent_result.plan,
                    "actions": agent_result.actions,
                    "observations": agent_result.observations
                }
            }
        
        else:  # baseline
            # Standard RAG: Retrieve ‚Üí Single LLM call
            results = self.search.search(query, top_k=Config.TOP_K)
            
            answer_result = baseline_answer_one_call(
                self.search.marker_kb or self.search.pdfplumber_kb,
                query,
                k_ctx=Config.TOP_K
            )
            
            return {
                "answer": answer_result.get("answer", ""),
                "hits": results[:5],  # Top-5 citations
                "execution_log": None
            }
    
    def _format_agent_answer(self, agent_result) -> str:
        """
        Format agentic answer with proper fallback
        """
        lines = []
        
        # Add execution summary (optional)
        if agent_result.observations and Config.VERBOSE:
            lines.append("**Execution Summary**")
            for obs in agent_result.observations:
                lines.append(f"- {obs}")
            lines.append("")
        
        fin = agent_result.final
        
        # Priority 1: Return LLM-synthesized answer if available
        if "answer" in fin and fin["answer"]:
            return fin["answer"]
        
        # Priority 2: Format tool outputs (FALLBACK)
        if fin.get("comparison_results"):
            lines.append("**Multi-Document Comparison**")
            for comp in fin["comparison_results"][:5]:
                doc = comp.get("doc", "Unknown")
                years = comp.get("years", [])
                values = comp.get("values", [])
                if years and values:
                    year_val = ", ".join(f"{y}: {v}" for y, v in zip(years, values))
                    lines.append(f"- {doc}: {year_val}")
        
        elif fin.get("table_rows"):
            lines.append("**Extracted Data**")
            for r in fin["table_rows"][:5]:
                doc = r.get("doc", "Unknown")
                label = r.get("label", "")
                
                if r.get("series_q"):
                    qkeys = sorted(r["series_q"].keys())[-5:]
                    ser = ", ".join(f"{k}: {r['series_q'][k]}" for k in qkeys)
                    lines.append(f"- {doc} | {label}: {ser}")
                elif r.get("series"):
                    ys = sorted(r["series"].keys())[-3:]
                    ser = ", ".join(f"{y}: {r['series'][y]}" for y in ys)
                    lines.append(f"- {doc} | {label}: {ser}")
                else:
                    lines.append(f"- {doc} | {label}: (no data)")
        
        # Priority 3: Final fallback message
        if not lines:
            lines.append("Analysis complete. No structured data extracted.")
        
        return "\n".join(lines)
    
    def _extract_agent_hits(self, agent_result) -> List[Dict]:
        """Extract citations from agent result"""
        contexts = agent_result.final.get("contexts")
        if contexts is None or contexts.empty:
            return []
        
        return [
            {
                "file": row.get("doc"),
                "page": row.get("page"),
                "section_hint": row.get("section_hint"),
                "index_source": "marker",
                "score": row.get("score")
            }
            for _, row in contexts.head(5).iterrows()
        ]


# ============================================================================
# BENCHMARK RUNNER
# ============================================================================

class BenchmarkRunner:
    """Standardized benchmark execution"""
    
    QUERIES = [
        "Report the Gross Margin (or Net Interest Margin, if a bank) over the last 5 quarters, with values.",
        "Show Operating Expenses for the last 3 fiscal years, year-on-year comparison.",
        "Calculate the Operating Efficiency Ratio (Opex √∑ Operating Income) for the last 3 fiscal years, showing the working."
    ]
    
    def __init__(self, engine: AnsweringEngine):
        self.engine = engine
    
    def run(self, mode: str = "baseline") -> Dict[str, Any]:
        """Run benchmark and save results"""
        out_dir = "data_marker" if mode == "agentic" else "data"
        os.makedirs(out_dir, exist_ok=True)
        
        print(f"\n{'='*60}")
        print(f"  {mode.upper()} BENCHMARK")
        print(f"{'='*60}\n")
        
        results = []
        for i, query in enumerate(self.QUERIES, 1):
            print(f"\nQ{i}. {query}\n")
            
            t0 = time.perf_counter()
            result = self.engine.answer(query, mode=mode)
            latency_ms = round((time.perf_counter() - t0) * 1000, 2)
            
            print(result["answer"])
            if result.get("hits"):
                print("\n--- Citations ---")
                for hit in result["hits"][:5]:
                    pg = f"p.{hit.get('page')}" if hit.get('page') else ""
                    print(f"- {hit['file']} {pg}")
            
            print(f"\n(Latency: {latency_ms} ms)")
            
            results.append({
                "query": query,
                "answer": result["answer"],
                "citations": result.get("hits", []),
                "execution_log": result.get("execution_log"),
                "latency_ms": latency_ms
            })
        
        # Save JSON with UTF-8 encoding
        json_path = f"{out_dir}/bench_{mode}.json"
        with open(json_path, "w", encoding="utf-8") as f:  # FIX: Add encoding
            json.dump({"results": results}, f, indent=2, ensure_ascii=False)  # FIX: Add ensure_ascii=False
        
        # Save Markdown
        md_path = f"{out_dir}/bench_{mode}.md"
        self._write_markdown(md_path, results, mode)
        
        # Summary
        latencies = [r["latency_ms"] for r in results]
        print(f"\n{'='*60}")
        print(f"  SUMMARY")
        print(f"{'='*60}")
        print(f"P50: {np.percentile(latencies, 50):.1f} ms")
        print(f"P95: {np.percentile(latencies, 95):.1f} ms\n")
        
        return {"json_path": json_path, "md_path": md_path, "results": results}
    
    def _write_markdown(self, path: str, results: List[Dict], mode: str):
        """Generate markdown report"""
        lines = [f"# {mode.title()} Benchmark Report\n"]
        
        if mode == "baseline":
            lines.append("**Pipeline**: Hybrid Search (BM25 + Vector + RRF + Rerank) -> Single LLM\n")  # Changed ‚Üí to ->
        else:
            lines.append("**Pipeline**: Parallel Sub-Queries -> Tool Execution -> Multi-step Reasoning\n")  # Changed ‚Üí to ->
        
        for i, r in enumerate(results, 1):
            lines.append(f"\n---\n\n## Q{i}. {r['query']}\n")
            lines.append(f"**Answer**\n\n{r['answer']}\n")
            
            if r.get("citations"):
                lines.append("\n**Citations**\n")
                for hit in r["citations"]:
                    pg = f"p.{hit.get('page')}" if hit.get('page') else ""
                    lines.append(f"- {hit['file']} {pg}")
            
            if r.get("execution_log"):
                lines.append("\n**Execution Log**\n```")
                lines.append(json.dumps(r["execution_log"], indent=2))
                lines.append("```")
            
            lines.append(f"\n**Latency**: {r['latency_ms']} ms")
        
        # Summary
        latencies = [r["latency_ms"] for r in results]
        lines.append("\n---\n\n## Summary\n")
        lines.append(f"- P50: {np.percentile(latencies, 50):.1f} ms")
        lines.append(f"- P95: {np.percentile(latencies, 95):.1f} ms")
        
        # FIX: Add encoding="utf-8"
        with open(path, "w", encoding="utf-8") as f:
            f.write("\n".join(lines))


# ============================================================================
# MAIN
# ============================================================================

def main():
    os.environ["TOKENIZERS_PARALLELISM"] = "false"
    
    engine = AnsweringEngine()
    benchmark = BenchmarkRunner(engine)
    
    print("\n" + "="*60)
    print("  TWO-MODE RAG SYSTEM")
    print("="*60)
    
    # Baseline
    baseline_results = benchmark.run(mode="baseline")
    
    # Agentic with parallel sub-queries
    agentic_results = benchmark.run(mode="agentic")
    
    print("\n" + "="*60)
    print("  COMPLETE")
    print("="*60)
    print(f"Baseline: {baseline_results['json_path']}")
    print(f"Agentic:  {agentic_results['json_path']}")


if __name__ == "__main__":
    main()

AnsweringEngine initialized
[BM25] ‚úì Indexed 13587 documents
[BM25] ‚úì Indexed 13587 documents
[Reranker] ‚úì Loaded cross-encoder/ms-marco-MiniLM-L-6-v2
[Marker] ‚úì Loaded 13587 chunks
[Reranker] ‚úì Loaded cross-encoder/ms-marco-MiniLM-L-6-v2
[Marker] ‚úì Loaded 13587 chunks
[BM25] ‚úì Indexed 1623 documents
[BM25] ‚úì Indexed 1623 documents
[Reranker] ‚úì Loaded cross-encoder/ms-marco-MiniLM-L-6-v2
[PDFPlumber] ‚úì Loaded 1623 chunks
[Agent] Tools: Calculator: ‚úì | Table: ‚úì | Text: ‚úì | MultiDoc: ‚úì
[AnsweringEngine] Parallel sub-queries enabled: True

  TWO-MODE RAG SYSTEM

  BASELINE BENCHMARK


Q1. Report the Gross Margin (or Net Interest Margin, if a bank) over the last 5 quarters, with values.

[Search] RRF fusion: 96 candidates
[Rerank] Reranking top-24 candidates...
[Reranker] ‚úì Loaded cross-encoder/ms-marco-MiniLM-L-6-v2
[PDFPlumber] ‚úì Loaded 1623 chunks
[Agent] Tools: Calculator: ‚úì | Table: ‚úì | Text: ‚úì | MultiDoc: ‚úì
[AnsweringEngine] Parallel sub-querie

## ReAct Agent CFO - True Agentic Reasoning

This implementation uses the ReAct (Reason + Act) pattern where:
- ‚úÖ **LLM decides the plan** (not hardcoded if-else)
- ‚úÖ **LLM selects tools dynamically** (based on query understanding)
- ‚úÖ **LLM reasons at each step** (thought ‚Üí action ‚Üí observation loop)
- ‚úÖ **Optimized for latency** (single-turn with few-shot examples, parallel tool calls when possible)
- ‚úÖ **Adaptive parameters** (auto-detects time periods from data)

In [1]:
"""
ReAct Agent CFO - True Agentic Reasoning with Latency Optimization

Key Features:
1. LLM-driven planning (not hardcoded routing)
2. Dynamic tool selection based on query understanding
3. Auto-detection of time periods and metrics
4. Single-turn optimization with structured output
5. Parallel tool execution when possible
6. Few-shot examples for consistent reasoning
"""

from __future__ import annotations
from typing import List, Dict, Any, Optional, Tuple
import pandas as pd
import json
import time
import numpy as np
import g2x
import re
from dataclasses import dataclass, asdict
from concurrent.futures import ThreadPoolExecutor, as_completed


# ============================================================================
# AUTO-DETECTION UTILITIES
# ============================================================================

class DataIntrospector:
    """Automatically detect available metrics and time periods from KB"""
    
    def __init__(self, tables_df: pd.DataFrame):
        self.df = tables_df
        self._cache = {}
    
    def detect_quarters(self, n: int = 5) -> List[str]:
        """Auto-detect last N quarters from data"""
        if 'quarters' in self._cache:
            return self._cache['quarters']
        
        quarter_pattern = r'\b([1-4]Q\d{2})\b'
        all_quarters = set()
        
        for col in self.df['column'].dropna():
            matches = re.findall(quarter_pattern, str(col))
            all_quarters.update(matches)
        
        def sort_key(q):
            match = re.match(r'([1-4])Q(\d{2})', q)
            if match:
                return (int(match.group(2)), int(match.group(1)))
            return (0, 0)
        
        sorted_quarters = sorted(all_quarters, key=sort_key)[-n:]
        self._cache['quarters'] = sorted_quarters
        return sorted_quarters
    
    def detect_years(self, n: int = 3) -> List[int]:
        """Auto-detect last N years from annual reports"""
        if 'years' in self._cache:
            return self._cache['years']
        
        year_pattern = r'annual-report-(\d{4})'
        all_years = set()
        
        for doc in self.df['doc_name'].unique():
            match = re.search(year_pattern, str(doc))
            if match:
                all_years.add(int(match.group(1)))
        
        sorted_years = sorted(all_years)[-n:]
        self._cache['years'] = sorted_years
        return sorted_years
    
    def detect_document_patterns(self) -> Dict[str, str]:
        """Detect document naming patterns"""
        if 'doc_patterns' in self._cache:
            return self._cache['doc_patterns']
        
        patterns = {
            'cfo_quarterly': None,
            'annual_report': None,
            'company_name': None
        }
        
        sample_docs = self.df['doc_name'].unique()[:50]
        
        for doc in sample_docs:
            if 'CFO' in doc and 'Q' in doc:
                # Extract pattern: 2Q24_CFO_presentation -> {quarter}_CFO_presentation
                patterns['cfo_quarterly'] = '{period}_CFO_presentation'
            
            if 'annual-report' in doc:
                # Extract: dbs-annual-report-2024 -> {company}-annual-report-{year}
                match = re.match(r'([a-z]+)-annual-report-\d{4}', doc)
                if match:
                    patterns['company_name'] = match.group(1)
                    patterns['annual_report'] = f'{match.group(1)}-annual-report-' + '{year}'
        
        self._cache['doc_patterns'] = patterns
        return patterns
    
    def suggest_metric_keywords(self, metric_name: str) -> List[str]:
        """Suggest keywords for a metric based on data"""
        metric_name_lower = metric_name.lower()
        
        # Common financial metric patterns
        keyword_map = {
            'nim': ['Group NIM (%)', 'Commercial NIM (%)', 'Net Interest Margin', 'NIM'],
            'net interest margin': ['Group NIM (%)', 'Commercial NIM (%)', 'Net Interest Margin', 'NIM'],
            'gross margin': ['Group NIM (%)', 'Gross Margin'],
            'income': ['Total income', 'Operating income', 'Net income'],
            'expense': ['Total expenses', 'Operating expenses', 'Opex'],
            'revenue': ['Total revenue', 'Revenue', 'Total income'],
            'profit': ['Profit', 'Net profit', 'Profit before tax']
        }
        
        for key, keywords in keyword_map.items():
            if key in metric_name_lower:
                return keywords
        
        return [metric_name]


# ============================================================================
# TOOLS WITH AUTO-DETECTION
# ============================================================================

@dataclass
class ToolCall:
    """Records a tool call"""
    tool_name: str
    inputs: Dict[str, Any]
    outputs: Dict[str, Any]
    latency_ms: float
    success: bool = True
    error: Optional[str] = None


class SmartTableParser:
    """
    Intelligent table parser with proven extraction logic from Agent CFO
    """
    
    def __init__(self, tables_df: pd.DataFrame, introspector: DataIntrospector):
        self.df = tables_df
        self.introspector = introspector
        self.name = "SmartTableParser"
    
    def parse(self, metric: str, periods: Optional[List[str]] = None, 
              doc_pattern: Optional[str] = None) -> Dict[str, Any]:
        """
        Parse financial metric using PROVEN extraction logic from Agent CFO
        
        Args:
            metric: Natural language metric name (e.g., "Net Interest Margin", "Operating Expenses")
            periods: Optional list of periods. If None, auto-detects
            doc_pattern: Optional document pattern. If None, uses 'dbs-annual-report'
        """
        start_time = time.time()
        
        try:
            # Auto-detect if not provided
            if periods is None:
                # Detect if quarterly or annual based on metric
                if any(q in metric.lower() for q in ['nim', 'margin', 'quarterly']):
                    periods = self.introspector.detect_quarters()
                else:
                    periods = [str(y) for y in self.introspector.detect_years()]
            
            # Get suggested keywords
            keywords = self.introspector.suggest_metric_keywords(metric)
            
            # Default doc pattern
            if doc_pattern is None:
                doc_pattern = 'dbs-annual-report'
            
            # Extract data using PROVEN Agent CFO logic
            results = {}
            sources = []
            
            for period in periods:
                # For NIM (quarterly data from CFO presentations)
                if 'Q' in period and len(period) <= 4:
                    nim_rows = self.df[
                        (self.df['doc_name'].str.contains(f"{period}_CFO_presentation", na=False)) &
                        (self.df['column'].str.contains('Group NIM', case=False, na=False))
                    ]
                    
                    if not nim_rows.empty:
                        tid = nim_rows['table_id'].iloc[0]
                        table_data = self.df[
                            (self.df['doc_name'].str.contains(f"{period}_CFO_presentation", na=False)) &
                            (self.df['table_id'] == tid)
                        ]
                        
                        for row_id in table_data['row_id'].unique():
                            row = table_data[table_data['row_id'] == row_id]
                            quarter_cells = row[row['column'].str.contains('Quarter', case=False, na=False)]
                            if not quarter_cells.empty:
                                quarter_val = quarter_cells.iloc[0]['value_str']
                                if period in str(quarter_val):
                                    nim_cells = row[row['column'].str.contains('Group NIM', case=False, na=False)]
                                    if not nim_cells.empty and pd.notna(nim_cells.iloc[0]['value_num']):
                                        results[period] = float(nim_cells.iloc[0]['value_num'])
                                        sources.append({
                                            'file': nim_cells.iloc[0]['doc_name'],
                                            'page': int(nim_cells.iloc[0]['page']) if pd.notna(nim_cells.iloc[0]['page']) else None,
                                            'table_id': int(tid)
                                        })
                                        break
                
                # For annual data (years)
                else:
                    metric_rows = self.df[
                        (self.df['doc_name'].str.contains(f'{doc_pattern}-{period}', na=False)) &
                        (self.df['value_str'].str.contains('|'.join(keywords), case=False, na=False, regex=True))
                    ]
                    
                    if not metric_rows.empty:
                        for _, row in metric_rows.iterrows():
                            table_data = self.df[
                                (self.df['doc_name'] == row['doc_name']) &
                                (self.df['table_id'] == row['table_id']) &
                                (self.df['row_id'] == row['row_id'])
                            ]
                            
                            # For income: prioritize columns with year or "Total"
                            if 'income' in '|'.join(keywords).lower():
                                candidates = []
                                for _, cell in table_data.iterrows():
                                    col_name = str(cell['column']).lower()
                                    if pd.notna(cell['value_num']) and cell['value_num'] > 10000:
                                        if period in col_name or 'total' in col_name:
                                            candidates.append((3, cell['value_num'], cell))
                                        else:
                                            candidates.append((1, cell['value_num'], cell))
                                
                                if candidates:
                                    candidates.sort(key=lambda x: (x[0], x[1]), reverse=True)
                                    results[period] = float(candidates[0][1])
                                    sources.append({
                                        'file': candidates[0][2]['doc_name'],
                                        'page': int(candidates[0][2]['page']) if pd.notna(candidates[0][2]['page']) else None,
                                        'table_id': int(row['table_id'])
                                    })
                                    break
                            
                            # For expenses: just take first numeric value > 1000
                            else:
                                nums = table_data[table_data['value_num'].notna() & (table_data['value_num'] > 1000)]
                                if not nums.empty:
                                    results[period] = float(nums.iloc[0]['value_num'])
                                    sources.append({
                                        'file': nums.iloc[0]['doc_name'],
                                        'page': int(nums.iloc[0]['page']) if pd.notna(nums.iloc[0]['page']) else None,
                                        'table_id': int(row['table_id'])
                                    })
                                    break
            
            latency_ms = (time.time() - start_time) * 1000
            
            return {
                'data': results,
                'sources': sources,
                'latency_ms': round(latency_ms, 2),
                'periods': periods,
                'keywords_used': keywords
            }
        
        except Exception as e:
            latency_ms = (time.time() - start_time) * 1000
            return {
                'data': {},
                'sources': [],
                'latency_ms': round(latency_ms, 2),
                'error': str(e)
            }



class AdvancedCalculator:
    """Calculator with common financial computations"""
    
    def __init__(self):
        self.name = "AdvancedCalculator"
    
    def compute(self, operation: str, data: Dict[str, Any]) -> Dict[str, Any]:
        """
        Perform calculation
        
        Args:
            operation: One of ['ratio', 'yoy_change', 'average', 'growth_rate', 'sum']
            data: Dictionary with required inputs for the operation
        """
        start_time = time.time()
        
        try:
            if operation == 'ratio':
                result = self._compute_ratio(data['numerator'], data['denominator'])
            elif operation == 'yoy_change':
                result = self._compute_yoy(data['values'])
            elif operation == 'average':
                result = self._compute_average(data['values'])
            elif operation == 'growth_rate':
                result = self._compute_growth_rate(data['values'])
            else:
                raise ValueError(f"Unknown operation: {operation}")
            
            latency_ms = (time.time() - start_time) * 1000
            return {
                'result': result,
                'latency_ms': round(latency_ms, 2)
            }
        
        except Exception as e:
            latency_ms = (time.time() - start_time) * 1000
            return {
                'result': None,
                'latency_ms': round(latency_ms, 2),
                'error': str(e)
            }
    
    def _compute_ratio(self, numerator: Dict[str, float], 
                       denominator: Dict[str, float]) -> Dict[str, float]:
        """Compute ratio for each period"""
        result = {}
        for period in numerator.keys():
            if period in denominator and denominator[period] != 0:
                result[period] = round((numerator[period] / denominator[period]) * 100, 2)
        return result
    
    def _compute_yoy(self, values: Dict[str, float]) -> Dict[str, float]:
        """Compute year-over-year changes"""
        sorted_periods = sorted(values.keys())
        result = {}
        
        for i in range(1, len(sorted_periods)):
            prev = values[sorted_periods[i-1]]
            curr = values[sorted_periods[i]]
            change = ((curr - prev) / prev) * 100
            result[f"{sorted_periods[i-1]}‚Üí{sorted_periods[i]}"] = round(change, 2)
        
        return result
    
    def _compute_average(self, values: Dict[str, float]) -> float:
        """Compute average"""
        return round(sum(values.values()) / len(values), 2)
    
    def _compute_growth_rate(self, values: Dict[str, float]) -> float:
        """Compute CAGR"""
        sorted_periods = sorted(values.keys())
        start_val = values[sorted_periods[0]]
        end_val = values[sorted_periods[-1]]
        n = len(sorted_periods) - 1
        
        if n > 0 and start_val > 0:
            cagr = ((end_val / start_val) ** (1/n) - 1) * 100
            return round(cagr, 2)
        return 0.0


class SmartTrendAnalyzer:
    """Analyze patterns in financial data"""
    
    def __init__(self):
        self.name = "SmartTrendAnalyzer"
    
    def analyze(self, values: Dict[str, float]) -> Dict[str, Any]:
        """Analyze trend pattern"""
        start_time = time.time()
        
        if len(values) < 2:
            return {
                'pattern': 'Insufficient Data',
                'latency_ms': round((time.time() - start_time) * 1000, 2)
            }
        
        sorted_periods = sorted(values.keys())
        sorted_values = [values[p] for p in sorted_periods]
        
        # Detect pattern
        increasing = all(sorted_values[i] <= sorted_values[i+1] for i in range(len(sorted_values)-1))
        decreasing = all(sorted_values[i] >= sorted_values[i+1] for i in range(len(sorted_values)-1))
        
        if increasing:
            pattern = "Consistently Increasing"
        elif decreasing:
            pattern = "Consistently Decreasing"
        else:
            pattern = "Fluctuating"
        
        latency_ms = (time.time() - start_time) * 1000
        
        return {
            'pattern': pattern,
            'min': round(min(sorted_values), 2),
            'max': round(max(sorted_values), 2),
            'avg': round(sum(sorted_values) / len(sorted_values), 2),
            'range': round(max(sorted_values) - min(sorted_values), 2),
            'latency_ms': round(latency_ms, 2)
        }


# ============================================================================
# REACT AGENT WITH STRUCTURED OUTPUT OPTIMIZATION
# ============================================================================

class ReActAgentCFO:
    """
    ReAct-based Agent with single-turn optimization
    
    Uses few-shot examples to guide LLM to produce complete plan in one call,
    then executes tools. This balances true agentic reasoning with latency.
    """
    
    def __init__(self, tables_df: pd.DataFrame, llm_client_tuple: Tuple):
        self.provider, self.client, self.model = llm_client_tuple
        
        # Initialize introspection
        self.introspector = DataIntrospector(tables_df)
        
        # Initialize tools
        self.parser = SmartTableParser(tables_df, self.introspector)
        self.calculator = AdvancedCalculator()
        self.analyzer = SmartTrendAnalyzer()
        
        self.tools = {
            'SmartTableParser': self.parser,
            'AdvancedCalculator': self.calculator,
            'SmartTrendAnalyzer': self.analyzer
        }
        
        self.tool_calls = []
    
    def run(self, query: str) -> Dict[str, Any]:
        """Execute query with ReAct reasoning"""
        start_time = time.time()
        
        # Step 1: LLM generates execution plan
        plan = self._generate_plan(query)
        
        if 'error' in plan:
            return {
                'answer': f"Error: {plan['error']}",
                'latency_ms': round((time.time() - start_time) * 1000, 2),
                'tool_calls': []
            }
        
        # Step 2: Execute tools according to plan
        execution_results = self._execute_plan(plan)
        
        # Step 3: LLM generates final answer from results
        answer = self._generate_answer(query, plan, execution_results)
        
        total_latency = (time.time() - start_time) * 1000
        
        return {
            'answer': answer,
            'plan': plan,
            'tool_calls': self.tool_calls,
            'latency_ms': round(total_latency, 2)
        }
    
    def _generate_plan(self, query: str) -> Dict[str, Any]:
        """LLM generates execution plan using few-shot examples"""
        
        # Get available data context
        quarters = self.introspector.detect_quarters()
        years = self.introspector.detect_years()
        doc_patterns = self.introspector.detect_document_patterns()
        
        system_prompt = f"""You are a financial analysis planning agent. Analyze the query and create an execution plan.

        Available Tools:
        1. SmartTableParser: Extract metrics from financial documents
        - Automatically detects time periods and document patterns
        - Input: {{"metric": "metric name"}}
        - Returns: {{"data": {{"period": value}}, "sources": [...]}}

        2. AdvancedCalculator: Perform calculations
        - Operations: ratio, yoy_change, average, growth_rate
        - Input: {{"operation": "type", "data": {{...}}}}
        - Returns: {{"result": {{...}}}}

        3. SmartTrendAnalyzer: Analyze patterns
        - Input: {{"values": {{"period": value}}}}
        - Returns: {{"pattern": "...", "min": x, "max": y, "avg": z}}

        Context:
        - Available quarters: {quarters}
        - Available years: {years}
        - Company: {doc_patterns.get('company_name', 'unknown')}

        Output a JSON plan with:
        {{
        "reasoning": "step-by-step thought process",
        "steps": [
            {{"tool": "ToolName", "inputs": {{...}}, "purpose": "why this step"}},
            ...
        ]
        }}
        """

        few_shot_examples = """
        Examples:

        Query: "What is the Net Interest Margin for the last 5 quarters?"
        Plan:
        {
        "reasoning": "Query asks for NIM over 5 quarters. I need to: (1) Extract NIM data using SmartTableParser (it will auto-detect quarters), (2) Analyze the trend using SmartTrendAnalyzer.",
        "steps": [
            {"tool": "SmartTableParser", "inputs": {"metric": "Net Interest Margin"}, "purpose": "Extract NIM values for recent quarters"},
            {"tool": "SmartTrendAnalyzer", "inputs": {"values": "$step1.data"}, "purpose": "Identify pattern in NIM"}
        ]
        }

        Query: "Calculate Operating Efficiency Ratio for the last 3 years"
        Plan:
        {
        "reasoning": "Efficiency ratio = Operating Expenses √∑ Operating Income. I need to: (1) Get expenses, (2) Get income, (3) Calculate ratio, (4) Analyze trend.",
        "steps": [
            {"tool": "SmartTableParser", "inputs": {"metric": "Operating Expenses"}, "purpose": "Extract expenses for 3 years"},
            {"tool": "SmartTableParser", "inputs": {"metric": "Operating Income"}, "purpose": "Extract income for 3 years"},
            {"tool": "AdvancedCalculator", "inputs": {"operation": "ratio", "data": {"numerator": "$step1.data", "denominator": "$step2.data"}}, "purpose": "Compute efficiency ratio"},
            {"tool": "SmartTrendAnalyzer", "inputs": {"values": "$step3.result"}, "purpose": "Analyze ratio trend"}
        ]
        }

        Query: "Show Operating Expenses year-over-year for 3 years"
        Plan:
        {
        "reasoning": "Need expenses and YoY changes. Steps: (1) Extract expenses, (2) Calculate YoY changes.",
        "steps": [
            {"tool": "SmartTableParser", "inputs": {"metric": "Operating Expenses"}, "purpose": "Get expenses for 3 years"},
            {"tool": "AdvancedCalculator", "inputs": {"operation": "yoy_change", "data": {"values": "$step1.data"}}, "purpose": "Calculate year-over-year changes"}
        ]
        }
        """

        user_message = f"{few_shot_examples}\n\nNow plan for this query:\nQuery: \"{query}\"\nPlan:"
        
        try:
            if self.provider == 'groq':
                response = self.client.chat.completions.create(
                    model=self.model,
                    messages=[
                        {"role": "system", "content": system_prompt},
                        {"role": "user", "content": user_message}
                    ],
                    temperature=0.1,
                    max_tokens=1000
                )
                plan_text = response.choices[0].message.content.strip()
            else:  # gemini
                chat = self.client.start_chat(history=[])
                response = chat.send_message(f"{system_prompt}\n\n{user_message}")
                plan_text = response.text.strip()
            
            # Parse JSON from response
            if '```json' in plan_text:
                plan_text = plan_text.split('```json')[1].split('```')[0].strip()
            elif '```' in plan_text:
                plan_text = plan_text.split('```')[1].split('```')[0].strip()
            
            plan = json.loads(plan_text)
            return plan
        
        except Exception as e:
            return {'error': f"Plan generation failed: {str(e)}"}
    
    def _execute_plan(self, plan: Dict[str, Any]) -> List[Dict[str, Any]]:
        """Execute the planned steps"""
        results = []
        self.tool_calls = []
        
        for i, step in enumerate(plan.get('steps', [])):
            tool_name = step['tool']
            inputs = step['inputs']
            
            # Resolve references to previous steps (e.g., "$step1.data")
            inputs = self._resolve_references(inputs, results)
            
            # Execute tool
            tool_result = self._execute_tool(tool_name, inputs)
            
            # Record tool call
            self.tool_calls.append(ToolCall(
                tool_name=tool_name,
                inputs=inputs,
                outputs=tool_result,
                latency_ms=tool_result.get('latency_ms', 0)
            ))
            
            results.append(tool_result)
        
        return results
    
    def _resolve_references(self, inputs: Dict[str, Any], 
                           results: List[Dict[str, Any]]) -> Dict[str, Any]:
        """Resolve references like $step1.data to actual values"""
        resolved = {}
        
        for key, value in inputs.items():
            if isinstance(value, str) and value.startswith('$step'):
                # Parse reference: $step1.data
                match = re.match(r'\$step(\d+)\.(\w+)', value)
                if match:
                    step_idx = int(match.group(1)) - 1
                    field = match.group(2)
                    
                    if 0 <= step_idx < len(results):
                        resolved[key] = results[step_idx].get(field, {})
                    else:
                        resolved[key] = {}
                else:
                    resolved[key] = value
            elif isinstance(value, dict):
                resolved[key] = self._resolve_references(value, results)
            else:
                resolved[key] = value
        
        return resolved
    
    def _execute_tool(self, tool_name: str, inputs: Dict[str, Any]) -> Dict[str, Any]:
        """Execute a single tool"""
        try:
            if tool_name == 'SmartTableParser':
                return self.parser.parse(**inputs)
            elif tool_name == 'AdvancedCalculator':
                return self.calculator.compute(**inputs)
            elif tool_name == 'SmartTrendAnalyzer':
                return self.analyzer.analyze(**inputs)
            else:
                return {'error': f'Unknown tool: {tool_name}'}
        except Exception as e:
            return {'error': str(e)}
    
    def _generate_answer(self, query: str, plan: Dict[str, Any], 
                        results: List[Dict[str, Any]]) -> str:
        """LLM generates final answer from execution results"""
        
        system_prompt = """You are a financial analyst. Generate a clear, professional answer to the query using the execution results.

                        Include:
                        1. Direct answer to the question
                        2. Data in table format if applicable
                        3. Key insights or trends
                        4. Citations with page numbers

                        Format citations as: [doc_name p.X]
                        """

        # Compile results summary
        results_summary = []
        for i, (step, result) in enumerate(zip(plan['steps'], results)):
            summary = f"Step {i+1} ({step['tool']}): "
            if 'data' in result:
                summary += f"Extracted {len(result['data'])} values"
            elif 'result' in result:
                summary += f"Computed {len(result['result'])} values" if isinstance(result['result'], dict) else "Computed result"
            elif 'pattern' in result:
                summary += f"Pattern: {result['pattern']}"
            
            results_summary.append(summary)
            results_summary.append(f"  Output: {json.dumps(result, indent=2)}")
        
        user_message = f"""Query: {query}

                            Plan: {plan['reasoning']}

                            Execution Results:
                            {chr(10).join(results_summary)}

                            Generate a professional answer with tables and citations."""

        try:
            if self.provider == 'groq':
                response = self.client.chat.completions.create(
                    model=self.model,
                    messages=[
                        {"role": "system", "content": system_prompt},
                        {"role": "user", "content": user_message}
                    ],
                    temperature=0.3,
                    max_tokens=2000
                )
                return response.choices[0].message.content.strip()
            else:  # gemini
                chat = self.client.start_chat(history=[])
                response = chat.send_message(f"{system_prompt}\n\n{user_message}")
                return response.text.strip()
        
        except Exception as e:
            # Fallback: generate answer from results directly
            return self._fallback_answer(query, results)
    
    def _fallback_answer(self, query: str, results: List[Dict[str, Any]]) -> str:
        """Fallback answer generation if LLM fails"""
        lines = [f"Query: {query}\n"]
        
        for i, result in enumerate(results):
            if 'data' in result and result['data']:
                lines.append(f"\nData (Step {i+1}):")
                for period, value in result['data'].items():
                    lines.append(f"  {period}: {value}")
                
                if 'sources' in result:
                    lines.append("\nCitations:")
                    for src in result['sources'][:3]:
                        lines.append(f"  [{src['file']} p.{src['page']}]")
        
        return "\n".join(lines)


# ============================================================================
# BENCHMARK RUNNER
# ============================================================================

# Initialize
kb = g2x.KBEnv()
llm_tuple = g2x._make_llm_client()

print("[ReAct Agent CFO] Initializing...")
react_agent = ReActAgentCFO(kb.tables_df, llm_tuple)
print(f"[ReAct Agent CFO] Ready")
print(f"  - LLM: {llm_tuple[0]}")
print(f"  - Auto-detected: {react_agent.introspector.detect_quarters(5)} quarters")
print(f"  - Auto-detected: {react_agent.introspector.detect_years(3)} years")

# Run benchmark
queries = [
    "Report the Gross Margin (or Net Interest Margin, if a bank) over the last 5 quarters, with values.",
    "Show Operating Expenses for the last 3 fiscal years, year-on-year comparison.",
    "Calculate the Operating Efficiency Ratio (Opex √∑ Operating Income) for the last 3 fiscal years, showing the working."
]

print("\n" + "=" * 60)
print("  REACT AGENT CFO BENCHMARK")
print("=" * 60)

results_json = []
latencies = []

for i, query in enumerate(queries, 1):
    print(f"\n{'=' * 60}")
    print(f"Q{i}. {query}")
    print("=" * 60)
    
    result = react_agent.run(query)
    latencies.append(result['latency_ms'])
    
    # Display reasoning
    if 'plan' in result and 'reasoning' in result['plan']:
        print(f"\n[LLM Reasoning] {result['plan']['reasoning']}\n")
    
    # Display tool calls
    if result['tool_calls']:
        print(f"[Tool Execution] {len(result['tool_calls'])} tools called:")
        for tc in result['tool_calls']:
            print(f"  - {tc.tool_name}: {tc.latency_ms:.2f} ms")
        print()
    
    # Display answer
    print(result['answer'])
    print(f"\n(Total Latency: {result['latency_ms']:.2f} ms)")
    
    # Record for JSON
    results_json.append({
        'query_id': f'Q{i}',
        'query': query,
        'answer': result['answer'],
        'plan_reasoning': result.get('plan', {}).get('reasoning', ''),
        'tool_calls': [
            {
                'tool': tc.tool_name,
                'latency_ms': tc.latency_ms
            }
            for tc in result['tool_calls']
        ],
        'latency_ms': result['latency_ms']
    })

# Summary
p50 = np.percentile(latencies, 50)
p95 = np.percentile(latencies, 95)

print("\n" + "=" * 60)
print("  SUMMARY")
print("=" * 60)
print(f"P50: {p50:.1f} ms")
print(f"P95: {p95:.1f} ms")
print(f"Approach: ReAct (single-turn optimized)")
print(f"LLM Calls: 2 per query (plan + answer)")

# Save results
output_path = "./data_marker/bench_react_agent_cfo.json"
with open(output_path, 'w') as f:
    json.dump({
        "system": "ReAct Agent CFO",
        "approach": "LLM-driven planning with auto-detection",
        "latency": {
            "p50_ms": round(p50, 2),
            "p95_ms": round(p95, 2)
        },
        "results": results_json
    }, f, indent=2)

print("\n" + "=" * 60)
print("  COMPLETE")
print("=" * 60)
print(f"ReAct Agent Results: {output_path}")
print("=" * 60)

KeyboardInterrupt: 

## 6. Instrumentation

Log timings: T_ingest, T_retrieve, T_rerank, T_reason, T_generate, T_total. Log tokens, cache hits, tools.

In [None]:
# Example instrumentation schema
import pandas as pd
logs = pd.DataFrame(columns=['Query','T_ingest','T_retrieve','T_rerank','T_reason','T_generate','T_total','Tokens','CacheHits','Tools'])
logs

## 7. Optimizations

**Required Optimizations**

Each team must implement at least:
*   2 retrieval optimizations (e.g., hybrid BM25+vector, smaller embeddings, dynamic k).
*   1 caching optimization (query cache or ratio cache).
*   1 agentic optimization (plan pruning, parallel sub-queries).
*   1 system optimization (async I/O, batch embedding, memory-mapped vectors).

###  7.1 Hybrid BM25 + vector + RRF

In [None]:
class KBEnv:
    def __init__(self, base="./data_marker", enable_bm25=True):
        self.base = Path(base)
        self.faiss_path = self.base / "kb_index.faiss"
        self.meta_path = self.base / "kb_index_meta.json"
        self.texts_path = self.base / "kb_texts.npy"
        self.chunks_path = self.base / "kb_chunks.parquet"
        self.tables_path = self.base / "kb_tables.parquet"
        self.outline_path = self.base / "kb_outline.parquet"

        if not self.faiss_path.exists():
            raise FileNotFoundError(self.faiss_path)
        if not self.meta_path.exists():
            raise FileNotFoundError(self.meta_path)
        if not self.texts_path.exists():
            raise FileNotFoundError(self.texts_path)
        if not self.chunks_path.exists():
            raise FileNotFoundError(self.chunks_path)

        self.texts: List[str] = np.load(self.texts_path, allow_pickle=True).tolist()
        self.meta_df: pd.DataFrame = pd.read_parquet(self.chunks_path)
        
        if 'page' in self.meta_df.columns:
            self.meta_df['page'] = pd.to_numeric(self.meta_df['page'], errors='coerce').astype('Int64')
            
        if len(self.texts) != len(self.meta_df):
            raise ValueError(f"texts ({len(self.texts)}) and meta ({len(self.meta_df)}) mismatch")

        self.tables_df: Optional[pd.DataFrame] = (
            pd.read_parquet(self.tables_path) if self.tables_path.exists() else None
        )
        self.outline_df: Optional[pd.DataFrame] = (
            pd.read_parquet(self.outline_path) if self.outline_path.exists() else None
        )

        # FAISS index
        self.index = faiss.read_index(str(self.faiss_path))
        idx_meta = json.loads(self.meta_path.read_text(encoding="utf-8"))
        self.model_name = idx_meta.get("model", "sentence-transformers/all-MiniLM-L6-v2")
        self.embed_dim = int(idx_meta.get("dim", 384))
        self.model = SentenceTransformer(self.model_name)

        # ========== NEW: BM25 Index ==========
        self.bm25 = None
        if enable_bm25:
            # print("[BM25] Building BM25 index...")
            tokenized_corpus = [text.lower().split() for text in self.texts]
            self.bm25 = BM25Okapi(tokenized_corpus)
            print(f"[BM25] ‚úì Indexed {len(self.texts)} documents")
        elif enable_bm25:
            print("[BM25] ‚úó rank_bm25 not installed, skipping BM25")
    def _embed(self, texts: List[str]) -> np.ndarray:
        v = self.model.encode(texts, normalize_embeddings=True)
        return np.asarray(v, dtype="float32")

    # ========== NEW: Hybrid Search with BM25 + Vector + RRF ==========
    def search(
        self, 
        query: str, 
        k: int = 12,
        alpha: float = 0.5,  # Weight for vector vs BM25 (0.0=pure BM25, 1.0=pure vector)
        
    ) -> pd.DataFrame:
        """
        Hybrid search with BM25 + Vector + RRF fusion
        
        Pipeline:
        1. BM25 search ‚Üí get scores
        2. Vector search ‚Üí get scores
        3. Fusion: RRF (reciprocal rank) or weighted score fusion
        4. Return top-k
        """
        rerank_top_k = k  # Get k candidates

        # ========== Step 1: Vector Search ==========
        qv = self._embed([query])
        vec_scores, vec_idxs = self.index.search(qv, min(k * 2, len(self.texts)))
        vec_idxs, vec_scores = vec_idxs[0], vec_scores[0]
        
        # Filter valid indices
        vec_results = {int(i): float(s) for i, s in zip(vec_idxs, vec_scores) if i >= 0 and i < len(self.texts)}

        # ========== Step 2: BM25 Search ==========
        bm25_results = {}
        if self.bm25 is not None:
            query_tokens = query.lower().split()
            bm25_scores = self.bm25.get_scores(query_tokens)
            
            # Normalize BM25 scores to [0, 1]
            max_bm25 = max(bm25_scores) if len(bm25_scores) > 0 else 1.0
            if max_bm25 > 0:
                bm25_scores = bm25_scores / max_bm25
            
            # Get top candidates
            top_bm25_idx = np.argsort(bm25_scores)[-k * 2:][::-1]
            bm25_results = {int(i): float(bm25_scores[i]) for i in top_bm25_idx if bm25_scores[i] > 0}

        # ========== Step 3: Fusion (RRF or Weighted Score) ==========
        all_indices = set(vec_results.keys()) | set(bm25_results.keys())
        
        if self.bm25 is not None:
            # Reciprocal Rank Fusion
            vec_ranks = {idx: rank for rank, idx in enumerate(sorted(vec_results, key=vec_results.get, reverse=True), 1)}
            bm25_ranks = {idx: rank for rank, idx in enumerate(sorted(bm25_results, key=bm25_results.get, reverse=True), 1)}
            
            k_rrf = 60  # RRF constant
            fused_scores = {}
            for idx in all_indices:
                vec_rank = vec_ranks.get(idx, len(self.texts))
                bm25_rank = bm25_ranks.get(idx, len(self.texts))
                fused_scores[idx] = (1 / (k_rrf + vec_rank)) + (1 / (k_rrf + bm25_rank))
            
            # print(f"[Search] RRF fusion: {len(all_indices)} candidates")
        else:
            # Weighted score fusion (fallback if BM25 disabled or RRF=False)
            fused_scores = {}
            for idx in all_indices:
                vec_score = vec_results.get(idx, 0.0)
                bm25_score = bm25_results.get(idx, 0.0)
                fused_scores[idx] = alpha * vec_score + (1 - alpha) * bm25_score
            
            print(f"[Search] Weighted fusion (Œ±={alpha}): {len(all_indices)} candidates")

        # Sort by fused score
        sorted_indices = sorted(fused_scores.keys(), key=fused_scores.get, reverse=True)[:k]

        # ========== Step 5: Build Results DataFrame ==========
        # Take top-k results (no reranking)
        final_indices = sorted_indices[:k]
        rows = []
        for rank, idx in enumerate(final_indices, start=1):
            md = self.meta_df.iloc[idx]
            item = {
                "rank": rank,
                "score": fused_scores[idx],
                "text": self.texts[idx],
                "doc": md.get("doc"),
                "path": md.get("path"),
                "modality": md.get("modality"),
                "chunk": int(md.get("chunk", 0)),
                "page": _page_or_none(md.get("page")),
            }

            # print(item)
            
            # Section hint
            if self.outline_df is not None:
                toc = self.outline_df[self.outline_df["doc_name"] == item["doc"]]
                if not toc.empty:
                    item["section_hint"] = toc.iloc[0]["title"]
            
            rows.append(item)
        
        return pd.DataFrame(rows)

### 7.2 Metadata Filtering/Boosting

Run below cell to enable the Metadata Filtering/boosting optimization, then run the Benchmark Runner to see the change in results.

In [2]:
# ============================================================================
# METADATA ENHANCEMENT SETUP (Run this cell once)
# ============================================================================

# Step 0: Reload modules to pick up latest changes
import sys
import importlib

# Remove old modules from cache
for mod in list(sys.modules.keys()):
    if 'metadata_enhancer' in mod or 'kb_metadata_extension' in mod:
        del sys.modules[mod]

print("=" * 80)
print("METADATA ENHANCEMENT SETUP - FINE-TUNED VERSION")
print("=" * 80 + "\n")

# Step 1: Enhance KB with metadata (ONE TIME - only run if not already done)
try:
    from metadata_enhancer import enhance_kb_with_metadata
    
    print("üìä Enhancing KB with metadata (year, quarter, doc_type, section)...")
    enhance_kb_with_metadata("./data_marker")
    print("‚úì KB enhancement complete!\n")
except FileNotFoundError as e:
    print(f"KB files not found: {e}")
    print("   Make sure ./data_marker/kb_chunks.parquet exists\n")
except Exception as e:
    print(f"‚ÑπEnhancement skipped (may already be done): {e}\n")

# Step 2: Add metadata search capability to KBEnv
try:
    from kb_metadata_extension import add_metadata_search_to_kbenv
    add_metadata_search_to_kbenv()
    print("‚úì Metadata search added to KBEnv\n")
except Exception as e:
    print(f"Could not add metadata search: {e}\n")

# Step 3: Replace KBEnv.search() with metadata-enhanced version
print("=" * 80)
print("APPLYING BALANCED METADATA OPTIMIZATION")
print("=" * 80 + "\n")

from g2x import KBEnv

# Save original search method if not already saved
if not hasattr(KBEnv, '_original_search'):
    KBEnv._original_search = KBEnv.search
    print("‚úì Original search method saved\n")

# Define metadata-enhanced search wrapper with balanced boosting
def _metadata_optimized_search(self, query, k=50, alpha=0.6, rerank_top_k=100):
    """
    Enhanced search with balanced metadata boosting, recency decay, and adaptive weights
    
    FINE-TUNED SETTINGS:
    - Reduced boost weights (quarter: 4.0x‚Üí6.0x vs old 5.0x‚Üí8.0x)
    - Less aggressive recency decay (5% per quarter vs 7%)
    - Smaller initial pool (k*12 vs k*15) for better speed
    - Improved "last N quarters" detection
    """
    if hasattr(self, 'search_with_metadata'):
        # Temporarily restore original to avoid recursion
        original_method = KBEnv.search
        KBEnv.search = KBEnv._original_search
        try:
            result = self.search_with_metadata(
                query, 
                k=k, 
                alpha=alpha, 
                rerank_top_k=rerank_top_k,
                enable_metadata_boost=True,
                enable_metadata_filter=True,  # Soft filter enabled
                boost_weights=None,  # Use adaptive weights (None = auto-detect)
                apply_recency_decay=True  # Apply time-based decay (5% per quarter)
            )
        finally:
            # Restore the enhanced search
            KBEnv.search = original_method
        return result
    else:
        # Fallback to original if metadata not available
        return KBEnv._original_search(self, query, k=k, alpha=alpha, rerank_top_k=rerank_top_k)

# Apply the optimization
KBEnv.search = _metadata_optimized_search

print("‚úì BALANCED OPTIMIZATION ENABLED")
print("   All kb.search() calls will now use:")
print("   ‚Ä¢ Adaptive metadata boosting (balanced weights)")
print("   ‚Ä¢ Recency decay (5% per quarter age)")
print("   ‚Ä¢ Soft filtering (¬±2 year window when year detected)")
print("   ‚Ä¢ Improved 'last N quarters' detection")

print("\n" + "=" * 80)
print("SETUP COMPLETE - Balanced Metadata Optimization Active!")
print("=" * 80)
print("""
TO DISABLE THIS OPTIMIZATION:
  - Restart the kernel, OR
  - Run: KBEnv.search = KBEnv._original_search
""")

METADATA ENHANCEMENT SETUP - FINE-TUNED VERSION



KeyboardInterrupt: 

### 7.3 Cache Extracted PDF Data

### 7.4 Parallel Sub-queries

In [None]:
# ----------------------------- PARALLEL QUERY DECOMPOSER -----------------------------

class ParallelQueryDecomposer:
    """Decomposes complex queries using QueryAnalyzer"""
    
    @staticmethod
    def decompose(query: str) -> List[str]:
        """
        Intelligent query decomposition using query analysis
        """
        analyzer = QueryAnalyzer
        
        # Extract intent
        needs_calc = analyzer.needs_calculation(query)
        metric = analyzer.extract_metric(query)
        years = analyzer.extract_years(query)
        num_periods = analyzer.extract_num_periods(query)
        
        # Q3: Efficiency Ratio (Opex √∑ Income)
        if needs_calc and metric and "efficiency" in metric.lower():
            return [
                f"Extract Operating Expenses for the last {num_periods or 3} fiscal years",
                f"Extract Total Income for the last {num_periods or 3} fiscal years"
            ]
        
        # Q3: Any ratio calculation (A √∑ B)
        if needs_calc and ("ratio" in query.lower() or "√∑" in query or "/" in query):
            # Try to extract both metrics
            parts = re.split(r'[√∑/]|\bdivided by\b', query, flags=re.I)
            if len(parts) == 2:
                metric_a = analyzer.extract_metric(parts[0])
                metric_b = analyzer.extract_metric(parts[1])
                if metric_a and metric_b:
                    return [
                        f"Extract {metric_a} for the last {num_periods or 3} fiscal years",
                        f"Extract {metric_b} for the last {num_periods or 3} fiscal years"
                    ]
        
        # Multi-metric comparison
        if analyzer.want_compare(query) and metric:
            # Decompose by year if multiple years specified
            if len(years) > 2:
                return [f"Extract {metric} for FY{y}" for y in years]
        
        # Single metric query (no decomposition)
        return [query]

    @staticmethod
    async def execute_parallel_async(kb: KBEnv, sub_queries: List[str], k_ctx: int) -> List[pd.DataFrame]:
        """Execute sub-queries in parallel"""
        loop = asyncio.get_event_loop()
        
        def search_sync(query):
            return kb.search(query, k=k_ctx)
        
        with ThreadPoolExecutor(max_workers=min(len(sub_queries), 4)) as executor:
            tasks = [loop.run_in_executor(executor, search_sync, sq) for sq in sub_queries]
            results = await asyncio.gather(*tasks)
        
        return results
    
    @staticmethod
    def execute_parallel(kb: KBEnv, sub_queries: List[str], k_ctx: int) -> List[pd.DataFrame]:
        """Blocking wrapper for async parallel execution"""
        try:
            loop = asyncio.get_event_loop()
            if loop.is_running():
                # Already in event loop (e.g., Jupyter), use nest_asyncio
                try:
                    import nest_asyncio
                    nest_asyncio.apply()
                except ImportError:
                    pass
        except RuntimeError:
            loop = asyncio.new_event_loop()
            asyncio.set_event_loop(loop)
        
        return loop.run_until_complete(
            ParallelQueryDecomposer.execute_parallel_async(kb, sub_queries, k_ctx)
        )
    
    @staticmethod
    def merge_results(results: List[pd.DataFrame], k_ctx: int) -> pd.DataFrame:
        """Merge and deduplicate parallel results"""
        if not results:
            return pd.DataFrame()
        
        # Concatenate
        merged = pd.concat([r for r in results if not r.empty], ignore_index=True)
        if merged.empty:
            return merged
        
        # Deduplicate by text (keep highest score)
        merged = merged.sort_values('score', ascending=False)
        merged = merged.drop_duplicates(subset=['text'], keep='first')
        
        # Take top-k
        merged = merged.head(k_ctx)
        
        # Re-rank
        merged = merged.sort_values('score', ascending=False).reset_index(drop=True)
        merged['rank'] = range(1, len(merged) + 1)
        
        return merged

### 7.5 Asynchronous I/O

In [None]:
pip install aiohttp

In [None]:
import asyncio
try:
    import nest_asyncio
except ImportError:
    nest_asyncio = None

def ensure_loop():
    try:
        loop = asyncio.get_event_loop()
        if loop.is_running() and nest_asyncio:
            nest_asyncio.apply()
        return loop
    except RuntimeError:
        loop = asyncio.new_event_loop()
        asyncio.set_event_loop(loop)
        return loop

# 2) Async retrieval shim (FAISS/BM25 via thread pool)
from concurrent.futures import ThreadPoolExecutor
_search_pool = ThreadPoolExecutor(max_workers=8)

class AsyncRetrieval:
    @staticmethod
    async def execute_parallel_async(kb, sub_queries, k_ctx: int, max_concurrency: int = 8):
        print(f"Running async parallel retrieval for {len(sub_queries)} sub-queries with concurrency={max_concurrency}")
        if not sub_queries:
            return []
        loop = asyncio.get_event_loop()
        sem = asyncio.Semaphore(max_concurrency)

        async def one(q: str):
            async with sem:
                return await loop.run_in_executor(_search_pool, kb.search, q, k_ctx)

        results = await asyncio.gather(*(one(q) for q in sub_queries), return_exceptions=True)
        import pandas as pd
        safe = []
        for r in results:
            safe.append(pd.DataFrame() if isinstance(r, Exception) else r)
        return safe

    @staticmethod
    def execute_parallel(kb, sub_queries, k_ctx: int, max_concurrency: int = 8):
        loop = ensure_loop()
        return loop.run_until_complete(
            AsyncRetrieval.execute_parallel_async(kb, sub_queries, k_ctx, max_concurrency)
        )

# 3) Wire into your decomposer if present; else, provide a tiny adapter
def execute_parallel_subqueries(kb, sub_queries, k_ctx, max_concurrency=8):
    return AsyncRetrieval.execute_parallel(kb, sub_queries, k_ctx, max_concurrency)

# 4) Async HTTP clients (LLM + Embeddings) with sync adapters
import aiohttp
from typing import List, Dict, Any

class AsyncLLMClient:
    def __init__(self, base_url: str, api_key: str, max_concurrent: int = 8, timeout_s: int = 60):
        self.base_url = base_url.rstrip("/")
        self.api_key = api_key
        self.sem = asyncio.Semaphore(max_concurrent)
        self.timeout_s = timeout_s

    async def chat(self, payload: Dict[str, Any]) -> Dict[str, Any]:
        print("Async LLM chat API call started")
        headers = {"Authorization": f"Bearer {self.api_key}"} if self.api_key else {}
        async with self.sem:
            async with aiohttp.ClientSession(timeout=aiohttp.ClientTimeout(total=self.timeout_s)) as sess:
                async with sess.post(f"{self.base_url}/chat/completions", json=payload, headers=headers) as r:
                    r.raise_for_status()
                    print("Async LLM chat API call completed")
                    return await r.json()

class AsyncEmbeddingsClient:
    def __init__(self, base_url: str, api_key: str, batch_size: int = 64, max_concurrent: int = 4, timeout_s: int = 60):
        self.base_url = base_url.rstrip("/")
        self.api_key = api_key
        self.batch_size = batch_size
        self.sem = asyncio.Semaphore(max_concurrent)
        self.timeout_s = timeout_s

    async def embed(self, texts: List[str]) -> List[List[float]]:
        headers = {"Authorization": f"Bearer {self.api_key}"} if self.api_key else {}
        out: List[List[float]] = []
        async with aiohttp.ClientSession(timeout=aiohttp.ClientTimeout(total=self.timeout_s)) as sess:
            tasks = []
            for i in range(0, len(texts), self.batch_size):
                chunk = texts[i:i+self.batch_size]
                async def one(ch=chunk):
                    async with self.sem:
                        async with sess.post(f"{self.base_url}/embeddings", json={"input": ch}, headers=headers) as r:
                            r.raise_for_status()
                            data = await r.json()
                            return [v["embedding"] for v in data["data"]]
                tasks.append(one())
            for res in await asyncio.gather(*tasks, return_exceptions=True):
                if isinstance(res, Exception):
                    continue
                out.extend(res)
        return out

# 5) Provide sync adapters so the rest of the notebook doesn‚Äôt break
import os
LLM_ASYNC = AsyncLLMClient(base_url=os.environ.get("OPENAI_BASE_URL","https://api.openai.com/v1"),
                           api_key=os.environ.get("OPENAI_API_KEY",""),
                           max_concurrent=8)

EMB_ASYNC = AsyncEmbeddingsClient(base_url=os.environ.get("OPENAI_BASE_URL","https://api.openai.com/v1"),
                                  api_key=os.environ.get("OPENAI_API_KEY",""),
                                  batch_size=64, max_concurrent=4)

def llm_chat_sync(payload: dict) -> dict:
    loop = ensure_loop()
    return loop.run_until_complete(LLM_ASYNC.chat(payload))

def embed_sync(texts: List[str]) -> List[List[float]]:
    loop = ensure_loop()
    return loop.run_until_complete(EMB_ASYNC.embed(texts))




## 8. Results & Plots

Show baseline vs optimized. Include latency plots (p50/p95) and accuracy tables.

In [None]:
import json
from pathlib import Path
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import os
import textwrap

sns.set(style="whitegrid")

RESULTS_DIR = Path("results")
RESULTS_DIR.mkdir(exist_ok=True, parents=True)

def pretty_xticks(ax, wrap_width=32, rotation=30, fontsize=9, bottom=0.36):
    """Wrap and rotate x-tick labels for long query strings."""
    labels = [t.get_text() for t in ax.get_xticklabels()]
    wrapped = [textwrap.fill(str(l), wrap_width) for l in labels]
    ax.set_xticklabels(wrapped, rotation=rotation, ha='right', fontsize=fontsize)
    ax.tick_params(axis='x', which='major', labelsize=fontsize)
    plt.subplots_adjust(bottom=bottom)

# Look for saved benchmark JSONs (baseline in data/, agentic in data_marker/)
def find_bench_jsons():
    cand = []
    for root in [Path("data"), Path("data_marker")]:
        if not root.exists(): 
            continue
        for p in root.glob("bench_*.json"):
            cand.append(p)
    return cand

def load_bench_json(path: Path):
    try:
        doc = json.loads(path.read_text(encoding="utf-8"))
    except Exception:
        return pd.DataFrame()
    rows = []
    for r in doc.get("results", []):
        rows.append({
            "mode": path.stem.split("_", 1)[1] if "_" in path.stem else path.stem,
            "query": r.get("query"),
            "latency_ms": float(r.get("latency_ms") or np.nan),
            "answer_len": len(str(r.get("answer") or "")),
            "n_citations": len(r.get("citations") or []),
            "execution_log": r.get("execution_log"),
            "raw_answer": r.get("answer")
        })
    return pd.DataFrame(rows)

files = find_bench_jsons()
if not files:
    print("No bench_*.json files found (expected data/bench_baseline.json and data_marker/bench_agentic.json).")
    print("Run the Benchmark Runner to produce JSON outputs and re-run this cell.")
else:
    dfs = []
    for f in files:
        df = load_bench_json(f)
        if not df.empty:
            dfs.append(df)
    if not dfs:
        print("No valid JSON content found in bench outputs.")
    else:
        bench_df = pd.concat(dfs, ignore_index=True)
        # safety: drop rows that have NaN latency (if any)
        bench_df = bench_df.dropna(subset=["latency_ms"])
        bench_df.to_csv(RESULTS_DIR / "bench_combined.csv", index=False)

        # 1) Latency Comparison: grouped bar per query
        pivot_lat = bench_df.pivot(index="query", columns="mode", values="latency_ms").fillna(np.nan)
        fig, ax = plt.subplots(figsize=(12, 5))
        pivot_lat.plot(kind="bar", ax=ax, rot=0)
        ax.set_ylabel("Latency (ms)")
        ax.set_title("Baseline vs Agentic Latency per Benchmark Query")
        pretty_xticks(ax, wrap_width=36, rotation=30, fontsize=10, bottom=0.37)
        plt.tight_layout()
        plt.savefig(RESULTS_DIR / "latency_per_query.png")
        plt.close()

        # 2) Latency distribution summary (p50/p95)
        stats = bench_df.groupby("mode")["latency_ms"].agg(["median", lambda s: s.quantile(0.95), "mean", "std"]).reset_index()
        stats = stats.rename(columns={"median": "p50_ms", "<lambda_0>": "p95_ms", "mean": "mean_ms", "std": "std_ms"})
        stats.to_csv(RESULTS_DIR / "latency_summary.csv", index=False)

        # 3) Answer length comparison per query
        pivot_len = bench_df.pivot(index="query", columns="mode", values="answer_len").fillna(0)
        fig, ax = plt.subplots(figsize=(12, 5))
        pivot_len.plot(kind="bar", ax=ax, rot=0)
        ax.set_ylabel("Answer length (chars)")
        ax.set_title("Answer Length (characters) ‚Äî Baseline vs Agentic")
        pretty_xticks(ax, wrap_width=36, rotation=30, fontsize=10, bottom=0.37)
        plt.tight_layout()
        plt.savefig(RESULTS_DIR / "answer_length_per_query.png")
        plt.close()

        # 4) Citation counts per query
        pivot_cit = bench_df.pivot(index="query", columns="mode", values="n_citations").fillna(0)
        fig, ax = plt.subplots(figsize=(12, 4))
        pivot_cit.plot(kind="bar", ax=ax, rot=0)
        ax.set_ylabel("Number of Citations")
        ax.set_title("Number of Citations per Query")
        pretty_xticks(ax, wrap_width=36, rotation=30, fontsize=10, bottom=0.33)
        plt.tight_layout()
        plt.savefig(RESULTS_DIR / "citations_per_query.png")
        plt.close()

        # 5) Tools / actions used in agentic runs (bar chart)
        # Try to extract 'actions' or any 'tool' mentions from execution_log
        def extract_tools_from_log(exec_log):
            if not exec_log:
                return []
            # execution_log might be a dict with 'plan'/'actions' or a list; support both
            tools = []
            if isinstance(exec_log, dict):
                # common keys: 'plan', 'actions', 'observations'
                if isinstance(exec_log.get("actions"), list):
                    tools.extend(exec_log.get("actions"))
                # plan may be a list of step dicts with 'tool'
                plan = exec_log.get("plan")
                if isinstance(plan, list):
                    for step in plan:
                        if isinstance(step, dict):
                            t = step.get("tool") or step.get("tool_call")
                            if isinstance(t, str):
                                tools.append(t)
            elif isinstance(exec_log, list):
                # fallback: scan entries for 'tool_call' text
                for ent in exec_log:
                    if isinstance(ent, dict):
                        tc = ent.get("tool_call") or ent.get("tool")
                        if tc:
                            tools.append(tc)
            return [t for t in tools if t]

        agentic_rows = bench_df[bench_df["mode"].str.contains("agent", case=False, na=False)]
        if not agentic_rows.empty:
            tool_counts = {}
            for _, r in agentic_rows.iterrows():
                tools = extract_tools_from_log(r["execution_log"])
                for t in tools:
                    # normalize names (strip parameters)
                    if isinstance(t, str):
                        name = t.split("(")[0].strip()
                        tool_counts[name] = tool_counts.get(name, 0) + 1
            if tool_counts:
                tool_items = pd.Series(tool_counts).sort_values(ascending=False)
                plt.figure(figsize=(8, 3))
                sns.barplot(x=tool_items.values, y=tool_items.index, palette="viridis")
                plt.xlabel("Call Count")
                plt.ylabel("Tool / Action")
                plt.title("Tools / Actions called (Agentic runs)")
                plt.tight_layout()
                plt.savefig(RESULTS_DIR / "tool_usage_agentic.png")
                plt.close()
                # save JSON form
                json.dump(tool_counts, open(RESULTS_DIR / "tool_usage_agentic.json", "w"), indent=2)
            else:
                print("No explicit 'actions' or tool usage found in agentic run execution logs.")
        else:
            print("No agentic results found to extract tools usage.")

        # 6) Latency summary (console)
        print("\nLatency summary (p50/p95/mean/std) by mode:")
        display_stats = stats.round(2)
        print(display_stats.to_string(index=False))

        # 7) Save a short report
        report = {
            "bench_combined_rows": len(bench_df),
            "modes": bench_df["mode"].unique().tolist(),
            "latency_summary": stats.to_dict(orient="records"),
            "files": [str(x) for x in files]
        }
        json.dump(report, open(RESULTS_DIR / "bench_report_summary.json", "w"), indent=2)

        # 8) Small pretty table (display in notebook)
        try:
            from IPython.display import display, Markdown
            md = ["# Results & Plots ‚Äî Summary"]
            md.append("## Latency summary (ms)")
            md.append(stats.to_markdown(index=False))
            md.append("## Quick notes")
            md.append(f"- Benchmarks combined rows: {len(bench_df)}")
            md.append(f"- Charts saved to: {RESULTS_DIR}")
            display(Markdown("\n\n".join(md)))
        except Exception:
            pass

print(f"Plots and CSV/JSON summaries written to: {RESULTS_DIR}")