In [2]:
from functools import partial
import jax
import jax.numpy as np
from flax import linen as nn
from jax.nn.initializers import lecun_normal, normal
from jax.numpy.linalg import eigh, inv, matrix_power
from jax.scipy.signal import convolve

In [3]:
rng = jax.random.PRNGKey(1)

$N$ is the dimension of the latent representation $x \in \mathbb{R}^N$.
$A \in \mathbb{R}^{N \times N}$ is the state transition matrix.
$B \in \mathbb{R}^{N \times 1}$ is the input matrix.
$C \in \mathbb{R}^{1 \times N}$ is the output matrix.

In [9]:
def random_SSM(rng: jax.random.PRNGKey, N: int):
    '''
    Generate random state-space model.
    :param rng: jax.random.PRNGKey random number generator
    :param N: int, dimension of the latent representation
    :return: 
        A: np.ndarray (N, N), state transition matrix
        B: np.ndarray (N, 1), input matrix
        C: np.ndarray (1, N), output matrix
    '''
    a_r, b_r, c_r = jax.random.split(rng, 3)
    A = jax.random.uniform(a_r, (N, N))
    B = jax.random.uniform(b_r, (N, 1))
    C = jax.random.uniform(c_r, (1, N))
    return A, B, C

In [10]:
def discretize(A: np.ndarray, B: np.ndarray, C: np.ndarray, step: float):
    '''
    Discretize continuous-time state-space model. We use the trapezoidal rule.
    :param A: np.ndarray (N, N), state transition matrix
    :param B: np.ndarray (N, 1), input matrix
    :param C: np.ndarray (1, N), output matrix
    :param step: float, time step of the discretization
    :return: 
        Ab: np.ndarray (N, N), discretized state transition matrix
        Bb: np.ndarray (N, 1), discretized input matrix
        C: np.ndarray (1, N), output matrix
    '''
    I = np.eye(A.shape[0])
    BL = inv(I - (step / 2.0) * A)
    Ab = BL @ (I + (step / 2.0) * A)
    Bb = (BL * step) @ B
    return Ab, Bb, C

In [11]:
def scan_SSM(Ab, Bb, Cb, u, x0):
    def step(x_k_1, u_k):
        print(f'u_k:{u_k.shape}')
        x_k = Ab @ x_k_1 + Bb @ u_k
        y_k = Cb @ x_k
        return x_k, y_k

    return jax.lax.scan(step, x0, u)

In [12]:
def run_SSM(A, B, C, u):
    L = u.shape[0]
    N = A.shape[0]
    Ab, Bb, Cb = discretize(A, B, C, step=1.0 / L)

    # Run recurrence
    return scan_SSM(Ab, Bb, Cb, u[:, np.newaxis], np.zeros((N,)))[1]


In [13]:
def example_mass(k, b, m):
    A = np.array([[0, 1], [-k / m, -b / m]])
    B = np.array([[0], [1.0 / m]])
    C = np.array([[1.0, 0]])
    return A, B, C

In [14]:
@partial(np.vectorize, signature="()->()")
def example_force(t):
    x = np.sin(10 * t)
    return x * (x > 0.5)

In [15]:
def example_ssm():
    # SSM
    ssm = example_mass(k=40, b=5, m=1)
    print(f'A:\n{ssm[0]}')
    print(f'B:\n{ssm[1]}')
    print(f'C:\n{ssm[2]}')
    

    # L samples of u(t).
    L = 100
    step = 1.0 / L
    ks = np.arange(L)
    u = example_force(ks * step)
    print(f'u:{u.shape}')
    print(f'u[0]:{u[0, np.newaxis]}')
    print(f'B @ u[0] {ssm[1] @ u[0, np.newaxis]}')

    # Approximation of y(t).
    y = run_SSM(*ssm, u)

    # Plotting ---
    import matplotlib.pyplot as plt
    import seaborn
    from celluloid import Camera
    plt.rcParams['animation.convert_path'] = '/usr/local/bin/convert'  # Adjust this path if necessary

    seaborn.set_context("paper")
    fig, (ax1, ax2, ax3) = plt.subplots(3)
    camera = Camera(fig)
    ax1.set_title("Force $u_k$")
    ax2.set_title("Position $y_k$")
    ax3.set_title("Object")
    ax1.set_xticks([], [])
    ax2.set_xticks([], [])

    # Animate plot over time
    for k in range(0, L, 2):
        ax1.plot(ks[:k], u[:k], color="red")
        ax2.plot(ks[:k], y[:k], color="blue")
        ax3.boxplot(
            [[y[k, 0] - 0.04, y[k, 0], y[k, 0] + 0.04]],
            showcaps=False,
            whis=False,
            vert=False,
            widths=10,
        )
        camera.snap()
    anim = camera.animate()
    
    anim.save("images/line.gif", dpi=150, writer="imagemagick")
    # do not show the plot
    plt.close()

In [36]:
plot = True
if plot:
    example_ssm()

A:
[[  0.   1.]
 [-40.  -5.]]
B:
[[0.]
 [1.]]
C:
[[1. 0.]]
u:(100,)
u[0]:[0.]
B @ u[0] [0. 0.]
u_k:(1,)


MovieWriter imagemagick unavailable; using Pillow instead.


<img src="images/line.gif" width="750" align="center">

In [16]:
def K_conv(Ab: np.array, Bb: np.array, Cb: np.array, L: int):
    '''
    Computes the convolution kernel for the SSM.
    K = (CB, CAB, C(A^2)B, ..., C(A^{L-1})B)
    WARNING: this is a naive implementation and should be avoided for large L (stability issues).
    :param Ab: Discretized state transition matrix A
    :param Bb: Discretized input matrix B 
    :param Cb: Discretized output matrix C
    :param L: Sequence length
    :return: 
    '''
    return np.array(
        [(Cb @ matrix_power(Ab, l) @ Bb).reshape() for l in range(L)]
    )

In [17]:
def causal_convolution(u: np.array, K: np.array, nofft: bool = False):
    '''
    Computes the convolution of the input signal u with the kernel K.
    Can FFT for faster computation in the frequency domain (need to pad the u and K to the same length).
    :param u: full length input signal
    :param K: SSM Kernel
    :param nofft: binary flag to disable FFT
    :return: full length output signal 
    '''
    if nofft:
        return convolve(u, K, mode="full")[: u.shape[0]] # from jax.scipy.signal
    else:
        assert K.shape[0] == u.shape[0]
        ud = np.fft.rfft(np.pad(u, (0, K.shape[0])))
        Kd = np.fft.rfft(np.pad(K, (0, u.shape[0])))
        out = ud * Kd
        return np.fft.irfft(out)[: u.shape[0]]

In [18]:
def test_cnn_is_rnn(N=4, L=16, step=1.0 / 16):
    '''
    Test if the convolution view of SSM is equivalent to the RNN view.
    :param N: State dimension
    :param L: Sequence length
    :param step: Discretization step
    :return: 
    '''
    ssm = random_SSM(rng, N)
    u = jax.random.uniform(rng, (L,))
    jax.random.split(rng, 3)
    
    # RNN
    rec = run_SSM(*ssm, u)

    # CNN
    ssmb = discretize(*ssm, step=step)
    conv = causal_convolution(u, K_conv(*ssmb, L))

    # Check
    # np.allclose: Returns True if two arrays are element-wise equal within a tolerance.
    print(f'CNN = RNN:  {np.allclose(rec.ravel(), conv.ravel())}')

In [19]:
test_cnn_is_rnn(4, 16, 1.0 / 16)

u_k:(1,)
CNN = RNN:  True
