# HiPPO Operator Minimal Test
---

## Load Packages

In [1]:
import os
import sys

module_path = os.path.abspath(os.path.join("../../../"))
print(f"module_path: {module_path}")
if module_path not in sys.path:
    print(f"Adding {module_path} to sys.path")
    sys.path.append(module_path)

module_path: /home/beegass/Documents/Coding/HiPPO-Jax


In [2]:
# os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "False"
os.environ["TF_FORCE_UNIFIED_MEMORY"] = "1"

In [3]:
## import packages
import jax
import jax.numpy as jnp
import einops
import numpy as np
import torch
import time
from jax import random
import flax.linen as nn
from jaxtyping import Array, Float
from scipy import special as ss
from typing import Any, Callable, List, Optional, Tuple, Union

KeyArray = random.KeyArray

  jax.tree_util.register_keypaths(data_clz, keypaths)
  jax.tree_util.register_keypaths(data_clz, keypaths)


In [4]:
print(jax.devices())
print(f"The Device: {jax.lib.xla_bridge.get_backend().platform}")

[StreamExecutorGpuDevice(id=0, process_index=0, slice_index=0)]
The Device: gpu


In [5]:
print(f"MPS enabled: {torch.backends.mps.is_available()}")

MPS enabled: False


In [6]:
torch.set_printoptions(linewidth=150)
np.set_printoptions(linewidth=150)
jnp.set_printoptions(linewidth=150)

In [7]:
seed = 1701
key = jax.random.PRNGKey(seed)

In [8]:
num_copies = 10
subkeys = jax.random.split(key, num=num_copies)
key = subkeys[0]

## Data Generator

In [9]:
def whitesignal(key, period, dt, freq, rms=0.5, batch_shape=()):
    """
    Produces output signal of length period / dt, band-limited to frequency freq
    Output shape (*batch_shape, period/dt)
    Adapted from the nengo library
    """

    if freq is not None and freq < 1.0 / period:
        raise ValueError(
            f"Make ``{freq=} >= 1. / {period=}`` to produce a non-zero signal",
        )

    nyquist_cutoff = 0.5 / dt
    if freq > nyquist_cutoff:
        raise ValueError(
            f"{freq} must not exceed the Nyquist frequency for the given dt ({nyquist_cutoff:0.3f})"
        )

    n_coefficients = int(jnp.ceil(period / dt / 2.0))
    shape = batch_shape + (n_coefficients + 1,)
    sigma = rms * jnp.sqrt(0.5)
    coefficients = 1j * jax.random.normal(key, shape) * sigma
    coefficients = jnp.array(coefficients)
    coefficients = coefficients.at[..., -1].set(0.0)
    coefficients += jax.random.normal(key, shape) * sigma
    coefficients = jnp.array(coefficients)
    coefficients = coefficients.at[..., 0].set(0.0)

    set_to_zero = jnp.fft.rfftfreq(2 * n_coefficients, d=dt) > freq
    coefficients *= 1 - set_to_zero
    power_correction = jnp.sqrt(
        1.0 - jnp.sum(set_to_zero, dtype=jnp.float32) / n_coefficients
    )
    if power_correction > 0.0:
        coefficients /= power_correction
    coefficients *= jnp.sqrt(2 * n_coefficients)
    signal = jnp.fft.irfft(coefficients, axis=-1)
    signal = signal - signal[..., :1]  # Start from 0
    return signal

## HiPPO Intializers 

In [10]:
class HiPPOInitializer(jax.nn.initializers.Initializer):
    """Base class for HiPPO initializers."""

    def __call__(self, key: jnp.ndarray, shape: tuple, dtype: Any = jnp.float32):
        raise NotImplementedError("HiPPOInitializer must implement __call__")

In [11]:
class LegSInitializer(HiPPOInitializer):
    """Initializer for the Scaled Legendre basis."""

    def __name__(self) -> str:
        return "legs"

    def __call__(
        self, key: KeyArray, shape: tuple, dtype: Any = jnp.float32
    ) -> Tuple[Float[Array, "N N"], Float[Array, "N 1"]]:
        assert (
            shape[0] == shape[1] or shape[1] == 1
        ), "LegSInitializer: shape mismatch for square matrix. Got {shape}"

        N = shape[0]
        q = jnp.arange(N, dtype=dtype)
        k, n = jnp.meshgrid(q, q)
        pre_D = jnp.sqrt(jnp.diag(2 * q + 1))
        B = D = jnp.diag(pre_D)[:, None]

        A_base = jnp.sqrt(2 * n + 1) * jnp.sqrt(2 * k + 1)

        A = jnp.where(n > k, A_base, jnp.where(n == k, n + 1, 0.0))

        if shape[0] >= 1 and shape[0] == shape[1]:
            return -A.astype(dtype)
        elif shape[1] == 1:
            return B.astype(dtype)
        else:
            raise ValueError(
                f"LegSInitializer: shape mismatch for square matrix. Got {shape}"
            )

In [12]:
def legs_initializer() -> Callable:
    return LegSInitializer()

## HiPPO Shell

In [13]:
class HiPPOCell(nn.Module):
    @staticmethod
    def initialize_state(
        rng,
        batch_size: int,
        hidden_size: int,
        init_fn: Callable,
    ):
        raise NotImplementedError

In [14]:
class HiPPO(nn.Module):
    """

    Examples:
            The matrix_args in the format of:
                {N: int,
                 measure: str,
                 lambda_n: Optional[float],
                 alpha: Optional[float], # rotation for lagt
                 beta: Optional[float], # rotation for lagt
                 dtype: Optional[jnp.dtype]
                }

            >>> {N:64, measure:"legs", lambda_n:1.0, dtype:jnp.float16}
            >>> {N:64, measure:"legt", lambda_n:2.0, dtype:jnp.float32} # produces LMU
            >>> {N:64, measure:"legt", lambda_n:1.0, dtype:jnp.float32} # produces LegT
            >>> {N:64, measure:"lagt", alpha:0.0, beta:1.0, dtype:jnp.float6} # produces LagT
            >>> {N:64, measure:"lagt", alpha:0.7, beta:1.4, dtype:jnp.float64} # produces a version of a slightly "rotated" LagT


            HiPPOLSICell, in the format of:
                (max_length: int
                 alpha: Optional[float], # alpha value for discretization
                 measure: Optional[str],
                 recon: Optional[bool],
                 dtype: Optional[jnp.dtype]
                )
            >>> {max_length=1024, alpha=0.0, measure="legs", recon=True, dtype=jnp.float16} # produces HiPPOLSICell w/ forward euler discretization, and reconstruction
            >>> {max_length=512, alpha=1.0, measure="legt", recon=False, dtype=jnp.float32} # produces HiPPOLSICell w/ backward euler discretization, and no reconstruction
            >>> {max_length=256, alpha=0.5, measure="fru", recon=True, dtype=jnp.float32} # produces HiPPOLSICell w/ bilinear transform discretization, and reconstruction
            >>> {max_length=512, alpha=2.0, measure="fout", recon=True, dtype=jnp.float32} # produces HiPPOLSICell w/ zero-order hold discretization, and reconstruction

            HiPPOLTICell, in the format of:
                (step_size: float, # 1 / sequence length
                 basis_size: float, # The intended maximum value of the basis function for the coefficients to be projected onto
                 alpha: Optional[float], # alpha value for discretization
                 recon: Optional[bool],
                 measure: Optional[str],
                 dtype: Optional[jnp.dtype]
                )
            >>> {step_size:1e-3, basis_size:1.0, alpha:0.0, recon:True, measure:"legs", dtype:jnp.float16} # produces HiPPOLTICell w/ forward euler discretization, and reconstruction, discretized every 1/1000 assuming a sequence length is 1000
            >>> {step_size:1e-4, basis_size:1.0, alpha:1.0, recon:True, measure:"lagt", dtype:jnp.float32} # produces HiPPOLTICell w/ backward euler discretization, and reconstruction, discretized every 1/10000 assuming a sequence length is 10000
            >>> {step_size:1e-2, basis_size:1.0, alpha:2.0, recon:True, measure:"foud", dtype:jnp.float64} # produces HiPPOLTICell w/ zero-order hold discretization, and reconstruction, discretized every 1/10000 assuming a sequence length is 100
            >>> {step_size:1e-2, basis_size:1.0, alpha:0.5, recon:True, measure:"fru", dtype:jnp.float16} # produces HiPPOLTICell w/ bilinear transform discretization, and reconstruction, discretized every 1/10000 assuming a sequence length is 100



    Args:
        features (int):
            The size of the hidden state of the HiPPO model

        hippo_cell (HiPPOCell):
            The HiPPOCell class to be used for the HiPPO model

        hippo_args (dict):
            The dict associated with the input parameters into the HiPPOCell class.

        matrix_args (dict):
            The dict associated with the input parameters into the TransMatrix class.

        unroll (bool):
            Determines if you wanted the full history (all time steps) of coefficients, and potentially reconstructions. Defaults to False

    Raises:
        ValueError: Enforces that the inputted cell is a HiPPOCell
    """

    features: int
    hippo_cell: HiPPOCell
    hippo_args: dict
    init_t: int = 0
    unroll: bool = False

    def setup(self) -> None:

        self._hippo = self.hippo_cell(
            features=self.features, init_t=self.init_t, **self.hippo_args
        )

    def __call__(
        self,
        f: Float[Array, "#batch seq_len input_size"],
        c_t_1: Float[Array, "#batch input_size N"],
    ) -> Tuple[
        Union[
            Float[Array, "#batch seq_len input_size N"],
            Float[Array, "#batch input_size N"],
        ],
        Union[
            Float[Array, "#batch input_size N"],
            Float[Array, "#batch seq_len input_size N"],
        ],
    ]:

        if isinstance(self._hippo, HiPPOLTICell):

            def lti_scan_fn(carry, i):
                c_tm1, y_t_1 = carry
                c_t, y = jax.vmap(self._hippo, in_axes=(0, 0, None))(f, c_tm1, i)
                return (c_t, y), (c_t, y)

            (c_n, y_n), (c_s, y_s) = jax.lax.scan(
                f=lti_scan_fn,
                init=(c_t_1, jnp.ones(f.shape)),
                xs=(jnp.arange(f.shape[1] - self.init_t) + 1),
            )

            if self.unroll:
                return c_s, y_s

            else:
                return c_n, y_n

        # elif isinstance(self._hippo, HiPPOLSICell):

        #     def lsi_scan_fn(carry, i):
        #         c_tm1, y_t_1 = carry
        #         c_t, y = jax.vmap(self._hippo, in_axes=(0, 0, None))(f, c_tm1, i)
        #         return (c_t, y), (c_t, y)

        #     (c_n, y_n), (c_s, y_s) = jax.lax.scan(
        #         f=lsi_scan_fn,
        #         init=(c_t_1, jnp.ones(f.shape)),
        #         xs=(jnp.arange(f.shape[1] - self.init_t) + 1),
        #     )

        #     if self.unroll:
        #         return c_s, y_s

        #     else:
        #         return c_n, y_n

        else:
            raise ValueError("hippo must be of type HiPPOLSICell or HiPPOLTICell")

    @staticmethod
    def initialize_state(
        rng,
        batch_size: int,
        hidden_size: int,
        init_fn=nn.initializers.zeros,
    ):
        mem_shape = (batch_size, hidden_size)
        return init_fn(rng, mem_shape)

## HiPPOLTI Cell

In [15]:
class HiPPOLTICell(HiPPOCell):

    features: int
    step_size: float
    basis_size: float
    alpha: float = 0.5
    init_t: int = 0
    recon: bool = True
    measure: str = "legs"
    A_init_fn: Callable = legs_initializer()
    B_init_fn: Callable = legs_initializer()
    dtype: Any = jnp.float32

    def setup(self) -> None:
        A = self.param(
            "A",
            self.A_init_fn,
            (jax.random.PRNGKey(0), (self.features, self.features)),
        )
        B = self.param("B", self.B_init_fn, (jax.random.PRNGKey(0), (self.features, 1)))

        A_d_, B_d_ = jax.lax.stop_gradient(
            self.discretize(
                A=A,
                B=B,
                step=self.step_size,
                alpha=self.alpha,
                dtype=self.dtype,
            )
        )

        self.A_d = A_d_
        self.B_d = B_d_

        if self.measure in ["legs", "legt", "lmu", "lagt", "fout"] and self.recon:
            self.vals = jnp.arange(0.0, self.basis_size, self.step_size)
            self.eval_matrix = self.basis(
                method=self.measure,
                N=self.A_d.shape[0],
                vals=self.vals,
                c=0.0,
                dtype=self.dtype,
            )  # (T/dt, N)

    def __call__(
        self,
        f: Float[Array, "#batch seq_len input_size"],
        c_t_1: Float[Array, "#batch input_size N"],
        t_step: int,
    ) -> Tuple[
        Float[Array, "#batch input_size N"], Float[Array, "#batch input_size N"]
    ]:
        t = t_step - 1 + self.init_t
        c_t = (jnp.dot(c_t_1, (self.A_d).T) + ((self.B_d).T * f[t, :])).astype(
            self.dtype
        )

        if self.measure in ["legs", "legt", "lmu", "lagt", "fout"] and self.recon:
            y = self.reconstruct(c_t).astype(self.dtype)
            return (c_t, y), (c_t, y)
        else:
            return (c_t, c_t), (c_t, c_t)

    def discretize(
        self,
        A: Float[Array, "N N"],
        B: Float[Array, "N input_size"],
        step: float,
        alpha: Union[float, str] = 0.5,
        dtype: Any = jnp.float32,
    ) -> Tuple[Float[Array, "N N"], Float[Array, "N input_size"]]:
        """
        Function used for discretizing the HiPPO A and B matrices

        Args:
            A (jnp.ndarray):
                shape: (N, N)
                matrix to be discretized

            B (jnp.ndarray):
                shape: (N, 1)
                matrix to be discretized

            step (float):
                step size used for discretization

            alpha (float, optional):
                used for determining which generalized bilinear transformation to use
                - forward Euler corresponds to α = 0,
                - backward Euler corresponds to α = 1,
                - bilinear corresponds to α = 0.5,
                - Zero-order Hold corresponds to α > 1

            dtype (jnp.float):
                type of float precision to be used

        Returns:
            GBT_A (jnp.ndarray):
                shape: (N, N)
                discretized A matrix based on the given step size and alpha value

            GBT_B (jnp.ndarray):
                shape: (N, 1)
                discretized B matrix based on the given step size and alpha value
        """
        if alpha <= 1:
            assert alpha in [0, 0.5, 1], "alpha must be 0, 0.5, or 1"
        else:
            assert (
                alpha > 1 or type(alpha) == str
            ), "alpha must be greater than 1 for zero-order hold"
            if type(alpha) == str:
                assert (
                    alpha == "zoh"
                ), "if alpha is a string, it must be defined as 'zoh' for zero-order hold"

        I = jnp.eye(A.shape[0])

        if alpha <= 1:  # Generalized Bilinear Transformation
            step_size = step
            part1 = I - (step_size * alpha * A)
            part2 = I + (step_size * (1 - alpha) * A)

            GBT_A = jnp.linalg.lstsq(part1, part2, rcond=None)[0]
            GBT_B = jnp.linalg.lstsq(part1, (step_size * B), rcond=None)[0]

        else:  # Zero-order Hold
            # refer to this for why this works
            # https://en.wikipedia.org/wiki/Discretization#:~:text=A%20clever%20trick%20to%20compute%20Ad%20and%20Bd%20in%20one%20step%20is%20by%20utilizing%20the%20following%20property

            n = A.shape[0]
            b_n = B.shape[1]
            A_B_square = jnp.block(
                [[A, B], [jnp.zeros((b_n, n)), jnp.zeros((b_n, b_n))]]
            )
            A_B = jax.scipy.linalg.expm(A_B_square * self.step_size)

            GBT_A = A_B[0:n, 0:n]
            GBT_B = A_B[0:-b_n, -b_n:]

        return GBT_A.astype(dtype), GBT_B.astype(dtype)

    def measure_fn(self, method: str, c: float = 0.0) -> Callable:
        """
        Returns a function that is used to measure the distance between the input sequence and the estimated coefficients

        Args:
            method (str):
                The method used to measure the distance between the input sequence and the estimated coefficients

            c (float):
                The tilt of the function used to measure the distance between the input sequence and the estimated coefficients

        Returns:
            fn_tilted (Callable):
                The function used to measure the distance between the input sequence and the estimated coefficients

        """

        if method == "legs":
            fn = lambda x: jnp.heaviside(x, 1.0) * jnp.exp(-x)
        else:
            raise NotImplementedError

        fn_tilted = lambda x: jnp.exp(c * x) * fn(x)

        return fn_tilted

    def basis(
        self,
        method: str,
        N: int,
        vals: Float[Array, "1"],
        c: float = 0.0,
        truncate_measure: bool = True,
        dtype: Any = jnp.float32,
    ) -> Float[Array, "seq_len N"]:
        """
        Creates the basis matrix (eval matrix) for the appropriate HiPPO method.

        Args:
            B (jnp.ndarray):
                shape: (N, 1)
                The HiPPO B matrix

            method (str):
                The HiPPO method to use

            N (int):
                The number of basis functions to use

            vals (jnp.ndarray):
                shape: (seq_len, )
                The values to evaluate the basis functions at

            c (float):
                The constant to use for the tilted measure

            truncate_measure (bool):
                Whether or not to truncate the measure to the interval [0, 1]

            dtype (Any):
                The dtype to use for the basis matrix

        Returns:
            eval_matrix (jnp.ndarray):
                shape: (seq_len, N)
                The basis matrix
        """

        if method == "legs":
            _vals = jnp.exp(-vals)
            base = (2 * jnp.arange(N) + 1) ** 0.5 * (-1) ** jnp.arange(
                N
            )  # unscaled, untranslated legendre polynomial matrix
            base = einops.rearrange(base, "N -> N 1")
            eval_matrix = (
                jax.lax.stop_gradient(
                    ss.eval_legendre(jnp.expand_dims(jnp.arange(N), -1), 1 - 2 * _vals)
                )
                * base
            ).T  # (L, N)
        else:
            raise NotImplementedError(f"method {method} not implemented")

        if truncate_measure:
            tilting_fn = self.measure_fn(method, c=c)
            val = tilting_fn(vals)
            eval_matrix = eval_matrix.at[val == 0.0].set(0.0)

        p = eval_matrix * jnp.exp(-c * vals)[:, None]  # [::-1, None]

        return p.astype(dtype)

    def reconstruct(
        self, c: Float[Array, "#batch input_size N"], evals=None
    ) -> Float[Array, "#batch seq_len input_size"]:
        """reconstructs the input sequence from the estimated coefficients and the evaluation matrix

        Args:
            c (jnp.ndarray):
                shape: (batch size, input length, N)
                Vector of the estimated coefficients, given the history of the function/sequence

            evals (jnp.ndarray, optional):
                shape: ()
                Vector of the evaluation points. Defaults to None.

        Returns:
            y (jnp.ndarray):
                shape: (batch size, input length, input size)
                The reconstructed input sequence
        """
        if evals is not None:
            eval_matrix = self.basis(method=self.measure, N=self.N, vals=evals)
        else:
            eval_matrix = self.eval_matrix

        y = None
        if len(c.shape) == 3:
            c = einops.rearrange(c, "batch input_size N -> batch N input_size")
            y = jax.vmap(jnp.dot, in_axes=(None, 0))(eval_matrix, c)
            y = einops.rearrange(y, "batch seq_len 1 -> batch seq_len")
            y = jax.vmap(jnp.flip, in_axes=(0, None))(y, 0)
        elif len(c.shape) == 4:
            c = einops.rearrange(
                c, "batch seq_len input_size N -> batch seq_len N input_size"
            )
            time_dot = jax.vmap(jnp.dot, in_axes=(None, 0))
            batch_time_dot = jax.vmap(time_dot, in_axes=(None, 0))
            y = batch_time_dot(eval_matrix, c)
            y = einops.rearrange(
                y, "batch seq_len 1 seq_len2 -> batch seq_len seq_len2"
            )
            y = jax.vmap(jax.vmap(jnp.flip, in_axes=(0, None)), in_axes=(0, None))(y, 0)
        else:
            raise ValueError(
                "c must be of shape (batch size, input length, N) or (batch seq_len input_size N)"
            )

        return y

    @staticmethod
    def initialize_state(
        rng,
        batch_size: int,
        hidden_size: int,
        init_fn=nn.initializers.zeros,
    ):
        mem_shape = (batch_size, hidden_size)
        return init_fn(rng, mem_shape)

## Test HiPPO Operators

In [16]:
def test_hippo_operator(key, hippo, random_input, hidden_size, batch_size):
    x_jnp = jnp.asarray(random_input, dtype=jnp.float32)
    x_jnp = einops.rearrange(x_jnp, "batch seq_len -> batch seq_len 1")

    c_t_1 = hippo.initialize_state(
        subkeys[7], batch_size=batch_size, hidden_size=hidden_size
    )
    params = hippo.init(key, f=x_jnp, c_t_1=c_t_1)

    start = time.time()
    c, y = hippo.apply(params, f=x_jnp)
    end = time.time()

    duration = end - start
    print(f"Duration: {duration}")

In [17]:
def test_operators(the_measure="legs", alpha=0.5):
    T = 1
    freq = 1
    step = 1e-3
    L = int(T / step)

    batch_size = 1
    data_size = L
    input_size = 1

    N = 64

    u = whitesignal(subkeys[4], T, step, freq, batch_shape=(batch_size,))
    x_np = np.asarray(u)

    # ----------------------------------------------------------------------------------
    # ------------------------------ Instantiate My HiPPOs -----------------------------
    # ----------------------------------------------------------------------------------
    print(f"Creating HiPPO-{the_measure} LTI model with {alpha} transform")
    hippo_lti_cell = HiPPOLTICell
    h_args = {
        "step_size": step,
        "basis_size": T,
        "alpha": alpha,
        "recon": True,
        "A_init_fn": legs_initializer(),
        "B_init_fn": legs_initializer(),
        "measure": the_measure,
    }
    hippo_lti = HiPPO(
        features=N,
        hippo_cell=hippo_lti_cell,
        hippo_args=h_args,
        init_t=0,
        unroll=False,
    )

    print(f"Testing Coeffiecients for {alpha} LTI HiPPO-{the_measure}")

    test_hippo_operator(
        key=subkeys[5],
        hippo=hippo_lti,
        random_input=x_np,
        hidden_size=N,
        batch_size=batch_size,
    )

    print(f"end of test for HiPPO-{the_measure} model")

#### LegS

In [18]:
test_operators(the_measure="legs", alpha=0.0)

Creating HiPPO-legs LTI model with 0.0 transform
Testing Coeffiecients for 0.0 LTI HiPPO-legs


AttributeError: "HiPPOLTICell" object has no attribute "A_d"