## **Overview**

- Simple multimodal retrieval augmented generation (MRAG) system that lets a user upload their PDF and query them for response.
- The model used for this project is [google/gemma-3-4b-it](https://huggingface.co/google/gemma-3-4b-it), which is a multimodal model, from Hugging Face **Gemma3ForConditionalGeneration** and **AutoProcessor** classes for the model and its processor respectively.
- When a user uploads a PDF the texts are extracted separately from the images, then the images are passed to the multimodal model for captioning.
- After captioning the captions are chunked together with the extracted texts per page and indexed and stored for retrieval.
- Text extraction is carried out by `PyMuPDF`, and chunking is done by `LangChain's` `textsplitter`.
- `FAISS` is used for indexing and retrieval allowing semantic search for the system.

## Installing and importing the required libraries

In [1]:
# !pip install -q PyMUPDF bitsandbytes accelerate faiss-cpu gradio
!pip install -q PyMUPDF gradio faiss-cpu
!pip install -qU transformers accelerate bitsandbytes

[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m24.1/24.1 MB[0m [31m51.6 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m31.4/31.4 MB[0m [31m17.6 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m375.8/375.8 kB[0m [31m9.6 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m60.1/60.1 MB[0m [31m13.1 MB/s[0m eta [36m0:00:00[0m
[?25h

In [2]:
# !pip install -qU transformers accelerate bitsandbytes
# In case the model acts up again with the 'Max cache length is not consistent across layers'
# points to an internal architectural detail of Gemma 3 and how the Hugging Face transformers library interacts with it,
# specifically regarding its attention mechanism and KV cache management.

In [3]:
# import os
# os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
# os.environ["TORCH_USE_CUDA_DSA"] = "1"

In [4]:
import os
import json
import torch
import fitz
import faiss
import gc
import re
import shutil
import numpy as np
import gradio as gr

from time import time
from typing import List, Dict, Tuple
from threading import Lock
from PIL import Image
from threading import Thread

from langchain.text_splitter import RecursiveCharacterTextSplitter
# from transformers import AutoTokenizer, AutoModelForCausalLM
from sentence_transformers import SentenceTransformer
from transformers import AutoProcessor, Gemma3ForConditionalGeneration, BitsAndBytesConfig, TextIteratorStreamer

In [5]:
torch.set_num_threads(4) # Using 4 threads for intra-op parallelism
device = "cuda" if torch.cuda.is_available() else "cpu" # Initialize the device to use

## Setting up helper functions for the system

In [6]:
# Defining some helper functions
def save_cache(data: List[Dict], filepath: str) -> None:
  """Saving the chunks and the embeddings for easy retrieval in .json format"""
  try:
    with open(filepath, 'w', encoding='utf-8') as f:
      json.dump(data, f, ensure_ascii=False, indent=2)
  except Exception as e:
    print(f"Failed to save cache to {filepath}: {e}")

def load_cache(filepath: str) -> List[Dict]:
  """Loading the saved cache"""
  if os.path.exists(filepath):
    try:
      with open(filepath, 'r', encoding='utf-8') as f:
        return json.load(f)
    except Exception as e:
      print(f"Failed to load cache from {filepath}: {e}")
  return []

def clear_gpu_cache():
  """Clear GPU cache and run garbage collection(saving on memory)."""
  if torch.cuda.is_available():
    torch.cuda.empty_cache()
  gc.collect()

In [7]:
# # Lazy loading the model
# model_cache = {}
# cache_lock = Lock()

# def get_model(model_name:str, model_class, processor_class=None, quantization_config=None):
#   """Lazy load the model to save on memory"""
#   cache_key = f"{model_name}_{model_class.__name__}"
#   with cache_lock:
#     try:
#       if cache_key in model_cache:
#         return model_cache[cache_key]

#       # Load processor first if provided
#       if processor_class:
#         processor = processor_class.from_pretrained(model_name, use_fast=True)
#       else:
#         processor = None # No processor if not specified

#       if quantization_config:
#         model = model_class.from_pretrained(pretrained_model_name_or_path=model_name, device_dtype=torch.bfloat16, device_map="auto", low_cpu_mem_usage=True, quantization_config=quantization_config, attn_implementation="eager")
#       else:
#         model = model_class.from_pretrained(pretrained_model_name_or_path=model_name, torch_dtype=torch.bfloat16, device_map="auto", low_cpu_mem_usage=True)

#       if processor_class:
#         model_cache[cache_key] = (model, processor)
#       else:
#         model_cache[cache_key] = model

#       return model_cache[cache_key]
#     except Exception as e:
#       print(f"[INFO] Could not load {model_name}: {str(e)}")
#       print("Please ensure you have access to the model and are authenticated with Hugging Face.")
#       return None # Return None or raise an exception to indicate failure

# quantization_config = BitsAndBytesConfig(
#     load_in_4bit=True,
#     bnb_4bit_compute_dtype=torch.bfloat16,
#     bnb_4bit_use_double_quant=True,
#     bnb_4bit_quant_type="nf4",
#     bnb_4bit_quant_storage=torch.bfloat16,
# )

# model_and_processor = get_model(
#     model_name="google/gemma-3-4b-it",
#     model_class=Gemma3ForConditionalGeneration,
#     processor_class=AutoProcessor,
#     quantization_config=quantization_config,
# )

# if model_and_processor is not None:
#     model, processor = model_and_processor
#     model.eval()
# else:
#     print("[INFO!!] Model and tokenizer could not be loaded. Please check the error messages above.")
#     model = None
#     processor = None

## Setting up the model
- The model is quantized to help save in compute memory resources in the environment it is running in.

In [8]:
# Loading the quantized model and its processor

model_name = "google/gemma-3-4b-it"

bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_compute_dtype=torch.bfloat16,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_quant_storage=torch.bfloat16,
)

model = Gemma3ForConditionalGeneration.from_pretrained(
    pretrained_model_name_or_path=model_name,
    torch_dtype=torch.bfloat16,
    device_map="auto",
    low_cpu_mem_usage=True,
    quantization_config=bnb_config,
    attn_implementation="sdpa" #"eager"
).eval()

processor = AutoProcessor.from_pretrained(model_name, use_fast=True)

`torch_dtype` is deprecated! Use `dtype` instead!


config.json:   0%|          | 0.00/855 [00:00<?, ?B/s]

model.safetensors.index.json:   0%|          | 0.00/90.6k [00:00<?, ?B/s]

Fetching 2 files:   0%|          | 0/2 [00:00<?, ?it/s]

model-00002-of-00002.safetensors:   0%|          | 0.00/3.64G [00:00<?, ?B/s]

model-00001-of-00002.safetensors:   0%|          | 0.00/4.96G [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/215 [00:00<?, ?B/s]

processor_config.json:   0%|          | 0.00/70.0 [00:00<?, ?B/s]

chat_template.json:   0%|          | 0.00/1.61k [00:00<?, ?B/s]

preprocessor_config.json:   0%|          | 0.00/570 [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/1.16M [00:00<?, ?B/s]

tokenizer.model:   0%|          | 0.00/4.69M [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/33.4M [00:00<?, ?B/s]

added_tokens.json:   0%|          | 0.00/35.0 [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/662 [00:00<?, ?B/s]

## Text extraction per page

In [9]:
def extract_pages_text_and_images(pdf_path, image_dir):
  """Extract text and images page-wise"""
  doc = fitz.open(pdf_path)
  os.makedirs(image_dir, exist_ok=True)

  page_texts = []
  page_images = []

  for page_num in range(len(doc)):
    page = doc.load_page(page_num)
    text = page.get_text()

    # Store all images on this page, store their paths
    images = []
    for img_index, img in enumerate(page.get_images(full=True)):
      xref = img[0]
      base_image = doc.extract_image(xref)
      image_bytes = base_image["image"]
      image_ext = base_image["ext"]
      image_filename = f"page_{page_num + 1}_img_{img_index}.{image_ext}"
      image_path = os.path.join(image_dir, image_filename)
      with open(image_path, "wb") as img_file:
        img_file.write(image_bytes)
      images.append(image_path)

    page_texts.append(text)
    page_images.append(images)

  if doc: doc.close()
  return page_texts, page_images

## Image Captioning

In [10]:
# processor = AutoProcessor.from_pretrained("google/gemma-3-4b-it") # for captioning
def generate_image_descriptions(image_paths):
  """Generate images and tables descriptions as texts."""

  clear_gpu_cache()

  captions = []
  if not processor or not model:
    print("[ERROR] Model or Processor not loaded. Cannot generate image descriptions.")
    return []
  for image_path in image_paths:
    raw_image = Image.open(image_path)
    if raw_image.mode != "RGB":
      # print("🔁 Converting to RGB format.")
      image = raw_image.convert("RGB")
    else:
      # print("🖼️ Image is already in RGB format.")
      image = raw_image
    width, height = image.size
    if width < 32 or height < 32: # Filtering out smaller images that may disrupt the process
      # print(f"⚠️ Skipping image due to small dimensions: {image_path} ({width}x{height})")
      continue
    # img_array = np.array(image)
    # print(f"🎯Final image shape: {img_array.shape} | dtype: {img_array.dtype}")
    messages = [
        {"role": "system", "content": [{"type": "text", "text": "You are a helpful assistant."}]},
        {"role": "user", "content": [
            {"type": "image", "image": image},
            {"type": "text", "text": "Describe the factual content visible in the image. Be concise and accurate as the descriptions will be used for retrieval."}
        ]}
    ]
    try:
      inputs = processor.apply_chat_template(messages, add_generation_prompt=True, tokenize=True,
                                             return_dict=True, return_tensors="pt").to(model.device, dtype=torch.bfloat16)
      input_len = inputs["input_ids"].shape[-1] # To get rid of the prompt echo
      with torch.inference_mode():
        generated_ids = model.generate(
            **inputs,
            max_new_tokens=512,
            do_sample=False,
            cache_implementation="offloaded_static"
            )

        # Decode only the newly generated tokens
        # caption = processor.decode(generated_ids[0][input_len:], skip_special_tokens=True)

        raw = processor.decode(generated_ids[0], skip_special_tokens=True)
        caption = clean_caption(raw)
        captions.append({"image_path": image_path, "caption": caption})
    except Exception as e:
      print(f"[ERROR]: Failed to generate caption for image {image_path}: {e}")
      captions.append({"image_path": image_path, "caption": "<---image---> (Captioning failed)"}) # Add a placeholder caption
      continue
    finally:
      gc.collect()
      clear_gpu_cache()
  return captions

In [11]:
# Cleaning the captions from the extracted images
# Regex: match everything from "model\n" up through first double-newline after "sent:"
prefix_re = re.compile(
    r"model\s*\n.*?\bsent:\s*\n\n",
    flags=re.IGNORECASE | re.DOTALL,
)

def clean_caption(raw: str) -> str:
  # 1. Strip off the prompt/header by splitting once.
  parts = prefix_re.split(raw.strip(), maxsplit=1)
  if len(parts) == 2:
    return parts[1].strip()

  # 2. Fallback: if the caption begins with ** (bold header), return from there.
  bold_index = raw.find("**")
  if bold_index >= 0:
    return raw[bold_index:].strip()

  # 3. Last resort: return everything except the first paragraph.
  paras = raw.strip().split('\n\n', 1)
  return paras[-1].strip()  # might still include some leading noise

In [12]:
def generate_captions_per_page(page_image_paths_list):
  """"Generate captions per page's images"""
  page_captions = []
  for image_paths in page_image_paths_list:
    captions = generate_image_descriptions(image_paths)
    # Extract the 'caption' strings only
    captions_texts = [cap['caption'] for cap in captions]
    page_captions.append(captions_texts)
  return page_captions

## Merging extracted texts and the image descriptions, and chunking helper functions

In [13]:
def merge_text_and_captions(page_texts, page_captions):
  """Merge text, image captions and table descriptions per page"""
  combined_pages = []
  for page_num, (text, captions) in enumerate(zip(page_texts, page_captions), 1):
    page_content = text.strip() + "\n\n"
    for cap in captions:
      page_content += f"[Image Description]: {cap}\n\n"
    combined_pages.append(page_content)
  return combined_pages

def chunk_text_with_metadata(merged_pages):
  """
  Given a list of pages (strings) with combined text and image captions,
  split each page's content into chunks, attach metadata, and collect all chunks.

  Args:
      merged_pages (List[str]): List where each item is the content (text + captions) of a single page.

  Returns:
      List[dict]: List of chunked dicts with keys: content, page, chunk_id, type
  """
  text_splitter = RecursiveCharacterTextSplitter(
      separators=["\n\n", "\n", ".", " ", ""], # Recursive splitting separators, from paragraphs to words
      chunk_size =1000,
      chunk_overlap =200,
      add_start_index=True
  )

  all_chunks = []
  chunk_global_id = 0

  for page_num, page_content in enumerate(merged_pages, start=1):
    # Split page content into chunks
    page_chunks = text_splitter.split_text(page_content)

    # Tag metadata on each chunk
    for chunk_num, chunk_text in enumerate(page_chunks, start=1):
      chunk_dict = {
          "content": chunk_text,
          "page": page_num,
          "chunk_id": chunk_global_id,
          "chunk_number_on_page": chunk_num,
          "type": "extracted_texts_and_captions_descriptions"
      }
      all_chunks.append(chunk_dict)
      chunk_global_id += 1

  return all_chunks

## Indexing and Vector Store Helper Fuctions

In [14]:
# Vector Store Helper Functions using IndexFlatIP (for semantic search)
def init_faiss_indexflatip(embedding_dim:int=768) -> faiss.IndexFlatIP:
  index = faiss.IndexFlatIP(embedding_dim)
  return index

def add_embeddings_to_index(index, embeddings: np.ndarray):
  if embeddings.size > 0: # Embedding array is not empty
    index.add(embeddings.astype(np.float32))

def search_faiss_index(index, query_embedding: np.ndarray, k: int = 5):
  # Ensure query_embedding is 2D
    if query_embedding.ndim == 1:
      query_embedding = query_embedding.reshape(1, -1)
    distances, indices = index.search(query_embedding.astype(np.float32), k)
    return distances, indices

def save_faiss_index(index, filepath: str):
  faiss.write_index(index, filepath)

def load_faiss_index(filepath: str):
  return faiss.read_index(filepath)

## Embedding, Semantic Search Functions

In [15]:
# Initialize the embedding model once
embedding_model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2')

# Preprocess the uploaded PDF
def preprocess_pdf(file_path: str, image_dir: str, embedding_model,
                   index_file: str = "index.faiss",
                   chunks_file: str = "chunks.json",
                   use_cache: bool = True) -> Tuple[faiss.IndexFlatIP, List[Dict]]:

  # Loading cache to save on compute time and resources everytime a query is made
  if use_cache and os.path.exists(index_file) and os.path.exists(chunks_file):
    print("[INFO] Loading cached FAISS index and chunks...")
    index = load_faiss_index(index_file)
    chunks = load_cache(chunks_file)
    return index, chunks

  # Otherwise run full processing
  page_texts, page_images = extract_pages_text_and_images(file_path, image_dir)
  page_captions = generate_captions_per_page(page_images)
  merged_pages = merge_text_and_captions(page_texts, page_captions)

  # Delete extracted images after captioning
  cleanup_images(image_dir)

  # Chunk the merged pages
  chunks = chunk_text_with_metadata(merged_pages)
  texts = [chunk['content'] for chunk in chunks]

  # Geenrate embeddings and initialize faiss index with the dimensions of the embeddings
  embeddings = embedding_model.encode(texts, normalize_embeddings=True)
  embeddings = embeddings.astype(np.float32) # Making sure embeddings are in float32 format for FAISS
  embedding_dim = embeddings.shape[1]
  index = init_faiss_indexflatip(embedding_dim=embedding_dim)

  # Add embeddings to index
  add_embeddings_to_index(index=index, embeddings=embeddings)

  # Save index and chunks
  if use_cache:
    save_faiss_index(index, index_file)
    save_cache(chunks, chunks_file)

  return index, chunks

# Semantic search funtion that uses preprocessed data
def semantic_search(query, embedding_model, index, chunks, top_k=10):
  # Embed user query
  query_embedding = embedding_model.encode([query], normalize_embeddings=True)

  # Retrieve top matches from FAISS
  distances, indices = search_faiss_index(index, query_embedding, k=top_k)

  # Retrieve matched chunks
  retrieved_chunks = [chunks[i] for i in indices[0]]

  return retrieved_chunks

modules.json:   0%|          | 0.00/349 [00:00<?, ?B/s]

config_sentence_transformers.json:   0%|          | 0.00/116 [00:00<?, ?B/s]

README.md: 0.00B [00:00, ?B/s]

sentence_bert_config.json:   0%|          | 0.00/53.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/612 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/90.9M [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/350 [00:00<?, ?B/s]

vocab.txt: 0.00B [00:00, ?B/s]

tokenizer.json: 0.00B [00:00, ?B/s]

special_tokens_map.json:   0%|          | 0.00/112 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/190 [00:00<?, ?B/s]

In [16]:
# Delete the image directory after captioning
def cleanup_images(image_dir: str):
  try:
    shutil.rmtree(image_dir)
    print(f"[INFO] Removed entire directory: {image_dir}")
  except Exception as e:
    print(f"[WARNING] Failed to remove directory {image_dir}: {e}")

## **Test Run**

In [17]:
# # Test run to see the relevant chunks retrieved
# path_to_pdf = "Maasai_Mara.pdf"
# image_dir = "extracted_images"

# #**************** Test for sanity check ******************
# def test_chunking(file_path, image_dir):
#   texts, images = extract_pages_text_and_images(file_path, image_dir)
#   image_descriptions = generate_captions_per_page(images)
#   merged_pages = merge_text_and_captions(texts, image_descriptions)
#   chunks = chunk_text_with_metadata(merged_pages)
#   return chunks, [chunk['content'] for chunk in chunks]

# #******* An example semantic search (for sanity check) ***********
# def semantic_search_example(query_text, k=3):
#   extracted_chunks, texts = test_chunking(path_to_pdf, image_dir)
#   embeddings = embedding_model.encode(texts, normalize_embeddings=True)
#   index = init_faiss_indexflatip(embedding_dim=embeddings.shape[1])
#   add_embeddings_to_index(index, embeddings)
#   query_embedding = embedding_model.encode([query_text], normalize_embeddings=True)
#   distances, indices = search_faiss_index(index, query_embedding, k)
#   results = [extracted_chunks[i] for i in indices[0]]
#   return distances[0], results

# # Example query
# query = "What type of animals are found in Maasai Mara?"
# distances, results = semantic_search_example(query)
# for dist, res in zip(distances, results):
#   print(f"Score: {dist*100:.4f}%, Page: {res['page']}, Chunk ID: {res['chunk_id']}")
#   print(f"Content:\n{res['content']}\n")
#   print("-----")

In [18]:
# import torch
# torch.cuda.get_device_capability(0)[0]

In [19]:
# # Test run to see the prompt format
# query = "Where is Maasai Mara located and what is the size of Maasai Mara?"

# context_texts = [
#     "The Maasai Mara is a large national game reserve located in Narok County, Kenya. It is contiguous with the Serengeti National Park in Tanzania and is named after the Maasai people.",
#     "[Image Description]: A map showing the boundary of the Maasai Mara National Reserve and nearby conservancies.",
#     "The reserve covers approximately 1,510 square kilometers and is renowned for its annual wildebeest migration involving over 1.5 million animals.",
#     "[Image Description]: A photo showing herds of wildebeest crossing the Mara River during migration.",
#     "The area is primarily open grassland with seasonal riverlets and is home to lions, leopards, elephants, and more than 470 bird species."
# ]

# system_instruction = """You are a helpful and precise assistant for question-answering tasks.
#                       Use only the following pieces of retrieved context to answer the question.
#                       If the answer is not found in the provided context, state that the information is not available in the document. Do not use any external knowledge or make assumptions.
#                       """
# rag_prompt_content = ""
# if system_instruction:
#   rag_prompt_content += f"{system_instruction.strip()}\n\n"
# if context_texts:
#   rag_prompt_content += "Context:\n" +"-"+ "\n-".join(context_texts).strip() + "\n\n"
# rag_prompt_content += f"Question: {query.strip()}\nAnswer:"

# print(f"Input text:\n{rag_prompt_content}\n")
# # *** CRITICAL CHANGE: Use messages list and apply_chat_template ***
# messages = [
#     {"role": "user", "content": rag_prompt_content}
#   ]
# prompt = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True, return_dict=True)
# print(f"Formatted prompt:\n{prompt}")

In [20]:
# # Corrected messages format for multimodal processor
# messages = [
#     {"role": "user", "content": [{"type": "text", "text": rag_prompt_content}]}
# ]
# print(f"Corrected messages format:\n{messages}")

# input_ids = processor.apply_chat_template(
#     messages,
#     tokenize=True,
#     add_generation_prompt=True,
#     return_dict=True,
#     return_tensors="pt"
# ).to(model.device)
# print(f"Tokenized inputs for the model:\n{input_ids}")

# outputs = model.generate(**input_ids, max_new_tokens=256)
# print(f"Model outputs (tokens):\n{outputs[0]}")

In [21]:
# _len = input_ids["input_ids"].shape[-1]
# outputs_decoded = processor.decode(outputs[0][_len:], skip_special_tokens=True)
# print(f"Model outputs (decoded):\n{outputs_decoded}")

## **Continuation with the main code pipeline**

## Response Generation

In [22]:
def generate_answer(query, retrieved_chunks, model, processor):
  # Extract text content from top retrieved chunks
  context_texts = [chunk['content'] for chunk in retrieved_chunks]

  # Combine system instruction, context, and query into a single string for the user role
  system_instruction = """You are a helpful and precise assistant for question-answering tasks.
                      Use only the following pieces of retrieved context to answer the question.
                      You may provide the response in a structured markdown format if necessary.
                      If the answer is not found in the provided context, state that the information is not available in the document. Do not use any external knowledge or make assumptions.
                      """

  # Build the core prompt string, excluding specific turn markers
  # The processor.apply_chat_template will handle the proper formatting
  rag_prompt_content = ""
  if system_instruction:
      rag_prompt_content += f"{system_instruction.strip()}\n\n"
  if context_texts:
      rag_prompt_content += "Context:\n" +"-"+ "\n-".join(context_texts).strip() + "\n\n"
  rag_prompt_content += f"Question: {query.strip()}\nAnswer:"

  # Robust format for multimodal processor
  messages = [
        {"role": "user", "content": [{"type": "text", "text": rag_prompt_content}]}
    ]

  # Prepare model inputs using apply_chat_template
  # This will correctly format the prompt for Gemma 3
  inputs = processor.apply_chat_template(
      messages,
      add_generation_prompt=True, # Tell the model it is the start of its turn
      tokenize=True,
      return_dict=True,
      return_tensors="pt",
      truncation=True,
      max_length=4096  # Apply max_length here if needed, truncation will handle it
  ).to(model.device)

  # Debugging: Print the tokenized prompt for verification
  # print("[DEBUG] Tokenized Prompt for Model (decoded):")
  # print(processor.decode(inputs['input_ids'][0]))
  # print("🧾 Input IDs Length:", inputs['input_ids'].shape[1])

  with torch.inference_mode():
    generated_ids = model.generate(
        **inputs,
        max_new_tokens=512,
        do_sample=False, # For factual RAG, do_sample=False is usually better
        temperature=0.2,
        top_p=0.9
    )

  # Calculate input length to remove prompt echo
  input_len = inputs["input_ids"].shape[-1]

  # Decode only the newly generated tokens
  answer = processor.decode(generated_ids[0][input_len:], skip_special_tokens=True)

  # Post-processing for clean answer
  # The current post-processing for "Answer:" or "<start_of_turn>model"
  # might still be useful, but the slicing should prevent most echoes.
  if "Answer:" in answer:
    answer = answer.split("Answer:")[-1].strip()
  elif "<start_of_turn>model" in answer: # This should be mostly handled by slicing now
    answer = answer.split("<start_of_turn>model")[-1].strip()
  elif "Question:" in answer: # This should be mostly handled by slicing now
    answer = answer.split("Question:")[-1].strip()

  return answer.strip()

## Combining all the steps for simple response demo

In [23]:
def ask(query: str, pdf_path: str, image_dir: str, embedding_model, model, processor, top_k: int=10) -> str:
  if model is None or processor is None:
    return "[ERROR] Model or processor is not loaded. Cannot generate answer."

  # Preprocess the document
  index, chunks = preprocess_pdf(pdf_path, image_dir, embedding_model)

  # Perform semantic search
  retrieved_chunks = semantic_search(query, embedding_model, index, chunks, top_k)

  # Generate answer from the retrieved chunks
  answer = generate_answer(query, retrieved_chunks, model, processor)

  return answer

In [25]:
from pprint import pprint

pdf_path = "Maasai_Mara.pdf" # an example pdf
image_dir = "extracted_images"

query = "Which country is Maasai Mara located in?"
start = time()
response = ask(
    query=query,
    pdf_path=pdf_path,
    image_dir=image_dir,
    embedding_model=embedding_model,
    model=model,
    processor=processor
)
print(f"\n[INFO!!] Time taken for inference: {(time() - start):.4f} seconds")
print(f"\nLLM Response:\n{'=='*20}\n")
pprint(response)

The following generation flags are not valid and may be ignored: ['top_p', 'top_k']. Set `TRANSFORMERS_VERBOSITY=info` for more details.


[INFO] Removed entire directory: extracted_images


The following generation flags are not valid and may be ignored: ['temperature', 'top_p', 'top_k']. Set `TRANSFORMERS_VERBOSITY=info` for more details.



[INFO!!] Time taken for inference: 189.6682 seconds

LLM Response:

'Maasai Mara is located in Kenya, Rift Valley Province.'


## **Gradio Implementation**


In [26]:
def generate_answer_stream(query, retrieved_chunks, model, processor):
  """Feeds tokens gradually from LLM."""
  context_texts = [chunk['content'] for chunk in retrieved_chunks]

  # Combine system instruction, context, and query into a single string for the user role
  system_instruction = """You are a helpful and precise assistant for question-answering tasks.
                      Use only the following pieces of retrieved context to answer the question.
                      You may provide the response in a structured markdown response if necessary.
                      If the answer is not found in the provided context, state that the information is not available in the document. Do not use any external knowledge or make assumptions.
                      """

  # Build the core prompt string, excluding specific turn markers
  # The processor.apply_chat_template will handle the proper formatting
  rag_prompt_content = ""
  if system_instruction:
      rag_prompt_content += f"{system_instruction.strip()}\n\n"
  if context_texts:
      rag_prompt_content += "Context:\n" +"-"+ "\n-".join(context_texts).strip() + "\n\n"
  rag_prompt_content += f"Question: {query.strip()}\nAnswer:"

  # Robust format for multimodal processor
  messages = [
        {"role": "user", "content": [{"type": "text", "text": rag_prompt_content}]}
    ]

  # Prepare model inputs using apply_chat_template
  # This will correctly format the prompt for Gemma 3
  inputs = processor.apply_chat_template(
      messages,
      add_generation_prompt=True, # Tell the model it is the start of its turn
      tokenize=True,
      return_dict=True,
      return_tensors="pt",
      truncation=True,
      max_length=4096  # Apply max_length here if needed, truncation will handle it
  ).to(model.device)

  streamer = TextIteratorStreamer(processor.tokenizer, skip_prompt=True, decode_kwargs={"skip_special_tokens": True})
  with torch.inference_mode():
    model.generate(**inputs, streamer=streamer, use_cache=True, max_new_tokens=512)

  accumulated = ""
  for new_text in streamer:
    # time.sleep(0.2)
    accumulated += new_text
    yield accumulated

  # Free memory after streaming is complete
  clear_gpu_cache()
  gc.collect()

### Setting up the Gradio Interface

In [27]:
# Global state shared across chats
state = {
    "index": None,
    "chunks": None,
    "pdf_path": None,
}

def handle_pdf_upload(file):
    if file is None:
        return "[ERROR] No file uploaded."

    state["pdf_path"] = file.name
    state["image_dir"] = "extracted_images"  # You can make this dynamic if needed

    # Run your PDF preprocessing
    index, chunks = preprocess_pdf(state["pdf_path"], state["image_dir"], embedding_model)
    state["index"] = index
    state["chunks"] = chunks

    return "✅ Document processed and ready for Q&A!"

def chat_streaming(message, history):
    if state["index"] is None or state["chunks"] is None:
      yield "[ERROR] Please upload and process a PDF first."

    # Perform semantic search
    retrieved_chunks = semantic_search(message, embedding_model, state["index"], state["chunks"], top_k=10)

    # Stream the answer
    for partial in generate_answer_stream(message, retrieved_chunks, model, processor):
      yield partial

description = """
 Remember to be specific when querying for better response.
 📖🧐
"""

with gr.Blocks() as demo:
    gr.Markdown("## 📚Multimodal RAG System\nUpload a PDF (≤50 pages recommended) and ask questions about it.")

    with gr.Row():
      file_input = gr.File(label="📂Upload PDF")
      upload_button = gr.Button("🔁Process PDF")

    upload_status = gr.Textbox(label="Upload Status", interactive=False)
    upload_button.click(handle_pdf_upload, inputs=file_input, outputs=upload_status)

    chat = gr.ChatInterface(
            fn=chat_streaming,
            type="messages",
            title="📄Ask Questions from PDF",
            description=description,
            examples=[["What is this document about?"]]
        )
    chat.queue()
demo.launch()

It looks like you are running Gradio on a hosted Jupyter notebook, which requires `share=True`. Automatically setting `share=True` (you can turn this off by setting `share=False` in `launch()` explicitly).

Colab notebook detected. To show errors in colab notebook, set debug=True in launch()
* Running on public URL: https://7832b107fae909c28e.gradio.live

This share link expires in 1 week. For free permanent hosting and GPU upgrades, run `gradio deploy` from the terminal in the working directory to deploy to Hugging Face Spaces (https://huggingface.co/spaces)




**Suggested Improvements**
- Make the multimoda RAG agentic by integrating LangChain and LangGraph agentic features.
- Enable internet search in case the provided information is not enough for a coherent or proper response.