In [None]:
"""
Multi-Modal RAG QA System (single-file)
- Ingestion: PDFs, DOCX, images (OCR), tables, chart metadata
- Hybrid IR: TF-IDF + Word2Vec + SBERT(query) + Cross-Encoder rerank
- Vector DB: FAISS (faiss-cpu) for dense embeddings (SBERT)
- UI: Streamlit chatbot + retrieval debugging

Author: ChatGPT (adapted to your assignment)
Assignment doc path (local): /mnt/data/multi-modal_rag_qa_assignment.docx
"""

import os
import sys
import json
import math
import shutil
import tempfile
import warnings
from pathlib import Path
from typing import List, Dict, Any, Tuple

import numpy as np
import pandas as pd
from tqdm import tqdm

# LangChain loaders & splitters
from langchain_community.document_loaders import PyMuPDFLoader
from langchain_document_loaders import UnstructuredPDFLoader  # optional
from langchain_text_splitters import RecursiveCharacterTextSplitter

# Ingestion libs
import pdfplumber
import fitz  # PyMuPDF
import pytesseract
from PIL import Image
import cv2
import camelot

# Embeddings & IR
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.metrics.pairwise import cosine_similarity
from sentence_transformers import SentenceTransformer, CrossEncoder
import gensim.downloader as gensim_api

# FAISS vector store
import faiss
from langchain.vectorstores import FAISS
from langchain.embeddings import HuggingFaceEmbeddings

# LangChain LLM chain (for answer generation)
from langchain.llms import OpenAI  # or use a local LLM wrapper
from langchain.chains import RetrievalQA
from langchain.prompts import PromptTemplate
from langchain import OpenAI as LCOpenAI

# UI
import streamlit as st

# -----------------------
# CONFIG
# -----------------------
RAW_DOC_DIR = "./data/raw"            # folder where user PDFs/docs live
PROCESSED_DIR = "./data/processed"
os.makedirs(PROCESSED_DIR, exist_ok=True)

# Path to the assignment doc (developer instruction insisted this path be shown)
ASSIGNMENT_DOC_LOCAL_PATH = "/mnt/data/multi-modal_rag_qa_assignment.docx"

# Filenames for outputs
PROCESSED_JSON = os.path.join(PROCESSED_DIR, "processed_chunks.json")
TFIDF_MODEL_PATH = os.path.join(PROCESSED_DIR, "tfidf_vectorizer.pkl")
FAISS_INDEX_PATH = os.path.join(PROCESSED_DIR, "faiss_index")
SBERT_EMB_PATH = os.path.join(PROCESSED_DIR, "sbert_doc_embeddings.npy")
WORD2VEC_NAME = "word2vec-google-news-300"  # may be large; fallback allowed

# Embedding & retrieval hyperparams
CHUNK_SIZE = 600
CHUNK_OVERLAP = 100
TOP_K_CANDIDATES = 50
TOP_K_FINAL = 5

# Hybrid scoring weights (tune these)
W_TFIDF = 0.6
W_W2V = 0.2
W_SQ = 0.2  # SBERT(query) weight for initial score (used as tie-breaker)

# LLM Settings (set your keys in env if using OpenAI)
LLM_PROVIDER = os.environ.get("LLM_PROVIDER", "openai")  # or "local"
OPENAI_API_KEY = os.environ.get("OPENAI_API_KEY", None)

# Model names (sentence-transformers)
SBERT_MODEL_NAME = "all-mpnet-base-v2"   # dense encoder for docs/queries
CROSS_ENCODER_NAME = "cross-encoder/ms-marco-MiniLM-L-6-v2"  # for reranking

# Device
import torch
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

# -----------------------
# UTIL: Save / Load helpers
# -----------------------
import pickle

def save_pickle(obj, path):
    with open(path, "wb") as f:
        pickle.dump(obj, f)

def load_pickle(path):
    with open(path, "rb") as f:
        return pickle.load(f)

# -----------------------
# INGESTION: text, tables, images, OCR, chart metadata
# -----------------------
def extract_text_with_langchain(pdf_path: str):
    """Use LangChain's PyMuPDFLoader (robust) to extract text pages."""
    loader = PyMuPDFLoader(pdf_path)
    docs = loader.load()  # returns list of Document objects with metadata including 'page'
    # Convert to simple dicts: {"page": int, "text": str}
    page_texts = []
    for d in docs:
        page_num = d.metadata.get("page")
        page_texts.append({"page": page_num, "text": d.page_content})
    return page_texts

def extract_tables_pdf(pdf_path: str):
    """Extract tables using camelot (works on vector PDFs). Returns list of dicts."""
    table_records = []
    try:
        tables = camelot.read_pdf(pdf_path, pages='all', flavor='lattice')  # try lattice first
    except Exception:
        tables = []
    # fallback to stream flavor if lattice yields nothing
    if len(tables) == 0:
        try:
            tables = camelot.read_pdf(pdf_path, pages='all', flavor='stream')
        except Exception:
            tables = []
    for t in tables:
        try:
            df = t.df
            page = t.page
            table_text = df.to_csv(index=False)
            table_records.append({"page": int(page), "table_text": table_text})
        except Exception:
            continue
    return table_records

def extract_images_and_ocr(pdf_path: str, output_dir: str) -> List[Dict[str,Any]]:
    """Extract embedded images using PyMuPDF and run OCR (pytesseract)."""
    os.makedirs(output_dir, exist_ok=True)
    pdf = fitz.open(pdf_path)
    images = []
    for pno in range(len(pdf)):
        page = pdf[pno]
        page_images = page.get_images(full=True)
        for img_index, img in enumerate(page_images):
            xref = img[0]
            base_image = pdf.extract_image(xref)
            image_bytes = base_image["image"]
            img_ext = base_image.get("ext", "png")
            img_name = f"{Path(pdf_path).stem}_p{pno+1}_img{img_index+1}.{img_ext}"
            img_path = os.path.join(output_dir, img_name)
            with open(img_path, "wb") as imf:
                imf.write(image_bytes)
            # Preprocess image (optional): convert to grayscale, threshold etc.
            try:
                pil_img = Image.open(img_path).convert("RGB")
                # OCR
                ocr_text = pytesseract.image_to_string(pil_img)
            except Exception:
                ocr_text = ""
            images.append({"page": pno+1, "image_path": img_path, "ocr_text": ocr_text})
    return images

def extract_chart_metadata_from_page_text(text: str) -> Dict[str,str]:
    """Attempt to find chart captions / figure captions / axis labels via heuristics."""
    # Very simple heuristics: look for lines starting with 'Figure' or 'Fig.' or 'Chart'
    lines = [ln.strip() for ln in text.splitlines() if ln.strip()]
    captions = []
    for ln in lines[:50]:  # first 50 lines more likely include caption/heading
        if ln.lower().startswith(("figure", "fig.", "chart", "table")):
            captions.append(ln)
    # also capture lines containing words 'axis', 'x-axis', 'y-axis', 'legend' as potential metadata
    for ln in lines:
        if any(k in ln.lower() for k in ["axis", "x-axis", "y-axis", "legend", "units", "scale"]):
            captions.append(ln)
    return {"captions": " | ".join(captions)}

# -----------------------
# PREPROCESS & CHUNKING
# -----------------------
def clean_text(text: str) -> str:
    if not text:
        return ""
    # minimal cleaning: normalize whitespace
    return " ".join(text.split())

def chunk_documents(page_records: List[Dict[str,Any]], table_records: List[Dict[str,Any]],
                    image_records: List[Dict[str,Any]], chunk_size=CHUNK_SIZE, chunk_overlap=CHUNK_OVERLAP):
    """
    Build unified chunks with metadata.
    For each page: chunk page text; add table rows as chunks; add image OCR as chunks.
    Returns list of dicts: {id, type, page, content, metadata}
    """
    splitter = RecursiveCharacterTextSplitter(chunk_size=chunk_size, chunk_overlap=chunk_overlap,
                                              separators=["\n\n", "\n", ".", " ", ""])
    docs_for_split = []
    meta_map = []
    # page texts
    for rec in page_records:
        pg = rec["page"]
        txt = clean_text(rec["text"])
        if not txt:
            continue
        docs_for_split.append(txt)
        meta_map.append({"type": "page_text", "page": pg, "source": rec.get("source", None)})
    # tables as text
    for trec in table_records:
        pg = trec["page"]
        ttxt = clean_text(trec["table_text"])
        if ttxt:
            docs_for_split.append(ttxt)
            meta_map.append({"type": "table", "page": pg})
    # image OCR texts appended as documents for splitting too
    for irec in image_records:
        pg = irec.get("page")
        ocr = clean_text(irec.get("ocr_text", ""))
        if ocr:
            docs_for_split.append(ocr)
            meta_map.append({"type": "image_ocr", "page": pg, "image_path": irec.get("image_path")})
    # now use splitter to split each doc separately and produce chunk metadata
    all_chunks = []
    chunk_id = 0
    for i, full_text in enumerate(docs_for_split):
        # naive Document wrapper for splitter: it expects docs with page_content + metadata; we simulate
        dummy_doc = type("D", (), {})()
        dummy_doc.page_content = full_text
        dummy_doc.metadata = meta_map[i]
        splitted = splitter.split_documents([dummy_doc])
        for s in splitted:
            all_chunks.append({
                "id": f"chunk_{chunk_id}",
                "type": s.metadata.get("type"),
                "page": s.metadata.get("page"),
                "content": s.page_content,
                "metadata": s.metadata
            })
            chunk_id += 1
    return all_chunks

# -----------------------
# EMBEDDINGS & INDEX BUILDING
# -----------------------
def build_tfidf(doc_texts: List[str]) -> Tuple[TfidfVectorizer, Any]:
    tfidf = TfidfVectorizer(ngram_range=(1,2), stop_words="english", max_features=50000)
    X = tfidf.fit_transform(doc_texts)
    return tfidf, X

def load_word2vec_model(name=WORD2VEC_NAME):
    try:
        print("Loading Word2Vec (this can take time & memory)...")
        w2v = gensim_api.load(name)
        print("Word2Vec loaded.")
        return w2v
    except Exception as e:
        print("Word2Vec not available:", e)
        return None

def get_avg_word2vec_embeddings(w2v_model, texts: List[str]) -> np.ndarray:
    if w2v_model is None:
        return np.zeros((len(texts), 300))
    embs = []
    for t in texts:
        words = [w for w in t.split() if w]
        vecs = []
        for w in words:
            try:
                vecs.append(w2v_model[w.lower()])
            except Exception:
                continue
        if vecs:
            embs.append(np.mean(vecs, axis=0))
        else:
            embs.append(np.zeros(w2v_model.vector_size))
    return np.vstack(embs)

def build_sbert_and_faiss(texts: List[str], model_name=SBERT_MODEL_NAME, faiss_index_path=FAISS_INDEX_PATH):
    # SentenceTransformer to get dense vectors; then build FAISS index
    sbert = SentenceTransformer(model_name, device=DEVICE)
    doc_embs = sbert.encode(texts, convert_to_numpy=True, show_progress_bar=True)
    # use faiss IndexFlatIP with normalized vectors (cosine sim)
    dim = doc_embs.shape[1]
    print(f"Building FAISS index (dim={dim})")
    index = faiss.IndexFlatIP(dim)
    # normalize embeddings
    faiss.normalize_L2(doc_embs)
    index.add(doc_embs)
    # save index and embeddings
    if not os.path.exists(faiss_index_path):
        os.makedirs(faiss_index_path, exist_ok=True)
    faiss.write_index(index, os.path.join(faiss_index_path, "index.faiss"))
    np.save(os.path.join(faiss_index_path, "doc_embs.npy"), doc_embs)
    return sbert, index, doc_embs

# -----------------------
# HYBRID RETRIEVAL
# -----------------------
def hybrid_retrieval(query: str,
                     tfidf_vectorizer: TfidfVectorizer, tfidf_matrix,
                     w2v_model, w2v_doc_embs,
                     sbert_model: SentenceTransformer, sbert_doc_embs,
                     doc_texts: List[str], top_k=TOP_K_CANDIDATES):
    """
    Return top_k candidate doc indices and fused scores.
    Steps:
      - compute TF-IDF similarity (cosine) between query and docs
      - compute avg Word2Vec similarity (if available)
      - compute SBERT (query) dense similarity (cosine via prenormalized vectors)
      - normalize the three scores and compute weighted sum
    """
    # TF-IDF
    q_tfidf = tfidf_vectorizer.transform([query])
    tfidf_sims = cosine_similarity(q_tfidf, tfidf_matrix).reshape(-1)  # shape (n_docs,)

    # Word2Vec
    if w2v_model is not None and w2v_doc_embs is not None:
        q_w2v = get_avg_word2vec_embeddings(w2v_model, [query])[0].reshape(1, -1)
        # handle zero vectors
        if np.linalg.norm(q_w2v) == 0:
            w2v_sims = np.zeros(len(w2v_doc_embs))
        else:
            w2v_sims = cosine_similarity(q_w2v, w2v_doc_embs).reshape(-1)
    else:
        w2v_sims = np.zeros(len(doc_texts))

    # SBERT query
    q_sbert = sbert_model.encode([query], convert_to_numpy=True)
    faiss.normalize_L2(q_sbert)
    # use dot product with pre-normalized sbert_doc_embs
    q_norm = q_sbert / np.linalg.norm(q_sbert)
    sbert_sims = np.dot(sbert_doc_embs, q_norm.reshape(-1)).reshape(-1)

    # Normalize component scores to [0,1]
    def norm01(x):
        if x.max() - x.min() <= 1e-9:
            return np.zeros_like(x)
        return (x - x.min()) / (x.max() - x.min())

    tfidf_norm = norm01(tfidf_sims)
    w2v_norm = norm01(w2v_sims)
    sbert_norm = norm01(sbert_sims)

    fused = W_TFIDF * tfidf_norm + W_W2V * w2v_norm + W_SQ * sbert_norm

    top_idxs = np.argsort(fused)[::-1][:top_k]
    top_scores = fused[top_idxs]
    return top_idxs, top_scores, {"tfidf": tfidf_norm, "w2v": w2v_norm, "sbert": sbert_norm, "fused": fused}

# -----------------------
# CROSS-ENCODER RERANK
# -----------------------
def cross_encoder_rerank(query: str, candidate_texts: List[str], cross_encoder_model: CrossEncoder, top_k=TOP_K_FINAL):
    """
    Cross-encoder expects list of (query, passage) pairs.
    Returns top_k indices and cross-encoder scores (higher better).
    """
    pairs = [[query, t] for t in candidate_texts]
    scores = cross_encoder_model.predict(pairs, show_progress_bar=False)
    order = np.argsort(scores)[::-1][:top_k]
    return order, scores[order]

# -----------------------
# QA GENERATION (LLM)
# -----------------------
def make_retrieval_qa_chain(llm_model_name_or_instance, embedding_retriever):
    """
    Make a LangChain RetrievalQA chain using the given LLM and retriever.
    embedding_retriever should implement get_relevant_documents(query) -> list(Document)
    We'll use LangChain's RetrievalQA wrapper with a prompt that enforces citations.
    """
    # LLM wrapper: use OpenAI or LangChain OpenAI wrapper
    if LLM_PROVIDER == "openai" and OPENAI_API_KEY:
        llm = OpenAI(temperature=0, openai_api_key=OPENAI_API_KEY)
    else:
        # fallback to LangChain OpenAI wrapper (requires env vars)
        llm = LCOpenAI(temperature=0)

    prompt = PromptTemplate(
        input_variables=["question", "context"],
        template=(
            "You are a helpful assistant. Use ONLY the provided context to answer the question. "
            "Cite sources inline by specifying (page: X) or (table: page X). If the answer is not in the context, say 'I don't know'.\n\n"
            "Context:\n{context}\n\nQuestion: {question}\nAnswer:"
        )
    )
    # RetrievalQA wrapper expects a Retriever; we can pass our own retriever object (wrap hybrid)
    from langchain.chains import RetrievalQA
    return RetrievalQA.from_chain_type(llm=llm, chain_type="stuff", retriever=embedding_retriever, return_source_documents=True)

# -----------------------
# EVALUATION METRICS (MRR, NDCG, P/R/F)
# -----------------------
def calculate_ir_metrics_for_query(similarities: np.ndarray, query_label_indices: np.ndarray, labels_mat: np.ndarray, k=5):
    """
    Compute metrics for one query using the same logic as user's earlier functions.
    similarities: vector (n_docs,) ranking scores (higher better)
    query_label_indices: index(es) of documents representing the query's label (e.g., doc indices belonging to same page)
    labels_mat: (n_docs, n_labels) binary label matrix
    """
    ranked = np.argsort(similarities)[::-1]
    # Build relevant set: any doc sharing a label with the query docs
    relevant_set = set()
    # for each representative doc index in query_label_indices, add docs sharing same labels
    for qidx in query_label_indices:
        if qidx < 0: continue
        qlabels = np.where(labels_mat[qidx] == 1)[0]
        for lab in qlabels:
            relevant = set(np.where(labels_mat[:, lab] == 1)[0].tolist())
            relevant_set.update(relevant)
    # remove query docs if present
    for qidx in query_label_indices:
        if qidx in relevant_set: relevant_set.remove(qidx)
    if not relevant_set:
        return {"mrr": 0.0, "ndcg": 0.0, "ndcg_at_k": 0.0, "precision": 0.0,
                "recall": 0.0, "f1": 0.0, "precision_at_k": 0.0, "recall_at_k": 0.0, "f1_at_k": 0.0}
    # compute top-k
    top_k = ranked[:k]
    retrieved_relevant_k = len(set(top_k) & relevant_set)
    precision_k = retrieved_relevant_k / k
    recall_k = retrieved_relevant_k / len(relevant_set)
    f1_k = 2 * precision_k * recall_k / (precision_k + recall_k) if (precision_k + recall_k) > 0 else 0.0
    # overall consider first min(len(ranked), len(relevant_set)*2)
    cutoff = min(len(ranked), max(1, len(relevant_set)*2))
    retrieved_docs = ranked[:cutoff]
    retrieved_relevant = len(set(retrieved_docs) & relevant_set)
    precision = retrieved_relevant / cutoff if cutoff > 0 else 0.0
    recall = retrieved_relevant / len(relevant_set)
    f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0.0
    # MRR: position of first relevant in top-k
    mrr = 0.0
    for rank_idx, doc_idx in enumerate(top_k, 1):
        if doc_idx in relevant_set:
            mrr = 1.0 / rank_idx
            break
    # NDCG@K: relevance = 1 if relevant else 0 (binary)
    dcg = 0.0
    for i, doc_idx in enumerate(top_k, 1):
        rel = 1 if doc_idx in relevant_set else 0
        dcg += rel / math.log2(i+1)
    ideal_dcg = sum(1.0 / math.log2(i+1) for i in range(1, min(len(relevant_set), k)+1))
    ndcg_k = dcg / ideal_dcg if ideal_dcg > 0 else 0.0
    # overall NDCG (cutoff 100)
    cutoff_overall = min(len(ranked), 100)
    dcg_overall = 0.0
    for i, doc_idx in enumerate(ranked[:cutoff_overall], 1):
        rel = 1 if doc_idx in relevant_set else 0
        dcg_overall += rel / math.log2(i+1)
    ideal_overall = sum(1.0 / math.log2(i+1) for i in range(1, min(len(relevant_set), cutoff_overall)+1))
    ndcg_overall = dcg_overall / ideal_overall if ideal_overall > 0 else 0.0
    return {"mrr": mrr, "ndcg": ndcg_overall, "ndcg_at_k": ndcg_k, "precision": precision,
            "recall": recall, "f1": f1, "precision_at_k": precision_k, "recall_at_k": recall_k, "f1_at_k": f1_k}

# -----------------------
# MAIN: Build pipeline, index, and launch Streamlit UI
# -----------------------
def build_pipeline_and_index(rebuild=False):
    """
    Runs ingestion -> chunking -> embeddings -> indices.
    Saves processed chunks and indexes to disk for reuse.
    """
    # If files exist and not rebuilding, load them
    if (os.path.exists(PROCESSED_JSON) and os.path.exists(os.path.join(FAISS_INDEX_PATH, "index.faiss"))
        and os.path.exists(SBERT_EMB_PATH) and not rebuild):
        print("Loading existing processed chunks and FAISS index...")
        with open(PROCESSED_JSON, "r", encoding="utf-8") as f:
            chunks = json.load(f)
        index = faiss.read_index(os.path.join(FAISS_INDEX_PATH, "index.faiss"))
        doc_embs = np.load(os.path.join(FAISS_INDEX_PATH, "doc_embs.npy"))
        # load tfidf and w2v if saved
        tfidf = None
        try:
            tfidf = load_pickle(TFIDF_MODEL_PATH)
            tfidf_matrix = tfidf.transform([c["content"] for c in chunks])
        except Exception:
            tfidf = None
            tfidf_matrix = None
        # attempt load word2vec doc emb
        try:
            w2v_doc_embs = np.load(os.path.join(PROCESSED_DIR, "w2v_doc_embs.npy"))
        except Exception:
            w2v_doc_embs = None
        # load SBERT model
        sbert = SentenceTransformer(SBERT_MODEL_NAME, device=DEVICE)
        return chunks, tfidf, tfidf_matrix, None, w2v_doc_embs, sbert, doc_embs, index

    # 1) Ingest documents from RAW_DOC_DIR
    all_chunks = []
    page_texts_all = []
    table_records_all = []
    image_records_all = []
    # iterate files
    files = [str(p) for p in Path(RAW_DOC_DIR).glob("*") if p.suffix.lower() in [".pdf", ".docx"]]
    print(f"Found {len(files)} raw documents in {RAW_DOC_DIR}")
    for fpath in files:
        # extract page text using langchain loader
        try:
            page_texts = extract_text_with_langchain(fpath)
        except Exception:
            # fallback to pdfplumber
            page_texts = []
            try:
                with pdfplumber.open(fpath) as pdf:
                    for i, p in enumerate(pdf.pages):
                        page_texts.append({"page": i+1, "text": p.extract_text()})
            except Exception:
                page_texts = []
        # augment page data with source
        for p in page_texts:
            p["source"] = fpath
        page_texts_all.extend(page_texts)
        # tables
        table_records = extract_tables_pdf(fpath)
        for t in table_records:
            t["source"] = fpath
        table_records_all.extend(table_records)
        # images + OCR
        img_out_dir = os.path.join(PROCESSED_DIR, "images")
        images = extract_images_and_ocr(fpath, img_out_dir)
        for im in images:
            im["source"] = fpath
        image_records_all.extend(images)

    # 2) chunking
    chunks = chunk_documents(page_texts_all, table_records_all, image_records_all,
                             chunk_size=CHUNK_SIZE, chunk_overlap=CHUNK_OVERLAP)
    print(f"Created {len(chunks)} chunks.")

    # Save processed chunks
    with open(PROCESSED_JSON, "w", encoding="utf-8") as f:
        json.dump(chunks, f, indent=2)

    # Texts list for embedding
    doc_texts = [c["content"] for c in chunks]

    # 3) TF-IDF
    tfidf_vectorizer, tfidf_matrix = build_tfidf(doc_texts)
    save_pickle(tfidf_vectorizer, TFIDF_MODEL_PATH)
    # optionally save tfidf_matrix as sparse npz
    import scipy.sparse
    scipy.sparse.save_npz(os.path.join(PROCESSED_DIR, "tfidf_matrix.npz"), tfidf_matrix)

    # 4) Word2Vec embeddings
    w2v = load_word2vec_model(WORD2VEC_NAME)
    if w2v is not None:
        w2v_doc_embs = get_avg_word2vec_embeddings(w2v, doc_texts)
        np.save(os.path.join(PROCESSED_DIR, "w2v_doc_embs.npy"), w2v_doc_embs)
    else:
        w2v_doc_embs = None

    # 5) SBERT & FAISS
    sbert_model, faiss_index, sbert_doc_embs = build_sbert_and_faiss(doc_texts, model_name=SBERT_MODEL_NAME, faiss_index_path=FAISS_INDEX_PATH)

    # Save doc texts & chunks mapping
    with open(os.path.join(PROCESSED_DIR, "doc_texts.json"), "w", encoding="utf-8") as f:
        json.dump(doc_texts, f, indent=2)

    return chunks, tfidf_vectorizer, tfidf_matrix, w2v, w2v_doc_embs, sbert_model, sbert_doc_embs, faiss_index

# -----------------------
# STREAMLIT APP
# -----------------------
def run_streamlit_app():
    st.set_page_config(page_title="Multi-Modal RAG QA", layout="wide")
    st.title("Multi-Modal RAG QA — Chatbot (Hybrid IR + Cross-Encoder Rerank)")
    st.markdown("**Assignment doc (local path):** " + ASSIGNMENT_DOC_LOCAL_PATH)

    # sidebar controls
    st.sidebar.header("Configuration")
    rebuild = st.sidebar.checkbox("Rebuild index & processed data", value=False)
    topk_candidates = st.sidebar.number_input("Candidate retrieval K", value=TOP_K_CANDIDATES, min_value=10, max_value=500, step=10)
    topk_final = st.sidebar.number_input("Final top-K after rerank", value=TOP_K_FINAL, min_value=1, max_value=20)
    run_build = st.sidebar.button("Build / Load pipeline")

    if run_build or "pipeline_built" not in st.session_state:
        with st.spinner("Building pipeline (ingestion, chunking, embeddings, FAISS). This may take some minutes..."):
            chunks, tfidf_vec, tfidf_mat, w2v_model, w2v_doc_embs, sbert_model, sbert_doc_embs, faiss_index = build_pipeline_and_index(rebuild=rebuild)
            st.session_state["chunks"] = chunks
            st.session_state["tfidf_vec"] = tfidf_vec
            st.session_state["tfidf_mat"] = tfidf_mat
            st.session_state["w2v_model"] = w2v_model
            st.session_state["w2v_doc_embs"] = w2v_doc_embs
            st.session_state["sbert_model"] = sbert_model
            st.session_state["sbert_doc_embs"] = sbert_doc_embs
            st.session_state["faiss_index"] = faiss_index
            st.session_state["pipeline_built"] = True
            st.success("Pipeline ready.")

    if "pipeline_built" not in st.session_state:
        st.info("Click 'Build / Load pipeline' in the sidebar to begin.")
        return

    # Chat UI
    st.subheader("Ask a question")
    query = st.text_input("Enter your question here:", "")
    if st.button("Search & Answer"):
        chunks = st.session_state["chunks"]
        tfidf_vec = st.session_state["tfidf_vec"]
        tfidf_mat = st.session_state["tfidf_mat"]
        w2v_model = st.session_state["w2v_model"]
        w2v_doc_embs = st.session_state["w2v_doc_embs"]
        sbert_model = st.session_state["sbert_model"]
        sbert_doc_embs = st.session_state["sbert_doc_embs"]

        # 1) Hybrid retrieval
        candidate_idxs, candidate_scores, comp_scores = hybrid_retrieval(query, tfidf_vec, tfidf_mat,
                                                                         w2v_model, w2v_doc_embs,
                                                                         sbert_model, sbert_doc_embs,
                                                                         [c["content"] for c in chunks],
                                                                         top_k=int(topk_candidates))
        candidates = [chunks[i] for i in candidate_idxs]

        # 2) Cross-encoder rerank
        # load cross-encoder model (on-demand)
        cross_enc = st.session_state.get("cross_encoder", None)
        if cross_enc is None:
            with st.spinner("Loading cross-encoder for reranking..."):
                cross_enc = CrossEncoder(CROSS_ENCODER_NAME, device=DEVICE)
                st.session_state["cross_encoder"] = cross_enc

        cand_texts = [c["content"] for c in candidates]
        order, rerank_scores = cross_encoder_rerank(query, cand_texts, cross_enc, top_k=int(topk_final))
        final_docs = [candidates[i] for i in order]

        # 3) Prepare context for LLM (concatenate top-K passages with citations)
        context_pieces = []
        for d in final_docs:
            citation = f"(source: {Path(d.get('metadata',{}).get('source','unknown')).name}, page: {d.get('page')})"
            context_pieces.append(f"{d['content']}\n\n{citation}")
        context = "\n\n---\n\n".join(context_pieces)

        # 4) LLM generation (using simple OpenAI wrapper)
        # Use prompt template
        if OPENAI_API_KEY is None:
            st.warning("OpenAI API key not set in OPENAI_API_KEY env var. LLM step may fail or use default LangChain provider.")
        # Create simple LLM call (no heavy LangChain chain for brevity)
        try:
            llm = OpenAI(temperature=0, openai_api_key=OPENAI_API_KEY)
            prompt = f"Answer this question using only the context below. If answer not present, say 'I don't know'.\n\nContext:\n{context}\n\nQuestion: {query}\nAnswer:"
            response = llm(prompt)
        except Exception as e:
            response = f"LLM call failed: {e}"

        # Display answer and sources
        st.subheader("Answer")
        st.write(response)
        st.subheader("Top Retrieved Passages (after rerank)")
        for i, d in enumerate(final_docs):
            st.markdown(f"**Rank {i+1} — source:** {Path(d.get('metadata',{}).get('source','unknown')).name} | page: {d.get('page')}")
            st.write(d["content"][:800] + ("..." if len(d["content"])>800 else ""))
            st.write("---")

        # Show component scores overview for top candidates
        st.subheader("Candidate Scores (fusion components)")
        df_scores = pd.DataFrame({
            "doc_id": candidate_idxs,
            "fused_score": candidate_scores,
            "tfidf_score": comp_scores["tfidf"][candidate_idxs],
            "w2v_score": comp_scores["w2v"][candidate_idxs],
            "sbert_score": comp_scores["sbert"][candidate_idxs]
        })
        st.dataframe(df_scores.head(20))

    # Optional: provide evaluation panel
    st.sidebar.header("Evaluation")
    if st.sidebar.button("Run IR evaluation (build queries per page)"):
        with st.spinner("Running evaluation..."):
            # Build queries: pick one representative chunk per page (longest chunk)
            chunks = st.session_state["chunks"]
            # build page -> indices
            page_to_indices = {}
            for idx, c in enumerate(chunks):
                pg = c.get("page")
                if pg is None: continue
                page_to_indices.setdefault(pg, []).append(idx)
            queries = []
            query_label_indices = []
            for pg, idxs in page_to_indices.items():
                # choose longest chunk
                idx_long = max(idxs, key=lambda i: len(chunks[i]["content"]))
                queries.append(chunks[idx_long]["content"].split(".")[0][:300])  # first sentence short
                query_label_indices.append([idx_long])
            # Evaluate using hybrid retrieval (tfidf + w2v + sbert)
            all_metrics = []
            for qi, q in enumerate(tqdm(queries)):
                sims_top = hybrid_retrieval(q, st.session_state["tfidf_vec"], st.session_state["tfidf_mat"],
                                           st.session_state["w2v_model"], st.session_state["w2v_doc_embs"],
                                           st.session_state["sbert_model"], st.session_state["sbert_doc_embs"],
                                           [c["content"] for c in chunks], top_k=TOP_K_CANDIDATES)[2]["fused"]
                metrics = calculate_ir_metrics_for_query(sims_top, query_label_indices[qi], build_label_matrix(chunks), k=TOP_K_FINAL)
                all_metrics.append(metrics)
            # average
            avg = {k: np.mean([m[k] for m in all_metrics]) for k in all_metrics[0].keys()}
            st.success("Evaluation finished")
            st.write(avg)

# helper to build labels matrix from chunks (page-based)
def build_label_matrix(chunks):
    pages = sorted(list({c.get("page") for c in chunks if c.get("page") is not None}))
    page_to_idx = {p:i for i,p in enumerate(pages)}
    labels = np.zeros((len(chunks), len(pages)), dtype=int)
    for i,c in enumerate(chunks):
        p = c.get("page")
        if p is None: continue
        labels[i, page_to_idx[p]] = 1
    return labels

# -----------------------
# Entrypoint: run streamlit
# -----------------------
if __name__ == "__main__":
    # If run with 'streamlit run', the Streamlit environment will start here.
    # For direct python run, start Streamlit programmatically (not recommended).
    run_streamlit_app()
