In [None]:
!pip install einops
!pip install tensorflow
!pip install torch

# Mamba 1 Selective Scan Algorithm

In [None]:
import tensorflow as tf
def selective_scan(u, delta, A, B, C, D):
    # first step of A_bar = exp(ΔA), i.e., ΔA
    dA = tf.einsum('bld,dn->bldn', delta, A)
    dB_u = tf.einsum('bld,bld,bln->bldn', delta, u, B)

    dA_cumsum = tf.pad(
        dA[:, 1:], [[0, 0], [1, 1], [0, 0], [0, 0]])[:, 1:, :, :]

    dA_cumsum = tf.reverse(dA_cumsum, axis=[1])  # Flip along axis 1

    # Cumulative sum along all the input tokens, parallel prefix sum,
    # calculates dA for all the input tokens parallely
    dA_cumsum = tf.math.cumsum(dA_cumsum, axis=1)

    # second step of A_bar = exp(ΔA), i.e., exp(ΔA)
    dA_cumsum = tf.exp(dA_cumsum)
    dA_cumsum = tf.reverse(dA_cumsum, axis=[1])  # Flip back along axis 1

    x = dB_u * dA_cumsum
    # 1e-12 to avoid division by 0
    x = tf.math.cumsum(x, axis=1)/(dA_cumsum + 1e-12)

    y = tf.einsum('bldn,bln->bld', x, C)

    return y + u * D

In [None]:
# random tensors for testing purposes
import tensorflow as tf
b = 5
l = 3
d = 1
n = 2
delta = tf.random.uniform((b,l,d), minval=0, maxval=1, dtype=tf.float32)  # (b,l,d)
A = tf.random.uniform((d,n), minval=0, maxval=1, dtype=tf.float32)  # (d,n)
u = tf.random.uniform((b,l,d), minval=0, maxval=1, dtype=tf.float32)  # (b,l,d)
x_new = tf.zeros((b, l, n))
B = tf.random.uniform((b,l,n), minval=0, maxval=1, dtype=tf.float32)  # (b,l,n)
C = tf.random.uniform((d,n), minval=0, maxval=1, dtype=tf.float32)  # (d,n)

In [None]:
# first step of A_bar = exp(ΔA), i.e., ΔA
dA = tf.einsum('bld,dn->bldn', delta, A)
# how does B change dependend on input u
dB_u = tf.einsum('bld,bld,bln->bldn', delta, u, B)

dA_cumsum = tf.pad(
    dA[:, 1:], [[0, 0], [1, 1], [0, 0], [0, 0]])[:, 1:, :, :]

dA_cumsum = tf.reverse(dA_cumsum, axis=[1])  # Flip along axis 1

$\ dA_{flipped} ~=~ [0 ~~~ A_l\Delta_l ~~~ A_{l-1}\Delta_{l-1} ~~~ A_{l-2}\Delta_{l-2} ~~~ ... ~~~ A_1\Delta_1]$

In [None]:
# Cumulative sum along all the input tokens, parallel prefix sum,
# calculates dA for all the input tokens parallely
dA_cumsum = tf.math.cumsum(dA_cumsum, axis=1)

$\ dA_{cumsum} ~=~ [0 ~~~~ 0+A_l\Delta_l ~~~~ 0+A_l\Delta_l + A_{l-1}\Delta_{l-1} ~~~~  ... ~~~~ 0+A_l\Delta_l+...+ A_1\Delta_1] $

In [None]:
# second step of A_bar = exp(ΔA), i.e., exp(ΔA)
dA_cumsum = tf.exp(dA_cumsum)
dA_cumsum = tf.reverse(dA_cumsum, axis=[1])  # Flip back along axis 1

$\ dA_{cumsum} ~=~ [e^{0+A_l\Delta_l+...+ A_1\Delta_1} ~~~~ e^{0+A_l\Delta_l+...+ A_2\Delta_2} ~~~~ ... ~~~~ e^{A_l\Delta_l} ~~~~ e^0] $

Simplification in Paper:
$
\overline{B} = B * \Delta
$

In [None]:
x = dB_u * dA_cumsum_exp
# 1e-12 to avoid division by 0
x_ref = tf.math.cumsum(x, axis=1)/(dA_cumsum_exp + 1e-12)

without simplification:
$
\overline{B} = \frac{(e^{A* \Delta}) - I) * B * \Delta}{A* \Delta}
$

In [None]:
# --without simplification

dB = tf.einsum('bld,bln->bldn', delta, B)


B_head = (dA * dB) / (dA + 1e-12)
B_head_u = tf.einsum('bldn,bld->bldn', B_head, u)
x_pre = B_head_u * dA_cumsum_exp
x_own = tf.math.cumsum(x_pre, axis=1)/(dA_cumsum_exp + 1e-12)

$\ h_0 = \frac{\Delta_0B_0x_0*e^{0+A_l\Delta_l+...+ A_1\Delta_1}}{e^{0+A_l\Delta_l+...+ A_1\Delta_1}}=\Delta_0B_0x_0$

$\ h_1 = \frac{\Delta_0B_0x_0*e^{0+A_l\Delta_l+...+ A_1\Delta_1} + \Delta_1B_1x_1*e^{0+A_l\Delta_l+...+ A_2\Delta_2}}{e^{0+A_l\Delta_l+...+ A_2\Delta_2}} \\ h_1 =\Delta_0B_0x_0*e^{A_1\Delta_1}+ \Delta_1B_1x_1$

$\ h_2 = \frac{\Delta_0B_0x_0*e^{0+A_l\Delta_l+...+ A_1\Delta_1} + \Delta_1B_1x_1*e^{0+A_l\Delta_l+...+ A_2\Delta_2} + \Delta_2B_2x_2*e^{0+A_l\Delta_l+...+ A_3\Delta_3}}{e^{0+A_l\Delta_l+...+ A_3\Delta_3}}\\h_2=\Delta_0B_0x_0*e^{A_2\Delta_2} * e^{A_1\Delta_1} + \Delta_1B_1x_1 * e^{A_2\Delta_2} + \Delta_2 B_2 x_2$

# Mamba 2 SSD Algorithm

In [None]:
!pip install einops
from einops import rearrange
import torch
import torch.nn.functional as F

def segsum(x):
  """Naive segment sum calculation. exp(segsum(A)) produces a 1-SS matrix,
  which is equivalent to a scalar SSM. This is for the off diagonal blocks"""
  # T = per_block_seq_len
  T = x.size(-1)
  x_cumsum = torch.cumsum(x, dim=-1)
  x_segsum = x_cumsum[..., :, None] - x_cumsum[..., None, :]
  mask = torch.tril(torch.ones(T, T, device=x.device, dtype=bool), diagonal=0)
  x_segsum = x_segsum.masked_fill(~mask, -1e9)
  return x_segsum


def ssd(X, A, B, C, block_len=64, initial_states=None):
  """
  Arguments:
  X: (batch, length, n_heads, d_head)
  A: (batch, length, n_heads)
  B: (batch, length, n_heads, d_state)
  C: (batch, length, n_heads, d_state)
  Return:
  Y: (batch, length, n_heads,d_head)
  """

  assert X.dtype == A.dtype == B.dtype == C.dtype
  assert X.shape[1] % block_len == 0
  # Rearrange into blocks/chunks
  # devide these along the time dimension in c blocks with len l
  # length: (b, c*l, n_heads, d_head/d_state) -> (b, c, l, n_heads, d_head/d_state)
  X, A, B, C = [rearrange(x, "b (c l) ... -> b c l ...", l=block_len) for x in (X, A, B, C)]
  # A is a scalar for every head
  A = rearrange(A, "b c l h -> b h c l")
  # cumsum along the time dimension
  A_cumsum = torch.cumsum(A, dim=-1)

  # 1. Compute the output for each intra-chunk (diagonal blocks)
  L = torch.exp(segsum(A))
  Y_diag = torch.einsum("bclhn,bcshn,bhcls,bcshp->bclhp", C, B, L, X)

  # 2. Compute the state for each intra-chunk
  # (right term of low-rank factorization of off-diagonal blocks; B terms)
  decay_states = torch.exp((A_cumsum[:, :, :, -1:] - A_cumsum))
  states = torch.einsum("bclhn,bhcl,bclhp->bchpn", B, decay_states, X)

  # 3. Compute the inter-chunk SSM recurrence; produces correct SSM states at chunk boundaries # (middle term of factorization of off-diag blocks; A terms)
  if initial_states is None:
    initial_states = torch.zeros_like(states[:, :1])
  states = torch.cat([initial_states, states], dim=1)
  decay_chunk = torch.exp(segsum(F.pad(A_cumsum[:, :, :, -1], (1, 0))))
  new_states = torch.einsum("bhzc,bchpn->bzhpn", decay_chunk, states)
  states, final_state = new_states[:, :-1], new_states[:, -1]

  # 4. Compute state -> output conversion per chunk
  # (left term of low-rank factorization of off-diagonal blocks; C terms) state_decay_out = torch.exp(A_cumsum)
  state_decay_out = torch.exp(A_cumsum)
  Y_off = torch.einsum('bclhn,bchpn,bhcl->bclhp', C, states, state_decay_out)

  # Add output of intra-chunk and inter-chunk terms (diagonal and off-diagonal blocks)
  Y = rearrange(Y_diag+Y_off, "b c l h p -> b (c l) h p")
  return Y, final_state


Collecting einops
  Downloading einops-0.8.0-py3-none-any.whl.metadata (12 kB)
Downloading einops-0.8.0-py3-none-any.whl (43 kB)
[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/43.2 kB[0m [31m?[0m eta [36m-:--:--[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m43.2/43.2 kB[0m [31m2.5 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: einops
Successfully installed einops-0.8.0


## Computation of L (lookahead-mask) and Y_diag (diagonal blocks output)

In [None]:
# create random tensor for debugging
n_heads = 1
batch = 1
seq_len = 6
d_head = 2
d_state = 3
block_len = 3
A = torch.randint(0, 10, (batch, seq_len, n_heads))

$ \begin{bmatrix}
A_1 & A_2 & A_3 & A_4 & A_5 & A_6
\end{bmatrix}  $

In [None]:
A_rearranged = rearrange(A, "b (c l) n-> b c l n", l=block_len)
A_rearranged = rearrange(A_rearranged, "b c l h -> b h c l")

(batch, n_heads,sequence_length/block_len, block_len)

$ \begin{bmatrix}
[A_1 & A_2 & A_3 ][A_4 & A_5 & A_6]
\end{bmatrix}  $

In [None]:
A_cumsum = torch.cumsum(A_rearranged, dim=-1)
x_segsum = A_cumsum[..., :, None] - A_cumsum[..., None, :]
T = A_cumsum.size(-1)

cumsum and broadcast to shape: (batch, n_heads, num_blocks, block_len, block_len)

$
\begin{bmatrix}
  \begin{bmatrix}
  A_1-A1 & A_1-(A_1+A_2) & A_1-(A_1+A_2+A_3) \\ A_1+A_2-A_1 & A_1+A_2-(A_1+A_2) & A_1+A_2-(A_1+A_2+A_3) \\ A_1+A_2+A_3-A_1 & A_1+A_2+A_3-(A_1+A_2) & A_1+A_2+A_3-(A_1+A_2+A_3)
  \end{bmatrix}  
  \begin{bmatrix}
  A_4-A_4 & A_4-(A_4+A_5) & A_4-(A_4+A_5+A_6) \\ A_4+A_5-A_4 & A_4+A_5-(A_4+A_5) & A_4+A_5-(A_4+A_5+A_6) \\ A_4+A_5+A_6-A_4 & A_4+A_5+A_6-(A_4+A_5) & A_4+A_5+A_6-(A_4+A_5+A_6)
  \end{bmatrix}
\end{bmatrix}  
$


=

$ \begin{bmatrix}
    \begin{bmatrix}
    0 & -A_1 & -A_2-A_3 \\
    A_2 & 0 & -A_1 \\ A_2+A_3 & A3 & 0
    \end{bmatrix}
    \begin{bmatrix}
    0 & -A_4 & -A_5-A_6 \\
    A_5 & 0 & -A_4 \\ A_5+A_6 & A6 & 0
    \end{bmatrix}
\end{bmatrix}  $

In [None]:
mask = torch.tril(torch.ones(T, T, device=x.device, dtype=bool), diagonal=0)
x_segsum = x_segsum.masked_fill(~mask, -1e9)
L = torch.exp(segsum(x_segsum))

$ \begin{bmatrix}
    \begin{bmatrix}
    1 & 0 & 0 \\
    e^{A_2} & 1 & 0 \\
    e^{A_2+A_3} & e^{A3} & 1
    \end{bmatrix}
    \begin{bmatrix}
    1 & 0 & 0 \\
    e^{A_5} & 1 & 0 \\
    e^{A_5+A_6} & e^{A6} & 1
    \end{bmatrix}
\end{bmatrix}  $

In [None]:
n_heads = 1
batch = 2
seq_len = 6
d_head = 2
d_state = 3
block_len = 3

A = torch.randint(0, 10, (batch, seq_len, n_heads))
X = torch.randint(0, 10, (batch, seq_len, n_heads, d_head))
B = torch.randint(0, 10, (batch, seq_len, n_heads, d_state))
C = torch.randint(0, 10, (batch, seq_len, n_heads, d_state))

X, A, B, C = [rearrange(x, "b (c l) ... -> b c l ...", l=block_len) for x in (X, A, B, C)]
# A is a scalar for every head
A = rearrange(A, "b c l h -> b h c l")
# cumsum along the time dimension
A_cumsum = torch.cumsum(A, dim=-1)

# 1. Compute the output for each intra-chunk (diagonal blocks)
L = torch.exp(segsum(A))

Y_diag = torch.einsum("bclhn,bcshn,bhcls,bcshp->bclhp", C.float(), B.float(), L.float(), X.float())

# outer product between C and B to get quadratic in seq_len equivalent to Q*K_T
CB_t = torch.einsum('bclhn,bcshn->bhclsn', C.float(), B.float())
# multiply with lookahead mask (only diagonal matrices are needed)
LCB_t = torch.einsum("bhclsn,bhcls->bhclsn", CB_t, L)
# multiply with X (equivalent to Values V)
# (L*Q*K_T)*V  <--> (L*C*B_T)*X
Y_diag_test = torch.einsum('bhclsn,bcshp->bclhp', LCB_t, X.float())

torch.allclose(Y_diag, Y_diag_test)

## Computation of states

In [None]:
n_heads = 1
batch = 1
seq_len = 9
d_head = 1
d_state = 1
block_len = 3

A = torch.randint(0, 10, (batch, seq_len, n_heads))
X = torch.randint(0, 10, (batch, seq_len, n_heads, d_head))
B = torch.randint(0, 10, (batch, seq_len, n_heads, d_state))
C = torch.randint(0, 10, (batch, seq_len, n_heads, d_state))

X, A, B, C = [rearrange(x, "b (c l) ... -> b c l ...", l=block_len) for x in (X, A, B, C)]
# bring head dimension to the front
A_rearranged = rearrange(A, "b c l h -> b h c l")

A_cumsum = torch.cumsum(A_rearranged, dim=-1)

decay_states = A_cumsum[:, :, :, -1:] - A_cumsum

$ \begin{bmatrix}
[e^{A_1+A_2+A_3-A_1} & e^{A_1+A_2+A_3-A_1-A_2} & e^{A_1+A_2+A_3-A_1-A_2-A_3} ][e^{A_4+A_5+A_6-A_4} & e^{A_4+A_5+A_6-A_4-A_5} & e^{A_4+A_5+A_6-A_4-A_5-A_6}]
[e^{A_7+A_8+A_9-A_7} & e^{A_7+A_8+A_9-A_7-A_8} & e^{A_7+A_8+A_9-A_7-A_8-A_9}]
\end{bmatrix}  $

=

$ \begin{bmatrix}
[e^{A_2+A_3} & e^{A_3} & 1 ][e^{A_5+A_6} & e^{A_6} & 1][e^{A_8+A_9} & e^{A_9} & 1]
\end{bmatrix}  $

In [None]:
states = torch.einsum("bclhn,bhcl,bclhp->bchpn", B.float(), decay_states.float(), X.float())
BA = torch.einsum("bclhn,bhcl->bhcln", B.float(), decay_states.float())

(batch=1, n_heads=1, n_blocks=3, block_len=3, d_state=2)

$ \begin{bmatrix}
\begin{bmatrix}
e^{A_2+A_3}*B_{11} & e^{A_2+A_3}*B_{12} \\
e^{A_3}*B_{21} & e^{A_3}*B_{22} \\
B_{31} & B_{32}
\end{bmatrix}
\begin{bmatrix}
e^{A_5+A_6}*B_{41} & e^{A_5+A_6}*B_{42} \\
e^{A_6}*B_{51} & e^{A_6}*B_{52} \\
B_{61} & B_{62}
\end{bmatrix}
\begin{bmatrix}
e^{A_8+A_9}*B_{71} & e^{A_8+A_9}*B_{72} \\
e^{A_9}*B_{81} & e^{A_9}*B_{82}] \\
B_{91} & B_{92}
\end{bmatrix}
\end{bmatrix}  $

In [None]:
# (batch, n_blocks, n_heads, d_head, d_state)
X_rearranged = rearrange(X, "b c l h p -> b h c l p")
states_test = torch.einsum("bhcln,bhclp->bchnp", BA, X_rearranged.float())
states_test = rearrange(states_test, "b c h n p -> b c h p n")
torch.allclose(states, states_test)

True

(batch=1, n_heads=1, n_blocks=3, d_state=2, block_len=3) * (batch=1, n_head=1, n_blocks=3, block_len=3, d_head=4) = (batch=1, n_blocks=3, n_heads=1, d_head=4, d_state=2)

$ \begin{bmatrix}
\begin{bmatrix}
e^{A_2+A_3}*B_{11} & e^{A_2+A_3}*B_{12} \\
e^{A_3}*B_{21} & e^{A_3}*B_{22} \\
B_{31} & B_{32}
\end{bmatrix}^T
\begin{bmatrix}
e^{A_5+A_6}*B_{41} & e^{A_5+A_6}*B_{42} \\
e^{A_6}*B_{51} & e^{A_6}*B_{52} \\
B_{61} & B_{62}
\end{bmatrix}^T
\begin{bmatrix}
e^{A_8+A_9}*B_{71} & e^{A_8+A_9}*B_{72} \\
e^{A_9}*B_{81} & e^{A_9}*B_{82}] \\
B_{91} & B_{92}
\end{bmatrix}^T
\end{bmatrix}  $

*

$
\begin{bmatrix}
  \begin{bmatrix}
  X_{11} & X_{12} & X_{13} & X_{14} \\ X_{21} & X_{22} & X_{23} & X_{24}\\ X_{31} & X_{32} & X_{33} & X_{34}
  \end{bmatrix}
  \begin{bmatrix}
  X_{41} & X_{42} & X_{43} & X_{44} \\ X_{51} & X_{52} & X_{53} & X_{54} \\ X_{61} & X_{62} & X_{63} & X_{64}
  \end{bmatrix}
  \begin{bmatrix}
  X_{71} & X_{72} & X_{73} & X_{74} \\
  X_{81} & X_{82} & X_{83} & X_{84} \\
  X_{91} & X_{92} & X_{93} & X_{94}
  \end{bmatrix}
\end{bmatrix}
$

=



$ \begin{bmatrix}
\begin{bmatrix}
e^{A_2+A_3}*B_{11} & e^{A_3}*B_{21} & B_{31} \\
e^{A_2+A_3}*B_{12} & e^{A_3}*B_{22} & B_{32}
\end{bmatrix}
\begin{bmatrix}
e^{A_5+A_6}*B_{41} & e^{A_6}*B_{51} & B_{61} \\
 e^{A_5+A_6}*B_{42} & e^{A_6}*B_{52} & B_{62}
 \end{bmatrix}
\begin{bmatrix}
e^{A_8+A_9}*B_{71} & e^{A_9}*B_{81} & B_{91} \\
e^{A_8+A_9}*B_{72} & e^{A_9}*B_{82} & B_{92}
\end{bmatrix}
\end{bmatrix}  $

*

$\begin{bmatrix}
  \begin{bmatrix}
  X_{11} & X_{12} & X_{13} & X_{14} \\ X_{21} & X_{22} & X_{23} & X_{24}\\ X_{31} & X_{32} & X_{33} & X_{34}
  \end{bmatrix}
  \begin{bmatrix}
  X_{41} & X_{42} & X_{43} & X_{44} \\ X_{51} & X_{52} & X_{53} & X_{54} \\ X_{61} & X_{62} & X_{63} & X_{64}
  \end{bmatrix}
  \begin{bmatrix}
  X_{71} & X_{72} & X_{73} & X_{74} \\
  X_{81} & X_{82} & X_{83} & X_{84} \\
  X_{91} & X_{92} & X_{93} & X_{94}
  \end{bmatrix}
\end{bmatrix} $

=



$
\begin{bmatrix}
\begin{bmatrix}
B_{11}X_{11}e^{A_2+A_3} + B_{21}X_{21}e^{A_3} + B_{31}X_{31} & B_{11}X_{12}e^{A_2+A_3} + B_{21}X_{22}e^{A_3} + B_{31}X_{32} & B_{11}X_{13}e^{A_2+A_3} + B_{21}X_{23}e^{A_3} + B_{31}X_{33} & B_{11}X_{14}e^{A_2+A_3} + B_{21}X_{24}e^{A_3} + B_{31}X_{34} \\
B_{12}X_{11}e^{A_2+A_3} + B_{22}X_{21}e^{A_3} + B_{32}X_{31} & B_{12}X_{12}e^{A_2+A_3} + B_{22}X_{22}e^{A_3} + B_{32}X_{32} & B_{12}X_{13}e^{A_2+A_3} + B_{22}X_{23}e^{A_3} + B_{32}X_{33} & B_{12}X_{14}e^{A_2+A_3} + B_{22}X_{24}e^{A_3} + B_{32}X_{34}
\end{bmatrix}^T
\begin{bmatrix}
B_{41}X_{41}e^{A_5+A_6} + B_{51}X_{51}e^{A_6} + B_{61}X_{61} & B_{41}X_{42}e^{A_5+A_6} + B_{51}X_{52}e^{A_6} + B_{61}X_{62} & B_{41}X_{43}e^{A_5+A_6} + B_{51}X_{53}e^{A_6} + B_{61}X_{63} & B_{41}X_{44}e^{A_5+A_6} + B_{51}X_{54}e^{A_6} + B_{61}X_{64} \\
B_{42}X_{41}e^{A_5+A_6} + B_{52}X_{51}e^{A_6} + B_{62}X_{61} & B_{42}X_{42}e^{A_5+A_6} + B_{52}X_{52}e^{A_6} + B_{62}X_{62} & B_{42}X_{43}e^{A_5+A_6} + B_{52}X_{53}e^{A_6} + B_{62}X_{63} & B_{42}X_{44}e^{A_5+A_6} + B_{52}X_{54}e^{A_6} + B_{62}X_{64}
\end{bmatrix}^T
\begin{bmatrix}
B_{71}X_{71}e^{A_8+A_9} + B_{81}X_{81}e^{A_9} + B_{91}X_{91} & B_{71}X_{72}e^{A_8+A_9} + B_{81}X_{82}e^{A_9} + B_{91}X_{92} & B_{71}X_{73}e^{A_8+A_9} + B_{81}X_{83}e^{A_9} + B_{91}X_{93} & B_{71}X_{74}e^{A_8+A_9} + B_{81}X_{84}e^{A_9} + B_{91}X_{94} \\
B_{72}X_{71}e^{A_8+A_9} + B_{82}X_{81}e^{A_9} + B_{92}X_{91} & B_{72}X_{72}e^{A_8+A_9} + B_{82}X_{82}e^{A_9} + B_{92}X_{92} & B_{72}X_{73}e^{A_8+A_9} + B_{82}X_{83}e^{A_9} + B_{92}X_{93} & B_{72}X_{74}e^{A_8+A_9} + B_{82}X_{84}e^{A_9} + B_{92}X_{94}
\end{bmatrix}^T
\end{bmatrix}
$

### d_state and d_head = 1

$ \begin{bmatrix}
e^{A_2+A_3} & e^{A_3} & 1 \\
e^{A_5+A_6} & e^{A_6} & 1 \\
e^{A_8+A_9} & e^{A_9} & 1
\end{bmatrix}  $
*
$ \begin{bmatrix}
B_1 & B_2 & B_3 \\
B_4 & B_5 & B_6 \\
B_7 & B_8 & B_9
\end{bmatrix}  $

=

$ \begin{bmatrix}
e^{A_2+A_3}*B_1 & e^{A_3}*B_2 & B_3 \\
e^{A_5+A_6}*B4 & e^{A_6}*B5 & B_6 \\
{A_8+A_9}*B_7 & e^{A_9}*B_8 & B_9
\end{bmatrix}  $
(batch, n_heads, n_blocks, block_len,1)

(batch, n_heads, n_blocks, 1, block_len) * (batch, n_heads, n_blocks, block_len,1) = (batch, n_heads, n_blocks, 1, 1)

$ \begin{bmatrix}
e^{A_2+A_3}*B_1 & e^{A_3}*B_2 & B_3 \\
e^{A_5+A_6}*B4 & e^{A_6}*B5 & B_6 \\
{A_8+A_9}*B_7 & e^{A_9}*B_8 & B_9
\end{bmatrix}  $
*
$
\begin{bmatrix}
X_1 & X_2 & X_3 \\
X_4 & X_5 & X_6 \\
X_7 & X_8 & X_9
\end{bmatrix}
$

=

$ \begin{bmatrix}
e^{A_2+A_3}*B_1*X_1 + e^{A_3}*B_2*X_2 + B_3*X_3 \\
e^{A_5+A_6}*B4*X_4 + e^{A_6}*B5*X_5 + B_6*X_6 \\
e^{A_8+A_9}*B_7*X_7 + e^{A_9}*B_8*X_8 + B_9*X_9
\end{bmatrix}  $

## Computation of new_states

In [None]:
n_heads = 1
batch = 1
seq_len = 9
d_head = 1
d_state = 1
block_len = 3

A = torch.randint(0, 10, (batch, seq_len, n_heads))
X = torch.randint(0, 10, (batch, seq_len, n_heads, d_head))
B = torch.randint(0, 10, (batch, seq_len, n_heads, d_state))
C = torch.randint(0, 10, (batch, seq_len, n_heads, d_state))

X, A, B, C = [rearrange(x, "b (c l) ... -> b c l ...", l=block_len) for x in (X, A, B, C)]
# bring head dimension to the front
A_rearranged = rearrange(A, "b c l h -> b h c l")

A_cumsum = torch.cumsum(A_rearranged, dim=-1)

decay_states = torch.exp((A_cumsum[:, :, :, -1:] - A_cumsum))

states = torch.einsum("bclhn,bhcl,bclhp->bchpn", B.float(), decay_states.float(), X.float())

In [None]:
initial_states = None
if initial_states is None:
  initial_states = torch.zeros_like(states[:, :1])
# (batch, n_blocks, n_heads, d_state, d_head) -> (batch, n_blocks+1, n_heads, d_state, d_head)
states = torch.cat([initial_states, states], dim=1) # dim 1 is block dimension

# the last values of every block gets selected and zero is padded at the front
# (batch, n_heads,n_blocks, block_len) -> (batch, n_heads,n_blocks+1)
A_cumsum_padded = F.pad(A_cumsum[:, :, :, -1], (1, 0))

(batch=1, n_heads=1, n_blocks=3, block_len=3)

$ \begin{bmatrix}
[A_1 & A_1+A_2 & A_1+A_2+A_3 ][A_4 & A_4+A_5 & A_4+A_5+A_6][A_7 & A_7+A_8 & A_7+A_8+A_9]
\end{bmatrix}  $


(batch,=1, n_heads=1, n_blocks+1=4)

$ \begin{bmatrix}
[0 & A_1+A_2+A_3 & A_4+A_5+A_6 & A_7+A_8+A_9]
\end{bmatrix}  $

In [None]:
decay_chunk = torch.exp(segsum(A_cumsum_padded))

$ \begin{bmatrix}
    1 & 0 & 0 & 0 \\
    e^{A_1+A_2+A_3} & 1 & 0 & 0\\
    e^{A_1+A_2+A_3+A_4+A_5+A_6}  & e^{A_4+A_5+A_6} & 1 & 0 \\
    e^{A_1+A_2+A_3+A_4+A_5+A_6+A_7+A_8+A_9}  & e^{A_4+A_5+A_6+A_7+A_8+A_9} & e^{A_7+A_8+A_9} & 1
\end{bmatrix}  $

In [None]:
# decay_chunk: (batch, n_heads, n_blocks+1, n_blocks+1)
# states:      (batch, n_blocks+1, n_heads, d_state, d_head)
# new_states:  (batch, n_blocks+1, n_heads, d_state, d_head)
new_states = torch.einsum("bhzc,bchpn->bzhpn", decay_chunk, states)


states_re = rearrange(states.float(), "b c h p n -> b h p c n")
#         decay_chunk                      x                    states_re                  =                 new_states
# (batch, n_heads, n_blocks+1, n_blocks+1) x (batch, n_heads, d_state, n_blocks+1, d_head) = (batch, n_heads, d_state, n_blocks+1, d_head)
#         (n_blocks+1, n_blocks+1)         x                 (n_blocks+1, d_head)          =             (n_blocks+1, d_head)
new_states_test = torch.matmul(decay_chunk.float(), states_re.float())
new_states_test = rearrange(new_states_test, "b h p z n -> b z h p n")
#torch.allclose(new_states.float(), new_states_test)
states, final_state = new_states[:, :-1], new_states[:, -1]

### d_head, d_state = 1

states:
$ \begin{bmatrix}
0 \\
e^{A_2+A_3}*B_1*X_1 + e^{A_3}*B_2*X_2 + B_3*X_3 \\
e^{A_5+A_6}*B4*X_4 + e^{A_6}*B5*X_5 + B_6*X_6 \\
e^{A_8+A_9}*B_7*X_7 + e^{A_9}*B_8*X_8 + B_9*X_9
\end{bmatrix}  $

decay_cunks =
$ \begin{bmatrix}
    1 & 0 & 0 & 0 \\
    e^{A_1+A_2+A_3} & 1 & 0 & 0\\
    e^{A_1+A_2+A_3+A_4+A_5+A_6}  & e^{A_4+A_5+A_6} & 1 & 0 \\
    e^{A_1+A_2+A_3+A_4+A_5+A_6+A_7+A_8+A_9}  & e^{A_4+A_5+A_6+A_7+A_8+A_9} & e^{A_7+A_8+A_9} & 1
\end{bmatrix}  $

new_states =
$ \begin{bmatrix}
0 \\
e^{A_2+A_3}B_1X_1 + e^{A_3}B_2X_2 + B_3X_3 \\
e^{A_2+A_3+A_4+A_5+A_6}B_1X_1 + e^{A_3+A_4+A_5+A_6}B_2X_2 + e^{A_4+A_5+A_6}B_3X_3 + e^{A_5+A_6}B_4X_4 + e^{A_6}B_5X_5 + B_6X_6 \\
e^{A_2+A_3+A_4+A_5+A_6+A_7+A_8+A_9}B_1X_1 + e^{A_3+A_4+A_5+A_6+A_7+A_8+A_9}B_2X_2 + e^{A_4+A_5+A_6+A_7+A_8+A_9}B_3X_3 + e^{A_5+A_6+A_7+A_8+A_9}B_4X_4 + e^{A_6+A_7+A_8+A_9}B_5X_5 + e^{A_7+A_8+A_9}B_6X_6 + e^{A_8+A_9}B_7X_7 + e^{A_9}B_8X_8 + B_9X_9
\end{bmatrix} $

**MamBa Notation**

$ \begin{bmatrix}
    1 & 0 & 0 & 0 \\
    A_{3:0} & 1 & 0 & 0\\
    A_{6:0}  & A_{6:3} & 1 & 0 \\
    A_{9:0}  & A_{9:3} & A_{9:6} & 1
\end{bmatrix}  $

*

$ \begin{bmatrix}
0 \\
A_{3:1}*B_1*X_1 + A_3*B_2*X_2 + B_3*X_3 \\
A_{6:4}*B4*X_4 + A_6*B_5*X_5 + B_6*X_6 \\
A_{9:7}*B_7*X_7 + A_9*B_8*X_8 + B_9*X_9
\end{bmatrix}  $

=

$ \begin{bmatrix}
0 \\
A_{3:1}B_1X_1 + A_3B_2X_2 + B_3X_3 \\
A_{6:1}B_1X_1 + A_{6:2}B_2X_2 + A_{6:3}B_3X_3 + A_{6:4}B_4X_4 + A_6B_5X_5 + B_6X_6 \\
A_{9:1}B_1X_1 + A_{9:2}B_2X_2 + A_{9:3}B_3X_3 + A_{9:4}B_4X_4 + A_{9:5}B_5X_5 + A_{9:6}B_6X_6 + A_{9:7}B_7X_7 + A_9B_8X_8 + B_9X_9
\end{bmatrix} $

## Computation of Y_off (off diagonal blocks output)

In [None]:
n_heads = 1
batch = 1
seq_len = 9
d_head = 1
d_state = 1
block_len = 3


A = torch.randint(0, 10, (batch, seq_len, n_heads))
X = torch.randint(0, 10, (batch, seq_len, n_heads, d_head))
B = torch.randint(0, 10, (batch, seq_len, n_heads, d_state))
C = torch.randint(0, 10, (batch, seq_len, n_heads, d_state))

X, A, B, C = [rearrange(x, "b (c l) ... -> b c l ...", l=block_len) for x in (X, A, B, C)]
# A is a scalar for every head
A = rearrange(A, "b c l h -> b h c l")
# cumsum along the time dimension
A_cumsum = torch.cumsum(A, dim=-1)

# 1. Compute the output for each intra-chunk (diagonal blocks)
L = torch.exp(segsum(A))
Y_diag = torch.einsum("bclhn,bcshn,bhcls,bcshp->bclhp", C, B, L, X)

# 2. Compute the state for each intra-chunk
# (right term of low-rank factorization of off-diagonal blocks; B terms)
decay_states = torch.exp((A_cumsum[:, :, :, -1:] - A_cumsum))
states = torch.einsum("bclhn,bhcl,bclhp->bchpn", B, decay_states, X)

# 3. Compute the inter-chunk SSM recurrence; produces correct SSM states at chunk boundaries # (middle term of factorization of off-diag blocks; A terms)
if initial_states is None:
  initial_states = torch.zeros_like(states[:, :1])
states = torch.cat([initial_states, states], dim=1)
decay_chunk = torch.exp(segsum(F.pad(A_cumsum[:, :, :, -1], (1, 0))))
new_states = torch.einsum("bhzc,bchpn->bzhpn", decay_chunk, states)
states, final_state = new_states[:, :-1], new_states[:, -1]

In [None]:
state_decay_out = torch.exp(A_cumsum)
Y_off = torch.einsum('bclhn,bchpn,bhcl->bclhp', C, states, state_decay_out)
Y = rearrange(Y_diag+Y_off, "b c l h p -> b (c l) h p")

C (batch=1, n_blocks=3, block_len=3, n_heads=1, d_state=1) =
$ \begin{bmatrix}
[C_1 & C_2 & C_3 ][C_4 & C_5 & C_6][C_7 & C_8 & C_9]
\end{bmatrix}  $

states (batch=1, n_blocks=3, n_heads=1, d_head=1, d_state=1) =
$ \begin{bmatrix}
0 &
e^{A_2+A_3}B_1X_1 + e^{A_3}B_2X_2 + B_3X_3 &
e^{A_2+A_3+A_4+A_5+A_6}B_1X_1 + e^{A_3+A_4+A_5+A_6}B_2X_2 + e^{A_4+A_5+A_6}B_3X_3 + e^{A_5+A_6}B_4X_4 + e^{A_6}B_5X_5 + B_6X_6
\end{bmatrix} $

state_decay_out (batch=1, n_heads=1, n_blocks=3, block_len=3) =


$ \begin{bmatrix}
[e^{A_1} & e^{A_1+A_2} & e^{A_1+A_2+A_3} ][e^{A_4} & e^{A_4+A_5} & e^{A_4+A_5+A_6}][e^{A_7} & e^{A_7+A_8} & e^{A_7+A_8+A_9}]
\end{bmatrix}  $

(batch, n_blocks, block_len, n_heads, d_state) * (batch, n_blocks, n_heads, d_state, d_head) = (batch, n_blocks, block_len, n_heads, d_head)

$ \begin{bmatrix}
[C_1 & C_2 & C_3 ][C_4 & C_5 & C_6][C_7 & C_8 & C_9]
\end{bmatrix}  $

*

$ \begin{bmatrix}
0 \\
e^{A_2+A_3}B_1X_1 + e^{A_3}B_2X_2 + B_3X_3 \\
e^{A_2+A_3+A_4+A_5+A_6}B_1X_1 + e^{A_3+A_4+A_5+A_6}B_2X_2 + e^{A_4+A_5+A_6}B_3X_3 + e^{A_5+A_6}B_4X_4 + e^{A_6}B_5X_5 + B_6X_6
\end{bmatrix} $

=
$\begin{bmatrix}
  \begin{bmatrix}
  0 &
  (e^{A_2+A_3}B_1X_1 + e^{A_3}B_2X_2 + B_3X_3)*C_2 &
  (e^{A_2+A_3+A_4+A_5+A_6}B_1X_1 + e^{A_3+A_4+A_5+A_6}B_2X_2 + e^{A_4+A_5+A_6}B_3X_3 + e^{A_5+A_6}B_4X_4 + e^{A_6}B_5X_5 + B_6X_6)*C_3
  \end{bmatrix}
  \begin{bmatrix}
  0 &
  (e^{A_2+A_3}B_1X_1 + e^{A_3}B_2X_2 + B_3X_3)*C_5 &
  (e^{A_2+A_3+A_4+A_5+A_6}B_1X_1 + e^{A_3+A_4+A_5+A_6}B_2X_2 + e^{A_4+A_5+A_6}B_3X_3 + e^{A_5+A_6}B_4X_4 + e^{A_6}B_5X_5 + B_6X_6)*C_6
  \end{bmatrix}
  \begin{bmatrix}
  0 &
  (e^{A_2+A_3}B_1X_1 + e^{A_3}B_2X_2 + B_3X_3)*C_8 &
  (e^{A_2+A_3+A_4+A_5+A_6}B_1X_1 + e^{A_3+A_4+A_5+A_6}B_2X_2 + e^{A_4+A_5+A_6}B_3X_3 + e^{A_5+A_6}B_4X_4 + e^{A_6}B_5X_5 + B_6X_6)*C_9
  \end{bmatrix}
\end{bmatrix} $

(batch, n_blocks, block_len, n_heads, d_head) * (batch, n_blocks, block_len) = (batch, n_blocks, block_len, n_heads, d_head)

$\begin{bmatrix}
  0 &
  (e^{A_2+A_3}B_1X_1 + e^{A_3}B_2X_2 + B_3X_3)*C_2*e^{A_1+A_2}  &
  (e^{A_2+A_3+A_4+A_5+A_6}B_1X_1 + e^{A_3+A_4+A_5+A_6}B_2X_2 + e^{A_4+A_5+A_6}B_3X_3 + e^{A_5+A_6}B_4X_4 + e^{A_6}B_5X_5 + B_6X_6)*C_3*e^{A_1+A_2+A_3} \\
  0 &
  (e^{A_2+A_3}B_1X_1 + e^{A_3}B_2X_2 + B_3X_3)*C_5*e^{A_4+A_5} &
  (e^{A_2+A_3+A_4+A_5+A_6}B_1X_1 + e^{A_3+A_4+A_5+A_6}B_2X_2 + e^{A_4+A_5+A_6} B_3X_3 + e^{A_5+A_6}B_4X_4 + e^{A_6}B_5X_5 + B_6X_6)*C_6*e^{A_4+A_5+A_6} \\
  0 &
  (e^{A_2+A_3}B_1X_1 + e^{A_3}B_2X_2 + B_3X_3)*C_8*e^{A_7+A_8} &
  (e^{A_2+A_3+A_4+A_5+A_6}B_1X_1 + e^{A_3+A_4+A_5+A_6}B_2X_2 + e^{A_4+A_5+A_6}B_3X_3 + e^{A_5+A_6}B_4X_4 + e^{A_6}B_5X_5 + B_6X_6)*C_9*e^{A_7+A_8+A_9}  
\end{bmatrix}
$