# HiPPO Matrices
---

## Load Packages

In [141]:
## import packages
import jax
import jax.numpy as jnp

from jax.nn.initializers import lecun_normal, uniform
from jax.numpy.linalg import eig, inv, matrix_power
from jax.scipy.signal import convolve

import os
import requests

from scipy import linalg as la
from scipy import signal
from scipy import special as ss

import math

## setup JAX to use TPUs if available
try:
    url = 'http:' + os.environ['TPU_NAME'].split(':')[1] + ':8475/requestversion/tpu_driver_nightly'
    resp = requests.post(url)
    jax.config.FLAGS.jax_xla_backend = 'tpu_driver'
    jax.config.FLAGS.jax_backend_target = os.environ['TPU_NAME']
except:
    pass

jax.devices()

[CpuDevice(id=0)]

In [None]:
rng = jax.random.PRNGKey(1)

In [None]:
def random_SSM(rng, N):
    a_r, b_r, c_r = jax.random.split(rng, 3)
    A = jax.random.uniform(a_r, (N, N))
    B = jax.random.uniform(b_r, (N, 1))
    C = jax.random.uniform(c_r, (1, N))
    return A, B, C

In [142]:
N=5

## Instantiate The HiPPO Matrix

## Translated Legendre (LegT)

In [143]:
# Translated Legendre (LegT) - vectorized
def build_LegT_V(N, lambda_n=1):
    q = jnp.arange(N, dtype=jnp.float64)
    k, n = jnp.meshgrid(q, q)
    case = jnp.power(-1.0, (n-k))
    A = None
    B = None
    
    if lambda_n == 1:
        A_base = -jnp.sqrt(2*n+1) * jnp.sqrt(2*k+1)
        pre_D = jnp.sqrt(jnp.diag(2*q+1))
        B = D = jnp.diag(pre_D)[:, None]
        A = jnp.where(k <= n, A_base, A_base * case) # if n >= k, then case_2 * A_base is used, otherwise A_base
        
    elif lambda_n == 2: #(jnp.sqrt(2*n+1) * jnp.power(-1, n)):
        A_base = -(2*n+1)
        B = jnp.diag((2*q+1) * jnp.power(-1, n))[:, None]
        A = jnp.where(k <= n, A_base * case, A_base) # if n >= k, then case_2 * A_base is used, otherwise A_base

    return A, B

In [144]:
# Translated Legendre (LegT) - non-vectorized
def build_LegT(N, legt_type="legt"):
    Q = jnp.arange(N, dtype=jnp.float64)
    pre_R = (2*Q + 1)
    k, n = jnp.meshgrid(Q, Q)
        
    if legt_type == "legt":
        R = jnp.sqrt(pre_R)
        A = R[:, None] * jnp.where(n < k, (-1.)**(n-k), 1) * R[None, :]
        B = R[:, None]
        A = -A

        # Halve again for timescale correctness
        # A, B = A/2, B/2
        #A *= 0.5
        #B *= 0.5
        
    elif legt_type == "lmu":
        R = pre_R[:, None]
        A = jnp.where(n < k, -1, (-1.)**(n-k+1)) * R
        B = (-1.)**Q[:, None] * R
        
    return A, B

In [145]:
nv_LegT_A, nv_LegT_B = build_LegT(N=N, legt_type="lmu")
LegT_A, LegT_B = build_LegT_V(N=N, lambda_n=2)
print(f"nv:\n", nv_LegT_A)
print(f"v:\n", LegT_A)
# print(f"nv:\n", nv_LegT_B)
# print(f"v:\n", LegT_B)
print(f"A Comparison:\n ", jnp.allclose(nv_LegT_A, LegT_A))
print(f"B Comparison:\n ", jnp.allclose(nv_LegT_B, LegT_B))

nv:
 [[-1. -1. -1. -1. -1.]
 [ 3. -3. -3. -3. -3.]
 [-5.  5. -5. -5. -5.]
 [ 7. -7.  7. -7. -7.]
 [-9.  9. -9.  9. -9.]]
v:
 [[-1. -1. -1. -1. -1.]
 [ 3. -3. -3. -3. -3.]
 [-5.  5. -5. -5. -5.]
 [ 7. -7.  7. -7. -7.]
 [-9.  9. -9.  9. -9.]]
A Comparison:
  True
B Comparison:
  True


## Translated Laguerre (LagT)

In [146]:
# Translated Laguerre (LagT) - non-vectorized
def build_LagT(alpha, beta, N):
    A = -jnp.eye(N) * (1 + beta) / 2 - jnp.tril(jnp.ones((N, N)), -1)
    B = ss.binom(alpha + jnp.arange(N), jnp.arange(N))[:, None]

    L = jnp.exp(.5 * (ss.gammaln(jnp.arange(N)+alpha+1) - ss.gammaln(jnp.arange(N)+1)))
    A = (1./L[:, None]) * A * L[None, :]
    B = (1./L[:, None]) * B * jnp.exp(-.5 * ss.gammaln(1-alpha)) * beta**((1-alpha)/2)
    
    return A, B

In [147]:
# Translated Laguerre (LagT) - non-vectorized
def build_LagT_V(alpha, beta, N):
    L = jnp.exp(.5 * (ss.gammaln(jnp.arange(N)+alpha+1) - ss.gammaln(jnp.arange(N)+1)))
    inv_L = 1./L[:, None]
    pre_A = (jnp.eye(N) * ((1 + beta) / 2)) + jnp.tril(jnp.ones((N, N)), -1)
    pre_B = ss.binom(alpha + jnp.arange(N), jnp.arange(N))[:, None]
    
    A = -inv_L * pre_A * L[None, :]
    B =  jnp.exp(-.5 * ss.gammaln(1-alpha)) * jnp.power(beta, (1-alpha)/2) * inv_L * pre_B 
    
    return A, B

In [148]:
nv_LagT_A, nv_LagT_B = build_LagT(alpha=0, beta=1, N=N)
LagT_A, LagT_B = build_LagT_V(alpha=0, beta=1, N=N)
print(f"nv:\n{nv_LagT_A}")
print(f"v:\n{LagT_A}")
# print(f"nv:\n{nv_LagT_B}")
# print(f"v:\n{LagT_B}")
print(f"A Comparison:\n ", jnp.allclose(nv_LagT_A, LagT_A))
print(f"B Comparison:\n ", jnp.allclose(nv_LagT_B, LagT_B))

nv:
[[-1. -0. -0. -0. -0.]
 [-1. -1. -0. -0. -0.]
 [-1. -1. -1. -0. -0.]
 [-1. -1. -1. -1. -0.]
 [-1. -1. -1. -1. -1.]]
v:
[[-1. -0. -0. -0. -0.]
 [-1. -1. -0. -0. -0.]
 [-1. -1. -1. -0. -0.]
 [-1. -1. -1. -1. -0.]
 [-1. -1. -1. -1. -1.]]
A Comparison:
  True
B Comparison:
  True


## Scaled Legendre (LegS)

In [149]:
#Scaled Legendre (LegS) vectorized
def build_LegS_V(N):
    q = jnp.arange(N, dtype=jnp.float64)
    k, n = jnp.meshgrid(q, q)
    pre_D = jnp.sqrt(jnp.diag(2*q+1))
    B = D = jnp.diag(pre_D)[:, None]
    
    A_base = (-jnp.sqrt(2*n+1)) * jnp.sqrt(2*k+1)
    case_2 = (n+1)/(2*n+1)
    
    A = jnp.where(n > k, A_base, 0.0) # if n > k, then A_base is used, otherwise 0
    A = jnp.where(n == k, (A_base * case_2), A) # if n == k, then A_base is used, otherwise A
    
    return A, B

In [150]:
#Scaled Legendre (LegS), non-vectorized
def build_LegS(N):
    q = jnp.arange(N, dtype=jnp.float64)  # q represents the values 1, 2, ..., N each column has
    k, n = jnp.meshgrid(q, q)
    r = 2 * q + 1
    M = -(jnp.where(n >= k, r, 0) - jnp.diag(q)) # represents the state matrix M 
    D = jnp.sqrt(jnp.diag(2 * q + 1)) # represents the diagonal matrix D $D := \text{diag}[(2n+1)^{\frac{1}{2}}]^{N-1}_{n=0}$
    A = D @ M @ jnp.linalg.inv(D)
    B = jnp.diag(D)[:, None]
    B = B.copy() # Otherwise "UserWarning: given NumPY array is not writeable..." after torch.as_tensor(B)
    
    return A, B

In [151]:
nv_LegS_A, nv_LegS_B = build_LegS(N=N)
LegS_A, LegS_B = build_LegS_V(N=N)
print(f"nv:\n{nv_LegS_A}")
print(f"v:\n{LegS_A}")
# print(f"nv:\n{nv_LegS_B}")
# print(f"v:\n{LegS_B}")
print(f"A Comparison:\n ", jnp.allclose(nv_LegS_A, LegS_A))
print(f"B Comparison:\n ", jnp.allclose(nv_LegS_B, LegS_B))

nv:
[[-1.         0.         0.         0.         0.       ]
 [-1.7320508 -1.9999999  0.         0.         0.       ]
 [-2.236068  -3.8729835 -3.         0.         0.       ]
 [-2.6457512 -4.582576  -5.9160795 -4.         0.       ]
 [-3.        -5.196152  -6.708204  -7.9372544 -5.       ]]
v:
[[-1.         0.         0.         0.         0.       ]
 [-1.7320508 -2.         0.         0.         0.       ]
 [-2.236068  -3.8729832 -3.         0.         0.       ]
 [-2.6457512 -4.5825753 -5.9160795 -4.         0.       ]
 [-3.        -5.196152  -6.7082043 -7.937254  -5.       ]]
A Comparison:
  True
B Comparison:
  True


## Fourier Basis

In [152]:
# Fourier Basis OPs and functions - vectorized
def build_Fourier_V(N, fourier_type='FRU'):    
    q = jnp.arange((N//2)*2, dtype=jnp.float64)
    k, n = jnp.meshgrid(q, q)
    
    n_odd = n % 2 == 0
    k_odd = k % 2 == 0
    
    case_1 = (n==k) & (n==0)
    case_2_3 = ((k==0) & (n_odd)) | ((n==0) & (k_odd))
    case_4 = (n_odd) & (k_odd)
    case_5 = (n-k==1) & (k_odd)
    case_6 = (k-n==1) & (n_odd)
    
    A = None
    B = None
    
    if fourier_type == "FRU": # Fourier Recurrent Unit (FRU) - vectorized
        A = jnp.diag(jnp.stack([jnp.zeros(N//2), jnp.zeros(N//2)], axis=-1).reshape(-1))
        B = jnp.zeros(A.shape[1], dtype=jnp.float64)
        q = jnp.arange((N//2)*2, dtype=jnp.float64)
        
        A = jnp.where(case_1, -1.0, 
                    jnp.where(case_2_3, -jnp.sqrt(2),
                                jnp.where(case_4, -2, 
                                        jnp.where(case_5, jnp.pi * (n//2), 
                                                    jnp.where(case_6, -jnp.pi * (k//2), 0.0)))))
        
        B = B.at[::2].set(jnp.sqrt(2))
        B = B.at[0].set(1)
        
    elif fourier_type == "FouT": # truncated Fourier (FouT) - vectorized
        A = jnp.diag(jnp.stack([jnp.zeros(N//2), jnp.zeros(N//2)], axis=-1).reshape(-1))
        B = jnp.zeros(A.shape[1], dtype=jnp.float64)
        k, n = jnp.meshgrid(q, q)
        n_odd = n % 2 == 0
        k_odd = k % 2 == 0
        
        A = jnp.where(case_1, -1.0, 
                    jnp.where(case_2_3, -jnp.sqrt(2),
                                jnp.where(case_4, -2, 
                                        jnp.where(case_5, jnp.pi * (n//2), 
                                                    jnp.where(case_6, -jnp.pi * (k//2), 0.0)))))
        
        B = B.at[::2].set(jnp.sqrt(2))
        B = B.at[0].set(1)
        
        A = 2 * A
        B = 2 * B
        
    elif fourier_type == "fourier_decay":
        A = jnp.diag(jnp.stack([jnp.zeros(N//2), jnp.zeros(N//2)], axis=-1).reshape(-1))
        B = jnp.zeros(A.shape[1], dtype=jnp.float64)
        
        A = jnp.where(case_1, -1.0, 
                    jnp.where(case_2_3, -jnp.sqrt(2),
                                jnp.where(case_4, -2, 
                                        jnp.where(case_5, 2 * jnp.pi * (n//2), 
                                                    jnp.where(case_6, 2 * -jnp.pi * (k//2), 0.0)))))
        
        B = B.at[::2].set(jnp.sqrt(2))
        B = B.at[0].set(1)
        
        A = 0.5 * A
        B = 0.5 * B
        
    
    
    B = B[:, None]
        
    return A, B

In [153]:
def build_Fourier(N, fourier_type='FRU'):
    freqs = jnp.arange(N//2)
    
    if fourier_type == "FRU": # Fourier Recurrent Unit (FRU) - non-vectorized
        d = jnp.stack([jnp.zeros(N//2), freqs], axis=-1).reshape(-1)[1:]
        A = jnp.pi*(-jnp.diag(d, 1) + jnp.diag(d, -1))
        
        B = jnp.zeros(A.shape[1])
        B = B.at[0::2].set(jnp.sqrt(2))
        B = B.at[0].set(1)
        
        A = A - B[:, None] * B[None, :]
        B = B[:, None]

    elif fourier_type == "FouT": # truncated Fourier (FouT) - non-vectorized
        freqs *= 2
        d = jnp.stack([jnp.zeros(N//2), freqs], axis=-1).reshape(-1)[1:]
        A = jnp.pi*(-jnp.diag(d, 1) + jnp.diag(d, -1))
        
        B = jnp.zeros(A.shape[1])
        B = B.at[0::2].set(jnp.sqrt(2))
        B = B.at[0].set(1)

        # Subtract off rank correction - this corresponds to the other endpoint u(t-1) in this case
        A = A - B[:, None] * B[None, :] * 2
        B = B[:, None] * 2
        
    elif fourier_type == "fourier_decay":
        d = jnp.stack([jnp.zeros(N//2), freqs], axis=-1).reshape(-1)[1:]
        A = jnp.pi*(-jnp.diag(d, 1) + jnp.diag(d, -1))
        
        B = jnp.zeros(A.shape[1])
        B = B.at[0::2].set(jnp.sqrt(2))
        B = B.at[0].set(1)

        # Subtract off rank correction - this corresponds to the other endpoint u(t-1) in this case
        A = A - 0.5 * B[:, None] * B[None, :]
        B = 0.5 * B[:, None]
        
            
    return A, B

In [154]:
nv_Fourier_A, nv_Fourier_B = build_Fourier(N=N, fourier_type='FRU')
Fourier_A, Fourier_B = build_Fourier_V(N=N, fourier_type='FRU')
print(f"nv:\n{nv_Fourier_A}")
print(f"v:\n{Fourier_A}")
# print(f"nv:\n{nv_Fourier_B}")
# print(f"v:\n{Fourier_B}")
print(f"A Comparison:\n {jnp.allclose(nv_Fourier_A, Fourier_A)}")
print(f"B Comparison:\n {jnp.allclose(nv_Fourier_B, Fourier_B)}")

nv:
[[-1.         0.        -1.4142135  0.       ]
 [ 0.         0.         0.         0.       ]
 [-1.4142135  0.        -1.9999999 -3.1415927]
 [ 0.         0.         3.1415927  0.       ]]
v:
[[-1.        -0.        -1.4142135  0.       ]
 [ 0.         0.         0.         0.       ]
 [-1.4142135  0.        -2.        -3.1415927]
 [ 0.         0.         3.1415927  0.       ]]
A Comparison:
 True
B Comparison:
 True


In [155]:
def make_HiPPO(N, v='nv', HiPPO_type="legs", lambda_n=1, fourier_type="FRU", alpha=0, beta=1):
    A = None
    B = None
    for k in range(1, N + 1):
        for n in range(1, N + 1):
            if HiPPO_type == "legt":
                if v == 'nv':
                    A, B = build_LegT(N=N, lambda_n=lambda_n)
                else:
                 A, B = build_LegT_V(N=N, lambda_n=lambda_n) 
                
            elif HiPPO_type == "lagt":
                if v == 'nv':
                    A, B = build_LagT(alpha=alpha, beta=beta, N=N)
                else:
                    A, B = build_LagT_V(alpha=alpha, beta=beta, N=N)
                
            elif HiPPO_type == "legs":
                if v == 'nv':
                    A, B = build_LegS(N=N)
                else:
                    A, B = build_LegS_V(N=N)
                
            elif HiPPO_type == "fourier":
                if v == 'nv':
                    A, B = build_Fourier(N=N, fourier_type=fourier_type)
                else:
                    A, B = build_Fourier_V(N=N, fourier_type=fourier_type)
                
            elif HiPPO_type == "random":
                A = jnp.random.randn(N, N) / N
                B = jnp.random.randn(N, 1)
                
            elif HiPPO_type == "diagonal":
                A = -jnp.diag(jnp.exp(jnp.random.randn(N)))
                B = jnp.random.randn(N, 1)
                
            else:
                raise ValueError("Invalid HiPPO type")
    
    return -jnp.array(A), B

In [156]:
nv_A, nv_B = make_HiPPO(N=N, v='nv', HiPPO_type="legs", lambda_n=1, fourier_type="FRU", alpha=0, beta=1)
v_A, v_B = make_HiPPO(N=N, v='v', HiPPO_type="legs", lambda_n=1, fourier_type="FRU", alpha=0, beta=1)
print(f"nv:\n", nv_A)
print(f"v:\n", v_A)
# print(f"nv:\n", nv_B)
# print(f"v:\n", v_B)
print(f"A Comparison:\n ", jnp.allclose(nv_A, v_A))
print(f"B Comparison:\n ", jnp.allclose(nv_B, v_B))

nv:
 [[ 1.        -0.        -0.        -0.        -0.       ]
 [ 1.7320508  1.9999999 -0.        -0.        -0.       ]
 [ 2.236068   3.8729835  3.        -0.        -0.       ]
 [ 2.6457512  4.582576   5.9160795  4.        -0.       ]
 [ 3.         5.196152   6.708204   7.9372544  5.       ]]
v:
 [[ 1.        -0.        -0.        -0.        -0.       ]
 [ 1.7320508  2.        -0.        -0.        -0.       ]
 [ 2.236068   3.8729832  3.        -0.        -0.       ]
 [ 2.6457512  4.5825753  5.9160795  4.        -0.       ]
 [ 3.         5.196152   6.7082043  7.937254   5.       ]]
A Comparison:
  True
B Comparison:
  True


In [157]:
def discretize(A, B, C, step, alpha=0.5):
    '''
    - forward Euler corresponds to α = 0,
    - backward Euler corresponds to α = 1,
    - bilinear corresponds to α = 0.5,
    - Zero-order Hold corresponds to α > 1
    '''
    I = jnp.eye(A.shape[0])
    GBT = jnp.linalg.inv(I - ((step * alpha) * A))
    GBT_A = GBT @ (I + ((step * (1-alpha)) * A))
    GBT_B = (step * GBT) @ B
    
    if alpha > 1: # Zero-order Hold
        GBT_A = jax.scipy.linalg.expm(step * A)
        GBT_B = (jnp.linalg.inv(A) @ (jax.scipy.linalg.expm(step * A) - I)) @ B 
    
    return GBT_A, GBT_B, C

In [None]:
def discretize_b(A, B, A_stacked, B_stacked, N, max_length=1024, discretization='bilinear'):
    for t in range(1, max_length + 1):
        At = A / t
        Bt = B / t
        if discretization == 'forward':
            A_stacked[t - 1] = jnp.eye(N) + At
            B_stacked[t - 1] = Bt
        elif discretization == 'backward':
            A_stacked[t - 1] = la.solve_triangular(jnp.eye(N) - At, jnp.eye(N), lower=True)
            B_stacked[t - 1] = la.solve_triangular(jnp.eye(N) - At, Bt, lower=True)
        elif discretization == 'bilinear':
            
            A_stacked[t - 1] = la.solve_triangular(jnp.eye(N) - At / 2, jnp.eye(N) + At / 2, lower=True)
            B_stacked[t - 1] = la.solve_triangular(jnp.eye(N) - At / 2, Bt, lower=True)
        else: # ZOH
            A_stacked[t - 1] = la.expm(A * (math.log(t + 1) - math.log(t)))
            B_stacked[t - 1] = la.solve_triangular(A, A_stacked[t - 1] @ B - B, lower=True)
            
    return A_stacked, B_stacked
        

In [None]:
def collect_SSM_vars(A, B, C, u, alpha=0.5):
    L = u.shape[0]
    N = A.shape[0]
    Ab, Bb, Cb = discretize(A, B, C, step=1.0 / L, alpha=alpha)
    
    return Ab, Bb, Cb

In [None]:
def scan_SSM(Ab, Bb, Cb, u, x0):
    '''
    This is for returning the discretized hidden state often needed for an RNN. 
    '''
    def step(x_k_1, u_k):
        x_k = Ab @ x_k_1 + Bb @ u_k
        y_k = Cb @ x_k
        return x_k, y_k

    return jax.lax.scan(step, x0, u)

In [None]:
class HiPPO_LegT(nn.Module):
    def __init__(self, N, dt=1.0, discretization='bilinear'):
        """
        N: the order of the HiPPO projection
        dt: discretization step size - should be roughly inverse to the length of the sequence
        """
        super().__init__()
        self.N = N
        A, B = transition('lmu', N)
        C = np.ones((1, N))
        D = np.zeros((1,))
        # dt, discretization options
        A, B, _, _, _ = signal.cont2discrete((A, B, C, D), dt=dt, method=discretization)

        B = B.squeeze(-1)

        self.register_buffer('A', torch.Tensor(A)) # (N, N)
        self.register_buffer('B', torch.Tensor(B)) # (N,)

        # vals = np.linspace(0.0, 1.0, 1./dt)
        vals = np.arange(0.0, 1.0, dt)
        self.eval_matrix = torch.Tensor(ss.eval_legendre(np.arange(N)[:, None], 1 - 2 * vals).T)

    def forward(self, inputs):
        """
        inputs : (length, ...)
        output : (length, ..., N) where N is the order of the HiPPO projection
        """

        inputs = inputs.unsqueeze(-1)
        u = inputs * self.B # (length, ..., N)

        c = torch.zeros(u.shape[1:])
        cs = []
        for f in inputs:
            c = F.linear(c, self.A) + self.B * f
            cs.append(c)
        return torch.stack(cs, dim=0)

    def reconstruct(self, c):
        return (self.eval_matrix @ c.unsqueeze(-1)).squeeze(-1)

In [158]:
class HiPPO_LegS(nn.Module):
    """ Vanilla HiPPO-LegS model (scale invariant instead of time invariant) """
    def __init__(self, N, max_length=1024, measure='legs', discretization='bilinear'):
        """
        max_length: maximum sequence length
        """
        super().__init__()
        self.N = N
        A, B = transition(measure, N)
        B = B.squeeze(-1)
        A_stacked = np.empty((max_length, N, N), dtype=A.dtype)
        B_stacked = np.empty((max_length, N), dtype=B.dtype)
        for t in range(1, max_length + 1):
            At = A / t
            Bt = B / t
            if discretization == 'forward':
                A_stacked[t - 1] = np.eye(N) + At
                B_stacked[t - 1] = Bt
            elif discretization == 'backward':
                A_stacked[t - 1] = la.solve_triangular(np.eye(N) - At, np.eye(N), lower=True)
                B_stacked[t - 1] = la.solve_triangular(np.eye(N) - At, Bt, lower=True)
            elif discretization == 'bilinear':
                A_stacked[t - 1] = la.solve_triangular(np.eye(N) - At / 2, np.eye(N) + At / 2, lower=True)
                B_stacked[t - 1] = la.solve_triangular(np.eye(N) - At / 2, Bt, lower=True)
            else: # ZOH
                A_stacked[t - 1] = la.expm(A * (math.log(t + 1) - math.log(t)))
                B_stacked[t - 1] = la.solve_triangular(A, A_stacked[t - 1] @ B - B, lower=True)
        self.A_stacked = torch.Tensor(A_stacked) # (max_length, N, N)
        self.B_stacked = torch.Tensor(B_stacked) # (max_length, N)
        # print("B_stacked shape", B_stacked.shape)

        vals = np.linspace(0.0, 1.0, max_length)
        self.eval_matrix = torch.Tensor((B[:, None] * ss.eval_legendre(np.arange(N)[:, None], 2 * vals - 1)).T)

    def forward(self, inputs, fast=False):
        """
        inputs : (length, ...)
        output : (length, ..., N) where N is the order of the HiPPO projection
        """

        L = inputs.shape[0]

        inputs = inputs.unsqueeze(-1)
        u = torch.transpose(inputs, 0, -2)
        u = u * self.B_stacked[:L]
        u = torch.transpose(u, 0, -2) # (length, ..., N)

        if fast:
            result = unroll.variable_unroll_matrix(self.A_stacked[:L], u)
        else:
            result = unroll.variable_unroll_matrix_sequential(self.A_stacked[:L], u)
        return result

    def reconstruct(self, c):
        a = self.eval_matrix @ c.unsqueeze(-1)
        return a.squeeze(-1)

In [159]:
class HiPPO(nn.Module):
    """HiPPO model"""
    N: int # order of the HiPPO projection, aka the number of coefficients to describe the matrix
    max_length: int # maximum sequence length to be input
    measure: str # the measure used to define which way to instantiate the HiPPO matrix
    step: float # step size used for discretization
    GBT_alpha: float # represents which descretization transformation to use based off the alpha value
    seq_L: int # length of the sequence to be used for training
    
    def setup(self):
        A, B = make_HiPPO(N=self.N, v='v', HiPPO_type="legs", lambda_n=1, fourier_type="FRU", alpha=0, beta=1)
        self.A = A
        self.B = B.squeeze(-1)
        self.C = jnp.ones((1, self.N))
        self.D = jnp.zeros((1,))
        
        if self.measure == "legt":
            L = self.seq_L
            vals = jnp.arange(0.0, 1.0, L)
            self.eval_matrix = jnp.ndarray(ss.eval_legendre(jnp.arange(self.N)[:, None], 1 - 2 * vals).T)
            
        elif self.measure == "legs":
            L = self.max_length
            vals = jnp.linspace(0.0, 1.0, L)
            self.eval_matrix = jnp.ndarray((B[:, None] * ss.eval_legendre(jnp.arange(self.N)[:, None], 2 * vals - 1)).T)
        
    def __call__(self, inputs, kernel=True):
        
        if not kernel:
            Ab, Bb, Cb, Db = self.collect_SSM_vars(self.A, self.B, self.C, u, alpha=self.GBT_alpha)
            c_k = self.scan_SSM(Ab, Bb, Cb, Db, u[:, jnp.newaxis], jnp.zeros((self.N,)))[1]
        else:
            Ab, Bb, Cb, Db = self.discretize(self.A, self.B, self.C, self.D, step=self.step, alpha=self.GBT_alpha)
            c_k = self.causal_convolution(u, self.K_conv(Ab, Bb, Cb, Db, L=self.max_length))
            
        return c_k
    
    def reconstruct(self, c):
        a = self.eval_matrix @ jnp.expand_dims(c, -1)
        return a.squeeze(-1)
    
    def discretize(self, A, B, C, D, step, alpha=0.5):
        '''
        function used for descretizing the HiPPO matrix
        - forward Euler corresponds to α = 0,
        - backward Euler corresponds to α = 1,
        - bilinear corresponds to α = 0.5,
        - Zero-order Hold corresponds to α > 1
        '''
        I = jnp.eye(A.shape[0])
        GBT = jnp.linalg.inv(I - ((step * alpha) * A))
        GBT_A = GBT @ (I + ((step * (1-alpha)) * A))
        GBT_B = (step * GBT) @ B
        
        if alpha > 1: # Zero-order Hold
            GBT_A = jax.scipy.linalg.expm(step * A)
            GBT_B = (jnp.linalg.inv(A) @ (jax.scipy.linalg.expm(step * A) - I)) @ B 
        
        return GBT_A, GBT_B, C, D
    
    def collect_SSM_vars(self, A, B, C, D, u, alpha=0.5):
        L = u.shape[0]
        assert L == self.seq_L
        N = A.shape[0]
        Ab, Bb, Cb, Db = self.discretize(A, B, C, D, step=1.0 / L, alpha=alpha)
    
        return Ab, Bb, Cb, Db
    
    def scan_SSM(self, Ab, Bb, Cb, Db, u, x0):
        '''
        This is for returning the discretized hidden state often needed for an RNN. 
        params:
            Ab: the discretized A matrix
            Bb: the discretized B matrix
            Cb: the discretized C matrix
            u: the input sequence
            x0: the initial hidden state
        returns:
            the next hidden state (aka coefficients representing the function, f(t))
        '''
        def step(x_k_1, u_k):
            '''
            x_k_1: previous hidden state
            u_k: output from function f at, descritized, time step, k.
            returns: 
                x_k: current hidden state
                y_k: current output of hidden state applied to Cb (sorry for being vague, I just dont know yet)
            '''
            x_k = (Ab @ x_k_1) + (Bb @ u_k)
            y_k = (Cb @ x_k) + (Db @ u_k)
            return x_k, y_k

        return jax.lax.scan(step, x0, u)

In [None]:
def test():
    N = 256
    L = 128
    
    x = torch.randn(L, 1)
    # ----------------------------------------------------------------------------------
    # ------------------------------ Test HiPPO LegT model -----------------------------
    # ----------------------------------------------------------------------------------
    hippo = HiPPO_LegT(N, dt=1./L)
    
    y = hippo(x)
    
    print(y.shape)
    z = hippo.reconstruct(y)
    print(z.shape)

    # mse = torch.mean((z[-1,0,:L].flip(-1) - x.squeeze(-1))**2)
    mse = torch.mean((z[-1,0,:L] - x.squeeze(-1))**2)
    print(mse)
    
    # ----------------------------------------------------------------------------------
    # ------------------------------ Test HiPPO LegS model -----------------------------
    # ----------------------------------------------------------------------------------
    # print(y.shape)
    hippo_legs = HiPPO_LegS(N, max_length=L)
    
    y = hippo_legs(x)
    
    print(y.shape)
    
    z = hippo_legs(x, fast=True)
    
    print(hippo_legs.reconstruct(z).shape)
    
    # print(y-z)
    
    # ----------------------------------------------------------------------------------
    # ------------------------------ Test Generic HiPPO model --------------------------
    # ----------------------------------------------------------------------------------
    hippo_LegS = HiPPO(N=N,
                       max_length=L, 
                       measure='legs', 
                       step=1.0/L,, 
                       GBT_alpha=0.5, 
                       seq_L=L)
    
    y_legs = hippo_LegS(x)
    
    print(y_legs.shape)
    
    z_legs = hippo_LegS(x, fast=True)
    
    print(hippo_LegS.reconstruct(z_legs).shape)
    
    
    
    
    

In [160]:
#Aa_1, Bb_1, C_1 = a_func(N, max_length=1024, measure='legs', discretization='bilinear')
Aa_2, Bb_2, C_2 = another_func(N, max_length=1024, measure='legs', discretization='bilinear')
#print(f"A-1:\n{Aa_1}")
# print(f"B-1:\n{Bb_1}")
# print(f"C-1:\n{C_1}")
print(f"A-2:\n{Aa_2}")
# print(f"B-2:\n{Bb_2}")
# print(f"C-2:\n{C_2}")
# print(f"A Comparison:\n ", jnp.allclose(Aa_1, Aa_2))
# print(f"B Comparison:\n ", jnp.allclose(Bb_1, Bb_2))
# print(f"C Comparison:\n ", jnp.allclose(C_1, C_2))

A-2:
[[1.0100503  0.         0.         0.         0.        ]
 [0.01758338 1.0202019  0.         0.         0.        ]
 [0.02316096 0.03971679 1.0304568  0.         0.        ]
 [0.02824333 0.04843212 0.06128747 1.0408163  0.        ]
 [0.03333876 0.05716986 0.07234447 0.08306911 1.0512819 ]]
