# HiPPO Matrices
---

## Table of Contents
* [Loading In Necessary Packages](#load-packages)
* [Instantiate The HiPPO Matrix](#instantiate-the-hippo-matrix)
    * [Translated Legendre (LegT)](#translated-legendre-legt)
        * [LegT](#legt)
        * [LMU](#lmu)
    * [Translated Laguerre (LagT)](#translated-laguerre-lagt)
    * [Scaled Legendre (LegS)](#scaled-legendre-legs)
    * [Fourier Basis](#fourier-basis)
        * [Fourier Recurrent Unit (FRU)](#fourier-recurrent-unit-fru)
        * [Truncated Fourier (FouT)](#truncated-fourier-fout)
        * [Fourier With Decay (FourD)](#fourier-with-decay-fourd)
* [Gu's Linear Time Invariant (LTI) HiPPO Operator](#gus-hippo-legt-operator)
* [Gu's Scale invariant (LSI) HiPPO Operator](#gus-scale-invariant-hippo-legs-operator)
* [Implementation Of General HiPPO Operator](#implementation-of-general-hippo-operator)
* [Test Generalized Bilinear Transform and Zero Order Hold Matrices](#test-generalized-bilinear-transform-and-zero-order-hold-matrices)
    * [Testing Forward Euler on GBT matrices](#testing-forward-euler-transform-for-lti-and-lsi)
    * [Testing Backward Euler on GBT matrices](#testing-backward-euler-transform-for-lti-and-lsi-on-legs-matrices)
    * [Testing Bidirectional on GBT matrices](#testing-lti-and-lsi-operators-with-bidirectional-transform)
    * [Testing ZOH on GBT matrices](#testing-zoh-transform-for-lti-and-lsi-on-legs-matrices)
* [Testing HiPPO Operators](#test-hippo-operators)
    * [Testing Forward Euler on HiPPO Operators](#testing-lti-and-lsi-operators-with-forward-euler-transform)
    * [Testing Backward Euler on HiPPO Operators](#testing-lti-and-lsi-operators-with-backward-euler-transform)
    * [Testing Bidirectional on HiPPO Operators](#testing-lti-and-lsi-operators-with-bidirectional-transform)
    * [Testing ZOH on HiPPO Operators](#testing-lti-and-lsi-operators-with-zoh-transform)
---


## Load Packages

In [1]:
import os
import sys

module_path = os.path.abspath(os.path.join("../../../"))
print(f"module_path: {module_path}")
if module_path not in sys.path:
    print(f"Adding {module_path} to sys.path")
    sys.path.append(module_path)

module_path: /home/beegass/Documents/Coding/HiPPO-Jax


In [2]:
# os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "False"
os.environ["TF_FORCE_UNIFIED_MEMORY"] = "1"

In [3]:
## import packages
import jax
import jax.numpy as jnp
import einops
from scipy import special as ss
import numpy as np
import torch
import sys
from functools import partial

In [4]:
sys.setrecursionlimit(10**6)

In [5]:
print(jax.devices())
print(f"The Device: {jax.lib.xla_bridge.get_backend().platform}")

[StreamExecutorGpuDevice(id=0, process_index=0, slice_index=0)]
The Device: gpu


In [6]:
print(f"MPS enabled: {torch.backends.mps.is_available()}")

MPS enabled: False


In [7]:
torch.set_printoptions(linewidth=150)
np.set_printoptions(linewidth=150)
jnp.set_printoptions(linewidth=150)

In [8]:
seed = 1701
key = jax.random.PRNGKey(seed)

In [9]:
np.random.seed(seed)

In [10]:
num_copies = 10
subkeys = jax.random.split(key, num=num_copies)
key = subkeys[0]

In [11]:
# @jax.jit
def legendre_recurrence_old(n, x, max_n):
    """
    Computes the Legendre polynomial of degree n at point x using the recurrence relation.

    Args:
    n: int, the degree of the Legendre polynomial.
    x: float, the point at which to evaluate the polynomial.
    max_n: int, the maximum degree of n in the batch.

    Returns:
    The value of the Legendre polynomial of degree n at point x.
    """
    # Initialize the array to store the Legendre polynomials for all degrees from 0 to max_n
    p = jnp.zeros((max_n + 1,) + x.shape)
    p = p.at[0].set(1.0)  # Set the 0th degree Legendre polynomial
    p = p.at[1].set(x)  # Set the 1st degree Legendre polynomial

    # Compute the Legendre polynomials for degrees 2 to max_n using the recurrence relation
    def body_fun(i, p):
        p_i = ((2 * i - 1) * x * p[i - 1] - (i - 1) * p[i - 2]) / i
        return p.at[i].set(p_i)

    p = jax.lax.fori_loop(2, max_n + 1, body_fun, p)

    return p[n]

In [12]:
# @jax.jit
def eval_legendre_old(n, x, out=None):
    """
    Evaluates the Legendre polynomials of degrees specified in the input array n at the points specified in the input array x.

    Args:
    n: array-like, the degrees of the Legendre polynomials.
    x: array-like, the points at which to evaluate the polynomials.
    out: optional, an output array to store the results.

    Returns:
    An array containing the Legendre polynomial values of the specified degrees at the specified points.
    """
    jax.debug.print("\nStarting eval_legendre_old")
    
    n = jnp.asarray(n)
    x = jnp.asarray(x)
    n_max = n.max()
    
    jax.debug.print("x: {x}", x=x)
    jax.debug.print("x shape: {x}", x=x.shape)
    
    jax.debug.print("n: {x}", x=n)
    jax.debug.print("n shape: {x}", x=n.shape)

    if n.ndim == 1 and x.ndim == 1:
        p = jax.vmap(
            lambda ni: jax.vmap(lambda xi: legendre_recurrence_old(ni, xi, n_max))(x)
        )(n)
        p = jnp.diagonal(
            p
        )  # Get the diagonal elements to match the scipy.special.eval_legendre output
    else:
        p = jax.vmap(
            lambda ni: jax.vmap(lambda xi: legendre_recurrence_old(ni, xi, n_max))(x)
        )(n)

    if out is not None:
        out = jnp.asarray(out)
        out = jnp.copy(p, out=out)
        return out
    else:
        return jnp.squeeze(p)

In [13]:
def legendre_recurrence(n, x, n_max):
    """
    Compute the Legendre polynomials up to degree n_max at a given point or array of points x.

    The function employs the recurrence relation for Legendre polynomials. The Legendre polynomials
    are orthogonal on the interval [-1,1] and are used in a wide array of scientific and mathematical applications.
    This function returns a series of Legendre polynomials evaluated at the point(s) x, up to the degree n_max.

    Parameters:
    n_max (int): The highest degree of Legendre polynomial to compute. Must be a non-negative integer.
    x (jnp.ndarray): The point(s) at which the Legendre polynomials are to be evaluated. Can be a single
                     point (float) or an array of points.

    Returns:
    jnp.ndarray: A sequence of Legendre polynomial values of shape (n_max+1,) + x.shape, evaluated at point(s) x.
                 The i-th entry of the output array corresponds to the Legendre polynomial of degree i.

    Notes:
    The first two Legendre polynomials are initialized as P_0(x) = 1 and P_1(x) = x. The subsequent polynomials
    are computed using the recurrence relation:
    P_{n+1}(x) = ((2n + 1) * x * P_n(x) - n * P_{n-1}(x)) / (n + 1).
    """

    p_init = jnp.zeros((2,) + x.shape)
    p_init = p_init.at[0].set(1.0)  # Set the 0th degree Legendre polynomial
    p_init = p_init.at[1].set(x)  # Set the 1st degree Legendre polynomial

    def body_fun(carry, _):
        i, (p_im1, p_i) = carry
        p_ip1 = ((2 * i + 1) * x * p_i - i * p_im1) / (i + 1)

        return ((i + 1).astype(int), (p_i, p_ip1)), p_ip1

    (_, (_, _)), p_n = jax.lax.scan(
        f=body_fun, init=(1, (p_init[0], p_init[1])), xs=(None), length=(n_max - 1)
    )
    p_n = jnp.concatenate((p_init, p_n), axis=0)

    return p_n[n]

In [14]:
# def eval_legendre(n, x):
#     """
#     Evaluate Legendre polynomials of specified degrees at provided point(s).

#     This function makes use of a vectorized version of the Legendre polynomial recurrence relation to
#     compute the necessary polynomials up to the maximum degree found in 'n'. It then selects and returns
#     the values of the polynomials at the degrees specified in 'n' and evaluated at the points in 'x'.

#     Parameters:
#         n (jnp.ndarray): An array of integer degrees for which the Legendre polynomials are to be evaluated.
#                         Each element must be a non-negative integer and the array can be of any shape.
#         x (jnp.ndarray): The point(s) at which the Legendre polynomials are to be evaluated. Can be a single
#                         point (float) or an array of points. The shape must be broadcastable to the shape of 'n'.

#     Returns:
#         jnp.ndarray: An array of Legendre polynomial values. The output has the same shape as 'n' and 'x' after broadcasting.
#                     The i-th entry corresponds to the Legendre polynomial of degree 'n[i]' evaluated at point 'x[i]'.

#     Notes:
#         This function makes use of the vectorized map (vmap) functionality in JAX to efficiently compute and select
#         the necessary Legendre polynomial values.
#     """

#     n_max = n.max()  # + 1
#     p = jax.vmap(legendre_recurrence, in_axes=(None, 0))(n_max, x)
#     p = jnp.squeeze(p)

#     p_selected = jax.vmap(lambda pi, ni: pi[ni], in_axes=(None, 1))(p, n)
#     return p

In [15]:
def eval_legendre(n, x):
    """
    Evaluate Legendre polynomials of specified degrees at provided point(s).

    This function makes use of a vectorized version of the Legendre polynomial recurrence relation to
    compute the necessary polynomials up to the maximum degree found in 'n'. It then selects and returns
    the values of the polynomials at the degrees specified in 'n' and evaluated at the points in 'x'.

    Parameters:
        n (jnp.ndarray): An array of integer degrees for which the Legendre polynomials are to be evaluated.
                        Each element must be a non-negative integer and the array can be of any shape.
        x (jnp.ndarray): The point(s) at which the Legendre polynomials are to be evaluated. Can be a single
                        point (float) or an array of points. The shape must be broadcastable to the shape of 'n'.

    Returns:
        jnp.ndarray: An array of Legendre polynomial values. The output has the same shape as 'n' and 'x' after broadcasting.
                    The i-th entry corresponds to the Legendre polynomial of degree 'n[i]' evaluated at point 'x[i]'.

    Notes:
        This function makes use of the vectorized map (vmap) functionality in JAX to efficiently compute and select
        the necessary Legendre polynomial values.
    """
    
    jax.debug.print("\nStarting eval_legendre")

    n = jnp.asarray(n)
    x = jnp.asarray(x)
    n_max = n.max()

    if n.ndim == 1 and x.ndim == 1:
        p = jax.vmap(
            lambda ni: jax.vmap(lambda xi: legendre_recurrence(ni, xi, n_max))(x)
        )(n)
        # p = jax.vmap(legendre_recurrence, in_axes=(None, 0))(n, x, n_max)
        p = jnp.diagonal(
            p
        )  # Get the diagonal elements to match the scipy.special.eval_legendre output
    else:
        p = jax.vmap(
            lambda ni: jax.vmap(lambda xi: legendre_recurrence(ni, xi, n_max))(x)
        )(n)
        # p = jax.vmap(legendre_recurrence, in_axes=(None, 0))(n, x, n_max)

    return jnp.squeeze(p)

In [16]:
def test_eval_legendre():
    n = 10
    x = np.arange(0.0, 2, 0.1)

    print(f"x = {x}")
    print(f"x shape = {x.shape}")

    y = ss.eval_legendre(np.arange(n)[:, None], x[None, :])

    print(f"y = {y}")
    print(f"y shape = {y.shape}")

    y_pred = eval_legendre(jnp.arange(n)[:, None], jnp.array(x)[None, :])

    print(f"y_pred = {y_pred}")
    print(f"y_pred shape = {y_pred.shape}")
    
    y_pred_old = eval_legendre_old(jnp.arange(n)[:, None], jnp.array(x)[None, :])
    
    print(f"y_pred_old = {y_pred_old}")
    print(f"y_pred_old shape = {y_pred_old.shape}")

    assert np.allclose(y_pred, y, rtol=1e-5, atol=1e-5), "Results do not match"
    print("Results match")
    
    assert np.allclose(y_pred_old, y, rtol=1e-5, atol=1e-5), "Results do not match"
    print("Results match")
    
    assert np.allclose(y_pred, y_pred_old, rtol=1e-5, atol=1e-5), "Results do not match"
    print("Results match")

To be transparent, I used chatGPT to help with this. This seems to be a nice work around for the time being:

```python
def legendre_recurrence(n, x, max_n):
    """
    Computes the Legendre polynomial of degree n at point x using the recurrence relation.

    Args:
    n: int, the degree of the Legendre polynomial.
    x: float, the point at which to evaluate the polynomial.
    max_n: int, the maximum degree of n in the batch.

    Returns:
    The value of the Legendre polynomial of degree n at point x.
    """
    # Initialize the array to store the Legendre polynomials for all degrees from 0 to max_n
    p = jnp.zeros((max_n + 1,) + x.shape)
    p = p.at[0].set(1.0)  # Set the 0th degree Legendre polynomial
    p = p.at[1].set(x)  # Set the 1st degree Legendre polynomial

    # Compute the Legendre polynomials for degrees 2 to max_n using the recurrence relation
    def body_fun(i, p):
        p_i = ((2 * i - 1) * x * p[i - 1] - (i - 1) * p[i - 2]) / i
        return p.at[i].set(p_i)

    p = jax.lax.fori_loop(2, max_n + 1, body_fun, p)

    return p[n]
```

```python
def eval_legendre(n, x, out=None):
    """
    Evaluates the Legendre polynomials of degrees specified in the input array n at the points specified in the input array x.

    Args:
    n: array-like, the degrees of the Legendre polynomials.
    x: array-like, the points at which to evaluate the polynomials.
    out: optional, an output array to store the results.

    Returns:
    An array containing the Legendre polynomial values of the specified degrees at the specified points.
    """
    n = jnp.asarray(n)
    x = jnp.asarray(x)
    max_n = n.max()

    if n.ndim == 1 and x.ndim == 1:
        p = jax.vmap(
            lambda ni: jax.vmap(lambda xi: legendre_recurrence(ni, xi, max_n))(x)
        )(n)
        p = jnp.diagonal(
            p
        )  # Get the diagonal elements to match the scipy.special.eval_legendre output
    else:
        p = jax.vmap(
            lambda ni: jax.vmap(lambda xi: legendre_recurrence(ni, xi, max_n))(x)
        )(n)

    if out is not None:
        out = jnp.asarray(out)
        out = jnp.copy(p, out=out)
        return out
    else:
        return p
```

```python
def test_eval_legendre():
    n = np.array([0, 1, 2, 3])

    print(f"n = {n}")
    print(f"n shape = {n.shape}")

    x = np.linspace(-1, 1, n.shape[0])

    print(f"x = {x}")
    print(f"x shape = {x.shape}")

    y_pred = eval_legendre(n, x)
    y = ss.eval_legendre(n, x)

    print(f"y_pred = {y_pred}")
    print(f"y_pred shape = {y_pred.shape}")
    print(f"y = {y}")
    print(f"y shape = {y.shape}")

    assert np.allclose(y_pred, y, rtol=1e-5, atol=1e-8), "Results do not match"
    print("Results match")
```

output:
```bash
n = [0 1 2 3]
n shape = (4,)
x = [-1.         -0.33333333  0.33333333  1.        ]
x shape = (4,)
y_pred = [ 1.         -0.33333334 -0.3333333   1.        ]
y_pred shape = (4,)
y = [ 1.         -0.33333333 -0.33333333  1.        ]
y shape = (4,)
Results match
```

In [17]:
test_eval_legendre()

x = [0.  0.1 0.2 0.3 0.4 0.5 0.6 0.7 0.8 0.9 1.  1.1 1.2 1.3 1.4 1.5 1.6 1.7 1.8 1.9]
x shape = (20,)
y = [[ 1.00000000e+00  1.00000000e+00  1.00000000e+00  1.00000000e+00  1.00000000e+00  1.00000000e+00  1.00000000e+00  1.00000000e+00  1.00000000e+00
   1.00000000e+00  1.00000000e+00  1.00000000e+00  1.00000000e+00  1.00000000e+00  1.00000000e+00  1.00000000e+00  1.00000000e+00  1.00000000e+00
   1.00000000e+00  1.00000000e+00]
 [ 0.00000000e+00  1.00000000e-01  2.00000000e-01  3.00000000e-01  4.00000000e-01  5.00000000e-01  6.00000000e-01  7.00000000e-01  8.00000000e-01
   9.00000000e-01  1.00000000e+00  1.10000000e+00  1.20000000e+00  1.30000000e+00  1.40000000e+00  1.50000000e+00  1.60000000e+00  1.70000000e+00
   1.80000000e+00  1.90000000e+00]
 [-5.00000000e-01 -4.85000000e-01 -4.40000000e-01 -3.65000000e-01 -2.60000000e-01 -1.25000000e-01  4.00000000e-02  2.35000000e-01  4.60000000e-01
   7.15000000e-01  1.00000000e+00  1.31500000e+00  1.66000000e+00  2.03500000e+00  2.44000000e