# HiPPO Matrices
---

## Table of Contents
* [Loading In Necessary Packages](#load-packages)
    * [Instantiate The HiPPO Matrix](#instantiate-the-hippo-matrix)
        * [Translated Legendre (LegT)](#translated-legendre-legt)
            * [LegT](#legt)
            * [LMU](#lmu)
        * [Translated Laguerre (LagT)](#translated-laguerre-lagt)
        * [Scaled Legendre (LegS)](#scaled-legendre-legs)
        * [Fourier Basis](#fourier-basis)
            * [Fourier Recurrent Unit (FRU)](#fourier-recurrent-unit-fru)
            * [Truncated Fourier (FouT)](#truncated-fourier-fout)
            * [Fourier With Decay (FourD)](#fourier-with-decay-fourd)
    * [Gu's HiPPO LegT Operator](#gus-hippo-legt-operator)
    * [Gu's Scale invariant HiPPO LegS Operator](#gus-scale-invariant-hippo-legs-operator)
    * [Implementation Of General HiPPO Operator](#implementation-of-general-hippo-operator)
    * [Output](#output)
---


## Load Packages

In [1]:
import os
import sys
module_path = os.path.abspath(os.path.join('../../../'))
if module_path not in sys.path:
    sys.path.append(module_path)

In [2]:
## import packages
import jax
import jax.numpy as jnp

from flax import linen as jnn

from jax.nn.initializers import lecun_normal, uniform
from jax.numpy.linalg import eig, inv, matrix_power
from jax.scipy.signal import convolve

# import modules 
from src.models.hippo.gu_transition import GuTransMatrix
from src.models.hippo.unroll import (
    measure,
    basis,
    variable_unroll_matrix,
    variable_unroll_matrix_sequential,
)
from src.data.process import moving_window, rolling_window


import requests

from scipy import linalg as la
from scipy import signal
from scipy import special as ss

import math

print(jax.devices())
print(f"The Device: {jax.lib.xla_bridge.get_backend().platform}")

[StreamExecutorGpuDevice(id=0, process_index=0, slice_index=0)]
The Device: gpu


In [3]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data as data

import numpy as np
import matplotlib.pyplot as plt
from einops import rearrange, repeat, reduce

from typing import Any
from functools import partial

print(f"MPS enabled: {torch.backends.mps.is_available()}")


MPS enabled: False


In [4]:
torch.set_printoptions(linewidth=150)
np.set_printoptions(linewidth=150)
jnp.set_printoptions(linewidth=150)

In [5]:
# N = 8


In [6]:
seed = 1701
key = jax.random.PRNGKey(seed)


In [7]:
num_copies = 5
rng, key2, key3, key4, key5 = jax.random.split(key, num=num_copies)


## Instantiate The HiPPO Matrix

In [8]:
class TransMatrix:
    def __init__(
        self,
        N: int,
        measure: str = "legs",
        lambda_n: float = 1.0,
        alpha: float = 0.0,
        beta: float = 1.0,
        dtype: Any = jnp.float32,
    ):
        """
        Instantiates the HiPPO matrix of a given order using a particular measure.
        Args:
            N (int): Order of coefficients to describe the orthogonal polynomial that is the HiPPO projection.
            v (str): choose between this repo's implementation or hazy research's implementation.
            measure (str):
                choose between
                    - HiPPO w/ Translated Legendre (LegT) - legt
                    - HiPPO w/ Translated Laguerre (LagT) - lagt
                    - HiPPO w/ Scaled Legendre (LegS) - legs
                    - HiPPO w/ Fourier basis
                        - FRU: Fourier Recurrent Unit - fru
                        - FouT: Translated Fourier - fout
                        - FourD: Fourier Decay - fourd
            lambda_n (int): The amount of tilt applied to the HiPPO-LegS basis, determines between LegS and LMU.
            alpha (float): The order of the Laguerre basis.
            beta (float): The scale of the Laguerre basis.

        Returns:
            A (jnp.ndarray): The HiPPO matrix multiplied by -1.
            B (jnp.ndarray): The other corresponding state space matrix.

        """
        A = None
        B = None
        if measure in ["legt", "lmu"]:
            A, B = self.build_LegT(N=N, lambda_n=lambda_n, dtype=dtype)

        elif measure == "lagt":
            A, B = self.build_LagT(alpha=alpha, beta=beta, N=N, dtype=dtype)

        elif measure == "legs":
            A, B = self.build_LegS(N=N, dtype=dtype)

        elif measure in ["fout", "fru", "foud"]:
            A, B = self.build_Fourier(N=N, fourier_type=measure, dtype=dtype)

        elif measure == "random":
            A = jnp.random.randn(N, N) / N
            B = jnp.random.randn(N, 1)

        elif measure == "diagonal":
            A = -jnp.diag(jnp.exp(jnp.random.randn(N)))
            B = jnp.random.randn(N, 1)

        else:
            raise ValueError("Invalid HiPPO type")

        self.A = (A.copy()).astype(dtype)
        self.B = (B.copy()).astype(dtype)

    # Translated Legendre (LegT) - vectorized
    @staticmethod
    def build_LegT(N, lambda_n=1, dtype=jnp.float32):
        """
        The, vectorized implementation of the, measure derived from the translated Legendre basis.

        Args:
            N (int): Order of coefficients to describe the orthogonal polynomial that is the HiPPO projection.
            legt_type (str): Choice between the two different tilts of basis.
                - legt: translated Legendre - 'legt'
                - lmu: Legendre Memory Unit - 'lmu'

        Returns:
            A (jnp.ndarray): The A HiPPO matrix.
            B (jnp.ndarray): The B HiPPO matrix.

        """
        q = jnp.arange(N, dtype=dtype)
        k, n = jnp.meshgrid(q, q)
        case = jnp.power(-1.0, (n - k))
        A = None
        B = None

        if lambda_n == 1:
            A_base = jnp.sqrt(2 * n + 1) * jnp.sqrt(2 * k + 1)
            pre_D = jnp.sqrt(jnp.diag(2 * q + 1))
            B = D = jnp.diag(pre_D)[:, None]
            A = jnp.where(
                k <= n, A_base, A_base * case
            )  # if n >= k, then case_2 * A_base is used, otherwise A_base

        elif lambda_n == 2:  # (jnp.sqrt(2*n+1) * jnp.power(-1, n)):
            A_base = 2 * n + 1
            B = jnp.diag((2 * q + 1) * jnp.power(-1, n))[:, None]
            A = jnp.where(
                k <= n, A_base * case, A_base
            )  # if n >= k, then case_2 * A_base is used, otherwise A_base

        return -A.astype(dtype), B.astype(dtype)

    # Translated Laguerre (LagT) - non-vectorized
    @staticmethod
    def build_LagT(alpha, beta, N, dtype=jnp.float32):
        """
        The, vectorized implementation of the, measure derived from the translated Laguerre basis.

        Args:
            alpha (float): The order of the Laguerre basis.
            beta (float): The scale of the Laguerre basis.
            N (int): Order of coefficients to describe the orthogonal polynomial that is the HiPPO projection.

        Returns:
            A (jnp.ndarray): The A HiPPO matrix.
            B (jnp.ndarray): The B HiPPO matrix.

        """
        L = jnp.exp(
            0.5
            * (ss.gammaln(jnp.arange(N) + alpha + 1) - ss.gammaln(jnp.arange(N) + 1))
        )
        inv_L = 1.0 / L[:, None]
        pre_A = (jnp.eye(N) * ((1 + beta) / 2)) + jnp.tril(jnp.ones((N, N)), -1)
        pre_B = ss.binom(alpha + jnp.arange(N), jnp.arange(N))[:, None]

        A = -inv_L * pre_A * L[None, :]
        B = (
            jnp.exp(-0.5 * ss.gammaln(1 - alpha))
            * jnp.power(beta, (1 - alpha) / 2)
            * inv_L
            * pre_B
        )

        return A.astype(dtype), B.astype(dtype)

    # Scaled Legendre (LegS) vectorized
    @staticmethod
    def build_LegS(N, dtype=jnp.float32):
        """
        The, vectorized implementation of the, measure derived from the Scaled Legendre basis.

        Args:
            N (int): Order of coefficients to describe the orthogonal polynomial that is the HiPPO projection.

        Returns:
            A (jnp.ndarray): The A HiPPO matrix.
            B (jnp.ndarray): The B HiPPO matrix.

        """
        q = jnp.arange(N, dtype=dtype)
        k, n = jnp.meshgrid(q, q)
        pre_D = jnp.sqrt(jnp.diag(2 * q + 1))
        B = D = jnp.diag(pre_D)[:, None]

        A_base = jnp.sqrt(2 * n + 1) * jnp.sqrt(2 * k + 1)

        A = jnp.where(n > k, A_base, jnp.where(n == k, n + 1, 0.0))

        return -A.astype(dtype), B.astype(dtype)

    # Fourier Basis OPs and functions - vectorized
    @staticmethod
    def build_Fourier(N, fourier_type="fru", dtype=jnp.float32):
        """
        Vectorized measure implementations derived from fourier basis.

        Args:
            N (int): Order of coefficients to describe the orthogonal polynomial that is the HiPPO projection.
            fourier_type (str): The type of Fourier measure.
                - FRU: Fourier Recurrent Unit - fru
                - FouT: truncated Fourier - fout
                - fouD: decayed Fourier - foud

        Returns:
            A (jnp.ndarray): The A HiPPO matrix.
            B (jnp.ndarray): The B HiPPO matrix.

        """
        A = jnp.diag(
            jnp.stack([jnp.zeros(N // 2), jnp.zeros(N // 2)], axis=-1).reshape(-1)[1:],
            1,
        )
        B = jnp.zeros(A.shape[1], dtype=dtype)

        B = B.at[0::2].set(jnp.sqrt(2))
        B = B.at[0].set(1)

        q = jnp.arange(A.shape[1], dtype=dtype)
        k, n = jnp.meshgrid(q, q)

        n_odd = n % 2 == 0
        k_odd = k % 2 == 0

        case_1 = (n == k) & (n == 0)
        case_2_3 = ((k == 0) & (n_odd)) | ((n == 0) & (k_odd))
        case_4 = (n_odd) & (k_odd)
        case_5 = (n - k == 1) & (k_odd)
        case_6 = (k - n == 1) & (n_odd)

        if fourier_type == "fru":  # Fourier Recurrent Unit (FRU) - vectorized
            A = jnp.where(
                case_1,
                -1.0,
                jnp.where(
                    case_2_3,
                    -jnp.sqrt(2),
                    jnp.where(
                        case_4,
                        -2,
                        jnp.where(
                            case_5,
                            jnp.pi * (n // 2),
                            jnp.where(case_6, -jnp.pi * (k // 2), 0.0),
                        ),
                    ),
                ),
            )

        elif fourier_type == "fout":  # truncated Fourier (FouT) - vectorized
            A = jnp.where(
                case_1,
                -1.0,
                jnp.where(
                    case_2_3,
                    -jnp.sqrt(2),
                    jnp.where(
                        case_4,
                        -2,
                        jnp.where(
                            case_5,
                            jnp.pi * (n // 2),
                            jnp.where(case_6, -jnp.pi * (k // 2), 0.0),
                        ),
                    ),
                ),
            )

            A = 2 * A
            B = 2 * B

        elif fourier_type == "foud":
            A = jnp.where(
                case_1,
                -1.0,
                jnp.where(
                    case_2_3,
                    -jnp.sqrt(2),
                    jnp.where(
                        case_4,
                        -2,
                        jnp.where(
                            case_5,
                            2 * jnp.pi * (n // 2),
                            jnp.where(case_6, 2 * -jnp.pi * (k // 2), 0.0),
                        ),
                    ),
                ),
            )

            A = 0.5 * A
            B = 0.5 * B

        B = B[:, None]

        return A.astype(dtype), B.astype(dtype)

## Translated Legendre (LegT)

### LegT

In [9]:
def test_LegT():
    legt_matrices = TransMatrix(N=8, measure="legt", lambda_n=1.0)
    A, B = legt_matrices.A, legt_matrices.B
    gu_legt_matrices = GuTransMatrix(N=8, measure="legt", lambda_n=1.0)
    gu_A, gu_B = gu_legt_matrices.A, gu_legt_matrices.B
    print(f"A:\n", A)
    print(f"Gu's A:\n", gu_A)
    print(f"B:\n", B)
    print(f"Gu's B:\n", gu_B)
    assert jnp.allclose(A, gu_A)
    assert jnp.allclose(B, gu_B)
    

In [10]:
test_LegT()

A:
 [[ -1.          1.7320508  -2.2360678   2.6457512  -3.          3.3166246  -3.6055512   3.8729832]
 [ -1.7320508  -3.          3.872983   -4.5825753   5.196152   -5.744562    6.244998   -6.708204 ]
 [ -2.2360678  -3.872983   -4.999999    5.916079   -6.7082033   7.4161973  -8.062257    8.660253 ]
 [ -2.6457512  -4.5825753  -5.916079   -6.9999995   7.937254   -8.774963    9.5393915 -10.24695  ]
 [ -3.         -5.196152   -6.7082033  -7.937254   -9.          9.949874  -10.816654   11.61895  ]
 [ -3.3166246  -5.744562   -7.4161973  -8.774963   -9.949874  -10.999999   11.958261  -12.845232 ]
 [ -3.6055512  -6.244998   -8.062257   -9.5393915 -10.816654  -11.958261  -13.         13.964239 ]
 [ -3.8729832  -6.708204   -8.660253  -10.24695   -11.61895   -12.845232  -13.964239  -14.999999 ]]
Gu's A:
 [[ -1.          1.7320508  -2.2360678   2.6457512  -3.          3.3166246  -3.6055512   3.8729832]
 [ -1.7320508  -3.          3.872983   -4.5825753   5.196152   -5.744562    6.244998   -6.70820

### LMU

In [11]:
def test_LMU():
    lmu_matrices = TransMatrix(
        N=8, measure="legt", lambda_n=2.0
    )  # change lambda so resulting matrix is in the form of LMU
    A, B = lmu_matrices.A, lmu_matrices.B
    gu_lmu_matrices = GuTransMatrix(
        N=8, measure="legt", lambda_n=2.0
    )  # change lambda so resulting matrix is in the form of LMU
    gu_A, gu_B = gu_lmu_matrices.A, gu_lmu_matrices.B
    print(f"A:\n", A)
    print(f"Gu's A:\n", gu_A)
    print(f"B:\n", B)
    print(f"Gu's B:\n", gu_B)
    assert jnp.allclose(A, gu_A)
    assert jnp.allclose(B, gu_B)
    

In [12]:
test_LMU()

A:
 [[ -1.  -1.  -1.  -1.  -1.  -1.  -1.  -1.]
 [  3.  -3.  -3.  -3.  -3.  -3.  -3.  -3.]
 [ -5.   5.  -5.  -5.  -5.  -5.  -5.  -5.]
 [  7.  -7.   7.  -7.  -7.  -7.  -7.  -7.]
 [ -9.   9.  -9.   9.  -9.  -9.  -9.  -9.]
 [ 11. -11.  11. -11.  11. -11. -11. -11.]
 [-13.  13. -13.  13. -13.  13. -13. -13.]
 [ 15. -15.  15. -15.  15. -15.  15. -15.]]
Gu's A:
 [[ -1.  -1.  -1.  -1.  -1.  -1.  -1.  -1.]
 [  3.  -3.  -3.  -3.  -3.  -3.  -3.  -3.]
 [ -5.   5.  -5.  -5.  -5.  -5.  -5.  -5.]
 [  7.  -7.   7.  -7.  -7.  -7.  -7.  -7.]
 [ -9.   9.  -9.   9.  -9.  -9.  -9.  -9.]
 [ 11. -11.  11. -11.  11. -11. -11. -11.]
 [-13.  13. -13.  13. -13.  13. -13. -13.]
 [ 15. -15.  15. -15.  15. -15.  15. -15.]]
B:
 [[  1.]
 [ -3.]
 [  5.]
 [ -7.]
 [  9.]
 [-11.]
 [ 13.]
 [-15.]]
Gu's B:
 [[  1.]
 [ -3.]
 [  5.]
 [ -7.]
 [  9.]
 [-11.]
 [ 13.]
 [-15.]]


## Translated Laguerre (LagT)

In [13]:
def test_LagT():
    lagt_matrices = TransMatrix(
        N=8,
        measure="lagt",
        alpha=0.0,  # change resulting tilt through alpha and beta
        beta=1.0,
    )  # change resulting tilt through alpha and beta
    A, B = lagt_matrices.A, lagt_matrices.B
    gu_lagt_matrices = GuTransMatrix(
        N=8,
        measure="lagt",
        alpha=0.0,  # change resulting tilt through alpha and beta
        beta=1.0,
    )  # change resulting tilt through alpha and beta
    gu_A, gu_B = gu_lagt_matrices.A, gu_lagt_matrices.B
    print(f"A:\n", A)
    print(f"Gu's A:\n", gu_A)
    print(f"B:\n", B)
    print(f"Gu's B:\n", gu_B)
    assert jnp.allclose(A, gu_A)
    assert jnp.allclose(B, gu_B)

In [14]:
test_LagT()

A:
 [[-1.         -0.         -0.         -0.         -0.         -0.         -0.         -0.        ]
 [-1.         -1.         -0.         -0.         -0.         -0.         -0.         -0.        ]
 [-1.         -1.         -1.         -0.         -0.         -0.         -0.         -0.        ]
 [-1.         -1.         -1.         -1.         -0.         -0.         -0.         -0.        ]
 [-1.         -1.         -1.         -1.         -1.         -0.         -0.         -0.        ]
 [-1.         -1.         -1.         -1.         -1.         -1.         -0.         -0.        ]
 [-1.         -1.         -1.         -1.         -1.         -1.         -1.         -0.        ]
 [-0.99999976 -0.99999976 -0.99999976 -0.99999976 -0.99999976 -0.99999976 -0.99999976 -1.        ]]
Gu's A:
 [[-1.         -0.         -0.         -0.         -0.         -0.         -0.         -0.        ]
 [-1.         -1.         -0.         -0.         -0.         -0.         -0.         -0.      

## Scaled Legendre (LegS)

In [15]:
def test_LegS():
    legs_matrices = TransMatrix(N=8, measure="legs")
    A, B = legs_matrices.A, legs_matrices.B
    gu_legs_matrices = GuTransMatrix(N=8, measure="legs")
    gu_A, gu_B = gu_legs_matrices.A, gu_legs_matrices.B
    print(f"A:\n", A)
    print(f"Gu's A:\n", gu_A)
    print(f"B:\n", B)
    print(f"Gu's B:\n", gu_B)
    assert jnp.allclose(A, gu_A)
    assert jnp.allclose(B, gu_B)

In [16]:
test_LegS()

A:
 [[ -1.         -0.         -0.         -0.         -0.         -0.         -0.         -0.       ]
 [ -1.7320508  -2.         -0.         -0.         -0.         -0.         -0.         -0.       ]
 [ -2.2360678  -3.872983   -3.         -0.         -0.         -0.         -0.         -0.       ]
 [ -2.6457512  -4.5825753  -5.916079   -4.         -0.         -0.         -0.         -0.       ]
 [ -3.         -5.196152   -6.7082033  -7.937254   -5.         -0.         -0.         -0.       ]
 [ -3.3166246  -5.744562   -7.4161973  -8.774963   -9.949874   -6.         -0.         -0.       ]
 [ -3.6055512  -6.244998   -8.062257   -9.5393915 -10.816654  -11.958261   -7.         -0.       ]
 [ -3.8729832  -6.708204   -8.660253  -10.24695   -11.61895   -12.845232  -13.964239   -8.       ]]
Gu's A:
 [[ -1.          0.          0.          0.          0.          0.          0.          0.       ]
 [ -1.7320508  -1.9999999   0.          0.          0.          0.          0.          0.     

## Fourier Basis

### Fourier Recurrent Unit (FRU)

In [17]:
def test_FRU():
    fru_matrices = TransMatrix(N=8, measure="fru")
    A, B = fru_matrices.A, fru_matrices.B
    gu_fru_matrices = GuTransMatrix(N=8, measure="fru")
    gu_A, gu_B = gu_fru_matrices.A, gu_fru_matrices.B
    print(f"A:\n", A)
    print(f"Gu's A:\n", gu_A)
    print(f"B:\n", B)
    print(f"Gu's B:\n", gu_B)
    assert jnp.allclose(A, gu_A)
    assert jnp.allclose(B, gu_B)

In [18]:
test_FRU()

A:
 [[-1.        -0.        -1.4142135  0.        -1.4142135  0.        -1.4142135  0.       ]
 [ 0.         0.         0.         0.         0.         0.         0.         0.       ]
 [-1.4142135  0.        -2.        -3.1415927 -2.         0.        -2.         0.       ]
 [ 0.         0.         3.1415927  0.         0.         0.         0.         0.       ]
 [-1.4142135  0.        -2.         0.        -2.        -6.2831855 -2.         0.       ]
 [ 0.         0.         0.         0.         6.2831855  0.         0.         0.       ]
 [-1.4142135  0.        -2.         0.        -2.         0.        -2.        -9.424778 ]
 [ 0.         0.         0.         0.         0.         0.         9.424778   0.       ]]
Gu's A:
 [[-1.         0.        -1.4142135  0.        -1.4142135  0.        -1.4142135  0.       ]
 [ 0.         0.         0.         0.         0.         0.         0.         0.       ]
 [-1.4142135  0.        -1.9999999 -3.1415927 -1.9999999  0.        -1.99999

### Truncated Fourier (FouT)

In [19]:
def test_FouT():
    fout_matrices = TransMatrix(N=8, measure="fout")
    A, B = fout_matrices.A, fout_matrices.B
    gu_fout_matrices = GuTransMatrix(N=8, measure="fout")
    gu_A, gu_B = gu_fout_matrices.A, gu_fout_matrices.B
    print(f"A:\n", A)
    print(f"Gu's A:\n", gu_A)
    print(f"B:\n", B)
    print(f"Gu's B:\n", gu_B)
    assert jnp.allclose(A, gu_A)
    assert jnp.allclose(B, gu_B)

In [20]:
test_FouT()

A:
 [[ -2.         -0.         -2.828427    0.         -2.828427    0.         -2.828427    0.       ]
 [  0.          0.          0.          0.          0.          0.          0.          0.       ]
 [ -2.828427    0.         -4.         -6.2831855  -4.          0.         -4.          0.       ]
 [  0.          0.          6.2831855   0.          0.          0.          0.          0.       ]
 [ -2.828427    0.         -4.          0.         -4.        -12.566371   -4.          0.       ]
 [  0.          0.          0.          0.         12.566371    0.          0.          0.       ]
 [ -2.828427    0.         -4.          0.         -4.          0.         -4.        -18.849556 ]
 [  0.          0.          0.          0.          0.          0.         18.849556    0.       ]]
Gu's A:
 [[ -2.          0.         -2.828427    0.         -2.828427    0.         -2.828427    0.       ]
 [  0.          0.          0.          0.          0.          0.          0.          0.     

### Fourier With Decay (FourD)

In [21]:
def test_FouD():
    the_measure = "foud"
    foud_matrices = TransMatrix(N=8, measure="foud")
    A, B = foud_matrices.A, foud_matrices.B
    gu_foud_matrices = GuTransMatrix(N=8, measure="foud")
    gu_A, gu_B = gu_foud_matrices.A, gu_foud_matrices.B
    print(f"A:\n", A)
    print(f"Gu's A:\n", gu_A)
    print(f"B:\n", B)
    print(f"Gu's B:\n", gu_B)
    assert jnp.allclose(A, gu_A)
    assert jnp.allclose(B, gu_B)

In [22]:
test_FouD()

A:
 [[-0.5        -0.         -0.70710677  0.         -0.70710677  0.         -0.70710677  0.        ]
 [ 0.          0.          0.          0.          0.          0.          0.          0.        ]
 [-0.70710677  0.         -1.         -3.1415927  -1.          0.         -1.          0.        ]
 [ 0.          0.          3.1415927   0.          0.          0.          0.          0.        ]
 [-0.70710677  0.         -1.          0.         -1.         -6.2831855  -1.          0.        ]
 [ 0.          0.          0.          0.          6.2831855   0.          0.          0.        ]
 [-0.70710677  0.         -1.          0.         -1.          0.         -1.         -9.424778  ]
 [ 0.          0.          0.          0.          0.          0.          9.424778    0.        ]]
Gu's A:
 [[-0.5         0.         -0.70710677  0.         -0.70710677  0.         -0.70710677  0.        ]
 [ 0.          0.          0.          0.          0.          0.          0.          0.      

## Gu's HiPPO LegT Operator

In [23]:
class HiPPO_LTI(nn.Module):
    """Linear time invariant x' = Ax + Bu"""

    def __init__(
        self,
        N,
        method="legt",
        dt=1.0,
        T=1.0,
        discretization=0.5,
        lambda_n=1.0,
        alpha=0.0,
        beta=1.0,
        c=0.0,
    ):
        """
        N: the order of the HiPPO projection
        dt: discretization step size - should be roughly inverse to the length of the sequence
        """
        super().__init__()

        self.method = method
        self.N = N
        self.dt = dt
        self.T = T
        self.c = c

        matrices = GuTransMatrix(
            N=N, measure=method, lambda_n=lambda_n, alpha=alpha, beta=beta
        )
        A = np.asarray(matrices.A, dtype=np.float32)
        B = np.asarray(matrices.B, dtype=np.float32)
        # A, B = transition(method, N)
        A = A + np.eye(N) * c
        self.A = A
        self.B = B.squeeze(-1)
        self.measure_fn = measure(method)

        C = np.ones((1, N))
        D = np.zeros((1,))
        if type(discretization) in [float, int]:
            dA, dB, _, _, _ = signal.cont2discrete(
                (A, B, C, D), dt=dt, method="gbt", alpha=discretization
            )
        else:
            dA, dB, _, _, _ = signal.cont2discrete((A, B, C, D), dt=dt, method="zoh")

        dB = dB.squeeze(-1)

        self.dA = torch.Tensor(dA.copy())  # (N, N)
        self.dB = torch.Tensor(dB.copy())  # (N, )

        self.vals = np.arange(0.0, T, dt)
        self.eval_matrix = basis(self.method, self.N, self.vals, c=self.c)  # (T/dt, N)
        self.measure = measure(self.method)(self.vals)

    def forward(self, inputs, fast=True):
        """
        inputs : (length, ...)
        output : (length, ..., N) where N is the order of the HiPPO projection
        """

        inputs = inputs.unsqueeze(-1)
        u = inputs * self.dB  # (length, ..., N)

        if fast:
            dA = repeat(self.dA, "m n -> l m n", l=u.size(0))
            return variable_unroll_matrix(dA, u)

        c = torch.zeros(u.shape[1:]).to(inputs)
        cs = []
        for f in inputs:
            
            # print(f"dA shape:\n{self.dA.shape}")
            # print(f"dA:\n{self.dA}")

            # print(f"c shape:\n{c.shape}")
            # print(f"c:\n{c}")

            # print(f"dB shape:\n{self.dB.shape}")
            # print(f"dB:\n{self.dB}")

            # print(f"f shape:\n{f.shape}")
            # print(f"f:\n{f}")
            
            part1 = F.linear(c, self.dA)
            part2 = self.dB * f
            
            c = part1 + part2

            # print(f"part1 shape:\n{part1.shape}")
            # print(f"part1 :\n{part1}")

            # print(f"part2 shape:\n{part2.shape}")
            # print(f"part2:\n{part2}")
            
            cs.append(c)
        return torch.stack(cs, dim=0)

    def reconstruct(
        self, c, evals=None
    ):  # TODO take in a times array for reconstruction
        """
        c: (..., N,) HiPPO coefficients (same as x(t) in S4 notation)
        output: (..., L,)
        """
        if evals is not None:
            eval_matrix = basis(self.method, self.N, evals)
        else:
            eval_matrix = self.eval_matrix

        m = self.measure[self.measure != 0.0]

        c = c.unsqueeze(-1)
        y = eval_matrix.to(c) @ c
        return y.squeeze(-1).flip(-1)


## Gu's Scale invariant HiPPO LegS Operator

In [24]:
class HiPPO_LSI(nn.Module):
    """Vanilla HiPPO-LegS model (scale invariant instead of time invariant)"""

    def __init__(
        self,
        N,
        method="legs",
        max_length=1024,
        discretization=0.5,
        lambda_n=1.0,
        alpha=0.0,
        beta=1.0,
    ):
        """
        max_length: maximum sequence length
        """
        super().__init__()
        self.N = N
        matrices = GuTransMatrix(
            N=N, measure=method, lambda_n=lambda_n, alpha=alpha, beta=beta
        )
        A = np.asarray(matrices.A, dtype=np.float32)
        B = np.asarray(matrices.B, dtype=np.float32)
        # A, B = transition(method, N)
        B = B.squeeze(-1)
        A_stacked = np.empty((max_length, N, N), dtype=A.dtype)
        B_stacked = np.empty((max_length, N), dtype=B.dtype)
        for t in range(1, max_length + 1):
            At = A / t
            Bt = B / t
            if discretization == 0.0:  # forward
                A_stacked[t - 1] = np.eye(N) + At
                B_stacked[t - 1] = Bt
            elif discretization == 1.0:  # backward
                A_stacked[t - 1] = la.solve_triangular(
                    np.eye(N) - At, np.eye(N), lower=True
                )
                B_stacked[t - 1] = la.solve_triangular(np.eye(N) - At, Bt, lower=True)
            elif discretization == 0.5:  # bilinear
                # A_stacked[t - 1] = la.solve_triangular(
                #     np.eye(N) - At / 2, np.eye(N) + At / 2, lower=True
                # )
                # B_stacked[t - 1] = la.solve_triangular(
                #     np.eye(N) - At / 2, Bt, lower=True
                # )
                alpha = 0.5
                A_stacked[t - 1] = np.linalg.lstsq(
                    np.eye(N) - (At * alpha), np.eye(N) + (At * alpha), rcond=None
                )[
                    0
                ]  # TODO: Referencing this: https://stackoverflow.com/questions/64527098/numpy-linalg-linalgerror-singular-matrix-error-when-trying-to-solve
                B_stacked[t - 1] = np.linalg.lstsq(
                    np.eye(N) - (At * alpha), Bt, rcond=None
                )[0]
            else:  # ZOH
                A_stacked[t - 1] = la.expm(A * (math.log(t + 1) - math.log(t)))
                # A_stacked[t - 1] = la.expm(At)
                B_stacked[t - 1] = la.solve_triangular(
                    A, A_stacked[t - 1] @ B - B, lower=True
                )
                
                # A_stacked[t - 1] = la.expm(At)
                # B_stacked[t - 1] = la.inv(A) @ (la.expm(At) - np.eye(A.shape[0])) @ B
                
                
        # self.register_buffer('A_stacked', torch.Tensor(A_stacked)) # (max_length, N, N)
        # self.register_buffer('B_stacked', torch.Tensor(B_stacked)) # (max_length, N)
        
        self.A_stacked = torch.Tensor(A_stacked.copy())  # (max_length, N, N)
        self.B_stacked = torch.Tensor(B_stacked.copy())  # (max_length, N)

        vals = np.linspace(0.0, 1.0, max_length)
        self.eval_matrix = torch.from_numpy(
            np.asarray(
                ((B[:, None] * ss.eval_legendre(np.arange(N)[:, None], 2 * vals - 1)).T)
            )
        )

    def forward(self, inputs, fast=True):
        """
        inputs : (length, ...)
        output : (length, ..., N) where N is the order of the HiPPO projection
        """

        L = inputs.shape[0]

        inputs = inputs.unsqueeze(-1)
        u = torch.transpose(inputs, 0, -2)
        u = u * self.B_stacked[:L]
        # print(f"Gu - u * self.B_stacked[:L]: {u}")
        u = torch.transpose(u, 0, -2)  # (length, ..., N)

        if fast:
            result = variable_unroll_matrix(self.A_stacked[:L], u)
            return result

        c = torch.zeros(u.shape[1:]).to(inputs)
        cs = []
        for t, f in enumerate(inputs):
            # print(f"\n--------------step {t}----------------")
            # print(f"self.A_stacked[{t}] shape:\n{self.A_stacked[t].shape}")
            # print(f"self.A_stacked[{t}]:\n{self.A_stacked[t]}")

            # print(f"c shape:\n{c.shape}")
            # print(f"c:\n{c}")

            # print(f"self.B_stacked[{t}] shape:\n{self.B_stacked[t].shape}")
            # print(f"self.B_stacked[{t}]:\n{self.B_stacked[t]}")

            # print(f"f shape:\n{f.shape}")
            # print(f"f:\n{f}")
            
            part1 = F.linear(c, self.A_stacked[t])
            part2 = self.B_stacked[t] * f
            
            c = part1 + part2

            # print(f"part1 - {t} - shape:\n{part1.shape}")
            # print(f"part1 - {t} -:\n{part1}")

            # print(f"part2 - {t} - shape:\n{part2.shape}")
            # print(f"part2 - {t} -:\n{part2}")
            
            cs.append(c)
        return torch.stack(cs, dim=0)

    def reconstruct(self, c):
        a = self.eval_matrix.to(c) @ c.unsqueeze(-1)
        return a


## Implementation Of General HiPPO Operator

In [25]:
class HiPPO(jnn.Module):
    """
    class that constructs HiPPO model using the defined measure.

    Args:
        N (int): order of the HiPPO projection, aka the number of coefficients to describe the matrix
        max_length (int): maximum sequence length to be input
        measure (str): the measure used to define which way to instantiate the HiPPO matrix
        step (float): step size used for descretization
        GBT_alpha (float): represents which descretization transformation to use based off the alpha value
        seq_L (int): length of the sequence to be used for training
        v (str): choice of vectorized or non-vectorized function instantiation
            - 'v': vectorized
            - 'nv': non-vectorized
        lambda_n (float): value associated with the tilt of legt
            - 1: tilt on legt
            - \sqrt(2n+1)(-1)^{N}: tilt associated with the legendre memory unit (LMU)
        fourier_type (str): choice of fourier measures
            - fru: fourier recurrent unit measure (FRU) - 'fru'
            - fout: truncated Fourier (FouT) - 'fout'
            - fourd: decaying fourier transform - 'fourd'
        alpha (float): The order of the Laguerre basis.
        beta (float): The scale of the Laguerre basis.
    """

    max_length: int
    step_size: float = 1.0  # < 1.0 if you want to use LTI discretization
    N: int = 100
    lambda_n: float = 1.0
    alpha: float = 0.0
    beta: float = 1.0
    GBT_alpha: float = 0.5
    measure: str = "legs"
    s_t: str = "lti"
    dtype: Any = jnp.float32
    verbose: bool = False

    def setup(self):
        matrices = TransMatrix(
            N=self.N,
            measure=self.measure,
            lambda_n=self.lambda_n,
            alpha=self.alpha,
            beta=self.beta,
            dtype=self.dtype,
        )

        self.A = matrices.A
        self.B = matrices.B

        self.C = jnp.ones((self.N, 1))
        self.D = jnp.zeros((1,))

        if self.step_size == 1.0:
            self.GBT_A_list, self.GBT_B_list = self.make_GBT_list(
                matrices.A, matrices.B, dtype=self.dtype
            )

        self.eval_matrix = self.create_eval_matrix(matrices.A, matrices.B)

    def __call__(self, f, init_state=None, kernel=False):
        if not kernel:
            if init_state is None:
                # init_state = jnp.zeros((f.shape[0], self.N, 1))
                init_state = jnp.zeros((f.shape[0], 1, self.N))

            if self.s_t == "lsi":
                c_k, y_k = self.lsi_recurrence(
                    A=self.GBT_A_list,
                    B=self.GBT_B_list,
                    C=self.C,
                    D=self.D,
                    c_0=init_state,
                    f=f,
                    alpha=self.GBT_alpha,
                    dtype=self.dtype,
                )
                c_k = jnp.stack(c_k, axis=0)
                y_k = jnp.stack(y_k, axis=0)

            elif self.s_t == "lti":
                c_k, y_k = self.lti_recurrence(
                    A=self.A,
                    B=self.B,
                    C=self.C,
                    D=self.D,
                    c_0=init_state,
                    f=f,
                    alpha=self.GBT_alpha,
                    step_size=self.step_size,
                    dtype=self.dtype,
                )
            else:
                raise ValueError(f"Incorrect value associated with invariance options, either pick 'lsi' or 'lti'.")

        else:
            Ab, Bb, Cb, Db = self.discretize(
                self.A,
                self.B,
                self.C,
                self.D,
                step=self.step_size,
                alpha=self.GBT_alpha,
            )
            c_k, y_k = self.causal_convolution(
                f, self.K_conv(Ab, Bb, Cb, Db, L=self.max_length)
            )

        return c_k, y_k

    def reconstruct(self, c):
        """
        Uses coeffecients to reconstruct the signal

        Args:
            c (jnp.ndarray): coefficients of the HiPPO projection

        Returns:
            reconstructed signal
        """
        return (self.eval_matrix @ jnp.expand_dims(c, -1)).squeeze(-1)

    def make_GBT_list(self, A, B, dtype=jnp.float32):
        """
        Creates the discretized GBT matrices for the given step size
        """
        GBT_a_list = []
        GBT_b_list = []
        for i in range(1, self.max_length + 1):
            # TODO: make this scale invariant optional
            GBT_A, GBT_B = self.discretize(
                A, B, step=i, alpha=self.GBT_alpha, dtype=dtype
            )
            GBT_a_list.append(GBT_A)
            GBT_b_list.append(GBT_B)

        return GBT_a_list, GBT_b_list

    def create_eval_matrix(self, A, B):
        """
        Creates the evaluation matrix used for reconstructing the signal
        """
        eval_matrix = None
        if self.measure == "legt":
            L = self.max_length
            vals = jnp.arange(0.0, 1.0, L)
            # n = jnp.arange(self.N)[:, None]
            zero_N = self.N - 1
            x = 1 - 2 * vals
            eval_matrix = jax.scipy.special.lpmn_values(
                m=zero_N, n=zero_N, z=x, is_normalized=False
            ).T  # ss.eval_legendre(n, x).T

        elif self.measure == "legs":
            L = self.max_length
            vals = jnp.linspace(0.0, 1.0, L)
            # n = jnp.arange(self.N)[:, None]
            zero_N = self.N - 1
            x = 2 * vals - 1
            eval_matrix = (
                B[:, None]
                * jax.scipy.special.lpmn_values(
                    m=zero_N, n=zero_N, z=x, is_normalized=False
                )
            ).T  # ss.eval_legendre(n, x)).T

        elif self.measure == "lagt":
            raise NotImplementedError("Translated Laguerre measure not implemented yet")

        elif self.measure == "fourier":
            raise NotImplementedError("Fourier measures are not implemented yet")

        else:
            raise ValueError("invalid measure")

        return eval_matrix
    
    def discretize(self, A, B, step, alpha=0.5, dtype=jnp.float32):
        """
        function used for discretizing the HiPPO matrix

        Args:
            A (jnp.ndarray): matrix to be discretized
            B (jnp.ndarray): matrix to be discretized
            C (jnp.ndarray): matrix to be discretized
            D (jnp.ndarray): matrix to be discretized
            step (float): step size used for discretization
            alpha (float, optional): used for determining which generalized bilinear transformation to use
                - forward Euler corresponds to α = 0,
                - backward Euler corresponds to α = 1,
                - bilinear corresponds to α = 0.5,
                - Zero-order Hold corresponds to α > 1
        """
        I = jnp.eye(A.shape[0])
        step_size = 1 / step
        part1 = I - (step_size * alpha * A)
        part2 = I + (step_size * (1 - alpha) * A)

        GBT_A = jnp.linalg.lstsq(part1, part2, rcond=None)[0]
        GBT_B = jnp.linalg.lstsq(part1, (step_size * B), rcond=None)[0]

        if alpha > 1:  # Zero-order Hold
            GBT_A = jax.scipy.linalg.expm(
                A * (math.log(step + self.step_size) - math.log(step))
            )
            GBT_B = jnp.linalg.inv(A) @ (GBT_A - I) @ B

        return GBT_A.astype(dtype), GBT_B.astype(dtype)

    def lsi_recurrence(self, A, B, C, D, c_0, f, alpha=0.5, dtype=jnp.float32):
        """
        This is for returning the discretized hidden state often needed for an RNN.
        Args:
            Ab (jnp.ndarray): the discretized A matrix
            Bb (jnp.ndarray): the discretized B matrix
            Cb (jnp.ndarray): the discretized C matrix
            f (jnp.ndarray): the input sequence
            c_0 (jnp.ndarray): the initial hidden state
        Returns:
            the next hidden state (aka coefficients representing the function, f(t))
        """

        c_k_list = []
        y_k_list = []

        c_k = c_0.copy()
        for i in range(f.shape[1]):
            # print(f"--------------step {i}----------------")
            c_k, y_k = jax.vmap(self.lsi_step, in_axes=(None, None, None, None, 0, 0))(
                A[i], B[i], C, D, c_k, f[:, i, :]
            )
            c_k_list.append((c_k.copy()).astype(dtype))
            y_k_list.append((y_k.copy()).astype(dtype))

        if self.verbose:
            return c_k_list, y_k_list
        else:
            return c_k_list[-1], y_k_list[-1]

    def lti_recurrence(self, A, B, C, D, c_0, f, alpha=0.5, step_size=1.0, dtype=jnp.float32):
        """
        This is for returning the discretized hidden state often needed for an RNN.
        Args:
            Ab (jnp.ndarray): the discretized A matrix
            Bb (jnp.ndarray): the discretized B matrix
            Cb (jnp.ndarray): the discretized C matrix
            f (jnp.ndarray): the input sequence
            c_0 (jnp.ndarray): the initial hidden state
        Returns:
            the next hidden state (aka coefficients representing the function, f(t))
        """
        Ad, Bd = self.discretize(
            A=A, B=B, step=step_size, alpha=alpha, dtype=dtype
        )

        def lti_step(c_k_i, f_k):
            """
            Get descretized coefficients of the hidden state by applying HiPPO matrix to input sequence, u_k, and previous hidden state, x_k_1.
            Args:
                c_k_i: previous hidden state
                f_k: output from function f at, descritized, time step, k.

            Returns:
                c_k: current hidden state
                y_k: current output of hidden state applied to Cb (sorry for being vague, I just dont know yet)
            """

            # part1 = jnp.dot(Ad, c_k_i)
            part1 = jnp.dot(c_k_i, Ad.T)
            part2 = Bd.T * f_k
            c_k = part1 + part2

            part3 = jnp.dot(C, c_k)
            part4 = D * f_k
            y_k = part3 + part4
            
            # jax.debug.print("Ad shape:\n{x1}", x1=Ad.shape)
            # jax.debug.print("Ad:\n{x2}", x2=Ad)
            
            # jax.debug.print("c_k_i shape:\n{x3}", x3=c_k_i.shape)
            # jax.debug.print("c_k_i:\n{x4}", x4=c_k_i)
            
            # jax.debug.print("Bd shape:\n{x5}", x5=Bd.shape)
            # jax.debug.print("Bd:\n{x6}", x6=Bd)
            
            # jax.debug.print("f_k shape:\n{x7}", x7=f_k.shape)
            # jax.debug.print("f_k:\n{x8}", x8=f_k)
            
            # jax.debug.print("part1 shape:\n{x9}", x9=part1.shape)
            # jax.debug.print("part1:\n{x10}", x10=part1)
            
            # jax.debug.print("part2 shape:\n{x11}", x11=part2.shape)
            # jax.debug.print("part2:\n{x12}", x12=part2)
            
            # jax.debug.print("part3 shape:\n{x13}", x13=part3.shape)
            # jax.debug.print("part3:\n{x14}", x14=part3)
            
            # jax.debug.print("part4 shape:\n{x15}", x15=part4.shape)
            # jax.debug.print("part4:\n{x16}", x16=part4)

            return c_k, (c_k, y_k)

        c_k, (c_s, y_s) = jax.vmap(jax.lax.scan, in_axes=(None, 0, 0))(lti_step, c_0, f)

        if self.verbose:
            return c_s, y_s
        else:
            return c_k, y_s

    def lsi_step(self, Ad, Bd, Cd, Dd, c_k_i, f_k):
        """
        Get descretized coefficients of the hidden state by applying HiPPO matrix to input sequence, u_k, and previous hidden state, x_k_1.
        Args:
            c_k_i: previous hidden state
            f_k: output from function f at, descritized, time step, k.

        Returns:
            c_k: current hidden state
            y_k: current output of hidden state applied to Cb (sorry for being vague, I just dont know yet)
        """

        # part1 = jnp.dot(Ad, c_k_i)
        part1 = jnp.dot(c_k_i, Ad.T)
        part2 = Bd.T * f_k
        c_k = part1 + part2

        part3 = jnp.dot(Cd, c_k)
        part4 = Dd * f_k
        y_k = part3 + part4
        
        # jax.debug.print("Ad shape:\n{x1}", x1=Ad.shape)
        # jax.debug.print("Ad:\n{x2}", x2=Ad)
        
        # jax.debug.print("c_k_i shape:\n{x3}", x3=c_k_i.shape)
        # jax.debug.print("c_k_i:\n{x4}", x4=c_k_i)
        
        # jax.debug.print("Bd shape:\n{x5}", x5=Bd.shape)
        # jax.debug.print("Bd:\n{x6}", x6=Bd)
        
        # jax.debug.print("f_k shape:\n{x7}", x7=f_k.shape)
        # jax.debug.print("f_k:\n{x8}", x8=f_k)
        
        # jax.debug.print("part1 shape:\n{x9}", x9=part1.shape)
        # jax.debug.print("part1:\n{x10}", x10=part1)
        
        # jax.debug.print("part2 shape:\n{x11}", x11=part2.shape)
        # jax.debug.print("part2:\n{x12}", x12=part2)
        
        # jax.debug.print("part3 shape:\n{x13}", x13=part3.shape)
        # jax.debug.print("part3:\n{x14}", x14=part3)
        
        # jax.debug.print("part4 shape:\n{x15}", x15=part4.shape)
        # jax.debug.print("part4:\n{x16}", x16=part4)

        return c_k, y_k

### LegS Bilinear Transform Tests

In [26]:
def test_hippo_legs_lsi_bi_operator(hippo_legs, gu_hippo_legs, random_input, legs_key):
    x_tensor = torch.tensor(random_input, dtype=torch.float32)
    x_jnp = jnp.asarray(x_tensor, dtype=jnp.float32)  # convert torch array to jax array
    
    # My Implementation
    print(f"-------------------------------------------------------------------------------------")
    print(f"----------------------------My LSI Implementation Outputs----------------------------")
    print(f"-------------------------------------------------------------------------------------")
    params = hippo_legs.init(legs_key, f=x_jnp)
    c_k, y_k_list = hippo_legs.apply(params, f=x_jnp)
    c_k = jnp.moveaxis(c_k, 0, 1)
    
    # Gu's HiPPO LegS
    print(f"-------------------------------------------------------------------------------------")
    print(f"---------------------------Gu's LSI Implementation Outputs---------------------------")
    print(f"-------------------------------------------------------------------------------------")
    x_tensor = torch.moveaxis(x_tensor, 0, 1)
    GU_c_k = gu_hippo_legs(x_tensor, fast=False)
    gu_c = jnp.asarray(GU_c_k, dtype=jnp.float32)  # convert torch array to jax array
    gu_c = jnp.moveaxis(gu_c, 0, 1)
    
    print(f"-------------------------------------------------------------------------")
    print(f"---------------------------Testing LSI Outputs---------------------------")
    print(f"-------------------------------------------------------------------------")
    jax.debug.print(f"inputted jnp-data shape: {x_jnp.shape}")
    jax.debug.print(f"inputted tensor-data shape: {x_tensor.shape}")
    print(f"c_k shape: {c_k.shape}")
    print(f"gu_c shape: {gu_c.shape}")
    
    for i in range(c_k.shape[0]):
        for j in range(c_k.shape[1]):
            print(
                f"batch {i} on trajectory {j} compare : {jnp.allclose(c_k[i,j,:,:], gu_c[i,j,:,:], rtol=1e-03, atol=1e-03)}"
            )


In [27]:
def test_hippo_legs_lti_bi_operator(hippo_legs, gu_hippo_legs, random_input, legs_key):
    x_tensor = torch.tensor(random_input, dtype=torch.float32)
    x_jnp = jnp.asarray(x_tensor, dtype=jnp.float32)  # convert torch array to jax array
    
    # My Implementation
    print(f"-------------------------------------------------------------------------------------")
    print(f"----------------------------My LTI Implementation Outputs----------------------------")
    print(f"-------------------------------------------------------------------------------------")
    params = hippo_legs.init(legs_key, f=x_jnp)
    c_k, y_k_list = hippo_legs.apply(params, f=x_jnp)
    # jax.debug.print(f"c_k: {c_k}")
    
    # Gu's HiPPO LegS
    print(f"-------------------------------------------------------------------------------------")
    print(f"---------------------------Gu's LTI Implementation Outputs---------------------------")
    print(f"-------------------------------------------------------------------------------------")
    x_tensor = torch.moveaxis(x_tensor, 0, 1)
    GU_c_k = gu_hippo_legs(x_tensor, fast=False)
    gu_c = jnp.asarray(GU_c_k, dtype=jnp.float32)  # convert torch array to jax array
    gu_c = jnp.moveaxis(gu_c, 0, 1)
    # jax.debug.print(f"gu_c: {gu_c}")
    
    print(f"-------------------------------------------------------------------------")
    print(f"---------------------------Testing LTI Outputs---------------------------")
    print(f"-------------------------------------------------------------------------")
    print(f"inputted jnp-data shape: {x_jnp.shape}")
    print(f"inputted tensor-data shape: {x_tensor.shape}")
    print(f"c_k shape: {c_k.shape}")
    print(f"gu_c shape: {gu_c.shape}")
    
    for i in range(c_k.shape[0]):
        for j in range(c_k.shape[1]):
            print(
                f"batch {i} on trajectory {j} compare : {jnp.allclose(c_k[i,j,:,:], gu_c[i,j,:,:], rtol=1e-03, atol=1e-03)}"
            )


### LegS ZOH Tests

In [28]:
def test_hippo_legs_lsi_zoh_operator(hippo_legs, gu_hippo_legs, random_input, legs_key):
    x_tensor = torch.tensor(random_input, dtype=torch.float32)
    x_jnp = jnp.asarray(x_tensor, dtype=jnp.float32)  # convert torch array to jax array
    
    # My Implementation
    print(f"-------------------------------------------------------------------------------------")
    print(f"----------------------------My LSI Implementation Outputs----------------------------")
    print(f"-------------------------------------------------------------------------------------")
    params = hippo_legs.init(legs_key, f=x_jnp)
    c_k, y_k_list = hippo_legs.apply(params, f=x_jnp)
    c_k = jnp.moveaxis(c_k, 0, 1)
    
    # Gu's HiPPO LegS
    print(f"-------------------------------------------------------------------------------------")
    print(f"---------------------------Gu's LSI Implementation Outputs---------------------------")
    print(f"-------------------------------------------------------------------------------------")
    x_tensor = torch.moveaxis(x_tensor, 0, 1)
    GU_c_k = gu_hippo_legs(x_tensor, fast=False)
    gu_c = jnp.asarray(GU_c_k, dtype=jnp.float32)  # convert torch array to jax array
    gu_c = jnp.moveaxis(gu_c, 0, 1)
    # jax.debug.print(f"gu_c: {gu_c}")
    
    print(f"-------------------------------------------------------------------------")
    print(f"---------------------------Testing LSI Outputs---------------------------")
    print(f"-------------------------------------------------------------------------")
    print(f"inputted jnp-data shape: {x_jnp.shape}")
    print(f"inputted tensor-data shape: {x_tensor.shape}")
    print(f"c_k shape: {c_k.shape}")
    print(f"gu_c shape: {gu_c.shape}")
    
    for i in range(c_k.shape[0]):
        for j in range(c_k.shape[1]):
            print(
                f"batch {i} on trajectory {j} compare : {jnp.allclose(c_k[i,j,:,:], gu_c[i,j,:,:], rtol=1e-03, atol=1e-03)}"
            )


In [29]:
def test_hippo_legs_lti_zoh_operator(hippo_legs, gu_hippo_legs, random_input, legs_key):
    x_tensor = torch.tensor(random_input, dtype=torch.float32)
    x_jnp = jnp.asarray(x_tensor, dtype=jnp.float32)  # convert torch array to jax array
    
    # My Implementation
    print(f"-------------------------------------------------------------------------------------")
    print(f"----------------------------My LTI Implementation Outputs----------------------------")
    print(f"-------------------------------------------------------------------------------------")
    params = hippo_legs.init(legs_key, f=x_jnp)
    c_k, y_k_list = hippo_legs.apply(params, f=x_jnp)
    # jax.debug.print(f"c_k: {c_k}")
    
    # Gu's HiPPO LegS
    print(f"-------------------------------------------------------------------------------------")
    print(f"---------------------------Gu's LTI Implementation Outputs---------------------------")
    print(f"-------------------------------------------------------------------------------------")
    x_tensor = torch.moveaxis(x_tensor, 0, 1)
    GU_c_k = gu_hippo_legs(x_tensor, fast=False)
    gu_c = jnp.asarray(GU_c_k, dtype=jnp.float32)  # convert torch array to jax array
    gu_c = jnp.moveaxis(gu_c, 0, 1)
    # jax.debug.print(f"gu_c: {gu_c}")
    
    print(f"-------------------------------------------------------------------------")
    print(f"---------------------------Testing LTI Outputs---------------------------")
    print(f"-------------------------------------------------------------------------")
    print(f"inputted jnp-data shape: {x_jnp.shape}")
    print(f"inputted tensor-data shape: {x_tensor.shape}")
    print(f"c_k shape: {c_k.shape}")
    print(f"gu_c shape: {gu_c.shape}")
    
    for i in range(c_k.shape[0]):
        for j in range(c_k.shape[1]):
            print(
                f"batch {i} on trajectory {j} compare : {jnp.allclose(c_k[i,j,:,:], gu_c[i,j,:,:], rtol=1e-03, atol=1e-03)}"
            )

In [30]:
def random_16_input(key_generator, batch_size=16, data_size=784, input_size=28):
    # x = jax.random.randint(key_generator, (batch_size, data_size), 0, 255)
    x = jax.random.uniform(key_generator, (batch_size, data_size))
    return jax.vmap(moving_window, in_axes=(0, None))(x, input_size)


In [31]:
def test_LSI_GBT(hippo, gu_hippo, A, B, random_input, alpha=0.5):
    L = random_input.shape[1]
    for i in range(1, L+1):
        GBT_A, GBT_B = hippo.discretize(A, B, step=i, alpha=alpha, dtype=jnp.float32)
        gu_GBT_A, gu_GBT_B = (
            jnp.asarray(gu_hippo.A_stacked[i-1], dtype=jnp.float32),
            jnp.expand_dims(jnp.asarray(gu_hippo.B_stacked[i-1], dtype=jnp.float32), axis=1),
        )
        
        print(f"GBT_A: {jnp.allclose(GBT_A, gu_GBT_A, rtol=1e-05, atol=1e-05)}")
        print(f"GBT_B: {jnp.allclose(GBT_B, gu_GBT_B, rtol=1e-05, atol=1e-05)}\n")

In [32]:
def test_LTI_GBT(hippo, gu_hippo, A, B, random_input, alpha=0.5):
    L = random_input.shape[1]
    GBT_A, GBT_B = hippo.discretize(A, B, step=1.0, alpha=alpha, dtype=jnp.float32)
    gu_GBT_A, gu_GBT_B = (
        jnp.asarray(gu_hippo.dA, dtype=jnp.float32),
        jnp.expand_dims(jnp.asarray(gu_hippo.dB, dtype=jnp.float32), axis=1),
    )
    print(f"gu_GBT_A shape:{gu_GBT_A.shape}\n")
    print(f"GBT_A shape: {GBT_A.shape}\n")
    print(f"gu_GBT_B shape: {gu_GBT_B.shape}\n")
    print(f"GBT_B shape: {GBT_B.shape}")
    
    print(f"gu_GBT_A:\n{gu_GBT_A}\n")
    print(f"GBT_A:\n{GBT_A}\n")
    print(f"gu_GBT_B:\n{gu_GBT_B}\n")
    print(f"GBT_B:\n{GBT_B}")
    
    print(f"GBT_A: {jnp.allclose(GBT_A, gu_GBT_A, rtol=1e-05, atol=1e-05)}")
    print(f"GBT_B: {jnp.allclose(GBT_B, gu_GBT_B, rtol=1e-05, atol=1e-05)}\n")

In [33]:
def test_hippos(matrices, gu_matrices):
    A = matrices.A
    B = matrices.B
    gu_A = gu_matrices.A
    gu_B = gu_matrices.B
    
    # print(f"A shape: {A.shape}")
    # print(f"B shape: {B.shape}")
    # print(f"gu_A shape: {gu_A.shape}")
    # print(f"gu_B shape: {gu_B.shape}")
    
    # print(f"gu_A:\n{gu_A}\n")
    # print(f"A:\n{A}\n")
    # print(f"gu_B:\n{gu_B}\n")
    # print(f"B:\n{B}")
    
    print(f"A: {jnp.allclose(A, gu_A, rtol=1e-05, atol=1e-05)}")
    print(f"B: {jnp.allclose(B, gu_B, rtol=1e-05, atol=1e-05)}\n")

In [34]:
def test():
    # N = 256
    # L = 128
    the_measure = "legs"
    
    batch_size = 2
    data_size = 16
    input_size = 1
    
    N = 32
    L = data_size
    
    x_jnp = random_16_input(
        key_generator=key3, 
        batch_size=batch_size, 
        data_size=data_size, 
        input_size=input_size
    )
    x_np = np.asarray(x_jnp)

    x = torch.tensor(x_np, dtype=torch.float32)

    print(f"X shape: {x.shape}")

    # ----------------------------------------------------------------------------------
    loss = nn.MSELoss()

    # ----------------------------------------------------------------------------------
    # ------------------------------ Test HiPPO LegT model -----------------------------
    # ----------------------------------------------------------------------------------
    print("\nTesting HiPPO LegT model")
    # hippo_legt = HiPPO_LTI(N, dt=1.0 / L)

    # c_k = hippo_legt(x)

    # print(f"Gu's Coeffiecients for LegT:\n{c_k}")
    # print(f"Gu's Coeffiecient shapes for LegT:\n{c_k.shape}")

    # z = hippo_legt.reconstruct(c_k)
    # print(f"Gu's Reconstruction for LegT:\n{z}")
    # print(f"Gu's Reconstruction shape for LegT:\n{z.shape}")

    # mse = loss(z[-1, 0, :L], x.squeeze(-1))
    # print(f"h-MSE shape:\n{mse}")
    # print(f"end of test for HiPPO LegT model")

    # ----------------------------------------------------------------------------------
    # ------------------------------ Test HiPPO LegS model -----------------------------
    # ----------------------------------------------------------------------------------
    print("\nTesting HiPPO LegS model")
    print(f"Creating Gu's HiPPO-{the_measure} LTI model with bidirectional transform")
    gu_hippo_legs_lti_bi = HiPPO_LTI(
        N=N,
        method="legs",
        dt=1.0,
        T=1.0,
        discretization=0.5,
        lambda_n=1.0,
        alpha=0.0,
        beta=1.0,
        c=0.0,
    )  # The Gu's
    
    print(f"Creating Gu's HiPPO-{the_measure} LSI model with bidirectional transform")
    gu_hippo_legs_lsi_bi = HiPPO_LSI(
        N=N,
        method="legs",
        max_length=L,
        discretization=0.5,
        lambda_n=1.0,
        alpha=0.0,
        beta=1.0
    )  # The Gu's
    
    print(f"Creating Gu's HiPPO-{the_measure} LTI model with ZOH transform")
    gu_hippo_legs_lti_zoh = HiPPO_LTI(
        N=N,
        method="legs",
        dt=1.0,
        T=1.0,
        discretization="zoh",
        lambda_n=1.0,
        alpha=0.0,
        beta=1.0,
        c=0.0,
    )  # The Gu's
    
    print(f"Creating Gu's HiPPO-{the_measure} LSI model with ZOH transform")
    gu_hippo_legs_lsi_zoh = HiPPO_LSI(
        N=N,
        method="legs",
        max_length=L,
        discretization="zoh",
        lambda_n=1.0,
        alpha=0.0,
        beta=1.0
    )  # The Gu's
    
    # print(f"gu_hippo_legs A_stacked: {gu_hippo_legs.A_stacked}")
    # c_k = gu_hippo_legs(x)

    # print(f"Gu's Coeffiecients  for LegS:\n{c_k}")
    # print(f"Gu's Coeffiecient shapes for LegS:\n{c_k.shape}")

    # z = hippo_legs.reconstruct(c_k)

    # print(f"Gu's Reconstruction for LegS:\n{z}")
    # print(f"Gu's Reconstruction shape for LegS:\n{z.shape}")

    # print(y-z)
    print(f"end of test for HiPPO LegS model")

    # ----------------------------------------------------------------------------------
    # ------------------------------ Test Generic HiPPO model --------------------------
    # ----------------------------------------------------------------------------------
    print(f"\nTesting BRYANS HiPPO-{the_measure} model")
    matrices = TransMatrix(
            N=N,
            measure="legs",
            lambda_n=1.0,
            alpha=0.0,
            beta=1.0,
            dtype=jnp.float32,
        )
    
    gu_matrices = GuTransMatrix(
            N=N,
            measure="legs",
            lambda_n=1.0,
            alpha=0.0,
            beta=1.0,
        )
    
    print(f"\nTesting BRYANS HiPPO-{the_measure} model")
    test_hippos(matrices, gu_matrices)

    A = matrices.A
    B = matrices.B
    
    print(f"Creating HiPPO-{the_measure} LTI model with bidirectional transform")    
    hippo_legs_lti_bi = HiPPO(
        max_length=L,
        step_size=1.0,
        N=N,
        lambda_n=1.0,
        alpha=0.0,
        beta=1.0,
        GBT_alpha=0.5,
        measure="legs",
        s_t="lti",
        dtype = jnp.float32,
        verbose = True,
    )  # Bryan's
    
    print(f"Creating HiPPO-{the_measure} LSI model with bidirectional transform")
    hippo_legs_lsi_bi = HiPPO(
        max_length=L,
        step_size=1.0,
        N=N,
        lambda_n=1.0,
        alpha=0.0,
        beta=1.0,
        GBT_alpha=0.5,
        measure="legs",
        s_t="lsi",
        dtype = jnp.float32,
        verbose = True,
    )  # Bryan's
    
    print(f"Creating HiPPO-{the_measure} LTI model with ZOH transform")
    hippo_legs_lti_zoh = HiPPO(
        max_length=L,
        step_size=1.0,
        N=N,
        lambda_n=1.0,
        alpha=0.0,
        beta=1.0,
        GBT_alpha=2.0,
        measure="legs",
        s_t="lti",
        dtype = jnp.float32,
        verbose = True,
    )  # Bryan's
    
    print(f"Creating HiPPO-{the_measure} LSI model with ZOH transform")
    hippo_legs_lsi_zoh = HiPPO(
        max_length=L,
        step_size=1.0,
        N=N,
        lambda_n=1.0,
        alpha=0.0,
        beta=1.0,
        GBT_alpha=2.0,
        measure="legs",
        s_t="lsi",
        dtype = jnp.float32,
        verbose = True,
    )  # Bryan's
    
    # print(f"Testing for correct LTI GBT matrices for HiPPO-{the_measure}")
    # test_LTI_GBT(
    #     hippo=hippo_legs_lti_bi, 
    #     gu_hippo=gu_hippo_legs_lti_bi, 
    #     A=A, 
    #     B=B, 
    #     random_input=x_np,
    #     alpha=0.5
    # )
    
    # print(f"Testing for correct LSI GBT matrices for HiPPO-{the_measure}")
    # test_LSI_GBT(
    #     hippo=hippo_legs_lsi_bi, 
    #     gu_hippo=gu_hippo_legs_lsi_bi, 
    #     A=A, 
    #     B=B, 
    #     random_input=x_np,
    #     alpha=0.5
    # )
    
    # print(f"Bryan's Coeffiecients for Bilinear LSI HiPPO-{the_measure}")
    
    # test_hippo_legs_lsi_bi_operator(
    #     hippo_legs=hippo_legs_lsi_bi, 
    #     gu_hippo_legs=gu_hippo_legs_lsi_bi, 
    #     random_input=x_np, 
    #     legs_key=key2
    # )
    
    # print(f"\n\nBryan's Coeffiecients for Bilinear LTI HiPPO-{the_measure}")
    # test_hippo_legs_lti_bi_operator(
    #     hippo_legs=hippo_legs_lti_bi, 
    #     gu_hippo_legs=gu_hippo_legs_lti_bi, 
    #     random_input=x_np, 
    #     legs_key=key2
    # )
    
    print(f"Testing for correct LTI GBT matrices for HiPPO-{the_measure} w/ ZOH")
    test_LTI_GBT(
        hippo=hippo_legs_lti_zoh, 
        gu_hippo=gu_hippo_legs_lti_zoh, 
        A=A, 
        B=B, 
        random_input=x_np,
        alpha=2.0
    )
    
    print(f"Testing for correct LSI GBT matrices for HiPPO-{the_measure} w/ ZOH")
    test_LSI_GBT(
        hippo=hippo_legs_lsi_zoh, 
        gu_hippo=gu_hippo_legs_lsi_zoh, 
        A=A, 
        B=B, 
        random_input=x_np,
        alpha=2.0
    )
    
    print(f"Bryan's Coeffiecients for ZOH LSI HiPPO-{the_measure}")
    
    test_hippo_legs_lsi_zoh_operator(
        hippo_legs=hippo_legs_lsi_zoh, 
        gu_hippo_legs=gu_hippo_legs_lsi_zoh, 
        random_input=x_np, 
        legs_key=key2
    )
    
    print(f"\n\nBryan's Coeffiecients for ZOH LTI HiPPO-{the_measure}")
    test_hippo_legs_lti_zoh_operator(
        hippo_legs=hippo_legs_lti_zoh, 
        gu_hippo_legs=gu_hippo_legs_lti_zoh, 
        random_input=x_np, 
        legs_key=key2
    )
    
    # y_legs = hippo_LegS_B.apply(
    #     {"params": params}, c_k, method=hippo_LegS_B.reconstruct
    # )

    # print(f"Bryan's Reconstruction for HiPPO-{the_measure}:\n{y_legs}")
    # print(f"Bryan's Reconstruction shape for HiPPO-{the_measure}:\n{y_legs.shape}")

    print(f"end of test for HiPPO-{the_measure} model")

## Output

In [35]:
test()


X shape: torch.Size([2, 16, 1])

Testing HiPPO LegT model

Testing HiPPO LegS model
Creating Gu's HiPPO-legs LTI model with bidirectional transform
Creating Gu's HiPPO-legs LSI model with bidirectional transform
Creating Gu's HiPPO-legs LTI model with ZOH transform
Creating Gu's HiPPO-legs LSI model with ZOH transform
end of test for HiPPO LegS model

Testing BRYANS HiPPO-legs model

Testing BRYANS HiPPO-legs model
A: True
B: True

Creating HiPPO-legs LTI model with bidirectional transform
Creating HiPPO-legs LSI model with bidirectional transform
Creating HiPPO-legs LTI model with ZOH transform
Creating HiPPO-legs LSI model with ZOH transform
Testing for correct LTI GBT matrices for HiPPO-legs w/ ZOH
gu_GBT_A shape:(32, 32)

GBT_A shape: (32, 32)

gu_GBT_B shape: (32, 1)

GBT_B shape: (32, 1)
gu_GBT_A:
[[ 3.6787945e-01  0.0000000e+00  0.0000000e+00 ...  0.0000000e+00  0.0000000e+00  0.0000000e+00]
 [-4.0277830e-01  1.3533530e-01  0.0000000e+00 ...  0.0000000e+00  0.0000000e+00  0.0000