In [None]:
# https://github.com/johnma2006/candle/tree/main/candle/models/mamba
# https://github.com/PeaBrane/mamba-tiny
# https://github.com/vedant-jumle/Mamba-tf/tree/main
# https://medium.com/data-science/mamba-ssm-theory-and-implementation-in-keras-and-tensorflow-32d6d4b32546
# https://scholarqa.allen.ai/chat/43312fa8-7f2c-4267-b237-3a07d2fe78e0?code=0qrg58wHRsAVXM0YSd91UVml8YDdw0b29uXPv5fGzGhaw&state=dmxJLnVpUkdXdG1sWHlQVXZkQ1FBX2YxNGt4aXBYZ3lOWXBCTklZQWEtVg%3D%3D&profile=corpus-qa-only
# https://www.maartengrootendorst.com/blog/mamba/
# https://github.com/state-spaces/mamba/blob/main/mamba_ssm/modules/mamba2_simple.py
# https://towardsdatascience.com/mamba-ssm-theory-and-implementation-in-keras-and-tensorflow-32d6d4b32546/
# https://srush.github.io/annotated-s4/

In [25]:
import jax.numpy as jnp
import optax
import math
from flax import nnx
import numpy as np
from einops import einsum, rearrange
import jax


In [22]:
def selective_scan(
        u: jnp.ndarray,
        delta: jnp.ndarray,
        A: jnp.ndarray,
        B: jnp.ndarray,
        C: jnp.ndarray,
        D: jnp.ndarray,
) -> jnp.ndarray:

    (b, l, d_in) = u.shape
    n = A.shape[1]
    # Discretize continuous parameters (A, B)
    deltaA = jnp.exp(einsum(delta, A, "b l d_in, d_in n -> b l d_in n"))
    deltaB_u = einsum(
            delta, B, u, "b l d_in, b l n, b l d_in -> b l d_in n"
    )

    # Perform selective scan (see scan_SSM() in The Annotated S4 [2])
    # Note that the below is sequential, while the official implementation does a much faster parallel scan that is additionally hardware-aware (like FlashAttention). # noqa: E501
    x = jnp.zeros((b, d_in, n))
    ys = []
    for i in range(l):
        x = deltaA[:, i] * x + deltaB_u[:, i]
        y = einsum(x, C[:, i, :], "b d_in n, b n -> b d_in")
        ys.append(y)

    y = jnp.stack(ys, axis=1)  # shape (b, l, d_in)

    y = y + u * D

    return y





In [23]:
def softplus(x: jnp.ndarray) -> jnp.ndarray:
    """
    Applies the Softplus activation function to the input tensor.

    Args:
        x (np.ndarray): Input tensor.

    Returns:
        np.ndarray: Output tensor after applying Softplus.
    """
    return jnp.log(1 + jnp.exp(x))

In [24]:
def ssm(self, x: np.ndarray) -> np.ndarray:
    """
    Runs the SSM. See:
        - Algorithm 2 in Section 3.2 in the Mamba paper [1] [1].
        - run_SSM(A, B, C, u) in The Annotated S4 [2]

    Args:
        x (np.ndarray): shape (b, l, d_in).

    Returns:
        np.ndarray: shape (b, l, d_in).

    Official Implementation:
        mamba_inner_ref(), see https://github.com/state-spaces/mamba/blob/main/mamba_ssm/ops/selective_scan_interface.py#L311

    References:
        [1] Mamba paper: https://arxiv.org/abs/2106.16067
        [2] The Annotated S4: https://github.com/state-spaces/mamba/blob/main/mamba_ssm/modules/mamba_simple.py#L119
    """ # noqa: E501
    (d_in, n) = self.A_log.shape

    # Compute ∆, A, B, C, D (state space parameters)
    # A and D are input-independent (see Mamba paper [1], Section 3.5.2 for A's interpretation) # noqa: E501
    # ∆, B, C are input-dependent (a key difference between Mamba and linear time-invariant S4) # noqa: E501

    A = -jnp.exp(self.A_log.astype(float))  # shape (d_in, n)
    D = self.D.astype(float)

    x_dbl = self.x_proj(x)  # (b, l, dt_rank + 2*n)
    (delta, B, C) = jnp.split(
            x_dbl,
            indices_or_sections=(
                    self.args.dt_rank,
                    self.args.dt_rank + n,
                    self.args.dt_rank + 2 * n,
            ),
            axis=-1,
    )[
                    :-1
                    ]  # delta: (b, l, dt_rank). B, C: (b, l, n)
    delta = softplus(self.dt_proj(delta))  # (b, l, d_in)

    y = self.selective_scan(
            x, delta, A, B, C, D
    )  # Similar to run_SSM(A, B, C, u) in The Annotated S4 [2]

    return y

In [None]:
class MambaBlock(nnx.Module):
    def __init__(
            self,
            in_proj,
            conv1d,
            x_proj,
            dt_proj,
            A_log: jnp.ndarray,
            D: jnp.ndarray,
            out_proj,

    ):

