In [1]:
!pip install --upgrade pip



In [2]:
!pip install --upgrade jax jaxlib
!pip install --upgrade flax



# Mamba Architecture
<img src="assets/mamba_architecture.png"/>

as we can see from the architecture the new is 
**SSM (Structured State Space Model)** : 
are at the core of Mamba, so it is important to note how they work. You can think of them as the replacement for the self attention mechanism in a transformer,<br/><br/>

A state space model takes in a 1D input sequence, maps to a N-D latent space and then projects back to a 1D output sequence.<br/><br/>
<img src="assets/statespace.png" width="50%"/>
<br/><br/>
<h1>SSMs</h1>
like S4 can be defined as through these equations: 

\begin{align*}
h'(t) &= Ah(t) + Bx(t) \\
y(t) &= Ch(t)
\end{align*}

<p>the first equation : </p>
<img src="assets/ht.png"/>

<p>the second equation : </p>
<img src="assets/yt.png" width="80%"/>


## Lets Implement SSM here

In [6]:
def random_SSM(N):
    a_r, b_r, c_r = jax.random.split(key, 3)
    A = jax.random.uniform(a_r, (N, N))
    B = jax.random.uniform(b_r, (N, 1))
    C = jax.random.uniform(c_r, (1, N))
    return A, B, C

In [24]:
random_SSM(2)

(Array([[0.6013386 , 0.33307862],
        [0.08670855, 0.5839832 ]], dtype=float32),
 Array([[0.02978921],
        [0.03150713]], dtype=float32),
 Array([[0.14587104, 0.45987797]], dtype=float32))

## Discretize

\begin{align*}
    A & = \exp(\Delta A) \\
    B & = (\Delta A)^{-1} (\exp(\Delta A) - I) \cdot \Delta B
\end{align*}

In [13]:
import jax
import jax.numpy as np
import flax
import flax.linen as nn
import numpy as npy

seed = 1701
key = jax.random.PRNGKey(seed)

In [5]:
def discretize(A, B, delta):
    A_delta = np.exp(A @ delta)
    
    inverse_delta_A = np.linalg.inv(Delta @ A)
    B_delta = (A_delta - np.eye(A.shape[0])) @ inverse_delta_A @ Delta @ B
     
    return A_delta, B_delta

In [14]:
# Example usage
D = 3  # Replace with actual dimensions
N = 3  # Replace with actual dimensions

A = npy.random.rand(D, N)  # Replace with actual values
Delta = npy.random.rand(D, D)  # Replace with actual values
B = npy.random.rand(D, N)  # Replace with actual values

A_delta, B_delta = discretize(A,B,Delta)

In [15]:
A_delta

Array([[2.362534 , 2.028057 , 1.9282295],
       [2.0856898, 1.8411039, 1.7657193],
       [1.4537508, 1.3298508, 1.2956617]], dtype=float32)

In [16]:
A

array([[0.64168341, 0.65988684, 0.67452471],
       [0.5218893 , 0.58767281, 0.55691602],
       [0.37728629, 0.20435282, 0.35816538]])

In [17]:
def scan_SSM(Ab, Bb, Cb, u, x0):
    def step(x_k_1, u_k):
        x_k = Ab @ x_k_1 + Bb @ u_k
        y_k = Cb @ x_k
        return x_k, y_k

    return jax.lax.scan(step, x0, u)

In [18]:
def run_SSM(A, B, C, u):
    L = u.shape[0]
    N = A.shape[0]
    Ab, Bb, Cb = discretize(A, B, C, step=1.0 / L)

    # Run recurrence
    return scan_SSM(Ab, Bb, Cb, u[:, np.newaxis], np.zeros((N,)))[1]

In [19]:
def K_conv(Ab, Bb, Cb, L):
    return np.array(
        [(Cb @ matrix_power(Ab, l) @ Bb).reshape() for l in range(L)]
    )

In [20]:
def causal_convolution(u, K, nofft=False):
    if nofft:
        return convolve(u, K, mode="full")[: u.shape[0]]
    else:
        assert K.shape[0] == u.shape[0]
        ud = np.fft.rfft(np.pad(u, (0, K.shape[0])))
        Kd = np.fft.rfft(np.pad(K, (0, u.shape[0])))
        out = ud * Kd
        return np.fft.irfft(out)[: u.shape[0]]

In [22]:
class SSMLayer(nn.Module):
    N: int
    l_max: int
    decode: bool = False

    def setup(self):
        # SSM parameters
        self.A = self.param("A", lecun_normal(), (self.N, self.N))
        self.B = self.param("B", lecun_normal(), (self.N, 1))
        self.C = self.param("C", lecun_normal(), (1, self.N))
        self.D = self.param("D", nn.initializers.ones, (1,))

        # Step parameter
        self.log_step = self.param("log_step", log_step_initializer(), (1,))

        step = np.exp(self.log_step)
        self.ssm = discretize(self.A, self.B, self.C, step=step)
        self.K = K_conv(*self.ssm, self.l_max)

        # RNN cache for long sequences
        self.x_k_1 = self.variable("cache", "cache_x_k", np.zeros, (self.N,))

    def __call__(self, u):
        if not self.decode:
            # CNN Mode
            return causal_convolution(u, self.K) + self.D * u
        else:
            # RNN Mode
            x_k, y_s = scan_SSM(*self.ssm, u[:, np.newaxis], self.x_k_1.value)
            if self.is_mutable_collection("cache"):
                self.x_k_1.value = x_k
            return y_s.reshape(-1).real + self.D * u