In [2]:
from functools import partial
import jax
import jax.numpy as np
from flax import linen as nn
from jax.nn.initializers import lecun_normal, normal
from jax.numpy.linalg import eigh, inv, matrix_power
from jax.scipy.signal import convolve


In [6]:
if __name__ == "__main__":
    rng = jax.random.PRNGKey(1)

In [14]:
# Create a random state space model
# Note that A, B and C are learnable parameters
def random_SSM(rng, N):
    a_r, b_r, c_r = jax.random.split(rng, 3)
    print(a_r, b_r, c_r)
    A = jax.random.uniform(a_r, (N,N))
    B = jax.random.uniform(b_r, (N,1))
    C = jax.random.uniform(c_r, (1,N))
    return A, B, C

In [15]:
N = 1
print(random_SSM(rng, N))

[3568232559  713140391] [3620866055 2761185182] [2442350529 2639019713]
(Array([[0.44898927]], dtype=float32), Array([[0.4015354]], dtype=float32), Array([[0.86350095]], dtype=float32))


In [None]:
# The recurrent representation is represented using a discrete-time SSM
# To change our formulation from a continuous to discrete representation, we use the bi-linear method
# to generate our recurrent representation
# Ab = (A-(delta/2)*A)(A+(delta/2)*A)
# where, delta is the step size
def discrete_SSM(A, B, C, delta):
    I = np.eye(A.shape(0))
    AI = inv(I-(delta/2.0)*A)
    Ab = AI @ (I+(delta/2.0)*A)
    Bb = (AI*delta) @ B
    Cb = C
    return Ab, Bb, Cb

# By discretizing the state space we allow for a recurrent relationship

In [18]:
# the scan function is useful as it provides an optimized JAX specific recurrent update of input states
def scan_SSM(Ab, Bb, Cb, u, x0):
     def step(x_k_1, u_k):
         x_k = Ab @ x_k_1 + Bb @ u_k
         y_k = Cb @ x_k
         return x_k, y_k?
     return jax.lax.scan(step, u, x0)

In [None]:
# Combine everything and run
# 1) discretize 
# 2) run recurrent step

In [22]:
def run(A, B, C, u):
    L = u.shape[0]
    N = A.shape[0]
    Ab, Bb, Cb = discrete_SSM(A, B, C, delta=1.0/L)
    return scan_SSM(Ab, Bb, Cb, u, np.zeros((N,)))

In [23]:
A, B, C = random_SSM(rng, N)
run(A, B, C, u)

[3568232559  713140391] [3620866055 2761185182] [2442350529 2639019713]


NameError: name 'u' is not defined