In [1]:
import os

os.chdir("/root/dev/playground/modules/mamba")
os.getcwd()

'/root/dev/playground/modules/mamba'

In [2]:
import torch

torch.cuda.set_device(5)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [3]:
!gpustat

[1m[37m794894ff7784              [m  Tue Oct 22 00:24:39 2024  [1m[30m535.154.05[m
[36m[0][m [34mNVIDIA GeForce RTX 4090[m |[31m 29°C[m, [32m  0 %[m | [36m[1m[33m 2536[m / [33m24564[m MB |
[36m[1][m [34mNVIDIA GeForce RTX 4090[m |[31m 32°C[m, [32m  0 %[m | [36m[1m[33m    3[m / [33m24564[m MB |
[36m[2][m [34mNVIDIA GeForce RTX 4090[m |[31m 28°C[m, [32m  0 %[m | [36m[1m[33m 2234[m / [33m24564[m MB |
[36m[3][m [34mNVIDIA GeForce RTX 4090[m |[31m 31°C[m, [32m  0 %[m | [36m[1m[33m    3[m / [33m24564[m MB |
[36m[4][m [34mNVIDIA GeForce RTX 4090[m |[31m 33°C[m, [32m  0 %[m | [36m[1m[33m    3[m / [33m24564[m MB |
[36m[5][m [34mNVIDIA GeForce RTX 4090[m |[31m 35°C[m, [32m  0 %[m | [36m[1m[33m  390[m / [33m24564[m MB |


# Selective scan

- https://github.com/state-spaces/mamba/blob/main/mamba_ssm/ops/triton/selective_state_update.py

In [8]:
import triton
import triton.language as tl

In [11]:
# https://github.com/state-spaces/mamba/blob/main/mamba_ssm/ops/triton/softplus.py
@triton.jit
def softplus(dt):
    return tl.math.log1p(tl.exp(dt)) # log(1 + exp(dt))

- The derivative of softplus is sigmoid

### _selective_scan_update_kernel function

Terminologies
- Stride: The step size (in memory) between consecutive elements along a specific dimension of a multi-dimensional tensor.
    - stride_x_batch: Tells how far to move in memory to get to the **next batch** of the `x` tensor
    - stride_x_head: Tells how far to move in memory to get to the **next head** of the `x` tensor
    - stride_x_dim: Tells how far to move in memory to get to the **next dim** of the `x` tensor
- Representations of dimensions
    $$ \tilde{A} \in \mathbb{R}^{(D, N)}, \; \tilde{B}, \tilde{C} \in \mathbb{R}^{(B, L, N)}, \; \Delta \in \mathbb{R}^{(B, L, D)}, \; A, B \in \mathbb{R}^{(B, L, D, N)}  $$
    - **B**: batch
    - **L**: head
    - **D**: dim
    - **N**: dstate

Triton-based GPU kernel designed for efficient updates of a recurrent state using selective state-space models (SSMs).

- Inputs: pointers to different matrices representing states, inputs, weights (A, B, C, D), and other meta-parameters, which correspond to various components of SSMs.
    - state_ptr: points to the recurrent state that is updated
    - x_ptr: Input matrix for the current time step
    - dt_ptr: Delta time for the update (can vary across steps)
    - A_ptr, B_ptr, C_ptr: SSM parameter matrices
    - (Optional) dt_bias_ptr, D_ptr, z_ptr: bias, weights for direct input-to-output connection, and auxiliary gating term
    - Meta-parameters
        - TIE-HDIM: whether it simplifies calculations by assuming head dimensions are tied, reproducing complexity
        - BLOCK_SIZE_M: determines the size of blocks processed in parallel (based on GPU memory constraints)
        - HAS_DT_BIAS, HAS_D, HAS_Z: Flags indicating whether certain inputs are present

So, the corresponding equation is:
$$ \begin{align*} h_t &= A_t h_{t-1} + B_t x_t \\ y_t &= C_t^\top h_t \end{align*} $$

- $\Delta t = \tau(\operatorname{Linear}(x_t)) = \operatorname{softplus}(\operatorname{Linear}(x_t))$ (by the proof of Theorem 1 in the paper)
- $A = \operatorname{exp}(\tilde{A} \cdot \Delta t)$
- $B = \sigma(\operatorname{Linear}(x_t))$

In [14]:
# perform the discretization and recurrence in SRAM
@triton.heuristics({"HAS_DT_BIAS": lambda args: args["dt_bias_ptr"] is not None})
@triton.heuristics({"HAS_D": lambda args: args["D_ptr"] is not None})
@triton.heuristics({"HAS_Z": lambda args: args["z_ptr"] is not None})
@triton.heuristics({"BLOCK_SIZE_DSTATE": lambda args: triton.next_power_of_2(args["dstate"])})
@triton.jit
def _selective_scan_update_kernel(
    # Pointers to matrices
    state_ptr, x_ptr, dt_ptr, dt_bias_ptr, A_ptr, B_ptr, C_ptr, D_ptr, z_ptr, out_ptr,
    # Matrix dimensions
    batch, nheads, dim, dstate, nheads_ngroups_ratio,
    # Strides
    stride_state_batch, stride_state_head, stride_state_dim, stride_state_dstate,
    stride_x_batch, stride_x_head, stride_x_dim,
    stride_dt_batch, stride_dt_head, stride_dt_dim,
    stride_dt_bias_head, stride_dt_bias_dim,
    stride_A_head, stride_A_dim, stride_A_dstate,
    stride_B_batch, stride_B_group, stride_B_dstate,
    stride_C_batch, stride_C_group, stride_C_dstate,
    stride_D_head, stride_D_dim,
    stride_z_batch, stride_z_head, stride_z_dim,
    stride_out_batch, stride_out_head, stride_out_dim,
    # Meta-parameters
    DT_SOFTPLUS: tl.constexpr,
    TIE_HDIM: tl.constexpr,
    BLOCK_SIZE_M: tl.constexpr,
    HAS_DT_BIAS: tl.constexpr, HAS_D: tl.constexpr, HAS_Z: tl.constexpr,
    BLOCK_SIZE_DSTATE: tl.constexpr,
    ):
    # program id and pointer adjustments
    pid_m, pid_b, pid_h = tl.program_id(axis=0), tl.program_id(axis=1), tl.program_id(axis=2)
    state_ptr += pid_b * stride_state_batch + pid_h * stride_state_head
    x_ptr += pid_b * stride_x_batch + pid_h * stride_x_head
    dt_ptr += pid_b * stride_dt_batch + pid_h * stride_dt_head
    if HAS_DT_BIAS:
        dt_bias_ptr += pid_h * stride_dt_bias_head
    A_ptr += pid_h * stride_A_head
    B_ptr += pid_b * stride_B_batch + (pid_h // nheads_ngroups_ratio) * stride_B_group
    C_ptr += pid_b * stride_C_batch + (pid_h // nheads_ngroups_ratio) * stride_C_group
    if HAS_Z:
        z_ptr += pid_b * stride_z_batch + pid_h * stride_z_head
    out_ptr += pid_b * stride_out_batch + pid_h * stride_out_head
    
    # pointers for loading blocks of data
    offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) # offset along the first dimension (dim)
    offs_n = tl.arange(0, BLOCK_SIZE_DSTATE) # offset along the second dimension (dstate)
    state_ptrs = state_ptr + (offs_m[:, None] * stride_state_dim + offs_n[None, :] * stride_state_dstate) # (stride_state_dim, stride_state_dstate)
    x_ptrs, dt_ptrs = x_ptr + offs_m * stride_x_dim, dt_ptr + offs_m * stride_dt_dim
    if HAS_DT_BIAS:
        dt_bias_ptrs = dt_bias_ptr + offs_m * stride_dt_bias_dim
    if HAS_D:
        D_ptr += pid_h * stride_D_head
    A_ptrs = A_ptr + (offs_m[:, None] * stride_A_dim + offs_n[None, :] * stride_A_dstate)
    B_ptrs = B_ptr + offs_n * stride_B_dstate
    C_ptrs = C_ptr + offs_n * stride_C_dstate
    if HAS_D:
        D_ptrs = D_ptr + offs_m * stride_D_dim
    if HAS_Z:
        z_ptrs = z_ptr + offs_m * stride_z_dim
    out_ptrs = out_ptr + offs_m * stride_out_dim
    
    # loading and discretization
    state = tl.load(state_ptrs, mask=(offs_m[:, None] < dim) & (offs_n[None, :] < dstate), other=0.)
    x = tl.load(x_ptrs, mask=offs_m < dim, other=0.).to(tl.float32)
    if not TIE_HDIM:
        dt = tl.load(dt_ptrs, mask=offs_m < dim, other=0.).to(tl.float32)
        if HAS_DT_BIAS:
            dt += tl.load(dt_bias_ptrs, mask=offs_m < dim, other=0.).to(tl.float32)
        if DT_SOFTPLUS:
            dt = tl.where(dt <= 20., softplus(dt), dt) # dt = softplus(dt) = log(1 + exp(dt))
        A = tl.load(A_ptrs, mask=(offs_m[:, None] < dim) & (offs_n[None, :] < dstate), other=0.).to(tl.float32)
        dA = tl.exp(A * dt[:, None]) # A_t = exp(\Delta A) = exp(A \Delta t)
    else:
        dt = tl.load(dt_ptr).to(tl.float32)
        if HAS_DT_BIAS:
            dt += tl.load(dt_bias_ptr).to(tl.float32)
        if DT_SOFTPLUS:
            dt = tl.where(dt <= 20., softplus(dt), dt) # dt = softplus(dt) = log(1 + exp(dt))
        A = tl.load(A_ptr).to(tl.float32)
        dA = tl.exp(A * dt) # scalar, not a matrix
    
    B = tl.load(B_ptrs, mask=offs_n < dstate, other=0.).to(tl.float32)
    C = tl.load(C_ptrs, mask=offs_n < dstate, other=0.).to(tl.float32)
    if HAS_D:
        D = tl.load(D_ptrs, mask=offs_m < dim, other=0.).to(tl.float32)
    if HAS_Z:
        z = tl.load(z_ptrs, mask=offs_m < dim, other=0.).to(tl.float32)
    
    if not TIE_HDIM:
        dB = B[None, :] * dt[:, None] # B_t = sigmoid(Linear(x_t)) → not explicitly reprsented here
    else:
        dB = B * dt # vector of size (dstate,)
    state = state * dA + dB * x[:, None] # Linear(x_t)
    tl.store(state_ptrs, state, mask=(offs_m[:, None] < dim) & (offs_n[None, :] < dstate))
    
    # output calculation
    out = tl.sum(state * C[None, :], axis=1)
    if HAS_D:
        out += x * D
    if HAS_Z:
        out *= z * tl.sigmoid(z)
    tl.store(out_ptrs, out, mask=offs_m < dim)

In [15]:
def selective_state_update(state, x, dt, A, B, C, D=None, z=None, dt_bias=None, dt_softplus=False):
    has_heads = state.dim() > 3
    if state.dim() == 3:
        state = state.unsqueeze(1) # (B, L, D, N)
    if x.dim() == 2:
        x = x.unsqueeze(1) # (B, L, D)
    if dt.dim() == 2:
        dt = dt.unsqueeze(1) # (B, L, D)
    if A.dim() == 2: # (D, N)
        A = A.unsqueeze(0) # (L, D, N)
    if B.dim() == 2:
        B = B.unsqueeze(1) # (B, L, N)
    if C.dim() == 2:
        C = C.unsqueeze(1) # (B, L, N)
    if D is not None and D.dim() == 1:
        D = D.unsqueeze(0) # (L, D)
    if z is not None and z.dim() == 2:
        z = z.unsqueeze(1) # (B, L, D)
    if dt_bias is not None and dt_bias.dim() == 1:
        dt_bias = dt_bias.unsqueeze(0) # (L, D)
    
    batch, nheads, dim, dstate = state.shape # B L D N
    assert x.shape == (batch, nheads, dim)
    assert dt.shape == x.shape
    assert A.shape == (nheads, dim, dstate)
    ngroups = B.shape[1]
    assert nheads % ngroups == 0, "nheads must be divisible by ngroups"
    assert B.shape == (batch, ngroups, dstate)
    assert C.shape == B.shape
    if D is not None:
        assert D.shape == (nheads, dim)
    if z is not None:
        assert z.shape == x.shape
    if dt_bias is not None:
        assert dt_bias.shape == (nheads, dim)
    
    out = torch.empty_like(x)
    grid = lambda META: (triton.cdiv(dim, META['BLOCK_SIZE_M']), batch, nheads)
    z_strides = ((z.stride(0), z.stride(1), z.stride(2)) if z is not None else (0, 0, 0))
    
    # We don't want to autotune since it will overwrite the state
    # Instead, we tune by hand.
    BLOCK_SIZE_M = num_warps = ((32, 4) if dstate <= 16
                                else ((16, 4) if dstate <= 32 else
                                      ((8, 4) if dstate <= 64 else
                                      ((4, 4) if dstate <= 128 else
                                      ((4, 8))))))
    tie_hdim = A.stride(-1) == 0 and A.stride(-2) == 0 and dt.stride(-1) == 0 and dt_bias.stride(-1) == 0
    with torch.cuda.device(x.device.index):
        _selective_scan_update_kernel[grid](
            state, x, dt, dt_bias, A, B, C, D, z, out,
            batch, nheads, dim, dstate, nheads // ngroups,
            state.stride(0), state.stride(1), state.stride(2), state.stride(3),
            x.stride(0), x.stride(1), x.stride(2),
            dt.stride(0), dt.stride(1), dt.stride(2),
            *(dt_bias.stride(0), dt_bias.stride(1)) if dt_bias is not None else 0,
            A.stride(0), A.stride(1), A.stride(2),
            B.stride(0), B.stride(1), B.stride(2),
            C.stride(0), C.stride(1), C.stride(2),
            *(D.stride(0), D.stride(1)) if D is not None else 0,
            z_strides[0], z_strides[1], z_strides[2],
            out.stride(0), out.stride(1), out.stride(2),
            dt_softplus, tie_hdim, BLOCK_SIZE_M, num_warps=num_warps)
        if not has_heads:
            out = out.squeeze(1)
        return out

## Mamba Model

In [16]:
class CausalConv1dFn(torch.autograd.Function):
    @staticmethod
    def forward(
        ctx,
        x,
        weight,
        bias=None,
        seq_idx=None,
        initial_states=None,
        return_final_states=False,
        final_states_out=None,
        activation=None,
    ):
        if activation not in [None, "silu", "swish"]:
            raise NotImplementedError("activation must be None, silu, or swish")
        if x.stride(2) != 1 and x.stride(1) != 1:
            x = x.contiguous()
        bias = bias.contiguous() if bias is not None else None
        if seq_idx is not None:
            assert (
                initial_states is None
            ), "initial_states must be None if seq_idx is not None"
            assert (
                not return_final_states
            ), "If seq_idx is not None, we don't return final_states_out"
        seq_idx = seq_idx.contiguous() if seq_idx is not None else None
        if initial_states is not None and (
            initial_states.stride(2) != 1 and initial_states.stride(1) != 1
        ):
            initial_states = initial_states.contiguous()
        if return_final_states:
            assert (
                x.stride(1) == 1
            ), "Only channel-last layout support returning final_states_out"
            if final_states_out is not None:
                assert (
                    final_states_out.stride(2) == 1 or final_states_out.stride(1) == 1
                )
            else:
                batch, dim, seqlen = x.shape
                width = weight.shape[1]
                final_states_out = torch.empty(
                    batch, width - 1, dim, device=x.device, dtype=x.dtype
                ).transpose(1, 2)
        else:
            final_states_out = None
        ctx.activation = activation in ["silu", "swish"]
        out = causal_conv1d_cuda.causal_conv1d_fwd(
            x, weight, bias, seq_idx, initial_states, final_states_out, ctx.activation
        )
        ctx.save_for_backward(x, weight, bias, seq_idx, initial_states)
        ctx.return_final_states = return_final_states
        ctx.return_dinitial_states = (
            initial_states is not None and initial_states.requires_grad
        )
        return out if not return_final_states else (out, final_states_out)

    @staticmethod
    def backward(ctx, dout, *args):
        x, weight, bias, seq_idx, initial_states = ctx.saved_tensors
        dfinal_states = args[0] if ctx.return_final_states else None
        if dout.stride(2) != 1 and dout.stride(1) != 1:
            dout = dout.contiguous()
        # The kernel supports passing in a pre-allocated dx (e.g., in case we want to fuse the
        # backward of conv1d with the backward of chunk).
        # Here we just pass in None and dx will be allocated in the C++ code.
        dx, dweight, dbias, dinitial_states = causal_conv1d_cuda.causal_conv1d_bwd(
            x,
            weight,
            bias,
            dout,
            seq_idx,
            initial_states,
            dfinal_states,
            None,
            ctx.return_dinitial_states,
            ctx.activation,
        )
        return (
            dx,
            dweight,
            dbias if bias is not None else None,
            None,
            dinitial_states if initial_states is not None else None,
            None,
            None,
            None,
        )


def causal_conv1d_fn(
    x,
    weight,
    bias=None,
    seq_idx=None,
    initial_states=None,
    return_final_states=False,
    final_states_out=None,
    activation=None,
):
    """
    x: (batch, dim, seqlen)
    weight: (dim, width)
    bias: (dim,)
    seq_idx: (batch, seqlen)
    initial_states: (batch, dim, width - 1)
    final_states_out: (batch, dim, width - 1), to be written to
    activation: either None or "silu" or "swish"

    out: (batch, dim, seqlen)
    """
    return CausalConv1dFn.apply(
        x,
        weight,
        bias,
        seq_idx,
        initial_states,
        return_final_states,
        final_states_out,
        activation,
    )

In [None]:
from torch.cuda.amp import custom_bwd, custom_fwd

class MambaInnerFn(torch.autograd.Function):
    @staticmethod
    @custom_fwd
    def forward(ctx, xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight,
                out_proj_weight, out_proj_bias,
                A, B=None, C=None, D=None, delta_bias=None, B_proj_bias=None,
                C_proj_bias=None, delta_softplus=True, checkpoint_lvl=1):
        """ xz: (batch, dim, seqlen) """
        assert checkpoint_lvl in [0, 1]
        L = xz.shape[-1]
        delta_rank = delta_proj_weight.shape[1]
        d_state = A.shape[-1] * (1 if not A.is_complex() else 2)
        if xz.stride(-1) != 1:
            xz = xz.contiguous()
        conv1d_weight = rearrange(conv1d_weight, "d 1 w -> d w")
        x, z = xz.chunk(2, dim=1)
        conv1d_bias = conv1d_bias.contiguous() if conv1d_bias is not None else None
        conv1d_out = causal_conv1d_cuda.causal_conv1d_fwd(x, conv1d_weight, conv1d_bias, None, None, None, True)
        
        # We're being very careful here about the layout, to avoid extra transposes.
        # We want delta to have d as the slowest moving dimension
        # and L as the fastest moving dimension, since those are what the ssm_scan kernel expects.
        x_dbl = F.linear(rearrange(conv1d_out, 'b d l -> (b l) d'), x_proj_weight)  # (bl d)
        delta = rearrange(delta_proj_weight @ x_dbl[:, :delta_rank].t(), "d (b l) -> b d l", l = L)
        ctx.is_variable_B = B is None
        ctx.is_variable_C = C is None
        ctx.B_proj_bias_is_None = B_proj_bias is None
        ctx.C_proj_bias_is_None = C_proj_bias is None
        if B is None:  # variable B
            B = x_dbl[:, delta_rank:delta_rank + d_state]  # (bl dstate)
            if B_proj_bias is not None:
                B = B + B_proj_bias.to(dtype=B.dtype)
            if not A.is_complex():
                # B = rearrange(B, "(b l) dstate -> b dstate l", l=L).contiguous()
                B = rearrange(B, "(b l) dstate -> b 1 dstate l", l=L).contiguous()
            else:
                B = rearrange(B, "(b l) (dstate two) -> b 1 dstate (l two)", l=L, two=2).contiguous()
        else:
            if B.stride(-1) != 1:
                B = B.contiguous()
        if C is None:  # variable C
            C = x_dbl[:, -d_state:]  # (bl dstate)
            if C_proj_bias is not None:
                C = C + C_proj_bias.to(dtype=C.dtype)
            if not A.is_complex():
                # C = rearrange(C, "(b l) dstate -> b dstate l", l=L).contiguous()
                C = rearrange(C, "(b l) dstate -> b 1 dstate l", l=L).contiguous()
            else:
                C = rearrange(C, "(b l) (dstate two) -> b 1 dstate (l two)", l=L, two=2).contiguous()
        else:
            if C.stride(-1) != 1:
                C = C.contiguous()
        if D is not None:
            D = D.contiguous()
        out, scan_intermediates, out_z = selective_scan_cuda.fwd(
            conv1d_out, delta, A, B, C, D, z, delta_bias, delta_softplus
        )
        ctx.delta_softplus = delta_softplus
        ctx.out_proj_bias_is_None = out_proj_bias is None
        ctx.checkpoint_lvl = checkpoint_lvl
        if checkpoint_lvl >= 1:  # Will recompute conv1d_out and delta in the backward pass
            conv1d_out, delta = None, None
        ctx.save_for_backward(xz, conv1d_weight, conv1d_bias, x_dbl, x_proj_weight,
                              delta_proj_weight, out_proj_weight, conv1d_out, delta,
                              A, B, C, D, delta_bias, scan_intermediates, out)
        return F.linear(rearrange(out_z, "b d l -> b l d"), out_proj_weight, out_proj_bias)
    
    
    @staticmethod
    @custom_bwd
    def backward(ctx, dout):
        # dout: (batch, seqlen, dim)
        assert causal_conv1d_cuda is not None, "causal_conv1d_cuda is not available. Please install causal-conv1d."
        (xz, conv1d_weight, conv1d_bias, x_dbl, x_proj_weight, delta_proj_weight, out_proj_weight,
         conv1d_out, delta, A, B, C, D, delta_bias, scan_intermediates, out) = ctx.saved_tensors
        L = xz.shape[-1]
        delta_rank = delta_proj_weight.shape[1]
        d_state = A.shape[-1] * (1 if not A.is_complex() else 2)
        x, z = xz.chunk(2, dim=1)
        if dout.stride(-1) != 1:
            dout = dout.contiguous()
        if ctx.checkpoint_lvl == 1:
            conv1d_out = causal_conv1d_cuda.causal_conv1d_fwd(
                x, conv1d_weight, conv1d_bias, None, None, None, True
            )
            delta = rearrange(delta_proj_weight @ x_dbl[:, :delta_rank].t(),
                              "d (b l) -> b d l", l = L)
        # The kernel supports passing in a pre-allocated dz (e.g., in case we want to fuse the
        # backward of selective_scan_cuda with the backward of chunk).
        dxz = torch.empty_like(xz)  # (batch, dim, seqlen)
        dx, dz = dxz.chunk(2, dim=1)
        dout = rearrange(dout, "b l e -> e (b l)")
        dout_y = rearrange(out_proj_weight.t() @ dout, "d (b l) -> b d l", l=L)
        dconv1d_out, ddelta, dA, dB, dC, dD, ddelta_bias, dz, out_z = selective_scan_cuda.bwd(
            conv1d_out, delta, A, B, C, D, z, delta_bias, dout_y, scan_intermediates, out, dz,
            ctx.delta_softplus,
            True  # option to recompute out_z
        )
        dout_proj_weight = torch.einsum("eB,dB->ed", dout, rearrange(out_z, "b d l -> d (b l)"))
        dout_proj_bias = dout.sum(dim=(0, 1)) if not ctx.out_proj_bias_is_None else None
        dD = dD if D is not None else None
        dx_dbl = torch.empty_like(x_dbl)
        dB_proj_bias = None
        if ctx.is_variable_B:
            if not A.is_complex():
                dB = rearrange(dB, "b 1 dstate l -> (b l) dstate").contiguous()
            else:
                dB = rearrange(dB, "b 1 dstate (l two) -> (b l) (dstate two)", two=2).contiguous()
            dB_proj_bias = dB.sum(0) if not ctx.B_proj_bias_is_None else None
            dx_dbl[:, delta_rank:delta_rank + d_state] = dB  # (bl d)
            dB = None
        dC_proj_bias = None
        if ctx.is_variable_C:
            if not A.is_complex():
                dC = rearrange(dC, "b 1 dstate l -> (b l) dstate").contiguous()
            else:
                dC = rearrange(dC, "b 1 dstate (l two) -> (b l) (dstate two)", two=2).contiguous()
            dC_proj_bias = dC.sum(0) if not ctx.C_proj_bias_is_None else None
            dx_dbl[:, -d_state:] = dC  # (bl d)
            dC = None
        ddelta = rearrange(ddelta, "b d l -> d (b l)")
        ddelta_proj_weight = torch.einsum("dB,Br->dr", ddelta, x_dbl[:, :delta_rank])
        dx_dbl[:, :delta_rank] = torch.einsum("dB,dr->Br", ddelta, delta_proj_weight)
        dconv1d_out = rearrange(dconv1d_out, "b d l -> d (b l)")
        dx_proj_weight = torch.einsum("Br,Bd->rd", dx_dbl, rearrange(conv1d_out, "b d l -> (b l) d"))
        dconv1d_out = torch.addmm(dconv1d_out, x_proj_weight.t(), dx_dbl.t(), out=dconv1d_out)
        dconv1d_out = rearrange(dconv1d_out, "d (b l) -> b d l", b=x.shape[0], l=x.shape[-1])
        # The kernel supports passing in a pre-allocated dx (e.g., in case we want to fuse the
        # backward of conv1d with the backward of chunk).
        dx, dconv1d_weight, dconv1d_bias, *_ = causal_conv1d_cuda.causal_conv1d_bwd(
            x, conv1d_weight, conv1d_bias, dconv1d_out, None, None, None, dx, False, True
        )
        dconv1d_bias = dconv1d_bias if conv1d_bias is not None else None
        dconv1d_weight = rearrange(dconv1d_weight, "d w -> d 1 w")
        return (dxz, dconv1d_weight, dconv1d_bias, dx_proj_weight, ddelta_proj_weight,
                dout_proj_weight, dout_proj_bias,
                dA, dB, dC, dD,
                ddelta_bias if delta_bias is not None else None,
                dB_proj_bias, dC_proj_bias, None)


def mamba_inner_fn(
    xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight,
    out_proj_weight, out_proj_bias,
    A, B=None, C=None, D=None, delta_bias=None, B_proj_bias=None,
    C_proj_bias=None, delta_softplus=True
):
    return MambaInnerFn.apply(xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight,
                              out_proj_weight, out_proj_bias,
                              A, B, C, D, delta_bias, B_proj_bias, C_proj_bias, delta_softplus)

### Mamba

- Mamba in Transformers (https://github.com/huggingface/transformers/blob/main/src/transformers/models/mamba/modeling_mamba.py)

In [23]:
from torch import nn

class Mamba(nn.Module):
    """ Compute \Delta, A, B, C, and D the state space parameters and compute the `contextualized_states`.
        A, D are input independent, and \Delta, B, C are input-dependent. This is why Mamba is called
        selective state spaces
    """
    
    def __init__(self, layer_idx: int):
        super().__init__()
        self.config = config
        self.hidden_size = config.hidden_size
        self.ssm_state_size = config.state_size
        self.conv_kernel_size = config.conv_kernel
        self.intermediate_size = config.intermediate_size
        self.time_step_rank = int(config.time_step_rank)
        self.layer_idx = layer_idx
        self.use_conv_bias = config.use_conv_bias
        self.conv1d = nn.Conv1d(
            in_channels=self.intermediate_size,
            out_channels=self.intermediate_size,
            bias=config.use_conv_bias,
            kernel_size=config.conv_kernel,
            groups=self.intermediate_size,
            padding=config.conv_kernel - 1,
        )
        
        self.activation = confgi.hidden_act
        self.act = ACT2FN[config.hidden_act]
        
        self.use_mambapy = config.use_mambapy
        
        # projection of the input hidden states
        self.in_proj = nn.Linear(self.hidden_size, self.intermediate_size * 2, bias=config.use_bias)
        # selective projection used to make dt, B and C input dependant
        self.x_proj = nn.Linear(self.intermediate_size, self.time_step_rank + self.ssm_state_size * 2, bias=False)
        # time step projection (discretization)
        self.dt_proj = nn.Linear(self.time_step_rank, self.intermediate_size, bias=True)
        
        # S4D real initialization. There are not discretized!
        # The core is to load them, compute the discrete states, then write the update state. Keeps the memory bounded
        A = torch.arange(1, self.ssm_state_size + 1, dtype=torch.float32)[None, :]
        A = A.expand(self.intermediate_size, -1).contiguous()
        
        self.A_log = nn.Parameter(torch.log(A))
        self.D = nn.Parameter(torch.ones(self.intermediate_size))
        self.out_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.use_bias)
        self.use_bias = config.use_bias

    def cuda_kernels_forward(self, hidden_states: torch.Tensor):
        # 1. Gated MLP's linear projection
        projected_states = self.in_proj(hidden_states).transpose(1, 2)
        
        contextualized_states = mamba_inner_fn(
            projected_states,
            self.conv1d.weight,
            self.conv1d.bias if self.use_conv_bias else None,
            self.x_proj.weight,
            self.dt_proj.weight,
            self.out_proj.weight,
            self.out_proj.bias.float() if self.use_bias else None,
            -torch.exp(self.a_log.float()),
            None, # input-dependent B
            None, # input-dependent C
            self.D.float(),
            delta_bias=self.dt_proj.bias.float(),
            delta_softplus=True,
        )
        
        return contextualized_states

## Mamba2 Model

In [25]:
class MambaSplitConv1dScanCombinedFn(torch.autograd.Function):

    @staticmethod
    @custom_fwd
    def forward(ctx, zxbcdt, conv1d_weight, conv1d_bias, dt_bias, A, D, chunk_size, initial_states=None, seq_idx=None, dt_limit=(0.0, float("inf")), return_final_states=False, activation="silu",
                rmsnorm_weight=None, rmsnorm_eps=1e-6, outproj_weight=None, outproj_bias=None, headdim=None,
                ngroups=1, norm_before_gate=True):
        assert activation in [None, "silu", "swish"]
        if D.dim() == 1:
            assert headdim is not None
            nheads, = D.shape
        else:
            nheads, headdim = D.shape
        batch, seqlen, _ = zxbcdt.shape
        dim = nheads * headdim
        assert nheads % ngroups == 0
        dstate = (conv1d_weight.shape[0] - dim) // ngroups // 2
        d_nonssm = (zxbcdt.shape[-1] - 2 * dim - 2 * ngroups * dstate - nheads) // 2
        assert d_nonssm >= 0
        assert zxbcdt.shape == (batch, seqlen, 2 * d_nonssm + 2 * dim + 2 * ngroups * dstate + nheads)
        assert dt_bias.shape == (nheads,)
        assert A.shape == (nheads,)
        zx0, z, xBC, dt = torch.split(zxbcdt, [2 * d_nonssm, dim, dim + ngroups * dstate * 2, nheads], dim=-1)
        seq_idx = seq_idx.contiguous() if seq_idx is not None else None
        xBC_conv = rearrange(
            causal_conv1d_cuda.causal_conv1d_fwd(rearrange(xBC, "b s d -> b d s"),
                                                 conv1d_weight, conv1d_bias, seq_idx, None, None, activation in ["silu", "swish"]),
            "b d s -> b s d"
        )
        x, B, C = torch.split(xBC_conv, [dim, ngroups * dstate, ngroups * dstate], dim=-1)
        x = rearrange(x, "b l (h p) -> b l h p", h=nheads)
        B = rearrange(B, "b l (g n) -> b l g n", g=ngroups)
        C = rearrange(C, "b l (g n) -> b l g n", g=ngroups)
        z = rearrange(z, "b l (h p) -> b l h p", h=nheads) if z is not None else None
        if rmsnorm_weight is None:
            out, out_x, dt_out, dA_cumsum, states, final_states = _mamba_chunk_scan_combined_fwd(x, dt, A, B, C, chunk_size=chunk_size, D=D, z=z, dt_bias=dt_bias, initial_states=initial_states, seq_idx=seq_idx, dt_softplus=True, dt_limit=dt_limit)
            out = rearrange(out, "b s h p -> b s (h p)")
            rstd = None
            if d_nonssm > 0:
                out = torch.cat([_swiglu_fwd(zx0), out], dim=-1)
        else:
            out_x, _, dt_out, dA_cumsum, states, final_states = _mamba_chunk_scan_combined_fwd(x, dt, A, B, C, chunk_size=chunk_size, D=D, z=None, dt_bias=dt_bias, initial_states=initial_states, seq_idx=seq_idx, dt_softplus=True, dt_limit=dt_limit)
            # reshape input data into 2D tensor
            x_rms = rearrange(out_x, "b s h p -> (b s) (h p)")
            z_rms = rearrange(z, "b s h p -> (b s) (h p)")
            rmsnorm_weight = rmsnorm_weight.contiguous()
            if d_nonssm == 0:
                out = None
            else:
                out01 = torch.empty((batch, seqlen, d_nonssm + dim), dtype=x_rms.dtype, device=x_rms.device)
                out = rearrange(out01[..., d_nonssm:], "b s d -> (b s) d")
                _swiglu_fwd(zx0, out=out01[..., :d_nonssm])
            out, _, rstd = _layer_norm_fwd(x_rms, rmsnorm_weight, None, rmsnorm_eps, z_rms, out=out,
                                           group_size=dim // ngroups,
                                           norm_before_gate=norm_before_gate, is_rms_norm=True)
            if d_nonssm == 0:
                out = rearrange(out, "(b s) d -> b s d", b=batch)
            else:
                out = out01
        ctx.outproj_weight_dtype = outproj_weight.dtype if outproj_weight is not None else None
        if outproj_weight is not None:
            if torch.is_autocast_enabled():
                dtype = torch.get_autocast_gpu_dtype()
                out, outproj_weight = out.to(dtype), outproj_weight.to(dtype)
                outproj_bias = outproj_bias.to(dtype) if outproj_bias is not None else None
            out = F.linear(out, outproj_weight, outproj_bias)
        else:
            assert outproj_bias is None
        ctx.save_for_backward(zxbcdt, conv1d_weight, conv1d_bias,
                              out_x, A, D, dt_bias, initial_states, seq_idx, rmsnorm_weight, rstd, outproj_weight, outproj_bias)
        ctx.dt_limit = dt_limit
        ctx.return_final_states = return_final_states
        ctx.activation = activation
        ctx.rmsnorm_eps = rmsnorm_eps
        ctx.norm_before_gate = norm_before_gate
        ctx.chunk_size = chunk_size
        ctx.headdim = headdim
        ctx.ngroups = ngroups
        return out if not return_final_states else (out, final_states)

    @staticmethod
    @custom_bwd
    def backward(ctx, dout, *args):
        zxbcdt, conv1d_weight, conv1d_bias, out, A, D, dt_bias, initial_states, seq_idx, rmsnorm_weight, rstd, outproj_weight, outproj_bias = ctx.saved_tensors
        dfinal_states = args[0] if ctx.return_final_states else None
        headdim = ctx.headdim
        nheads = D.shape[0]
        dim = nheads * headdim
        assert nheads % ctx.ngroups == 0
        dstate = (conv1d_weight.shape[0] - dim) // ctx.ngroups // 2
        d_nonssm = (zxbcdt.shape[-1] - 2 * dim - 2 * ctx.ngroups * dstate - nheads) // 2
        assert d_nonssm >= 0
        recompute_output = outproj_weight is not None
        if recompute_output:
            out_recompute = torch.empty(*out.shape[:2], d_nonssm + dim, device=out.device, dtype=out.dtype)
            out0_recompute, out1_recompute = out_recompute.split([d_nonssm, dim], dim=-1)
        zx0, z, xBC, dt = torch.split(zxbcdt, [2 * d_nonssm, dim, dim + 2 * ctx.ngroups * dstate, nheads], dim=-1)
        # Recompute x, B, C
        xBC_conv = rearrange(
            causal_conv1d_cuda.causal_conv1d_fwd(rearrange(xBC, "b s d -> b d s"),
                                                 conv1d_weight, conv1d_bias, seq_idx, None, None, ctx.activation in ["silu", "swish"]),
            "b d s -> b s d"
        )
        x, B, C = torch.split(xBC_conv, [dim, ctx.ngroups * dstate, ctx.ngroups * dstate], dim=-1)
        x = rearrange(x, "b l (h p) -> b l h p", h=nheads)
        B = rearrange(B, "b l (g n) -> b l g n", g=ctx.ngroups)
        C = rearrange(C, "b l (g n) -> b l g n", g=ctx.ngroups)
        dzxbcdt = torch.empty_like(zxbcdt)
        dzx0, dz, dxBC_given, ddt_given = torch.split(dzxbcdt, [2 * d_nonssm, dim, dim + 2 * ctx.ngroups * dstate, nheads], dim=-1)
        dxBC = torch.empty_like(xBC)
        dx, dB, dC = torch.split(dxBC, [dim, ctx.ngroups * dstate, ctx.ngroups * dstate], dim=-1)
        z = rearrange(z, "b l (h p) -> b l h p", h=nheads)
        dx = rearrange(dx, "b l (h p) -> b l h p", h=nheads)
        dB = rearrange(dB, "b l (g n) -> b l g n", g=ctx.ngroups)
        dC = rearrange(dC, "b l (g n) -> b l g n", g=ctx.ngroups)
        if outproj_weight is not None:
            dout_og = dout
            dout = F.linear(dout, outproj_weight.t())
        if d_nonssm > 0:
            dout0, dout = dout.split([d_nonssm, dim], dim=-1)
            _swiglu_bwd(zx0, dout0, dxy=dzx0, recompute_output=True, out=out0_recompute)
        dout = rearrange(dout, "b s (h p) -> b s h p", p=headdim)
        if rmsnorm_weight is None:
            dz = rearrange(dz, "b l (h p) -> b l h p", h=nheads)
            dx, ddt, dA, dB, dC, dD, dz, ddt_bias, dinitial_states, *rest = _mamba_chunk_scan_combined_bwd(
                dout, x, dt, A, B, C, out, ctx.chunk_size, D=D, z=z, dt_bias=dt_bias, initial_states=initial_states, dfinal_states=dfinal_states, seq_idx=seq_idx, dt_softplus=True, dt_limit=ctx.dt_limit, dx=dx, ddt=ddt_given, dB=dB, dC=dC, dz=dz, recompute_output=recompute_output
            )
            out_for_linear = rearrange(rest[0], "b s h p -> b s (h p)") if recompute_output else None
            drmsnorm_weight = None
        else:
            batch = dout.shape[0]
            dy_rms = rearrange(dout, "b s h p -> (b s) (h p)")
            dz = rearrange(dz, "b l d -> (b l) d")
            x_rms = rearrange(out, "b s h p -> (b s) (h p)")
            z_rms = rearrange(z, "b s h p -> (b s) (h p)")
            out1_recompute = rearrange(out1_recompute, "b s d -> (b s) d") if recompute_output else None
            dout, drmsnorm_weight, _, dz, *rest = _layer_norm_bwd(dy_rms, x_rms, rmsnorm_weight, None, ctx.rmsnorm_eps, None, rstd, z_rms, group_size=dim//ctx.ngroups, norm_before_gate=ctx.norm_before_gate, is_rms_norm=True, recompute_output=recompute_output, dz=dz, out=out1_recompute if recompute_output else None)
            out_for_linear = out_recompute if recompute_output else None
            dout = rearrange(dout, "(b s) (h p) -> b s h p", b=batch, p=headdim)
            dx, ddt, dA, dB, dC, dD, _, ddt_bias, dinitial_states = _mamba_chunk_scan_combined_bwd(
                dout, x, dt, A, B, C, out, ctx.chunk_size, D=D, z=None, dt_bias=dt_bias, initial_states=initial_states, dfinal_states=dfinal_states, seq_idx=seq_idx, dt_softplus=True, dt_limit=ctx.dt_limit, dx=dx, ddt=ddt_given, dB=dB, dC=dC
            )

        if outproj_weight is not None:
            doutproj_weight = torch.einsum("bso,bsd->od", dout_og, out_for_linear)
            doutproj_bias = dout_og.sum(dim=(0, 1)) if outproj_bias is not None else None
        else:
            doutproj_weight, doutproj_bias = None, None
        dxBC_given = rearrange(dxBC_given, "b s d -> b d s")
        dxBC_given, dweight, dbias, *_ = causal_conv1d_cuda.causal_conv1d_bwd(
            rearrange(xBC, "b s d -> b d s"), conv1d_weight, conv1d_bias,
            rearrange(dxBC, "b s d -> b d s"), seq_idx, None, None, dxBC_given, False, ctx.activation in ["silu", "swish"]
        )
        dxBC_given = rearrange(dxBC_given, "b d s -> b s d")
        return dzxbcdt, dweight, dbias, ddt_bias, dA, dD, None, dinitial_states, None, None, None, None, drmsnorm_weight, None, doutproj_weight, doutproj_bias, None, None, None


def mamba_split_conv1d_scan_combined(zxbcdt, conv1d_weight, conv1d_bias, dt_bias, A, D, chunk_size, initial_states=None, seq_idx=None, dt_limit=(0.0, float("inf")), return_final_states=False, activation="silu", rmsnorm_weight=None, rmsnorm_eps=1e-6, outproj_weight=None, outproj_bias=None, headdim=None, ngroups=1, norm_before_gate=True):
    """
    Argument:
        zxbcdt: (batch, seqlen, 2 * dim + 2 * ngroups * dstate + nheads) where dim == nheads * headdim
        conv1d_weight: (dim + 2 * ngroups * dstate, width)
        conv1d_bias: (dim + 2 * ngroups * dstate,)
        dt_bias: (nheads,)
        A: (nheads)
        D: (nheads, headdim) or (nheads,)
        initial_states: (batch, nheads, headdim, dstate)
        seq_idx: (batch, seqlen), int32
        rmsnorm_weight: (dim,)
        outproj_weight: (out_dim, dim)
        outproj_bias: (out_dim,)
        headdim: if D is 1D, headdim must be passed in
        norm_before_gate: if True, we do RMSNorm(x) * F.silu(z). If False, we do RMSNorm(x * F.silu(z))
    Return:
        out: (batch, seqlen, dim)
    """
    return MambaSplitConv1dScanCombinedFn.apply(zxbcdt, conv1d_weight, conv1d_bias, dt_bias, A, D, chunk_size, initial_states, seq_idx, dt_limit, return_final_states, activation, rmsnorm_weight, rmsnorm_eps, outproj_weight, outproj_bias, headdim, ngroups, norm_before_gate)

In [27]:
class Mamba2(nn.Module):
    """ Compute \Delta, A, B, C, and D the state space parameters and compute the `contextualized_states`.
        A, D are input independent, and \Delta, B, C are input-dependent.
    """
    
    def __init__(self, layer_idx: int):
        super().__init__()
        self.config = config
        self.hidden_size = config.hidden_size
        self.ssm_state_size = config.state_size
        self.conv_kernel_size = config.conv_kernel
        self.intermediate_size = config.intermediate_size
        self.time_step_rank = int(config.time_step_rank)
        self.layer_idx = layer_idx
        self.use_conv_bias = config.use_conv_bias
        
        self.activation = config.hidden_act
        self.act = ACT2FN[config.hidden_act]
        
        self.layer_norm_epsilon = config.layer_norm_epsilon
        self.rms_norm = config.rms_norm
        
        self.n_groups = config.n_groups
        self.head_dim = config.head_dim
        self.chunk_size = config.chunk_size
        
        self.time_step_limit = config.time_step_limit
        self.time_step_min = config.time_step_min
        self.time_step_max = config.time_step_max
        
        self.conv_dim = self.intermediate_size + 2 * self.n_groups * self.ssm_state_size
        self.conv1d = nn.Conv1d(
            in_channels=self.conv_dim, # intermediate_size in Mamba1
            out_channels=self.conv_dim, # intermediate_size in Mamba1
            bias=config.use_conv_bias,
            kernel_size=config.conv_kernel,
            groups=self.conv_dim, # intermediate_size in Mamba1
            padding=config.conv_kernel - 1,
        )
        
        # projection of the input hidden states
        projection_size = self.intermediate_size + self.conv_dim + self.num_heads
        self.in_proj = nn.Linear(self.hidden_size, projection_size, bias=config.use_bias) # selective projection
        
        # time step projection (discretization) - instantiate once and copy inv_dt in init_weights of PretrainedModel
        self.dt_bias = nn.Parameter(torch.ones(self.num_heads)) # no x_proj and dt_proj, compared to Mamba1
        
        # S4D real initialization. These are not discretized!
        # The core is to load them, compute the discrete states, then write the updated state. Keeps the memory bounded
        A = torch.arange(1, self.num_heads + 1)
        self.A_log = nn.Parameter(torch.log(A))
        self.A_log._no_weight_decay = True
        self.norm = MambaRMSNormGated(self.intermediate_size, eps=self.layer_norm_epsilon) # added, compared to Mamba1
        self.D = nn.Parameter(torch.ones(self.num_heads))
        self.D._no_weight_decay = True
        
        self.out_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.use_bias)
        self.use_bias = config.use_bias
        
    def cuda_kernels_forward(self, hidden_states: torch.Tensor):
        batch_size, seq_len, _ = hidden_states.shape
        groups_time_state_size = self.n_groups * self.ssm_state_size
        d_to_remove = 2 * self.intermediate_size + 2 * self.n_groups * self.ssm_state_size + self.num_heads
        
        projected_states = self.in_proj(hidden_states)
        A = -torch.exp(self.A_log.float()) # (num_heads) or (intermediate_size, state_size)
        dt_limit_kwargs = {} if self.time_step_limit == (0., float("inf")) else {"dt_limit": self.time_step_limit}
        out, ssm_state = mamba_split_conv1d_scan_combined(
            projected_states,
            self.conv1d.weight.squeeze(1),
            self.conv1d.bias,
            self.dt_bias,
            A,
            D=self.D,
            chunk_size=self.chunk_size,
            seq_idx=None, # was seq_idx
            activation=self.activation,
            rmsnorm_weight=self.norm.weight,
            rmsnorm_eps=self.norm.variance_epsilon,
            outproj_weight=self.out_proj.weight,
            outproj_bias=self.out_proj.bias,
            headdim=self.head_dim,
            ngroups=self.n_groups,
            norm_before_gate=False,
            return_final_states=True,
            **dt_limit_kwargs,
        )
        return out

# Mamba vs Transformers

In [5]:
# for Mamba1
from mamba_ssm.ops.selective_scan_interface import mamba_inner_fn, selective_scan_fn
from mamba_ssm.ops.triton.selective_state_update import selective_state_update

# for Mamba2
from mamba_ssm.ops.triton.selective_state_update import selective_state_update
from mamba_ssm.ops.triton.ssd_combined import mamba_chunk_scan_combined, mamba_split_conv1d_scan_combined

In [26]:
!ls /root/dataset_sj/hf_cache

accelerate   models--stabilityai--stable-diffusion-2-base  token
hub	     models--state-spaces--mamba-130m-hf
huggingface  stable-diffusion-2-base


huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


In [11]:
cache_dir = "/root/dataset_sj/hf_cache"

In [13]:
from transformers import MambaConfig, MambaForCausalLM, AutoTokenizer
import torch

tokenizer = AutoTokenizer.from_pretrained("state-spaces/mamba-130m-hf", cache_dir=cache_dir)
model = MambaForCausalLM.from_pretrained("state-spaces/mamba-130m-hf", cache_dir=cache_dir)
input_ids = tokenizer("Hey, how are you doing?", return_tensors="pt")["input_ids"]

out = model.generate(input_ids, max_new_tokens=10)
print(tokenizer.batch_decode(out))

The attention mask is not set and cannot be inferred from input because pad token is same as eos token. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.


["Hey, how are you doing?\n\nI'm so glad you're here."]


In [17]:
from transformers.models.mamba import MambaConfig, MambaModel
from transformers.models.mamba2 import Mamba2Config, Mamba2Model

Configuration of Mamba1
- vocab_size (int, optional, defaults to 50280) — Vocabulary size of the MAMBA model. Defines the number of different tokens that can be represented by the inputs_ids passed when calling MambaModel.
- hidden_size (int, optional, defaults to 768) — Dimensionality of the embeddings and hidden states.
- state_size (int, optional, defaults to 16) — shape of the state space latents.
- num_hidden_layers (int, optional, defaults to 32) — Number of hidden layers in the model.
- layer_norm_epsilon (float, optional, defaults to 1e-05) — The epsilon to use in the layer normalization layers.
- expand (int, optional, defaults to 2) — Expanding factor used to determine the intermediate size.
- conv_kernel (int, optional, defaults to 4) — Size of the convolution kernel.
- use_bias (bool, optional, defaults to False) — Whether or not to use bias in [“in_proj”, “out_proj”] of the mixer block
- use_conv_bias (bool, optional, defaults to True) — Whether or not to use bias in the convolution layer of the mixer block.
- hidden_act (str, optional, defaults to "silu") — The non-linear activation function (function or string) in the decoder.
- initializer_range (float, optional, defaults to 0.1) — The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
- residual_in_fp32 (bool, optional, defaults to True) — Whether or not residuals should be in float32. If set to False residuals will keep the same dtype as the rest of the model
- time_step_rank (Union[int,str], optional, defaults to "auto") — Rank of the discretization projection matrix. "auto" means that it will default to math.ceil(self.hidden_size / 16)
- time_step_scale (float, optional, defaults to 1.0) — Scale used used to scale dt_proj.bias.
- time_step_min (float, optional, defaults to 0.001) — Minimum time_step used to bound dt_proj.bias.
- time_step_max (float, optional, defaults to 0.1) — Maximum time_step used to bound dt_proj.bias.
- time_step_init_scheme (float, optional, defaults to "random") — Init scheme used for dt_proj.weight. Should be one of ["random","uniform"]
- time_step_floor (float, optional, defaults to 0.0001) — Minimum clamping value of the dt_proj.bias layer initialization.
- rescale_prenorm_residual (bool, optional, defaults to False) — Whether or not to rescale out_proj weights when initializing.
- use_mambapy (bool, optional, defaults to False) — Determines the fallback strategy during training if the CUDA-based official implementation of Mamba is not avaiable. If True, the mamba.py implementation is used. If False, the naive and slower implementation is used. Consider switching to the naive version if memory is limited.

In [32]:
mamba1 = MambaModel(MambaConfig())

In [40]:
model_size = 0
for param in mamba1.parameters():
    model_size += param.data.nelement()
print(model_size)

159308544


Configuration of Mamba2
- num_heads (int, optional, defaults to 128) — Number of heads for the evolution matrices of mamba 2.
- head_dim (int, optional, defaults to 64) — Dimension of each head.
- vocab_size (int, optional, defaults to 32768) — Vocabulary size of the MAMBA2 model. Defines the number of different tokens that can be represented by the inputs_ids passed when calling Mamba2Model.
- hidden_size (int, optional, defaults to 4096) — Dimensionality of the embeddings and hidden states.
- state_size (int, optional, defaults to 128) — shape of the state space latents.
- num_hidden_layers (int, optional, defaults to 64) — Number of hidden layers in the model.
- layer_norm_epsilon (float, optional, defaults to 1e-05) — The epsilon to use in the layer normalization layers.
- pad_token_id (int, optional, defaults to 1) — Padding token id.
- bos_token_id (int, optional, defaults to 0) — The id of the beginning of sentence token in the vocabulary.
- eos_token_id (int, optional, defaults to 2) — The id of the end of sentence token in the vocabulary.
- expand (int, optional, defaults to 2) — Expanding factor used to determine the intermediate size.
- conv_kernel (int, optional, defaults to 4) — Size of the convolution kernel.
- n_groups (int, optional, defaults to 8) — Number of groups for the evolution matrices of mamba 2.
- use_bias (bool, optional, defaults to False) — Whether or not to use bias in [“in_proj”, “out_proj”] of the mixer block
- use_conv_bias (bool, optional, defaults to True) — Whether or not to use bias in the convolution layer of the mixer block.
- hidden_act (str, optional, defaults to "silu") — The non-linear activation function (function or string) in the decoder.
- initializer_range (float, optional, defaults to 0.1) — The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
- residual_in_fp32 (bool, optional, defaults to True) — Whether or not residuals should be in float32. If set to False residuals will keep the same dtype as the rest of the model
- time_step_rank (Union[int,str], optional, defaults to "auto") — Rank of the discretization projection matrix. "auto" means that it will default to math.ceil(self.hidden_size / 16)
- time_step_min (float, optional, defaults to 0.001) — Minimum time_step used to bound dt_proj.bias.
- time_step_max (float, optional, defaults to 0.1) — Maximum time_step used to bound dt_proj.bias.
- time_step_floor (float, optional, defaults to 0.0001) — Minimum clamping value of the dt_proj.bias layer initialization.
- time_step_limit (tuple, optional, defaults to (0.0, inf)) — Accepted range of time step values.
- rescale_prenorm_residual (bool, optional, defaults to False) — Whether or not to rescale out_proj weights when initializing.
- use_cache (bool, optional, defaults to True) — Whether or not the cache should be used.
- rms_norm (bool, optional, defaults to True) — Whether to use RMS norm or not.
- chunk_size (int, optional, defaults to 256) — Size of the chunks that will comprise the sequence.
- tie_word_embeddings (bool, optional, defaults to False) — Whether to tie word embeddings or not.

In [39]:
mamba2 = Mamba2Model(Mamba2Config())

In [41]:
model_size = 0
for param in mamba2.parameters():
    model_size += param.data.nelement()
print(model_size)

7151185920
