## setup

In [1]:
# If running the first time on a new machine, uncomment these:
# %pip install -q sentence-transformers rank-bm25 requests numpy

import os, json, time
from typing import List, Dict, Any, Tuple, Optional
import numpy as np
from pathlib import Path

from sentence_transformers import SentenceTransformer
from rank_bm25 import BM25Okapi
import requests

# ---- Config ----
DATA_ROOT = Path("data")   # expects subfolders like gita_arnold/, upanishads_sbe/, rigveda_griffith/
EMB_MODEL_NAME = "sentence-transformers/all-MiniLM-L6-v2"  # 384-d
GROQ_API_KEY = os.getenv("GROQ_API_KEY", "")               # set env or paste here
GROQ_MODEL = "llama-3.1-8b-instant"                        # or "llama-3.1-70b-versatile"
GROQ_URL = "https://api.groq.com/openai/v1/chat/completions"

print("Data root exists:", DATA_ROOT.exists())
print("Groq key present?", bool(GROQ_API_KEY))

Data root exists: True
Groq key present? True


In [2]:
# Choose which shards to include in chatting:
SELECTED_SHARDS = ["gita_arnold", "upanishads_sbe", "rigveda_griffith"]
# For quick tests you can do: SELECTED_SHARDS = ["gita_arnold"]

def load_shard(folder: Path):
    manifest = json.load((folder / "manifest.json").open("r", encoding="utf-8"))
    chunks_path = folder / manifest["combined_chunks"]
    emb_path = folder / manifest["embeddings_bin"]

    # Load text records
    records = []
    with chunks_path.open("r", encoding="utf-8") as f:
        for line in f:
            if line.strip():
                records.append(json.loads(line))

    # Load embeddings (float16) → view as float32
    raw = np.fromfile(emb_path, dtype=np.float16)
    dim = manifest["dim"]
    emb = raw.reshape(len(records), dim).astype(np.float32, copy=False)

    return records, emb

all_chunks: List[Dict[str, Any]] = []
all_embs = []

for name in SELECTED_SHARDS:
    folder = DATA_ROOT / name
    if not folder.exists():
        print("! Missing shard folder:", folder)
        continue
    recs, embs = load_shard(folder)
    print(f"Loaded shard {name}: {len(recs)} units, dim={embs.shape[1]}")
    for r in recs:
        r.setdefault("shard", name)  # track origin shard
    all_chunks.extend(recs)
    all_embs.append(embs)

if not all_chunks:
    raise RuntimeError("No chunks loaded. Check SELECTED_SHARDS and that shards were built.")

EMB = np.vstack(all_embs)  # (N, D)
print("TOTAL units:", EMB.shape[0], "Embedding dim:", EMB.shape[1])


Loaded shard gita_arnold: 2436 units, dim=384
Loaded shard upanishads_sbe: 8412 units, dim=384
Loaded shard rigveda_griffith: 10562 units, dim=384
TOTAL units: 21410 Embedding dim: 384


## retriever creation

In [3]:
def bm25_tokenize(s: str):
    return s.lower().split()

BM25 = BM25Okapi([bm25_tokenize(r.get("text","")) for r in all_chunks])
print("BM25 index built. Total documents:", len(all_chunks))


BM25 index built. Total documents: 21410


In [4]:
_embedder = None
def get_embedder():
    global _embedder
    if _embedder is None:
        _embedder = SentenceTransformer(EMB_MODEL_NAME)
    return _embedder

def embed_query(q: str) -> np.ndarray:
    return get_embedder().encode([q], normalize_embeddings=True)[0].astype(np.float32)


## helper to filter sources

In [5]:
def get_ref(r: Dict[str, Any]) -> str:
    return r.get("canonical_ref") or r.get("canon_id") or r.get("id") or r.get("work") or "?"

def get_score(r: Dict[str, Any]) -> str:
    try:
        return f"{float(r.get('_score')):.3f}"
    except Exception:
        return "n/a"

def list_available_sources():
    works = sorted({r.get("work","?") for r in all_chunks})
    shards = sorted({r.get("shard","?") for r in all_chunks})
    colls = sorted({r.get("collection","?") for r in all_chunks})
    print("Works:", works)
    print("Shards:", shards)
    print("Collections:", colls)

# Run this once if you want to see labels to filter on:
# list_available_sources()


In [6]:
def _matches_source(rec: Dict[str, Any], source_filter: Optional[dict]) -> bool:
    if not source_filter:
        return True
    w_ok = s_ok = c_ok = True
    if "works" in source_filter and source_filter["works"]:
        w_ok = rec.get("work") in set(source_filter["works"])
    if "shards" in source_filter and source_filter["shards"]:
        s_ok = rec.get("shard") in set(source_filter["shards"])
    if "collections" in source_filter and source_filter["collections"]:
        c_ok = rec.get("collection") in set(source_filter["collections"])
    return w_ok and s_ok and c_ok


## retriever

In [7]:
def hybrid_retrieve(query: str, k_vec=12, k_bm=12, top_final=8, source_filter: Optional[dict] = None):
    # Optional mask for selected subset
    if source_filter:
        mask_arr = np.array([_matches_source(r, source_filter) for r in all_chunks])
        idx_map = np.flatnonzero(mask_arr)
        if idx_map.size == 0:
            return []
        mat = EMB[mask_arr]
    else:
        mask_arr = None
        idx_map = None
        mat = EMB

    # Vector scores
    q_emb = embed_query(query)
    scores = mat @ q_emb
    k_vec = min(k_vec, scores.shape[0])
    vec_idx = np.argpartition(-scores, kth=k_vec-1)[:k_vec]
    vec_idx = vec_idx[np.argsort(scores[vec_idx])[::-1]]
    vec = [(int(i), float(scores[i])) for i in vec_idx]

    # BM25 on full corpus; normalize and merge (apply filter during merge)
    bm_scores = BM25.get_scores(bm25_tokenize(query))
    k_bm = min(k_bm, len(bm_scores))
    bm_idx = np.argpartition(-bm_scores, kth=k_bm-1)[:k_bm]
    bm_idx = bm_idx[np.argsort(bm_scores[bm_idx])[::-1]]
    bm = [(int(i), float(bm_scores[i])) for i in bm_idx]

    # Map local vec indices back to global if masked
    if idx_map is not None:
        vec = [(int(idx_map[i]), s) for (i, s) in vec]

    # Merge
    combined = {}
    for i, s in vec:
        combined[i] = max(combined.get(i, -1e9), s)
    if bm:
        max_bm = max(s for _, s in bm) or 1.0
        for i, s in bm:
            if source_filter and not _matches_source(all_chunks[i], source_filter):
                continue
            combined[i] = max(combined.get(i, -1e9), (s / max_bm) * 0.9)

    # Final top_k
    final = sorted(combined.items(), key=lambda kv: kv[1], reverse=True)[:top_final]
    results = []
    for i, sc in final:
        r = dict(all_chunks[i])
        r["_idx"] = int(i)
        try:
            r["_score"] = float(sc)
        except Exception:
            r["_score"] = None
        results.append(r)
    return results


## answer generation helper

In [9]:
SYSTEM_PROMPT = '''You are a Hindu scripture assistant.
Rules:
1) Answer ONLY using the provided passages.
2) Include citations inline like `BG 2.47`, `Kena 1.3`, or `RV 1.1.1`.
3) If not supported by the passages, say: "Not found in current corpus."
4) Be concise and neutral; when schools differ, note that briefly.
'''

def build_context(nodes: List[Dict[str, Any]]) -> str:
    blocks = []
    for r in nodes:
        ref = get_ref(r)
        work = r.get("work","")
        trn  = r.get("translator","")
        txt  = (r.get("text") or "").strip()
        blocks.append(f"[{ref}] ({work}) {txt} (Translator: {trn})")
    return "\n\n".join(blocks)

def groq_answer(query: str, nodes: List[Dict[str, Any]], temperature=0.2, max_tokens=600) -> str:
    if not GROQ_API_KEY:
        return ""
    payload = {
        "model": GROQ_MODEL,
        "messages": [
            {"role": "system", "content": SYSTEM_PROMPT},
            {"role": "user", "content": f"Query: {query}\n\nPassages:\n{build_context(nodes)}\n\nAnswer with citations."}
        ],
        "temperature": temperature,
        "max_tokens": max_tokens,
        "stream": False
    }
    try:
        r = requests.post(GROQ_URL,
                          headers={"Authorization": f"Bearer {GROQ_API_KEY}",
                                   "Content-Type": "application/json"},
                          json=payload, timeout=60)
        r.raise_for_status()
        data = r.json()
        return data["choices"][0]["message"]["content"].strip()
    except Exception as e:
        print("Groq error → falling back to extractive:", e)
        return ""

def extractive_answer(query: str, retrieved: List[Dict[str, Any]]) -> Dict[str, Any]:
    if not retrieved:
        return {"answer":"Not found in current corpus.", "citations":[]}
    cites, bullets = [], []
    for r in retrieved[:4]:
        ref = get_ref(r)
        bullets.append(f"- {r.get('text','').strip()} ({ref})")
        cites.append(ref)
    ans = "Here are relevant passages:\n" + "\n".join(bullets) + \
          "\n\nSummary: themes include duty without attachment, devotion, and knowledge depending on context. " + \
          f"(citations: {', '.join(cites)})"
    return {"answer": ans, "citations": cites}

def answer_query(query, k_vec=8, k_bm=8, top_final=6, temperature=0.2, source_filter: dict | None = None):
    retrieved = hybrid_retrieve(query, k_vec=k_vec, k_bm=k_bm, top_final=top_final, source_filter=source_filter)
    context = build_context(retrieved)

    if not GROQ_API_KEY:
        fallback = extractive_answer(query, retrieved)
        fallback["retrieved"] = retrieved
        return fallback

    try:
        text = groq_answer(query, retrieved, temperature=temperature, max_tokens=600)
        refs_in_text = [get_ref(r) for r in retrieved if get_ref(r) in (text or "")]
        return {"answer": text or "Not found in current corpus.", "citations": refs_in_text, "retrieved": retrieved}
    except Exception as e:
        print("Groq call failed; using extractive fallback. Error:\n", e)
        fb = extractive_answer(query, retrieved)
        fb["retrieved"] = retrieved
        return fb



## demos

In [10]:
def pretty_sources(rows):
    for r in rows:
        print(f"  - {get_ref(r):>12}  [{r.get('work')}]  score={get_score(r)}")

tests = [
    "What does Gita say about parents?",
    "Explain nishkama karma with references.",
    "Is it ok to be attached?",
]

print("=== Unfiltered ===")
for q in tests:
    print("="*80)
    print("Q:", q)
    out = answer_query(q)
    print("\nAnswer:\n", out.get("answer","<no answer>")[:1200])
    print("\nTop sources:")
    pretty_sources(out.get("retrieved", []))

print("\n\n=== Filter: Only Bhagavad Gita ===")
for q in tests:
    print("="*80)
    print("Q:", q)
    out = answer_query(q, source_filter={"works": ["Bhagavad Gita"]})
    print("\nAnswer:\n", out.get("answer","<no answer>")[:1200])
    print("\nTop sources:")
    pretty_sources(out.get("retrieved", []))

print("\n\n=== Filter: Only Upanishads collection ===")
print("="*80)
out = answer_query("What is Atman and Brahman relationship?", source_filter={"collections": ["Upanishads"]})
print("\nAnswer:\n", out.get("answer","<no answer>")[:1200])
print("\nTop sources:")
pretty_sources(out.get("retrieved", []))

print("\n\n=== Filter: Only Rig Veda ===")
print("="*80)
out = answer_query("A hymn to Agni", source_filter={"works": ["Rig Veda"]})
print("\nAnswer:\n", out.get("answer","<no answer>")[:1200])
print("\nTop sources:")
pretty_sources(out.get("retrieved", []))


=== Unfiltered ===
Q: What does Gita say about parents?

Answer:
 Not found in current corpus.

Top sources:
  -   SBE01158 3  [sbe01158]  score=0.900
  - SBE15072 133  [sbe15072]  score=0.878
  -      BG 4.78  [Bhagavad Gita]  score=0.878
  - SBE15072 111  [sbe15072]  score=0.857
  - SBE15072 113  [sbe15072]  score=0.857
  - SBE15072 104  [sbe15072]  score=0.857
Q: Explain nishkama karma with references.

Answer:
 Nishkama karma refers to selfless action or action without attachment to the outcome. This concept is described in the Bhagavad Gita as:

"Thy work, the KARMA? Tell me what it is" (BG 8.3)

Krishna explains that karma is the cause of all life, implying that it is a fundamental aspect of existence:

"Causing all life to live, is KARMA called" (BG 8.12)

In the context of yoga, nishkama karma is the performance of actions without attachment to their fruits. This is described in the Bhagavad Gita as:

"Yasmin pu n ye 'nukûle 'hni karma k ikîrshati tata h prâk pu n yâham evârabh