# HiPPO Matrices
---

## Table of Contents
* [Loading In Necessary Packages](#load-packages)
* [Instantiate The HiPPO Matrix](#instantiate-the-hippo-matrix)
    * [Translated Legendre (LegT)](#translated-legendre-legt)
        * [LegT](#legt)
        * [LMU](#lmu)
    * [Translated Laguerre (LagT)](#translated-laguerre-lagt)
    * [Scaled Legendre (LegS)](#scaled-legendre-legs)
    * [Fourier Basis](#fourier-basis)
        * [Fourier Recurrent Unit (FRU)](#fourier-recurrent-unit-fru)
        * [Truncated Fourier (FouT)](#truncated-fourier-fout)
        * [Fourier With Decay (FourD)](#fourier-with-decay-fourd)
* [Gu's Linear Time Invariant (LTI) HiPPO Operator](#gus-hippo-legt-operator)
* [Gu's Scale invariant (LSI) HiPPO Operator](#gus-scale-invariant-hippo-legs-operator)
* [Implementation Of General HiPPO Operator](#implementation-of-general-hippo-operator)
* [Test Generalized Bilinear Transform and Zero Order Hold Matrices](#test-generalized-bilinear-transform-and-zero-order-hold-matrices)
    * [Testing Forward Euler on GBT matrices](#testing-forward-euler-transform-for-lti-and-lsi)
    * [Testing Backward Euler on GBT matrices](#testing-backward-euler-transform-for-lti-and-lsi-on-legs-matrices)
    * [Testing Bidirectional on GBT matrices](#testing-lti-and-lsi-operators-with-bidirectional-transform)
    * [Testing ZOH on GBT matrices](#testing-zoh-transform-for-lti-and-lsi-on-legs-matrices)
* [Testing HiPPO Operators](#test-hippo-operators)
    * [Testing Forward Euler on HiPPO Operators](#testing-lti-and-lsi-operators-with-forward-euler-transform)
    * [Testing Backward Euler on HiPPO Operators](#testing-lti-and-lsi-operators-with-backward-euler-transform)
    * [Testing Bidirectional on HiPPO Operators](#testing-lti-and-lsi-operators-with-bidirectional-transform)
    * [Testing ZOH on HiPPO Operators](#testing-lti-and-lsi-operators-with-zoh-transform)
---


## Load Packages

In [1]:
import os
import sys

module_path = os.path.abspath(os.path.join("../../../"))
if module_path not in sys.path:
    sys.path.append(module_path)

In [2]:
## import packages
import math

import jax
import jax.numpy as jnp
import requests
from flax import linen as jnn
from jax.nn.initializers import lecun_normal, uniform
from jax.numpy.linalg import eig, inv, matrix_power
from jax.scipy.signal import convolve

from jaxtyping import Array, Float, Float16, Float32, Float64
from typing import Callable, List, Optional, Tuple, Any, Union

from scipy import linalg as la
from scipy import signal
from scipy import special as ss

from src.data.process import moving_window, rolling_window

# import modules
from src.models.hippo.gu_transition import GuTransMatrix
from src.models.hippo.transition import TransMatrix
from src.models.hippo.unroll import (
    basis,
    measure,
    variable_unroll_matrix,
    variable_unroll_matrix_sequential,
)
from src.utils.ops import genlaguerre

print(jax.devices())
print(f"The Device: {jax.lib.xla_bridge.get_backend().platform}")

[StreamExecutorGpuDevice(id=0, process_index=0, slice_index=0)]
The Device: gpu


In [3]:
from functools import partial
from typing import Any

import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data as data
from einops import rearrange, reduce, repeat

print(f"MPS enabled: {torch.backends.mps.is_available()}")

MPS enabled: False


In [4]:
torch.set_printoptions(linewidth=150)
np.set_printoptions(linewidth=150)
jnp.set_printoptions(linewidth=150)

In [5]:
seed = 1701
key = jax.random.PRNGKey(seed)

In [6]:
num_copies = 5
subkeys = jax.random.split(key, num=num_copies)
key = subkeys[0]

## Gu's HiPPO Linear Time Invariant Operator

In [7]:
class gu_HiPPO_LTI(nn.Module):
    """Linear time invariant x' = Ax + Bu"""

    def __init__(
        self,
        N,
        method="legt",
        dt=1.0,
        T=1.0,
        discretization=0.5,
        lambda_n=1.0,
        alpha=0.0,
        beta=1.0,
        c=0.0,
    ):
        """
        N: the order of the HiPPO projection
        dt: discretization step size - should be roughly inverse to the length of the sequence
        """
        super().__init__()

        self.method = method
        self.N = N
        self.dt = dt
        self.T = T
        self.c = c

        matrices = GuTransMatrix(
            N=N, measure=method, lambda_n=lambda_n, alpha=alpha, beta=beta
        )
        A = np.asarray(matrices.A, dtype=np.float32)
        B = np.asarray(matrices.B, dtype=np.float32)
        # A, B = transition(method, N)
        A = A + (np.eye(N) * c)
        self.A = A
        self.B = B.squeeze(-1)
        self.measure_fn = measure(method)

        C = np.ones((1, N))
        D = np.zeros((1,))
        if type(discretization) in [float, int]:
            dA, dB, _, _, _ = signal.cont2discrete(
                (A, B, C, D), dt=dt, method="gbt", alpha=discretization
            )
        else:
            dA, dB, _, _, _ = signal.cont2discrete((A, B, C, D), dt=dt, method="zoh")

        dB = dB.squeeze(-1)

        self.dA = torch.Tensor(dA.copy())  # (N, N)
        self.dB = torch.Tensor(dB.copy())  # (N, )

        self.vals = np.arange(0.0, T, dt)
        self.eval_matrix = basis(self.method, self.N, self.vals, c=self.c)  # (T/dt, N)
        self.measure = measure(self.method)(self.vals)

    def forward(self, inputs, fast=True):
        """
        inputs : (length, ...)
        output : (length, ..., N) where N is the order of the HiPPO projection
        """

        inputs = inputs.unsqueeze(-1)
        u = inputs * self.dB  # (length, ..., N)

        if fast:
            dA = repeat(self.dA, "m n -> l m n", l=u.size(0))
            return variable_unroll_matrix(dA, u)

        c = torch.zeros(u.shape[1:]).to(inputs)
        cs = []
        for f in inputs:

            # print(f"dA shape:\n{self.dA.shape}")
            # print(f"dA:\n{self.dA}")

            # print(f"c shape:\n{c.shape}")
            # print(f"c:\n{c}")

            # print(f"dB shape:\n{self.dB.shape}")
            # print(f"dB:\n{self.dB}")

            # print(f"f shape:\n{f.shape}")
            # print(f"f:\n{f}")

            part1 = F.linear(c, self.dA)
            part2 = self.dB * f

            c = part1 + part2

            # print(f"part1 shape:\n{part1.shape}")
            # print(f"part1 :\n{part1}")

            # print(f"part2 shape:\n{part2.shape}")
            # print(f"part2:\n{part2}")

            cs.append(c)
        return torch.stack(cs, dim=0)

    def reconstruct(
        self, c, evals=None
    ):  # TODO take in a times array for reconstruction
        """
        c: (..., N,) HiPPO coefficients (same as x(t) in S4 notation)
        output: (..., L,)
        """
        if evals is not None:
            eval_matrix = basis(self.method, self.N, evals)
        else:
            eval_matrix = self.eval_matrix

        m = self.measure[self.measure != 0.0]

        c = c.unsqueeze(-1)
        y = eval_matrix.to(c) @ c
        return y.squeeze(-1).flip(-1)

## Navigation To Table Of Contents
---
* [Table Of Contents](#table-of-contents)
* [Loading In Necessary Packages](#load-packages)
* [Instantiate The HiPPO Matrix](#instantiate-the-hippo-matrix)
* [Gu's Linear Time Invariant (LTI) HiPPO Operator](#gus-hippo-legt-operator)
* [Gu's Scale invariant (LSI) HiPPO Operator](#gus-scale-invariant-hippo-legs-operator)
* [Implementation Of General HiPPO Operator](#implementation-of-general-hippo-operator)
* [Test Generalized Bilinear Transform and Zero Order Hold Matrices](#test-generalized-bilinear-transform-and-zero-order-hold-matrices)
* [Testing HiPPO Operators](#test-hippo-operators)
---

## Gu's Scale invariant HiPPO LegS Operator

In [8]:
class gu_HiPPO_LSI(nn.Module):
    """Vanilla HiPPO-LegS model (scale invariant instead of time invariant)"""

    def __init__(
        self,
        N,
        method="legs",
        max_length=1024,
        discretization=0.5,
        lambda_n=1.0,
        alpha=0.0,
        beta=1.0,
    ):
        """
        max_length: maximum sequence length
        """
        super().__init__()
        self.N = N
        matrices = GuTransMatrix(
            N=N, measure=method, lambda_n=lambda_n, alpha=alpha, beta=beta
        )
        A = np.asarray(matrices.A, dtype=np.float32)
        B = np.asarray(matrices.B, dtype=np.float32)
        # A, B = transition(method, N)
        B = B.squeeze(-1)
        A_stacked = np.empty((max_length, N, N), dtype=A.dtype)
        B_stacked = np.empty((max_length, N), dtype=B.dtype)
        for t in range(1, max_length + 1):
            At = A / t
            Bt = B / t
            if discretization == 0.0:  # forward
                A_stacked[t - 1] = np.eye(N) + At
                B_stacked[t - 1] = Bt
            elif discretization == 1.0:  # backward
                A_stacked[t - 1] = la.solve_triangular(
                    np.eye(N) - At, np.eye(N), lower=True
                )
                B_stacked[t - 1] = la.solve_triangular(np.eye(N) - At, Bt, lower=True)
            elif discretization == 0.5:  # bilinear
                # A_stacked[t - 1] = la.solve_triangular(
                #     np.eye(N) - At / 2, np.eye(N) + At / 2, lower=True
                # )
                # B_stacked[t - 1] = la.solve_triangular(
                #     np.eye(N) - At / 2, Bt, lower=True
                # )
                alpha = 0.5
                A_stacked[t - 1] = np.linalg.lstsq(
                    np.eye(N) - (At * alpha), np.eye(N) + (At * alpha), rcond=None
                )[
                    0
                ]  # TODO: Referencing this: https://stackoverflow.com/questions/64527098/numpy-linalg-linalgerror-singular-matrix-error-when-trying-to-solve
                B_stacked[t - 1] = np.linalg.lstsq(
                    np.eye(N) - (At * alpha), Bt, rcond=None
                )[0]
            else:  # ZOH
                A_stacked[t - 1] = la.expm(A * (math.log(t + 1) - math.log(t)))
                # A_stacked[t - 1] = la.expm(At)
                B_stacked[t - 1] = la.solve_triangular(
                    A, A_stacked[t - 1] @ B - B, lower=True
                )

                # A_stacked[t - 1] = la.expm(At)
                # B_stacked[t - 1] = la.inv(A) @ (la.expm(At) - np.eye(A.shape[0])) @ B

        # self.register_buffer('A_stacked', torch.Tensor(A_stacked)) # (max_length, N, N)
        # self.register_buffer('B_stacked', torch.Tensor(B_stacked)) # (max_length, N)

        self.A_stacked = torch.Tensor(A_stacked.copy())  # (max_length, N, N)
        self.B_stacked = torch.Tensor(B_stacked.copy())  # (max_length, N)

        vals = np.linspace(0.0, 1.0, max_length)
        self.eval_matrix = torch.from_numpy(
            np.asarray(
                ((B[:, None] * ss.eval_legendre(np.arange(N)[:, None], 2 * vals - 1)).T)
            )
        )

    def forward(self, inputs, fast=True):
        """
        inputs : (length, ...)
        output : (length, ..., N) where N is the order of the HiPPO projection
        """

        L = inputs.shape[0]

        inputs = inputs.unsqueeze(-1)
        u = torch.transpose(inputs, 0, -2)
        u = u * self.B_stacked[:L]
        # print(f"Gu - u * self.B_stacked[:L]: {u}")
        u = torch.transpose(u, 0, -2)  # (length, ..., N)

        if fast:
            result = variable_unroll_matrix(self.A_stacked[:L], u)
            return result

        c = torch.zeros(u.shape[1:]).to(inputs)
        cs = []
        for t, f in enumerate(inputs):
            # print(f"\n--------------step {t}----------------")
            # print(f"self.A_stacked[{t}] shape:\n{self.A_stacked[t].shape}")
            # print(f"self.A_stacked[{t}]:\n{self.A_stacked[t]}")

            # print(f"c shape:\n{c.shape}")
            # print(f"c:\n{c}")

            # print(f"self.B_stacked[{t}] shape:\n{self.B_stacked[t].shape}")
            # print(f"self.B_stacked[{t}]:\n{self.B_stacked[t]}")

            # print(f"f shape:\n{f.shape}")
            # print(f"f:\n{f}")

            part1 = F.linear(c, self.A_stacked[t])
            part2 = self.B_stacked[t] * f

            c = part1 + part2

            # print(f"part1 - {t} - shape:\n{part1.shape}")
            # print(f"part1 - {t} -:\n{part1}")

            # print(f"part2 - {t} - shape:\n{part2.shape}")
            # print(f"part2 - {t} -:\n{part2}")

            cs.append(c)
        return torch.stack(cs, dim=0)

    def reconstruct(self, c):
        a = self.eval_matrix.to(c) @ c.unsqueeze(-1)
        return a

## Navigation To Table Of Contents
---
* [Table Of Contents](#table-of-contents)
* [Loading In Necessary Packages](#load-packages)
* [Instantiate The HiPPO Matrix](#instantiate-the-hippo-matrix)
* [Gu's Linear Time Invariant (LTI) HiPPO Operator](#gus-hippo-legt-operator)
* [Gu's Scale invariant (LSI) HiPPO Operator](#gus-scale-invariant-hippo-legs-operator)
* [Implementation Of General HiPPO Operator](#implementation-of-general-hippo-operator)
* [Test Generalized Bilinear Transform and Zero Order Hold Matrices](#test-generalized-bilinear-transform-and-zero-order-hold-matrices)
* [Testing HiPPO Operators](#test-hippo-operators)
---

## Implementation Of General HiPPO Operator

In [None]:
class HiPPO(jnn.Module):
    
    def __init__(self):
        pass
    
    def __call__(self, f, init_state=None):
        return NotImplementedError

    def discretize(self, A, B, step, alpha=0.5, dtype=jnp.float32):
        return NotImplementedError
    
    def recurrence(self, A, B, c_0, f, dtype=jnp.float32):
        return NotImplementedError

In [9]:
class HiPPOLSI(jnn.Module):
    """
    class that constructs HiPPO model using the defined measure.

    Args:

        max_length (int):
            maximum sequence length to be input

        step_size (float):
            step size used for descretization

        N (int):
            order of the HiPPO projection, aka the number of coefficients to describe the matrix

        lambda_n (float):
            value associated with the tilt of legt
            - 1: tilt on legt
            - \sqrt(2n+1)(-1)^{N}: tilt associated with the legendre memory unit (LMU)

        alpha (float):
            The order of the Laguerre basis.

        beta (float):
            The scale of the Laguerre basis.

        GBT_alpha (float):
            represents which descretization transformation to use based off the alpha value

        measure (str):
            the measure used to define which way to instantiate the HiPPO matrix

        s_t (str):
            choice between LSI and LTI systems
            - "lsi"
            - "lti"

        dtype (jnp.float):
            represents the float precision of the class

        verbose (bool):
            shows the rolled out coefficients over time/scale

    """

    N: int
    max_length: int = 1024
    step_size: float = 1.0
    lambda_n: float = 1.0
    alpha: float = 0.0
    beta: float = 1.0
    GBT_alpha: float = 0.5
    measure: str = "legs"
    dtype: Any = jnp.float32
    verbose: bool = False

    def setup(self):
        matrices = TransMatrix(
            N=self.N,
            measure=self.measure,
            lambda_n=self.lambda_n,
            alpha=self.alpha,
            beta=self.beta,
            dtype=self.dtype,
        )

        self.GBT_A_list, self.GBT_B_list = self.temporal_GBT(
            matrices.A, matrices.B, dtype=self.dtype
        )

        vals = jnp.linspace(0.0, 1.0, self.max_length)
        self.eval_matrix = (
            (matrices.B)[:, None]
            * ss.eval_legendre(jnp.arange(self.N)[:, None], 2 * vals - 1)
        ).T

    def __call__(
        self,
        f: Float[Array, "batch seq_len input_size"],
        init_state: Optional[Float[Array, "batch input_size N"]] = None,
    ) -> Float[Array, "batch input_size N"]:

        if init_state is None:
            init_state = jnp.zeros((f.shape[0], 1, self.N))

        c_k = self.recurrence(
            A=self.GBT_A_list,
            B=self.GBT_B_list,
            c_0=init_state,
            f=f,
            dtype=self.dtype,
        )
        c_k = jnp.stack(c_k, axis=0)

        return c_k

    def temporal_GBT(
        self, A: Float[Array, "N_a N_b"], B: Float[Array, "N 1"], dtype=jnp.float32
    ) -> Tuple[List[Float[Array, "N_a N_b"]], List[Float[Array, "N 1"]]]:
        """
        Creates the list of discretized GBT matrices for the given step size
        """
        GBT_a_list = []
        GBT_b_list = []
        for i in range(1, self.max_length + 1):
            GBT_A, GBT_B = self.discretize(
                A, B, step=i, alpha=self.GBT_alpha, dtype=dtype
            )
            GBT_a_list.append(GBT_A)
            GBT_b_list.append(GBT_B)

        return GBT_a_list, GBT_b_list

    def discretize(
        self,
        A: Float[Array, "N_a N_b"],
        B: Float[Array, "N 1"],
        step: float,
        alpha: Union[float, str] = 0.5,
        dtype=jnp.float32,
    ) -> Tuple[Float[Array, "N_a N_b"], Float[Array, "N 1"]]:
        """
        function used for discretizing the HiPPO matrix

        Args:
            A (jnp.ndarray):
                shape: (N, N)
                matrix to be discretized

            B (jnp.ndarray):
                shape: (N, 1)
                matrix to be discretized

            step (float):
                step size used for discretization

            alpha (float, optional):
                used for determining which generalized bilinear transformation to use
                - forward Euler corresponds to α = 0,
                - backward Euler corresponds to α = 1,
                - bilinear corresponds to α = 0.5,
                - Zero-order Hold corresponds to α > 1
        """
        if alpha <= 1:
            assert (
                alpha == 0.0 or alpha == 0.5 or alpha == 1.0
            ), "alpha must be 0, 0.5, or 1"
        else:
            assert (
                alpha > 1 or type(alpha) == str
            ), "alpha must be greater than 1 for zero-order hold"
            if type(alpha) == str:
                assert (
                    alpha == "zoh"
                ), "if alpha is a string, it must be defined as 'zoh' for zero-order hold"

        I = jnp.eye(A.shape[0])

        if alpha <= 1:  # Generalized Bilinear Transformation
            step_size = 1 / step
            part1 = I - (step_size * alpha * A)
            part2 = I + (step_size * (1 - alpha) * A)

            GBT_A = jnp.linalg.lstsq(part1, part2, rcond=None)[0]
            GBT_B = jnp.linalg.lstsq(part1, (step_size * B), rcond=None)[0]

        else:  # Zero-order Hold
            # refer to this for why this works
            # https://en.wikipedia.org/wiki/Discretization#:~:text=A%20clever%20trick%20to%20compute%20Ad%20and%20Bd%20in%20one%20step%20is%20by%20utilizing%20the%20following%20property

            n = A.shape[0]
            b_n = B.shape[1]
            A_B_square = jnp.block(
                [[A, B], [jnp.zeros((b_n, n)), jnp.zeros((b_n, b_n))]]
            )
            if self.s_t == "lsi":
                A_B = jax.scipy.linalg.expm(
                    A_B_square * (math.log(step + self.step_size) - math.log(step))
                )
            else:
                A_B = jax.scipy.linalg.expm(A_B_square * self.step_size)

            GBT_A = A_B[0:n, 0:n]
            GBT_B = A_B[0:-b_n, -b_n:]

        return GBT_A.astype(dtype), GBT_B.astype(dtype)

    def recurrence(
        self,
        A: Float[Array, "N_a N_b"],
        B: Float[Array, "N 1"],
        c_0: Float[Array, "batch input_size N"],
        f: Float[Array, "batch seq_len input_size"],
        dtype=jnp.float32,
    ) -> Union[
        List[Float[Array, "batch input_size N"]], Float[Array, "batch input_size N"]
    ]:
        """
        This is for returning the discretized hidden state often needed for an RNN.
        Args:
            A (jnp.ndarray):
                shape: (N, N)
                the discretized A matrix

            B (jnp.ndarray):
                shape: (N, 1)
                the discretized B matrix

            c_0 (jnp.ndarray):
                shape: (batch size, input length, N)
                the initial hidden state

            f (jnp.ndarray):
                shape: (sequence length, 1)
                the input sequence


        Returns:
            the next hidden state (aka coefficients representing the function, f(t))
        """

        c_s = []

        c_k = c_0.copy()
        for i in range(f.shape[1]):
            c_k = jax.vmap(self.step, in_axes=(None, None, 0, 0))(
                A[i], B[i], c_k, f[:, i, :]
            )
            c_s.append((c_k.copy()).astype(dtype))

        if self.verbose:
            return c_s  # list of hidden states
        else:
            return c_s[-1]  # last hidden state

    def step(
        self,
        Ad: Float[Array, "N_a N_b"],
        Bd: Float[Array, "N 1"],
        c_k_i: Float[Array, "batch input_size N"],
        f_k: Float[Array, "batch seq_len input_size"],
    ) -> Float[Array, "batch input_size N"]:
        """
        Get descretized coefficients of the hidden state by applying HiPPO matrix to input sequence, u_k, and previous hidden state, x_k_1.
        Args:
            c_k_i:
                shape: (input length, N)
                previous hidden state

            f_k:
                shape: (1, )
                output from function f at, descritized, time step, k.

        Returns:
            c_k: current hidden state
            y_k: current output of hidden state applied to Cb (sorry for being vague, I just dont know yet)
        """

        c_k = (jnp.dot(c_k_i, Ad.T)) + (Bd.T * f_k)

        return c_k

    def reconstruct(
        self, c: Float[Array, "batch input_size N"]
    ) -> Float[Array, "batch seq_len input_size"]:
        y = self.eval_matrix @ c

        return y

In [10]:
class HiPPOLTI(jnn.Module):
    """
    class that constructs HiPPO model using the defined measure.

    Args:

        max_length (int):
            maximum sequence length to be input

        step_size (float):
            step size used for descretization

        N (int):
            order of the HiPPO projection, aka the number of coefficients to describe the matrix

        lambda_n (float):
            value associated with the tilt of legt
            - 1: tilt on legt
            - \sqrt(2n+1)(-1)^{N}: tilt associated with the legendre memory unit (LMU)

        alpha (float):
            The order of the Laguerre basis.

        beta (float):
            The scale of the Laguerre basis.

        GBT_alpha (float):
            represents which descretization transformation to use based off the alpha value

        measure (str):
            the measure used to define which way to instantiate the HiPPO matrix

        s_t (str):
            choice between LSI and LTI systems
            - "lsi"
            - "lti"

        dtype (jnp.float):
            represents the float precision of the class

        verbose (bool):
            shows the rolled out coefficients over time/scale

    """

    N: int
    step_size: float = 1.0
    lambda_n: float = 1.0
    alpha: float = 0.0
    beta: float = 1.0
    GBT_alpha: float = 0.5
    measure: str = "legs"
    basis_size: float = 1.0
    dtype: Any = jnp.float32
    verbose: bool = False

    def setup(self):
        matrices = TransMatrix(
            N=self.N,
            measure=self.measure,
            lambda_n=self.lambda_n,
            alpha=self.alpha,
            beta=self.beta,
            dtype=self.dtype,
        )

        self.Ad, self.Bd = self.discretize(
            A=matrices.A,
            B=matrices.B,
            step=self.step_size,
            alpha=self.GBT_alpha,
            dtype=self.dtype,
        )

        self.vals = jnp.arange(0.0, self.basis_size, self.step_size)
        jax.debug.print("self.vals shape:\n{x3}", x3=self.vals.shape)
        jax.debug.print("self.vals:\n{x4}", x4=self.vals)
        # self.eval_matrix = jax.vmap(self.basis, in_axes=(None, None, 0, None))(self.measure, self.N, self.vals, 0.0)
        self.eval_matrix = self.basis(
            self.measure, self.N, self.vals, c=0.0
        )  # (T/dt, N)
        # jax.debug.print("eval_matrix shape:\n{x5}", x5=self.eval_matrix.shape)

    def __call__(
        self,
        f: Float[Array, "batch seq_len input_size"],
        init_state: Optional[Float[Array, "batch input_size N"]] = None,
    ) -> Float[Array, "batch input_size N"]:

        if init_state is None:
            init_state = jnp.zeros((f.shape[0], 1, self.N))

        c_k = self.recurrence(
            Ad=self.Ad,
            Bd=self.Bd,
            c_0=init_state,
            f=f,
            dtype=self.dtype,
        )

        return c_k

    def discretize(
        self,
        A: Float[Array, "N_a N_b"],
        B: Float[Array, "N 1"],
        step: float,
        alpha: Union[float, str] = 0.5,
        dtype=jnp.float32,
    ) -> Tuple[Float[Array, "N_a N_b"], Float[Array, "N 1"]]:
        """
        function used for discretizing the HiPPO matrix

        Args:
            A (jnp.ndarray):
                shape: (N, N)
                matrix to be discretized

            B (jnp.ndarray):
                shape: (N, 1)
                matrix to be discretized

            C (jnp.ndarray):
                shape: (N, 1)
                matrix to be discretized

            D (jnp.ndarray):
                shape: (1,)
                matrix to be discretized

            step (float):
                step size used for discretization

            alpha (float, optional):
                used for determining which generalized bilinear transformation to use
                - forward Euler corresponds to α = 0,
                - backward Euler corresponds to α = 1,
                - bilinear corresponds to α = 0.5,
                - Zero-order Hold corresponds to α > 1
        """
        if alpha <= 1:
            assert (
                alpha == 0.0 or alpha == 0.5 or alpha == 1.0
            ), "alpha must be 0, 0.5, or 1"
        else:
            assert (
                alpha > 1 or type(alpha) == str
            ), "alpha must be greater than 1 for zero-order hold"
            if type(alpha) == str:
                assert (
                    alpha == "zoh"
                ), "if alpha is a string, it must be defined as 'zoh' for zero-order hold"

        I = jnp.eye(A.shape[0])

        if alpha <= 1:  # Generalized Bilinear Transformation
            step_size = 1 / step
            part1 = I - (step_size * alpha * A)
            part2 = I + (step_size * (1 - alpha) * A)

            GBT_A = jnp.linalg.lstsq(part1, part2, rcond=None)[0]
            GBT_B = jnp.linalg.lstsq(part1, (step_size * B), rcond=None)[0]

        else:  # Zero-order Hold
            # refer to this for why this works
            # https://en.wikipedia.org/wiki/Discretization#:~:text=A%20clever%20trick%20to%20compute%20Ad%20and%20Bd%20in%20one%20step%20is%20by%20utilizing%20the%20following%20property

            n = A.shape[0]
            b_n = B.shape[1]
            A_B_square = jnp.block(
                [[A, B], [jnp.zeros((b_n, n)), jnp.zeros((b_n, b_n))]]
            )
            if self.s_t == "lsi":
                A_B = jax.scipy.linalg.expm(
                    A_B_square * (math.log(step + self.step_size) - math.log(step))
                )
            else:
                A_B = jax.scipy.linalg.expm(A_B_square * self.step_size)

            GBT_A = A_B[0:n, 0:n]
            GBT_B = A_B[0:-b_n, -b_n:]

        return GBT_A.astype(dtype), GBT_B.astype(dtype)

    def recurrence(
        self,
        A: Float[Array, "N_a N_b"],
        B: Float[Array, "N 1"],
        c_0: Float[Array, "batch input_size N"],
        f: Float[Array, "batch seq_len input_size"],
        dtype=jnp.float32,
    ) -> Union[
        List[Float[Array, "batch input_size N"]], Float[Array, "batch input_size N"]
    ]:
        """
        This is for returning the discretized hidden state often needed for an RNN.
        Args:
            Ad (jnp.ndarray):
                shape: (N, N)
                the discretized A matrix

            Bd (jnp.ndarray):
                shape: (N, 1)
                the discretized B matrix

            f (jnp.ndarray):
                shape: (sequence length, 1)
                the input sequence

            c_0 (jnp.ndarray):
                shape: (batch size, input length, N)
                the initial hidden state

        Returns:
            the next hidden state (aka coefficients representing the function, f(t))
        """

        def step(
            c_k_i: Float32[Array, "batch input_size N"],
            f_k: Float32[Array, "batch seq_len input_size"],
        ):
            """
            Get descretized coefficients of the hidden state by applying HiPPO matrix to input sequence, u_k, and previous hidden state, x_k_1.
            Args:
                c_k_i:
                    shape: (input length, N)
                    previous hidden state

                f_k:
                    shape: (1, )
                    output from function f at, descritized, time step, k.

            Returns:
                c_k: current hidden state
                y_k: current output of hidden state applied to Cb (sorry for being vague, I just dont know yet)
            """

            c_k = (jnp.dot(c_k_i, Ad.T)) + (Bd.T * f_k)

            return c_k, c_k

        c_k, c_s = jax.vmap(jax.lax.scan, in_axes=(None, 0, 0))(step, c_0, f)

        if self.verbose:
            return c_s
        else:
            return c_k

    def measure_fn(self, method, c=0.0):

        if method == "legs":
            fn = lambda x: jnp.heaviside(x, 1.0) * jnp.exp(-x)
        elif method in ["legt", "lmu"]:
            fn = lambda x: jnp.heaviside(x, 0.0) * jnp.heaviside(1.0 - x, 0.0)
        elif method == "lagt":
            fn = lambda x: jnp.heaviside(x, 1.0) * jnp.exp(-x)
        elif method in ["fourier", "fru", "fout", "foud"]:
            fn = lambda x: jnp.heaviside(x, 1.0) * jnp.heaviside(1.0 - x, 1.0)
        else:
            raise NotImplementedError

        fn_tilted = lambda x: jnp.exp(c * x) * fn(x)

        return fn_tilted

    def basis(self, method, N, vals, c=0.0, truncate_measure=True):
        """
        vals: list of times (forward in time)
        returns: shape (T, N) where T is length of vals
        """
        eval_matrix = None
        if method == "legs":
            zero_N = self.N - 1
            jax.debug.print("vals shape:\n{x1}", x1=vals.shape)
            jax.debug.print("vals:\n{x2}", x2=vals)
            x = 2 * vals - 1
            # x = jnp.expand_dims(x, axis=0)
            jax.debug.print("x shape:\n{x1}", x1=x.shape)
            jax.debug.print("x:\n{x2}", x2=x)
            # x = jnp.arccos(x)
            jax.debug.print("x shape:\n{x1}", x1=x.shape)
            jax.debug.print("x:\n{x2}", x2=x)
            eval_matrix = jnp.real(
                jax.scipy.special.sph_harm(
                    m=jnp.array([0]), n=jnp.array([zero_N]), theta=x, phi=x
                )
            )
            # ri_eval_matrix = jax.vmap(jax.scipy.special.sph_harm, in_axes=(None, None, 0, 0))(jnp.array([0]), jnp.array([zero_N]), x, x)
            # eval_matrix = jnp.real(ri_eval_matrix)
            jax.debug.print("eval_matrix shape:\n{x3}", x3=eval_matrix.shape)
            jax.debug.print("eval_matrix:\n{x3}", x3=eval_matrix)
            # eval_matrix = ss.eval_legendre(zero_N, x).T
            # jax.debug.print("eval_matrix shape:\n{x3}", x3=eval_matrix.shape)
            # jax.debug.print("eval_matrix:\n{x3}", x3=eval_matrix)
            eval_matrix *= (2 * jnp.arange(N) + 1) ** 0.5 * (-1) ** jnp.arange(N)
            jax.debug.print("eval_matrix shape:\n{x3}", x3=eval_matrix.shape)
            jax.debug.print("eval_matrix:\n{x3}", x3=eval_matrix)

        elif method in ["legt", "lmu"]:
            zero_N = self.N - 1
            x = 1 - 2 * vals
            # x = jnp.expand_dims(x, axis=0)
            eval_matrix = jnp.real(
                jax.scipy.special.sph_harm(
                    m=jnp.array([0]), n=jnp.array([zero_N]), theta=x, phi=x
                )
            )
            # eval_matrix = jax.scipy.special.lpmn_values(
            #     m=0, n=zero_N, z=x, is_normalized=False
            # ).T # Legendre polynomials are special cases of legendre functions, in this case where m=0
            # eval_matrix = ss.eval_legendre(zero_N, x).T
            eval_matrix *= (2 * jnp.arange(N) + 1) ** 0.5 * (-1) ** jnp.arange(N)

        elif method == "lagt":
            vals = vals[::-1]
            # eval_matrix = ss.eval_genlaguerre(np.arange(N)[:, None], 0, vals)
            zero_N = self.N - 1
            eval_matrix = genlaguerre(zero_N, 0, vals)
            eval_matrix = eval_matrix * jnp.exp(-vals / 2)
            eval_matrix = eval_matrix.T

        elif method in ["fourier", "fru", "fout", "foud"]:
            cos = 2**0.5 * jnp.cos(
                2 * jnp.pi * jnp.arange(N // 2)[:, None] * (vals)
            )  # (N/2, T/dt)
            sin = 2**0.5 * jnp.sin(
                2 * jnp.pi * jnp.arange(N // 2)[:, None] * (vals)
            )  # (N/2, T/dt)
            cos[0] /= 2**0.5
            eval_matrix = jnp.stack([cos.T, sin.T], axis=-1).reshape(-1, N)  # (T/dt, N)
        #     print("eval_matrix shape", eval_matrix.shape)

        if truncate_measure:
            jax.debug.print("eval_matrix shape:\n{x3}", x3=eval_matrix.shape)
            jax.debug.print("eval_matrix:\n{x3}", x3=eval_matrix)
            tilting_fn = self.measure_fn(method, c=c)
            val = tilting_fn(vals)
            jax.debug.print("val shape:\n{x3}", x3=val.shape)
            jax.debug.print("val:\n{x3}", x3=val)
            eval_matrix = jnp.where(val == 0.0, 0.0, eval_matrix)

        p = eval_matrix * jnp.exp(-c * vals)[:, None]  # [::-1, None]

        return p

    def reconstruct(
        self, c: Float[Array, "batch input_size N"], evals=None
    ):  # TODO take in a times array for reconstruction
        """
        c: (..., N,) HiPPO coefficients (same as x(t) in S4 notation)
        output: (..., L,)
        """
        if evals is not None:
            eval_matrix = self.basis(self.measure, self.N, evals)
        else:
            eval_matrix = self.eval_matrix

        y = eval_matrix @ c

        return y

In [11]:
class HiPPO(jnn.Module):

    N: int
    max_length: int = 1024
    step_size: float = 1.0
    basis_size: float = 1.0
    lambda_n: float = 1.0
    alpha: float = 0.0
    beta: float = 1.0
    GBT_alpha: float = 0.5
    measure: str = "legs"
    s_t: str = "lti"
    truncate_measure: bool = True
    dtype: Any = jnp.float32
    verbose: bool = False

    def setup(self) -> None:

        jax.debug.print("In The HiPPO Setup")

        # Define the encoder that performs the polynomial projections with user specified matrix initialization
        if self.s_t == "lsi":
            jax.debug.print("Using LSI encoder")
            self.encoder = HiPPOLSI(
                N=self.N,
                max_length=self.max_length,
                step_size=self.step_size,
                lambda_n=self.lambda_n,
                alpha=self.alpha,
                beta=self.beta,
                GBT_alpha=self.GBT_alpha,
                measure=self.measure,
                dtype=self.dtype,
                verbose=self.verbose,
            )
        elif self.s_t == "lti":
            jax.debug.print("Using LTI encoder")
            self.encoder = HiPPOLTI(
                N=self.N,
                step_size=self.step_size,
                lambda_n=self.lambda_n,
                alpha=self.alpha,
                beta=self.beta,
                GBT_alpha=self.GBT_alpha,
                measure=self.measure,
                basis_size=self.basis_size,
                dtype=self.dtype,
                verbose=self.verbose,
            )
        else:
            raise ValueError(
                f"s_t must be either 'lsi' or 'lti'. s_t is currently set to: {self.s_t}"
            )

    def __call__(self, x, init_state=None):

        # Apply the polynomial projections to the input
        hidden = self.encoder(x, init_state=init_state)

        # Decode the polynomial projections to the output space through applying the coefficients to the basis
        # if self.s_t == "lti":
        #     output = self.encoder.reconstruct(c=hidden, evals=x)
        # else:
        #     output = self.encoder.reconstruct(c=hidden)

        return hidden #, output

In [12]:
import jax
import jax.numpy as jnp
import flax.linen as nn


class A(nn.Module):
    @nn.compact
    def __call__(self, x):
        return x

    def a_func(self, x):
        return x + 1


class B(nn.Module):
    @nn.compact
    def __call__(self, x):
        return x

    def a_func(self, x):
        return x + 1


class C(nn.Module):

    flag: bool = True

    def setup(self):
        if self.flag:
            self.encoder = A()
        else:
            self.encoder = B()

    def __call__(self, x):
        x = self.encoder(x)
        return x

In [13]:
x = jnp.ones((5, 5))
c = C()
variables = c.init(jax.random.PRNGKey(0), x)
model = c.bind(variables)
z = model.encoder.a_func(x)
print(z)

[[2. 2. 2. 2. 2.]
 [2. 2. 2. 2. 2.]
 [2. 2. 2. 2. 2.]
 [2. 2. 2. 2. 2.]
 [2. 2. 2. 2. 2.]]


## Navigation To Table Of Contents
---
* [Table Of Contents](#table-of-contents)
* [Loading In Necessary Packages](#load-packages)
* [Instantiate The HiPPO Matrix](#instantiate-the-hippo-matrix)
* [Gu's Linear Time Invariant (LTI) HiPPO Operator](#gus-hippo-legt-operator)
* [Gu's Scale invariant (LSI) HiPPO Operator](#gus-scale-invariant-hippo-legs-operator)
* [Implementation Of General HiPPO Operator](#implementation-of-general-hippo-operator)
* [Test Generalized Bilinear Transform and Zero Order Hold Matrices](#test-generalized-bilinear-transform-and-zero-order-hold-matrices)
* [Testing HiPPO Operators](#test-hippo-operators)
---

In [14]:
def random_16_input(key_generator, batch_size=16, data_size=784, input_size=28):
    # x = jax.random.randint(key_generator, (batch_size, data_size), 0, 255)
    x = jax.random.uniform(key_generator, (batch_size, data_size))
    return np.asarray(jax.vmap(moving_window, in_axes=(0, None))(x, input_size))

## Test Generalized Bilinear Transform and Zero Order Hold Matrices

In [15]:
def test_LSI_GBT(hippo, gu_hippo, A, B, random_input, alpha=0.5):
    L = random_input.shape[1]
    for i in range(1, L + 1):
        GBT_A, GBT_B = hippo.encoder.discretize(
            A, B, step=i, alpha=alpha, dtype=jnp.float32
        )
        gu_GBT_A, gu_GBT_B = (
            jnp.asarray(gu_hippo.A_stacked[i - 1], dtype=jnp.float32),
            jnp.expand_dims(
                jnp.asarray(gu_hippo.B_stacked[i - 1], dtype=jnp.float32), axis=1
            ),
        )

        print(f"GBT_A: {jnp.allclose(GBT_A, gu_GBT_A, rtol=1e-04, atol=1e-04)}")
        print(f"GBT_B: {jnp.allclose(GBT_B, gu_GBT_B, rtol=1e-04, atol=1e-04)}\n")

In [16]:
def test_LTI_GBT(hippo, gu_hippo, A, B, random_input, alpha=0.5, print_all=False):
    L = random_input.shape[1]
    GBT_A, GBT_B = hippo.encoder.discretize(
        A, B, step=1.0, alpha=alpha, dtype=jnp.float32
    )
    gu_GBT_A, gu_GBT_B = (
        jnp.asarray(gu_hippo.dA, dtype=jnp.float32),
        jnp.expand_dims(jnp.asarray(gu_hippo.dB, dtype=jnp.float32), axis=1),
    )
    if print_all:
        print(f"gu_GBT_A shape:{gu_GBT_A.shape}\n")
        print(f"GBT_A shape: {GBT_A.shape}\n")
        print(f"gu_GBT_B shape: {gu_GBT_B.shape}\n")
        print(f"GBT_B shape: {GBT_B.shape}")

        print(f"gu_GBT_A:\n{gu_GBT_A}\n")
        print(f"GBT_A:\n{GBT_A}\n")
        print(f"gu_GBT_B:\n{gu_GBT_B}\n")
        print(f"GBT_B:\n{GBT_B}")

    print(f"GBT_A: {jnp.allclose(GBT_A, gu_GBT_A, rtol=1e-04, atol=1e-04)}")
    print(f"GBT_B: {jnp.allclose(GBT_B, gu_GBT_B, rtol=1e-04, atol=1e-04)}\n")

In [17]:
def test_GBT(
    the_measure="legs", lambda_n=1.0, alpha=0.5, discretization=0.5, print_all=False
):
    batch_size = 16
    data_size = 256
    input_size = 1

    N = 50
    L = data_size

    x_np = random_16_input(
        key_generator=subkeys[1],
        batch_size=batch_size,
        data_size=data_size,
        input_size=input_size,
    )
    print(x_np.shape)

    print(f"Creating Gu's HiPPO-{the_measure} LTI model with {alpha} transform")
    gu_hippo_lti = gu_HiPPO_LTI(
        N=N,
        method=the_measure,
        dt=1.0,
        T=1.0,
        discretization=discretization,
        lambda_n=lambda_n,
        alpha=0.0,
        beta=1.0,
        c=0.0,
    )  # The Gu's

    if the_measure == "legs":
        print(f"Creating Gu's HiPPO-{the_measure} LSI model with {alpha} transform")
        gu_hippo_lsi = gu_HiPPO_LSI(
            N=N,
            method="legs",
            max_length=L,
            discretization=discretization,
            lambda_n=lambda_n,
            alpha=0.0,
            beta=1.0,
        )  # The Gu's

    matrices = TransMatrix(
        N=N,
        measure=the_measure,
        lambda_n=lambda_n,
        alpha=0.0,
        beta=1.0,
        dtype=jnp.float32,
    )

    A = matrices.A
    B = matrices.B

    print(f"Creating HiPPO-{the_measure} LTI model with {alpha} transform")
    hippo_lti = HiPPO(
        N=N,
        max_length=L,
        step_size=1.0,
        basis_size=1.0,
        lambda_n=lambda_n,
        alpha=0.0,
        beta=1.0,
        GBT_alpha=alpha,
        measure=the_measure,
        s_t="lti",
        truncate_measure=True,
        dtype=jnp.float32,
        verbose=True,
    )
    lti_variables = hippo_lti.init(subkeys[2], x_np)
    hippo_lti_model = hippo_lti.bind(lti_variables)

    if the_measure == "legs":
        print(f"Creating HiPPO-{the_measure} LSI model with {alpha} transform")
        hippo_lsi = HiPPO(
            N=N,
            max_length=L,
            step_size=1.0,
            basis_size=1.0,
            lambda_n=lambda_n,
            alpha=0.0,
            beta=1.0,
            GBT_alpha=alpha,
            measure=the_measure,
            s_t="lsi",
            truncate_measure=True,
            dtype=jnp.float32,
            verbose=True,
        )
        lsi_variables = hippo_lsi.init(subkeys[3], x_np)
        hippo_lsi_model = hippo_lsi.bind(lsi_variables)

    print(f"Testing for correct LTI GBT matrices for HiPPO-{the_measure}")
    test_LTI_GBT(
        hippo=hippo_lti_model,
        gu_hippo=gu_hippo_lti,
        A=A,
        B=B,
        random_input=x_np,
        alpha=alpha,
        print_all=print_all,
    )
    if the_measure == "legs":
        print(f"Testing for correct LSI GBT matrices for HiPPO-{the_measure}")
        test_LSI_GBT(
            hippo=hippo_lsi_model,
            gu_hippo=gu_hippo_lsi,
            A=A,
            B=B,
            random_input=x_np,
            alpha=alpha,
        )

## Navigation To Table Of Contents
---
* [Table Of Contents](#table-of-contents)
* [Loading In Necessary Packages](#load-packages)
* [Instantiate The HiPPO Matrix](#instantiate-the-hippo-matrix)
* [Gu's Linear Time Invariant (LTI) HiPPO Operator](#gus-hippo-legt-operator)
* [Gu's Scale invariant (LSI) HiPPO Operator](#gus-scale-invariant-hippo-legs-operator)
* [Implementation Of General HiPPO Operator](#implementation-of-general-hippo-operator)
* [Test Generalized Bilinear Transform and Zero Order Hold Matrices](#test-generalized-bilinear-transform-and-zero-order-hold-matrices)
* [Testing HiPPO Operators](#test-hippo-operators)
---

In [18]:
print_all = False

### Testing Forward Euler Transform for LTI and LSI

#### LegS

In [19]:
test_GBT(
    the_measure="legs", lambda_n=1.0, alpha=0.0, discretization=0.0, print_all=print_all
)

(16, 256, 1)
Creating Gu's HiPPO-legs LTI model with 0.0 transform
Creating Gu's HiPPO-legs LSI model with 0.0 transform
Creating HiPPO-legs LTI model with 0.0 transform
In The HiPPO Setup
Using LTI encoder
self.vals shape:
(1,)
self.vals:
[0.]
vals shape:
(1,)
vals:
[0.]
x shape:
(1,)
x:
[-1.]
x shape:
(1,)
x:
[-1.]
eval_matrix shape:
(1,)
eval_matrix:
[0.00634819]
eval_matrix shape:
(50,)
eval_matrix:
[ 0.00634819 -0.01099539  0.01419499 -0.01679574  0.01904458 -0.02105457  0.02288873 -0.02458644  0.02617427 -0.02767113  0.02909108 -0.03044486
  0.03174096 -0.03298618  0.03418607 -0.03534524  0.03646759 -0.03755641  0.03861455 -0.03964445  0.04064827 -0.04162788  0.04258497 -0.04352102
  0.04443735 -0.04533516  0.04621554 -0.04707946  0.0479278  -0.04876139  0.04958097 -0.05038722  0.05118077 -0.05196219  0.05273205 -0.05349082
  0.05423898 -0.05497696  0.05570517 -0.05642397  0.05713373 -0.05783479  0.05852744 -0.059212    0.05988873 -0.0605579   0.06121975 -0.06187452
  0.06252245 

#### LegT

In [20]:
test_GBT(
    the_measure="legt", lambda_n=1.0, alpha=0.0, discretization=0.0, print_all=print_all
)

(16, 256, 1)
Creating Gu's HiPPO-legt LTI model with 0.0 transform
Creating HiPPO-legt LTI model with 0.0 transform
In The HiPPO Setup
Using LTI encoder
self.vals shape:
(1,)
self.vals:
[0.]
eval_matrix shape:
(50,)
eval_matrix:
[ 0.00634819 -0.01099539  0.01419499 -0.01679574  0.01904458 -0.02105457  0.02288873 -0.02458644  0.02617427 -0.02767113  0.02909108 -0.03044486
  0.03174096 -0.03298618  0.03418607 -0.03534524  0.03646759 -0.03755641  0.03861455 -0.03964445  0.04064827 -0.04162788  0.04258497 -0.04352102
  0.04443735 -0.04533516  0.04621554 -0.04707946  0.0479278  -0.04876139  0.04958097 -0.05038722  0.05118077 -0.05196219  0.05273205 -0.05349082
  0.05423898 -0.05497696  0.05570517 -0.05642397  0.05713373 -0.05783479  0.05852744 -0.059212    0.05988873 -0.0605579   0.06121975 -0.06187452
  0.06252245 -0.06316372]
val shape:
(1,)
val:
[0.]
Testing for correct LTI GBT matrices for HiPPO-legt
In The HiPPO Setup
Using LTI encoder
self.vals shape:
(1,)
self.vals:
[0.]
eval_matrix 

#### LMU

In [21]:
test_GBT(
    the_measure="lmu", lambda_n=2.0, alpha=0.0, discretization=0.0, print_all=print_all
)

(16, 256, 1)
Creating Gu's HiPPO-lmu LTI model with 0.0 transform
Creating HiPPO-lmu LTI model with 0.0 transform
In The HiPPO Setup
Using LTI encoder
self.vals shape:
(1,)
self.vals:
[0.]
eval_matrix shape:
(50,)
eval_matrix:
[ 0.00634819 -0.01099539  0.01419499 -0.01679574  0.01904458 -0.02105457  0.02288873 -0.02458644  0.02617427 -0.02767113  0.02909108 -0.03044486
  0.03174096 -0.03298618  0.03418607 -0.03534524  0.03646759 -0.03755641  0.03861455 -0.03964445  0.04064827 -0.04162788  0.04258497 -0.04352102
  0.04443735 -0.04533516  0.04621554 -0.04707946  0.0479278  -0.04876139  0.04958097 -0.05038722  0.05118077 -0.05196219  0.05273205 -0.05349082
  0.05423898 -0.05497696  0.05570517 -0.05642397  0.05713373 -0.05783479  0.05852744 -0.059212    0.05988873 -0.0605579   0.06121975 -0.06187452
  0.06252245 -0.06316372]
val shape:
(1,)
val:
[0.]
Testing for correct LTI GBT matrices for HiPPO-lmu
In The HiPPO Setup
Using LTI encoder
self.vals shape:
(1,)
self.vals:
[0.]
eval_matrix sha

#### LagT

In [22]:
test_GBT(
    the_measure="lagt", lambda_n=1.0, alpha=0.0, discretization=0.0, print_all=print_all
)

(16, 256, 1)
Creating Gu's HiPPO-lagt LTI model with 0.0 transform
Creating HiPPO-lagt LTI model with 0.0 transform
In The HiPPO Setup
Using LTI encoder
self.vals shape:
(1,)
self.vals:
[0.]


TypeError: scan carry output and input must have same type structure, got PyTreeDef(*) and PyTreeDef((*, *)).

#### FRU

In [None]:
test_GBT(
    the_measure="fru", lambda_n=1.0, alpha=0.0, discretization=0.0, print_all=print_all
)

(16, 256, 1)
Creating Gu's HiPPO-fru LTI model with 0.0 transform
Creating HiPPO-fru LTI model with 0.0 transform
Testing for correct LTI GBT matrices for HiPPO-fru
GBT_A: True
GBT_B: True



#### FouT

In [None]:
test_GBT(
    the_measure="fout", lambda_n=1.0, alpha=0.0, discretization=0.0, print_all=print_all
)

(16, 256, 1)
Creating Gu's HiPPO-fout LTI model with 0.0 transform
Creating HiPPO-fout LTI model with 0.0 transform
Testing for correct LTI GBT matrices for HiPPO-fout
GBT_A: True
GBT_B: True



#### FouD

In [None]:
test_GBT(
    the_measure="foud", lambda_n=1.0, alpha=0.0, discretization=0.0, print_all=print_all
)

(16, 256, 1)
Creating Gu's HiPPO-foud LTI model with 0.0 transform
Creating HiPPO-foud LTI model with 0.0 transform
Testing for correct LTI GBT matrices for HiPPO-foud
GBT_A: True
GBT_B: True



### Testing Backward Euler Transform for LTI and LSI on LegS Matrices

#### LegS

In [None]:
test_GBT(
    the_measure="legs", lambda_n=1.0, alpha=1.0, discretization=1.0, print_all=print_all
)

(16, 256, 1)
Creating Gu's HiPPO-legs LTI model with 1.0 transform
Creating Gu's HiPPO-legs LSI model with 1.0 transform
Creating HiPPO-legs LTI model with 1.0 transform
Creating HiPPO-legs LSI model with 1.0 transform
Testing for correct LTI GBT matrices for HiPPO-legs
GBT_A: True
GBT_B: True

Testing for correct LSI GBT matrices for HiPPO-legs
GBT_A: True
GBT_B: True

GBT_A: True
GBT_B: True

GBT_A: True
GBT_B: True

GBT_A: True
GBT_B: True

GBT_A: True
GBT_B: True

GBT_A: True
GBT_B: True

GBT_A: True
GBT_B: True

GBT_A: True
GBT_B: True

GBT_A: True
GBT_B: True

GBT_A: True
GBT_B: True

GBT_A: True
GBT_B: True

GBT_A: True
GBT_B: True

GBT_A: True
GBT_B: True

GBT_A: True
GBT_B: True

GBT_A: True
GBT_B: True

GBT_A: True
GBT_B: True

GBT_A: True
GBT_B: True

GBT_A: True
GBT_B: True

GBT_A: True
GBT_B: True

GBT_A: True
GBT_B: True

GBT_A: True
GBT_B: True

GBT_A: True
GBT_B: True

GBT_A: True
GBT_B: True

GBT_A: True
GBT_B: True

GBT_A: True
GBT_B: True

GBT_A: True
GBT_B: True

GB

#### LegT

In [None]:
test_GBT(
    the_measure="legt", lambda_n=1.0, alpha=1.0, discretization=1.0, print_all=print_all
)

(16, 256, 1)
Creating Gu's HiPPO-legt LTI model with 1.0 transform
Creating HiPPO-legt LTI model with 1.0 transform
Testing for correct LTI GBT matrices for HiPPO-legt
GBT_A: True
GBT_B: True



#### LMU

In [None]:
test_GBT(
    the_measure="lmu", lambda_n=2.0, alpha=1.0, discretization=1.0, print_all=print_all
)

(16, 256, 1)
Creating Gu's HiPPO-lmu LTI model with 1.0 transform
Creating HiPPO-lmu LTI model with 1.0 transform
Testing for correct LTI GBT matrices for HiPPO-lmu
GBT_A: True
GBT_B: True



#### LagT

In [None]:
test_GBT(
    the_measure="lagt", lambda_n=1.0, alpha=1.0, discretization=1.0, print_all=print_all
)

(16, 256, 1)
Creating Gu's HiPPO-lagt LTI model with 1.0 transform
Creating HiPPO-lagt LTI model with 1.0 transform
Testing for correct LTI GBT matrices for HiPPO-lagt
GBT_A: True
GBT_B: True



#### FRU

In [None]:
test_GBT(
    the_measure="fru", lambda_n=1.0, alpha=1.0, discretization=1.0, print_all=print_all
)

(16, 256, 1)
Creating Gu's HiPPO-fru LTI model with 1.0 transform
Creating HiPPO-fru LTI model with 1.0 transform
Testing for correct LTI GBT matrices for HiPPO-fru
GBT_A: True
GBT_B: True



#### FouT

In [None]:
test_GBT(
    the_measure="fout", lambda_n=1.0, alpha=1.0, discretization=1.0, print_all=print_all
)

(16, 256, 1)
Creating Gu's HiPPO-fout LTI model with 1.0 transform
Creating HiPPO-fout LTI model with 1.0 transform
Testing for correct LTI GBT matrices for HiPPO-fout
GBT_A: True
GBT_B: True



#### FouD

In [None]:
test_GBT(
    the_measure="foud", lambda_n=1.0, alpha=1.0, discretization=1.0, print_all=print_all
)

(16, 256, 1)
Creating Gu's HiPPO-foud LTI model with 1.0 transform
Creating HiPPO-foud LTI model with 1.0 transform
Testing for correct LTI GBT matrices for HiPPO-foud
GBT_A: True
GBT_B: True



### Testing Bidirectional Transform for LTI and LSI on LegS Matrices

#### LegS

In [None]:
test_GBT(
    the_measure="legs", lambda_n=1.0, alpha=0.5, discretization=0.5, print_all=print_all
)

(16, 256, 1)
Creating Gu's HiPPO-legs LTI model with 0.5 transform
Creating Gu's HiPPO-legs LSI model with 0.5 transform
Creating HiPPO-legs LTI model with 0.5 transform
Creating HiPPO-legs LSI model with 0.5 transform
Testing for correct LTI GBT matrices for HiPPO-legs
GBT_A: True
GBT_B: True

Testing for correct LSI GBT matrices for HiPPO-legs
GBT_A: True
GBT_B: True

GBT_A: True
GBT_B: True

GBT_A: True
GBT_B: True

GBT_A: True
GBT_B: True

GBT_A: True
GBT_B: True

GBT_A: True
GBT_B: True

GBT_A: True
GBT_B: True

GBT_A: True
GBT_B: True

GBT_A: True
GBT_B: True

GBT_A: True
GBT_B: True

GBT_A: True
GBT_B: True

GBT_A: True
GBT_B: True

GBT_A: True
GBT_B: True

GBT_A: True
GBT_B: True

GBT_A: True
GBT_B: True

GBT_A: True
GBT_B: True

GBT_A: True
GBT_B: True

GBT_A: True
GBT_B: True

GBT_A: True
GBT_B: True

GBT_A: True
GBT_B: True

GBT_A: True
GBT_B: True

GBT_A: True
GBT_B: True

GBT_A: True
GBT_B: True

GBT_A: True
GBT_B: True

GBT_A: True
GBT_B: True

GBT_A: True
GBT_B: True

GB

#### LegT

In [None]:
test_GBT(
    the_measure="legt", lambda_n=1.0, alpha=0.5, discretization=0.5, print_all=print_all
)

(16, 256, 1)
Creating Gu's HiPPO-legt LTI model with 0.5 transform
Creating HiPPO-legt LTI model with 0.5 transform
Testing for correct LTI GBT matrices for HiPPO-legt
GBT_A: True
GBT_B: True



#### LMU

In [None]:
test_GBT(
    the_measure="lmu", lambda_n=2.0, alpha=0.5, discretization=0.5, print_all=print_all
)

(16, 256, 1)
Creating Gu's HiPPO-lmu LTI model with 0.5 transform
Creating HiPPO-lmu LTI model with 0.5 transform
Testing for correct LTI GBT matrices for HiPPO-lmu
GBT_A: True
GBT_B: True



#### LagT

In [None]:
test_GBT(
    the_measure="lagt", lambda_n=1.0, alpha=0.5, discretization=0.5, print_all=print_all
)

(16, 256, 1)
Creating Gu's HiPPO-lagt LTI model with 0.5 transform
Creating HiPPO-lagt LTI model with 0.5 transform
Testing for correct LTI GBT matrices for HiPPO-lagt
GBT_A: True
GBT_B: True



#### FRU

In [None]:
test_GBT(
    the_measure="fru", lambda_n=1.0, alpha=0.5, discretization=0.5, print_all=print_all
)

(16, 256, 1)
Creating Gu's HiPPO-fru LTI model with 0.5 transform
Creating HiPPO-fru LTI model with 0.5 transform
Testing for correct LTI GBT matrices for HiPPO-fru
GBT_A: True
GBT_B: True



#### FouT

In [None]:
test_GBT(
    the_measure="fout", lambda_n=1.0, alpha=0.5, discretization=0.5, print_all=print_all
)

(16, 256, 1)
Creating Gu's HiPPO-fout LTI model with 0.5 transform
Creating HiPPO-fout LTI model with 0.5 transform
Testing for correct LTI GBT matrices for HiPPO-fout
GBT_A: True
GBT_B: True



#### FouD

In [None]:
test_GBT(
    the_measure="foud", lambda_n=1.0, alpha=0.5, discretization=0.5, print_all=print_all
)

(16, 256, 1)
Creating Gu's HiPPO-foud LTI model with 0.5 transform
Creating HiPPO-foud LTI model with 0.5 transform
Testing for correct LTI GBT matrices for HiPPO-foud
GBT_A: True
GBT_B: True



### Testing ZOH Transform for LTI and LSI on LegS Matrices

#### LegS

In [None]:
test_GBT(
    the_measure="legs",
    lambda_n=1.0,
    alpha=2.0,
    discretization="zoh",
    print_all=print_all,
)

(16, 256, 1)
Creating Gu's HiPPO-legs LTI model with 2.0 transform
Creating Gu's HiPPO-legs LSI model with 2.0 transform
Creating HiPPO-legs LTI model with 2.0 transform
Creating HiPPO-legs LSI model with 2.0 transform
Testing for correct LTI GBT matrices for HiPPO-legs
GBT_A: True
GBT_B: True

Testing for correct LSI GBT matrices for HiPPO-legs
GBT_A: True
GBT_B: True

GBT_A: True
GBT_B: True

GBT_A: True
GBT_B: True

GBT_A: True
GBT_B: True

GBT_A: True
GBT_B: True

GBT_A: True
GBT_B: True

GBT_A: True
GBT_B: True

GBT_A: True
GBT_B: True

GBT_A: True
GBT_B: True

GBT_A: True
GBT_B: True

GBT_A: True
GBT_B: True

GBT_A: True
GBT_B: True

GBT_A: True
GBT_B: True

GBT_A: True
GBT_B: True

GBT_A: True
GBT_B: True

GBT_A: True
GBT_B: True

GBT_A: True
GBT_B: True

GBT_A: True
GBT_B: True

GBT_A: True
GBT_B: True

GBT_A: True
GBT_B: True

GBT_A: True
GBT_B: True

GBT_A: True
GBT_B: True

GBT_A: True
GBT_B: True

GBT_A: True
GBT_B: True

GBT_A: True
GBT_B: True

GBT_A: True
GBT_B: True

GB

#### LegT

In [None]:
test_GBT(
    the_measure="legt",
    lambda_n=1.0,
    alpha=2.0,
    discretization="zoh",
    print_all=print_all,
)

(16, 256, 1)
Creating Gu's HiPPO-legt LTI model with 2.0 transform
Creating HiPPO-legt LTI model with 2.0 transform
Testing for correct LTI GBT matrices for HiPPO-legt
GBT_A: True
GBT_B: True



#### LMU

In [None]:
test_GBT(
    the_measure="lmu",
    lambda_n=2.0,
    alpha=2.0,
    discretization="zoh",
    print_all=print_all,
)

(16, 256, 1)
Creating Gu's HiPPO-lmu LTI model with 2.0 transform
Creating HiPPO-lmu LTI model with 2.0 transform
Testing for correct LTI GBT matrices for HiPPO-lmu
GBT_A: True
GBT_B: True



#### LagT

In [None]:
test_GBT(
    the_measure="lagt",
    lambda_n=1.0,
    alpha=2.0,
    discretization="zoh",
    print_all=print_all,
)

(16, 256, 1)
Creating Gu's HiPPO-lagt LTI model with 2.0 transform
Creating HiPPO-lagt LTI model with 2.0 transform
Testing for correct LTI GBT matrices for HiPPO-lagt
GBT_A: True
GBT_B: True



#### FRU

In [None]:
test_GBT(
    the_measure="fru",
    lambda_n=1.0,
    alpha=2.0,
    discretization="zoh",
    print_all=print_all,
)

(16, 256, 1)
Creating Gu's HiPPO-fru LTI model with 2.0 transform
Creating HiPPO-fru LTI model with 2.0 transform
Testing for correct LTI GBT matrices for HiPPO-fru
GBT_A: True
GBT_B: True



#### FouT

In [None]:
test_GBT(
    the_measure="fout",
    lambda_n=1.0,
    alpha=2.0,
    discretization="zoh",
    print_all=print_all,
)

(16, 256, 1)
Creating Gu's HiPPO-fout LTI model with 2.0 transform
Creating HiPPO-fout LTI model with 2.0 transform
Testing for correct LTI GBT matrices for HiPPO-fout
GBT_A: True
GBT_B: True



#### FouD

In [None]:
test_GBT(
    the_measure="foud",
    lambda_n=1.0,
    alpha=2.0,
    discretization="zoh",
    print_all=print_all,
)

(16, 256, 1)
Creating Gu's HiPPO-foud LTI model with 2.0 transform
Creating HiPPO-foud LTI model with 2.0 transform
Testing for correct LTI GBT matrices for HiPPO-foud
GBT_A: True
GBT_B: True



## Test HiPPO Operators

In [None]:
def test_hippo_operator(
    hippo_legs, gu_hippo_legs, random_input, key, s_or_t="lti", print_all=False
):
    x_tensor = torch.tensor(random_input, dtype=torch.float32)
    x_jnp = jnp.asarray(x_tensor, dtype=jnp.float32)  # convert torch array to jax array

    # My Implementation
    if print_all:
        print(
            f"------------------------------------------------------------------------------------------"
        )
        print(
            f"----------------------------My {s_or_t} Implementation Outputs----------------------------"
        )
        print(
            f"------------------------------------------------------------------------------------------"
        )
    params = hippo_legs.init(key, f=x_jnp)
    c_k = hippo_legs.apply(params, f=x_jnp)
    if s_or_t == "lsi":
        c_k = jnp.moveaxis(c_k, 0, 1)

    # Gu's HiPPO LegS
    if print_all:
        print(
            f"------------------------------------------------------------------------------------------"
        )
        print(
            f"---------------------------Gu's {s_or_t} Implementation Outputs---------------------------"
        )
        print(
            f"------------------------------------------------------------------------------------------"
        )
    x_tensor = torch.moveaxis(x_tensor, 0, 1)
    GU_c_k = gu_hippo_legs(x_tensor, fast=False)
    gu_c = jnp.asarray(GU_c_k, dtype=jnp.float32)  # convert torch array to jax array
    gu_c = jnp.moveaxis(gu_c, 0, 1)

    if print_all:
        print(
            f"------------------------------------------------------------------------------"
        )
        print(
            f"---------------------------Testing {s_or_t} Outputs---------------------------"
        )
        print(
            f"------------------------------------------------------------------------------"
        )
        jax.debug.print(f"inputted jnp-data shape: {x_jnp.shape}")
        jax.debug.print(f"inputted tensor-data shape: {x_tensor.shape}")
        print(f"c_k shape: {c_k.shape}")
        print(f"gu_c shape: {gu_c.shape}")

    flag = True
    for i in range(c_k.shape[0]):
        for j in range(c_k.shape[1]):
            if print_all:
                print(f"c_k[{i},{j},:,:]:\n{c_k[i,j,:,:]}")
                print(f"gu_c[{i},{j},:,:]:\n{gu_c[i,j,:,:]}")

            check = jnp.allclose(
                c_k[i, j, :, :], gu_c[i, j, :, :], rtol=1e-04, atol=1e-04
            )
            if check == False:
                flag = False
    if not print_all:
        print(f"The Test Passed: {flag}")

In [None]:
def test_operators(
    the_measure="legs", lambda_n=1.0, alpha=0.5, discretization=0.5, print_all=False
):
    # N = 256
    # L = 128

    batch_size = 16
    data_size = 512
    input_size = 1

    N = 50
    L = data_size

    x_jnp = random_16_input(
        key_generator=subkey[4],
        batch_size=batch_size,
        data_size=data_size,
        input_size=input_size,
    )
    x_np = np.asarray(x_jnp)

    x = torch.tensor(x_np, dtype=torch.float32)

    # ----------------------------------------------------------------------------------
    # ------------------------------ Instantiate Gu's HiPPOs ---------------------------
    # ----------------------------------------------------------------------------------

    print(f"Creating Gu's HiPPO-{the_measure} LTI model with {alpha} transform")
    gu_hippo_lti = gu_HiPPO_LTI(
        N=N,
        method=the_measure,
        dt=1.0,
        T=1.0,
        discretization=discretization,
        lambda_n=lambda_n,
        alpha=0.0,
        beta=1.0,
        c=0.0,
    )  # The Gu's

    if the_measure == "legs":
        print(f"Creating Gu's HiPPO-{the_measure} LSI model with {alpha} transform")
        gu_hippo_lsi = HiPPO_LSI(
            N=N,
            method=the_measure,
            max_length=L,
            discretization=discretization,
            lambda_n=lambda_n,
            alpha=0.0,
            beta=1.0,
        )  # The Gu's

    # ----------------------------------------------------------------------------------
    # ------------------------------ Instantiate My HiPPOs -----------------------------
    # ----------------------------------------------------------------------------------
    print(f"\nTesting BRYANS HiPPO-{the_measure} model")

    matrices = TransMatrix(
        N=N,
        measure=the_measure,
        lambda_n=lambda_n,
        alpha=0.0,
        beta=1.0,
        dtype=jnp.float32,
    )

    A = matrices.A
    B = matrices.B

    print(f"Creating HiPPO-{the_measure} LTI model with {alpha} transform")
    hippo_lti = HiPPO(
        max_length=L,
        step_size=1.0,
        N=N,
        lambda_n=lambda_n,
        alpha=0.0,
        beta=1.0,
        GBT_alpha=alpha,
        measure=the_measure,
        s_t="lti",
        dtype=jnp.float32,
        verbose=True,
    )  # Bryan's

    if the_measure == "legs":
        print(f"Creating HiPPO-{the_measure} LSI model with {alpha} transform")
        hippo_lsi = HiPPO(
            max_length=L,
            step_size=1.0,
            N=N,
            lambda_n=lambda_n,
            alpha=0.0,
            beta=1.0,
            GBT_alpha=alpha,
            measure=the_measure,
            s_t="lsi",
            dtype=jnp.float32,
            verbose=True,
        )  # Bryan's

    # ----------------------------------------------------------------------------------
    # ------------------------------ Test HiPPO Operators ------------------------------
    # ----------------------------------------------------------------------------------

    print(f"Bryan's Coeffiecients for {alpha} LTI HiPPO-{the_measure}")

    test_hippo_operator(
        hippo_legs=hippo_lti,
        gu_hippo_legs=gu_hippo_lti,
        random_input=x_np,
        key=subkey[5],
        s_or_t="lti",
        print_all=print_all,
    )

    if the_measure == "legs":
        print(f"\n\nBryan's Coeffiecients for {alpha} LSI HiPPO-{the_measure}")

        test_hippo_operator(
            hippo_legs=hippo_lsi,
            gu_hippo_legs=gu_hippo_lsi,
            random_input=x_np,
            key=subkey[6],
            s_or_t="lsi",
            print_all=print_all,
        )

    print(f"end of test for HiPPO-{the_measure} model")

## Navigation To Table Of Contents
---
* [Table Of Contents](#table-of-contents)
* [Loading In Necessary Packages](#load-packages)
* [Instantiate The HiPPO Matrix](#instantiate-the-hippo-matrix)
* [Gu's Linear Time Invariant (LTI) HiPPO Operator](#gus-hippo-legt-operator)
* [Gu's Scale invariant (LSI) HiPPO Operator](#gus-scale-invariant-hippo-legs-operator)
* [Implementation Of General HiPPO Operator](#implementation-of-general-hippo-operator)
* [Test Generalized Bilinear Transform and Zero Order Hold Matrices](#test-generalized-bilinear-transform-and-zero-order-hold-matrices)
* [Testing HiPPO Operators](#test-hippo-operators)
---

In [None]:
print_all = False

### Testing (LTI and LSI) Operators With Forward Euler Transform

#### LegS

In [None]:
test_operators(
    the_measure="legs", lambda_n=1.0, alpha=0.0, discretization=0.0, print_all=print_all
)

Creating Gu's HiPPO-legs LTI model with 0.0 transform
Creating Gu's HiPPO-legs LSI model with 0.0 transform

Testing BRYANS HiPPO-legs model
Creating HiPPO-legs LTI model with 0.0 transform
Creating HiPPO-legs LSI model with 0.0 transform
Bryan's Coeffiecients for 0.0 LTI HiPPO-legs
The Test Passed: False


Bryan's Coeffiecients for 0.0 LSI HiPPO-legs
The Test Passed: False
end of test for HiPPO-legs model


#### LegT

In [None]:
test_operators(
    the_measure="legt", lambda_n=1.0, alpha=0.0, discretization=0.0, print_all=print_all
)

Creating Gu's HiPPO-legt LTI model with 0.0 transform

Testing BRYANS HiPPO-legt model
Creating HiPPO-legt LTI model with 0.0 transform
Bryan's Coeffiecients for 0.0 LTI HiPPO-legt
The Test Passed: False
end of test for HiPPO-legt model


#### LMU

In [None]:
test_operators(
    the_measure="lmu", lambda_n=2.0, alpha=0.0, discretization=0.0, print_all=print_all
)

Creating Gu's HiPPO-lmu LTI model with 0.0 transform

Testing BRYANS HiPPO-lmu model
Creating HiPPO-lmu LTI model with 0.0 transform
Bryan's Coeffiecients for 0.0 LTI HiPPO-lmu
The Test Passed: False
end of test for HiPPO-lmu model


#### LagT

In [None]:
test_operators(
    the_measure="lagt", lambda_n=1.0, alpha=0.0, discretization=0.0, print_all=print_all
)

Creating Gu's HiPPO-lagt LTI model with 0.0 transform

Testing BRYANS HiPPO-lagt model
Creating HiPPO-lagt LTI model with 0.0 transform
Bryan's Coeffiecients for 0.0 LTI HiPPO-lagt
The Test Passed: False
end of test for HiPPO-lagt model


#### FRU

In [None]:
test_operators(
    the_measure="fru", lambda_n=1.0, alpha=0.0, discretization=0.0, print_all=print_all
)

Creating Gu's HiPPO-fru LTI model with 0.0 transform

Testing BRYANS HiPPO-fru model
Creating HiPPO-fru LTI model with 0.0 transform
Bryan's Coeffiecients for 0.0 LTI HiPPO-fru
The Test Passed: False
end of test for HiPPO-fru model


#### FouT

In [None]:
test_operators(
    the_measure="fout", lambda_n=1.0, alpha=0.0, discretization=0.0, print_all=print_all
)

Creating Gu's HiPPO-fout LTI model with 0.0 transform

Testing BRYANS HiPPO-fout model
Creating HiPPO-fout LTI model with 0.0 transform
Bryan's Coeffiecients for 0.0 LTI HiPPO-fout
The Test Passed: False
end of test for HiPPO-fout model


#### FouD

In [None]:
test_operators(
    the_measure="foud", lambda_n=1.0, alpha=0.0, discretization=0.0, print_all=print_all
)

Creating Gu's HiPPO-foud LTI model with 0.0 transform

Testing BRYANS HiPPO-foud model
Creating HiPPO-foud LTI model with 0.0 transform
Bryan's Coeffiecients for 0.0 LTI HiPPO-foud
The Test Passed: False
end of test for HiPPO-foud model


### Testing (LTI and LSI) Operators With Backward Euler Transform

#### LegS

In [None]:
test_operators(
    the_measure="legs", lambda_n=1.0, alpha=1.0, discretization=1.0, print_all=print_all
)

Creating Gu's HiPPO-legs LTI model with 1.0 transform
Creating Gu's HiPPO-legs LSI model with 1.0 transform

Testing BRYANS HiPPO-legs model
Creating HiPPO-legs LTI model with 1.0 transform
Creating HiPPO-legs LSI model with 1.0 transform
Bryan's Coeffiecients for 1.0 LTI HiPPO-legs
The Test Passed: True


Bryan's Coeffiecients for 1.0 LSI HiPPO-legs
The Test Passed: True
end of test for HiPPO-legs model


#### LegT

In [None]:
test_operators(
    the_measure="legt", lambda_n=1.0, alpha=1.0, discretization=1.0, print_all=print_all
)

Creating Gu's HiPPO-legt LTI model with 1.0 transform

Testing BRYANS HiPPO-legt model
Creating HiPPO-legt LTI model with 1.0 transform
Bryan's Coeffiecients for 1.0 LTI HiPPO-legt
The Test Passed: True
end of test for HiPPO-legt model


#### LMU

In [None]:
test_operators(
    the_measure="lmu", lambda_n=2.0, alpha=1.0, discretization=1.0, print_all=print_all
)

Creating Gu's HiPPO-lmu LTI model with 1.0 transform

Testing BRYANS HiPPO-lmu model
Creating HiPPO-lmu LTI model with 1.0 transform
Bryan's Coeffiecients for 1.0 LTI HiPPO-lmu
The Test Passed: True
end of test for HiPPO-lmu model


#### LagT

In [None]:
test_operators(
    the_measure="lagt", lambda_n=1.0, alpha=1.0, discretization=1.0, print_all=print_all
)

Creating Gu's HiPPO-lagt LTI model with 1.0 transform

Testing BRYANS HiPPO-lagt model
Creating HiPPO-lagt LTI model with 1.0 transform
Bryan's Coeffiecients for 1.0 LTI HiPPO-lagt
The Test Passed: True
end of test for HiPPO-lagt model


#### FRU

In [None]:
test_operators(
    the_measure="fru", lambda_n=1.0, alpha=1.0, discretization=1.0, print_all=print_all
)

Creating Gu's HiPPO-fru LTI model with 1.0 transform

Testing BRYANS HiPPO-fru model
Creating HiPPO-fru LTI model with 1.0 transform
Bryan's Coeffiecients for 1.0 LTI HiPPO-fru
The Test Passed: False
end of test for HiPPO-fru model


#### FouT

In [None]:
test_operators(
    the_measure="fout", lambda_n=1.0, alpha=1.0, discretization=1.0, print_all=print_all
)

Creating Gu's HiPPO-fout LTI model with 1.0 transform

Testing BRYANS HiPPO-fout model
Creating HiPPO-fout LTI model with 1.0 transform
Bryan's Coeffiecients for 1.0 LTI HiPPO-fout
The Test Passed: True
end of test for HiPPO-fout model


#### FouD

In [None]:
test_operators(
    the_measure="foud", lambda_n=1.0, alpha=1.0, discretization=1.0, print_all=print_all
)

Creating Gu's HiPPO-foud LTI model with 1.0 transform

Testing BRYANS HiPPO-foud model
Creating HiPPO-foud LTI model with 1.0 transform
Bryan's Coeffiecients for 1.0 LTI HiPPO-foud
The Test Passed: False
end of test for HiPPO-foud model


### Testing (LTI and LSI) Operators With Bidirectional Transform

#### LegS

In [None]:
test_operators(
    the_measure="legs", lambda_n=1.0, alpha=0.5, discretization=0.5, print_all=print_all
)

Creating Gu's HiPPO-legs LTI model with 0.5 transform
Creating Gu's HiPPO-legs LSI model with 0.5 transform

Testing BRYANS HiPPO-legs model
Creating HiPPO-legs LTI model with 0.5 transform
Creating HiPPO-legs LSI model with 0.5 transform
Bryan's Coeffiecients for 0.5 LTI HiPPO-legs
The Test Passed: True


Bryan's Coeffiecients for 0.5 LSI HiPPO-legs
The Test Passed: True
end of test for HiPPO-legs model


#### LegT

In [None]:
test_operators(
    the_measure="legt", lambda_n=1.0, alpha=0.5, discretization=0.5, print_all=print_all
)

Creating Gu's HiPPO-legt LTI model with 0.5 transform

Testing BRYANS HiPPO-legt model
Creating HiPPO-legt LTI model with 0.5 transform
Bryan's Coeffiecients for 0.5 LTI HiPPO-legt
The Test Passed: True
end of test for HiPPO-legt model


#### LMU

In [None]:
test_operators(
    the_measure="lmu", lambda_n=2.0, alpha=0.5, discretization=0.5, print_all=print_all
)

Creating Gu's HiPPO-lmu LTI model with 0.5 transform

Testing BRYANS HiPPO-lmu model
Creating HiPPO-lmu LTI model with 0.5 transform
Bryan's Coeffiecients for 0.5 LTI HiPPO-lmu
The Test Passed: False
end of test for HiPPO-lmu model


#### LagT

In [None]:
test_operators(
    the_measure="lagt", lambda_n=1.0, alpha=0.5, discretization=0.5, print_all=print_all
)

Creating Gu's HiPPO-lagt LTI model with 0.5 transform

Testing BRYANS HiPPO-lagt model
Creating HiPPO-lagt LTI model with 0.5 transform
Bryan's Coeffiecients for 0.5 LTI HiPPO-lagt
The Test Passed: True
end of test for HiPPO-lagt model


#### FRU

In [None]:
test_operators(
    the_measure="fru", lambda_n=1.0, alpha=0.5, discretization=0.5, print_all=print_all
)

Creating Gu's HiPPO-fru LTI model with 0.5 transform

Testing BRYANS HiPPO-fru model
Creating HiPPO-fru LTI model with 0.5 transform
Bryan's Coeffiecients for 0.5 LTI HiPPO-fru
The Test Passed: True
end of test for HiPPO-fru model


#### FouT

In [None]:
test_operators(
    the_measure="fout", lambda_n=1.0, alpha=0.5, discretization=0.5, print_all=print_all
)

Creating Gu's HiPPO-fout LTI model with 0.5 transform

Testing BRYANS HiPPO-fout model
Creating HiPPO-fout LTI model with 0.5 transform
Bryan's Coeffiecients for 0.5 LTI HiPPO-fout
The Test Passed: False
end of test for HiPPO-fout model


#### FouD

In [None]:
test_operators(
    the_measure="foud", lambda_n=1.0, alpha=0.5, discretization=0.5, print_all=print_all
)

Creating Gu's HiPPO-foud LTI model with 0.5 transform

Testing BRYANS HiPPO-foud model
Creating HiPPO-foud LTI model with 0.5 transform
Bryan's Coeffiecients for 0.5 LTI HiPPO-foud
The Test Passed: True
end of test for HiPPO-foud model


### Testing (LTI and LSI) Operators With ZOH Transform

#### LegS

In [None]:
test_operators(
    the_measure="legs",
    lambda_n=1.0,
    alpha=2.0,
    discretization="zoh",
    print_all=print_all,
)

Creating Gu's HiPPO-legs LTI model with 2.0 transform
Creating Gu's HiPPO-legs LSI model with 2.0 transform

Testing BRYANS HiPPO-legs model
Creating HiPPO-legs LTI model with 2.0 transform
Creating HiPPO-legs LSI model with 2.0 transform
Bryan's Coeffiecients for 2.0 LTI HiPPO-legs
The Test Passed: True


Bryan's Coeffiecients for 2.0 LSI HiPPO-legs
The Test Passed: True
end of test for HiPPO-legs model


#### LegT

In [None]:
test_operators(
    the_measure="legt",
    lambda_n=1.0,
    alpha=2.0,
    discretization="zoh",
    print_all=print_all,
)

Creating Gu's HiPPO-legt LTI model with 2.0 transform

Testing BRYANS HiPPO-legt model
Creating HiPPO-legt LTI model with 2.0 transform
Bryan's Coeffiecients for 2.0 LTI HiPPO-legt
The Test Passed: True
end of test for HiPPO-legt model


#### LMU

In [None]:
test_operators(
    the_measure="lmu",
    lambda_n=2.0,
    alpha=2.0,
    discretization="zoh",
    print_all=print_all,
)

Creating Gu's HiPPO-lmu LTI model with 2.0 transform

Testing BRYANS HiPPO-lmu model
Creating HiPPO-lmu LTI model with 2.0 transform
Bryan's Coeffiecients for 2.0 LTI HiPPO-lmu
The Test Passed: True
end of test for HiPPO-lmu model


#### LagT

In [None]:
test_operators(
    the_measure="lagt",
    lambda_n=1.0,
    alpha=2.0,
    discretization="zoh",
    print_all=print_all,
)

Creating Gu's HiPPO-lagt LTI model with 2.0 transform

Testing BRYANS HiPPO-lagt model
Creating HiPPO-lagt LTI model with 2.0 transform
Bryan's Coeffiecients for 2.0 LTI HiPPO-lagt
The Test Passed: True
end of test for HiPPO-lagt model


#### FRU

In [None]:
test_operators(
    the_measure="fru",
    lambda_n=1.0,
    alpha=2.0,
    discretization="zoh",
    print_all=print_all,
)

Creating Gu's HiPPO-fru LTI model with 2.0 transform

Testing BRYANS HiPPO-fru model
Creating HiPPO-fru LTI model with 2.0 transform
Bryan's Coeffiecients for 2.0 LTI HiPPO-fru
The Test Passed: True
end of test for HiPPO-fru model


#### FouT

In [None]:
test_operators(
    the_measure="fout",
    lambda_n=1.0,
    alpha=2.0,
    discretization="zoh",
    print_all=print_all,
)

Creating Gu's HiPPO-fout LTI model with 2.0 transform

Testing BRYANS HiPPO-fout model
Creating HiPPO-fout LTI model with 2.0 transform
Bryan's Coeffiecients for 2.0 LTI HiPPO-fout
The Test Passed: True
end of test for HiPPO-fout model


#### FouD

In [None]:
test_operators(
    the_measure="foud",
    lambda_n=1.0,
    alpha=2.0,
    discretization="zoh",
    print_all=print_all,
)

Creating Gu's HiPPO-foud LTI model with 2.0 transform

Testing BRYANS HiPPO-foud model
Creating HiPPO-foud LTI model with 2.0 transform
Bryan's Coeffiecients for 2.0 LTI HiPPO-foud
The Test Passed: True
end of test for HiPPO-foud model
