### Web App

In [1]:
!pip install git+https://github.com/illuin-tech/colpali.git -q
!pip install qwen-vl-utils -q
!pip install pinecone -q

  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone


In [2]:
!pip install streamlit pdf2image pyngrok requests -q

In [3]:
!sudo apt-get install --quiet -y poppler-utils

Reading package lists...
Building dependency tree...
Reading state information...
poppler-utils is already the newest version (22.02.0-2ubuntu0.7).
0 upgraded, 0 newly installed, 0 to remove and 34 not upgraded.


In [4]:
!pip install chromadb pillow pydantic -q

In [5]:
!ngrok config add-authtoken 2wOlwfSCyrR717tnSJ9QLwrzFPx_6FEiQuNMhmGtQRFnkyie1

Authtoken saved to configuration file: /root/.config/ngrok/ngrok.yml


In [6]:
%%writefile MultiModal-RAG.py

from transformers.utils.import_utils import is_flash_attn_2_available
from colpali_engine.models import BiQwen2_5, BiQwen2_5_Processor
from transformers import AutoProcessor, AutoModelForVision2Seq
from PIL import Image

import requests
import matplotlib.pyplot as plt
from pdf2image import convert_from_path, convert_from_bytes
import streamlit as st
import os

import torch
from torch.utils.data import Dataset, DataLoader
from pinecone import Pinecone, ServerlessSpec

import textwrap
from matplotlib import gridspec
import numpy as np

st.set_page_config(layout="wide")
st.title("PDF -- RAG")

@st.cache_resource
def load_models():
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model_name = "nomic-ai/nomic-embed-multimodal-3b"

    model = BiQwen2_5.from_pretrained(
        model_name,
        torch_dtype=torch.bfloat16,
        device_map="cuda:0",
        attn_implementation="flash_attention_2" if is_flash_attn_2_available() else None,
    ).eval()

    processor = BiQwen2_5_Processor.from_pretrained(model_name)

    vlm_processor = AutoProcessor.from_pretrained("HuggingFaceTB/SmolVLM-256M-Instruct")
    vlm_model = AutoModelForVision2Seq.from_pretrained(
        "HuggingFaceTB/SmolVLM-256M-Instruct"
    ).to(device)

    return model, processor, vlm_model, vlm_processor


def load_pdf():
    PRELOADED_PDFS = {
        "1": {"title": "Attention Is All You Need", "file": "Attention Is All You Need.pdf"},
        "2": {"title": "Deep Residual Learning", "file": "Deep Residual Learning.pdf"}
    }

    left, middle, right = st.columns(3)

    if left.button("Attention Is All You Need", use_container_width=True):
        st.session_state.choice = "Attention Is All You Need"

    if middle.button("Deep Residual Learning", use_container_width=True):
        st.session_state.choice = "Deep Residual Learning"

    if right.button("BYOP", use_container_width=True):
        st.session_state.choice = "BYOP"

    choice = st.session_state.get("choice", None)
    st.write("You selected:", choice)

    selected = None
    for key, pdf in PRELOADED_PDFS.items():
        if pdf["title"] == choice:
            selected = pdf
            break

    if selected:
        images = convert_from_path(selected["file"])
        pdf_data = {
            "title": selected["title"],
            "file": selected["file"],
            "images": images,
        }
    elif choice == 'BYOP':
        path_or_url = st.text_input("Enter local PDF path or direct PDF URL: ").strip()
        pdf_data = None
        if path_or_url:
            try:
                if path_or_url.lower().startswith("http"):
                    response = requests.get(path_or_url)
                    response.raise_for_status()
                    pdf_bytes = response.content
                    title = os.path.splitext(os.path.basename(path_or_url))[0]
                else:
                    if not os.path.exists(path_or_url) or not path_or_url.endswith(".pdf"):
                        raise ValueError("Invalid file path.")
                    with open(path_or_url, "rb") as f:
                        pdf_bytes = f.read()
                    title = os.path.splitext(os.path.basename(path_or_url))[0]

                images = convert_from_bytes(pdf_bytes)
                st.write(f"\nYou loaded: {title}")
                pdf_data = {
                    "title": title,
                    "file": path_or_url,
                    "images": images,
                }
            except Exception as e:
                st.text(f"----- Failed to load PDF: {e} -----")
    else:
        pdf_data = {
            "title": None,
            "file": None,
            "images": None,
        }

    return pdf_data


def display_pdf_images(images_list):
    num_images = len(images_list)
    num_rows = num_images // 5 + (1 if num_images % 5 > 0 else 0)
    fig, axes = plt.subplots(num_rows, 5, figsize=(20, 4 * num_rows))
    if num_rows == 1:
        axes = [axes] if num_images == 1 else axes
    else:
        axes = axes.flatten()
    for i, img in enumerate(images_list):
        if i < len(axes):
            ax = axes[i]
            ax.imshow(img)
            ax.set_title(f"Page {i+1}")
            ax.axis('off')
    for j in range(num_images, len(axes)):
        axes[j].axis('off')
    plt.tight_layout()
    st.pyplot(fig)


def create_pine_index():
    key = 'pcsk_2xTq6Y_SBbnbbAFXv9hfL6j7pxyvhgyV7w1iSR5h8CUG4emtaqRizX9cvp8G1o95iw6oTk'
    pc = Pinecone(api_key=key)
    index_name = "rag-multimodal"

    if index_name not in pc.list_indexes().names():
        pc.create_index(
            name=index_name,
            dimension=2048,
            metric="cosine",
            spec=ServerlessSpec(cloud="aws", region="us-east-1")
        )

    index = pc.Index(index_name)
    return pc, index, index_name


def generate_pdf_embeddings(PDFs, processor, model, index, image_save_dir="saved_images"):
    if not os.path.exists(image_save_dir):
        os.makedirs(image_save_dir)

    image_counter = 0
    for pdf in PDFs:
        pdf['page_embeddings'] = []

        for i, image in enumerate(pdf["images"]):
            inputs = processor.process_images([image])
            inputs = {k: v.to(model.device) for k, v in inputs.items()}
            with torch.no_grad():
                embeddings = model(**inputs)
            embeddings = embeddings.cpu()
            del inputs
            embeddings = embeddings / torch.norm(embeddings, dim=1, keepdim=True)
            torch.cuda.empty_cache()

            unique_id = f"{pdf['title']}_page_{i+1}"
            image_filename = f"{pdf['title']}_page_{i+1}.jpg"
            image_filepath = os.path.join(image_save_dir, image_filename)
            image.save(image_filepath)

            metadata = {
                "title": pdf['title'],
                "page_number": i + 1,
                "file": pdf['file'],
                "image": image_filepath
            }

            index.upsert(vectors=[(unique_id, embeddings[0].tolist(), metadata)])
            image_counter += 1

    print(f"Generated embeddings for {image_counter} pages and saved images locally.")


def retrieve(query, index, model, processor, k=3):
    q_inputs = processor.process_queries([query])
    with torch.no_grad():
        q_inputs = {k: v.to(model.device) for k, v in q_inputs.items()}
        q_emb = model(**q_inputs).float().cpu().numpy()
    q_emb /= np.linalg.norm(q_emb)

    results = index.query(vector=[q_emb.tolist()], top_k=k, include_metadata=True)

    retrieved = []
    for match in results["matches"]:
        md = match["metadata"]
        try:
            image = Image.open(md["image"])
        except Exception as e:
            print(f"⚠️ could not load image {md['image']}: {e}")
            image = None

        retrieved.append({
            "title": md["title"],
            "file": md["file"],
            "page_number": md["page_number"],
            "score": match["score"],
            "image": image
        })

    return retrieved


def query_vlm(query, vlm_model, vlm_processor, images):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    message_content = [{"type": "image"} for _ in images]
    message_content.append({"type": "text", "text": query})

    messages = [
        {
            "role": "system",
            "content": "You are an expert professional PDF analyst who gives rigorous in-depth answers. \
            Analyse the context you've been given and give answer accordingly like an expert with relevant information."
        },
        {
            "role": "user",
            "content": message_content
        }
    ]

    prompt = vlm_processor.apply_chat_template(messages, add_generation_prompt=True)
    inputs = vlm_processor(text=prompt, images=images, return_tensors="pt").to(device)

    with torch.no_grad():
        generated_ids = vlm_model.generate(**inputs, max_new_tokens=1000)

    response = vlm_processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
    return response


def plot_rag_result(query, answer, images):
    wrapped_query = '\n'.join(textwrap.wrap(query, width=70))
    num_images = len(images)
    fig = plt.figure(figsize=(14, 10))
    outer = gridspec.GridSpec(1, 2, width_ratios=[1, 1], wspace=0.1)

    if num_images == 1:
        ax1 = fig.add_subplot(outer[0])
        ax1.imshow(images[0])
        ax1.axis('off')
        ax1.set_title("Source Document\nretrieved by Nomic Embed Multimodal", fontsize=12, fontweight='bold', loc='left', pad=0)
    else:
        left = gridspec.GridSpecFromSubplotSpec(1, 1, subplot_spec=outer[0])
        fig.text(0.1, 0.9, "Source Documents\nretrieved by: Nomic Embed Multimodal", fontsize=12, fontweight='bold', va='top', ha='left')
        cols = min(int(np.ceil(np.sqrt(num_images))), 3)
        rows = int(np.ceil(num_images / cols))
        inner = gridspec.GridSpecFromSubplotSpec(rows, cols, subplot_spec=left[0], wspace=0.05, hspace=0.05)
        for i, image in enumerate(images):
            ax_sub = fig.add_subplot(inner[i])
            ax_sub.imshow(image)
            ax_sub.axis('off')

    ax2 = fig.add_subplot(outer[1])
    ax2.axis('off')
    ax2.set_title("Answer generated by SMOLVLM-256M-Instruct", fontsize=12, fontweight='bold', loc='left')
    wrapped_answer = '\n'.join(['\n'.join(textwrap.wrap(line, width=80)) for line in answer.split('\n')])
    fontsize = min(9, max(4, 9 - ((len(wrapped_answer) - 500) // 1000)))
    ax2.text(0.02, 0.97, wrapped_answer, transform=ax2.transAxes, fontsize=fontsize, verticalalignment='top',
             bbox=dict(boxstyle='round', facecolor='white', alpha=0.9, edgecolor='#2C3E50', linewidth=2, pad=1.0))
    fig.suptitle(f"Query: {wrapped_query}", fontsize=14, fontweight='bold', y=0.96)
    st.pyplot(fig)


def delete_index(index_name, pc):
    pc.delete_index(index_name)
    st.write(f"Index {index_name} deleted.")

###################################################################################################
# ✅ State initialization for control
if 'choice' not in st.session_state:
    st.session_state.choice = None
if 'embeddings_created' not in st.session_state:
    st.session_state.embeddings_created = False
if 'last_choice' not in st.session_state:
    st.session_state.last_choice = None
if st.session_state.last_choice != st.session_state.choice:
    st.session_state.embeddings_created = False
    st.session_state.last_choice = st.session_state.choice

model, processor, vlm_model, vlm_processor = load_models()
st.subheader("Model Loaded")

pdf_data = load_pdf()
PDFs = [pdf_data]

if pdf_data and pdf_data['images']:
    imagesss, delete_button = st.columns([1, 0.2])
    pc, index, index_name = create_pine_index()

    with imagesss:
        with st.expander("First 5 Pages"):
            display_pdf_images(PDFs[0]["images"][:5])

    with delete_button:
        st.button("Delete Index", on_click=delete_index, args=(index_name, pc))

    # ✅ Generate embeddings once
    if not st.session_state.embeddings_created:
        st.write("Creating Embeddings")
        generate_pdf_embeddings(PDFs, processor, model, index)
        st.session_state.embeddings_created = True

    doc_query = st.text_input("Enter your query")

    if doc_query:
        doc_rag_results = retrieve(doc_query, index, model, processor, k=1)
        doc_image = doc_rag_results[0]["image"]
        doc_answer = query_vlm(doc_query, vlm_model, vlm_processor, [doc_image])
        plot_rag_result(doc_query, doc_answer, [doc_image])

Overwriting MultiModal-RAG.py


In [7]:
from pyngrok import ngrok

ngrok.kill()

!streamlit run MultiModal-RAG.py &> streamlit_log.txt &

public_url = ngrok.connect(8501)
print("Streamlit app is live at:", public_url)

Streamlit app is live at: NgrokTunnel: "https://8a73-34-126-147-217.ngrok-free.app" -> "http://localhost:8501"


### Debugging

In [19]:
from pinecone import Pinecone, ServerlessSpec
key = 'pcsk_2xTq6Y_SBbnbbAFXv9hfL6j7pxyvhgyV7w1iSR5h8CUG4emtaqRizX9cvp8G1o95iw6oTk'
pc = Pinecone(api_key=key)

index_name = "rag-multimodal"
if index_name not in pc.list_indexes().names():
    pc.create_index(
      name=index_name,
      dimension=2048,
      metric="cosine",
      spec=ServerlessSpec(cloud="aws", region="us-east-1")
      )

index = pc.Index(index_name)

In [59]:
index.describe_index_stats()

{'dimension': 2048,
 'index_fullness': 0.0,
 'metric': 'cosine',
 'namespaces': {'': {'vector_count': 12}},
 'total_vector_count': 12,
 'vector_type': 'dense'}

In [60]:
pc.list_indexes()

[
    {
        "name": "rag-multimodal",
        "metric": "cosine",
        "host": "rag-multimodal-ejzuolb.svc.aped-4627-b74a.pinecone.io",
        "spec": {
            "serverless": {
                "cloud": "aws",
                "region": "us-east-1"
            }
        },
        "status": {
            "ready": true,
            "state": "Ready"
        },
        "vector_type": "dense",
        "dimension": 2048,
        "deletion_protection": "disabled",
        "tags": null
    }
]

In [61]:
index_name = "rag-multimodal"
pc.delete_index(index_name)
pc.list_indexes()

[]