- This notebook demonstrates the upload for the multimodal RAG system to Hugging Face Space programmatically from the notebook setting.

In [None]:
# !pip install -q PyMuPDF gradio faiss-cpu
# !pip install -qU transformers accelerate bitsandbytes

File structure of the RAG system that will be published to Hugging Face Space:
```
  setup/
    |-- multimodal_rag/
        |-- app.py
        |-- main.py
        |-- model_setup.py
        |-- utils.py
        |-- README.md
        |-- requirements.txt
        |-- cache/
```

- `app.py` - contains the gradio setup and interface
- `main.py` - for PDF processing, chunking, embeddings, semantic search
- `model_setup.py` - loads Gemma3, processor and embedding model.
- `utils.py` - helper functiions for FAISS, caching and cleanup.
- `cache/` - a directory that is autogenerated per-PDF FAISS and chunk files.
- `README.md` - a little documnentation about the system, also includes the yaml block at the top with configuration details which is required when publishing to hugging face space.
- `requirements.txt` - a txt file with the required packages to be installed for running the system listed.

***NOTE:*** For the Hugging Face version that is uploaded, the model is not going to be set to use GPU because for the free tier Hugging Face Space, CPU is only available.

## Create directory path for the files

In [None]:
from pathlib import Path

dir_name = Path("setup/multimodal_rag/")
dir_name.mkdir(exist_ok=True, parents=True)

## `utils.py`

In [None]:
%%writefile setup/multimodal_rag/utils.py
"""Contains helper functions that are used in the RAG pipeline."""

import os
import gc
import json
import torch
import shutil
from typing import List, Dict
import faiss
import numpy as np

def save_cache(data: List[Dict], filepath: str) -> None:
  """Saving the chunks and the embeddings for easy retrieval in .json format"""
  try:
    with open(filepath, 'w', encoding='utf-8') as f:
      json.dump(data, f, ensure_ascii=False, indent=2)
  except Exception as e:
    print(f"Failed to save cache to {filepath}: {e}")

def load_cache(filepath: str) -> List[Dict]:
  """Loading the saved cache"""
  if os.path.exists(filepath):
    try:
      with open(filepath, 'r', encoding='utf-8') as f:
        return json.load(f)
    except Exception as e:
      print(f"Failed to load cache from {filepath}: {e}")
  return []


# Vector Store Helper Functions using IndexFlatIP (for semantic search)
def init_faiss_indexflatip(embedding_dim:int) -> faiss.IndexFlatIP:
  index = faiss.IndexFlatIP(embedding_dim)
  return index

def add_embeddings_to_index(index, embeddings: np.ndarray):
  if embeddings.size > 0: # Embedding array is not empty
    index.add(embeddings.astype(np.float32))

def search_faiss_index(index, query_embedding: np.ndarray, k: int = 5):
  # Ensure query_embedding is 2D
    if query_embedding.ndim == 1:
      query_embedding = query_embedding.reshape(1, -1)
    distances, indices = index.search(query_embedding.astype(np.float32), k)
    return distances, indices

def save_faiss_index(index, filepath: str):
  faiss.write_index(index, filepath)

def load_faiss_index(filepath: str):
  return faiss.read_index(filepath)


# Deleting extracted images directory after captioning
def cleanup_images(image_dir: str):
  try:
    shutil.rmtree(image_dir)
    print(f"[INFO] Cleaned up extracted images directory: {image_dir}")
  except Exception as e:
    print(f"[WARNING] Failed to delete some images in {image_dir}: {e}")

# Just being agnostic because my space may only be using CPU but why not?
def clear_gpu_cache():
  """Clear GPU cache and run garbage collection(saving on memory)."""
  if torch.cuda.is_available():
    torch.cuda.empty_cache()
  gc.collect()

Overwriting setup/multimodal_rag/utils.py


## `model_setup.py`

In [None]:
# Set up for the Hugging Face authorization within the notebook
from google.colab import userdata
from huggingface_hub import login

HF_TOKEN = userdata.get("HF_TOKEN_2")
login(HF_TOKEN)

In [None]:
%%writefile setup/multimodal_rag/model_setup.py
"""loading the models to be used by the Mulltimodal RAG system."""

import torch
import gc

from sentence_transformers import SentenceTransformer
from transformers import AutoProcessor, Gemma3ForConditionalGeneration, BitsAndBytesConfig
# from accelerate import disk_offload
from utils import clear_gpu_cache

# device = "cuda" if torch.cuda.is_available() else "cpu"

# Embedding model
embedding_model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")

# Gemma3 quantization config
model_name = "google/gemma-3-4b-it"
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_compute_dtype=torch.bfloat16,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4",
    # llm_int8_enable_fp32_cpu_offload=True  # Allow offloading
)

# Load Gemma3
model = Gemma3ForConditionalGeneration.from_pretrained(
    model_name,
    torch_dtype=torch.bfloat16,
    device_map="cpu", # Explicitly avoid meta tensors
    # quantization_config=bnb_config,
    # low_cpu_mem_usage=False, # To avoid lazy meta tensors
    # attn_implementation="sdpa"
)
# disk_offload(model=model, offload_dir="offload")
model.to("cpu") # Explicitly load to CPU
model.eval()

# Processor
processor = AutoProcessor.from_pretrained(model_name, use_fast=True)

clear_gpu_cache()
gc.collect()

Overwriting setup/multimodal_rag/model_setup.py


## `main.py`

In [None]:
%%writefile setup/multimodal_rag/main.py
"""Main Mulitmodal-RAG pipeline script."""

import os
import torch
import fitz #PyMuPDF
import faiss
import re
import gc
import numpy as np


from typing import List, Dict, Tuple
from PIL import Image

from langchain.text_splitter import RecursiveCharacterTextSplitter
from transformers import TextIteratorStreamer

from utils import (
    save_cache, load_cache,
    init_faiss_indexflatip, add_embeddings_to_index,
    search_faiss_index, save_faiss_index, load_faiss_index, cleanup_images, clear_gpu_cache
)

from model_setup import embedding_model, model, processor

torch.set_num_threads(4) # Just being agnostic

device = "cuda" if torch.cuda.is_available() else "cpu"

# Function to extract text and images from each page of the PDF
# This function uses PyMuPDF (fitz) to extract text and images from each page
def extract_pages_text_and_images(pdf_path, image_dir):
  """Extract text and images page-wise"""
  doc = fitz.open(pdf_path)
  os.makedirs(image_dir, exist_ok=True)

  page_texts = []
  page_images = []

  for page_num in range(len(doc)):
    page = doc.load_page(page_num)
    text = page.get_text()

    # Store all images on this page, store their paths
    images = []
    for img_index, img in enumerate(page.get_images(full=True)):
      xref = img[0]
      base_image = doc.extract_image(xref)
      image_bytes = base_image["image"]
      image_ext = base_image["ext"]
      image_filename = f"page_{page_num + 1}_img_{img_index}.{image_ext}"
      image_path = os.path.join(image_dir, image_filename)
      with open(image_path, "wb") as img_file:
        img_file.write(image_bytes)
      images.append(image_path)

    page_texts.append(text)
    page_images.append(images)

  if doc: doc.close()
  return page_texts, page_images

# Generate image descriptions using the Gemma3 model
# This function will be called in parallel for each page's images
def generate_image_descriptions(image_paths):
  """Generate images and tables descriptions as texts."""
  captions = []
  if not processor or not model:
    print("[ERROR] Model or Processor not loaded. Cannot generate image descriptions.")
    return []
  for image_path in image_paths:
    raw_image = Image.open(image_path)
    if raw_image.mode != "RGB":
      image = raw_image.convert("RGB")
    else:
      image = raw_image
    width, height = image.size
    if width < 32 or height < 32: # Filtering out smaller images that may disrupt the process
      continue
    messages = [
        {"role": "system", "content": [{"type": "text", "text": "You are a helpful assistant."}]},
        {"role": "user", "content": [
            {"type": "image", "image": image},
            {"type": "text", "text": "Describe the factual content visible in the image. Be concise and accurate as the descriptions will be used for retrieval."}
        ]}
    ]
    try:
      inputs = processor.apply_chat_template(messages, add_generation_prompt=True, tokenize=True,
                                             return_dict=True, return_tensors="pt").to("cpu", dtype=torch.bfloat16)
      input_len = inputs["input_ids"].shape[-1] # To get rid of the prompt echo
      with torch.inference_mode():
        generated_ids = model.generate(
            **inputs,
            max_new_tokens=512,
            do_sample=False,
            cache_implementation="offloaded_static"
            )

        raw = processor.decode(generated_ids[0], skip_special_tokens=True)
        caption = clean_caption(raw)
        captions.append({"image_path": image_path, "caption": caption})
    except Exception as e:
      print(f"[⚠️ ERROR]: Failed to generate caption for image {image_path}: {e}")
      captions.append({"image_path": image_path, "caption": "<---image---> (Captioning failed)"}) # Add a placeholder caption
      continue
    finally:
      gc.collect()
      clear_gpu_cache()
  return captions

# Cleaning the captions from the extracted images
# Regex: match everything from "model\n" up through first double-newline after "sent:"
prefix_re = re.compile(
    r"model\s*\n.*?\bsent:\s*\n\n",
    flags=re.IGNORECASE | re.DOTALL,
)

def clean_caption(raw: str) -> str:
  # 1. Strip off the prompt/header by splitting once.
  parts = prefix_re.split(raw.strip(), maxsplit=1)
  if len(parts) == 2:
    return parts[1].strip()

  # 2. Fallback: if the caption begins with ** (bold header), return from there.
  bold_index = raw.find("**")
  if bold_index >= 0:
    return raw[bold_index:].strip()

  # 3. Last resort: return everything except the first paragraph.
  paras = raw.strip().split('\n\n', 1)
  return paras[-1].strip()  # might still include some leading noise

# Generate captions for all images on each page
def generate_captions_per_page(page_image_paths_list):
  """"Generate captions per page's images"""
  page_captions = []
  for image_paths in page_image_paths_list:
    captions = generate_image_descriptions(image_paths)
    # Extract the 'caption' strings only
    captions_texts = [cap['caption'] for cap in captions]
    page_captions.append(captions_texts)
  return page_captions

# Merge text and captions for each page
# This function combines the text and captions for each page into a single string
def merge_text_and_captions(page_texts, page_captions):
  """Merge text, image captions and table descriptions per page"""
  combined_pages = []
  for page_num, (text, captions) in enumerate(zip(page_texts, page_captions), 1):
    page_content = text.strip() + "\n\n"
    for cap in captions:
      page_content += f"[Image Description]: {cap}\n\n"
    combined_pages.append(page_content)
  return combined_pages

# Chunk the merged pages into smaller text chunks with metadata
# This function splits the combined text of each page into smaller chunks
def chunk_text_with_metadata(merged_pages):
  """
  Given a list of pages (strings) with combined text and image captions,
  split each page's content into chunks, attach metadata, and collect all chunks.

  Args:
      merged_pages (List[str]): List where each item is the content (text + captions) of a single page.

  Returns:
      List[dict]: List of chunked dicts with keys: content, page, chunk_id, type
  """
  text_splitter = RecursiveCharacterTextSplitter(
      separators=["\n\n", "\n", ".", " ", ""], # Recursive splitting separators, from paragraphs to words
      chunk_size =1000,
      chunk_overlap =200,
      add_start_index=True
  )

  all_chunks = []
  chunk_global_id = 0

  for page_num, page_content in enumerate(merged_pages, start=1):
    # Split page content into chunks
    page_chunks = text_splitter.split_text(page_content)

    # Tag metadata on each chunk
    for chunk_num, chunk_text in enumerate(page_chunks, start=1):
      chunk_dict = {
          "content": chunk_text,
          "page": page_num,
          "chunk_id": chunk_global_id,
          "chunk_number_on_page": chunk_num,
          "type": "extracted_texts_and_captions_descriptions"
      }
      all_chunks.append(chunk_dict)
      chunk_global_id += 1

  return all_chunks

# Preprocess the uploaded PDF
def preprocess_pdf(file_path: str, image_dir: str, embedding_model,
                   index_file: str = "index.faiss",
                   chunks_file: str = "chunks.json",
                   use_cache: bool = True) -> Tuple[faiss.IndexFlatIP, List[Dict]]:

  if not os.path.exists(file_path):
    raise FileNotFoundError(f"PDF not found at {file_path}")

  # Loading cache to save on compute time and resources everytime a query is made
  if use_cache and os.path.exists(index_file) and os.path.exists(chunks_file):
    print("[INFO] Loading cached FAISS index and chunks...")
    index = load_faiss_index(index_file)
    chunks = load_cache(chunks_file)
    return index, chunks

  # Cleanup stale cache if not using it or if missing
  if not use_cache or not (os.path.exists(index_file) and os.path.exists(chunks_file)):
    if os.path.exists(index_file):
      os.remove(index_file)
    if os.path.exists(chunks_file):
      os.remove(chunks_file)

  # Otherwise run full processing
  try:
    page_texts, page_images = extract_pages_text_and_images(file_path, image_dir)
  except Exception as e:
    print(f"Error reading PDF: {e}")
    raise e

  page_captions = generate_captions_per_page(page_images)
  merged_pages = merge_text_and_captions(page_texts, page_captions)

  # Delete extracted images after captioning
  cleanup_images(image_dir)

  # Chunk the merged pages
  chunks = chunk_text_with_metadata(merged_pages)
  texts = [chunk['content'] for chunk in chunks]

  # Generate embeddings and initialize faiss index with the dimensions of the embeddings
  embeddings = embedding_model.encode(texts, normalize_embeddings=True)
  embeddings = embeddings.astype(np.float32) # Making sure embeddings are in float32 format for FAISS

  embedding_dim = embeddings.shape[1]
  index = init_faiss_indexflatip(embedding_dim=embedding_dim)

  # Add embeddings to index
  add_embeddings_to_index(index=index, embeddings=embeddings)

  # Save index and chunks
  if use_cache:
    save_faiss_index(index, index_file)
    save_cache(chunks, chunks_file)

  return index, chunks

# Semantic search funtion that uses preprocessed data
def semantic_search(query, embedding_model, index, chunks, top_k=10):
  # Embed user query
  query_embedding = embedding_model.encode([query], normalize_embeddings=True)

  # Retrieve top matches from FAISS
  distances, indices = search_faiss_index(index, query_embedding, k=top_k)

  # Retrieve matched chunks
  retrieved_chunks = [chunks[i] for i in indices[0]]

  return retrieved_chunks


# Generate answer for Gradio interface
def generate_answer_stream(query, retrieved_chunks, model, processor):
  """Feeds tokens gradually from LLM."""
  context_texts = [chunk['content'] for chunk in retrieved_chunks]

  # Combine system instruction, context, and query into a single string for the user role
  system_instruction = """You are a helpful and precise assistant for question-answering tasks.
                      Use only the following pieces of retrieved context to answer the question.
                      You may provide the response in a structured markdown response if necessary.
                      If the answer is not found in the provided context, state that the information is not available in the document. Do not use any external knowledge or make assumptions.
                      """

  # Build the core prompt string, excluding specific turn markers
  # The processor.apply_chat_template will handle the proper formatting
  rag_prompt_content = ""
  if system_instruction:
      rag_prompt_content += f"{system_instruction.strip()}\n\n"
  if context_texts:
      rag_prompt_content += "Context:\n" +"-"+ "\n-".join(context_texts).strip() + "\n\n"
  rag_prompt_content += f"Question: {query.strip()}\nAnswer:"

  # Robust format for multimodal processor
  messages = [
        {"role": "user", "content": [{"type": "text", "text": rag_prompt_content}]}
    ]

  # Prepare model inputs using apply_chat_template
  # This will correctly format the prompt for Gemma 3
  inputs = processor.apply_chat_template(
      messages,
      add_generation_prompt=True, # Tell the model it is the start of its turn
      tokenize=True,
      return_dict=True,
      return_tensors="pt",
      truncation=True,
      max_length=4096  # Apply max_length here if needed, truncation will handle it
  ).to("cpu")

  streamer = TextIteratorStreamer(processor.tokenizer, skip_prompt=True, decode_kwargs={"skip_special_tokens": True})
  with torch.inference_mode():
    model.generate(**inputs, streamer=streamer, use_cache=True, max_new_tokens=512)
    gc.collect() # Free memory after model generation

  accumulated = ""
  for new_text in streamer:
    # time.sleep(0.2)
    accumulated += new_text
    yield accumulated

  # Free memory after streaming is complete
  clear_gpu_cache()
  gc.collect()

Overwriting setup/multimodal_rag/main.py


## `app.py`

In [None]:
%%writefile setup/multimodal_rag/app.py
"""Gradio setup for the Multimodal RAG system."""
import os
import torch
import shutil
import gradio as gr
# import gc

from utils import save_cache, load_cache, save_faiss_index, load_faiss_index
from model_setup import embedding_model, model, processor
from main import preprocess_pdf, semantic_search, generate_answer_stream

torch.set_num_threads(4)  # cpu thread limit

# Creating a cache directory for the retrieved chunks and index files
CACHE_DIR = "cache_dir"
os.makedirs(CACHE_DIR, exist_ok=True)

INDEX_FILE = os.path.join(CACHE_DIR, "index.faiss")
CHUNKS_FILE = os.path.join(CACHE_DIR, "chunks.json")

# Global state shared across chats
state = {
    "index": None,
    "chunks": None,
    "pdf_path": None,
}

def handle_pdf_upload(file):
  if file is None:
      return "[ERROR] No file uploaded."

  state["pdf_path"] = file.name
  state["image_dir"] = os.path.join(CACHE_DIR, "extracted_images")

  try:
    if os.path.exists(INDEX_FILE) and os.path.exists(CHUNKS_FILE):
      # Load from cache
      state["index"] = load_faiss_index(INDEX_FILE)
      state["chunks"] = load_cache(CHUNKS_FILE)
      return "✅ Loaded from cache and ready for Q&A!"
    else:
      # Run your PDF preprocessing
      index, chunks = preprocess_pdf(
          state["pdf_path"],
          state["image_dir"],
          embedding_model=embedding_model,
          index_file=INDEX_FILE,
          chunks_file=CHUNKS_FILE,
          use_cache=True)
      state["index"] = index
      state["chunks"] = chunks

      # Save to cache
      save_faiss_index(index, INDEX_FILE)
      save_cache(chunks, CHUNKS_FILE)

      return "✅ Document processed and ready for Q&A!"
  except Exception as e:
    return f"[⚠️ ERROR] Failed to process document: {e}"

def chat_streaming(message, history):
    if state["index"] is None and state["chunks"] is None:
      yield "[ERROR] Please upload and process a PDF first."
      return

    # Perform semantic search
    retrieved_chunks = semantic_search(message, embedding_model, state["index"], state["chunks"], top_k=10)

    # Stream the answer
    for partial in generate_answer_stream(message, retrieved_chunks, model, processor):
      yield partial

# Function for clearing the cache files before uploading another document to prevent stale cache retrieval
def manual_clear_cache():
  if not os.path.exists(INDEX_FILE) or not os.path.exists(CHUNKS_FILE):
    return "⚠️No cache files exists to clear."
  if os.path.exists(CACHE_DIR):
    shutil.rmtree(CACHE_DIR)

  state["index"], state["chunks"] = None, None
  return "✅ Cache cleared! You can upload a new document now."

description = """
 Remember to be specific when querying for better response.
 📖🧐
"""

with gr.Blocks() as demo:
    gr.Markdown("## 📚Multimodal RAG System\nUpload a PDF (≤50 pages recommended) and ask questions about it.")

    with gr.Row():
      file_input = gr.File(label="📂Upload PDF")
      upload_button = gr.Button("🔁Process PDF")

    with gr.Row():
      clear_cache_button = gr.Button("🧹 Clear Cache")
      clear_cache_status = gr.Textbox(label="Cache Clear Status", interactive=False)

    upload_status = gr.Textbox(label="Upload Status", interactive=False)
    upload_button.click(handle_pdf_upload, inputs=file_input, outputs=upload_status)
    clear_cache_button.click(manual_clear_cache, outputs=clear_cache_status)

    chat = gr.ChatInterface(
            fn=chat_streaming,
            type="messages",
            title="📄Ask Questions from PDF",
            description=description,
            examples=[["What is this document about?"]]
        )
    chat.queue()
demo.launch()

Overwriting setup/multimodal_rag/app.py


## `README.md`

In [None]:
%%writefile setup/multimodal_rag/README.md
---
title: Multimodal RAG System 📄
emoji: 📚
colorFrom: purple
colorTo: pink
sdk: gradio
sdk_version: 5.43.1
app_file: app.py
pinned: false
license: mit
---

# Multimodal RAG System 📖

A **Multimodal Retrieval-Augmented Generation (RAG) system** that allows users to upload PDFs and ask questions based on the text, images, and tables in the document. Uses **Gemma3** for image captioning and multimodal text generation and **SentenceTransformer** + **FAISS** for semantic search.

## Features

- Extracts text and images from PDF documents.
- Generates factual captions for images and tables.
- Chunks the combined text + captions for efficient retrieval.
- Stores embeddings in a FAISS index for fast semantic search.
- Streams answers from the LLM using Gradio interface.
- Efficient memory usage with bitsandbytes 4-bit quantization.

The **[google/gemma-3-4b-it](https://huggingface.co/google/gemma-3-4b-it)** is both used to generate image descriptions for the extracted images and for text generation for the RAG system.

## `requirements.txt`

In [None]:
%%writefile setup/multimodal_rag/requirements.txt
torch
transformers
numpy
gradio
pillow
PyMuPDF
faiss-cpu
sentence-transformers
langchain
bitsandbytes
accelerate

Overwriting setup/multimodal_rag/requirements.txt


In [None]:
!ls setup/multimodal_rag/

app.py	main.py  model_setup.py  README.md  requirements.txt  utils.py


## Uploading to Hugging Face Hub Space

In [None]:
import warnings
warnings.filterwarnings("ignore", category=UserWarning, module="huggingface_hub")

In [None]:
from huggingface_hub import create_repo, get_full_repo_name, upload_file, upload_folder

# Defining the parameters to use for the upload
LOCAL_FOLDER_TO_UPLOAD = "setup/multimodal_rag"
HF_TARGET_SPACE_NAME = "multimodal_rag_system"
HF_REPO_TYPE = "space"
HF_SPACE_SDK = "gradio"
# HF_TOKEN = ""

# Create a space repo on Hugging Face Hub
print(f"ℹ️ Creating repo on HF Hub with name: {HF_TARGET_SPACE_NAME}")
create_repo(
    repo_id=HF_TARGET_SPACE_NAME,
    # token=HF_TOKEN,
    repo_type=HF_REPO_TYPE,
    private=False,
    space_sdk=HF_SPACE_SDK,
    exist_ok=True, # Prevent errors when same repo is re-uploaded
)

# Get the full repo name
full_hf_repo_name = get_full_repo_name(model_id=HF_TARGET_SPACE_NAME)
print(f"ℹ️ Full Hugging Face Hub repo name: {full_hf_repo_name}")

# Upload the folder
print(f"ℹ️ Uploading {LOCAL_FOLDER_TO_UPLOAD} to repo name: {full_hf_repo_name}")
folder_upload_url = upload_folder(
    repo_id=full_hf_repo_name,
    folder_path=LOCAL_FOLDER_TO_UPLOAD,
    path_in_repo=".", # upload to the root directory
    repo_type=HF_REPO_TYPE,
    commit_message="Uploading Mulitimodal Retrieval Augmented Generation System."
)

print(f"✅ Folder succesfully uploaded with commit URL: {folder_upload_url}")

ℹ️ Creating repo on HF Hub with name: multimodal_rag_system
ℹ️ Full Hugging Face Hub repo name: Saint5/multimodal_rag_system
ℹ️ Uploading setup/multimodal_rag to repo name: Saint5/multimodal_rag_system
✅ Folder succesfully uploaded with commit URL: https://huggingface.co/spaces/Saint5/multimodal_rag_system/tree/main/.


In [None]:
# # %%writefile setup/multimodal_rag/app.py
# """Gradio setup for the Multimodal RAG system."""
# import os
# import torch
# import gradio as gr
# # import gc

# # from utils import load_faiss_index, load_cache
# # from model_setup import embedding_model, model, processor
# # from main import preprocess_pdf, semantic_search, generate_answer_stream

# # torch.set_num_threads(4)  # cpu thread limit

# CACHE_DIR = "cache"
# os.makedirs(CACHE_DIR, exist_ok=True)

# INDEX_FILE = os.path.join(CACHE_DIR, "index.faiss")
# CHUNKS_FILE = os.path.join(CACHE_DIR, "chunks.json")

# # Global state shared across chats
# state = {
#     "index": None,
#     "chunks": None,
#     "pdf_path": None,
#     "image_dir": "extracted_images",
# }

# # Function to clear cache to prevent stale cache retrieval if new document is uploaded
# def clear_cache_files():
#   if os.path.exists(INDEX_FILE):
#       os.remove(INDEX_FILE)
#   if os.path.exists(CHUNKS_FILE):
#       os.remove(CHUNKS_FILE)
#   state["index"], state["chunks"] = None, None

# def handle_pdf_upload(file):
#   if file is None:
#       return "[ERROR ⚠️] No file uploaded."

#   # Save uploaded file to cache directory to ensure accessibility
#   pdf_path = os.path.join(CACHE_DIR, os.path.basename(file.name))
#   with open(pdf_path, "wb") as f_out:
#       f_out.write(file.file.read())

#   if state["pdf_path"] != pdf_path:
#       clear_cache_files()

#   state["pdf_path"] = pdf_path

#   index, chunks = preprocess_pdf(
#       file_path=state["pdf_path"],
#       image_dir=state["image_dir"],
#       embedding_model=embedding_model,
#       index_file=INDEX_FILE,
#       chunks_file=CHUNKS_FILE,
#       use_cache=True
#   )
#   state["index"], state["chunks"] = index, chunks
#   return "✅ Document processed and ready for Q&A!"

# def chat_streaming(message, history):
#   if state["index"] is None or state["chunks"] is None:
#     yield "[ERROR ⚠️] Please upload and process a PDF first."
#     return
#   retrieved_chunks = semantic_search(message, embedding_model, state["index"], state["chunks"], top_k=10)
#   for partial in generate_answer_stream(message, retrieved_chunks, model, processor):
#     yield partial


# description = """
# Remember to be specific when querying for better response.
# 📖🧐
# """
# # Gradio setup
# with gr.Blocks() as demo:
#   gr.Markdown("""## 📚Multimodal RAG System
#                 Upload a PDF (≤50 pages recommended) and ask questions about it.""")
#   with gr.Row():
#     file_input = gr.File(label="📂Upload PDF")
#     upload_button = gr.Button("🔁Process PDF")
#   upload_status = gr.Textbox(label="Upload Status", interactive=False)

#   upload_button.click(handle_pdf_upload, inputs=file_input, outputs=upload_status)

#   chat = gr.ChatInterface(
#       fn=chat_streaming,
#       type="messages",
#       title="📄😃 Ask Questions on your PDF!",
#       description=description,
#       examples=[["What is this document about?"]]
#   )
#   chat.queue()

# demo.launch()

In [None]:
# # %%writefile setup/multimodal_rag/app.py

# """Gradio setup for the Multimodal RAG system."""

# import os
# import hashlib
# import torch
# import gradio as gr
# # import gc

# from utils import load_faiss_index, load_cache
# from model_setup import embedding_model, model, processor
# from main import preprocess_pdf, semantic_search, generate_answer_stream

# torch.set_num_threads(4) # Limits to 4 threads for better performance

# device = "cuda" if torch.cuda.is_available() else "cpu"

# # Ensure cache directory exists
# CACHE_DIR = "cache"
# os.makedirs(CACHE_DIR, exist_ok=True)

# # Global state shared across chats
# state = {
#     "index": None,
#     "chunks": None,
#     "pdf_path": None,
#     "image_dir": "extracted_images",  # Default image directory
#     "index_file": None,
#     "chunks_file": None,
#     "processed_pdfs": {}, # pdf_name -> (index_file, chunks_file)
# }

# def _make_cache_names(pdf_path: str) -> tuple[str, str]:
#     """Generate unique cache file names per PDF based on hash of filename."""
#     pdf_hash = hashlib.md5(pdf_path.encode()).hexdigest()[:8]  # Shorten for readability
#     base_name = os.path.splitext(os.path.basename(pdf_path))[0]
#     index_file = os.path.join(CACHE_DIR, f"{base_name}_{pdf_hash}_index.faiss")
#     chunks_file = os.path.join(CACHE_DIR, f"{base_name}_{pdf_hash}_chunks.json")
#     return index_file, chunks_file

# def handle_pdf_upload(file):
#     if file is None:
#         return "[ERROR ⚠️] No file uploaded.", gr.update()

#     # Save uploaded file to cache directory to ensure accessibility
#     new_pdf_path = os.path.join(CACHE_DIR, os.path.basename(file.name))
#     with open(new_pdf_path, "wb") as f_out:
#       f_out.write(file.file.read())
#     state["pdf_path"] = new_pdf_path

#     # Create unique cache file names for this PDF
#     state["index_file"], state["chunks_file"] = _make_cache_names(new_pdf_path)

#     # Run preprocessing (reuse cache if exists)
#     index, chunks = preprocess_pdf(
#         file_path=state["pdf_path"],
#         image_dir=state["image_dir"],
#         embedding_model=embedding_model,
#         index_file=state["index_file"],
#         chunks_file=state["chunks_file"],
#         use_cache=True # allow cache for the PDF
#     )
#     state["index"], state["chunks"] = index, chunks
#     # gc.collect() # Free memeory after PDF processing

#     # Store in processed_pdfs for later selection
#     pdf_key = os.path.basename(new_pdf_path) # Use PDF basename as dropdown key and store full cache paths as value
#     state["processed_pdfs"][pdf_key] = (state["index_file"], state["chunks_file"])

#     return (
#         f"✅ Document '{pdf_key}' processed and ready for Q&A!",
#         gr.update(choices=list(state["processed_pdfs"].keys()), value=pdf_key)
#     )

# def handle_pdf_selection(pdf_name):
#     """Switch active PDF from dropdown."""
#     if pdf_name not in state["processed_pdfs"]:
#         return "[ERROR] Selected PDF not found in cache."

#     # state["pdf_path"] = os.path.join(CACHE_DIR, pdf_name)
#     # state["index_file"], state["chunks_file"] = state["processed_pdfs"][pdf_name]

#     # Retrieve cached full paths
#     state["index_file"], state["chunks_file"] = state["processed_pdfs"][pdf_name]

#     # Optionally reset pdf_path or keep as None if not needed
#     state["pdf_path"] = None  # Or remove if not used after upload

#     # Reload index + chunks from cache directly
#     index = load_faiss_index(state["index_file"])
#     chunks = load_cache(state["chunks_file"])

#     state["index"], state["chunks"] = index, chunks
#     return f"📂 Switched to cached PDF: {pdf_name}"

# def chat_streaming(message, history):
#     if state["index"] is None or state["chunks"] is None:
#         yield "[ERROR ⚠️] Please upload and process a PDF first."
#         return

#     # Perform semantic search
#     retrieved_chunks = semantic_search(message, embedding_model, state["index"], state["chunks"], top_k=10)
#     # gc.collect() # Free memory after semantic search

#     # Stream the answer
#     for partial in generate_answer_stream(message, retrieved_chunks, model, processor):
#         yield partial

# description = """
#  Remember to be specific when querying for better response.
#  📖🧐
# """

# with gr.Blocks() as demo:
#     gr.Markdown("""
#                 ## 📚Simple Multimodal RAG System
#                 Upload a PDF (≤50 pages recommended) and ask questions about it.
#                 Supports multiple PDFs, just upload and select from the dropdown.""")

#     with gr.Row():
#         file_input = gr.File(label="📂Upload PDF")
#         # upload_button = gr.Button("Process PDF")

#     upload_status = gr.Textbox(label="Upload Status", interactive=False)
#     pdf_selector = gr.Dropdown(label="📄 Select a Processed PDF", choices=[], interactive=True)

#     upload_button = gr.Button("Process PDF")
#     upload_button.click(handle_pdf_upload, inputs=file_input, outputs=[upload_status, pdf_selector])

#     # Switch active PDF from dropdown
#     pdf_selector.change(handle_pdf_selection, inputs=pdf_selector, outputs=upload_status)

#     chat = gr.ChatInterface(
#             fn=chat_streaming,
#             type="messages",
#             title="📄😃 Ask Questions from PDF",
#             description=description,
#             examples=[["What is this document about?"]]
#         )
#     chat.queue()

# demo.launch()