# MM-RAG: Multimodal Retrieval-Augmented Generation Demo

This notebook demonstrates the MM-RAG system for multimodal question answering.

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/abhijoysarkar/mmrag/blob/main/examples/MM_RAG_Demo.ipynb)

## Installation

In [None]:
# Install MM-RAG
!pip install -q git+https://github.com/abhijoysarkar/mmrag.git

## Setup and Imports

In [None]:
import torch
import numpy as np
from PIL import Image, ImageDraw
import requests
from io import BytesIO

print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
print(f"Device: {'cuda' if torch.cuda.is_available() else 'cpu'}")

## Step 1: Create Sample Knowledge Base

In [None]:
# Sample documents about famous landmarks
documents = [
    "The Eiffel Tower is a wrought-iron lattice tower in Paris, France. Built in 1889, it stands 330 meters tall.",
    "The Statue of Liberty is located on Liberty Island in New York Harbor. It was a gift from France in 1886.",
    "The Great Wall of China stretches over 13,000 miles across northern China. Most was built during the Ming Dynasty.",
    "The Taj Mahal is an ivory-white marble mausoleum in Agra, India. It was built by Shah Jahan in 1632.",
    "The Colosseum is an ancient amphitheater in Rome, Italy. Built in 80 AD, it could hold 50,000 spectators.",
    "Machu Picchu is a 15th-century Inca citadel in Peru. It sits at 2,430 meters above sea level.",
    "The Golden Gate Bridge spans the Golden Gate strait in San Francisco. Opened in 1937, it's painted International Orange.",
    "The Sydney Opera House is a performing arts center in Sydney, Australia. Its unique sail-like design was completed in 1973.",
]

print(f"Knowledge base: {len(documents)} documents")
for i, doc in enumerate(documents[:3], 1):
    print(f"{i}. {doc[:80]}...")

## Step 2: Initialize MM-RAG Components

In [None]:
from mmrag.models.vision_encoder import VisionEncoder
from mmrag.models.retriever import FaissRetriever
from mmrag.models.fusion import CrossModalFusionBlock

# Initialize components (CPU for Colab free tier)
print("Loading vision encoder...")
vision_encoder = VisionEncoder(
    model_name="openai/clip-vit-base-patch16",
    device="auto"
)

print("Loading retriever...")
retriever = FaissRetriever(
    dim=512,  # CLIP dimension
    text_encoder_name='sentence-transformers/all-MiniLM-L6-v2',
    device='auto'
)

print("Loading fusion module...")
fusion = CrossModalFusionBlock(
    embed_dim=512,
    num_heads=8,
    dropout=0.1,
    fusion_type='attention'
)

print("✓ All components loaded!")

## Step 3: Index Documents

In [None]:
# Encode and index documents
print("Encoding documents...")
doc_embeddings_384 = retriever.encode_text(documents)

# Pad to match CLIP dimension (512)
doc_embeddings = np.pad(
    doc_embeddings_384,
    ((0, 0), (0, 512 - doc_embeddings_384.shape[1])),
    mode='constant'
)

doc_ids = [f"doc_{i}" for i in range(len(documents))]
retriever.add(doc_embeddings, doc_ids, documents)

print(f"✓ Indexed {len(documents)} documents")
print(f"  Index size: {retriever.index.ntotal} vectors")
print(f"  Embedding dimension: {retriever.dim}")

## Step 4: Create Test Image

In [None]:
def create_tower_image():
    """Create a simple tower-like structure."""
    img = Image.new('RGB', (512, 512), color='lightblue')
    draw = ImageDraw.Draw(img)
    
    # Draw tower
    draw.rectangle([(200, 450), (312, 470)], fill='darkgray', outline='black', width=2)
    points = [(180, 450), (256, 300), (332, 450)]
    draw.polygon(points, fill='gray', outline='black')
    points = [(220, 300), (256, 200), (292, 300)]
    draw.polygon(points, fill='darkgray', outline='black')
    points = [(240, 200), (256, 100), (272, 200)]
    draw.polygon(points, fill='gray', outline='black')
    draw.line([(256, 100), (256, 50)], fill='black', width=3)
    
    return img

test_image = create_tower_image()
display(test_image)

## Step 5: Run Multimodal Retrieval

In [None]:
# Query
query = "What famous tower structure is this and where is it located?"

print(f"Query: {query}\n")

# Encode image
print("Encoding image...")
image_emb = vision_encoder.encode(test_image)
print(f"✓ Image embedding shape: {image_emb.shape}")

# Hybrid search (text + image)
print("\nPerforming hybrid search...")
results = retriever.hybrid_search(
    query,
    image_emb.cpu().numpy(),
    alpha=0.5,  # Balance between text and image
    top_k=3
)

print(f"✓ Retrieved {len(results)} documents\n")
print("=" * 80)
print("TOP RETRIEVED DOCUMENTS:")
print("=" * 80)
for i, (doc_id, score, text) in enumerate(results, 1):
    print(f"\n{i}. [{doc_id}] Score: {score:.4f}")
    print(f"   {text}")

## Step 6: Cross-Modal Fusion

In [None]:
# Prepare embeddings for fusion
text_emb = torch.tensor(retriever.encode_text(query), device=vision_encoder.device)

# Get document embeddings
doc_texts = [r[2] for r in results]
doc_embs_raw = retriever.encode_text(doc_texts)
doc_embs = np.pad(
    doc_embs_raw,
    ((0, 0), (0, 512 - doc_embs_raw.shape[1])),
    mode='constant'
)
doc_embs = torch.tensor(doc_embs, device=vision_encoder.device)

# Reshape for fusion module
if image_emb.dim() == 2:
    image_emb = image_emb.unsqueeze(1)  # (1, 1, 512)
if text_emb.dim() == 2:
    text_emb = text_emb.unsqueeze(1)  # (1, 1, 512)
if doc_embs.dim() == 2:
    doc_embs = doc_embs.unsqueeze(0)  # (1, K, 512)

print(f"Image embedding: {image_emb.shape}")
print(f"Text embedding: {text_emb.shape}")
print(f"Document embeddings: {doc_embs.shape}")

# Fuse modalities
print("\nFusing modalities with attention...")
with torch.no_grad():
    fused_emb = fusion(image_emb, text_emb, doc_embs)

print(f"✓ Fused embedding shape: {fused_emb.shape}")
print(f"✓ Fused embedding norm: {fused_emb.norm():.4f}")

## Results Summary

In [None]:
print("=" * 80)
print("MM-RAG PIPELINE SUMMARY")
print("=" * 80)
print(f"\n📊 Query: {query}")
print(f"\n🖼️  Image: Tower structure (512x512)")
print(f"\n📚 Knowledge Base: {len(documents)} documents")
print(f"\n🔍 Retrieved Documents: {len(results)}")
print(f"\n🎯 Top Match: {results[0][2][:100]}...")
print(f"   Relevance Score: {results[0][1]:.4f}")
print(f"\n🔀 Fusion Type: Attention-based")
print(f"   Output Dimension: {fused_emb.shape[-1]}")
print("\n✅ Pipeline executed successfully!")
print("\nℹ️  Note: This demo shows retrieval + fusion only.")
print("   For text generation, connect a language model (LLaMA, GPT, etc.)")
print("=" * 80)

## Try Your Own Image!

In [None]:
# Upload your own image or load from URL
from google.colab import files

# Option 1: Upload
uploaded = files.upload()
if uploaded:
    filename = list(uploaded.keys())[0]
    custom_image = Image.open(filename)
    display(custom_image)
    
    # Run retrieval
    custom_query = input("Enter your query: ")
    custom_emb = vision_encoder.encode(custom_image)
    custom_results = retriever.hybrid_search(custom_query, custom_emb.cpu().numpy(), top_k=3)
    
    print("\nResults:")
    for i, (_, score, text) in enumerate(custom_results, 1):
        print(f"{i}. [{score:.4f}] {text}")

## Next Steps

- **Add Generation**: Connect a language model for full MM-RAG
- **Fine-tune**: Train LoRA adapters on your domain
- **Deploy**: Use the FastAPI server for production
- **Extend**: Add more fusion strategies or encoders

### Resources
- [GitHub Repository](https://github.com/abhijoysarkar/mmrag)
- [Documentation](https://github.com/abhijoysarkar/mmrag#readme)
- [Paper: RAG and Multimodal Learning](https://arxiv.org/abs/2005.11401)