# RAG pipeline with Chroma, Sentence-Transformers, and FLAN-T5 (Colab-ready)

This Colab notebook implements:
- OCR-aware PDF extraction (handles flattened PDFs) using `pdfplumber` + `pytesseract`.
- Text extraction from DOCX and images.
- Chunking and per-document JSON export.
- Embeddings with `sentence-transformers` (recommended: `all-mpnet-base-v2`).
- Vector storage in Chroma (local persistent folder).
- Retrieval (query → embedding → top-K).
- RAG answer synthesis using `google/flan-t5-base`.
- Optional exact-span extraction using a QA model.

**How to use**
1. Open this notebook in Google Colab.
2. (Optional) Change runtime to GPU for faster generation: Runtime → Change runtime type → GPU.
3. Run cells sequentially. Upload files when prompted (or mount Google Drive and modify paths).
4. Use the `rag_answer` function to ask questions.



In [1]:
# Install required packages (Colab)
!pip install -q sentence-transformers chromadb pdfplumber python-docx pytesseract pillow transformers accelerate sentencepiece
# install tesseract binary (Colab / Debian) and poppler-utils for PDF rasterization
!apt-get update -qq && apt-get install -y -qq tesseract-ocr libtesseract-dev poppler-utils


[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m43.6/43.6 kB[0m [31m2.8 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m67.3/67.3 kB[0m [31m2.6 MB/s[0m eta [36m0:00:00[0m
[?25h  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m67.7/67.7 kB[0m [31m2.5 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m21.4/21.4 MB[0m [31m73.8 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m60.0/60.0 kB[0m [31m3.4 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m5.6/5.6 MB[0m [31m67.3 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m253.0/253.0 kB[0m [31m14.5 MB/s[0m eta [36m

In [2]:
# Imports, paths, and settings
import os, json, uuid
from pathlib import Path
from tqdm import tqdm

# Text extraction libs
import pdfplumber, docx
from PIL import Image, ImageFilter, ImageOps
import pytesseract

# Embeddings
from sentence_transformers import SentenceTransformer

# Chroma
import chromadb
from chromadb.config import Settings

# Generation & QA
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
import torch

# Folders
DB_DIR = "/content/chroma_db"
JSON_OUTPUT_DIR = "/content/doc_jsons"
os.makedirs(DB_DIR, exist_ok=True)
os.makedirs(JSON_OUTPUT_DIR, exist_ok=True)

# Embedding model (choose accuracy vs speed)
EMBEDDING_MODEL_NAME = "all-mpnet-base-v2"   # recommended for accuracy
# EMBEDDING_MODEL_NAME = "all-MiniLM-L6-v2"  # alternative (faster)

# Generation model (RAG fusion / answer synthesis)
GEN_MODEL_NAME = "google/flan-t5-base"  # use GPU for larger models

# QA span-extraction model (optional)
QA_MODEL_NAME = "deepset/roberta-base-squad2"

# Tesseract settings
pytesseract.pytesseract.tesseract_cmd = "/usr/bin/tesseract"  # colab default
OCR_LANG = "eng"
TESSERACT_COMMON_CONFIG = f"--oem 1 --psm 6 -l {OCR_LANG}"

print("Embedding model:", EMBEDDING_MODEL_NAME)
print("Generation model:", GEN_MODEL_NAME)
print("DB_DIR:", DB_DIR)


Embedding model: all-mpnet-base-v2
Generation model: google/flan-t5-base
DB_DIR: /content/chroma_db


In [3]:
# OCR-aware PDF extractor and helpers

def preprocess_pil_image_for_ocr(pil_img):
    img = pil_img.convert("L")
    img = ImageOps.autocontrast(img, cutoff=1)
    img = img.filter(ImageFilter.MedianFilter(size=3))
    return img

def extract_text_from_pdf(pdf_path, ocr_if_needed=True, min_text_len_for_layer=50, dpi=300):
    all_text = []
    with pdfplumber.open(pdf_path) as pdf:
        for page_index, page in enumerate(pdf.pages):
            try:
                text = page.extract_text()
            except Exception as e:
                print(f"page.extract_text() error on page {page_index+1}: {e}")
                text = None

            if (not text or len(text.strip()) < min_text_len_for_layer) and ocr_if_needed:
                print(f"Page {page_index+1}: Running OCR at {dpi} DPI...")
                try:
                    page_image = page.to_image(resolution=dpi).original
                except Exception as e:
                    print(f"page.to_image error fallback for page {page_index+1}: {e}")
                    page_image = page.to_image().original
                page_image = preprocess_pil_image_for_ocr(page_image)
                try:
                    text = pytesseract.image_to_string(page_image, config=TESSERACT_COMMON_CONFIG)
                except Exception as e:
                    print(f"Tesseract OCR failed on page {page_index+1}: {e}")
                    text = ""
            if not text:
                text = ""
            all_text.append(f"\\n--- PAGE {page_index+1} ---\\n{text}")
    return "\\n".join(all_text)

def extract_text_from_docx(path):
    try:
        doc = docx.Document(path)
        paragraphs = [p.text for p in doc.paragraphs if p.text and p.text.strip()]
        return "\\n".join(paragraphs)
    except Exception as e:
        print("DOCX read error:", e)
        return ""

def extract_text_from_image(path):
    try:
        img = Image.open(path)
        img = preprocess_pil_image_for_ocr(img)
        return pytesseract.image_to_string(img, config=TESSERACT_COMMON_CONFIG)
    except Exception as e:
        print("Image OCR error:", e)
        return ""


In [4]:
# Upload files (Colab interactive)
from google.colab import files
uploaded = files.upload()
uploaded_filenames = list(uploaded.keys())
print("Uploaded:", uploaded_filenames)

Saving embbedings.docx to embbedings.docx
Uploaded: ['embbedings.docx']


In [22]:
# Chunking, process files to JSON and prepare records
def chunk_text(text, chunk_size=80, overlap=20):
    tokens = text.split()
    chunks = []
    i = 0
    while i < len(tokens):
        chunk_tokens = tokens[i:i+chunk_size]
        chunks.append(" ".join(chunk_tokens))
        i += chunk_size - overlap
    return chunks

def process_and_export(files_list):
    records = []
    for fname in files_list:
        path = Path(fname)
        ext = path.suffix.lower()
        if ext == ".pdf":
            raw = extract_text_from_pdf(str(path))
        elif ext in [".docx", ".doc"]:
            raw = extract_text_from_docx(str(path))
        elif ext in [".png", ".jpg", ".jpeg", ".tif", ".tiff", ".bmp"]:
            raw = extract_text_from_image(str(path))
        else:
            print("Skipping unsupported:", fname)
            continue

        if not raw or len(raw.strip()) == 0:
            print("No text extracted for:", fname)
            continue

        metadata = {"source_filename": fname, "id": str(uuid.uuid4()), "n_chars": len(raw)}
        chunks = chunk_text(raw, chunk_size=80, overlap=20)
        doc_json = {"metadata": metadata, "full_text": raw, "chunks": []}
        for idx, chunk in enumerate(chunks):
            chunk_id = f"{metadata['id']}_chunk_{idx}"
            doc_json["chunks"].append({"chunk_id": chunk_id, "text": chunk, "chunk_index": idx})
            records.append({"id": chunk_id, "text": chunk, "metadata": {**metadata, "chunk_index": idx}})
        outpath = Path(JSON_OUTPUT_DIR) / (path.stem + ".json")
        with open(outpath, "w", encoding="utf-8") as f:
            json.dump(doc_json, f, ensure_ascii=False, indent=2)
        print("Wrote JSON:", outpath)
    return records

# Run processing on uploaded files
records = process_and_export(uploaded_filenames)
print("Total chunks prepared:", len(records))

Wrote JSON: /content/doc_jsons/embbedings.json
Total chunks prepared: 6


In [23]:
# Load embedding model and create embeddings
embed_model = SentenceTransformer(EMBEDDING_MODEL_NAME)
texts = [r["text"] for r in records]
ids = [r["id"] for r in records]
metadatas = [r["metadata"] for r in records]

BATCH = 64
embeddings = []
for i in tqdm(range(0, len(texts), BATCH), desc="Embedding"):
    batch_texts = texts[i:i+BATCH]
    embs = embed_model.encode(batch_texts, show_progress_bar=False, convert_to_numpy=True)
    embeddings.extend(embs)
print("Embeddings vectors:", len(embeddings))

Embedding: 100%|██████████| 1/1 [00:00<00:00,  4.69it/s]

Embeddings vectors: 6





In [24]:
# Create Chroma DB and insert vectors
client = chromadb.PersistentClient(path=DB_DIR, settings=Settings())
collection_name = "customer_docs"
try:
    collection = client.get_or_create_collection(collection_name)
except Exception:
    collection = client.create_collection(name=collection_name)

# prepare embeddings to lists of floats (Chroma expects python floats)
vecs = [e.tolist() if hasattr(e, "tolist") else list(map(float, e)) for e in embeddings]

collection.add(ids=ids, metadatas=metadatas, documents=texts, embeddings=vecs)
print("Chroma persisted at", DB_DIR)

Chroma persisted at /content/chroma_db


In [25]:
# Retrieval helper
def search_chroma(query_text, top_k=3):
    q_emb = embed_model.encode([query_text], convert_to_numpy=True)[0].astype(float).tolist()
    results = collection.query(query_embeddings=[q_emb], n_results=top_k, include=["documents","metadatas","distances"])
    hits = []
    for i in range(len(results["ids"][0])):
        hits.append({
            "id": results["ids"][0][i],
            "document": results["documents"][0][i],
            "metadata": results["metadatas"][0][i],
            "distance": results["distances"][0][i]
        })
    return hits

In [26]:
# Load generation model (FLAN-T5) for RAG fusion
device = "cuda" if torch.cuda.is_available() else "cpu"
print("Generation device:", device)

gen_tokenizer = AutoTokenizer.from_pretrained(GEN_MODEL_NAME)
gen_model = AutoModelForSeq2SeqLM.from_pretrained(GEN_MODEL_NAME).to(device)

def tokenize_len(text):
    return len(gen_tokenizer.encode(text, truncation=False))
def truncate_context_by_tokens(chunks_texts, max_input_tokens, sep="\n\n---\n\n"):
    out = []
    tokens = 0
    for t in chunks_texts:
        t_tokens = len(gen_tokenizer.encode(t + sep, truncation=False))
        if tokens + t_tokens > max_input_tokens:
            break
        out.append(t)
        tokens += t_tokens
    return sep.join(out)

Generation device: cuda


In [32]:
# Prompt template and RAG answer function
PROMPT_SYSTEM = """You are a helpful assistant. Answer the user question using ONLY the CONTEXT provided.
If the context does not contain enough information to answer, respond: INSUFFICIENT_CONTEXT.
Be concise (max ~120 words). Provide the answer and then list the sources (filename and chunk index)."""

PROMPT_USER_TEMPLATE = """
CONTEXT:
{context}

QUESTION:
{question}

INSTRUCTIONS:
1) Give a short direct answer (<120 words).
2) After the answer print a "SOURCES:" section listing each source as - filename (chunk_index).
3) If you can't answer from the context, respond exactly: INSUFFICIENT_CONTEXT
4) Do not provide answers exactly in json format of in any strucutred manner, provide them in natural language by first processing it.
"""

def build_prompt(context, question):
    return PROMPT_SYSTEM + "\n\n" + PROMPT_USER_TEMPLATE.format(context=context, question=question)

def rag_answer(question, top_k=3, max_context_tokens=1500, max_answer_tokens=180):
    # 1. Retrieve
    hits = search_chroma(question, top_k=top_k)
    if not hits:
        return {"question": question, "answer": "INSUFFICIENT_CONTEXT", "provenance": [], "used_context": ""}

    # 2. Build entries with provenance
    entries = []
    provenance = []
    for h in hits:
        src = h['metadata'].get('source_filename', 'unknown')
        idx = h['metadata'].get('chunk_index', -1)
        entry_text = f"[{src} | chunk {idx}]\n{h['document']}"
        entries.append(entry_text)
        provenance.append({"source": src, "chunk_index": idx, "distance": h['distance'], "id": h['id']})

    # 3. Truncate context to token budget
    context = truncate_context_by_tokens(entries, max_context_tokens, sep="\n\n---\n\n")
    prompt = build_prompt(context, question)

    # 4. Tokenize & generate
    inputs = gen_tokenizer(prompt, return_tensors="pt", truncation=True, max_length=2048).to(device)
    outputs = gen_model.generate(
        **inputs,
        max_new_tokens=max_answer_tokens,
        num_beams=4,
        do_sample=False,
        early_stopping=True
    )
    answer = gen_tokenizer.decode(outputs[0], skip_special_tokens=True).strip()

    if not answer:
        answer = "INSUFFICIENT_CONTEXT"

    return {"question": question, "answer": answer, "provenance": provenance, "used_context": context}

In [33]:
# Optional: QA span extraction for exact fields
qa_pipeline = pipeline("question-answering", model=QA_MODEL_NAME, tokenizer=QA_MODEL_NAME, device=0 if torch.cuda.is_available() else -1)

def extract_exact_span(question, top_hit):
    context = top_hit['document']
    res = qa_pipeline(question=question, context=context, top_k=1)
    return res

Device set to use cuda:0


In [36]:
# Examples / Usage

# Retrieval-only example:
q = "What is the borrower’s Social Security Number?"
hits = search_chroma(q, top_k=3)
for i,h in enumerate(hits,1):
    print(f"HIT {i}: {h['metadata']['source_filename']} chunk {h['metadata']['chunk_index']}")
    print(h['document'][:500].replace("\n"," "))
    print("distance:", h['distance'])
    print("-----\n")

# RAG: generate concise, grounded answer
out = rag_answer(q, top_k=3)
print("ANSWER:\n", out["answer"])
print("\nPROVENANCE:\n", out["provenance"])

# Optional: exact span from top hit
if hits:
    span = extract_exact_span(q, hits[0])
    print("\nExact span extracted by QA model:", span)

HIT 1: embbedings.docx chunk 0
{\n "borrower": {\n "name": "JACK MERIDITH SPECTOR",\n "social_security_number": "15-52-556",\n "date_of_birth": "02/02/2000",\n "citizenship": "U.S. Citizen",\n "alternate_names": ["JACK MERIDITH SPECTOR"],\n "marital_status": "Unmarried",\n "dependents": {\n "number": 1,\n "ages": [25]\n },\n "contact_information": {\n "home_phone": "252-252-3564",\n "cell_phone": "252-252-6543",\n "work_phone": "252-262-3654 x5635",\n "email": "JACK@GMAIL.COM"\n },\n "current_address": {\n "street": "BAKER'S 
distance: 1.1984632015228271
-----

HIT 2: embbedings.docx chunk 0
{\n "borrower": {\n "name": "JACK MERIDITH SPECTOR",\n "social_security_number": "15-52-556",\n "date_of_birth": "02/02/2000",\n "citizenship": "U.S. Citizen",\n "alternate_names": ["JACK MERIDITH SPECTOR"],\n "marital_status": "Unmarried",\n "dependents": {\n "number": 1,\n "ages": [25]\n },\n "contact_information": {\n "home_phone": "252-252-3564",\n "cell_phone": "252-252-6543",\n "work_phone": 

In [42]:
# Examples / Usage

# Retrieval-only example:
q = "What is the borrower’s base income?"
hits = search_chroma(q, top_k=3)
for i,h in enumerate(hits,1):
    print(f"HIT {i}: {h['metadata']['source_filename']} chunk {h['metadata']['chunk_index']}")
    print(h['document'][:500].replace("\n"," "))
    print("distance:", h['distance'])
    print("-----\n")

# RAG: generate concise, grounded answer
out = rag_answer(q, top_k=3)
print("ANSWER:\n", out["answer"])
print("\nPROVENANCE:\n", out["provenance"])

# Optional: exact span from top hit
if hits:
    span = extract_exact_span(q, hits[0])
    print("\nExact span extracted by QA model:", span)

HIT 1: embbedings.docx chunk 1
"1 year",\n "employed_by_family_member": false,\n "income": {\n "base": 50000,\n "overtime": 2000,\n "bonus": 1000,\n "commission": 0,\n "military_entitlements": 20000,\n "other": 0,\n "monthly_total_income": 73000\n }\n }\n },\n "assets": {\n "bank_and_investment_accounts": [\n {\n "type": "Individual Development Account",\n "institution": "USA1",\n "value": 10000\n },\n {\n "type": "Cash Value of Life Insurance",\n "institution": "LIC",\n "value": 20000\n }\n ],\n "total_asset_value": 30000\n 
distance: 1.3143130540847778
-----

HIT 2: embbedings.docx chunk 5
"loan_originator_information": {\n "organization_name": "USA BANK",\n "organization_address": "WALL STREET, NEW YORK, USA",\n "organization_nmlsr_id": "25612DE23",\n "organization_state_license": "SDFS5642",\n "originator_name": "MAX VESTAPPEREN",\n "originator_nmlsr_id": "DF23521",\n "originator_state_license": "NY5623",\n "email": "MAX@GMAIL.COM",\n "phone": "231-568-9999",\n "signature_date": "1

In [43]:
# Examples / Usage

# Retrieval-only example:
q = "What is the total monthly income?"
hits = search_chroma(q, top_k=3)
for i,h in enumerate(hits,1):
    print(f"HIT {i}: {h['metadata']['source_filename']} chunk {h['metadata']['chunk_index']}")
    print(h['document'][:500].replace("\n"," "))
    print("distance:", h['distance'])
    print("-----\n")

# RAG: generate concise, grounded answer
out = rag_answer(q, top_k=3)
print("ANSWER:\n", out["answer"])
print("\nPROVENANCE:\n", out["provenance"])

# Optional: exact span from top hit
if hits:
    span = extract_exact_span(q, hits[0])
    print("\nExact span extracted by QA model:", span)

HIT 1: embbedings.docx chunk 2
"position": "MANAGER",\n "start_date": "02/02/2020",\n "experience": "1 year",\n "employed_by_family_member": false,\n "income": {\n "base": 50000,\n "overtime": 2000,\n "bonus": 1000,\n "commission": 0,\n "military_entitlements": 20000,\n "other": 0,\n "monthly_total_income": 73000\n }\n }\n },\n "assets": {\n "bank_and_investment_accounts": [\n {\n "type": "Individual Development Account",\n "institution": "USA1",\n "value": 10000\n },\n {\n "type": "Cash Value of Life Insurance",\n "instituti
distance: 1.245530366897583
-----

HIT 2: embbedings.docx chunk 1
"1 year",\n "employed_by_family_member": false,\n "income": {\n "base": 50000,\n "overtime": 2000,\n "bonus": 1000,\n "commission": 0,\n "military_entitlements": 20000,\n "other": 0,\n "monthly_total_income": 73000\n }\n }\n },\n "assets": {\n "bank_and_investment_accounts": [\n {\n "type": "Individual Development Account",\n "institution": "USA1",\n "value": 10000\n },\n {\n "type": "Cash Value of 

In [44]:
# Examples / Usage

# Retrieval-only example:
q = "What is the total asset value?"
hits = search_chroma(q, top_k=3)
for i,h in enumerate(hits,1):
    print(f"HIT {i}: {h['metadata']['source_filename']} chunk {h['metadata']['chunk_index']}")
    print(h['document'][:500].replace("\n"," "))
    print("distance:", h['distance'])
    print("-----\n")

# RAG: generate concise, grounded answer
out = rag_answer(q, top_k=3)
print("ANSWER:\n", out["answer"])
print("\nPROVENANCE:\n", out["provenance"])

# Optional: exact span from top hit
if hits:
    span = extract_exact_span(q, hits[0])
    print("\nExact span extracted by QA model:", span)

HIT 1: embbedings.docx chunk 2
"position": "MANAGER",\n "start_date": "02/02/2020",\n "experience": "1 year",\n "employed_by_family_member": false,\n "income": {\n "base": 50000,\n "overtime": 2000,\n "bonus": 1000,\n "commission": 0,\n "military_entitlements": 20000,\n "other": 0,\n "monthly_total_income": 73000\n }\n }\n },\n "assets": {\n "bank_and_investment_accounts": [\n {\n "type": "Individual Development Account",\n "institution": "USA1",\n "value": 10000\n },\n {\n "type": "Cash Value of Life Insurance",\n "instituti
distance: 1.196429967880249
-----

HIT 2: embbedings.docx chunk 1
"1 year",\n "employed_by_family_member": false,\n "income": {\n "base": 50000,\n "overtime": 2000,\n "bonus": 1000,\n "commission": 0,\n "military_entitlements": 20000,\n "other": 0,\n "monthly_total_income": 73000\n }\n }\n },\n "assets": {\n "bank_and_investment_accounts": [\n {\n "type": "Individual Development Account",\n "institution": "USA1",\n "value": 10000\n },\n {\n "type": "Cash Value of 

You seem to be using the pipelines sequentially on GPU. In order to maximize efficiency please use a dataset


ANSWER:
 30000

PROVENANCE:
 [{'source': 'embbedings.docx', 'chunk_index': 2, 'distance': 1.196429967880249, 'id': '286074e1-4d7d-4ce0-a72c-a236ae5831e9_chunk_2'}, {'source': 'embbedings.docx', 'chunk_index': 1, 'distance': 1.2140611410140991, 'id': '941308fb-ee47-44a3-8fe4-30cddf524581_chunk_1'}, {'source': 'embbedings.docx', 'chunk_index': 3, 'distance': 1.3899924755096436, 'id': '286074e1-4d7d-4ce0-a72c-a236ae5831e9_chunk_3'}]

Exact span extracted by QA model: {'score': 0.8957427574787289, 'start': 561, 'end': 566, 'answer': '30000'}


In [None]:
# Save DB and JSONs for download (optional)
!zip -r -q /content/chroma_db.zip /content/chroma_db
!zip -r -q /content/doc_jsons.zip /content/doc_jsons
print("Zipped at /content/chroma_db.zip and /content/doc_jsons.zip")

Zipped at /content/chroma_db.zip and /content/doc_jsons.zip
