# Lab 4.1.3 Solutions: Multimodal RAG

This notebook contains solutions to the exercises in the Multimodal RAG notebook.

---

## Challenge Solution: Hybrid Search System

The challenge was to create a system that accepts queries with both text AND an example image.

### Approach
We use weighted combination of CLIP embeddings:
1. **Get both embeddings**: Compute CLIP embeddings for both text query and image query
2. **Weighted average**: `combined = text_weight * text_emb + (1-text_weight) * image_emb`
3. **Re-normalize**: Normalize the combined embedding for cosine similarity search
4. **Search**: Use the combined embedding to query ChromaDB

### Why This Works
CLIP embeds both images and text into the same vector space where similar concepts are close together. By combining embeddings, we get:
- `text_weight=1.0`: Pure text search (finds images matching the description)
- `text_weight=0.0`: Pure image search (finds visually similar images)
- `text_weight=0.5`: Balanced search (finds images matching both criteria)

### Key Insight
The weight parameter lets users control the search behavior dynamically, making the system flexible for different use cases.

In [None]:
import torch
import numpy as np
import gc
from PIL import Image
from typing import List, Dict, Optional
import chromadb
from chromadb.config import Settings
import base64
from io import BytesIO

def clear_gpu_memory():
    gc.collect()
    if torch.cuda.is_available():
        torch.cuda.empty_cache()

In [None]:
# Load CLIP
from transformers import CLIPProcessor, CLIPModel

print("Loading CLIP...")
clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-large-patch14")
clip_model = CLIPModel.from_pretrained(
    "openai/clip-vit-large-patch14",
    torch_dtype=torch.bfloat16  # Optimized for Blackwell
).to("cuda")
clip_model.eval()
print("Loaded!")

In [None]:
def get_image_embedding(image: Image.Image) -> np.ndarray:
    """Get normalized CLIP embedding for an image."""
    inputs = clip_processor(images=image, return_tensors="pt")
    inputs = {k: v.to(clip_model.device) for k, v in inputs.items()}
    
    with torch.inference_mode():
        features = clip_model.get_image_features(**inputs)
        features = features / features.norm(dim=-1, keepdim=True)
    
    return features.cpu().numpy()[0]


def get_text_embedding(text: str) -> np.ndarray:
    """Get normalized CLIP embedding for text."""
    inputs = clip_processor(text=[text], return_tensors="pt", padding=True, truncation=True)
    inputs = {k: v.to(clip_model.device) for k, v in inputs.items()}
    
    with torch.inference_mode():
        features = clip_model.get_text_features(**inputs)
        features = features / features.norm(dim=-1, keepdim=True)
    
    return features.cpu().numpy()[0]


def image_to_base64(image: Image.Image) -> str:
    """Convert PIL Image to base64."""
    buffer = BytesIO()
    image.save(buffer, format="PNG")
    return base64.b64encode(buffer.getvalue()).decode('utf-8')


def base64_to_image(b64: str) -> Image.Image:
    """Convert base64 to PIL Image."""
    return Image.open(BytesIO(base64.b64decode(b64)))

In [None]:
def hybrid_search(
    text_query: str,
    image_query: Image.Image,
    collection,
    text_weight: float = 0.7,
    n_results: int = 5
) -> List[Dict]:
    """
    Search using both text and image queries with weighted combination.
    
    Args:
        text_query: Natural language query
        image_query: Image to match
        collection: ChromaDB collection
        text_weight: Weight for text query (image_weight = 1 - text_weight)
        n_results: Number of results to return
        
    Returns:
        List of matching images with metadata and scores
    """
    # Get embeddings
    text_emb = get_text_embedding(text_query)
    image_emb = get_image_embedding(image_query)
    
    # Weighted combination
    image_weight = 1.0 - text_weight
    combined_emb = text_weight * text_emb + image_weight * image_emb
    
    # Normalize the combined embedding
    combined_emb = combined_emb / np.linalg.norm(combined_emb)
    
    # Search
    results = collection.query(
        query_embeddings=[combined_emb.tolist()],
        n_results=n_results,
        include=['metadatas', 'distances']
    )
    
    # Format results
    formatted = []
    for i in range(len(results['ids'][0])):
        meta = results['metadatas'][0][i].copy()
        img_b64 = meta.pop('image_b64', None)
        
        result = {
            'id': results['ids'][0][i],
            'similarity': 1 - results['distances'][0][i],
            'metadata': meta
        }
        
        if img_b64:
            result['image'] = base64_to_image(img_b64)
        
        formatted.append(result)
    
    return formatted

print("hybrid_search() function ready!")

In [None]:
# Create a test dataset
from PIL import ImageDraw

def create_shape(shape: str, color: str) -> Image.Image:
    img = Image.new('RGB', (224, 224), 'white')
    draw = ImageDraw.Draw(img)
    
    if shape == 'circle':
        draw.ellipse([50, 50, 174, 174], fill=color)
    elif shape == 'square':
        draw.rectangle([50, 50, 174, 174], fill=color)
    elif shape == 'triangle':
        draw.polygon([(112, 50), (50, 174), (174, 174)], fill=color)
    
    return img

# Create dataset
shapes = ['circle', 'square', 'triangle']
colors = ['red', 'blue', 'green', 'yellow', 'purple', 'orange']

dataset = []
for shape in shapes:
    for color in colors:
        img = create_shape(shape, color)
        dataset.append({
            'image': img,
            'shape': shape,
            'color': color
        })

print(f"Created {len(dataset)} images")

In [None]:
# Index the dataset
client = chromadb.Client(Settings(anonymized_telemetry=False))

try:
    client.delete_collection("hybrid_demo")
except:
    pass

collection = client.create_collection(
    name="hybrid_demo",
    metadata={"hnsw:space": "cosine"}
)

print("Indexing images...")
for i, item in enumerate(dataset):
    emb = get_image_embedding(item['image'])
    collection.add(
        ids=[f"img_{i}"],
        embeddings=[emb.tolist()],
        metadatas=[{
            'shape': item['shape'],
            'color': item['color'],
            'image_b64': image_to_base64(item['image'])
        }]
    )

print(f"Indexed {collection.count()} images")

In [None]:
# Test hybrid search!
import matplotlib.pyplot as plt

# Create a query image (a pink circle - not in our dataset)
query_image = create_shape('circle', 'pink')

# Text query for something red
text_query = "something red"

print(f"Query image: pink circle")
print(f"Text query: '{text_query}'")
print("\nThis should find red circles (matching both text 'red' and image 'circle')")

# Test with different weights
weights = [0.0, 0.3, 0.5, 0.7, 1.0]

fig, axes = plt.subplots(len(weights), 6, figsize=(12, 10))

for row, w in enumerate(weights):
    results = hybrid_search(text_query, query_image, collection, text_weight=w, n_results=5)
    
    axes[row, 0].set_ylabel(f"text={w:.1f}\nimg={1-w:.1f}", fontsize=10)
    axes[row, 0].imshow(query_image)
    axes[row, 0].set_title("Query", fontsize=8)
    axes[row, 0].axis('off')
    
    for col, r in enumerate(results, start=1):
        if col < 6:
            axes[row, col].imshow(r['image'])
            axes[row, col].set_title(f"{r['metadata']['color']}\n{r['metadata']['shape']}", fontsize=8)
            axes[row, col].axis('off')

plt.suptitle(f"Hybrid Search: '{text_query}' + pink circle image", fontsize=12, fontweight='bold')
plt.tight_layout()
plt.show()

In [None]:
# Analyze the results
print("\nAnalysis of weight effects:")
print("-" * 50)
print("\n- text_weight=0.0 (100% image): Finds circles of any color")
print("- text_weight=0.5 (balanced): Finds red + circles")
print("- text_weight=1.0 (100% text): Finds red shapes of any type")
print("\nThe red circle should rank highest with balanced weights!")

In [None]:
# Cleanup
del clip_model, clip_processor
clear_gpu_memory()
print("Solutions notebook complete!")