# HiPPO Matrices
---

## Table of Contents
* [Loading In Necessary Packages](#load-packages)
* [Instantiate The HiPPO Matrix](#instantiate-the-hippo-matrix)
    * [Translated Legendre (LegT)](#translated-legendre-legt)
        * [LegT](#legt)
        * [LMU](#lmu)
    * [Translated Laguerre (LagT)](#translated-laguerre-lagt)
    * [Scaled Legendre (LegS)](#scaled-legendre-legs)
    * [Fourier Basis](#fourier-basis)
        * [Fourier Recurrent Unit (FRU)](#fourier-recurrent-unit-fru)
        * [Truncated Fourier (FouT)](#truncated-fourier-fout)
        * [Fourier With Decay (FourD)](#fourier-with-decay-fourd)
* [Gu's Linear Time Invariant (LTI) HiPPO Operator](#gus-hippo-legt-operator)
* [Gu's Scale invariant (LSI) HiPPO Operator](#gus-scale-invariant-hippo-legs-operator)
* [Implementation Of General HiPPO Operator](#implementation-of-general-hippo-operator)
* [Test Generalized Bilinear Transform and Zero Order Hold Matrices](#test-generalized-bilinear-transform-and-zero-order-hold-matrices)
    * [Testing Forward Euler on GBT matrices](#testing-forward-euler-transform-for-lti-and-lsi)
    * [Testing Backward Euler on GBT matrices](#testing-backward-euler-transform-for-lti-and-lsi-on-legs-matrices)
    * [Testing Bidirectional on GBT matrices](#testing-lti-and-lsi-operators-with-bidirectional-transform)
    * [Testing ZOH on GBT matrices](#testing-zoh-transform-for-lti-and-lsi-on-legs-matrices)
* [Testing HiPPO Operators](#test-hippo-operators)
    * [Testing Forward Euler on HiPPO Operators](#testing-lti-and-lsi-operators-with-forward-euler-transform)
    * [Testing Backward Euler on HiPPO Operators](#testing-lti-and-lsi-operators-with-backward-euler-transform)
    * [Testing Bidirectional on HiPPO Operators](#testing-lti-and-lsi-operators-with-bidirectional-transform)
    * [Testing ZOH on HiPPO Operators](#testing-lti-and-lsi-operators-with-zoh-transform)
---


## Load Packages

In [1]:
import os
import sys

module_path = os.path.abspath(os.path.join("../../../"))
print(f"module_path: {module_path}")
if module_path not in sys.path:
    print(f"Adding {module_path} to sys.path")
    sys.path.append(module_path)

module_path: /home/beegass/Documents/Coding/hippo-jax (copy)


In [2]:
# os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "False"
os.environ["TF_FORCE_UNIFIED_MEMORY"] = "1"

In [3]:
## import packages
import jax
import jax.numpy as jnp
from jax._src.scipy.special import _gen_associated_legendre, gammaln
import einops
from scipy import special as ss
import numpy as np
import torch
import sys
from functools import partial

In [4]:
sys.setrecursionlimit(10**6)

In [5]:
print(jax.devices())
print(f"The Device: {jax.lib.xla_bridge.get_backend().platform}")

[StreamExecutorGpuDevice(id=0, process_index=0, slice_index=0)]
The Device: gpu


In [6]:
print(f"MPS enabled: {torch.backends.mps.is_available()}")

MPS enabled: False


In [7]:
torch.set_printoptions(linewidth=150)
np.set_printoptions(linewidth=150)
jnp.set_printoptions(linewidth=150)

In [8]:
seed = 1701
key = jax.random.PRNGKey(seed)

In [9]:
np.random.seed(seed)

In [10]:
num_copies = 10
subkeys = jax.random.split(key, num=num_copies)
key = subkeys[0]

In [None]:
# @jax.jit
# def genlaguerre_recurrence(n, alpha, x, max_n):
#     """
#     Computes the generalized Laguerre polynomial of degree n with parameter alpha at point x using the recurrence relation.

#     Args:
#     n: int, the degree of the generalized Laguerre polynomial.
#     alpha: float, the parameter of the generalized Laguerre polynomial.
#     x: float, the point at which to evaluate the polynomial.
#     max_n: int, the maximum degree of n in the batch.

#     Returns:
#     The value of the generalized Laguerre polynomial of degree n with parameter alpha at point x.
#     """
#     # Initialize the array to store the generalized Laguerre polynomials for all degrees from 0 to max_n
#     p = jnp.zeros((max_n + 1,) + x.shape)
#     p = p.at[0].set(1.0)  # Set the 0th degree generalized Laguerre polynomial

#     # Compute the generalized Laguerre polynomials for degrees 1 to max_n using the recurrence relation
#     def body_fun(i, p):
#         p_i = ((2 * i + alpha - 1 - x) * p[i - 1] - (i + alpha - 1) * p[i - 2]) / i
#         return p.at[i].set(p_i)

#     p = jax.lax.fori_loop(1, max_n + 1, body_fun, p)

#     return p[n]

In [None]:
# @jax.jit
# def eval_genlaguerre(n, alpha, x, out=None):
#     """
#     Evaluates the generalized Laguerre polynomials of degrees specified in the input array n with parameter alpha at the points specified in the input array x.

#     Args:
#     n: array-like, the degrees of the generalized Laguerre polynomials.
#     alpha: float, the parameter of the generalized Laguerre polynomials.
#     x: array-like, the points at which to evaluate the polynomials.
#     out: optional, an output array to store the results.

#     Returns:
#     An array containing the generalized Laguerre polynomial values of the specified degrees with parameter alpha at the specified points.
#     """
#     n = jnp.asarray(n)
#     x = jnp.asarray(x)
#     max_n = n.max()

#     if n.ndim == 1 and x.ndim == 1:
#         p = jax.vmap(
#             lambda ni: jax.vmap(
#                 lambda xi: genlaguerre_recurrence(ni, alpha, xi, max_n)
#             )(x)
#         )(n)
#         p = jnp.diagonal(
#             p
#         )  # Get the diagonal elements to match the scipy.signal.eval_genlaguerre output
#     else:
#         p = jax.vmap(
#             lambda ni: jax.vmap(
#                 lambda xi: genlaguerre_recurrence(ni, alpha, xi, max_n)
#             )(x)
#         )(n)

#     if out is not None:
#         out = jnp.asarray(out)
#         out = jnp.copy(p, out=out)
#         return out
#     else:
#         return jnp.squeeze(p)

In [None]:
def genlaguerre_recurrence(n, alpha, x):
    """
    Computes the generalized Laguerre polynomial of degree n with parameter alpha at point x using the recurrence relation.
    """

    def body_fun(carry, _):
        i, (p_im1, p_i) = carry
        p_ip1 = ((2 * i + alpha + 1 - x) * p_i - (i + alpha) * p_im1) / (i + 1)
        return (i + 1, (p_i, p_ip1)), None

    _, (p_im1, p_i) = jax.lax.scan(body_fun, (1, (1.0, x)), None, n - 1)

    return p_i if n > 0 else p_im1

In [None]:
def eval_genlaguerre(n, alpha, x):
    """
    Evaluates the generalized Laguerre polynomials of degrees specified in the input array n with parameter alpha at the points specified in the input array x.
    """
    n = jnp.asarray(n)
    x = jnp.asarray(x)

    p = jax.vmap(lambda ni, xi: genlaguerre_recurrence(ni, alpha, xi))(n, x)

    return jnp.squeeze(p)

In [13]:
def test_eval_genlaguerre():
    alpha = 2.0
    n = np.arange(10)

    print(f"n = {n}")
    print(f"n shape = {n.shape}")

    x = np.linspace(-1, 1, n.shape[0])

    print(f"x = {x}")
    print(f"x shape = {x.shape}")

    y_pred = eval_genlaguerre(n, alpha, x)
    y = jnp.array(ss.eval_genlaguerre(n, alpha, x))

    print(f"y_pred = {y_pred}")
    print(f"y_pred shape = {y_pred.shape}")
    print(f"y = {y}")
    print(f"y shape = {y.shape}")

    assert np.allclose(y_pred, y, rtol=1e-5, atol=1e-8), "Results do not match"
    print("Results match")

I also did this for evaluating Laguerre polynomials:

```python
def genlaguerre_recurrence(n, alpha, x, max_n):
    """
    Computes the generalized Laguerre polynomial of degree n with parameter alpha at point x using the recurrence relation.

    Args:
    n: int, the degree of the generalized Laguerre polynomial.
    alpha: float, the parameter of the generalized Laguerre polynomial.
    x: float, the point at which to evaluate the polynomial.
    max_n: int, the maximum degree of n in the batch.

    Returns:
    The value of the generalized Laguerre polynomial of degree n with parameter alpha at point x.
    """
    # Initialize the array to store the generalized Laguerre polynomials for all degrees from 0 to max_n
    p = jnp.zeros((max_n + 1,) + x.shape)
    p = p.at[0].set(1.0)  # Set the 0th degree generalized Laguerre polynomial

    # Compute the generalized Laguerre polynomials for degrees 1 to max_n using the recurrence relation
    def body_fun(i, p):
        p_i = ((2 * i + alpha - 1 - x) * p[i - 1] - (i + alpha - 1) * p[i - 2]) / i
        return p.at[i].set(p_i)

    p = jax.lax.fori_loop(1, max_n + 1, body_fun, p)

    return p[n]
```

```python
def eval_genlaguerre(n, alpha, x, out=None):
    """
    Evaluates the generalized Laguerre polynomials of degrees specified in the input array n with parameter alpha at the points specified in the input array x.

    Args:
    n: array-like, the degrees of the generalized Laguerre polynomials.
    alpha: float, the parameter of the generalized Laguerre polynomials.
    x: array-like, the points at which to evaluate the polynomials.
    out: optional, an output array to store the results.

    Returns:
    An array containing the generalized Laguerre polynomial values of the specified degrees with parameter alpha at the specified points.
    """
    n = jnp.asarray(n)
    x = jnp.asarray(x)
    max_n = n.max()

    if n.ndim == 1 and x.ndim == 1:
        p = jax.vmap(
            lambda ni: jax.vmap(
                lambda xi: genlaguerre_recurrence(ni, alpha, xi, max_n)
            )(x)
        )(n)
        p = jnp.diagonal(
            p
        )  # Get the diagonal elements to match the scipy.signal.eval_genlaguerre output
    else:
        p = jax.vmap(
            lambda ni: jax.vmap(
                lambda xi: genlaguerre_recurrence(ni, alpha, xi, max_n)
            )(x)
        )(n)

    if out is not None:
        out = jnp.asarray(out)
        out = jnp.copy(p, out=out)
        return out
    else:
        return p
```

```python
def test_eval_genlaguerre():
    alpha = 2.0
    n = np.array([0, 1, 2, 3])

    print(f"n = {n}")
    print(f"n shape = {n.shape}")

    x = np.linspace(-1, 1, n.shape[0])

    print(f"x = {x}")
    print(f"x shape = {x.shape}")

    y_pred = eval_genlaguerre(n, alpha, x)
    y = jnp.array(ss.eval_genlaguerre(n, alpha, x))

    print(f"y_pred = {y_pred}")
    print(f"y_pred shape = {y_pred.shape}")
    print(f"y = {y}")
    print(f"y shape = {y.shape}")

    assert np.allclose(y_pred, y, rtol=1e-5, atol=1e-8), "Results do not match"
    print("Results match")
```

output:
```bash
n = [0 1 2 3]
n shape = (4,)
x = [-1.         -0.33333333  0.33333333  1.        ]
x shape = (4,)
y_pred = [1.        3.3333333 4.7222223 2.3333335]
y_pred shape = (4,)
y = [1.        3.3333333 4.7222223 2.3333333]
y shape = (4,)
Results match
```

In [14]:
test_eval_genlaguerre()

n = [0 1 2 3]
n shape = (4,)
x = [-1.         -0.33333333  0.33333333  1.        ]
x shape = (4,)
y_pred = [1.        3.3333333 4.7222223 2.3333335]
y_pred shape = (4,)
y = [1.        3.3333333 4.7222223 2.3333333]
y shape = (4,)
Results match
