In [13]:
!pip install --upgrade -q jax jaxlib

^C
[31mERROR: Operation cancelled by user[0m[31m
[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m23.3.1[0m[39;49m -> [0m[32;49m23.3.2[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpython3.11 -m pip install --upgrade pip[0m


# Mamba Architecture
<img src="assets/mamba_architecture.png"/>

as we can see from the architecture the new is 
**SSM (Structured State Space Model)** : 
are at the core of Mamba, so it is important to note how they work. You can think of them as the replacement for the self attention mechanism in a transformer,<br/><br/>

A state space model takes in a 1D input sequence, maps to a N-D latent space and then projects back to a 1D output sequence.<br/><br/>
<img src="assets/statespace.png" width="50%"/>
<br/><br/>
<h1>SSMs</h1>
like S4 can be defined as through these equations: 

\begin{align*}
h'(t) &= Ah(t) + Bx(t) \\
y(t) &= Ch(t)
\end{align*}

<p>the first equation : </p>
<img src="assets/ht.png"/>

<p>the second equation : </p>
<img src="assets/yt.png" width="80%"/>


## Lets Implement SSM here

In [23]:
def random_SSM(N):
    a_r, b_r, c_r = jax.random.split(key, 3)
    A = jax.random.uniform(a_r, (N, N))
    B = jax.random.uniform(b_r, (N, 1))
    C = jax.random.uniform(c_r, (1, N))
    return A, B, C

In [24]:
random_SSM(2)

(Array([[0.6013386 , 0.33307862],
        [0.08670855, 0.5839832 ]], dtype=float32),
 Array([[0.02978921],
        [0.03150713]], dtype=float32),
 Array([[0.14587104, 0.45987797]], dtype=float32))

## Discretize

\begin{align*}
    A & = \exp(\Delta A) \\
    B & = (\Delta A)^{-1} (\exp(\Delta A) - I) \cdot \Delta B
\end{align*}

In [13]:
import jax
import numpy as np

seed = 1701
key = jax.random.PRNGKey(seed)

In [8]:
def discretize(A, B, delta):
    A_delta = np.exp(A @ delta)
    
    inverse_delta_A = np.linalg.inv(Delta @ A)
    B_delta = (A_delta - np.eye(A.shape[0])) @ inverse_delta_A @ Delta @ B
     
    return A_delta, B_delta

In [9]:
# Example usage
D = 3  # Replace with actual dimensions
N = 3  # Replace with actual dimensions

A = np.random.rand(D, N)  # Replace with actual values
Delta = np.random.rand(D, D)  # Replace with actual values
B = np.random.rand(D, N)  # Replace with actual values

A_delta, B_delta = discretize(A,B,Delta)

In [10]:
A_delta

array([[1.39413233, 3.13766258, 2.30211207],
       [1.44236549, 2.82693333, 2.15917945],
       [1.33927042, 1.99111211, 1.71609092]])

In [11]:
A

array([[0.48978847, 0.36320954, 0.64564645],
       [0.68602276, 0.34511721, 0.34106767],
       [0.64623685, 0.28620197, 0.01209012]])