# Multimodal RAG (PDF with Images) — Gemini Version

In [4]:
# === Imports & Config ===
import os, io, base64
from dotenv import load_dotenv
load_dotenv()

import fitz  # PyMuPDF
from PIL import Image
import numpy as np
import torch

from transformers import CLIPProcessor, CLIPModel
from langchain_core.documents import Document
from langchain.text_splitter import RecursiveCharacterTextSplitter

# CONFIG — EDIT OR SET VIA .env
pdf_path = os.getenv("PDF_PATH")  # <-- set your PDF path
device = "cuda" if torch.cuda.is_available() else "cpu"

print(f"Using device: {device}")
print(f"PDF path: {pdf_path}")



Using device: cpu
PDF path: M:\job hunt\Harrisburg University Documents\ML Project\Kris_Naik_series_Projects\Multimodal RAG\The 2025 AI Engineering Report _ Amplify Partners.pdf


In [5]:
### CLIP MODEL — Contrastive Language Image PreTraining
clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to(device)
clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
clip_model.eval()  # IMPORTANT: put model in eval mode

def _normalize_torch(v: torch.Tensor) -> torch.Tensor:
    return v / (v.norm(dim=-1, keepdim=True) + 1e-8)

@torch.no_grad()
def embed_image(image_data) -> np.ndarray:
    """Embed an image with CLIP and return a normalized numpy vector."""
    if isinstance(image_data, str):  # path
        image = Image.open(image_data).convert("RGB")
    elif isinstance(image_data, (bytes, bytearray)):
        image = Image.open(io.BytesIO(image_data)).convert("RGB")
    elif isinstance(image_data, Image.Image):
        image = image_data.convert("RGB")
    else:
        raise ValueError("Unsupported image_data type for embed_image")

    inputs = clip_processor(images=image, return_tensors="pt").to(device)
    feats = clip_model.get_image_features(**inputs)
    feats = _normalize_torch(feats).squeeze(0).detach().cpu().numpy()
    return feats

@torch.no_grad()
def embed_text(text: str) -> np.ndarray:
    """Embed text with CLIP and return a normalized numpy vector."""
    inputs = clip_processor(
        text=text,
        return_tensors="pt",
        padding=True,
        truncation=True,
        max_length=77  # CLIP's text limit
    ).to(device)
    feats = clip_model.get_text_features(**inputs)
    feats = _normalize_torch(feats).squeeze(0).detach().cpu().numpy()
    return feats


Using a slow image processor as `use_fast` is unset and a slow processor was saved with this model. `use_fast=True` will be the default behavior in v4.52, even if the model was saved with a slow processor. This will result in minor differences in outputs. You'll still be able to use a slow processor with `use_fast=False`.


In [6]:
## Process PDF
assert os.path.exists(pdf_path), f"PDF not found: {pdf_path}"
doc = fitz.open(pdf_path)

## Storage for all documents and embeddings
all_docs = []               # will store langchain Document objects (text + image markers)
all_embeddings = []         # numpy vectors aligned with all_docs
image_data_store = {}       # {image_id: base64_png} for multimodal LLM input

splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=100)


In [7]:
for i, page in enumerate(doc):
    ## process text
    text = page.get_text() or ""
    if text.strip():
        temp_doc = Document(page_content=text, metadata={"page": i, "type": "text"})
        # Some LangChain versions use split_documents; older used split_document
        try:
            text_chunks = splitter.split_documents([temp_doc])  # standard
        except AttributeError:
            text_chunks = splitter.split_document([temp_doc])   # compat

        # Embed each chunk using CLIP
        for chunk in text_chunks:
            embedding = embed_text(chunk.page_content)
            all_docs.append(chunk)
            all_embeddings.append(embedding)

    ## process images
    for img_index, img in enumerate(page.get_images(full=True)):
        try:
            xref = img[0]
            base_image = doc.extract_image(xref)
            image_bytes = base_image["image"]

            # Convert to PIL and store as base64 PNG for LLM
            pil_image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
            buffered = io.BytesIO()
            pil_image.save(buffered, format="PNG")
            img_base64 = base64.b64encode(buffered.getvalue()).decode("utf-8")

            image_id = f"page_{i}_img_{img_index}"
            image_data_store[image_id] = img_base64

            # Embed the image
            embedding = embed_image(pil_image)

            # Create a lightweight doc to pair with the embedding
            image_doc = Document(
                page_content=f"[Image:{image_id}]",
                metadata={"page": i, "type": "image", "image_id": image_id}
            )
            all_docs.append(image_doc)
            all_embeddings.append(embedding)

        except Exception as e:
            print(f"Error processing image {img_index} on page {i}: {e}")

doc.close()

# Build embedding matrix (N x D), D=512 for CLIP ViT-B/32
emb_matrix = np.vstack(all_embeddings) if all_embeddings else np.zeros((0, 512))
print(f"Docs: {len(all_docs)} | Embeddings shape: {emb_matrix.shape}")


Docs: 70 | Embeddings shape: (70, 512)


## Retrivers Helpers (search + quick viewer)

In [8]:
## Retrieval helpers
def _cos_sim(a: np.ndarray, b: np.ndarray) -> float:
    a_n = a / (np.linalg.norm(a) + 1e-8)
    b_n = b / (np.linalg.norm(b) + 1e-8)
    return float(np.dot(a_n, b_n))

def search(query_text=None, query_image=None, top_k=5):
    assert (query_text is not None) ^ (query_image is not None), "Provide either query_text or query_image"
    if emb_matrix.shape[0] == 0:
        return []

    # Build query vector
    if query_text is not None:
        q = embed_text(query_text)
    else:
        if isinstance(query_image, (bytes, bytearray)):
            query_image = Image.open(io.BytesIO(query_image)).convert("RGB")
        q = embed_image(query_image)

    # Cosine similarity via pre-normalized trick
    q_n = q / (np.linalg.norm(q) + 1e-8)
    sims = emb_matrix @ q_n  # (N, D) @ (D,) -> (N,)

    idx = np.argsort(-sims)[:top_k]
    return [(float(sims[i]), all_docs[i]) for i in idx]

def show_matches(results):
    for score, d in results:
        kind = d.metadata.get("type")
        page = d.metadata.get("page")
        if kind == "text":
            snippet = d.page_content.replace("\n", " ")[:320]
            print(f"[{score:.3f}] page {page} :: {snippet}...")
        else:
            print(f"[{score:.3f}] page {page} :: [image] {d.metadata.get('image_id')}")


## Gemini-1.5 Multimodal Answer

In [9]:
## Gemini 1.5 Answer (multimodal)
def _answer_gemini(question: str, context: str, image_payloads_b64):
    """
    image_payloads_b64: list of dicts like {"mime_type": "image/png", "image_data": "<base64str>"}
    We convert them to PIL.Image for the Gemini SDK.
    """
    import google.generativeai as genai
    api_key = os.environ.get("GOOGLE_API_KEY")
    assert api_key, "GOOGLE_API_KEY not set. Put it in your .env or os.environ."

    genai.configure(api_key=api_key)
    model_name = os.getenv("GEMINI_MODEL", "gemini-1.5-flash")
    model = genai.GenerativeModel(model_name)

    # Build prompt (keep it grounded in retrieved context)
    sys_text = (
        "You are a concise, helpful assistant. "
        "Use ONLY the provided context. If the answer is not in the context, say you don't know.\n\n"
        f"Context:\n{context}\n\nQuestion: {question}\nAnswer:"
    )

    # Convert base64 → PIL for Gemini
    from PIL import Image
    imgs = []
    for p in image_payloads_b64[:4]:  # limit to keep request light
        try:
            if p.get("mime_type","").startswith("image/") and p.get("image_data"):
                img_bytes = base64.b64decode(p["image_data"])
                imgs.append(Image.open(io.BytesIO(img_bytes)).convert("RGB"))
        except Exception:
            continue

    # Send as mixed parts: text first, then images
    parts = [sys_text] + imgs
    resp = model.generate_content(parts, safety_settings=None)
    try:
        return resp.text.strip()
    except Exception:
        return str(resp)

def answer(question: str, top_k: int = 5):
    """
    RAG answer flow:
    1) Retrieve top_k relevant chunks/images.
    2) Build text context from chunks + attach some images.
    3) Ask Gemini. If Gemini isn't configured, return fallback with matches.
    """
    results = search(query_text=question, top_k=top_k)

    # 1) Build text context & image payloads from retrieval
    ctx_bits, image_payloads = [], []
    for score, d in results:
        if d.metadata.get("type") == "text":
            snippet = d.page_content.strip().replace("\n", " ")
            ctx_bits.append(f"[p{d.metadata.get('page','?')}] {snippet}")
        elif d.metadata.get("type") == "image":
            image_id = d.metadata.get("image_id")
            if image_id and image_id in image_data_store:
                image_payloads.append({
                    "type": "input_image",
                    "image_data": image_data_store[image_id],
                    "mime_type": "image/png",
                })

    context = "\n".join(ctx_bits[:10])  # trim to keep prompt lean

    # 2) Try Gemini; if not available, return a fallback
    try:
        ans = _answer_gemini(question, context, image_payloads)
        return {"answer": ans, "matches": results, "provider": "gemini"}
    except Exception as e:
        return {
            "answer": f"(LLM unavailable) {e}\nTop matches shown below.",
            "matches": results,
            "provider": None
        }


In [10]:
# Add this under your existing LLM wiring cell
def answer_charts_ok(question: str, top_k: int = 12, max_imgs: int = 6):
    # 1) retrieve (hybrid preferred, else CLIP-only)
    try:
        results = search_hybrid(query_text=question, top_k=top_k)
    except NameError:
        results = search(query_text=question, top_k=top_k)

    # 2) collect text + images
    ctx_bits, image_payloads = [], []
    img_count = 0
    for score, d in results:
        if d.metadata.get("type") == "text":
            snippet = d.page_content.strip().replace("\n", " ")
            ctx_bits.append(f"[p{d.metadata.get('page','?')}] {snippet}")
        elif d.metadata.get("type") == "image" and img_count < max_imgs:
            img_id = d.metadata.get("image_id")
            if img_id and img_id in image_data_store:
                image_payloads.append({
                    "type":"input_image",
                    "image_data": image_data_store[img_id],
                    "mime_type":"image/png"
                })
                img_count += 1
    context = "\n".join(ctx_bits[:10])

    # 3) chart-aware Gemini call
    def _answer_gemini_charts(q, ctx, imgs):
        import google.generativeai as genai, os, base64, io
        from PIL import Image
        genai.configure(api_key=os.environ["GOOGLE_API_KEY"])
        model = genai.GenerativeModel(os.getenv("GEMINI_MODEL","gemini-1.5-flash"))

        sys_text = (
            "You will see numeric tables and charts. Read them carefully.\n"
            "Summarize the key findings and trends using only the provided content.\n"
            "Prioritize charts/tables if narrative text is limited. Cite page numbers in brackets.\n\n"
            f"Context (text snippets):\n{ctx}\n\nQuestion: {q}\nAnswer:"
        )

        images = []
        for p in imgs[:6]:
            try:
                img_bytes = base64.b64decode(p["image_data"])
                images.append(Image.open(io.BytesIO(img_bytes)).convert("RGB"))
            except Exception:
                continue

        parts = [sys_text] + images
        resp = model.generate_content(parts, safety_settings=None)
        try:
            return resp.text.strip()
        except Exception:
            return str(resp)

    try:
        ans = _answer_gemini_charts(question, context, image_payloads)
        return {"answer": ans, "matches": results, "provider": "gemini"}
    except Exception as e:
        return {"answer": f"(LLM unavailable) {e}", "matches": results, "provider": None}


In [11]:
def ask_many(questions, use_charts=False, top_k=8, max_imgs=6, save_csv=None):
    """
    Ask multiple questions, print answers + top matches, and optionally save a CSV.
    - use_charts=True    -> uses answer_charts_ok (better for table/chart-heavy PDFs)
    - use_charts=False   -> uses answer (better for narrative PDFs)
    """
    fn = answer_charts_ok if use_charts else answer
    rows = []
    for q in questions:
        resp = fn(q, top_k=top_k) if not use_charts else fn(q, top_k=top_k, max_imgs=max_imgs)

        print("\n" + "="*80)
        print("Q:", q)
        print("\n--- Answer ---")
        print(resp["answer"])
        print("\n--- Top matches ---")
        show_matches(resp["matches"])

        pages = []
        for score, d in resp["matches"]:
            p = d.metadata.get("page")
            if isinstance(p, int):
                pages.append(p)
        rows.append({"question": q, "answer": resp["answer"], "pages_topk": pages[:top_k]})

    if save_csv:
        try:
            import pandas as pd, csv
            pd.DataFrame(rows).to_csv(save_csv, index=False, quoting=csv.QUOTE_MINIMAL)
            print(f"\nSaved: {save_csv}")
        except Exception as e:
            print("CSV save skipped:", e)
    return rows


## A quick Test

In [12]:
print("GOOGLE_API_KEY set:", bool(os.getenv("GOOGLE_API_KEY")))

# q = "Summarize the main findings of the paper."
# resp = answer(q, top_k=5)

# print("\n--- Answer ---")
# print(resp["answer"])

# print("\n--- Top matches ---")
# show_matches(resp["matches"])

questions = [
  "Use only the provided content. Cite pages like [p#]. List section headings with page numbers.",
  "Use only the provided content. Cite pages like [p#]. Extract 7 key findings with exact % and dates.",
  "Use only the provided content. Cite pages like [p#]. How often do teams update models? Give % weekly/monthly + reasons.",
  "Use only the provided content. Cite pages like [p#]. Compare effectiveness of LLMs vs agents; include definitions if given.",
  "Use only the provided content. Cite pages like [p#]. What share won’t pay more for lower latency/better reasoning? Any cuts by size?",
  "Use only the provided content. Cite pages like [p#]. Most used tools (vector DBs/orchestration/eval) with any %.",
  "Use only the provided content. Cite pages like [p#]. How do teams evaluate models (offline/online, metrics, examples)?",
  "Use only the provided content. Cite pages like [p#]. Deployment targets (cloud/on-prem/edge) and motives with %.",
  "Use only the provided content. Cite pages like [p#]. Grounding data sources; % and challenges.",
  "Use only the provided content. Cite pages like [p#]. Top 5 pains in AI engineering and their frequency.",
  "Use only the provided content. Cite pages like [p#]. Top priorities for the next 6–12 months (5 bullets).",
  "Use only the provided content. Cite pages like [p#]. From charts, summarize 5 trends; include figure titles and pages.",
  "Use only the provided content. Cite pages like [p#]. Methods/survey design: N, MOE, field dates, sampling.",
  "Use only the provided content. Cite pages like [p#]. Podcasts/newsletters people actually learn from.",
  "Use only the provided content. Cite pages like [p#]. 150-word executive summary + 3 citations."
]

# Narrative-first run
ask_many(questions, use_charts=False, top_k=10, save_csv="/mnt/data/amplify_batch_narrative.csv")

# If your hits are mostly charts/tables:
ask_many(questions, use_charts=True, top_k=14, max_imgs=6, save_csv="/mnt/data/amplify_batch_charts.csv")



GOOGLE_API_KEY set: True

Q: Use only the provided content. Cite pages like [p#]. List section headings with page numbers.

--- Answer ---
Here's a summary of the provided text, organized by section headings and page numbers:

**[p2] AI Engineering Experience:** Many seasoned developers are new to AI; 45% of respondents with 10+ years of software experience have 3 years or fewer of AI experience, and 1 in 10 have less than 1 year.

**[p9] Model Updates:**  New models are released frequently, with over 50% of respondents updating their models monthly, and 17% weekly.

**[p13] Other Modalities:** Usage of audio, image, and video lags significantly behind text.

**[p15] Image Generation:** Image generation is the most popular among other modalities (audio, image, video).

**[p18] Agents:**  Agents (LLM-controlled systems) are nascent; 80% of respondents say LLMs work well, but less than 20% say the same about agents. Fewer than 1 in 10 have no plans to use agents.

**[p19] Agent Permissio

[{'question': 'Use only the provided content. Cite pages like [p#]. List section headings with page numbers.',
  'answer': '(LLM unavailable) 429 You exceeded your current quota, please check your plan and billing details. For more information on this error, head to: https://ai.google.dev/gemini-api/docs/rate-limits. [violations {\n  quota_metric: "generativelanguage.googleapis.com/generate_content_free_tier_requests"\n  quota_id: "GenerateRequestsPerMinutePerProjectPerModel-FreeTier"\n  quota_dimensions {\n    key: "model"\n    value: "gemini-1.5-flash"\n  }\n  quota_dimensions {\n    key: "location"\n    value: "global"\n  }\n  quota_value: 15\n}\n, links {\n  description: "Learn more about Gemini API quotas"\n  url: "https://ai.google.dev/gemini-api/docs/rate-limits"\n}\n, retry_delay {\n  seconds: 26\n}\n]',
  'pages_topk': [9, 26, 31, 27, 18, 2, 19, 13, 30, 15, 24, 3, 1, 22]},
 {'question': 'Use only the provided content. Cite pages like [p#]. Extract 7 key findings with exact % a

In [13]:
# === Where to save outputs (MD/JSON/PNG/CSV) ===
import os
OUTPUT_DIR = r"M:\job hunt\Harrisburg University Documents\ML Project\Kris_Naik_series_Projects\Multimodal RAG"
os.makedirs(OUTPUT_DIR, exist_ok=True)
os.environ["OUTPUT_DIR"] = OUTPUT_DIR
print("Saving to:", OUTPUT_DIR)


Saving to: M:\job hunt\Harrisburg University Documents\ML Project\Kris_Naik_series_Projects\Multimodal RAG


## Saving to Location.

In [14]:
# Replace your existing ask_and_save with this version (if you haven't already)
import os, json, datetime, textwrap
import matplotlib.pyplot as plt

def ask_and_save(question: str, use_charts: bool = False, top_k: int = 10, max_imgs: int = 6,
                 out_dir: str | None = None, tag: str = "run"):
    out_dir = out_dir or os.getenv("OUTPUT_DIR", "/mnt/data/outputs")
    os.makedirs(out_dir, exist_ok=True)

    fn = answer_charts_ok if use_charts else answer
    resp = fn(question, top_k=top_k, **({"max_imgs": max_imgs} if use_charts else {}))

    pages = []
    for score, d in resp["matches"]:
        p = d.metadata.get("page")
        if isinstance(p, int):
            pages.append(p)
    top_pages = pages[:top_k]

    ts = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
    base = os.path.join(out_dir, f"{tag}_{ts}")

    # Markdown
    md = [
        "# Multimodal RAG — Q&A",
        f"**PDF:** {os.path.basename(os.getenv('PDF_PATH',''))}",
        f"**Question:** {question}",
        "**Answer:**",
        resp["answer"],
        f"**Top pages (retrieved):** {top_pages}",
    ]
    with open(base + ".md", "w", encoding="utf-8") as f:
        f.write("\n\n".join(md))

    # JSON
    with open(base + ".json", "w", encoding="utf-8") as f:
        json.dump({
            "pdf": os.path.basename(os.getenv("PDF_PATH","")),
            "question": question,
            "answer": resp["answer"],
            "pages_topk": top_pages
        }, f, ensure_ascii=False, indent=2)

    # PNG summary card
    fig = plt.figure(figsize=(12, 7))
    plt.axis("off")
    wrap_q = textwrap.fill(question, 95)
    wrap_a = textwrap.fill(resp["answer"][:1800], 95)
    plt.text(0.02, 0.94, "ReportFindings-RAG — Q&A", fontsize=18, weight="bold")
    plt.text(0.02, 0.88, f"PDF: {os.path.basename(os.getenv('PDF_PATH',''))}", fontsize=11)
    plt.text(0.02, 0.83, "Question:", fontsize=13, weight="bold"); plt.text(0.02, 0.79, wrap_q, fontsize=12, va="top")
    plt.text(0.02, 0.67, "Answer:", fontsize=13, weight="bold");  plt.text(0.02, 0.63, wrap_a, fontsize=12, va="top")
    plt.text(0.02, 0.09, f"Top pages: {top_pages}", fontsize=11)
    plt.text(0.02, 0.05, "Note: Generated from retrieved PDF context only.", fontsize=9)
    fig.savefig(base + ".png", dpi=200, bbox_inches="tight"); plt.close(fig)

    print("Saved:")
    print(" -", base + ".md")
    print(" -", base + ".json")
    print(" -", base + ".png")
    return resp
