
---

# **CHAPTER 27: MULTIMODAL AI**

*Unified Perception Across Vision, Language, and Audio*

## **Chapter Overview**

While unimodal models excel within single domains, real-world intelligence requires integrating information across vision, language, audio, and sensor data. This chapter covers the architectures and training paradigms that align heterogeneous modalities into unified representation spaces, enabling capabilities like visual question answering, image captioning, and audio-visual speech recognition.

**Estimated Time:** 40-50 hours (3-4 weeks)  
**Prerequisites:** Chapters 14, 25 (Transformers), Chapter 26 (Generative AI), computer vision fundamentals

---

## **27.0 Learning Objectives**

By the end of this chapter, you will be able to:
1. Implement contrastive learning objectives (InfoNCE) to align vision and language representations
2. Architect vision-language models using projection layers and Q-Formers (BLIP-2 style)
3. Fine-tune large language models for multimodal understanding (LLaVA, InstructBLIP patterns)
4. Process audio spectrograms and implement speech recognition pipelines (Whisper architecture)
5. Design video understanding systems that model temporal dependencies across frames
6. Evaluate multimodal systems for grounding, hallucination, and cross-modal retrieval

---

## **27.1 Vision-Language Foundations**

#### **27.1.1 Contrastive Language-Image Pre-training (CLIP)**

Aligns images and text in shared embedding space via contrastive loss.

```python
# clip_model.py
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import CLIPTokenizer, CLIPTextModel
from torchvision import models

class CLIPModel(nn.Module):
    def __init__(self, embed_dim=512, temperature=0.07):
        super().__init__()
        self.temperature = nn.Parameter(torch.ones([]) * temperature)
        
        # Image encoder (Vision Transformer or ResNet)
        self.image_encoder = models.vision_transformer.vit_b_16(pretrained=True)
        self.image_encoder.heads = nn.Identity()  # Remove classification head
        
        # Project to common space
        self.image_projection = nn.Linear(768, embed_dim)
        
        # Text encoder (Transformer)
        self.text_encoder = CLIPTextModel.from_pretrained("openai/clip-vit-base-patch32")
        self.text_projection = nn.Linear(512, embed_dim)
        
    def forward(self, images, input_ids, attention_mask):
        # Encode images: (batch, 3, 224, 224) -> (batch, embed_dim)
        image_features = self.image_encoder(images)
        image_features = self.image_projection(image_features)
        image_features = F.normalize(image_features, dim=-1)
        
        # Encode text: (batch, seq_len) -> (batch, embed_dim)
        text_outputs = self.text_encoder(input_ids=input_ids, attention_mask=attention_mask)
        text_features = text_outputs.pooler_output  # [CLS] token
        text_features = self.text_projection(text_features)
        text_features = F.normalize(text_features, dim=-1)
        
        # Compute similarity matrix: (batch, batch)
        logits = torch.matmul(image_features, text_features.t()) / self.temperature
        
        # Symmetric contrastive loss (InfoNCE)
        batch_size = images.size(0)
        labels = torch.arange(batch_size, device=images.device)
        
        loss_i2t = F.cross_entropy(logits, labels)  # Image-to-text
        loss_t2i = F.cross_entropy(logits.t(), labels)  # Text-to-image
        
        return (loss_i2t + loss_t2i) / 2
    
    def encode_image(self, images):
        features = self.image_encoder(images)
        return F.normalize(self.image_projection(features), dim=-1)
    
    def encode_text(self, input_ids, attention_mask):
        outputs = self.text_encoder(input_ids=input_ids, attention_mask=attention_mask)
        features = self.text_projection(outputs.pooler_output)
        return F.normalize(features, dim=-1)

# Zero-shot classification
def zero_shot_predict(model, images, class_names, tokenizer):
    text_inputs = tokenizer([f"a photo of a {c}" for c in class_names], 
                           return_tensors="pt", padding=True)
    
    with torch.no_grad():
        image_features = model.encode_image(images)
        text_features = model.encode_text(**text_inputs)
        
        similarity = (image_features @ text_features.T).softmax(dim=-1)
    
    return similarity.argmax(dim=-1)
```

**Training Details:**
- **Dataset:** 400M image-text pairs (web-crawled)
- **Batch Size:** 32,768 (larger batches improve contrastive learning)
- **Data Augmentation:** Random square crop, color jitter for images; random truncation for text

#### **27.1.2 ALIGN (Alternative Approach)**

Uses larger noisier datasets with dual encoder architecture and normalized temperature-scaled cross entropy.

---

## **27.2 Multimodal Architectures**

#### **27.2.1 Flamingo-Style: Frozen LLM with Perceiver Resampler**

Inject visual tokens into frozen language model using gated cross-attention.

```python
class PerceiverResampler(nn.Module):
    """
    Compresses variable-length image features to fixed number of latents
    """
    def __init__(self, dim=768, num_latents=64, num_layers=6):
        super().__init__()
        self.latents = nn.Parameter(torch.randn(num_latents, dim))
        
        self.layers = nn.ModuleList([
            nn.TransformerEncoderLayer(d_model=dim, nhead=8, batch_first=True)
            for _ in range(num_layers)
        ])
        
    def forward(self, image_features):
        # image_features: (batch, num_patches, dim)
        batch_size = image_features.size(0)
        latents = self.latents.unsqueeze(0).expand(batch_size, -1, -1)
        
        # Cross-attention: latents attend to image features
        x = torch.cat([latents, image_features], dim=1)
        for layer in self.layers:
            x = layer(x)
            
        return x[:, :self.latents.size(0), :]  # Return only latent queries

class GatedCrossAttention(nn.Module):
    def __init__(self, dim, num_heads):
        super().__init__()
        self.attn = nn.MultiheadAttention(dim, num_heads, batch_first=True)
        self.gate = nn.Parameter(torch.zeros(1))
        self.norm = nn.LayerNorm(dim)
        
    def forward(self, x, context):
        # x: language hidden states, context: visual features
        attn_out, _ = self.attn(x, context, context)
        # Gating mechanism: gradually introduce visual info during training
        return self.norm(x + torch.tanh(self.gate) * attn_out)

class MultimodalLLM(nn.Module):
    def __init__(self, llm_model_name, vision_encoder):
        super().__init__()
        # Load frozen LLM
        self.llm = AutoModelForCausalLM.from_pretrained(llm_model_name)
        self.llm.requires_grad_(False)  # Freeze LLM
        
        # Vision components
        self.vision_encoder = vision_encoder
        self.perceiver = PerceiverResampler(dim=768)
        
        # Insert cross-attention layers every N transformer blocks
        self.gated_layers = nn.ModuleList([
            GatedCrossAttention(dim=4096, num_heads=32) 
            for _ in range(4)  # 4 visual injection points
        ])
        
    def forward(self, images, input_ids, attention_mask):
        # Extract visual tokens
        visual_features = self.vision_encoder(images)  # (batch, 256, 768)
        visual_tokens = self.perceiver(visual_features)  # (batch, 64, 768)
        
        # Project to LLM dimension if needed
        visual_tokens = self.visual_projection(visual_tokens)
        
        # Process through LLM with gated cross-attention
        hidden_states = self.llm.get_input_embeddings()(input_ids)
        
        for i, layer in enumerate(self.llm.model.layers):
            hidden_states = layer(hidden_states, attention_mask=attention_mask)[0]
            
            # Inject vision at specific layers
            if i in [3, 7, 11, 15]:
                visual_idx = [3, 7, 11, 15].index(i)
                hidden_states = self.gated_layers[visual_idx](hidden_states, visual_tokens)
        
        # Language modeling head
        logits = self.llm.lm_head(hidden_states)
        return logits
```

#### **27.2.2 LLaVA Architecture (Linear Projection)**

Simpler approach: single linear layer projects vision features to LLM embedding space.

```python
class LLaVAModel(nn.Module):
    def __init__(self, vision_tower, llm, mm_projector_type="linear"):
        super().__init__()
        self.vision_tower = vision_tower  # CLIP ViT
        self.llm = llm  # Vicuna/LLaMA
        
        if mm_projector_type == "linear":
            self.mm_projector = nn.Linear(vision_tower.config.hidden_size, 
                                         llm.config.hidden_size)
        else:
            # MLP projector
            self.mm_projector = nn.Sequential(
                nn.Linear(vision_tower.config.hidden_size, llm.config.hidden_size),
                nn.GELU(),
                nn.Linear(llm.config.hidden_size, llm.config.hidden_size)
            )
        
    def prepare_inputs_labels_for_multimodal(
        self, input_ids, images, attention_mask
    ):
        # Encode images
        with torch.no_grad():
            image_features = self.vision_tower(images).last_hidden_state[:, 1:]  # Remove CLS
        
        # Project to LLM space
        image_embeds = self.mm_projector(image_features)
        
        # Insert image tokens into text sequence
        # <image> token in input_ids is replaced with image_embeds
        batch_size = input_ids.size(0)
        new_input_embeds = []
        
        for batch_idx in range(batch_size):
            cur_input_ids = input_ids[batch_idx]
            image_token_idx = torch.where(cur_input_ids == IMAGE_TOKEN_INDEX)[0]
            
            cur_input_embeds = self.llm.get_input_embeddings()(cur_input_ids)
            
            # Replace <image> token position with image embeddings
            num_images = image_token_idx.size(0)
            if num_images > 0:
                cur_image_features = image_embeds[batch_idx]
                cur_input_embeds[image_token_idx] = cur_image_features[:num_images]
            
            new_input_embeds.append(cur_input_embeds)
        
        # Stack and pass to LLM
        inputs_embeds = torch.stack(new_input_embeds)
        return self.llm(inputs_embeds=inputs_embeds, attention_mask=attention_mask)
```

---

## **27.3 Training Strategies**

#### **27.3.1 Instruction Tuning for Multimodal Models**

Convert diverse tasks to instruction-following format:

```json
{
  "instruction": "What is unusual about this image?",
  "input": "<image>",
  "output": "The cat is sitting on a computer keyboard while code is being written, which is unusual because cats typically obstruct keyboards rather than assist with programming."
}
```

**Data Formatting:**
- **Pre-training:** Alignment (captioning) on image-text pairs
- **Instruction Tuning:** Visual question answering, OCR, reasoning with chain-of-thought

#### **27.3.2 Parameter-Efficient Fine-Tuning**

Freeze vision encoder and LLM, only train projection layers and LoRA adapters.

```python
from peft import get_peft_model, LoraConfig, TaskType

# Configure LoRA for LLM attention layers
peft_config = LoraConfig(
    task_type=TaskType.CAUSAL_LM,
    inference_mode=False,
    r=8,  # LoRA rank
    lora_alpha=32,
    lora_dropout=0.1,
    target_modules=["q_proj", "v_proj"]  # Target attention projections
)

model.llm = get_peft_model(model.llm, peft_config)

# Only unfreeze projection layer
for param in model.mm_projector.parameters():
    param.requires_grad = True
```

---

## **27.4 Speech & Audio**

#### **27.4.1 Whisper Architecture (Encoder-Decoder for ASR)**

```python
# Whisper-style model
class WhisperModel(nn.Module):
    def __init__(self, n_mels=80, n_audio_ctx=1500, n_text_ctx=448):
        super().__init__()
        
        # Audio encoder (conv + transformer)
        self.conv1 = nn.Conv1d(n_mels, 384, kernel_size=3, padding=1)
        self.conv2 = nn.Conv1d(384, 384, kernel_size=3, stride=2, padding=1)
        
        self.audio_encoder = nn.TransformerEncoder(
            nn.TransformerEncoderLayer(d_model=384, nhead=6, dim_feedforward=1536),
            num_layers=6
        )
        
        # Text decoder (causal)
        self.token_embedding = nn.Embedding(51864, 384)  # Vocab size
        self.text_decoder = nn.TransformerDecoder(
            nn.TransformerDecoderLayer(d_model=384, nhead=6),
            num_layers=6
        )
        
        # Multi-task format tokens <|transcribe|>, <|translate|>, <|notimestamps|>
        self.special_tokens = {
            "startoftranscript": 50257,
            "transcribe": 50358,
            "translate": 50359,
            "notimestamps": 50362
        }
        
    def forward(self, mel_spectrogram, decoder_input_ids):
        # Audio encoding
        x = F.gelu(self.conv1(mel_spectrogram))
        x = F.gelu(self.conv2(x))
        x = x.permute(0, 2, 1)  # (batch, time, features)
        
        audio_features = self.audio_encoder(x)
        
        # Text decoding with cross-attention to audio
        text_embeds = self.token_embedding(decoder_input_ids)
        
        # Causal mask for decoder
        tgt_mask = nn.Transformer.generate_square_subsequent_mask(decoder_input_ids.size(1))
        
        output = self.text_decoder(
            text_embeds, 
            memory=audio_features,
            tgt_mask=tgt_mask.to(text_embeds.device)
        )
        
        return output

# Preprocessing: Log-Mel spectrogram
def log_mel_spectrogram(audio, n_mels=80, n_fft=400, hop_length=160):
    """
    audio: torch tensor of raw waveform (16kHz)
    """
    window = torch.hann_window(n_fft)
    stft = torch.stft(audio, n_fft, hop_length, window=window, return_complex=True)
    magnitudes = stft.abs() ** 2
    
    # Mel filterbank
    mel_filters = torch.load("mel_filters.pt")  # Pre-computed
    mel_spec = mel_filters @ magnitudes
    
    log_spec = torch.clamp(mel_spec, min=1e-10).log10()
    log_spec = torch.maximum(log_spec, log_spec.max() - 8.0)
    log_spec = (log_spec + 4.0) / 4.0  # Normalize
    
    return log_spec
```

#### **27.4.2 Audio-Language Alignment (AudioCLIP)**

Extend CLIP to audio by adding audio encoder (AST or CNN) and tri-modal contrastive loss.

---

## **27.5 Video Understanding**

#### **27.5.1 Video Vision Transformers**

Handle temporal dimension: tubelet embedding (3D patches) or frame sampling + temporal attention.

```python
class VideoTransformer(nn.Module):
    def __init__(self, img_size=224, patch_size=16, num_frames=8, embed_dim=768):
        super().__init__()
        self.patch_embed = nn.Conv3d(
            3, embed_dim, 
            kernel_size=(2, patch_size, patch_size),  # Temporal patch size 2
            stride=(2, patch_size, patch_size)
        )
        
        # Positional embedding: spatial + temporal
        num_patches_per_frame = (img_size // patch_size) ** 2
        num_temporal = num_frames // 2
        self.pos_embed = nn.Parameter(
            torch.zeros(1, num_temporal * num_patches_per_frame, embed_dim)
        )
        
        self.blocks = nn.ModuleList([
            nn.TransformerEncoderLayer(embed_dim, num_heads=12, batch_first=True)
            for _ in range(12)
        ])
        
    def forward(self, x):
        # x: (batch, channels, time, height, width)
        x = self.patch_embed(x)  # (batch, embed, T', H', W')
        x = x.flatten(2).transpose(1, 2)  # (batch, T'*H'*W', embed)
        
        x = x + self.pos_embed
        
        for block in self.blocks:
            x = block(x)
            
        return x
```

#### **27.5.2 Video-Text Retrieval**

Similar to CLIP but with temporal aggregation (mean pooling or transformer pooling over time).

---

## **27.6 Advanced: Binding Multiple Modalities**

#### **27.6.1 ImageBind (One Embedding Space for 6 Modalities)**

Aligns images, text, audio, depth, thermal, and IMU data using image as binding mechanism.

```python
class ImageBindModel(nn.Module):
    def __init__(self):
        super().__init__()
        # Separate encoders for each modality
        self.image_encoder = ViTModel()
        self.text_encoder = TextTransformer()
        self.audio_encoder = AudioSpectrogramTransformer()
        self.depth_encoder = ViTModel()  # Shares weights with image or separate
        
        # Projection heads to common space
        self.image_proj = nn.Linear(768, 1024)
        self.text_proj = nn.Linear(768, 1024)
        self.audio_proj = nn.Linear(768, 1024)
        
    def forward(self, image=None, text=None, audio=None, depth=None):
        embeddings = {}
        
        if image is not None:
            embeddings['image'] = F.normalize(self.image_proj(self.image_encoder(image)), dim=-1)
        if text is not None:
            embeddings['text'] = F.normalize(self.text_proj(self.text_encoder(text)), dim=-1)
        if audio is not None:
            embeddings['audio'] = F.normalize(self.audio_proj(self.audio_encoder(audio)), dim=-1)
            
        return embeddings
    
    def contrastive_loss(self, embeddings):
        # All pairs of modalities present in batch should align
        loss = 0
        modalities = list(embeddings.keys())
        
        for i in range(len(modalities)):
            for j in range(i+1, len(modalities)):
                m1, m2 = modalities[i], modalities[j]
                logits = embeddings[m1] @ embeddings[m2].T / self.temperature
                labels = torch.arange(len(logits))
                loss += (F.cross_entropy(logits, labels) + 
                        F.cross_entropy(logits.T, labels)) / 2
                        
        return loss
```

---

## **27.7 Workbook Labs**

### **Lab 1: CLIP from Scratch**
Train a CLIP-style model on COCO or Conceptual Captions subset:

1. **Data Pipeline:** Load image-text pairs, apply augmentations
2. **Architecture:** ViT-B/32 + Text Transformer
3. **Training:** Implement symmetric InfoNCE loss, train for 10 epochs
4. **Evaluation:** Zero-shot classification on CIFAR-10, image-text retrieval Recall@K

**Deliverable:** Trained model with retrieval accuracy metrics.

### **Lab 2: Visual Instruction Tuning**
Fine-tune multimodal model on visual QA dataset:

1. **Base Model:** Load LLaVA-1.5 architecture (pretrained projector + Vicuna)
2. **Dataset:** Convert VQA v2 to instruction format
3. **Fine-tuning:** LoRA on LLM, full tuning on projection layer
4. **Inference:** Generate answers for held-out images, evaluate with GPT-4 judge or exact match

**Deliverable:** Fine-tuned checkpoint and example outputs showing reasoning.

### **Lab 3: Speech Recognition Pipeline**
Implement Whisper-style ASR:

1. **Preprocessing:** Convert audio to log-Mel spectrograms
2. **Model:** Encoder-decoder transformer
3. **Training:** Train on LibriSpeech subset with CTC or cross-entropy
4. **Decoding:** Implement greedy and beam search decoding

**Deliverable:** WER (Word Error Rate) evaluation on test-clean.

### **Lab 4: Multimodal Retrieval System**
Build cross-modal search engine:

1. **Indexing:** Encode image database with CLIP
2. **Search:** Text-to-image and image-to-image retrieval
3. **Ranking:** Re-rank with fine-tuned model on domain data
4. **UI:** Gradio interface for searching photo collection

**Deliverable:** Working retrieval system with mAP (mean Average Precision) evaluation.

---

## **27.8 Common Pitfalls**

1. **Modality Collapse:** One modality dominates, others ignored. **Solution:** Gradient masking, balanced sampling, or contrastive loss temperature tuning per modality.

2. **Misalignment in Tokenization:** Different sequence lengths for vision patches vs. text tokens causing fusion issues. **Solution:** Careful position ID management, separate position embeddings per modality.

3. **Hallucination in VLM:** Model generates text not grounded in image. **Solution:** Instruction tuning with negative examples (unanswerable questions), grounding losses (text-image matching during generation).

4. **Audio Preprocessing Errors:** Incorrect spectrogram parameters (window size, hop length) causing information loss. **Solution:** Validate reconstruction from spectrogram, standardize on 16kHz, 80-bin mel.

5. **Temporal Misalignment in Video:** Treating all frames equally ignores temporal structure. **Solution:** Temporal positional encodings, temporal attention pooling, or 3D convolutions.

---

## **27.9 Interview Questions**

**Q1:** How does CLIP enable zero-shot classification, and what are its limitations?
*A: CLIP learns joint embedding space where text descriptions and images are close if semantically similar. Zero-shot works by embedding class names as text prompts ("a photo of a [class]"), computing similarity with image, selecting highest score. Limitations: (1) Requires good text descriptions (poor on rare or fine-grained concepts), (2) Biases from training data (web-crawled pairs often noisy or biased), (3) Struggles with abstract concepts or systematic generalization (e.g., counting), (4) Poor performance on specialized domains not in pre-training (medical, satellite).*

**Q2:** Explain the difference between Flamingo-style and LLaVA-style multimodal architectures.
*A: Flamingo: Keeps LLM frozen, inserts gated cross-attention layers at specific transformer depths to inject visual information. Uses Perceiver resampler to compress variable image features to fixed tokens. More parameter-efficient, better for few-shot. LLaVA: Projects vision features directly into LLM embedding space via simple MLP/linear layer, treating images as "foreign language" tokens. Fine-tunes LLM with LoRA. Simpler, easier to train, but may disrupt pre-trained LLM knowledge if not careful. Flamingo better preserves LLM capabilities; LLaVA more scalable for instruction tuning.*

**Q3:** How do you handle different sequence lengths (e.g., 50 vision tokens vs. 512 text tokens) in multimodal transformers?
*A: (1) Projection to same dimension then concatenation along sequence dimension, (2) Separate position embeddings for each modality (vision positions 0-49, text 50-561), (3) Modality-specific segment embeddings added to indicate source, (4) Attention masks ensure causality only within text (vision can attend bidirectionally), (5) Perceiver resampler to compress vision to fixed small number of tokens (e.g., 64) regardless of image resolution, simplifying sequence management.*

**Q4:** What is the "curse of multimodality" in training, and how do you mitigate it?
*A: Different modalities have different scales, convergence rates, and optimal hyperparameters. Vision often converges slower than text; audio has different spectral characteristics. Mitigation: (1) Modality-specific optimizers or learning rates, (2) Gradient clipping per modality, (3) Balanced sampling ensuring equal representation, (4) Frozen encoders for majority of training then gradual unfreezing, (5) Normalization layers per modality before fusion, (6) Contrastive losses balance gradient contributions.*

**Q5:** Design a system to detect inconsistency between video and audio (e.g., lip-sync deepfake detection).
*A: Architecture: Two-stream network processing video frames (mouth region) and audio spectrograms. Options: (1) Late fusion: Embed both with modality-specific encoders, compute cosine similarity (should be high for real, low for fake), (2) Early fusion: Concatenate features at multiple time scales, train binary classifier, (3) SyncNet-style: Contrastive learning where aligned AV pairs are positive, temporally shifted are negative. Key: Temporal alignment is critical—use audio features shifted by video latency (typically 0-200ms). Augmentation: Train with various audio delays to learn robustness.*

---

## **27.10 Further Reading**

**Papers:**
- "Learning Transferable Visual Models From Natural Language Supervision" (CLIP, Radford et al., 2021)
- "Flamingo: A Visual Language Model for Few-Shot Learning" (Alayrac et al., 2022)
- "Visual Instruction Tuning" (LLaVA, Liu et al., 2023)
- "Robust Speech Recognition via Large-Scale Weak Supervision" (Whisper, Radford et al., 2022)
- "ImageBind: One Embedding Space To Bind Them All" (Girdhar et al., 2023)

**Tools:**
- **Hugging Face Transformers:** CLIP, LLaVA, Whisper implementations
- **OpenCLIP:** Open-source CLIP training code
- **Salesforce LAVIS:** Library for vision-language tasks

---

## **27.11 Checkpoint Project: Multimodal Document AI**

Build a system to understand and reason about PDF documents containing text, figures, and tables.

**Requirements:**

1. **Document Parsing:**
   - Extract text (OCR if needed), images, and layout information
   - Use LayoutLM or custom detection model for region classification (text block, figure, table)

2. **Multimodal Encoding:**
   - Text: Standard transformer tokenizer
   - Figures: CLIP or fine-tuned ResNet visual encoder
   - Layout: 2D positional embeddings (bounding box coordinates)

3. **Reasoning Engine:**
   - Fine-tuned multimodal LLM (LLaVA-style) for document QA
   - Support questions requiring joint reasoning over text and figures ("What does Figure 3 indicate about the trend in Table 1?")

4. **Retrieval:**
   - Index documents by multimodal embeddings
   - Cross-modal search: Find documents containing specific visual concepts

5. **Evaluation:**
   - Dataset: DocVQA or custom annotated technical manuals
   - Metrics: ANLS (Average Normalized Levenshtein Similarity) for text answers, accuracy for multiple choice

**Deliverables:**
- `document_ai/` pipeline with OCR, layout analysis, and QA
- Fine-tuned model checkpoint
- Evaluation report on document understanding benchmark
- Demo: Interactive QA over technical manual PDF

**Success Criteria:**
- Answer questions requiring joint text+figure reasoning with >70% accuracy
- Successfully retrieve documents based on visual content descriptions
- Handle multi-page documents with cross-page references

---

**End of Chapter 27**

*You now master multimodal AI systems. Chapter 28 covers AI Safety, Alignment, and Robustness.*

<div style='width:100%; display:flex; justify-content:space-between; align-items:center; margin: 1em 0;'>
  <a href='26. generative_AI_and_diffusion_models.ipynb' style='font-weight:bold; font-size:1.05em;'>&larr; Previous</a>
  <a href='../TOC.md' style='font-weight:bold; font-size:1.05em; text-align:center;'>Table of Contents</a>
  <a href='28. ai_safety_alignment_and_robustness.ipynb' style='font-weight:bold; font-size:1.05em;'>Next &rarr;</a>
</div>
