In [None]:
# --- Install dependencies ---
!pip install "colpali-engine>=0.3.0,<0.4.0" pymupdf pillow transformers==4.51.1 openai gradio --quiet

# --- Imports ---
import fitz
import torch
import requests
import gradio as gr
from PIL import Image
from colpali_engine.models import ColPali, ColPaliProcessor
from transformers import BlipProcessor, BlipForConditionalGeneration
from io import BytesIO
import traceback
import os
import numpy as np
import json

# --- Setup ---
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")

# --- Load Models ---
model_name = "vidore/colpali-v1.2"
model = ColPali.from_pretrained(model_name, torch_dtype=torch.bfloat16 if device == "cuda" else torch.float32, device_map={"": device} if device == "cuda" else None).eval()
processor = ColPaliProcessor.from_pretrained(model_name)
blip_processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
blip_model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base").to(device)

# --- Helper Functions ---
def normalize(v):
    norm = np.linalg.norm(v, axis=-1, keepdims=True)
    return v / np.clip(norm, 1e-6, None)

def pad_embedding(emb, target_len=32):
    emb = normalize(emb[:target_len])
    if emb.shape[0] < target_len:
        pad = np.zeros((target_len - emb.shape[0], emb.shape[1]), dtype=np.float32)
        emb = np.concatenate([emb, pad], axis=0)
    return emb

def extract_chunks_and_images(pdf_path, chunk_size=500):
    doc = fitz.open(pdf_path)
    text_chunks = []
    page_images = []
    for page_num, page in enumerate(doc):
        text = page.get_text("text")
        for i in range(0, len(text), chunk_size):
            chunk = text[i:i+chunk_size].strip()
            if chunk:
                text_chunks.append({"page": page_num, "text": chunk})
        pix = page.get_pixmap(dpi=150)
        img = Image.frombytes("RGB", [pix.width, pix.height], pix.samples)
        page_images.append(img)
    return text_chunks, page_images

def generate_image_captions(images):
    captions = []
    for page_num, img in enumerate(images):
        inputs = blip_processor(img, return_tensors="pt").to(device)
        with torch.no_grad():
            out = blip_model.generate(**inputs)
        caption = blip_processor.decode(out[0], skip_special_tokens=True)
        captions.append({"page": page_num, "caption": caption})
    return captions

def embed_and_push_to_vespa(text_chunks, images, captions, endpoint):
    print(f"[INFO] Embedding and pushing {len(text_chunks)} text chunks and {len(images)} images to Vespa...")
    docs = []
    try:
        texts = [c["text"] for c in text_chunks]
        text_inputs = processor.process_queries(texts)
        text_inputs = {k: v.to(device) for k, v in text_inputs.items()}
        with torch.no_grad():
            text_embs = model(**text_inputs)

        for i, (chunk, emb) in enumerate(zip(text_chunks, text_embs)):
            emb_np = emb.to(torch.float32).cpu().numpy()
            multi_vec = pad_embedding(emb_np, 32)
            doc = {
                "put": f"id:multimodal-content:multimodal_doc::{i}",
                "fields": {
                    "title": f"Text chunk page {chunk['page']}",
                    "description": chunk["text"][:200],
                    "page": chunk["page"],
                    "modality": "text",
                    "embedding": {"values": multi_vec.tolist()}
                }
            }
            docs.append({"embedding": multi_vec, "desc": chunk["text"][:200], "page": chunk["page"], "modality": "text"})
            r = requests.post(f"{endpoint}/document/v1/multimodal-content/multimodal_doc/docid/{i}", json=doc)
            print(f"Text doc {i} push: {r.status_code}")

        image_inputs = processor.process_images(images)
        image_inputs = {k: v.to(device) for k, v in image_inputs.items()}
        with torch.no_grad():
            image_embs = model(**image_inputs)

        for i, (cap, emb) in enumerate(zip(captions, image_embs)):
            emb_np = emb.to(torch.float32).cpu().numpy()
            multi_vec = pad_embedding(emb_np, 32)
            doc_id = f"img_{i}"
            doc = {
                "put": f"id:multimodal-content:multimodal_doc::{doc_id}",
                "fields": {
                    "title": f"Image page {cap['page']}",
                    "description": cap["caption"],
                    "page": cap["page"],
                    "modality": "image",
                    "embedding": {"values": multi_vec.tolist()}
                }
            }
            docs.append({"embedding": multi_vec, "desc": cap["caption"], "page": cap["page"], "modality": "image"})
            r = requests.post(f"{endpoint}/document/v1/multimodal-content/multimodal_doc/docid/{doc_id}", json=doc)
            print(f"Image doc {doc_id} push: {r.status_code}")

    except Exception as e:
        print("[ERROR] Vespa push failed:")
        traceback.print_exc()

    return docs

def encode_query(text):
    inputs = processor.process_queries([text])
    inputs = {k: v.to(device) for k, v in inputs.items()}
    with torch.no_grad():
        emb = model(**inputs)[0].to(torch.float32).cpu().numpy()
    return pad_embedding(emb, 32)

def maxsim_score(query_emb, doc_emb):
    return float(np.sum(np.max(np.matmul(query_emb, doc_emb.T), axis=1)))

def maxsim_search(query_text, docs, top_k_text=3, top_k_img=3):
    query_emb = encode_query(query_text)
    text_hits = []
    image_hits = []
    for doc in docs:
        score = maxsim_score(query_emb, doc["embedding"])
        if doc["modality"] == "text":
            text_hits.append((score, doc["desc"], doc["page"]))
        elif doc["modality"] == "image":
            image_hits.append((score, doc["desc"], doc["page"]))
    text_hits.sort(key=lambda x: -x[0])
    image_hits.sort(key=lambda x: -x[0])
    return text_hits[:top_k_text], image_hits[:top_k_img]

def save_for_generation(query, text_hits, image_hits):
    results = {
        "query": query,
        "top_texts": text_hits,
        "top_images": [{"page": page, "caption": desc} for _, desc, page in image_hits]
    }

    with open("retrieved_context.json", "w", encoding="utf-8") as f:
        json.dump(results, f, ensure_ascii=False, indent=2)

    for _, _, page in image_hits[:3]:
        if page in image_cache:
            image_cache[page].save(f"page_{page}.png")

    print("[INFO] Saved retrieved_context.json and corresponding images as page_<number>.png")

vespa_endpoint = "https://146e-2a02-ff0-c06-1e42-c845-7837-93e2-cfab.ngrok-free.app"
image_cache = {}
embedded_docs = []

with gr.Blocks(title="ColPali RAG Pipeline") as demo:
    gr.Markdown("## ColPali RAG Pipeline")

    with gr.Row():
        pdf_input = gr.File(label="Upload PDF")
        query_input = gr.Textbox(label="Enter your query")
        retrieve_btn = gr.Button("Retrieve")

    retrieved_texts = gr.Textbox(label="Top 3 Relevant Texts", lines=8)
    retrieved_images = gr.Gallery(label="Top 3 Relevant Pages", show_label=True)

    def retrieve_pipeline(pdf_file, query):
        global embedded_docs
        if not pdf_file or not query:
            return "Missing input", []
        text_chunks, page_images = extract_chunks_and_images(pdf_file.name)
        global image_cache
        image_cache = {i: img for i, img in enumerate(page_images)}
        captions = generate_image_captions(page_images)
        embedded_docs = embed_and_push_to_vespa(text_chunks, page_images, captions, vespa_endpoint)
        text_hits, image_hits = maxsim_search(query, embedded_docs, top_k_text=3, top_k_img=3)
        save_for_generation(query, text_hits, image_hits)
        texts = [desc for _, desc, _ in text_hits]
        imgs = [(image_cache[page], f"Page {page}") for _, _, page in image_hits if page in image_cache]
        return "\n---\n".join(texts), imgs

    retrieve_btn.click(
        retrieve_pipeline,
        inputs=[pdf_input, query_input],
        outputs=[retrieved_texts, retrieved_images]
    )

    demo.launch(share=True, debug=True)