# Lab 4.1.3: Multimodal RAG - SOLUTIONS

This notebook contains the complete solution for the Visual Question Answering RAG challenge.

---

In [None]:
# Setup
import gc
import time
import hashlib
import requests
from io import BytesIO
from typing import Optional, List, Dict, Any
from dataclasses import dataclass

import torch
import numpy as np
from PIL import Image

import chromadb
from chromadb.config import Settings
from transformers import CLIPModel, CLIPProcessor, AutoProcessor, LlavaForConditionalGeneration

device = "cuda" if torch.cuda.is_available() else "cpu"

# Load CLIP
print("Loading CLIP...")
clip_model = CLIPModel.from_pretrained("openai/clip-vit-large-patch14").to(device)
clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-large-patch14")
clip_model.eval()

# Load LLaVA
print("Loading LLaVA...")
vlm_processor = AutoProcessor.from_pretrained("llava-hf/llava-1.5-7b-hf")
vlm_model = LlavaForConditionalGeneration.from_pretrained(
    "llava-hf/llava-1.5-7b-hf",
    torch_dtype=torch.bfloat16,
    device_map="auto",
)

print("‚úÖ Models loaded!")

---

## Challenge Solution: Visual Question Answering RAG

In [None]:
class VisualQARAG:
    """
    A Visual Question Answering system that uses RAG to enhance answers.
    
    This system:
    1. Takes a question about an image
    2. Retrieves relevant context from a knowledge base
    3. Uses a VLM to answer the question with the retrieved context
    """
    
    def __init__(self):
        # Initialize ChromaDB
        self.client = chromadb.Client(Settings(anonymized_telemetry=False))
        self.collection = self.client.get_or_create_collection(
            name="visual_qa_kb",
            metadata={"hnsw:space": "cosine"}
        )
    
    def _embed_image(self, image: Image.Image) -> np.ndarray:
        """Get CLIP embedding for image."""
        inputs = clip_processor(images=image, return_tensors="pt", padding=True)
        inputs = {k: v.to(device) for k, v in inputs.items()}
        
        with torch.no_grad():
            features = clip_model.get_image_features(**inputs)
            features = features / features.norm(dim=-1, keepdim=True)
        
        return features.cpu().numpy()[0]
    
    def _embed_text(self, text: str) -> np.ndarray:
        """Get CLIP embedding for text."""
        inputs = clip_processor(text=[text], return_tensors="pt", padding=True, truncation=True)
        inputs = {k: v.to(device) for k, v in inputs.items()}
        
        with torch.no_grad():
            features = clip_model.get_text_features(**inputs)
            features = features / features.norm(dim=-1, keepdim=True)
        
        return features.cpu().numpy()[0]
    
    def _analyze_image(self, image: Image.Image, question: str) -> str:
        """Get VLM analysis of image."""
        prompt = f"USER: <image>\n{question}\nASSISTANT:"
        
        inputs = vlm_processor(text=prompt, images=image, return_tensors="pt")
        inputs = {k: v.to(vlm_model.device) for k, v in inputs.items()}
        
        with torch.no_grad():
            output_ids = vlm_model.generate(
                **inputs,
                max_new_tokens=256,
                do_sample=False,
            )
        
        response = vlm_processor.decode(output_ids[0], skip_special_tokens=True)
        if "ASSISTANT:" in response:
            response = response.split("ASSISTANT:")[-1].strip()
        
        return response
    
    def add_knowledge(self, text: str, metadata: Dict = None) -> str:
        """
        Add knowledge to the RAG database.
        
        Args:
            text: Knowledge text to add
            metadata: Optional metadata
            
        Returns:
            ID of added item
        """
        item_id = hashlib.md5(text.encode()).hexdigest()
        embedding = self._embed_text(text)
        
        meta = {"type": "knowledge"}
        if metadata:
            meta.update(metadata)
        
        self.collection.add(
            ids=[item_id],
            embeddings=[embedding.tolist()],
            metadatas=[meta],
            documents=[text],
        )
        
        return item_id
    
    def retrieve_context(
        self,
        query: str,
        image: Optional[Image.Image] = None,
        top_k: int = 3,
    ) -> List[str]:
        """
        Retrieve relevant context for a query.
        
        Args:
            query: Text query
            image: Optional image to include in search
            top_k: Number of results
            
        Returns:
            List of relevant context strings
        """
        # Get text embedding
        text_emb = self._embed_text(query)
        
        # If image provided, combine with image embedding
        if image is not None:
            image_emb = self._embed_image(image)
            # Average the embeddings
            combined_emb = (text_emb + image_emb) / 2
            combined_emb = combined_emb / np.linalg.norm(combined_emb)
            query_emb = combined_emb
        else:
            query_emb = text_emb
        
        # Search
        results = self.collection.query(
            query_embeddings=[query_emb.tolist()],
            n_results=top_k,
            include=["documents"],
        )
        
        if results["documents"]:
            return results["documents"][0]
        return []
    
    def answer_question(
        self,
        image: Image.Image,
        question: str,
        use_rag: bool = True,
        top_k: int = 3,
    ) -> Dict[str, Any]:
        """
        Answer a question about an image using RAG.
        
        Args:
            image: Query image
            question: Question about the image
            use_rag: Whether to use RAG context
            top_k: Number of context items to retrieve
            
        Returns:
            Dictionary with answer, context, and metadata
        """
        context_items = []
        
        if use_rag and self.collection.count() > 0:
            # Retrieve relevant context
            context_items = self.retrieve_context(question, image, top_k)
        
        # Build enhanced question with context
        if context_items:
            context_str = "\n".join(f"- {item}" for item in context_items)
            enhanced_question = f"""Based on the image and the following relevant information:

{context_str}

Please answer: {question}"""
        else:
            enhanced_question = question
        
        # Get VLM answer
        answer = self._analyze_image(image, enhanced_question)
        
        return {
            "question": question,
            "answer": answer,
            "context_used": context_items,
            "rag_enabled": use_rag,
        }

print("‚úÖ VisualQARAG class ready!")

In [None]:
# Create and populate the RAG system
rag = VisualQARAG()

# Add knowledge about animals
knowledge = [
    "Cats are obligate carnivores and primarily eat meat. They need taurine in their diet which is found in animal tissue.",
    "Domestic cats (Felis catus) typically live 12-18 years. Indoor cats often live longer than outdoor cats.",
    "Cats sleep an average of 12-16 hours per day. They are crepuscular, meaning most active at dawn and dusk.",
    "Dogs are omnivores and can eat a variety of foods including meat, vegetables, and grains.",
    "Dogs have been domesticated for over 15,000 years and are known for their loyalty and social nature.",
    "Golden Retrievers are known for being friendly, reliable, and excellent family pets. They were originally bred for hunting.",
    "Orange tabby cats are not a specific breed but a color pattern. About 80% of orange tabby cats are male.",
    "Cats have over 20 vocalizations, including the meow, purr, hiss, and chirp. Each has different meanings.",
]

print("Adding knowledge to RAG...")
for item in knowledge:
    rag.add_knowledge(item)

print(f"‚úÖ Added {len(knowledge)} knowledge items")

In [None]:
# Test with a cat image
def load_image_from_url(url: str) -> Image.Image:
    response = requests.get(url, timeout=10)
    return Image.open(BytesIO(response.content)).convert("RGB")

cat_url = "https://upload.wikimedia.org/wikipedia/commons/thumb/3/3a/Cat03.jpg/1200px-Cat03.jpg"
cat_image = load_image_from_url(cat_url)

# Display image
import matplotlib.pyplot as plt
plt.figure(figsize=(8, 6))
plt.imshow(cat_image)
plt.axis('off')
plt.title("Query Image")
plt.show()

# Test questions
questions = [
    "What type of animal is this and what do they eat?",
    "How long do these animals typically live?",
    "What interesting facts can you tell me about this animal's behavior?",
]

print("\nüìù Visual QA with RAG")
print("=" * 60)

for q in questions:
    print(f"\n‚ùì Question: {q}")
    
    result = rag.answer_question(cat_image, q, use_rag=True)
    
    print(f"\nüìö Context Retrieved:")
    for ctx in result['context_used']:
        print(f"   ‚Ä¢ {ctx[:80]}...")
    
    print(f"\nüí¨ Answer: {result['answer']}")
    print("-" * 60)

In [None]:
# Compare with and without RAG
print("\nüìä Comparison: With vs Without RAG")
print("=" * 60)

question = "What type of animal is this and what special facts do you know about orange cats?"

# Without RAG
result_no_rag = rag.answer_question(cat_image, question, use_rag=False)
print(f"\n‚ùå Without RAG:")
print(f"   {result_no_rag['answer']}")

# With RAG
result_with_rag = rag.answer_question(cat_image, question, use_rag=True)
print(f"\n‚úÖ With RAG:")
print(f"   Context: {result_with_rag['context_used']}")
print(f"   Answer: {result_with_rag['answer']}")

---

## Cleanup

In [None]:
del clip_model, clip_processor, vlm_model, vlm_processor
torch.cuda.empty_cache()
gc.collect()
print("‚úÖ Cleanup complete!")