# Gemini Multimodal Example

This notebook demonstrates how to use Google Gemini models for multimodal tasks (Text + Image) using LangChain.

In [25]:
import fitz  # PyMuPDF
from langchain_core.documents import Document
from transformers import CLIPProcessor, CLIPModel
from PIL import Image
import torch
import numpy as np
from langchain_google_genai import ChatGoogleGenerativeAI
from langchain.prompts import PromptTemplate
from langchain.schema.messages import HumanMessage
from sklearn.metrics.pairwise import cosine_similarity
import os
import base64
import io
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.vectorstores import FAISS
from dotenv import load_dotenv

In [26]:
load_dotenv()

# Set up the environment for Google
os.environ["GOOGLE_API_KEY"] = os.getenv("GOOGLE_API_KEY")

### Initialize the CLIP Model for unified embeddings
clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
clip_model.eval()

Loading weights: 100%|██████████| 398/398 [00:00<00:00, 1295.82it/s, Materializing param=visual_projection.weight]                                
[1mCLIPModel LOAD REPORT[0m from: openai/clip-vit-base-patch32
Key                                  | Status     |  | 
-------------------------------------+------------+--+-
vision_model.embeddings.position_ids | UNEXPECTED |  | 
text_model.embeddings.position_ids   | UNEXPECTED |  | 

[3mNotes:
- UNEXPECTED[3m	:can be ignored when loading from different task/architecture; not ok if you expect identical arch.[0m


CLIPModel(
  (text_model): CLIPTextTransformer(
    (embeddings): CLIPTextEmbeddings(
      (token_embedding): Embedding(49408, 512)
      (position_embedding): Embedding(77, 512)
    )
    (encoder): CLIPEncoder(
      (layers): ModuleList(
        (0-11): 12 x CLIPEncoderLayer(
          (self_attn): CLIPAttention(
            (k_proj): Linear(in_features=512, out_features=512, bias=True)
            (v_proj): Linear(in_features=512, out_features=512, bias=True)
            (q_proj): Linear(in_features=512, out_features=512, bias=True)
            (out_proj): Linear(in_features=512, out_features=512, bias=True)
          )
          (layer_norm1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
          (mlp): CLIPMLP(
            (activation_fn): QuickGELUActivation()
            (fc1): Linear(in_features=512, out_features=2048, bias=True)
            (fc2): Linear(in_features=2048, out_features=512, bias=True)
          )
          (layer_norm2): LayerNorm((512,), eps=1e-05,

In [73]:
### FIXED Embedding functions
def embed_image(image_data):
    """Embed image using CLIP"""
    if isinstance(image_data, str):  # If path
        image = Image.open(image_data).convert("RGB")
    else:  # If PIL Image
        image = image_data
    
    inputs = clip_processor(images=image, return_tensors="pt")
    with torch.no_grad():
        vision_outputs = clip_model.vision_model(pixel_values=inputs['pixel_values'])
        pooled = vision_outputs.pooler_output
        features = clip_model.visual_projection(pooled)
        features = features / features.norm(dim=-1, keepdim=True)
        return features.squeeze().numpy()

In [74]:
def embed_text(text):
    """Embed text using CLIP."""
    inputs = clip_processor(
        text=text, 
        return_tensors="pt", 
        padding=True,
        truncation=True,
        max_length=77
    )
    with torch.no_grad():
        text_outputs = clip_model.text_model(
            input_ids=inputs['input_ids'],
            attention_mask=inputs['attention_mask']
        )
        pooled = text_outputs.pooler_output
        features = clip_model.text_projection(pooled)
        features = features / features.norm(dim=-1, keepdim=True)
        return features.squeeze().numpy()

In [75]:
## Process PDF
pdf_path = "multimodal_sample.pdf"
doc = fitz.open(pdf_path)

In [76]:
# Storage for all documents and embeddings
all_docs = []
all_embeddings = []
image_data_store = {}  # Store actual image data for LLM

In [77]:
# Text splitter
splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=100)

In [78]:
doc

Document('multimodal_sample.pdf')

In [79]:
for i, page in enumerate(doc):
    ## Process text
    text = page.get_text()
    if text.strip():
        # Create temporary document for splitting
        temp_doc = Document(page_content=text, metadata={"page": i, "type": "text"})
        text_chunks = splitter.split_documents([temp_doc])

        # Embed each chunk using CLIP
        for chunk in text_chunks:
            embedding = embed_text(chunk.page_content)
            all_embeddings.append(embedding)
            all_docs.append(chunk)

    ## Process images
    for img_index, img in enumerate(page.get_images(full=True)):
        try:
            xref = img[0]
            base_image = doc.extract_image(xref)
            image_bytes = base_image["image"]
            
            # Convert to PIL Image
            pil_image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
            
            # Create unique identifier
            image_id = f"page_{i}_img_{img_index}"
            
            # Store image as base64 for later use with Gemini
            buffered = io.BytesIO()
            pil_image.save(buffered, format="PNG")
            img_base64 = base64.b64encode(buffered.getvalue()).decode()
            image_data_store[image_id] = img_base64
            
            # Embed image using CLIP
            embedding = embed_image(pil_image)
            all_embeddings.append(embedding)
            
            # Create document for image
            image_doc = Document(
                page_content=f"[Image: {image_id}]",
                metadata={"page": i, "type": "image", "image_id": image_id}
            )
            all_docs.append(image_doc)
            
        except Exception as e:
            print(f"Error processing image {img_index} on page {i}: {e}")
            continue

doc.close()

In [80]:
all_docs

[Document(metadata={'page': 0, 'type': 'text'}, page_content='Annual Revenue Overview\nThis document summarizes the revenue trends across Q1, Q2, and Q3. As illustrated in the chart\nbelow, revenue grew steadily with the highest growth recorded in Q3.\nQ1 showed a moderate increase in revenue as new product lines were introduced. Q2 outperformed\nQ1 due to marketing campaigns. Q3 had exponential growth due to global expansion.'),
 Document(metadata={'page': 0, 'type': 'image', 'image_id': 'page_0_img_0'}, page_content='[Image: page_0_img_0]')]

In [81]:
# Create unified FAISS vector store with CLIP embeddings
embeddings_array = np.array(all_embeddings)

In [82]:
embeddings_array

array([[-0.00267238,  0.01282993, -0.05183155, ..., -0.00385087,
         0.02977716, -0.00010691],
       [ 0.01732346, -0.01327705, -0.02427034, ...,  0.08993971,
        -0.00272153,  0.03253055]], shape=(2, 512), dtype=float32)

In [83]:
(all_docs,embeddings_array)

([Document(metadata={'page': 0, 'type': 'text'}, page_content='Annual Revenue Overview\nThis document summarizes the revenue trends across Q1, Q2, and Q3. As illustrated in the chart\nbelow, revenue grew steadily with the highest growth recorded in Q3.\nQ1 showed a moderate increase in revenue as new product lines were introduced. Q2 outperformed\nQ1 due to marketing campaigns. Q3 had exponential growth due to global expansion.'),
  Document(metadata={'page': 0, 'type': 'image', 'image_id': 'page_0_img_0'}, page_content='[Image: page_0_img_0]')],
 array([[-0.00267238,  0.01282993, -0.05183155, ..., -0.00385087,
          0.02977716, -0.00010691],
        [ 0.01732346, -0.01327705, -0.02427034, ...,  0.08993971,
         -0.00272153,  0.03253055]], shape=(2, 512), dtype=float32))

In [84]:
# Create custom FAISS index since we have precomputed embeddings
vector_store = FAISS.from_embeddings(
    text_embeddings=[(doc.page_content, emb) for doc, emb in zip(all_docs, embeddings_array)],
    embedding=None,  # We're using precomputed embeddings
    metadatas=[doc.metadata for doc in all_docs]
)


`embedding_function` is expected to be an Embeddings object, support for passing in a function will soon be removed.


In [85]:
llm = ChatGoogleGenerativeAI(model="gemini-2.5-flash-lite")

In [86]:
def retrieve_multimodal(query, k=5):
    """Unified retrieval using CLIP embeddings for both text and images."""
    # Embed query using CLIP
    query_embedding = embed_text(query)
    
    # Search in unified vector store
    results = vector_store.similarity_search_by_vector(
        embedding=query_embedding,
        k=k
    )
    
    return results

In [87]:
def create_multimodal_message(query, retrieved_docs):
    """Create a message with both text and images for Gemini."""
    content = []
    
    # Add the query
    content.append({
        "type": "text",
        "text": f"Question: {query}\n\nContext:\n"
    })
    
    # Separate text and image documents
    text_docs = [doc for doc in retrieved_docs if doc.metadata.get("type") == "text"]
    image_docs = [doc for doc in retrieved_docs if doc.metadata.get("type") == "image"]
    
    # Add text context
    if text_docs:
        text_context = "\n\n".join([
            f"[Page {doc.metadata['page']}]: {doc.page_content}"
            for doc in text_docs
        ])
        content.append({
            "type": "text",
            "text": f"Text excerpts:\n{text_context}\n"
        })
    
    # Add images
    for doc in image_docs:
        image_id = doc.metadata.get("image_id")
        if image_id and image_id in image_data_store:
            content.append({
                "type": "text",
                "text": f"\n[Image from page {doc.metadata['page']}]:\n"
            })
            content.append({
                "type": "image_url",
                "image_url": {
                    "url": f"data:image/png;base64,{image_data_store[image_id]}"
                }
            })
    
    # Add instruction
    content.append({
        "type": "text",
        "text": "\n\nPlease answer the question based on the provided text and images."
    })
    
    return HumanMessage(content=content)

In [88]:
def multimodal_pdf_rag_pipeline(query):
    """Main pipeline for multimodal RAG with Gemini."""
    # Retrieve relevant documents
    context_docs = retrieve_multimodal(query, k=5)
    
    # Create multimodal message
    message = create_multimodal_message(query, context_docs)
    
    # Get response from Gemini
    response = llm.invoke([message])
    
    # Print retrieved context info
    print(f"\nRetrieved {len(context_docs)} documents:")
    for doc in context_docs:
        doc_type = doc.metadata.get("type", "unknown")
        page = doc.metadata.get("page", "?")
        if doc_type == "text":
            preview = doc.page_content[:100] + "..." if len(doc.page_content) > 100 else doc.page_content
            print(f"  - Text from page {page}: {preview}")
        else:
            print(f"  - Image from page {page}")
    print("\n")
    
    return response.content

from IPython.display import display, Markdown
if __name__ == "__main__":
    # Example queries
    queries = [
        "What does the chart on page 1 show about revenue trends?",
        "Summarize the main findings from the document",
        "What visual elements are present in the document?"
    ]
    
    for query in queries:
        print(f"\nQuery: {query}")
        print("-" * 70)
        answer = multimodal_pdf_rag_pipeline(query)
        
        # Display the answer with Markdown formatting
        display(Markdown(f"**Answer:**\n\n{answer}"))
        print("=" * 70)


Query: What does the chart on page 1 show about revenue trends?
----------------------------------------------------------------------

Retrieved 2 documents:
  - Text from page 0: Annual Revenue Overview
This document summarizes the revenue trends across Q1, Q2, and Q3. As illust...
  - Image from page 0




**Answer:**

The chart shows that revenue grew steadily across Q1, Q2, and Q3, with the highest growth recorded in Q3. The blue bar represents Q1, the green bar represents Q2, and the red bar represents Q3. The height of the bars indicates revenue, with the red bar (Q3) being the tallest, followed by the green bar (Q2), and then the blue bar (Q1). This visual representation confirms the text's description of steady growth and highest growth in Q3.


Query: Summarize the main findings from the document
----------------------------------------------------------------------

Retrieved 2 documents:
  - Text from page 0: Annual Revenue Overview
This document summarizes the revenue trends across Q1, Q2, and Q3. As illust...
  - Image from page 0




**Answer:**

The document summarizes revenue trends across Q1, Q2, and Q3, indicating steady growth with the highest growth in Q3. Q1 saw a moderate increase due to new product introductions. Q2 outperformed Q1, attributed to marketing campaigns. Q3 experienced exponential growth driven by global expansion.


Query: What visual elements are present in the document?
----------------------------------------------------------------------

Retrieved 2 documents:
  - Text from page 0: Annual Revenue Overview
This document summarizes the revenue trends across Q1, Q2, and Q3. As illust...
  - Image from page 0




**Answer:**

The document contains a bar chart. The bar chart has three bars of different heights and colors: blue, green, and red. The text mentions that the document summarizes revenue trends across Q1, Q2, and Q3, and that the chart illustrates this. The text also states that revenue grew steadily, with the highest growth in Q3. This suggests that the bars in the chart represent revenue for Q1, Q2, and Q3, with the blue bar likely representing Q1, the green bar representing Q2, and the red bar representing Q3, given their increasing heights.

