In [5]:
from functools import partial
import jax
import jax.numpy as np
from flax import linen as nn
from jax.nn.initializers import lecun_normal, normal
from jax.numpy.linalg import eigh, inv, matrix_power
from jax.scipy.signal import convolve
import time

In [6]:
# Random SSM

def random_SSM(rng, N):
    a_r, b_r, c_r = jax.random.split(rng, 3)
    A = jax.random.uniform(a_r, (N, N))
    B = jax.random.uniform(b_r, (N, 1))
    C = jax.random.uniform(c_r, (1, N))
    return A, B, C

In [7]:
def discretize(A, B, C, step):
    I = np.eye(A.shape[0])
    BL = inv(I - (step / 2.0) * A)
    Ab = BL @ (I + (step / 2.0) * A)
    Bb = (BL * step) @ B
    return Ab, Bb, C

## Recurrent Representation

In [8]:
def scan_SSM(Ab, Bb, Cb, u, x0):
    def step(x_k_1, u_k):
        x_k = Ab @ x_k_1 + Bb @ u_k
        y_k = Cb @ x_k
        return x_k, y_k

    return jax.lax.scan(step, x0, u)

In [9]:
def run_SSM(A, B, C, u, step):
    L = u.shape[0]
    N = A.shape[0]
    Ab, Bb, Cb = discretize(A, B, C, step=step)

    # Run recurrence
    return scan_SSM(Ab, Bb, Cb, u[:, np.newaxis], np.zeros((N,)))[1]

## Convolutional Representation

In [75]:
# This creates our convolution kernel, K. This is a naive way to create it, so it shall be replaced by S4 later
def K_conv(Ab, Bb, Cb, L):
    # def bin_op(A_i, A_j):
    #     return A_j @ A_i
    # A_elems = Ab * np.ones((L, Ab.shape[0], Ab.shape[1]))
    # A_powers = jax.lax.associative_scan(bin_op, A_elems)
    # out = jax.vmap(lambda A: Cb @ A @ Bb)(A_powers)[:,0][:,0]
    out = np.array(
        [(Cb @ matrix_power(Ab, l) @ Bb).reshape() for l in range(L)]
    )
    return out

In [76]:
def causal_convolution(u, K, nofft=False):
    if nofft:
        return convolve(u, K, mode="full")[: u.shape[0]]
    else:
        assert K.shape[0] == u.shape[0]
        # We need to pad before computing the FFT because an FFT computes a circular convolution. So in order to make this compute a linear convolution, we need to make sure the wrap around consists of only zeros
        ud = np.fft.rfft(np.pad(u, (0, K.shape[0])))
        Kd = np.fft.rfft(np.pad(K, (0, u.shape[0])))
        out = ud * Kd
        return np.fft.irfft(out)[: u.shape[0]] # Here we unpad to get rid of the padding we added

## Parallel Scan Representation

In [77]:
def associative_scan(Ab, Bb, Cb, u):
    Bu = jax.vmap(lambda u_i: Bb * u_i)(u)
    def bin_op(q_i, q_j):
        A_i, b_i = q_i
        A_j, b_j = q_j
        return A_j @ A_i, A_j @ b_i + b_j
    A_elems = Ab * np.ones((u.shape[0], Ab.shape[0], Ab.shape[1]))
    _, x = jax.lax.associative_scan(bin_op, (A_elems, Bu))
    return jax.vmap(lambda x_i: Cb @ x_i)(x)

TODO: work on this

In [None]:
def associative_scan_LR(Bb, Cb, u):
    Bu = jax.vmap(lambda u_i: Bb * u_i)(u)
    def bin_op(q_i, q_j):
        L_i, R_i, b_i = q_i
        L_j, R_j, b_j = q_j
        return A_j @ A_i, A_j @ b_i + b_j
    A_elems = Ab * np.ones((u.shape[0], Ab.shape[0], Ab.shape[1]))
    _, x = jax.lax.associative_scan(bin_op, (A_elems, Bu))
    return jax.vmap(lambda x_i: Cb @ x_i)(x)

### Testing to see if the recurrent form is the same as the convolutional form

In [78]:
rng = jax.random.PRNGKey(1)

In [83]:
def test_cnn_is_rnn(N=4, L=16, step=1.0 / 16):
    """
    :param N: Width of the SSM
    :param L: Sequence length
    :param step: Timestep scalar
    :return:
    """
    ssm = random_SSM(rng, N) # This is a tuple containing our SSM parameters
    u = jax.random.uniform(rng, (L,))
    jax.random.split(rng, 3)
    # RNN
    start = time.time_ns()
    rec = run_SSM(*ssm, u, step=step)
    rec_time = time.time_ns()-start
    print(rec_time)
    # CNN
    start = time.time_ns()
    ssmb = discretize(*ssm, step=step)
    conv_time_mid = time.time_ns()-start
    K = K_conv(*ssmb, L)
    start = time.time_ns()
    conv = causal_convolution(u, K)
    conv_time = time.time_ns()-start+conv_time_mid
    print(conv_time)
    # Associative Scan
    start = time.time_ns()
    scan = associative_scan(*ssmb, u)
    scan_time = time.time_ns()-start
    print(scan_time)

    # Check
    assert np.allclose(rec.ravel(), conv.ravel(), scan.ravel())

    return rec_time, conv_time, scan_time
rec, conv, scan = test_cnn_is_rnn(L=16)

218504838
6266132
19221244
