# HiPPO Matrices
---

## Table of Contents
* [Loading In Necessary Packages](#load-packages)
    * [Instantiate The HiPPO Matrix](#instantiate-the-hippo-matrix)
        * [Translated Legendre (LegT)](#translated-legendre-legt)
            * [LegT](#legt)
            * [LMU](#lmu)
        * [Translated Laguerre (LagT)](#translated-laguerre-lagt)
        * [Scaled Legendre (LegS)](#scaled-legendre-legs)
        * [Fourier Basis](#fourier-basis)
            * [Fourier Recurrent Unit (FRU)](#fourier-recurrent-unit-fru)
            * [Truncated Fourier (FouT)](#truncated-fourier-fout)
            * [Fourier With Decay (FourD)](#fourier-with-decay-fourd)
    * [Utilities For Gu HiPPO Operator](#utilities-for-gu-hippo-operator)
    * [Gu's HiPPO LegT Operator](#gus-hippo-legt-operator)
    * [Gu's Scale invariant HiPPO LegS Operator](#gus-scale-invariant-hippo-legs-operator)
    * [Implementation Of General HiPPO Operator](#implementation-of-general-hippo-operator)
    * [Output](#output)
---


## Load Packages

In [1]:
import os
import sys
module_path = os.path.abspath(os.path.join('../../../'))
if module_path not in sys.path:
    sys.path.append(module_path)

In [2]:
## import packages
import jax
import jax.numpy as jnp

from flax import linen as jnn

from jax.nn.initializers import lecun_normal, uniform
from jax.numpy.linalg import eig, inv, matrix_power
from jax.scipy.signal import convolve

# import modules 
from src.models.hippo.gu_transition import GuTransMatrix
from src.data.process import moving_window, rolling_window


import requests

from scipy import linalg as la
from scipy import signal
from scipy import special as ss

import math

print(jax.devices())
print(f"The Device: {jax.lib.xla_bridge.get_backend().platform}")

[StreamExecutorGpuDevice(id=0, process_index=0, slice_index=0)]
The Device: gpu


In [3]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data as data
import numpy as np
import matplotlib.pyplot as plt
print(f"MPS enabled: {torch.backends.mps.is_available()}")


MPS enabled: False


In [4]:
torch.set_printoptions(linewidth=150)
np.set_printoptions(linewidth=150)
jnp.set_printoptions(linewidth=150)

In [5]:
# N = 8


In [6]:
seed = 1701
key = jax.random.PRNGKey(seed)


In [7]:
num_copies = 5
rng, key2, key3, key4, key5 = jax.random.split(key, num=num_copies)


## Instantiate The HiPPO Matrix

In [8]:
class TransMatrix:
    def __init__(
        self, N, measure="legs", lambda_n=1, fourier_type="fru", alpha=0, beta=1
    ):
        """
        Instantiates the HiPPO matrix of a given order using a particular measure.
        Args:
            N (int): Order of coefficients to describe the orthogonal polynomial that is the HiPPO projection.
            v (str): choose between this repo's implementation or hazy research's implementation.
            measure (str):
                choose between
                    - HiPPO w/ Translated Legendre (LegT) - legt
                    - HiPPO w/ Translated Laguerre (LagT) - lagt
                    - HiPPO w/ Scaled Legendre (LegS) - legs
                    - HiPPO w/ Fourier basis - fourier
                        - FRU: Fourier Recurrent Unit
                        - FouT: Translated Fourier
            lambda_n (int): The amount of tilt applied to the HiPPO-LegS basis, determines between LegS and LMU.
            fourier_type (str): chooses between the following:
                - FRU: Fourier Recurrent Unit - fru
                - FouT: Translated Fourier - fout
                - FourD: Fourier Decay - fourd
            alpha (float): The order of the Laguerre basis.
            beta (float): The scale of the Laguerre basis.

        Returns:
            A (jnp.ndarray): The HiPPO matrix multiplied by -1.
            B (jnp.ndarray): The other corresponding state space matrix.

        """
        A = None
        B = None
        if measure == "legt":
            A, B = self.build_LegT(N=N, lambda_n=lambda_n)

        elif measure == "lagt":
            A, B = self.build_LagT(alpha=alpha, beta=beta, N=N)

        elif measure == "legs":
            A, B = self.build_LegS(N=N)

        elif measure == "fourier":
            A, B = self.build_Fourier(N=N, fourier_type=fourier_type)

        elif measure == "random":
            A = jnp.random.randn(N, N) / N
            B = jnp.random.randn(N, 1)

        elif measure == "diagonal":
            A = -jnp.diag(jnp.exp(jnp.random.randn(N)))
            B = jnp.random.randn(N, 1)

        else:
            raise ValueError("Invalid HiPPO type")

        self.A_matrix = (A.copy()).astype(jnp.float32)
        self.B_matrix = (B.copy()).astype(jnp.float32)

    # Translated Legendre (LegT) - vectorized
    @staticmethod
    def build_LegT(N, lambda_n=1):
        """
        The, vectorized implementation of the, measure derived from the translated Legendre basis.

        Args:
            N (int): Order of coefficients to describe the orthogonal polynomial that is the HiPPO projection.
            legt_type (str): Choice between the two different tilts of basis.
                - legt: translated Legendre - 'legt'
                - lmu: Legendre Memory Unit - 'lmu'

        Returns:
            A (jnp.ndarray): The A HiPPO matrix.
            B (jnp.ndarray): The B HiPPO matrix.

        """
        q = jnp.arange(N, dtype=jnp.float32)
        k, n = jnp.meshgrid(q, q)
        case = jnp.power(-1.0, (n - k))
        A = None
        B = None

        if lambda_n == 1:
            A_base = jnp.sqrt(2 * n + 1) * jnp.sqrt(2 * k + 1)
            pre_D = jnp.sqrt(jnp.diag(2 * q + 1))
            B = D = jnp.diag(pre_D)[:, None]
            A = jnp.where(
                k <= n, A_base, A_base * case
            )  # if n >= k, then case_2 * A_base is used, otherwise A_base

        elif lambda_n == 2:  # (jnp.sqrt(2*n+1) * jnp.power(-1, n)):
            A_base = 2 * n + 1
            B = jnp.diag((2 * q + 1) * jnp.power(-1, n))[:, None]
            A = jnp.where(
                k <= n, A_base * case, A_base
            )  # if n >= k, then case_2 * A_base is used, otherwise A_base

        return -A, B

    # Translated Laguerre (LagT) - non-vectorized
    @staticmethod
    def build_LagT(alpha, beta, N):
        """
        The, vectorized implementation of the, measure derived from the translated Laguerre basis.

        Args:
            alpha (float): The order of the Laguerre basis.
            beta (float): The scale of the Laguerre basis.
            N (int): Order of coefficients to describe the orthogonal polynomial that is the HiPPO projection.

        Returns:
            A (jnp.ndarray): The A HiPPO matrix.
            B (jnp.ndarray): The B HiPPO matrix.

        """
        L = jnp.exp(
            0.5
            * (ss.gammaln(jnp.arange(N) + alpha + 1) - ss.gammaln(jnp.arange(N) + 1))
        )
        inv_L = 1.0 / L[:, None]
        pre_A = (jnp.eye(N) * ((1 + beta) / 2)) + jnp.tril(jnp.ones((N, N)), -1)
        pre_B = ss.binom(alpha + jnp.arange(N), jnp.arange(N))[:, None]

        A = -inv_L * pre_A * L[None, :]
        B = (
            jnp.exp(-0.5 * ss.gammaln(1 - alpha))
            * jnp.power(beta, (1 - alpha) / 2)
            * inv_L
            * pre_B
        )

        return A, B

    # Scaled Legendre (LegS) vectorized
    @staticmethod
    def build_LegS(N):
        """
        The, vectorized implementation of the, measure derived from the Scaled Legendre basis.

        Args:
            N (int): Order of coefficients to describe the orthogonal polynomial that is the HiPPO projection.

        Returns:
            A (jnp.ndarray): The A HiPPO matrix.
            B (jnp.ndarray): The B HiPPO matrix.

        """
        q = jnp.arange(N, dtype=jnp.float32)
        k, n = jnp.meshgrid(q, q)
        pre_D = jnp.sqrt(jnp.diag(2 * q + 1))
        B = D = jnp.diag(pre_D)[:, None]

        A_base = jnp.sqrt(2 * n + 1) * jnp.sqrt(2 * k + 1)

        A = jnp.where(n > k, A_base, jnp.where(n == k, n + 1, 0.0))

        return -A.astype(jnp.float32), B.astype(jnp.float32)

    # Fourier Basis OPs and functions - vectorized
    @staticmethod
    def build_Fourier(N, fourier_type="fru"):
        """
        Vectorized measure implementations derived from fourier basis.

        Args:
            N (int): Order of coefficients to describe the orthogonal polynomial that is the HiPPO projection.
            fourier_type (str): The type of Fourier measure.
                - FRU: Fourier Recurrent Unit - fru
                - FouT: truncated Fourier - fout
                - fouD: decayed Fourier - foud

        Returns:
            A (jnp.ndarray): The A HiPPO matrix.
            B (jnp.ndarray): The B HiPPO matrix.

        """
        A = jnp.diag(
            jnp.stack([jnp.zeros(N // 2), jnp.zeros(N // 2)], axis=-1).reshape(-1)[1:],
            1,
        )
        B = jnp.zeros(A.shape[1], dtype=jnp.float32)

        B = B.at[0::2].set(jnp.sqrt(2))
        B = B.at[0].set(1)

        q = jnp.arange(A.shape[1], dtype=jnp.float32)
        k, n = jnp.meshgrid(q, q)

        n_odd = n % 2 == 0
        k_odd = k % 2 == 0

        case_1 = (n == k) & (n == 0)
        case_2_3 = ((k == 0) & (n_odd)) | ((n == 0) & (k_odd))
        case_4 = (n_odd) & (k_odd)
        case_5 = (n - k == 1) & (k_odd)
        case_6 = (k - n == 1) & (n_odd)

        if fourier_type == "fru":  # Fourier Recurrent Unit (FRU) - vectorized
            A = jnp.where(
                case_1,
                -1.0,
                jnp.where(
                    case_2_3,
                    -jnp.sqrt(2),
                    jnp.where(
                        case_4,
                        -2,
                        jnp.where(
                            case_5,
                            jnp.pi * (n // 2),
                            jnp.where(case_6, -jnp.pi * (k // 2), 0.0),
                        ),
                    ),
                ),
            )

        elif fourier_type == "fout":  # truncated Fourier (FouT) - vectorized
            A = jnp.where(
                case_1,
                -1.0,
                jnp.where(
                    case_2_3,
                    -jnp.sqrt(2),
                    jnp.where(
                        case_4,
                        -2,
                        jnp.where(
                            case_5,
                            jnp.pi * (n // 2),
                            jnp.where(case_6, -jnp.pi * (k // 2), 0.0),
                        ),
                    ),
                ),
            )

            A = 2 * A
            B = 2 * B

        elif fourier_type == "foud":
            A = jnp.where(
                case_1,
                -1.0,
                jnp.where(
                    case_2_3,
                    -jnp.sqrt(2),
                    jnp.where(
                        case_4,
                        -2,
                        jnp.where(
                            case_5,
                            2 * jnp.pi * (n // 2),
                            jnp.where(case_6, 2 * -jnp.pi * (k // 2), 0.0),
                        ),
                    ),
                ),
            )

            A = 0.5 * A
            B = 0.5 * B

        B = B[:, None]

        return A.astype(jnp.float32), B.astype(jnp.float32)

## Translated Legendre (LegT)

### LegT

In [9]:
def test_LegT():
    legt_matrices = TransMatrix(N=8, measure="legt", lambda_n=1.0)
    A, B = legt_matrices.A_matrix, legt_matrices.B_matrix
    gu_legt_matrices = GuTransMatrix(N=8, measure="legt", lambda_n=1.0)
    gu_A, gu_B = gu_legt_matrices.A_matrix, gu_legt_matrices.B_matrix
    print(f"A:\n", A)
    print(f"Gu's A:\n", gu_A)
    print(f"B:\n", B)
    print(f"Gu's B:\n", gu_B)
    assert jnp.allclose(A, gu_A)
    assert jnp.allclose(B, gu_B)
    

In [10]:
test_LegT()

A:
 [[ -1.          1.7320508  -2.2360678   2.6457512  -3.          3.3166246  -3.6055512   3.8729832]
 [ -1.7320508  -3.          3.872983   -4.5825753   5.196152   -5.744562    6.244998   -6.708204 ]
 [ -2.2360678  -3.872983   -4.999999    5.916079   -6.7082033   7.4161973  -8.062257    8.660253 ]
 [ -2.6457512  -4.5825753  -5.916079   -6.9999995   7.937254   -8.774963    9.5393915 -10.24695  ]
 [ -3.         -5.196152   -6.7082033  -7.937254   -9.          9.949874  -10.816654   11.61895  ]
 [ -3.3166246  -5.744562   -7.4161973  -8.774963   -9.949874  -10.999999   11.958261  -12.845232 ]
 [ -3.6055512  -6.244998   -8.062257   -9.5393915 -10.816654  -11.958261  -13.         13.964239 ]
 [ -3.8729832  -6.708204   -8.660253  -10.24695   -11.61895   -12.845232  -13.964239  -14.999999 ]]
Gu's A:
 [[ -1.          1.7320508  -2.2360678   2.6457512  -3.          3.3166246  -3.6055512   3.8729832]
 [ -1.7320508  -3.          3.872983   -4.5825753   5.196152   -5.744562    6.244998   -6.70820

### LMU

In [11]:
def test_LMU():
    lmu_matrices = TransMatrix(
        N=8, measure="legt", lambda_n=2.0
    )  # change lambda so resulting matrix is in the form of LMU
    A, B = lmu_matrices.A_matrix, lmu_matrices.B_matrix
    gu_lmu_matrices = GuTransMatrix(
        N=8, measure="legt", lambda_n=2.0
    )  # change lambda so resulting matrix is in the form of LMU
    gu_A, gu_B = gu_lmu_matrices.A_matrix, gu_lmu_matrices.B_matrix
    print(f"A:\n", A)
    print(f"Gu's A:\n", gu_A)
    print(f"B:\n", B)
    print(f"Gu's B:\n", gu_B)
    assert jnp.allclose(A, gu_A)
    assert jnp.allclose(B, gu_B)
    

In [12]:
test_LMU()

A:
 [[ -1.  -1.  -1.  -1.  -1.  -1.  -1.  -1.]
 [  3.  -3.  -3.  -3.  -3.  -3.  -3.  -3.]
 [ -5.   5.  -5.  -5.  -5.  -5.  -5.  -5.]
 [  7.  -7.   7.  -7.  -7.  -7.  -7.  -7.]
 [ -9.   9.  -9.   9.  -9.  -9.  -9.  -9.]
 [ 11. -11.  11. -11.  11. -11. -11. -11.]
 [-13.  13. -13.  13. -13.  13. -13. -13.]
 [ 15. -15.  15. -15.  15. -15.  15. -15.]]
Gu's A:
 [[ -1.  -1.  -1.  -1.  -1.  -1.  -1.  -1.]
 [  3.  -3.  -3.  -3.  -3.  -3.  -3.  -3.]
 [ -5.   5.  -5.  -5.  -5.  -5.  -5.  -5.]
 [  7.  -7.   7.  -7.  -7.  -7.  -7.  -7.]
 [ -9.   9.  -9.   9.  -9.  -9.  -9.  -9.]
 [ 11. -11.  11. -11.  11. -11. -11. -11.]
 [-13.  13. -13.  13. -13.  13. -13. -13.]
 [ 15. -15.  15. -15.  15. -15.  15. -15.]]
B:
 [[  1.]
 [ -3.]
 [  5.]
 [ -7.]
 [  9.]
 [-11.]
 [ 13.]
 [-15.]]
Gu's B:
 [[  1.]
 [ -3.]
 [  5.]
 [ -7.]
 [  9.]
 [-11.]
 [ 13.]
 [-15.]]


## Translated Laguerre (LagT)

In [13]:
def test_LagT():
    lagt_matrices = TransMatrix(
        N=8,
        measure="lagt",
        alpha=0.0,  # change resulting tilt through alpha and beta
        beta=1.0,
    )  # change resulting tilt through alpha and beta
    A, B = lagt_matrices.A_matrix, lagt_matrices.B_matrix
    gu_lagt_matrices = GuTransMatrix(
        N=8,
        measure="lagt",
        alpha=0.0,  # change resulting tilt through alpha and beta
        beta=1.0,
    )  # change resulting tilt through alpha and beta
    gu_A, gu_B = gu_lagt_matrices.A_matrix, gu_lagt_matrices.B_matrix
    print(f"A:\n", A)
    print(f"Gu's A:\n", gu_A)
    print(f"B:\n", B)
    print(f"Gu's B:\n", gu_B)
    assert jnp.allclose(A, gu_A)
    assert jnp.allclose(B, gu_B)

In [14]:
test_LagT()

A:
 [[-1.         -0.         -0.         -0.         -0.         -0.         -0.         -0.        ]
 [-1.         -1.         -0.         -0.         -0.         -0.         -0.         -0.        ]
 [-1.         -1.         -1.         -0.         -0.         -0.         -0.         -0.        ]
 [-1.         -1.         -1.         -1.         -0.         -0.         -0.         -0.        ]
 [-1.         -1.         -1.         -1.         -1.         -0.         -0.         -0.        ]
 [-1.         -1.         -1.         -1.         -1.         -1.         -0.         -0.        ]
 [-1.         -1.         -1.         -1.         -1.         -1.         -1.         -0.        ]
 [-0.99999976 -0.99999976 -0.99999976 -0.99999976 -0.99999976 -0.99999976 -0.99999976 -1.        ]]
Gu's A:
 [[-1.         -0.         -0.         -0.         -0.         -0.         -0.         -0.        ]
 [-1.         -1.         -0.         -0.         -0.         -0.         -0.         -0.      

## Scaled Legendre (LegS)

In [15]:
def test_LegS():
    legs_matrices = TransMatrix(N=8, measure="legs")
    A, B = legs_matrices.A_matrix, legs_matrices.B_matrix
    gu_legs_matrices = GuTransMatrix(N=8, measure="legs")
    gu_A, gu_B = gu_legs_matrices.A_matrix, gu_legs_matrices.B_matrix
    print(f"A:\n", A)
    print(f"Gu's A:\n", gu_A)
    print(f"B:\n", B)
    print(f"Gu's B:\n", gu_B)
    assert jnp.allclose(A, gu_A)
    assert jnp.allclose(B, gu_B)

In [16]:
test_LegS()

A:
 [[ -1.         -0.         -0.         -0.         -0.         -0.         -0.         -0.       ]
 [ -1.7320508  -2.         -0.         -0.         -0.         -0.         -0.         -0.       ]
 [ -2.2360678  -3.872983   -3.         -0.         -0.         -0.         -0.         -0.       ]
 [ -2.6457512  -4.5825753  -5.916079   -4.         -0.         -0.         -0.         -0.       ]
 [ -3.         -5.196152   -6.7082033  -7.937254   -5.         -0.         -0.         -0.       ]
 [ -3.3166246  -5.744562   -7.4161973  -8.774963   -9.949874   -6.         -0.         -0.       ]
 [ -3.6055512  -6.244998   -8.062257   -9.5393915 -10.816654  -11.958261   -7.         -0.       ]
 [ -3.8729832  -6.708204   -8.660253  -10.24695   -11.61895   -12.845232  -13.964239   -8.       ]]
Gu's A:
 [[ -1.          0.          0.          0.          0.          0.          0.          0.       ]
 [ -1.7320508  -1.9999999   0.          0.          0.          0.          0.          0.     

## Fourier Basis

### Fourier Recurrent Unit (FRU)

In [17]:
def test_FRU():
    fru_matrices = TransMatrix(N=8, measure="fourier", fourier_type="fru")
    A, B = fru_matrices.A_matrix, fru_matrices.B_matrix
    gu_fru_matrices = GuTransMatrix(N=8, measure="fourier", fourier_type="fru")
    gu_A, gu_B = gu_fru_matrices.A_matrix, gu_fru_matrices.B_matrix
    print(f"A:\n", A)
    print(f"Gu's A:\n", gu_A)
    print(f"B:\n", B)
    print(f"Gu's B:\n", gu_B)
    assert jnp.allclose(A, gu_A)
    assert jnp.allclose(B, gu_B)

In [18]:
test_FRU()

A:
 [[-1.        -0.        -1.4142135  0.        -1.4142135  0.        -1.4142135  0.       ]
 [ 0.         0.         0.         0.         0.         0.         0.         0.       ]
 [-1.4142135  0.        -2.        -3.1415927 -2.         0.        -2.         0.       ]
 [ 0.         0.         3.1415927  0.         0.         0.         0.         0.       ]
 [-1.4142135  0.        -2.         0.        -2.        -6.2831855 -2.         0.       ]
 [ 0.         0.         0.         0.         6.2831855  0.         0.         0.       ]
 [-1.4142135  0.        -2.         0.        -2.         0.        -2.        -9.424778 ]
 [ 0.         0.         0.         0.         0.         0.         9.424778   0.       ]]
Gu's A:
 [[-1.         0.        -1.4142135  0.        -1.4142135  0.        -1.4142135  0.       ]
 [ 0.         0.         0.         0.         0.         0.         0.         0.       ]
 [-1.4142135  0.        -1.9999999 -3.1415927 -1.9999999  0.        -1.99999

### Truncated Fourier (FouT)

In [19]:
def test_FouT():
    fout_matrices = TransMatrix(N=8, measure="fourier", fourier_type="fout")
    A, B = fout_matrices.A_matrix, fout_matrices.B_matrix
    gu_fout_matrices = GuTransMatrix(N=8, measure="fourier", fourier_type="fout")
    gu_A, gu_B = gu_fout_matrices.A_matrix, gu_fout_matrices.B_matrix
    print(f"A:\n", A)
    print(f"Gu's A:\n", gu_A)
    print(f"B:\n", B)
    print(f"Gu's B:\n", gu_B)
    assert jnp.allclose(A, gu_A)
    assert jnp.allclose(B, gu_B)

In [20]:
test_FouT()

A:
 [[ -2.         -0.         -2.828427    0.         -2.828427    0.         -2.828427    0.       ]
 [  0.          0.          0.          0.          0.          0.          0.          0.       ]
 [ -2.828427    0.         -4.         -6.2831855  -4.          0.         -4.          0.       ]
 [  0.          0.          6.2831855   0.          0.          0.          0.          0.       ]
 [ -2.828427    0.         -4.          0.         -4.        -12.566371   -4.          0.       ]
 [  0.          0.          0.          0.         12.566371    0.          0.          0.       ]
 [ -2.828427    0.         -4.          0.         -4.          0.         -4.        -18.849556 ]
 [  0.          0.          0.          0.          0.          0.         18.849556    0.       ]]
Gu's A:
 [[ -2.          0.         -2.828427    0.         -2.828427    0.         -2.828427    0.       ]
 [  0.          0.          0.          0.          0.          0.          0.          0.     

### Fourier With Decay (FourD)

In [21]:
def test_FouD():
    the_measure = "fourier"
    fourier_type = "foud"
    foud_matrices = TransMatrix(N=8, measure=the_measure, fourier_type=fourier_type)
    A, B = foud_matrices.A_matrix, foud_matrices.B_matrix
    gu_foud_matrices = GuTransMatrix(N=8, measure="fourier", fourier_type="foud")
    gu_A, gu_B = gu_foud_matrices.A_matrix, gu_foud_matrices.B_matrix
    print(f"A:\n", A)
    print(f"Gu's A:\n", gu_A)
    print(f"B:\n", B)
    print(f"Gu's B:\n", gu_B)
    assert jnp.allclose(A, gu_A)
    assert jnp.allclose(B, gu_B)

In [22]:
test_FouD()

A:
 [[-0.5        -0.         -0.70710677  0.         -0.70710677  0.         -0.70710677  0.        ]
 [ 0.          0.          0.          0.          0.          0.          0.          0.        ]
 [-0.70710677  0.         -1.         -3.1415927  -1.          0.         -1.          0.        ]
 [ 0.          0.          3.1415927   0.          0.          0.          0.          0.        ]
 [-0.70710677  0.         -1.          0.         -1.         -6.2831855  -1.          0.        ]
 [ 0.          0.          0.          0.          6.2831855   0.          0.          0.        ]
 [-0.70710677  0.         -1.          0.         -1.          0.         -1.         -9.424778  ]
 [ 0.          0.          0.          0.          0.          0.          9.424778    0.        ]]
Gu's A:
 [[-0.5         0.         -0.70710677  0.         -0.70710677  0.         -0.70710677  0.        ]
 [ 0.          0.          0.          0.          0.          0.          0.          0.      

## Utilities For Gu HiPPO Operator

In [23]:
def shift_up(a, s=None, drop=True, dim=0):
    assert dim == 0
    if s is None:
        s = torch.zeros_like(a[0, ...])
    s = s.unsqueeze(dim)
    if drop:
        a = a[:-1, ...]
    return torch.cat((s, a), dim=dim)

def interleave(a, b, uneven=False, dim=0):
    """ Interleave two tensors of same shape """
    # assert(a.shape == b.shape)
    assert dim == 0 # TODO temporary to make handling uneven case easier
    if dim < 0:
        dim = N + dim
    if uneven:
        a_ = a[-1:, ...]
        a = a[:-1, ...]
    c = torch.stack((a, b), dim+1)
    out_shape = list(a.shape)
    out_shape[dim] *= 2
    c = c.view(out_shape)
    if uneven:
        c = torch.cat((c, a_), dim=dim)
    return c

def batch_mult(A, u, has_batch=None):
    """ Matrix mult A @ u with special case to save memory if u has additional batch dim

    The batch dimension is assumed to be the second dimension
    A : (L, ..., N, N)
    u : (L, [B], ..., N)
    has_batch: True, False, or None. If None, determined automatically

    Output:
    x : (L, [B], ..., N)
      A @ u broadcasted appropriately
    """

    if has_batch is None:
        has_batch = len(u.shape) >= len(A.shape)

    if has_batch:
        u = u.permute([0] + list(range(2, len(u.shape))) + [1])
    else:
        u = u.unsqueeze(-1)
    v = (A @ u)
    if has_batch:
        v = v.permute([0] + [len(u.shape)-1] + list(range(1, len(u.shape)-1)))
    else:
        v = v[..., 0]
    return v



### Main unrolling functions

def unroll(A, u):
    """
    A : (..., N, N) # TODO I think this can't take batch dimension?
    u : (L, ..., N)
    output : x (..., N) # TODO a lot of these shapes are wrong
    x[i, ...] = A^{i} @ u[0, ...] + ... + A @ u[i-1, ...] + u[i, ...]
    """

    m = u.new_zeros(u.shape[1:])
    outputs = []
    for u_ in torch.unbind(u, dim=0):
        m = F.linear(m, A) + u_
        outputs.append(m)

    output = torch.stack(outputs, dim=0)
    return output


def parallel_unroll_recursive(A, u):
    """ Bottom-up divide-and-conquer version of unroll. """

    # Main recursive function
    def parallel_unroll_recursive_(A, u):
        if u.shape[0] == 1:
            return u

        u_evens = u[0::2, ...]
        u_odds = u[1::2, ...]

        # u2 = F.linear(u_evens, A) + u_odds
        u2 = (A @ u_evens.unsqueeze(-1)).squeeze(-1) + u_odds
        A2 = A @ A

        x_odds = parallel_unroll_recursive_(A2, u2)
        # x_evens = F.linear(shift_up(x_odds), A) + u_evens
        x_evens = (A @ shift_up(x_odds).unsqueeze(-1)).squeeze(-1) + u_evens

        x = interleave(x_evens, x_odds, dim=0)
        return x

    # Pad u to power of 2
    n = u.shape[0]
    m = int(math.ceil(math.log(n)/math.log(2)))
    N = 1 << m
    u = torch.cat((u, u.new_zeros((N-u.shape[0],) + u.shape[1:] )), dim=0)

    return parallel_unroll_recursive_(A, u)[:n, ...]



def parallel_unroll_recursive_br(A, u):
    """ Same as parallel_unroll_recursive but uses bit reversal for locality. """

    # Main recursive function
    def parallel_unroll_recursive_br_(A, u):
        n = u.shape[0]
        if n == 1:
            return u

        m = n//2
        u_0 = u[:m, ...]
        u_1 = u[m:, ...]

        u2 = F.linear(u_0, A) + u_1
        A2 = A @ A

        x_1 = parallel_unroll_recursive_br_(A2, u2)
        x_0 = F.linear(shift_up(x_1), A) + u_0

        # x = torch.cat((x_0, x_1), dim=0) # is there a way to do this with cat?
        x = interleave(x_0, x_1, dim=0)
        return x

    # Pad u to power of 2
    n = u.shape[0]
    m = int(math.ceil(math.log(n)/math.log(2)))
    N = 1 << m
    u = torch.cat((u, u.new_zeros((N-u.shape[0],) + u.shape[1:] )), dim=0)

    # Apply bit reversal
    br = bitreversal_po2(N)
    u = u[br, ...]

    x = parallel_unroll_recursive_br_(A, u)
    return x[:n, ...]

def parallel_unroll_iterative(A, u):
    """ Bottom-up divide-and-conquer version of unroll, implemented iteratively """

    # Pad u to power of 2
    n = u.shape[0]
    m = int(math.ceil(math.log(n)/math.log(2)))
    N = 1 << m
    u = torch.cat((u, u.new_zeros((N-u.shape[0],) + u.shape[1:] )), dim=0)

    # Apply bit reversal
    br = bitreversal_po2(N)
    u = u[br, ...]

    # Main recursive loop, flattened
    us = [] # stores the u_0 terms in the recursive version
    N_ = N
    As = [] # stores the A matrices
    for l in range(m):
        N_ = N_ // 2
        As.append(A)
        u_0 = u[:N_, ...]
        us.append(u_0)
        u = F.linear(u_0, A) + u[N_:, ...]
        A = A @ A
    x_0 = []
    x = u # x_1
    for l in range(m-1, -1, -1):
        x_0 = F.linear(shift_up(x), As[l]) + us[l]
        x = interleave(x_0, x, dim=0)

    return x[:n, ...]


def variable_unroll_sequential(A, u, s=None, variable=True):
    """ Unroll with variable (in time/length) transitions A.

    A : ([L], ..., N, N) dimension L should exist iff variable is True
    u : (L, [B], ..., N) updates
    s : ([B], ..., N) start state
    output : x (..., N)
    x[i, ...] = A[i]..A[0] @ s + A[i..1] @ u[0] + ... + A[i] @ u[i-1] + u[i]
    """

    if s is None:
        s = torch.zeros_like(u[0])

    if not variable:
        A = A.expand((u.shape[0],) + A.shape)
    has_batch = len(u.shape) >= len(A.shape)

    outputs = []
    for (A_, u_) in zip(torch.unbind(A, dim=0), torch.unbind(u, dim=0)):
        # s = F.linear(s, A_) + u_
        s = batch_mult(A_.unsqueeze(0), s.unsqueeze(0), has_batch)[0]
        s = s + u_
        outputs.append(s)

    output = torch.stack(outputs, dim=0)
    return output



def variable_unroll(A, u, s=None, variable=True, recurse_limit=16):
    """ Bottom-up divide-and-conquer version of variable_unroll. """

    if u.shape[0] <= recurse_limit:
        return variable_unroll_sequential(A, u, s, variable)

    if s is None:
        s = torch.zeros_like(u[0])

    uneven = u.shape[0] % 2 == 1
    has_batch = len(u.shape) >= len(A.shape)

    u_0 = u[0::2, ...]
    u_1  = u[1::2, ...]

    if variable:
        A_0 = A[0::2, ...]
        A_1  = A[1::2, ...]
    else:
        A_0 = A
        A_1 = A

    u_0_ = u_0
    A_0_ = A_0
    if uneven:
        u_0_ = u_0[:-1, ...]
        if variable:
            A_0_ = A_0[:-1, ...]

    u_10 = batch_mult(A_1, u_0_, has_batch)
    u_10 = u_10 + u_1
    A_10 = A_1 @ A_0_

    # Recursive call
    x_1 = variable_unroll(A_10, u_10, s, variable, recurse_limit)

    x_0 = shift_up(x_1, s, drop=not uneven)
    x_0 = batch_mult(A_0, x_0, has_batch)
    x_0 = x_0 + u_0


    x = interleave(x_0, x_1, uneven, dim=0) # For some reason this interleave is slower than in the (non-multi) unroll_recursive
    return x

def variable_unroll_general_sequential(A, u, s, op, variable=True):
    """ Unroll with variable (in time/length) transitions A with general associative operation

    A : ([L], ..., N, N) dimension L should exist iff variable is True
    u : (L, [B], ..., N) updates
    s : ([B], ..., N) start state
    output : x (..., N)
    x[i, ...] = A[i]..A[0] s + A[i..1] u[0] + ... + A[i] u[i-1] + u[i]
    """

    if not variable:
        A = A.expand((u.shape[0],) + A.shape)

    outputs = []
    for (A_, u_) in zip(torch.unbind(A, dim=0), torch.unbind(u, dim=0)):
        s = op(A_, s)
        s = s + u_
        outputs.append(s)

    output = torch.stack(outputs, dim=0)
    return output

def variable_unroll_matrix_sequential(A, u, s=None, variable=True):
    if s is None:
        s = torch.zeros_like(u[0])

    if not variable:
        A = A.expand((u.shape[0],) + A.shape)
    # has_batch = len(u.shape) >= len(A.shape)

    # op = lambda x, y: batch_mult(x.unsqueeze(0), y.unsqueeze(0), has_batch)[0]
    op = lambda x, y: batch_mult(x.unsqueeze(0), y.unsqueeze(0))[0]

    return variable_unroll_general_sequential(A, u, s, op, variable=True)

def variable_unroll_toeplitz_sequential(A, u, s=None, variable=True, pad=False):
    if s is None:
        s = torch.zeros_like(u[0])

    if not variable:
        A = A.expand((u.shape[0],) + A.shape)
    # has_batch = len(u.shape) >= len(A.shape)

    # op = lambda x, y: batch_mult(x.unsqueeze(0), y.unsqueeze(0), has_batch)[0]
    # op = lambda x, y: batch_mult(x.unsqueeze(0), y.unsqueeze(0))[0]

    if pad:
        n = A.shape[-1]
        A = F.pad(A, (0, n))
        u = F.pad(u, (0, n))
        s = F.pad(s, (0, n))
        ret = variable_unroll_general_sequential(A, u, s, triangular_toeplitz_multiply_padded, variable=True)
        ret = ret[..., :n]
        return ret

    return variable_unroll_general_sequential(A, u, s, triangular_toeplitz_multiply, variable=True)



### General parallel scan functions with generic binary composition operators

def variable_unroll_general(A, u, s, op, compose_op=None, sequential_op=None, variable=True, recurse_limit=16):
    """ Bottom-up divide-and-conquer version of variable_unroll.

    compose is an optional function that defines how to compose A without multiplying by a leaf u
    """

    if u.shape[0] <= recurse_limit:
        if sequential_op is None:
            sequential_op = op
        return variable_unroll_general_sequential(A, u, s, sequential_op, variable)

    if compose_op is None:
        compose_op = op

    uneven = u.shape[0] % 2 == 1
    # has_batch = len(u.shape) >= len(A.shape)

    u_0 = u[0::2, ...]
    u_1 = u[1::2, ...]

    if variable:
        A_0 = A[0::2, ...]
        A_1 = A[1::2, ...]
    else:
        A_0 = A
        A_1 = A

    u_0_ = u_0
    A_0_ = A_0
    if uneven:
        u_0_ = u_0[:-1, ...]
        if variable:
            A_0_ = A_0[:-1, ...]

    u_10 = op(A_1, u_0_) # batch_mult(A_1, u_0_, has_batch)
    u_10 = u_10 + u_1
    A_10 = compose_op(A_1, A_0_)

    # Recursive call
    x_1 = variable_unroll_general(A_10, u_10, s, op, compose_op, sequential_op, variable=variable, recurse_limit=recurse_limit)

    x_0 = shift_up(x_1, s, drop=not uneven)
    x_0 = op(A_0, x_0) # batch_mult(A_0, x_0, has_batch)
    x_0 = x_0 + u_0


    x = interleave(x_0, x_1, uneven, dim=0) # For some reason this interleave is slower than in the (non-multi) unroll_recursive
    return x

def variable_unroll_matrix(A, u, s=None, variable=True, recurse_limit=16):
    if s is None:
        s = torch.zeros_like(u[0])
    has_batch = len(u.shape) >= len(A.shape)
    op = lambda x, y: batch_mult(x, y, has_batch)
    sequential_op = lambda x, y: batch_mult(x.unsqueeze(0), y.unsqueeze(0), has_batch)[0]
    matmul = lambda x, y: x @ y
    return variable_unroll_general(A, u, s, op, compose_op=matmul, sequential_op=sequential_op, variable=variable, recurse_limit=recurse_limit)

def variable_unroll_toeplitz(A, u, s=None, variable=True, recurse_limit=8, pad=False):
    """ Unroll with variable (in time/length) transitions A with general associative operation

    A : ([L], ..., N) dimension L should exist iff variable is True
    u : (L, [B], ..., N) updates
    s : ([B], ..., N) start state
    output : x (L, [B], ..., N) same shape as u
    x[i, ...] = A[i]..A[0] s + A[i..1] u[0] + ... + A[i] u[i-1] + u[i]
    """
    # Add the batch dimension to A if necessary
    A_batch_dims = len(A.shape) - int(variable)
    u_batch_dims = len(u.shape)-1
    if u_batch_dims > A_batch_dims:
        # assert u_batch_dims == A_batch_dims + 1
        if variable:
            while len(A.shape) < len(u.shape):
                A = A.unsqueeze(1)
        # else:
        #     A = A.unsqueeze(0)

    if s is None:
        s = torch.zeros_like(u[0])

    if pad:
        n = A.shape[-1]
        A = F.pad(A, (0, n))
        u = F.pad(u, (0, n))
        s = F.pad(s, (0, n))
        op = triangular_toeplitz_multiply_padded
        ret = variable_unroll_general(A, u, s, op, compose_op=op, variable=variable, recurse_limit=recurse_limit)
        ret = ret[..., :n]
        return ret

    op = triangular_toeplitz_multiply
    ret = variable_unroll_general(A, u, s, op, compose_op=op, variable=variable, recurse_limit=recurse_limit)
    return ret

## Gu's HiPPO LegT Operator

In [24]:
class HiPPO_LegT(nn.Module):
    def __init__(self, N, dt=1.0, discretization="bilinear", lambda_n=1.0):
        """
        N: the order of the HiPPO projection
        dt: discretization step size - should be roughly inverse to the length of the sequence
        """
        super().__init__()
        self.N = N
        # A, B = transition('lmu', N)
        legt_matrices = GuTransMatrix(N=N, measure="legt", lambda_n=lambda_n)
        A = legt_matrices.A_matrix
        B = legt_matrices.B_matrix
        C = np.ones((1, N))
        D = np.zeros((1,))
        # dt, discretization options
        A, B, _, _, _ = signal.cont2discrete((A, B, C, D), dt=dt, method=discretization)

        B = B.squeeze(-1)

        self.register_buffer("A", torch.Tensor(A))  # (N, N)
        self.register_buffer("B", torch.Tensor(B))  # (N,)

        # vals = np.linspace(0.0, 1.0, 1./dt)
        vals = np.arange(0.0, 1.0, dt)
        self.eval_matrix = torch.Tensor(
            ss.eval_legendre(np.arange(N)[:, None], 1 - 2 * vals).T
        )

    def forward(self, inputs):
        """
        inputs : (length, ...)
        output : (length, ..., N) where N is the order of the HiPPO projection
        """

        inputs = inputs.unsqueeze(-1)
        u = inputs * self.B  # (length, ..., N)

        c = torch.zeros(u.shape[1:])
        cs = []
        for f in inputs:
            c = F.linear(c, self.A) + self.B * f
            # print(f"f:\n{f}")
            cs.append(c)
        return torch.stack(cs, dim=0)

    def reconstruct(self, c):
        return (self.eval_matrix @ c.unsqueeze(-1)).squeeze(-1)


## Gu's Scale invariant HiPPO LegS Operator

In [25]:
class HiPPO_LegS(nn.Module):
    """Vanilla HiPPO-LegS model (scale invariant instead of time invariant)"""

    def __init__(self, N, max_length=1024, measure="legs", discretization="bilinear"):
        """
        max_length: maximum sequence length
        """
        super().__init__()
        self.N = N
        legs_matrices = GuTransMatrix(N=self.N, measure=measure)
        A = legs_matrices.A_matrix
        B = legs_matrices.B_matrix
        # A, B = transition(measure, N)
        B = B.squeeze(-1)
        A_stacked = np.empty((max_length, N, N), dtype=A.dtype)
        B_stacked = np.empty((max_length, N), dtype=B.dtype)
        for t in range(1, max_length + 1):
            At = A / t
            Bt = B / t
            if discretization == "forward":
                A_stacked[t - 1] = np.eye(N) + At
                B_stacked[t - 1] = Bt
            elif discretization == "backward":
                A_stacked[t - 1] = la.solve_triangular(
                    np.eye(N) - At, np.eye(N), lower=True
                )
                B_stacked[t - 1] = la.solve_triangular(np.eye(N) - At, Bt, lower=True)
            elif discretization == "bilinear":
                alpha = 0.5
                A_stacked[t - 1] = np.linalg.lstsq(
                    np.eye(N) - (At * alpha), np.eye(N) + (At * alpha), rcond=None
                )[
                    0
                ]  # TODO: Referencing this: https://stackoverflow.com/questions/64527098/numpy-linalg-linalgerror-singular-matrix-error-when-trying-to-solve
                B_stacked[t - 1] = np.linalg.lstsq(
                    np.eye(N) - (At * alpha), Bt, rcond=None
                )[0]
            else:  # ZOH
                A_stacked[t - 1] = la.expm(A * (math.log(t + 1) - math.log(t)))
                B_stacked[t - 1] = la.solve_triangular(
                    A, A_stacked[t - 1] @ B - B, lower=True
                )
        self.A_stacked = torch.Tensor(A_stacked.copy())  # (max_length, N, N)
        self.B_stacked = torch.Tensor(B_stacked.copy())  # (max_length, N)
        vals = np.linspace(0.0, 1.0, max_length)
        self.eval_matrix = torch.from_numpy(
            np.asarray(
                ((B[:, None] * ss.eval_legendre(np.arange(N)[:, None], 2 * vals - 1)).T)
            )
        )

    def forward(self, inputs, fast=False):
        """
        inputs : (length, ...)
        output : (length, ..., N) where N is the order of the HiPPO projection
        """
        result = None

        L = inputs.shape[0]

        u = inputs.unsqueeze(-1)
        u = torch.transpose(u, 0, -2)
        u = u * self.B_stacked[:L]  # c_k = A @ c_{k-1} + B @ f_k
        # print(f"u - Gu: {u}")
        my_b = torch.Tensor(
            [
                [6.6666657e-01],
                [5.7735050e-01],
                [1.4907140e-01],
                [-2.3096800e-07],
                [-2.7939677e-09],
                [2.9616058e-07],
                [-2.2817403e-08],
                [-8.1490725e-08],
            ]
        )
        u = torch.transpose(u, 0, -2)  # (length, ..., N)

        # print(f"A_stacked: {self.A_stacked[:L]}")
        # print(f"B_stacked: {self.B_stacked[:L]}")

        if fast:
            result = variable_unroll_matrix(self.A_stacked[:L], u)

        else:
            result = variable_unroll_matrix_sequential(self.A_stacked[:L], u)

        return result

    def reconstruct(self, c):
        a = self.eval_matrix @ c.unsqueeze(-1)
        return a.squeeze(-1)

## Implementation Of General HiPPO Operator

In [26]:
class nb_HiPPO(jnn.Module):
    """
    class that constructs HiPPO model using the defined measure.

    Args:
        N (int): order of the HiPPO projection, aka the number of coefficients to describe the matrix
        max_length (int): maximum sequence length to be input
        measure (str): the measure used to define which way to instantiate the HiPPO matrix
        step (float): step size used for descretization
        GBT_alpha (float): represents which descretization transformation to use based off the alpha value
        seq_L (int): length of the sequence to be used for training
        v (str): choice of vectorized or non-vectorized function instantiation
            - 'v': vectorized
            - 'nv': non-vectorized
        lambda_n (float): value associated with the tilt of legt
            - 1: tilt on legt
            - \sqrt(2n+1)(-1)^{N}: tilt associated with the legendre memory unit (LMU)
        fourier_type (str): choice of fourier measures
            - fru: fourier recurrent unit measure (FRU) - 'fru'
            - fout: truncated Fourier (FouT) - 'fout'
            - fourd: decaying fourier transform - 'fourd'
        alpha (float): The order of the Laguerre basis.
        beta (float): The scale of the Laguerre basis.
    """

    N: int
    max_length: int
    step: float
    GBT_alpha: float
    seq_L: int
    A: jnp.ndarray
    B: jnp.ndarray
    measure: str

    def setup(self):
        A = self.A
        B = self.B
        self.C = jnp.ones((self.N,))
        self.D = jnp.zeros((1,))

        if self.measure == "legt":
            L = self.seq_L
            vals = jnp.arange(0.0, 1.0, L)
            # n = jnp.arange(self.N)[:, None]
            zero_N = self.N - 1
            x = 1 - 2 * vals
            self.eval_matrix = jax.scipy.special.lpmn_values(
                m=zero_N, n=zero_N, z=x, is_normalized=False
            ).T  # ss.eval_legendre(n, x).T

        elif self.measure == "legs":
            L = self.max_length
            vals = jnp.linspace(0.0, 1.0, L)
            # n = jnp.arange(self.N)[:, None]
            zero_N = self.N - 1
            x = 2 * vals - 1
            self.eval_matrix = (
                B[:, None]
                * jax.scipy.special.lpmn_values(
                    m=zero_N, n=zero_N, z=x, is_normalized=False
                )
            ).T  # ss.eval_legendre(n, x)).T

        elif self.measure == "lagt":
            raise NotImplementedError("Translated Laguerre measure not implemented yet")

        elif self.measure == "fourier":
            raise NotImplementedError("Fourier measures are not implemented yet")

        else:
            raise ValueError("invalid measure")

    def __call__(self, f, init_state=None, t_step=0, kernel=False):
        # print(f"u shape:\n{f.shape}")
        # print(f"u:\n{f}")
        if not kernel:
            if init_state is None:
                init_state = jnp.zeros((self.N, 1))

            # Ab, Bb, Cb, Db = self.collect_SSM_vars(
            #     self.A, self.B, self.C, self.D, f, t_step=t_step, alpha=self.GBT_alpha
            # )
            c_k, y_k, GBT_A, GBT_B = self.loop_SSM(
                A=self.A,
                B=self.B,
                C=self.C,
                D=self.D,
                c_0=init_state,
                f=f,
                alpha=self.GBT_alpha,
            )
            # c_k, y_k = self.scan_SSM(Ab=Ab, Bb=Bb, Cb=Cb, Db=Db, c_0=init_state, f=f)

        else:
            Ab, Bb, Cb, Db = self.discretize(
                self.A, self.B, self.C, self.D, step=self.step, alpha=self.GBT_alpha
            )
            c_k, y_k = self.causal_convolution(
                f, self.K_conv(Ab, Bb, Cb, Db, L=self.max_length)
            )

        return c_k, y_k, GBT_A, GBT_B

    def reconstruct(self, c):
        """
        Uses coeffecients to reconstruct the signal

        Args:
            c (jnp.ndarray): coefficients of the HiPPO projection

        Returns:
            reconstructed signal
        """
        return (self.eval_matrix @ jnp.expand_dims(c, -1)).squeeze(-1)

    def discretize(self, A, B, C, D, step, alpha=0.5):
        """
        function used for discretizing the HiPPO matrix

        Args:
            A (jnp.ndarray): matrix to be discretized
            B (jnp.ndarray): matrix to be discretized
            C (jnp.ndarray): matrix to be discretized
            D (jnp.ndarray): matrix to be discretized
            step (float): step size used for discretization
            alpha (float, optional): used for determining which generalized bilinear transformation to use
                - forward Euler corresponds to α = 0,
                - backward Euler corresponds to α = 1,
                - bilinear corresponds to α = 0.5,
                - Zero-order Hold corresponds to α > 1
        """
        I = jnp.eye(A.shape[0])
        step_size = 1 / step
        part1 = I - (step_size * alpha * A)
        part2 = I + (step_size * (1 - alpha) * A)

        GBT_A = jnp.linalg.lstsq(part1, part2, rcond=None)[0]

        base_GBT_B = jnp.linalg.lstsq(part1, B, rcond=None)[0]
        GBT_B = step_size * base_GBT_B

        if alpha > 1:  # Zero-order Hold
            GBT_A = jax.scipy.linalg.expm(step_size * A)
            GBT_B = (jnp.linalg.inv(A) @ (jax.scipy.linalg.expm(step_size * A) - I)) @ B

        return (
            GBT_A.astype(jnp.float32),
            GBT_B.astype(jnp.float32),
            C.astype(jnp.float32),
            D.astype(jnp.float32),
        )

    def collect_SSM_vars(self, A, B, C, D, f, t_step=0, alpha=0.5):
        """
        turns the continuos HiPPO matrix components into discrete ones

        Args:
            A (jnp.ndarray): matrix to be discretized
            B (jnp.ndarray): matrix to be discretized
            C (jnp.ndarray): matrix to be discretized
            D (jnp.ndarray): matrix to be discretized
            f (jnp.ndarray): input signal
            alpha (float, optional): used for determining which generalized bilinear transformation to use

        Returns:
            Ab (jnp.ndarray): discrete form of the HiPPO matrix
            Bb (jnp.ndarray): discrete form of the HiPPO matrix
            Cb (jnp.ndarray): discrete form of the HiPPO matrix
            Db (jnp.ndarray): discrete form of the HiPPO matrix
        """
        N = A.shape[0]

        if t_step == 0:
            L = f.shape[0]  # seq_L, 1
            assert (
                L == self.seq_L
            ), f"sequence length must match, currently {L} != {self.seq_L}"
            assert N == self.N, f"Order number must match, currently {N} != {self.N}"
        else:
            L = t_step
            assert t_step >= 1, f"time step must be greater than 0, currently {t_step}"
            assert N == self.N, f"Order number must match, currently {N} != {self.N}"

        Ab, Bb, Cb, Db = self.discretize(A, B, C, D, step=L, alpha=alpha)

        return (
            Ab.astype(jnp.float32),
            Bb.astype(jnp.float32),
            Cb.astype(jnp.float32),
            Db.astype(jnp.float32),
        )

    def scan_SSM(self, Ad, Bd, Cd, Dd, c_0, f):
        """
        This is for returning the discretized hidden state often needed for an RNN.
        Args:
            Ab (jnp.ndarray): the discretized A matrix
            Bb (jnp.ndarray): the discretized B matrix
            Cb (jnp.ndarray): the discretized C matrix
            f (jnp.ndarray): the input sequence
            c_0 (jnp.ndarray): the initial hidden state
        Returns:
            the next hidden state (aka coefficients representing the function, f(t))
        """

        def step(c_k_1, f_k):
            """
            Get descretized coefficients of the hidden state by applying HiPPO matrix to input sequence, u_k, and previous hidden state, x_k_1.
            Args:
                c_k_1: previous hidden state
                f_k: output from function f at, descritized, time step, k.
                t:

            Returns:
                c_k: current hidden state
                y_k: current output of hidden state applied to Cb (sorry for being vague, I just dont know yet)
            """
            part1 = Ad @ c_k_1
            part2 = jnp.expand_dims((Bd @ f_k), -1)

            c_k = part1 + part2
            y_k = Cd @ c_k  # + (Db.T @ f_k)

            return c_k, y_k

        return jax.lax.scan(step, c_0, f)

    def loop_SSM(self, A, B, C, D, c_0, f, alpha=0.5):
        """
        This is for returning the discretized hidden state often needed for an RNN.
        Args:
            Ab (jnp.ndarray): the discretized A matrix
            Bb (jnp.ndarray): the discretized B matrix
            Cb (jnp.ndarray): the discretized C matrix
            f (jnp.ndarray): the input sequence
            c_0 (jnp.ndarray): the initial hidden state
        Returns:
            the next hidden state (aka coefficients representing the function, f(t))
        """
        GBT_A_lst = []
        GBT_B_lst = []
        c_k_list = []
        y_k_list = []

        c_k = c_0.copy()
        for i in range(1, f.shape[0] + 1):
            Ad_i, Bd_i, Cd_i, Dd_i = self.collect_SSM_vars(
                A=A, B=B, C=C, D=D, f=f, t_step=i, alpha=alpha
            )
            c_k, y_k = self.loop_step(
                Ad=Ad_i, Bd=Bd_i, Cd=Cd_i, Dd=Dd_i, c_k_i=c_k, f_k=f[i - 1][0]
            )
            c_k_list.append(c_k.copy())
            y_k_list.append(y_k.copy())
            GBT_A_lst.append(Ad_i.copy())
            GBT_B_lst.append(Bd_i.copy())

        return c_k_list, y_k_list, GBT_A_lst, GBT_B_lst

    def loop_step(self, Ad, Bd, Cd, Dd, c_k_i, f_k):
        """
        Get descretized coefficients of the hidden state by applying HiPPO matrix to input sequence, u_k, and previous hidden state, x_k_1.
        Args:
            c_k_i: previous hidden state
            f_k: output from function f at, descritized, time step, k.

        Returns:
            c_k: current hidden state
            y_k: current output of hidden state applied to Cb (sorry for being vague, I just dont know yet)
        """
        # print(f"c_k_i:\n{c_k_i}")
        # print(f"f_k:\n{f_k}")

        part1 = Ad @ c_k_i
        part2 = Bd * f_k
        c_k = part1 + part2
        y_k = Cd @ c_k  # + (Db.T @ f_k)

        return c_k.astype(jnp.float32), y_k.astype(jnp.float32)
    

In [27]:
class HiPPO(jnn.Module):
    """
    class that constructs HiPPO model using the defined measure.

    Args:
        N (int): order of the HiPPO projection, aka the number of coefficients to describe the matrix
        max_length (int): maximum sequence length to be input
        measure (str): the measure used to define which way to instantiate the HiPPO matrix
        step (float): step size used for descretization
        GBT_alpha (float): represents which descretization transformation to use based off the alpha value
        seq_L (int): length of the sequence to be used for training
        v (str): choice of vectorized or non-vectorized function instantiation
            - 'v': vectorized
            - 'nv': non-vectorized
        lambda_n (float): value associated with the tilt of legt
            - 1: tilt on legt
            - \sqrt(2n+1)(-1)^{N}: tilt associated with the legendre memory unit (LMU)
        fourier_type (str): choice of fourier measures
            - fru: fourier recurrent unit measure (FRU) - 'fru'
            - fout: truncated Fourier (FouT) - 'fout'
            - fourd: decaying fourier transform - 'fourd'
        alpha (float): The order of the Laguerre basis.
        beta (float): The scale of the Laguerre basis.
    """

    N: int
    max_length: int
    step: float
    GBT_alpha: float
    seq_L: int
    A: jnp.ndarray
    B: jnp.ndarray
    measure: str

    def setup(self):
        A = self.A
        B = self.B
        self.C = jnp.ones((self.N,))
        self.D = jnp.zeros((1,))

        if self.measure == "legt":
            L = self.seq_L
            vals = jnp.arange(0.0, 1.0, L)
            # n = jnp.arange(self.N)[:, None]
            zero_N = self.N - 1
            x = 1 - 2 * vals
            self.eval_matrix = jax.scipy.special.lpmn_values(
                m=zero_N, n=zero_N, z=x, is_normalized=False
            ).T  # ss.eval_legendre(n, x).T

        elif self.measure == "legs":
            L = self.max_length
            vals = jnp.linspace(0.0, 1.0, L)
            # n = jnp.arange(self.N)[:, None]
            zero_N = self.N - 1
            x = 2 * vals - 1
            self.eval_matrix = (
                B[:, None]
                * jax.scipy.special.lpmn_values(
                    m=zero_N, n=zero_N, z=x, is_normalized=False
                )
            ).T  # ss.eval_legendre(n, x)).T

        elif self.measure == "lagt":
            raise NotImplementedError("Translated Laguerre measure not implemented yet")

        elif self.measure == "fourier":
            raise NotImplementedError("Fourier measures are not implemented yet")

        else:
            raise ValueError("invalid measure")

    def __call__(self, f, init_state=None, t_step=0, kernel=False):
        # print(f"u shape:\n{f.shape}")
        # print(f"u:\n{f}")
        if not kernel:
            if init_state is None:
                init_state = jnp.zeros((f.shape[0], self.N, 1))

            # Ab, Bb, Cb, Db = self.collect_SSM_vars(
            #     self.A, self.B, self.C, self.D, f, t_step=t_step, alpha=self.GBT_alpha
            # )
            c_k, y_k, GBT_A, GBT_B = self.loop_SSM(
                A=self.A,
                B=self.B,
                C=self.C,
                D=self.D,
                c_0=init_state,
                f=f,
                alpha=self.GBT_alpha,
            )
            # c_k, y_k = self.scan_SSM(Ab=Ab, Bb=Bb, Cb=Cb, Db=Db, c_0=init_state, f=f)

        else:
            Ab, Bb, Cb, Db = self.discretize(
                self.A, self.B, self.C, self.D, step=self.step, alpha=self.GBT_alpha
            )
            c_k, y_k = self.causal_convolution(
                f, self.K_conv(Ab, Bb, Cb, Db, L=self.max_length)
            )

        return c_k, y_k, GBT_A, GBT_B

    def reconstruct(self, c):
        """
        Uses coeffecients to reconstruct the signal

        Args:
            c (jnp.ndarray): coefficients of the HiPPO projection

        Returns:
            reconstructed signal
        """
        return (self.eval_matrix @ jnp.expand_dims(c, -1)).squeeze(-1)

    def discretize(self, A, B, C, D, step, alpha=0.5):
        """
        function used for discretizing the HiPPO matrix

        Args:
            A (jnp.ndarray): matrix to be discretized
            B (jnp.ndarray): matrix to be discretized
            C (jnp.ndarray): matrix to be discretized
            D (jnp.ndarray): matrix to be discretized
            step (float): step size used for discretization
            alpha (float, optional): used for determining which generalized bilinear transformation to use
                - forward Euler corresponds to α = 0,
                - backward Euler corresponds to α = 1,
                - bilinear corresponds to α = 0.5,
                - Zero-order Hold corresponds to α > 1
        """
        I = jnp.eye(A.shape[0])
        step_size = 1 / step
        part1 = I - (step_size * alpha * A)
        part2 = I + (step_size * (1 - alpha) * A)

        GBT_A = jnp.linalg.lstsq(part1, part2, rcond=None)[0]

        base_GBT_B = jnp.linalg.lstsq(part1, B, rcond=None)[0]
        GBT_B = step_size * base_GBT_B

        if alpha > 1:  # Zero-order Hold
            GBT_A = jax.scipy.linalg.expm(step_size * A)
            GBT_B = (jnp.linalg.inv(A) @ (jax.scipy.linalg.expm(step_size * A) - I)) @ B

        return (
            GBT_A.astype(jnp.float32),
            GBT_B.astype(jnp.float32),
            C.astype(jnp.float32),
            D.astype(jnp.float32),
        )

    def collect_SSM_vars(self, A, B, C, D, f, t_step=0, alpha=0.5):
        """
        turns the continuos HiPPO matrix components into discrete ones

        Args:
            A (jnp.ndarray): matrix to be discretized
            B (jnp.ndarray): matrix to be discretized
            C (jnp.ndarray): matrix to be discretized
            D (jnp.ndarray): matrix to be discretized
            f (jnp.ndarray): input signal
            alpha (float, optional): used for determining which generalized bilinear transformation to use

        Returns:
            Ab (jnp.ndarray): discrete form of the HiPPO matrix
            Bb (jnp.ndarray): discrete form of the HiPPO matrix
            Cb (jnp.ndarray): discrete form of the HiPPO matrix
            Db (jnp.ndarray): discrete form of the HiPPO matrix
        """
        N = A.shape[0]

        if t_step == 0:
            L = f.shape[1]  # seq_L, 1
            assert (
                L == self.seq_L
            ), f"sequence length must match, currently {L} != {self.seq_L}"
            assert N == self.N, f"Order number must match, currently {N} != {self.N}"
        else:
            L = t_step
            assert t_step >= 1, f"time step must be greater than 0, currently {t_step}"
            assert N == self.N, f"Order number must match, currently {N} != {self.N}"

        Ab, Bb, Cb, Db = self.discretize(A, B, C, D, step=L, alpha=alpha)

        return (
            Ab.astype(jnp.float32),
            Bb.astype(jnp.float32),
            Cb.astype(jnp.float32),
            Db.astype(jnp.float32),
        )

    def scan_SSM(self, Ad, Bd, Cd, Dd, c_0, f):
        """
        This is for returning the discretized hidden state often needed for an RNN.
        Args:
            Ab (jnp.ndarray): the discretized A matrix
            Bb (jnp.ndarray): the discretized B matrix
            Cb (jnp.ndarray): the discretized C matrix
            f (jnp.ndarray): the input sequence
            c_0 (jnp.ndarray): the initial hidden state
        Returns:
            the next hidden state (aka coefficients representing the function, f(t))
        """

        def step(c_k_1, f_k):
            """
            Get descretized coefficients of the hidden state by applying HiPPO matrix to input sequence, u_k, and previous hidden state, x_k_1.
            Args:
                c_k_1: previous hidden state
                f_k: output from function f at, descritized, time step, k.
                t:

            Returns:
                c_k: current hidden state
                y_k: current output of hidden state applied to Cb (sorry for being vague, I just dont know yet)
            """
            part1 = Ad @ c_k_1
            part2 = jnp.expand_dims((Bd @ f_k), -1)

            c_k = part1 + part2
            y_k = Cd @ c_k  # + (Db.T @ f_k)

            return c_k, y_k

        return jax.lax.scan(step, c_0, f)

    def loop_SSM(self, A, B, C, D, c_0, f, alpha=0.5):
        """
        This is for returning the discretized hidden state often needed for an RNN.
        Args:
            Ab (jnp.ndarray): the discretized A matrix
            Bb (jnp.ndarray): the discretized B matrix
            Cb (jnp.ndarray): the discretized C matrix
            f (jnp.ndarray): the input sequence
            c_0 (jnp.ndarray): the initial hidden state
        Returns:
            the next hidden state (aka coefficients representing the function, f(t))
        """
        GBT_A_lst = []
        GBT_B_lst = []
        c_k_list = []
        y_k_list = []

        c_k = c_0.copy()
        # print(f"f shape:\n{f.shape}")
        # jax.debug.print(f"f:\n{f}")
        for i in range(1, f.shape[1] + 1):
            Ad_i, Bd_i, Cd_i, Dd_i = self.collect_SSM_vars(
                A=A, B=B, C=C, D=D, f=f, t_step=i, alpha=alpha
            )
            # jax.debug.print(f"f[:,i-1,:] shape: {f[:,i-1,:].shape}")
            # jax.debug.print(f"f[:,i-1,:]: {f[:,i-1,:]}")
            # print(f"f[i - 1][0] shape: {f[i - 1].shape}")
            # print(f"f[i - 1][0]: {f[i - 1]}")
            # print(f"c_k shape: {c_k.shape}")
            # c_k, y_k = self.loop_step(
            #     Ad=Ad_i, Bd=Bd_i, Cd=Cd_i, Dd=Dd_i, c_k_i=c_k, f_k=f[i-1,:][0]
            # )
            c_k, y_k = jax.vmap(self.loop_step, in_axes=(None, None, None, None, 0, 0))(
                Ad_i, Bd_i, Cd_i, Dd_i, c_k, f[:,i-1,:]
            )
            c_k_list.append(c_k.copy())
            y_k_list.append(y_k.copy())
            GBT_A_lst.append(Ad_i.copy())
            GBT_B_lst.append(Bd_i.copy())

        return c_k_list, y_k_list, GBT_A_lst, GBT_B_lst

    def loop_step(self, Ad, Bd, Cd, Dd, c_k_i, f_k):
        """
        Get descretized coefficients of the hidden state by applying HiPPO matrix to input sequence, u_k, and previous hidden state, x_k_1.
        Args:
            c_k_i: previous hidden state
            f_k: output from function f at, descritized, time step, k.

        Returns:
            c_k: current hidden state
            y_k: current output of hidden state applied to Cb (sorry for being vague, I just dont know yet)
        """
        # jax.debug.print(f"c_k_i:\n{c_k_i}")
        # jax.debug.print(f"f_k:\n{f_k}")
        
        part1 = Ad @ c_k_i
        part2 = Bd * f_k
        c_k = part1 + part2
        y_k = Cd @ c_k  # + (Db.T @ f_k)

        return c_k.astype(jnp.float32), y_k.astype(jnp.float32)
    

In [28]:
def vmap_compare(gu_c, c_k):
    # jax.debug.print(f"c_k shape: {c_k.shape}")
    # jax.debug.print(f"gu_c shape: {gu_c.shape}")
    for i in range(c_k.shape[0]):
        # jax.debug.print(f"c_k[i,:,:] shape: {c_k[i,:,:].shape}")
        # jax.debug.print(f"gu_c[i,:,:] shape: {gu_c[i,:,:].shape}")
        # jax.debug.print(f"c_k[{i},:,:]:\n{c_k[i,:,:]}\n")
        # jax.debug.print(f"gu_c[{i},:,:]:\n{gu_c[i,:,:]}\n")
        
        jax.debug.print(f"HiPPO LegS Test: {jnp.allclose(c_k[i,:,:], gu_c[i,:,:], rtol=1e-03, atol=1e-03)}")
        #print(f"c_k:\n{c_k}\n")

In [29]:
def test_hippo_legs_operator(hippo_legs, nb_hippo_legs, gu_hippo_legs, nb_gu_hippo_legs, random_input, legs_key, nb_legs_key):
    i = 0
    x_tensor = torch.tensor(random_input, dtype=torch.float32)
    x_jnp = jnp.asarray(x_tensor, dtype=jnp.float32)  # convert torch array to jax array
    nb_x_jnp = jnp.squeeze(x_jnp, axis=0)
    print(f"nb_x_jnp shape: {nb_x_jnp.shape}")
    
    #NOT BATCHED
    nb_params = nb_hippo_legs.init(nb_legs_key, f=nb_x_jnp, t_step=(nb_x_jnp.shape[0]))
    nb_c_k_list, nb_y_k_list, nb_GBT_A_list, nb_GBT_B_list = nb_hippo_legs.apply(
        nb_params, f=nb_x_jnp, t_step=(nb_x_jnp.shape[0])
    )
    nb_c_k = jnp.stack(nb_c_k_list, axis=0)
    print(f"nb_c_k shape: {nb_c_k.shape}")
    
    # BATCHED
    params = hippo_legs.init(legs_key, f=x_jnp, t_step=(x_jnp.shape[1]))
    c_k_list, y_k_list, GBT_A_list, GBT_B_list = hippo_legs.apply(
        params, f=x_jnp, t_step=(x_jnp.shape[0])
    )
    c_k = jnp.stack(c_k_list, axis=0)
    c_k = jnp.moveaxis(c_k, 0, 1)
    print(f"c_k shape: {c_k.shape}")
    
    # Gu's HiPPO LegS
    GU_c_k = gu_hippo_legs(x_tensor)
    gu_c = jnp.asarray(GU_c_k, dtype=jnp.float32)  # convert torch array to jax array
    gu_c = jnp.moveaxis(gu_c, -1, -2)
    print(f"gu_c shape: {gu_c.shape}")
    
    # NOT BATCHED Gu's HiPPO LegS
    nb_x_tensor = torch.squeeze(torch.tensor(random_input, dtype=torch.float32), dim=0)
    nb_GU_c_k = nb_gu_hippo_legs(nb_x_tensor)
    nb_gu_c = jnp.asarray(nb_GU_c_k, dtype=jnp.float32)  # convert torch array to jax array
    nb_gu_c = jnp.moveaxis(nb_gu_c, -1, -2)
    print(f"nb_gu_c shape: {nb_gu_c.shape}")
    
    
    # print(f"c_k shape before vmap: {c_k.shape}")
    # print(f"gu_c shape before vmap: {gu_c.shape}")
    
    # print(f"c_k before vmap:\n{c_k}")
    # print(f"nb_c_k before vmap:\n{nb_c_k}")
    # print(f"gu_c before vmap:\n{gu_c}")
    for i in range(c_k.shape[0]):
        for j in range(c_k.shape[1]):
            jax.debug.print(f"c_k @ b{i} t{j} - before vmap:\n{c_k[i,j,:,:]}\n")
            jax.debug.print(f"nb_c_k @ t{j} - before vmap:\n{nb_c_k[j,:,:]}\n")
            jax.debug.print(f"gu_c @ b{i} t{j} - before vmap:\n{gu_c[i,j,:,:]}\n")
            jax.debug.print(f"nb_gu_c @ t{j} - before vmap:\n{nb_gu_c[j,:,:]}\n")
            jax.debug.print(f"batch {i} on trajectory {j} compare : {jnp.allclose(c_k[i,j,:,:], gu_c[i,j,:,:], rtol=1e-03, atol=1e-03)}")
            jax.debug.print(f"no batch on trajectory {j} compare : {jnp.allclose(nb_c_k[j,:,:], gu_c[i,j,:,:], rtol=1e-03, atol=1e-03)}")
            jax.debug.print(f"no batch on trajectory {j} compare : {jnp.allclose(nb_c_k[j,:,:], nb_gu_c[j,:,:], rtol=1e-03, atol=1e-03)}")
            jax.debug.print(f"no batch on trajectory {j} compare : {jnp.allclose(c_k[i,j,:,:], nb_gu_c[j,:,:], rtol=1e-03, atol=1e-03)}")
            jax.debug.print(f"no batch on trajectory {j} compare : {jnp.allclose(nb_gu_c[j,:,:], gu_c[i,j,:,:], rtol=1e-03, atol=1e-03)}\n")
        
    
    #jax.vmap(vmap_compare, in_axes=(1, 1))(gu_c, c_k)
    # print(f"GU_c_k shape: {GU_c_k.shape}")
    # for i, c_k in enumerate(c_k_list):
    #     g_c_k = GU_c_k[i,:,:,:]
    #     # g_c_k = GU_c_k[i,:,:]
    #     gu = torch.unsqueeze(g_c_k, -1)
    #     gu_c = jnp.asarray(gu, dtype=jnp.float32)  # convert torch array to jax array
    #     print(f"HiPPO LegS Test: {jnp.allclose(c_k, gu_c, rtol=1e-04, atol=1e-06)}")
    #     #print(f"c_k:\n{c_k}\n")

In [30]:
def random_16_input(key_generator, batch_size=16, data_size=784, input_size=28):
    x = jax.random.randint(key_generator, (batch_size, data_size), 0, 255)
    return jax.vmap(moving_window, in_axes=(0, None))(x, input_size)


In [31]:
def test():
    # N = 256
    # L = 128
    
    batch_size = 1
    data_size = 32
    input_size = 1
    
    N = 8
    L = data_size
    
    x_jnp = random_16_input(
        key_generator=key3, 
        batch_size=batch_size, 
        data_size=data_size, 
        input_size=input_size
    )
    x_np = np.asarray(x_jnp)
    
    # N = 16
    # L = 8
    
    # x_np = np.array(
    #     [
    #         [0.3527],
    #         [0.6617],
    #         [0.2434],
    #         [0.6674],
    #         [1.2293],
    #         [0.0964],
    #         [-2.2756],
    #         [0.5618],
    #     ],
    #     dtype=np.float32,
    # )

    # x = torch.randn(L, 1)
    x = torch.tensor(x_np, dtype=torch.float32)

    print(f"X shape: {x.shape}")
    # print(f"X:\n{x}")

    # ----------------------------------------------------------------------------------
    loss = nn.MSELoss()

    # ----------------------------------------------------------------------------------
    # ------------------------------ Test HiPPO LegT model -----------------------------
    # ----------------------------------------------------------------------------------
    print("\nTesting HiPPO LegT model")
    hippo_legt = HiPPO_LegT(N, dt=1.0 / L)

    c_k = hippo_legt(x)

    # print(f"Gu's Coeffiecients for LegT:\n{c_k}")
    # print(f"Gu's Coeffiecient shapes for LegT:\n{c_k.shape}")

    # z = hippo_legt.reconstruct(c_k)
    # print(f"Gu's Reconstruction for LegT:\n{z}")
    # print(f"Gu's Reconstruction shape for LegT:\n{z.shape}")

    # mse = loss(z[-1, 0, :L], x.squeeze(-1))
    # print(f"h-MSE shape:\n{mse}")
    # print(f"end of test for HiPPO LegT model")

    # ----------------------------------------------------------------------------------
    # ------------------------------ Test HiPPO LegS model -----------------------------
    # ----------------------------------------------------------------------------------
    print("\nTesting HiPPO LegS model")
    gu_hippo_legs = HiPPO_LegS(N, max_length=L)  # The Gu's

    c_k = gu_hippo_legs(x, fast=True)

    print(f"Gu's Coeffiecients  for LegS:\n{c_k}")
    print(f"Gu's Coeffiecient shapes for LegS:\n{c_k.shape}")

    # z = hippo_legs.reconstruct(c_k)

    # print(f"Gu's Reconstruction for LegS:\n{z}")
    # print(f"Gu's Reconstruction shape for LegS:\n{z.shape}")

    # print(y-z)
    print(f"end of test for HiPPO LegS model")

    # ----------------------------------------------------------------------------------
    # ------------------------------ Test Generic HiPPO model --------------------------
    # ----------------------------------------------------------------------------------
    the_measure = "legs"
    print(f"\nTesting BRYANS HiPPO-{the_measure} model")
    legs_matrices = TransMatrix(N=N, measure=the_measure)
    A = legs_matrices.A_matrix
    B = legs_matrices.B_matrix
    nb_hippo_LegS_B = nb_HiPPO(
        N=N,
        max_length=L,
        step=1.0 / L,
        GBT_alpha=0.5,
        seq_L=L,
        A=A,
        B=B,
        measure=the_measure,
    )  # Bryan's
    
    hippo_LegS_B = HiPPO(
        N=N,
        max_length=L,
        step=1.0 / L,
        GBT_alpha=0.5,
        seq_L=L,
        A=A,
        B=B,
        measure=the_measure,
    )  # Bryan's
    
    print(f"Bryan's Coeffiecients for HiPPO-{the_measure}")
    nb_gu_hippo_legs = HiPPO_LegS(N, max_length=L)  # The Gu's
    test_hippo_legs_operator(hippo_legs=hippo_LegS_B, 
                             nb_hippo_legs=nb_hippo_LegS_B,
                             gu_hippo_legs=gu_hippo_legs, 
                             nb_gu_hippo_legs=nb_gu_hippo_legs,
                             random_input=x_np, 
                             legs_key=key2,
                             nb_legs_key=key4)
    
    # y_legs = hippo_LegS_B.apply(
    #     {"params": params}, c_k, method=hippo_LegS_B.reconstruct
    # )

    # print(f"Bryan's Reconstruction for HiPPO-{the_measure}:\n{y_legs}")
    # print(f"Bryan's Reconstruction shape for HiPPO-{the_measure}:\n{y_legs.shape}")

    print(f"end of test for HiPPO-{the_measure} model")

## Output

In [32]:
test()


X shape: torch.Size([1, 32, 1])

Testing HiPPO LegT model

Testing HiPPO LegS model
Gu's Coeffiecients  for LegS:
tensor([[[[ 2.1333e+01,  1.8475e+01,  4.7703e+00, -6.1748e-07,  1.0867e-06, -9.6922e-08, -7.9580e-07,  4.4762e-07]],

         [[ 1.6200e+02,  1.4030e+02,  3.6224e+01, -4.6890e-06,  8.2520e-06, -7.3600e-07, -6.0431e-06,  3.3991e-06]],

         [[ 1.6067e+02,  1.3914e+02,  3.5926e+01, -4.6504e-06,  8.1841e-06, -7.2995e-07, -5.9933e-06,  3.3712e-06]],

         [[ 1.3000e+02,  1.1258e+02,  2.9069e+01, -3.7627e-06,  6.6220e-06, -5.9062e-07, -4.8494e-06,  2.7277e-06]],

         [[ 4.5333e+01,  3.9260e+01,  1.0137e+01, -1.3121e-06,  2.3092e-06, -2.0596e-07, -1.6911e-06,  9.5120e-07]],

         [[ 1.3333e+00,  1.1547e+00,  2.9814e-01, -3.8592e-08,  6.7918e-08, -6.0576e-09, -4.9737e-08,  2.7976e-08]],

         [[ 8.9333e+01,  7.7365e+01,  1.9976e+01, -2.5857e-06,  4.5505e-06, -4.0586e-07, -3.3324e-06,  1.8744e-06]],

         [[ 1.3733e+02,  1.1893e+02,  3.0709e+01, -3.9750e-0

  self.eval_matrix = torch.from_numpy(


nb_x_jnp shape: (32, 1)
nb_c_k shape: (32, 8, 1)
c_k shape: (1, 32, 8, 1)
gu_c shape: (1, 32, 8, 1)
nb_gu_c shape: (32, 8, 1)
c_k @ b0 t0 - before vmap:
[[ 2.1333340e+01]
 [ 1.8475212e+01]
 [ 4.7702761e+00]
 [ 9.0003014e-06]
 [-3.0100346e-06]
 [ 2.2947788e-06]
 [-4.6342611e-06]
 [ 7.8976154e-07]]

nb_c_k @ t0 - before vmap:
[[ 2.1333340e+01]
 [ 1.8475212e+01]
 [ 4.7702761e+00]
 [ 9.0003014e-06]
 [-3.0100346e-06]
 [ 2.2947788e-06]
 [-4.6342611e-06]
 [ 7.8976154e-07]]

gu_c @ b0 t0 - before vmap:
[[ 2.1333334e+01]
 [ 1.8475208e+01]
 [ 4.7702780e+00]
 [-6.1747636e-07]
 [ 1.0866888e-06]
 [-9.6922243e-08]
 [-7.9579695e-07]
 [ 4.4762382e-07]]

nb_gu_c @ t0 - before vmap:
[[ 2.1333334e+01]
 [ 1.8475208e+01]
 [ 4.7702780e+00]
 [-6.1747636e-07]
 [ 1.0866888e-06]
 [-9.6922243e-08]
 [-7.9579695e-07]
 [ 4.4762382e-07]]

batch 0 on trajectory 0 compare : True
no batch on trajectory 0 compare : True
no batch on trajectory 0 compare : True
no batch on trajectory 0 compare : True
no batch on trajector