
# RAG Chatbot (Colab, Private) — PDF → Preprocess → Index → Ask
**Purpose:** Quickly experience a Retrieval‑Augmented Generation (RAG) flow entirely inside a Colab/Jupyter notebook.  
**Stack:** `ipywidgets` UI, `pypdf` for extraction, `scikit-learn` for vector search, Google **Gemini** API for embeddings & generation.

> This notebook looks for the environment variable `GOOGLE_API_KEY`.


## Imports & Installs

In [None]:
!pip -q install pypdf ipywidgets reportlab

import os, io, requests, textwrap, tempfile
from pathlib import Path
from dataclasses import dataclass
from typing import List, Dict, Any, Tuple

from pypdf import PdfReader
from reportlab.pdfgen import canvas
from reportlab.lib.pagesizes import letter

import numpy as np
from sklearn.neighbors import NearestNeighbors

import ipywidgets as widgets
from IPython.display import display, Markdown, clear_output

## Gemini API Setup

In [None]:
from google.colab import userdata
try:
    import google.generativeai as genai
    genai.configure(api_key=userdata.get('GOOGLE_API_KEY'))
    _GEMINI_READY = bool(userdata.get('GOOGLE_API_KEY'))
except Exception:
    print("Install google-generativeai to enable live calls.")
    _GEMINI_READY = False

GENERATION_MODEL = "gemini-2.5-flash"
EMBED_MODEL      = "gemini-embedding-001"

## Default PDF Download & Sample Generation

In [None]:
DEFAULT_POLICY_PDF_URLS = [
    # FERPA(Family Educational Rights and Privacy Act) documents for dafault samples
    "https://studentprivacy.ed.gov/sites/default/files/resource_document/file/An%20Eligible%20Student%20Guide%20to%20FERPA_0.pdf",
    "https://studentprivacy.ed.gov/sites/default/files/resource_document/file/A%20parent%20guide%20to%20ferpa_508.pdf"
]

def download_pdf(url: str) -> bytes:
    resp = requests.get(url, timeout=30)
    resp.raise_for_status()
    # Some servers may not set proper PDF headers; we still try bytes
    return resp.content

def create_local_sample_policy_pdf(path: Path):
    # Fallback: generate a tiny FERPA-related PDF so the pipeline always works offline.
    sample_text = textwrap.dedent('''
        Sample FERPA Policy (Demo)
        --------------------------
        1. Rights of Parents and Eligible Students:
           Parents or eligible students have the right to inspect and review education records
           maintained by the school.
        2. Amendment of Records:
           Parents or eligible students may request corrections to records they believe are inaccurate
           or misleading.
        3. Consent for Disclosure:
           Schools generally must have written permission before releasing information from a student's
           education record.
        4. Exceptions:
           FERPA allows disclosure without consent to certain parties such as school officials with
           legitimate educational interests, or in cases of health and safety emergencies.
        5. Directory Information:
           Schools may disclose "directory" information (such as name, address, phone number, dates of
           attendance) without consent, but must inform parents and eligible students and allow them
           time to opt out.
        6. Compliance Note:
           This demo text is adapted for instructional purposes only and is not a full legal document.
    ''').strip()

    c = canvas.Canvas(str(path), pagesize=letter)
    width, height = letter
    y = height - 72
    for line in sample_text.splitlines():
        c.drawString(72, y, line)
        y -= 18
        if y < 72:
            c.showPage()
            y = height - 72
    c.save()

def ensure_pdf_bytes(uploaded: bytes = None, fallback_url: str = None):
    """Returns (pdf_bytes, source_label). Tries uploaded, then fallback_url, then local sample."""
    if uploaded:
        return uploaded, "uploaded_file.pdf"
    urls = [fallback_url] if fallback_url else [] 
    urls += DEFAULT_POLICY_PDF_URLS
    for u in urls:
        try:
            b = download_pdf(u)
            return b, u
        except Exception as e:
            print(f"⚠️ Could not fetch {u}: {e}")
    # Last resort: generate a local PDF with sample text
    tmp = Path(tempfile.gettempdir()) / "sample_policy.pdf"
    create_local_sample_policy_pdf(tmp)
    with open(tmp, "rb") as f:
        return f.read(), "local_sample_policy.pdf"

## PDF to Pages & Chunking

In [None]:
@dataclass
class PageText:
    page_num: int
    text: str

def pdf_to_pages(pdf_bytes: bytes):
    reader = PdfReader(io.BytesIO(pdf_bytes))
    pages = []
    for i, page in enumerate(reader.pages):
        try:
            txt = page.extract_text() or ""
        except Exception:
            txt = ""
        txt = " ".join(txt.split())
        pages.append(PageText(page_num=i+1, text=txt))
    return pages

def chunk_text(text: str, chunk_chars: int = 1200, overlap: int = 200):
    chunks = []
    i = 0
    n = len(text)
    while i < n:
        end = min(i + chunk_chars, n)
        chunk = text[i:end]
        chunks.append(chunk)
        if end == n:
            break
        i = end - overlap if end - overlap > i else end
    return chunks

## Embedding & Vector Indexing

In [None]:
def embed_texts(texts, model=EMBED_MODEL):
    if isinstance(texts, str):
        texts = [texts]
    rows = []
    for t in texts:
        resp = genai.embed_content(model=model, content=t)
        emb = resp.get("embedding", resp)
        if isinstance(emb, dict) and "values" in emb:
            v = emb["values"]
        else:
            v = emb
        rows.append(np.asarray(v, dtype=float))
    X = np.vstack(rows)
    if X.ndim != 2:
        raise RuntimeError(f"Expected 2D embeddings, got {X.shape}")
    return X

class SimpleVectorIndex:
    def __init__(self, embeddings, meta):
        self.embeddings = np.asarray(embeddings, dtype=float)
        if self.embeddings.ndim != 2:
            raise ValueError(f"embeddings must be 2D, got {self.embeddings.shape}")
        self.meta = meta
        self.n = self.embeddings.shape[0]

        n_neighbors = max(1, min(10, self.n))
        self.nn = NearestNeighbors(n_neighbors=n_neighbors, metric="cosine")
        self.nn.fit(self.embeddings)

    def search(self, query_vec, top_k=5):
        q = np.asarray(query_vec, dtype=float).reshape(1, -1)
        k = max(1, min(int(top_k), self.n))

        # Fast-path for single-item indexes
        if self.n == 1:
            return [(0, 1.0)]

        distances, indices = self.nn.kneighbors(q, n_neighbors=k)
        results = []
        for d, idx in zip(distances[0], indices[0]):
            sim = 1.0 - float(d)  # cosine distance -> similarity
            results.append((int(idx), sim))
        return results

## RAG Prompt & Answer

In [None]:
RAG_SYSTEM_PROMPT = '''You are a careful assistant. Answer ONLY using the provided context from the approved document(s).
If the answer is not present in the context, say "I don't know based on the approved documents."
Always cite page numbers like (p. X) when relevant.
Keep answers concise, accurate, and policy-aligned.'''

def build_context(snippets, max_tokens: int = 1500):
    buf = []
    used = 0
    for s in snippets:
        header = f"[Source: {s['source']} p.{s['page']}]"
        body = s["text"]
        piece = header + "\n" + body + "\n"
        if used + len(piece) > max_tokens * 4:
            break
        buf.append(piece)
        used += len(piece)
    return "\n".join(buf)

def answer_with_gemini(question: str, context: str):
    prompt = f"{RAG_SYSTEM_PROMPT}\n\nContext:\n{context}\n\nQuestion: {question}\nAnswer:"
    resp = genai.GenerativeModel(GENERATION_MODEL).generate_content(prompt)
    return resp.text or "(no response)"


## Demo workflow
1. **Load a PDF** by uploading or pasting a URL (or rely on a default FERPA PDF).  
2. Click **Preprocess & Chunk** and then **Build Index**.  
3. Ask: *What is FERPA?”* (or any relevant question).  

> Tweak chunk size, overlap, and Top‑K to see how retrieval quality changes.


In [None]:
# Widgets
upload = widgets.FileUpload(accept=".pdf", multiple=False)
url_input = widgets.Text(
    value="",
    placeholder="Optional: paste a PDF URL (used if no file uploaded)",
    description="PDF URL:"
)
chunk_slider = widgets.IntSlider(value=500, min=200, max=1000, step=100, description="Chunk size")
overlap_slider = widgets.IntSlider(value=100, min=0, max=400, step=50, description="Overlap")
topk_slider = widgets.IntSlider(value=5, min=1, max=10, step=1, description="Top-K")
load_btn = widgets.Button(description="Load PDF", button_style="primary")
prep_btn = widgets.Button(description="Preprocess & Chunk", button_style="warning")
index_btn = widgets.Button(description="Build", button_style="info")

status_out = widgets.Output(layout={'border': '1px solid #ddd'})
prep_out = widgets.Output(layout={'border': '1px solid #ddd'})
search_out = widgets.Output(layout={'border': '1px solid #ddd'})

display(widgets.VBox([
    widgets.HTML("<h3>1) Load a PDF</h3>"),
    widgets.HTML("Upload a PDF file or provide a URL to a publicly accessible PDF. If both are provided, the uploaded file is used."),
    widgets.HTML("Tip: If no PDF is provided, a sample FERPA policy will be used."),
    widgets.HBox([upload]),
    url_input,
    load_btn,
    status_out,
    widgets.HTML("<h3>2) Preprocess PDF</h3>"),
    chunk_slider, overlap_slider, prep_btn, prep_out,
    widgets.HTML("<h3>3) Build Index</h3>"),
    topk_slider, index_btn, search_out,
]))

STATE = {"source_label": None, "pages": [], "chunks": [], "chunk_meta": [], "emb": None, "index": None}

def on_load_clicked(_):
    status_out.clear_output()
    with status_out:
        try:
            uploaded_bytes = None
            if upload.value:
                first = list(upload.value.values())[0]
                uploaded_bytes = first["content"]
            url = url_input.value.strip() or None
            pdf_bytes, src = ensure_pdf_bytes(uploaded=uploaded_bytes, fallback_url=url)
            pages = pdf_to_pages(pdf_bytes)
            STATE["source_label"] = src
            STATE["pages"] = pages
            print(f"Loaded PDF from: {src}")
            print(f"   Pages extracted: {len(pages)}")
        except Exception as e:
            print(f"Load failed: {e}")

def on_prep_clicked(_):
    prep_out.clear_output()
    with prep_out:
        if not STATE["pages"]:
            print("No PDF loaded yet.")
            return
        chunk_chars = int(chunk_slider.value)
        overlap = int(overlap_slider.value)
        chunks, meta = [], []
        src = STATE["source_label"]
        for p in STATE["pages"]:
            if not p.text:
                continue
            cks = chunk_text(p.text, chunk_chars=chunk_chars, overlap=overlap)
            for c in cks:
                chunks.append(c)
                meta.append({"source": src, "page": p.page_num, "text": c})
        STATE["chunks"] = chunks
        STATE["chunk_meta"] = meta
        print(f"Created {len(chunks)} chunks across {len(STATE['pages'])} pages.")

def on_index_clicked(_):
    search_out.clear_output()
    with search_out:
        if not STATE["chunks"]:
            print("No chunks to index. Run preprocessing first.")
            return
        try:
            print("Embedding chunks... (this may take a moment)")
            emb = embed_texts(STATE["chunks"])
            idx = SimpleVectorIndex(emb, STATE["chunk_meta"])
            STATE["emb"] = emb
            STATE["index"] = idx
            print(f"Index built with shape {emb.shape}.")
            print("Ready to answer questions. Next cell will display the chat interface.")
        except Exception as e:
            print(f"Index build failed: {e}")

def on_ask(q_input):
    if not STATE["index"]:
        return "Error - Build the index first."
    if not q_input:
        return "Enter a question."
    q = q_input
    answer_out = ''

    try:
        qvec = embed_texts([q])[0]
        results = STATE["index"].search(qvec, top_k=int(topk_slider.value))
        snippets = []
        for idx, sim in results:
            m = STATE["chunk_meta"][idx]
            snippets.append({"source": m["source"], "page": m["page"], "text": m["text"], "score": sim})
        context = build_context(snippets)
        # print("Top matches:")
        # for s in snippets[:3]:
        #     print(f"  - p.{s['page']} (sim≈{s['score']:.3f})")
        # print("\nAnswer:\n")
        ans = answer_with_gemini(q, context)
        answer_out += f"{ans}\n"
        answer_out += f"\nContext used (truncated):\n{context[:1000]}...\n"
        return answer_out
    except Exception as e:
        return f"Q&A failed: {e}"

load_btn.on_click(on_load_clicked)
prep_btn.on_click(on_prep_clicked)
index_btn.on_click(on_index_clicked)

## Chat Interface

In [None]:
chat_box = widgets.Output()
input_box = widgets.Text(placeholder="Type a message and press Enter…")
send_btn = widgets.Button(description="Send")

history = []

def render_chat():
    chat_box.clear_output()
    with chat_box:
        for u, a in history:
            display(Markdown(f"**You:** {u}"))
            display(Markdown(f"**Bot:** {a}"))

def send_message(_=None):
    user_msg = input_box.value.strip()
    if not user_msg:
        return
    input_box.value = ""

    # Generate a reply
    context = "\n".join([f"User: {u}\nAssistant: {a}" for u, a in history])
    prompt = (context + f"\nUser: {user_msg}\nAssistant:").strip()

    try:
        resp = on_ask(prompt)
        bot_msg = resp or "(no response)"
    except Exception as e:
        bot_msg = f"(error) {e}"

    history.append((user_msg, bot_msg))
    render_chat()

send_btn.on_click(send_message)
input_box.on_submit(send_message)

display(widgets.VBox([chat_box, widgets.HBox([input_box, send_btn])]))
render_chat()