### Multimodal RAG Agent

In [1]:
import pymupdf
from langchain_core.documents import Document
from transformers import CLIPProcessor, CLIPModel
from PIL import Image
import numpy as np
from langchain.chat_models import init_chat_model
from langchain.prompts import PromptTemplate
from langchain.schema.messages import HumanMessage
from sklearn.metrics.pairwise import cosine_similarity
import os
import base64
import io
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.vectorstores import FAISS


  from .autonotebook import tqdm as notebook_tqdm


In [4]:
## HUgging face setup (CLIP model)

from dotenv import load_dotenv
load_dotenv()

os.environ["HUGGING_FACE_API_KEY"] = os.getenv("HUGGING_FACE_API_KEY")

# init clip model for unified embeddings
clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32") # used to create text + image into embeddings
clip_processor = CLIPProcessor.from_pretrained('openai/clip-vit-base-patch32') # to structure the input format for CLIP model

clip_model.eval()


To support symlinks on Windows, you either need to activate Developer Mode or to run Python as an administrator. In order to activate developer mode, see this article: https://docs.microsoft.com/en-us/windows/apps/get-started/enable-your-device-for-development


CLIPModel(
  (text_model): CLIPTextTransformer(
    (embeddings): CLIPTextEmbeddings(
      (token_embedding): Embedding(49408, 512)
      (position_embedding): Embedding(77, 512)
    )
    (encoder): CLIPEncoder(
      (layers): ModuleList(
        (0-11): 12 x CLIPEncoderLayer(
          (self_attn): CLIPSdpaAttention(
            (k_proj): Linear(in_features=512, out_features=512, bias=True)
            (v_proj): Linear(in_features=512, out_features=512, bias=True)
            (q_proj): Linear(in_features=512, out_features=512, bias=True)
            (out_proj): Linear(in_features=512, out_features=512, bias=True)
          )
          (layer_norm1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
          (mlp): CLIPMLP(
            (activation_fn): QuickGELUActivation()
            (fc1): Linear(in_features=512, out_features=2048, bias=True)
            (fc2): Linear(in_features=2048, out_features=512, bias=True)
          )
          (layer_norm2): LayerNorm((512,), eps=1e

In [10]:
import torch

## Helper functions to create embeddings
def embed_image(image_data):
    # if path -> fetch the image and convert
    if isinstance(image_data, str):
        image = Image.open(image_data).convert("RGB")
    else: image = image_data

    inputs = clip_processor(images=image, return_tensors='pt')
    with torch.no_grad():
        features = clip_model.get_image_features(**inputs)
        # normalize embeddings to unit vectors
        features = features / features.norm(dim=-1, keepdim=True)
        return features.squeeze().numpy() # return in  numpy array
        
def embed_text(data):
    inputs = clip_processor(
        text=data,
        return_tensors='pt',
        padding=True,
        truncation=True,
        max_length = 77 # CLIP's max token length
    )
    with torch.no_grad():
        features = clip_model.get_text_features(**inputs)
        # normalize embeddings
        features = features / features.norm(dim= -1, keepdim=True)
        return features.squeeze().numpy()


In [11]:
## process the pdf
pdf_path = './multimodal_sample.pdf'
doc = pymupdf.open(pdf_path)

all_docs = [] # docs will be stored here
all_embeddings = [] # embeddings will be stored here
image_data_store = {} # store actual image data

text_splitter = RecursiveCharacterTextSplitter(chunk_size = 500, chunk_overlap = 100)


In [12]:
for idx, page in enumerate(doc):
    ## process text
    text = page.get_text()
    if text.strip():
        temp_doc_for_splitting = Document(page_content=text, metadata={'page': idx, 'type': 'text'})
        text_chunks = text_splitter.split_documents([temp_doc_for_splitting])

        # embed chunk using clip
        for chunk in text_chunks:
            embeddings = embed_text(chunk.page_content)
            all_embeddings.append(embeddings)
            all_docs.append(chunk)

    ## process image
    for img_idx, img in enumerate(page.get_images(full=True)):
        try:
            xref = img[0]
            base_image = doc.extract_image(xref)
            image_bytes = base_image['image']
            
            # s1: convert pdf iamge to PIL format
            PIL_image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
            
            image_id = f"page_{idx}_image_{img_idx}"

            # s2: store as base64 for LLM
            buffered = io.BytesIO()
            PIL_image.save(buffered, format='PNG')
            img_base64 = base64.b64encode(buffered.getvalue()).decode()
            image_data_store[image_id] = img_base64

            # s3: create clip embeddings for retrieval
            embeddings = embed_image(PIL_image)
            all_embeddings.append(embeddings)
            
            image_doc = Document(
                page_content = f'[Image: {image_id}]',
                metadata = {'page': idx, "type": "image", "image_id": image_id}
            )
            all_docs.append(image_doc)

        except Exception as e:
            print(f"Error processing image {img_idx} on page {idx}: {e}")
            continue

doc.close()


In [13]:
all_embeddings


[array([-2.67243991e-03,  1.28299948e-02, -5.18313870e-02,  4.14879471e-02,
        -2.33941860e-02, -7.55864428e-03, -3.67659405e-02,  1.19710669e-01,
         8.52081329e-02,  2.05423264e-03, -1.11534819e-02, -1.29592139e-02,
         5.25014736e-02, -3.65396030e-03,  4.76078615e-02,  1.58372764e-02,
         2.03388259e-02,  4.35362570e-02, -3.29167186e-03,  2.03181785e-02,
         1.88023411e-03, -4.23493721e-02,  5.44102443e-03,  3.70935947e-02,
        -1.65622663e-02,  6.48645638e-03, -4.78012040e-02,  8.67484324e-03,
         5.88859580e-02, -3.21394317e-02,  4.32439968e-02,  9.65300482e-03,
        -4.47920570e-03, -1.94858182e-02, -3.63502875e-02, -1.23471767e-02,
        -2.17929184e-02, -1.99016239e-02,  8.09619799e-02, -3.32986489e-02,
        -2.38901116e-02, -3.96138951e-02, -1.27279749e-02,  3.50380726e-02,
        -2.52217241e-02,  2.00030743e-03,  1.49660828e-02, -2.31976565e-02,
        -6.86791465e-02, -5.25778159e-04, -2.22545806e-02, -1.04103768e-02,
        -1.9

In [14]:
all_docs


[Document(metadata={'page': 0, 'type': 'text'}, page_content='Annual Revenue Overview\nThis document summarizes the revenue trends across Q1, Q2, and Q3. As illustrated in the chart\nbelow, revenue grew steadily with the highest growth recorded in Q3.\nQ1 showed a moderate increase in revenue as new product lines were introduced. Q2 outperformed\nQ1 due to marketing campaigns. Q3 had exponential growth due to global expansion.'),
 Document(metadata={'page': 0, 'type': 'image', 'image_id': 'page_0_image_0'}, page_content='[Image: page_0_image_0]')]

In [15]:
## creating vector store
embeddings_array = np.array(all_embeddings)

vector_store = FAISS.from_embeddings(
    text_embeddings=[(doc.page_content, emb) for doc, emb in zip(all_docs, embeddings_array)],
    embedding=None, # using pre-computed embeddings
    metadatas=[doc.metadata for doc in all_docs]
)

vector_store


`embedding_function` is expected to be an Embeddings object, support for passing in a function will soon be removed.


<langchain_community.vectorstores.faiss.FAISS at 0x1fc0a2fe410>

In [27]:
os.environ["COHERE_API_KEY"] = os.getenv("COHERE_API_KEY")

llm = init_chat_model('cohere: command-r') # it's not a multimodal so this llm won't work

llm


ChatCohere(client=<cohere.client.Client object at 0x000001FC0CE66010>, async_client=<cohere.client.AsyncClient object at 0x000001FC0C8B2050>, model=' command-r', cohere_api_key=SecretStr('**********'))

In [None]:
# Replace the Cohere model with OpenAI's multimodal model
os.environ["OPENAI_API_KEY"] = os.getenv("OPENAI_API_KEY")

llm = init_chat_model('openai:gpt-4-vision-preview')
# or use: llm = init_chat_model('openai:gpt-4o')


In [28]:
# helper function to retrieve relevant info from vector store
def retrieve(query, k=5):
    query_embeddings = embed_text(query)
    
    # search vector store
    results = vector_store.similarity_search_by_vector(
        embedding=query_embeddings,
        k = k # count of vectors to fetch
    )
    
    return results


In [29]:
def prompt_template(query, retrieved_docs):
    content = []
    
    content.append({
        'type': 'text',
        'text': f"Question: {query} \n\n context: \n"
    })
    
    # differentiate the text and image docs
    text_docs = [doc for doc in retrieved_docs if doc.metadata.get('type') == 'text']
    image_docs = [doc for doc in retrieved_docs if doc.metadata.get('type') == 'image']
    
    # define context
    if text_docs:
        text_context = "\n\n".join([
            f"[Page {doc.metadata['page']}]: {doc.page_content}"
            for doc in text_docs
        ])
        content.append({
            'type': "text",
            'text': f"text excerpts: \n{text_context}\n"
        })
        
    # provide images
    for doc in image_docs:
        image_id = doc.metadata.get('image_id')
        if image_id and image_id in image_data_store:
            content.append({
                'type': 'text',
                'text': f"\n[Image from page {doc.metadata['page']}]: \n"
            })
            content.append({
                'type': 'image_url',
                'image_url': {
                    'url': f'data: image/png; base64, {image_data_store[image_id]}'
                }
            })
            
    content.append({
        'type': 'text',
        'text': '\n\n Please answer the question based on the provided text and images.'
    })
    
    return HumanMessage(content=content)


In [30]:
def prompt_template_for_cohere(query, retrieved_docs):
    # Build a text-only prompt since Cohere doesn't support multimodal input
    prompt_parts = [f"Question: {query}\n\nContext:\n"]
    
    # differentiate the text and image docs
    text_docs = [doc for doc in retrieved_docs if doc.metadata.get('type') == 'text']
    image_docs = [doc for doc in retrieved_docs if doc.metadata.get('type') == 'image']
    
    # Add text context
    if text_docs:
        text_context = "\n\n".join([
            f"[Page {doc.metadata['page']}]: {doc.page_content}"
            for doc in text_docs
        ])
        prompt_parts.append(f"Text excerpts:\n{text_context}\n")
    
    # Add image descriptions (since we can't send actual images to Cohere)
    if image_docs:
        image_descriptions = []
        for doc in image_docs:
            image_descriptions.append(f"[Image on page {doc.metadata['page']}]")
        prompt_parts.append(f"Images found:\n" + "\n".join(image_descriptions) + "\n")
    
    prompt_parts.append("\nPlease answer the question based on the provided text context and knowledge of any images mentioned.")
    
    # Return a simple HumanMessage with text content only
    return HumanMessage(content="".join(prompt_parts))


In [31]:
def rag_pipeline(query):
    # retrieve relevant docs
    context_docs = retrieve(query, k = 5)
    
    # create message
    # message = prompt_template(query, context_docs)
    message = prompt_template_for_cohere(query, context_docs)
    
    # get response
    response = llm.invoke([message])
    
    print(f'\n REtrieved {len(context_docs)} documents: ')
    for doc in context_docs:
        doc_type = doc.metadata.get("type", 'unknown')
        page = doc.metadata.get('page', '?')
        if doc_type == 'text':
            preview = doc.page_content[:100] + "..." if len(doc.page_content) > 100 else doc.page_content
            print(f" - text from page {page}: {preview}")
        else:
            print(f" - image from page {page}")
    print('\n')
    
    return response.content


In [32]:
if __name__ == "__main__":
    # Example queries
    queries = [
        "What does the chart on page 1 show about revenue trends?",
        "Summarize the main findings from the document",
        "What visual elements are present in the document?"
    ]
    
    for query in queries:
        print(f"\nQuery: {query}")
        print("-" * 50)
        answer = rag_pipeline(query)
        print(f"Answer: {answer}")
        print("=" * 70)



Query: What does the chart on page 1 show about revenue trends?
--------------------------------------------------


NotFoundError: status_code: 404, body: {'id': '5292f0fa-7e36-4eb4-9da2-19a0d7fc8f05', 'message': "model ' command-r' not found, make sure the correct model ID was used and that you have access to the model."}