<a href="https://colab.research.google.com/github/SreeTatikonda/RAG_projects/blob/main/Multi_Modal_RAG.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!apt-get update
!apt-get install -y openjdk-21-jdk
!update-alternatives --install /usr/bin/java java /usr/lib/jvm/java-21-openjdk-amd64/bin/java 1
!update-alternatives --config java
!java -version


In [None]:
# 0) Basic setup
!pip -q install pymupdf pdfplumber pytesseract sentence-transformers faiss-cpu pyserini pillow transformers accelerate bitsandbytes
import os, json, math, re, io
import fitz  # PyMuPDF
import pdfplumber
from PIL import Image
from sentence_transformers import SentenceTransformer
import numpy as np
import faiss

# Choose small, free models
TEXT_EMB = SentenceTransformer("BAAI/bge-small-en-v1.5")
CAPTION_EMB = SentenceTransformer("intfloat/e5-small")
# For generation, weâ€™ll select later based on GPU


In [None]:
# 1) Parse PDF into text sections and figure crops with captions
def parse_pdf(pdf_path, dpi=200):
    doc = fitz.open(pdf_path)
    pages = []
    for i in range(len(doc)):
        page = doc[i]
        text = page.get_text("blocks")  # preserves layout blocks
        pages.append({"index": i, "blocks": text})
    return doc, pages

def extract_figures_with_captions(pdf_path):
    figures = []
    with pdfplumber.open(pdf_path) as pdf:
        for pi, page in enumerate(pdf.pages):
            # Heuristic: captions often start with "Figure" or "Fig."
            text = page.extract_text() or ""
            # Simple heuristic: find lines starting with Fig/Figure
            caption_candidates = [l for l in (text.split("\n")) if re.match(r"^\s*(Fig\.|Figure)\s*\d+", l)]
            # Rough image detection via page.images (bbox)
            for im in page.images:
                # Crop image
                bbox = (im["x0"], im["top"], im["x1"], im["bottom"])
                crop = page.crop(bbox).to_image(resolution=200)
                img_bytes = io.BytesIO()
                crop.save(img_bytes, format="PNG")
                img_bytes.seek(0)
                # Attach nearest caption line (fallback empty)
                cap = caption_candidates[0] if caption_candidates else ""
                figures.append({
                    "page": pi, "bbox": bbox, "caption": cap, "image_bytes": img_bytes.getvalue()
                })
    return figures


In [None]:
# 2) Build text chunks with overlap and link figure callouts
def blocks_to_paragraphs(pages):
    paras = []
    pid = 0
    for p in pages:
        # Join block text, keep page index
        page_text = "\n".join([b[4] for b in p["blocks"] if isinstance(b[4], str) and b[4].strip()])
        # Split by double newline as a crude paragraph boundary
        for para in re.split(r"\n{2,}", page_text):
            para = para.strip()
            if len(para) > 0:
                paras.append({"paragraph_id": pid, "page": p["index"], "text": para})
                pid += 1
    return paras

def make_chunks(paras, min_tokens=180, max_tokens=600):
    # Token proxy: words count
    chunks = []
    buf = []
    wcount = 0
    for para in paras:
        words = para["text"].split()
        if wcount + len(words) > max_tokens and buf:
            chunks.append({
                "text": " ".join([b["text"] for b in buf]),
                "pages": list({b["page"] for b in buf}),
                "paragraph_ids": [b["paragraph_id"] for b in buf]
            })
            # start new buffer with overlap
            overlap = buf[-1:]
            buf = overlap + [para]
            wcount = sum(len(b["text"].split()) for b in buf)
        else:
            buf.append(para)
            wcount += len(words)
    if buf:
        chunks.append({
            "text": " ".join([b["text"] for b in buf]),
            "pages": list({b["page"] for b in buf}),
            "paragraph_ids": [b["paragraph_id"] for b in buf]
        })
    return chunks


In [None]:
!apt-get update
!apt-get install -y openjdk-21-jdk
!update-alternatives --install /usr/bin/java java /usr/lib/jvm/java-21-openjdk-amd64/bin/java 1
!update-alternatives --config java
!java -version


In [None]:
!pip install rank_bm25


In [None]:
# 3) Embeddings + FAISS index (dense) and BM25 (sparse)
import faiss, numpy as np
from rank_bm25 import BM25Okapi

def build_dense_index(chunks, model=TEXT_EMB):
    texts = [c["text"] for c in chunks]
    vecs = model.encode(texts, normalize_embeddings=True)
    index = faiss.IndexFlatIP(vecs.shape[1])
    index.add(np.array(vecs).astype(np.float32))
    return index, vecs, texts

def build_caption_index(figures, model=CAPTION_EMB):
    caps = [f["caption"] or "" for f in figures]
    vecs = model.encode(caps, normalize_embeddings=True)
    index = faiss.IndexFlatIP(vecs.shape[1])
    index.add(np.array(vecs).astype(np.float32))
    return index, vecs, caps

def build_bm25(chunks):
    tokenized = [c["text"].split() for c in chunks]
    bm25 = BM25Okapi(tokenized)
    return bm25, tokenized

# Correct usage (unpack the tuples properly)
text_index, text_vecs, text_texts = build_dense_index(chunks)
cap_index, cap_vecs, cap_texts   = build_caption_index(figs)
bm25, tokenized                  = build_bm25(chunks)


In [None]:
# Build indices
text_index, text_vecs, text_texts = build_dense_index(chunks)
cap_index, cap_vecs, cap_texts = build_caption_index(figs)
bm25, tokenized = build_bm25(chunks)

# Example query
query = "cloud computing adoption trends"
q_vec = TEXT_EMB.encode([query], normalize_embeddings=True).astype(np.float32)

# Dense retrieval
D, I = text_index.search(q_vec, k=5)
dense_hits = [text_texts[i] for i in I[0]]

# BM25 retrieval
bm25_scores = bm25.get_scores(query.split())
top_bm25 = np.argsort(bm25_scores)[::-1][:5]
bm25_hits = [chunks[i]["text"] for i in top_bm25]


In [None]:
# 4) Retrieval: text + figure captions + simple fusion
def retrieve(query, chunks, figures, text_index, cap_index, topk=6, figk=4):
    q_text_vec = TEXT_EMB.encode([query], normalize_embeddings=True).astype(np.float32)
    D, I = text_index.search(q_text_vec, topk)
    text_hits = [{"idx": int(i), "score": float(d), "chunk": chunks[int(i)]} for d, i in zip(D[0], I[0])]

    q_cap_vec = CAPTION_EMB.encode([query], normalize_embeddings=True).astype(np.float32)
    D2, I2 = cap_index.search(q_cap_vec, figk)
    fig_hits = [{"idx": int(i), "score": float(d), "figure": figures[int(i)]} for d, i in zip(D2[0], I2[0])]

    return text_hits, fig_hits


In [None]:
# 5) Visual QA (optional): extract key facts from figures
# Using Qwen-VL or LLaVA for small QA over images
from transformers import AutoProcessor, AutoModelForVision2Seq, AutoTokenizer, AutoModelForCausalLM

def load_vlm(name="Qwen/Qwen2-VL-2B-Instruct"):
    processor = AutoProcessor.from_pretrained(name)
    model = AutoModelForVision2Seq.from_pretrained(name, device_map="auto", torch_dtype="auto")
    return processor, model

def vqa_on_figures(fig_hits, question, processor, model):
    facts = []
    for fh in fig_hits:
        img = Image.open(io.BytesIO(fh["figure"]["image_bytes"])).convert("RGB")
        prompt = f"Question: {question}\nAnswer concisely with key values, trends, axes and units."
        inputs = processor(text=prompt, images=img, return_tensors="pt").to(model.device)
        out = model.generate(**inputs, max_new_tokens=128)
        ans = processor.batch_decode(out, skip_special_tokens=True)[0]
        facts.append({"page": fh["figure"]["page"], "caption": fh["figure"]["caption"], "answer": ans})
    return facts


In [None]:
# 6) Generation: summary without chain-of-thought, with citations and figures
from transformers import AutoModelForCausalLM, AutoTokenizer

def load_llm(name="mistralai/Mistral-7B-Instruct-v0.3"):
    tok = AutoTokenizer.from_pretrained(name, use_fast=True)
    model = AutoModelForCausalLM.from_pretrained(
        name, device_map="auto", torch_dtype="auto"
    )
    return tok, model

def build_prompt(query, text_hits, fig_hits, vqa_facts):
    # Assemble grounded context with locations
    text_ctx = []
    for h in text_hits:
        pages = ", ".join([str(p) for p in h["chunk"]["pages"]])
        snippet = h["chunk"]["text"][:1200]
        text_ctx.append(f"[TEXT p.{pages}] {snippet}")

    fig_ctx = []
    for f in fig_hits:
        fig_ctx.append(f"[FIGURE p.{f['figure']['page']}] {f['figure']['caption']}")

    vqa_ctx = []
    for v in vqa_facts:
        vqa_ctx.append(f"[FIGURE-FACT p.{v['page']}] {v['answer']}")

    system = (
        "You are a scientific assistant. Provide a concise summary answering the question. "
        "Do not reveal your reasoning steps. Only use provided context. "
        "Cite using page markers like [p.X] after each claim."
    )
    user = (
        f"Question: {query}\n\nContext:\n"
        + "\n".join(text_ctx[:6]) + "\n"
        + "\n".join(fig_ctx[:4]) + "\n"
        + ("\n".join(vqa_ctx) if vqa_ctx else "")
        + "\n\nInstructions:\n"
        "- Summarize without chain-of-thought.\n"
        "- Attribute claims with [p.PAGE].\n"
        "- If a figure supports a claim, cite it with [p.PAGE, figure]."
    )
    return system, user

def generate_answer(tok, model, system, user, max_new_tokens=400):
    prompt = f"<s>[SYSTEM]\n{system}\n[/SYSTEM]\n[USER]\n{user}\n[/USER]"
    inputs = tok(prompt, return_tensors="pt").to(model.device)
    out = model.generate(**inputs, max_new_tokens=max_new_tokens, do_sample=False)
    return tok.decode(out[0], skip_special_tokens=True)


In [None]:
from google.colab import files
uploaded = files.upload()

pdf_path = list(uploaded.keys())[0]  # use the actual uploaded filename
doc, pages = parse_pdf(pdf_path)
paras = blocks_to_paragraphs(pages)
chunks = make_chunks(paras)
figs = extract_figures_with_captions(pdf_path)

text_index, text_vecs, text_texts = build_dense_index(chunks)
cap_index, cap_vecs, cap_texts   = build_caption_index(figs)


query = "What are the main findings and their quantitative results?"
text_hits, fig_hits = retrieve(query, chunks, figs, text_index, cap_index)

# Optional VQA (skip if low GPU)
# vlp, vlm = load_vlm()
# vqa_facts = vqa_on_figures(fig_hits, query, vlp, vlm)
vqa_facts = []

tok, llm = load_llm()  # choose a small model if limited GPU
system, user = build_prompt(query, text_hits, fig_hits, vqa_facts)
answer = generate_answer(tok, llm, system, user)
print(answer)
