# Lab 2.4.1: Mamba Inference - SOLUTIONS

This notebook contains complete solutions to the exercises from Lab 2.4.1.

---

## Setup (Same as Lab)

In [None]:
import torch
import time
import gc
import numpy as np
import matplotlib.pyplot as plt
from transformers import AutoModelForCausalLM, AutoTokenizer
from datasets import load_dataset

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Load Mamba model
MODEL_NAME = "state-spaces/mamba-2.8b-hf"
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

model = AutoModelForCausalLM.from_pretrained(
    MODEL_NAME,
    torch_dtype=torch.bfloat16,
    device_map="auto",
)
print("Model loaded!")

---

## Exercise 1 Solution: Long Document Processing

In [None]:
# Solution: Load and process a long document

# Load a book from PG19 dataset
print("Loading PG19 dataset...")
pg19 = load_dataset("pg19", split="test[:1]")

# Get the text and truncate to manageable size
long_text = pg19[0]["text"][:100000]  # First 100K characters
print(f"\nDocument length: {len(long_text):,} characters")

# Tokenize and measure
tokens = tokenizer.encode(long_text, add_special_tokens=False)
print(f"Token count: {len(tokens):,} tokens")

# Calculate theoretical memory for transformer
def transformer_memory(context_length, hidden_size=4096, num_layers=32, 
                       num_heads=32, bytes_per_element=2):
    """Calculate KV cache memory for a transformer."""
    head_dim = hidden_size // num_heads
    # KV cache: 2 (K and V) * num_layers * context * num_heads * head_dim
    kv_cache_bytes = 2 * num_layers * context_length * num_heads * head_dim * bytes_per_element
    return kv_cache_bytes / 1e9  # GB

# Run Mamba inference
print(f"\nRunning Mamba inference on {len(tokens):,} tokens...")

# Truncate if needed for memory
max_tokens = min(len(tokens), 32768)  # 32K limit for demo
input_ids = torch.tensor([tokens[:max_tokens]], device=device)

# Clear memory stats
torch.cuda.reset_peak_memory_stats()

# Measure memory during forward pass
with torch.no_grad():
    _ = model(input_ids)

mamba_memory = torch.cuda.max_memory_allocated() / 1e9
transformer_theoretical = transformer_memory(max_tokens)

print(f"\nüìä Memory Comparison at {max_tokens:,} tokens:")
print(f"   Mamba actual: {mamba_memory:.2f} GB")
print(f"   Transformer KV cache (theoretical): {transformer_theoretical:.2f} GB")
print(f"   Memory savings: {(transformer_theoretical - mamba_memory):.2f} GB")

---

## Exercise 2 Solution: Context Scaling Analysis

In [None]:
# Solution: Detailed benchmark with TTFT, throughput, and memory

def comprehensive_benchmark(model, tokenizer, context_lengths, generation_length=50):
    """
    Comprehensive benchmark measuring:
    - Time-to-first-token (TTFT)
    - Throughput (tokens/second)
    - Peak memory usage
    """
    results = []
    
    for ctx_len in context_lengths:
        print(f"\nBenchmarking {ctx_len:,} tokens...")
        
        try:
            # Create input
            input_ids = torch.randint(
                100, tokenizer.vocab_size - 100,
                (1, ctx_len), device=device
            )
            
            # Clear memory
            torch.cuda.empty_cache()
            torch.cuda.reset_peak_memory_stats()
            
            # Measure TTFT (time to generate first token)
            torch.cuda.synchronize()
            ttft_start = time.perf_counter()
            
            with torch.no_grad():
                first_output = model.generate(
                    input_ids,
                    max_new_tokens=1,
                    do_sample=False,
                    pad_token_id=tokenizer.pad_token_id,
                )
            
            torch.cuda.synchronize()
            ttft = time.perf_counter() - ttft_start
            
            # Measure throughput (full generation)
            torch.cuda.synchronize()
            gen_start = time.perf_counter()
            
            with torch.no_grad():
                full_output = model.generate(
                    input_ids,
                    max_new_tokens=generation_length,
                    do_sample=False,
                    pad_token_id=tokenizer.pad_token_id,
                )
            
            torch.cuda.synchronize()
            gen_time = time.perf_counter() - gen_start
            
            # Get memory
            peak_memory = torch.cuda.max_memory_allocated() / 1e9
            
            results.append({
                'context_length': ctx_len,
                'ttft_seconds': ttft,
                'throughput': generation_length / gen_time,
                'peak_memory_gb': peak_memory,
            })
            
            print(f"  TTFT: {ttft*1000:.1f}ms | Throughput: {generation_length/gen_time:.1f} tok/s | Memory: {peak_memory:.2f}GB")
            
        except RuntimeError as e:
            if "out of memory" in str(e).lower():
                print(f"  OOM at {ctx_len:,} tokens")
                torch.cuda.empty_cache()
            else:
                raise
    
    return results

# Run comprehensive benchmark
test_contexts = [1024, 2048, 4096, 8192, 16384]
benchmark_results = comprehensive_benchmark(model, tokenizer, test_contexts)

# Plot results
fig, axes = plt.subplots(1, 3, figsize=(15, 4))

contexts = [r['context_length'] for r in benchmark_results]
ttfts = [r['ttft_seconds'] * 1000 for r in benchmark_results]  # Convert to ms
throughputs = [r['throughput'] for r in benchmark_results]
memories = [r['peak_memory_gb'] for r in benchmark_results]

# TTFT plot
axes[0].plot(contexts, ttfts, 'o-', linewidth=2, markersize=8, color='#E74C3C')
axes[0].set_xlabel('Context Length (tokens)')
axes[0].set_ylabel('Time to First Token (ms)')
axes[0].set_title('TTFT vs Context Length', fontweight='bold')
axes[0].set_xscale('log', base=2)
axes[0].grid(True, alpha=0.3)

# Throughput plot
axes[1].plot(contexts, throughputs, 's-', linewidth=2, markersize=8, color='#27AE60')
axes[1].set_xlabel('Context Length (tokens)')
axes[1].set_ylabel('Tokens per Second')
axes[1].set_title('Throughput vs Context Length', fontweight='bold')
axes[1].set_xscale('log', base=2)
axes[1].grid(True, alpha=0.3)

# Memory plot
axes[2].plot(contexts, memories, '^-', linewidth=2, markersize=8, color='#3498DB')
axes[2].set_xlabel('Context Length (tokens)')
axes[2].set_ylabel('Peak Memory (GB)')
axes[2].set_title('Memory vs Context Length', fontweight='bold')
axes[2].set_xscale('log', base=2)
axes[2].grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

# Key insight
print("\nüîç Key Observations:")
print(f"   TTFT scales approximately linearly with context (O(n))")
print(f"   Throughput remains relatively stable")
print(f"   Memory increases minimally - this is Mamba's advantage!")

---

## Challenge Solution: Long-Context Summarizer

In [None]:
# Solution: Long-context summarizer comparison

def mamba_summarize(model, tokenizer, text, max_summary_tokens=200):
    """
    Summarize a long document using Mamba (processes full context).
    """
    prompt = f"Summarize the following text:\n\n{text}\n\nSummary:"
    
    inputs = tokenizer(prompt, return_tensors="pt", truncation=True, 
                      max_length=32768).to(device)
    
    start = time.perf_counter()
    with torch.no_grad():
        outputs = model.generate(
            **inputs,
            max_new_tokens=max_summary_tokens,
            temperature=0.7,
            do_sample=True,
            pad_token_id=tokenizer.pad_token_id,
        )
    elapsed = time.perf_counter() - start
    
    summary = tokenizer.decode(outputs[0][inputs['input_ids'].shape[1]:], 
                               skip_special_tokens=True)
    
    return summary, elapsed, inputs['input_ids'].shape[1]

def chunked_summarize(model, tokenizer, text, chunk_size=2048, max_summary_tokens=100):
    """
    Summarize using chunking approach (for comparison with transformer-style).
    """
    # Tokenize full text
    tokens = tokenizer.encode(text, add_special_tokens=False)
    
    # Split into chunks
    chunks = [tokens[i:i+chunk_size] for i in range(0, len(tokens), chunk_size)]
    
    start = time.perf_counter()
    chunk_summaries = []
    
    for chunk_tokens in chunks[:5]:  # Limit chunks for demo
        chunk_text = tokenizer.decode(chunk_tokens)
        prompt = f"Summarize briefly:\n{chunk_text}\n\nSummary:"
        
        inputs = tokenizer(prompt, return_tensors="pt", truncation=True,
                          max_length=chunk_size + 100).to(device)
        
        with torch.no_grad():
            outputs = model.generate(
                **inputs,
                max_new_tokens=max_summary_tokens // len(chunks[:5]),
                temperature=0.7,
                do_sample=True,
                pad_token_id=tokenizer.pad_token_id,
            )
        
        chunk_summary = tokenizer.decode(outputs[0][inputs['input_ids'].shape[1]:],
                                         skip_special_tokens=True)
        chunk_summaries.append(chunk_summary)
    
    # Combine chunk summaries
    combined = " ".join(chunk_summaries)
    elapsed = time.perf_counter() - start
    
    return combined, elapsed, len(chunks)

# Test on a long document
if 'pg19' in dir():
    test_text = pg19[0]["text"][:20000]  # 20K characters
else:
    test_text = "This is a test document. " * 500

print(f"Document length: {len(test_text):,} characters")
print(f"Document tokens: {len(tokenizer.encode(test_text)):,}")

# Mamba full-context approach
print("\nü¶é Mamba Full-Context Summarization:")
mamba_summary, mamba_time, mamba_tokens = mamba_summarize(model, tokenizer, test_text)
print(f"   Time: {mamba_time:.2f}s")
print(f"   Input tokens: {mamba_tokens:,}")
print(f"   Summary: {mamba_summary[:300]}...")

# Chunked approach
print("\nüìë Chunked Summarization:")
chunked_summary, chunked_time, num_chunks = chunked_summarize(model, tokenizer, test_text)
print(f"   Time: {chunked_time:.2f}s")
print(f"   Chunks processed: {num_chunks}")
print(f"   Summary: {chunked_summary[:300]}...")

print("\nüìä Comparison:")
print(f"   Mamba full-context: {mamba_time:.2f}s (sees entire document at once)")
print(f"   Chunked approach: {chunked_time:.2f}s (loses cross-chunk context)")
print(f"   Speedup: {chunked_time/mamba_time:.2f}x (but Mamba has better context!)")

---

## Cleanup

In [None]:
del model, tokenizer
torch.cuda.empty_cache()
gc.collect()
print("‚úÖ Cleanup complete!")