# HiPPO Matrices
---

## Load Packages

In [17]:
## import packages
import jax
import jax.numpy as jnp

from flax import linen as jnn

from jax.nn.initializers import lecun_normal, uniform
from jax.numpy.linalg import eig, inv, matrix_power
from jax.scipy.signal import convolve

import os
import requests

from scipy import linalg as la
from scipy import signal
from scipy import special as ss

import math

## setup JAX to use TPUs if available
try:
    url = 'http:' + os.environ['TPU_NAME'].split(':')[1] + ':8475/requestversion/tpu_driver_nightly'
    resp = requests.post(url)
    jax.config.FLAGS.jax_xla_backend = 'tpu_driver'
    jax.config.FLAGS.jax_backend_target = os.environ['TPU_NAME']
except:
    pass

jax.devices()

[GpuDevice(id=0, process_index=0)]

In [18]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data as data
import numpy as np
import matplotlib.pyplot as plt

In [19]:
seed = 1701
key = jax.random.PRNGKey(seed)

In [20]:
num_copies = 5
rng, key2, key3, key4, key5 = jax.random.split(key, num=num_copies)

In [21]:
def random_SSM(rng, N):
    a_r, b_r, c_r = jax.random.split(rng, 3)
    A = jax.random.uniform(a_r, (N, N))
    B = jax.random.uniform(b_r, (N, 1))
    C = jax.random.uniform(c_r, (1, N))
    return A, B, C

In [22]:
N=5

## Instantiate The HiPPO Matrix

## Translated Legendre (LegT)

In [23]:
# Translated Legendre (LegT) - vectorized
def build_LegT_V(N, lambda_n=1):
    """
        The, vectorized implementation of the, measure derived from the translated Legendre basis.
        
        Args:
            N (int): Order of coefficients to describe the orthogonal polynomial that is the HiPPO projection.
            legt_type (str): Choice between the two different tilts of basis.
                - legt: translated Legendre - 'legt'
                - lmu: Legendre Memory Unit - 'lmu'
            
        Returns:
            A (jnp.ndarray): The A HiPPO matrix.
            B (jnp.ndarray): The B HiPPO matrix.
            
    """
    q = jnp.arange(N, dtype=jnp.float64)
    k, n = jnp.meshgrid(q, q)
    case = jnp.power(-1.0, (n-k))
    A = None
    B = None
    
    if lambda_n == 1:
        A_base = -jnp.sqrt(2*n+1) * jnp.sqrt(2*k+1)
        pre_D = jnp.sqrt(jnp.diag(2*q+1))
        B = D = jnp.diag(pre_D)[:, None]
        A = jnp.where(k <= n, A_base, A_base * case) # if n >= k, then case_2 * A_base is used, otherwise A_base
        
    elif lambda_n == 2: #(jnp.sqrt(2*n+1) * jnp.power(-1, n)):
        A_base = -(2*n+1)
        B = jnp.diag((2*q+1) * jnp.power(-1, n))[:, None]
        A = jnp.where(k <= n, A_base * case, A_base) # if n >= k, then case_2 * A_base is used, otherwise A_base

    return A, B

In [24]:
# Translated Legendre (LegT) - non-vectorized
def build_LegT(N, legt_type="legt"):
    """
        The, non-vectorized implementation of the, measure derived from the translated Legendre basis
        
        Args:
            N (int): Order of coefficients to describe the orthogonal polynomial that is the HiPPO projection.
            legt_type (str): Choice between the two different tilts of basis.
                - legt: translated Legendre - 'legt'
                - lmu: Legendre Memory Unit - 'lmu'
            
        Returns:
            A (jnp.ndarray): The A HiPPO matrix.
            B (jnp.ndarray): The B HiPPO matrix.
            
    """
    Q = jnp.arange(N, dtype=jnp.float64)
    pre_R = (2*Q + 1)
    k, n = jnp.meshgrid(Q, Q)
        
    if legt_type == "legt":
        R = jnp.sqrt(pre_R)
        A = R[:, None] * jnp.where(n < k, (-1.)**(n-k), 1) * R[None, :]
        B = R[:, None]
        A = -A

        # Halve again for timescale correctness
        # A, B = A/2, B/2
        #A *= 0.5
        #B *= 0.5
        
    elif legt_type == "lmu":
        R = pre_R[:, None]
        A = jnp.where(n < k, -1, (-1.)**(n-k+1)) * R
        B = (-1.)**Q[:, None] * R
        
    return A, B

In [25]:
nv_LegT_A, nv_LegT_B = build_LegT(N=N, legt_type="lmu")
LegT_A, LegT_B = build_LegT_V(N=N, lambda_n=2)
print(f"nv:\n", nv_LegT_A)
print(f"v:\n", LegT_A)
# print(f"nv:\n", nv_LegT_B)
# print(f"v:\n", LegT_B)
print(f"A Comparison:\n ", jnp.allclose(nv_LegT_A, LegT_A))
print(f"B Comparison:\n ", jnp.allclose(nv_LegT_B, LegT_B))

nv:
 [[-1. -1. -1. -1. -1.]
 [ 3. -3. -3. -3. -3.]
 [-5.  5. -5. -5. -5.]
 [ 7. -7.  7. -7. -7.]
 [-9.  9. -9.  9. -9.]]
v:
 [[-1. -1. -1. -1. -1.]
 [ 3. -3. -3. -3. -3.]
 [-5.  5. -5. -5. -5.]
 [ 7. -7.  7. -7. -7.]
 [-9.  9. -9.  9. -9.]]
A Comparison:
  True
B Comparison:
  True


## Translated Laguerre (LagT)

In [26]:
# Translated Laguerre (LagT) - non-vectorized
def build_LagT_V(alpha, beta, N):
    """
        The, vectorized implementation of the, measure derived from the translated Laguerre basis. 
        
        Args:
            alpha (float): The order of the Laguerre basis.
            beta (float): The scale of the Laguerre basis.
            N (int): Order of coefficients to describe the orthogonal polynomial that is the HiPPO projection.
            
        Returns:
            A (jnp.ndarray): The A HiPPO matrix.
            B (jnp.ndarray): The B HiPPO matrix.
            
    """
    L = jnp.exp(.5 * (ss.gammaln(jnp.arange(N)+alpha+1) - ss.gammaln(jnp.arange(N)+1)))
    inv_L = 1./L[:, None]
    pre_A = (jnp.eye(N) * ((1 + beta) / 2)) + jnp.tril(jnp.ones((N, N)), -1)
    pre_B = ss.binom(alpha + jnp.arange(N), jnp.arange(N))[:, None]
    
    A = -inv_L * pre_A * L[None, :]
    B =  jnp.exp(-.5 * ss.gammaln(1-alpha)) * jnp.power(beta, (1-alpha)/2) * inv_L * pre_B 
    
    return A, B

In [27]:
# Translated Laguerre (LagT) - non-vectorized
def build_LagT(alpha, beta, N):
    """
        The, non-vectorized implementation of the, measure derived from the translated Laguerre basis. 
        
        Args:
            alpha (float): The order of the Laguerre basis.
            beta (float): The scale of the Laguerre basis.
            N (int): Order of coefficients to describe the orthogonal polynomial that is the HiPPO projection.
            
        Returns:
            A (jnp.ndarray): The A HiPPO matrix.
            B (jnp.ndarray): The B HiPPO matrix.
            
    """
    A = -jnp.eye(N) * (1 + beta) / 2 - jnp.tril(jnp.ones((N, N)), -1)
    B = ss.binom(alpha + jnp.arange(N), jnp.arange(N))[:, None]

    L = jnp.exp(.5 * (ss.gammaln(jnp.arange(N)+alpha+1) - ss.gammaln(jnp.arange(N)+1)))
    A = (1./L[:, None]) * A * L[None, :]
    B = (1./L[:, None]) * B * jnp.exp(-.5 * ss.gammaln(1-alpha)) * beta**((1-alpha)/2)
    
    return A, B

In [28]:
nv_LagT_A, nv_LagT_B = build_LagT(alpha=0, beta=1, N=N)
LagT_A, LagT_B = build_LagT_V(alpha=0, beta=1, N=N)
print(f"nv:\n{nv_LagT_A}")
print(f"v:\n{LagT_A}")
# print(f"nv:\n{nv_LagT_B}")
# print(f"v:\n{LagT_B}")
print(f"A Comparison:\n ", jnp.allclose(nv_LagT_A, LagT_A))
print(f"B Comparison:\n ", jnp.allclose(nv_LagT_B, LagT_B))

nv:
[[-1. -0. -0. -0. -0.]
 [-1. -1. -0. -0. -0.]
 [-1. -1. -1. -0. -0.]
 [-1. -1. -1. -1. -0.]
 [-1. -1. -1. -1. -1.]]
v:
[[-1. -0. -0. -0. -0.]
 [-1. -1. -0. -0. -0.]
 [-1. -1. -1. -0. -0.]
 [-1. -1. -1. -1. -0.]
 [-1. -1. -1. -1. -1.]]
A Comparison:
  True
B Comparison:
  True


## Scaled Legendre (LegS)

In [29]:
#Scaled Legendre (LegS) vectorized
def build_LegS_V(N):
    """
        The, vectorized implementation of the, measure derived from the Scaled Legendre basis. 
        
        Args:
            N (int): Order of coefficients to describe the orthogonal polynomial that is the HiPPO projection.
            
        Returns:
            A (jnp.ndarray): The A HiPPO matrix.
            B (jnp.ndarray): The B HiPPO matrix.
            
    """
    q = jnp.arange(N, dtype=jnp.float64)
    k, n = jnp.meshgrid(q, q)
    pre_D = jnp.sqrt(jnp.diag(2*q+1))
    B = D = jnp.diag(pre_D)[:, None]
    
    A_base = (-jnp.sqrt(2*n+1)) * jnp.sqrt(2*k+1)
    case_2 = (n+1)/(2*n+1)
    
    A = jnp.where(n > k, A_base, 0.0) # if n > k, then A_base is used, otherwise 0
    A = jnp.where(n == k, (A_base * case_2), A) # if n == k, then A_base is used, otherwise A

    return A, B

In [30]:
#Scaled Legendre (LegS), non-vectorized
def build_LegS(N):
    """
        The, non-vectorized implementation of the, measure derived from the Scaled Legendre basis. 
        
        Args:
            N (int): Order of coefficients to describe the orthogonal polynomial that is the HiPPO projection.
            
        Returns:
            A (jnp.ndarray): The A HiPPO matrix.
            B (jnp.ndarray): The B HiPPO matrix.
            
    """
    q = jnp.arange(N, dtype=jnp.float64)  # q represents the values 1, 2, ..., N each column has
    k, n = jnp.meshgrid(q, q)
    r = 2 * q + 1
    M = -(jnp.where(n >= k, r, 0) - jnp.diag(q)) # represents the state matrix M 
    D = jnp.sqrt(jnp.diag(2 * q + 1)) # represents the diagonal matrix D $D := \text{diag}[(2n+1)^{\frac{1}{2}}]^{N-1}_{n=0}$
    A = D @ M @ jnp.linalg.inv(D)
    B = jnp.diag(D)[:, None]
    B = B.copy() # Otherwise "UserWarning: given NumPY array is not writeable..." after torch.as_tensor(B)
    
    return A, B

In [31]:
nv_LegS_A, nv_LegS_B = build_LegS(N=N)
LegS_A, LegS_B = build_LegS_V(N=N)
print(f"nv:\n{nv_LegS_A}")
print(f"v:\n{LegS_A}")
# print(f"nv:\n{nv_LegS_B}")
# print(f"v:\n{LegS_B}")
print(f"A Comparison:\n ", jnp.allclose(nv_LegS_A, LegS_A))
print(f"B Comparison:\n ", jnp.allclose(nv_LegS_B, LegS_B))

nv:
[[-1.         0.         0.         0.         0.       ]
 [-1.7324219 -1.9997292  0.         0.         0.       ]
 [-2.2363281 -3.873207  -3.0015717  0.         0.       ]
 [-2.6464844 -4.58337   -5.919281  -4.00074    0.       ]
 [-3.        -5.194336  -6.7089844 -7.9365234 -4.9987793]]
v:
[[-1.         0.         0.         0.         0.       ]
 [-1.7320508 -2.         0.         0.         0.       ]
 [-2.2360678 -3.872983  -2.9999995  0.         0.       ]
 [-2.6457512 -4.5825753 -5.916079  -4.         0.       ]
 [-3.        -5.196152  -6.7082033 -7.937254  -5.       ]]
A Comparison:
  False
B Comparison:
  True


## Fourier Basis

In [32]:
# Fourier Basis OPs and functions - vectorized
def build_Fourier_V(N, fourier_type='fru'): 
    """
        Vectorized measure implementations derived from fourier basis. 
        
        Args:
            N (int): Order of coefficients to describe the orthogonal polynomial that is the HiPPO projection.
            fourier_type (str): The type of Fourier measure.
                - FRU: Fourier Recurrent Unit - fru
                - FouT: truncated Fourier - fout
                - fouD: decayed Fourier - foud
            
        Returns:
            A (jnp.ndarray): The A HiPPO matrix.
            B (jnp.ndarray): The B HiPPO matrix.
            
    """   
    q = jnp.arange((N//2)*2, dtype=jnp.float64)
    k, n = jnp.meshgrid(q, q)
    
    n_odd = n % 2 == 0
    k_odd = k % 2 == 0
    
    case_1 = (n==k) & (n==0)
    case_2_3 = ((k==0) & (n_odd)) | ((n==0) & (k_odd))
    case_4 = (n_odd) & (k_odd)
    case_5 = (n-k==1) & (k_odd)
    case_6 = (k-n==1) & (n_odd)
    
    A = None
    B = None
    
    if fourier_type == "fru": # Fourier Recurrent Unit (FRU) - vectorized
        A = jnp.diag(jnp.stack([jnp.zeros(N//2), jnp.zeros(N//2)], axis=-1).reshape(-1))
        B = jnp.zeros(A.shape[1], dtype=jnp.float64)
        q = jnp.arange((N//2)*2, dtype=jnp.float64)
        
        A = jnp.where(case_1, -1.0, 
                    jnp.where(case_2_3, -jnp.sqrt(2),
                                jnp.where(case_4, -2, 
                                        jnp.where(case_5, jnp.pi * (n//2), 
                                                    jnp.where(case_6, -jnp.pi * (k//2), 0.0)))))
        
        B = B.at[::2].set(jnp.sqrt(2))
        B = B.at[0].set(1)
        
    elif fourier_type == "fout": # truncated Fourier (FouT) - vectorized
        A = jnp.diag(jnp.stack([jnp.zeros(N//2), jnp.zeros(N//2)], axis=-1).reshape(-1))
        B = jnp.zeros(A.shape[1], dtype=jnp.float64)
        k, n = jnp.meshgrid(q, q)
        n_odd = n % 2 == 0
        k_odd = k % 2 == 0
        
        A = jnp.where(case_1, -1.0, 
                    jnp.where(case_2_3, -jnp.sqrt(2),
                                jnp.where(case_4, -2, 
                                        jnp.where(case_5, jnp.pi * (n//2), 
                                                    jnp.where(case_6, -jnp.pi * (k//2), 0.0)))))
        
        B = B.at[::2].set(jnp.sqrt(2))
        B = B.at[0].set(1)
        
        A = 2 * A
        B = 2 * B
        
    elif fourier_type == "fourd":
        A = jnp.diag(jnp.stack([jnp.zeros(N//2), jnp.zeros(N//2)], axis=-1).reshape(-1))
        B = jnp.zeros(A.shape[1], dtype=jnp.float64)
        
        A = jnp.where(case_1, -1.0, 
                    jnp.where(case_2_3, -jnp.sqrt(2),
                                jnp.where(case_4, -2, 
                                        jnp.where(case_5, 2 * jnp.pi * (n//2), 
                                                    jnp.where(case_6, 2 * -jnp.pi * (k//2), 0.0)))))
        
        B = B.at[::2].set(jnp.sqrt(2))
        B = B.at[0].set(1)
        
        A = 0.5 * A
        B = 0.5 * B
        
    
    
    B = B[:, None]
        
    return A, B

In [33]:
def build_Fourier(N, fourier_type='fru'):
    """
        Non-vectorized measure implementations derived from fourier basis. 
        
        Args:
            N (int): Order of coefficients to describe the orthogonal polynomial that is the HiPPO projection.
            fourier_type (str): The type of Fourier measure.
                - FRU: Fourier Recurrent Unit - fru
                - FouT: truncated Fourier - fout
                - fouD: decayed Fourier - foud
            
        Returns:
            A (jnp.ndarray): The A HiPPO matrix.
            B (jnp.ndarray): The B HiPPO matrix.
            
    """
    freqs = jnp.arange(N//2)
    
    if fourier_type == "fru": # Fourier Recurrent Unit (FRU) - non-vectorized
        d = jnp.stack([jnp.zeros(N//2), freqs], axis=-1).reshape(-1)[1:]
        A = jnp.pi*(-jnp.diag(d, 1) + jnp.diag(d, -1))
        
        B = jnp.zeros(A.shape[1])
        B = B.at[0::2].set(jnp.sqrt(2))
        B = B.at[0].set(1)
        
        A = A - B[:, None] * B[None, :]
        B = B[:, None]

    elif fourier_type == "fout": # truncated Fourier (FouT) - non-vectorized
        freqs *= 2
        d = jnp.stack([jnp.zeros(N//2), freqs], axis=-1).reshape(-1)[1:]
        A = jnp.pi*(-jnp.diag(d, 1) + jnp.diag(d, -1))
        
        B = jnp.zeros(A.shape[1])
        B = B.at[0::2].set(jnp.sqrt(2))
        B = B.at[0].set(1)

        # Subtract off rank correction - this corresponds to the other endpoint u(t-1) in this case
        A = A - B[:, None] * B[None, :] * 2
        B = B[:, None] * 2
        
    elif fourier_type == "fourd":
        d = jnp.stack([jnp.zeros(N//2), freqs], axis=-1).reshape(-1)[1:]
        A = jnp.pi*(-jnp.diag(d, 1) + jnp.diag(d, -1))
        
        B = jnp.zeros(A.shape[1])
        B = B.at[0::2].set(jnp.sqrt(2))
        B = B.at[0].set(1)

        # Subtract off rank correction - this corresponds to the other endpoint u(t-1) in this case
        A = A - 0.5 * B[:, None] * B[None, :]
        B = 0.5 * B[:, None]
        
            
    return A, B

In [34]:
nv_Fourier_A, nv_Fourier_B = build_Fourier(N=N, fourier_type='fru')
Fourier_A, Fourier_B = build_Fourier_V(N=N, fourier_type='fru')
print(f"nv:\n{nv_Fourier_A}")
print(f"v:\n{Fourier_A}")
# print(f"nv:\n{nv_Fourier_B}")
# print(f"v:\n{Fourier_B}")
print(f"A Comparison:\n {jnp.allclose(nv_Fourier_A, Fourier_A)}")
print(f"B Comparison:\n {jnp.allclose(nv_Fourier_B, Fourier_B)}")

  lax_internal._check_user_dtype_supported(dtype, "zeros")


nv:
[[-1.         0.        -1.4142135  0.       ]
 [ 0.         0.         0.         0.       ]
 [-1.4142135  0.        -1.9999999 -3.1415927]
 [ 0.         0.         3.1415927  0.       ]]
v:
[[-1.        -0.        -1.4142135  0.       ]
 [ 0.         0.         0.         0.       ]
 [-1.4142135  0.        -2.        -3.1415927]
 [ 0.         0.         3.1415927  0.       ]]
A Comparison:
 True
B Comparison:
 True


In [35]:
def make_HiPPO(N, v='nv', measure="legs", lambda_n=1, fourier_type="fru", alpha=0, beta=1):
    """
        Instantiates the HiPPO matrix of a given order using a particular measure. 
        Args:
            N (int): Order of coefficients to describe the orthogonal polynomial that is the HiPPO projection.
            v (str): choose between this repo's implementation or hazy research's implementation.
            measure (str): 
                choose between 
                    - HiPPO w/ Translated Legendre (LegT) - legt
                    - HiPPO w/ Translated Laguerre (LagT) - lagt
                    - HiPPO w/ Scaled Legendre (LegS) - legs
                    - HiPPO w/ Fourier basis - fourier
                        - FRU: Fourier Recurrent Unit 
                        - FouT: Translated Fourier 
            lambda_n (int): The amount of tilt applied to the HiPPO-LegS basis, determines between LegS and LMU. 
            fourier_type (str): chooses between the following:
                - FRU: Fourier Recurrent Unit - fru
                - FouT: Translated Fourier - fout
                - FourD: Fourier Decay - fourd
            alpha (float): The order of the Laguerre basis.
            beta (float): The scale of the Laguerre basis.
            
        Returns:
            A (jnp.ndarray): The HiPPO matrix multiplied by -1.
            B (jnp.ndarray): The other corresponding state space matrix. 
            
    """
    A = None
    B = None
    if measure == "legt":
        if v == 'nv':
            A, B = build_LegT(N=N, lambda_n=lambda_n)
        else:
            A, B = build_LegT_V(N=N, lambda_n=lambda_n) 
        
    elif measure == "lagt":
        if v == 'nv':
            A, B = build_LagT(alpha=alpha, beta=beta, N=N)
        else:
            A, B = build_LagT_V(alpha=alpha, beta=beta, N=N)
        
    elif measure == "legs":
        if v == 'nv':
            A, B = build_LegS(N=N)
        else:
            A, B = build_LegS_V(N=N)
        
    elif measure == "fourier":
        if v == 'nv':
            A, B = build_Fourier(N=N, fourier_type=fourier_type)
        else:
            A, B = build_Fourier_V(N=N, fourier_type=fourier_type)
        
    elif measure == "random":
        A = jnp.random.randn(N, N) / N
        B = jnp.random.randn(N, 1)
        
    elif measure == "diagonal":
        A = -jnp.diag(jnp.exp(jnp.random.randn(N)))
        B = jnp.random.randn(N, 1)
        
    else:
        raise ValueError("Invalid HiPPO type")
    
    A_copy = A.copy()
    B_copy = B.copy()
    
    return jnp.array(A_copy), B_copy

In [36]:
nv_A, nv_B = make_HiPPO(N=N, v='nv', measure="legs", lambda_n=1, fourier_type="fru", alpha=0, beta=1)
v_A, v_B = make_HiPPO(N=N, v='v', measure="legs", lambda_n=1, fourier_type="fru", alpha=0, beta=1)
print(f"nv:\n", nv_A)
print(f"v:\n", v_A)
# print(f"nv:\n", nv_B)
# print(f"v:\n", v_B)
print(f"A Comparison:\n ", jnp.allclose(nv_A, v_A))
print(f"B Comparison:\n ", jnp.allclose(nv_B, v_B))

nv:
 [[-1.         0.         0.         0.         0.       ]
 [-1.7324219 -1.9997292  0.         0.         0.       ]
 [-2.2363281 -3.873207  -3.0015717  0.         0.       ]
 [-2.6464844 -4.58337   -5.919281  -4.00074    0.       ]
 [-3.        -5.194336  -6.7089844 -7.9365234 -4.9987793]]
v:
 [[-1.         0.         0.         0.         0.       ]
 [-1.7320508 -2.         0.         0.         0.       ]
 [-2.2360678 -3.872983  -2.9999995  0.         0.       ]
 [-2.6457512 -4.5825753 -5.916079  -4.         0.       ]
 [-3.        -5.196152  -6.7082033 -7.937254  -5.       ]]
A Comparison:
  False
B Comparison:
  True


  lax_internal._check_user_dtype_supported(dtype, "arange")


In [37]:
""" Old utilities for parallel scan implementation of Linear RNNs. """

import torch.nn.functional as F

### Utilities


def shift_up(a, s=None, drop=True, dim=0):
    assert dim == 0
    if s is None:
        s = torch.zeros_like(a[0, ...])
    s = s.unsqueeze(dim)
    if drop:
        a = a[:-1, ...]
    return torch.cat((s, a), dim=dim)

def interleave(a, b, uneven=False, dim=0):
    """ Interleave two tensors of same shape """
    # assert(a.shape == b.shape)
    assert dim == 0 # TODO temporary to make handling uneven case easier
    if dim < 0:
        dim = N + dim
    if uneven:
        a_ = a[-1:, ...]
        a = a[:-1, ...]
    c = torch.stack((a, b), dim+1)
    out_shape = list(a.shape)
    out_shape[dim] *= 2
    c = c.view(out_shape)
    if uneven:
        c = torch.cat((c, a_), dim=dim)
    return c

def batch_mult(A, u, has_batch=None):
    """ Matrix mult A @ u with special case to save memory if u has additional batch dim

    The batch dimension is assumed to be the second dimension
    A : (L, ..., N, N)
    u : (L, [B], ..., N)
    has_batch: True, False, or None. If None, determined automatically

    Output:
    x : (L, [B], ..., N)
      A @ u broadcasted appropriately
    """

    if has_batch is None:
        has_batch = len(u.shape) >= len(A.shape)

    if has_batch:
        u = u.permute([0] + list(range(2, len(u.shape))) + [1])
    else:
        u = u.unsqueeze(-1)
    v = (A @ u)
    if has_batch:
        v = v.permute([0] + [len(u.shape)-1] + list(range(1, len(u.shape)-1)))
    else:
        v = v[..., 0]
    return v



### Main unrolling functions

def unroll(A, u):
    """
    A : (..., N, N) # TODO I think this can't take batch dimension?
    u : (L, ..., N)
    output : x (..., N) # TODO a lot of these shapes are wrong
    x[i, ...] = A^{i} @ u[0, ...] + ... + A @ u[i-1, ...] + u[i, ...]
    """

    m = u.new_zeros(u.shape[1:])
    outputs = []
    for u_ in torch.unbind(u, dim=0):
        m = F.linear(m, A) + u_
        outputs.append(m)

    output = torch.stack(outputs, dim=0)
    return output


def parallel_unroll_recursive(A, u):
    """ Bottom-up divide-and-conquer version of unroll. """

    # Main recursive function
    def parallel_unroll_recursive_(A, u):
        if u.shape[0] == 1:
            return u

        u_evens = u[0::2, ...]
        u_odds = u[1::2, ...]

        # u2 = F.linear(u_evens, A) + u_odds
        u2 = (A @ u_evens.unsqueeze(-1)).squeeze(-1) + u_odds
        A2 = A @ A

        x_odds = parallel_unroll_recursive_(A2, u2)
        # x_evens = F.linear(shift_up(x_odds), A) + u_evens
        x_evens = (A @ shift_up(x_odds).unsqueeze(-1)).squeeze(-1) + u_evens

        x = interleave(x_evens, x_odds, dim=0)
        return x

    # Pad u to power of 2
    n = u.shape[0]
    m = int(math.ceil(math.log(n)/math.log(2)))
    N = 1 << m
    u = torch.cat((u, u.new_zeros((N-u.shape[0],) + u.shape[1:] )), dim=0)

    return parallel_unroll_recursive_(A, u)[:n, ...]



def parallel_unroll_recursive_br(A, u):
    """ Same as parallel_unroll_recursive but uses bit reversal for locality. """

    # Main recursive function
    def parallel_unroll_recursive_br_(A, u):
        n = u.shape[0]
        if n == 1:
            return u

        m = n//2
        u_0 = u[:m, ...]
        u_1 = u[m:, ...]

        u2 = F.linear(u_0, A) + u_1
        A2 = A @ A

        x_1 = parallel_unroll_recursive_br_(A2, u2)
        x_0 = F.linear(shift_up(x_1), A) + u_0

        # x = torch.cat((x_0, x_1), dim=0) # is there a way to do this with cat?
        x = interleave(x_0, x_1, dim=0)
        return x

    # Pad u to power of 2
    n = u.shape[0]
    m = int(math.ceil(math.log(n)/math.log(2)))
    N = 1 << m
    u = torch.cat((u, u.new_zeros((N-u.shape[0],) + u.shape[1:] )), dim=0)

    # Apply bit reversal
    br = bitreversal_po2(N)
    u = u[br, ...]

    x = parallel_unroll_recursive_br_(A, u)
    return x[:n, ...]

def parallel_unroll_iterative(A, u):
    """ Bottom-up divide-and-conquer version of unroll, implemented iteratively """

    # Pad u to power of 2
    n = u.shape[0]
    m = int(math.ceil(math.log(n)/math.log(2)))
    N = 1 << m
    u = torch.cat((u, u.new_zeros((N-u.shape[0],) + u.shape[1:] )), dim=0)

    # Apply bit reversal
    br = bitreversal_po2(N)
    u = u[br, ...]

    # Main recursive loop, flattened
    us = [] # stores the u_0 terms in the recursive version
    N_ = N
    As = [] # stores the A matrices
    for l in range(m):
        N_ = N_ // 2
        As.append(A)
        u_0 = u[:N_, ...]
        us.append(u_0)
        u = F.linear(u_0, A) + u[N_:, ...]
        A = A @ A
    x_0 = []
    x = u # x_1
    for l in range(m-1, -1, -1):
        x_0 = F.linear(shift_up(x), As[l]) + us[l]
        x = interleave(x_0, x, dim=0)

    return x[:n, ...]


def variable_unroll_sequential(A, u, s=None, variable=True):
    """ Unroll with variable (in time/length) transitions A.

    A : ([L], ..., N, N) dimension L should exist iff variable is True
    u : (L, [B], ..., N) updates
    s : ([B], ..., N) start state
    output : x (..., N)
    x[i, ...] = A[i]..A[0] @ s + A[i..1] @ u[0] + ... + A[i] @ u[i-1] + u[i]
    """

    if s is None:
        s = torch.zeros_like(u[0])

    if not variable:
        A = A.expand((u.shape[0],) + A.shape)
    has_batch = len(u.shape) >= len(A.shape)

    outputs = []
    for (A_, u_) in zip(torch.unbind(A, dim=0), torch.unbind(u, dim=0)):
        s = batch_mult(A_.unsqueeze(0), s.unsqueeze(0), has_batch)[0]
        s = s + u_
        outputs.append(s)

    output = torch.stack(outputs, dim=0)
    return output



def variable_unroll(A, u, s=None, variable=True, recurse_limit=16):
    """ Bottom-up divide-and-conquer version of variable_unroll. """

    if u.shape[0] <= recurse_limit:
        return variable_unroll_sequential(A, u, s, variable)

    if s is None:
        s = torch.zeros_like(u[0])

    uneven = u.shape[0] % 2 == 1
    has_batch = len(u.shape) >= len(A.shape)

    u_0 = u[0::2, ...]
    u_1  = u[1::2, ...]

    if variable:
        A_0 = A[0::2, ...]
        A_1  = A[1::2, ...]
    else:
        A_0 = A
        A_1 = A

    u_0_ = u_0
    A_0_ = A_0
    if uneven:
        u_0_ = u_0[:-1, ...]
        if variable:
            A_0_ = A_0[:-1, ...]

    u_10 = batch_mult(A_1, u_0_, has_batch)
    u_10 = u_10 + u_1
    A_10 = A_1 @ A_0_

    # Recursive call
    x_1 = variable_unroll(A_10, u_10, s, variable, recurse_limit)

    x_0 = shift_up(x_1, s, drop=not uneven)
    x_0 = batch_mult(A_0, x_0, has_batch)
    x_0 = x_0 + u_0


    x = interleave(x_0, x_1, uneven, dim=0) # For some reason this interleave is slower than in the (non-multi) unroll_recursive
    return x

def variable_unroll_general_sequential(A, u, s, op, variable=True):
    """ Unroll with variable (in time/length) transitions A with general associative operation

    A : ([L], ..., N, N) dimension L should exist iff variable is True
    u : (L, [B], ..., N) updates
    s : ([B], ..., N) start state
    output : x (..., N)
    x[i, ...] = A[i]..A[0] s + A[i..1] u[0] + ... + A[i] u[i-1] + u[i]
    """

    if not variable:
        A = A.expand((u.shape[0],) + A.shape)

    outputs = []
    for (A_, u_) in zip(torch.unbind(A, dim=0), torch.unbind(u, dim=0)):
        s = op(A_, s)
        s = s + u_
        outputs.append(s)

    output = torch.stack(outputs, dim=0)
    return output

def variable_unroll_matrix_sequential(A, u, s=None, variable=True):
    if s is None:
        s = torch.zeros_like(u[0])

    if not variable:
        A = A.expand((u.shape[0],) + A.shape)
    # has_batch = len(u.shape) >= len(A.shape)

    # op = lambda x, y: batch_mult(x.unsqueeze(0), y.unsqueeze(0), has_batch)[0]
    op = lambda x, y: batch_mult(x.unsqueeze(0), y.unsqueeze(0))[0]

    return variable_unroll_general_sequential(A, u, s, op, variable=True)

def variable_unroll_toeplitz_sequential(A, u, s=None, variable=True, pad=False):
    if s is None:
        s = torch.zeros_like(u[0])

    if not variable:
        A = A.expand((u.shape[0],) + A.shape)
    # has_batch = len(u.shape) >= len(A.shape)

    # op = lambda x, y: batch_mult(x.unsqueeze(0), y.unsqueeze(0), has_batch)[0]
    # op = lambda x, y: batch_mult(x.unsqueeze(0), y.unsqueeze(0))[0]

    if pad:
        n = A.shape[-1]
        A = F.pad(A, (0, n))
        u = F.pad(u, (0, n))
        s = F.pad(s, (0, n))
        ret = variable_unroll_general_sequential(A, u, s, triangular_toeplitz_multiply_padded, variable=True)
        ret = ret[..., :n]
        return ret

    return variable_unroll_general_sequential(A, u, s, triangular_toeplitz_multiply, variable=True)



### General parallel scan functions with generic binary composition operators

def variable_unroll_general(A, u, s, op, compose_op=None, sequential_op=None, variable=True, recurse_limit=16):
    """ Bottom-up divide-and-conquer version of variable_unroll.

    compose is an optional function that defines how to compose A without multiplying by a leaf u
    """

    if u.shape[0] <= recurse_limit:
        if sequential_op is None:
            sequential_op = op
        return variable_unroll_general_sequential(A, u, s, sequential_op, variable)

    if compose_op is None:
        compose_op = op

    uneven = u.shape[0] % 2 == 1
    has_batch = len(u.shape) >= len(A.shape)

    u_0 = u[0::2, ...]
    u_1 = u[1::2, ...]

    if variable:
        A_0 = A[0::2, ...]
        A_1 = A[1::2, ...]
    else:
        A_0 = A
        A_1 = A

    u_0_ = u_0
    A_0_ = A_0
    if uneven:
        u_0_ = u_0[:-1, ...]
        if variable:
            A_0_ = A_0[:-1, ...]

    u_10 = op(A_1, u_0_) # batch_mult(A_1, u_0_, has_batch)
    u_10 = u_10 + u_1
    A_10 = compose_op(A_1, A_0_)

    # Recursive call
    x_1 = variable_unroll_general(A_10, u_10, s, op, compose_op, sequential_op, variable=variable, recurse_limit=recurse_limit)

    x_0 = shift_up(x_1, s, drop=not uneven)
    x_0 = op(A_0, x_0) # batch_mult(A_0, x_0, has_batch)
    x_0 = x_0 + u_0


    x = interleave(x_0, x_1, uneven, dim=0) # For some reason this interleave is slower than in the (non-multi) unroll_recursive
    return x

def variable_unroll_matrix(A, u, s=None, variable=True, recurse_limit=16):
    if s is None:
        s = torch.zeros_like(u[0])
    has_batch = len(u.shape) >= len(A.shape)
    op = lambda x, y: batch_mult(x, y, has_batch)
    sequential_op = lambda x, y: batch_mult(x.unsqueeze(0), y.unsqueeze(0), has_batch)[0]
    matmul = lambda x, y: x @ y
    return variable_unroll_general(A, u, s, op, compose_op=matmul, sequential_op=sequential_op, variable=variable, recurse_limit=recurse_limit)

def variable_unroll_toeplitz(A, u, s=None, variable=True, recurse_limit=8, pad=False):
    """ Unroll with variable (in time/length) transitions A with general associative operation

    A : ([L], ..., N) dimension L should exist iff variable is True
    u : (L, [B], ..., N) updates
    s : ([B], ..., N) start state
    output : x (L, [B], ..., N) same shape as u
    x[i, ...] = A[i]..A[0] s + A[i..1] u[0] + ... + A[i] u[i-1] + u[i]
    """
    # Add the batch dimension to A if necessary
    A_batch_dims = len(A.shape) - int(variable)
    u_batch_dims = len(u.shape)-1
    if u_batch_dims > A_batch_dims:
        # assert u_batch_dims == A_batch_dims + 1
        if variable:
            while len(A.shape) < len(u.shape):
                A = A.unsqueeze(1)
        # else:
        #     A = A.unsqueeze(0)

    if s is None:
        s = torch.zeros_like(u[0])

    if pad:
        n = A.shape[-1]
        A = F.pad(A, (0, n))
        u = F.pad(u, (0, n))
        s = F.pad(s, (0, n))
        op = triangular_toeplitz_multiply_padded()
        ret = variable_unroll_general(A, u, s, op, compose_op=op, variable=variable, recurse_limit=recurse_limit)
        ret = ret[..., :n]
        return ret

    op = triangular_toeplitz_multiply
    ret = variable_unroll_general(A, u, s, op, compose_op=op, variable=variable, recurse_limit=recurse_limit)
    return ret



### Testing

def test_correctness():
    print("Testing Correctness\n====================")

    # Test sequential unroll
    L = 3
    A = torch.Tensor([[1, 1], [1, 0]])
    u = torch.ones((L, 2))
    x = unroll(A, u)
    assert torch.isclose(x, torch.Tensor([[1., 1.], [3., 2.], [6., 4.]])).all()

    # Test utilities
    assert torch.isclose(shift_up(x), torch.Tensor([[0., 0.], [1., 1.], [3., 2.]])).all()
    assert torch.isclose(interleave(x, x), torch.Tensor([[1., 1.], [1., 1.], [3., 2.], [3., 2.], [6., 4.], [6., 4.]])).all()

    # Test parallel unroll
    x = parallel_unroll_recursive(A, u)
    assert torch.isclose(x, torch.Tensor([[1., 1.], [3., 2.], [6., 4.]])).all()

    # Powers
    L = 12
    A = torch.Tensor([[1, 0, 0], [2, 1, 0], [3, 3, 1]])
    u = torch.ones((L, 3))
    x = parallel_unroll_recursive(A, u)
    print("recursive", x)
    x = parallel_unroll_recursive_br(A, u)
    print("recursive_br", x)
    x = parallel_unroll_iterative(A, u)
    print("iterative_br", x)


    A = A.repeat((L, 1, 1))
    s = torch.zeros(3)
    print("A shape", A.shape)
    x = variable_unroll_sequential(A, u, s)
    print("variable_unroll", x)
    x = variable_unroll(A, u, s)
    print("parallel_variable_unroll", x)


def generate_data(L, N, B=None, cuda=True):
    A = torch.eye(N) + torch.normal(0, 1, size=(N, N)) / (N**.5) / L
    u = torch.normal(0, 1, size=(L, B, N))


    # device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')
    device = torch.device('cuda:0') if cuda else torch.device('cpu')
    A = A.to(device)
    u = u.to(device)
    return A, u

def test_stability():
    print("Testing Stability\n====================")
    L = 256
    N = L // 2
    B = 100
    A, u = generate_data(L, N, B)

    x = unroll(A, u)
    x1 = parallel_unroll_recursive(A, u)
    x2 = parallel_unroll_recursive_br(A, u)
    x3 = parallel_unroll_iterative(A, u)
    print("norm error", torch.norm(x-x1))
    print("norm error", torch.norm(x-x2))
    print("norm error", torch.norm(x-x3))
    
    print("max error", torch.max(torch.abs(x-x1)))
    print("max error", torch.max(torch.abs(x-x2)))
    print("max error", torch.max(torch.abs(x-x3)))

    A = A.repeat((L, 1, 1))
    x = variable_unroll_sequential(A, u)
    x_ = variable_unroll(A, u)
    # x_ = variable_unroll_matrix_sequential(A, u)
    x_ = variable_unroll_matrix(A, u)
    print(x-x_)
    abserr = torch.abs(x-x_)
    relerr = abserr/(torch.abs(x)+1e-8)
    print("norm abs error", torch.norm(abserr))
    print("max abs error", torch.max(abserr))
    print("norm rel error", torch.norm(relerr))
    print("max rel error", torch.max(relerr))

def test_toeplitz():
    from model.toeplitz import construct_toeplitz
    def summarize(name, x, x_, showdiff=False):
        print(name, "stats")
        if showdiff:
            print(x-x_)
        abserr = torch.abs(x-x_)
        relerr = abserr/(torch.abs(x)+1e-8)
        print("  norm abs error", torch.norm(abserr))
        print("  max abs error", torch.max(abserr))
        print("  norm rel error", torch.norm(relerr))
        print("  max rel error", torch.max(relerr))

    print("Testing Toeplitz\n====================")
    L = 512
    N = L // 2
    B = 100
    A, u = generate_data(L, N, B)

    A = A[..., 0]
    A = construct_toeplitz(A)

    # print("SHAPES", A.shape, u.shape)

    # Static A
    x = unroll(A, u)
    x_ = variable_unroll(A, u, variable=False)
    summarize("nonvariable matrix original", x, x_, showdiff=False)
    x_ = variable_unroll_matrix(A, u, variable=False)
    summarize("nonvariable matrix general", x, x_, showdiff=False)
    x_ = variable_unroll_toeplitz(A[..., 0], u, variable=False)
    summarize("nonvariable toeplitz", x, x_, showdiff=False)

    # Sequential
    A = A.repeat((L, 1, 1))
    for _ in range(1):
        x_ = variable_unroll_sequential(A, u)
        summarize("variable unroll sequential", x, x_, showdiff=False)
        x_ = variable_unroll_matrix_sequential(A, u)
        summarize("variable matrix sequential", x, x_, showdiff=False)
        x_ = variable_unroll_toeplitz_sequential(A[..., 0], u, pad=True)
        summarize("variable toeplitz sequential", x, x_, showdiff=False)

    # Parallel
    for _ in range(1):
        x_ = variable_unroll(A, u)
        summarize("variable matrix original", x, x_, showdiff=False)
        x_ = variable_unroll_matrix(A, u)
        summarize("variable matrix general", x, x_, showdiff=False)
        x_ = variable_unroll_toeplitz(A[..., 0], u, pad=True, recurse_limit=8)
        summarize("variable toeplitz", x, x_, showdiff=False)

def test_speed(variable=False, it=1):
    print("Testing Speed\n====================")
    N = 256
    L = 1024
    B = 100
    A, u = generate_data(L, N, B)
    As = A.repeat((L, 1, 1))

    u.requires_grad=True
    As.requires_grad=True
    for _ in range(it):
        x = unroll(A, u)
        x = torch.sum(x)
        x.backward()

        x = parallel_unroll_recursive(A, u)
        x = torch.sum(x)
        x.backward()

        # parallel_unroll_recursive_br(A, u)
        # parallel_unroll_iterative(A, u)

    for _ in range(it):
        if variable:
            x = variable_unroll_sequential(As, u, variable=True, recurse_limit=16)
            x = torch.sum(x)
            x.backward()
            x = variable_unroll(As, u, variable=True, recurse_limit=16)
            x = torch.sum(x)
            x.backward()
        else:
            variable_unroll_sequential(A, u, variable=False, recurse_limit=16)
            variable_unroll(A, u, variable=False, recurse_limit=16)


In [38]:
class HiPPO_LegT(nn.Module):
    def __init__(self, N, dt=1.0, discretization='bilinear'):
        """
        N: the order of the HiPPO projection
        dt: discretization step size - should be roughly inverse to the length of the sequence
        """
        super().__init__()
        self.N = N
        #A, B = transition('lmu', N)
        A, B = make_HiPPO(N=self.N, v='v', measure='legt', lambda_n=1, fourier_type="fru", alpha=0, beta=1)
        C = np.ones((1, N))
        D = np.zeros((1,))
        # dt, discretization options
        A, B, _, _, _ = signal.cont2discrete((A, B, C, D), dt=dt, method=discretization)

        B = B.squeeze(-1)

        self.register_buffer('A', torch.Tensor(A)) # (N, N)
        self.register_buffer('B', torch.Tensor(B)) # (N,)

        # vals = np.linspace(0.0, 1.0, 1./dt)
        vals = np.arange(0.0, 1.0, dt)
        self.eval_matrix = torch.Tensor(ss.eval_legendre(np.arange(N)[:, None], 1 - 2 * vals).T)

    def forward(self, inputs):
        """
        inputs : (length, ...)
        output : (length, ..., N) where N is the order of the HiPPO projection
        """

        inputs = inputs.unsqueeze(-1)
        u = inputs * self.B # (length, ..., N)

        c = torch.zeros(u.shape[1:])
        cs = []
        for f in inputs:
            c = F.linear(c, self.A) + self.B * f
            cs.append(c)
        return torch.stack(cs, dim=0)

    def reconstruct(self, c):
        return (self.eval_matrix @ c.unsqueeze(-1)).squeeze(-1)

In [39]:
class HiPPO_LegS(nn.Module):
    """ Vanilla HiPPO-LegS model (scale invariant instead of time invariant) """
    def __init__(self, N, max_length=1024, measure='legs', discretization='bilinear'):
        """
        max_length: maximum sequence length
        """
        super().__init__()
        self.N = N
        A, B = make_HiPPO(N=self.N, v='v', measure=measure, lambda_n=1, fourier_type="fru", alpha=0, beta=1)
        #A, B = transition(measure, N)
        B = B.squeeze(-1)
        A_stacked = np.empty((max_length, N, N), dtype=A.dtype)
        B_stacked = np.empty((max_length, N), dtype=B.dtype)
        for t in range(1, max_length + 1):
            At = A / t
            Bt = B / t
            if discretization == 'forward':
                A_stacked[t - 1] = np.eye(N) + At
                B_stacked[t - 1] = Bt
            elif discretization == 'backward':
                A_stacked[t - 1] = la.solve_triangular(np.eye(N) - At, np.eye(N), lower=True)
                B_stacked[t - 1] = la.solve_triangular(np.eye(N) - At, Bt, lower=True)
            elif discretization == 'bilinear':
                A_stacked[t - 1] = np.linalg.lstsq(np.eye(N) - At / 2,  np.eye(N) + At / 2, rcond=None)[0] # TODO: Referencing this: https://stackoverflow.com/questions/64527098/numpy-linalg-linalgerror-singular-matrix-error-when-trying-to-solve 
                B_stacked[t - 1] = np.linalg.lstsq(np.eye(N) - At / 2,  Bt, rcond=None)[0]
            else: # ZOH
                A_stacked[t - 1] = la.expm(A * (math.log(t + 1) - math.log(t)))
                B_stacked[t - 1] = la.solve_triangular(A, A_stacked[t - 1] @ B - B, lower=True)
        self.A_stacked = torch.Tensor(A_stacked.copy()) # (max_length, N, N)
        self.B_stacked = torch.Tensor(B_stacked.copy()) # (max_length, N)
        vals = np.linspace(0.0, 1.0, max_length)
        self.eval_matrix = torch.from_numpy(np.asarray(((B[:, None] * ss.eval_legendre(np.arange(N)[:, None], 2 * vals - 1)).T)))

    def forward(self, inputs, fast=False):
        """
        inputs : (length, ...)
        output : (length, ..., N) where N is the order of the HiPPO projection
        """
        result = None
        
        L = inputs.shape[0]

        inputs = inputs.unsqueeze(-1)
        u = torch.transpose(inputs, 0, -2)
        u = u * self.B_stacked[:L]
        u = torch.transpose(u, 0, -2) # (length, ..., N)

        if fast:
            result = variable_unroll_matrix(self.A_stacked[:L], u)
            
        else:
            result = variable_unroll_matrix_sequential(self.A_stacked[:L], u)
            
        return result

    def reconstruct(self, c):
        a = self.eval_matrix @ c.unsqueeze(-1)
        return a.squeeze(-1)

In [40]:
class HiPPO(jnn.Module):
    '''
    class that constructs HiPPO model using the defined measure. 
    
    Args:
        N (int): order of the HiPPO projection, aka the number of coefficients to describe the matrix
        max_length (int): maximum sequence length to be input
        measure (str): the measure used to define which way to instantiate the HiPPO matrix
        step (float): step size used for descretization
        GBT_alpha (float): represents which descretization transformation to use based off the alpha value
        seq_L (int): length of the sequence to be used for training
        v (str): choice of vectorized or non-vectorized function instantiation 
            - 'v': vectorized
            - 'nv': non-vectorized
        lambda_n (float): value associated with the tilt of legt
            - 1: tilt on legt
            - \sqrt(2n+1)(-1)^{N}: tilt associated with the legendre memory unit (LMU)
        fourier_type (str): choice of fourier measures
            - fru: fourier recurrent unit measure (FRU) - 'fru'
            - fout: truncated Fourier (FouT) - 'fout'
            - fourd: decaying fourier transform - 'fourd'
        alpha (float): The order of the Laguerre basis.
        beta (float): The scale of the Laguerre basis.
    '''
    N: int 
    max_length: int 
    measure: str 
    step: float 
    GBT_alpha: float 
    seq_L: int 
    v: str 
    lambda_n: float
    fourier_type: str
    alpha: float
    beta: float
    
    def setup(self):
        A, B = make_HiPPO(N=self.N,
                          v=self.v, 
                          measure=self.measure, 
                          lambda_n=self.lambda_n, 
                          fourier_type=self.fourier_type,
                          alpha=self.alpha,
                          beta=self.beta)
        
        self.A = A
        self.B = B #.squeeze(-1)
        self.C = jnp.ones((1, self.N)).squeeze(0)
        self.D = jnp.zeros((1,))
        
        if self.measure == "legt":
            L = self.seq_L
            vals = jnp.arange(0.0, 1.0, L)
            n = jnp.arange(self.N)[:, None]
            x = 1 - 2 * vals
            self.eval_matrix = ss.eval_legendre(n, x).T
            
        elif self.measure == "legs":
            L = self.max_length
            vals = jnp.linspace(0.0, 1.0, L)
            n = jnp.arange(self.N)[:, None]
            x =  2 * vals - 1
            self.eval_matrix = (B[:, None] * ss.eval_legendre(n, x)).T
        else:
            raise ValueError("invalid measure")
        
    def __call__(self, u, kernel=False):
        if not kernel:
            Ab, Bb, Cb, Db = self.collect_SSM_vars(self.A, self.B, self.C, self.D, u, alpha=self.GBT_alpha)
            c_k = self.scan_SSM(Ab, Bb, Cb, Db, u, x0=jnp.zeros((self.N, )))[1]
        else:
            Ab, Bb, Cb, Db = self.discretize(self.A, self.B, self.C, self.D, step=self.step, alpha=self.GBT_alpha)
            c_k = self.causal_convolution(u, self.K_conv(Ab, Bb, Cb, Db, L=self.max_length))
            
        return c_k
    
    def reconstruct(self, c):
        '''
        Uses coeffecients to reconstruct the signal
        
        Args: 
            c (jnp.ndarray): coefficients of the HiPPO projection
            
        Returns:
            reconstructed signal
        '''
        return (self.eval_matrix @ jnp.expand_dims(c, -1)).squeeze(-1)
    
    def discretize(self, A, B, C, D, step, alpha=0.5):
        '''
        function used for discretizing the HiPPO matrix
        
        Args:
            A (jnp.ndarray): matrix to be discretized
            B (jnp.ndarray): matrix to be discretized
            C (jnp.ndarray): matrix to be discretized
            D (jnp.ndarray): matrix to be discretized
            step (float): step size used for discretization
            alpha (float, optional): used for determining which generalized bilinear transformation to use
                - forward Euler corresponds to α = 0,
                - backward Euler corresponds to α = 1,
                - bilinear corresponds to α = 0.5,
                - Zero-order Hold corresponds to α > 1
        '''
        I = jnp.eye(A.shape[0])
        GBT = jnp.linalg.inv(I - (step * alpha * A))
        GBT_A = GBT @ (I + (step * (1-alpha) * A))
        GBT_B = (step * GBT) @ B
        
        if alpha > 1: # Zero-order Hold
            GBT_A = jax.scipy.linalg.expm(step * A)
            GBT_B = (jnp.linalg.inv(A) @ (jax.scipy.linalg.expm(step * A) - I)) @ B 
        
        return GBT_A, GBT_B, C, D
    
    def collect_SSM_vars(self, A, B, C, D, u, alpha=0.5):
        '''
        turns the continuos HiPPO matrix components into discrete ones
        
        Args:
            A (jnp.ndarray): matrix to be discretized
            B (jnp.ndarray): matrix to be discretized
            C (jnp.ndarray): matrix to be discretized
            D (jnp.ndarray): matrix to be discretized
            u (jnp.ndarray): input signal
            alpha (float, optional): used for determining which generalized bilinear transformation to use
            
        Returns:
            Ab (jnp.ndarray): discrete form of the HiPPO matrix
            Bb (jnp.ndarray): discrete form of the HiPPO matrix
            Cb (jnp.ndarray): discrete form of the HiPPO matrix
            Db (jnp.ndarray): discrete form of the HiPPO matrix
        '''
        L = u.shape[0]
        N = A.shape[0]
        assert L == self.seq_L
        assert N == self.N

        Ab, Bb, Cb, Db = self.discretize(A, B, C, D, step=1.0/L, alpha=alpha)
    
        return Ab, Bb, Cb, Db
    
    def scan_SSM(self, Ab, Bb, Cb, Db, u, x0):
        '''
        This is for returning the discretized hidden state often needed for an RNN. 
        Args:
            Ab (jnp.ndarray): the discretized A matrix
            Bb (jnp.ndarray): the discretized B matrix
            Cb (jnp.ndarray): the discretized C matrix
            u (jnp.ndarray): the input sequence
            x0 (jnp.ndarray): the initial hidden state
        Returns:
            the next hidden state (aka coefficients representing the function, f(t))
        '''
        def step(x_k_1, u_k):
            '''
            Get descretized coefficients of the hidden state by applying HiPPO matrix to input sequence, u_k, and previous hidden state, x_k_1.
            Args:
                x_k_1: previous hidden state
                u_k: output from function f at, descritized, time step, k.
            
            Returns: 
                x_k: current hidden state
                y_k: current output of hidden state applied to Cb (sorry for being vague, I just dont know yet)
            '''

            x_k = (Ab @ x_k_1) + (Bb @ u_k)
            y_k = (Cb @ x_k) + (Db @ u_k)
            
            return x_k, y_k

        return jax.lax.scan(step, x0, u)


In [41]:
def test():
    # N = 256
    # L = 128
    
    N = 8
    L = 8
    
    x = torch.randn(L, 1)
    
    # ----------------------------------------------------------------------------------
    loss = nn.MSELoss()
    
    # ----------------------------------------------------------------------------------
    # ------------------------------ Test HiPPO LegT model -----------------------------
    # ----------------------------------------------------------------------------------
    print('\nTesting HiPPO LegT model')
    hippo_legt = HiPPO_LegT(N, dt=1./L)
    
    y = hippo_legt(x)
    
    print(f"h-y shape for LegT:\n{y.shape}")
    z = hippo_legt.reconstruct(y)
    print(f"h-z shape for LegT:\n{z.shape}")


    mse = loss(z[-1,0,:L], x.squeeze(-1))
    print(f"h-MSE shape:\n{mse}")
    print(f"end of test for HiPPO LegT model")
    
    # ----------------------------------------------------------------------------------
    # ------------------------------ Test HiPPO LegS model -----------------------------
    # ----------------------------------------------------------------------------------
    print('\nTesting HiPPO LegS model')
    hippo_legs = HiPPO_LegS(N, max_length=L) #The Gu's
    
    y = hippo_legs(x)
    
    print(f"h-y shape for LegS:\n{y.shape}")
    print(f"h-y for LegS:\n{y}")
    
    z = hippo_legs(x, fast=True)
    
    print(f"h-reconstruction shape for LegS:\n{hippo_legs.reconstruct(z).shape}")
    print(f"h-reconstruction for LegS:\n{hippo_legs.reconstruct(z)}")
    
    # print(y-z)
    print(f"end of test for HiPPO LegT model")
    
    # ----------------------------------------------------------------------------------
    # ------------------------------ Test Generic HiPPO model --------------------------
    # ----------------------------------------------------------------------------------
    print('\nTesting BRYANS HiPPO LegS model')
    hippo_LegS_B = HiPPO(N=N, 
                         max_length=L,
                         measure='legs', 
                         step=1.0/L, 
                         GBT_alpha=0.5, 
                         seq_L=L, 
                         v='v', 
                         lambda_n=1.0,
                         fourier_type='fru',
                         alpha=0.0,
                         beta=1.0) # Bryan's
    
    x = jnp.asarray(x) # convert torch array to jax array
    print(f"input:\n{x}")
    print(f"input type:\n{type(x)}")
    
    params = hippo_LegS_B.init(key2, x)
    
    c_legs = hippo_LegS_B.apply(params, x)
    
    y_legs = hippo_LegS_B.apply({'params': params}, c_legs, method=hippo_LegS_B.reconstruct)
    
    print(f"U-c shape for LegS:\n{c_legs.shape}")
    print(f"U-c for LegS:\n{c_legs}")
    
    print(f"U-reconstruction shape for LegS:\n{y_legs.shape}")
    print(f"U-reconstruction for LegS:\n{y_legs}")
    
    print(f"end of test for HiPPO LegT model")

In [42]:
test()


Testing HiPPO LegT model
h-y shape for LegT:
torch.Size([8, 1, 8])
h-z shape for LegT:
torch.Size([8, 1, 8])
h-MSE shape:
1.9479849338531494
end of test for HiPPO LegT model

Testing HiPPO LegS model


  self.eval_matrix = torch.from_numpy(np.asarray(((B[:, None] * ss.eval_legendre(np.arange(N)[:, None], 2 * vals - 1)).T)))


h-y shape for LegS:
torch.Size([8, 1, 8])
h-y for LegS:
tensor([[[ 5.7822e-01,  5.0075e-01,  1.2929e-01,  3.1453e-08, -1.0582e-08,
           5.9981e-09, -1.9635e-08,  9.4720e-09]],

        [[ 3.8025e-01, -6.1671e-02, -4.7741e-01, -2.5596e-01, -3.2248e-02,
          -1.3894e-08, -9.5905e-10,  8.4431e-09]],

        [[ 1.8778e-01, -2.8086e-01, -2.9865e-01,  2.5937e-01,  3.1473e-01,
           8.7701e-02,  7.3339e-03,  7.3536e-09]],

        [[-3.3568e-01, -8.9385e-01, -5.7342e-01,  4.1692e-02, -1.5803e-01,
          -3.3563e-01, -1.5511e-01, -2.7721e-02]],

        [[-5.7133e-01, -9.3603e-01, -1.6893e-01,  4.8520e-01,  1.6309e-01,
           1.3658e-01,  3.3824e-01,  2.2577e-01]],

        [[-5.4450e-01, -6.2876e-01,  3.5066e-01,  6.7681e-01, -1.1994e-02,
          -1.2064e-01, -9.2157e-02, -3.1154e-01]],

        [[-3.0017e-01, -1.0127e-01,  8.4314e-01,  6.7228e-01, -2.5100e-01,
          -1.7507e-01,  6.3745e-03,  7.5165e-03]],

        [[-4.6900e-01, -3.3871e-01,  3.4866e-01, -2.057