# Day 2, Session 1 Lab: Add Vision to Text Agent

## Lab Overview

**Estimated Time:** 40 minutes  
**Difficulty:** Intermediate  
**Prerequisites:** Completion of Day 1 labs

### Learning Objectives

By the end of this lab, you will be able to:

1. **Extend Text Agents with Vision**
   - Modify yesterday's support agent to handle images
   - Implement proper state management for multimodal data
   - Build conditional routing based on input modality

2. **Implement Real Vision Models**
   - Load and use CLIP for image classification
   - Apply BLIP-2 for image captioning
   - Integrate EasyOCR for text extraction
   - Manage GPU memory effectively

3. **Build Production-Ready Multimodal Systems**
   - Handle different input combinations gracefully
   - Implement error handling for GPU limitations
   - Create robust vision processing pipelines

### Real-World Application

This lab simulates building a customer support system that can:
- Answer text questions about invoices
- Process uploaded invoice images
- Handle combined text+image queries
- Extract structured data from documents

### Lab Structure

1. **Extend State Management** (5 minutes)
2. **Implement Vision Models** (15 minutes)  
3. **Build Vision Processing Nodes** (10 minutes)
4. **Update Graph Routing** (5 minutes)
5. **Test with Real Invoices** (5 minutes)

Let's transform your text agent into a multimodal powerhouse!

In [None]:
# Server configuration - instructor provides actual values
OLLAMA_URL = "http://XX.XX.XX.XX"  # Course server IP
API_TOKEN = "YOUR_TOKEN_HERE"      # Instructor provides token
MODEL = "qwen3:8b"                  # Default model on server

import requests
import json
import time
import base64
from PIL import Image
import io
import os
import torch
from typing import Dict, List, Optional, Any, TypedDict
from dataclasses import dataclass
import numpy as np
import warnings
warnings.filterwarnings('ignore')

# GPU memory monitoring
def get_gpu_memory():
    """Get current GPU memory usage"""
    if torch.cuda.is_available():
        allocated = torch.cuda.memory_allocated() / 1024**3  # GB
        reserved = torch.cuda.memory_reserved() / 1024**3    # GB
        return {"allocated": allocated, "reserved": reserved}
    return {"allocated": 0, "reserved": 0}

# Health check
def check_server_health():
    """Verify server connection and model availability"""
    try:
        response = requests.get(f"{OLLAMA_URL}/health")
        if response.status_code == 200:
            data = response.json()
            print(f"✅ Server Status: {data.get('status', 'Unknown')}")
            print(f"📊 Models Available: {data.get('models_count', 0)}")
            return True
    except Exception as e:
        print(f"❌ Server connection failed: {e}")
    return False

# LLM calling function
def call_llm(prompt, model=MODEL):
    """Call the LLM with a prompt"""
    headers = {
        "Authorization": f"Bearer {API_TOKEN}",
        "Content-Type": "application/json"
    }
    
    data = {
        "model": model,
        "prompt": prompt
    }
    
    try:
        response = requests.post(
            f"{OLLAMA_URL}/think",
            headers=headers,
            json=data
        )
        if response.status_code == 200:
            return response.json().get('response', '')
        else:
            return f"Error: {response.status_code}"
    except Exception as e:
        return f"Error: {e}"

print("🔌 Connecting to course server...")
server_available = check_server_health()

print(f"\n🖥️ GPU Status: {'Available' if torch.cuda.is_available() else 'Not Available'}")
if torch.cuda.is_available():
    gpu_mem = get_gpu_memory()
    print(f"💾 GPU Memory: {gpu_mem['allocated']:.1f}GB allocated, {gpu_mem['reserved']:.1f}GB reserved")

# Install required packages for vision processing
print("\n📦 Installing vision processing packages...")
!pip install -q transformers torch torchvision easyocr pillow
print("✅ Packages installed")

In [None]:
# Download real invoice dataset
import requests
import zipfile
import io

dropbox_url = "https://www.dropbox.com/scl/fo/m9hyfmvi78snwv0nh34mo/AMEXxwXMLAOeve-_yj12ck8?rlkey=urinkikgiuven0fro7r4x5rcu&st=hv3of7g7&dl=1"

print("📦 Downloading invoice dataset...")
try:
    response = requests.get(dropbox_url)
    with zipfile.ZipFile(io.BytesIO(response.content)) as z:
        z.extractall("invoice_images")
    print("✅ Downloaded invoice dataset")
    
    # List available images for testing
    INVOICE_FILES = []
    for root, dirs, files in os.walk("invoice_images"):
        for file in files:
            if file.lower().endswith(('.png', '.jpg', '.jpeg')):
                full_path = os.path.join(root, file)
                INVOICE_FILES.append(full_path)
                print(f"  📄 {full_path}")
    
    print(f"\n📊 Total images available: {len(INVOICE_FILES)}")
    
except Exception as e:
    print(f"❌ Error downloading: {e}")
    INVOICE_FILES = []

## Task 1: Extend State Management (5 minutes)

First, extend yesterday's simple support state to handle multimodal data.

**Your Task:** Complete the `MultimodalSupportState` class by adding the missing fields for image processing.

In [None]:
from typing import TypedDict, List, Optional, Dict, Any
from langgraph.graph import StateGraph, END

# Yesterday's simple state for reference
class SimpleSupportState(TypedDict):
    """Yesterday's basic support state"""
    question: str
    answer: str
    context: str

# TODO: Complete the multimodal state
class MultimodalSupportState(TypedDict):
    """Enhanced state for multimodal support agent"""
    # Basic text fields (from yesterday)
    question: Optional[str]
    answer: Optional[str]
    context: str
    
    # TODO: Add image-related fields
    # Hint: You need fields for:
    # - List of base64 encoded images
    # - List of image descriptions
    # - List of extracted text from images
    # - Current modality being processed
    # - Confidence score for vision processing
    # - Processing metrics dictionary
    
    # TODO: Uncomment and complete these fields:
    # images: List[str]  # Base64 encoded images
    # image_descriptions: List[str]  # Generated descriptions
    # extracted_text: List[str]  # OCR results
    # modality: str  # 'text', 'image', or 'multimodal'
    # vision_confidence: float  # Confidence in vision processing
    # processing_metrics: Dict[str, Any]  # Performance tracking

def create_initial_state(question=None, images=None):
    """Create a clean initial state for testing"""
    # TODO: Complete this function to create a proper initial state
    # Use the MultimodalSupportState structure
    
    state = {
        "question": question,
        "answer": None,
        "context": "",
        # TODO: Add the missing fields with appropriate default values
    }
    
    return state

# Test your state creation
print("🧪 Testing state creation...")
test_state = create_initial_state("What is this document?", [])
print(f"Created state with {len(test_state)} fields")

# TODO: Verify your state has all required fields
required_fields = ['question', 'answer', 'context', 'images', 'image_descriptions', 
                  'extracted_text', 'modality', 'vision_confidence', 'processing_metrics']

missing_fields = [field for field in required_fields if field not in test_state]
if missing_fields:
    print(f"❌ Missing fields: {missing_fields}")
    print("💡 Hint: Check your MultimodalSupportState definition")
else:
    print("✅ All required fields present!")

## Task 2: Implement Vision Models (15 minutes)

Load and configure real vision models optimized for T4 GPU (16GB memory).

**Your Task:** Complete the vision model setup and processing functions.

In [None]:
from transformers import CLIPProcessor, CLIPModel, BlipProcessor, BlipForConditionalGeneration
import easyocr
import torch
from PIL import Image

# Global model storage
VISION_MODELS = {}

def clear_gpu_cache():
    """Clear GPU cache to manage memory"""
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
        torch.cuda.synchronize()

def setup_vision_models():
    """Load vision models optimized for T4 GPU"""
    global VISION_MODELS
    
    print("🔧 Setting up vision models...")
    
    # Check initial GPU memory
    initial_memory = get_gpu_memory()
    print(f"📊 Initial GPU memory: {initial_memory['allocated']:.1f}GB")
    
    try:
        # TODO: Load CLIP model for image classification
        # Hint: Use "openai/clip-vit-base-patch32" - it's smaller and faster
        # Remember to move to GPU if available
        
        print("📸 Loading CLIP model...")
        # TODO: Uncomment and complete:
        # clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
        # clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
        # 
        # if torch.cuda.is_available():
        #     clip_model = clip_model.to('cuda')
        # 
        # VISION_MODELS['clip_model'] = clip_model
        # VISION_MODELS['clip_processor'] = clip_processor
        
        memory_after_clip = get_gpu_memory()
        print(f"   GPU memory after CLIP: {memory_after_clip['allocated']:.1f}GB")
        
        # TODO: Load BLIP-2 model for image captioning
        # Hint: Use "Salesforce/blip-image-captioning-base" for T4 compatibility
        
        print("📝 Loading BLIP model...")
        # TODO: Uncomment and complete:
        # blip_processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
        # blip_model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base")
        # 
        # if torch.cuda.is_available():
        #     blip_model = blip_model.to('cuda')
        # 
        # VISION_MODELS['blip_model'] = blip_model
        # VISION_MODELS['blip_processor'] = blip_processor
        
        memory_after_blip = get_gpu_memory()
        print(f"   GPU memory after BLIP: {memory_after_blip['allocated']:.1f}GB")
        
        # TODO: Initialize EasyOCR reader
        # Hint: Use English and common European languages
        
        print("🔤 Loading EasyOCR...")
        # TODO: Uncomment and complete:
        # ocr_reader = easyocr.Reader(['en', 'fr', 'de', 'es'], gpu=torch.cuda.is_available())
        # VISION_MODELS['ocr_reader'] = ocr_reader
        
        final_memory = get_gpu_memory()
        print(f"   Final GPU memory: {final_memory['allocated']:.1f}GB")
        
        total_used = final_memory['allocated'] - initial_memory['allocated']
        print(f"✅ Models loaded successfully! Used {total_used:.1f}GB GPU memory")
        
        return True
        
    except Exception as e:
        print(f"❌ Error loading models: {e}")
        print("💡 Try clearing GPU cache or using smaller models")
        clear_gpu_cache()
        return False

def encode_image_to_base64(image_path: str) -> str:
    """Convert image file to base64 string"""
    try:
        with open(image_path, 'rb') as image_file:
            return base64.b64encode(image_file.read()).decode('utf-8')
    except Exception as e:
        print(f"Error encoding image: {e}")
        return ""

def decode_base64_to_image(base64_string: str) -> Image.Image:
    """Convert base64 string to PIL Image"""
    try:
        image_data = base64.b64decode(base64_string)
        return Image.open(io.BytesIO(image_data))
    except Exception as e:
        print(f"Error decoding image: {e}")
        return None

# TODO: Call your setup function
print("🚀 Initializing vision models...")
models_loaded = setup_vision_models()

if not models_loaded:
    print("\n⚠️ Models not loaded. Check GPU memory and try again.")
    print("💡 You can continue with mock implementations for now.")

In [None]:
# Vision processing functions

def classify_image_with_clip(image: Image.Image, categories: List[str]) -> Dict[str, Any]:
    """Classify image using CLIP model"""
    # TODO: Implement CLIP classification
    # Steps:
    # 1. Get CLIP model and processor from VISION_MODELS
    # 2. Process image and text categories
    # 3. Compute similarities
    # 4. Return top category with confidence
    
    if 'clip_model' not in VISION_MODELS:
        return {"category": "unknown", "confidence": 0.0, "error": "CLIP model not loaded"}
    
    try:
        # TODO: Uncomment and complete:
        # model = VISION_MODELS['clip_model']
        # processor = VISION_MODELS['clip_processor']
        # 
        # # Process inputs
        # inputs = processor(text=categories, images=image, return_tensors="pt", padding=True)
        # 
        # if torch.cuda.is_available():
        #     inputs = {k: v.to('cuda') if isinstance(v, torch.Tensor) else v for k, v in inputs.items()}
        # 
        # # Get predictions
        # with torch.no_grad():
        #     outputs = model(**inputs)
        #     logits_per_image = outputs.logits_per_image
        #     probs = logits_per_image.softmax(dim=-1)
        # 
        # # Get top prediction
        # top_prob, top_idx = probs[0].max(dim=0)
        # 
        # return {
        #     "category": categories[top_idx.item()],
        #     "confidence": top_prob.item(),
        #     "all_scores": {cat: prob.item() for cat, prob in zip(categories, probs[0])}
        # }
        
        # Mock implementation for now
        return {"category": "invoice", "confidence": 0.85, "method": "mock"}
        
    except Exception as e:
        return {"category": "unknown", "confidence": 0.0, "error": str(e)}

def generate_image_caption(image: Image.Image) -> Dict[str, Any]:
    """Generate caption for image using BLIP"""
    # TODO: Implement BLIP captioning
    # Steps:
    # 1. Get BLIP model and processor
    # 2. Process image
    # 3. Generate caption
    # 4. Return caption with confidence
    
    if 'blip_model' not in VISION_MODELS:
        return {"caption": "Unable to generate caption", "confidence": 0.0, "error": "BLIP model not loaded"}
    
    try:
        # TODO: Uncomment and complete:
        # model = VISION_MODELS['blip_model']
        # processor = VISION_MODELS['blip_processor']
        # 
        # # Process image
        # inputs = processor(image, return_tensors="pt")
        # 
        # if torch.cuda.is_available():
        #     inputs = {k: v.to('cuda') if isinstance(v, torch.Tensor) else v for k, v in inputs.items()}
        # 
        # # Generate caption
        # with torch.no_grad():
        #     out = model.generate(**inputs, max_length=50)
        # 
        # caption = processor.decode(out[0], skip_special_tokens=True)
        # 
        # return {
        #     "caption": caption,
        #     "confidence": 0.8,  # BLIP doesn't provide confidence directly
        #     "length": len(caption.split())
        # }
        
        # Mock implementation for now
        return {"caption": "A professional invoice document with company information and itemized costs", "confidence": 0.75, "method": "mock"}
        
    except Exception as e:
        return {"caption": "Caption generation failed", "confidence": 0.0, "error": str(e)}

def extract_text_with_ocr(image: Image.Image) -> Dict[str, Any]:
    """Extract text from image using EasyOCR"""
    # TODO: Implement OCR text extraction
    # Steps:
    # 1. Get OCR reader
    # 2. Convert PIL image to format EasyOCR expects
    # 3. Extract text with bounding boxes
    # 4. Return text and confidence scores
    
    if 'ocr_reader' not in VISION_MODELS:
        return {"text": "", "confidence": 0.0, "error": "OCR reader not loaded"}
    
    try:
        # TODO: Uncomment and complete:
        # reader = VISION_MODELS['ocr_reader']
        # 
        # # Convert PIL to numpy array
        # image_array = np.array(image)
        # 
        # # Extract text
        # results = reader.readtext(image_array)
        # 
        # # Process results
        # extracted_text = []
        # confidences = []
        # 
        # for (bbox, text, conf) in results:
        #     if conf > 0.5:  # Filter low confidence
        #         extracted_text.append(text)
        #         confidences.append(conf)
        # 
        # full_text = ' '.join(extracted_text)
        # avg_confidence = np.mean(confidences) if confidences else 0.0
        # 
        # return {
        #     "text": full_text,
        #     "confidence": avg_confidence,
        #     "word_count": len(extracted_text),
        #     "individual_results": results
        # }
        
        # Mock implementation for now
        mock_text = "INVOICE\nCompany: TechSupplies Co.\nAmount: $15,000.00\nDate: 2024-01-15\nVAT: GB123456789"
        return {"text": mock_text, "confidence": 0.92, "word_count": 8, "method": "mock"}
        
    except Exception as e:
        return {"text": "", "confidence": 0.0, "error": str(e)}

print("✅ Vision processing functions defined")
print("📝 TODO: Complete the implementation by uncommenting and filling in the TODOs")
print("💡 Start with mock implementations, then replace with real models once loaded")

## Task 3: Build Vision Processing Nodes (10 minutes)

Create LangGraph nodes that process different aspects of images.

**Your Task:** Complete the vision processing nodes for the workflow.

In [None]:
def detect_modality(state: MultimodalSupportState) -> MultimodalSupportState:
    """Detect whether input is text, image, or multimodal"""
    # TODO: Implement modality detection
    # Check what types of input are present and set the modality field
    
    has_text = state.get('question') and len(state['question'].strip()) > 0
    has_images = state.get('images') and len(state['images']) > 0
    
    # TODO: Set the modality based on inputs
    if has_text and has_images:
        # TODO: Set state['modality'] = "multimodal"
        pass
    elif has_images:
        # TODO: Set state['modality'] = "image"
        pass
    elif has_text:
        # TODO: Set state['modality'] = "text"
        pass
    else:
        # TODO: Set state['modality'] = "unknown"
        pass
    
    print(f"🔍 Detected modality: {state.get('modality', 'unknown')}")
    return state

def classify_document(state: MultimodalSupportState) -> MultimodalSupportState:
    """Classify document type using CLIP"""
    if not state.get('images'):
        return state
    
    print("📋 Classifying document type...")
    
    # TODO: Process each image with CLIP
    categories = ["invoice", "receipt", "contract", "letter", "form", "other document"]
    
    classifications = []
    
    for i, image_b64 in enumerate(state['images']):
        try:
            # TODO: Decode base64 to PIL Image
            # image = decode_base64_to_image(image_b64)
            # 
            # if image:
            #     result = classify_image_with_clip(image, categories)
            #     classifications.append(result)
            #     print(f"   Image {i+1}: {result['category']} (confidence: {result['confidence']:.2f})")
            
            # Mock for now
            classifications.append({"category": "invoice", "confidence": 0.85})
            print(f"   Image {i+1}: invoice (confidence: 0.85)")
            
        except Exception as e:
            print(f"   Error classifying image {i+1}: {e}")
            classifications.append({"category": "unknown", "confidence": 0.0})
    
    # TODO: Store classifications in processing_metrics
    if 'processing_metrics' not in state:
        state['processing_metrics'] = {}
    state['processing_metrics']['classifications'] = classifications
    
    return state

def describe_image(state: MultimodalSupportState) -> MultimodalSupportState:
    """Generate descriptions for images using BLIP"""
    if not state.get('images'):
        return state
    
    print("📝 Generating image descriptions...")
    
    descriptions = []
    confidences = []
    
    # TODO: Process each image with BLIP
    for i, image_b64 in enumerate(state['images']):
        try:
            # TODO: Decode and caption the image
            # image = decode_base64_to_image(image_b64)
            # 
            # if image:
            #     result = generate_image_caption(image)
            #     descriptions.append(result['caption'])
            #     confidences.append(result['confidence'])
            #     print(f"   Image {i+1}: {result['caption'][:100]}...")
            
            # Mock for now
            mock_description = "A professional invoice document with company letterhead, itemized costs, and payment details"
            descriptions.append(mock_description)
            confidences.append(0.75)
            print(f"   Image {i+1}: {mock_description[:60]}...")
            
        except Exception as e:
            print(f"   Error describing image {i+1}: {e}")
            descriptions.append("Unable to describe image")
            confidences.append(0.0)
    
    # TODO: Update state with descriptions
    state['image_descriptions'] = descriptions
    
    # Update confidence score
    avg_confidence = sum(confidences) / len(confidences) if confidences else 0.0
    state['vision_confidence'] = avg_confidence
    
    return state

def extract_text_ocr(state: MultimodalSupportState) -> MultimodalSupportState:
    """Extract text from images using OCR"""
    if not state.get('images'):
        return state
    
    print("🔤 Extracting text with OCR...")
    
    extracted_texts = []
    ocr_confidences = []
    
    # TODO: Process each image with OCR
    for i, image_b64 in enumerate(state['images']):
        try:
            # TODO: Extract text from image
            # image = decode_base64_to_image(image_b64)
            # 
            # if image:
            #     result = extract_text_with_ocr(image)
            #     extracted_texts.append(result['text'])
            #     ocr_confidences.append(result['confidence'])
            #     print(f"   Image {i+1}: Extracted {len(result['text'])} characters (conf: {result['confidence']:.2f})")
            
            # Mock for now
            mock_text = "INVOICE\nCompany: TechSupplies Co.\nAmount: $15,000.00\nDate: 2024-01-15\nVAT: GB123456789"
            extracted_texts.append(mock_text)
            ocr_confidences.append(0.92)
            print(f"   Image {i+1}: Extracted {len(mock_text)} characters (conf: 0.92)")
            
        except Exception as e:
            print(f"   Error extracting text from image {i+1}: {e}")
            extracted_texts.append("")
            ocr_confidences.append(0.0)
    
    # TODO: Update state with extracted text
    state['extracted_text'] = extracted_texts
    
    # Update processing metrics
    if 'processing_metrics' not in state:
        state['processing_metrics'] = {}
    state['processing_metrics']['ocr_confidences'] = ocr_confidences
    
    return state

def validate_extraction(state: MultimodalSupportState) -> MultimodalSupportState:
    """Validate that vision processing succeeded"""
    print("✅ Validating extraction results...")
    
    # TODO: Check if vision processing was successful
    has_descriptions = state.get('image_descriptions') and any(desc for desc in state['image_descriptions'])
    has_text = state.get('extracted_text') and any(text for text in state['extracted_text'])
    confidence_ok = state.get('vision_confidence', 0) > 0.5
    
    validation_status = {
        'has_descriptions': has_descriptions,
        'has_extracted_text': has_text,
        'confidence_acceptable': confidence_ok,
        'overall_success': has_descriptions and has_text and confidence_ok
    }
    
    if 'processing_metrics' not in state:
        state['processing_metrics'] = {}
    state['processing_metrics']['validation'] = validation_status
    
    print(f"   Descriptions: {'✅' if has_descriptions else '❌'}")
    print(f"   Extracted text: {'✅' if has_text else '❌'}")
    print(f"   Confidence: {'✅' if confidence_ok else '❌'} ({state.get('vision_confidence', 0):.2f})")
    
    return state

print("✅ Vision processing nodes defined")
print("📝 TODO: Complete the implementations by filling in the TODO sections")
print("💡 Test each node individually before building the full graph")

## Task 4: Update Graph Routing (5 minutes)

Build a workflow that routes based on detected modality.

**Your Task:** Complete the multimodal workflow graph with proper routing.

In [None]:
def generate_final_answer(state: MultimodalSupportState) -> MultimodalSupportState:
    """Generate final answer using all available information"""
    print("🤖 Generating final answer...")
    
    # Collect all available information
    context_parts = []
    
    if state.get('question'):
        context_parts.append(f"Question: {state['question']}")
    
    if state.get('image_descriptions'):
        context_parts.append(f"Image descriptions: {' | '.join(state['image_descriptions'])}")
    
    if state.get('extracted_text'):
        context_parts.append(f"Extracted text: {' | '.join(state['extracted_text'])}")
    
    # TODO: Use LLM to generate comprehensive answer
    if server_available and context_parts:
        context = "\n".join(context_parts)
        prompt = f"""Based on the following information about a document, provide a helpful response:

{context}

Provide a clear, informative answer that addresses the question and incorporates all available information."""
        
        response = call_llm(prompt)
        state['answer'] = response
    else:
        # Fallback answer
        state['answer'] = f"Based on the {state.get('modality', 'unknown')} input, I can provide information about the document."
    
    state['context'] = "\n".join(context_parts)
    
    return state

def route_by_modality(state: MultimodalSupportState) -> str:
    """Route processing based on detected modality"""
    modality = state.get('modality', 'unknown')
    
    # TODO: Implement routing logic
    if modality == 'text':
        return 'generate_answer'  # Skip vision processing
    elif modality in ['image', 'multimodal']:
        return 'classify_document'  # Start vision processing
    else:
        return 'generate_answer'  # Default fallback

# TODO: Build the multimodal workflow
print("🏗️ Building multimodal workflow...")

# Create the graph
multimodal_workflow = StateGraph(MultimodalSupportState)

# TODO: Add all the nodes
multimodal_workflow.add_node("detect_modality", detect_modality)
multimodal_workflow.add_node("classify_document", classify_document)
multimodal_workflow.add_node("describe_image", describe_image)
multimodal_workflow.add_node("extract_text_ocr", extract_text_ocr)
multimodal_workflow.add_node("validate_extraction", validate_extraction)
multimodal_workflow.add_node("generate_answer", generate_final_answer)

# TODO: Set entry point
multimodal_workflow.set_entry_point("detect_modality")

# TODO: Add conditional routing after modality detection
multimodal_workflow.add_conditional_edges(
    "detect_modality",
    route_by_modality,
    {
        "classify_document": "classify_document",
        "generate_answer": "generate_answer"
    }
)

# TODO: Add sequential edges for vision processing
multimodal_workflow.add_edge("classify_document", "describe_image")
multimodal_workflow.add_edge("describe_image", "extract_text_ocr")
multimodal_workflow.add_edge("extract_text_ocr", "validate_extraction")
multimodal_workflow.add_edge("validate_extraction", "generate_answer")

# TODO: End after generating answer
multimodal_workflow.add_edge("generate_answer", END)

# Compile the workflow
try:
    multimodal_app = multimodal_workflow.compile()
    print("✅ Multimodal workflow compiled successfully!")
except Exception as e:
    print(f"❌ Error compiling workflow: {e}")
    print("💡 Check your state definition and node implementations")

# Visualize the workflow structure
print("\n📊 Workflow Structure:")
print("┌─────────────────┐")
print("│ Detect          │")
print("│ Modality        │")
print("└─────────┬───────┘")
print("          │")
print("     ┌────┴────┐")
print("     ▼         ▼")
print("┌─────────┐ ┌───────┐")
print("│ Vision  │ │ Direct│")
print("│Pipeline │ │Answer │")
print("└─────────┘ └───────┘")
print("     │         │")
print("     ▼         │")
print("┌─────────────┐ │")
print("│   Final     │ │")
print("│  Answer     │◄┘")
print("└─────────────┘")

## Task 5: Test with Real Invoices (5 minutes)

Test your multimodal agent with different scenarios.

**Your Task:** Run the test scenarios and debug any issues.

In [None]:
def test_multimodal_agent():
    """Test the multimodal agent with various scenarios"""
    
    print("🧪 TESTING MULTIMODAL SUPPORT AGENT")
    print("=" * 50)
    
    # Test 1: Text-only query
    print("\n📝 Test 1: Text-only Query")
    print("-" * 30)
    
    try:
        text_state = create_initial_state(
            question="What information is typically found on an invoice?",
            images=[]
        )
        
        if 'multimodal_app' in globals():
            result = multimodal_app.invoke(text_state)
            print(f"✅ Success: {result.get('answer', 'No answer')[:150]}...")
            print(f"   Modality: {result.get('modality')}")
        else:
            print("❌ Workflow not compiled")
            
    except Exception as e:
        print(f"❌ Text test failed: {e}")
    
    # Test 2: Image-only query
    print("\n🖼️ Test 2: Image-only Query")
    print("-" * 30)
    
    if INVOICE_FILES:
        try:
            # TODO: Test with real invoice image
            sample_image = INVOICE_FILES[0]
            encoded_image = encode_image_to_base64(sample_image)
            
            if encoded_image:
                image_state = create_initial_state(
                    question=None,
                    images=[encoded_image]
                )
                
                if 'multimodal_app' in globals():
                    result = multimodal_app.invoke(image_state)
                    print(f"✅ Success: Processed image")
                    print(f"   Modality: {result.get('modality')}")
                    print(f"   Descriptions: {len(result.get('image_descriptions', []))} generated")
                    print(f"   Extracted text: {len(result.get('extracted_text', []))} results")
                    print(f"   Confidence: {result.get('vision_confidence', 0):.2f}")
                else:
                    print("❌ Workflow not compiled")
            else:
                print("❌ Could not encode image")
                
        except Exception as e:
            print(f"❌ Image test failed: {e}")
    else:
        print("⚠️ No invoice images available for testing")
    
    # Test 3: Multimodal query
    print("\n🎭 Test 3: Multimodal Query (Text + Image)")
    print("-" * 45)
    
    if INVOICE_FILES:
        try:
            sample_image = INVOICE_FILES[0]
            encoded_image = encode_image_to_base64(sample_image)
            
            if encoded_image:
                multimodal_state = create_initial_state(
                    question="What is the total amount on this invoice?",
                    images=[encoded_image]
                )
                
                if 'multimodal_app' in globals():
                    result = multimodal_app.invoke(multimodal_state)
                    print(f"✅ Success: {result.get('answer', 'No answer')[:150]}...")
                    print(f"   Modality: {result.get('modality')}")
                    print(f"   Processing metrics: {result.get('processing_metrics', {})}")
                else:
                    print("❌ Workflow not compiled")
            else:
                print("❌ Could not encode image")
                
        except Exception as e:
            print(f"❌ Multimodal test failed: {e}")
    
    # Test 4: Error handling
    print("\n⚠️ Test 4: Error Handling")
    print("-" * 25)
    
    try:
        # Test with invalid base64
        error_state = create_initial_state(
            question="Analyze this image",
            images=["invalid_base64_string"]
        )
        
        if 'multimodal_app' in globals():
            result = multimodal_app.invoke(error_state)
            print(f"✅ Error handled gracefully")
            print(f"   Answer generated: {bool(result.get('answer'))}")
        else:
            print("❌ Workflow not compiled")
            
    except Exception as e:
        print(f"⚠️ Error handling test: {e}")
    
    # Performance summary
    print("\n📊 Performance Summary")
    print("-" * 25)
    final_memory = get_gpu_memory()
    print(f"GPU Memory: {final_memory['allocated']:.1f}GB allocated")
    print(f"Models loaded: {len(VISION_MODELS)}")
    print(f"Test images available: {len(INVOICE_FILES)}")

# TODO: Run the tests
test_multimodal_agent()

## Lab Completion and Self-Assessment

### What You've Built

Congratulations! You've successfully extended a text-only agent into a full multimodal system with:

1. **Enhanced State Management**
   - Extended simple text state to handle images and metadata
   - Added proper tracking for vision processing results
   - Implemented modality detection and routing

2. **Real Vision Models**
   - Integrated CLIP for image classification
   - Used BLIP-2 for automatic image captioning
   - Applied EasyOCR for text extraction
   - Managed GPU memory effectively

3. **Production-Ready Pipeline**
   - Built conditional routing based on input type
   - Implemented error handling for GPU limitations
   - Created comprehensive testing scenarios

### Self-Assessment Questions

Rate your understanding (1-5 scale) and provide brief explanations:

1. **State Management** (1-5): ___
   - How does multimodal state differ from text-only state?
   - What challenges arise when combining different data types?

2. **Vision Model Integration** (1-5): ___
   - Why is GPU memory management critical for vision models?
   - How do you choose appropriate models for your hardware?

3. **Workflow Design** (1-5): ___
   - How does conditional routing improve efficiency?
   - What are the trade-offs between sequential and parallel processing?

4. **Error Handling** (1-5): ___
   - What types of errors are common in vision processing?
   - How do you ensure graceful degradation?

5. **Performance Optimization** (1-5): ___
   - What strategies help manage GPU memory usage?
   - How do you balance accuracy with processing speed?

### Common Issues and Solutions

**GPU Out of Memory (OOM)**
- Use smaller model variants (e.g., CLIP base instead of large)
- Clear cache between model loads: `torch.cuda.empty_cache()`
- Process images in smaller batches
- Use fp16 precision for inference

**Poor OCR Results**
- Preprocess images (resize, enhance contrast)
- Use appropriate language settings for EasyOCR
- Filter results by confidence threshold
- Consider multiple OCR engines for redundancy

**Slow Processing**
- Cache model outputs when possible
- Use GPU acceleration for all models
- Implement batch processing for multiple images
- Consider model quantization for speed

### Extensions for Advanced Students

If you completed the lab early, try these enhancements:

1. **Confidence Thresholds**
   - Implement dynamic confidence thresholds
   - Route to human review for low-confidence results

2. **Ensemble Approach**
   - Combine multiple vision models for better accuracy
   - Weight results based on model strengths

3. **Image Preprocessing**
   - Add image enhancement pipeline
   - Implement automatic rotation detection

4. **Batch Processing**
   - Process multiple invoices simultaneously
   - Implement progress tracking for large batches

### Next Steps

To further your multimodal AI expertise:

1. **Experiment with Different Models**
   - Try LayoutLM for document understanding
   - Explore newer vision-language models

2. **Optimize for Production**
   - Implement model quantization
   - Add comprehensive monitoring

3. **Build Domain-Specific Models**
   - Fine-tune models on your specific document types
   - Create custom classification categories

**Congratulations!** You've successfully built a production-ready multimodal AI system that can process both text and images intelligently!