# Mamba: Linear-Time Sequence Modeling with Selective State Spaces

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/adiel2012/deep-learning-abc/blob/main/mamba_ssm.ipynb)

Mamba (State Space Models) offers Transformer-quality performance with linear time scaling $O(N)$.

Key Innovations:
1. **Selection Mechanism:** The SSM parameters ($B, C, \Delta$) are functions of the input $x_t$, allowing the model to "select" what to remember/forget.
2. **Hardware-aware Algorithm:** Uses a parallel associative scan (prefix sum) to compute the recurrence efficiently on GPU.

$$ h_t = \bar{A}_t h_{t-1} + \bar{B}_t x_t $$ 
$$ y_t = C_t h_t $$

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Using device: {device}')

## 1. Selective Scan (Simulation)

The core of Mamba is the "Selective Scan". It's a recurrence where the transition matrices change at every timestep based on the input.

In [None]:
def selective_scan_ref(u, delta, A, B, C, D=None):
    """
    Reference implementation of Selective Scan (Sequential).
    u: (B, L, D_in) - Input
    delta: (B, L, D_in) - Time step parameter
    A: (D_in, N) - State transition (diagonal)
    B: (B, L, N) - Input projection (Input dependent)
    C: (B, L, N) - Output projection (Input dependent)
    D: (D_in) - Residual
    """
    batch_size, seq_len, d_in = u.shape
    n = A.shape[1]
    
    # Discretize A
    # deltaA = exp(delta * A)
    deltaA = torch.exp(torch.einsum('bld,dn->bldn', delta, A))
    
    # Discretize B
    # deltaB = delta * B
    deltaB = torch.einsum('bld,bln->bldn', delta, B)
    
    # Recurrence
    x = torch.zeros((batch_size, d_in, n), device=u.device)
    ys = []
    
    for t in range(seq_len):
        x = deltaA[:, t] * x + deltaB[:, t] * u[:, t].unsqueeze(-1)
        y = torch.einsum('bdn,b n->bd', x, C[:, t]) # (B, D_in)
        ys.append(y)
        
    y = torch.stack(ys, dim=1) # (B, L, D_in)
    
    if D is not None:
        y = y + u * D
        
    return y

## 2. Mamba Block

A Mamba block combines projections, convolution (for local context), and the Selective Scan.

In [None]:
class MambaBlock(nn.Module):
    def __init__(self, d_model, d_state=16, expand=2):
        super().__init__()
        self.d_model = d_model
        self.d_state = d_state
        self.d_inner = int(expand * d_model)
        
        # 1. Input Projections
        self.in_proj = nn.Linear(d_model, self.d_inner * 2)
        
        # 2. Convolution (1D)
        self.conv1d = nn.Conv1d(
            in_channels=self.d_inner,
            out_channels=self.d_inner,
            kernel_size=4,
            groups=self.d_inner,
            padding=3
        )
        
        # 3. Parameters for Selective Scan (Input Dependent!)
        self.x_proj = nn.Linear(self.d_inner,  d_state + d_state + self.d_inner)
        # The line above projects input to [delta, B, C]
        
        self.dt_proj = nn.Linear(d_state, self.d_inner, bias=True)
        
        # Fixed parameters A and D
        A = torch.arange(1, d_state + 1, dtype=torch.float32).repeat(self.d_inner, 1)
        self.A_log = nn.Parameter(torch.log(A))
        self.D = nn.Parameter(torch.ones(self.d_inner))
        
        # 4. Output Projection
        self.out_proj = nn.Linear(self.d_inner, d_model)
        
    def forward(self, x):
        # x: (B, L, D_model)
        batch_size, seq_len, _ = x.shape
        
        x_and_res = self.in_proj(x)
        x, res = x_and_res.chunk(2, dim=-1)
        
        x = x.transpose(1, 2)
        x = self.conv1d(x)[:, :, :seq_len]
        x = F.silu(x)
        x = x.transpose(1, 2) # (B, L, D_inner)
        
        # Compute input-dependent parameters
        # Project x to get delta, B, C
        # x_proj outputs: dt (D_inner) + B (D_state) + C (D_state) ? Close -- usually it's:
        # For simplification, we map x -> (delta, B, C)
        # Let's assume standard Mamba shapes for simplicity
        
        delta_B_C = self.x_proj(x)
        
        # Split them
        # This part requires careful dimension handling matching the official implementation
        # Here we do a simplified version:
        # delta: (B, L, D_inner)
        # B: (B, L, D_state)
        # C: (B, L, D_state)
        
        delta = delta_B_C[:, :, :self.d_inner]
        B = delta_B_C[:, :, self.d_inner : self.d_inner + self.d_state]
        C = delta_B_C[:, :, self.d_inner + self.d_state :]
        
        delta = F.softplus(delta) # Ensure positive time step
        A = -torch.exp(self.A_log) # Ensure A is negative for stability
        
        # Run SSM
        y = selective_scan_ref(x, delta, A, B, C, self.D)
        
        y = y * F.silu(res) # Res connection gating
        return self.out_proj(y)

# Test
block = MambaBlock(d_model=64, d_state=16).to(device)
x = torch.randn(2, 32, 64).to(device)
out = block(x)
print("Input:", x.shape)
print("Output:", out.shape)