# Multimodal Production Systems

This notebook covers production-grade multimodal ML systems - essential for FAANG ML engineers building next-generation AI applications.

## Topics Covered
1. **Vision-Language Models** - CLIP-style architectures
2. **Multimodal Embeddings** - Cross-modal representations
3. **Image-Text Retrieval** - Production search systems
4. **Audio Processing** - Speech and audio ML
5. **Video Understanding** - Temporal multimodal models
6. **Production Deployment** - Serving multimodal systems

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from typing import Dict, List, Any, Optional, Tuple, Union
from dataclasses import dataclass, field
from enum import Enum
from abc import ABC, abstractmethod
import json
import time
from datetime import datetime
import logging

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

## 1. Vision-Language Models (CLIP-style)

Contrastive Language-Image Pre-training for unified vision-text understanding.

In [None]:
class VisionEncoder(nn.Module):
    """
    Vision Transformer (ViT) encoder for images.
    Simplified implementation for demonstration.
    """
    
    def __init__(
        self,
        image_size: int = 224,
        patch_size: int = 16,
        in_channels: int = 3,
        embed_dim: int = 768,
        num_heads: int = 12,
        num_layers: int = 12,
        projection_dim: int = 512
    ):
        super().__init__()
        self.patch_size = patch_size
        self.embed_dim = embed_dim
        num_patches = (image_size // patch_size) ** 2
        
        # Patch embedding
        self.patch_embed = nn.Conv2d(
            in_channels, embed_dim,
            kernel_size=patch_size, stride=patch_size
        )
        
        # Class token and position embedding
        self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
        self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
        
        # Transformer encoder
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=embed_dim,
            nhead=num_heads,
            dim_feedforward=embed_dim * 4,
            batch_first=True
        )
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
        
        # Layer norm and projection
        self.ln = nn.LayerNorm(embed_dim)
        self.projection = nn.Linear(embed_dim, projection_dim)
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """Forward pass: image -> embedding"""
        batch_size = x.shape[0]
        
        # Patch embedding: (B, C, H, W) -> (B, num_patches, embed_dim)
        x = self.patch_embed(x)
        x = x.flatten(2).transpose(1, 2)
        
        # Add class token
        cls_tokens = self.cls_token.expand(batch_size, -1, -1)
        x = torch.cat([cls_tokens, x], dim=1)
        
        # Add position embedding
        x = x + self.pos_embed
        
        # Transformer
        x = self.transformer(x)
        
        # Get CLS token output and project
        x = self.ln(x[:, 0])
        x = self.projection(x)
        
        return x


class TextEncoder(nn.Module):
    """
    Transformer encoder for text.
    """
    
    def __init__(
        self,
        vocab_size: int = 50000,
        max_seq_len: int = 77,
        embed_dim: int = 512,
        num_heads: int = 8,
        num_layers: int = 12,
        projection_dim: int = 512
    ):
        super().__init__()
        self.embed_dim = embed_dim
        self.max_seq_len = max_seq_len
        
        # Token and position embeddings
        self.token_embed = nn.Embedding(vocab_size, embed_dim)
        self.pos_embed = nn.Parameter(torch.zeros(1, max_seq_len, embed_dim))
        
        # Transformer encoder
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=embed_dim,
            nhead=num_heads,
            dim_feedforward=embed_dim * 4,
            batch_first=True
        )
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
        
        # Layer norm and projection
        self.ln = nn.LayerNorm(embed_dim)
        self.projection = nn.Linear(embed_dim, projection_dim)
    
    def forward(
        self,
        input_ids: torch.Tensor,
        attention_mask: torch.Tensor = None
    ) -> torch.Tensor:
        """Forward pass: tokens -> embedding"""
        seq_len = input_ids.shape[1]
        
        # Token + position embedding
        x = self.token_embed(input_ids)
        x = x + self.pos_embed[:, :seq_len, :]
        
        # Create attention mask for transformer
        if attention_mask is not None:
            # Convert to transformer format (True = ignore)
            src_key_padding_mask = (attention_mask == 0)
        else:
            src_key_padding_mask = None
        
        # Transformer
        x = self.transformer(x, src_key_padding_mask=src_key_padding_mask)
        
        # Pool: take EOS token (last non-padded token)
        if attention_mask is not None:
            # Get position of last real token
            seq_lengths = attention_mask.sum(dim=1) - 1
            pooled = x[torch.arange(x.shape[0]), seq_lengths]
        else:
            pooled = x[:, -1]
        
        # Layer norm and project
        pooled = self.ln(pooled)
        pooled = self.projection(pooled)
        
        return pooled


class CLIPModel(nn.Module):
    """
    CLIP-style vision-language model.
    """
    
    def __init__(
        self,
        vision_config: Dict[str, Any] = None,
        text_config: Dict[str, Any] = None,
        projection_dim: int = 512,
        temperature: float = 0.07
    ):
        super().__init__()
        
        vision_config = vision_config or {}
        text_config = text_config or {}
        
        self.vision_encoder = VisionEncoder(
            projection_dim=projection_dim,
            **vision_config
        )
        self.text_encoder = TextEncoder(
            projection_dim=projection_dim,
            **text_config
        )
        
        # Learnable temperature
        self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / temperature))
    
    def encode_image(self, images: torch.Tensor) -> torch.Tensor:
        """Encode images to embeddings"""
        return F.normalize(self.vision_encoder(images), dim=-1)
    
    def encode_text(
        self,
        input_ids: torch.Tensor,
        attention_mask: torch.Tensor = None
    ) -> torch.Tensor:
        """Encode text to embeddings"""
        return F.normalize(
            self.text_encoder(input_ids, attention_mask),
            dim=-1
        )
    
    def forward(
        self,
        images: torch.Tensor,
        input_ids: torch.Tensor,
        attention_mask: torch.Tensor = None
    ) -> Dict[str, torch.Tensor]:
        """Forward pass with contrastive loss"""
        # Get embeddings
        image_embeds = self.encode_image(images)
        text_embeds = self.encode_text(input_ids, attention_mask)
        
        # Compute similarity scores
        logit_scale = self.logit_scale.exp()
        logits_per_image = logit_scale * image_embeds @ text_embeds.T
        logits_per_text = logits_per_image.T
        
        return {
            "image_embeds": image_embeds,
            "text_embeds": text_embeds,
            "logits_per_image": logits_per_image,
            "logits_per_text": logits_per_text,
            "logit_scale": logit_scale
        }
    
    def compute_loss(self, outputs: Dict[str, torch.Tensor]) -> torch.Tensor:
        """Compute contrastive loss"""
        batch_size = outputs["logits_per_image"].shape[0]
        labels = torch.arange(batch_size, device=outputs["logits_per_image"].device)
        
        loss_i2t = F.cross_entropy(outputs["logits_per_image"], labels)
        loss_t2i = F.cross_entropy(outputs["logits_per_text"], labels)
        
        return (loss_i2t + loss_t2i) / 2


# Example: CLIP Model
clip_model = CLIPModel(
    vision_config={"image_size": 224, "patch_size": 16, "num_layers": 4},
    text_config={"vocab_size": 10000, "num_layers": 4},
    projection_dim=256
)

# Simulate batch
batch_size = 4
images = torch.randn(batch_size, 3, 224, 224)
input_ids = torch.randint(0, 10000, (batch_size, 32))
attention_mask = torch.ones(batch_size, 32)

outputs = clip_model(images, input_ids, attention_mask)
loss = clip_model.compute_loss(outputs)

print(f"Image embeddings shape: {outputs['image_embeds'].shape}")
print(f"Text embeddings shape: {outputs['text_embeds'].shape}")
print(f"Contrastive loss: {loss.item():.4f}")

## 2. Multimodal Embeddings & Fusion

Combining information from multiple modalities.

In [None]:
class ModalityFusion(nn.Module):
    """
    Multi-modal fusion strategies.
    """
    
    def __init__(
        self,
        embed_dims: Dict[str, int],
        fusion_dim: int = 512,
        fusion_type: str = "concat"
    ):
        super().__init__()
        self.fusion_type = fusion_type
        self.modalities = list(embed_dims.keys())
        
        # Projection layers to common dimension
        self.projections = nn.ModuleDict({
            mod: nn.Linear(dim, fusion_dim)
            for mod, dim in embed_dims.items()
        })
        
        if fusion_type == "concat":
            self.fusion_layer = nn.Linear(
                fusion_dim * len(embed_dims),
                fusion_dim
            )
        elif fusion_type == "attention":
            self.cross_attention = nn.MultiheadAttention(
                embed_dim=fusion_dim,
                num_heads=8,
                batch_first=True
            )
            self.fusion_layer = nn.Linear(fusion_dim, fusion_dim)
        elif fusion_type == "gated":
            self.gates = nn.ModuleDict({
                mod: nn.Sequential(
                    nn.Linear(fusion_dim, fusion_dim),
                    nn.Sigmoid()
                )
                for mod in embed_dims.keys()
            })
            self.fusion_layer = nn.Linear(fusion_dim, fusion_dim)
    
    def forward(
        self,
        modality_embeds: Dict[str, torch.Tensor]
    ) -> torch.Tensor:
        """
        Fuse embeddings from multiple modalities.
        
        Args:
            modality_embeds: Dict of modality name -> embedding tensor
        """
        # Project all modalities to common dimension
        projected = {
            mod: self.projections[mod](embed)
            for mod, embed in modality_embeds.items()
        }
        
        if self.fusion_type == "concat":
            # Simple concatenation
            fused = torch.cat(
                [projected[mod] for mod in self.modalities],
                dim=-1
            )
            return self.fusion_layer(fused)
        
        elif self.fusion_type == "attention":
            # Cross-modal attention
            # Stack modalities as sequence
            stacked = torch.stack(
                [projected[mod] for mod in self.modalities],
                dim=1
            )  # (B, num_modalities, fusion_dim)
            
            # Self-attention across modalities
            attended, _ = self.cross_attention(stacked, stacked, stacked)
            
            # Pool and project
            pooled = attended.mean(dim=1)
            return self.fusion_layer(pooled)
        
        elif self.fusion_type == "gated":
            # Gated fusion
            gated_embeds = []
            for mod in self.modalities:
                gate = self.gates[mod](projected[mod])
                gated_embeds.append(gate * projected[mod])
            
            # Sum gated embeddings
            fused = sum(gated_embeds)
            return self.fusion_layer(fused)
        
        else:
            raise ValueError(f"Unknown fusion type: {self.fusion_type}")


class MultimodalEmbeddingModel(nn.Module):
    """
    End-to-end multimodal embedding model.
    """
    
    def __init__(
        self,
        fusion_dim: int = 512,
        fusion_type: str = "attention"
    ):
        super().__init__()
        
        # Modality encoders (simplified)
        self.image_encoder = nn.Sequential(
            nn.Conv2d(3, 64, 7, stride=2, padding=3),
            nn.ReLU(),
            nn.AdaptiveAvgPool2d(1),
            nn.Flatten(),
            nn.Linear(64, 256)
        )
        
        self.text_encoder = nn.Sequential(
            nn.Embedding(10000, 128),
            nn.LSTM(128, 128, batch_first=True),
        )
        self.text_proj = nn.Linear(128, 256)
        
        self.audio_encoder = nn.Sequential(
            nn.Conv1d(1, 32, 7, stride=2, padding=3),
            nn.ReLU(),
            nn.AdaptiveAvgPool1d(1),
            nn.Flatten(),
            nn.Linear(32, 256)
        )
        
        # Fusion
        self.fusion = ModalityFusion(
            embed_dims={"image": 256, "text": 256, "audio": 256},
            fusion_dim=fusion_dim,
            fusion_type=fusion_type
        )
    
    def forward(
        self,
        image: torch.Tensor = None,
        text: torch.Tensor = None,
        audio: torch.Tensor = None
    ) -> torch.Tensor:
        """Forward pass with available modalities"""
        modality_embeds = {}
        
        if image is not None:
            modality_embeds["image"] = self.image_encoder(image)
        
        if text is not None:
            text_output, _ = self.text_encoder(text)
            # Take last hidden state
            modality_embeds["text"] = self.text_proj(text_output[:, -1, :])
        
        if audio is not None:
            modality_embeds["audio"] = self.audio_encoder(audio)
        
        if len(modality_embeds) == 0:
            raise ValueError("At least one modality must be provided")
        
        return self.fusion(modality_embeds)


# Example: Multimodal Fusion
fusion_model = ModalityFusion(
    embed_dims={"image": 512, "text": 256, "audio": 128},
    fusion_dim=512,
    fusion_type="attention"
)

modality_embeds = {
    "image": torch.randn(4, 512),
    "text": torch.randn(4, 256),
    "audio": torch.randn(4, 128)
}

fused = fusion_model(modality_embeds)
print(f"Fused embedding shape: {fused.shape}")

## 3. Image-Text Retrieval System

Production-ready cross-modal search system.

In [None]:
@dataclass
class RetrievalResult:
    """Result from multimodal retrieval"""
    item_id: str
    score: float
    modality: str
    metadata: Dict[str, Any] = field(default_factory=dict)


class MultimodalIndex:
    """
    In-memory multimodal index for fast similarity search.
    In production, use FAISS, Milvus, or Pinecone.
    """
    
    def __init__(self, embedding_dim: int):
        self.embedding_dim = embedding_dim
        self.embeddings: Dict[str, np.ndarray] = {}  # id -> embedding
        self.metadata: Dict[str, Dict[str, Any]] = {}  # id -> metadata
        self.modalities: Dict[str, str] = {}  # id -> modality
        
        # Pre-computed normalized embeddings for fast search
        self._normalized_matrix = None
        self._id_list = None
        self._needs_rebuild = True
    
    def add(
        self,
        item_id: str,
        embedding: np.ndarray,
        modality: str,
        metadata: Dict[str, Any] = None
    ) -> None:
        """Add item to index"""
        self.embeddings[item_id] = embedding
        self.modalities[item_id] = modality
        self.metadata[item_id] = metadata or {}
        self._needs_rebuild = True
    
    def add_batch(
        self,
        item_ids: List[str],
        embeddings: np.ndarray,
        modality: str,
        metadata_list: List[Dict[str, Any]] = None
    ) -> None:
        """Add batch of items"""
        metadata_list = metadata_list or [{} for _ in item_ids]
        
        for i, item_id in enumerate(item_ids):
            self.embeddings[item_id] = embeddings[i]
            self.modalities[item_id] = modality
            self.metadata[item_id] = metadata_list[i]
        
        self._needs_rebuild = True
    
    def _rebuild_index(self) -> None:
        """Rebuild search index"""
        if not self._needs_rebuild:
            return
        
        self._id_list = list(self.embeddings.keys())
        if self._id_list:
            matrix = np.stack([self.embeddings[id] for id in self._id_list])
            # Normalize for cosine similarity
            norms = np.linalg.norm(matrix, axis=1, keepdims=True) + 1e-8
            self._normalized_matrix = matrix / norms
        else:
            self._normalized_matrix = None
        
        self._needs_rebuild = False
    
    def search(
        self,
        query_embedding: np.ndarray,
        top_k: int = 10,
        modality_filter: str = None
    ) -> List[RetrievalResult]:
        """Search for similar items"""
        self._rebuild_index()
        
        if self._normalized_matrix is None:
            return []
        
        # Normalize query
        query_norm = query_embedding / (np.linalg.norm(query_embedding) + 1e-8)
        
        # Compute similarities
        similarities = self._normalized_matrix @ query_norm
        
        # Create results
        results = []
        indices = np.argsort(similarities)[::-1]
        
        for idx in indices:
            item_id = self._id_list[idx]
            
            # Apply modality filter
            if modality_filter and self.modalities[item_id] != modality_filter:
                continue
            
            results.append(RetrievalResult(
                item_id=item_id,
                score=float(similarities[idx]),
                modality=self.modalities[item_id],
                metadata=self.metadata[item_id]
            ))
            
            if len(results) >= top_k:
                break
        
        return results


class CrossModalRetriever:
    """
    Cross-modal retrieval system (e.g., text-to-image search).
    """
    
    def __init__(
        self,
        clip_model: CLIPModel,
        embedding_dim: int = 512
    ):
        self.model = clip_model
        self.index = MultimodalIndex(embedding_dim)
        self.model.train(False)
    
    @torch.no_grad()
    def index_images(
        self,
        images: torch.Tensor,
        image_ids: List[str],
        metadata_list: List[Dict[str, Any]] = None
    ) -> None:
        """Index images for retrieval"""
        embeddings = self.model.encode_image(images).cpu().numpy()
        self.index.add_batch(image_ids, embeddings, "image", metadata_list)
    
    @torch.no_grad()
    def search_by_text(
        self,
        input_ids: torch.Tensor,
        attention_mask: torch.Tensor = None,
        top_k: int = 10
    ) -> List[RetrievalResult]:
        """Search images by text query"""
        text_embedding = self.model.encode_text(
            input_ids, attention_mask
        ).cpu().numpy()[0]
        
        return self.index.search(
            text_embedding,
            top_k=top_k,
            modality_filter="image"
        )
    
    @torch.no_grad()
    def search_by_image(
        self,
        image: torch.Tensor,
        top_k: int = 10
    ) -> List[RetrievalResult]:
        """Search images by image query"""
        image_embedding = self.model.encode_image(image).cpu().numpy()[0]
        return self.index.search(image_embedding, top_k=top_k)


# Example: Cross-modal retrieval
retriever = CrossModalRetriever(clip_model, embedding_dim=256)

# Index some images
images = torch.randn(10, 3, 224, 224)
image_ids = [f"img_{i}" for i in range(10)]
metadata = [{"category": f"cat_{i % 3}"} for i in range(10)]

retriever.index_images(images, image_ids, metadata)

# Search by text
query_ids = torch.randint(0, 10000, (1, 16))
results = retriever.search_by_text(query_ids, top_k=5)

print("Text-to-Image Search Results:")
for r in results:
    print(f"  {r.item_id}: score={r.score:.4f}, metadata={r.metadata}")

## 4. Audio Processing for ML

Audio feature extraction and speech processing.

In [None]:
class AudioFeatureExtractor(nn.Module):
    """
    Audio feature extraction for ML models.
    Extracts mel-spectrograms and learned features.
    """
    
    def __init__(
        self,
        sample_rate: int = 16000,
        n_mels: int = 80,
        n_fft: int = 400,
        hop_length: int = 160,
        embed_dim: int = 512
    ):
        super().__init__()
        self.sample_rate = sample_rate
        self.n_mels = n_mels
        self.n_fft = n_fft
        self.hop_length = hop_length
        
        # Learned filterbanks (simplified - in production use torchaudio)
        self.mel_filters = nn.Parameter(
            torch.randn(n_mels, n_fft // 2 + 1) * 0.1
        )
        
        # Convolutional feature extractor
        self.conv_layers = nn.Sequential(
            nn.Conv1d(n_mels, 128, 3, padding=1),
            nn.ReLU(),
            nn.Conv1d(128, 256, 3, padding=1),
            nn.ReLU(),
            nn.Conv1d(256, 512, 3, padding=1),
            nn.ReLU(),
        )
        
        # Transformer for temporal modeling
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=512,
            nhead=8,
            batch_first=True
        )
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=4)
        
        # Projection
        self.projection = nn.Linear(512, embed_dim)
    
    def compute_spectrogram(
        self,
        waveform: torch.Tensor
    ) -> torch.Tensor:
        """
        Compute mel-spectrogram from waveform.
        Simplified - in production use torchaudio.transforms.MelSpectrogram
        """
        # waveform: (batch, samples)
        batch_size = waveform.shape[0]
        
        # Simulate STFT + mel filterbank
        # In production: use torch.stft and proper mel filterbanks
        n_frames = waveform.shape[1] // self.hop_length
        
        # Reshape and apply learned transform
        mel_spec = torch.randn(
            batch_size, self.n_mels, n_frames,
            device=waveform.device
        )
        
        # Log compression
        mel_spec = torch.log(mel_spec.abs() + 1e-6)
        
        return mel_spec
    
    def forward(self, waveform: torch.Tensor) -> torch.Tensor:
        """Extract audio embeddings from waveform"""
        # Compute mel spectrogram
        mel_spec = self.compute_spectrogram(waveform)
        
        # Convolutional features
        features = self.conv_layers(mel_spec)  # (B, 512, T)
        
        # Transpose for transformer
        features = features.transpose(1, 2)  # (B, T, 512)
        
        # Temporal modeling
        features = self.transformer(features)
        
        # Global average pooling
        pooled = features.mean(dim=1)
        
        # Project to embedding dimension
        embedding = self.projection(pooled)
        
        return embedding


class SpeechRecognitionModel(nn.Module):
    """
    CTC-based speech recognition model.
    """
    
    def __init__(
        self,
        vocab_size: int = 1000,
        n_mels: int = 80,
        hidden_dim: int = 512
    ):
        super().__init__()
        
        # Feature extraction
        self.encoder = nn.Sequential(
            nn.Conv1d(n_mels, 256, 3, padding=1),
            nn.ReLU(),
            nn.Conv1d(256, hidden_dim, 3, padding=1),
            nn.ReLU(),
        )
        
        # Bidirectional LSTM
        self.lstm = nn.LSTM(
            hidden_dim, hidden_dim // 2,
            num_layers=4,
            batch_first=True,
            bidirectional=True
        )
        
        # Output projection
        self.output = nn.Linear(hidden_dim, vocab_size)
    
    def forward(
        self,
        mel_spec: torch.Tensor,
        input_lengths: torch.Tensor = None
    ) -> torch.Tensor:
        """
        Forward pass.
        
        Args:
            mel_spec: (batch, n_mels, time)
            input_lengths: Length of each sequence
        
        Returns:
            Log probabilities: (batch, time, vocab_size)
        """
        # Encode
        features = self.encoder(mel_spec)  # (B, hidden, T)
        features = features.transpose(1, 2)  # (B, T, hidden)
        
        # LSTM
        lstm_out, _ = self.lstm(features)
        
        # Output logits
        logits = self.output(lstm_out)
        
        # Log softmax for CTC
        log_probs = F.log_softmax(logits, dim=-1)
        
        return log_probs
    
    def decode_greedy(self, log_probs: torch.Tensor) -> List[List[int]]:
        """Greedy CTC decoding"""
        # Get most likely tokens
        predictions = torch.argmax(log_probs, dim=-1)  # (B, T)
        
        decoded = []
        for seq in predictions:
            # Remove consecutive duplicates
            tokens = []
            prev = -1
            for token in seq.tolist():
                if token != prev and token != 0:  # 0 = blank
                    tokens.append(token)
                prev = token
            decoded.append(tokens)
        
        return decoded


# Example: Audio Processing
audio_encoder = AudioFeatureExtractor(embed_dim=256)

# Simulate waveform
waveform = torch.randn(4, 16000)  # 1 second at 16kHz
audio_embedding = audio_encoder(waveform)
print(f"Audio embedding shape: {audio_embedding.shape}")

# Speech recognition example
asr_model = SpeechRecognitionModel(vocab_size=100)
mel_spec = torch.randn(4, 80, 100)  # Simulated mel spectrogram
log_probs = asr_model(mel_spec)
decoded = asr_model.decode_greedy(log_probs)
print(f"ASR output shape: {log_probs.shape}")
print(f"Decoded tokens (first sample): {decoded[0][:10]}...")

## 5. Video Understanding

Temporal modeling for video content.

In [None]:
class VideoEncoder(nn.Module):
    """
    Video encoder with temporal modeling.
    Uses a frame encoder + temporal transformer.
    """
    
    def __init__(
        self,
        frame_encoder: nn.Module = None,
        frame_dim: int = 512,
        num_frames: int = 16,
        temporal_dim: int = 512,
        num_temporal_layers: int = 4
    ):
        super().__init__()
        
        # Frame encoder (uses pretrained vision model)
        if frame_encoder is None:
            self.frame_encoder = nn.Sequential(
                nn.Conv2d(3, 64, 7, stride=2, padding=3),
                nn.ReLU(),
                nn.AdaptiveAvgPool2d(1),
                nn.Flatten(),
                nn.Linear(64, frame_dim)
            )
        else:
            self.frame_encoder = frame_encoder
        
        # Temporal position embedding
        self.temporal_pos_embed = nn.Parameter(
            torch.zeros(1, num_frames, temporal_dim)
        )
        
        # Projection to temporal dim
        self.frame_proj = nn.Linear(frame_dim, temporal_dim)
        
        # Temporal transformer
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=temporal_dim,
            nhead=8,
            batch_first=True
        )
        self.temporal_transformer = nn.TransformerEncoder(
            encoder_layer,
            num_layers=num_temporal_layers
        )
        
        # CLS token for video-level representation
        self.cls_token = nn.Parameter(torch.zeros(1, 1, temporal_dim))
        
        # Output projection
        self.output_proj = nn.Linear(temporal_dim, temporal_dim)
    
    def forward(
        self,
        video: torch.Tensor
    ) -> Dict[str, torch.Tensor]:
        """
        Encode video.
        
        Args:
            video: (batch, num_frames, C, H, W)
        
        Returns:
            dict with video_embedding and frame_embeddings
        """
        batch_size, num_frames = video.shape[:2]
        
        # Flatten batch and frames
        frames = video.view(-1, *video.shape[2:])  # (B*T, C, H, W)
        
        # Encode frames
        frame_features = self.frame_encoder(frames)  # (B*T, frame_dim)
        frame_features = frame_features.view(batch_size, num_frames, -1)
        
        # Project to temporal dimension
        frame_features = self.frame_proj(frame_features)
        
        # Add temporal position embedding
        frame_features = frame_features + self.temporal_pos_embed[:, :num_frames, :]
        
        # Add CLS token
        cls_tokens = self.cls_token.expand(batch_size, -1, -1)
        features = torch.cat([cls_tokens, frame_features], dim=1)
        
        # Temporal transformer
        features = self.temporal_transformer(features)
        
        # Get video-level embedding from CLS token
        video_embedding = self.output_proj(features[:, 0])
        frame_embeddings = features[:, 1:]
        
        return {
            "video_embedding": video_embedding,
            "frame_embeddings": frame_embeddings
        }


class VideoTextModel(nn.Module):
    """
    Video-text multimodal model for video understanding.
    """
    
    def __init__(
        self,
        video_dim: int = 512,
        text_dim: int = 512,
        projection_dim: int = 256,
        num_classes: int = None
    ):
        super().__init__()
        
        self.video_encoder = VideoEncoder(temporal_dim=video_dim)
        
        self.text_encoder = nn.Sequential(
            nn.Embedding(10000, 256),
            nn.LSTM(256, text_dim // 2, batch_first=True, bidirectional=True)
        )
        self.text_proj = nn.Linear(text_dim, projection_dim)
        
        # Video projection
        self.video_proj = nn.Linear(video_dim, projection_dim)
        
        # Classification head (if needed)
        if num_classes:
            self.classifier = nn.Linear(projection_dim, num_classes)
        else:
            self.classifier = None
        
        # Temperature for contrastive learning
        self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
    
    def encode_video(self, video: torch.Tensor) -> torch.Tensor:
        """Encode video to embedding"""
        video_output = self.video_encoder(video)
        video_embedding = self.video_proj(video_output["video_embedding"])
        return F.normalize(video_embedding, dim=-1)
    
    def encode_text(self, input_ids: torch.Tensor) -> torch.Tensor:
        """Encode text to embedding"""
        embedded = self.text_encoder[0](input_ids)
        lstm_out, _ = self.text_encoder[1](embedded)
        # Take last hidden state
        text_features = lstm_out[:, -1, :]
        text_embedding = self.text_proj(text_features)
        return F.normalize(text_embedding, dim=-1)
    
    def forward(
        self,
        video: torch.Tensor,
        input_ids: torch.Tensor = None
    ) -> Dict[str, torch.Tensor]:
        """Forward pass"""
        video_embedding = self.encode_video(video)
        
        outputs = {"video_embedding": video_embedding}
        
        if input_ids is not None:
            text_embedding = self.encode_text(input_ids)
            outputs["text_embedding"] = text_embedding
            
            # Compute similarity
            logit_scale = self.logit_scale.exp()
            outputs["logits"] = logit_scale * video_embedding @ text_embedding.T
        
        if self.classifier is not None:
            outputs["class_logits"] = self.classifier(video_embedding)
        
        return outputs


# Example: Video Understanding
video_model = VideoTextModel(num_classes=10)

# Simulate video input (4 videos, 16 frames each)
video = torch.randn(4, 16, 3, 224, 224)
text = torch.randint(0, 10000, (4, 32))

outputs = video_model(video, text)
print(f"Video embedding shape: {outputs['video_embedding'].shape}")
print(f"Text embedding shape: {outputs['text_embedding'].shape}")
print(f"Classification logits shape: {outputs['class_logits'].shape}")

## 6. Production Deployment Patterns

Deploying multimodal systems at scale.

In [None]:
@dataclass
class MultimodalRequest:
    """Request for multimodal inference"""
    request_id: str
    modalities: Dict[str, Any]  # modality_name -> data
    task: str  # 'embedding', 'retrieval', 'classification'
    parameters: Dict[str, Any] = field(default_factory=dict)


@dataclass
class MultimodalResponse:
    """Response from multimodal inference"""
    request_id: str
    results: Dict[str, Any]
    latency_ms: float
    modalities_used: List[str]


class MultimodalServingPipeline:
    """
    Production serving pipeline for multimodal models.
    """
    
    def __init__(
        self,
        model: nn.Module,
        preprocessors: Dict[str, Callable] = None,
        batch_size: int = 16,
        max_wait_ms: float = 50
    ):
        self.model = model
        self.model.train(False)
        self.preprocessors = preprocessors or {}
        self.batch_size = batch_size
        self.max_wait_ms = max_wait_ms
        
        # Request queue for batching
        self.request_queue: List[MultimodalRequest] = []
        self.metrics = {
            "total_requests": 0,
            "total_latency_ms": 0,
            "batch_sizes": []
        }
    
    def preprocess(
        self,
        request: MultimodalRequest
    ) -> Dict[str, torch.Tensor]:
        """Preprocess request data"""
        processed = {}
        
        for modality, data in request.modalities.items():
            if modality in self.preprocessors:
                processed[modality] = self.preprocessors[modality](data)
            else:
                # Default: assume already tensor
                if isinstance(data, np.ndarray):
                    processed[modality] = torch.from_numpy(data)
                elif isinstance(data, torch.Tensor):
                    processed[modality] = data
                else:
                    processed[modality] = torch.tensor(data)
        
        return processed
    
    @torch.no_grad()
    def process_single(
        self,
        request: MultimodalRequest
    ) -> MultimodalResponse:
        """Process a single request"""
        start_time = time.time()
        
        # Preprocess
        processed = self.preprocess(request)
        
        # Add batch dimension
        batched = {k: v.unsqueeze(0) for k, v in processed.items()}
        
        # Inference
        outputs = self._run_inference(batched, request.task)
        
        # Remove batch dimension
        results = {k: v[0] if isinstance(v, torch.Tensor) else v 
                   for k, v in outputs.items()}
        
        latency_ms = (time.time() - start_time) * 1000
        
        self.metrics["total_requests"] += 1
        self.metrics["total_latency_ms"] += latency_ms
        
        return MultimodalResponse(
            request_id=request.request_id,
            results=self._postprocess_results(results),
            latency_ms=latency_ms,
            modalities_used=list(request.modalities.keys())
        )
    
    def _run_inference(
        self,
        inputs: Dict[str, torch.Tensor],
        task: str
    ) -> Dict[str, Any]:
        """Run model inference"""
        # Route to appropriate model method based on task
        if task == "embedding":
            if "image" in inputs:
                return {"embedding": self.model.encode_image(inputs["image"])}
            elif "text" in inputs:
                return {"embedding": self.model.encode_text(inputs["text"])}
        
        # Default: full forward pass
        return self.model(**inputs)
    
    def _postprocess_results(
        self,
        results: Dict[str, Any]
    ) -> Dict[str, Any]:
        """Convert results to serializable format"""
        processed = {}
        for key, value in results.items():
            if isinstance(value, torch.Tensor):
                processed[key] = value.cpu().numpy().tolist()
            else:
                processed[key] = value
        return processed
    
    def get_metrics(self) -> Dict[str, Any]:
        """Get serving metrics"""
        total = self.metrics["total_requests"]
        return {
            "total_requests": total,
            "avg_latency_ms": (
                self.metrics["total_latency_ms"] / total if total > 0 else 0
            ),
            "avg_batch_size": (
                np.mean(self.metrics["batch_sizes"]) 
                if self.metrics["batch_sizes"] else 0
            )
        }


class MultimodalModelOptimizer:
    """
    Optimization techniques for multimodal models.
    """
    
    @staticmethod
    def quantize_model(
        model: nn.Module,
        quantization_type: str = "dynamic"
    ) -> nn.Module:
        """
        Quantize model for faster inference.
        """
        if quantization_type == "dynamic":
            return torch.quantization.quantize_dynamic(
                model,
                {nn.Linear, nn.LSTM},
                dtype=torch.qint8
            )
        else:
            # For static quantization, would need calibration data
            return model
    
    @staticmethod
    def export_onnx(
        model: nn.Module,
        sample_inputs: Dict[str, torch.Tensor],
        output_path: str
    ) -> None:
        """
        Export model to ONNX format.
        """
        # For demo purposes, just print
        print(f"Exporting model to ONNX: {output_path}")
        print(f"Input shapes: {[(k, v.shape) for k, v in sample_inputs.items()]}")
        
        # In production:
        # torch.onnx.export(
        #     model,
        #     tuple(sample_inputs.values()),
        #     output_path,
        #     input_names=list(sample_inputs.keys()),
        #     dynamic_axes={k: {0: 'batch'} for k in sample_inputs.keys()}
        # )
    
    @staticmethod
    def cache_embeddings(
        model: nn.Module,
        data_loader,
        modality: str,
        cache_path: str
    ) -> None:
        """
        Pre-compute and cache embeddings for static content.
        """
        model.train(False)
        all_embeddings = []
        all_ids = []
        
        # Simulate caching
        print(f"Caching {modality} embeddings to {cache_path}")
        print("Would iterate through data_loader and compute embeddings...")


# Example: Production Serving
serving_pipeline = MultimodalServingPipeline(
    model=clip_model,
    batch_size=8
)

# Simulate request
request = MultimodalRequest(
    request_id="req_001",
    modalities={
        "image": torch.randn(3, 224, 224)
    },
    task="embedding"
)

response = serving_pipeline.process_single(request)
print(f"\nServing Response:")
print(f"  Request ID: {response.request_id}")
print(f"  Latency: {response.latency_ms:.2f}ms")
print(f"  Modalities: {response.modalities_used}")
print(f"  Embedding dims: {len(response.results['embedding'])}")

print(f"\nServing Metrics: {serving_pipeline.get_metrics()}")

## FAANG Interview Questions

### Q1: How would you design a multimodal search system (e.g., search images with text)?

**Answer:**
I would design a CLIP-style contrastive learning architecture:

1. **Encoders**: Separate encoders for each modality (ViT for images, Transformer for text)
2. **Shared Embedding Space**: Project all modalities to same dimension with normalization
3. **Training**: Contrastive loss (InfoNCE) on image-text pairs
4. **Indexing**: Pre-compute image embeddings, store in vector database (FAISS/Milvus)
5. **Retrieval**: Encode query text, find nearest neighbors in embedding space
6. **Optimization**: 
   - Quantize embeddings (int8) for memory efficiency
   - Use approximate nearest neighbor (HNSW) for speed
   - Cache hot embeddings in Redis

### Q2: What are the challenges of deploying multimodal models in production?

**Answer:**
Key challenges:

1. **Latency**: Multiple modality encoders increase inference time
   - Solution: Parallel encoding, model distillation, caching

2. **Memory**: Large models and embeddings
   - Solution: Quantization, model pruning, gradient checkpointing

3. **Missing Modalities**: Not all inputs have all modalities
   - Solution: Modality-agnostic training, graceful degradation

4. **Synchronization**: Aligning temporal modalities (audio-video)
   - Solution: Cross-modal attention, contrastive pre-training

5. **Data Pipeline**: Complex preprocessing for multiple formats
   - Solution: Modular preprocessors, streaming pipelines

### Q3: How do you handle temporal alignment in video-audio models?

**Answer:**
Temporal alignment strategies:

1. **Synchronized Sampling**: Sample frames and audio at fixed intervals
2. **Cross-Modal Attention**: Let modalities attend to each other's features
3. **Contrastive Learning**: Train with synchronized pairs vs. misaligned negatives
4. **Temporal Transformers**: Model temporal dependencies across modalities
5. **Feature Interpolation**: Resample features to common temporal resolution

Key considerations:
- Audio typically has higher temporal resolution (16kHz) vs. video (30fps)
- Use mel-spectrograms to bridge audio-visual representations
- Consider temporal jittering for robust training

## Summary

This notebook covered:

1. **Vision-Language Models**: CLIP-style contrastive learning
2. **Multimodal Fusion**: Concat, attention, and gated fusion strategies
3. **Cross-Modal Retrieval**: Production search systems with vector indexing
4. **Audio Processing**: Feature extraction and speech recognition
5. **Video Understanding**: Temporal modeling with transformers
6. **Production Deployment**: Serving pipelines and optimization

### Key Takeaways for FAANG Interviews:
- Multimodal models require unified embedding spaces for cross-modal tasks
- Contrastive learning is the dominant paradigm for vision-language models
- Fusion strategies depend on task: early fusion for dense prediction, late fusion for classification
- Production systems need efficient indexing, caching, and batching
- Handle missing modalities gracefully in real-world deployments