In [None]:
#!/usr/bin/env python3
"""
Retrieval-only RAG (no LLM):
- Ingests .txt/.md from ./docs
- Chunks, embeds (sentence-transformers)
- Stores in FAISS (cosine via inner product on L2-normalized vectors)
- Answers by returning top-K chunks and top sentences (extractive)
"""

import os, sys, glob, pickle
from dataclasses import dataclass
from typing import List, Tuple
import numpy as np
import faiss
from sentence_transformers import SentenceTransformer

# ---------------- Config ----------------
DOCS_DIR = "./docs"
INDEX_PATH = "./faiss.index"
META_PATH  = "./faiss_meta.pkl"
EMBED_MODEL = "all-MiniLM-L6-v2"
CHUNK_SIZE = 900
CHUNK_OVERLAP = 200
TOP_K_CHUNKS = 5
TOP_K_SENTENCES = 6
MMR_LAMBDA = 0.7       # 1.0 = purely relevance, 0 = purely diversity

# --------------- Data types -------------
@dataclass
class ChunkMeta:
    source: str
    chunk_id: int
    text: str

# --------------- Utils ------------------
def read_texts(dir_path: str) -> List[Tuple[str, str]]:
    """Return list of (path, text). Supports .txt, .md."""
    out = []
    for p in glob.glob(os.path.join(dir_path, "**", "*"), recursive=True):
        if os.path.isdir(p):
            continue
        ext = os.path.splitext(p)[1].lower()
        if ext not in {".txt", ".md"}:
            continue
        try:
            with open(p, "r", encoding="utf-8", errors="ignore") as f:
                t = f.read().strip()
                if t:
                    out.append((p, t))
        except Exception:
            pass
    return out

def chunk_text(t: str, size: int, overlap: int) -> List[str]:
    if size <= 0: return [t]
    chunks, n, i = [], len(t), 0
    while i < n:
        j = min(n, i + size)
        chunks.append(t[i:j])
        if j == n: break
        i = max(0, j - overlap)
    return chunks

def l2_normalize(a: np.ndarray) -> np.ndarray:
    n = np.linalg.norm(a, axis=1, keepdims=True) + 1e-12
    return (a / n).astype("float32")

def embed(model: SentenceTransformer, texts: List[str]) -> np.ndarray:
    v = model.encode(texts, batch_size=64, convert_to_numpy=True, show_progress_bar=False)
    return l2_normalize(v)

# --------------- Index ------------------
class VectorIndex:
    def __init__(self, dim: int):
        self.faiss = faiss.IndexFlatIP(dim)   # inner product on normalized = cosine
        self.meta: List[ChunkMeta] = []

    def add(self, vecs: np.ndarray, metas: List[ChunkMeta]):
        self.faiss.add(vecs)
        self.meta.extend(metas)

    def search(self, q: np.ndarray, k: int) -> Tuple[np.ndarray, List[ChunkMeta]]:
        D, I = self.faiss.search(q, k)
        hits = []
        for idx in I[0]:
            if idx == -1: continue
            hits.append(self.meta[idx])
        return D[0], hits

    def save(self):
        faiss.write_index(self.faiss, INDEX_PATH)
        with open(META_PATH, "wb") as f:
            pickle.dump(self.meta, f)

    @classmethod
    def load(cls):
        if not (os.path.exists(INDEX_PATH) and os.path.exists(META_PATH)):
            return None
        idx = faiss.read_index(INDEX_PATH)
        with open(META_PATH, "rb") as f:
            meta = pickle.load(f)
        vi = cls(idx.d)
        vi.faiss = idx
        vi.meta = meta
        return vi

# ------------- Build / Load -------------
def build_or_load_index() -> Tuple[VectorIndex, SentenceTransformer]:
    emb_model = SentenceTransformer(EMBED_MODEL)

    vi = VectorIndex.load()
    if vi:
        return vi, emb_model

    docs = read_texts(DOCS_DIR)
    if not docs:
        raise SystemExit(f"No .txt/.md files found under {DOCS_DIR}")

    vi = VectorIndex(dim=emb_model.get_sentence_embedding_dimension())
    metas, all_chunks = [], []

    for path, txt in docs:
        chunks = chunk_text(txt, CHUNK_SIZE, CHUNK_OVERLAP)
        for j, ch in enumerate(chunks):
            metas.append(ChunkMeta(source=path, chunk_id=j, text=ch))
            all_chunks.append(ch)

    vecs = embed(emb_model, all_chunks)
    vi.add(vecs, metas)
    vi.save()
    print(f"[index] built: {len(all_chunks)} chunks from {len(docs)} files")
    return vi, emb_model

# ------------- Retrieval core -----------
def mmr_select(doc_vecs: np.ndarray, query_vec: np.ndarray, k: int, lam: float=0.7) -> List[int]:
    """
    Maximal Marginal Relevance:
    selects k indices balancing relevance to query and diversity among selected.
    """
    if k <= 0: return []
    n = doc_vecs.shape[0]
    sim_to_q = (doc_vecs @ query_vec.T).ravel()  # cosine (since normalized)
    selected, candidates = [], list(range(n))

    # pick most relevant first
    first = int(np.argmax(sim_to_q))
    selected.append(first)
    candidates.remove(first)

    while len(selected) < min(k, n) and candidates:
        # compute diversity term: max similarity to any already-selected doc
        sel_vecs = doc_vecs[selected]             # (m, d)
        cand_vecs = doc_vecs[candidates]          # (c, d)
        # (c, m) similarities
        sim_to_sel = cand_vecs @ sel_vecs.T
        max_sim_to_sel = sim_to_sel.max(axis=1)

        # mmr score
        mmr = lam * sim_to_q[candidates] - (1 - lam) * max_sim_to_sel
        pick_local = int(np.argmax(mmr))
        pick = candidates[pick_local]
        selected.append(pick)
        candidates.remove(pick)
    return selected

def split_sentences(text: str) -> List[str]:
    # simple sentence splitter; for better results, swap in nltk/syntok
    import re
    sents = re.split(r'(?<=[\.\?\!])\s+', text.strip())
    return [s for s in sents if s]

def top_sentences(emb_model: SentenceTransformer, query: str, passages: List[str], k: int) -> List[Tuple[float, str]]:
    # rank individual sentences by cosine similarity
    sents = []
    for p in passages:
        sents.extend(split_sentences(p))

    if not sents:
        return []

    qv = embed(emb_model, [query])          # (1,d)
    sv = embed(emb_model, sents)            # (n,d)
    scores = (sv @ qv.T).ravel()            # (n,)
    idx = np.argsort(-scores)[:k]
    return [(float(scores[i]), sents[i]) for i in idx]

# ------------- Public API ----------------
def retrieve_only(query: str, top_k_chunks=TOP_K_CHUNKS, top_k_sents=TOP_K_SENTENCES):
    vi, emb_model = build_or_load_index()

    # brute: get a wide set first (e.g., 5x K), then MMR reduce to K
    qv = embed(emb_model, [query])          # (1,d)
    wide_k = max(top_k_chunks * 5, top_k_chunks)
    D, _ = vi.faiss.search(qv, wide_k)
    # collect candidate vectors + metas
    # We need the vectors for MMR; re-embed the candidate metas’ texts:
    # (cheaper: we could cache vectors alongside meta; omitted for brevity)
    candidates = vi.meta[: D.shape[1]] if vi.faiss.ntotal == D.shape[1] else vi.meta
    cand_texts = [m.text for m in candidates]
    cand_vecs = embed(emb_model, cand_texts)

    sel_idx = mmr_select(cand_vecs, qv, k=top_k_chunks, lam=MMR_LAMBDA)
    chosen = [((cand_vecs[i] @ qv.T).item(), candidates[i]) for i in sel_idx]


    # Top sentences (extractive)
    passages = [m.text for _, m in chosen]
    sent_hits = top_sentences(emb_model, query, passages, k=top_k_sents)

    return chosen, sent_hits

# ------------- CLI ----------------------
def main():
    if len(sys.argv) < 2:
        print(f"Usage: {sys.argv[0]} 'your query'")
        sys.exit(1)
    query = sys.argv[1]

    chunks, sents = retrieve_only(query)

    print("\n=== Top Passages ===")
    for rank, (score, meta) in enumerate(chunks, 1):
        short = meta.text.replace("\n", " ")[:240]
        print(f"{rank:>2}. score={score:.3f}  {meta.source}  [chunk {meta.chunk_id}]")
        print(f"    {short}...")
    if not chunks:
        print("No passages found.")

    print("\n=== Top Sentences (extractive) ===")
    for rank, (score, sent) in enumerate(sents, 1):
        print(f"{rank:>2}. {score:.3f}  {sent}")

if __name__ == "__main__":
    main()
