# Mamba: Linear-Time Sequence Modeling with Selective State Spaces

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/adiel2012/deep-learning-abc/blob/main/mamba_ssm.ipynb)

## 1. Mathematical Foundations: Continuous to Discrete SSM

State Space Models (SSMs) map a 1D function or sequence $x(t) \in \mathbb{R}$ to $y(t) \in \mathbb{R}$ through a latent state $h(t) \in \mathbb{R}^N$.

### Continuous System
$$ \dot{h}(t) = \mathbf{A} h(t) + \mathbf{B} x(t) $$
$$ y(t) = \mathbf{C} h(t) $$

where $\mathbf{A}$ determines the state evolution and $\mathbf{B}, \mathbf{C}$ are projections.

### Discretization (Zero-Order Hold)
To handle discrete data (text), we discretize the system using a step size $\Delta_t$ (which is input-dependent in Mamba). Using the Zero-Order Hold (ZOH) assumption:

$$ \bar{\mathbf{A}}_t = \exp(\Delta_t \mathbf{A}) $$
$$ \bar{\mathbf{B}}_t = (\Delta_t \mathbf{A})^{-1} (\exp(\Delta_t \mathbf{A}) - \mathbf{I}) \cdot \Delta_t \mathbf{B} \approx \Delta_t \mathbf{B} $$

The discrete recurrence becomes:
$$ h_t = \bar{\mathbf{A}}_t h_{t-1} + \bar{\mathbf{B}}_t x_t $$
$$ y_t = \mathbf{C}_t h_t $$

### Selectivity
In standard SSMs (S4), $\mathbf{A}, \mathbf{B}, \mathbf{C}$ are time-invariant. In Mamba, $\Delta_t, \mathbf{B}_t, \mathbf{C}_t$ are functions of the input $x_t$, allowing the model to selectively remember or ignore information.

In [None]:
!pip install torch torchvision matplotlib

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
import matplotlib.pyplot as plt

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Using device: {device}')

## 2. Selective Scan Implementation (Simulation)

In [None]:
def selective_scan_ref(u, delta, A, B, C, D=None):
    """
    Reference implementation of Selective Scan (Sequential).
    u: (B, L, D_in) - Input
    delta: (B, L, D_in) - Time step parameter
    A: (D_in, N) - State transition (diagonal)
    B: (B, L, N) - Input projection (Input dependent)
    C: (B, L, N) - Output projection (Input dependent)
    D: (D_in) - Residual
    """
    batch_size, seq_len, d_in = u.shape
    n = A.shape[1]
    
    # Discretize A
    # deltaA = exp(delta * A)
    deltaA = torch.exp(torch.einsum('bld,dn->bldn', delta, A))
    
    # Discretize B
    # deltaB = delta * B
    deltaB = torch.einsum('bld,bln->bldn', delta, B)
    
    # Recurrence
    x = torch.zeros((batch_size, d_in, n), device=u.device)
    ys = []
    
    # For visualization of one channel's state
    h_trace = []
    
    for t in range(seq_len):
        x = deltaA[:, t] * x + deltaB[:, t] * u[:, t].unsqueeze(-1)
        y = torch.einsum('bdn,b n->bd', x, C[:, t]) # (B, D_in)
        ys.append(y)
        h_trace.append(x[0, 0, :].detach().cpu()) # Store state of batch 0, channel 0
        
    y = torch.stack(ys, dim=1) # (B, L, D_in)
    
    if D is not None:
        y = y + u * D
        
    return y, h_trace

## 3. Mamba Block

In [None]:
class MambaBlock(nn.Module):
    def __init__(self, d_model, d_state=16, expand=2):
        super().__init__()
        self.d_model = d_model
        self.d_state = d_state
        self.d_inner = int(expand * d_model)
        
        self.in_proj = nn.Linear(d_model, self.d_inner * 2)
        
        self.conv1d = nn.Conv1d(
            in_channels=self.d_inner,
            out_channels=self.d_inner,
            kernel_size=4,
            groups=self.d_inner,
            padding=3
        )
        
        self.x_proj = nn.Linear(self.d_inner,  d_state + d_state + self.d_inner)
        self.dt_proj = nn.Linear(d_state, self.d_inner, bias=True)
        
        A = torch.arange(1, d_state + 1, dtype=torch.float32).repeat(self.d_inner, 1)
        self.A_log = nn.Parameter(torch.log(A))
        self.D = nn.Parameter(torch.ones(self.d_inner))
        self.out_proj = nn.Linear(self.d_inner, d_model)
        
    def forward(self, x, return_trace=False):
        batch_size, seq_len, _ = x.shape
        x_and_res = self.in_proj(x)
        x, res = x_and_res.chunk(2, dim=-1)
        
        x = x.transpose(1, 2)
        x = self.conv1d(x)[:, :, :seq_len]
        x = F.silu(x)
        x = x.transpose(1, 2)
        
        delta_B_C = self.x_proj(x)
        delta = delta_B_C[:, :, :self.d_inner]
        B = delta_B_C[:, :, self.d_inner : self.d_inner + self.d_state]
        C = delta_B_C[:, :, self.d_inner + self.d_state :]
        
        delta = F.softplus(delta)
        A = -torch.exp(self.A_log)
        
        y, h_trace = selective_scan_ref(x, delta, A, B, C, self.D)
        
        y = y * F.silu(res)
        out = self.out_proj(y)
        
        if return_trace:
            return out, h_trace
        return out


## 4. Visualization: Hidden State Dynamics

In [None]:
block = MambaBlock(d_model=64, d_state=16).to(device)
x = torch.randn(1, 50, 64).to(device) # Single sequence
out, h_trace = block(x, return_trace=True)

# Stack trace: (Seq_Len, D_state)
h_trace = torch.stack(h_trace).numpy()

plt.figure(figsize=(10, 4))
plt.plot(h_trace)
plt.title("Evolution of Hidden State $h_t$ (16 dimensions) over Time")
plt.xlabel("Time Step")
plt.ylabel("State Value")
plt.grid(True, alpha=0.3)
plt.show()

print("The hidden state evolves based on input selectivity.")