In [1]:
!pip install torch transformers pillow PyMuPDF anthropic streamlit pyngrok
!pip install colpali-engine==0.1.1

from pyngrok import ngrok, conf
import getpass

ngrok_token = getpass.getpass("Enter your ngrok auth token: ")
ngrok.set_auth_token(ngrok_token)

streamlit_app_code = '''
import streamlit as st
import torch
import fitz  # PyMuPDF
from PIL import Image
from torch.utils.data import DataLoader
import anthropic
import base64
import io
from transformers import AutoProcessor
from colpali_engine.models.paligemma_colbert_architecture import ColPali
from colpali_engine.utils.colpali_processing_utils import process_images, process_queries
import os
import numpy as np

# Simple custom evaluator optimized for high GPU memory
class SimpleEvaluator:
    def __init__(self, is_multi_vector=True, device="cuda"):
        self.is_multi_vector = is_multi_vector
        self.device = device

    def evaluate(self, query_embeddings, doc_embeddings):
        """ColPali multi-vector similarity evaluation - GPU optimized"""
        # All operations on GPU, convert to float32 to avoid BFloat16 issues
        query_emb = torch.stack(query_embeddings).float().to(self.device)
        doc_emb = torch.stack(doc_embeddings).float().to(self.device)

        # ColPali multi-vector retrieval on GPU
        with torch.no_grad():
            scores = []
            for q_idx in range(query_emb.shape[0]):
                q = query_emb[q_idx]  # [seq_len, dim]

                # Normalize query
                q_norm = torch.nn.functional.normalize(q, p=2, dim=-1)

                # Batch process all documents at once
                d_norm = torch.nn.functional.normalize(doc_emb, p=2, dim=-1)

                # Compute similarity: [num_docs, query_seq_len, doc_seq_len]
                sim_matrix = torch.matmul(
                    q_norm.unsqueeze(0),  # [1, query_seq_len, dim]
                    d_norm.transpose(-2, -1)  # [num_docs, dim, doc_seq_len]
                )

                # Max pooling: [num_docs, query_seq_len]
                max_sim_per_query_token = torch.max(sim_matrix, dim=-1)[0]

                # Average over query tokens: [num_docs]
                doc_scores = torch.mean(max_sim_per_query_token, dim=-1)

                # Convert to float32 before numpy conversion
                scores.append(doc_scores.float().cpu().numpy())

        return np.array(scores)

class ColPaliRAGChain:
    def __init__(self, anthropic_api_key, use_gpu_storage=True):
        self.anthropic_client = anthropic.Anthropic(api_key=anthropic_api_key)
        self.model = None
        self.processor = None
        self.retriever_evaluator = None
        self.document_embeddings = []
        self.document_metadata = []
        self.is_initialized = False
        self.use_gpu_storage = use_gpu_storage
        self.device = "cuda"

    def initialize_colpali(self):
        """Initialize ColPali model"""
        if self.is_initialized:
            return True

        with st.spinner("Loading ColPali model..."):
            model_name = "vidore/colpali"
            self.model = ColPali.from_pretrained(
                "vidore/colpaligemma-3b-mix-448-base",
                torch_dtype=torch.bfloat16,
                device_map="cuda"
            ).eval()
            self.model.load_adapter(model_name)
            self.processor = AutoProcessor.from_pretrained(model_name)
            self.retriever_evaluator = SimpleEvaluator(is_multi_vector=True, device=self.device)
            self.is_initialized = True
        return True

    def pdf_to_patches(self, pdf_file, patch_size=(448, 448)):
        """Convert PDF file to patches"""
        pdf_file.seek(0)
        pdf_bytes = pdf_file.read()
        doc = fitz.open(stream=pdf_bytes, filetype="pdf")

        all_patches = []
        patch_metadata = []

        progress_bar = st.progress(0)
        total_pages = len(doc)

        for page_num in range(total_pages):
            page = doc[page_num]

            # Convert page to image
            mat = fitz.Matrix(2, 2)
            pix = page.get_pixmap(matrix=mat)
            img = Image.frombytes("RGB", [pix.width, pix.height], pix.samples)

            # Create exactly 1024 patches (32x32 grid)
            patches = self.create_patches(img, patch_size)

            for i, patch in enumerate(patches):
                all_patches.append(patch)
                patch_metadata.append({
                    'page_num': page_num,
                    'patch_id': i,
                    'original_image': img,
                    'grid_row': i // 32,
                    'grid_col': i % 32
                })

            progress_bar.progress((page_num + 1) / total_pages)

        doc.close()
        return all_patches, patch_metadata

    def create_patches(self, image, patch_size=(448, 448), grid_size=32):
        """Create exactly 1024 patches (32x32 grid) from image"""
        patches = []
        width, height = image.size

        # Calculate patch dimensions based on image size
        patch_width = width // grid_size
        patch_height = height // grid_size

        for row in range(grid_size):
            for col in range(grid_size):
                # Calculate patch coordinates
                left = col * patch_width
                top = row * patch_height
                right = min(left + patch_width, width)
                bottom = min(top + patch_height, height)

                # Extract patch
                patch = image.crop((left, top, right, bottom))

                # Resize to standard patch size for ColPali
                patch_resized = patch.resize(patch_size, Image.Resampling.LANCZOS)
                patches.append(patch_resized)

        return patches

    def embed_documents(self, patches, patch_metadata):
        """Embed document patches using ColPali"""
        if not self.is_initialized:
            self.initialize_colpali()

        with st.spinner("Creating embeddings..."):
            dataloader = DataLoader(
                patches,
                batch_size=4,
                shuffle=False,
                collate_fn=lambda x: process_images(self.processor, x),
            )

            embeddings = []
            progress_bar = st.progress(0)
            total_batches = len(dataloader)

            for i, batch_doc in enumerate(dataloader):
                with torch.no_grad():
                    batch_doc = {k: v.to(self.model.device) for k, v in batch_doc.items()}
                    embeddings_doc = self.model(**batch_doc)

                # Store on GPU
                if self.use_gpu_storage:
                    embeddings.extend(list(torch.unbind(embeddings_doc)))  # Keep on GPU
                else:
                    embeddings.extend(list(torch.unbind(embeddings_doc.to("cpu"))))

                progress_bar.progress((i + 1) / total_batches)

            self.document_embeddings = embeddings
            self.document_metadata = patch_metadata

            storage_location = "GPU" if self.use_gpu_storage else "CPU"
            st.success(f"Created {len(embeddings)} embeddings stored in {storage_location} memory")

        return embeddings

    def embed_query(self, query_text):
        """Embed a single query"""
        if not self.is_initialized:
            self.initialize_colpali()

        dataloader = DataLoader(
            [query_text],
            batch_size=1,
            shuffle=False,
            collate_fn=lambda x: process_queries(
                self.processor,
                x,
                Image.new("RGB", (448, 448), (255, 255, 255))
            ),
        )

        for batch_query in dataloader:
            with torch.no_grad():
                batch_query = {k: v.to(self.model.device) for k, v in batch_query.items()}
                embeddings_query = self.model(**batch_query)
            return list(torch.unbind(embeddings_query.to("cpu")))[0]

    def retrieve_diverse_pages(self, query_text, top_k_pages=3):
        """Retrieve top patches ensuring page diversity"""
        query_embedding = self.embed_query(query_text)
        scores = self.retriever_evaluator.evaluate([query_embedding], self.document_embeddings)

        # Get all patch scores with metadata
        all_results = []
        for idx, score in enumerate(scores[0]):
            all_results.append({
                'patch_idx': int(idx),
                'score': float(score),
                'metadata': self.document_metadata[idx]
            })

        # Sort by score (best first)
        all_results.sort(key=lambda x: x['score'], reverse=True)

        # Select top patches ensuring page diversity
        selected_pages = {}
        page_results = []

        for result in all_results:
            page_num = result['metadata']['page_num']

            # Add this page if we haven't seen it yet
            if page_num not in selected_pages:
                selected_pages[page_num] = result
                page_results.append(result)

                # Stop when we have enough diverse pages
                if len(page_results) >= top_k_pages:
                    break
            else:
                # Update if this patch has a better score for this page
                if result['score'] > selected_pages[page_num]['score']:
                    selected_pages[page_num] = result
                    # Update in page_results too
                    for i, page_result in enumerate(page_results):
                        if page_result['metadata']['page_num'] == page_num:
                            page_results[i] = result
                            break

        return page_results

    def ask_claude_multiple(self, images, query, page_numbers):
        """Send multiple images + query to Claude"""
        content = []

        # Add all images
        for i, image in enumerate(images):
            buffer = io.BytesIO()
            if image.mode != 'RGB':
                image = image.convert('RGB')
            image.save(buffer, format='PNG')
            image_b64 = base64.b64encode(buffer.getvalue()).decode()

            content.append({
                "type": "image",
                "source": {
                    "type": "base64",
                    "media_type": "image/png",
                    "data": image_b64
                }
            })

        # Add text query
        page_list = ", ".join([f"Page {p}" for p in page_numbers])
        content.append({
            "type": "text",
            "text": f"Based on these document images from {page_list}, please answer: {query} also Please review all the provided pages and give a comprehensive answer based on the most relevant information found."
        })

        response = self.anthropic_client.messages.create(
            model="claude-3-5-sonnet-20241022",
            max_tokens=2000,  # Increased for multiple page context
            messages=[{
                "role": "user",
                "content": content
            }]
        )

        return response.content[0].text

    def rag_query(self, query_text, top_k_pages=3):
        """Complete RAG pipeline: Retrieve + Generate with diverse pages"""
        # Retrieve top pages ensuring diversity
        retrieval_results = self.retrieve_diverse_pages(query_text, top_k_pages)

        # Extract images and page numbers (already diverse)
        images = [result['metadata']['original_image'] for result in retrieval_results]
        page_numbers = [result['metadata']['page_num'] + 1 for result in retrieval_results]
        scores = [result['score'] for result in retrieval_results]

        # Generate answer with Claude using multiple pages
        answer = self.ask_claude_multiple(images, query_text, page_numbers)

        return {
            'query': query_text,
            'answer': answer,
            'source_pages': page_numbers,
            'confidence_scores': scores,
            'num_pages_used': len(images)
        }

def main():
    st.set_page_config(
        page_title="ColPali + Claude RAG",
        page_icon="📄",
        layout="wide"
    )

    st.title("ColPali + Claude RAG System")
    st.markdown("Upload a PDF and ask questions about its content!")

    # Sidebar for API key
    with st.sidebar:
        st.header("Configuration")
        api_key = st.text_input(
            "Anthropic API Key",
            type="password",
            help="Enter your Anthropic API key"
        )

        # GPU storage option
        use_gpu_storage = st.checkbox(
            "Store embeddings in GPU memory",
            value=True,
            help="Faster retrieval with high GPU memory"
        )

        # Retrieval configuration
        st.subheader("Retrieval Settings")
        top_k_pages = st.slider(
            "Number of pages to retrieve",
            min_value=1,
            max_value=10,
            value=3,
            help="More pages = better accuracy but slower processing"
        )

        if not api_key:
            st.warning("Please enter your Anthropic API key to continue.")
            st.stop()

    # Initialize RAG chain
    if 'rag_chain' not in st.session_state:
        st.session_state.rag_chain = ColPaliRAGChain(api_key, use_gpu_storage)

    # Main interface
    col1, col2 = st.columns([1, 1])

    with col1:
        st.header("Document Upload")

        uploaded_file = st.file_uploader(
            "Choose a PDF file",
            type="pdf",
            help="Upload a PDF document to analyze"
        )

        if uploaded_file is not None:
            if st.button("Process Document", type="primary"):
                # Process PDF
                patches, metadata = st.session_state.rag_chain.pdf_to_patches(uploaded_file)
                st.success(f"Extracted {len(patches)} patches from {len(set(m['page_num'] for m in metadata))} pages")

                # Embed documents
                embeddings = st.session_state.rag_chain.embed_documents(patches, metadata)
                st.success(f"Created {len(embeddings)} embeddings")
                st.session_state.document_processed = True

    with col2:
        st.header("Ask Questions")

        if 'document_processed' not in st.session_state:
            st.info("Please upload and process a document first")
        else:
            query = st.text_input(
                "Enter your question:",
                placeholder="What is the main topic of this document?"
            )

            if st.button("Get Answer", type="primary") and query:
                with st.spinner("Searching and generating answer..."):
                    result = st.session_state.rag_chain.rag_query(query, top_k_pages)

                # Display results
                st.success("Answer generated!")

                st.subheader("Results")
                st.write(f"**Question:** {result['query']}")
                st.write(f"**Answer:** {result['answer']}")

                # Metadata
                with st.expander("Source Information"):
                    st.write(f"**Source Pages:** {', '.join(map(str, result['source_pages']))}")
                    st.write(f"**Pages Used:** {result['num_pages_used']}")
                    st.write(f"**Confidence Scores:** {[f'{score:.4f}' for score in result['confidence_scores']]}")

                    # Show individual page scores
                    for i, (page, score) in enumerate(zip(result['source_pages'], result['confidence_scores'])):
                        st.write(f"  - Page {page}: {score:.4f}")

                # Store in history
                if 'query_history' not in st.session_state:
                    st.session_state.query_history = []
                st.session_state.query_history.append(result)

            # Query history
            if 'query_history' in st.session_state and st.session_state.query_history:
                st.subheader("Recent Questions")
                for i, hist_result in enumerate(reversed(st.session_state.query_history[-5:])):
                    with st.expander(f"Q: {hist_result['query'][:50]}..."):
                        st.write(f"**A:** {hist_result['answer']}")
                        if 'source_pages' in hist_result:  # New multi-page format
                            st.write(f"**Source:** Pages {', '.join(map(str, hist_result['source_pages']))}")
                        else:  # Old single-page format (backwards compatibility)
                            st.write(f"**Source:** Page {hist_result.get('source_page', 'Unknown')}")

if __name__ == "__main__":
    main()
'''

with open('/content/colpali_app.py', 'w') as f:
    f.write(streamlit_app_code)

print("Streamlit app created successfully!")

import subprocess
import threading
import time

def run_streamlit():
    """Run Streamlit app"""
    subprocess.run(['streamlit', 'run', '/content/colpali_app.py', '--server.port=8501'])

streamlit_thread = threading.Thread(target=run_streamlit)
streamlit_thread.daemon = True
streamlit_thread.start()

time.sleep(5)

public_url = ngrok.connect(8501)
print(f"\\n{'='*50}")
print(f"Public URL: {public_url}")
print(f"{'='*50}")

while True:
    time.sleep(60)

Enter your ngrok auth token: ··········
Streamlit app created successfully!
Streamlit app is running!
Public URL: NgrokTunnel: "https://6219-34-67-22-190.ngrok-free.app" -> "http://localhost:8501"
\nCopy the URL above and open it in your browser to access the app.
The app will remain active as long as this Colab session is running.


KeyboardInterrupt: 