In [1]:
import fitz  # PyMuPDF
from langchain_core.documents import Document
from transformers import CLIPProcessor, CLIPModel
from PIL import Image
import torch 
import numpy as np
from langchain.prompts import PromptTemplate
from langchain.schema.messages import HumanMessage
from sklearn.metrics.pairwise import cosine_similarity
import os
import base64
import io
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.vectorstores import FAISS
import ollama_python as ollama
import requests
import json

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
###Clip Model
import os
from dotenv import load_dotenv
load_dotenv()

### initialize the Clip Model for unified embeddings
clip_model=CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
clip_processor=CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
clip_model.eval()


Using a slow image processor as `use_fast` is unset and a slow processor was saved with this model. `use_fast=True` will be the default behavior in v4.52, even if the model was saved with a slow processor. This will result in minor differences in outputs. You'll still be able to use a slow processor with `use_fast=False`.


CLIPModel(
  (text_model): CLIPTextTransformer(
    (embeddings): CLIPTextEmbeddings(
      (token_embedding): Embedding(49408, 512)
      (position_embedding): Embedding(77, 512)
    )
    (encoder): CLIPEncoder(
      (layers): ModuleList(
        (0-11): 12 x CLIPEncoderLayer(
          (self_attn): CLIPAttention(
            (k_proj): Linear(in_features=512, out_features=512, bias=True)
            (v_proj): Linear(in_features=512, out_features=512, bias=True)
            (q_proj): Linear(in_features=512, out_features=512, bias=True)
            (out_proj): Linear(in_features=512, out_features=512, bias=True)
          )
          (layer_norm1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
          (mlp): CLIPMLP(
            (activation_fn): QuickGELUActivation()
            (fc1): Linear(in_features=512, out_features=2048, bias=True)
            (fc2): Linear(in_features=2048, out_features=512, bias=True)
          )
          (layer_norm2): LayerNorm((512,), eps=1e-05,

In [3]:
### Embedding functions
def embed_image(image_data):
    """Embed image using CLIP"""
    if isinstance(image_data, str):  # If path
        image = Image.open(image_data).convert("RGB")
    else:  # If PIL Image
        image = image_data
    
    inputs=clip_processor(images=image,return_tensors="pt")
    with torch.no_grad():
        features = clip_model.get_image_features(**inputs)
        # Normalize embeddings to unit vector
        features = features / features.norm(dim=-1, keepdim=True)
        return features.squeeze().numpy()
    
def embed_text(text):
    """Embed text using CLIP."""
    inputs = clip_processor(
        text=text, 
        return_tensors="pt", 
        padding=True,
        truncation=True,
        max_length=77  # CLIP's max token length
    )
    with torch.no_grad():
        features = clip_model.get_text_features(**inputs)
        # Normalize embeddings
        features = features / features.norm(dim=-1, keepdim=True)
        return features.squeeze().numpy()

In [4]:
## Process PDF
pdf_path="multimodal_sample.pdf"
doc=fitz.open(pdf_path)
# Storage for all documents and embeddings
all_docs = []
all_embeddings = []
image_data_store = {}  # Store actual image data for LLM

# Text splitter
splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=100)




In [5]:
doc

Document('multimodal_sample.pdf')

In [6]:
for i,page in enumerate(doc):
    ## process text
    text=page.get_text()
    if text.strip():
        ##create temporary document for splitting
        temp_doc = Document(page_content=text, metadata={"page": i, "type": "text"})
        text_chunks = splitter.split_documents([temp_doc])

        #Embed each chunk using CLIP
        for chunk in text_chunks:
            embedding = embed_text(chunk.page_content)
            all_embeddings.append(embedding)
            all_docs.append(chunk)



    ## process images
    ##Three Important Actions:

    ##Convert PDF image to PIL format
    ##Store as base64 for GPT-4V (which needs base64 images)
    ##Create CLIP embedding for retrieval

    for img_index, img in enumerate(page.get_images(full=True)):
        try:
            xref = img[0]
            base_image = doc.extract_image(xref)
            image_bytes = base_image["image"]
            
            # Convert to PIL Image
            pil_image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
            
            # Create unique identifier
            image_id = f"page_{i}_img_{img_index}"
            
            # Store image as base64 for later use with GPT-4V
            buffered = io.BytesIO()
            pil_image.save(buffered, format="PNG")
            img_base64 = base64.b64encode(buffered.getvalue()).decode()
            image_data_store[image_id] = img_base64
            
            # Embed image using CLIP
            embedding = embed_image(pil_image)
            all_embeddings.append(embedding)
            
            # Create document for image
            image_doc = Document(
                page_content=f"[Image: {image_id}]",
                metadata={"page": i, "type": "image", "image_id": image_id}
            )
            all_docs.append(image_doc)
            
        except Exception as e:
            print(f"Error processing image {img_index} on page {i}: {e}")
            continue

doc.close()


In [7]:
all_docs

[Document(metadata={'page': 0, 'type': 'text'}, page_content='Annual Revenue Overview\nThis document summarizes the revenue trends across Q1, Q2, and Q3. As illustrated in the chart\nbelow, revenue grew steadily with the highest growth recorded in Q3.\nQ1 showed a moderate increase in revenue as new product lines were introduced. Q2 outperformed\nQ1 due to marketing campaigns. Q3 had exponential growth due to global expansion.'),
 Document(metadata={'page': 0, 'type': 'image', 'image_id': 'page_0_img_0'}, page_content='[Image: page_0_img_0]')]

In [8]:
# Create unified FAISS vector store with CLIP embeddings
embeddings_array = np.array(all_embeddings)
embeddings_array

array([[-0.00267244,  0.01282998, -0.05183141, ..., -0.00385081,
         0.02977718, -0.00010685],
       [ 0.01732335, -0.01327693, -0.02427032, ...,  0.0899405 ,
        -0.00272154,  0.03253041]], shape=(2, 512), dtype=float32)

In [9]:
(all_docs,embeddings_array)

([Document(metadata={'page': 0, 'type': 'text'}, page_content='Annual Revenue Overview\nThis document summarizes the revenue trends across Q1, Q2, and Q3. As illustrated in the chart\nbelow, revenue grew steadily with the highest growth recorded in Q3.\nQ1 showed a moderate increase in revenue as new product lines were introduced. Q2 outperformed\nQ1 due to marketing campaigns. Q3 had exponential growth due to global expansion.'),
  Document(metadata={'page': 0, 'type': 'image', 'image_id': 'page_0_img_0'}, page_content='[Image: page_0_img_0]')],
 array([[-0.00267244,  0.01282998, -0.05183141, ..., -0.00385081,
          0.02977718, -0.00010685],
        [ 0.01732335, -0.01327693, -0.02427032, ...,  0.0899405 ,
         -0.00272154,  0.03253041]], shape=(2, 512), dtype=float32))

In [10]:


# Create custom FAISS index since we have precomputed embeddings
embeddings_array = np.vstack(all_embeddings)  
vector_store = FAISS.from_embeddings(
    text_embeddings=[(doc.page_content, emb) for doc, emb in zip(all_docs, embeddings_array)],
    embedding=None,  # We're using precomputed embeddings
    metadatas=[doc.metadata for doc in all_docs]
)
vector_store

`embedding_function` is expected to be an Embeddings object, support for passing in a function will soon be removed.


<langchain_community.vectorstores.faiss.FAISS at 0x7c5325a7f170>

In [11]:

# 🆕 NEW: Ollama Client Class
class OllamaLLaVA:
    def __init__(self, model="llava:7b", base_url="http://localhost:11434"):
        self.model = model
        self.base_url = base_url
        
    def generate_with_vision(self, prompt, images=None):
        """Generate response using Ollama's LLaVA model with vision capabilities."""
        url = f"{self.base_url}/api/generate"
        
        payload = {
            "model": self.model,
            "prompt": prompt,
            "stream": False,
            "options": {
                "temperature": 0.1,
                "top_p": 0.9,
                "num_predict": 1000
            }
        }
        
        # Add images if provided (base64 format)
        if images:
            payload["images"] = images
        
        try:
            print(f"🤖 Querying {self.model}...")
            response = requests.post(url, json=payload, timeout=180)
            response.raise_for_status()
            
            result = response.json()
            return result.get("response", "No response generated")
            
        except requests.exceptions.RequestException as e:
            return f"❌ Ollama Error: {e}"
        except Exception as e:
            return f"❌ Unexpected Error: {e}"


In [12]:
# Initialize GPT-4 Vision model
llm = OllamaLLaVA(model="llava:7b")

In [13]:
def retrieve_multimodal(query, k=5):
    """Unified retrieval using CLIP embeddings for both text and images."""
    # Embed query using CLIP
    query_embedding = embed_text(query)
    
    # Search in unified vector store
    results = vector_store.similarity_search_by_vector(
        embedding=query_embedding,
        k=k
    )
    
    return results

In [14]:

# 🆕 NEW: Updated message creation for Ollama LLaVA
def create_ollama_multimodal_message(query, retrieved_docs):
    """Create a message with both text and images for Ollama LLaVA."""
    
    # Separate text and image documents
    text_docs = [doc for doc in retrieved_docs if doc.metadata.get("type") == "text"]
    image_docs = [doc for doc in retrieved_docs if doc.metadata.get("type") == "image"]
    
    # Build text context
    context_parts = []
    
    # Add text context
    if text_docs:
        text_context = "\n\n".join([
            f"📄 [Page {doc.metadata['page']}]: {doc.page_content}"
            for doc in text_docs
        ])
        context_parts.append(f"📝 **TEXT CONTENT:**\n{text_context}")
    
    # Collect images for Ollama
    images_to_analyze = []
    image_descriptions = []
    
    for doc in image_docs:
        image_id = doc.metadata.get("image_id")
        page = doc.metadata.get('page', '?')
        
        if image_id and image_id in image_data_store:
            # Add base64 image for Ollama
            images_to_analyze.append(image_data_store[image_id])
            image_descriptions.append(f"🖼️ Image from page {page} (analyzing below)")
    
    if image_descriptions:
        context_parts.append(f"📊 **VISUAL CONTENT:**\n" + "\n".join(image_descriptions))
    
    # Create comprehensive prompt for LLaVA
    full_context = "\n\n".join(context_parts)
    
    prompt = f"""🔍 **MULTIMODAL PDF ANALYSIS**

❓ **QUESTION:** {query}

📚 **DOCUMENT CONTEXT:**
{full_context}

🎯 **INSTRUCTIONS:**
1. Carefully read and understand the text content from the PDF
2. Analyze any images/charts/diagrams provided below
3. Provide a comprehensive answer combining insights from BOTH text and visual elements
4. Reference specific page numbers when mentioning information
5. If you see charts, graphs, or tables in images, describe what data they show
6. Be specific and detailed in your analysis

💡 **Your comprehensive analysis:**"""
    
    return {
        "prompt": prompt,
        "images": images_to_analyze
    }

In [15]:

# 🔄 UPDATED: Main pipeline function for Ollama
def multimodal_pdf_rag_pipeline(query):
    """Main pipeline for multimodal RAG using Ollama LLaVA."""
    
    # Retrieve relevant documents (SAME AS BEFORE ✅)
    context_docs = retrieve_multimodal(query, k=5)
    
    # Create Ollama-compatible message (NEW!)
    message_data = create_ollama_multimodal_message(query, context_docs)
    
    # Get response from Ollama LLaVA (NEW!)
    response = llm.generate_with_vision(
        prompt=message_data["prompt"],
        images=message_data["images"]
    )
    
    # Print retrieved context info (SAME AS BEFORE ✅)
    print(f"\n📋 Retrieved {len(context_docs)} documents:")
    for doc in context_docs:
        doc_type = doc.metadata.get("type", "unknown")
        page = doc.metadata.get("page", "?")
        if doc_type == "text":
            preview = doc.page_content[:100] + "..." if len(doc.page_content) > 100 else doc.page_content
            print(f"  - 📝 Text from page {page}: {preview}")
        else:
            print(f"  - 🖼️ Image from page {page}")
    print("\n")
    
    return response

In [16]:

# 🧪 Testing with enhanced queries
if __name__ == "__main__":
    # 🔧 First, check if Ollama is running
    try:
        response = requests.get("http://localhost:11434/api/tags", timeout=5)
        models = [m["name"] for m in response.json()["models"]]
        print(f"✅ Ollama is running! Available models: {models}")
        
        if "llava:7b" not in models:
            print("❌ LLaVA:7b not found! Run: ollama pull llava:7b")
            exit(1)
        else:
            print("🎉 LLaVA:7b is ready!")
            
    except Exception as e:
        print(f"❌ Ollama not running! Start with: ollama serve\nError: {e}")
        exit(1)
    
    # Example queries optimized for vision capabilities
    queries = [
        "What does the chart on page 1 show about revenue trends?",
        "Summarize the main findings from the document including visual data",
        "What visual elements are present in the document and what do they tell us?",
        "Analyze any graphs, charts, or diagrams and explain the key insights"
    ]
    
    for query in queries:
        print(f"\n🔥 **Query:** {query}")
        print("-" * 50)
        answer = multimodal_pdf_rag_pipeline(query)
        print(f"**Answer:** {answer}")
        print("=" * 70)

✅ Ollama is running! Available models: ['llava:7b']
🎉 LLaVA:7b is ready!

🔥 **Query:** What does the chart on page 1 show about revenue trends?
--------------------------------------------------
🤖 Querying llava:7b...

📋 Retrieved 2 documents:
  - 📝 Text from page 0: Annual Revenue Overview
This document summarizes the revenue trends across Q1, Q2, and Q3. As illust...
  - 🖼️ Image from page 0


**Answer:** ❌ Ollama Error: 500 Server Error: Internal Server Error for url: http://localhost:11434/api/generate

🔥 **Query:** Summarize the main findings from the document including visual data
--------------------------------------------------
🤖 Querying llava:7b...

📋 Retrieved 2 documents:
  - 📝 Text from page 0: Annual Revenue Overview
This document summarizes the revenue trends across Q1, Q2, and Q3. As illust...
  - 🖼️ Image from page 0


**Answer:** ❌ Ollama Error: 500 Server Error: Internal Server Error for url: http://localhost:11434/api/generate

🔥 **Query:** What visual elements are