<a href="https://colab.research.google.com/github/Sanchit9587/Shaastra_Chatbot_26/blob/main/RAG_00.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install -q sentence-transformers chromadb PyPDF2 transformers accelerate huggingface_hub


In [None]:
import os
import re
from typing import List, Tuple, Dict, Any

import numpy as np
import chromadb
from chromadb import PersistentClient
from PyPDF2 import PdfReader
from sentence_transformers import SentenceTransformer

import torch
from transformers import AutoTokenizer, AutoModelForCausalLM

In [None]:
#Setup the API keys here

HF_TOKEN = "YOUR_HF_API_KEY"  # <<-- put your HF access token here

if "YOUR_HF_TOKEN_HERE" in HF_TOKEN:
    print("[WARN] Please set HF_TOKEN at the top of this cell if the model is gated.")

os.environ["HF_HOME"] = "/content/.cache/huggingface"
os.environ["HF_TOKEN"] = HF_TOKEN

In [None]:
#Defining models and dbs we will be using

# Embedding model (free & local)
EMBEDDING_MODEL_NAME = "sentence-transformers/all-MiniLM-L6-v2"

# Gemma model (local in Colab)
GEMMA_MODEL_NAME = "google/gemma-7b-it"

# Chroma config
CHROMA_PATH = "/content/chroma_db"
COLLECTION_NAME = "pdf_docs_collection"
DOCS_FOLDER = "/content/Documents"   # Put your PDFs here


In [None]:
#Loading the embedding model

print(f"[INFO] Loading embedding model: {EMBEDDING_MODEL_NAME}")
embedder = SentenceTransformer(EMBEDDING_MODEL_NAME)
print("[INFO] Embedding model loaded.")


In [None]:
#Loading Gemma Locally because the HF API just does not work

print(f"[INFO] Loading Gemma model: {GEMMA_MODEL_NAME}")
if "YOUR_HF_TOKEN_HERE" in HF_TOKEN:
    print("[WARN] HF_TOKEN not set. Model download may fail if the model is gated.")

tokenizer = AutoTokenizer.from_pretrained(
    GEMMA_MODEL_NAME,
    use_auth_token=HF_TOKEN,
)

# Use bfloat16 for efficiency; device_map="auto" to spread across GPU/CPU if needed
model = AutoModelForCausalLM.from_pretrained(
    GEMMA_MODEL_NAME,
    torch_dtype=torch.bfloat16,
    device_map="auto",
    use_auth_token=HF_TOKEN,
)

model.eval()
print("[INFO] Gemma model loaded.")

In [None]:
#Just setting up few things

def split_into_sentences(text: str) -> List[str]:
    """Very simple sentence splitter for English-like text."""
    text = text.strip()
    if not text:
        return []
    sentences = re.split(r'(?<=[.!?])\s+', text)
    return [s.strip() for s in sentences if s.strip()]


def embed_texts(texts: List[str]) -> List[List[float]]:
    """Embed texts using SentenceTransformer (local, free)."""
    if not texts:
        return []
    embeddings = embedder.encode(texts, convert_to_numpy=True)
    return embeddings.tolist()


def cosine_similarity(vec_a: List[float], vec_b: List[float]) -> float:
    """Compute cosine similarity between two vectors."""
    a = np.array(vec_a, dtype=np.float32)
    b = np.array(vec_b, dtype=np.float32)
    denom = (np.linalg.norm(a) * np.linalg.norm(b)) + 1e-10
    return float(np.dot(a, b) / denom)

In [None]:
#Loading in then PDF

def load_pdfs_from_folder(folder_path: str) -> List[Tuple[str, str]]:
    """
    Load all .pdf files from given folder.
    Returns list of (filename, extracted_text).
    """
    docs: List[Tuple[str, str]] = []

    if not os.path.exists(folder_path):
        print(f"[WARN] Folder does not exist: {folder_path}")
        return docs

    for fname in os.listdir(folder_path):
        if fname.lower().endswith(".pdf"):
            full_path = os.path.join(folder_path, fname)
            try:
                reader = PdfReader(full_path)
                pages_text = []
                for page in reader.pages:
                    t = page.extract_text() or ""
                    pages_text.append(t)
                text = "\n".join(pages_text)
                docs.append((fname, text))
                print(f"[INFO] Loaded PDF: {fname} ({len(reader.pages)} pages)")
            except Exception as e:
                print(f"[ERROR] Failed to read {fname}: {e}")
    return docs

In [None]:
#Chunker

def recursive_character_split(
    text: str,
    chunk_size: int = 2000,
    chunk_overlap: int = 300,
    separators: List[str] = None,
) -> List[str]:
    """
    Recursive character-based splitter similar to LangChain's RecursiveCharacterTextSplitter.

    Strategy:
    - Start with big separators: paragraphs "\\n\\n"
    - Then lines "\\n"
    - Then sentence-ish ". "
    - Then spaces " "
    - Finally raw characters ""
    - For any segment longer than `chunk_size`, recursively split with the next separator.
    - After we get small segments, we merge them into final chunks with `chunk_overlap`.
    """
    if separators is None:
        separators = ["\n\n", "\n", ". "]

    def _split(text: str, sep_index: int) -> List[str]:
        # If text already small enough, stop splitting
        if len(text) <= chunk_size:
            return [text]

        # If we've exhausted separators, hard-cut by characters
        if sep_index >= len(separators):
            return [text[i:i + chunk_size] for i in range(0, len(text), chunk_size)]

        sep = separators[sep_index]
        # If separator is empty string, treat as char-level split
        if sep == "":
            # already at character level; just hard split
            return [text[i:i + chunk_size] for i in range(0, len(text), chunk_size)]

        parts = text.split(sep)
        result: List[str] = []

        for i, part in enumerate(parts):
            # Re-add the separator we split on (except before first part)
            if i < len(parts) - 1:
                candidate = part + sep
            else:
                candidate = part

            if len(candidate) <= chunk_size:
                result.append(candidate)
            else:
                # Too big, go one level deeper with a smaller separator
                result.extend(_split(candidate, sep_index + 1))

        return result

    # First pass: recursively split until every piece <= chunk_size
    raw_pieces = [p for p in _split(text, 0) if p and p.strip()]

    # Second pass: merge pieces into chunks with overlap
    chunks: List[str] = []
    current = ""

    for piece in raw_pieces:
        piece = piece.strip()
        if not piece:
            continue

        # If current is empty, start a new chunk
        if not current:
            current = piece
            continue

        # If adding this piece stays within chunk_size, add it
        if len(current) + 1 + len(piece) <= chunk_size:
            current = current + " " + piece
        else:
            # Close current chunk
            chunks.append(current.strip())

            # Start new chunk with overlap from the end of previous chunk
            if chunk_overlap > 0 and len(current) > chunk_overlap:
                overlap = current[-chunk_overlap:]
                current = overlap + " " + piece
            else:
                current = piece

    if current:
        chunks.append(current.strip())

    return chunks

#Just the helper function don't get confused by the name, if you want to change the chunk size do it here and not the above function
def semantic_chunk_document(
    text: str,
    chunk_size: int = 500,
    chunk_overlap: int =0,
) -> List[str]:
    """
    Wrapper around recursive character-based chunking.

    - Uses hierarchical separators: ["\\n\\n", "\\n", ". ", " ", ""]
    - Respects `chunk_size` in characters
    - Adds `chunk_overlap` (characters) between chunks
    - No embeddings / similarity involved in chunking.
    """
    return recursive_character_split(
        text=text,
        chunk_size=chunk_size,
        chunk_overlap=chunk_overlap,
        separators=["\n\n", "\n"],
    )

In [None]:
#Indexing the context doc and storing it using Chroma DB

def build_chroma_index(docs_folder: str = DOCS_FOLDER) -> None:
    """
    - Load PDFs from folder
    - Chunk each doc with recursive character-based chunker
    - Embed chunks
    - Store (id, document, embedding, metadata) in Chroma
    """
    print(f"[INFO] Building Chroma index from PDFs in: {docs_folder}")

    chroma_client: PersistentClient = chromadb.PersistentClient(path=CHROMA_PATH)

    # Delete existing collection if we want a fresh rebuild
    try:
        chroma_client.delete_collection(COLLECTION_NAME)
        print(f"[INFO] Deleted existing collection '{COLLECTION_NAME}'.")
    except Exception:
        pass

    collection = chroma_client.get_or_create_collection(name=COLLECTION_NAME)

    pdf_docs = load_pdfs_from_folder(docs_folder)
    if not pdf_docs:
        print("[WARN] No PDFs found in docs folder.")
        return

    all_ids: List[str] = []
    all_texts: List[str] = []
    all_embeddings: List[List[float]] = []
    all_metadatas: List[Dict[str, Any]] = []

    for doc_idx, (filename, text) in enumerate(pdf_docs):
        print(f"[INFO] Chunking document: {filename}")
        chunks = semantic_chunk_document(text)

        if not chunks:
            print(f"[WARN] No chunks created for: {filename}")
            continue

        print(f"       -> {len(chunks)} chunks")

        # Embed final chunks
        chunk_embeddings = embed_texts(chunks)

        for chunk_idx, (chunk_text, emb) in enumerate(zip(chunks, chunk_embeddings)):
            doc_id = f"{doc_idx}-{chunk_idx}-{filename}"
            all_ids.append(doc_id)
            all_texts.append(chunk_text)
            all_embeddings.append(emb)
            all_metadatas.append(
                {
                    "source": filename,
                    "chunk_index": chunk_idx,
                }
            )

    if not all_texts:
        print("[WARN] No chunks produced; nothing to index.")
        return

    collection.add(
        ids=all_ids,
        documents=all_texts,
        embeddings=all_embeddings,
        metadatas=all_metadatas,
    )

    print(
        f"[INFO] Indexed {len(all_texts)} chunks "
        f"from {len(pdf_docs)} PDF(s) into collection '{COLLECTION_NAME}'."
    )


def get_chroma_collection():
    chroma_client: PersistentClient = chromadb.PersistentClient(path=CHROMA_PATH)
    return chroma_client.get_or_create_collection(name=COLLECTION_NAME)


In [None]:
#Retriver that retrieves the top k chunks

def retrieve_top_k_chunks(query: str, k: int = 3) -> List[Dict[str, Any]]:
    """
    - Embed query using SentenceTransformer
    - Fetch all chunks + embeddings from Chroma
    - Compute cosine similarity manually
    - Return top-k chunks
    """
    collection = get_chroma_collection()

    data = collection.get(include=["documents", "embeddings", "metadatas"])

    docs = data.get("documents", [])
    embeddings = data.get("embeddings", [])
    metadatas = data.get("metadatas", [])
    ids = data.get("ids", [])

    if not docs:
        print("[WARN] No documents in collection. Build the index first.")
        return []

    query_emb = embed_texts([query])[0]

    sims = [cosine_similarity(query_emb, emb) for emb in embeddings]

    top_indices = sorted(range(len(sims)), key=lambda i: sims[i], reverse=True)[:k]

    results: List[Dict[str, Any]] = []
    for idx in top_indices:
        results.append(
            {
                "id": ids[idx] if idx < len(ids) else str(idx),
                "document": docs[idx],
                "metadata": metadatas[idx],
                "similarity": sims[idx],
            }
        )

    return results


In [None]:
#Combining the context retrieved with a helpful prompt for gemma

def build_prompt(query: str, retrieved_chunks: List[Dict[str, Any]]) -> str:
    """Build a simple context + question prompt for Gemma."""
    context_blocks = []
    for i, item in enumerate(retrieved_chunks):
        source = item["metadata"].get("source", "unknown")
        context_blocks.append(
            f"[Chunk {i+1} | Source: {source}]\n{item['document']}"
        )

    context_str = "\n\n".join(context_blocks) if context_blocks else "No context."

    prompt = f"""You are ShaastraBot â€” an enthusiastic, helpful assistant for Shaastra IIT Madras!

    Your job is to answer user questions **using ONLY the information found in the provided context**.
    Do NOT use outside knowledge. If something is missing from the context, clearly say that it is not available.

    When the user asks about an event:
    - Try to include **date, time, venue, on-spot registration info, viewership status, prize details**, or any other specifics IF they appear anywhere in the context.
    - Present the information in a friendly, energetic tone â€” you're excited to help participants learn about Shaastra!
    - Keep the answer concise but informative.

    General rules:
    - NEVER hallucinate or invent details.
    - If the context contains conflicting information, state that clearly.
    - If the answer cannot be determined from the context, say so politely.
    - Structure answers clearly using bullet points or short paragraphs when appropriate.
    - If the userâ€™s query is broad or unclear, summarize relevant info from the context to help them.

    Now use the following context to answer the userâ€™s question as accurately and enthusiastically as possible.


Context:
{context_str}

User question: {query}

Answer:"""
    return prompt


In [None]:
#Using gemma locally to generate the answer
def call_gemma_local(
    prompt: str,
    max_new_tokens: int = 512,
    temperature: float = 0.2,
) -> str:
    """
    Generate an answer using the local Gemma model.
    """
    inputs = tokenizer(prompt, return_tensors="pt").to(model.device)

    with torch.no_grad():
        output_ids = model.generate(
            **inputs,
            max_new_tokens=max_new_tokens,
            do_sample=True,
            temperature=temperature,
            pad_token_id=tokenizer.eos_token_id,
        )

    # Only decode the newly generated tokens, not the whole prompt
    generated_ids = output_ids[0][inputs["input_ids"].shape[1]:]
    text = tokenizer.decode(generated_ids, skip_special_tokens=True)
    return text.strip()


In [None]:
#The helper function that uses everything sequentially to generate the answer

def rag_answer(
    query: str,
    k: int = 5,
    max_new_tokens: int = 512,
    temperature: float = 0.2,
) -> str:
    """
    Full RAG pipeline:
    - Embed query
    - Retrieve top-k chunks from Chroma via cosine similarity
    - Print chunks clearly
    - Build prompt with context
    - Call local Gemma
    - Print & return answer
    """
    chunks = retrieve_top_k_chunks(query, k=k)
    if not chunks:
        return "[ERROR] No chunks retrieved. Did you build the index?"

    # Printing the Top K chunks for a better understanding while making the RAG pipeline
    print("\n================ RETRIEVED TOP CHUNKS ================")
    for i, c in enumerate(chunks, start=1):
        print(f"\nðŸ”¹ Chunk {i}  (score={c['similarity']:.3f})")
        print(f"Source: {c['metadata'].get('source')}")
        print("-" * 60)
        text_preview = c["document"][:500].strip()
        print(text_preview + ("..." if len(c["document"]) > 500 else ""))
    print("\n========================================================")

    prompt = build_prompt(query, chunks)

    print("\n[INFO] Generating answer using local Gemma...\n")
    answer = call_gemma_local(
        prompt,
        max_new_tokens=max_new_tokens,
        temperature=temperature,
    )

    #Printing the final answer
    print("\n===================== ANSWER ==========================\n")
    print(answer)
    print("\n========================================================\n")

    return answer


print("[INFO] RAG setup complete. Steps:")
print("1) Upload PDFs into /content/docs (or change DOCS_FOLDER).")
print("2) Run: build_chroma_index().")
print("3) Run: rag_answer('your question').")

In [None]:
#Calling the function that chunks the context doc

build_chroma_index()

[INFO] Building Chroma index from PDFs in: /content/Documents
[INFO] Deleted existing collection 'pdf_docs_collection'.
[INFO] Loaded PDF: ShaastraContextDoc.pdf (45 pages)
[INFO] Chunking document: ShaastraContextDoc.pdf
       -> 191 chunks
[INFO] Indexed 191 chunks from 1 PDF(s) into collection 'pdf_docs_collection'.


In [None]:
#Ask your questions here, feel free to mess around with the k value to get a better answer

answer = rag_answer("What is AT-Makeathon?", k=3)
print(answer)