In [1]:
!pip install -q langchain langchain-community langchain-text-splitters chromadb sentence-transformers transformers accelerate pypdf faiss-cpu

[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/67.3 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m67.3/67.3 kB[0m [31m3.0 MB/s[0m eta [36m0:00:00[0m
[?25h  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.5/2.5 MB[0m [31m78.9 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m21.7/21.7 MB[0m [31m66.1 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m328.3/328.3 kB[0m [31m22.1 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m23.7/23.7 MB[0m [31m53.5 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m278.2/278.2 kB[0m [31m17.0 MB/s[0m eta [36m0:00

In [2]:
import os
from typing import List, Dict, Any, Tuple
from math import ceil

try:
    from langchain_community.document_loaders import PyPDFLoader
    from langchain_text_splitters import RecursiveCharacterTextSplitter
    from langchain_community.embeddings import HuggingFaceEmbeddings
    from langchain_community.vectorstores import FAISS
    from langchain_core.documents import Document
    from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
    import torch
    import gc # Garbage Collection
except ImportError:
    print("Dependencies not installed. Please run the installation cell first: !pip install -q langchain langchain-community langchain-text-splitters chromadb sentence-transformers transformers accelerate pypdf faiss-cpu")

In [3]:


# Put your PDF(s) in this folder in Colab
DOCS_FOLDER = "/content/docs"
PDF_FILENAME = "ShaastraContextDoc.pdf"   # <-- Ensure this file is uploaded
PDF_PATH = os.path.join(DOCS_FOLDER, PDF_FILENAME)

# Models
EMBED_MODEL = "sentence-transformers/all-mpnet-base-v2"

# Specific models for their tasks
SUMMARY_MODEL_NAME = "google/gemma-2-2b-it" # Smaller model for fast summarization
RAG_MODEL_NAME = "google/gemma-7b-it"      # Larger model for high-quality answers
RAG_CONTEXT_WINDOW = 8192

# Chunking params
DETAILED_CHUNK_SIZE = 600
DETAILED_CHUNK_OVERLAP = 50
SUMMARY_GROUP_SIZE = 5

# Retrieval params
TOP_K_SUMMARIES = 5
TOP_K_FINAL_CHUNKS = 5
DETAILED_K_SEARCH = 50

# Memory management
CONTEXT_RESERVATION = 2000

if not os.path.exists(DOCS_FOLDER):
    os.makedirs(DOCS_FOLDER)

print(f"[SETUP] Ensure your PDF ({PDF_FILENAME}) is uploaded to {DOCS_FOLDER}")

## 1. Load Data and Chunking

def load_pdf_as_docs(pdf_path: str) -> List[Document]:
    if not os.path.exists(pdf_path):
        print(f"[ERROR] PDF not found at: {pdf_path}. Please upload it and run again.")
        return []
    loader = PyPDFLoader(pdf_path)
    docs = loader.load()
    for i, d in enumerate(docs):
        d.metadata["page_number"] = d.metadata.get("page", i + 1)
        d.metadata["source"] = os.path.basename(pdf_path)
    print(f"[INFO] Loaded {len(docs)} page-level docs from {pdf_path}")
    return docs

def make_detailed_chunks(
    docs: List[Document],
    chunk_size: int = DETAILED_CHUNK_SIZE,
    chunk_overlap: int = DETAILED_CHUNK_OVERLAP,
) -> List[Document]:
    splitter = RecursiveCharacterTextSplitter(
        chunk_size=chunk_size,
        chunk_overlap=chunk_overlap,
        separators=["\n\n", "\n", ". ", " ", ""],
    )
    chunks = splitter.split_documents(docs)
    for idx, c in enumerate(chunks):
        c.metadata["chunk_id"] = idx
    print(f"[INFO] Created {len(chunks)} detailed chunks.")
    return chunks

docs = load_pdf_as_docs(PDF_PATH)
if not docs:
    raise FileNotFoundError("Document loading failed. Please check file path.")

detailed_chunks = make_detailed_chunks(docs)

# Load Embedding Model
print("[INFO] Loading embedding model...")
embeddings = HuggingFaceEmbeddings(model_name=EMBED_MODEL)

## 2. Load and Unload Summarization Model (Gemma-2B)

def get_summarizer_pipe():
    print(f"[INFO] Loading Summarization Model: {SUMMARY_MODEL_NAME}...")
    tokenizer = AutoTokenizer.from_pretrained(SUMMARY_MODEL_NAME)
    model = AutoModelForCausalLM.from_pretrained(
        SUMMARY_MODEL_NAME,
        device_map="auto",
        torch_dtype=torch.bfloat16,
    )
    pipe = pipeline(
        "text-generation",
        model=model,
        tokenizer=tokenizer,
        max_new_tokens=300, # Max output tokens for summary
        do_sample=False,
    )
    return pipe, tokenizer, model

def summarize_with_pipe(pipe, tokenizer, text: str, max_new_tokens: int = 300) -> str:
    """Uses the provided pipeline to summarize text."""
    prompt = (
        f"<start_of_turn>user\nSummarize the following section of a document into a concise paragraph, capturing the key subjects and specific details (e.g., dates, names, locations) to aid in information retrieval.\n\nText:\n{text}\n\nSummary:<end_of_turn>\n<start_of_turn>model\n"
    )

    outputs = pipe(
        prompt,
        max_new_tokens=max_new_tokens,
        do_sample=False,
    )
    generated = outputs[0]["generated_text"]

    # Post-processing
    summary = generated.split("<start_of_turn>model\n", 1)[-1].strip()
    if "<end_of_turn>" in summary:
        summary = summary.split("<end_of_turn>", 1)[0].strip()

    return summary

def build_hierarchical_indices(
    detailed_chunks: List[Document],
    group_size: int = SUMMARY_GROUP_SIZE,
) -> Tuple[FAISS, FAISS]:

    # 2a. Load Summarizer (Gemma-2B)
    summarizer_pipe, summarizer_tokenizer, summarizer_model = get_summarizer_pipe()

    summary_texts = []
    summary_metadatas = []

    print(f"[INFO] Building hierarchical indices with group_size={group_size}...")

    num_chunks = len(detailed_chunks)
    num_groups = ceil(num_chunks / group_size)

    for g in range(num_groups):
        start = g * group_size
        end = min((g + 1) * group_size, num_chunks)
        group = detailed_chunks[start:end]

        combined_text = "\n\n".join([c.page_content for c in group])

        print(f"[INFO] Summarizing group {g+1}/{num_groups} (chunks {start}..{end-1})...")
        summary = summarize_with_pipe(summarizer_pipe, summarizer_tokenizer, combined_text)

        chunk_ids = list(range(start, end))
        meta = {
            "summary_id": g,
            "source": group[0].metadata.get("source", "unknown"),
            "chunk_ids": chunk_ids,
            "detailed_pages": sorted(list(set(c.metadata['page_number'] for c in group)))
        }

        summary_texts.append(summary)
        summary_metadatas.append(meta)

    # 2b. Unload Summarizer (Gemma-2B)
    del summarizer_pipe
    del summarizer_tokenizer
    del summarizer_model
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
    gc.collect()
    print(f"[INFO] Summarization Model ({SUMMARY_MODEL_NAME}) unloaded.")

    # 2c. Build FAISS stores
    print("[INFO] Building FAISS summary store...")
    summary_store = FAISS.from_texts(
        texts=summary_texts,
        embedding=embeddings,
        metadatas=summary_metadatas,
    )

    print("[INFO] Building FAISS detailed store...")
    detailed_texts = [c.page_content for c in detailed_chunks]
    detailed_metadatas = [c.metadata for c in detailed_chunks]

    detailed_store = FAISS.from_texts(
        texts=detailed_texts,
        embedding=embeddings,
        metadatas=detailed_metadatas,
    )

    print("[INFO] Hierarchical indices built.")
    return summary_store, detailed_store

# Execute index building and unload Gemma-2B
summary_store, detailed_store = build_hierarchical_indices(detailed_chunks)

## 3. Load RAG Model (Gemma-7B)

print(f"[INFO] Loading RAG Model: {RAG_MODEL_NAME} for chat/Q&A...")
rag_tokenizer = AutoTokenizer.from_pretrained(RAG_MODEL_NAME)
rag_model = AutoModelForCausalLM.from_pretrained(
    RAG_MODEL_NAME,
    device_map="auto",
    torch_dtype=torch.bfloat16,
)

rag_pipe = pipeline(
    "text-generation",
    model=rag_model,
    tokenizer=rag_tokenizer,
    max_new_tokens=512, # Increase for better chat responses
    do_sample=False,
)

print(f"[INFO] RAG Model ({RAG_MODEL_NAME}) pipeline ready. Proceed to chat.")


## 4. Hierarchical Retrieval Function

def hierarchical_retrieve(
    query: str,
    summary_store: FAISS,
    detailed_store: FAISS,
    top_k_summaries: int = TOP_K_SUMMARIES,
    top_k_final_chunks: int = TOP_K_FINAL_CHUNKS,
    detailed_k_search: int = DETAILED_K_SEARCH,
) -> List[Document]:
    """Hierarchical retrieval: Coarse search on summaries, then filtered fine-grained search."""
    print(f"[RETRIEVAL] Performing hierarchical retrieval for query: '{query[:50]}...'")

    summary_docs = summary_store.similarity_search(query, k=top_k_summaries)

    allowed_chunk_ids = set()
    for sdoc in summary_docs:
        chunk_ids = sdoc.metadata.get("chunk_ids", [])
        allowed_chunk_ids.update(chunk_ids)

    if not allowed_chunk_ids:
        print("[WARN] No relevant summaries found. Falling back to flat search.")
        return detailed_store.similarity_search(query, k=top_k_final_chunks)

    candidate_detailed = detailed_store.similarity_search(query, k=detailed_k_search)
    filtered = []
    for d in candidate_detailed:
        cid = d.metadata.get("chunk_id")
        if cid in allowed_chunk_ids:
            filtered.append(d)

    final_docs = filtered[:top_k_final_chunks]

    print(f"[RETRIEVAL] {len(filtered)} chunks matched the summaries. Returning top {len(final_docs)}.")
    return final_docs


## 5. Chatbot and Memory Management

class RAGChatbot:
    def __init__(
        self,
        pipe,
        tokenizer,
        summary_store,
        detailed_store,
        context_window,
        context_reservation
    ):
        self.pipe = pipe
        self.tokenizer = tokenizer
        self.summary_store = summary_store
        self.detailed_store = detailed_store
        self.context_window = context_window
        self.context_reservation = context_reservation

        self.memory: List[str] = []
        self.memory_summary: str = ""

    def _get_current_history(self) -> str:
        """Combines the memory summary and the current memory buffer."""
        history_parts = []
        if self.memory_summary:
            history_parts.append(f"**PREVIOUS CONVERSATION SUMMARY:**\n{self.memory_summary}\n")

        history_parts.extend(self.memory)
        return "\n".join(history_parts)

    def _condense_memory(self):
        """Condenses the current memory buffer into a summary using the RAG model."""
        full_history = self._get_current_history()

        if len(self.tokenizer.encode(full_history, return_tensors='pt')[0]) < self.context_reservation:
            print("[MEMORY] History is short, skipping condensation.")
            return

        print("[MEMORY] Condensing conversation history...")

        prompt = (
            f"<start_of_turn>user\nCondense the following conversation history into a single, comprehensive paragraph that preserves all factual details and context. This summary will be used to answer future questions.\n\nCONVERSATION HISTORY:\n{full_history}\n\nCONDENSED SUMMARY:<end_of_turn>\n<start_of_turn>model\n"
        )

        prompt_len = len(self.tokenizer.encode(prompt, return_tensors='pt')[0])
        max_new_tokens = min(256, self.context_window - prompt_len - 10)

        if max_new_tokens <= 0:
            print("[WARN] Conversation history too long for condensation. Resetting memory.")
            self.memory_summary = ""
        else:
            outputs = self.pipe(
                prompt,
                max_new_tokens=max_new_tokens,
                do_sample=False
            )
            generated = outputs[0]["generated_text"]
            condensed = generated.split("<start_of_turn>model\n", 1)[-1].strip()
            if "<end_of_turn>" in condensed:
                condensed = condensed.split("<end_of_turn>", 1)[0].strip()

            self.memory_summary = condensed
            self.memory = []
            print("[MEMORY] History condensed successfully.")

    def _build_final_prompt(
        self,
        query: str,
        retrieved_docs: List[Document],
        chat_history_str: str,
    ) -> str:
        """Formats the final prompt using chat history and RAG context."""
        context = []
        for doc in retrieved_docs:
            meta = doc.metadata
            context.append(
                f"--Source (Page {meta.get('page_number', 'N/A')}, Chunk {meta.get('chunk_id', 'N/A')}): \n{doc.page_content.strip()}\n"
            )
        context_str = "\n".join(context)

        system_prompt = (
            "You are an expert RAG assistant. Your task is to answer the user's question \n"
            "based ONLY on the provided context and the conversation history. \n"
            "If the answer cannot be found in the provided context or history, state that you don't know. \n"
            "Always cite the source pages from the context, e.g., [Page X], at the end of the sentence where the fact is mentioned. \n"
        )

        prompt_template = (
            f"<start_of_turn>user\n{system_prompt}\n\n"
            f"**CONVERSATION HISTORY:**\n{chat_history_str}\n\n"
            f"**RETRIEVED CONTEXT:**\n{context_str}\n\n"
            f"**USER QUESTION:** {query}\n\n"
            f"Answer:<end_of_turn>\n<start_of_turn>model\n"
        )

        return prompt_template

    def chat(self, query: str):
        # 1. Retrieve context chunks for the query
        retrieved_docs = hierarchical_retrieve(
            query=query,
            summary_store=self.summary_store,
            detailed_store=self.detailed_store
        )

        # 2. Memory Management Loop
        while True:
            chat_history_str = self._get_current_history()
            final_prompt = self._build_final_prompt(query, retrieved_docs, chat_history_str)

            prompt_tokens = len(self.tokenizer.encode(final_prompt, return_tensors='pt')[0])
            response_max_tokens = 512
            total_length = prompt_tokens + response_max_tokens

            if total_length < self.context_window - 50:
                break
            else:
                if self.memory:
                    self._condense_memory()
                else:
                    if self.memory_summary:
                        print("[FATAL WARNING] Condensed summary is too large. Clearing summary.")
                        self.memory_summary = ""
                    else:
                        print("[ERROR] Final prompt exceeds context window even with no history. Reduce TOP_K or CHUNK_SIZE.")
                        return "Error: Context too large for model. Please re-run with smaller TOP_K or CHUNK_SIZE."

        # 3. Generate response
        print(f"[LLM] Total prompt tokens: {prompt_tokens}. Generating response...")
        outputs = self.pipe(
            final_prompt,
            max_new_tokens=response_max_tokens,
            do_sample=False,
        )
        generated = outputs[0]["generated_text"]

        response = generated.split("<start_of_turn>model\n", 1)[-1].strip()
        if "<end_of_turn>" in response:
            response = response.split("<end_of_turn>", 1)[0].strip()

        # 4. Update memory
        self.memory.append(f"**USER:** {query}")
        self.memory.append(f"**ASSISTANT:** {response}")

        print("\n" + "=" * 70 + "\n")
        print(f"[FINAL RESPONSE]\n{response}")
        print("\n" + "=" * 70 + "\n")

        if self.memory_summary:
             print(f"[MEMORY STATUS] Condensed Summary is active: '{self.memory_summary[:100]}...'")
        print(f"[MEMORY STATUS] Current Buffer Turns: {len(self.memory) // 2}\n")

        return response

# Initialize the Chatbot
chatbot = RAGChatbot(
    pipe=rag_pipe,
    tokenizer=rag_tokenizer,
    summary_store=summary_store,
    detailed_store=detailed_store,
    context_window=RAG_CONTEXT_WINDOW,
    context_reservation=CONTEXT_RESERVATION
)


[SETUP] Ensure your PDF (ShaastraContextDoc.pdf) is uploaded to /content/docs
[INFO] Loaded 45 page-level docs from /content/docs/ShaastraContextDoc.pdf
[INFO] Created 183 detailed chunks.
[INFO] Loading embedding model...


  embeddings = HuggingFaceEmbeddings(model_name=EMBED_MODEL)


modules.json:   0%|          | 0.00/349 [00:00<?, ?B/s]

config_sentence_transformers.json:   0%|          | 0.00/116 [00:00<?, ?B/s]

README.md: 0.00B [00:00, ?B/s]

sentence_bert_config.json:   0%|          | 0.00/53.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/571 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/438M [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/363 [00:00<?, ?B/s]

vocab.txt: 0.00B [00:00, ?B/s]

tokenizer.json: 0.00B [00:00, ?B/s]

special_tokens_map.json:   0%|          | 0.00/239 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/190 [00:00<?, ?B/s]

[INFO] Loading Summarization Model: google/gemma-2-2b-it...


tokenizer_config.json:   0%|          | 0.00/47.0k [00:00<?, ?B/s]

tokenizer.model:   0%|          | 0.00/4.24M [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/17.5M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/636 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/838 [00:00<?, ?B/s]

`torch_dtype` is deprecated! Use `dtype` instead!


model.safetensors.index.json:   0%|          | 0.00/24.2k [00:00<?, ?B/s]

Fetching 2 files:   0%|          | 0/2 [00:00<?, ?it/s]

model-00001-of-00002.safetensors:   0%|          | 0.00/4.99G [00:00<?, ?B/s]

model-00002-of-00002.safetensors:   0%|          | 0.00/241M [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/187 [00:00<?, ?B/s]

Device set to use cuda:0
The following generation flags are not valid and may be ignored: ['temperature']. Set `TRANSFORMERS_VERBOSITY=info` for more details.


[INFO] Building hierarchical indices with group_size=5...
[INFO] Summarizing group 1/37 (chunks 0..4)...
[INFO] Summarizing group 2/37 (chunks 5..9)...
[INFO] Summarizing group 3/37 (chunks 10..14)...
[INFO] Summarizing group 4/37 (chunks 15..19)...
[INFO] Summarizing group 5/37 (chunks 20..24)...
[INFO] Summarizing group 6/37 (chunks 25..29)...
[INFO] Summarizing group 7/37 (chunks 30..34)...
[INFO] Summarizing group 8/37 (chunks 35..39)...
[INFO] Summarizing group 9/37 (chunks 40..44)...
[INFO] Summarizing group 10/37 (chunks 45..49)...


You seem to be using the pipelines sequentially on GPU. In order to maximize efficiency please use a dataset


[INFO] Summarizing group 11/37 (chunks 50..54)...
[INFO] Summarizing group 12/37 (chunks 55..59)...
[INFO] Summarizing group 13/37 (chunks 60..64)...
[INFO] Summarizing group 14/37 (chunks 65..69)...
[INFO] Summarizing group 15/37 (chunks 70..74)...
[INFO] Summarizing group 16/37 (chunks 75..79)...
[INFO] Summarizing group 17/37 (chunks 80..84)...
[INFO] Summarizing group 18/37 (chunks 85..89)...
[INFO] Summarizing group 19/37 (chunks 90..94)...
[INFO] Summarizing group 20/37 (chunks 95..99)...
[INFO] Summarizing group 21/37 (chunks 100..104)...
[INFO] Summarizing group 22/37 (chunks 105..109)...
[INFO] Summarizing group 23/37 (chunks 110..114)...
[INFO] Summarizing group 24/37 (chunks 115..119)...
[INFO] Summarizing group 25/37 (chunks 120..124)...
[INFO] Summarizing group 26/37 (chunks 125..129)...
[INFO] Summarizing group 27/37 (chunks 130..134)...
[INFO] Summarizing group 28/37 (chunks 135..139)...
[INFO] Summarizing group 29/37 (chunks 140..144)...
[INFO] Summarizing group 30/37 (

tokenizer_config.json:   0%|          | 0.00/34.2k [00:00<?, ?B/s]

tokenizer.model:   0%|          | 0.00/4.24M [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/17.5M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/636 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/694 [00:00<?, ?B/s]

model.safetensors.index.json:   0%|          | 0.00/20.9k [00:00<?, ?B/s]

Fetching 4 files:   0%|          | 0/4 [00:00<?, ?it/s]

model-00001-of-00004.safetensors:   0%|          | 0.00/5.00G [00:00<?, ?B/s]

model-00004-of-00004.safetensors:   0%|          | 0.00/2.11G [00:00<?, ?B/s]

model-00003-of-00004.safetensors:   0%|          | 0.00/4.98G [00:00<?, ?B/s]

model-00002-of-00004.safetensors:   0%|          | 0.00/4.98G [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/137 [00:00<?, ?B/s]

Device set to use cuda:0


[INFO] RAG Model (google/gemma-7b-it) pipeline ready. Proceed to chat.


In [5]:
print("\n--- Starting Chat Session ---\n")
print("--- User 1 ---")
chatbot.chat("Why should anyone attend shaastra?")
print("\n--- User 2 (Follow-up) ---")
chatbot.chat("What are the fun activities to do in IIT Madras campus?")
print("\n--- User 3 (New Topic) ---")
chatbot.chat("Why is shaastra popular?")
print("\n--- User 4 (New Topic) ---")
chatbot.chat("What are the best events to attend in shaastra?")


--- Starting Chat Session ---

--- User 1 ---
[RETRIEVAL] Performing hierarchical retrieval for query: 'Why should anyone attend shaastra?...'
[RETRIEVAL] 15 chunks matched the summaries. Returning top 5.
[LLM] Total prompt tokens: 1135. Generating response...


[FINAL RESPONSE]
The text does not provide information about why anyone should attend Shaastra, therefore I cannot answer the user's question.


[MEMORY STATUS] Current Buffer Turns: 4


--- User 2 (Follow-up) ---
[RETRIEVAL] Performing hierarchical retrieval for query: 'What are the fun activities to do in IIT Madras ca...'
[RETRIEVAL] 16 chunks matched the summaries. Returning top 5.
[LLM] Total prompt tokens: 1163. Generating response...


[FINAL RESPONSE]
The text does not mention any fun activities to do in the IIT Madras campus, therefore I cannot answer the user's question.


[MEMORY STATUS] Current Buffer Turns: 5


--- User 3 (New Topic) ---
[RETRIEVAL] Performing hierarchical retrieval for query: 'Why is shaastra pop

"The text does not provide information about the best events to attend in Shaastra, therefore I cannot answer the user's question."