# Enhanced: Prompt Engineering + RAG v2 (Gemma-2B)

This notebook preserves your original `rag_vs_promt_engineering-LATEST.ipynb` content unchanged below and adds improved helper functions and an embeddings-based retriever at the top.

- Deterministic decoding with proper token slicing
- Unified refusal policy and prompt template
- Embeddings retriever (all-MiniLM-L6-v2) + FAISS with chunking
- RAG v2 wrapper that conditions strictly on retrieved context


In [None]:
# Setup
!pip install -q transformers accelerate datasets sentence-transformers faiss-cpu scikit-learn huggingface_hub

from transformers import AutoTokenizer, AutoModelForCausalLM
from huggingface_hub import login, whoami
import os, torch

# Hugging Face auth: set HF_TOKEN env var OR paste directly below
HF_TOKEN = (os.getenv("HF_TOKEN", "").strip()) or ""  # paste your hf_... token between quotes if not using env var
if HF_TOKEN:
    try:
        login(token=HF_TOKEN)
        _u = whoami()
        print("HF login OK:", _u.get("name") or _u.get("email") or "authenticated")
    except Exception as _e:
        print("HF login failed:", _e)
else:
    print("No HF token provided. Set HF_TOKEN env var or paste it in HF_TOKEN.")

def load_model_tokenizer(model_id: str = "google/gemma-2b-it", token: str | None = HF_TOKEN):
    tokenizer = AutoTokenizer.from_pretrained(model_id, token=token)
    model = AutoModelForCausalLM.from_pretrained(
        model_id,
        token=token,
        torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
        device_map="auto",
    )
    return model, tokenizer

MODEL_ID = "google/gemma-2b-it"
model, tokenizer = load_model_tokenizer(MODEL_ID, token=HF_TOKEN)

GEN_KW = dict(
    max_new_tokens=512,
    do_sample=False,
    eos_token_id=tokenizer.eos_token_id,
    pad_token_id=tokenizer.eos_token_id,
)

def build_prompt_v2(user_question: str, context: str = "") -> str:
    rules = (
        "You are a careful assistant. Follow STRICTLY:\n"
        "1) Only answer using the provided context (if any).\n"
        "2) If the answer is not in the context or you are uncertain, reply EXACTLY with: Sorry I do not have that information\n"
        "3) Do not add any explanation, punctuation, or extra words when refusing.\n"
        "4) When you do know the answer, explain it in detail with at least 3 sentences and examples if possible.\n"
        "5) Do not rephrase the refusal.\n"
    )
    if context.strip():
        content = rules + f"\nUse ONLY this context to answer:\n---\n{context}\n---\n\nQuestion: {user_question}"
    else:
        content = rules + f"\nQuestion: {user_question}"
    messages = [{"role": "user", "content": content}]
    return tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)

@torch.no_grad()
def generate_answer_v2(question_text: str, context: str = "") -> str:
    prompt = build_prompt_v2(question_text, context)
    inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
    input_len = inputs.input_ids.shape[-1]
    outputs = model.generate(**inputs, **GEN_KW)
    gen_ids = outputs[0][input_len:]
    answer = tokenizer.decode(gen_ids, skip_special_tokens=True).strip()
    if answer.lower().startswith("sorry i do not have that information"):
        return "Sorry I do not have that information"
    return answer


In [None]:
# Embedding retriever with FAISS
from sentence_transformers import SentenceTransformer
import faiss, numpy as np, math, json
from typing import List, Dict, Any

EMB_MODEL_NAME = "sentence-transformers/all-MiniLM-L6-v2"
emb_model = SentenceTransformer(EMB_MODEL_NAME)

class DocumentChunk:
    def __init__(self, doc_id: str, chunk_id: int, text: str, meta: Dict[str, Any]):
        self.doc_id = doc_id
        self.chunk_id = chunk_id
        self.text = text
        self.meta = meta


def chunk_text(text: str, max_tokens: int = 180, overlap: int = 30) -> List[str]:
    # Simple whitespace chunking with overlap; token-approx via words
    words = text.split()
    if not words:
        return []
    chunks = []
    step = max_tokens - overlap
    for i in range(0, len(words), step):
        chunk_words = words[i : i + max_tokens]
        if not chunk_words:
            break
        chunks.append(" ".join(chunk_words))
        if i + max_tokens >= len(words):
            break
    return chunks


def build_corpus(kb_records: List[Dict[str, Any]]) -> List[DocumentChunk]:
    corpus: List[DocumentChunk] = []
    for rec in kb_records:
        content = (rec.get("title", "") + "\n" + rec.get("content", "")).strip()
        chunks = chunk_text(content, max_tokens=180, overlap=30)
        for idx, ch in enumerate(chunks):
            meta = {
                "title": rec.get("title", ""),
                "id": rec.get("id", ""),
                "version": rec.get("version", ""),
                "urls": [s.get("url") for s in rec.get("answer_card", {}).get("sources", [])],
            }
            corpus.append(DocumentChunk(str(rec.get("id", "")), idx, ch, meta))
    return corpus


def build_faiss_index(chunks: List[DocumentChunk]):
    texts = [c.text for c in chunks]
    embeddings = emb_model.encode(texts, convert_to_numpy=True, show_progress_bar=False)
    dim = embeddings.shape[1]
    index = faiss.IndexFlatIP(dim)
    faiss.normalize_L2(embeddings)
    index.add(embeddings)
    return index, embeddings


def retrieve_top_k(query: str, chunks: List[DocumentChunk], index, top_k: int = 3):
    q_emb = emb_model.encode([query], convert_to_numpy=True)
    faiss.normalize_L2(q_emb)
    scores, idxs = index.search(q_emb, top_k)
    results = []
    for score, idx in zip(scores[0], idxs[0]):
        ch = chunks[int(idx)]
        results.append((float(score), ch))
    return results


def format_context_from_results(results) -> str:
    blocks = []
    for score, ch in results:
        srcs = ch.meta.get("urls") or []
        title = ch.meta.get("title", "")
        header = f"### {title} (score={score:.2f})\n"
        src_line = ("Sources: " + ", ".join(srcs)) if srcs else ""
        blocks.append(header + ch.text + ("\n" + src_line if src_line else ""))
    return "\n\n".join(blocks)

# Optional: quick load if kb path available; else keep helpers ready
KB_PATH = "/content/python_release_kb.jsonl"
try:
    kb_records = []
    with open(KB_PATH, "r", encoding="utf-8") as f:
        for line in f:
            kb_records.append(json.loads(line))
    corpus_chunks = build_corpus(kb_records)
    faiss_index, _ = build_faiss_index(corpus_chunks)
except Exception as e:
    kb_records, corpus_chunks, faiss_index = None, None, None


In [None]:
# RAG v2 answer functions
from IPython.display import display, Markdown


def build_context_v2(query: str, top_k: int = 3) -> str:
    assert corpus_chunks is not None and faiss_index is not None, "Load KB and build index first."
    results = retrieve_top_k(query, corpus_chunks, faiss_index, top_k=top_k)
    return format_context_from_results(results)


def rag_answer_v2(query: str, top_k: int = 3) -> str:
    context = build_context_v2(query, top_k=top_k)
    answer = generate_answer_v2(query, context=context)
    display(Markdown(answer))
    return answer

# Smoke test cells (commented):
# q = "What PEP replaced PEP 722 for inline script metadata?"
# rag_answer_v2(q, top_k=3)


---

## Section 1: Prompt Engineering (no RAG)

Run the setup cell above, then use the demo cells at the end of this section to ask questions without any retrieved context.


In [None]:
# Prompt-only demo cells
# 1) Generic question (should answer)
print(generate_answer_v2("Explain list comprehensions in Python with a short example."))

# 2) Domain-specific fact (should refuse without context)
print(generate_answer_v2("What PEP replaced PEP 722 for inline script metadata?"))


---

## Section 2: Retrieval-Augmented Generation (RAG)

Set your KB path, build the retriever index, and then ask the same question again with context.


In [None]:
# RAG setup: point to your KB and build index
import json

KB_PATH = "data/processed/python_release_kb.jsonl"  # adjust if needed
kb_records = []
with open(KB_PATH, "r", encoding="utf-8") as f:
    for line in f:
        kb_records.append(json.loads(line))

corpus_chunks = build_corpus(kb_records)
faiss_index, _ = build_faiss_index(corpus_chunks)
len(corpus_chunks)


In [None]:
# RAG demo: ask the same question
rag_answer_v2("What PEP replaced PEP 722 for inline script metadata?", top_k=3)
