# HiPPO Matrices
---

## Load Packages

In [1]:
import os
import sys
import warnings

module_path = os.path.abspath(os.path.join("../../../"))
print(f"module_path: {module_path}")
if module_path not in sys.path:
    print(f"Adding {module_path} to sys.path")
    sys.path.append(module_path)

module_path: /home/beegass/Documents/Coding/HiPPO-Jax


In [2]:
warnings.filterwarnings("ignore")

In [3]:
# os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "False"
os.environ["TF_FORCE_UNIFIED_MEMORY"] = "1"

In [4]:
## import packages
import jax
import jax.numpy as jnp
from src.models.hippo.hr_hippo import HRHiPPO_LSI, HRHiPPO_LTI
from src.models.hippo.hippo import HiPPOLTI, HiPPOLSI
from src.models.hippo.cells import HiPPOLSICell, HiPPOLTICell, HiPPO
from src.models.hippo.transition import (
    legs,
    legs_initializer,
    legt,
    legt_initializer,
    lmu,
    lmu_initializer,
    lagt,
    lagt_initializer,
    fru,
    fru_initializer,
    fout,
    fout_initializer,
    foud,
    foud_initializer,
    chebt,
    chebt_initializer,
)
from src.data.process import whitesignal
import einops
import numpy as np
import torch
import time

In [5]:
print(jax.devices())
print(f"The Device: {jax.lib.xla_bridge.get_backend().platform}")

[StreamExecutorGpuDevice(id=0, process_index=0, slice_index=0)]
The Device: gpu


In [6]:
print(f"MPS enabled: {torch.backends.mps.is_available()}")

MPS enabled: False


In [7]:
torch.set_printoptions(linewidth=150)
np.set_printoptions(linewidth=150)
jnp.set_printoptions(linewidth=150)

In [8]:
seed = 1701
key = jax.random.PRNGKey(seed)

In [9]:
num_copies = 10
subkeys = jax.random.split(key, num=num_copies)
key = subkeys[0]

In [10]:
# Parameters For Generating Data
T = 1
freq = 1
step = 1e-3
L = int(T / step)

# Parameters For Training
batch_size = 2
data_size = L
input_size = 1

# Parameters For HiPPO
N = 64

## Test HiPPO Operators

In [11]:
def LSI_GBT(hippo, hr_hippo, A, B, random_input, alpha=0.5, print_all=False):
    L = random_input.shape[1]

    flag = True
    for i in range(1, L + 1):
        GBT_A, GBT_B = hippo.discretize(A, B, step=i, alpha=alpha, dtype=jnp.float32)
        hr_GBT_A, hr_GBT_B = (
            jnp.asarray(hr_hippo.A_stacked[i - 1], dtype=jnp.float32),
            jnp.expand_dims(
                jnp.asarray(hr_hippo.B_stacked[i - 1], dtype=jnp.float32), axis=1
            ),
        )

        if print_all:
            print(f"GBT_A: {jnp.allclose(GBT_A, hr_GBT_A, rtol=1e-04, atol=1e-04)}")
            print(f"GBT_B: {jnp.allclose(GBT_B, hr_GBT_B, rtol=1e-04, atol=1e-04)}\n")

            print(f"hr_GBT_A shape:{hr_GBT_A.shape}\n")
            print(f"GBT_A shape: {GBT_A.shape}\n")
            print(f"hr_GBT_B shape: {hr_GBT_B.shape}\n")
            print(f"GBT_B shape: {GBT_B.shape}")

            print(f"hr_GBT_A:\n{hr_GBT_A}\n")
            print(f"GBT_A:\n{GBT_A}\n")
            print(f"hr_GBT_B:\n{hr_GBT_B}\n")
            print(f"GBT_B:\n{GBT_B}")

        check_A = jnp.allclose(GBT_A, hr_GBT_A, rtol=1e-04, atol=1e-04)
        check_B = jnp.allclose(GBT_B, hr_GBT_B, rtol=1e-04, atol=1e-04)

        if not check_A and not check_B:
            flag = False

    if not print_all:
        print(f"\n---------------------------------------------")
        print(f"---------- The Test Passed: {flag} ----------")
        print(f"---------------------------------------------\n")

In [12]:
def LTI_GBT(hippo, hr_hippo, A, B, random_input, step, alpha=0.5, print_all=False):
    L = random_input.shape[1]
    flag = True
    GBT_A, GBT_B = hippo.discretize(A, B, step=step, alpha=alpha, dtype=jnp.float32)
    hr_GBT_A, hr_GBT_B = (
        jnp.asarray(hr_hippo.dA, dtype=jnp.float32),
        jnp.expand_dims(jnp.asarray(hr_hippo.dB, dtype=jnp.float32), axis=1),
    )

    check_A = jnp.allclose(GBT_A, hr_GBT_A, rtol=1e-04, atol=1e-04)
    check_B = jnp.allclose(GBT_B, hr_GBT_B, rtol=1e-04, atol=1e-04)

    if not check_A and not check_B:
        flag = False

    if print_all:
        print(f"GBT_A: {jnp.allclose(GBT_A, hr_GBT_A, rtol=1e-04, atol=1e-04)}")
        print(f"GBT_B: {jnp.allclose(GBT_B, hr_GBT_B, rtol=1e-04, atol=1e-04)}\n")

        print(f"hr_GBT_A shape:{hr_GBT_A.shape}\n")
        print(f"GBT_A shape: {GBT_A.shape}\n")
        print(f"hr_GBT_B shape: {hr_GBT_B.shape}\n")
        print(f"GBT_B shape: {GBT_B.shape}")

        print(f"hr_GBT_A:\n{hr_GBT_A}\n")
        print(f"GBT_A:\n{GBT_A}\n")
        print(f"hr_GBT_B:\n{hr_GBT_B}\n")
        print(f"GBT_B:\n{GBT_B}")

    if not print_all:
        print(f"\n---------------------------------------------")
        print(f"---------- The Test Passed: {flag} ----------")
        print(f"---------------------------------------------\n")

In [13]:
def test_LSI_GBT(
    the_measure="legs", lambda_n=1.0, alpha=0.5, discretization=0.5, print_all=False
):
    assert (
        the_measure == "legs"
    ), f"the_measure must be 'legs' for this test, got {the_measure}"

    u = whitesignal(subkeys[4], T, step, freq, batch_shape=(batch_size,))

    x_np = np.asarray(u)
    x_np = einops.rearrange(x_np, "batch seq_len -> batch seq_len 1")
    x_jnp = jnp.asarray(x_np, dtype=jnp.float32)

    # ----------------------------------------------------------------------------------
    # ------------------- Instantiate Hazy Research's HiPPOs ---------------------------
    # ----------------------------------------------------------------------------------

    print(f"Creating Gu's HiPPO-{the_measure} LSI model with {alpha} transform")
    hr_hippo_lsi = HRHiPPO_LSI(
        N=N,
        method=the_measure,
        max_length=L,
        discretization=discretization,
        lambda_n=lambda_n,
        alpha=0.0,
        beta=1.0,
    )  # The Gu's

    # ----------------------------------------------------------------------------------
    # ------------------------------ Instantiate HiPPO Matrices ------------------------
    # ----------------------------------------------------------------------------------

    A, B, _ = legs(N=N, dtype=jnp.float32)

    # ----------------------------------------------------------------------------------
    # ------------------------------ Instantiate Our HiPPOs ---------------------------
    # ----------------------------------------------------------------------------------

    print(f"Creating HiPPO-{the_measure} LSI model with {alpha} transform")
    hippo_lsi_cell = HiPPOLSICell(
        features=N,
        max_length=L,
        alpha=alpha,
        init_t=0,
        recon=True,
        A_init_fn=legs,
        B_init_fn=legs,
    )
    c_t_1 = hippo_lsi_cell.initialize_state(
        subkeys[7], batch_size=x_jnp.shape[0], hidden_size=N
    )
    t = jnp.arange(0, x_jnp.shape[1])
    lsi_variables = hippo_lsi_cell.init(subkeys[2], c_t_1=c_t_1, f=x_jnp, t_step=t)
    hippo_lsi_cell = hippo_lsi_cell.bind(lsi_variables)

    print(f"Testing for correct LSI GBT matrices for HiPPO-{the_measure}")
    LSI_GBT(
        hippo=hippo_lsi_cell,
        hr_hippo=hr_hippo_lsi,
        A=A,
        B=B,
        random_input=x_np,
        alpha=alpha,
        print_all=print_all,
    )

In [14]:
def test_LTI_GBT(
    the_measure="legs", lambda_n=1.0, alpha=0.5, discretization=0.5, print_all=False
):
    assert the_measure in ["legs", "legt", "lmu", "lagt", "fru", "fout", "foud"]
    init_fn = {
        "legs": legs,
        "legt": legt,
        "lmu": lmu,
        "lagt": lagt,
        "fru": fru,
        "fout": fout,
        "foud": foud,
        "chebt": chebt,
    }[the_measure]

    u = whitesignal(subkeys[4], T, step, freq, batch_shape=(batch_size,))

    x_np = np.asarray(u)
    x_np = einops.rearrange(x_np, "batch seq_len -> batch seq_len 1")
    x_jnp = jnp.asarray(x_np, dtype=jnp.float32)

    # ----------------------------------------------------------------------------------
    # ------------------- Instantiate Hazy Research's HiPPOs ---------------------------
    # ----------------------------------------------------------------------------------

    print(f"Creating Gu's HiPPO-{the_measure} LTI model with {alpha} transform")
    hr_hippo_lti = HRHiPPO_LTI(
        N=N,
        method=the_measure,
        dt=step,
        T=T,
        discretization=discretization,
        lambda_n=lambda_n,
        alpha=0.0,
        beta=1.0,
        c=0.0,
    )  # The Gu's

    # ----------------------------------------------------------------------------------
    # ------------------------------ Instantiate HiPPO Matrices ------------------------
    # ----------------------------------------------------------------------------------

    A, B, _ = init_fn(N=N, dtype=jnp.float32)

    # ----------------------------------------------------------------------------------
    # ------------------------------ Instantiate Our HiPPOs ---------------------------
    # ----------------------------------------------------------------------------------
    print(f"Creating HiPPO-{the_measure} LTI model with {alpha} transform")
    hippo_lti_cell = HiPPOLTICell(
        features=N,
        step_size=step,
        basis_size=T,
        alpha=alpha,
        recon=True,
        A_init_fn=legs,
        B_init_fn=legs,
    )
    c_t_1 = hippo_lti_cell.initialize_state(
        subkeys[7], batch_size=x_jnp.shape[0], hidden_size=N
    )
    lti_variables = hippo_lti_cell.init(subkeys[2], c_t_1=c_t_1, f=x_jnp)
    hippo_lti_cell = hippo_lti_cell.bind(lti_variables)

    print(f"Testing for correct LTI GBT matrices for HiPPO-{the_measure}")
    LTI_GBT(
        hippo=hippo_lti_cell,
        hr_hippo=hr_hippo_lti,
        A=A,
        B=B,
        random_input=x_np,
        step=step,
        alpha=alpha,
        print_all=print_all,
    )

In [15]:
print_all = False

### Testing (LTI and LSI) Operators With Forward Euler Transform

#### LegS

In [16]:
test_LSI_GBT(
    the_measure="legs", lambda_n=1.0, alpha=0.0, discretization=0.0, print_all=print_all
)

Creating Gu's HiPPO-legs LSI model with 0.0 transform
Creating HiPPO-legs LSI model with 0.0 transform


TypeError: mul got incompatible shapes for broadcasting: (1, 64, 1000), (2, 1000, 1).

In [None]:
test_LTI_GBT(
    the_measure="legs", lambda_n=1.0, alpha=0.0, discretization=0.0, print_all=print_all
)

#### LegT

In [None]:
test_LTI_GBT(
    the_measure="legt", lambda_n=1.0, alpha=0.0, discretization=0.0, print_all=print_all
)

Creating Gu's HiPPO-legt LTI model with 0.0 transform
Creating HiPPO-legt LTI model with 0.0 transform
Testing Coeffiecients for 0.0 LTI HiPPO-legt
Jax Duration: 0.21370410919189453
PyTorch Duration: 0.12743449211120605

------------------------------------------------------------
---------- The Coefficients Test Passed: True ----------
-------------------------------------------------------------


end of test for HiPPO-legt model


#### LMU

In [None]:
test_LTI_GBT(
    the_measure="lmu", lambda_n=2.0, alpha=0.0, discretization=0.0, print_all=print_all
)

Creating Gu's HiPPO-lmu LTI model with 0.0 transform
Creating HiPPO-lmu LTI model with 0.0 transform
Testing Coeffiecients for 0.0 LTI HiPPO-lmu
Jax Duration: 0.20223116874694824
PyTorch Duration: 0.12044692039489746

------------------------------------------------------------
---------- The Coefficients Test Passed: True ----------
-------------------------------------------------------------


end of test for HiPPO-lmu model


#### LagT

In [None]:
test_LTI_GBT(
    the_measure="lagt", lambda_n=1.0, alpha=0.0, discretization=0.0, print_all=print_all
)

Creating Gu's HiPPO-lagt LTI model with 0.0 transform
Creating HiPPO-lagt LTI model with 0.0 transform
Testing Coeffiecients for 0.0 LTI HiPPO-lagt
Jax Duration: 0.20954418182373047
PyTorch Duration: 0.12122273445129395

------------------------------------------------------------
---------- The Coefficients Test Passed: True ----------
-------------------------------------------------------------


end of test for HiPPO-lagt model


#### FRU

In [None]:
test_LTI_GBT(
    the_measure="fru", lambda_n=1.0, alpha=0.0, discretization=0.0, print_all=print_all
)

Creating Gu's HiPPO-fru LTI model with 0.0 transform
Creating HiPPO-fru LTI model with 0.0 transform
Testing Coeffiecients for 0.0 LTI HiPPO-fru
Jax Duration: 0.23551368713378906
PyTorch Duration: 0.12724924087524414

------------------------------------------------------------
---------- The Coefficients Test Passed: True ----------
-------------------------------------------------------------


end of test for HiPPO-fru model


#### FouT

In [None]:
test_LTI_GBT(
    the_measure="fout", lambda_n=1.0, alpha=0.0, discretization=0.0, print_all=print_all
)

Creating Gu's HiPPO-fout LTI model with 0.0 transform
Creating HiPPO-fout LTI model with 0.0 transform
Testing Coeffiecients for 0.0 LTI HiPPO-fout
Jax Duration: 0.26561474800109863
PyTorch Duration: 0.13262104988098145

------------------------------------------------------------
---------- The Coefficients Test Passed: True ----------
-------------------------------------------------------------


end of test for HiPPO-fout model


#### FouD

In [None]:
test_LTI_GBT(
    the_measure="foud", lambda_n=1.0, alpha=0.0, discretization=0.0, print_all=print_all
)

Creating Gu's HiPPO-foud LTI model with 0.0 transform
Creating HiPPO-foud LTI model with 0.0 transform
Testing Coeffiecients for 0.0 LTI HiPPO-foud
Jax Duration: 0.23000240325927734
PyTorch Duration: 0.12862062454223633

------------------------------------------------------------
---------- The Coefficients Test Passed: True ----------
-------------------------------------------------------------


end of test for HiPPO-foud model


### Testing (LTI and LSI) Operators With Backward Euler Transform

#### LegS

In [None]:
test_LSI_GBT(
    the_measure="legs", lambda_n=1.0, alpha=1.0, discretization=1.0, print_all=print_all
)

In [None]:
test_LTI_GBT(
    the_measure="legs", lambda_n=1.0, alpha=1.0, discretization=1.0, print_all=print_all
)

#### LegT

In [None]:
test_LTI_GBT(
    the_measure="legt", lambda_n=1.0, alpha=1.0, discretization=1.0, print_all=print_all
)

Creating Gu's HiPPO-legt LTI model with 1.0 transform
Creating HiPPO-legt LTI model with 1.0 transform
Testing Coeffiecients for 1.0 LTI HiPPO-legt
Jax Duration: 0.2201390266418457
PyTorch Duration: 0.12486433982849121

------------------------------------------------------------
---------- The Coefficients Test Passed: True ----------
-------------------------------------------------------------


end of test for HiPPO-legt model


#### LMU

In [None]:
test_LTI_GBT(
    the_measure="lmu", lambda_n=2.0, alpha=1.0, discretization=1.0, print_all=print_all
)

Creating Gu's HiPPO-lmu LTI model with 1.0 transform
Creating HiPPO-lmu LTI model with 1.0 transform
Testing Coeffiecients for 1.0 LTI HiPPO-lmu
Jax Duration: 0.2026219367980957
PyTorch Duration: 0.12890625

------------------------------------------------------------
---------- The Coefficients Test Passed: True ----------
-------------------------------------------------------------


end of test for HiPPO-lmu model


#### LagT

In [None]:
test_LTI_GBT(
    the_measure="lagt", lambda_n=1.0, alpha=1.0, discretization=1.0, print_all=print_all
)

Creating Gu's HiPPO-lagt LTI model with 1.0 transform
Creating HiPPO-lagt LTI model with 1.0 transform
Testing Coeffiecients for 1.0 LTI HiPPO-lagt
Jax Duration: 0.21983551979064941
PyTorch Duration: 0.12653017044067383

------------------------------------------------------------
---------- The Coefficients Test Passed: True ----------
-------------------------------------------------------------


end of test for HiPPO-lagt model


#### FRU

In [None]:
test_LTI_GBT(
    the_measure="fru", lambda_n=1.0, alpha=1.0, discretization=1.0, print_all=print_all
)

Creating Gu's HiPPO-fru LTI model with 1.0 transform
Creating HiPPO-fru LTI model with 1.0 transform
Testing Coeffiecients for 1.0 LTI HiPPO-fru
Jax Duration: 0.21619820594787598
PyTorch Duration: 0.12327194213867188

------------------------------------------------------------
---------- The Coefficients Test Passed: True ----------
-------------------------------------------------------------


end of test for HiPPO-fru model


#### FouT

In [None]:
test_LTI_GBT(
    the_measure="fout", lambda_n=1.0, alpha=1.0, discretization=1.0, print_all=print_all
)

Creating Gu's HiPPO-fout LTI model with 1.0 transform
Creating HiPPO-fout LTI model with 1.0 transform
Testing Coeffiecients for 1.0 LTI HiPPO-fout
Jax Duration: 0.28851747512817383
PyTorch Duration: 0.12918448448181152

------------------------------------------------------------
---------- The Coefficients Test Passed: True ----------
-------------------------------------------------------------


end of test for HiPPO-fout model


#### FouD

In [None]:
test_LTI_GBT(
    the_measure="foud", lambda_n=1.0, alpha=1.0, discretization=1.0, print_all=print_all
)

Creating Gu's HiPPO-foud LTI model with 1.0 transform
Creating HiPPO-foud LTI model with 1.0 transform
Testing Coeffiecients for 1.0 LTI HiPPO-foud
Jax Duration: 0.25328493118286133
PyTorch Duration: 0.12967562675476074

------------------------------------------------------------
---------- The Coefficients Test Passed: True ----------
-------------------------------------------------------------


end of test for HiPPO-foud model


### Testing (LTI and LSI) Operators With Bidirectional Transform

#### LegS

In [None]:
test_LSI_GBT(
    the_measure="legs", lambda_n=1.0, alpha=0.5, discretization=0.5, print_all=print_all
)

Creating Gu's HiPPO-legs LSI model with 0.5 transform
Creating HiPPO-legs LSI model with 0.5 transform


Testing Coeffiecients for 0.5 LSI HiPPO-legs
Jax Duration: 0.3303253650665283
PyTorch Duration: 0.13050270080566406

------------------------------------------------------------
---------- The Coefficients Test Passed: True ----------
-------------------------------------------------------------


end of test for HiPPO-legs model


In [None]:
test_LTI_GBT(
    the_measure="legs", lambda_n=1.0, alpha=0.5, discretization=0.5, print_all=print_all
)

Creating Gu's HiPPO-legs LTI model with 0.5 transform
Creating HiPPO-legs LTI model with 0.5 transform
Testing Coeffiecients for 0.5 LTI HiPPO-legs
Jax Duration: 0.22737574577331543
PyTorch Duration: 0.12870502471923828

------------------------------------------------------------
---------- The Coefficients Test Passed: True ----------
-------------------------------------------------------------


end of test for HiPPO-legs model


#### LegT

In [None]:
test_LTI_GBT(
    the_measure="legt", lambda_n=1.0, alpha=0.5, discretization=0.5, print_all=print_all
)

Creating Gu's HiPPO-legt LTI model with 0.5 transform
Creating HiPPO-legt LTI model with 0.5 transform
Testing Coeffiecients for 0.5 LTI HiPPO-legt
Jax Duration: 0.2138075828552246
PyTorch Duration: 0.13062429428100586

------------------------------------------------------------
---------- The Coefficients Test Passed: True ----------
-------------------------------------------------------------


end of test for HiPPO-legt model


#### LMU

In [None]:
test_LTI_GBT(
    the_measure="lmu", lambda_n=2.0, alpha=0.5, discretization=0.5, print_all=print_all
)

Creating Gu's HiPPO-lmu LTI model with 0.5 transform
Creating HiPPO-lmu LTI model with 0.5 transform
Testing Coeffiecients for 0.5 LTI HiPPO-lmu
Jax Duration: 0.203108549118042
PyTorch Duration: 0.12275242805480957

------------------------------------------------------------
---------- The Coefficients Test Passed: True ----------
-------------------------------------------------------------


end of test for HiPPO-lmu model


#### LagT

In [None]:
test_LTI_GBT(
    the_measure="lagt", lambda_n=1.0, alpha=0.5, discretization=0.5, print_all=print_all
)

Creating Gu's HiPPO-lagt LTI model with 0.5 transform
Creating HiPPO-lagt LTI model with 0.5 transform
Testing Coeffiecients for 0.5 LTI HiPPO-lagt
Jax Duration: 0.2168130874633789
PyTorch Duration: 0.1261742115020752

------------------------------------------------------------
---------- The Coefficients Test Passed: True ----------
-------------------------------------------------------------


end of test for HiPPO-lagt model


#### FRU

In [None]:
test_LTI_GBT(
    the_measure="fru", lambda_n=1.0, alpha=0.5, discretization=0.5, print_all=print_all
)

Creating Gu's HiPPO-fru LTI model with 0.5 transform
Creating HiPPO-fru LTI model with 0.5 transform
Testing Coeffiecients for 0.5 LTI HiPPO-fru
Jax Duration: 0.23087668418884277
PyTorch Duration: 0.12711548805236816

------------------------------------------------------------
---------- The Coefficients Test Passed: True ----------
-------------------------------------------------------------


end of test for HiPPO-fru model


#### FouT

In [None]:
test_LTI_GBT(
    the_measure="fout", lambda_n=1.0, alpha=0.5, discretization=0.5, print_all=print_all
)

Creating Gu's HiPPO-fout LTI model with 0.5 transform
Creating HiPPO-fout LTI model with 0.5 transform
Testing Coeffiecients for 0.5 LTI HiPPO-fout
Jax Duration: 0.25879812240600586
PyTorch Duration: 0.1258854866027832

------------------------------------------------------------
---------- The Coefficients Test Passed: True ----------
-------------------------------------------------------------


end of test for HiPPO-fout model


#### FouD

In [None]:
test_LTI_GBT(
    the_measure="foud", lambda_n=1.0, alpha=0.5, discretization=0.5, print_all=print_all
)

Creating Gu's HiPPO-foud LTI model with 0.5 transform
Creating HiPPO-foud LTI model with 0.5 transform
Testing Coeffiecients for 0.5 LTI HiPPO-foud
Jax Duration: 0.23973417282104492
PyTorch Duration: 0.12212300300598145

------------------------------------------------------------
---------- The Coefficients Test Passed: True ----------
-------------------------------------------------------------


end of test for HiPPO-foud model


### Testing (LTI and LSI) Operators With ZOH Transform

#### LegS

In [None]:
test_LSI_GBT(
    the_measure="legs",
    lambda_n=1.0,
    alpha=2.0,
    discretization="zoh",
    print_all=print_all,
)

Creating Gu's HiPPO-legs LSI model with 2.0 transform
Creating HiPPO-legs LSI model with 2.0 transform


Testing Coeffiecients for 2.0 LSI HiPPO-legs
Jax Duration: 0.3464968204498291
PyTorch Duration: 0.13298988342285156

------------------------------------------------------------
---------- The Coefficients Test Passed: True ----------
-------------------------------------------------------------


end of test for HiPPO-legs model


In [None]:
test_LTI_GBT(
    the_measure="legs",
    lambda_n=1.0,
    alpha=2.0,
    discretization="zoh",
    print_all=print_all,
)

Creating Gu's HiPPO-legs LTI model with 2.0 transform
Creating HiPPO-legs LTI model with 2.0 transform
Testing Coeffiecients for 2.0 LTI HiPPO-legs
Jax Duration: 0.22150254249572754
PyTorch Duration: 0.12716913223266602

------------------------------------------------------------
---------- The Coefficients Test Passed: True ----------
-------------------------------------------------------------


end of test for HiPPO-legs model


#### LegT

In [None]:
test_LTI_GBT(
    the_measure="legt",
    lambda_n=1.0,
    alpha=2.0,
    discretization="zoh",
    print_all=print_all,
)

Creating Gu's HiPPO-legt LTI model with 2.0 transform
Creating HiPPO-legt LTI model with 2.0 transform
Testing Coeffiecients for 2.0 LTI HiPPO-legt
Jax Duration: 0.22566962242126465
PyTorch Duration: 0.12636852264404297

------------------------------------------------------------
---------- The Coefficients Test Passed: True ----------
-------------------------------------------------------------


end of test for HiPPO-legt model


#### LMU

In [None]:
test_LTI_GBT(
    the_measure="lmu",
    lambda_n=2.0,
    alpha=2.0,
    discretization="zoh",
    print_all=print_all,
)

Creating Gu's HiPPO-lmu LTI model with 2.0 transform
Creating HiPPO-lmu LTI model with 2.0 transform
Testing Coeffiecients for 2.0 LTI HiPPO-lmu
Jax Duration: 0.2460160255432129
PyTorch Duration: 0.12760305404663086

------------------------------------------------------------
---------- The Coefficients Test Passed: True ----------
-------------------------------------------------------------


end of test for HiPPO-lmu model


#### LagT

In [None]:
test_LTI_GBT(
    the_measure="lagt",
    lambda_n=1.0,
    alpha=2.0,
    discretization="zoh",
    print_all=print_all,
)

Creating Gu's HiPPO-lagt LTI model with 2.0 transform
Creating HiPPO-lagt LTI model with 2.0 transform
Testing Coeffiecients for 2.0 LTI HiPPO-lagt
Jax Duration: 0.23780131340026855
PyTorch Duration: 0.1227273941040039

------------------------------------------------------------
---------- The Coefficients Test Passed: True ----------
-------------------------------------------------------------


end of test for HiPPO-lagt model


#### FRU

In [None]:
test_LTI_GBT(
    the_measure="fru",
    lambda_n=1.0,
    alpha=2.0,
    discretization="zoh",
    print_all=print_all,
)

Creating Gu's HiPPO-fru LTI model with 2.0 transform
Creating HiPPO-fru LTI model with 2.0 transform
Testing Coeffiecients for 2.0 LTI HiPPO-fru
Jax Duration: 0.24944567680358887
PyTorch Duration: 0.12150359153747559

------------------------------------------------------------
---------- The Coefficients Test Passed: True ----------
-------------------------------------------------------------


end of test for HiPPO-fru model


#### FouT

In [None]:
test_LTI_GBT(
    the_measure="fout",
    lambda_n=1.0,
    alpha=2.0,
    discretization="zoh",
    print_all=print_all,
)

Creating Gu's HiPPO-fout LTI model with 2.0 transform
Creating HiPPO-fout LTI model with 2.0 transform
Testing Coeffiecients for 2.0 LTI HiPPO-fout
Jax Duration: 0.28164148330688477
PyTorch Duration: 0.1184089183807373

------------------------------------------------------------
---------- The Coefficients Test Passed: True ----------
-------------------------------------------------------------


end of test for HiPPO-fout model


#### FouD

In [None]:
test_LTI_GBT(
    the_measure="foud",
    lambda_n=1.0,
    alpha=2.0,
    discretization="zoh",
    print_all=print_all,
)

Creating Gu's HiPPO-foud LTI model with 2.0 transform
Creating HiPPO-foud LTI model with 2.0 transform
Testing Coeffiecients for 2.0 LTI HiPPO-foud
Jax Duration: 0.2612276077270508
PyTorch Duration: 0.12821745872497559

------------------------------------------------------------
---------- The Coefficients Test Passed: True ----------
-------------------------------------------------------------


end of test for HiPPO-foud model
