
# Retrieval‑Augmented Generation (RAG) — Policy Assistant (Gemini 2.5 Flash Lite)

**Scenario:** A policy assistant answers employee questions using HR policies, benefits PDFs, and recent announcement emails.  
**Goal:** Short, grounded answers with exact citations and a structured JSON for audit — or **abstain** when sources don’t suffice.



- Ingesting documents with metadata (source, date, owner, **ACL**)
- Chunking & indexing; **BM25**, simple **semantic** cosine; **Hybrid** + **MMR** diversity
- Building a strict, source-bound prompt that **requires citations** and allows abstention
- Calling **Gemini 2.5 Flash Lite** for low-temperature JSON outputs
- Validating a JSON **contract** (with `None` fallbacks) and computing offline **RAG metrics**
- Logging retrieval hits and feeding failures back to improve data/prompts


## Setup — SDK & Model

In [None]:
from google.colab import userdata
try:
    import google.generativeai as genai
    genai.configure(api_key=userdata.get('GOOGLE_API_KEY'))
    _GEMINI_READY = bool(userdata.get('GOOGLE_API_KEY'))
except Exception:
    print("Install google-generativeai to enable live calls.")
    _GEMINI_READY = False

In [None]:
import os, re, math, time, json
from datetime import datetime
from collections import defaultdict, Counter

MODEL_NAME = os.getenv("GEMINI_MODEL", "gemini-2.5-flash-lite")

if not _GEMINI_READY:
    print("WARNING: GOOGLE_API_KEY is not set. Set it before running LLM cells.")
else:
    print("GOOGLE_API_KEY detected. Model:", MODEL_NAME)


## Ingest — Mock Corpus with Metadata & ACL


In [None]:
CORPUS = [
    {"doc_id":"HR-Handbook-2025.pdf","source_type":"pdf","date":"2025-06-15","owner":"HR",
     "acl":{"departments":["all"],"locations":["all"]},
     "pages":{"11":"Benefits overview. Health, dental, vision. Parental leave policy updated in 2025.",
              "12":"Parental leave: 16 weeks paid for primary caregivers; 8 weeks paid for secondary caregivers. Eligibility: 6 months tenure.",
              "28":"Sabbatical program: 8 weeks unpaid after 4 years of service; manager approval required."}},
    {"doc_id":"Benefits-Site","url":"https://intranet/hr/benefits#leave","source_type":"web","date":"2025-07-20","owner":"HR",
     "acl":{"departments":["all"],"locations":["US","CA","UK"]},
     "sections":{"leave":"Paid time off, sick leave, parental leave details; links to eligibility tables and forms.",
                 "health":"Medical plans with HSA and FSA options; contact benefits@company.example"}},
    {"doc_id":"Announcement-Email-2025-08-05","source_type":"email","date":"2025-08-05","owner":"HR",
     "acl":{"departments":["all"],"locations":["all"]},
     "body":"Reminder: parental leave remains 16 weeks for primary caregivers; remote employees in UK now receive an additional 2 weeks statutory-compatible leave."},
    {"doc_id":"Policy-Wiki-Travel","source_type":"wiki","date":"2025-05-10","owner":"Finance",
     "acl":{"departments":["Engineering","Finance"],"locations":["US","CA"]},
     "sections":{"per_diem":"US/CA per diem allowances by city tier. Submit receipts via Concur.",
                 "intl_travel":"Visa and insurance checklist. Book via corporate portal."}},
    {"doc_id":"Ticket-History-1234","source_type":"ticket","date":"2025-03-11","owner":"HR",
     "acl":{"departments":["all"],"locations":["all"]},
     "body":"Resolved: clarification on PTO accrual rates. Linked to HR-Handbook-2025.pdf p.10."}
]

def flatten_passages(corpus):
    passages = []
    for d in corpus:
        meta = {k:d.get(k) for k in ["doc_id","source_type","date","owner","acl"]}
        if d.get("pages"):
            for p, text in d["pages"].items():
                passages.append({**meta, "section_or_page": str(p), "text": text, "url": None})
        if d.get("sections"):
            for s, text in d["sections"].items():
                passages.append({**meta, "section_or_page": str(s), "text": text, "url": d.get("url")})
        if d.get("body"):
            passages.append({**meta, "section_or_page": "body", "text": d["body"], "url": d.get("url")})
    return passages

PASSAGES = flatten_passages(CORPUS)
len(PASSAGES)


## Index — Tokenize, BM25, TF‑IDF (Simple Semantic), and Metadata


In [None]:
WORD = re.compile(r"[A-Za-z0-9_#@]+")
def tokenize(text): return [w.lower() for w in WORD.findall(text or "")]

import math
from collections import Counter

DOC_TOKENS = [tokenize(p["text"]) for p in PASSAGES]
N_DOCS = len(DOC_TOKENS)

df = Counter()
for toks in DOC_TOKENS:
    for t in set(toks):
        df[t] += 1
idf = {t: math.log((N_DOCS - c + 0.5)/(c + 0.5) + 1) for t, c in df.items()}

avgdl = sum(len(toks) for toks in DOC_TOKENS)/N_DOCS
k1, b = 1.5, 0.75

def bm25(query, k=50):
    q = tokenize(query)
    scores = []
    for i, toks in enumerate(DOC_TOKENS):
        dl = len(toks); tf = Counter(toks); s = 0.0
        for t in q:
            if t in tf and t in idf:
                num = tf[t]*(k1+1); den = tf[t] + k1*(1 - b + b*dl/avgdl)
                s += idf[t]*(num/den)
        scores.append((s,i))
    scores.sort(reverse=True)
    return scores[:k]

def tfidf_vec(tokens):
    tf = Counter(tokens)
    return {t:(1+math.log(1+c))*idf[t] for t,c in tf.items() if t in idf}

def cosine(v1, v2):
    if not v1 or not v2: return 0.0
    num = sum(v1.get(t,0)*v2.get(t,0) for t in set(v1)|set(v2))
    den = math.sqrt(sum(x*x for x in v1.values())) * math.sqrt(sum(x*x for x in v2.values()))
    return num/den if den else 0.0

DOC_TFIDF = [tfidf_vec(toks) for toks in DOC_TOKENS]

def semantic_scores(query, k=50):
    qv = tfidf_vec(tokenize(query))
    scores = [(cosine(qv, dv), i) for i, dv in enumerate(DOC_TFIDF)]
    scores.sort(reverse=True)
    return scores[:k]


## Retrieve — Hybrid → ACL Filter → Rerank → MMR Diversity


In [None]:
from collections import defaultdict

def normalize_scores(pairs):
    if not pairs: return []
    smax = max(s for s,_ in pairs) or 1.0
    return [(s/smax, i) for s,i in pairs]

def hybrid_retrieve(query, user_department="all", user_location="US", k_cand=50, top_k=8, lambda_sem=0.5):
    bm = normalize_scores(bm25(query, k=k_cand))
    se = normalize_scores(semantic_scores(query, k=k_cand))
    combined = defaultdict(float)
    for s,i in bm: combined[i] += (1 - lambda_sem)*s
    for s,i in se: combined[i] += lambda_sem*s

    def acl_ok(p):
        acl = p["acl"]
        dep_ok = "all" in acl["departments"] or user_department in acl["departments"]
        loc_ok = "all" in acl["locations"] or user_location in acl["locations"]
        return dep_ok and loc_ok

    items = [(sc, i) for i, sc in combined.items() if acl_ok(PASSAGES[i])]
    items.sort(reverse=True)

    def recency_weight(p):
        try: return 1.0 + 0.01*int(p["date"].replace("-",""))
        except: return 1.0
    items = [(s*recency_weight(PASSAGES[i]), i) for s,i in items]
    items.sort(reverse=True)

    selected, selected_vecs = [], []
    qv = tfidf_vec(tokenize(query))
    cand = [i for _,i in items[:k_cand]]
    while cand and len(selected) < top_k:
        scores = []
        for i in cand:
            sim_q = cosine(qv, DOC_TFIDF[i])
            sim_sel = max([cosine(DOC_TFIDF[i], v) for v in selected_vecs] + [0.0])
            mmr = 0.7*sim_q - 0.3*sim_sel
            scores.append((mmr, i))
        scores.sort(reverse=True)
        best = scores[0][1]
        selected.append(best)
        selected_vecs.append(DOC_TFIDF[best])
        cand.remove(best)
    return selected

def render_sources(indices):
    lines = []
    for r, i in enumerate(indices, 1):
        p = PASSAGES[i]
        src = f"{p['doc_id']}:{p['section_or_page']}"
        lines.append(f"{r}) {src} — {p['text']}")
    return "\n".join(lines)


## Contract — Answer JSON (with `None` fallbacks)


In [None]:
ANSWER_SCHEMA = {
  "type": "object",
  "properties": {
    "answer": {"type":"string"},
    "citations": {"type":"array","items":{"type":"object","properties":{
        "doc_id":{"type":"string"}, "page":{"type":"string"}, "section":{"type":"string"}, "url":{"type":"string"}
    }}},
    "confidence": {"type":"string","enum":["low","medium","high"]},
    "rationale_internal": {"type":"string"}
  },
  "required": ["answer","citations","confidence"],
  "additionalProperties": False
}

def validate_with_fallbacks(obj, schema=ANSWER_SCHEMA):
    errors = []
    props = schema.get("properties", {})
    required = schema.get("required", [])
    fallbacks = {"string":"", "array":[]}
    norm = {}
    for k,v in obj.items():
        spec = props.get(k)
        if spec and v is None:
            if spec.get("type") == "string": norm[k] = fallbacks["string"]
            elif spec.get("type") == "array": norm[k] = fallbacks["array"]
            else: norm[k] = v
        else:
            norm[k] = v
    for key in required:
        if key not in norm: errors.append(f"Missing required field: {key}")
    if schema.get("additionalProperties") is False:
        for k in norm.keys():
            if k not in props: errors.append(f"Unexpected field: {k}")
    for k,spec in props.items():
        if k not in norm: continue
        val = norm[k]
        if spec.get("type") == "string":
            if not isinstance(val, str): errors.append(f"Field {k} must be string")
            if "enum" in spec and val and val not in spec["enum"]:
                errors.append(f"Field {k} not in enum {spec['enum']} (got: {val})")
        elif spec.get("type") == "array":
            if not isinstance(val, list): errors.append(f"Field {k} must be array")
    return errors, norm


## Prompt Construction — Strict, Source‑Bound, with Abstention


In [None]:
ABSTAIN_TEXT = "Insufficient information in approved sources."

def build_prompt(question, sources_text, schema):
    contract = json.dumps(schema, ensure_ascii=False, indent=2)
    return f"""SYSTEM: You answer using only the provided sources. Cite each claim. If the answer isn’t present, reply exactly: "{ABSTAIN_TEXT}".

USER: {question}

SOURCES:
{sources_text}

RESPONSE FORMAT (JSON only):
{contract}

Rules:
- Use only claims that appear in SOURCES.
- Cite each claim with doc_id and page or section if available.
- If insufficient info, set "answer" to "{ABSTAIN_TEXT}" and return empty citations.
- Keep "rationale_internal" brief.
- Tone: concise, neutral, actionable.
"""

## Generate — Call Gemini 2.5 Flash Lite (JSON only)

In [None]:
def call_gemini_json(prompt):
    if not _GEMINI_READY:
        raise RuntimeError("GOOGLE_API_KEY not configured or SDK not available.")
    model = genai.GenerativeModel(MODEL_NAME)
    t0 = time.time()
    resp = model.generate_content(prompt)
    latency = time.time() - t0
    text = getattr(resp, "text", None)
    if not text and hasattr(resp, "candidates") and resp.candidates:
        parts = getattr(resp.candidates[0].content, "parts", None)
        if parts and hasattr(parts[0], "text"):
            text = parts[0].text
    if not text: raise RuntimeError("Empty response from model")
    s = text.strip()
    if s.startswith("```"):
        s = s.strip("`"); s = s.split("\n",1)[1] if "\n" in s else s
    start = s.find("{"); end = s.rfind("}")
    if start >= 0 and end > start: s = s[start:end+1]
    obj = json.loads(s)
    return obj, latency


## Orchestrator — Retrieve → Prompt → Generate → Validate → Post‑process


In [None]:
AUDIT = []

def within_acl(p, dept, loc):
    acl = p["acl"]
    return ("all" in acl["departments"] or dept in acl["departments"]) and ("all" in acl["locations"] or loc in acl["locations"])

def generate_answer(question, user_department="all", user_location="US", top_k=6):
    idxs = hybrid_retrieve(question, user_department=user_department, user_location=user_location, top_k=top_k)
    idxs = [i for i in idxs if within_acl(PASSAGES[i], user_department, user_location)]
    sources_text = render_sources(idxs)
    prompt = build_prompt(question, sources_text, ANSWER_SCHEMA)
    raw, latency = call_gemini_json(prompt)
    errors, norm = validate_with_fallbacks(raw)

    if (norm.get("answer") or "").strip() == ABSTAIN_TEXT:
        norm["citations"] = []
        norm["confidence"] = "low"
    if not norm.get("citations"):
        norm["citations"] = []
        for i in idxs[:2]:
            p = PASSAGES[i]
            norm["citations"].append({
                "doc_id": p["doc_id"],
                "page": p["section_or_page"] if p["source_type"]=="pdf" else "",
                "section": p["section_or_page"] if p["source_type"]!="pdf" else "",
                "url": p.get("url") or ""
            })

    errors2, norm2 = validate_with_fallbacks(norm)
    AUDIT.append({
        "ts": datetime.utcnow().isoformat()+"Z",
        "q": question, "dept": user_department, "loc": user_location,
        "retrieved": [f"{PASSAGES[i]['doc_id']}:{PASSAGES[i]['section_or_page']}" for i in idxs],
        "latency_ms": int(latency*1000), "errors": errors+errors2
    })
    return norm2, idxs, latency, errors+errors2


## Offline Evaluation — Golden Set & RAG Metrics


In [None]:
EVAL_SET = [
    {"id":"e1","q":"How many weeks of paid parental leave do primary caregivers get?","dept":"all","loc":"US",
     "expect":{"must_contain":["16 weeks"],"should_cite":["HR-Handbook-2025.pdf","Benefits-Site"]}},
    {"id":"e2","q":"Do UK remote employees get extra parental leave?","dept":"all","loc":"UK",
     "expect":{"must_contain":["additional 2 weeks"],"should_cite":["Announcement-Email-2025-08-05"]}},
    {"id":"e3","q":"What are the per diem rules for Canada engineering travel?","dept":"Engineering","loc":"CA",
     "expect":{"must_contain":["per diem"],"should_cite":["Policy-Wiki-Travel"]}},
    {"id":"e4","q":"What is the 401(k) employer match policy?","dept":"all","loc":"US",
     "expect":{"abstain": True}}
]

def grounded(pred, idxs):
    cited_ids = set([c.get("doc_id") for c in (pred.get("citations") or []) if isinstance(c, dict)])
    retrieved_ids = set([PASSAGES[i]["doc_id"] for i in idxs])
    if (pred.get("answer") or "") == ABSTAIN_TEXT: return True
    return bool(cited_ids) and cited_ids.issubset(retrieved_ids)

def contains_required_text(pred, must_list):
    ans = (pred.get("answer") or "").lower()
    return all(m.lower() in ans for m in must_list)

def abstention_correct(pred, expect):
    wants = expect.get("abstain", False)
    is_abstain = (pred.get("answer") or "") == ABSTAIN_TEXT
    return wants == is_abstain

def evaluate(dataset):
    rows, latencies = [], []
    for ex in dataset:
        pred, idxs, latency, errs = generate_answer(ex["q"], ex["dept"], ex["loc"])
        lat_ms = round(latency*1000.0,1); latencies.append(lat_ms)
        rows.append({
            "id": ex["id"],
            "latency_ms": lat_ms,
            "grounded": grounded(pred, idxs),
            "abstention_ok": abstention_correct(pred, ex["expect"]),
            "must_contain_ok": contains_required_text(pred, ex["expect"].get("must_contain", [])),
            "pred": pred,
            "retrieved": [f"{PASSAGES[i]['doc_id']}:{PASSAGES[i]['section_or_page']}" for i in idxs]
        })
    agg = {
        "grounded_rate": sum(1 for r in rows if r["grounded"])/len(rows),
        "abstention_correct_rate": sum(1 for r in rows if r["abstention_ok"])/len(rows),
        "must_contain_rate": sum(1 for r in rows if r["must_contain_ok"])/len(rows),
        "latency_p95_ms": sorted(latencies)[int(0.95*len(latencies))-1] if len(latencies)>1 else latencies[0]
    }
    return rows, agg

if _GEMINI_READY:
    preview_pred, preview_idxs, preview_latency, preview_errs = generate_answer("How long is parental leave for primary caregivers?", "all", "US")
    print("Preview citations:", preview_pred.get("citations"))
    print("Preview answer:", preview_pred.get("answer")[:120], "...")
else:
    print("Set GOOGLE_API_KEY to run generation.")

## Run Offline Evaluation

In [None]:
if _GEMINI_READY:
    rows, agg = evaluate(EVAL_SET)
    import json
    print("Aggregate metrics:")
    print(json.dumps(agg, indent=2))
    print("\nFirst result:")
    print(json.dumps(rows[0], indent=2))
else:
    print("Set GOOGLE_API_KEY to evaluate.")


## Retrieval Strategy Cheatsheet
- **Semantic (vectors)** → open‑ended questions on long text; rerank for precision  
- **Keyword (BM25)** → short queries, exact terms/codes  
- **Hybrid** → robust default (BM25 + semantic)  
- **Reranking (cross‑encoder)** → use on small candidate set (e.g., top‑50)  
- **MMR/diversity** → avoid duplicates; improve coverage



## Next Steps
- Replace TF‑IDF with true **embeddings**; add cross‑encoder reranker  
- Enrich corpus with **fresh** announcements; prefer recency in reranks  
- Add **JSON mode** or function calling; stricter citation validators  
- Hook into real **ACL**/identity and doc stores; log misses to refine chunks/prompts


## Audit Log (last 10)

In [None]:
for row in AUDIT[-10:]:
    import json
    print(json.dumps(row, indent=2))