# Lab 2.4.1: Mamba Inference

**Module:** 2.4 - Efficient Architectures  
**Time:** 2 hours  
**Difficulty:** ‚≠ê‚≠ê‚≠ê (Intermediate)

---

## üéØ Learning Objectives

By the end of this lab, you will:
- [ ] Understand why Mamba's O(n) complexity matters for long sequences
- [ ] Load and run Mamba models using HuggingFace
- [ ] Benchmark Mamba against transformers on speed and memory
- [ ] See Mamba's advantage scale with context length

---

## üìö Prerequisites

- Completed: Module 2.3 (NLP & Transformers)
- Knowledge of: Basic transformer architecture, attention mechanism
- Hardware: DGX Spark with 128GB unified memory (or GPU with 16GB+)

---

## üåç Real-World Context

**The Long Context Problem**

Imagine you're building an AI assistant that needs to:
- Analyze entire codebases (100K+ tokens)
- Read and summarize legal documents (50+ pages)
- Process hour-long meeting transcripts

Traditional transformers struggle here because their attention mechanism is O(n¬≤)‚Äîdoubling the context length quadruples the computation! A 32K context requires **1 billion** attention computations per layer.

**Enter Mamba**: A new architecture with O(n) complexity. Process twice as much text with only twice the compute. On DGX Spark's 128GB, this means processing 100K+ token contexts that would crash consumer GPUs.

**Companies using long-context models:**
- Google (Gemini 1M context)
- Anthropic (Claude 200K context)
- AI21 Labs (Jamba hybrid architecture)

---

## üßí ELI5: Understanding Mamba

> **Imagine you're reading a very long book...**
>
> **Transformer approach**: At each word, you flip back through ALL previous pages to understand context. Page 1, page 2, page 3... For a 500-page book, you'd flip through 500 pages at EVERY word. That's exhausting!
>
> **Mamba approach**: You read like a human‚Äîone word at a time, keeping a mental "summary" of what came before. You don't flip back; you just update your summary as you go. Reading page 500 is just as easy as reading page 5!
>
> **In AI terms**: 
> - Transformers use "attention" which looks at all previous tokens (O(n¬≤) memory for KV cache)
> - Mamba uses a "state space" that compresses history into a fixed-size state (O(1) memory!)
> - This means Mamba can read a 100,000-word document using the same memory as a 1,000-word document

### The Key Insight: Selective State Spaces

Mamba doesn't just blindly compress‚Äîit **selects** what's important:

```
Traditional RNN: state = fixed_function(state, input)
Mamba:          state = learned_function(state, input, context)
                        ^^^^^^^^^^^^^^^^
                        The "selective" part!
```

Think of it like taking notes: Mamba learns WHAT to write down based on what it's reading.

---

## Part 1: Environment Setup

Let's verify our environment and understand our hardware.

In [None]:
# Check environment
import sys
print(f"Python: {sys.version}")

# Check transformers version (need >= 4.46.0 for Mamba)
import transformers
print(f"Transformers: {transformers.__version__}")

# Verify minimum version
min_version = (4, 46, 0)
current = tuple(map(int, transformers.__version__.split('.')[:3]))
if current < min_version:
    print(f"‚ö†Ô∏è  Mamba requires transformers >= 4.46.0")
    print(f"   Run: pip install --upgrade transformers")
else:
    print(f"‚úÖ Transformers version OK for Mamba")

import torch
print(f"\nPyTorch: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")

if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name()}")
    total_mem = torch.cuda.get_device_properties(0).total_memory / 1e9
    print(f"GPU Memory: {total_mem:.1f} GB")
    
    # DGX Spark detection
    if total_mem > 100:
        print(f"\nüöÄ DGX Spark detected! 128GB unified memory available.")
        print(f"   You can run very long context experiments!")
    else:
        print(f"\nüí° Tip: Reduce context lengths if you run out of memory.")

In [None]:
# Upgrade transformers if needed (uncomment to run)
# !pip install --upgrade transformers>=4.46.0

In [None]:
# Import required libraries
import gc
import time
import numpy as np
import matplotlib.pyplot as plt
from typing import Dict, List, Tuple

from transformers import AutoModelForCausalLM, AutoTokenizer

# Set default device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Helper function to clear GPU memory
def clear_gpu_memory():
    """Clear GPU cache and run garbage collection."""
    torch.cuda.empty_cache()
    gc.collect()
    torch.cuda.reset_peak_memory_stats()
    print(f"GPU memory cleared. Allocated: {torch.cuda.memory_allocated()/1e9:.2f}GB")

---

## Part 2: Loading Mamba Models

### Available Mamba Models

| Model | Parameters | Memory (BF16) | DGX Spark Fit? |
|-------|------------|---------------|----------------|
| mamba-130m | 130M | ~260 MB | ‚úÖ Easily |
| mamba-370m | 370M | ~740 MB | ‚úÖ Easily |
| mamba-790m | 790M | ~1.6 GB | ‚úÖ Easily |
| mamba-1.4b | 1.4B | ~2.8 GB | ‚úÖ Easily |
| mamba-2.8b | 2.8B | ~5.6 GB | ‚úÖ Easily |

Let's load Mamba-2.8B‚Äîthe largest publicly available Mamba model.

In [None]:
# Load Mamba model
# Using the HuggingFace version (state-spaces/mamba-2.8b-hf)

MODEL_NAME = "state-spaces/mamba-2.8b-hf"
# For faster testing, try: "state-spaces/mamba-130m-hf"

print(f"Loading {MODEL_NAME}...")
print("This may take a minute on first run (downloading weights)...\n")

# Clear memory first
clear_gpu_memory()

# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token
print(f"‚úÖ Tokenizer loaded. Vocab size: {tokenizer.vocab_size:,}")

# Load model with bfloat16 (native to DGX Spark's Blackwell architecture)
mamba_model = AutoModelForCausalLM.from_pretrained(
    MODEL_NAME,
    torch_dtype=torch.bfloat16,  # Blackwell-optimized
    device_map="auto",
)

# Report stats
num_params = sum(p.numel() for p in mamba_model.parameters())
memory_used = torch.cuda.memory_allocated() / 1e9

print(f"\n‚úÖ Model loaded!")
print(f"   Parameters: {num_params/1e9:.2f}B")
print(f"   GPU Memory: {memory_used:.2f} GB")
print(f"   DGX Spark headroom: {128 - memory_used:.1f} GB remaining")

### üîç What Just Happened?

We loaded a 2.8 billion parameter Mamba model in ~5.6 GB (BF16). Notice:
- **No attention layers** = No quadratic memory scaling
- **Fixed state size** = Memory doesn't grow with context
- **Massive headroom** = 128GB - 5.6GB = 122+ GB free for inference!

This headroom is crucial for long-context inference.

In [None]:
# Quick test: Generate some text
prompt = "The key insight of Mamba over transformers is"

inputs = tokenizer(prompt, return_tensors="pt").to(device)

print(f"Prompt: {prompt}")
print(f"Input tokens: {inputs['input_ids'].shape[1]}")
print("\nGenerating...")

with torch.no_grad():
    outputs = mamba_model.generate(
        **inputs,
        max_new_tokens=100,
        temperature=0.7,
        top_p=0.9,
        do_sample=True,
        pad_token_id=tokenizer.pad_token_id,
    )

generated = tokenizer.decode(outputs[0], skip_special_tokens=True)
print(f"\nü§ñ Generated text:\n{generated}")

---

## Part 3: Understanding Mamba's Memory Advantage

Let's visualize why Mamba's constant memory matters.

### Theoretical Memory Comparison

```
Transformer Memory = Model + KV Cache
                   = Model + 2 √ó layers √ó heads √ó context √ó head_dim √ó precision
                   
Mamba Memory      = Model + State
                   = Model + layers √ó state_dim √ó precision
                   
Key difference: Transformer scales with context, Mamba doesn't!
```

In [None]:
# Visualize memory scaling (theoretical)

def calculate_transformer_memory(context_length: int, 
                                  model_memory_gb: float = 6.0,
                                  num_layers: int = 32,
                                  num_heads: int = 32,
                                  head_dim: int = 128,
                                  precision_bytes: int = 2) -> float:
    """
    Calculate approximate transformer memory including KV cache.
    
    KV cache = 2 (K and V) √ó layers √ó heads √ó context √ó head_dim √ó precision
    """
    kv_cache_bytes = 2 * num_layers * num_heads * context_length * head_dim * precision_bytes
    kv_cache_gb = kv_cache_bytes / 1e9
    return model_memory_gb + kv_cache_gb

def calculate_mamba_memory(context_length: int,
                           model_memory_gb: float = 5.6,
                           num_layers: int = 64,
                           state_dim: int = 16,
                           d_model: int = 2560,
                           precision_bytes: int = 2) -> float:
    """
    Calculate Mamba memory (constant regardless of context!).
    
    State memory = layers √ó state_dim √ó d_model √ó precision
    (This is independent of context length)
    """
    state_bytes = num_layers * state_dim * d_model * precision_bytes
    state_gb = state_bytes / 1e9
    return model_memory_gb + state_gb  # State is tiny!

# Generate comparison data
context_lengths = [1024, 2048, 4096, 8192, 16384, 32768, 65536, 131072]

transformer_memory = [calculate_transformer_memory(ctx) for ctx in context_lengths]
mamba_memory = [calculate_mamba_memory(ctx) for ctx in context_lengths]

# Plot
fig, ax = plt.subplots(figsize=(10, 6))

ax.plot(context_lengths, transformer_memory, 'o-', label='Transformer (3B)', 
        linewidth=2, markersize=8, color='#E74C3C')
ax.plot(context_lengths, mamba_memory, 's-', label='Mamba (2.8B)', 
        linewidth=2, markersize=8, color='#27AE60')

# DGX Spark limit
ax.axhline(y=128, color='#3498DB', linestyle='--', linewidth=2, label='DGX Spark (128GB)')
ax.axhline(y=24, color='#9B59B6', linestyle=':', linewidth=2, label='Consumer GPU (24GB)')

ax.set_xscale('log', base=2)
ax.set_yscale('log')
ax.set_xlabel('Context Length (tokens)', fontsize=12)
ax.set_ylabel('GPU Memory (GB)', fontsize=12)
ax.set_title('Memory Scaling: Transformer vs Mamba', fontsize=14, fontweight='bold')
ax.legend(loc='upper left', fontsize=10)
ax.grid(True, alpha=0.3)

# Annotations
ax.annotate('Transformer OOM\non consumer GPU!', 
            xy=(16384, 30), fontsize=10, color='#E74C3C',
            arrowprops=dict(arrowstyle='->', color='#E74C3C'),
            xytext=(8192, 60))

ax.annotate('Mamba: constant memory!', 
            xy=(65536, 5.7), fontsize=10, color='#27AE60',
            arrowprops=dict(arrowstyle='->', color='#27AE60'),
            xytext=(32768, 12))

plt.tight_layout()
plt.show()

print("\nüìä Memory at 64K context:")
print(f"   Transformer: {calculate_transformer_memory(65536):.1f} GB")
print(f"   Mamba:       {calculate_mamba_memory(65536):.1f} GB")
print(f"   Savings:     {calculate_transformer_memory(65536) - calculate_mamba_memory(65536):.1f} GB!")

### üîç What This Means

The graph shows Mamba's killer advantage:
- **At 8K tokens**: Transformer uses ~14GB, Mamba uses ~5.6GB
- **At 32K tokens**: Transformer uses ~38GB, Mamba STILL uses ~5.6GB
- **At 128K tokens**: Transformer needs 140GB+, Mamba STILL uses ~5.6GB!

On DGX Spark:
- Transformer can process ~60K tokens before hitting 128GB limit
- Mamba can process 200K+ tokens easily (limited by compute, not memory)

---

## Part 4: Benchmarking Mamba Inference

Let's measure actual performance across different context lengths.

In [None]:
def benchmark_generation(
    model,
    tokenizer,
    context_length: int,
    generation_length: int = 50,
    warmup_runs: int = 2,
    benchmark_runs: int = 3,
) -> Dict:
    """
    Benchmark model generation at a specific context length.
    
    Returns dict with timing and memory stats.
    """
    # Create input of specified length
    # Use a repeating sentence pattern for realistic tokens
    base_text = "The quick brown fox jumps over the lazy dog. " * 50
    tokens = tokenizer.encode(base_text, add_special_tokens=False)
    while len(tokens) < context_length:
        tokens = tokens + tokens
    tokens = tokens[:context_length]
    
    input_ids = torch.tensor([tokens], device=device)
    
    # Warmup
    for _ in range(warmup_runs):
        with torch.no_grad():
            _ = model.generate(
                input_ids,
                max_new_tokens=5,
                do_sample=False,
                pad_token_id=tokenizer.pad_token_id,
            )
    
    # Reset memory tracking
    torch.cuda.empty_cache()
    torch.cuda.reset_peak_memory_stats()
    
    # Benchmark runs
    times = []
    for _ in range(benchmark_runs):
        torch.cuda.synchronize()
        start = time.perf_counter()
        
        with torch.no_grad():
            outputs = model.generate(
                input_ids,
                max_new_tokens=generation_length,
                do_sample=False,
                pad_token_id=tokenizer.pad_token_id,
            )
        
        torch.cuda.synchronize()
        times.append(time.perf_counter() - start)
    
    avg_time = np.mean(times)
    peak_memory = torch.cuda.max_memory_allocated() / 1e9
    
    return {
        "context_length": context_length,
        "generation_length": generation_length,
        "avg_time_seconds": avg_time,
        "tokens_per_second": generation_length / avg_time,
        "peak_memory_gb": peak_memory,
        "times": times,
    }

In [None]:
# Run benchmarks across context lengths
# Adjust these based on your available memory

# For DGX Spark (128GB): can go up to 100K+
# For 24GB GPU: limit to 16K
# For 8GB GPU: limit to 4K

total_memory_gb = torch.cuda.get_device_properties(0).total_memory / 1e9

if total_memory_gb > 100:  # DGX Spark
    test_contexts = [1024, 4096, 8192, 16384, 32768, 65536]
elif total_memory_gb > 20:  # RTX 3090/4090
    test_contexts = [1024, 4096, 8192, 16384]
else:  # Smaller GPU
    test_contexts = [1024, 2048, 4096]

print(f"Benchmarking Mamba at context lengths: {test_contexts}")
print(f"GPU Memory: {total_memory_gb:.1f} GB\n")

mamba_results = []

for ctx_len in test_contexts:
    print(f"Testing context length: {ctx_len:,} tokens...")
    try:
        result = benchmark_generation(mamba_model, tokenizer, ctx_len)
        mamba_results.append(result)
        print(f"  ‚úÖ {result['tokens_per_second']:.1f} tokens/sec, {result['peak_memory_gb']:.2f} GB peak")
    except RuntimeError as e:
        if "out of memory" in str(e).lower():
            print(f"  ‚ùå Out of memory at {ctx_len:,} tokens")
            break
        else:
            raise e

print("\n‚úÖ Benchmarks complete!")

In [None]:
# Visualize Mamba benchmark results

if mamba_results:
    contexts = [r['context_length'] for r in mamba_results]
    speeds = [r['tokens_per_second'] for r in mamba_results]
    memories = [r['peak_memory_gb'] for r in mamba_results]
    
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5))
    
    # Speed plot
    ax1.bar(range(len(contexts)), speeds, color='#27AE60', alpha=0.8)
    ax1.set_xticks(range(len(contexts)))
    ax1.set_xticklabels([f'{c//1024}K' for c in contexts], fontsize=10)
    ax1.set_xlabel('Context Length', fontsize=12)
    ax1.set_ylabel('Tokens per Second', fontsize=12)
    ax1.set_title('Mamba Generation Speed', fontsize=14, fontweight='bold')
    ax1.grid(True, axis='y', alpha=0.3)
    
    # Add value labels on bars
    for i, v in enumerate(speeds):
        ax1.text(i, v + 1, f'{v:.1f}', ha='center', fontsize=9)
    
    # Memory plot
    ax2.bar(range(len(contexts)), memories, color='#3498DB', alpha=0.8)
    ax2.set_xticks(range(len(contexts)))
    ax2.set_xticklabels([f'{c//1024}K' for c in contexts], fontsize=10)
    ax2.set_xlabel('Context Length', fontsize=12)
    ax2.set_ylabel('Peak Memory (GB)', fontsize=12)
    ax2.set_title('Mamba Memory Usage', fontsize=14, fontweight='bold')
    ax2.grid(True, axis='y', alpha=0.3)
    
    # Highlight the key insight: memory barely changes!
    if len(memories) > 1:
        memory_increase = memories[-1] - memories[0]
        ax2.annotate(f'Only +{memory_increase:.2f}GB\nfrom {contexts[0]//1024}K to {contexts[-1]//1024}K!',
                    xy=(len(contexts)-1, memories[-1]),
                    xytext=(len(contexts)-2, memories[-1] + 2),
                    fontsize=10, color='#E74C3C',
                    arrowprops=dict(arrowstyle='->', color='#E74C3C'))
    
    plt.tight_layout()
    plt.show()
    
    # Print summary
    print("\nüìä Mamba Benchmark Summary:")
    print("-" * 50)
    print(f"{'Context':<12} {'Speed (tok/s)':<15} {'Memory (GB)':<12}")
    print("-" * 50)
    for r in mamba_results:
        print(f"{r['context_length']:,} tokens  {r['tokens_per_second']:<15.1f} {r['peak_memory_gb']:<12.2f}")
else:
    print("No benchmark results to display")

### üîç Key Observations

Notice in the benchmarks:

1. **Memory stays nearly constant** - Whether processing 1K or 64K tokens, Mamba uses almost the same memory. This is the O(n) vs O(n¬≤) advantage in action!

2. **Speed is consistent** - Generation speed doesn't drop dramatically with longer contexts (unlike transformers where longer KV cache = slower generation)

3. **DGX Spark advantage** - With 128GB, you can process contexts that would crash consumer GPUs

---

## Part 5: Loading a Transformer for Comparison

To truly appreciate Mamba, let's compare with a similar-sized transformer.

In [None]:
# Load a comparable transformer model
# Using a smaller model for fair comparison on limited memory

TRANSFORMER_MODEL = "microsoft/phi-2"  # 2.7B parameters
# Alternatives: "TinyLlama/TinyLlama-1.1B-Chat-v1.0", "facebook/opt-2.7b"

print(f"Loading transformer: {TRANSFORMER_MODEL}...")
print("(This is for comparison purposes)\n")

transformer_tokenizer = AutoTokenizer.from_pretrained(TRANSFORMER_MODEL, trust_remote_code=True)
if transformer_tokenizer.pad_token is None:
    transformer_tokenizer.pad_token = transformer_tokenizer.eos_token

transformer_model = AutoModelForCausalLM.from_pretrained(
    TRANSFORMER_MODEL,
    torch_dtype=torch.bfloat16,
    device_map="auto",
    trust_remote_code=True,
)

t_params = sum(p.numel() for p in transformer_model.parameters())
t_memory = torch.cuda.memory_allocated() / 1e9

print(f"‚úÖ Transformer loaded!")
print(f"   Parameters: {t_params/1e9:.2f}B")
print(f"   GPU Memory (total): {t_memory:.2f} GB")

In [None]:
# Benchmark transformer at same context lengths

# Transformers may OOM at high context lengths
# Limit based on available memory
if total_memory_gb > 100:  # DGX Spark
    transformer_contexts = [1024, 4096, 8192, 16384, 32768]
elif total_memory_gb > 20:
    transformer_contexts = [1024, 4096, 8192]
else:
    transformer_contexts = [1024, 2048]

print(f"Benchmarking Transformer at context lengths: {transformer_contexts}\n")

transformer_results = []

for ctx_len in transformer_contexts:
    print(f"Testing context length: {ctx_len:,} tokens...")
    try:
        # Clear before each test
        torch.cuda.empty_cache()
        
        result = benchmark_generation(
            transformer_model, 
            transformer_tokenizer, 
            ctx_len,
            warmup_runs=1,  # Fewer warmups for transformer (slower)
            benchmark_runs=2,
        )
        transformer_results.append(result)
        print(f"  ‚úÖ {result['tokens_per_second']:.1f} tokens/sec, {result['peak_memory_gb']:.2f} GB peak")
    except RuntimeError as e:
        if "out of memory" in str(e).lower():
            print(f"  ‚ùå Out of memory at {ctx_len:,} tokens")
            print(f"     (This is expected - transformer KV cache grows with context!)")
            break
        else:
            raise e

print("\n‚úÖ Transformer benchmarks complete!")

In [None]:
# Side-by-side comparison visualization

if mamba_results and transformer_results:
    # Find common context lengths
    mamba_contexts = {r['context_length']: r for r in mamba_results}
    transformer_contexts = {r['context_length']: r for r in transformer_results}
    common_contexts = sorted(set(mamba_contexts.keys()) & set(transformer_contexts.keys()))
    
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5))
    
    # Speed comparison
    x = np.arange(len(common_contexts))
    width = 0.35
    
    mamba_speeds = [mamba_contexts[c]['tokens_per_second'] for c in common_contexts]
    transformer_speeds = [transformer_contexts[c]['tokens_per_second'] for c in common_contexts]
    
    bars1 = ax1.bar(x - width/2, mamba_speeds, width, label='Mamba', color='#27AE60')
    bars2 = ax1.bar(x + width/2, transformer_speeds, width, label='Transformer', color='#E74C3C')
    
    ax1.set_xlabel('Context Length', fontsize=12)
    ax1.set_ylabel('Tokens per Second', fontsize=12)
    ax1.set_title('Generation Speed Comparison', fontsize=14, fontweight='bold')
    ax1.set_xticks(x)
    ax1.set_xticklabels([f'{c//1024}K' for c in common_contexts])
    ax1.legend()
    ax1.grid(True, axis='y', alpha=0.3)
    
    # Memory comparison
    mamba_memories = [mamba_contexts[c]['peak_memory_gb'] for c in common_contexts]
    transformer_memories = [transformer_contexts[c]['peak_memory_gb'] for c in common_contexts]
    
    ax2.plot(common_contexts, mamba_memories, 'o-', label='Mamba', 
             linewidth=2, markersize=10, color='#27AE60')
    ax2.plot(common_contexts, transformer_memories, 's-', label='Transformer', 
             linewidth=2, markersize=10, color='#E74C3C')
    
    ax2.set_xscale('log', base=2)
    ax2.set_xlabel('Context Length (tokens)', fontsize=12)
    ax2.set_ylabel('Peak Memory (GB)', fontsize=12)
    ax2.set_title('Memory Usage Comparison', fontsize=14, fontweight='bold')
    ax2.legend()
    ax2.grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.show()
    
    # Print comparison table
    print("\nüìä Architecture Comparison:")
    print("=" * 70)
    print(f"{'Context':<12} {'Mamba Speed':<15} {'Trans. Speed':<15} {'Mamba Mem':<12} {'Trans. Mem':<12}")
    print("-" * 70)
    for c in common_contexts:
        print(f"{c//1024}K tokens    "
              f"{mamba_contexts[c]['tokens_per_second']:<15.1f} "
              f"{transformer_contexts[c]['tokens_per_second']:<15.1f} "
              f"{mamba_contexts[c]['peak_memory_gb']:<12.2f} "
              f"{transformer_contexts[c]['peak_memory_gb']:<12.2f}")
    print("=" * 70)
else:
    print("Need both Mamba and Transformer results for comparison")

---

## ‚ö†Ô∏è Common Mistakes

### Mistake 1: Wrong Transformers Version
```python
# ‚ùå Error: Unknown model type 'mamba'
model = AutoModelForCausalLM.from_pretrained("state-spaces/mamba-2.8b-hf")

# ‚úÖ Fix: Upgrade transformers
# pip install --upgrade transformers>=4.46.0
```
**Why:** Mamba support requires transformers 4.46.0 or higher

### Mistake 2: Using Float32 Instead of BFloat16
```python
# ‚ùå Wastes memory (2x more than needed)
model = AutoModelForCausalLM.from_pretrained(
    "state-spaces/mamba-2.8b-hf",
    torch_dtype=torch.float32,  # 4 bytes per param
)

# ‚úÖ Use bfloat16 (native Blackwell support)
model = AutoModelForCausalLM.from_pretrained(
    "state-spaces/mamba-2.8b-hf",
    torch_dtype=torch.bfloat16,  # 2 bytes per param
)
```
**Why:** DGX Spark's Blackwell architecture has native BF16 support

### Mistake 3: Expecting Mamba to Match Transformer Quality Everywhere
```python
# ‚ö†Ô∏è Mamba may not match transformers on all tasks
# Mamba excels at:
#   - Long document processing
#   - Audio/time-series
#   - Streaming inference
# Transformers still win on:
#   - Complex reasoning requiring precise attention
#   - Tasks with well-established transformer benchmarks
```
**Why:** Different architectures have different strengths

---

## üìö Working with Long Documents: The Datasets Library

Before we practice with long documents, let's learn how to load them efficiently using the HuggingFace `datasets` library.

### The `datasets` Library

The `datasets` library provides easy access to thousands of text datasets, perfect for testing long-context models:

```python
from datasets import load_dataset

# Load a dataset with optional split specification
dataset = load_dataset("dataset_name", split="train[:10]")  # First 10 examples

# Access data like a list
text = dataset[0]["text"]  # First document's text field
```

**Key Functions:**
- `load_dataset(name, split)` - Load a dataset by name
- `split="train[:N]"` - Load first N examples from train split
- `split="test[:1]"` - Load first example from test split

**Useful Long-Document Datasets:**
- `pg19` - Project Gutenberg books (very long texts)
- `scientific_papers` - arXiv/PubMed papers
- `bookcorpus` - Book excerpts

In [None]:
# Install datasets if needed (uncomment to run)
# !pip install datasets

# Import and demonstrate the datasets library
from datasets import load_dataset

# Load a sample from Project Gutenberg (PG19) - contains full books
print("Loading a sample from PG19 (Project Gutenberg books)...")
pg19_sample = load_dataset("pg19", split="test[:1]", trust_remote_code=True)

# Access the text
sample_text = pg19_sample[0]["text"]

print(f"‚úÖ Loaded 1 book from PG19")
print(f"   Total characters: {len(sample_text):,}")
print(f"   First 200 chars: {sample_text[:200]}...")

# Tokenize to see token count
sample_tokens = tokenizer.encode(sample_text[:50000])  # First 50K chars
print(f"\n   Tokens in first 50K chars: {len(sample_tokens):,}")

---

## üéâ Checkpoint

You've learned:
- ‚úÖ Why transformers have O(n¬≤) complexity (attention looks at all pairs)
- ‚úÖ How Mamba achieves O(n) with selective state spaces
- ‚úÖ Loading Mamba models using HuggingFace transformers
- ‚úÖ Benchmarking and comparing architectures
- ‚úÖ Mamba's constant memory advantage for long contexts

---

## ‚úã Try It Yourself

### Exercise 1: Long Document Processing
Load a long text document (e.g., a book chapter from Project Gutenberg) and:
1. Tokenize it and measure the token count
2. Run Mamba inference on the full document
3. Compare memory usage with what a transformer would theoretically need

<details>
<summary>üí° Hint</summary>

```python
# Load a long document
from datasets import load_dataset
pg19 = load_dataset("pg19", split="test[:1]")
long_text = pg19[0]["text"][:50000]  # First 50K characters

# Tokenize and check length
tokens = tokenizer.encode(long_text)
print(f"Token count: {len(tokens)}")
```
</details>

In [None]:
# Your code for Exercise 1 here
# Try loading a long document and processing it with Mamba



### Exercise 2: Context Scaling Analysis
Create a more detailed benchmark that measures:
1. Time-to-first-token (TTFT) at different context lengths
2. How throughput (tokens/sec) changes with context
3. Memory usage at each context length

Plot all three metrics.

<details>
<summary>üí° Hint</summary>

For TTFT measurement, you can use streaming generation or measure time to generate just 1 token:
```python
# Measure time to first token
start = time.perf_counter()
with torch.no_grad():
    outputs = model.generate(inputs, max_new_tokens=1, do_sample=False)
ttft = time.perf_counter() - start
```
</details>

In [None]:
# Your code for Exercise 2 here
# Create a detailed scaling analysis



---

## üöÄ Challenge (Optional)

### Advanced Challenge: Build a Long-Context Summarizer

Build a function that:
1. Takes a very long document (100K+ tokens if on DGX Spark)
2. Uses Mamba to generate a summary
3. Compares performance with a chunked-transformer approach

For the chunked approach:
- Split document into 4K chunks
- Summarize each chunk
- Combine summaries and summarize again

Compare:
- Total processing time
- Memory usage
- Quality of final summary

In [None]:
# Your advanced challenge code here



---

## üìñ Further Reading

- [Mamba Paper](https://arxiv.org/abs/2312.00752) - "Mamba: Linear-Time Sequence Modeling with Selective State Spaces"
- [Mamba-2 Paper](https://arxiv.org/abs/2405.21060) - 8√ó faster training with structured state space duality
- [HuggingFace Mamba Guide](https://huggingface.co/docs/transformers/model_doc/mamba)
- [State Spaces Explained](https://srush.github.io/annotated-s4/) - Annotated S4 (Mamba's predecessor)
- [The Mamba Repository](https://github.com/state-spaces/mamba) - Official implementation

---

## üßπ Cleanup

In [None]:
# Clean up GPU memory
print("Cleaning up...")

# Delete models
if 'mamba_model' in dir():
    del mamba_model
if 'transformer_model' in dir():
    del transformer_model
if 'tokenizer' in dir():
    del tokenizer
if 'transformer_tokenizer' in dir():
    del transformer_tokenizer

# Clear GPU cache
import gc
torch.cuda.empty_cache()
gc.collect()

print(f"GPU memory after cleanup: {torch.cuda.memory_allocated()/1e9:.2f} GB")
print("\n‚úÖ Cleanup complete!")

---

## üèÅ Summary

In this lab, you learned:

| Concept | Key Takeaway |
|---------|-------------|
| Mamba Architecture | Selective State Space = O(n) complexity |
| Memory Advantage | Constant memory regardless of context length |
| Speed | Consistent generation speed at all context lengths |
| DGX Spark | 128GB enables 100K+ token contexts |
| Use Cases | Long documents, streaming, audio/time-series |

**Next:** In Lab 2.4.2, we'll dive deeper into Mamba's architecture and visualize the selective scan mechanism!