# S4 code pipleline

## 1. SSM in Convolutional Representation
本阶段主要是证明对于一个SSM模型来说，在使用卷积形式得到的结果和RNN Representation得到的结果一致

### 1. 定义连续SSM与离散化后的SSM的参数

In [1]:
import numpy as np
np.random.seed(1)

# Continuous SSM
def Random_SSM(N):
    # A: (N, N)
    # B: (N, 1)
    # C: (1, N)
    
    A = np.random.rand(N, N)
    B = np.random.rand(N, 1)
    C = np.random.rand(1, N)
    return A, B, C

# Discrete SSM
def discretize(A, B, C, step):
    # A_bar = (I - step/2 * A)^(-1) * (I + step/2 * A)
    # B_bar = (I - step/2 * A)^(-1) * B * step
    
    I = np.eye(A.shape[0])
    BL = np.linalg.inv(I - (step / 2.0) * A)
    Ab = BL @ (I + (step / 2.0) * A)
    Bb = (BL * step) @ B
    return Ab, Bb, C

### 2. SSM RNN Representation

In [2]:
def scan_SSM(Ab, Bb, Cb, u, x0): 
    # x = A_bar * x + B_bar * u
    # y = C_bar * x
    
    x0 = Ab @ x0 + Bb * u
    y  = Cb @ x0
    return x0, y

# Demo: Run SSM
def run_SSM(A, B, C, u):
    L = u.shape[0]
    N = A.shape[0]
    Ab, Bb, Cb = discretize(A, B, C, step=1.0 / L)
    x0 = np.zeros((N, 1))
    
    # Run recurrence
    for i in range(L):
        x0, y = scan_SSM(Ab, Bb, Cb, u[i], x0)
    return y

### 3. utils

In [3]:
# Mat Power
def matmul_n_times(A, n_times):
    raw_data = A
    if n_times > 0:
        for i in range(n_times - 1):
            A = np.matmul(A, raw_data)
    elif n_times == 0:
        A = np.eye(A.shape[0])
    return A

# Get Conv Kernel
def K_conv(Ab, Bb, Cb, L):
    # K = [C_bar * A_bar^(i) * B_bar for i in range(L)]
    return np.array([(Cb @ matmul_n_times(Ab, i) @ Bb) for i in range(L)]).squeeze()

# Convolution
def causal_convolution(u, K, nofft=False):
    if nofft:
        # 不使用FFT
        return np.dot(K[::-1], np.transpose(u))
    else:
        # 使用FFT
        assert K.shape[0] == u.shape[0]
        ud = np.fft.rfft(np.pad(u, (0, K.shape[0])))
        Kd = np.fft.rfft(np.pad(K, (0, u.shape[0])))
        out = ud * Kd
        return np.fft.irfft(out)[u.shape[0] - 1]

In [5]:
# # Demo in this stage
# L = 4
# step = 1.0 / L
# A, B, C = Random_SSM(3)
# Ab, Bb, Cb = discretize(A, B, C, step)
# Origin_Kernel = K_conv(Ab, Bb, Cb, L)
# print("Origin Version Kernel is: {}".format(Origin_Kernel))

Origin Version Kernel is: [0.13734084 0.16658424 0.20268662 0.24721982]


In [6]:
def test_cnn_is_rnn(N=3, L=5, step=1.0 / 5):
    
    ssm = Random_SSM(N)
    u = np.array([-1,-2,-3,-4,-5])
    
    # RNN Results
    rec = run_SSM(*ssm, u)

    # CNN Results
    ssmb = discretize(*ssm, step=step)
    # Get Conv Kernel K
    K = K_conv(*ssmb, L)
    # Calculate K * u
    conv = causal_convolution(u, K, True)  
    conv2 = causal_convolution(u, K, False)
    
    # Check results
    print()
    print("RNN result is:", rec.ravel()[0])
    print("CNN(w\o FFT) result is:", conv.ravel()[0])
    print("CNN(w\ FFT) result is:", conv2.ravel()[0])

    return (np.abs((rec.ravel()[0] - conv2.ravel()[0])) < 1e-6)

print("RNN presentation and CNN(w\ FFT) presentation has the same output :", test_cnn_is_rnn())


RNN result is: -2.9878612423736812
CNN(w\o FFT) result is: -2.9878612423736812
CNN(w\ FFT) result is: -2.987861242373682
RNN presentation and CNN(w\ FFT) presentation has the same output : True


## 2. S4 Improvments

In [7]:
# Define Hippo Matrix
def make_HiPPO(N):
    P = np.sqrt(1 + 2 * np.arange(N))
    A = P[:, np.newaxis] * P[np.newaxis, :]  
    A = np.tril(A) - np.diag(np.arange(N))
    return -A

In [17]:
# function K*z naive method
def K_gen_simple(Ab, Bb, Cb, L):
    K = K_conv(Ab, Bb, Cb, L)

    def gen(z):
        return np.sum(K * (z ** np.arange(L)))

    return gen

# function K*z (generation function method)
def K_gen_inverse(Ab, Bb, Cb, L):
    I = np.eye(Ab.shape[0])
    Ab_L = matmul_n_times(Ab, L)
    # C 波浪 = C_bar * (I - A_bar^(L))
    Ct = Cb @ (I - Ab_L)

    return lambda z: (Ct @ np.linalg.inv(I - Ab * z) @ Bb)

def conv_from_gen(gen, L):
    # 创建长度为 L 的单位根
    Omega_L = np.exp((-2j * np.pi) * (np.arange(L) / L))
    
    # 在单位根上计算生成函数的值
    # M^(*) * Y = X
    atRoots = np.vectorize(gen)(Omega_L)

    # 计算逆 FFT 并求得系数
    out = np.fft.ifft(atRoots, L)
    # 返回实部，忽略虚部误差
    return out.real

def test_gen_inverse(L=16, N=4):
    ssm = Random_SSM(N)
    ssm = discretize(*ssm, 1.0 / L)
    Kernel = K_conv(*ssm, L=L)
    
    Kernel_new = conv_from_gen(K_gen_inverse(*ssm, L=L), L)
    print(Kernel)
    print(Kernel_new)
    return np.allclose(Kernel, Kernel_new, atol=1e-3)

In [18]:
test_gen_inverse()

[0.0057784  0.00720001 0.00885825 0.01079064 0.01304066 0.01565868
 0.01870304 0.02224131 0.02635178 0.03112513 0.03666643 0.04309739
 0.05055902 0.05921468 0.06925366 0.08089523]
[0.0057784  0.00720001 0.00885825 0.01079064 0.01304066 0.01565868
 0.01870304 0.02224131 0.02635178 0.03112513 0.03666643 0.04309739
 0.05055902 0.05921468 0.06925366 0.08089523]


True

In [10]:
def make_NPLR_HiPPO(N):
    # Make -HiPPO
    nhippo = make_HiPPO(N)

    # Add in a rank 1 term. Makes it Normal.
    P = np.sqrt(np.arange(N) + 0.5)

    # HiPPO also specifies the B matrix (LegS, A, B)
    B = np.sqrt(2 * np.arange(N) + 1.0)
    return nhippo, P, B

In [11]:
make_NPLR_HiPPO(3)

(array([[-1.        , -0.        , -0.        ],
        [-1.73205081, -2.        , -0.        ],
        [-2.23606798, -3.87298335, -3.        ]]),
 array([0.70710678, 1.22474487, 1.58113883]),
 array([1.        , 1.73205081, 2.23606798]))

In [21]:
def make_DPLR_HiPPO(N):
    """Diagonalize NPLR representation"""
    A, P, B = make_NPLR_HiPPO(N)

    S = A + P[:, np.newaxis] * P[np.newaxis, :]

    # Check skew symmetry
    S_diag = np.diagonal(S)
    Lambda_real = np.mean(S_diag) * np.ones_like(S_diag)
    Eye_Lambda_real = np.diag(Lambda_real)

#     print(S - Eye_Lambda_real)
#     print()
#     print((-1)*(S-Eye_Lambda_real).conj().T)

    print("S w\out diag is a skew-symmetry matrix: {}".format(np.allclose((S-Eye_Lambda_real), (-1)*(S-Eye_Lambda_real).conj().T, atol=1e-3)))

    # Diagonalize S to V \Lambda V^*
    # (1) S = skew + Diag(real)
    # (2) eigs for Skew should be 0 or imaginary number
    # (3) eigs for Diag(real) should be real number
    # (4) eigs for S: real + img; real comes from Diag part, eigs comes from skew part
    
    # eigh: Hermitian Matrix
    Lambda_imag, V = np.linalg.eigh(S)
    P = V.conj().T @ P
    B = V.conj().T @ B
    return Lambda_real + 1j * Lambda_imag, P, B, V

In [22]:
make_DPLR_HiPPO(3)

S w\out diag is a skew-symmetry matrix: True


(array([-0.5-3.17436055j, -0.5+0.21601465j, -0.5+1.4583459j ]),
 array([-2.08934509, -0.26782347,  0.25081405]),
 array([-2.95478016, -0.37875958,  0.35470464]),
 array([[-0.46541345,  0.87669884, -0.12161197],
        [-0.61300759, -0.41839483, -0.67019957],
        [-0.63844501, -0.23737083,  0.73214962]]))

### follow the pipeline to get k matrix

In [14]:
def cauchy(v, omega, lambd):
    """Cauchy matrix multiplication: (n), (l), (n) -> (l)"""
    cauchy_dot = lambda _omega: (v / (_omega - lambd)).sum()
    return np.vectorize(cauchy_dot)(omega)

# [1, 2, 3]
# [4, 5, 6]
# [4, 10, 18] sum
# [4, 10, 18]
def kernel_DPLR(Lambda, P, Q, B, C, step, L):
    # Evaluate at roots of unity
    # Generating function is (-)z-transform, so we evaluate at (-)root
    Omega_L = np.exp((-2j * np.pi) * (np.arange(L) / L))

    aterm = (C.conj(), Q.conj())
    bterm = (B, P)

    g = (2.0 / step) * ((1.0 - Omega_L) / (1.0 + Omega_L))
    c = 2.0 / (1.0 + Omega_L)

    # Reduction to core Cauchy kernel
    k00 = cauchy(aterm[0] * bterm[0], g, Lambda)
    k01 = cauchy(aterm[0] * bterm[1], g, Lambda)
    k10 = cauchy(aterm[1] * bterm[0], g, Lambda)
    k11 = cauchy(aterm[1] * bterm[1], g, Lambda)
    atRoots = c * (k00 - k01 * (1.0 / (1.0 + k11)) * k10)
    out = np.fft.ifft(atRoots, L).reshape(L)
    return out.real

In [24]:
def test_gen_dplr(L=16, N=4):
    I = np.eye(N)

    # Create a DPLR A matrix and discretize
    Lambda, P, B, _ = make_DPLR_HiPPO(N)
    A = np.diag(Lambda) - P[:, np.newaxis] @ P[:, np.newaxis].conj().T
    _, _, C = Random_SSM(N)

    Ab, Bb, Cb = discretize(A, B, C, 1.0 / L)
    # Naive Method to get Kernel
    a = K_conv(Ab, Bb, Cb.conj(), L=L)

    # Compare to the DPLR generating function approach.
    C_t = (I - matmul_n_times(Ab, L)).conj().T @ Cb.ravel()
    b = kernel_DPLR(Lambda, P, P, B, C_t, step=1.0 / L, L=L)
    print(a)
    print(b)
    return np.allclose(a.real, b.real)
test_gen_dplr()

S w\out diag is a skew-symmetry matrix: True
[ 0.0434862 -0.00966923j  0.02054288-0.02152032j  0.00523044-0.02266117j
 -0.00344699-0.01969137j -0.00752933-0.01600404j -0.00890294-0.01299339j
 -0.00891709-0.01099022j -0.00838588-0.00986299j -0.00772568-0.00934778j
 -0.00710559-0.0091985j  -0.00656401-0.00923537j -0.00608363-0.00934525j
 -0.00563252-0.00946493j -0.00518281-0.0095627j  -0.00471653-0.00962414j
 -0.00422552-0.00964291j]
[ 0.0434862   0.02054288  0.00523044 -0.00344699 -0.00752933 -0.00890294
 -0.00891709 -0.00838588 -0.00772568 -0.00710559 -0.00656401 -0.00608363
 -0.00563252 -0.00518281 -0.00471653 -0.00422552]


True

In [None]:
# A(Lambda_real, Lambda_img), B*V, C, step (SSM:(V^-1 * A * V, B * V, C))