# Beyond Language: Transformers for Vision, Audio, and Multimodal AI

This notebook provides a hands-on exploration of how transformers have expanded beyond text to revolutionize computer vision, audio processing, and multimodal AI. We'll implement and visualize examples from Article 7, demonstrating practical applications across industries.

## What You'll Learn

1. **Vision Transformers**: How ViT, DeiT, and Swin process images
2. **Audio Processing**: Speech recognition with Whisper and audio classification
3. **Generative AI**: Creating images with Stable Diffusion XL
4. **Multimodal Models**: CLIP, BLIP, and cross-modal search
5. **Production Pipelines**: Building scalable systems with SGLang

Let's begin our journey beyond language!

## Setup and Imports

First, let's set up our environment and import the necessary libraries. We'll use the Hugging Face ecosystem for its unified API across modalities.

In [None]:
# Core imports
import sys
from pathlib import Path
import warnings
warnings.filterwarnings('ignore')

# Add src to path
sys.path.append('../src')

# Import our modules
from config import get_device, SAMPLE_IMAGE_URL, OUTPUT_DIR, IMAGES_DIR, AUDIO_DIR

# ML libraries
from transformers import (
    AutoImageProcessor, AutoModelForImageClassification,
    AutoProcessor, AutoModel, AutoModelForSpeechRecognition,
    pipeline, BlipProcessor, BlipForConditionalGeneration
)
from diffusers import StableDiffusionXLPipeline
from PIL import Image
import torch
import numpy as np
import matplotlib.pyplot as plt
import requests
from IPython.display import Audio, display
import time

# Check device
device = get_device()
print(f"Using device: {device}")
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")

## Part 1: Vision Transformers - Teaching Transformers to See

Vision Transformers (ViT) revolutionized computer vision by applying the transformer architecture to images. Instead of processing pixels sequentially, ViT divides images into patches and treats them like words in a sentence.

### How Vision Transformers Work

1. **Image → Patches**: Split image into fixed-size patches (e.g., 16×16 pixels)
2. **Patches → Embeddings**: Convert each patch to a vector representation
3. **Add Positions**: Include positional encodings so the model knows patch locations
4. **Self-Attention**: Let patches "communicate" to understand the full image
5. **Classification**: Output final prediction

Let's visualize this process:

In [None]:
# Visualize how ViT processes images
def visualize_vit_patches(image_url, patch_size=16):
    """Visualize how ViT divides an image into patches."""
    # Load image
    response = requests.get(image_url, stream=True)
    img = Image.open(response.raw)
    img_array = np.array(img)
    
    # Create figure
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 5))
    
    # Original image
    ax1.imshow(img_array)
    ax1.set_title("Original Image", fontsize=14)
    ax1.axis('off')
    
    # Image with patch grid
    ax2.imshow(img_array)
    ax2.set_title(f"ViT Patches ({patch_size}×{patch_size} pixels)", fontsize=14)
    
    # Draw grid
    h, w = img_array.shape[:2]
    for i in range(0, h, patch_size):
        ax2.axhline(y=i, color='red', linewidth=0.5, alpha=0.7)
    for i in range(0, w, patch_size):
        ax2.axvline(x=i, color='red', linewidth=0.5, alpha=0.7)
    
    ax2.axis('off')
    plt.tight_layout()
    plt.show()
    
    # Calculate number of patches
    n_patches = (h // patch_size) * (w // patch_size)
    print(f"\nImage divided into {n_patches} patches")
    print(f"Each patch: {patch_size}×{patch_size} pixels")
    print(f"Total patches for 224×224 image: {(224//patch_size)**2}")

# Visualize patching process
visualize_vit_patches(SAMPLE_IMAGE_URL)

### Basic Image Classification with ViT

Now let's use Vision Transformer to classify an image. Notice how simple the API is—Hugging Face abstracts away the complexity.

In [None]:
def classify_image_with_vit(image_url):
    """Classify an image using Vision Transformer."""
    print("Loading Vision Transformer...")
    
    # Load model and processor
    processor = AutoImageProcessor.from_pretrained('google/vit-base-patch16-224')
    model = AutoModelForImageClassification.from_pretrained('google/vit-base-patch16-224')
    
    if device != "cpu":
        model = model.to(device)
    
    # Load and display image
    image = Image.open(requests.get(image_url, stream=True).raw)
    
    # Create visualization
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 5))
    
    # Show image
    ax1.imshow(image)
    ax1.set_title("Input Image", fontsize=14)
    ax1.axis('off')
    
    # Process and predict
    print("\nProcessing image...")
    inputs = processor(images=image, return_tensors="pt")
    if device != "cpu":
        inputs = {k: v.to(device) for k, v in inputs.items()}
    
    with torch.no_grad():
        outputs = model(**inputs)
    
    # Get top 5 predictions
    logits = outputs.logits
    probs = torch.nn.functional.softmax(logits, dim=-1)
    top5_probs, top5_indices = torch.topk(probs[0], 5)
    
    # Plot predictions
    labels = [model.config.id2label[idx.item()] for idx in top5_indices]
    scores = top5_probs.cpu().numpy()
    
    ax2.barh(labels, scores, color='skyblue')
    ax2.set_xlabel('Confidence Score', fontsize=12)
    ax2.set_title('Top 5 Predictions', fontsize=14)
    ax2.set_xlim(0, 1)
    
    # Add percentage labels
    for i, (label, score) in enumerate(zip(labels, scores)):
        ax2.text(score + 0.01, i, f'{score*100:.1f}%', va='center')
    
    plt.tight_layout()
    plt.show()
    
    print(f"\nTop prediction: {labels[0]} ({scores[0]*100:.1f}% confidence)")
    return labels[0], scores[0]

# Run classification
predicted_class, confidence = classify_image_with_vit(SAMPLE_IMAGE_URL)

### Comparing Modern Vision Transformers

Let's compare different vision transformer architectures to see how they've evolved. Each has unique strengths:

- **ViT**: Original, simple, effective with large datasets
- **DeiT**: Data-efficient, uses distillation, faster training
- **Swin**: Hierarchical, shifted windows, better for dense prediction tasks

In [None]:
def compare_vision_transformers(image_url):
    """Compare different vision transformer architectures."""
    models = {
        "ViT": "google/vit-base-patch16-224",
        "DeiT": "facebook/deit-base-patch16-224",
        "Swin": "microsoft/swin-tiny-patch4-window7-224"
    }
    
    # Load image once
    image = Image.open(requests.get(image_url, stream=True).raw)
    
    results = []
    
    print("Comparing Vision Transformer Architectures")
    print("=" * 50)
    
    for name, model_id in models.items():
        print(f"\n{name}:")
        
        try:
            start_time = time.time()
            
            # Load model
            processor = AutoImageProcessor.from_pretrained(model_id)
            model = AutoModelForImageClassification.from_pretrained(model_id)
            
            if device != "cpu":
                model = model.to(device)
            
            # Process image
            inputs = processor(images=image, return_tensors="pt")
            if device != "cpu":
                inputs = {k: v.to(device) for k, v in inputs.items()}
            
            # Predict
            with torch.no_grad():
                outputs = model(**inputs)
            
            # Get prediction
            pred_idx = outputs.logits.argmax(-1).item()
            pred_label = model.config.id2label[pred_idx]
            confidence = torch.nn.functional.softmax(outputs.logits, dim=-1)[0, pred_idx].item()
            
            inference_time = time.time() - start_time
            
            results.append({
                "model": name,
                "prediction": pred_label,
                "confidence": confidence,
                "time": inference_time
            })
            
            print(f"  Prediction: {pred_label}")
            print(f"  Confidence: {confidence:.3f}")
            print(f"  Time: {inference_time:.3f}s")
            
        except Exception as e:
            print(f"  Error: {e}")
    
    # Visualize comparison
    if results:
        fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 5))
        
        # Confidence comparison
        models_names = [r['model'] for r in results]
        confidences = [r['confidence'] for r in results]
        
        ax1.bar(models_names, confidences, color=['#1f77b4', '#ff7f0e', '#2ca02c'])
        ax1.set_ylabel('Confidence Score', fontsize=12)
        ax1.set_title('Model Confidence Comparison', fontsize=14)
        ax1.set_ylim(0, 1)
        
        # Inference time comparison
        times = [r['time'] for r in results]
        
        ax2.bar(models_names, times, color=['#1f77b4', '#ff7f0e', '#2ca02c'])
        ax2.set_ylabel('Inference Time (seconds)', fontsize=12)
        ax2.set_title('Inference Speed Comparison', fontsize=14)
        
        plt.tight_layout()
        plt.show()
    
    return results

# Compare models
comparison_results = compare_vision_transformers(SAMPLE_IMAGE_URL)

## Part 2: Audio Processing - Teaching Transformers to Listen

Transformers have revolutionized audio processing, especially speech recognition. Models like Whisper process raw audio end-to-end, eliminating the need for hand-crafted features.

### How Audio Transformers Work

1. **Audio Waveform**: Raw sound waves captured by microphone
2. **Spectrogram**: Convert to visual representation of frequencies over time
3. **Transformer Processing**: Apply self-attention to understand patterns
4. **Text Generation**: Decode audio features into text tokens

Let's visualize this process:

In [None]:
# Create sample audio for demonstration
def create_sample_audio():
    """Create a simple audio sample for visualization."""
    import numpy as np
    import scipy.io.wavfile as wavfile
    
    # Generate a simple tone
    sample_rate = 16000
    duration = 2  # seconds
    frequency = 440  # Hz (A4 note)
    
    t = np.linspace(0, duration, int(sample_rate * duration))
    # Create a tone with some variation
    audio = np.sin(2 * np.pi * frequency * t) * 0.5
    audio += np.sin(2 * np.pi * frequency * 2 * t) * 0.2  # Add harmonic
    audio += np.random.normal(0, 0.05, audio.shape)  # Add noise
    
    # Save audio
    audio_path = OUTPUT_DIR / "sample_audio.wav"
    wavfile.write(str(audio_path), sample_rate, (audio * 32767).astype(np.int16))
    
    return audio_path, audio, sample_rate

# Visualize audio processing
def visualize_audio_processing():
    """Visualize how transformers process audio."""
    audio_path, audio, sample_rate = create_sample_audio()
    
    fig, axes = plt.subplots(2, 2, figsize=(12, 8))
    
    # 1. Waveform
    time = np.arange(len(audio)) / sample_rate
    axes[0, 0].plot(time, audio)
    axes[0, 0].set_title('Audio Waveform', fontsize=14)
    axes[0, 0].set_xlabel('Time (s)')
    axes[0, 0].set_ylabel('Amplitude')
    
    # 2. Spectrogram
    from scipy import signal
    frequencies, times, spectrogram = signal.spectrogram(audio, sample_rate)
    
    im = axes[0, 1].pcolormesh(times, frequencies, 10 * np.log10(spectrogram))
    axes[0, 1].set_title('Spectrogram', fontsize=14)
    axes[0, 1].set_xlabel('Time (s)')
    axes[0, 1].set_ylabel('Frequency (Hz)')
    plt.colorbar(im, ax=axes[0, 1], label='Power (dB)')
    
    # 3. Mel-spectrogram (what Whisper uses)
    # Simplified visualization
    mel_spec = np.random.randn(80, 100)  # Placeholder for mel-spectrogram
    axes[1, 0].imshow(mel_spec, aspect='auto', origin='lower')
    axes[1, 0].set_title('Mel-Spectrogram (Whisper Input)', fontsize=14)
    axes[1, 0].set_xlabel('Time Frames')
    axes[1, 0].set_ylabel('Mel Channels')
    
    # 4. Processing pipeline
    axes[1, 1].text(0.5, 0.7, 'Audio Processing Pipeline:', 
                    ha='center', va='center', fontsize=16, weight='bold')
    axes[1, 1].text(0.5, 0.5, '1. Raw Audio → Mel-Spectrogram', 
                    ha='center', va='center', fontsize=12)
    axes[1, 1].text(0.5, 0.4, '2. Transformer Encoder (30 layers)', 
                    ha='center', va='center', fontsize=12)
    axes[1, 1].text(0.5, 0.3, '3. Transformer Decoder', 
                    ha='center', va='center', fontsize=12)
    axes[1, 1].text(0.5, 0.2, '4. Text Token Generation', 
                    ha='center', va='center', fontsize=12)
    axes[1, 1].axis('off')
    
    plt.tight_layout()
    plt.show()
    
    # Play audio
    print("\nGenerated sample audio:")
    display(Audio(audio, rate=sample_rate))
    
    return audio_path

# Visualize audio processing
sample_audio_path = visualize_audio_processing()

### Speech Recognition with Whisper

Whisper is OpenAI's robust speech recognition model, trained on 680,000 hours of multilingual data. Let's see it in action:

In [None]:
def demonstrate_whisper():
    """Demonstrate speech recognition with Whisper."""
    print("Loading Whisper model...")
    
    # Create ASR pipeline
    asr = pipeline(
        "automatic-speech-recognition",
        model="openai/whisper-base",
        device=0 if device == "cuda" else -1
    )
    
    # Check for sample audio files
    audio_files = list(Path(AUDIO_DIR).glob("*.wav"))
    
    if audio_files:
        # Process real audio file
        audio_file = audio_files[0]
        print(f"\nTranscribing: {audio_file.name}")
        
        result = asr(str(audio_file))
        print(f"Transcription: {result['text']}")
        
        # Display audio
        display(Audio(str(audio_file)))
    else:
        # Demonstrate with explanation
        print("\nNo audio files found. Here's how Whisper works:")
        print("\n1. Input: Audio file (WAV, MP3, FLAC, etc.)")
        print("2. Processing: Convert to mel-spectrogram")
        print("3. Encoding: 30-layer transformer encoder")
        print("4. Decoding: Generate text tokens")
        print("5. Output: Transcribed text")
        
        print("\nExample transcription:")
        print("Audio: [Person speaking about AI]")
        print("Transcription: 'Artificial intelligence is transforming how we interact with technology.'")
    
    print("\nWhisper capabilities:")
    print("✓ 99 languages supported")
    print("✓ Robust to accents and noise")
    print("✓ Automatic language detection")
    print("✓ Timestamp generation")
    print("✓ Translation to English")

# Demonstrate Whisper
demonstrate_whisper()

### Audio Classification

Beyond speech, transformers can classify any type of sound—from music genres to environmental sounds. This has applications in security, healthcare, and smart home devices.

In [None]:
def demonstrate_audio_classification():
    """Demonstrate audio event classification."""
    print("Audio Classification with Transformers")
    print("=" * 50)
    
    # Create audio classifier
    classifier = pipeline(
        "audio-classification",
        model="superb/wav2vec2-base-superb-ks",
        device=0 if device == "cuda" else -1
    )
    
    # Check for audio files
    audio_files = list(Path(AUDIO_DIR).glob("*.wav"))
    
    if audio_files:
        # Classify real audio
        for audio_file in audio_files[:2]:
            print(f"\nClassifying: {audio_file.name}")
            
            try:
                results = classifier(str(audio_file))
                
                # Visualize results
                labels = [r['label'] for r in results[:5]]
                scores = [r['score'] for r in results[:5]]
                
                plt.figure(figsize=(8, 4))
                plt.barh(labels, scores, color='lightcoral')
                plt.xlabel('Confidence Score')
                plt.title(f'Audio Classification: {audio_file.name}')
                plt.xlim(0, 1)
                
                for i, score in enumerate(scores):
                    plt.text(score + 0.01, i, f'{score:.3f}', va='center')
                
                plt.tight_layout()
                plt.show()
                
            except Exception as e:
                print(f"Error: {e}")
    else:
        # Show example use cases
        print("\nAudio Classification Use Cases:")
        
        use_cases = [
            {"domain": "Security", "sounds": ["glass_breaking", "alarm", "scream"]},
            {"domain": "Healthcare", "sounds": ["cough", "snoring", "heartbeat"]},
            {"domain": "Smart Home", "sounds": ["doorbell", "dog_barking", "baby_crying"]},
            {"domain": "Industrial", "sounds": ["machinery_fault", "leak", "abnormal_vibration"]}
        ]
        
        fig, axes = plt.subplots(2, 2, figsize=(12, 8))
        axes = axes.ravel()
        
        for idx, use_case in enumerate(use_cases):
            # Simulate classification results
            sounds = use_case["sounds"]
            scores = np.random.dirichlet(np.ones(len(sounds)) * 2)
            
            axes[idx].barh(sounds, scores, color=plt.cm.Set3(idx))
            axes[idx].set_xlabel('Detection Confidence')
            axes[idx].set_title(f'{use_case["domain"]} Applications')
            axes[idx].set_xlim(0, 1)
            
            for i, score in enumerate(scores):
                axes[idx].text(score + 0.01, i, f'{score:.2f}', va='center')
        
        plt.tight_layout()
        plt.show()

# Demonstrate audio classification
demonstrate_audio_classification()

## Part 3: Generative AI - Creating Images with Diffusion Models

Diffusion models like Stable Diffusion XL represent a breakthrough in generative AI. They start with random noise and progressively refine it into detailed images guided by text prompts.

### How Diffusion Works

1. **Start with Noise**: Pure random pixels
2. **Text Encoding**: Convert prompt to embeddings
3. **Iterative Denoising**: Gradually remove noise
4. **Guidance**: Text embeddings steer the process
5. **Final Image**: High-quality result after ~25-50 steps

In [None]:
def visualize_diffusion_process():
    """Visualize the diffusion process conceptually."""
    fig, axes = plt.subplots(1, 5, figsize=(15, 3))
    
    # Simulate diffusion steps
    steps = [0, 5, 10, 20, 25]
    
    for idx, (ax, step) in enumerate(zip(axes, steps)):
        # Create increasingly clear image
        noise_level = 1.0 - (step / 25)
        
        # Base image (hidden in noise initially)
        x, y = np.meshgrid(np.linspace(-1, 1, 100), np.linspace(-1, 1, 100))
        image = np.exp(-(x**2 + y**2) / 0.2)  # Simple gaussian
        
        # Add noise
        noise = np.random.randn(100, 100) * noise_level
        noisy_image = image * (1 - noise_level) + noise
        
        ax.imshow(noisy_image, cmap='viridis')
        ax.set_title(f'Step {step}', fontsize=12)
        ax.axis('off')
        
        if step == 0:
            ax.text(50, 110, 'Pure Noise', ha='center', fontsize=10)
        elif step == 25:
            ax.text(50, 110, 'Final Image', ha='center', fontsize=10)
    
    plt.suptitle('Diffusion Process: From Noise to Image', fontsize=16)
    plt.tight_layout()
    plt.show()
    
    print("\nDiffusion Model Components:")
    print("1. Text Encoder: Converts prompt to embeddings")
    print("2. U-Net: Predicts noise to remove at each step")
    print("3. Scheduler: Controls the denoising process")
    print("4. VAE Decoder: Converts latents to final image")

# Visualize diffusion
visualize_diffusion_process()

### Generating Images with Stable Diffusion XL

Let's generate images from text prompts. Note: This requires significant GPU memory. We'll show the process even if generation isn't possible on your hardware.

In [None]:
def demonstrate_image_generation():
    """Demonstrate text-to-image generation."""
    print("Text-to-Image Generation with Stable Diffusion XL")
    print("=" * 50)
    
    prompts = [
        {
            "prompt": "A serene Japanese garden with cherry blossoms and a wooden bridge, highly detailed digital art",
            "style": "peaceful_landscape"
        },
        {
            "prompt": "A friendly robot teaching mathematics to children in a colorful classroom, cartoon style",
            "style": "educational_illustration"
        },
        {
            "prompt": "An astronaut riding a horse on Mars, photorealistic, dramatic lighting",
            "style": "surreal_concept"
        }
    ]
    
    # Check if we can actually generate
    can_generate = device == "cuda" and torch.cuda.get_device_properties(0).total_memory > 8e9
    
    if can_generate:
        try:
            print("Loading Stable Diffusion XL...")
            from diffusion_models import generate_image
            
            fig, axes = plt.subplots(1, 3, figsize=(15, 5))
            
            for idx, (prompt_info, ax) in enumerate(zip(prompts, axes)):
                print(f"\nGenerating: {prompt_info['style']}...")
                
                # Generate image
                image = generate_image(
                    prompt_info["prompt"],
                    num_inference_steps=25
                )
                
                # Display
                ax.imshow(image)
                ax.set_title(prompt_info['style'].replace('_', ' ').title(), fontsize=12)
                ax.axis('off')
            
            plt.tight_layout()
            plt.show()
            
        except Exception as e:
            print(f"\nGeneration failed: {e}")
            can_generate = False
    
    if not can_generate:
        # Show prompt engineering tips instead
        print("\nPrompt Engineering Best Practices:")
        print("\n1. Be Specific:")
        print("   ❌ 'A dog'")
        print("   ✅ 'A golden retriever puppy playing in autumn leaves, soft lighting'")
        
        print("\n2. Include Style Modifiers:")
        print("   • 'digital art', 'oil painting', 'photorealistic'")
        print("   • 'studio lighting', 'golden hour', 'dramatic shadows'")
        print("   • 'highly detailed', '4k', 'award winning'")
        
        print("\n3. Use Negative Prompts:")
        print("   • 'blurry, low quality, distorted'")
        print("   • 'extra limbs, bad anatomy' (for people)")
        print("   • 'watermark, text, logo'")
        
        # Visualize prompt structure
        fig, ax = plt.subplots(figsize=(10, 6))
        ax.text(0.5, 0.9, 'Anatomy of a Good Prompt', ha='center', fontsize=16, weight='bold')
        
        components = [
            "Subject: 'A majestic eagle'",
            "Action: 'soaring through'",
            "Setting: 'mountain peaks at sunset'",
            "Style: 'photorealistic, National Geographic style'",
            "Quality: 'highly detailed, 8k resolution'"
        ]
        
        for i, component in enumerate(components):
            ax.text(0.1, 0.7 - i*0.12, component, fontsize=12)
        
        ax.axis('off')
        plt.tight_layout()
        plt.show()

# Demonstrate image generation
demonstrate_image_generation()

## Part 4: Multimodal Models - Connecting Vision and Language

Multimodal models like CLIP and BLIP understand both images and text, enabling powerful applications like image search, captioning, and visual question answering.

### How CLIP Works

CLIP (Contrastive Language-Image Pretraining) learns to embed images and text in the same vector space:

1. **Dual Encoders**: Separate networks for images and text
2. **Shared Space**: Both produce 512-dimensional vectors
3. **Contrastive Learning**: Matching pairs pulled together, mismatches pushed apart
4. **Zero-Shot**: Can classify images using any text description

In [None]:
def visualize_clip_embedding_space():
    """Visualize how CLIP creates a shared embedding space."""
    from sklearn.decomposition import PCA
    import matplotlib.patches as patches
    
    # Create synthetic embeddings for visualization
    np.random.seed(42)
    
    # Image embeddings (clustered by type)
    cat_images = np.random.randn(5, 50) + np.array([2, 0] + [0]*48)
    dog_images = np.random.randn(5, 50) + np.array([-2, 0] + [0]*48)
    car_images = np.random.randn(5, 50) + np.array([0, 2] + [0]*48)
    
    # Text embeddings (should be near corresponding images)
    cat_texts = np.random.randn(3, 50) * 0.5 + np.array([2, 0] + [0]*48)
    dog_texts = np.random.randn(3, 50) * 0.5 + np.array([-2, 0] + [0]*48)
    car_texts = np.random.randn(3, 50) * 0.5 + np.array([0, 2] + [0]*48)
    
    # Combine all embeddings
    all_embeddings = np.vstack([
        cat_images, dog_images, car_images,
        cat_texts, dog_texts, car_texts
    ])
    
    # Reduce to 2D for visualization
    pca = PCA(n_components=2)
    embeddings_2d = pca.fit_transform(all_embeddings)
    
    # Plot
    fig, ax = plt.subplots(figsize=(10, 8))
    
    # Plot image embeddings
    ax.scatter(embeddings_2d[:5, 0], embeddings_2d[:5, 1], 
               c='red', s=100, marker='o', label='Cat Images', alpha=0.6)
    ax.scatter(embeddings_2d[5:10, 0], embeddings_2d[5:10, 1], 
               c='blue', s=100, marker='o', label='Dog Images', alpha=0.6)
    ax.scatter(embeddings_2d[10:15, 0], embeddings_2d[10:15, 1], 
               c='green', s=100, marker='o', label='Car Images', alpha=0.6)
    
    # Plot text embeddings
    ax.scatter(embeddings_2d[15:18, 0], embeddings_2d[15:18, 1], 
               c='red', s=100, marker='^', label='Cat Texts', alpha=0.8)
    ax.scatter(embeddings_2d[18:21, 0], embeddings_2d[18:21, 1], 
               c='blue', s=100, marker='^', label='Dog Texts', alpha=0.8)
    ax.scatter(embeddings_2d[21:24, 0], embeddings_2d[21:24, 1], 
               c='green', s=100, marker='^', label='Car Texts', alpha=0.8)
    
    # Add cluster circles
    for center, color in [(embeddings_2d[:5].mean(0), 'red'),
                          (embeddings_2d[5:10].mean(0), 'blue'),
                          (embeddings_2d[10:15].mean(0), 'green')]:
        circle = patches.Circle(center, 1.5, fill=False, 
                               edgecolor=color, linewidth=2, linestyle='--', alpha=0.5)
        ax.add_patch(circle)
    
    ax.set_xlabel('Dimension 1', fontsize=12)
    ax.set_ylabel('Dimension 2', fontsize=12)
    ax.set_title('CLIP Embedding Space (Simplified 2D Visualization)', fontsize=16)
    ax.legend(loc='best')
    ax.grid(True, alpha=0.3)
    
    # Add annotations
    ax.annotate('Images and text describing\nthe same concept cluster together',
                xy=(0, -2), xytext=(2, -3.5),
                arrowprops=dict(arrowstyle='->', lw=1.5),
                fontsize=11, ha='center')
    
    plt.tight_layout()
    plt.show()
    
    print("\nKey Insights:")
    print("• Images and matching text descriptions are embedded nearby")
    print("• This enables zero-shot classification and search")
    print("• The actual CLIP space is 512-dimensional")
    print("• Trained on 400 million image-text pairs")

# Visualize CLIP embeddings
visualize_clip_embedding_space()

### Building a Multimodal Search Engine

Let's build a simple image search engine that finds images based on text descriptions. This demonstrates the power of multimodal embeddings.

In [None]:
def build_multimodal_search_demo():
    """Build and demonstrate a multimodal search engine."""
    print("Building Multimodal Search Engine with CLIP")
    print("=" * 50)
    
    # Load CLIP model
    print("\nLoading CLIP model...")
    model = AutoModel.from_pretrained("openai/clip-vit-base-patch16")
    processor = AutoProcessor.from_pretrained("openai/clip-vit-base-patch16")
    model.eval()
    
    if device != "cpu":
        model = model.to(device)
    
    # Create sample images for demo
    print("\nCreating sample image collection...")
    
    # Generate synthetic images with descriptions
    sample_data = [
        {"desc": "Red sports car", "color": "red", "object": "car"},
        {"desc": "Blue ocean waves", "color": "blue", "object": "ocean"},
        {"desc": "Green forest", "color": "green", "object": "forest"},
        {"desc": "Yellow sunflower", "color": "yellow", "object": "flower"},
        {"desc": "Purple mountain sunset", "color": "purple", "object": "mountain"}
    ]
    
    images = []
    for item in sample_data:
        # Create simple colored image
        img = Image.new('RGB', (224, 224), color=item['color'])
        images.append(img)
    
    # Embed images
    print("\nEmbedding images...")
    inputs = processor(images=images, return_tensors="pt", padding=True)
    if device != "cpu":
        inputs = {k: v.to(device) for k, v in inputs.items()}
    
    with torch.no_grad():
        image_features = model.get_image_features(**inputs)
        image_features /= image_features.norm(dim=-1, keepdim=True)
    
    # Search function
    def search_images(query):
        """Search images by text query."""
        # Embed query
        text_inputs = processor(text=[query], return_tensors="pt", padding=True)
        if device != "cpu":
            text_inputs = {k: v.to(device) for k, v in text_inputs.items()}
        
        with torch.no_grad():
            text_features = model.get_text_features(**text_inputs)
            text_features /= text_features.norm(dim=-1, keepdim=True)
        
        # Compute similarities
        similarities = (image_features @ text_features.T).squeeze(1)
        
        # Get ranking
        scores, indices = similarities.sort(descending=True)
        
        return scores.cpu().numpy(), indices.cpu().numpy()
    
    # Test queries
    queries = [
        "something red",
        "nature scene",
        "bright yellow object",
        "water and waves"
    ]
    
    fig, axes = plt.subplots(2, 2, figsize=(12, 10))
    axes = axes.ravel()
    
    for idx, query in enumerate(queries):
        scores, ranking = search_images(query)
        
        # Plot results
        ax = axes[idx]
        ax.set_title(f'Query: "{query}"', fontsize=14, pad=10)
        
        # Show top 3 results
        for i in range(3):
            img_idx = ranking[i]
            score = scores[i]
            
            # Create small subplot
            x_pos = i * 0.3 + 0.1
            y_pos = 0.3
            
            # Draw image placeholder
            rect = plt.Rectangle((x_pos, y_pos), 0.2, 0.3, 
                               facecolor=sample_data[img_idx]['color'],
                               edgecolor='black', linewidth=2)
            ax.add_patch(rect)
            
            # Add description
            ax.text(x_pos + 0.1, y_pos - 0.05, sample_data[img_idx]['desc'],
                   ha='center', fontsize=10)
            ax.text(x_pos + 0.1, y_pos - 0.1, f'Score: {score:.3f}',
                   ha='center', fontsize=9, style='italic')
            
            # Rank label
            ax.text(x_pos + 0.1, y_pos + 0.35, f'#{i+1}',
                   ha='center', fontsize=12, weight='bold')
        
        ax.set_xlim(0, 1)
        ax.set_ylim(0, 1)
        ax.axis('off')
    
    plt.suptitle('Multimodal Search Results', fontsize=16)
    plt.tight_layout()
    plt.show()
    
    print("\n✨ Search engine ready! The system can:")
    print("• Find images matching natural language queries")
    print("• Work with any text description (zero-shot)")
    print("• Rank results by semantic similarity")
    print("• Scale to millions of images with vector databases")

# Build search engine
build_multimodal_search_demo()

### Advanced Multimodal Models: BLIP and Beyond

While CLIP excels at understanding image-text relationships, newer models like BLIP, BLIP-2, and LLaVA can generate captions and answer questions about images.

In [None]:
def demonstrate_advanced_multimodal():
    """Demonstrate advanced multimodal capabilities."""
    print("Advanced Multimodal Models")
    print("=" * 50)
    
    # Create comparison visualization
    fig, ax = plt.subplots(figsize=(12, 8))
    
    # Model timeline and capabilities
    models = [
        {"name": "CLIP", "year": 2021, "capabilities": ["Image-Text Matching", "Zero-shot Classification"]},
        {"name": "BLIP", "year": 2022, "capabilities": ["Image Captioning", "VQA", "Image-Text Matching"]},
        {"name": "BLIP-2", "year": 2023, "capabilities": ["Efficient Architecture", "Better Captioning", "Instruction Following"]},
        {"name": "LLaVA", "year": 2023, "capabilities": ["Visual Reasoning", "Detailed Descriptions", "Multi-turn Dialogue"]},
        {"name": "GPT-4V", "year": 2023, "capabilities": ["Complex Reasoning", "OCR", "Multi-image Understanding"]}
    ]
    
    # Plot timeline
    years = [m["year"] for m in models]
    names = [m["name"] for m in models]
    
    # Create timeline
    for i, model in enumerate(models):
        y_pos = i
        
        # Model box
        rect = plt.Rectangle((model["year"] - 2021, y_pos - 0.3), 0.8, 0.6,
                           facecolor='lightblue', edgecolor='navy', linewidth=2)
        ax.add_patch(rect)
        
        # Model name
        ax.text(model["year"] - 2021 + 0.4, y_pos, model["name"],
               ha='center', va='center', fontsize=12, weight='bold')
        
        # Capabilities
        for j, cap in enumerate(model["capabilities"]):
            ax.text(model["year"] - 2021 + 1.2, y_pos + 0.2 - j*0.2, 
                   f"• {cap}", fontsize=9)
    
    ax.set_xlim(-0.5, 5)
    ax.set_ylim(-0.5, len(models) - 0.5)
    ax.set_xlabel('Years since 2021', fontsize=12)
    ax.set_ylabel('Models', fontsize=12)
    ax.set_title('Evolution of Multimodal Models', fontsize=16)
    ax.grid(True, axis='x', alpha=0.3)
    
    # Remove y-axis ticks
    ax.set_yticks([])
    
    plt.tight_layout()
    plt.show()
    
    # Example use cases
    print("\nPractical Applications:")
    
    use_cases = {
        "E-commerce": [
            "Generate product descriptions from images",
            "Answer customer questions about products",
            "Visual search for similar items"
        ],
        "Healthcare": [
            "Describe medical images for reports",
            "Answer questions about X-rays or scans",
            "Assist in diagnostic workflows"
        ],
        "Accessibility": [
            "Describe images for visually impaired users",
            "Answer questions about visual content",
            "Navigate physical spaces with camera input"
        ],
        "Education": [
            "Explain diagrams and charts",
            "Answer questions about educational images",
            "Create study materials from visuals"
        ]
    }
    
    for domain, apps in use_cases.items():
        print(f"\n{domain}:")
        for app in apps:
            print(f"  • {app}")

# Demonstrate advanced multimodal
demonstrate_advanced_multimodal()

## Part 5: Production Deployment with SGLang

SGLang (Serving Graph Language) enables efficient deployment of multimodal pipelines. It provides graph-based orchestration, automatic optimization, and scalable serving.

### Building a Customer Support Pipeline

Let's design a production pipeline that processes customer screenshots and voice messages to generate support tickets.

In [None]:
def visualize_sglang_pipeline():
    """Visualize a production multimodal pipeline."""
    import matplotlib.patches as mpatches
    from matplotlib.patches import FancyBboxPatch, FancyArrowPatch
    
    fig, ax = plt.subplots(figsize=(12, 8))
    
    # Define nodes
    nodes = [
        {"name": "Image Input", "pos": (1, 7), "color": "lightblue", "type": "input"},
        {"name": "Audio Input", "pos": (1, 5), "color": "lightblue", "type": "input"},
        {"name": "Image Classifier\n(ViT/CLIP)", "pos": (4, 7), "color": "lightgreen", "type": "model"},
        {"name": "Audio Transcriber\n(Whisper)", "pos": (4, 5), "color": "lightgreen", "type": "model"},
        {"name": "Context Combiner", "pos": (7, 6), "color": "lightyellow", "type": "processor"},
        {"name": "LLM Summarizer\n(Llama 2)", "pos": (10, 6), "color": "lightcoral", "type": "model"},
        {"name": "Ticket Output", "pos": (13, 6), "color": "lightgray", "type": "output"}
    ]
    
    # Draw nodes
    for node in nodes:
        box = FancyBboxPatch(
            (node["pos"][0] - 0.8, node["pos"][1] - 0.3),
            1.6, 0.6,
            boxstyle="round,pad=0.1",
            facecolor=node["color"],
            edgecolor="black",
            linewidth=2
        )
        ax.add_patch(box)
        ax.text(node["pos"][0], node["pos"][1], node["name"],
               ha='center', va='center', fontsize=10, weight='bold')
    
    # Draw connections
    connections = [
        (nodes[0]["pos"], nodes[2]["pos"]),  # Image -> Classifier
        (nodes[1]["pos"], nodes[3]["pos"]),  # Audio -> Transcriber
        (nodes[2]["pos"], nodes[4]["pos"]),  # Classifier -> Combiner
        (nodes[3]["pos"], nodes[4]["pos"]),  # Transcriber -> Combiner
        (nodes[4]["pos"], nodes[5]["pos"]),  # Combiner -> Summarizer
        (nodes[5]["pos"], nodes[6]["pos"])   # Summarizer -> Output
    ]
    
    for start, end in connections:
        arrow = FancyArrowPatch(
            start, end,
            connectionstyle="arc3,rad=0.1",
            arrowstyle="->",
            mutation_scale=20,
            linewidth=2,
            color="darkblue"
        )
        ax.add_patch(arrow)
    
    # Add annotations
    ax.text(2.5, 7.5, "Screenshot", fontsize=9, style='italic')
    ax.text(2.5, 4.5, "Voice Message", fontsize=9, style='italic')
    ax.text(11.5, 6.5, "Support Ticket", fontsize=9, style='italic')
    
    # Add optimization badges
    optimizations = [
        {"text": "Quantized\n(AWQ)", "pos": (4, 8)},
        {"text": "Batched\nInference", "pos": (7, 7)},
        {"text": "Cached\nEmbeddings", "pos": (10, 7)}
    ]
    
    for opt in optimizations:
        badge = mpatches.Rectangle(
            (opt["pos"][0] - 0.5, opt["pos"][1] - 0.2),
            1, 0.4,
            facecolor="gold",
            edgecolor="orange",
            linewidth=1
        )
        ax.add_patch(badge)
        ax.text(opt["pos"][0], opt["pos"][1], opt["text"],
               ha='center', va='center', fontsize=8)
    
    ax.set_xlim(0, 14)
    ax.set_ylim(4, 8.5)
    ax.set_title('SGLang Multimodal Pipeline for Customer Support', fontsize=16)
    ax.axis('off')
    
    plt.tight_layout()
    plt.show()
    
    # Show SGLang code structure
    print("\nSGLang Pipeline Definition:")
    print("""\n@sgl.function
def support_pipeline(s, image, audio):
    # Parallel processing of inputs
    s_img = classify_image.run(image=image)
    s_audio = transcribe_audio.run(audio=audio)
    
    # Combine results
    image_class = s_img["classification"]
    transcript = s_audio["text"]
    
    # Generate summary
    s = summarize_ticket(s, image_class, transcript)
    return s
""")
    
    print("\nProduction Features:")
    print("✓ Automatic batching for concurrent requests")
    print("✓ Model quantization (4-bit, 8-bit) for efficiency")
    print("✓ Request caching and deduplication")
    print("✓ Health checks and monitoring endpoints")
    print("✓ Horizontal scaling with load balancing")

# Visualize pipeline
visualize_sglang_pipeline()

### Performance Optimization Strategies

Let's explore key optimization techniques for production multimodal systems:

In [None]:
def demonstrate_optimization_strategies():
    """Show optimization strategies for production deployment."""
    
    # Create comparison chart
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 6))
    
    # Memory usage comparison
    models = ['ViT-Base', 'CLIP-Base', 'Whisper-Base', 'Llama-7B']
    fp32_memory = [0.35, 0.43, 0.39, 28.0]  # GB
    int8_memory = [0.09, 0.11, 0.10, 7.0]   # GB
    int4_memory = [0.04, 0.05, 0.05, 3.5]   # GB
    
    x = np.arange(len(models))
    width = 0.25
    
    ax1.bar(x - width, fp32_memory, width, label='FP32', color='#ff7f0e')
    ax1.bar(x, int8_memory, width, label='INT8', color='#2ca02c')
    ax1.bar(x + width, int4_memory, width, label='INT4', color='#1f77b4')
    
    ax1.set_xlabel('Models', fontsize=12)
    ax1.set_ylabel('Memory Usage (GB)', fontsize=12)
    ax1.set_title('Model Quantization Impact', fontsize=14)
    ax1.set_xticks(x)
    ax1.set_xticklabels(models)
    ax1.legend()
    ax1.grid(True, axis='y', alpha=0.3)
    
    # Throughput comparison
    batch_sizes = [1, 4, 8, 16, 32]
    single_model = [10, 35, 65, 120, 200]  # requests/sec
    with_batching = [10, 40, 75, 145, 280]
    with_caching = [15, 60, 120, 230, 420]
    
    ax2.plot(batch_sizes, single_model, 'o-', label='Single Model', linewidth=2)
    ax2.plot(batch_sizes, with_batching, 's-', label='With Batching', linewidth=2)
    ax2.plot(batch_sizes, with_caching, '^-', label='With Caching', linewidth=2)
    
    ax2.set_xlabel('Batch Size', fontsize=12)
    ax2.set_ylabel('Throughput (requests/sec)', fontsize=12)
    ax2.set_title('Optimization Techniques Impact', fontsize=14)
    ax2.legend()
    ax2.grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.show()
    
    # Best practices
    print("\nProduction Deployment Best Practices:")
    print("\n1. Model Optimization:")
    print("   • Quantization: Reduce precision (FP16 → INT8 → INT4)")
    print("   • Pruning: Remove unnecessary weights")
    print("   • Distillation: Train smaller models from larger ones")
    
    print("\n2. Inference Optimization:")
    print("   • Batching: Process multiple requests together")
    print("   • Caching: Store frequent computations")
    print("   • Streaming: Return partial results as available")
    
    print("\n3. Infrastructure:")
    print("   • GPU pooling: Share GPUs across services")
    print("   • Auto-scaling: Scale based on load")
    print("   • Load balancing: Distribute requests evenly")
    
    print("\n4. Monitoring:")
    print("   • Latency tracking: P50, P95, P99")
    print("   • Error rates: Track failures and retries")
    print("   • Resource usage: GPU, memory, network")

# Show optimization strategies
demonstrate_optimization_strategies()

## Summary and Next Steps

Congratulations! You've explored how transformers have expanded beyond language to revolutionize:

1. **Vision**: ViT, DeiT, and Swin process images as sequences
2. **Audio**: Whisper and Wav2Vec enable robust speech recognition
3. **Generation**: Diffusion models create images from text
4. **Multimodal**: CLIP and BLIP connect different modalities
5. **Production**: SGLang enables scalable deployment

### Key Takeaways

- **Unified Architecture**: Transformers work across all modalities
- **Pretrained Models**: Leverage existing models for your tasks
- **Simple APIs**: Hugging Face makes advanced AI accessible
- **Production Ready**: Modern tools enable scalable deployment

### What's Next?

1. **Experiment**: Try different models on your own data
2. **Fine-tune**: Adapt models to your specific domain
3. **Build**: Create multimodal applications
4. **Deploy**: Use SGLang or similar tools for production

The future is multimodal—and you're now equipped to build it!

In [None]:
# Final visualization: The Multimodal AI Landscape
def create_final_summary():
    """Create a final summary visualization."""
    fig, ax = plt.subplots(figsize=(12, 8))
    
    # Central node
    ax.scatter([0], [0], s=2000, c='gold', edgecolors='black', linewidth=3, zorder=5)
    ax.text(0, 0, 'Multimodal\nAI', ha='center', va='center', fontsize=16, weight='bold')
    
    # Modality nodes
    modalities = [
        {"name": "Vision", "pos": (-2, 2), "color": "lightblue", "examples": ["ViT", "DeiT", "Swin"]},
        {"name": "Audio", "pos": (2, 2), "color": "lightgreen", "examples": ["Whisper", "Wav2Vec", "CLAP"]},
        {"name": "Text", "pos": (0, -2.5), "color": "lightcoral", "examples": ["BERT", "GPT", "T5"]},
        {"name": "Generation", "pos": (-2, -2), "color": "plum", "examples": ["DALL-E", "Stable Diffusion", "Midjourney"]},
        {"name": "Cross-Modal", "pos": (2, -2), "color": "lightyellow", "examples": ["CLIP", "BLIP", "LLaVA"]}
    ]
    
    for mod in modalities:
        # Draw connection to center
        ax.plot([0, mod["pos"][0]], [0, mod["pos"][1]], 'k--', alpha=0.3, linewidth=2)
        
        # Draw modality circle
        ax.scatter(mod["pos"][0], mod["pos"][1], s=1500, c=mod["color"], 
                  edgecolors='black', linewidth=2, zorder=3)
        ax.text(mod["pos"][0], mod["pos"][1], mod["name"], 
               ha='center', va='center', fontsize=14, weight='bold')
        
        # Add examples
        for i, example in enumerate(mod["examples"]):
            offset = 0.7
            angle = np.pi/6 * (i - 1)
            x = mod["pos"][0] + offset * np.cos(angle)
            y = mod["pos"][1] + offset * np.sin(angle) - 0.8
            ax.text(x, y, example, ha='center', va='center', fontsize=9, 
                   bbox=dict(boxstyle="round,pad=0.3", facecolor='white', alpha=0.7))
    
    # Add applications around the edge
    applications = [
        "Healthcare", "E-commerce", "Education", "Entertainment",
        "Security", "Accessibility", "Research", "Creative Arts"
    ]
    
    for i, app in enumerate(applications):
        angle = 2 * np.pi * i / len(applications)
        x = 3.5 * np.cos(angle)
        y = 3.5 * np.sin(angle)
        ax.text(x, y, app, ha='center', va='center', fontsize=11, 
               style='italic', color='darkblue')
    
    ax.set_xlim(-4.5, 4.5)
    ax.set_ylim(-4.5, 4.5)
    ax.set_aspect('equal')
    ax.axis('off')
    ax.set_title('The Multimodal AI Ecosystem', fontsize=18, pad=20)
    
    plt.tight_layout()
    plt.show()
    
    print("\n🎉 You're now ready to build with multimodal AI!")
    print("\nResources for continued learning:")
    print("• Hugging Face Model Hub: huggingface.co/models")
    print("• Papers With Code: paperswithcode.com")
    print("• Course Materials: huggingface.co/course")
    print("• Community: discuss.huggingface.co")

# Create final summary
create_final_summary()