# HiPPO Operator Minimal Test
---

## Load Packages

In [1]:
import os
import sys

module_path = os.path.abspath(os.path.join("../../../"))
print(f"module_path: {module_path}")
if module_path not in sys.path:
    print(f"Adding {module_path} to sys.path")
    sys.path.append(module_path)

module_path: /home/beegass/Documents/Coding/HiPPO-Jax


In [2]:
# os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "False"
os.environ["TF_FORCE_UNIFIED_MEMORY"] = "1"

In [3]:
## import packages
import jax
import jax.numpy as jnp
import einops
import numpy as np
import torch
import time
from jax import random
import flax.linen as nn
from jaxtyping import Array, Float
from scipy import special as ss
from typing import Any, Callable, List, Optional, Tuple, Union

  jax.tree_util.register_keypaths(data_clz, keypaths)
  jax.tree_util.register_keypaths(data_clz, keypaths)


In [4]:
from src.models.hippo.hr_hippo import HRHiPPO_LSI, HRHiPPO_LTI
from src.models.hippo.hippo import HiPPOLTI, HiPPOLSI
from src.models.hippo.cells import HiPPOLSICell, HiPPOLTICell, HiPPO
from src.models.hippo.transition import (
    legs,
    legs_initializer,
    legt,
    legt_initializer,
    lmu,
    lmu_initializer,
    lagt,
    lagt_initializer,
    fru,
    fru_initializer,
    fout,
    fout_initializer,
    foud,
    foud_initializer,
    chebt,
    chebt_initializer,
)
from src.data.process import whitesignal

In [5]:
print(jax.devices())
print(f"The Device: {jax.lib.xla_bridge.get_backend().platform}")

[StreamExecutorGpuDevice(id=0, process_index=0, slice_index=0)]
The Device: gpu


In [6]:
print(f"MPS enabled: {torch.backends.mps.is_available()}")

MPS enabled: False


In [7]:
torch.set_printoptions(linewidth=150)
np.set_printoptions(linewidth=150)
jnp.set_printoptions(linewidth=150)

In [8]:
seed = 1701
key = jax.random.PRNGKey(seed)

In [9]:
num_copies = 10
subkeys = jax.random.split(key, num=num_copies)
key = subkeys[0]

In [10]:
def test_hippo_operator(key, hippo, random_input, hidden_size, batch_size):
    x_jnp = jnp.asarray(random_input, dtype=jnp.float32)
    x_jnp = einops.rearrange(x_jnp, "batch seq_len -> batch seq_len 1")

    c_t_1 = hippo.initialize_state(
        subkeys[7], batch_size=batch_size, hidden_size=hidden_size
    )
    params = hippo.init(key, f=x_jnp, c_t_1=c_t_1)

    start = time.time()
    c, y = hippo.apply(params, f=x_jnp, c_t_1=c_t_1)
    # c, y = hippo.apply({'params': params}, f=x_jnp, c_t_1=c_t_1)
    end = time.time()

    duration = end - start
    print(f"Duration: {duration}")

In [11]:
def test_operators(the_measure="legs", alpha=0.5):
    T = 1
    freq = 1
    step = 1e-3
    L = int(T / step)

    batch_size = 2
    data_size = L
    input_size = 1

    N = 64

    u = whitesignal(subkeys[4], T, step, freq, batch_shape=(batch_size,))
    x_np = np.asarray(u)

    # ----------------------------------------------------------------------------------
    # ------------------------------ Instantiate My HiPPOs -----------------------------
    # ----------------------------------------------------------------------------------
    print(f"Creating HiPPO-{the_measure} LTI model with {alpha} transform")
    hippo_lti_cell = HiPPOLTICell
    h_args = {
        "step_size": step,
        "basis_size": T,
        "alpha": alpha,
        "recon": False,
        "A_init_fn": legs,
        "B_init_fn": legs,
        "measure": the_measure,
    }
    hippo_lti = HiPPO(
        features=N,
        hippo_cell=hippo_lti_cell,
        hippo_args=h_args,
        init_t=0,
        unroll=False,
    )

    print(f"Testing Coeffiecients for {alpha} LTI HiPPO-{the_measure}")

    test_hippo_operator(
        key=subkeys[5],
        hippo=hippo_lti,
        random_input=x_np,
        hidden_size=N,
        batch_size=batch_size,
    )

    print(f"end of test for HiPPO-{the_measure} model")

#### LegS

In [12]:
test_operators(the_measure="legs", alpha=0.0)

Creating HiPPO-legs LTI model with 0.0 transform
Testing Coeffiecients for 0.0 LTI HiPPO-legs


  scopes, treedef = jax.tree_flatten(scope_tree)
  leaves = jax.tree_leaves(x)
  lengths = set(jax.tree_leaves(lengths))
  in_avals, in_tree = jax.tree_flatten(input_avals)
  leaves = jax.tree_leaves(x)
  axis_sizes = set(jax.tree_leaves(axis_sizes))
  jax.tree_leaves(tree)))
  broadcast_in, constants_out = jax.tree_unflatten(out_tree(), out_flat)


Duration: 0.23715925216674805
end of test for HiPPO-legs model
