# model

> Basic architecture and components of mamba

In [21]:
#| default_exp model

In [22]:
#| hide
from nbdev.showdoc import *

In [23]:
#| export
import numpy as np
import jax
import jax.numpy as jnp
import jax.numpy.linalg as jla
import jax.scipy.signal as jss
from jax.nn.initializers import lecun_normal, normal
import equinox as eqx

from miniMamba.modelArgs import ModelArgs

In [24]:
#| hide
rng = jax.random.PRNGKey(0)

### State space models

State space models(SSMs) give a mapping from a $1-D$ input $u(t)$ onto an $N-D$ latent state $x(t)$, which is then projected down a $1-D$ output $y(t)$.

$$
\begin{aligned}
x'(t) &= Ax(t)+ Bu(t) \\
y(t) &= Cx(t) + Du(t)
\end{aligned}
$$
where $A,B,C,$ and $D$ are learnable parameters. 

For simplicity lets start by omitting the $D$ term. $Du$ can be viewed as an easy to compute skip connection

In [25]:
#| export
def random_SSM(rng, N):
    a_r, b_r, c_r = jax.random.split(rng, 3)
    A = jax.random.uniform(a_r, (N, N))
    B = jax.random.uniform(b_r, (N, 1))
    C = jax.random.uniform(c_r, (1, N))
    return A, B, C

To apply this SSM on a discrete sequence $(u_0,u_1,u_2,\dots)$ instead of the continous function $u(t)$, it must be discretized with step size $\Delta$ representing input resolution so that input $u_k:=u(k\Delta)$

This discretization can be performed by a bilinear transformation, mapping the state matrix $A$ to $\bar{A}$.
$$
\begin{aligned}
\bar{A} &= (I-\Delta/2 A)^{-1}(I+\Delta/2 A) \\
\bar{B} &= (I-\Delta/2 A)^{-1}\Delta B \\
\bar{C} &= C
\end{aligned}
$$

In [26]:
def discretize(A, B, C, step):
    I = jnp.eye(A.shape[0])
    BL = jla.inv(I - (step / 2.0) * A)
    Ab = BL @ (I + (step / 2.0) * A)
    Bb = (BL * step) @ B
    return Ab, Bb, C

Given this sequence to sequence mapping, the discrete SSM can be computed like n RNN with a recurrence relation in $x$.
$$
\begin{aligned}
x_{k} &= \bar{A}x_{k-1} + \bar{B}u_k \\
y_k &= \bar{C}x_k
\end{aligned}
$$

In [27]:
def scan_SSM(Ab, Bb, Cb, u, x0):
    def step(x_k_1, u_k):
        x_k = Ab @ x_k_1 + Bb @ u_k
        y_k = Cb @ x_k
        return x_k, y_k

    return jax.lax.scan(step, x0, u)

To run the SSM we first discretize and then iterate.

In [28]:
def run_SSM(A, B, C, u):
    L = u.shape[0]
    N = A.shape[0]
    Ab, Bb, Cb = discretize(A, B, C, step=1.0 / L)

    # Run recurrence
    return scan_SSM(Ab, Bb, Cb, u[:, np.newaxis], np.zeros((N,)))[1]

In [29]:
class MambaBlock(eqx.Module):
    def __init__(self, args: ModelArgs):
        pass

    def forward(self, x):
        pass

    def ssm(self, x):
        pass

    def selective_scan(self, u, delta, A, B, C, D):
        pass



In [30]:
class ResidualBlock(eqx.Module):
    def __init__(self, args: ModelArgs):
        pass

    def forward(self, x):
        pass

In [31]:
class Mamba(eqx.Module):
    def __init__(self, args: ModelArgs):
        pass

    def forward(self, x):
        pass

    @staticmethod
    def load_pretrained(pretrained_model_name: str):
        pass

In [32]:
#| hide
import nbdev; nbdev.nbdev_export()