In [None]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [None]:
# --- Install deps (quiet) ---
!pip -q install fastapi uvicorn nest_asyncio pyngrok google-cloud-storage pillow \
  "pymilvus>=2.4.5" git+https://github.com/illuin-tech/colpali.git psycopg2-binary

# Gemini + LangChain integration
!pip -q install google-generativeai langchain-google-genai langsmith

# Optional: LangChain text splitters for better OCR chunking
!pip -q install langchain-text-splitters

  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone


In [None]:
import os

# ==== EDIT THESE (fill in your real keys) ====
os.environ["GOOGLE_APPLICATION_CREDENTIALS"] = ""  # path to your uploaded SA key
os.environ["GCS_BUCKET"] = ""  # e.g. mercerchat-dev
os.environ["ZILLIZ_URI"] = ""
os.environ["ZILLIZ_TOKEN"] = ""
os.environ["GPU_BEARER_TOKEN"] = ""
os.environ["NGROK_AUTHTOKEN"] = ""
os.environ["PG_CONN_STR"] = ""

# --- Gemini (VLM) ---
# Set your Google API key to use Gemini as the VLM
os.environ["GOOGLE_API_KEY"] = ""  # <-- put your actual API key here

# Choose VLM provider & model.
# Options you'll wire in Cell 3+: VLM_PROVIDER in {"qwen", "gemini"}.
# We'll default to Gemini to try reasoning improvements; switch back to "qwen" if needed.
os.environ["VLM_PROVIDER"] = "gemini"
os.environ["GEMINI_MODEL"] = "gemini-2.5-pro"  # override here if your project uses a different name
os.environ["LANGCHAIN_TRACING"] = "true"
os.environ["LANGCHAIN_API_KEY"] = ""
# Optional but recommended:
os.environ["LANGCHAIN_PROJECT"] = ""

# ===== Optional knobs =====
# Retrieval / rendering
os.environ["TOP_K_DEFAULT"] = "8"
os.environ["MAX_PAGES_FOR_VLM_DEFAULT"] = "5"
os.environ["PAGE_LIMIT"] = "0"            # 0 = no page cap at GCS listing
os.environ["PER_SUBVEC"] = "150"
os.environ["HNSW_EF"] = "512"
os.environ["MAX_QTOK"] = "32"

# VLM generation
os.environ["MAX_NEW_TOKENS"] = "1028"     # bump token budget as requested
os.environ["VLM_TEMPERATURE"] = "0"       # default greedy; set >0 for sampling

# GPU hygiene (still useful for ColNomic embeddings on CUDA)
os.environ["CLEAR_CUDA_EACH"] = "1"

# History inclusion (used when building the prompt for the VLM)
os.environ["HISTORY_TURNS"] = "8"         # how many recent rows to fetch from Postgres
os.environ["HISTORY_CHARS"] = "6000"      # cap history text length

In [None]:
# =======================
# MercerChat GPU Backend (Cell 1/2)
# Setup only: env, clients, models, helpers, and LangChain-style Runnable (Gemini 2.5 Pro)
# =======================

# --- Imports ---
import os, io, re, time, traceback, base64
from typing import List, Dict, Tuple
from collections import defaultdict

import torch
import numpy as np
from PIL import Image
from google.cloud import storage
from pymilvus import MilvusClient, DataType
import psycopg2, psycopg2.extras

# ColNomic embeddings (retriever)
from colpali_engine.models import ColQwen2_5, ColQwen2_5_Processor

# LangChain / Gemini / LangSmith
from langchain_core.runnables import RunnableMap
from langchain_core.messages import SystemMessage, HumanMessage, AIMessage
from langchain_google_genai import ChatGoogleGenerativeAI
from langsmith.run_helpers import traceable

# Optional splitter (installed in Cell 1)
try:
    from langchain_text_splitters import RecursiveCharacterTextSplitter
except Exception:
    RecursiveCharacterTextSplitter = None

# ===== Env =====
GPU_BEARER_TOKEN = os.environ.get("GPU_BEARER_TOKEN", "")
GCS_BUCKET       = os.environ.get("GCS_BUCKET", "")
ZILLIZ_URI       = os.environ.get("ZILLIZ_URI", "")
ZILLIZ_TOKEN     = os.environ.get("ZILLIZ_TOKEN", "")
NGROK_AUTHTOKEN  = os.environ.get("NGROK_AUTHTOKEN", "")
PG_CONN_STR      = os.environ.get("PG_CONN_STR", "")

# Tunables
PAGE_LIMIT       = int(os.environ.get("PAGE_LIMIT", "0"))        # 0 = no limit
PER_SUBVEC       = int(os.environ.get("PER_SUBVEC", "150"))      # ANN hits per query subvector
HNSW_EF          = int(os.environ.get("HNSW_EF", "512"))         # ef at search time
TOP_PAGES        = int(os.environ.get("TOP_PAGES", "3"))         # pages to show VLM
MAX_QTOK         = int(os.environ.get("MAX_QTOK", "32"))         # truncate query subvectors
DIM              = 128                                           # ColNomic projection dim
CLEAR_CUDA_EACH  = os.environ.get("CLEAR_CUDA_EACH", "1") == "1" # per-request GPU cleanup
HISTORY_TURNS    = int(os.environ.get("HISTORY_TURNS", "8"))     # how many past messages to fetch
HISTORY_CHARS    = int(os.environ.get("HISTORY_CHARS", "1500"))  # max chars of history injected

# ===== GCS =====
gcs = storage.Client()
page_num_re = re.compile(r"/page_(\d+)\.png$", re.IGNORECASE)

def list_page_keys(conversation_id: str, doc_id: str) -> List[str]:
    prefix = f"pages/{conversation_id}/{doc_id}/"
    keys = []
    for b in gcs.list_blobs(GCS_BUCKET, prefix=prefix):
        if b.name.endswith(".png") and page_num_re.search(b.name):
            keys.append(b.name)
    # numeric sort by page index (robust guard)
    def _pg(k):
        m = page_num_re.search(k)
        return int(m.group(1)) if m else 10**9
    keys.sort(key=_pg)
    if PAGE_LIMIT and len(keys) > PAGE_LIMIT:
        keys = keys[:PAGE_LIMIT]
    return keys

def list_text_keys(conversation_id: str, doc_id: str) -> List[str]:
    """List OCR text page files saved by the first backend."""
    prefix = f"text/{conversation_id}/{doc_id}/"
    keys = []
    for b in gcs.list_blobs(GCS_BUCKET, prefix=prefix):
        if b.name.endswith(".txt"):
            keys.append(b.name)
    # sort by page number if available
    def _pg(k):
        m = re.search(r"/page_(\d+)\.txt$", k, re.IGNORECASE)
        return int(m.group(1)) if m else 10**9
    keys.sort(key=_pg)
    if PAGE_LIMIT and len(keys) > PAGE_LIMIT:
        keys = keys[:PAGE_LIMIT]
    return keys

def page_text_key_for_image_key(image_key: str) -> str:
    """pages/<conv>/<doc>/page_X.png  ->  text/<conv>/<doc>/page_X.txt"""
    return image_key.replace("pages/", "text/").replace(".png", ".txt")

def page_image_key_for_text_key(text_key: str) -> str:
    """text/<conv>/<doc>/page_X.txt  ->  pages/<conv>/<doc>/page_X.png"""
    return text_key.replace("text/", "pages/").replace(".txt", ".png")

def download_images_for_keys(keys: List[str]) -> List[Image.Image]:
    imgs = []
    for k in keys:
        try:
            data = gcs.bucket(GCS_BUCKET).blob(k).download_as_bytes()
            imgs.append(Image.open(io.BytesIO(data)).convert("RGB"))
        except Exception as e:
            print(f"[WARN] could not download {k}: {e}")
    return imgs

def download_text_for_keys(image_keys: List[str]) -> List[str]:
    """
    Given image page keys, fetch matching OCR text files.
    Returns list of text strings ("" if missing).
    """
    texts = []
    for k in image_keys:
        try:
            txt_key = page_text_key_for_image_key(k)
            data = gcs.bucket(GCS_BUCKET).blob(txt_key).download_as_bytes()
            texts.append(data.decode("utf-8"))
        except Exception as e:
            print(f"[WARN] could not download OCR text for {k}: {e}")
            texts.append("")
    return texts

# ===== Zilliz/Milvus (vector store) =====
milvus = MilvusClient(uri=ZILLIZ_URI, token=ZILLIZ_TOKEN)

def coll_name_for(conversation_id: str) -> str:
    # safe: letters, numbers, underscores; must start with a letter
    name = re.sub(r'[^A-Za-z0-9_]+', '_', str(conversation_id))
    if not re.match(r'^[A-Za-z]', name):
        name = f'c_{name}'
    return f'conv_{name}'[:255]

def ensure_collection(conv_id: str) -> str:
    collection = coll_name_for(conv_id)
    if milvus.has_collection(collection):
        return collection
    schema = milvus.create_schema(auto_id=True, enable_dynamic_fields=True)
    schema.add_field("pk",       DataType.INT64,   is_primary=True)
    schema.add_field("vector",   DataType.FLOAT_VECTOR, dim=DIM)
    schema.add_field("doc_id",   DataType.VARCHAR, max_length=64)
    schema.add_field("page",     DataType.INT32)
    schema.add_field("seq_id",   DataType.INT32)
    schema.add_field("gcs_key",  DataType.VARCHAR, max_length=512)
    milvus.create_collection(collection, schema=schema)
    idx = milvus.prepare_index_params()
    idx.add_index(field_name="vector", index_name="vec", index_type="HNSW",
                  metric_type="IP", params={"M": 16, "efConstruction": 200})
    idx.add_index(field_name="doc_id", index_name="doc", index_type="INVERTED")
    milvus.create_index(collection, index_params=idx, sync=True)
    milvus.load_collection(collection)
    return collection

# ===== Extra TEXT collection for OCR chunk vectors =====
def coll_name_text_for(conversation_id: str) -> str:
    base = coll_name_for(conversation_id)  # e.g., conv_abc123
    return f"{base}_txt"

def ensure_text_collection(conv_id: str) -> str:
    collection = coll_name_text_for(conv_id)
    if milvus.has_collection(collection):
        return collection
    schema = milvus.create_schema(auto_id=True, enable_dynamic_fields=True)
    schema.add_field("pk",       DataType.INT64,   is_primary=True)
    schema.add_field("vector",   DataType.FLOAT_VECTOR, dim=DIM)
    schema.add_field("doc_id",   DataType.VARCHAR, max_length=64)
    schema.add_field("page",     DataType.INT32)
    schema.add_field("seq_id",   DataType.INT32)  # text-chunk index
    schema.add_field("gcs_key",  DataType.VARCHAR, max_length=512)  # store IMAGE page key here
    milvus.create_collection(collection, schema=schema)
    idx = milvus.prepare_index_params()
    idx.add_index(field_name="vector", index_name="vec", index_type="HNSW",
                  metric_type="IP", params={"M": 16, "efConstruction": 200})
    idx.add_index(field_name="doc_id", index_name="doc", index_type="INVERTED")
    milvus.create_index(collection, index_params=idx, sync=True)
    milvus.load_collection(collection)
    return collection

# ===== Embedding Model (ColNomic) for retrieval =====
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
# You switched to 7B — keep DIM = 128 (ColPali fixed projection)
MODEL_NAME = "nomic-ai/colnomic-embed-multimodal-7b"

embed_model = ColQwen2_5.from_pretrained(
    MODEL_NAME,
    torch_dtype=(torch.bfloat16 if DEVICE == "cuda" else torch.float32),
    device_map=DEVICE if DEVICE == "cuda" else None,
).eval()
embed_proc = ColQwen2_5_Processor.from_pretrained(MODEL_NAME)

def _first_tensor(t):
    if torch.is_tensor(t):
        return t
    if isinstance(t, (list, tuple)):
        for x in t:
            if torch.is_tensor(x):
                return x
    return None

def _to_numpy_2d(t):
    """
    Convert model output to (N, D) float32.
    Flattens batch/nesting axes, keeps all sub-vectors (late interaction).
    """
    x = _first_tensor(t)
    if x is None:
        raise ValueError("No tensor found in model output")
    arr = x.detach().to("cpu", dtype=torch.float32).numpy()  # (B,S,D) or (S,D)
    D = arr.shape[-1]
    vecs = arr.reshape(-1, D)                                # (N, D)
    return vecs, D

@torch.no_grad()
def embed_image_to_vectors(img: Image.Image):
    """
    Image -> list[list[float]] where each inner list is a 128-d sub-vector.
    """
    dev = next(embed_model.parameters()).device
    batch = embed_proc.process_images([img]).to(dev)
    out = embed_model(**batch)        # typically (1, S, 128)
    vecs_2d, D = _to_numpy_2d(out)    # -> (S, 128)
    if D != DIM:
        raise ValueError(f"Embedding dim {D} != expected {DIM}")
    return [[float(v) for v in row] for row in vecs_2d], D

# Text embeddings for OCR chunks: mean-pool to single 128-d per chunk
@torch.no_grad()
def embed_text_to_vectors(text: str) -> List[List[float]]:
    """
    Text -> one pooled 128-d vector (list[list[float]] for API symmetry).
    Uses the same ColNomic text pathway as queries.
    """
    dev = next(embed_model.parameters()).device
    if hasattr(embed_proc, "process_queries"):
        batch = embed_proc.process_queries([text]).to(dev)
    else:
        batch = embed_proc(text=[text]).to(dev)
    out = embed_model(**batch)  # (1, S, 128)
    x = _first_tensor(out)      # (1, S, 128) or (S,128)
    arr = x.detach().to("cpu", dtype=torch.float32).numpy().reshape(-1, DIM)  # (S,128)
    pooled = arr.mean(axis=0).tolist()  # (128,)
    return [pooled]

def upsert_vectors_batched(collection: str, doc_id: str, page_key: str,
                           vectors: List[List[float]], batch_size: int = 4096) -> int:
    """
    Upserts all sub-vectors for a page PNG into Milvus (image path).
    """
    m = page_num_re.search(page_key)
    page_num = int(m.group(1)) if m else 0
    N = len(vectors)
    if N == 0:
        return 0

    rows = [
        {
            "vector": vectors[i],
            "doc_id": str(doc_id),
            "page": page_num,
            "seq_id": i,
            "gcs_key": page_key,
        }
        for i in range(N)
    ]
    total = 0
    for s in range(0, N, batch_size):
        milvus.insert(collection, rows[s:s+batch_size])
        total += min(batch_size, N - s)
    # Make inserts visible to search immediately
    milvus.flush(collection)
    return total

def upsert_text_vectors_batched(collection: str, doc_id: str, page_key: str,
                                vectors: List[List[float]], chunk_idx: int,
                                batch_size: int = 4096) -> int:
    """
    Upserts pooled text vectors for OCR chunks into the TEXT collection.
    'page_key' should be the **IMAGE** page key so later we can fetch the image directly.
    """
    m = page_num_re.search(page_key)
    page_num = int(m.group(1)) if m else 0
    rows = []
    for vi, v in enumerate(vectors):
        rows.append({
            "vector": v,
            "doc_id": str(doc_id),
            "page": page_num,
            "seq_id": chunk_idx * 10000 + vi,  # chunk index spacing
            "gcs_key": page_key,
        })
    total = 0
    for s in range(0, len(rows), batch_size):
        milvus.insert(collection, rows[s:s+batch_size])
        total += min(batch_size, len(rows) - s)
    # Flush to make searchable
    milvus.flush(collection)
    return total

# --- OCR text chunking helper ---
def chunk_text(text: str) -> List[str]:
    if not text:
        return []
    if RecursiveCharacterTextSplitter:
        splitter = RecursiveCharacterTextSplitter(
            chunk_size=1000, chunk_overlap=150, separators=["\n\n", "\n", " ", ""]
        )
        return [c for c in splitter.split_text(text) if c.strip()]
    # fallback naive
    chunk_size, overlap = 1000, 150
    chunks, n = [], len(text)
    i = 0
    while i < n:
        chunks.append(text[i:i+chunk_size])
        i += (chunk_size - overlap)
    return [c for c in chunks if c.strip()]

# --- Bulk embed OCR text for a given doc (to be called by server during ingestion) ---
def embed_doc_ocr_text_to_milvus(conversation_id: str, doc_id: str) -> dict:
    """
    Loads OCR text files from GCS (text/<conv>/<doc>/page_X.txt),
    chunks → embeds → upserts into TEXT collection with gcs_key set to the matching IMAGE key.
    """
    txt_keys = list_text_keys(conversation_id, doc_id)
    if not txt_keys:
        return {"pages": 0, "chunks": 0}

    txt_coll = ensure_text_collection(conversation_id)
    total_chunks = 0
    for tk in txt_keys:
        try:
            data = gcs.bucket(GCS_BUCKET).blob(tk).download_as_bytes()
            text = data.decode("utf-8", errors="ignore")
        except Exception as e:
            print(f"[WARN] failed to read {tk}: {e}")
            continue

        img_key = page_image_key_for_text_key(tk)  # store image key in Milvus
        # doc_id from path sanity
        m = re.search(rf"/{re.escape(doc_id)}/", tk)
        if not m:
            # best effort: extract doc id from path
            pass

        chunks = chunk_text(text)
        for ci, ch in enumerate(chunks):
            vecs = embed_text_to_vectors(ch)  # [[128]]
            upsert_text_vectors_batched(txt_coll, doc_id, img_key, vecs, chunk_idx=ci)
            total_chunks += 1

    return {"pages": len(txt_keys), "chunks": total_chunks}

# ===== Query encoder (ColNomic text -> (S_q, 128)) =====
@torch.no_grad()
def encode_query_text(text: str) -> np.ndarray:
    dev = next(embed_model.parameters()).device
    if hasattr(embed_proc, "process_queries"):
        batch = embed_proc.process_queries([text]).to(dev)
    else:
        batch = embed_proc(text=[text]).to(dev)
    out = embed_model(**batch)  # (1, S_q, 128) typically
    x = _first_tensor(out)
    arr = x.detach().to("cpu", dtype=torch.float32).numpy()
    return arr.reshape(-1, arr.shape[-1])  # (S_q, 128)

# ===== ANN search and late-interaction aggregation (images) =====
def search_per_token(collection: str, qvec: np.ndarray, topk: int):
    res = milvus.search(
        collection_name=collection,
        data=[qvec.tolist()],
        limit=topk,
        output_fields=["doc_id", "page", "seq_id", "gcs_key"],
        search_params={"metric_type": "IP", "params": {"ef": HNSW_EF}},
    )
    hits = res[0] if res else []
    return [
        {
            "score": float(h["distance"]),
            "doc_id": h["entity"]["doc_id"],
            "page": int(h["entity"]["page"]),
            "gcs_key": h["entity"]["gcs_key"],
        }
        for h in hits
    ]

def late_interaction_aggregate(query_vecs: np.ndarray, collection: str) -> list:
    q = query_vecs[:MAX_QTOK]
    page_scores: Dict[Tuple[str, int], float] = defaultdict(float)
    page_bestkey: Dict[Tuple[str, int], Tuple[float, str]] = {}

    for qv in q:
        hits = search_per_token(collection, qv, topk=PER_SUBVEC)
        # per-token max per page
        per_page_max: Dict[Tuple[str, int], Tuple[float, str]] = {}
        for h in hits:
            key = (h["doc_id"], h["page"])
            cur = per_page_max.get(key, (-1e9, ""))
            if h["score"] > cur[0]:
                per_page_max[key] = (h["score"], h["gcs_key"])
        # sum into global and remember best gcs_key
        for key, (sc, gk) in per_page_max.items():
            page_scores[key] += sc
            best = page_bestkey.get(key, (-1e9, ""))
            if sc > best[0]:
                page_bestkey[key] = (sc, gk)

    ranked = []
    for (doc_id, page), sc in page_scores.items():
        gk = page_bestkey.get((doc_id, page), (-1e9, None))[1]
        ranked.append((sc, {"doc_id": doc_id, "page": page, "gcs_key": gk}))
    ranked.sort(key=lambda x: x[0], reverse=True)
    return ranked

# ===== TEXT ANN search + aggregation to pages =====
def search_text(collection: str, qvec_1d: np.ndarray, limit: int = 200):
    res = milvus.search(
        collection_name=collection,
        data=[qvec_1d.tolist()],
        limit=limit,
        output_fields=["doc_id", "page", "gcs_key", "seq_id"],
        search_params={"metric_type": "IP", "params": {"ef": HNSW_EF}},
    )
    hits = res[0] if res else []
    return [
        {
            "score": float(h["distance"]),
            "doc_id": h["entity"]["doc_id"],
            "page": int(h["entity"]["page"]),
            "gcs_key": h["entity"]["gcs_key"],
        }
        for h in hits
    ]

def aggregate_text_to_pages(hits: List[dict]) -> List[dict]:
    # Sum scores per (doc_id, page, gcs_key)
    agg = {}
    for h in hits:
        key = (h["doc_id"], h["page"], h["gcs_key"])
        agg[key] = agg.get(key, 0.0) + h["score"]
    ranked = [{"key": k, "score": sc} for k, sc in agg.items()]
    ranked.sort(key=lambda x: x["score"], reverse=True)
    return ranked

# ===== Reciprocal Rank Fusion =====
def rrf(runs: List[List[tuple]], k: int = 60, c: int = 60) -> List[tuple]:
    """
    runs: list of ranked lists. Each ranked list is a list of keys (e.g., (doc_id, page, gcs_key)).
    Returns: list of (key, fused_score), sorted desc.
    """
    scores = {}
    for run in runs:
        for rank, key in enumerate(run, start=1):
            rr = 1.0 / (c + rank)
            scores[key] = scores.get(key, 0.0) + rr
    fused = sorted(scores.items(), key=lambda x: x[1], reverse=True)
    return fused[:k]

# ===== Hybrid retrieval (image late-interaction + text ANN + RRF) =====
def hybrid_retrieval(conv_id: str, query: str, top_pages: int):
    # 1) encode query text
    qvecs = encode_query_text(query)   # (S,128)
    qvec  = qvecs.mean(axis=0)         # pooled (128,) for text ANN

    img_coll = ensure_collection(conv_id)
    txt_coll = ensure_text_collection(conv_id)

    # 2) IMAGE: late-interaction per-token → page aggregation
    img_ranked = late_interaction_aggregate(qvecs, img_coll)  # [(score, {doc_id,page,gcs_key})...]

    # Convert to ordered unique keys for RRF
    img_keys = []
    seen = set()
    for _, meta in img_ranked:
        key = (meta["doc_id"], meta["page"], meta["gcs_key"])
        if key not in seen:
            img_keys.append(key)
            seen.add(key)

    # 3) TEXT: ANN with pooled qvec → aggregate to pages
    txt_hits   = search_text(txt_coll, qvec, limit=PER_SUBVEC * 2)
    txt_ranked = aggregate_text_to_pages(txt_hits)
    txt_keys   = [x["key"] for x in txt_ranked]

    # 4) Fuse via RRF
    fused = rrf([img_keys, txt_keys], k=max(top_pages * 4, 40))

    # 5) Slice top pages
    top = fused[:top_pages]
    top_page_keys = [k for (k, _) in top]  # (doc_id, page, gcs_key)
    gcs_keys = [k[2] for k in top_page_keys]

    return gcs_keys, {
        "img_ranked": img_ranked[:top_pages],
        "txt_ranked": txt_ranked[:top_pages],
        "fused": top,
    }

# ===== Postgres: fetch short chat history =====
def _pg_conn():
    if not PG_CONN_STR:
        raise RuntimeError("PG_CONN_STR not set in env")
    return psycopg2.connect(PG_CONN_STR, connect_timeout=10)

def fetch_chat_history(conv_id: str, limit: int = HISTORY_TURNS) -> list[dict]:
    """
    Returns a list of dicts [{sender, content, created_at}] in chronological order.
    Table schema assumed:
      chat_messages(id, conversation_id, sender, content, metadata, created_at)
    sender is either 'human' or 'AI'.
    """
    if not conv_id:
        return []
    q = """
        SELECT sender, content, created_at
        FROM chat_messages
        WHERE conversation_id = %s
        ORDER BY created_at DESC
        LIMIT %s
    """
    with _pg_conn() as conn:
        with conn.cursor(cursor_factory=psycopg2.extras.RealDictCursor) as cur:
            cur.execute(q, (conv_id, limit))
            rows = cur.fetchall() or []
    rows = list(reversed(rows))
    return [
        {
            "sender": (r["sender"] or "").strip(),
            "content": r.get("content") or "",
            "created_at": r.get("created_at"),
        }
        for r in rows
    ]

# ===== Gemini 2.5 Pro (LangChain) =====
# Requires env GOOGLE_API_KEY
gemini_llm = ChatGoogleGenerativeAI(
    model="gemini-2.0-flash",
    temperature=0,
    max_output_tokens=1028,
)

def _pil_to_data_url(img: Image.Image) -> str:
    """PIL → data URL (image/png; base64) for Gemini vision inputs."""
    buf = io.BytesIO()
    img.save(buf, format="PNG")
    b64 = base64.b64encode(buf.getvalue()).decode("utf-8")
    return f"data:image/png;base64,{b64}"

def build_messages(state: dict) -> list:
    """
    Compose a Gemini multi-modal prompt with:
    - strict system instructions
    - short structured chat history (text only)
    - current user turn = question + OCR text + retrieved page images
    """
    system_inst = (
        "You are an expert assistant for analyzing insurance policy documents from page images.\n"
        "\n"
        "Core rules:\n"
        "1) Evidence-bound: Use ONLY the provided page images and their OCR text.\n"
        "   • The OCR text is an automated transcript of the images — use it as assistance to read and verify text.\n"
        "   • If there is any difference between the OCR text and the image, make the determination which one takes precedence.\n"
        "2) If the answer is not visible, say you cannot find it. Never guess.\n"
        "3) Output style: provide a VERY DETAILED, comprehensive response in clear prose and bullet lists.\n"
        "   • Do NOT be brief. Avoid short summaries.\n"
        "4) Quote numbers/identifiers exactly as printed (benefit limits, premiums, deductibles, co-insurance %, policy/endorsement numbers,\n"
        "   dates, plan codes). Preserve units/currency and capitalization as shown.\n"
        "5) Robust terminology: If the requested term is not found verbatim, search for close synonyms or equivalents on the images (e.g.,\n"
        "   'policy no', 'policy number', 'policy #', 'plan code', 'certificate no', 'endorsement', 'benefit limit', 'co-insurance', 'coinsurance').\n"
        "   Return the closest explicit match from the images. If still unclear, state that explicitly.\n"
        "6) Multiple candidates: If several plausible values appear, report each with its nearby label/heading and explain the ambiguity instead of choosing arbitrarily.\n"
        "7) Page awareness: When helpful, reference where you found information (e.g., 'Page 2 – Schedule of Benefits – Room & Board').\n"
        "8) Conversation history: You may use prior conversation ONLY to understand intent. Do NOT treat prior messages as evidence.\n"
        "9) Fidelity/uncertainty: If characters are ambiguous (e.g., 0 vs O, 1 vs I), cross-check nearby labels or repeated occurrences on the provided pages.\n"
        "   If uncertainty remains, say so (e.g., 'appears to be 75105, but could be 7S105').\n"
        "10) Coverage depth: Go deep. When answering about a benefit/plan/endorsement, include relevant conditions, limits, sub-limits, co-insurance,\n"
        "    deductibles, waiting periods, exclusions, and any rider/endorsement interactions that are visible on the pages.\n"
        "11) Formatting:\n"
        "    • Use short headings and bullets to organize a long, detailed answer.\n"
        "    • Keep labels from the document when they aid clarity (e.g., 'Co-insurance by Insured Member: 20%').\n"
        "    • Ensure **clear spacing** between sections and details:\n"
        "        - Always add a blank line between different headings.\n"
        "        - Always add a blank line between bullet points.\n"
        "        - Always add a blank line between paragraphs.\n"
        "    • Avoid tables unless explicitly requested; expand details in prose/bullets instead.\n"
        "12) Scope control: If the question spans multiple plans or sections, provide a thorough breakdown per plan/section with all visible figures and constraints.\n"
        "13) No preambles or system scaffolding; answer directly with detailed content.\n"
        "\n"
        "Task-specific guidance:\n"
        "• Numbers/IDs: Return exact strings as printed (including hyphens/spaces). If the user asks for a number (e.g., an endorsement or policy number),\n"
        "  provide the exact value and indicate where it was read.\n"
        "• Definitions: When terms are defined on the page (e.g., 'Any One Disability', 'Deductible'), include those definitions verbatim or closely paraphrased.\n"
        "• Comparisons: When comparing plans/benefits, include all relevant differences (limits, coinsurance, eligibility, notes) in detailed prose; do not condense.\n"
        "• Missing/Not covered: If the document states a benefit is excluded or not applicable, say that clearly and point to the wording.\n"
        "• Dates/periods: Preserve visible formats (e.g., DD/MM/YYYY vs MM/DD/YYYY) exactly.\n"
        "\n"
        "If a direct answer is unavailable on the provided images, say: 'I can't find this in the provided pages.' Optionally suggest likely sections to check\n"
        "(e.g., 'Schedule of Benefits', 'Endorsement', 'Policy Schedule', 'General Provisions').\n"
    )

    msgs = [SystemMessage(content=system_inst)]

    # Add short history as alternating human/assistant text turns
    for turn in state.get("history", []):
        sender = (turn.get("sender") or "").strip().lower()
        txt = (turn.get("content") or "").strip()
        if not txt:
            continue
        if sender == "human":
            msgs.append(HumanMessage(content=txt))
        else:
            msgs.append(AIMessage(content=txt))

    # Current user: question + OCR text + images as data URLs
    q    = state["question"]
    imgs = state.get("images", [])
    txts = state.get("texts", [])

    parts = [{"type": "text", "text": q}]
    for i, im in enumerate(imgs):
        if i < len(txts) and txts[i]:
            parts.append({"type": "text", "text": f"OCR Page {i+1}:\n{txts[i]}"})
        parts.append({"type": "image_url", "image_url": _pil_to_data_url(im)})

    msgs.append(HumanMessage(content=parts))
    return msgs

# ===== Retrieval step for RunnableMap (now hybrid & sends OCR text too) =====
def retrieval_step(state: dict) -> dict:
    """
    Input:  {"question": str, "conversation_id": str}
    Output: {"question", "conversation_id", "history", "images", "texts", "pages_used", "retrieval"}
    """
    q = state["question"]
    conv_id = state["conversation_id"]

    # 1) History for the VLM (text-only)
    try:
        history_rows = fetch_chat_history(conv_id, limit=HISTORY_TURNS)
    except Exception as e:
        print(f"[WARN] history fetch failed: {e}")
        history_rows = []

    # 2) Hybrid retrieval (image + text via RRF)
    gcs_keys, debug = hybrid_retrieval(conv_id, q, top_pages=TOP_PAGES)

    # 3) Select page images + OCR text
    imgs = download_images_for_keys(gcs_keys)
    txts = download_text_for_keys(gcs_keys)

    return {
        "question": q,
        "conversation_id": conv_id,
        "history": history_rows,
        "images": imgs,
        "texts": txts,
        "retrieval": debug,          # includes img_ranked / txt_ranked / fused
        "pages_used": gcs_keys,      # list of IMAGE GCS keys
    }

@traceable(name="gemini_invoke")
def call_gemini_with_messages(msgs: list):
    # returns the text content so your RunnableMap stays the same
    return gemini_llm.invoke(msgs).content

# ===== RunnableMap pipeline =====
pipeline = (
    RunnableMap({
        "question":        lambda inp: inp["question"],
        "conversation_id": lambda inp: inp["conversation_id"],
    })
    | retrieval_step
    | RunnableMap({
        "answer":     lambda s: call_gemini_with_messages(build_messages(s)),
        "pages_used": lambda s: s["pages_used"],
        "retrieval":  lambda s: s["retrieval"],
    })
)

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


Fetching 7 files:   0%|          | 0/7 [00:00<?, ?it/s]

Loading checkpoint shards:   0%|          | 0/7 [00:00<?, ?it/s]

Some weights of the model checkpoint at nomic-ai/colqwen2.5-7B-base were not used when initializing ColQwen2_5: ['lm_head.weight']
- This IS expected if you are initializing ColQwen2_5 from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing ColQwen2_5 from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Using a slow image processor as `use_fast` is unset and a slow processor was saved with this model. `use_fast=True` will be the default behavior in v4.52, even if the model was saved with a slow processor. This will result in minor differences in outputs. You'll still be able to use a slow processor with `use_fast=False`.
You have video processor config saved in `preprocessor.json` file which is deprecated. Vid

In [None]:
# =======================
# MercerChat GPU Backend (Cell 1/2) (IMAGE-BASED RETRIEVAL)
# Setup only: env, clients, models, helpers, and LangChain-style Runnable (Gemini 2.5 Pro)
# =======================

# --- Imports ---
import os, io, re, time, traceback, base64
from typing import List, Dict, Tuple
from collections import defaultdict

import torch
import numpy as np
from PIL import Image
from google.cloud import storage
from pymilvus import MilvusClient, DataType
import psycopg2, psycopg2.extras

# ColNomic embeddings (retriever)
from colpali_engine.models import ColQwen2_5, ColQwen2_5_Processor

# LangChain / Gemini / LangSmith
from langchain_core.runnables import RunnableMap
from langchain_core.messages import SystemMessage, HumanMessage, AIMessage
from langchain_google_genai import ChatGoogleGenerativeAI
from langsmith.run_helpers import traceable

# Optional splitter (installed in Cell 1)
try:
    from langchain_text_splitters import RecursiveCharacterTextSplitter
except Exception:
    RecursiveCharacterTextSplitter = None

# ===== Env =====
GPU_BEARER_TOKEN = os.environ.get("GPU_BEARER_TOKEN", "")
GCS_BUCKET       = os.environ.get("GCS_BUCKET", "")
ZILLIZ_URI       = os.environ.get("ZILLIZ_URI", "")
ZILLIZ_TOKEN     = os.environ.get("ZILLIZ_TOKEN", "")
NGROK_AUTHTOKEN  = os.environ.get("NGROK_AUTHTOKEN", "")
PG_CONN_STR      = os.environ.get("PG_CONN_STR", "")

# Tunables
PAGE_LIMIT       = int(os.environ.get("PAGE_LIMIT", "0"))        # 0 = no limit
PER_SUBVEC       = int(os.environ.get("PER_SUBVEC", "150"))      # ANN hits per query subvector
HNSW_EF          = int(os.environ.get("HNSW_EF", "512"))         # ef at search time
TOP_PAGES        = int(os.environ.get("TOP_PAGES", "16"))         # pages to show VLM
MAX_QTOK         = int(os.environ.get("MAX_QTOK", "32"))         # truncate query subvectors
DIM              = 128                                           # ColNomic projection dim
CLEAR_CUDA_EACH  = os.environ.get("CLEAR_CUDA_EACH", "1") == "1" # per-request GPU cleanup
HISTORY_TURNS    = int(os.environ.get("HISTORY_TURNS", "2"))     # how many past messages to fetch
HISTORY_CHARS    = int(os.environ.get("HISTORY_CHARS", "1500"))  # max chars of history injected

# ===== GCS =====
gcs = storage.Client()
page_num_re = re.compile(r"/page_(\d+)\.png$", re.IGNORECASE)

def list_page_keys(conversation_id: str, doc_id: str) -> List[str]:
    prefix = f"pages/{conversation_id}/{doc_id}/"
    keys = []
    for b in gcs.list_blobs(GCS_BUCKET, prefix=prefix):
        if b.name.endswith(".png") and page_num_re.search(b.name):
            keys.append(b.name)
    # numeric sort by page index (robust guard)
    def _pg(k):
        m = page_num_re.search(k)
        return int(m.group(1)) if m else 10**9
    keys.sort(key=_pg)
    if PAGE_LIMIT and len(keys) > PAGE_LIMIT:
        keys = keys[:PAGE_LIMIT]
    return keys

def list_text_keys(conversation_id: str, doc_id: str) -> List[str]:
    """List OCR text page files saved by the first backend."""
    prefix = f"text/{conversation_id}/{doc_id}/"
    keys = []
    for b in gcs.list_blobs(GCS_BUCKET, prefix=prefix):
        if b.name.endswith(".txt"):
            keys.append(b.name)
    # sort by page number if available
    def _pg(k):
        m = re.search(r"/page_(\d+)\.txt$", k, re.IGNORECASE)
        return int(m.group(1)) if m else 10**9
    keys.sort(key=_pg)
    if PAGE_LIMIT and len(keys) > PAGE_LIMIT:
        keys = keys[:PAGE_LIMIT]
    return keys

def page_text_key_for_image_key(image_key: str) -> str:
    """pages/<conv>/<doc>/page_X.png  ->  text/<conv>/<doc>/page_X.txt"""
    return image_key.replace("pages/", "text/").replace(".png", ".txt")

def page_image_key_for_text_key(text_key: str) -> str:
    """text/<conv>/<doc>/page_X.txt  ->  pages/<conv>/<doc>/page_X.png"""
    return text_key.replace("text/", "pages/").replace(".txt", ".png")

def download_images_for_keys(keys: List[str]) -> List[Image.Image]:
    imgs = []
    for k in keys:
        try:
            data = gcs.bucket(GCS_BUCKET).blob(k).download_as_bytes()
            imgs.append(Image.open(io.BytesIO(data)).convert("RGB"))
        except Exception as e:
            print(f"[WARN] could not download {k}: {e}")
    return imgs

def download_text_for_keys(image_keys: List[str]) -> List[str]:
    """
    Given image page keys, fetch matching OCR text files.
    Returns list of text strings ("" if missing).
    """
    texts = []
    for k in image_keys:
        try:
            txt_key = page_text_key_for_image_key(k)
            data = gcs.bucket(GCS_BUCKET).blob(txt_key).download_as_bytes()
            texts.append(data.decode("utf-8"))
        except Exception as e:
            print(f"[WARN] could not download OCR text for {k}: {e}")
            texts.append("")
    return texts

# ===== Zilliz/Milvus (vector store) =====
milvus = MilvusClient(uri=ZILLIZ_URI, token=ZILLIZ_TOKEN)

def coll_name_for(conversation_id: str) -> str:
    # safe: letters, numbers, underscores; must start with a letter
    name = re.sub(r'[^A-Za-z0-9_]+', '_', str(conversation_id))
    if not re.match(r'^[A-Za-z]', name):
        name = f'c_{name}'
    return f'conv_{name}'[:255]

def ensure_collection(conv_id: str) -> str:
    collection = coll_name_for(conv_id)
    if milvus.has_collection(collection):
        return collection
    schema = milvus.create_schema(auto_id=True, enable_dynamic_fields=True)
    schema.add_field("pk",       DataType.INT64,   is_primary=True)
    schema.add_field("vector",   DataType.FLOAT_VECTOR, dim=DIM)
    schema.add_field("doc_id",   DataType.VARCHAR, max_length=64)
    schema.add_field("page",     DataType.INT32)
    schema.add_field("seq_id",   DataType.INT32)
    schema.add_field("gcs_key",  DataType.VARCHAR, max_length=512)
    milvus.create_collection(collection, schema=schema)
    idx = milvus.prepare_index_params()
    idx.add_index(field_name="vector", index_name="vec", index_type="HNSW",
                  metric_type="IP", params={"M": 16, "efConstruction": 200})
    idx.add_index(field_name="doc_id", index_name="doc", index_type="INVERTED")
    milvus.create_index(collection, index_params=idx, sync=True)
    milvus.load_collection(collection)
    return collection

# ===== Extra TEXT collection for OCR chunk vectors =====
def coll_name_text_for(conversation_id: str) -> str:
    base = coll_name_for(conversation_id)  # e.g., conv_abc123
    return f"{base}_txt"

def ensure_text_collection(conv_id: str) -> str:
    collection = coll_name_text_for(conv_id)
    if milvus.has_collection(collection):
        return collection
    schema = milvus.create_schema(auto_id=True, enable_dynamic_fields=True)
    schema.add_field("pk",       DataType.INT64,   is_primary=True)
    schema.add_field("vector",   DataType.FLOAT_VECTOR, dim=DIM)
    schema.add_field("doc_id",   DataType.VARCHAR, max_length=64)
    schema.add_field("page",     DataType.INT32)
    schema.add_field("seq_id",   DataType.INT32)  # text-chunk index
    schema.add_field("gcs_key",  DataType.VARCHAR, max_length=512)  # store IMAGE page key here
    milvus.create_collection(collection, schema=schema)
    idx = milvus.prepare_index_params()
    idx.add_index(field_name="vector", index_name="vec", index_type="HNSW",
                  metric_type="IP", params={"M": 16, "efConstruction": 200})
    idx.add_index(field_name="doc_id", index_name="doc", index_type="INVERTED")
    milvus.create_index(collection, index_params=idx, sync=True)
    milvus.load_collection(collection)
    return collection

# ===== Embedding Model (ColNomic) for retrieval =====
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
# You switched to 7B — keep DIM = 128 (ColPali fixed projection)
MODEL_NAME = "nomic-ai/colnomic-embed-multimodal-7b"

embed_model = ColQwen2_5.from_pretrained(
    MODEL_NAME,
    torch_dtype=(torch.bfloat16 if DEVICE == "cuda" else torch.float32),
    device_map=DEVICE if DEVICE == "cuda" else None,
).eval()
embed_proc = ColQwen2_5_Processor.from_pretrained(MODEL_NAME)

def _first_tensor(t):
    if torch.is_tensor(t):
        return t
    if isinstance(t, (list, tuple)):
        for x in t:
            if torch.is_tensor(x):
                return x
    return None

def _to_numpy_2d(t):
    """
    Convert model output to (N, D) float32.
    Flattens batch/nesting axes, keeps all sub-vectors (late interaction).
    """
    x = _first_tensor(t)
    if x is None:
        raise ValueError("No tensor found in model output")
    arr = x.detach().to("cpu", dtype=torch.float32).numpy()  # (B,S,D) or (S,D)
    D = arr.shape[-1]
    vecs = arr.reshape(-1, D)                                # (N, D)
    return vecs, D

@torch.no_grad()
def embed_image_to_vectors(img: Image.Image):
    """
    Image -> list[list[float]] where each inner list is a 128-d sub-vector.
    """
    dev = next(embed_model.parameters()).device
    batch = embed_proc.process_images([img]).to(dev)
    out = embed_model(**batch)        # typically (1, S, 128)
    vecs_2d, D = _to_numpy_2d(out)    # -> (S, 128)
    if D != DIM:
        raise ValueError(f"Embedding dim {D} != expected {DIM}")
    return [[float(v) for v in row] for row in vecs_2d], D

# Text embeddings for OCR chunks: mean-pool to single 128-d per chunk
@torch.no_grad()
def embed_text_to_vectors(text: str) -> List[List[float]]:
    """
    Text -> one pooled 128-d vector (list[list[float]] for API symmetry).
    Uses the same ColNomic text pathway as queries.
    """
    dev = next(embed_model.parameters()).device
    if hasattr(embed_proc, "process_queries"):
        batch = embed_proc.process_queries([text]).to(dev)
    else:
        batch = embed_proc(text=[text]).to(dev)
    out = embed_model(**batch)  # (1, S, 128)
    x = _first_tensor(out)      # (1, S, 128) or (S,128)
    arr = x.detach().to("cpu", dtype=torch.float32).numpy().reshape(-1, DIM)  # (S,128)
    pooled = arr.mean(axis=0).tolist()  # (128,)
    return [pooled]

def upsert_vectors_batched(collection: str, doc_id: str, page_key: str,
                           vectors: List[List[float]], batch_size: int = 4096) -> int:
    """
    Upserts all sub-vectors for a page PNG into Milvus (image path).
    """
    m = page_num_re.search(page_key)
    page_num = int(m.group(1)) if m else 0
    N = len(vectors)
    if N == 0:
        return 0

    rows = [
        {
            "vector": vectors[i],
            "doc_id": str(doc_id),
            "page": page_num,
            "seq_id": i,
            "gcs_key": page_key,
        }
        for i in range(N)
    ]
    total = 0
    for s in range(0, N, batch_size):
        milvus.insert(collection, rows[s:s+batch_size])
        total += min(batch_size, N - s)
    # Make inserts visible to search immediately
    milvus.flush(collection)
    return total

def upsert_text_vectors_batched(collection: str, doc_id: str, page_key: str,
                                vectors: List[List[float]], chunk_idx: int,
                                batch_size: int = 4096) -> int:
    """
    Upserts pooled text vectors for OCR chunks into the TEXT collection.
    'page_key' should be the **IMAGE** page key so later we can fetch the image directly.
    """
    m = page_num_re.search(page_key)
    page_num = int(m.group(1)) if m else 0
    rows = []
    for vi, v in enumerate(vectors):
        rows.append({
            "vector": v,
            "doc_id": str(doc_id),
            "page": page_num,
            "seq_id": chunk_idx * 10000 + vi,  # chunk index spacing
            "gcs_key": page_key,
        })
    total = 0
    for s in range(0, len(rows), batch_size):
        milvus.insert(collection, rows[s:s+batch_size])
        total += min(batch_size, len(rows) - s)
    # Flush to make searchable
    milvus.flush(collection)
    return total

# --- OCR text chunking helper ---
def chunk_text(text: str) -> List[str]:
    if not text:
        return []
    if RecursiveCharacterTextSplitter:
        splitter = RecursiveCharacterTextSplitter(
            chunk_size=1000, chunk_overlap=150, separators=["\n\n", "\n", " ", ""]
        )
        return [c for c in splitter.split_text(text) if c.strip()]
    # fallback naive
    chunk_size, overlap = 1000, 150
    chunks, n = [], len(text)
    i = 0
    while i < n:
        chunks.append(text[i:i+chunk_size])
        i += (chunk_size - overlap)
    return [c for c in chunks if c.strip()]

# --- Bulk embed OCR text for a given doc (to be called by server during ingestion) ---
def embed_doc_ocr_text_to_milvus(conversation_id: str, doc_id: str) -> dict:
    """
    Loads OCR text files from GCS (text/<conv>/<doc>/page_X.txt),
    chunks → embeds → upserts into TEXT collection with gcs_key set to the matching IMAGE key.
    """
    txt_keys = list_text_keys(conversation_id, doc_id)
    if not txt_keys:
        return {"pages": 0, "chunks": 0}

    txt_coll = ensure_text_collection(conversation_id)
    total_chunks = 0
    for tk in txt_keys:
        try:
            data = gcs.bucket(GCS_BUCKET).blob(tk).download_as_bytes()
            text = data.decode("utf-8", errors="ignore")
        except Exception as e:
            print(f"[WARN] failed to read {tk}: {e}")
            continue

        img_key = page_image_key_for_text_key(tk)  # store image key in Milvus
        # doc_id from path sanity
        m = re.search(rf"/{re.escape(doc_id)}/", tk)
        if not m:
            # best effort: extract doc id from path
            pass

        chunks = chunk_text(text)
        for ci, ch in enumerate(chunks):
            vecs = embed_text_to_vectors(ch)  # [[128]]
            upsert_text_vectors_batched(txt_coll, doc_id, img_key, vecs, chunk_idx=ci)
            total_chunks += 1

    return {"pages": len(txt_keys), "chunks": total_chunks}

# ===== Query encoder (ColNomic text -> (S_q, 128)) =====
@torch.no_grad()
def encode_query_text(text: str) -> np.ndarray:
    dev = next(embed_model.parameters()).device
    if hasattr(embed_proc, "process_queries"):
        batch = embed_proc.process_queries([text]).to(dev)
    else:
        batch = embed_proc(text=[text]).to(dev)
    out = embed_model(**batch)  # (1, S_q, 128) typically
    x = _first_tensor(out)
    arr = x.detach().to("cpu", dtype=torch.float32).numpy()
    return arr.reshape(-1, arr.shape[-1])  # (S_q, 128)

# ===== ANN search and late-interaction aggregation (images) =====
def search_per_token(collection: str, qvec: np.ndarray, topk: int):
    res = milvus.search(
        collection_name=collection,
        data=[qvec.tolist()],
        limit=topk,
        output_fields=["doc_id", "page", "seq_id", "gcs_key"],
        search_params={"metric_type": "IP", "params": {"ef": HNSW_EF}},
    )
    hits = res[0] if res else []
    return [
        {
            "score": float(h["distance"]),
            "doc_id": h["entity"]["doc_id"],
            "page": int(h["entity"]["page"]),
            "gcs_key": h["entity"]["gcs_key"],
        }
        for h in hits
    ]

def late_interaction_aggregate(query_vecs: np.ndarray, collection: str) -> list:
    q = query_vecs[:MAX_QTOK]
    page_scores: Dict[Tuple[str, int], float] = defaultdict(float)
    page_bestkey: Dict[Tuple[str, int], Tuple[float, str]] = {}

    for qv in q:
        hits = search_per_token(collection, qv, topk=PER_SUBVEC)
        # per-token max per page
        per_page_max: Dict[Tuple[str, int], Tuple[float, str]] = {}
        for h in hits:
            key = (h["doc_id"], h["page"])
            cur = per_page_max.get(key, (-1e9, ""))
            if h["score"] > cur[0]:
                per_page_max[key] = (h["score"], h["gcs_key"])
        # sum into global and remember best gcs_key
        for key, (sc, gk) in per_page_max.items():
            page_scores[key] += sc
            best = page_bestkey.get(key, (-1e9, ""))
            if sc > best[0]:
                page_bestkey[key] = (sc, gk)

    ranked = []
    for (doc_id, page), sc in page_scores.items():
        gk = page_bestkey.get((doc_id, page), (-1e9, None))[1]
        ranked.append((sc, {"doc_id": doc_id, "page": page, "gcs_key": gk}))
    ranked.sort(key=lambda x: x[0], reverse=True)
    return ranked

# ===== TEXT ANN search + aggregation to pages =====
def search_text(collection: str, qvec_1d: np.ndarray, limit: int = 200):
    res = milvus.search(
        collection_name=collection,
        data=[qvec_1d.tolist()],
        limit=limit,
        output_fields=["doc_id", "page", "gcs_key", "seq_id"],
        search_params={"metric_type": "IP", "params": {"ef": HNSW_EF}},
    )
    hits = res[0] if res else []
    return [
        {
            "score": float(h["distance"]),
            "doc_id": h["entity"]["doc_id"],
            "page": int(h["entity"]["page"]),
            "gcs_key": h["entity"]["gcs_key"],
        }
        for h in hits
    ]

def aggregate_text_to_pages(hits: List[dict]) -> List[dict]:
    # Sum scores per (doc_id, page, gcs_key)
    agg = {}
    for h in hits:
        key = (h["doc_id"], h["page"], h["gcs_key"])
        agg[key] = agg.get(key, 0.0) + h["score"]
    ranked = [{"key": k, "score": sc} for k, sc in agg.items()]
    ranked.sort(key=lambda x: x["score"], reverse=True)
    return ranked

# ===== Reciprocal Rank Fusion =====
def rrf(runs: List[List[tuple]], k: int = 60, c: int = 60) -> List[tuple]:
    """
    runs: list of ranked lists. Each ranked list is a list of keys (e.g., (doc_id, page, gcs_key)).
    Returns: list of (key, fused_score), sorted desc.
    """
    scores = {}
    for run in runs:
        for rank, key in enumerate(run, start=1):
            rr = 1.0 / (c + rank)
            scores[key] = scores.get(key, 0.0) + rr
    fused = sorted(scores.items(), key=lambda x: x[1], reverse=True)
    return fused[:k]


# ===== Postgres: fetch short chat history =====
def _pg_conn():
    if not PG_CONN_STR:
        raise RuntimeError("PG_CONN_STR not set in env")
    return psycopg2.connect(PG_CONN_STR, connect_timeout=10)

def fetch_chat_history(conv_id: str, limit: int = HISTORY_TURNS) -> list[dict]:
    """
    Returns a list of dicts [{sender, content, created_at}] in chronological order.
    Table schema assumed:
      chat_messages(id, conversation_id, sender, content, metadata, created_at)
    sender is either 'human' or 'AI'.
    """
    if not conv_id:
        return []
    q = """
        SELECT sender, content, created_at
        FROM chat_messages
        WHERE conversation_id = %s
        ORDER BY created_at DESC
        LIMIT %s
    """
    with _pg_conn() as conn:
        with conn.cursor(cursor_factory=psycopg2.extras.RealDictCursor) as cur:
            cur.execute(q, (conv_id, limit))
            rows = cur.fetchall() or []
    rows = list(reversed(rows))
    return [
        {
            "sender": (r["sender"] or "").strip(),
            "content": r.get("content") or "",
            "created_at": r.get("created_at"),
        }
        for r in rows
    ]

def image_only_retrieval(conv_id: str, query: str, top_pages: int):
    """
    Use ONLY image embeddings to rank pages, but we can still fetch and send the OCR text
    for whatever image pages are selected.
    """
    # Encode the query to ColNomic query token vectors
    qvecs = encode_query_text(query)   # (S, 128)

    # Late interaction on the image collection
    img_coll  = ensure_collection(conv_id)
    img_ranked = late_interaction_aggregate(qvecs, img_coll)  # [(score, {doc_id,page,gcs_key}), ...]

    # Deduplicate to page keys, keep original order
    img_keys = []
    seen = set()
    for _, meta in img_ranked:
        key = (meta["doc_id"], meta["page"], meta["gcs_key"])
        if key not in seen:
            img_keys.append(key)
            seen.add(key)

    # Take top pages
    top = img_keys[:top_pages]
    gcs_keys = [k[2] for k in top]  # list of IMAGE page keys

    debug = {
        "mode": "image_only",
        "img_ranked": img_ranked[:top_pages],
        "fused": [(k, None) for k in top],  # for parity (not actually fused)
        "txt_ranked": [],                   # empty since we didn't use text ANN
    }
    return gcs_keys, debug

# ===== Gemini 2.5 Pro (LangChain) =====
# Requires env GOOGLE_API_KEY
gemini_llm = ChatGoogleGenerativeAI(
    model="gemini-2.0-flash",
    temperature=0,
    max_output_tokens=1028,
)

def _pil_to_data_url(img: Image.Image) -> str:
    """PIL → data URL (image/png; base64) for Gemini vision inputs."""
    buf = io.BytesIO()
    img.save(buf, format="PNG")
    b64 = base64.b64encode(buf.getvalue()).decode("utf-8")
    return f"data:image/png;base64,{b64}"

def build_messages(state: dict) -> list:
    """
    Compose a Gemini multi-modal prompt with:
    - strict system instructions
    - short structured chat history (text only)
    - current user turn = question + OCR text + retrieved page images
    """
    system_inst = (
        "You are an expert assistant for analyzing insurance policy documents from page images.\n"
        "\n"
        "Core rules:\n"
        "1) Evidence-bound: Use ONLY the provided page images and their OCR text.\n"
        "   • The OCR text is an automated transcript of the images — use it as assistance to read and verify text.\n"
        "   • If there is any difference between the OCR text and the image, make OCR takes precedence.\n"
        "2) If the answer is not visible, say you cannot find it. Never guess.\n"
        "3) Output style: provide a VERY DETAILED, comprehensive response in clear prose and bullet lists.\n"
        "   • Do NOT be brief. Avoid short summaries.\n"
        "4) Quote numbers/identifiers exactly as printed (benefit limits, premiums, deductibles, co-insurance %, policy/endorsement numbers,\n"
        "   dates, plan codes). Preserve units/currency and capitalization as shown.\n"
        "5) Robust terminology: If the requested term is not found verbatim, search for close synonyms or equivalents on the images (e.g.,\n"
        "   'policy no', 'policy number', 'policy #', 'plan code', 'certificate no', 'endorsement', 'benefit limit', 'co-insurance', 'coinsurance').\n"
        "   Return the closest explicit match from the images. If still unclear, state that explicitly.\n"
        "6) Multiple candidates: If several plausible values appear, report each with its nearby label/heading and explain the ambiguity instead of choosing arbitrarily.\n"
        "7) Page awareness: When helpful, reference where you found information (e.g., 'Page 2 – Schedule of Benefits – Room & Board').\n"
        "8) Conversation history: You may use prior conversation ONLY to understand intent. Do NOT treat prior messages as evidence.\n"
        "9) Fidelity/uncertainty: If characters are ambiguous (e.g., 0 vs O, 1 vs I), cross-check nearby labels or repeated occurrences on the provided pages.\n"
        "   If uncertainty remains, say so (e.g., 'appears to be 75105, but could be 7S105').\n"
        "10) Coverage depth: Go deep. When answering about a benefit/plan/endorsement, include relevant conditions, limits, sub-limits, co-insurance,\n"
        "    deductibles, waiting periods, exclusions, and any rider/endorsement interactions that are visible on the pages.\n"
        "11) Formatting:\n"
        "    • Use short headings and bullets to organize a long, detailed answer.\n"
        "    • Keep labels from the document when they aid clarity (e.g., 'Co-insurance by Insured Member: 20%').\n"
        "    • Ensure **clear spacing** between sections and details:\n"
        "        - Always add a blank line between different headings.\n"
        "        - Always add a blank line between bullet points.\n"
        "        - Always add a blank line between paragraphs.\n"
        "    • Avoid tables unless explicitly requested; expand details in prose/bullets instead.\n"
        "12) Scope control: If the question spans multiple plans or sections, provide a thorough breakdown per plan/section with all visible figures and constraints.\n"
        "13) No preambles or system scaffolding; answer directly with detailed content.\n"
        "\n"
        "Task-specific guidance:\n"
        "• Numbers/IDs: Return exact strings as printed (including hyphens/spaces). If the user asks for a number (e.g., an endorsement or policy number),\n"
        "  provide the exact value and indicate where it was read.\n"
        "• Definitions: When terms are defined on the page (e.g., 'Any One Disability', 'Deductible'), include those definitions verbatim or closely paraphrased.\n"
        "• Comparisons: When comparing plans/benefits, include all relevant differences (limits, coinsurance, eligibility, notes) in detailed prose; do not condense.\n"
        "• Missing/Not covered: If the document states a benefit is excluded or not applicable, say that clearly and point to the wording.\n"
        "• Dates/periods: Preserve visible formats (e.g., DD/MM/YYYY vs MM/DD/YYYY) exactly.\n"
        "\n"
        "If a direct answer is unavailable on the provided images, say: 'I can't find this in the provided pages.' Optionally suggest likely sections to check\n"
        "(e.g., 'Schedule of Benefits', 'Endorsement', 'Policy Schedule', 'General Provisions').\n"
    )

    msgs = [SystemMessage(content=system_inst)]

    # Add short history as alternating human/assistant text turns
    for turn in state.get("history", []):
        sender = (turn.get("sender") or "").strip().lower()
        txt = (turn.get("content") or "").strip()
        if not txt:
            continue
        if sender == "human":
            msgs.append(HumanMessage(content=txt))
        else:
            msgs.append(AIMessage(content=txt))

    # Current user: question + OCR text + images as data URLs
    q    = state["question"]
    imgs = state.get("images", [])
    txts = state.get("texts", [])

    parts = [{"type": "text", "text": q}]
    for i, im in enumerate(imgs):
        if i < len(txts) and txts[i]:
            parts.append({"type": "text", "text": f"OCR Page {i+1}:\n{txts[i]}"})
        parts.append({"type": "image_url", "image_url": _pil_to_data_url(im)})

    msgs.append(HumanMessage(content=parts))
    return msgs

# ===== Retrieval step for RunnableMap (now hybrid & sends OCR text too) =====
def retrieval_step(state: dict) -> dict:
    """
    Input:  {"question": str, "conversation_id": str}
    Output: {"question", "conversation_id", "history", "images", "texts", "pages_used", "retrieval"}
    """
    q = state["question"]
    conv_id = state["conversation_id"]

    # 1) History for the VLM (text-only)
    try:
        history_rows = fetch_chat_history(conv_id, limit=HISTORY_TURNS)
    except Exception as e:
        print(f"[WARN] history fetch failed: {e}")
        history_rows = []

    # 2) Hybrid retrieval (image + text via RRF)
    gcs_keys, debug = image_only_retrieval(conv_id, q, top_pages=TOP_PAGES)

    # 3) Select page images + OCR text
    imgs = download_images_for_keys(gcs_keys)
    txts = download_text_for_keys(gcs_keys)

    return {
        "question": q,
        "conversation_id": conv_id,
        "history": history_rows,
        "images": imgs,
        "texts": txts,
        "retrieval": debug,          # includes img_ranked / txt_ranked / fused
        "pages_used": gcs_keys,      # list of IMAGE GCS keys
    }

@traceable(name="gemini_invoke")
def call_gemini_with_messages(msgs: list):
    # returns the text content so your RunnableMap stays the same
    return gemini_llm.invoke(msgs).content

# ===== RunnableMap pipeline =====
pipeline = (
    RunnableMap({
        "question":        lambda inp: inp["question"],
        "conversation_id": lambda inp: inp["conversation_id"],
    })
    | retrieval_step
    | RunnableMap({
        "answer":     lambda s: call_gemini_with_messages(build_messages(s)),
        "pages_used": lambda s: s["pages_used"],
        "retrieval":  lambda s: s["retrieval"],
    })
)

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


adapter_config.json:   0%|          | 0.00/805 [00:00<?, ?B/s]

config.json: 0.00B [00:00, ?B/s]

model.safetensors.index.json: 0.00B [00:00, ?B/s]

Fetching 7 files:   0%|          | 0/7 [00:00<?, ?it/s]

model-00006-of-00007.safetensors:   0%|          | 0.00/4.93G [00:00<?, ?B/s]

model-00004-of-00007.safetensors:   0%|          | 0.00/5.00G [00:00<?, ?B/s]

model-00007-of-00007.safetensors:   0%|          | 0.00/3.39G [00:00<?, ?B/s]

model-00001-of-00007.safetensors:   0%|          | 0.00/4.95G [00:00<?, ?B/s]

model-00003-of-00007.safetensors:   0%|          | 0.00/4.93G [00:00<?, ?B/s]

model-00002-of-00007.safetensors:   0%|          | 0.00/4.98G [00:00<?, ?B/s]

model-00005-of-00007.safetensors:   0%|          | 0.00/4.98G [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/7 [00:00<?, ?it/s]

Some weights of the model checkpoint at nomic-ai/colqwen2.5-7B-base were not used when initializing ColQwen2_5: ['lm_head.weight']
- This IS expected if you are initializing ColQwen2_5 from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing ColQwen2_5 from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


adapter_model.safetensors:   0%|          | 0.00/323M [00:00<?, ?B/s]

preprocessor_config.json:   0%|          | 0.00/574 [00:00<?, ?B/s]

Using a slow image processor as `use_fast` is unset and a slow processor was saved with this model. `use_fast=True` will be the default behavior in v4.52, even if the model was saved with a slow processor. This will result in minor differences in outputs. You'll still be able to use a slow processor with `use_fast=False`.


tokenizer_config.json: 0.00B [00:00, ?B/s]

vocab.json: 0.00B [00:00, ?B/s]

merges.txt: 0.00B [00:00, ?B/s]

tokenizer.json:   0%|          | 0.00/11.4M [00:00<?, ?B/s]

added_tokens.json:   0%|          | 0.00/605 [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/613 [00:00<?, ?B/s]

You have video processor config saved in `preprocessor.json` file which is deprecated. Video processor configs should be saved in their own `video_preprocessor.json` file. You can rename the file or load and save the processor back which renames it automatically. Loading from `preprocessor.json` will be removed in v5.0.


processor_config.json:   0%|          | 0.00/81.0 [00:00<?, ?B/s]

chat_template.json: 0.00B [00:00, ?B/s]

In [None]:
# =======================
# MercerChat GPU Backend (Cell 2/2)
# FastAPI server with hybrid ingest:
# - If doc_id present → ingest (embed image + embed OCR text from GCS) THEN answer
# - Else → answer-only
# Logs conversation to Postgres AFTER answering.
# =======================

import os, io, time, threading, gc, json, re
import uvicorn
from fastapi import FastAPI, Request, HTTPException
from fastapi.responses import JSONResponse
from PIL import Image  # used during ingest
from pyngrok import ngrok

# --- Postgres / psycopg2 ---
import psycopg2
import psycopg2.extras

PG_CONN_STR     = os.environ.get("PG_CONN_STR", "")
BEARER_TOKEN    = os.environ.get("GPU_BEARER_TOKEN", "superlongrandomsharedsecret")
NGROK_AUTHTOKEN = os.environ.get("NGROK_AUTHTOKEN", "")

def pg_connect():
    if not PG_CONN_STR:
        raise RuntimeError("PG_CONN_STR not set in env")
    return psycopg2.connect(PG_CONN_STR, connect_timeout=10)

def db_insert_chat(conversation_id: str, sender: str, content: str, metadata: dict | None):
    sql = """
        INSERT INTO chat_messages (conversation_id, sender, content, metadata)
        VALUES (%s, %s, %s, %s)
    """
    try:
        with pg_connect() as conn:
            with conn.cursor() as cur:
                cur.execute(
                    sql,
                    (
                        conversation_id,
                        sender,  # exactly 'human' or 'AI'
                        content,
                        psycopg2.extras.Json(metadata) if metadata is not None else None,
                    ),
                )
            conn.commit()
    except Exception as e:
        print(f"[DB][WARN] insert failed: {e}")

def maybe_free_gpu():
    try:
        import torch as _torch
        if _torch.cuda.is_available() and 'CLEAR_CUDA_EACH' in globals() and CLEAR_CUDA_EACH:
            _torch.cuda.empty_cache()
    except Exception as _e:
        print(f"[WARN] GPU cleanup skipped: {_e}")
    finally:
        gc.collect()

# ===== Helpers reused from Cell 1/2 =====
# gcs, GCS_BUCKET
# ensure_collection, ensure_text_collection
# list_page_keys
# embed_image_to_vectors, upsert_vectors_batched
# embed_doc_ocr_text_to_milvus   <-- call this to read OCR .txt from GCS and index into text collection
# pipeline (RunnableMap→Gemini; hybrid retrieval lives in Cell 2)
# milvus

# ===== FastAPI =====
app = FastAPI(title="MercerChat GPU Backend")

def _auth_or_401(req: Request):
    auth = req.headers.get("authorization") or ""
    if not auth.lower().startswith("bearer "):
        raise HTTPException(status_code=401, detail="Missing or invalid auth header")
    token = auth.split(" ", 1)[1].strip()
    if token != BEARER_TOKEN:
        raise HTTPException(status_code=403, detail="Invalid Bearer token")

@app.get("/healthz")
def healthz():
    return {"ok": True, "time": time.time()}

@app.post("/ingest-and-answer")
async def ingest_and_answer(req: Request):
    _auth_or_401(req)

    # Require JSON here so the frontend knows this is the ingest+answer endpoint
    ct = req.headers.get("content-type", "")
    if not ct.startswith("application/json"):
        raise HTTPException(status_code=415, detail="Use application/json for /ingest-and-answer")

    try:
        payload = await req.json()
    except Exception:
        raise HTTPException(status_code=400, detail="Malformed JSON body")

    conversation_id = payload.get("conversation_id")
    human           = payload.get("human")
    doc_id          = payload.get("doc_id")  # optional

    if not conversation_id or not human:
        raise HTTPException(status_code=400, detail="conversation_id and human are required")

    ingested = None
    collection_used = None
    text_collection_used = None

    # NOTE: We do NOT log the human turn yet.

    # If doc_id present → INGEST first
    if doc_id:
        try:
            # Ensure both collections exist
            collection_used      = ensure_collection(conversation_id)
            text_collection_used = ensure_text_collection(conversation_id)

            # List page PNG keys in GCS
            keys = list_page_keys(conversation_id, doc_id)
            if not keys:
                ai_msg = f"(gcs) No pages found for doc_id={doc_id} in conv={conversation_id}."
                # Save both turns (human first, then AI) AFTER we form the response:
                db_insert_chat(
                    conversation_id, "human", human,
                    {"mode": "ingest_and_answer", "doc_id": doc_id}
                )
                db_insert_chat(
                    conversation_id, "AI", ai_msg,
                    {"pages_used": [], "doc_id": doc_id,
                     "collection_used": collection_used,
                     "text_collection_used": text_collection_used}
                )
                return JSONResponse({
                    "ai": ai_msg,
                    "conversation_id": conversation_id,
                    "pages_used": [],
                    "collection_used": collection_used,
                    "text_collection_used": text_collection_used
                })

            total_img_vecs = 0

            # --- Embed IMAGES into image collection ---
            for k in keys:
                try:
                    data = gcs.bucket(GCS_BUCKET).blob(k).download_as_bytes()
                    img = Image.open(io.BytesIO(data)).convert("RGB")

                    img_vectors, _ = embed_image_to_vectors(img)
                    total_img_vecs += upsert_vectors_batched(
                        collection_used, doc_id, k, img_vectors, batch_size=4096
                    )
                except Exception as e:
                    print(f"[INGEST][WARN] Failed page {k}: {e}")

            # --- Embed OCR TEXT (already saved by the first backend) into text collection ---
            # This reads text/<conv>/<doc>/page_X.txt, chunks → embeds → upserts with gcs_key set to the IMAGE page key
            try:
                text_embed_stats = embed_doc_ocr_text_to_milvus(conversation_id, doc_id)
            except Exception as e:
                print(f"[INGEST][WARN] text embedding failed for doc {doc_id}: {e}")
                text_embed_stats = {"pages": 0, "chunks": 0}

            # ensure visibility for subsequent search
            try:
                milvus.flush(collection_used)
            except Exception:
                pass
            try:
                if text_collection_used:
                    milvus.flush(text_collection_used)
            except Exception:
                pass

            ingested = {
                "image_vectors": total_img_vecs,
                "text_pages": text_embed_stats.get("pages", 0),
                "text_chunks": text_embed_stats.get("chunks", 0),
                "pages": len(keys),
            }

        except Exception as e:
            maybe_free_gpu()
            err = f"(error) ingest failed: {type(e).__name__}: {str(e)}"
            # Log both turns now (human then AI error)
            db_insert_chat(
                conversation_id, "human", human,
                {"mode": "ingest_and_answer", "doc_id": doc_id}
            )
            db_insert_chat(
                conversation_id, "AI", err,
                {"pages_used": [], "doc_id": doc_id,
                 "collection_used": collection_used,
                 "text_collection_used": text_collection_used}
            )
            return JSONResponse({"ai": err, "conversation_id": conversation_id}, status_code=500)

    # ===== ANSWER (always run, whether or not we ingested) =====
    try:
        result = pipeline.invoke({
            "question": human,
            "conversation_id": conversation_id,
            "doc_id": doc_id or ""
        })

        if isinstance(result, dict):
            answer     = result.get("answer", "")
            pages_used = result.get("pages_used", []) or []
            retrieval  = result.get("retrieval", None)
        else:
            answer, pages_used, retrieval = str(result), [], None

        # Log BOTH turns (human then AI)
        db_insert_chat(
            conversation_id=conversation_id,
            sender="human",
            content=human,
            metadata={"mode": ("ingest_and_answer" if doc_id else "answer_only"),
                      **({"doc_id": doc_id} if doc_id else {})}
        )
        ai_meta = {"pages_used": pages_used}
        if ingested is not None:
            ai_meta["ingested"] = ingested
        if collection_used:
            ai_meta["collection_used"] = collection_used
        if text_collection_used:
            ai_meta["text_collection_used"] = text_collection_used
        if doc_id:
            ai_meta["doc_id"] = doc_id
        if retrieval is not None:
            ai_meta["retrieval"] = retrieval

        db_insert_chat(
            conversation_id=conversation_id,
            sender="AI",
            content=answer,
            metadata=ai_meta
        )

        # Respond WITHOUT signed URLs (first backend will enrich)
        resp = {
            "ai": answer,
            "conversation_id": conversation_id,
            "pages_used": pages_used,
        }
        if ingested is not None:
            resp["ingested"] = ingested
        if collection_used:
            resp["collection_used"] = collection_used
        if text_collection_used:
            resp["text_collection_used"] = text_collection_used

        return JSONResponse(resp)

    except Exception as e:
        err = f"(error) answer failed: {type(e).__name__}: {str(e)}"
        # Log both turns (human first, then AI error)
        db_insert_chat(
            conversation_id=conversation_id,
            sender="human",
            content=human,
            metadata={"mode": ("ingest_and_answer" if doc_id else "answer_only"),
                      **({"doc_id": doc_id} if doc_id else {})}
        )
        db_insert_chat(
            conversation_id=conversation_id,
            sender="AI",
            content=err,
            metadata={"pages_used": [], "doc_id": doc_id,
                      "collection_used": collection_used,
                      "text_collection_used": text_collection_used}
        )
        return JSONResponse({"ai": err, "conversation_id": conversation_id}, status_code=500)
    finally:
        maybe_free_gpu()

# --- Run server (for Colab/manual start) ---
def _run():
    uvicorn.run(app, host="0.0.0.0", port=8001, log_level="info")

server_thread = threading.Thread(target=_run, daemon=True)
server_thread.start()

if NGROK_AUTHTOKEN:
    try:
        ngrok.set_auth_token(NGROK_AUTHTOKEN)
    except Exception as e:
        print(f"[WARN] ngrok auth failed: {e}")

public_url = ngrok.connect(8001, "http").public_url
print("Public URL:", public_url)

Downloading ngrok ...

INFO:     Started server process [517]
INFO:     Waiting for application startup.
INFO:     Application startup complete.
INFO:     Uvicorn running on http://0.0.0.0:8001 (Press CTRL+C to quit)


Public URL: https://2fa68a2e1dc2.ngrok-free.app


In [None]:
# =======================
# MercerChat GPU Backend (Cell 2/2) — IMAGE-ONLY INGEST
# - If doc_id present → ingest (embed image pages only) THEN answer
# - Else → answer-only
# Logs conversation to Postgres AFTER answering.
# =======================

import os, io, time, threading, gc, json
import uvicorn
from fastapi import FastAPI, Request, HTTPException
from fastapi.responses import JSONResponse
from PIL import Image  # used during ingest
from pyngrok import ngrok

# --- Postgres / psycopg2 ---
import psycopg2
import psycopg2.extras

PG_CONN_STR     = os.environ.get("PG_CONN_STR", "")
BEARER_TOKEN    = os.environ.get("GPU_BEARER_TOKEN", "superlongrandomsharedsecret")
NGROK_AUTHTOKEN = os.environ.get("NGROK_AUTHTOKEN", "")

def pg_connect():
    if not PG_CONN_STR:
        raise RuntimeError("PG_CONN_STR not set in env")
    return psycopg2.connect(PG_CONN_STR, connect_timeout=10)

def db_insert_chat(conversation_id: str, sender: str, content: str, metadata: dict | None):
    sql = """
        INSERT INTO chat_messages (conversation_id, sender, content, metadata)
        VALUES (%s, %s, %s, %s)
    """
    try:
        with pg_connect() as conn:
            with conn.cursor() as cur:
                cur.execute(
                    sql,
                    (
                        conversation_id,
                        sender,  # exactly 'human' or 'AI'
                        content,
                        psycopg2.extras.Json(metadata) if metadata is not None else None,
                    ),
                )
            conn.commit()
    except Exception as e:
        print(f"[DB][WARN] insert failed: {e}")

def maybe_free_gpu():
    try:
        import torch as _torch
        if _torch.cuda.is_available() and 'CLEAR_CUDA_EACH' in globals() and CLEAR_CUDA_EACH:
            _torch.cuda.empty_cache()
    except Exception as _e:
        print(f"[WARN] GPU cleanup skipped: {_e}")
    finally:
        gc.collect()

# ===== Helpers reused from Cell 1/2 =====
# gcs, GCS_BUCKET
# ensure_collection
# list_page_keys
# embed_image_to_vectors, upsert_vectors_batched
# pipeline (RunnableMap→Gemini; retrieval lives in Cell 1 and will also load OCR text)
# milvus

# ===== FastAPI =====
app = FastAPI(title="MercerChat GPU Backend")

def _auth_or_401(req: Request):
    auth = req.headers.get("authorization") or ""
    if not auth.lower().startswith("bearer "):
        raise HTTPException(status_code=401, detail="Missing or invalid auth header")
    token = auth.split(" ", 1)[1].strip()
    if token != BEARER_TOKEN:
        raise HTTPException(status_code=403, detail="Invalid Bearer token")

@app.get("/healthz")
def healthz():
    return {"ok": True, "time": time.time()}

@app.post("/ingest-and-answer")
async def ingest_and_answer(req: Request):
    _auth_or_401(req)

    # Require JSON so the frontend knows this is the ingest+answer endpoint
    ct = req.headers.get("content-type", "")
    if not ct.startswith("application/json"):
        raise HTTPException(status_code=415, detail="Use application/json for /ingest-and-answer")

    try:
        payload = await req.json()
    except Exception:
        raise HTTPException(status_code=400, detail="Malformed JSON body")

    conversation_id = payload.get("conversation_id")
    human           = payload.get("human")
    doc_id          = payload.get("doc_id")  # optional

    if not conversation_id or not human:
        raise HTTPException(status_code=400, detail="conversation_id and human are required")

    ingested = None
    collection_used = None

    # NOTE: We do NOT log the human turn yet.

    # If doc_id present → INGEST (images only)
    if doc_id:
        try:
            collection_used = ensure_collection(conversation_id)

            # List page PNG keys in GCS
            keys = list_page_keys(conversation_id, doc_id)
            if not keys:
                ai_msg = f"(gcs) No pages found for doc_id={doc_id} in conv={conversation_id}."
                # Save both turns (human first, then AI) AFTER we form the response:
                db_insert_chat(
                    conversation_id, "human", human,
                    {"mode": "ingest_and_answer", "doc_id": doc_id}
                )
                db_insert_chat(
                    conversation_id, "AI", ai_msg,
                    {"pages_used": [], "doc_id": doc_id,
                     "collection_used": collection_used}
                )
                return JSONResponse({
                    "ai": ai_msg,
                    "conversation_id": conversation_id,
                    "pages_used": [],
                    "collection_used": collection_used
                })

            total_img_vecs = 0

            # --- Embed IMAGES into image collection ---
            from google.cloud import storage  # ensure gcs client is available here
            bucket = gcs.bucket(GCS_BUCKET)
            for k in keys:
                try:
                    data = bucket.blob(k).download_as_bytes()
                    img = Image.open(io.BytesIO(data)).convert("RGB")

                    img_vectors, _ = embed_image_to_vectors(img)
                    total_img_vecs += upsert_vectors_batched(
                        collection_used, doc_id, k, img_vectors, batch_size=4096
                    )
                except Exception as e:
                    print(f"[INGEST][WARN] Failed page {k}: {e}")

            # ensure visibility for subsequent search
            try:
                milvus.flush(collection_used)
            except Exception:
                pass

            ingested = {
                "image_vectors": total_img_vecs,
                "pages": len(keys),
            }

        except Exception as e:
            maybe_free_gpu()
            err = f"(error) ingest failed: {type(e).__name__}: {str(e)}"
            # Log both turns now (human then AI error)
            db_insert_chat(
                conversation_id, "human", human,
                {"mode": "ingest_and_answer", "doc_id": doc_id}
            )
            db_insert_chat(
                conversation_id, "AI", err,
                {"pages_used": [], "doc_id": doc_id,
                 "collection_used": collection_used}
            )
            return JSONResponse({"ai": err, "conversation_id": conversation_id}, status_code=500)

    # ===== ANSWER (always run, whether or not we ingested) =====
    try:
        result = pipeline.invoke({
            "question": human,
            "conversation_id": conversation_id,
            "doc_id": doc_id or ""
        })

        if isinstance(result, dict):
            answer     = result.get("answer", "")
            pages_used = result.get("pages_used", []) or []
            retrieval  = result.get("retrieval", None)
        else:
            answer, pages_used, retrieval = str(result), [], None

        # Log BOTH turns (human then AI)
        db_insert_chat(
            conversation_id=conversation_id,
            sender="human",
            content=human,
            metadata={"mode": ("ingest_and_answer" if doc_id else "answer_only"),
                      **({"doc_id": doc_id} if doc_id else {})}
        )
        ai_meta = {"pages_used": pages_used}
        if ingested is not None:
            ai_meta["ingested"] = ingested
        if collection_used:
            ai_meta["collection_used"] = collection_used
        if doc_id:
            ai_meta["doc_id"] = doc_id
        if retrieval is not None:
            ai_meta["retrieval"] = retrieval

        db_insert_chat(
            conversation_id=conversation_id,
            sender="AI",
            content=answer,
            metadata=ai_meta
        )

        # Respond WITHOUT signed URLs (first backend will enrich)
        resp = {
            "ai": answer,
            "conversation_id": conversation_id,
            "pages_used": pages_used,
        }
        if ingested is not None:
            resp["ingested"] = ingested
        if collection_used:
            resp["collection_used"] = collection_used

        return JSONResponse(resp)

    except Exception as e:
        err = f"(error) answer failed: {type(e).__name__}: {str(e)}"
        # Log both turns (human first, then AI error)
        db_insert_chat(
            conversation_id=conversation_id,
            sender="human",
            content=human,
            metadata={"mode": ("ingest_and_answer" if doc_id else "answer_only"),
                      **({"doc_id": doc_id} if doc_id else {})}
        )
        db_insert_chat(
            conversation_id=conversation_id,
            sender="AI",
            content=err,
            metadata={"pages_used": [], "doc_id": doc_id,
                      "collection_used": collection_used}
        )
        return JSONResponse({"ai": err, "conversation_id": conversation_id}, status_code=500)
    finally:
        maybe_free_gpu()

# --- Run server (for Colab/manual start) ---
def _run():
    uvicorn.run(app, host="0.0.0.0", port=8001, log_level="info")

server_thread = threading.Thread(target=_run, daemon=True)
server_thread.start()

if NGROK_AUTHTOKEN:
    try:
        ngrok.set_auth_token(NGROK_AUTHTOKEN)
    except Exception as e:
        print(f"[WARN] ngrok auth failed: {e}")

public_url = ngrok.connect(8001, "http").public_url
print("Public URL:", public_url)

INFO:     Started server process [2519]
INFO:     Waiting for application startup.
INFO:     Application startup complete.
INFO:     Uvicorn running on http://0.0.0.0:8001 (Press CTRL+C to quit)


Public URL: https://bd1c616d22c0.ngrok-free.app
