In [None]:
!pip install -q langchain langchain-community langchain-text-splitters chromadb sentence-transformers transformers accelerate pypdf faiss-cpu


[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/67.3 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m67.3/67.3 kB[0m [31m5.1 MB/s[0m eta [36m0:00:00[0m
[?25h  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.5/2.5 MB[0m [31m62.4 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m21.4/21.4 MB[0m [31m41.7 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m328.3/328.3 kB[0m [31m30.6 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m23.7/23.7 MB[0m [31m99.2 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m278.2/278.2 kB[0m [31m26.5 MB/s[0m eta [36m0:00

In [None]:
import os
from typing import List, Dict, Any

from langchain_community.document_loaders import PyPDFLoader
from langchain_text_splitters import RecursiveCharacterTextSplitter
from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain_community.vectorstores import FAISS
from langchain_core.documents import Document

from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
from sentence_transformers import CrossEncoder  # optional later if you want reranking

# ==== CONFIG ====

# Put your PDF(s) in this folder in Colab
DOCS_FOLDER = "/content/docs"
PDF_FILENAME = "ShaastraContextDoc.pdf"   # <-- CHANGE THIS to your actual file name
PDF_PATH = os.path.join(DOCS_FOLDER, PDF_FILENAME)

# Embedding model (better than MiniLM for retrieval)
EMBED_MODEL = "sentence-transformers/all-mpnet-base-v2"

# Summarization LLM (Gemma-2 2B Instruct)
GEMMA_MODEL_NAME = "google/gemma-2-2b-it"

# Chunking params
DETAILED_CHUNK_SIZE = 600
DETAILED_CHUNK_OVERLAP = 50

# How many detailed chunks to group per summary
SUMMARY_GROUP_SIZE = 5

In [None]:
def load_pdf_as_docs(pdf_path: str) -> List[Document]:
    if not os.path.exists(pdf_path):
        raise FileNotFoundError(f"PDF not found at: {pdf_path}")

    loader = PyPDFLoader(pdf_path)
    docs = loader.load()  # one Document per page
    for i, d in enumerate(docs):
        d.metadata["page_number"] = d.metadata.get("page", i)
        d.metadata["source"] = os.path.basename(pdf_path)
    print(f"[INFO] Loaded {len(docs)} page-level docs from {pdf_path}")
    return docs

docs = load_pdf_as_docs(PDF_PATH)


[INFO] Loaded 45 page-level docs from /content/docs/ShaastraContextDoc.pdf


In [None]:
def make_detailed_chunks(
    docs: List[Document],
    chunk_size: int = DETAILED_CHUNK_SIZE,
    chunk_overlap: int = DETAILED_CHUNK_OVERLAP,
) -> List[Document]:
    splitter = RecursiveCharacterTextSplitter(
        chunk_size=chunk_size,
        chunk_overlap=chunk_overlap,
        separators=["\n\n", "\n", ". ", " ", ""],
    )
    chunks = splitter.split_documents(docs)

    for idx, c in enumerate(chunks):
        c.metadata["chunk_id"] = idx
        c.metadata["source"] = c.metadata.get("source", "unknown")
        # keep page_number if present
        if "page_number" not in c.metadata:
            c.metadata["page_number"] = c.metadata.get("page", None)
    print(f"[INFO] Created {len(chunks)} detailed chunks.")
    return chunks

detailed_chunks = make_detailed_chunks(docs)


[INFO] Created 183 detailed chunks.


In [None]:
embeddings = HuggingFaceEmbeddings(model_name=EMBED_MODEL)
print("[INFO] Embedding model loaded:", EMBED_MODEL)


[INFO] Embedding model loaded: sentence-transformers/all-mpnet-base-v2


In [None]:
print("[INFO] Loading Gemma-2 2B Instruct for summarization...")

gemma_tokenizer = AutoTokenizer.from_pretrained(GEMMA_MODEL_NAME)
gemma_model = AutoModelForCausalLM.from_pretrained(
    GEMMA_MODEL_NAME,
    device_map="auto",
    torch_dtype="auto",
)

gemma_pipe = pipeline(
    "text-generation",
    model=gemma_model,
    tokenizer=gemma_tokenizer,
    max_new_tokens=256,
    do_sample=False,
)

print("[INFO] Gemma-2 summarization pipeline ready.")


[INFO] Loading Gemma-2 2B Instruct for summarization...


tokenizer_config.json:   0%|          | 0.00/47.0k [00:00<?, ?B/s]

tokenizer.model:   0%|          | 0.00/4.24M [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/17.5M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/636 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/838 [00:00<?, ?B/s]

`torch_dtype` is deprecated! Use `dtype` instead!


model.safetensors.index.json:   0%|          | 0.00/24.2k [00:00<?, ?B/s]

Fetching 2 files:   0%|          | 0/2 [00:00<?, ?it/s]

model-00002-of-00002.safetensors:   0%|          | 0.00/241M [00:00<?, ?B/s]

model-00001-of-00002.safetensors:   0%|          | 0.00/4.99G [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/187 [00:00<?, ?B/s]

Device set to use cuda:0
The following generation flags are not valid and may be ignored: ['temperature']. Set `TRANSFORMERS_VERBOSITY=info` for more details.


[INFO] Gemma-2 summarization pipeline ready.


In [None]:
def summarize_with_gemma(text: str, max_new_tokens: int = 256) -> str:
    """
    Use Gemma-2 2B to summarize a long text segment into a short,
    informative summary suitable for hierarchical retrieval.
    """
    prompt = (
        "You are a helpful assistant. Summarize the following section of a larger document "
        "into a concise paragraph capturing the key points that would help answer questions about it.\n\n"
        f"Text:\n{text}\n\n"
        "Summary:"
    )

    outputs = gemma_pipe(
        prompt,
        max_new_tokens=max_new_tokens,
        do_sample=False,
    )
    # pipeline returns list of { 'generated_text': ... }
    generated = outputs[0]["generated_text"]
    # Return everything after the last occurrence of "Summary:" if present
    if "Summary:" in generated:
        summary = generated.split("Summary:", 1)[-1].strip()
    else:
        summary = generated.strip()
    return summary


In [None]:
from math import ceil

def build_hierarchical_indices(
    detailed_chunks: List[Document],
    group_size: int = SUMMARY_GROUP_SIZE,
):
    summary_texts = []
    summary_metadatas = []

    print(f"[INFO] Building hierarchical indices with group_size={group_size}...")

    num_chunks = len(detailed_chunks)
    num_groups = ceil(num_chunks / group_size)

    for g in range(num_groups):
        start = g * group_size
        end = min((g + 1) * group_size, num_chunks)
        group = detailed_chunks[start:end]

        combined_text = "\n\n".join([c.page_content for c in group])

        print(f"[INFO] Summarizing group {g+1}/{num_groups} (chunks {start}..{end-1})...")
        summary = summarize_with_gemma(combined_text)

        # metadata: which detailed chunk ids are covered by this summary
        chunk_ids = list(range(start, end))
        meta = {
            "summary_id": g,
            "source": group[0].metadata.get("source", "unknown"),
            "chunk_ids": chunk_ids,
        }

        summary_texts.append(summary)
        summary_metadatas.append(meta)

    # Build FAISS stores
    print("[INFO] Building FAISS summary store...")
    summary_store = FAISS.from_texts(
        texts=summary_texts,
        embedding=embeddings,
        metadatas=summary_metadatas,
    )

    print("[INFO] Building FAISS detailed store...")
    detailed_texts = [c.page_content for c in detailed_chunks]
    detailed_metadatas = [c.metadata for c in detailed_chunks]

    detailed_store = FAISS.from_texts(
        texts=detailed_texts,
        embedding=embeddings,
        metadatas=detailed_metadatas,
    )

    print("[INFO] Hierarchical indices built.")
    return summary_store, detailed_store

summary_store, detailed_store = build_hierarchical_indices(detailed_chunks)


[INFO] Building hierarchical indices with group_size=5...
[INFO] Summarizing group 1/37 (chunks 0..4)...
[INFO] Summarizing group 2/37 (chunks 5..9)...
[INFO] Summarizing group 3/37 (chunks 10..14)...
[INFO] Summarizing group 4/37 (chunks 15..19)...
[INFO] Summarizing group 5/37 (chunks 20..24)...
[INFO] Summarizing group 6/37 (chunks 25..29)...
[INFO] Summarizing group 7/37 (chunks 30..34)...
[INFO] Summarizing group 8/37 (chunks 35..39)...
[INFO] Summarizing group 9/37 (chunks 40..44)...
[INFO] Summarizing group 10/37 (chunks 45..49)...
[INFO] Summarizing group 11/37 (chunks 50..54)...
[INFO] Summarizing group 12/37 (chunks 55..59)...
[INFO] Summarizing group 13/37 (chunks 60..64)...
[INFO] Summarizing group 14/37 (chunks 65..69)...
[INFO] Summarizing group 15/37 (chunks 70..74)...
[INFO] Summarizing group 16/37 (chunks 75..79)...
[INFO] Summarizing group 17/37 (chunks 80..84)...
[INFO] Summarizing group 18/37 (chunks 85..89)...
[INFO] Summarizing group 19/37 (chunks 90..94)...
[INFO

In [None]:
def hierarchical_retrieve(
    query: str,
    summary_store: FAISS,
    detailed_store: FAISS,
    top_k_summaries: int = 5,
    top_k_final_chunks: int = 5,
    detailed_k_search: int = 50,
) -> List[Document]:
    """
    Hierarchical retrieval:
    1. Use summary_store to retrieve coarse relevant regions.
    2. Use detailed_store to retrieve fine-grained chunks.
    3. Filter detailed chunks to those belonging to the coarse regions.
    """
    print(f"\n[INFO] Query: {query}\n")

    # 1. Coarse retrieval
    summary_docs = summary_store.similarity_search(query, k=top_k_summaries)
    print(f"[INFO] Retrieved {len(summary_docs)} summary nodes.\n")

    allowed_chunk_ids = set()
    for i, sdoc in enumerate(summary_docs, start=1):
        meta = sdoc.metadata
        chunk_ids = meta.get("chunk_ids", [])
        allowed_chunk_ids.update(chunk_ids)
        print(f"  Summary {i}: summary_id={meta.get('summary_id')}, covers chunks {chunk_ids[:5]}{'...' if len(chunk_ids) > 5 else ''}")

    if not allowed_chunk_ids:
        print("[WARN] No chunk_ids found in summary metadata; falling back to flat detailed search.")
        return detailed_store.similarity_search(query, k=top_k_final_chunks)

    # 2. Fine-grained retrieval across all detailed chunks
    candidate_detailed = detailed_store.similarity_search(query, k=detailed_k_search)
    print(f"\n[INFO] Retrieved {len(candidate_detailed)} candidate detailed chunks before filtering.")

    # 3. Filter candidates to those under the selected summaries
    filtered = []
    for d in candidate_detailed:
        cid = d.metadata.get("chunk_id")
        if cid in allowed_chunk_ids:
            filtered.append(d)

    print(f"[INFO] {len(filtered)} detailed chunks remain after hierarchical filtering.")

    # 4. Take top-k among filtered; if too few, back off to unfiltered
    final_docs = filtered[:top_k_final_chunks]
    if len(final_docs) < top_k_final_chunks:
        print("[WARN] Not enough filtered chunks; filling from unfiltered candidates.")
        extra_needed = top_k_final_chunks - len(final_docs)
        for d in candidate_detailed:
            if d not in final_docs:
                final_docs.append(d)
                extra_needed -= 1
                if extra_needed <= 0:
                    break

    return final_docs


In [None]:
# You can directly set the query here
query = "Tell me about all the events on day 1"  # <-- change this to test different questions

retrieved_chunks = hierarchical_retrieve(
    query=query,
    summary_store=summary_store,
    detailed_store=detailed_store,
    top_k_summaries=5,
    top_k_final_chunks=5,
    detailed_k_search=50,
)

print("\n================ RETRIEVED CHUNKS ================\n")
for i, doc in enumerate(retrieved_chunks, start=1):
    meta = doc.metadata
    print(f"🔹 Chunk {i}")
    print(f"Source: {meta.get('source', 'unknown')} | page: {meta.get('page_number')} | chunk_id: {meta.get('chunk_id')}")
    print("-" * 70)
    print(doc.page_content.strip())
    print("\n" + "=" * 70 + "\n")



[INFO] Query: Tell me about all the events on day 1

[INFO] Retrieved 5 summary nodes.

  Summary 1: summary_id=0, covers chunks [0, 1, 2, 3, 4]
  Summary 2: summary_id=1, covers chunks [5, 6, 7, 8, 9]
  Summary 3: summary_id=23, covers chunks [115, 116, 117, 118, 119]
  Summary 4: summary_id=3, covers chunks [15, 16, 17, 18, 19]
  Summary 5: summary_id=4, covers chunks [20, 21, 22, 23, 24]

[INFO] Retrieved 50 candidate detailed chunks before filtering.
[INFO] 15 detailed chunks remain after hierarchical filtering.


🔹 Chunk 1
Source: ShaastraContextDoc.pdf | page: 2 | chunk_id: 5
----------------------------------------------------------------------
JMT (Build-a-thon) NAC Hall 7:00-21:00 
Caterpillar Autonomy Challenge KV Grounds 6:00-21:00 
Flipkart Grid Robotics 6.0 Newton Hall 7:00-22:00 
Shaastra Moot Court Finals Kalam Hall (ED Dept) 9:30-13:00 
Brain Maze KV 9:00-18:00 
Product Management Workshop CRC 201 9:00-17:00 
Shaastra Main Quiz CRC 202 9:00-14:00 
Finvent CRC 204 9:00-