# Lab 2.4.2: Mamba Architecture Study

**Module:** 2.4 - Efficient Architectures  
**Time:** 2 hours  
**Difficulty:** ‚≠ê‚≠ê‚≠ê‚≠ê (Advanced)

---

## üéØ Learning Objectives

By the end of this lab, you will:
- [ ] Understand the mathematical foundation of State Space Models
- [ ] Implement a simplified selective scan algorithm in PyTorch
- [ ] Visualize how Mamba's state evolves across a sequence
- [ ] Compare Mamba's "attention" with transformer attention patterns

---

## üìö Prerequisites

- Completed: Lab 2.4.1 (Mamba Inference)
- Knowledge of: Linear algebra basics, RNNs
- Helpful: Understanding of differential equations (but not required)

---

## üåç Real-World Context

Understanding Mamba's internals helps you:
- **Debug** why a model behaves unexpectedly on certain inputs
- **Optimize** inference by understanding bottlenecks
- **Choose** the right architecture for your use case
- **Research** improvements and hybrid architectures (like Jamba)

Major companies are investing in SSM research: Google (S4), Microsoft, NVIDIA, and AI21 Labs (Jamba).

---

## üßí ELI5: State Space Models

> **Imagine you're a weather forecaster...**
>
> You have a "state" that represents your understanding of the weather:
> - Temperature trends
> - Humidity patterns  
> - Pressure systems
>
> Each day, you update this state based on new observations:
> ```
> new_state = A √ó old_state + B √ó today's_observation
> prediction = C √ó new_state
> ```
>
> Where A, B, C are learned "rules" for how weather evolves.
>
> **The "Selective" part in Mamba:**
> Normal forecasters use the SAME rules every day. But imagine if the rules CHANGED based on what you observe:
> - Sunny day ‚Üí update temperature state strongly
> - Rainy day ‚Üí update humidity state strongly
>
> This is what makes Mamba special: A, B, C change based on input!

### The Mathematics (Simplified)

**Classical State Space Model:**
```
h(t) = A¬∑h(t-1) + B¬∑x(t)    # State update
y(t) = C¬∑h(t) + D¬∑x(t)       # Output
```

**Selective State Space (Mamba):**
```
A, B, C = f(x(t))            # Parameters depend on input!
h(t) = A¬∑h(t-1) + B¬∑x(t)     # State update (with dynamic A, B)
y(t) = C¬∑h(t)                # Output (with dynamic C)
```

---

## Part 1: Setup and Prerequisites

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
from typing import Tuple, Optional
import math

# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Set random seed for reproducibility
torch.manual_seed(42)
np.random.seed(42)

# Visualization settings
plt.style.use('seaborn-v0_8-whitegrid')
plt.rcParams['figure.figsize'] = (12, 6)
plt.rcParams['font.size'] = 11

---

## Part 2: Classical State Space Model

Before understanding Mamba's selective scan, let's implement a simple state space model.

### The Continuous-Time State Space

The classical formulation (from control theory):
```
dx/dt = Ax + Bu    (state evolves continuously)
y = Cx + Du        (output)
```

For sequence modeling, we discretize this:
```
h[k] = ƒÄ¬∑h[k-1] + BÃÑ¬∑x[k]
y[k] = C¬∑h[k]
```

Where ƒÄ and BÃÑ are discretized versions of A and B.

In [None]:
class SimpleSSM(nn.Module):
    """
    A simple (non-selective) State Space Model.
    
    This is the building block that Mamba extends with selectivity.
    
    Parameters:
        d_model: Input/output dimension
        d_state: Hidden state dimension (compression factor)
    """
    
    def __init__(self, d_model: int = 64, d_state: int = 16):
        super().__init__()
        self.d_model = d_model
        self.d_state = d_state
        
        # State space parameters
        # A: State transition matrix (d_state x d_state)
        # We use a diagonal initialization for stability
        self.A = nn.Parameter(torch.randn(d_state) * 0.1)
        
        # B: Input projection (d_state x d_model)
        self.B = nn.Parameter(torch.randn(d_state, d_model) * 0.1)
        
        # C: Output projection (d_model x d_state)
        self.C = nn.Parameter(torch.randn(d_model, d_state) * 0.1)
        
        # Delta: Discretization step (learnable)
        self.log_delta = nn.Parameter(torch.zeros(d_model))
        
    def discretize(self):
        """
        Discretize continuous parameters using Zero-Order Hold (ZOH).
        
        ƒÄ = exp(Œî¬∑A)
        BÃÑ = (Œî¬∑A)^(-1) ¬∑ (ƒÄ - I) ¬∑ Œî¬∑B ‚âà Œî¬∑B for small Œî
        """
        delta = F.softplus(self.log_delta)  # Ensure positive
        
        # Simplified discretization
        A_discrete = torch.exp(delta.unsqueeze(-1) * self.A)  # [d_model, d_state]
        B_discrete = delta.unsqueeze(-1) * self.B.T  # [d_model, d_state]
        
        return A_discrete, B_discrete
    
    def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Process a sequence through the SSM.
        
        Args:
            x: Input tensor of shape [batch, seq_len, d_model]
            
        Returns:
            output: Output tensor of shape [batch, seq_len, d_model]
            states: Hidden states of shape [batch, seq_len, d_state]
        """
        batch, seq_len, _ = x.shape
        
        # Discretize parameters
        A_bar, B_bar = self.discretize()
        
        # Initialize hidden state
        h = torch.zeros(batch, self.d_state, device=x.device)
        
        outputs = []
        states = []
        
        # Sequential scan (this is what Mamba parallelizes!)
        for t in range(seq_len):
            # Input at time t
            x_t = x[:, t, :]  # [batch, d_model]
            
            # State update: h = A¬∑h + B¬∑x
            # Using diagonal A for efficiency
            h = torch.einsum('bd,ds->bs', A_bar * h.unsqueeze(1).expand(-1, self.d_model, -1).mean(1), 
                           torch.ones(self.d_state, self.d_state, device=x.device))
            h = h[:, :self.d_state] + torch.einsum('bd,ds->bs', x_t, self.B)
            
            # Output: y = C¬∑h
            y_t = torch.einsum('bs,ds->bd', h, self.C.T)
            
            outputs.append(y_t)
            states.append(h.clone())
        
        output = torch.stack(outputs, dim=1)
        all_states = torch.stack(states, dim=1)
        
        return output, all_states

# Test the simple SSM
simple_ssm = SimpleSSM(d_model=64, d_state=16).to(device)
test_input = torch.randn(2, 100, 64, device=device)  # [batch=2, seq=100, dim=64]

with torch.no_grad():
    output, states = simple_ssm(test_input)

print(f"Input shape:  {test_input.shape}")
print(f"Output shape: {output.shape}")
print(f"States shape: {states.shape}")
print(f"\n‚úÖ Simple SSM working!")

### üîç Key Insight: The Sequential Bottleneck

Notice the `for t in range(seq_len)` loop? This is the **sequential bottleneck** of RNNs!

- We can't parallelize because h[t] depends on h[t-1]
- This makes training slow on GPUs (which love parallelism)

**Mamba's trick**: The parallel scan algorithm that computes all states in O(log n) parallel steps!

---

## Part 3: The Selective Scan

Now let's implement Mamba's key innovation: **input-dependent parameters**.

In [None]:
class SelectiveSSM(nn.Module):
    """
    Selective State Space Model (simplified Mamba block).
    
    The key difference from SimpleSSM:
    - A, B, C, delta are computed FROM THE INPUT
    - This makes the model "selective" about what to remember
    
    Parameters:
        d_model: Input/output dimension
        d_state: Hidden state dimension
        d_conv: Local convolution width (for local context)
    """
    
    def __init__(self, d_model: int = 64, d_state: int = 16, d_conv: int = 4):
        super().__init__()
        self.d_model = d_model
        self.d_state = d_state
        self.d_conv = d_conv
        
        # Expand dimension for internal processing
        self.d_inner = d_model * 2
        
        # Input projection
        self.in_proj = nn.Linear(d_model, self.d_inner * 2, bias=False)
        
        # Local convolution (like in Mamba)
        self.conv1d = nn.Conv1d(
            self.d_inner, self.d_inner,
            kernel_size=d_conv,
            padding=d_conv - 1,
            groups=self.d_inner,  # Depthwise
        )
        
        # Selective parameters - computed from input!
        self.x_proj = nn.Linear(self.d_inner, d_state * 2 + 1, bias=False)  # B, C, delta
        
        # A is still a base parameter (log for stability)
        self.A_log = nn.Parameter(torch.log(torch.arange(1, d_state + 1, dtype=torch.float32)))
        
        # Output projection
        self.out_proj = nn.Linear(self.d_inner, d_model, bias=False)
        
    def forward(self, x: torch.Tensor, return_states: bool = False) -> torch.Tensor:
        """
        Selective scan forward pass.
        
        Args:
            x: Input [batch, seq_len, d_model]
            return_states: If True, also return hidden states
            
        Returns:
            output: [batch, seq_len, d_model]
            states (optional): [batch, seq_len, d_state]
        """
        batch, seq_len, _ = x.shape
        
        # Input projection and split
        xz = self.in_proj(x)  # [batch, seq_len, d_inner * 2]
        x_branch, z = xz.chunk(2, dim=-1)  # Each [batch, seq_len, d_inner]
        
        # Local convolution (for position awareness)
        x_conv = self.conv1d(x_branch.transpose(1, 2))[:, :, :seq_len].transpose(1, 2)
        x_conv = F.silu(x_conv)  # Activation
        
        # Generate selective parameters FROM INPUT
        x_params = self.x_proj(x_conv)  # [batch, seq_len, d_state*2 + 1]
        
        # Split into B, C, delta
        B = x_params[:, :, :self.d_state]  # [batch, seq_len, d_state]
        C = x_params[:, :, self.d_state:2*self.d_state]  # [batch, seq_len, d_state]
        delta = F.softplus(x_params[:, :, -1])  # [batch, seq_len] - discretization step
        
        # Get A (negative for stability)
        A = -torch.exp(self.A_log)  # [d_state]
        
        # Discretize: A_bar = exp(delta * A)
        A_bar = torch.exp(delta.unsqueeze(-1) * A)  # [batch, seq_len, d_state]
        
        # Run selective scan
        h = torch.zeros(batch, self.d_state, device=x.device)
        outputs = []
        states = []
        
        for t in range(seq_len):
            # Input-dependent state update
            h = A_bar[:, t] * h + B[:, t] * x_conv[:, t, 0:1]  # Selective!
            
            # Input-dependent output
            y_t = (C[:, t] * h).sum(dim=-1, keepdim=True)  # Selective!
            
            outputs.append(y_t)
            states.append(h.clone())
        
        y = torch.cat(outputs, dim=-1)  # [batch, d_inner]
        y = y.unsqueeze(1).expand(-1, seq_len, -1)
        
        # Gated output (like in Mamba)
        y = y * F.silu(z)
        
        # Output projection
        output = self.out_proj(y)
        
        if return_states:
            all_states = torch.stack(states, dim=1)
            return output, all_states
        return output

# Test selective SSM
selective_ssm = SelectiveSSM(d_model=64, d_state=16).to(device)
test_input = torch.randn(2, 50, 64, device=device)

with torch.no_grad():
    output, states = selective_ssm(test_input, return_states=True)

print(f"Input shape:  {test_input.shape}")
print(f"Output shape: {output.shape}")
print(f"States shape: {states.shape}")
print(f"\n‚úÖ Selective SSM working!")

### üîç The Key Difference: Selectivity

Notice these lines in the code:

```python
# Parameters computed FROM input!
x_params = self.x_proj(x_conv)  
B = x_params[:, :, :self.d_state]
C = x_params[:, :, self.d_state:2*self.d_state]
delta = F.softplus(x_params[:, :, -1])
```

**This is the magic of Mamba!**
- When processing "important" tokens, B and delta can be large (update state more)
- When processing "unimportant" tokens, B can be small (ignore them)
- C controls what parts of state to output

The model LEARNS what's important!

---

## Part 4: Visualizing State Evolution

Let's see how the hidden state evolves as it processes a sequence.

In [None]:
def visualize_state_evolution(model, tokenizer, text: str, max_tokens: int = 50):
    """
    Visualize how the model's hidden state changes across tokens.
    """
    # Tokenize
    tokens = tokenizer.encode(text, add_special_tokens=False)[:max_tokens]
    token_strs = [tokenizer.decode([t]) for t in tokens]
    
    # Create input
    input_ids = torch.tensor([tokens], device=device)
    
    # Get hidden states from first layer
    model.config.output_hidden_states = True
    
    with torch.no_grad():
        outputs = model(input_ids, output_hidden_states=True)
    
    # Extract hidden states
    # For Mamba, hidden_states[0] is embeddings, hidden_states[1] is after first layer
    hidden_states = outputs.hidden_states[1][0].cpu().numpy()  # [seq_len, hidden_dim]
    
    # Take first 32 dimensions for visualization
    states_viz = hidden_states[:, :32]
    
    # Plot
    fig, axes = plt.subplots(2, 1, figsize=(14, 8))
    
    # Heatmap of state evolution
    im = axes[0].imshow(states_viz.T, aspect='auto', cmap='RdBu_r', 
                        interpolation='nearest')
    axes[0].set_xlabel('Token Position')
    axes[0].set_ylabel('State Dimension')
    axes[0].set_title('Hidden State Evolution (First 32 Dimensions)', fontweight='bold')
    plt.colorbar(im, ax=axes[0], label='Activation')
    
    # Add token labels if not too many
    if len(token_strs) <= 30:
        axes[0].set_xticks(range(len(token_strs)))
        axes[0].set_xticklabels(token_strs, rotation=45, ha='right', fontsize=8)
    
    # State magnitude over time
    state_norms = np.linalg.norm(hidden_states, axis=1)
    axes[1].plot(state_norms, 'b-', linewidth=2)
    axes[1].fill_between(range(len(state_norms)), state_norms, alpha=0.3)
    axes[1].set_xlabel('Token Position')
    axes[1].set_ylabel('State Magnitude (L2 norm)')
    axes[1].set_title('State Magnitude Over Sequence', fontweight='bold')
    
    plt.tight_layout()
    plt.show()
    
    return hidden_states, token_strs

In [None]:
# Load a Mamba model for visualization
from transformers import AutoModelForCausalLM, AutoTokenizer

# Use smaller model for faster loading
MODEL_NAME = "state-spaces/mamba-130m-hf"

print(f"Loading {MODEL_NAME} for state visualization...")
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModelForCausalLM.from_pretrained(
    MODEL_NAME,
    torch_dtype=torch.bfloat16,
    device_map="auto",
)

print("‚úÖ Model loaded!")

In [None]:
# Visualize state evolution on a sample text
sample_text = "The quick brown fox jumps over the lazy dog. This is a test."

print(f"Analyzing: '{sample_text}'\n")
states, tokens = visualize_state_evolution(model, tokenizer, sample_text)

In [None]:
# Compare state evolution for different content types

texts = {
    "Code": "def fibonacci(n):\n    if n <= 1:\n        return n\n    return fibonacci(n-1) + fibonacci(n-2)",
    "Math": "The derivative of x squared plus three x equals two x plus three",
    "Story": "Once upon a time in a faraway kingdom there lived a brave knight",
}

fig, axes = plt.subplots(len(texts), 1, figsize=(14, 4*len(texts)))

for idx, (label, text) in enumerate(texts.items()):
    # Get hidden states
    tokens = tokenizer.encode(text, add_special_tokens=False)[:40]
    input_ids = torch.tensor([tokens], device=device)
    
    with torch.no_grad():
        outputs = model(input_ids, output_hidden_states=True)
    
    hidden = outputs.hidden_states[1][0].cpu().numpy()[:, :32]
    
    # Plot
    im = axes[idx].imshow(hidden.T, aspect='auto', cmap='RdBu_r')
    axes[idx].set_title(f'{label}: "{text[:50]}..."', fontweight='bold')
    axes[idx].set_ylabel('State Dim')
    plt.colorbar(im, ax=axes[idx])

axes[-1].set_xlabel('Token Position')
plt.tight_layout()
plt.show()

print("\nüîç Notice how different content types create different state patterns!")
print("   This is the 'selectivity' - the model adapts to what it's processing.")

---

## Part 5: Comparing with Attention Patterns

Transformers have explicit attention patterns we can visualize. What does Mamba's "implicit attention" look like?

In [None]:
def compute_effective_attention(states: np.ndarray) -> np.ndarray:
    """
    Compute an "effective attention" matrix from state evolution.
    
    This shows how much each position's output depends on each previous position.
    Computed as correlation between state changes.
    """
    seq_len, hidden_dim = states.shape
    
    # Compute state differences (how much state changed)
    state_diffs = np.diff(states, axis=0)  # [seq_len-1, hidden_dim]
    
    # Compute correlation matrix
    # This approximates "how much does position j influence position i"
    attention_like = np.zeros((seq_len, seq_len))
    
    for i in range(seq_len):
        for j in range(i + 1):  # Causal: can only attend to past
            # Measure similarity between current state and state at j
            similarity = np.dot(states[i], states[j]) / (
                np.linalg.norm(states[i]) * np.linalg.norm(states[j]) + 1e-8
            )
            # Weight by recency (newer = more influence)
            decay = np.exp(-0.1 * (i - j))
            attention_like[i, j] = similarity * decay
    
    # Normalize rows to sum to 1
    row_sums = attention_like.sum(axis=1, keepdims=True) + 1e-8
    attention_like = attention_like / row_sums
    
    return attention_like

# Compute and visualize effective attention
text = "The cat sat on the mat because it was tired."
tokens = tokenizer.encode(text, add_special_tokens=False)
token_strs = [tokenizer.decode([t]) for t in tokens]

input_ids = torch.tensor([tokens], device=device)
with torch.no_grad():
    outputs = model(input_ids, output_hidden_states=True)
states = outputs.hidden_states[1][0].cpu().numpy()

# Compute effective attention
eff_attention = compute_effective_attention(states)

# Plot
fig, ax = plt.subplots(figsize=(10, 8))
im = ax.imshow(eff_attention, cmap='Blues')
ax.set_xticks(range(len(token_strs)))
ax.set_yticks(range(len(token_strs)))
ax.set_xticklabels(token_strs, rotation=45, ha='right')
ax.set_yticklabels(token_strs)
ax.set_xlabel('Source Token (attending from)')
ax.set_ylabel('Target Token (attending to)')
ax.set_title('Mamba "Effective Attention" Pattern\n(derived from state evolution)', 
             fontweight='bold')
plt.colorbar(im, ax=ax, label='Attention Weight')
plt.tight_layout()
plt.show()

print("\nüîç Compare this to transformer attention:")
print("   - Mamba's pattern is smoother (compressed state)")
print("   - Strong diagonal = recency bias (recent tokens matter more)")
print("   - But notice: 'it' still attends to 'cat' (pronoun resolution!)")

---

## Part 6: The Parallel Scan Algorithm (Conceptual)

How does Mamba avoid the sequential bottleneck? The **parallel scan** algorithm!

### The Key Insight

The recurrence `h[t] = A¬∑h[t-1] + B¬∑x[t]` can be rewritten as an associative operation:

```
‚äï: (a‚ÇÅ, b‚ÇÅ) ‚äï (a‚ÇÇ, b‚ÇÇ) = (a‚ÇÅ¬∑a‚ÇÇ, a‚ÇÇ¬∑b‚ÇÅ + b‚ÇÇ)
```

This associativity allows us to compute all states in O(log n) parallel steps!

In [None]:
def sequential_scan(A: torch.Tensor, B: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
    """
    Sequential scan: h[t] = A[t]¬∑h[t-1] + B[t]¬∑x[t]
    
    Time: O(n) sequential
    """
    seq_len = x.shape[0]
    h = torch.zeros_like(x[0])
    outputs = []
    
    for t in range(seq_len):
        h = A[t] * h + B[t] * x[t]
        outputs.append(h)
    
    return torch.stack(outputs)

def parallel_scan(A: torch.Tensor, B: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
    """
    Parallel scan using associative operation.
    
    Time: O(log n) parallel steps
    
    This is a simplified illustration - real implementation uses GPU primitives.
    """
    seq_len = x.shape[0]
    
    # Pack into tuples: (A_cumulative, B¬∑x_cumulative)
    # Start: [(A[0], B[0]¬∑x[0]), (A[1], B[1]¬∑x[1]), ...]
    coeffs = A.clone()
    values = B * x
    
    # Parallel prefix sum
    offset = 1
    while offset < seq_len:
        # In parallel: combine pairs at distance 'offset'
        for i in range(offset, seq_len):
            # Associative operation: (a1, b1) ‚äï (a2, b2) = (a1¬∑a2, a2¬∑b1 + b2)
            values[i] = coeffs[i] * values[i - offset] + values[i]
            coeffs[i] = coeffs[i] * coeffs[i - offset]
        offset *= 2
    
    return values

# Compare sequential vs parallel
seq_len = 8
A = torch.rand(seq_len) * 0.9  # Decay factor < 1 for stability
B = torch.rand(seq_len)
x = torch.rand(seq_len)

h_seq = sequential_scan(A, B, x)
h_par = parallel_scan(A.clone(), B.clone(), x.clone())

print("Sequential scan result:")
print(h_seq.numpy())
print("\nParallel scan result:")
print(h_par.numpy())
print(f"\nResults match: {torch.allclose(h_seq, h_par, atol=1e-5)}")
print("\n‚úÖ Parallel scan produces identical results!")
print("\nüìä Complexity comparison:")
print(f"   Sequential: {seq_len} steps")
print(f"   Parallel:   {int(np.ceil(np.log2(seq_len)))} parallel steps")

### Visualizing the Parallel Scan

In [None]:
# Visualize how parallel scan works
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Sequential scan visualization
ax = axes[0]
seq_len = 8
for i in range(seq_len):
    # Draw node
    ax.scatter(i, 0, s=300, c='#3498DB', zorder=5)
    ax.text(i, 0, str(i), ha='center', va='center', fontweight='bold', color='white')
    
    # Draw arrow from previous
    if i > 0:
        ax.annotate('', xy=(i-0.15, 0), xytext=(i-0.85, 0),
                   arrowprops=dict(arrowstyle='->', color='#E74C3C', lw=2))

ax.set_xlim(-0.5, seq_len - 0.5)
ax.set_ylim(-1, 1)
ax.set_title('Sequential Scan\n(O(n) sequential steps)', fontweight='bold', fontsize=12)
ax.axis('off')
ax.text(seq_len/2, -0.7, f'Total: {seq_len} sequential operations', 
        ha='center', fontsize=11, color='#E74C3C')

# Parallel scan visualization
ax = axes[1]
levels = int(np.ceil(np.log2(seq_len)))

# Draw nodes
for i in range(seq_len):
    ax.scatter(i, 0, s=300, c='#3498DB', zorder=5)
    ax.text(i, 0, str(i), ha='center', va='center', fontweight='bold', color='white')

# Draw parallel operations
colors = ['#27AE60', '#E74C3C', '#9B59B6']
for level in range(levels):
    offset = 2 ** level
    y_pos = -(level + 1) * 0.5
    
    for i in range(offset, seq_len):
        # Draw arc showing combination
        ax.annotate('', xy=(i, y_pos + 0.1), xytext=(i - offset, y_pos + 0.1),
                   arrowprops=dict(arrowstyle='->', color=colors[level % len(colors)], 
                                  lw=2, connectionstyle='arc3,rad=-0.2'))
    
    ax.text(-0.7, y_pos, f'Step {level+1}', fontsize=10, va='center')

ax.set_xlim(-1.5, seq_len - 0.5)
ax.set_ylim(-levels * 0.5 - 0.5, 0.5)
ax.set_title('Parallel Scan\n(O(log n) parallel steps)', fontweight='bold', fontsize=12)
ax.axis('off')
ax.text(seq_len/2, -levels * 0.5 - 0.3, f'Total: {levels} parallel steps', 
        ha='center', fontsize=11, color='#27AE60')

plt.tight_layout()
plt.show()

print("üîç Key insight: With enough parallel processors, the scan completes in O(log n) time!")
print(f"   For seq_len={seq_len}: {seq_len} sequential ops ‚Üí {levels} parallel ops")
print(f"   For seq_len=65536: 65536 sequential ops ‚Üí {int(np.log2(65536))} parallel ops")

---

## ‚ö†Ô∏è Common Mistakes

### Mistake 1: Thinking Mamba is "Just an RNN"
```python
# ‚ùå Wrong mental model
# Mamba is NOT like traditional RNNs (LSTM, GRU)

# ‚úÖ Correct understanding
# Mamba = Selective State Space + Parallel Scan + Hardware-aware design
# - Selectivity makes it content-aware (like attention)
# - Parallel scan makes it trainable at scale
# - Hardware design makes it fast on GPUs
```

### Mistake 2: Expecting Identical Behavior to Transformers
```python
# ‚ùå Mamba won't perfectly copy what transformers do
# - Different inductive bias
# - Different handling of long-range dependencies

# ‚úÖ Understand the tradeoffs
# - Mamba: Better memory, may miss precise long-range patterns
# - Transformer: Precise attention, but O(n¬≤) memory
# - Hybrid (Jamba): Best of both!
```

### Mistake 3: Ignoring State Initialization
```python
# ‚ùå Starting with random state
h = torch.randn(batch, d_state)  # Bad!

# ‚úÖ Start with zeros for deterministic behavior
h = torch.zeros(batch, d_state)
```

---

## üéâ Checkpoint

You've learned:
- ‚úÖ The mathematical foundation of State Space Models
- ‚úÖ How selectivity makes Mamba content-aware
- ‚úÖ The parallel scan algorithm for efficient training
- ‚úÖ How to visualize Mamba's "implicit attention"
- ‚úÖ The tradeoffs vs transformer attention

---

## ‚úã Try It Yourself

### Exercise: State Evolution Analysis

Analyze how Mamba's state evolves differently for:
1. Repetitive text ("the the the the the")
2. Structured text (code with consistent patterns)
3. Diverse text (random words)

Visualize and compare the state evolution patterns.

In [None]:
# Your code here



---

## üìñ Further Reading

- [Mamba Paper](https://arxiv.org/abs/2312.00752) - Original paper with full mathematical details
- [Annotated S4](https://srush.github.io/annotated-s4/) - Excellent walkthrough of the predecessor
- [Parallel Scan Tutorial](https://developer.nvidia.com/gpugems/gpugems3/part-vi-gpu-computing/chapter-39-parallel-prefix-sum-scan-cuda) - GPU Gems chapter on parallel scan
- [Jamba Paper](https://arxiv.org/abs/2403.19887) - Hybrid Mamba-Attention architecture

---

## üßπ Cleanup

In [None]:
# Cleanup
import gc

if 'model' in dir():
    del model
if 'simple_ssm' in dir():
    del simple_ssm
if 'selective_ssm' in dir():
    del selective_ssm

torch.cuda.empty_cache()
gc.collect()

print("‚úÖ Cleanup complete!")