### Multi-modal RAG
Many documents contain a mixture of content types, including text and images.

Yet, information captured in images is lost in most RAG applications.

With the emergence of multimodal LLMs, like GPT-4V, it is worth considering how to utilize images in RAG:

- Option 1:

Use multimodal embeddings (such as CLIP) to embed images and text

Retrieve both using similarity search

Pass raw images and text chunks to a multimodal LLM for answer synthesis

- Option 2:

Use a multimodal LLM (such as GPT-4V, LLaVA, or FUYU-8b) to produce text summaries from images

Embed and retrieve text

Pass text chunks to an LLM for answer synthesis

- Option 3

Use a multimodal LLM (such as GPT-4V, LLaVA, or FUYU-8b) to produce text summaries from images

Embed and retrieve image summaries with a reference to the raw image

Pass raw images and text chunks to a multimodal LLM for answer synthesis

In [1]:
import fitz  # PyMuPDF
from langchain_core.documents import Document
from transformers import CLIPProcessor, CLIPModel
from PIL import Image
import torch
import numpy as np
from langchain.chat_models import init_chat_model
from langchain.prompts import PromptTemplate
from langchain.schema.messages import HumanMessage
from sklearn.metrics.pairwise import cosine_similarity
import os
import base64
import io
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.vectorstores import FAISS

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
###Clip Model
import os
from dotenv import load_dotenv
load_dotenv()

## set up the environment
os.environ["OPENAI_API_KEY"]=os.getenv("OPENAI_API_KEY")

### initialize the Clip Model for unified embeddings
clip_model=CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
clip_processor=CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
clip_model.eval()

Xet Storage is enabled for this repo, but the 'hf_xet' package is not installed. Falling back to regular HTTP download. For better performance, install the package with: `pip install huggingface_hub[hf_xet]` or `pip install hf_xet`
Using a slow image processor as `use_fast` is unset and a slow processor was saved with this model. `use_fast=True` will be the default behavior in v4.52, even if the model was saved with a slow processor. This will result in minor differences in outputs. You'll still be able to use a slow processor with `use_fast=False`.


CLIPModel(
  (text_model): CLIPTextTransformer(
    (embeddings): CLIPTextEmbeddings(
      (token_embedding): Embedding(49408, 512)
      (position_embedding): Embedding(77, 512)
    )
    (encoder): CLIPEncoder(
      (layers): ModuleList(
        (0-11): 12 x CLIPEncoderLayer(
          (self_attn): CLIPAttention(
            (k_proj): Linear(in_features=512, out_features=512, bias=True)
            (v_proj): Linear(in_features=512, out_features=512, bias=True)
            (q_proj): Linear(in_features=512, out_features=512, bias=True)
            (out_proj): Linear(in_features=512, out_features=512, bias=True)
          )
          (layer_norm1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
          (mlp): CLIPMLP(
            (activation_fn): QuickGELUActivation()
            (fc1): Linear(in_features=512, out_features=2048, bias=True)
            (fc2): Linear(in_features=2048, out_features=512, bias=True)
          )
          (layer_norm2): LayerNorm((512,), eps=1e-05,

To support symlinks on Windows, you either need to activate Developer Mode or to run Python as an administrator. In order to activate developer mode, see this article: https://docs.microsoft.com/en-us/windows/apps/get-started/enable-your-device-for-development


In [4]:
### Embedding functions
def embed_image(image_data):
    """Embed image using CLIP"""
    if isinstance(image_data, str):  # If path
        image = Image.open(image_data).convert("RGB")
    else:  # If PIL Image
        image = image_data
    
    inputs=clip_processor(images=image,return_tensors="pt")
    with torch.no_grad():
        features = clip_model.get_image_features(**inputs)
        # Normalize embeddings to unit vector
        features = features / features.norm(dim=-1, keepdim=True)
        return features.squeeze().numpy()
    
def embed_text(text):
    """Embed text using CLIP."""
    inputs = clip_processor(
        text=text, 
        return_tensors="pt", 
        padding=True,
        truncation=True,
        max_length=77  # CLIP's max token length
    )
    with torch.no_grad():
        features = clip_model.get_text_features(**inputs)
        # Normalize embeddings
        features = features / features.norm(dim=-1, keepdim=True)
        return features.squeeze().numpy()

In [39]:
## Process PDF
pdf_path="multimodal_sample.pdf"
doc=fitz.open(pdf_path)
# Storage for all documents and embeddings
all_docs = []
all_embeddings = []
image_data_store = {}  # Store actual image data for LLM

# Text splitter
splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=100)

In [40]:
for i,page in enumerate(doc):
    ## process text
    text=page.get_text()
    if text.strip():
        ##create temporary document for splitting
        temp_doc = Document(page_content=text, metadata={"page": i, "type": "text"})
        text_chunks = splitter.split_documents([temp_doc])

        #Embed each chunk using CLIP
        for chunk in text_chunks:
            embedding = embed_text(chunk.page_content)
            all_embeddings.append(embedding)
            all_docs.append(chunk)



    ## process images
    ##Three Important Actions:

    ##Convert PDF image to PIL format
    ##Store as base64 for GPT-4V (which needs base64 images)
    ##Create CLIP embedding for retrieval

    for img_index, img in enumerate(page.get_images(full=True)):
        try:
            xref = img[0]
            base_image = doc.extract_image(xref)
            image_bytes = base_image["image"]
            
            # Convert to PIL Image
            pil_image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
            
            # Create unique identifier
            image_id = f"page_{i}_img_{img_index}"
            
            # Store image as base64 for later use with GPT-4V
            buffered = io.BytesIO()
            pil_image.save(buffered, format="PNG")
            img_base64 = base64.b64encode(buffered.getvalue()).decode()
            image_data_store[image_id] = img_base64
            
            # Embed image using CLIP
            embedding = embed_image(pil_image)
            all_embeddings.append(embedding)
            
            # Create document for image
            image_doc = Document(
                page_content=f"[Image: {image_id}]",
                metadata={"page": i, "type": "image", "image_id": image_id}
            )
            all_docs.append(image_doc)
            
        except Exception as e:
            print(f"Error processing image {img_index} on page {i}: {e}")
            continue

doc.close()


In [42]:
all_docs

[Document(metadata={'page': 0, 'type': 'text'}, page_content='Annual Revenue Overview\nThis document summarizes the revenue trends across Q1, Q2, and Q3. As illustrated in the chart\nbelow, revenue grew steadily with the highest growth recorded in Q3.\nQ1 showed a moderate increase in revenue as new product lines were introduced. Q2 outperformed\nQ1 due to marketing campaigns. Q3 had exponential growth due to global expansion.'),
 Document(metadata={'page': 0, 'type': 'image', 'image_id': 'page_0_img_0'}, page_content='[Image: page_0_img_0]')]

In [43]:
import os
import numpy as np
from dotenv import load_dotenv
from qdrant_client import QdrantClient
from qdrant_client.models import PointStruct, VectorParams, Distance
from langchain_openai import OpenAIEmbeddings
# from langchain_core.documents import Document # Assuming you have this import

# --- Assume your data is loaded ---
load_dotenv()
# all_docs = [Document(...), Document(...)]
# all_embeddings = [[0.1, ...], [0.2, ...]]
# ------------------------------------

# Convert your list of embeddings to a NumPy array if it isn't already
embeddings_array = np.array(all_embeddings).tolist()

# 1. Set up the Qdrant client
qdrant_client = QdrantClient(
    url=os.getenv("QDRANT_URL"),
    api_key=os.getenv("QDRANT_API_KEY"),
)

collection_name = "multimodal_sample"

# 2. ✅ ADD THIS STEP: Create the collection
# You must define the size of the vectors and the distance metric.
embedding_size = len(embeddings_array[0]) # Dynamically get the vector dimension

qdrant_client.recreate_collection(
    collection_name=collection_name,
    vectors_config=VectorParams(size=embedding_size, distance=Distance.COSINE),
)

print(f"✅ Collection '{collection_name}' created successfully.")

# 3. Prepare the points for upserting
points_to_upsert = []
for i, (doc, emb) in enumerate(zip(all_docs, embeddings_array)):
    points_to_upsert.append(
        PointStruct(
            id=i,  # It's better to use a more stable ID, like a UUID
            vector=emb,
            payload={
                "page_content": doc.page_content,
                "metadata": doc.metadata
            }
        )
    )

# 4. Upsert the points into the now-existing collection
qdrant_client.upsert(
    collection_name=collection_name,
    points=points_to_upsert,
    wait=True
)

print(f"✅ Successfully populated collection '{collection_name}'.")

  qdrant_client.recreate_collection(


✅ Collection 'multimodal_sample' created successfully.
✅ Successfully populated collection 'multimodal_sample'.


In [44]:
llm = init_chat_model("openai:gpt-4.1")
llm

ChatOpenAI(client=<openai.resources.chat.completions.completions.Completions object at 0x000002862CCAFED0>, async_client=<openai.resources.chat.completions.completions.AsyncCompletions object at 0x000002862D388F50>, root_client=<openai.OpenAI object at 0x000002862CFE7680>, root_async_client=<openai.AsyncOpenAI object at 0x000002862CFE7790>, model_name='gpt-4.1', model_kwargs={}, openai_api_key=SecretStr('**********'))

In [45]:
def retrieve_multimodal(query, k=5):
    """Unified retrieval using CLIP embeddings for both text and images."""
    # Embed query using CLIP
    query_embedding = embed_text(query)
    
    # Search in unified vector store
    search_results = qdrant_client.search(
        collection_name="multimodal_sample",
        query_vector=query_embedding,
        limit=5,  # Return the top 5 most similar results
        with_payload=True  # Include the payload in the results
    )
    
    return search_results

In [46]:
def create_multimodal_message(query, retrieved_docs):
    """Create a message with both text and images for GPT-4V."""
    content = []

    # Add the query
    content.append({
        "type": "text",
        "text": f"Question: {query}\n\nContext:\n"
    })

    # Separate text and image documents
    text_docs = [doc for doc in retrieved_docs if doc.payload.get("metadata", {}).get("type") == "text"]
    image_docs = [doc for doc in retrieved_docs if doc.payload.get("metadata", {}).get("type") == "image"]

    # Add text context
    if text_docs:
        text_context = "\n\n".join([
            f"[Page {doc.payload['metadata']['page']}]: {doc.payload['page_content']}"
            for doc in text_docs
        ])
        content.append({
            "type": "text",
            "text": f"Text excerpts:\n{text_context}\n"
        })

    # Add images
    for doc in image_docs:
        image_id = doc.payload["metadata"].get("image_id")
        if image_id and image_id in image_data_store:
            content.append({
                "type": "text",
                "text": f"\n[Image from page {doc.payload['metadata']['page']}]:\n"
            })
            content.append({
                "type": "image_url",
                "image_url": {
                    "url": f"data:image/png;base64,{image_data_store[image_id]}"
                }
            })

    # Add instruction
    content.append({
        "type": "text",
        "text": "\n\nPlease answer the question based on the provided text and images."
    })

    return HumanMessage(content=content)

def multimodal_pdf_rag_pipeline(query):
    """Main pipeline for multimodal RAG."""
    # Retrieve relevant documents
    context_docs = retrieve_multimodal(query, k=5)

    # Create multimodal message
    message = create_multimodal_message(query, context_docs)

    # Get response from GPT-4V
    response = llm.invoke([message])

    # Print retrieved context info
    print(f"\nRetrieved {len(context_docs)} documents:")
    for doc in context_docs:
        doc_type = doc.payload["metadata"].get("type", "unknown")
        page = doc.payload["metadata"].get("page", "?")
        if doc_type == "text":
            preview = doc.payload["page_content"][:100] + "..." if len(doc.payload["page_content"]) > 100 else doc.payload["page_content"]
            print(f"  - Text from page {page}: {preview}")
        else:
            print(f"  - Image from page {page}")
    print("\n")

    return response.content

In [None]:
if __name__ == "__main__":
    # Example queries
    queries = [
         "What does the chart on page 1 show about revenue trends?",
        "Summarize the main findings from the document",
        "What visual elements are present in the document?"
    ]
    
    for query in queries:
        print(f"\nQuery: {query}")
        print("-" * 50)
        answer = multimodal_pdf_rag_pipeline(query)
        print(f"Answer: {answer}")
        print("=" * 70)


Query: What does the chart on page 1 show about revenue trends?
--------------------------------------------------


  search_results = qdrant_client.search(



Retrieved 2 documents:
  - Text from page 0: Annual Revenue Overview
This document summarizes the revenue trends across Q1, Q2, and Q3. As illust...
  - Image from page 0


Answer: The chart on page 1 shows that revenue increased steadily across the three quarters. Q1 (represented by the shortest blue bar) had the lowest revenue, Q2 (green bar) saw higher revenue, and Q3 (tallest red bar) had the highest revenue and most significant growth. This visual trend confirms the text, which explains that Q1 experienced moderate growth, Q2 performed better due to marketing efforts, and Q3 experienced exponential growth because of global expansion. Overall, the company’s revenue trend is clearly upward, with the largest increase occurring in Q3.

Query: Summarize the main findings from the document
--------------------------------------------------


  search_results = qdrant_client.search(



Retrieved 2 documents:
  - Text from page 0: Annual Revenue Overview
This document summarizes the revenue trends across Q1, Q2, and Q3. As illust...
  - Image from page 0


Answer: **Summary of Main Findings:**

The document analyzes annual revenue trends across the first three quarters (Q1, Q2, Q3):

- **Q1:** Revenue saw a moderate increase, which was attributed to the introduction of new product lines.
- **Q2:** Revenue further increased, outperforming Q1 due to the impact of marketing campaigns.
- **Q3:** Revenue experienced exponential growth, the highest among all quarters, driven by global expansion.

The included bar chart visually confirms this pattern, showing revenue rising from Q1 to Q3, with the highest bar in Q3. Overall, the main finding is a trend of accelerating revenue growth across the quarters, with particularly strong performance in Q3.

Query: What visual elements are present in the document?
--------------------------------------------------


  search_results = qdrant_client.search(



Retrieved 2 documents:
  - Text from page 0: Annual Revenue Overview
This document summarizes the revenue trends across Q1, Q2, and Q3. As illust...
  - Image from page 0


Answer: **Visual elements present in the document:**

1. **Bar Chart**: The image on page 0 contains a bar chart with three vertical bars.
   - The **first bar** is colored blue.
   - The **second bar** is colored green.
   - The **third bar** is colored red.
   - The heights of the bars increase from left to right, visually representing growth across three periods.

2. **Color-Coding**: Each bar uses a distinct, bright color (blue, green, and red) to differentiate the quarters or periods being represented.

3. **White Background**: The chart is set against a plain white background, making the colored bars stand out.

**Summary:**  
The main visual element is a simple, three-bar vertical chart that uses different colors to indicate data for Q1 (blue), Q2 (green), and Q3 (red), with each bar increasing in height to 

: 