# HiPPO Matrices
---

## Table of Contents
* [Loading In Necessary Packages](#load-packages)
* [Instantiate The HiPPO Matrix](#instantiate-the-hippo-matrix)
    * [Translated Legendre (LegT)](#translated-legendre-legt)
        * [LegT](#legt)
        * [LMU](#lmu)
    * [Translated Laguerre (LagT)](#translated-laguerre-lagt)
    * [Scaled Legendre (LegS)](#scaled-legendre-legs)
    * [Fourier Basis](#fourier-basis)
        * [Fourier Recurrent Unit (FRU)](#fourier-recurrent-unit-fru)
        * [Truncated Fourier (FouT)](#truncated-fourier-fout)
        * [Fourier With Decay (FourD)](#fourier-with-decay-fourd)
* [Gu's Linear Time Invariant (LTI) HiPPO Operator](#gus-hippo-legt-operator)
* [Gu's Scale invariant (LSI) HiPPO Operator](#gus-scale-invariant-hippo-legs-operator)
* [Implementation Of General HiPPO Operator](#implementation-of-general-hippo-operator)
* [Test Generalized Bilinear Transform and Zero Order Hold Matrices](#test-generalized-bilinear-transform-and-zero-order-hold-matrices)
    * [Testing Forward Euler on GBT matrices](#testing-forward-euler-transform-for-lti-and-lsi)
    * [Testing Backward Euler on GBT matrices](#testing-backward-euler-transform-for-lti-and-lsi-on-legs-matrices)
    * [Testing Bidirectional on GBT matrices](#testing-lti-and-lsi-operators-with-bidirectional-transform)
    * [Testing ZOH on GBT matrices](#testing-zoh-transform-for-lti-and-lsi-on-legs-matrices)
* [Testing HiPPO Operators](#test-hippo-operators)
    * [Testing Forward Euler on HiPPO Operators](#testing-lti-and-lsi-operators-with-forward-euler-transform)
    * [Testing Backward Euler on HiPPO Operators](#testing-lti-and-lsi-operators-with-backward-euler-transform)
    * [Testing Bidirectional on HiPPO Operators](#testing-lti-and-lsi-operators-with-bidirectional-transform)
    * [Testing ZOH on HiPPO Operators](#testing-lti-and-lsi-operators-with-zoh-transform)
---


## Load Packages

In [1]:
import os
import sys

module_path = os.path.abspath(os.path.join("../../../"))
print(f"module_path: {module_path}")
if module_path not in sys.path:
    print(f"Adding {module_path} to sys.path")
    sys.path.append(module_path)

module_path: /home/beegass/Documents/Coding/s4mer


In [2]:
## import packages
import math
import time

import jax
import jax.numpy as jnp
from flax import linen as nn
from flax.training import train_state

import optax

import wandb
import hydra
from omegaconf import DictConfig, OmegaConf
import requests

from jaxtyping import Array, Float, Float16, Float32, Float64
from typing import Callable, List, Optional, Tuple, Any, Union

from src.data.process import moving_window, rolling_window

# import modules
from src.models.hippo.hippo import HiPPOLTI, HiPPOLSI
from src.models.hippo.transition import TransMatrix

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data as data
from torchvision import datasets, transforms

[StreamExecutorGpuDevice(id=0, process_index=0, slice_index=0)]
The Device: gpu


In [None]:
print(jax.devices())
print(f"The Device: {jax.lib.xla_bridge.get_backend().platform}")

In [None]:
print(f"MPS enabled: {torch.backends.mps.is_available()}")

In [4]:
torch.set_printoptions(linewidth=150)
np.set_printoptions(linewidth=150)
jnp.set_printoptions(linewidth=150)

In [12]:
def random_16_input(key_generator, batch_size=16, data_size=784, input_size=28):
    # x = jax.random.randint(key_generator, (batch_size, data_size), 0, 255)
    x = jax.random.uniform(key_generator, (batch_size, data_size))
    return np.asarray(jax.vmap(moving_window, in_axes=(0, None))(x, input_size))

In [None]:
def get_datasets(cfg):
    # download and transform train dataset
    train_loader = torch.utils.data.DataLoader(
        datasets.MNIST(
            "../../datasets/mnist_data",
            download=True,
            train=True,
            transform=transforms.Compose(
                [
                    transforms.ToTensor(),  # first, convert image to PyTorch tensor
                ]
            ),
        ),
        batch_size=cfg.training.params.batch_size,
        shuffle=True,
    )

    # download and transform test dataset
    test_loader = torch.utils.data.DataLoader(
        datasets.MNIST(
            "../../datasets/mnist_data",
            download=True,
            train=False,
            transform=transforms.Compose(
                [
                    transforms.ToTensor(),  # first, convert image to PyTorch tensor
                ]
            ),
        ),
        batch_size=cfg.training.params.batch_size,
        shuffle=True,
    )
    return train_loader, test_loader

In [None]:
def preprocess_data(cfg, data):
    # preprocess data
    x = None
    if cfg.data.dataset.preprocess_data == "flatten":
        x = data.cpu().detach().numpy()
        x = jnp.asarray(data, dtype=jnp.float32)
        x = jnp.squeeze(x, axis=1)
        x = vmap(jnp.ravel, in_axes=0)(x)
        x = vmap(moving_window, in_axes=(0, None))(x, cfg["training"]["input_length"])

    return x

In [None]:
def preprocess_labels(cfg, labels):
    # preprocess data
    y = None
    if cfg.data.dataset.preprocess_labels == "one hot":
        y = labels.cpu().detach().numpy()
        y = jnp.asarray(y, dtype=jnp.float32)
        # y = jax.nn.one_hot(y, 10, dtype=jnp.float32)

    return y

In [None]:
def pick_optim(cfg, model, params):

    tx = None
    if cfg.training.params.optim == "adam":
        tx = optax.adamw(
            learning_rate=cfg.training.params.lr,
            weight_decay=cfg.training.params.weight_decay,
        )
    elif cfg.training.params.optim == "sgd":
        tx = optax.sgd(learning_rate=cfg.training.params.lr)
    else:
        raise ValueError("Unknown optimizer")

    # tx_state = tx.init(params)
    # print(f"tx_state: {tx_state}")

    return train_state.TrainState.create(apply_fn=model.apply, params=params, tx=tx)
    # , opt_state=tx_state

In [None]:
@jax.jit
def apply_model(state, carry, data, labels):
    """Computes gradients, loss and accuracy for a single batch."""

    def loss_fn(params):
        # jax.debug.print("params:\n{params}", params=params)

        logits = state.apply_fn({"params": params}, carry=carry, input=data)

        one_hot = jax.nn.one_hot(labels, 10, dtype=jnp.float32)
        loss = jnp.mean(optax.softmax_cross_entropy(logits=logits, labels=one_hot))

        # jax.debug.print("logits:\n{logits}", logits=logits)
        # jax.debug.print("labels:\n{labels}", labels=labels)
        # jax.debug.print("one_hot:\n{one_hot}", one_hot=one_hot)
        # jax.debug.print("loss:\n{loss}", loss=loss)

        return loss, logits

    grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
    (loss, logits), grads = grad_fn(state.params)
    accuracy = jnp.mean(jnp.argmax(logits, -1) == labels)

    return grads, loss, accuracy

In [None]:
@jax.jit
def update_model(state, grads):
    return state.apply_gradients(grads=grads)

In [None]:
@hydra.main(config_path="../../config", config_name="config")
def recurrent_train(
    cfg: DictConfig,
) -> None:  # num_epochs, opt_state, net_type="RNN", train_key=None):
    """
    Implements a learning loop over epochs.

    Args:
        cfg: Hydra config

    Returns:
        None

    """
    with wandb.init(
        project="BeeGass-HiPPOs", entity="beegass", config=cfg
    ):  # initialize wandb project for logging

        # get keys for parameters
        seed = cfg.training.seed
        key = jax.random.PRNGKey(seed)

        
        num_copies = cfg.training.key_num
        subkeys = jax.random.split(key, num=num_copies)

        # get train and test datasets
        train_loader, test_loader = get_datasets(cfg)
        print(f"got dataset")

        # pick a model
        model, params, carry = pick_model(a_key, cfg)
        print(f"got model and params")

        # pick an optimizer
        state = pick_optim(cfg, model, params)
        print(f"got optimizer state")

        # pick a scheduler
        # TODO: implement choice of scheduler

        # pick a loss function
        # TODO: implement choice of loss function

        epoch_loss = []
        epoch_accuracy = []

        print(f"starting training loop")
        # Loop over the training epochs
        for epoch in range(cfg.training.params.num_epochs):
            start_time = time.time()
            for batch_id, (train_data, train_labels) in enumerate(train_loader):
                data = preprocess_data(cfg, train_data)
                labels = preprocess_labels(cfg, train_labels)
                # carry = model.initialize_carry(
                #     rng=subkey,
                #     batch_size=(cfg["training"]["batch_size"],),
                #     hidden_size=cfg["models"]["deep_rnn"]["hidden_size"],
                # )
                # grads, loss, accuracy = apply_model(
                #     state=state, carry=None, data=data, labels=labels
                # )
                grads, loss, accuracy = apply_model(
                    state=state,
                    carry=carry,
                    data=data,
                    labels=labels,
                )
                state = update_model(state, grads)
                epoch_loss.append(loss)
                epoch_accuracy.append(accuracy)

            # train loss and accuracy for current epoch
            train_loss = jnp.mean(jnp.array(epoch_loss))
            train_accuracy = jnp.mean(jnp.array(epoch_accuracy))
            wandb.log(
                {"train_loss": train_loss, "train_accuracy": train_accuracy}, step=epoch
            )

            epoch_test_loss = []
            epoch_test_accuracy = []

            for data, target in test_loader:
                data = preprocess_data(cfg, train_data)
                target = preprocess_labels(cfg, target)

                # test loss for current epoch
                # _, test_loss, test_accuracy = apply_model(
                #     state=state, carry=None, data=data, labels=target
                # )
                _, test_loss, test_accuracy = apply_model(
                    state=state,
                    carry=carry,
                    data=data,
                    labels=target,
                )
                epoch_test_loss.append(test_loss)
                epoch_test_accuracy.append(test_accuracy)

            test_epoch_loss = jnp.mean(jnp.array(epoch_test_loss))
            test_epoch_accuracy = jnp.mean(jnp.array(epoch_test_accuracy))
            wandb.log(
                {"test_loss": test_epoch_loss, "test_accuracy": test_epoch_accuracy},
                step=epoch,
            )

            epoch_time = time.time() - start_time
            print(f"Epoch {epoch + 1} in {epoch_time:.2f} sec")
            print(f"Train Loss: {train_loss:.4f}, Train Accuracy: {train_accuracy:.4f}")
            print(
                f"Test Loss: {test_epoch_loss:.4f}, Test Accuracy: {test_epoch_accuracy:.4f}"
            )

        return state

## Test HiPPO Reconstruction

In [13]:
def test_hippo_reconstruction(
    hippo, gu_hippo, random_input, key, s_or_t="lti", print_all=False
):
    x_tensor = torch.tensor(random_input, dtype=torch.float32)
    x_jnp = jnp.asarray(x_tensor, dtype=jnp.float32)  # convert torch array to jax array

    # My Implementation
    if print_all:
        print(
            f"------------------------------------------------------------------------------------------"
        )
        print(
            f"----------------------------My {s_or_t} Implementation Outputs----------------------------"
        )
        print(
            f"------------------------------------------------------------------------------------------"
        )
    params = hippo.init(key, f=x_jnp)
    hippo = hippo.bind(params)
    c = hippo.__call__(f=x_jnp)
    y = hippo.reconstruct(c)

    if s_or_t == "lsi":
        c = jnp.moveaxis(c, 0, 1)
        # y = jnp.moveaxis(y, 0, 1)

    mse = lambda y_hat, y: jnp.mean((y_hat - y) ** 2)
    batch_mse = jax.vmap(mse, in_axes=(0, 0))
    print(f"The Loss:\n {batch_mse(y, x_jnp)}\n\n\n")

    # Gu's HiPPO LegS
    if print_all:
        print(
            f"------------------------------------------------------------------------------------------"
        )
        print(
            f"---------------------------Gu's {s_or_t} Implementation Outputs---------------------------"
        )
        print(
            f"------------------------------------------------------------------------------------------"
        )
    x_tensor = torch.moveaxis(x_tensor, 0, 1)
    GU_c_s, GU_c_k = gu_hippo(x_tensor, fast=False)
    # print(f"GU_c_k shape:\n{GU_c_k.shape}")

    gu_y = None
    if s_or_t == "lsi":
        gu_y = gu_hippo.reconstruct(GU_c_k)
    elif s_or_t == "lti":
        gu_y = gu_hippo.reconstruct(GU_c_k)
    else:
        raise ValueError(
            f"s_or_t must be either 'lsi' or 'lti'. s_or_t is currently set to: {s_or_t}"
        )

    gu_c = jnp.asarray(GU_c_k, dtype=jnp.float32)  # convert torch array to jax array
    gu_y = jnp.asarray(gu_y, dtype=jnp.float32)  # convert torch array to jax array

    # gu_c = jnp.moveaxis(gu_c, 0, 1)
    # gu_y = jnp.moveaxis(gu_y, 0, 1)

    # print(f"gu_y shape:\n {gu_y.shape}")

    x_tensor = jnp.asarray(
        x_tensor, dtype=jnp.float32
    )  # convert torch array to jax array
    x_tensor = jnp.moveaxis(x_tensor, 0, 1)

    print(f"The Loss:\n {batch_mse(gu_y, x_tensor)}\n")

In [14]:
def test_reconstruction(
    the_measure="legs", lambda_n=1.0, alpha=0.5, discretization=0.5, print_all=False
):
    # N = 256
    # L = 128

    batch_size = 16
    data_size = 256
    input_size = 1

    N = 50
    L = data_size

    x_jnp = random_16_input(
        key_generator=subkeys[4],
        batch_size=batch_size,
        data_size=data_size,
        input_size=input_size,
    )
    x_np = np.asarray(x_jnp)

    x = torch.tensor(x_np, dtype=torch.float32)

    # ----------------------------------------------------------------------------------
    # ------------------------------ Instantiate Gu's HiPPOs ---------------------------
    # ----------------------------------------------------------------------------------

    print(f"Creating Gu's HiPPO-{the_measure} LTI model with {alpha} transform")
    gu_hippo_lti = gu_HiPPO_LTI(
        N=N,
        method=the_measure,
        dt=1.0,
        T=L,
        discretization=discretization,
        lambda_n=lambda_n,
        alpha=0.0,
        beta=1.0,
        c=0.0,
    )  # The Gu's

    if the_measure == "legs":
        print(f"Creating Gu's HiPPO-{the_measure} LSI model with {alpha} transform")
        gu_hippo_lsi = gu_HiPPO_LSI(
            N=N,
            method=the_measure,
            max_length=L,
            discretization=discretization,
            lambda_n=lambda_n,
            alpha=0.0,
            beta=1.0,
        )  # The Gu's

    # ----------------------------------------------------------------------------------
    # ------------------------------ Instantiate My HiPPOs -----------------------------
    # ----------------------------------------------------------------------------------
    print(f"\nTesting BRYANS HiPPO-{the_measure} model")

    matrices = TransMatrix(
        N=N,
        measure=the_measure,
        lambda_n=lambda_n,
        alpha=0.0,
        beta=1.0,
        dtype=jnp.float32,
    )

    A = matrices.A
    B = matrices.B

    print(f"Creating HiPPO-{the_measure} LTI model with {alpha} transform")
    hippo_lti = HiPPOLTI(
        N=N,
        step_size=1.0,
        lambda_n=lambda_n,
        alpha=0.0,
        beta=1.0,
        GBT_alpha=alpha,
        measure=the_measure,
        basis_size=L,
        dtype=jnp.float32,
        unroll=False,
    )  # Bryan's

    # hippo_lti = HiPPO(
    #     max_length=L,
    #     step_size=1.0,
    #     N=N,
    #     lambda_n=lambda_n,
    #     alpha=0.0,
    #     beta=1.0,
    #     GBT_alpha=alpha,
    #     measure=the_measure,
    #     s_t="lti",
    #     dtype=jnp.float32,
    #     unroll=False,
    # )  # Bryan's

    if the_measure == "legs":
        print(f"Creating HiPPO-{the_measure} LSI model with {alpha} transform")
        hippo_lsi = HiPPOLSI(
            N=N,
            max_length=L,
            step_size=1.0,
            lambda_n=lambda_n,
            alpha=0.0,
            beta=1.0,
            GBT_alpha=alpha,
            measure=the_measure,
            dtype=jnp.float32,
            unroll=False,
        )  # Bryan's
        # hippo_lsi = HiPPO(
        #     max_length=L,
        #     step_size=1.0,
        #     N=N,
        #     lambda_n=lambda_n,
        #     alpha=0.0,
        #     beta=1.0,
        #     GBT_alpha=alpha,
        #     measure=the_measure,
        #     s_t="lsi",
        #     dtype=jnp.float32,
        #     unroll=False,
        # )  # Bryan's

    # ----------------------------------------------------------------------------------
    # ------------------------------ Test HiPPO Operators ------------------------------
    # ----------------------------------------------------------------------------------

    print(f"Bryan's Coeffiecients for {alpha} LTI HiPPO-{the_measure}")

    test_hippo_reconstruction(
        hippo=hippo_lti,
        gu_hippo=gu_hippo_lti,
        random_input=x_np,
        key=subkeys[5],
        s_or_t="lti",
        print_all=print_all,
    )

    if the_measure == "legs":
        print(f"\n\nBryan's Coeffiecients for {alpha} LSI HiPPO-{the_measure}")

        test_hippo_reconstruction(
            hippo=hippo_lsi,
            gu_hippo=gu_hippo_lsi,
            random_input=x_np,
            key=subkeys[6],
            s_or_t="lsi",
            print_all=print_all,
        )

    print(f"end of test for HiPPO-{the_measure} model")

## Navigation To Table Of Contents
---
* [Table Of Contents](#table-of-contents)
* [Loading In Necessary Packages](#load-packages)
* [Instantiate The HiPPO Matrix](#instantiate-the-hippo-matrix)
* [Gu's Linear Time Invariant (LTI) HiPPO Operator](#gus-hippo-legt-operator)
* [Gu's Scale invariant (LSI) HiPPO Operator](#gus-scale-invariant-hippo-legs-operator)
* [Implementation Of General HiPPO Operator](#implementation-of-general-hippo-operator)
* [Test Generalized Bilinear Transform and Zero Order Hold Matrices](#test-generalized-bilinear-transform-and-zero-order-hold-matrices)
* [Testing HiPPO Operators](#test-hippo-operators)
---

In [15]:
print_all = False

### Testing (LTI and LSI) Operators With Forward Euler Transform

#### LegS

In [16]:
test_reconstruction(
    the_measure="legs", lambda_n=1.0, alpha=0.0, discretization=0.0, print_all=print_all
)

Creating Gu's HiPPO-legs LTI model with 0.0 transform
gu's vals: (256,)
gu's vals: (256,)
Creating Gu's HiPPO-legs LSI model with 0.0 transform

Testing BRYANS HiPPO-legs model
Creating HiPPO-legs LTI model with 0.0 transform
Creating HiPPO-legs LSI model with 0.0 transform
Bryan's Coeffiecients for 0.0 LTI HiPPO-legs
The Loss:
 [nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan]



The Loss:
 [nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan]



Bryan's Coeffiecients for 0.0 LSI HiPPO-legs
The Loss:
 [41.367485  18.531097  14.098469  30.345993  10.6641865  7.9826894  9.042797  63.674404  29.575207  22.450172   4.7379494 25.612314  15.653051
  8.503983  27.31075   44.781677 ]



The Loss:
 [40.91818   17.95709   13.634292  29.817364  10.300377   7.7191753  8.815974  63.2482    28.771494  22.10402    4.4629774 25.125237  14.982275
  8.256765  26.396832  43.420464 ]

end of test for HiPPO-legs model


#### LegT

In [17]:
test_reconstruction(
    the_measure="legt", lambda_n=1.0, alpha=0.0, discretization=0.0, print_all=print_all
)

Creating Gu's HiPPO-legt LTI model with 0.0 transform

Testing BRYANS HiPPO-legt model
Creating HiPPO-legt LTI model with 0.0 transform
Bryan's Coeffiecients for 0.0 LTI HiPPO-legt


  return binary_op(*args)


The Loss:
 [nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan]



The Loss:
 [nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan]

end of test for HiPPO-legt model


#### LMU

In [18]:
test_reconstruction(
    the_measure="lmu", lambda_n=2.0, alpha=0.0, discretization=0.0, print_all=print_all
)

Creating Gu's HiPPO-lmu LTI model with 0.0 transform

Testing BRYANS HiPPO-lmu model
Creating HiPPO-lmu LTI model with 0.0 transform
Bryan's Coeffiecients for 0.0 LTI HiPPO-lmu
The Loss:
 [nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan]



The Loss:
 [nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan]

end of test for HiPPO-lmu model


#### LagT

In [19]:
test_reconstruction(
    the_measure="lagt", lambda_n=1.0, alpha=0.0, discretization=0.0, print_all=print_all
)

Creating Gu's HiPPO-lagt LTI model with 0.0 transform

Testing BRYANS HiPPO-lagt model
Creating HiPPO-lagt LTI model with 0.0 transform
Bryan's Coeffiecients for 0.0 LTI HiPPO-lagt
The Loss:
 [inf inf inf inf inf inf inf inf inf inf inf inf inf inf inf inf]



The Loss:
 [2.7180708e+23 2.2506286e+26 4.5019242e+25 3.6529218e+25 1.0892659e+26 5.1358761e+25 2.0433805e+26 1.0698097e+24 2.2920454e+24 2.7865366e+24
 2.1316504e+25 9.0141790e+24 6.8558258e+24 3.8399939e+24 2.2452325e+25 1.1329689e+25]

end of test for HiPPO-lagt model


#### FRU

In [20]:
test_reconstruction(
    the_measure="fru", lambda_n=1.0, alpha=0.0, discretization=0.0, print_all=print_all
)

Creating Gu's HiPPO-fru LTI model with 0.0 transform

Testing BRYANS HiPPO-fru model
Creating HiPPO-fru LTI model with 0.0 transform
Bryan's Coeffiecients for 0.0 LTI HiPPO-fru
The Loss:
 [nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan]



The Loss:
 [nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan]

end of test for HiPPO-fru model


#### FouT

In [44]:
test_reconstruction(
    the_measure="fout", lambda_n=1.0, alpha=0.0, discretization=0.0, print_all=print_all
)

Creating Gu's HiPPO-fout LTI model with 0.0 transform

Testing BRYANS HiPPO-fout model
Creating HiPPO-fout LTI model with 0.0 transform
Bryan's Coeffiecients for 0.0 LTI HiPPO-fout
The Loss:
 [nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan]



The Loss:
 [nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan]

end of test for HiPPO-fout model


#### FouD

In [45]:
test_reconstruction(
    the_measure="foud", lambda_n=1.0, alpha=0.0, discretization=0.0, print_all=print_all
)

Creating Gu's HiPPO-foud LTI model with 0.0 transform

Testing BRYANS HiPPO-foud model
Creating HiPPO-foud LTI model with 0.0 transform
Bryan's Coeffiecients for 0.0 LTI HiPPO-foud
The Loss:
 [nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan]



The Loss:
 [nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan]

end of test for HiPPO-foud model


### Testing (LTI and LSI) Operators With Backward Euler Transform

#### LegS

In [23]:
test_reconstruction(
    the_measure="legs", lambda_n=1.0, alpha=1.0, discretization=1.0, print_all=print_all
)

Creating Gu's HiPPO-legs LTI model with 1.0 transform
gu's vals: (256,)
gu's vals: (256,)
Creating Gu's HiPPO-legs LSI model with 1.0 transform

Testing BRYANS HiPPO-legs model
Creating HiPPO-legs LTI model with 1.0 transform
Creating HiPPO-legs LSI model with 1.0 transform
Bryan's Coeffiecients for 1.0 LTI HiPPO-legs
The Loss:
 [0.24920948 0.26120606 0.2225034  0.2562787  0.22515053 0.21028472 0.24650134 0.27222782 0.22471839 0.27827442 0.27449894 0.3005273  0.28741568
 0.24308945 0.31829214 0.28291494]



The Loss:
 [0.10874012 0.08415962 0.0876613  0.09463326 0.08717994 0.08624279 0.12696809 0.09238221 0.08174874 0.08340041 0.08126472 0.09182876 0.09582479
 0.09089237 0.09624433 0.08324574]



Bryan's Coeffiecients for 1.0 LSI HiPPO-legs
The Loss:
 [0.09170212 0.08274928 0.08844691 0.0866254  0.09163544 0.08870362 0.09487514 0.09269952 0.07820393 0.08769968 0.08195163 0.081138   0.097323
 0.09436389 0.10034208 0.09119291]



The Loss:
 [0.07773688 0.07255538 0.07355011 0.08291766 0.

#### LegT

In [24]:
test_reconstruction(
    the_measure="legt", lambda_n=1.0, alpha=1.0, discretization=1.0, print_all=print_all
)

Creating Gu's HiPPO-legt LTI model with 1.0 transform

Testing BRYANS HiPPO-legt model
Creating HiPPO-legt LTI model with 1.0 transform
Bryan's Coeffiecients for 1.0 LTI HiPPO-legt
The Loss:
 [0.35142815 0.3455534  0.3226121  0.37903214 0.31867814 0.30122083 0.3527272  0.34916523 0.31786817 0.35646066 0.3277972  0.3211748  0.3205662
 0.36499843 0.33480752 0.3240801 ]



The Loss:
 [0.35142815 0.3455534  0.3226121  0.37903214 0.31867814 0.30122083 0.3527272  0.34916523 0.31786817 0.35646066 0.3277972  0.3211748  0.3205662
 0.36499843 0.33480752 0.3240801 ]

end of test for HiPPO-legt model


#### LMU

In [46]:
test_reconstruction(
    the_measure="lmu", lambda_n=2.0, alpha=1.0, discretization=1.0, print_all=print_all
)

Creating Gu's HiPPO-lmu LTI model with 1.0 transform

Testing BRYANS HiPPO-lmu model
Creating HiPPO-lmu LTI model with 1.0 transform
Bryan's Coeffiecients for 1.0 LTI HiPPO-lmu
The Loss:
 [0.35142815 0.3455534  0.3226121  0.37903214 0.31867814 0.30122083 0.3527272  0.34916523 0.31786817 0.35646066 0.3277972  0.3211748  0.3205662
 0.36499843 0.33480752 0.3240801 ]



The Loss:
 [0.35142815 0.3455534  0.3226121  0.37903214 0.31867814 0.30122083 0.3527272  0.34916523 0.31786817 0.35646066 0.3277972  0.3211748  0.3205662
 0.36499843 0.33480752 0.3240801 ]

end of test for HiPPO-lmu model


#### LagT

In [47]:
test_reconstruction(
    the_measure="lagt", lambda_n=1.0, alpha=1.0, discretization=1.0, print_all=print_all
)

Creating Gu's HiPPO-lagt LTI model with 1.0 transform

Testing BRYANS HiPPO-lagt model
Creating HiPPO-lagt LTI model with 1.0 transform
Bryan's Coeffiecients for 1.0 LTI HiPPO-lagt
The Loss:
 [2.97866306e+27 7.69596023e+26 2.34778854e+27 2.34798733e+27 1.69587412e+27 1.89873532e+27 3.40461727e+27 7.15103529e+26 4.11259534e+26
 1.05265504e+27 1.34903733e+26 1.87793130e+26 7.50029257e+25 3.75032490e+27 4.09923281e+27 3.00623784e+27]



The Loss:
 [0.34614652 0.3436889  0.32018572 0.37510973 0.31538126 0.29790556 0.34649736 0.34762293 0.31610107 0.35278565 0.32734016 0.31865036 0.3183095
 0.35929114 0.33028978 0.3190096 ]

end of test for HiPPO-lagt model


#### FRU

In [48]:
test_reconstruction(
    the_measure="fru", lambda_n=1.0, alpha=1.0, discretization=1.0, print_all=print_all
)

Creating Gu's HiPPO-fru LTI model with 1.0 transform

Testing BRYANS HiPPO-fru model
Creating HiPPO-fru LTI model with 1.0 transform
Bryan's Coeffiecients for 1.0 LTI HiPPO-fru
The Loss:
 [0.35103828 0.3443796  0.3208844  0.37567043 0.31828076 0.29978725 0.35313404 0.3464237  0.31523266 0.35390437 0.32718375 0.32085764 0.31820312
 0.36389875 0.3335151  0.32449648]



The Loss:
 [0.35103828 0.3443796  0.3208844  0.37567043 0.31828076 0.29978725 0.35313404 0.3464237  0.31523266 0.35390437 0.32718375 0.32085764 0.31820312
 0.36389875 0.3335151  0.3244965 ]

end of test for HiPPO-fru model


#### FouT

In [49]:
test_reconstruction(
    the_measure="fout", lambda_n=1.0, alpha=1.0, discretization=1.0, print_all=print_all
)

Creating Gu's HiPPO-fout LTI model with 1.0 transform

Testing BRYANS HiPPO-fout model
Creating HiPPO-fout LTI model with 1.0 transform
Bryan's Coeffiecients for 1.0 LTI HiPPO-fout
The Loss:
 [0.35106996 0.344378   0.32087767 0.37567848 0.31823134 0.29978877 0.3531564  0.34632966 0.31542528 0.35397243 0.3271905  0.32091826 0.31893167
 0.36397302 0.33368683 0.32445395]



The Loss:
 [0.35106996 0.344378   0.3208777  0.37567848 0.31823134 0.29978877 0.3531564  0.34632966 0.31542528 0.35397243 0.3271905  0.32091826 0.31893167
 0.36397302 0.3336868  0.32445395]

end of test for HiPPO-fout model


#### FouD

In [50]:
test_reconstruction(
    the_measure="foud", lambda_n=1.0, alpha=1.0, discretization=1.0, print_all=print_all
)

Creating Gu's HiPPO-foud LTI model with 1.0 transform

Testing BRYANS HiPPO-foud model
Creating HiPPO-foud LTI model with 1.0 transform
Bryan's Coeffiecients for 1.0 LTI HiPPO-foud
The Loss:
 [0.3510547  0.34435478 0.3209077  0.37571317 0.3182457  0.29978248 0.35311943 0.34622428 0.31510615 0.35376358 0.32711208 0.32100677 0.3177272
 0.36386395 0.3333422  0.32429516]



The Loss:
 [0.3510547  0.34435478 0.3209077  0.37571317 0.3182457  0.29978248 0.35311943 0.34622425 0.31510615 0.35376355 0.32711208 0.32100677 0.3177272
 0.36386395 0.3333422  0.32429516]

end of test for HiPPO-foud model


### Testing (LTI and LSI) Operators With Bidirectional Transform

#### LegS

In [51]:
test_reconstruction(
    the_measure="legs", lambda_n=1.0, alpha=0.5, discretization=0.5, print_all=print_all
)

Creating Gu's HiPPO-legs LTI model with 0.5 transform
gu's vals: (256,)
gu's vals: (256,)
Creating Gu's HiPPO-legs LSI model with 0.5 transform

Testing BRYANS HiPPO-legs model
Creating HiPPO-legs LTI model with 0.5 transform
Creating HiPPO-legs LSI model with 0.5 transform
Bryan's Coeffiecients for 0.5 LTI HiPPO-legs
The Loss:
 [ 40.985847    0.5634093   2.8889482   5.893869   12.014208   48.11052    36.70481    26.736866    2.832956   96.23154     0.7655915  77.20819
   2.3401718  34.16419   103.69086     9.367006 ]



The Loss:
 [1.0503573  1.2949177  0.14021397 0.3582669  1.0865629  0.64516515 0.50587565 0.39670065 0.8277695  2.0997803  0.21040663 2.4856312  0.22027926
 2.7459843  1.2234805  0.25088328]



Bryan's Coeffiecients for 0.5 LSI HiPPO-legs
The Loss:
 [0.10447852 0.09750983 0.1006453  0.09969212 0.10920918 0.09401403 0.10654827 0.09686814 0.09675886 0.11329925 0.08927324 0.09767537 0.11696188
 0.10941317 0.1298894  0.12789543]



The Loss:
 [0.07221566 0.06683332 0.068980

#### LegT

In [52]:
test_reconstruction(
    the_measure="legt", lambda_n=1.0, alpha=0.5, discretization=0.5, print_all=print_all
)

Creating Gu's HiPPO-legt LTI model with 0.5 transform

Testing BRYANS HiPPO-legt model
Creating HiPPO-legt LTI model with 0.5 transform
Bryan's Coeffiecients for 0.5 LTI HiPPO-legt
The Loss:
 [0.35142815 0.3455534  0.3226121  0.37903214 0.31867814 0.30122083 0.3527272  0.34916523 0.31786817 0.35646066 0.3277972  0.3211748  0.3205662
 0.36499843 0.33480752 0.3240801 ]



The Loss:
 [0.35142815 0.3455534  0.3226121  0.37903214 0.31867814 0.30122083 0.3527272  0.34916523 0.31786817 0.35646066 0.3277972  0.3211748  0.3205662
 0.36499843 0.33480752 0.3240801 ]

end of test for HiPPO-legt model


#### LMU

In [53]:
test_reconstruction(
    the_measure="lmu", lambda_n=2.0, alpha=0.5, discretization=0.5, print_all=print_all
)

Creating Gu's HiPPO-lmu LTI model with 0.5 transform

Testing BRYANS HiPPO-lmu model
Creating HiPPO-lmu LTI model with 0.5 transform
Bryan's Coeffiecients for 0.5 LTI HiPPO-lmu
The Loss:
 [0.35142815 0.3455534  0.3226121  0.37903214 0.31867814 0.30122083 0.3527272  0.34916523 0.31786817 0.35646066 0.3277972  0.3211748  0.3205662
 0.36499843 0.33480752 0.3240801 ]



The Loss:
 [0.35142815 0.3455534  0.3226121  0.37903214 0.31867814 0.30122083 0.3527272  0.34916523 0.31786817 0.35646066 0.3277972  0.3211748  0.3205662
 0.36499843 0.33480752 0.3240801 ]

end of test for HiPPO-lmu model


#### LagT

In [33]:
test_reconstruction(
    the_measure="lagt", lambda_n=1.0, alpha=0.5, discretization=0.5, print_all=print_all
)

Creating Gu's HiPPO-lagt LTI model with 0.5 transform

Testing BRYANS HiPPO-lagt model
Creating HiPPO-lagt LTI model with 0.5 transform
Bryan's Coeffiecients for 0.5 LTI HiPPO-lagt
The Loss:
 [          inf           inf           inf 5.0527295e+34 4.5679888e+34           inf           inf           inf 4.5509465e+35 3.0879517e+35
 3.5372083e+35           inf           inf           inf           inf           inf]



The Loss:
 [0.3595016  0.3442213  0.37153238 0.3943251  0.33519942 0.30110395 0.44138414 0.35445568 0.37508082 0.35947025 0.34108108 0.31860203 0.32243496
 0.362264   0.3312602  0.31949446]

end of test for HiPPO-lagt model


#### FRU

In [34]:
test_reconstruction(
    the_measure="fru", lambda_n=1.0, alpha=0.5, discretization=0.5, print_all=print_all
)

Creating Gu's HiPPO-fru LTI model with 0.5 transform

Testing BRYANS HiPPO-fru model
Creating HiPPO-fru LTI model with 0.5 transform
Bryan's Coeffiecients for 0.5 LTI HiPPO-fru
The Loss:
 [0.35118264 0.38082975 0.32966015 0.37876427 0.32509214 0.31076282 0.35634804 0.3444245  0.32797927 0.35355532 0.3306821  0.3208621  0.3228538
 0.3640939  0.3412339  0.34140104]



The Loss:
 [0.3511826  0.3808303  0.32966012 0.3787643  0.32509232 0.31076315 0.35634816 0.3444245  0.32797945 0.35355532 0.33068204 0.3208621  0.32285362
 0.36409396 0.34123397 0.34140077]

end of test for HiPPO-fru model


#### FouT

In [35]:
test_reconstruction(
    the_measure="fout", lambda_n=1.0, alpha=0.5, discretization=0.5, print_all=print_all
)

Creating Gu's HiPPO-fout LTI model with 0.5 transform

Testing BRYANS HiPPO-fout model
Creating HiPPO-fout LTI model with 0.5 transform
Bryan's Coeffiecients for 0.5 LTI HiPPO-fout
The Loss:
 [0.3535667  0.35146195 0.35392866 0.37577248 0.32537478 0.342341   0.35267413 0.34554935 0.33443934 0.35393754 0.3432759  0.32086352 0.34556466
 0.39044106 0.35431844 0.35615224]



The Loss:
 [0.35356656 0.35146213 0.35392857 0.37577245 0.3253752  0.34234062 0.3526743  0.34554917 0.33443916 0.3539375  0.3432762  0.32086352 0.3455655
 0.39044186 0.35431787 0.35615334]

end of test for HiPPO-fout model


#### FouD

In [36]:
test_reconstruction(
    the_measure="foud", lambda_n=1.0, alpha=0.5, discretization=0.5, print_all=print_all
)

Creating Gu's HiPPO-foud LTI model with 0.5 transform

Testing BRYANS HiPPO-foud model
Creating HiPPO-foud LTI model with 0.5 transform
Bryan's Coeffiecients for 0.5 LTI HiPPO-foud
The Loss:
 [0.35069418 0.35698712 0.32436812 0.3781004  0.31943154 0.30363774 0.35764936 0.34441382 0.3172199  0.3541609  0.32780588 0.32111537 0.3175849
 0.3674386  0.33515513 0.33352256]



The Loss:
 [0.35069418 0.356987   0.32436806 0.37810043 0.31943154 0.30363765 0.35764927 0.34441385 0.31722    0.35416102 0.32780588 0.3211154  0.31758484
 0.36743852 0.3351551  0.33352235]

end of test for HiPPO-foud model


### Testing (LTI and LSI) Operators With ZOH Transform

#### LegS

In [37]:
test_reconstruction(
    the_measure="legs",
    lambda_n=1.0,
    alpha=2.0,
    discretization="zoh",
    print_all=print_all,
)

Creating Gu's HiPPO-legs LTI model with 2.0 transform
gu's vals: (256,)
gu's vals: (256,)
Creating Gu's HiPPO-legs LSI model with 2.0 transform

Testing BRYANS HiPPO-legs model
Creating HiPPO-legs LTI model with 2.0 transform
Creating HiPPO-legs LSI model with 2.0 transform
Bryan's Coeffiecients for 2.0 LTI HiPPO-legs
The Loss:
 [0.24887547 0.26386017 0.22381255 0.2577321  0.23089343 0.2123164  0.24520826 0.27531803 0.21867877 0.29000908 0.274537   0.2946152  0.28397858
 0.2402501  0.34515768 0.29902413]



The Loss:
 [0.2927326  0.12845896 0.165967   0.10037367 0.21014486 0.278264   0.3281554  0.09506361 0.12716429 0.08225903 0.11849605 0.090541   0.10785306
 0.17675102 0.12504598 0.09001569]



Bryan's Coeffiecients for 2.0 LSI HiPPO-legs
The Loss:
 [0.10172817 0.09642398 0.09994817 0.0996581  0.10897219 0.09424824 0.1072233  0.09650478 0.092353   0.09804132 0.08860762 0.09561859 0.11546683
 0.10787295 0.11346029 0.11233236]



The Loss:
 [0.06824626 0.06779923 0.0704189  0.07547039 

#### LegT

In [38]:
test_reconstruction(
    the_measure="legt",
    lambda_n=1.0,
    alpha=2.0,
    discretization="zoh",
    print_all=print_all,
)

Creating Gu's HiPPO-legt LTI model with 2.0 transform

Testing BRYANS HiPPO-legt model
Creating HiPPO-legt LTI model with 2.0 transform
Bryan's Coeffiecients for 2.0 LTI HiPPO-legt
The Loss:
 [0.35142815 0.3455534  0.3226121  0.37903214 0.31867814 0.30122083 0.3527272  0.34916523 0.31786817 0.35646066 0.3277972  0.3211748  0.3205662
 0.36499843 0.33480752 0.3240801 ]



The Loss:
 [0.35142815 0.3455534  0.3226121  0.37903214 0.31867814 0.30122083 0.3527272  0.34916523 0.31786817 0.35646066 0.3277972  0.3211748  0.3205662
 0.36499843 0.33480752 0.3240801 ]

end of test for HiPPO-legt model


#### LMU

In [39]:
test_reconstruction(
    the_measure="lmu",
    lambda_n=2.0,
    alpha=2.0,
    discretization="zoh",
    print_all=print_all,
)

Creating Gu's HiPPO-lmu LTI model with 2.0 transform

Testing BRYANS HiPPO-lmu model
Creating HiPPO-lmu LTI model with 2.0 transform
Bryan's Coeffiecients for 2.0 LTI HiPPO-lmu
The Loss:
 [0.35142815 0.3455534  0.3226121  0.37903214 0.31867814 0.30122083 0.3527272  0.34916523 0.31786817 0.35646066 0.3277972  0.3211748  0.3205662
 0.36499843 0.33480752 0.3240801 ]



The Loss:
 [0.35142815 0.3455534  0.3226121  0.37903214 0.31867814 0.30122083 0.3527272  0.34916523 0.31786817 0.35646066 0.3277972  0.3211748  0.3205662
 0.36499843 0.33480752 0.3240801 ]

end of test for HiPPO-lmu model


#### LagT

In [40]:
test_reconstruction(
    the_measure="lagt",
    lambda_n=1.0,
    alpha=2.0,
    discretization="zoh",
    print_all=print_all,
)

Creating Gu's HiPPO-lagt LTI model with 2.0 transform

Testing BRYANS HiPPO-lagt model
Creating HiPPO-lagt LTI model with 2.0 transform
Bryan's Coeffiecients for 2.0 LTI HiPPO-lagt
The Loss:
 [9.1217859e+34 5.7362997e+33 1.0018595e+35 2.5804589e+35 2.4022569e+35 1.4176100e+34 7.2374545e+34 4.2535694e+34 6.4100862e+34 6.5418139e+34
 5.6368510e+33 9.4797193e+32 2.7605731e+34 2.0170646e+35 5.3546453e+34 2.1092401e+35]



The Loss:
 [0.34603003 0.3435721  0.32015926 0.37499648 0.3152003  0.2978233  0.3463508  0.3474422  0.31605712 0.35251835 0.32731292 0.31842488 0.31799686
 0.3592921  0.3303862  0.31895453]

end of test for HiPPO-lagt model


#### FRU

In [41]:
test_reconstruction(
    the_measure="fru",
    lambda_n=1.0,
    alpha=2.0,
    discretization="zoh",
    print_all=print_all,
)

Creating Gu's HiPPO-fru LTI model with 2.0 transform

Testing BRYANS HiPPO-fru model
Creating HiPPO-fru LTI model with 2.0 transform
Bryan's Coeffiecients for 2.0 LTI HiPPO-fru
The Loss:
 [0.35110626 0.34437096 0.32087862 0.3758601  0.31798816 0.299774   0.35316494 0.3458764  0.3154822  0.35367787 0.32720894 0.32093203 0.31910044
 0.36382112 0.3333275  0.323862  ]



The Loss:
 [0.35110626 0.34437096 0.3208786  0.3758601  0.31798816 0.299774   0.3531649  0.3458764  0.3154822  0.35367787 0.32720894 0.32093203 0.31910044
 0.36382112 0.3333275  0.323862  ]

end of test for HiPPO-fru model


#### FouT

In [42]:
test_reconstruction(
    the_measure="fout",
    lambda_n=1.0,
    alpha=2.0,
    discretization="zoh",
    print_all=print_all,
)

Creating Gu's HiPPO-fout LTI model with 2.0 transform

Testing BRYANS HiPPO-fout model
Creating HiPPO-fout LTI model with 2.0 transform
Bryan's Coeffiecients for 2.0 LTI HiPPO-fout
The Loss:
 [0.35099867 0.34436995 0.3208948  0.375654   0.31843498 0.29979926 0.35301557 0.34657085 0.3154192  0.3542355  0.32717574 0.32098946 0.31938225
 0.36411223 0.33431405 0.3247923 ]



The Loss:
 [0.35099867 0.34436995 0.3208948  0.375654   0.31843498 0.29979926 0.35301557 0.34657085 0.3154192  0.3542355  0.32717574 0.32098946 0.31938225
 0.36411223 0.33431405 0.3247923 ]

end of test for HiPPO-fout model


#### FouD

In [43]:
test_reconstruction(
    the_measure="foud",
    lambda_n=1.0,
    alpha=2.0,
    discretization="zoh",
    print_all=print_all,
)

Creating Gu's HiPPO-foud LTI model with 2.0 transform

Testing BRYANS HiPPO-foud model
Creating HiPPO-foud LTI model with 2.0 transform
Bryan's Coeffiecients for 2.0 LTI HiPPO-foud
The Loss:
 [0.35119998 0.34435833 0.32090923 0.375829   0.31804895 0.2997815  0.35316664 0.34592736 0.31527933 0.35360616 0.32716495 0.32088467 0.31803918
 0.36382782 0.33322173 0.3239077 ]



The Loss:
 [0.35119998 0.34435833 0.32090923 0.375829   0.31804898 0.2997815  0.35316664 0.34592736 0.31527933 0.35360616 0.32716495 0.32088467 0.31803915
 0.36382782 0.33322173 0.32390773]

end of test for HiPPO-foud model
