# HiPPO Matrices
---

## Table of Contents
* [Loading In Necessary Packages](#load-packages)
* [Instantiate The HiPPO Matrix](#instantiate-the-hippo-matrix)
    * [Translated Legendre (LegT)](#translated-legendre-legt)
        * [LegT](#legt)
        * [LMU](#lmu)
    * [Translated Laguerre (LagT)](#translated-laguerre-lagt)
    * [Scaled Legendre (LegS)](#scaled-legendre-legs)
    * [Fourier Basis](#fourier-basis)
        * [Fourier Recurrent Unit (FRU)](#fourier-recurrent-unit-fru)
        * [Truncated Fourier (FouT)](#truncated-fourier-fout)
        * [Fourier With Decay (FourD)](#fourier-with-decay-fourd)
* [Gu's Linear Time Invariant (LTI) HiPPO Operator](#gus-hippo-legt-operator)
* [Gu's Scale invariant (LSI) HiPPO Operator](#gus-scale-invariant-hippo-legs-operator)
* [Implementation Of General HiPPO Operator](#implementation-of-general-hippo-operator)
* [Test Generalized Bilinear Transform and Zero Order Hold Matrices](#test-generalized-bilinear-transform-and-zero-order-hold-matrices)
    * [Testing Forward Euler on GBT matrices](#testing-forward-euler-transform-for-lti-and-lsi)
    * [Testing Backward Euler on GBT matrices](#testing-backward-euler-transform-for-lti-and-lsi-on-legs-matrices)
    * [Testing Bidirectional on GBT matrices](#testing-lti-and-lsi-operators-with-bidirectional-transform)
    * [Testing ZOH on GBT matrices](#testing-zoh-transform-for-lti-and-lsi-on-legs-matrices)
* [Testing HiPPO Operators](#test-hippo-operators)
    * [Testing Forward Euler on HiPPO Operators](#testing-lti-and-lsi-operators-with-forward-euler-transform)
    * [Testing Backward Euler on HiPPO Operators](#testing-lti-and-lsi-operators-with-backward-euler-transform)
    * [Testing Bidirectional on HiPPO Operators](#testing-lti-and-lsi-operators-with-bidirectional-transform)
    * [Testing ZOH on HiPPO Operators](#testing-lti-and-lsi-operators-with-zoh-transform)
---


## Load Packages

In [1]:
import os
import sys
import warnings

module_path = os.path.abspath(os.path.join("../../../"))
print(f"module_path: {module_path}")
if module_path not in sys.path:
    print(f"Adding {module_path} to sys.path")
    sys.path.append(module_path)

module_path: /home/beegass/Documents/Coding/HiPPO-Jax


In [2]:
warnings.filterwarnings("ignore")

In [3]:
# os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "False"
os.environ["TF_FORCE_UNIFIED_MEMORY"] = "1"

In [4]:
## import packages
import jax
import jax.numpy as jnp
from src.models.hippo.hr_hippo import HRHiPPO_LSI, HRHiPPO_LTI
from src.models.hippo.hippo import HiPPOLTI, HiPPOLSI
from src.models.hippo.cells import HiPPOLSICell, HiPPOLTICell, HiPPO
from src.models.hippo.transition import (
    legs,
    legs_initializer,
    legt,
    legt_initializer,
    lmu,
    lmu_initializer,
    lagt,
    lagt_initializer,
    fru,
    fru_initializer,
    fout,
    fout_initializer,
    foud,
    foud_initializer,
    chebt,
    chebt_initializer,
)
from src.data.process import whitesignal
import einops
import numpy as np
import torch
import time

In [5]:
print(jax.devices())
print(f"The Device: {jax.lib.xla_bridge.get_backend().platform}")

[StreamExecutorGpuDevice(id=0, process_index=0, slice_index=0)]
The Device: gpu


In [6]:
print(f"MPS enabled: {torch.backends.mps.is_available()}")

MPS enabled: False


In [7]:
torch.set_printoptions(linewidth=150)
np.set_printoptions(linewidth=150)
jnp.set_printoptions(linewidth=150)

In [8]:
seed = 1701
key = jax.random.PRNGKey(seed)

In [9]:
num_copies = 10
subkeys = jax.random.split(key, num=num_copies)
key = subkeys[0]

In [10]:
# Parameters For Generating Data
T = 1
freq = 1
step = 1e-3
L = int(T / step)

# Parameters For Training
batch_size = 2
data_size = L
input_size = 1

# Parameters For HiPPO
N = 64

## Test HiPPO Operators

In [11]:
def test_hippo_operator(
    key, hippo, hr_hippo, random_input, s_or_t="lti", print_all=False
):
    x_tensor = torch.tensor(random_input, dtype=torch.float32)
    x_jnp = jnp.asarray(x_tensor, dtype=jnp.float32)  # convert torch array to jax array

    # My Implementation
    if print_all:
        print(
            f"------------------------------------------------------------------------------------------"
        )
        print(
            f"----------------------------My {s_or_t} Implementation Outputs----------------------------"
        )
        print(
            f"------------------------------------------------------------------------------------------"
        )

    x_jnp = einops.rearrange(x_jnp, "batch seq_len -> batch seq_len 1")

    # params = hippo.init(key, f=x_jnp)
    c_t_1 = hippo.initialize_state(
        subkeys[7], batch_size=random_input.shape[0], hidden_size=64
    )
    params = hippo.init(key, f=x_jnp, c_t_1=c_t_1)

    start = time.time()
    c, y = hippo.apply(params, f=x_jnp, c_t_1=c_t_1)
    end = time.time()

    duration = end - start
    print(f"Jax Duration: {duration}")

    # Gu's HiPPO LegS
    if print_all:
        print(
            f"------------------------------------------------------------------------------------------"
        )
        print(
            f"---------------------------Gu's {s_or_t} Implementation Outputs---------------------------"
        )
        print(
            f"------------------------------------------------------------------------------------------"
        )
    x_tensor = torch.moveaxis(x_tensor, 0, 1)

    start_t = time.time()
    hr_cs, hr_c_k = hr_hippo(x_tensor, fast=False)
    end_t = time.time()

    duration_t = end_t - start_t
    print(f"PyTorch Duration: {duration_t}")

    hr_c = None
    if s_or_t == "lti":
        hr_c = einops.rearrange(hr_c_k, "batch N -> batch 1 N")
    elif s_or_t == "lsi":
        hr_c = einops.rearrange(hr_cs, "seq_len batch N -> batch seq_len 1 N")
    hr_c = jnp.asarray(hr_c, dtype=jnp.float32)  # convert torch array to jax array

    if print_all:
        print(
            f"------------------------------------------------------------------------------"
        )
        print(
            f"---------------------------Testing {s_or_t} Outputs---------------------------"
        )
        print(
            f"------------------------------------------------------------------------------"
        )

    co_flag = True

    def check_close(a, b):
        if print_all:
            jax.debug.print("c:\n{x}", x=a)
            jax.debug.print("hr_c:\n{x}", x=b)
        return jnp.allclose(a, b, rtol=1e-03, atol=1e-03)

    bool_arr = None
    if s_or_t == "lsi":
        batch_check_close = jax.vmap(check_close, in_axes=(0, 0))
        time_check_close = jax.vmap(batch_check_close, in_axes=(0, 0))
        bool_arr = time_check_close(c, hr_c)
    elif s_or_t == "lti":
        batch_check_close = jax.vmap(check_close, in_axes=(0, 0))
        bool_arr = batch_check_close(c, hr_c)

    if not bool_arr.all():
        co_flag = False

    if not print_all:
        print(f"\n------------------------------------------------------------")
        print(f"---------- The Coefficients Test Passed: {co_flag} ----------")
        print(f"-------------------------------------------------------------\n\n")

    if print_all:
        print(f"bool_arr:\n{bool_arr}\n\n")
        print(f"c:\n {c}")
        print(f"c shape:\n {c.shape}\n\n")

        print(f"hr_c:\n {hr_c}")
        print(f"hr_c shape:\n {hr_c.shape}\n\n")

In [12]:
def test_lsi_operators(
    the_measure="legs", alpha=0.5, discretization=0.5, print_all=False
):
    assert (
        the_measure == "legs"
    ), f"the_measure must be 'legs' for this test, got {the_measure}"

    u = whitesignal(subkeys[4], T, step, freq, batch_shape=(batch_size,))

    x_np = np.asarray(u)

    # ----------------------------------------------------------------------------------
    # ------------------------------ Instantiate Hazy Research's HiPPOs ---------------------------
    # ----------------------------------------------------------------------------------
    print(f"Creating Gu's HiPPO-{the_measure} LSI model with {alpha} transform")
    hr_hippo_lsi = HRHiPPO_LSI(
        N=N,
        method="legs",
        max_length=L,
        discretization=discretization,
        lambda_n=1.0,
        alpha=0.0,
        beta=1.0,
    )

    # ----------------------------------------------------------------------------------
    # ------------------------------ Instantiate My HiPPOs -----------------------------
    # ----------------------------------------------------------------------------------
    print(f"Creating HiPPO-{the_measure} LSI model with {alpha} transform")

    hippo_lsi_cell = HiPPOLSICell
    h_args = {
        "max_length": L,
        "alpha": alpha,
        "init_t": 0,
        "recon": True,
        "A_init_fn": legs,
        "B_init_fn": legs,
    }
    hippo_lsi = HiPPO(
        features=N,
        hippo_cell=hippo_lsi_cell,
        hippo_args=h_args,
        init_t=0,
        unroll=True,
        st=True,
    )

    # ----------------------------------------------------------------------------------
    # ------------------------------ Test HiPPO Operators ------------------------------
    # ----------------------------------------------------------------------------------
    print(f"\n\nTesting Coeffiecients for {alpha} LSI HiPPO-{the_measure}")

    test_hippo_operator(
        key=subkeys[6],
        hippo=hippo_lsi,
        hr_hippo=hr_hippo_lsi,
        random_input=x_np,
        s_or_t="lsi",
        print_all=print_all,
    )

    print(f"end of test for HiPPO-{the_measure} model")

In [13]:
def test_lti_operators(
    the_measure="legs", lambda_n=1.0, alpha=0.5, discretization=0.5, print_all=False
):
    assert the_measure in ["legs", "legt", "lmu", "lagt", "fru", "fout", "foud"]
    init_fn = {
        "legs": legs,
        "legt": legt,
        "lmu": lmu,
        "lagt": lagt,
        "fru": fru,
        "fout": fout,
        "foud": foud,
        "chebt": chebt,
    }[the_measure]

    u = whitesignal(subkeys[4], T, step, freq, batch_shape=(batch_size,))

    x_np = np.asarray(u)

    # ----------------------------------------------------------------------------------
    # ------------------------------ Instantiate Gu's HiPPOs ---------------------------
    # ----------------------------------------------------------------------------------
    print(f"Creating Gu's HiPPO-{the_measure} LTI model with {alpha} transform")
    hr_hippo_lti = HRHiPPO_LTI(
        N=N,
        method=the_measure,
        dt=step,
        T=T,
        discretization=discretization,
        lambda_n=lambda_n,
        alpha=0.0,
        beta=1.0,
        c=0.0,
    )  # The Gu's

    # ----------------------------------------------------------------------------------
    # ------------------------------ Instantiate My HiPPOs -----------------------------
    # ----------------------------------------------------------------------------------
    print(f"Creating HiPPO-{the_measure} LTI model with {alpha} transform")
    hippo_lti_cell = HiPPOLTICell
    h_args = {
        "step_size": step,
        "basis_size": T,
        "alpha": alpha,
        "recon": True,
        "A_init_fn": init_fn,
        "B_init_fn": init_fn,
    }
    hippo_lti = HiPPO(
        features=N,
        hippo_cell=hippo_lti_cell,
        hippo_args=h_args,
        init_t=0,
        unroll=False,
        st=False,
    )

    # ----------------------------------------------------------------------------------
    # ------------------------------ Test HiPPO Operators ------------------------------
    # ----------------------------------------------------------------------------------
    print(f"Testing Coeffiecients for {alpha} LTI HiPPO-{the_measure}")

    test_hippo_operator(
        hippo=hippo_lti,
        hr_hippo=hr_hippo_lti,
        random_input=x_np,
        key=subkeys[5],
        s_or_t="lti",
        print_all=print_all,
    )

    print(f"end of test for HiPPO-{the_measure} model")

## Navigation To Table Of Contents
---
* [Table Of Contents](#table-of-contents)
* [Loading In Necessary Packages](#load-packages)
* [Instantiate The HiPPO Matrix](#instantiate-the-hippo-matrix)
* [Gu's Linear Time Invariant (LTI) HiPPO Operator](#gus-hippo-legt-operator)
* [Gu's Scale invariant (LSI) HiPPO Operator](#gus-scale-invariant-hippo-legs-operator)
* [Implementation Of General HiPPO Operator](#implementation-of-general-hippo-operator)
* [Test Generalized Bilinear Transform and Zero Order Hold Matrices](#test-generalized-bilinear-transform-and-zero-order-hold-matrices)
* [Testing HiPPO Operators](#test-hippo-operators)
---

In [14]:
print_all = False

### Testing (LTI and LSI) Operators With Forward Euler Transform

#### LegS

In [15]:
test_lsi_operators(
    the_measure="legs", alpha=0.0, discretization=0.0, print_all=print_all
)

Creating Gu's HiPPO-legs LSI model with 0.0 transform
Creating HiPPO-legs LSI model with 0.0 transform


Testing Coeffiecients for 0.0 LSI HiPPO-legs
Jax Duration: 0.3518373966217041
PyTorch Duration: 0.1345679759979248

------------------------------------------------------------
---------- The Coefficients Test Passed: False ----------
-------------------------------------------------------------


end of test for HiPPO-legs model


In [16]:
test_lti_operators(
    the_measure="legs", lambda_n=1.0, alpha=0.0, discretization=0.0, print_all=print_all
)

Creating Gu's HiPPO-legs LTI model with 0.0 transform
Creating HiPPO-legs LTI model with 0.0 transform
Testing Coeffiecients for 0.0 LTI HiPPO-legs
Jax Duration: 0.30309629440307617
PyTorch Duration: 0.12659096717834473

------------------------------------------------------------
---------- The Coefficients Test Passed: True ----------
-------------------------------------------------------------


end of test for HiPPO-legs model


#### LegT

In [17]:
test_lti_operators(
    the_measure="legt", lambda_n=1.0, alpha=0.0, discretization=0.0, print_all=print_all
)

Creating Gu's HiPPO-legt LTI model with 0.0 transform
Creating HiPPO-legt LTI model with 0.0 transform
Testing Coeffiecients for 0.0 LTI HiPPO-legt
Jax Duration: 0.21370410919189453
PyTorch Duration: 0.12743449211120605

------------------------------------------------------------
---------- The Coefficients Test Passed: True ----------
-------------------------------------------------------------


end of test for HiPPO-legt model


#### LMU

In [18]:
test_lti_operators(
    the_measure="lmu", lambda_n=2.0, alpha=0.0, discretization=0.0, print_all=print_all
)

Creating Gu's HiPPO-lmu LTI model with 0.0 transform
Creating HiPPO-lmu LTI model with 0.0 transform
Testing Coeffiecients for 0.0 LTI HiPPO-lmu
Jax Duration: 0.20223116874694824
PyTorch Duration: 0.12044692039489746

------------------------------------------------------------
---------- The Coefficients Test Passed: True ----------
-------------------------------------------------------------


end of test for HiPPO-lmu model


#### LagT

In [19]:
test_lti_operators(
    the_measure="lagt", lambda_n=1.0, alpha=0.0, discretization=0.0, print_all=print_all
)

Creating Gu's HiPPO-lagt LTI model with 0.0 transform
Creating HiPPO-lagt LTI model with 0.0 transform
Testing Coeffiecients for 0.0 LTI HiPPO-lagt
Jax Duration: 0.20954418182373047
PyTorch Duration: 0.12122273445129395

------------------------------------------------------------
---------- The Coefficients Test Passed: True ----------
-------------------------------------------------------------


end of test for HiPPO-lagt model


#### FRU

In [20]:
test_lti_operators(
    the_measure="fru", lambda_n=1.0, alpha=0.0, discretization=0.0, print_all=print_all
)

Creating Gu's HiPPO-fru LTI model with 0.0 transform
Creating HiPPO-fru LTI model with 0.0 transform
Testing Coeffiecients for 0.0 LTI HiPPO-fru
Jax Duration: 0.23551368713378906
PyTorch Duration: 0.12724924087524414

------------------------------------------------------------
---------- The Coefficients Test Passed: True ----------
-------------------------------------------------------------


end of test for HiPPO-fru model


#### FouT

In [21]:
test_lti_operators(
    the_measure="fout", lambda_n=1.0, alpha=0.0, discretization=0.0, print_all=print_all
)

Creating Gu's HiPPO-fout LTI model with 0.0 transform
Creating HiPPO-fout LTI model with 0.0 transform
Testing Coeffiecients for 0.0 LTI HiPPO-fout
Jax Duration: 0.26561474800109863
PyTorch Duration: 0.13262104988098145

------------------------------------------------------------
---------- The Coefficients Test Passed: True ----------
-------------------------------------------------------------


end of test for HiPPO-fout model


#### FouD

In [22]:
test_lti_operators(
    the_measure="foud", lambda_n=1.0, alpha=0.0, discretization=0.0, print_all=print_all
)

Creating Gu's HiPPO-foud LTI model with 0.0 transform
Creating HiPPO-foud LTI model with 0.0 transform
Testing Coeffiecients for 0.0 LTI HiPPO-foud
Jax Duration: 0.23000240325927734
PyTorch Duration: 0.12862062454223633

------------------------------------------------------------
---------- The Coefficients Test Passed: True ----------
-------------------------------------------------------------


end of test for HiPPO-foud model


### Testing (LTI and LSI) Operators With Backward Euler Transform

#### LegS

In [23]:
test_lsi_operators(
    the_measure="legs", alpha=1.0, discretization=1.0, print_all=print_all
)

Creating Gu's HiPPO-legs LSI model with 1.0 transform
Creating HiPPO-legs LSI model with 1.0 transform


Testing Coeffiecients for 1.0 LSI HiPPO-legs
Jax Duration: 0.3319535255432129
PyTorch Duration: 0.131638765335083

------------------------------------------------------------
---------- The Coefficients Test Passed: True ----------
-------------------------------------------------------------


end of test for HiPPO-legs model


In [24]:
test_lti_operators(
    the_measure="legs", lambda_n=1.0, alpha=1.0, discretization=1.0, print_all=print_all
)

Creating Gu's HiPPO-legs LTI model with 1.0 transform
Creating HiPPO-legs LTI model with 1.0 transform
Testing Coeffiecients for 1.0 LTI HiPPO-legs
Jax Duration: 0.20882225036621094
PyTorch Duration: 0.12144041061401367

------------------------------------------------------------
---------- The Coefficients Test Passed: True ----------
-------------------------------------------------------------


end of test for HiPPO-legs model


#### LegT

In [25]:
test_lti_operators(
    the_measure="legt", lambda_n=1.0, alpha=1.0, discretization=1.0, print_all=print_all
)

Creating Gu's HiPPO-legt LTI model with 1.0 transform
Creating HiPPO-legt LTI model with 1.0 transform
Testing Coeffiecients for 1.0 LTI HiPPO-legt
Jax Duration: 0.2201390266418457
PyTorch Duration: 0.12486433982849121

------------------------------------------------------------
---------- The Coefficients Test Passed: True ----------
-------------------------------------------------------------


end of test for HiPPO-legt model


#### LMU

In [26]:
test_lti_operators(
    the_measure="lmu", lambda_n=2.0, alpha=1.0, discretization=1.0, print_all=print_all
)

Creating Gu's HiPPO-lmu LTI model with 1.0 transform
Creating HiPPO-lmu LTI model with 1.0 transform
Testing Coeffiecients for 1.0 LTI HiPPO-lmu
Jax Duration: 0.2026219367980957
PyTorch Duration: 0.12890625

------------------------------------------------------------
---------- The Coefficients Test Passed: True ----------
-------------------------------------------------------------


end of test for HiPPO-lmu model


#### LagT

In [27]:
test_lti_operators(
    the_measure="lagt", lambda_n=1.0, alpha=1.0, discretization=1.0, print_all=print_all
)

Creating Gu's HiPPO-lagt LTI model with 1.0 transform
Creating HiPPO-lagt LTI model with 1.0 transform
Testing Coeffiecients for 1.0 LTI HiPPO-lagt
Jax Duration: 0.21983551979064941
PyTorch Duration: 0.12653017044067383

------------------------------------------------------------
---------- The Coefficients Test Passed: True ----------
-------------------------------------------------------------


end of test for HiPPO-lagt model


#### FRU

In [28]:
test_lti_operators(
    the_measure="fru", lambda_n=1.0, alpha=1.0, discretization=1.0, print_all=print_all
)

Creating Gu's HiPPO-fru LTI model with 1.0 transform
Creating HiPPO-fru LTI model with 1.0 transform
Testing Coeffiecients for 1.0 LTI HiPPO-fru
Jax Duration: 0.21619820594787598
PyTorch Duration: 0.12327194213867188

------------------------------------------------------------
---------- The Coefficients Test Passed: True ----------
-------------------------------------------------------------


end of test for HiPPO-fru model


#### FouT

In [29]:
test_lti_operators(
    the_measure="fout", lambda_n=1.0, alpha=1.0, discretization=1.0, print_all=print_all
)

Creating Gu's HiPPO-fout LTI model with 1.0 transform
Creating HiPPO-fout LTI model with 1.0 transform
Testing Coeffiecients for 1.0 LTI HiPPO-fout
Jax Duration: 0.28851747512817383
PyTorch Duration: 0.12918448448181152

------------------------------------------------------------
---------- The Coefficients Test Passed: True ----------
-------------------------------------------------------------


end of test for HiPPO-fout model


#### FouD

In [30]:
test_lti_operators(
    the_measure="foud", lambda_n=1.0, alpha=1.0, discretization=1.0, print_all=print_all
)

Creating Gu's HiPPO-foud LTI model with 1.0 transform
Creating HiPPO-foud LTI model with 1.0 transform
Testing Coeffiecients for 1.0 LTI HiPPO-foud
Jax Duration: 0.25328493118286133
PyTorch Duration: 0.12967562675476074

------------------------------------------------------------
---------- The Coefficients Test Passed: True ----------
-------------------------------------------------------------


end of test for HiPPO-foud model


### Testing (LTI and LSI) Operators With Bidirectional Transform

#### LegS

In [31]:
test_lsi_operators(
    the_measure="legs", alpha=0.5, discretization=0.5, print_all=print_all
)

Creating Gu's HiPPO-legs LSI model with 0.5 transform
Creating HiPPO-legs LSI model with 0.5 transform


Testing Coeffiecients for 0.5 LSI HiPPO-legs
Jax Duration: 0.3303253650665283
PyTorch Duration: 0.13050270080566406

------------------------------------------------------------
---------- The Coefficients Test Passed: True ----------
-------------------------------------------------------------


end of test for HiPPO-legs model


In [32]:
test_lti_operators(
    the_measure="legs", lambda_n=1.0, alpha=0.5, discretization=0.5, print_all=print_all
)

Creating Gu's HiPPO-legs LTI model with 0.5 transform
Creating HiPPO-legs LTI model with 0.5 transform
Testing Coeffiecients for 0.5 LTI HiPPO-legs
Jax Duration: 0.22737574577331543
PyTorch Duration: 0.12870502471923828

------------------------------------------------------------
---------- The Coefficients Test Passed: True ----------
-------------------------------------------------------------


end of test for HiPPO-legs model


#### LegT

In [33]:
test_lti_operators(
    the_measure="legt", lambda_n=1.0, alpha=0.5, discretization=0.5, print_all=print_all
)

Creating Gu's HiPPO-legt LTI model with 0.5 transform
Creating HiPPO-legt LTI model with 0.5 transform
Testing Coeffiecients for 0.5 LTI HiPPO-legt
Jax Duration: 0.2138075828552246
PyTorch Duration: 0.13062429428100586

------------------------------------------------------------
---------- The Coefficients Test Passed: True ----------
-------------------------------------------------------------


end of test for HiPPO-legt model


#### LMU

In [34]:
test_lti_operators(
    the_measure="lmu", lambda_n=2.0, alpha=0.5, discretization=0.5, print_all=print_all
)

Creating Gu's HiPPO-lmu LTI model with 0.5 transform
Creating HiPPO-lmu LTI model with 0.5 transform
Testing Coeffiecients for 0.5 LTI HiPPO-lmu
Jax Duration: 0.203108549118042
PyTorch Duration: 0.12275242805480957

------------------------------------------------------------
---------- The Coefficients Test Passed: True ----------
-------------------------------------------------------------


end of test for HiPPO-lmu model


#### LagT

In [35]:
test_lti_operators(
    the_measure="lagt", lambda_n=1.0, alpha=0.5, discretization=0.5, print_all=print_all
)

Creating Gu's HiPPO-lagt LTI model with 0.5 transform
Creating HiPPO-lagt LTI model with 0.5 transform
Testing Coeffiecients for 0.5 LTI HiPPO-lagt
Jax Duration: 0.2168130874633789
PyTorch Duration: 0.1261742115020752

------------------------------------------------------------
---------- The Coefficients Test Passed: True ----------
-------------------------------------------------------------


end of test for HiPPO-lagt model


#### FRU

In [36]:
test_lti_operators(
    the_measure="fru", lambda_n=1.0, alpha=0.5, discretization=0.5, print_all=print_all
)

Creating Gu's HiPPO-fru LTI model with 0.5 transform
Creating HiPPO-fru LTI model with 0.5 transform
Testing Coeffiecients for 0.5 LTI HiPPO-fru
Jax Duration: 0.23087668418884277
PyTorch Duration: 0.12711548805236816

------------------------------------------------------------
---------- The Coefficients Test Passed: True ----------
-------------------------------------------------------------


end of test for HiPPO-fru model


#### FouT

In [37]:
test_lti_operators(
    the_measure="fout", lambda_n=1.0, alpha=0.5, discretization=0.5, print_all=print_all
)

Creating Gu's HiPPO-fout LTI model with 0.5 transform
Creating HiPPO-fout LTI model with 0.5 transform
Testing Coeffiecients for 0.5 LTI HiPPO-fout
Jax Duration: 0.25879812240600586
PyTorch Duration: 0.1258854866027832

------------------------------------------------------------
---------- The Coefficients Test Passed: True ----------
-------------------------------------------------------------


end of test for HiPPO-fout model


#### FouD

In [38]:
test_lti_operators(
    the_measure="foud", lambda_n=1.0, alpha=0.5, discretization=0.5, print_all=print_all
)

Creating Gu's HiPPO-foud LTI model with 0.5 transform
Creating HiPPO-foud LTI model with 0.5 transform
Testing Coeffiecients for 0.5 LTI HiPPO-foud
Jax Duration: 0.23973417282104492
PyTorch Duration: 0.12212300300598145

------------------------------------------------------------
---------- The Coefficients Test Passed: True ----------
-------------------------------------------------------------


end of test for HiPPO-foud model


### Testing (LTI and LSI) Operators With ZOH Transform

#### LegS

In [39]:
test_lsi_operators(
    the_measure="legs",
    alpha=2.0,
    discretization="zoh",
    print_all=print_all,
)

Creating Gu's HiPPO-legs LSI model with 2.0 transform
Creating HiPPO-legs LSI model with 2.0 transform


Testing Coeffiecients for 2.0 LSI HiPPO-legs
Jax Duration: 0.3464968204498291
PyTorch Duration: 0.13298988342285156

------------------------------------------------------------
---------- The Coefficients Test Passed: True ----------
-------------------------------------------------------------


end of test for HiPPO-legs model


In [40]:
test_lti_operators(
    the_measure="legs",
    lambda_n=1.0,
    alpha=2.0,
    discretization="zoh",
    print_all=print_all,
)

Creating Gu's HiPPO-legs LTI model with 2.0 transform
Creating HiPPO-legs LTI model with 2.0 transform
Testing Coeffiecients for 2.0 LTI HiPPO-legs
Jax Duration: 0.22150254249572754
PyTorch Duration: 0.12716913223266602

------------------------------------------------------------
---------- The Coefficients Test Passed: True ----------
-------------------------------------------------------------


end of test for HiPPO-legs model


#### LegT

In [41]:
test_lti_operators(
    the_measure="legt",
    lambda_n=1.0,
    alpha=2.0,
    discretization="zoh",
    print_all=print_all,
)

Creating Gu's HiPPO-legt LTI model with 2.0 transform
Creating HiPPO-legt LTI model with 2.0 transform
Testing Coeffiecients for 2.0 LTI HiPPO-legt
Jax Duration: 0.22566962242126465
PyTorch Duration: 0.12636852264404297

------------------------------------------------------------
---------- The Coefficients Test Passed: True ----------
-------------------------------------------------------------


end of test for HiPPO-legt model


#### LMU

In [42]:
test_lti_operators(
    the_measure="lmu",
    lambda_n=2.0,
    alpha=2.0,
    discretization="zoh",
    print_all=print_all,
)

Creating Gu's HiPPO-lmu LTI model with 2.0 transform
Creating HiPPO-lmu LTI model with 2.0 transform
Testing Coeffiecients for 2.0 LTI HiPPO-lmu
Jax Duration: 0.2460160255432129
PyTorch Duration: 0.12760305404663086

------------------------------------------------------------
---------- The Coefficients Test Passed: True ----------
-------------------------------------------------------------


end of test for HiPPO-lmu model


#### LagT

In [43]:
test_lti_operators(
    the_measure="lagt",
    lambda_n=1.0,
    alpha=2.0,
    discretization="zoh",
    print_all=print_all,
)

Creating Gu's HiPPO-lagt LTI model with 2.0 transform
Creating HiPPO-lagt LTI model with 2.0 transform
Testing Coeffiecients for 2.0 LTI HiPPO-lagt
Jax Duration: 0.23780131340026855
PyTorch Duration: 0.1227273941040039

------------------------------------------------------------
---------- The Coefficients Test Passed: True ----------
-------------------------------------------------------------


end of test for HiPPO-lagt model


#### FRU

In [44]:
test_lti_operators(
    the_measure="fru",
    lambda_n=1.0,
    alpha=2.0,
    discretization="zoh",
    print_all=print_all,
)

Creating Gu's HiPPO-fru LTI model with 2.0 transform
Creating HiPPO-fru LTI model with 2.0 transform
Testing Coeffiecients for 2.0 LTI HiPPO-fru
Jax Duration: 0.24944567680358887
PyTorch Duration: 0.12150359153747559

------------------------------------------------------------
---------- The Coefficients Test Passed: True ----------
-------------------------------------------------------------


end of test for HiPPO-fru model


#### FouT

In [45]:
test_lti_operators(
    the_measure="fout",
    lambda_n=1.0,
    alpha=2.0,
    discretization="zoh",
    print_all=print_all,
)

Creating Gu's HiPPO-fout LTI model with 2.0 transform
Creating HiPPO-fout LTI model with 2.0 transform
Testing Coeffiecients for 2.0 LTI HiPPO-fout
Jax Duration: 0.28164148330688477
PyTorch Duration: 0.1184089183807373

------------------------------------------------------------
---------- The Coefficients Test Passed: True ----------
-------------------------------------------------------------


end of test for HiPPO-fout model


#### FouD

In [46]:
test_lti_operators(
    the_measure="foud",
    lambda_n=1.0,
    alpha=2.0,
    discretization="zoh",
    print_all=print_all,
)

Creating Gu's HiPPO-foud LTI model with 2.0 transform
Creating HiPPO-foud LTI model with 2.0 transform
Testing Coeffiecients for 2.0 LTI HiPPO-foud
Jax Duration: 0.2612276077270508
PyTorch Duration: 0.12821745872497559

------------------------------------------------------------
---------- The Coefficients Test Passed: True ----------
-------------------------------------------------------------


end of test for HiPPO-foud model
