# HiPPO Matrices
---

## Load Packages

In [1]:
## import packages
import jax
import jax.numpy as jnp

from flax import linen as jnn

from jax.nn.initializers import lecun_normal, uniform
from jax.numpy.linalg import eig, inv, matrix_power
from jax.scipy.signal import convolve

import os
import requests

from scipy import linalg as la
from scipy import signal
from scipy import special as ss

import math

print(jax.devices())
print(f"The Device: {jax.lib.xla_bridge.get_backend().platform}")

[GpuDevice(id=0, process_index=0)]
The Device: gpu


In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data as data
import numpy as np
import matplotlib.pyplot as plt
print(f"MPS enabled: {torch.backends.mps.is_available()}")


MPS enabled: False


In [3]:
seed = 1701
key = jax.random.PRNGKey(seed)


In [4]:
num_copies = 5
rng, key2, key3, key4, key5 = jax.random.split(key, num=num_copies)


In [5]:
def random_SSM(rng, N):
    a_r, b_r, c_r = jax.random.split(rng, 3)
    A = jax.random.uniform(a_r, (N, N))
    B = jax.random.uniform(b_r, (N, 1))
    C = jax.random.uniform(c_r, (1, N))

    return A, B, C


In [6]:
N = 8


## Instantiate The HiPPO Matrix

## Translated Legendre (LegT)

In [7]:
# Translated Legendre (LegT) - vectorized
def build_LegT_V(N, lambda_n=1):
    """
    The, vectorized implementation of the, measure derived from the translated Legendre basis.

    Args:
        N (int): Order of coefficients to describe the orthogonal polynomial that is the HiPPO projection.
        legt_type (str): Choice between the two different tilts of basis.
            - legt: translated Legendre - 'legt'
            - lmu: Legendre Memory Unit - 'lmu'

    Returns:
        A (jnp.ndarray): The A HiPPO matrix.
        B (jnp.ndarray): The B HiPPO matrix.

    """
    q = jnp.arange(N, dtype=jnp.float64)
    k, n = jnp.meshgrid(q, q)
    case = jnp.power(-1.0, (n - k))
    A = None
    B = None

    if lambda_n == 1:
        A_base = -jnp.sqrt(2 * n + 1) * jnp.sqrt(2 * k + 1)
        pre_D = jnp.sqrt(jnp.diag(2 * q + 1))
        B = D = jnp.diag(pre_D)[:, None]
        A = jnp.where(
            k <= n, A_base, A_base * case
        )  # if n >= k, then case_2 * A_base is used, otherwise A_base

    elif lambda_n == 2:  # (jnp.sqrt(2*n+1) * jnp.power(-1, n)):
        A_base = -(2 * n + 1)
        B = jnp.diag((2 * q + 1) * jnp.power(-1, n))[:, None]
        A = jnp.where(
            k <= n, A_base * case, A_base
        )  # if n >= k, then case_2 * A_base is used, otherwise A_base

    return A, B


In [8]:
# Translated Legendre (LegT) - non-vectorized
def build_LegT(N, legt_type="legt"):
    """
    The, non-vectorized implementation of the, measure derived from the translated Legendre basis

    Args:
        N (int): Order of coefficients to describe the orthogonal polynomial that is the HiPPO projection.
        legt_type (str): Choice between the two different tilts of basis.
            - legt: translated Legendre - 'legt'
            - lmu: Legendre Memory Unit - 'lmu'

    Returns:
        A (jnp.ndarray): The A HiPPO matrix.
        B (jnp.ndarray): The B HiPPO matrix.

    """
    Q = jnp.arange(N, dtype=jnp.float64)
    pre_R = 2 * Q + 1
    k, n = jnp.meshgrid(Q, Q)

    if legt_type == "legt":
        R = jnp.sqrt(pre_R)
        A = R[:, None] * jnp.where(n < k, (-1.0) ** (n - k), 1) * R[None, :]
        B = R[:, None]
        A = -A

        # Halve again for timescale correctness
        # A, B = A/2, B/2
        # A *= 0.5
        # B *= 0.5

    elif legt_type == "lmu":
        R = pre_R[:, None]
        A = jnp.where(n < k, -1, (-1.0) ** (n - k + 1)) * R
        B = (-1.0) ** Q[:, None] * R

    return A, B


In [9]:
nv_LegT_A, nv_LegT_B = build_LegT(N=N, legt_type="lmu")
LegT_A, LegT_B = build_LegT_V(N=N, lambda_n=2)
print(f"A nv:\n", nv_LegT_A)
print(f"A v:\n", LegT_A)
print(f"B nv:\n", nv_LegT_B)
print(f"B v:\n", LegT_B)
print(f"A Comparison:\n ", jnp.allclose(nv_LegT_A, LegT_A))
print(f"B Comparison:\n ", jnp.allclose(nv_LegT_B, LegT_B))


  Q = jnp.arange(N, dtype=jnp.float64)


A nv:
 [[ -1.  -1.  -1.  -1.  -1.  -1.  -1.  -1.]
 [  3.  -3.  -3.  -3.  -3.  -3.  -3.  -3.]
 [ -5.   5.  -5.  -5.  -5.  -5.  -5.  -5.]
 [  7.  -7.   7.  -7.  -7.  -7.  -7.  -7.]
 [ -9.   9.  -9.   9.  -9.  -9.  -9.  -9.]
 [ 11. -11.  11. -11.  11. -11. -11. -11.]
 [-13.  13. -13.  13. -13.  13. -13. -13.]
 [ 15. -15.  15. -15.  15. -15.  15. -15.]]
A v:
 [[ -1.  -1.  -1.  -1.  -1.  -1.  -1.  -1.]
 [  3.  -3.  -3.  -3.  -3.  -3.  -3.  -3.]
 [ -5.   5.  -5.  -5.  -5.  -5.  -5.  -5.]
 [  7.  -7.   7.  -7.  -7.  -7.  -7.  -7.]
 [ -9.   9.  -9.   9.  -9.  -9.  -9.  -9.]
 [ 11. -11.  11. -11.  11. -11. -11. -11.]
 [-13.  13. -13.  13. -13.  13. -13. -13.]
 [ 15. -15.  15. -15.  15. -15.  15. -15.]]
B nv:
 [[  1.]
 [ -3.]
 [  5.]
 [ -7.]
 [  9.]
 [-11.]
 [ 13.]
 [-15.]]
B v:
 [[  1.]
 [ -3.]
 [  5.]
 [ -7.]
 [  9.]
 [-11.]
 [ 13.]
 [-15.]]
A Comparison:
  True
B Comparison:
  True


## Translated Laguerre (LagT)

In [10]:
# Translated Laguerre (LagT) - non-vectorized
def build_LagT_V(alpha, beta, N):
    """
    The, vectorized implementation of the, measure derived from the translated Laguerre basis.

    Args:
        alpha (float): The order of the Laguerre basis.
        beta (float): The scale of the Laguerre basis.
        N (int): Order of coefficients to describe the orthogonal polynomial that is the HiPPO projection.

    Returns:
        A (jnp.ndarray): The A HiPPO matrix.
        B (jnp.ndarray): The B HiPPO matrix.

    """
    L = jnp.exp(
        0.5 * (ss.gammaln(jnp.arange(N) + alpha + 1) - ss.gammaln(jnp.arange(N) + 1))
    )
    inv_L = 1.0 / L[:, None]
    pre_A = (jnp.eye(N) * ((1 + beta) / 2)) + jnp.tril(jnp.ones((N, N)), -1)
    pre_B = ss.binom(alpha + jnp.arange(N), jnp.arange(N))[:, None]

    A = -inv_L * pre_A * L[None, :]
    B = (
        jnp.exp(-0.5 * ss.gammaln(1 - alpha))
        * jnp.power(beta, (1 - alpha) / 2)
        * inv_L
        * pre_B
    )

    return A, B


In [11]:
# Translated Laguerre (LagT) - non-vectorized
def build_LagT(alpha, beta, N):
    """
    The, non-vectorized implementation of the, measure derived from the translated Laguerre basis.

    Args:
        alpha (float): The order of the Laguerre basis.
        beta (float): The scale of the Laguerre basis.
        N (int): Order of coefficients to describe the orthogonal polynomial that is the HiPPO projection.

    Returns:
        A (jnp.ndarray): The A HiPPO matrix.
        B (jnp.ndarray): The B HiPPO matrix.

    """
    A = -jnp.eye(N) * (1 + beta) / 2 - jnp.tril(jnp.ones((N, N)), -1)
    B = ss.binom(alpha + jnp.arange(N), jnp.arange(N))[:, None]

    L = jnp.exp(
        0.5 * (ss.gammaln(jnp.arange(N) + alpha + 1) - ss.gammaln(jnp.arange(N) + 1))
    )
    A = (1.0 / L[:, None]) * A * L[None, :]
    B = (
        (1.0 / L[:, None])
        * B
        * jnp.exp(-0.5 * ss.gammaln(1 - alpha))
        * beta ** ((1 - alpha) / 2)
    )

    return A, B


In [12]:
nv_LagT_A, nv_LagT_B = build_LagT(alpha=0, beta=1, N=N)
LagT_A, LagT_B = build_LagT_V(alpha=0, beta=1, N=N)
print(f"A nv:\n{nv_LagT_A}")
print(f"A v:\n{LagT_A}")
print(f"B nv:\n{nv_LagT_B}")
print(f"B v:\n{LagT_B}")
print(f"A Comparison:\n ", jnp.allclose(nv_LagT_A, LagT_A))
print(f"B Comparison:\n ", jnp.allclose(nv_LagT_B, LagT_B))


A nv:
[[-1. -0. -0. -0. -0. -0. -0. -0.]
 [-1. -1. -0. -0. -0. -0. -0. -0.]
 [-1. -1. -1. -0. -0. -0. -0. -0.]
 [-1. -1. -1. -1. -0. -0. -0. -0.]
 [-1. -1. -1. -1. -1. -0. -0. -0.]
 [-1. -1. -1. -1. -1. -1. -0. -0.]
 [-1. -1. -1. -1. -1. -1. -1. -0.]
 [-1. -1. -1. -1. -1. -1. -1. -1.]]
A v:
[[-1. -0. -0. -0. -0. -0. -0. -0.]
 [-1. -1. -0. -0. -0. -0. -0. -0.]
 [-1. -1. -1. -0. -0. -0. -0. -0.]
 [-1. -1. -1. -1. -0. -0. -0. -0.]
 [-1. -1. -1. -1. -1. -0. -0. -0.]
 [-1. -1. -1. -1. -1. -1. -0. -0.]
 [-1. -1. -1. -1. -1. -1. -1. -0.]
 [-1. -1. -1. -1. -1. -1. -1. -1.]]
B nv:
[[1.]
 [1.]
 [1.]
 [1.]
 [1.]
 [1.]
 [1.]
 [1.]]
B v:
[[1.]
 [1.]
 [1.]
 [1.]
 [1.]
 [1.]
 [1.]
 [1.]]
A Comparison:
  True
B Comparison:
  True


## Scaled Legendre (LegS)

In [13]:
# Scaled Legendre (LegS) vectorized
def build_LegS_V(N):
    """
    The, vectorized implementation of the, measure derived from the Scaled Legendre basis.

    Args:
        N (int): Order of coefficients to describe the orthogonal polynomial that is the HiPPO projection.

    Returns:
        A (jnp.ndarray): The A HiPPO matrix.
        B (jnp.ndarray): The B HiPPO matrix.

    """
    q = jnp.arange(N, dtype=jnp.float64)
    k, n = jnp.meshgrid(q, q)
    pre_D = jnp.sqrt(jnp.diag(2 * q + 1))
    B = D = jnp.diag(pre_D)[:, None]

    A_base = (-jnp.sqrt(2 * n + 1)) * jnp.sqrt(2 * k + 1)
    case_2 = (n + 1) / (2 * n + 1)

    A = jnp.where(n > k, A_base, 0.0)  # if n > k, then A_base is used, otherwise 0
    A = jnp.where(
        n == k, (A_base * case_2), A
    )  # if n == k, then A_base is used, otherwise A

    return A, B


In [14]:
# Scaled Legendre (LegS), non-vectorized
def build_LegS(N):
    """
    The, non-vectorized implementation of the, measure derived from the Scaled Legendre basis.

    Args:
        N (int): Order of coefficients to describe the orthogonal polynomial that is the HiPPO projection.

    Returns:
        A (jnp.ndarray): The A HiPPO matrix.
        B (jnp.ndarray): The B HiPPO matrix.

    """
    q = jnp.arange(
        N, dtype=jnp.float64
    )  # q represents the values 1, 2, ..., N each column has
    k, n = jnp.meshgrid(q, q)
    r = 2 * q + 1
    M = -(jnp.where(n >= k, r, 0) - jnp.diag(q))  # represents the state matrix M
    D = jnp.sqrt(
        jnp.diag(2 * q + 1)
    )  # represents the diagonal matrix D $D := \text{diag}[(2n+1)^{\frac{1}{2}}]^{N-1}_{n=0}$
    A = D @ M @ jnp.linalg.inv(D)
    B = jnp.diag(D)[:, None]
    B = (
        B.copy()
    )  # Otherwise "UserWarning: given NumPY array is not writeable..." after torch.as_tensor(B)

    return A, B


In [15]:
nv_LegS_A, nv_LegS_B = build_LegS(N=N)
LegS_A, LegS_B = build_LegS_V(N=N)
print(f"A nv:\n{nv_LegS_A}")
print(f"A v:\n{LegS_A}")
print(f"B nv:\n{nv_LegS_B}")
print(f"B v:\n{LegS_B}")
print(f"A Comparison:\n ", jnp.allclose(nv_LegS_A, LegS_A))
print(f"B Comparison:\n ", jnp.allclose(nv_LegS_B, LegS_B))


  q = jnp.arange(


A nv:
[[ -1.          0.          0.          0.          0.          0.
    0.          0.       ]
 [ -1.7320508  -1.9999999   0.          0.          0.          0.
    0.          0.       ]
 [ -2.2360678  -3.872983   -3.          0.          0.          0.
    0.          0.       ]
 [ -2.6457512  -4.582576   -5.91608    -4.          0.          0.
    0.          0.       ]
 [ -3.         -5.196152   -6.7082047  -7.9372544  -5.          0.
    0.          0.       ]
 [ -3.3166246  -5.744562   -7.4161987  -8.774965   -9.949874   -6.
    0.          0.       ]
 [ -3.6055512  -6.244998   -8.062259   -9.539392  -10.816654  -11.958261
   -7.          0.       ]
 [ -3.8729832  -6.708204   -8.6602545 -10.246951  -11.61895   -12.845232
  -13.964239   -7.9999995]]
A v:
[[ -1.          0.          0.          0.          0.          0.
    0.          0.       ]
 [ -1.7320508  -2.          0.          0.          0.          0.
    0.          0.       ]
 [ -2.2360678  -3.872983   -2.999999

## Fourier Basis

In [16]:

# Fourier Basis OPs and functions - vectorized
def build_Fourier_V(N, fourier_type="fru"):
    """
    Vectorized measure implementations derived from fourier basis.

    Args:
        N (int): Order of coefficients to describe the orthogonal polynomial that is the HiPPO projection.
        fourier_type (str): The type of Fourier measure.
            - FRU: Fourier Recurrent Unit - fru
            - FouT: truncated Fourier - fout
            - fouD: decayed Fourier - foud

    Returns:
        A (jnp.ndarray): The A HiPPO matrix.
        B (jnp.ndarray): The B HiPPO matrix.

    """
    A = jnp.diag(
        jnp.stack([jnp.zeros(N // 2), jnp.zeros(N // 2)], axis=-1).reshape(-1)[1:], 1
    )
    B = jnp.zeros(A.shape[1], dtype=jnp.float64)

    B = B.at[0::2].set(jnp.sqrt(2))
    B = B.at[0].set(1)
    
    q = jnp.arange(A.shape[1], dtype=jnp.float64)
    k, n = jnp.meshgrid(q, q)

    n_odd = n % 2 == 0
    k_odd = k % 2 == 0

    case_1 = (n == k) & (n == 0)
    case_2_3 = ((k == 0) & (n_odd)) | ((n == 0) & (k_odd))
    case_4 = (n_odd) & (k_odd)
    case_5 = (n - k == 1) & (k_odd)
    case_6 = (k - n == 1) & (n_odd)

    if fourier_type == "fru":  # Fourier Recurrent Unit (FRU) - vectorized
        A = jnp.where(
            case_1,
            -1.0,
            jnp.where(
                case_2_3,
                -jnp.sqrt(2),
                jnp.where(
                    case_4,
                    -2,
                    jnp.where(
                        case_5,
                        jnp.pi * (n // 2),
                        jnp.where(case_6, -jnp.pi * (k // 2), 0.0),
                    ),
                ),
            ),
        )

    elif fourier_type == "fout":  # truncated Fourier (FouT) - vectorized
        A = jnp.where(
            case_1,
            -1.0,
            jnp.where(
                case_2_3,
                -jnp.sqrt(2),
                jnp.where(
                    case_4,
                    -2,
                    jnp.where(
                        case_5,
                        jnp.pi * (n // 2),
                        jnp.where(case_6, -jnp.pi * (k // 2), 0.0),
                    ),
                ),
            ),
        )

        A = 2 * A
        B = 2 * B

    elif fourier_type == "foud":
        A = jnp.where(
            case_1,
            -1.0,
            jnp.where(
                case_2_3,
                -jnp.sqrt(2),
                jnp.where(
                    case_4,
                    -2,
                    jnp.where(
                        case_5,
                        2 * jnp.pi * (n // 2),
                        jnp.where(case_6, 2 * -jnp.pi * (k // 2), 0.0),
                    ),
                ),
            ),
        )

        A = 0.5 * A
        B = 0.5 * B

    B = B[:, None]

    return A, B


In [17]:
def build_Fourier(N, fourier_type="fru"):
    """
    Non-vectorized measure implementations derived from fourier basis.

    Args:
        N (int): Order of coefficients to describe the orthogonal polynomial that is the HiPPO projection.
        fourier_type (str): The type of Fourier measure.
            - FRU: Fourier Recurrent Unit - fru
            - FouT: truncated Fourier - fout
            - fouD: decayed Fourier - foud

    Returns:
        A (jnp.ndarray): The A HiPPO matrix.
        B (jnp.ndarray): The B HiPPO matrix.

    """
    freqs = jnp.arange(N // 2)

    if fourier_type == "fru":  # Fourier Recurrent Unit (FRU) - non-vectorized
        d = jnp.stack([jnp.zeros(N // 2), freqs], axis=-1).reshape(-1)[1:]
        A = jnp.pi * (-jnp.diag(d, 1) + jnp.diag(d, -1))

        B = jnp.zeros(A.shape[1])
        B = B.at[0::2].set(jnp.sqrt(2))
        B = B.at[0].set(1)

        A = A - B[:, None] * B[None, :]
        B = B[:, None]

    elif fourier_type == "fout":  # truncated Fourier (FouT) - non-vectorized
        freqs *= 2
        d = jnp.stack([jnp.zeros(N // 2), freqs], axis=-1).reshape(-1)[1:]
        A = jnp.pi * (-jnp.diag(d, 1) + jnp.diag(d, -1))

        B = jnp.zeros(A.shape[1])
        B = B.at[0::2].set(jnp.sqrt(2))
        B = B.at[0].set(1)

        # Subtract off rank correction - this corresponds to the other endpoint u(t-1) in this case
        A = A - B[:, None] * B[None, :] * 2
        B = B[:, None] * 2

    elif fourier_type == "foud":
        d = jnp.stack([jnp.zeros(N // 2), freqs], axis=-1).reshape(-1)[1:]
        A = jnp.pi * (-jnp.diag(d, 1) + jnp.diag(d, -1))

        B = jnp.zeros(A.shape[1])
        B = B.at[0::2].set(jnp.sqrt(2))
        B = B.at[0].set(1)

        # Subtract off rank correction - this corresponds to the other endpoint u(t-1) in this case
        A = A - 0.5 * B[:, None] * B[None, :]
        B = 0.5 * B[:, None]

    return A, B


In [18]:
nv_Fourier_A, nv_Fourier_B = build_Fourier(N=N, fourier_type="fru")
Fourier_A, Fourier_B = build_Fourier_V(N=N, fourier_type="fru")
print(f"nv for A:\n{nv_Fourier_A}")
print(f"v for A:\n{Fourier_A}")
print(f"nv for B:\n{nv_Fourier_B}")
print(f"v for B:\n{Fourier_B}")
print(f"A Comparison:\n {jnp.allclose(nv_Fourier_A, Fourier_A)}")
print(f"B Comparison:\n {jnp.allclose(nv_Fourier_B, Fourier_B)}")


  B = jnp.zeros(A.shape[1], dtype=jnp.float64)
  q = jnp.arange(A.shape[1], dtype=jnp.float64)


nv for A:
[[-1.         0.        -1.4142135  0.        -1.4142135  0.
  -1.4142135  0.       ]
 [ 0.         0.         0.         0.         0.         0.
   0.         0.       ]
 [-1.4142135  0.        -1.9999999 -3.1415927 -1.9999999  0.
  -1.9999999  0.       ]
 [ 0.         0.         3.1415927  0.         0.         0.
   0.         0.       ]
 [-1.4142135  0.        -1.9999999  0.        -1.9999999 -6.2831855
  -1.9999999  0.       ]
 [ 0.         0.         0.         0.         6.2831855  0.
   0.         0.       ]
 [-1.4142135  0.        -1.9999999  0.        -1.9999999  0.
  -1.9999999 -9.424778 ]
 [ 0.         0.         0.         0.         0.         0.
   9.424778   0.       ]]
v for A:
[[-1.        -0.        -1.4142135  0.        -1.4142135  0.
  -1.4142135  0.       ]
 [ 0.         0.         0.         0.         0.         0.
   0.         0.       ]
 [-1.4142135  0.        -2.        -3.1415927 -2.         0.
  -2.         0.       ]
 [ 0.         0.         3.

In [19]:
def make_HiPPO(
    N, v="nv", measure="legs", lambda_n=1, fourier_type="fru", alpha=0, beta=1
):
    """
    Instantiates the HiPPO matrix of a given order using a particular measure.
    Args:
        N (int): Order of coefficients to describe the orthogonal polynomial that is the HiPPO projection.
        v (str): choose between this repo's implementation or hazy research's implementation.
        measure (str):
            choose between
                - HiPPO w/ Translated Legendre (LegT) - legt
                - HiPPO w/ Translated Laguerre (LagT) - lagt
                - HiPPO w/ Scaled Legendre (LegS) - legs
                - HiPPO w/ Fourier basis - fourier
                    - FRU: Fourier Recurrent Unit
                    - FouT: Translated Fourier
        lambda_n (int): The amount of tilt applied to the HiPPO-LegS basis, determines between LegS and LMU.
        fourier_type (str): chooses between the following:
            - FRU: Fourier Recurrent Unit - fru
            - FouT: Translated Fourier - fout
            - FourD: Fourier Decay - fourd
        alpha (float): The order of the Laguerre basis.
        beta (float): The scale of the Laguerre basis.

    Returns:
        A (jnp.ndarray): The HiPPO matrix multiplied by -1.
        B (jnp.ndarray): The other corresponding state space matrix.

    """
    A = None
    B = None
    if measure == "legt":
        if v == "nv":
            A, B = build_LegT(N=N, lambda_n=lambda_n)
        else:
            A, B = build_LegT_V(N=N, lambda_n=lambda_n)

    elif measure == "lagt":
        if v == "nv":
            A, B = build_LagT(alpha=alpha, beta=beta, N=N)
        else:
            A, B = build_LagT_V(alpha=alpha, beta=beta, N=N)

    elif measure == "legs":
        if v == "nv":
            A, B = build_LegS(N=N)
        else:
            A, B = build_LegS_V(N=N)

    elif measure == "fourier":
        if v == "nv":
            A, B = build_Fourier(N=N, fourier_type=fourier_type)
        else:
            A, B = build_Fourier_V(N=N, fourier_type=fourier_type)

    elif measure == "random":
        A = jnp.random.randn(N, N) / N
        B = jnp.random.randn(N, 1)

    elif measure == "diagonal":
        A = -jnp.diag(jnp.exp(jnp.random.randn(N)))
        B = jnp.random.randn(N, 1)

    else:
        raise ValueError("Invalid HiPPO type")

    A_copy = A.copy()
    B_copy = B.copy()

    return jnp.array(A_copy), B_copy


In [20]:
nv_A, nv_B = make_HiPPO(
    N=N, v="nv", measure="legs", lambda_n=1, fourier_type="fru", alpha=0, beta=1
)
v_A, v_B = make_HiPPO(
    N=N, v="v", measure="legs", lambda_n=1, fourier_type="fru", alpha=0, beta=1
)
print(f"nv:\n", nv_A)
print(f"v:\n", v_A)
# print(f"nv:\n", nv_B)
# print(f"v:\n", v_B)
print(f"A Comparison:\n ", jnp.allclose(nv_A, v_A))
print(f"B Comparison:\n ", jnp.allclose(nv_B, v_B))


nv:
 [[ -1.          0.          0.          0.          0.          0.
    0.          0.       ]
 [ -1.7320508  -1.9999999   0.          0.          0.          0.
    0.          0.       ]
 [ -2.2360678  -3.872983   -3.          0.          0.          0.
    0.          0.       ]
 [ -2.6457512  -4.582576   -5.91608    -4.          0.          0.
    0.          0.       ]
 [ -3.         -5.196152   -6.7082047  -7.9372544  -5.          0.
    0.          0.       ]
 [ -3.3166246  -5.744562   -7.4161987  -8.774965   -9.949874   -6.
    0.          0.       ]
 [ -3.6055512  -6.244998   -8.062259   -9.539392  -10.816654  -11.958261
   -7.          0.       ]
 [ -3.8729832  -6.708204   -8.6602545 -10.246951  -11.61895   -12.845232
  -13.964239   -7.9999995]]
v:
 [[ -1.          0.          0.          0.          0.          0.
    0.          0.       ]
 [ -1.7320508  -2.          0.          0.          0.          0.
    0.          0.       ]
 [ -2.2360678  -3.872983   -2.9999995 

  q = jnp.arange(


In [21]:
### Utilities


def shift_up(a, s=None, drop=True, dim=0):
    assert dim == 0
    if s is None:
        s = torch.zeros_like(a[0, ...])
    s = s.unsqueeze(dim)
    if drop:
        a = a[:-1, ...]
    return torch.cat((s, a), dim=dim)

def interleave(a, b, uneven=False, dim=0):
    """ Interleave two tensors of same shape """
    # assert(a.shape == b.shape)
    assert dim == 0 # TODO temporary to make handling uneven case easier
    if dim < 0:
        dim = N + dim
    if uneven:
        a_ = a[-1:, ...]
        a = a[:-1, ...]
    c = torch.stack((a, b), dim+1)
    out_shape = list(a.shape)
    out_shape[dim] *= 2
    c = c.view(out_shape)
    if uneven:
        c = torch.cat((c, a_), dim=dim)
    return c

def batch_mult(A, u, has_batch=None):
    """ Matrix mult A @ u with special case to save memory if u has additional batch dim

    The batch dimension is assumed to be the second dimension
    A : (L, ..., N, N)
    u : (L, [B], ..., N)
    has_batch: True, False, or None. If None, determined automatically

    Output:
    x : (L, [B], ..., N)
      A @ u broadcasted appropriately
    """

    if has_batch is None:
        has_batch = len(u.shape) >= len(A.shape)

    if has_batch:
        u = u.permute([0] + list(range(2, len(u.shape))) + [1])
    else:
        u = u.unsqueeze(-1)
    v = (A @ u)
    if has_batch:
        v = v.permute([0] + [len(u.shape)-1] + list(range(1, len(u.shape)-1)))
    else:
        v = v[..., 0]
    return v



### Main unrolling functions

def unroll(A, u):
    """
    A : (..., N, N) # TODO I think this can't take batch dimension?
    u : (L, ..., N)
    output : x (..., N) # TODO a lot of these shapes are wrong
    x[i, ...] = A^{i} @ u[0, ...] + ... + A @ u[i-1, ...] + u[i, ...]
    """

    m = u.new_zeros(u.shape[1:])
    outputs = []
    for u_ in torch.unbind(u, dim=0):
        m = F.linear(m, A) + u_
        outputs.append(m)

    output = torch.stack(outputs, dim=0)
    return output


def parallel_unroll_recursive(A, u):
    """ Bottom-up divide-and-conquer version of unroll. """

    # Main recursive function
    def parallel_unroll_recursive_(A, u):
        if u.shape[0] == 1:
            return u

        u_evens = u[0::2, ...]
        u_odds = u[1::2, ...]

        # u2 = F.linear(u_evens, A) + u_odds
        u2 = (A @ u_evens.unsqueeze(-1)).squeeze(-1) + u_odds
        A2 = A @ A

        x_odds = parallel_unroll_recursive_(A2, u2)
        # x_evens = F.linear(shift_up(x_odds), A) + u_evens
        x_evens = (A @ shift_up(x_odds).unsqueeze(-1)).squeeze(-1) + u_evens

        x = interleave(x_evens, x_odds, dim=0)
        return x

    # Pad u to power of 2
    n = u.shape[0]
    m = int(math.ceil(math.log(n)/math.log(2)))
    N = 1 << m
    u = torch.cat((u, u.new_zeros((N-u.shape[0],) + u.shape[1:] )), dim=0)

    return parallel_unroll_recursive_(A, u)[:n, ...]



def parallel_unroll_recursive_br(A, u):
    """ Same as parallel_unroll_recursive but uses bit reversal for locality. """

    # Main recursive function
    def parallel_unroll_recursive_br_(A, u):
        n = u.shape[0]
        if n == 1:
            return u

        m = n//2
        u_0 = u[:m, ...]
        u_1 = u[m:, ...]

        u2 = F.linear(u_0, A) + u_1
        A2 = A @ A

        x_1 = parallel_unroll_recursive_br_(A2, u2)
        x_0 = F.linear(shift_up(x_1), A) + u_0

        # x = torch.cat((x_0, x_1), dim=0) # is there a way to do this with cat?
        x = interleave(x_0, x_1, dim=0)
        return x

    # Pad u to power of 2
    n = u.shape[0]
    m = int(math.ceil(math.log(n)/math.log(2)))
    N = 1 << m
    u = torch.cat((u, u.new_zeros((N-u.shape[0],) + u.shape[1:] )), dim=0)

    # Apply bit reversal
    br = bitreversal_po2(N)
    u = u[br, ...]

    x = parallel_unroll_recursive_br_(A, u)
    return x[:n, ...]

def parallel_unroll_iterative(A, u):
    """ Bottom-up divide-and-conquer version of unroll, implemented iteratively """

    # Pad u to power of 2
    n = u.shape[0]
    m = int(math.ceil(math.log(n)/math.log(2)))
    N = 1 << m
    u = torch.cat((u, u.new_zeros((N-u.shape[0],) + u.shape[1:] )), dim=0)

    # Apply bit reversal
    br = bitreversal_po2(N)
    u = u[br, ...]

    # Main recursive loop, flattened
    us = [] # stores the u_0 terms in the recursive version
    N_ = N
    As = [] # stores the A matrices
    for l in range(m):
        N_ = N_ // 2
        As.append(A)
        u_0 = u[:N_, ...]
        us.append(u_0)
        u = F.linear(u_0, A) + u[N_:, ...]
        A = A @ A
    x_0 = []
    x = u # x_1
    for l in range(m-1, -1, -1):
        x_0 = F.linear(shift_up(x), As[l]) + us[l]
        x = interleave(x_0, x, dim=0)

    return x[:n, ...]


def variable_unroll_sequential(A, u, s=None, variable=True):
    """ Unroll with variable (in time/length) transitions A.

    A : ([L], ..., N, N) dimension L should exist iff variable is True
    u : (L, [B], ..., N) updates
    s : ([B], ..., N) start state
    output : x (..., N)
    x[i, ...] = A[i]..A[0] @ s + A[i..1] @ u[0] + ... + A[i] @ u[i-1] + u[i]
    """

    if s is None:
        s = torch.zeros_like(u[0])

    if not variable:
        A = A.expand((u.shape[0],) + A.shape)
    has_batch = len(u.shape) >= len(A.shape)

    outputs = []
    for (A_, u_) in zip(torch.unbind(A, dim=0), torch.unbind(u, dim=0)):
        # s = F.linear(s, A_) + u_
        s = batch_mult(A_.unsqueeze(0), s.unsqueeze(0), has_batch)[0]
        s = s + u_
        outputs.append(s)

    output = torch.stack(outputs, dim=0)
    return output



def variable_unroll(A, u, s=None, variable=True, recurse_limit=16):
    """ Bottom-up divide-and-conquer version of variable_unroll. """

    if u.shape[0] <= recurse_limit:
        return variable_unroll_sequential(A, u, s, variable)

    if s is None:
        s = torch.zeros_like(u[0])

    uneven = u.shape[0] % 2 == 1
    has_batch = len(u.shape) >= len(A.shape)

    u_0 = u[0::2, ...]
    u_1  = u[1::2, ...]

    if variable:
        A_0 = A[0::2, ...]
        A_1  = A[1::2, ...]
    else:
        A_0 = A
        A_1 = A

    u_0_ = u_0
    A_0_ = A_0
    if uneven:
        u_0_ = u_0[:-1, ...]
        if variable:
            A_0_ = A_0[:-1, ...]

    u_10 = batch_mult(A_1, u_0_, has_batch)
    u_10 = u_10 + u_1
    A_10 = A_1 @ A_0_

    # Recursive call
    x_1 = variable_unroll(A_10, u_10, s, variable, recurse_limit)

    x_0 = shift_up(x_1, s, drop=not uneven)
    x_0 = batch_mult(A_0, x_0, has_batch)
    x_0 = x_0 + u_0


    x = interleave(x_0, x_1, uneven, dim=0) # For some reason this interleave is slower than in the (non-multi) unroll_recursive
    return x

def variable_unroll_general_sequential(A, u, s, op, variable=True):
    """ Unroll with variable (in time/length) transitions A with general associative operation

    A : ([L], ..., N, N) dimension L should exist iff variable is True
    u : (L, [B], ..., N) updates
    s : ([B], ..., N) start state
    output : x (..., N)
    x[i, ...] = A[i]..A[0] s + A[i..1] u[0] + ... + A[i] u[i-1] + u[i]
    """

    if not variable:
        A = A.expand((u.shape[0],) + A.shape)

    outputs = []
    for (A_, u_) in zip(torch.unbind(A, dim=0), torch.unbind(u, dim=0)):
        s = op(A_, s)
        s = s + u_
        outputs.append(s)

    output = torch.stack(outputs, dim=0)
    return output

def variable_unroll_matrix_sequential(A, u, s=None, variable=True):
    if s is None:
        s = torch.zeros_like(u[0])

    if not variable:
        A = A.expand((u.shape[0],) + A.shape)
    # has_batch = len(u.shape) >= len(A.shape)

    # op = lambda x, y: batch_mult(x.unsqueeze(0), y.unsqueeze(0), has_batch)[0]
    op = lambda x, y: batch_mult(x.unsqueeze(0), y.unsqueeze(0))[0]

    return variable_unroll_general_sequential(A, u, s, op, variable=True)

def variable_unroll_toeplitz_sequential(A, u, s=None, variable=True, pad=False):
    if s is None:
        s = torch.zeros_like(u[0])

    if not variable:
        A = A.expand((u.shape[0],) + A.shape)
    # has_batch = len(u.shape) >= len(A.shape)

    # op = lambda x, y: batch_mult(x.unsqueeze(0), y.unsqueeze(0), has_batch)[0]
    # op = lambda x, y: batch_mult(x.unsqueeze(0), y.unsqueeze(0))[0]

    if pad:
        n = A.shape[-1]
        A = F.pad(A, (0, n))
        u = F.pad(u, (0, n))
        s = F.pad(s, (0, n))
        ret = variable_unroll_general_sequential(A, u, s, triangular_toeplitz_multiply_padded, variable=True)
        ret = ret[..., :n]
        return ret

    return variable_unroll_general_sequential(A, u, s, triangular_toeplitz_multiply, variable=True)



### General parallel scan functions with generic binary composition operators

def variable_unroll_general(A, u, s, op, compose_op=None, sequential_op=None, variable=True, recurse_limit=16):
    """ Bottom-up divide-and-conquer version of variable_unroll.

    compose is an optional function that defines how to compose A without multiplying by a leaf u
    """

    if u.shape[0] <= recurse_limit:
        if sequential_op is None:
            sequential_op = op
        return variable_unroll_general_sequential(A, u, s, sequential_op, variable)

    if compose_op is None:
        compose_op = op

    uneven = u.shape[0] % 2 == 1
    # has_batch = len(u.shape) >= len(A.shape)

    u_0 = u[0::2, ...]
    u_1 = u[1::2, ...]

    if variable:
        A_0 = A[0::2, ...]
        A_1 = A[1::2, ...]
    else:
        A_0 = A
        A_1 = A

    u_0_ = u_0
    A_0_ = A_0
    if uneven:
        u_0_ = u_0[:-1, ...]
        if variable:
            A_0_ = A_0[:-1, ...]

    u_10 = op(A_1, u_0_) # batch_mult(A_1, u_0_, has_batch)
    u_10 = u_10 + u_1
    A_10 = compose_op(A_1, A_0_)

    # Recursive call
    x_1 = variable_unroll_general(A_10, u_10, s, op, compose_op, sequential_op, variable=variable, recurse_limit=recurse_limit)

    x_0 = shift_up(x_1, s, drop=not uneven)
    x_0 = op(A_0, x_0) # batch_mult(A_0, x_0, has_batch)
    x_0 = x_0 + u_0


    x = interleave(x_0, x_1, uneven, dim=0) # For some reason this interleave is slower than in the (non-multi) unroll_recursive
    return x

def variable_unroll_matrix(A, u, s=None, variable=True, recurse_limit=16):
    if s is None:
        s = torch.zeros_like(u[0])
    has_batch = len(u.shape) >= len(A.shape)
    op = lambda x, y: batch_mult(x, y, has_batch)
    sequential_op = lambda x, y: batch_mult(x.unsqueeze(0), y.unsqueeze(0), has_batch)[0]
    matmul = lambda x, y: x @ y
    return variable_unroll_general(A, u, s, op, compose_op=matmul, sequential_op=sequential_op, variable=variable, recurse_limit=recurse_limit)

def variable_unroll_toeplitz(A, u, s=None, variable=True, recurse_limit=8, pad=False):
    """ Unroll with variable (in time/length) transitions A with general associative operation

    A : ([L], ..., N) dimension L should exist iff variable is True
    u : (L, [B], ..., N) updates
    s : ([B], ..., N) start state
    output : x (L, [B], ..., N) same shape as u
    x[i, ...] = A[i]..A[0] s + A[i..1] u[0] + ... + A[i] u[i-1] + u[i]
    """
    # Add the batch dimension to A if necessary
    A_batch_dims = len(A.shape) - int(variable)
    u_batch_dims = len(u.shape)-1
    if u_batch_dims > A_batch_dims:
        # assert u_batch_dims == A_batch_dims + 1
        if variable:
            while len(A.shape) < len(u.shape):
                A = A.unsqueeze(1)
        # else:
        #     A = A.unsqueeze(0)

    if s is None:
        s = torch.zeros_like(u[0])

    if pad:
        n = A.shape[-1]
        A = F.pad(A, (0, n))
        u = F.pad(u, (0, n))
        s = F.pad(s, (0, n))
        op = triangular_toeplitz_multiply_padded
        ret = variable_unroll_general(A, u, s, op, compose_op=op, variable=variable, recurse_limit=recurse_limit)
        ret = ret[..., :n]
        return ret

    op = triangular_toeplitz_multiply
    ret = variable_unroll_general(A, u, s, op, compose_op=op, variable=variable, recurse_limit=recurse_limit)
    return ret

In [22]:
class HiPPO_LegT(nn.Module):
    def __init__(self, N, dt=1.0, discretization="bilinear"):
        """
        N: the order of the HiPPO projection
        dt: discretization step size - should be roughly inverse to the length of the sequence
        """
        super().__init__()
        self.N = N
        # A, B = transition('lmu', N)
        A, B = make_HiPPO(
            N=self.N,
            v="v",
            measure="legt",
            lambda_n=1,
            fourier_type="fru",
            alpha=0,
            beta=1,
        )
        C = np.ones((1, N))
        D = np.zeros((1,))
        # dt, discretization options
        A, B, _, _, _ = signal.cont2discrete((A, B, C, D), dt=dt, method=discretization)

        B = B.squeeze(-1)

        self.register_buffer("A", torch.Tensor(A))  # (N, N)
        self.register_buffer("B", torch.Tensor(B))  # (N,)

        # vals = np.linspace(0.0, 1.0, 1./dt)
        vals = np.arange(0.0, 1.0, dt)
        self.eval_matrix = torch.Tensor(
            ss.eval_legendre(np.arange(N)[:, None], 1 - 2 * vals).T
        )

    def forward(self, inputs):
        """
        inputs : (length, ...)
        output : (length, ..., N) where N is the order of the HiPPO projection
        """

        inputs = inputs.unsqueeze(-1)
        u = inputs * self.B  # (length, ..., N)

        c = torch.zeros(u.shape[1:])
        cs = []
        for f in inputs:
            c = F.linear(c, self.A) + self.B * f
            #print(f"f:\n{f}")
            cs.append(c)
        return torch.stack(cs, dim=0)

    def reconstruct(self, c):
        return (self.eval_matrix @ c.unsqueeze(-1)).squeeze(-1)


In [23]:
class HiPPO_LegS(nn.Module):
    """Vanilla HiPPO-LegS model (scale invariant instead of time invariant)"""

    def __init__(self, N, max_length=1024, measure="legs", discretization="bilinear"):
        """
        max_length: maximum sequence length
        """
        super().__init__()
        self.N = N
        A, B = make_HiPPO(
            N=self.N,
            v="nv",
            measure=measure,
            lambda_n=1,
            fourier_type="fru",
            alpha=0,
            beta=1,
        )
        # A, B = transition(measure, N)
        B = B.squeeze(-1)
        A_stacked = np.empty((max_length, N, N), dtype=A.dtype)
        B_stacked = np.empty((max_length, N), dtype=B.dtype)
        for t in range(1, max_length + 1):
            At = A / t
            Bt = B / t
            if discretization == "forward":
                A_stacked[t - 1] = np.eye(N) + At
                B_stacked[t - 1] = Bt
            elif discretization == "backward":
                A_stacked[t - 1] = la.solve_triangular(
                    np.eye(N) - At, np.eye(N), lower=True
                )
                B_stacked[t - 1] = la.solve_triangular(np.eye(N) - At, Bt, lower=True)
            elif discretization == "bilinear":
                alpha = 0.5
                A_stacked[t - 1] = np.linalg.lstsq(
                    np.eye(N) - (At * alpha), np.eye(N) + (At * alpha), rcond=None
                )[
                    0
                ]  # TODO: Referencing this: https://stackoverflow.com/questions/64527098/numpy-linalg-linalgerror-singular-matrix-error-when-trying-to-solve
                B_stacked[t - 1] = np.linalg.lstsq(np.eye(N) - (At * alpha), Bt, rcond=None)[
                    0
                ]
            else:  # ZOH
                A_stacked[t - 1] = la.expm(A * (math.log(t + 1) - math.log(t)))
                B_stacked[t - 1] = la.solve_triangular(
                    A, A_stacked[t - 1] @ B - B, lower=True
                )
        self.A_stacked = torch.Tensor(A_stacked.copy())  # (max_length, N, N)
        self.B_stacked = torch.Tensor(B_stacked.copy())  # (max_length, N)
        vals = np.linspace(0.0, 1.0, max_length)
        self.eval_matrix = torch.from_numpy(
            np.asarray(
                ((B[:, None] * ss.eval_legendre(np.arange(N)[:, None], 2 * vals - 1)).T)
            )
        )

    def forward(self, inputs, fast=False):
        """
        inputs : (length, ...)
        output : (length, ..., N) where N is the order of the HiPPO projection
        """
        result = None

        L = inputs.shape[0]

        u = inputs.unsqueeze(-1)
        u = torch.transpose(u, 0, -2)
        u = u * self.B_stacked[:L]  # c_k = A @ c_{k-1} + B @ f_k
        print(f"u - Gu: {u}")
        my_b = torch.Tensor(
            [
                [6.6666657e-01],
                [5.7735050e-01],
                [1.4907140e-01],
                [-2.3096800e-07],
                [-2.7939677e-09],
                [2.9616058e-07],
                [-2.2817403e-08],
                [-8.1490725e-08],
            ]
        )
        u = torch.transpose(u, 0, -2)  # (length, ..., N)

        #print(f"A_stacked: {self.A_stacked[:L]}")
        # print(f"B_stacked: {self.B_stacked[:L]}")

        if fast:
            result = variable_unroll_matrix(self.A_stacked[:L], u)

        else:
            result = variable_unroll_matrix_sequential(self.A_stacked[:L], u)

        return result

    def reconstruct(self, c):
        a = self.eval_matrix @ c.unsqueeze(-1)
        return a.squeeze(-1)

In [24]:
class HiPPO(jnn.Module):
    """
    class that constructs HiPPO model using the defined measure.

    Args:
        N (int): order of the HiPPO projection, aka the number of coefficients to describe the matrix
        max_length (int): maximum sequence length to be input
        measure (str): the measure used to define which way to instantiate the HiPPO matrix
        step (float): step size used for descretization
        GBT_alpha (float): represents which descretization transformation to use based off the alpha value
        seq_L (int): length of the sequence to be used for training
        v (str): choice of vectorized or non-vectorized function instantiation
            - 'v': vectorized
            - 'nv': non-vectorized
        lambda_n (float): value associated with the tilt of legt
            - 1: tilt on legt
            - \sqrt(2n+1)(-1)^{N}: tilt associated with the legendre memory unit (LMU)
        fourier_type (str): choice of fourier measures
            - fru: fourier recurrent unit measure (FRU) - 'fru'
            - fout: truncated Fourier (FouT) - 'fout'
            - fourd: decaying fourier transform - 'fourd'
        alpha (float): The order of the Laguerre basis.
        beta (float): The scale of the Laguerre basis.
    """

    N: int
    max_length: int
    measure: str
    step: float
    GBT_alpha: float
    seq_L: int
    v: str
    lambda_n: float
    fourier_type: str
    alpha: float
    beta: float

    def setup(self):
        A, B = make_HiPPO(
            N=self.N,
            v=self.v,
            measure=self.measure,
            lambda_n=self.lambda_n,
            fourier_type=self.fourier_type,
            alpha=self.alpha,
            beta=self.beta,
        )

        self.A = A
        self.B = B  # .squeeze(-1)
        self.C = jnp.ones((self.N, ))
        self.D = jnp.zeros((1,))

        if self.measure == "legt":
            L = self.seq_L
            vals = jnp.arange(0.0, 1.0, L)
            # n = jnp.arange(self.N)[:, None]
            zero_N = self.N - 1
            x = 1 - 2 * vals
            self.eval_matrix = jax.scipy.special.lpmn_values(
                m=zero_N, n=zero_N, z=x, is_normalized=False
            ).T  # ss.eval_legendre(n, x).T

        elif self.measure == "legs":
            L = self.max_length
            vals = jnp.linspace(0.0, 1.0, L)
            # n = jnp.arange(self.N)[:, None]
            zero_N = self.N - 1
            x = 2 * vals - 1
            self.eval_matrix = (
                B[:, None]
                * jax.scipy.special.lpmn_values(
                    m=zero_N, n=zero_N, z=x, is_normalized=False
                )
            ).T  # ss.eval_legendre(n, x)).T
        else:
            raise ValueError("invalid measure")

    def __call__(self, f, init_state=None, t_step=0, kernel=False):
        # print(f"u shape:\n{f.shape}")
        # print(f"u:\n{f}")
        if not kernel:
            if init_state is None:
                init_state = jnp.zeros((self.N, 1))

            Ab, Bb, Cb, Db = self.collect_SSM_vars(
                self.A, self.B, self.C, self.D, f, t_step=t_step, alpha=self.GBT_alpha
            )
            #c_k = self.loop_SSM(Ab=Ab, Bb=Bb, Cb=Cb, Db=Db, c_0=init_state, f=f)
            c_k, y_k = self.scan_SSM(Ab=Ab, Bb=Bb, Cb=Cb, Db=Db, c_0=init_state, f=f)

        else:
            Ab, Bb, Cb, Db = self.discretize(
                self.A, self.B, self.C, self.D, step=self.step, alpha=self.GBT_alpha
            )
            c_k, y_k = self.causal_convolution(
                f, self.K_conv(Ab, Bb, Cb, Db, L=self.max_length)
            )

        return c_k

    def reconstruct(self, c):
        """
        Uses coeffecients to reconstruct the signal

        Args:
            c (jnp.ndarray): coefficients of the HiPPO projection

        Returns:
            reconstructed signal
        """
        return (self.eval_matrix @ jnp.expand_dims(c, -1)).squeeze(-1)

    def discretize(self, A, B, C, D, step, alpha=0.5):
        """
        function used for discretizing the HiPPO matrix

        Args:
            A (jnp.ndarray): matrix to be discretized
            B (jnp.ndarray): matrix to be discretized
            C (jnp.ndarray): matrix to be discretized
            D (jnp.ndarray): matrix to be discretized
            step (float): step size used for discretization
            alpha (float, optional): used for determining which generalized bilinear transformation to use
                - forward Euler corresponds to α = 0,
                - backward Euler corresponds to α = 1,
                - bilinear corresponds to α = 0.5,
                - Zero-order Hold corresponds to α > 1
        """
        I = jnp.eye(A.shape[0])
        step_size = 1 / step
        part1 = I - (A * step_size * alpha)
        part2 = I + (A * step_size * (1 - alpha))
        
        
        GBT_A = jnp.linalg.lstsq(part1, part2, rcond=None)[0]
        GBT_A_2 = inv(part1) @ part2
        
        base_GBT_B = jnp.linalg.lstsq(part1, B, rcond=None)[0]
        GBT_B = step_size * base_GBT_B
        GBT_B_2 = step_size * inv(part1) @ B
        
        assert jnp.allclose(GBT_A, GBT_A_2, rtol=1e-04, atol=1e-06), f"\nGBT_A:\n{GBT_A}\n\nnot equal to\n\nGBT_A_2\n{GBT_A_2}"
        assert jnp.allclose(GBT_B, GBT_B_2, rtol=1e-04, atol=1e-06), f"\nGBT_B:\n{GBT_B}\n\nnot equal to\n\GBT_B_2\n{GBT_B_2}"

        # print(f"GBT_A:\n{GBT_A}")
        # print(f"GBT_B:\n{GBT_B}")

        if alpha > 1:  # Zero-order Hold
            GBT_A = jax.scipy.linalg.expm(step_size * A)
            GBT_B = (jnp.linalg.inv(A) @ (jax.scipy.linalg.expm(step_size * A) - I)) @ B
            
        GBT_A_matrix = jnp.array(
            [
                [
                    [
                        3.3333e-01,
                        -1.4883e-15,
                        -4.7901e-16,
                        -9.7330e-16,
                        -6.9630e-16,
                        -1.2915e-15,
                        -6.7919e-16,
                        3.5877e-17,
                    ],
                    [
                        -5.7735e-01,
                        2.9802e-08,
                        7.7716e-16,
                        1.0547e-15,
                        9.9920e-16,
                        7.2164e-16,
                        -1.6653e-16,
                        -1.5266e-16,
                    ],
                    [
                        -1.4907e-01,
                        -7.7460e-01,
                        -2.0000e-01,
                        -5.5511e-16,
                        -6.3838e-16,
                        -8.1879e-16,
                        1.5266e-16,
                        1.1449e-16,
                    ],
                    [
                        1.9296e-08,
                        -6.1884e-08,
                        -7.8881e-01,
                        -3.3333e-01,
                        -2.2204e-16,
                        -1.6653e-16,
                        -4.1633e-17,
                        -1.3531e-16,
                    ],
                    [
                        -3.3959e-08,
                        1.2015e-07,
                        1.2778e-01,
                        -7.5593e-01,
                        -4.2857e-01,
                        -8.3267e-17,
                        -2.0817e-16,
                        -6.9389e-17,
                    ],
                    [
                        3.0288e-09,
                        -6.3927e-08,
                        -3.5315e-02,
                        2.0893e-01,
                        -7.1071e-01,
                        -5.0000e-01,
                        -1.3878e-17,
                        4.5103e-17,
                    ],
                    [
                        2.4869e-08,
                        2.6210e-08,
                        1.2797e-02,
                        -7.5709e-02,
                        2.5754e-01,
                        -6.6435e-01,
                        -5.5556e-01,
                        1.5959e-16,
                    ],
                    [
                        -1.3988e-08,
                        -4.6192e-08,
                        -5.4985e-03,
                        3.2530e-02,
                        -1.1066e-01,
                        2.8545e-01,
                        -6.2063e-01,
                        -6.0000e-01,
                    ],
                ],
                [
                    [
                        6.0000e-01,
                        -1.3763e-15,
                        -2.0851e-16,
                        1.8474e-16,
                        -3.9911e-16,
                        3.8276e-16,
                        -2.5718e-16,
                        -5.8887e-17,
                    ],
                    [
                        -4.6188e-01,
                        3.3333e-01,
                        -4.9960e-16,
                        -6.6613e-16,
                        5.5511e-17,
                        3.3307e-16,
                        -6.1062e-16,
                        -2.0817e-16,
                    ],
                    [
                        -2.5555e-01,
                        -7.3771e-01,
                        1.4286e-01,
                        4.1633e-16,
                        4.4409e-16,
                        -1.9429e-16,
                        1.1102e-16,
                        1.2317e-16,
                    ],
                    [
                        -7.5593e-02,
                        -2.1822e-01,
                        -8.4515e-01,
                        6.9389e-16,
                        1.9429e-16,
                        -6.9389e-17,
                        -2.7756e-17,
                        4.8572e-17,
                    ],
                    [
                        -9.5238e-03,
                        -2.7493e-02,
                        -1.0648e-01,
                        -8.8192e-01,
                        -1.1111e-01,
                        2.9143e-16,
                        6.2450e-16,
                        6.2450e-17,
                    ],
                    [
                        -2.8427e-09,
                        -6.9719e-09,
                        9.9898e-08,
                        -7.4493e-08,
                        -8.8443e-01,
                        -2.0000e-01,
                        -1.1796e-16,
                        3.4694e-18,
                    ],
                    [
                        1.6951e-08,
                        -2.3124e-08,
                        -1.1497e-07,
                        1.2396e-07,
                        8.7407e-02,
                        -8.6969e-01,
                        -2.7273e-01,
                        -2.6021e-17,
                    ],
                    [
                        -9.3280e-09,
                        -2.3898e-08,
                        6.2306e-08,
                        -5.0203e-08,
                        -1.5648e-02,
                        1.5570e-01,
                        -8.4632e-01,
                        -3.3333e-01,
                    ],
                ],
                [
                    [
                        7.1429e-01,
                        1.4154e-16,
                        -7.5657e-16,
                        -6.4569e-16,
                        -3.3881e-16,
                        -3.7441e-16,
                        -2.4573e-16,
                        6.8165e-18,
                    ],
                    [
                        -3.7115e-01,
                        5.0000e-01,
                        -4.4409e-16,
                        -4.9960e-16,
                        -2.7756e-16,
                        -4.4409e-16,
                        -1.6653e-16,
                        1.3878e-17,
                    ],
                    [
                        -2.6620e-01,
                        -6.4550e-01,
                        3.3333e-01,
                        -1.6653e-16,
                        -9.7145e-17,
                        -4.1633e-17,
                        -1.3184e-16,
                        -3.2960e-17,
                    ],
                    [
                        -1.2599e-01,
                        -3.0551e-01,
                        -7.8881e-01,
                        2.0000e-01,
                        -1.1102e-16,
                        4.3021e-16,
                        -1.6653e-16,
                        -1.3878e-17,
                    ],
                    [
                        -3.8961e-02,
                        -9.4475e-02,
                        -2.4393e-01,
                        -8.6588e-01,
                        9.0909e-02,
                        8.3267e-17,
                        -8.3267e-17,
                        -1.9949e-17,
                    ],
                    [
                        -7.1788e-03,
                        -1.7408e-02,
                        -4.4947e-02,
                        -1.5954e-01,
                        -9.0453e-01,
                        -2.9143e-16,
                        1.5959e-16,
                        -8.6736e-19,
                    ],
                    [
                        -6.0030e-04,
                        -1.4557e-03,
                        -3.7587e-03,
                        -1.3342e-02,
                        -7.5641e-02,
                        -9.1987e-01,
                        -7.6923e-02,
                        -1.1276e-17,
                    ],
                    [
                        2.6170e-10,
                        -4.7852e-08,
                        -8.2657e-09,
                        3.1517e-08,
                        -1.4809e-09,
                        -3.8236e-08,
                        -9.2072e-01,
                        -1.4286e-01,
                    ],
                ],
                [
                    [
                        7.7778e-01,
                        6.8154e-16,
                        -1.1268e-16,
                        5.3622e-17,
                        4.3422e-16,
                        -1.0521e-16,
                        -2.2829e-16,
                        6.1580e-25,
                    ],
                    [
                        -3.0792e-01,
                        6.0000e-01,
                        1.6653e-16,
                        2.7756e-16,
                        -1.6653e-16,
                        -5.8287e-16,
                        -1.1102e-16,
                        1.1581e-23,
                    ],
                    [
                        -2.5297e-01,
                        -5.6334e-01,
                        4.5455e-01,
                        -1.6653e-16,
                        -2.0817e-16,
                        -5.5511e-17,
                        -2.7756e-16,
                        7.6514e-24,
                    ],
                    [
                        -1.4966e-01,
                        -3.3328e-01,
                        -7.1710e-01,
                        3.3333e-01,
                        1.9429e-16,
                        2.7756e-16,
                        -3.4694e-17,
                        3.3087e-24,
                    ],
                    [
                        -6.5268e-02,
                        -1.4535e-01,
                        -3.1274e-01,
                        -8.1408e-01,
                        2.3077e-01,
                        -3.0531e-16,
                        1.9429e-16,
                        -5.3767e-24,
                    ],
                    [
                        -2.0616e-02,
                        -4.5911e-02,
                        -9.8784e-02,
                        -2.5714e-01,
                        -8.7471e-01,
                        1.4286e-01,
                        5.6205e-16,
                        -8.0650e-24,
                    ],
                    [
                        -4.4824e-03,
                        -9.9820e-03,
                        -2.1478e-02,
                        -5.5908e-02,
                        -1.9018e-01,
                        -9.1111e-01,
                        6.6667e-02,
                        5.1699e-24,
                    ],
                    [
                        -6.0187e-04,
                        -1.3403e-03,
                        -2.8838e-03,
                        -7.5069e-03,
                        -2.5536e-02,
                        -1.2234e-01,
                        -9.3095e-01,
                        2.9802e-08,
                    ],
                ],
                [
                    [
                        8.1818e-01,
                        3.7236e-16,
                        1.0195e-16,
                        2.5556e-16,
                        1.8434e-16,
                        2.4138e-17,
                        9.4042e-17,
                        -1.1605e-17,
                    ],
                    [
                        -2.6243e-01,
                        6.6667e-01,
                        -5.5511e-17,
                        -2.2204e-16,
                        0.0000e00,
                        -2.2204e-16,
                        2.7756e-16,
                        -3.4694e-17,
                    ],
                    [
                        -2.3455e-01,
                        -4.9654e-01,
                        5.3846e-01,
                        -1.5266e-16,
                        -4.0246e-16,
                        -1.6653e-16,
                        -1.3878e-16,
                        5.4644e-17,
                    ],
                    [
                        -1.5859e-01,
                        -3.3572e-01,
                        -6.5012e-01,
                        4.2857e-01,
                        1.5266e-16,
                        6.9389e-17,
                        -2.5674e-16,
                        3.8164e-17,
                    ],
                    [
                        -8.3916e-02,
                        -1.7765e-01,
                        -3.4401e-01,
                        -7.5593e-01,
                        3.3333e-01,
                        -3.1919e-16,
                        -3.3307e-16,
                        -2.5153e-17,
                    ],
                    [
                        -3.4790e-02,
                        -7.3648e-02,
                        -1.4262e-01,
                        -3.1339e-01,
                        -8.2916e-01,
                        2.5000e-01,
                        2.8449e-16,
                        -4.3368e-17,
                    ],
                    [
                        -1.1124e-02,
                        -2.3548e-02,
                        -4.5601e-02,
                        -1.0020e-01,
                        -2.6511e-01,
                        -8.7928e-01,
                        1.7647e-01,
                        -1.3878e-17,
                    ],
                    [
                        -2.6552e-03,
                        -5.6211e-03,
                        -1.0885e-02,
                        -2.3919e-02,
                        -6.3284e-02,
                        -2.0989e-01,
                        -9.1270e-01,
                        1.1111e-01,
                    ],
                ],
                [
                    [
                        8.4615e-01,
                        -1.4020e-17,
                        -5.3729e-16,
                        5.1558e-17,
                        -5.9539e-17,
                        4.8426e-17,
                        1.2853e-16,
                        2.7276e-17,
                    ],
                    [
                        -2.2840e-01,
                        7.1429e-01,
                        0.0000e00,
                        -1.6653e-16,
                        1.6653e-16,
                        1.1102e-16,
                        2.7756e-16,
                        -2.0817e-17,
                    ],
                    [
                        -2.1624e-01,
                        -4.4263e-01,
                        6.0000e-01,
                        -3.1919e-16,
                        -2.3592e-16,
                        3.1919e-16,
                        -5.8981e-17,
                        -5.2042e-18,
                    ],
                    [
                        -1.5991e-01,
                        -3.2733e-01,
                        -5.9161e-01,
                        5.0000e-01,
                        -2.2204e-16,
                        -4.1633e-17,
                        1.4572e-16,
                        -2.6021e-17,
                    ],
                    [
                        -9.5992e-02,
                        -1.9649e-01,
                        -3.5514e-01,
                        -7.0035e-01,
                        4.1176e-01,
                        -2.7756e-17,
                        1.8041e-16,
                        -6.4185e-17,
                    ],
                    [
                        -4.7166e-02,
                        -9.6547e-02,
                        -1.7450e-01,
                        -3.4412e-01,
                        -7.8038e-01,
                        3.3333e-01,
                        -4.8919e-16,
                        9.5410e-17,
                    ],
                    [
                        -1.8891e-02,
                        -3.8669e-02,
                        -6.9890e-02,
                        -1.3782e-01,
                        -3.1256e-01,
                        -8.3918e-01,
                        2.6316e-01,
                        5.7246e-17,
                    ],
                    [
                        -6.0876e-03,
                        -1.2461e-02,
                        -2.2522e-02,
                        -4.4414e-02,
                        -1.0072e-01,
                        -2.7043e-01,
                        -8.8195e-01,
                        2.0000e-01,
                    ],
                ],
                [
                    [
                        8.6667e-01,
                        -6.8770e-17,
                        -2.7995e-16,
                        -6.0583e-16,
                        6.0395e-17,
                        1.0834e-16,
                        -1.2672e-16,
                        3.5785e-17,
                    ],
                    [
                        -2.0207e-01,
                        7.5000e-01,
                        1.1102e-16,
                        -5.5511e-17,
                        0.0000e00,
                        1.6653e-16,
                        -2.7756e-17,
                        3.4694e-17,
                    ],
                    [
                        -1.9949e-01,
                        -3.9869e-01,
                        6.4706e-01,
                        2.0817e-16,
                        -3.3307e-16,
                        1.6653e-16,
                        3.0878e-16,
                        -5.8981e-17,
                    ],
                    [
                        -1.5736e-01,
                        -3.1449e-01,
                        -5.4134e-01,
                        5.5556e-01,
                        5.5511e-17,
                        -1.8041e-16,
                        -1.4919e-16,
                        9.0206e-17,
                    ],
                    [
                        -1.0330e-01,
                        -2.0645e-01,
                        -3.5537e-01,
                        -6.4983e-01,
                        4.7368e-01,
                        4.8572e-17,
                        -3.4694e-17,
                        -3.1225e-17,
                    ],
                    [
                        -5.7103e-02,
                        -1.1412e-01,
                        -1.9644e-01,
                        -3.5921e-01,
                        -7.3315e-01,
                        4.0000e-01,
                        -1.7347e-16,
                        6.5919e-17,
                    ],
                    [
                        -2.6604e-02,
                        -5.3170e-02,
                        -9.1522e-02,
                        -1.6736e-01,
                        -3.4158e-01,
                        -7.9722e-01,
                        3.3333e-01,
                        1.5613e-17,
                    ],
                    [
                        -1.0392e-02,
                        -2.0768e-02,
                        -3.5749e-02,
                        -6.5371e-02,
                        -1.3342e-01,
                        -3.1140e-01,
                        -8.4632e-01,
                        2.7273e-01,
                    ],
                ],
                [
                    [
                        8.8235e-01,
                        -1.9373e-16,
                        1.3060e-17,
                        2.0996e-16,
                        4.1722e-18,
                        2.7963e-17,
                        3.6134e-16,
                        -6.0316e-17,
                    ],
                    [
                        -1.8113e-01,
                        7.7778e-01,
                        -3.0531e-16,
                        6.1062e-16,
                        1.9429e-16,
                        5.5511e-17,
                        6.9389e-17,
                        -1.3878e-17,
                    ],
                    [
                        -1.8461e-01,
                        -3.6238e-01,
                        6.8421e-01,
                        -6.1062e-16,
                        8.3267e-17,
                        2.4980e-16,
                        -2.1164e-16,
                        -4.5103e-17,
                    ],
                    [
                        -1.5290e-01,
                        -3.0015e-01,
                        -4.9820e-01,
                        6.0000e-01,
                        -5.5511e-17,
                        2.4286e-16,
                        -6.5919e-17,
                        -1.7347e-17,
                    ],
                    [
                        -1.0733e-01,
                        -2.1068e-01,
                        -3.4970e-01,
                        -6.0474e-01,
                        5.2381e-01,
                        7.6328e-17,
                        4.3021e-16,
                        -1.6306e-16,
                    ],
                    [
                        -6.4721e-02,
                        -1.2705e-01,
                        -2.1088e-01,
                        -3.6467e-01,
                        -6.8917e-01,
                        4.5455e-01,
                        -2.7062e-16,
                        6.2450e-17,
                    ],
                    [
                        -3.3650e-02,
                        -6.6054e-02,
                        -1.0964e-01,
                        -1.8960e-01,
                        -3.5832e-01,
                        -7.5625e-01,
                        3.9130e-01,
                        1.0408e-16,
                    ],
                    [
                        -1.5061e-02,
                        -2.9564e-02,
                        -4.9072e-02,
                        -8.4861e-02,
                        -1.6037e-01,
                        -3.3848e-01,
                        -8.0952e-01,
                        3.3333e-01,
                    ],
                ],
            ]
        )


        GBT_B_matrix = jnp.array(
            [
                [
                    6.6667e-01,
                    5.7735e-01,
                    1.4907e-01,
                    -1.9296e-08,
                    3.3959e-08,
                    -3.0288e-09,
                    -2.4869e-08,
                    1.3988e-08,
                ],
                [
                    4.0000e-01,
                    4.6188e-01,
                    2.5555e-01,
                    7.5593e-02,
                    9.5238e-03,
                    2.8427e-09,
                    -1.6951e-08,
                    9.3280e-09,
                ],
                [
                    2.8571e-01,
                    3.7115e-01,
                    2.6620e-01,
                    1.2599e-01,
                    3.8961e-02,
                    7.1788e-03,
                    6.0030e-04,
                    -2.6170e-10,
                ],
                [
                    2.2222e-01,
                    3.0792e-01,
                    2.5297e-01,
                    1.4966e-01,
                    6.5268e-02,
                    2.0616e-02,
                    4.4824e-03,
                    6.0187e-04,
                ],
                [
                    1.8182e-01,
                    2.6243e-01,
                    2.3455e-01,
                    1.5859e-01,
                    8.3916e-02,
                    3.4790e-02,
                    1.1124e-02,
                    2.6552e-03,
                ],
                [
                    1.5385e-01,
                    2.2840e-01,
                    2.1624e-01,
                    1.5991e-01,
                    9.5992e-02,
                    4.7166e-02,
                    1.8891e-02,
                    6.0876e-03,
                ],
                [
                    1.3333e-01,
                    2.0207e-01,
                    1.9949e-01,
                    1.5736e-01,
                    1.0330e-01,
                    5.7103e-02,
                    2.6604e-02,
                    1.0392e-02,
                ],
                [
                    1.1765e-01,
                    1.8113e-01,
                    1.8461e-01,
                    1.5290e-01,
                    1.0733e-01,
                    6.4721e-02,
                    3.3650e-02,
                    1.5061e-02,
                ],
            ]
        )

        i = int(step) - 1
        b = GBT_B_matrix[i]
        
        GBT_b = jnp.expand_dims(b, -1)
        GBT_a = GBT_A_matrix[i]
        
        assert jnp.allclose(GBT_B, GBT_b, rtol=1e-04, atol=1e-06), f"\nGBT_B\n{GBT_B}\n\nnot equal to\n\nGBT_b\n{GBT_b}"
        assert jnp.allclose(GBT_A, GBT_a, rtol=1e-04, atol=1e-06), f"\nGBT_A\n{GBT_A}\n\nnot equal to\n\nGBT_a\n{GBT_a}"
        
        # return GBT_a, GBT_b, C, D
        return GBT_A, GBT_B, C, D

    def collect_SSM_vars(self, A, B, C, D, f, t_step=0, alpha=0.5):
        """
        turns the continuos HiPPO matrix components into discrete ones

        Args:
            A (jnp.ndarray): matrix to be discretized
            B (jnp.ndarray): matrix to be discretized
            C (jnp.ndarray): matrix to be discretized
            D (jnp.ndarray): matrix to be discretized
            f (jnp.ndarray): input signal
            alpha (float, optional): used for determining which generalized bilinear transformation to use

        Returns:
            Ab (jnp.ndarray): discrete form of the HiPPO matrix
            Bb (jnp.ndarray): discrete form of the HiPPO matrix
            Cb (jnp.ndarray): discrete form of the HiPPO matrix
            Db (jnp.ndarray): discrete form of the HiPPO matrix
        """
        N = A.shape[0]

        if t_step == 0:
            L = f.shape[0]  # seq_L, 1
            assert (
                L == self.seq_L
            ), f"sequence length must match, currently {L} != {self.seq_L}"
            assert N == self.N, f"Order number must match, currently {N} != {self.N}"
        else:
            L = t_step
            assert t_step >= 1, f"time step must be greater than 0, currently {t_step}"
            assert N == self.N, f"Order number must match, currently {N} != {self.N}"
            
        Ab, Bb, Cb, Db = self.discretize(A, B, C, D, step=L, alpha=alpha)

        return Ab, Bb, Cb, Db

    def scan_SSM(self, Ab, Bb, Cb, Db, c_0, f):
        """
        This is for returning the discretized hidden state often needed for an RNN.
        Args:
            Ab (jnp.ndarray): the discretized A matrix
            Bb (jnp.ndarray): the discretized B matrix
            Cb (jnp.ndarray): the discretized C matrix
            f (jnp.ndarray): the input sequence
            c_0 (jnp.ndarray): the initial hidden state
        Returns:
            the next hidden state (aka coefficients representing the function, f(t))
        """

        def step(c_k_1, f_k):
            """
            Get descretized coefficients of the hidden state by applying HiPPO matrix to input sequence, u_k, and previous hidden state, x_k_1.
            Args:
                c_k_1: previous hidden state
                f_k: output from function f at, descritized, time step, k.

            Returns:
                c_k: current hidden state
                y_k: current output of hidden state applied to Cb (sorry for being vague, I just dont know yet)
            """
            print(f"Aa.shape: {Ab.shape}")
            print(f"c_k_1.shape: {c_k_1.shape}")
            print(f"Bb.shape: {Bb.shape}")
            print(f"f_k.shape: {f_k.shape}")
            part1 = (Ab @ c_k_1)
            print(f"part1.shape: {part1.shape}")
            part2 = jnp.expand_dims((Bb @ f_k), -1)
            print(f"part2.shape: {part2.shape}")

            c_k = part1 + part2
            y_k = Cb @ c_k  # + (Db.T @ f_k)

            return c_k, y_k
        
        return jax.lax.scan(step, c_0, f)

    def loop_SSM(self, Ab, Bb, Cb, Db, c_0, f):
        """
        This is for returning the discretized hidden state often needed for an RNN.
        Args:
            Ab (jnp.ndarray): the discretized A matrix
            Bb (jnp.ndarray): the discretized B matrix
            Cb (jnp.ndarray): the discretized C matrix
            f (jnp.ndarray): the input sequence
            c_0 (jnp.ndarray): the initial hidden state
        Returns:
            the next hidden state (aka coefficients representing the function, f(t))
        """

        def step(c_k_1, f_k):
            """
            Get descretized coefficients of the hidden state by applying HiPPO matrix to input sequence, u_k, and previous hidden state, x_k_1.
            Args:
                c_k_1: previous hidden state
                f_k: output from function f at, descritized, time step, k.

            Returns:
                c_k: current hidden state
                y_k: current output of hidden state applied to Cb (sorry for being vague, I just dont know yet)
            """

            part1 = Ab @ c_k_1
            # print(f"Ab @ c_k_1 shape:\n{(part1).shape}")
            # print(f"Ab @ c_k_1:\n{part1}")
            part2 = f_k.T * Bb.T
            # print(f"Bb * f_k shape:\n{part2.shape}")
            # print(f"Bb * f_k:\n{part2}")
            c_k = part1 + part2.T

            return c_k

        return step(c_0, f)

In [25]:
def test():
    # N = 256
    # L = 128

    N = 8
    L = 8

    np.random.seed(1701)
    x = np.array(
        [
            [0.3527],
            [0.6617],
            [0.2434],
            [0.6674],
            [1.2293],
            [0.0964],
            [-2.2756],
            [0.5618],
        ],
        dtype=np.float32,
    )
    # x = np.array(
    #     [
    #         [1.00000],
    #         [1.00000],
    #         [1.00000],
    #         [1.00000],
    #         [1.00000],
    #         [1.00000],
    #         [1.00000],
    #         [1.00000],
    #     ],
    #     dtype=np.float32,
    # )

    # x = torch.randn(L, 1)
    x = torch.tensor(x, dtype=torch.float32)

    print(f"THIS IS X:\n{x}")

    # ----------------------------------------------------------------------------------
    loss = nn.MSELoss()

    # ----------------------------------------------------------------------------------
    # ------------------------------ Test HiPPO LegT model -----------------------------
    # ----------------------------------------------------------------------------------
    print("\nTesting HiPPO LegT model")
    hippo_legt = HiPPO_LegT(N, dt=1.0 / L)

    c_k = hippo_legt(x)

    # print(f"Gu's Coeffiecients for LegT:\n{c_k}")
    # print(f"Gu's Coeffiecient shapes for LegT:\n{c_k.shape}")

    # z = hippo_legt.reconstruct(c_k)
    # print(f"Gu's Reconstruction for LegT:\n{z}")
    # print(f"Gu's Reconstruction shape for LegT:\n{z.shape}")

    # mse = loss(z[-1, 0, :L], x.squeeze(-1))
    # print(f"h-MSE shape:\n{mse}")
    # print(f"end of test for HiPPO LegT model")

    # ----------------------------------------------------------------------------------
    # ------------------------------ Test HiPPO LegS model -----------------------------
    # ----------------------------------------------------------------------------------
    print("\nTesting HiPPO LegS model")
    hippo_legs = HiPPO_LegS(N, max_length=L)  # The Gu's

    c_k = hippo_legs(x, fast=True)

    print(f"Gu's Coeffiecients  for LegS:\n{c_k}")
    print(f"Gu's Coeffiecient shapes for LegS:\n{c_k.shape}")

    # z = hippo_legs.reconstruct(c_k)

    # print(f"Gu's Reconstruction for LegS:\n{z}")
    # print(f"Gu's Reconstruction shape for LegS:\n{z.shape}")

    # print(y-z)
    print(f"end of test for HiPPO LegT model")

    # ----------------------------------------------------------------------------------
    # ------------------------------ Test Generic HiPPO model --------------------------
    # ----------------------------------------------------------------------------------
    the_measure = "legs"
    print(f"\nTesting BRYANS HiPPO-{the_measure} model")

    hippo_LegS_B = HiPPO(
        N=N,
        max_length=L,
        measure=the_measure,
        step=1.0 / L,
        GBT_alpha=0.5,
        seq_L=L,
        v="nv",
        lambda_n=1.0,
        fourier_type="fru",
        alpha=0.0,
        beta=1.0,
    )  # Bryan's

    x = jnp.asarray(x, dtype=jnp.float32)  # convert torch array to jax array
    print(f"input:\n{x}")
    print(f"input type:\n{type(x)}")

    print(f"Bryan's Coeffiecients for HiPPO-{the_measure}")

    LTI_bool = True
    c_k = None
    i = 0
    params = hippo_LegS_B.init(key2, f=x, init_state=c_k, t_step=i)
    if LTI_bool:
        for i in range(1, (x.shape[0] + 1)):
            # c_k = hippo_LegS_B.apply(params, f=x, init_state=c_k, t_step=i)
            c_k = hippo_LegS_B.apply(params, f=x, init_state=c_k, t_step=i)
            gu_c_k = jnp.array(
                        [
                            [
                                [
                                    2.3513e-01,
                                    2.0363e-01,
                                    5.2577e-02,
                                    -6.8057e-09,
                                    1.1977e-08,
                                    -1.0683e-09,
                                    -8.7712e-09,
                                    4.9337e-09,
                                ]
                            ],
                            [
                                [
                                    4.0576e-01,
                                    2.6490e-01,
                                    -3.3701e-02,
                                    -5.6627e-02,
                                    -7.1343e-03,
                                    -5.3342e-09,
                                    -1.3616e-08,
                                    7.8134e-09,
                                ]
                            ],
                            [
                                [
                                    3.5937e-01,
                                    7.2189e-02,
                                    -2.2545e-01,
                                    -8.6126e-02,
                                    2.5252e-02,
                                    1.1226e-02,
                                    9.3872e-04,
                                    -2.7088e-09,
                                ]
                            ],
                            [
                                [
                                    4.2782e-01,
                                    1.3816e-01,
                                    -6.5221e-02,
                                    1.5500e-01,
                                    1.5606e-01,
                                    2.6968e-02,
                                    -4.6502e-03,
                                    -1.5067e-03,
                                ]
                            ],
                            [
                                [
                                    5.7355e-01,
                                    3.0244e-01,
                                    8.4267e-02,
                                    1.8955e-01,
                                    7.2271e-07,
                                    -1.4422e-01,
                                    -7.2802e-02,
                                    -1.3106e-02,
                                ]
                            ],
                            [
                                [
                                    5.0014e-01,
                                    1.0705e-01,
                                    -1.8648e-01,
                                    -1.3038e-01,
                                    -2.6791e-01,
                                    -1.7971e-01,
                                    4.9144e-02,
                                    8.3597e-02,
                                ]
                            ],
                            [
                                [
                                    1.3004e-01,
                                    -4.8061e-01,
                                    -7.1708e-01,
                                    -4.4194e-01,
                                    -2.8475e-01,
                                    3.7278e-02,
                                    2.1051e-01,
                                    5.7035e-02,
                                ]
                            ],
                            [
                                [
                                    1.8084e-01,
                                    -2.9561e-01,
                                    -2.3676e-01,
                                    3.0236e-01,
                                    5.1647e-01,
                                    6.1457e-01,
                                    3.6490e-01,
                                    -2.4947e-02,
                                ]
                            ],
                        ]
                    )
            idx = i - 1
            g_c_k = gu_c_k[0][idx]
            gu = jnp.expand_dims(g_c_k, -1)
            #assert jnp.allclose(c_k[:3], gu[:3], rtol=1e-04, atol=1e-06), f"\nc_k\n{c_k}\n\n!=\n\ngu_c_k[0][{idx}]\n{gu}\n"
            print(f"c_k output:\n{c_k}\n\n")
    else:
        for _ in range(x.shape[0]):
            c_k = hippo_LegS_B.apply(params, f=x, init_state=c_k)
            #print(c_k)
            # print(f"Bryan's Coeffiecient shape for HiPPO-{the_measure}:\n{c_k.shape}")

    # y_legs = hippo_LegS_B.apply(
    #     {"params": params}, c_k, method=hippo_LegS_B.reconstruct
    # )

    # print(f"Bryan's Reconstruction for HiPPO-{the_measure}:\n{y_legs}")
    # print(f"Bryan's Reconstruction shape for HiPPO-{the_measure}:\n{y_legs.shape}")

    print(f"end of test for HiPPO-{the_measure} model")

In [26]:
test()


THIS IS X:
tensor([[ 0.3527],
        [ 0.6617],
        [ 0.2434],
        [ 0.6674],
        [ 1.2293],
        [ 0.0964],
        [-2.2756],
        [ 0.5618]])

Testing HiPPO LegT model

Testing HiPPO LegS model
u - Gu: tensor([[[ 2.3513e-01,  2.0363e-01,  5.2577e-02, -6.8057e-09,  1.1977e-08,
          -1.0683e-09, -8.7712e-09,  4.9337e-09],
         [ 2.6468e-01,  3.0563e-01,  1.6910e-01,  5.0020e-02,  6.3019e-03,
           1.8810e-09, -1.1216e-08,  6.1723e-09],
         [ 6.9543e-02,  9.0339e-02,  6.4793e-02,  3.0666e-02,  9.4831e-03,
           1.7473e-03,  1.4611e-04, -6.3698e-11],
         [ 1.4831e-01,  2.0551e-01,  1.6883e-01,  9.9882e-02,  4.3560e-02,
           1.3759e-02,  2.9916e-03,  4.0169e-04],
         [ 2.2351e-01,  3.2261e-01,  2.8834e-01,  1.9495e-01,  1.0316e-01,
           4.2767e-02,  1.3674e-02,  3.2641e-03],
         [ 1.4831e-02,  2.2018e-02,  2.0845e-02,  1.5415e-02,  9.2537e-03,
           4.5468e-03,  1.8211e-03,  5.8684e-04],
         [-3.0341e-01, -4.

  q = jnp.arange(N, dtype=jnp.float64)
  q = jnp.arange(
  self.eval_matrix = torch.from_numpy(


Aa.shape: (8, 8)
c_k_1.shape: (8, 1)
Bb.shape: (8, 1)
f_k.shape: (1,)
part1.shape: (8, 1)
part2.shape: (8, 1)
Aa.shape: (8, 8)
c_k_1.shape: (8, 1)
Bb.shape: (8, 1)
f_k.shape: (1,)
part1.shape: (8, 1)
part2.shape: (8, 1)


  q = jnp.arange(


c_k output:
[[-0.08678842]
 [ 1.123388  ]
 [ 1.4680651 ]
 [-0.16593416]
 [-0.10218585]
 [ 0.23969904]
 [-0.09888125]
 [-0.00468131]]


Aa.shape: (8, 8)
c_k_1.shape: (8, 1)
Bb.shape: (8, 1)
f_k.shape: (1,)
part1.shape: (8, 1)
part2.shape: (8, 1)
c_k output:
[[-0.14432177]
 [ 0.09979893]
 [ 1.1887181 ]
 [ 0.93654966]
 [-0.11057909]
 [-0.16303433]
 [ 0.11659439]
 [ 0.4009422 ]]


Aa.shape: (8, 8)
c_k_1.shape: (8, 1)
Bb.shape: (8, 1)
f_k.shape: (1,)
part1.shape: (8, 1)
part2.shape: (8, 1)
c_k output:
[[-0.07438603]
 [-0.16257854]
 [ 0.5743408 ]
 [ 1.1077945 ]
 [ 0.42409083]
 [-0.07744831]
 [-0.1627025 ]
 [ 0.01196916]]


Aa.shape: (8, 8)
c_k_1.shape: (8, 1)
Bb.shape: (8, 1)
f_k.shape: (1,)
part1.shape: (8, 1)
part2.shape: (8, 1)
c_k output:
[[-0.0212223 ]
 [-0.215584  ]
 [ 0.20531823]
 [ 0.88195187]
 [ 0.74549735]
 [ 0.1178245 ]
 [-0.01753642]
 [-0.22679192]]


Aa.shape: (8, 8)
c_k_1.shape: (8, 1)
Bb.shape: (8, 1)
f_k.shape: (1,)
part1.shape: (8, 1)
part2.shape: (8, 1)
c_k output:
[[ 0.015