# Diagonal Plus Low Rank (DPLR) HiPPO Matrices
---

## Table of Contents
* [Loading In Necessary Packages](#load-packages)
    * [Instantiate The HiPPO Matrix](#instantiate-the-hippo-matrix)
        * [Translated Legendre (LegT)](#translated-legendre-legt)
            * [LegT](#legt)
            * [LMU](#lmu)
        * [Translated Laguerre (LagT)](#translated-laguerre-lagt)
        * [Scaled Legendre (LegS)](#scaled-legendre-legs)
        * [Fourier Basis](#fourier-basis)
            * [Fourier Recurrent Unit (FRU)](#fourier-recurrent-unit-fru)
            * [Truncated Fourier (FouT)](#truncated-fourier-fout)
            * [Fourier With Decay (FourD)](#fourier-with-decay-fourd)
    * [Make HiPPO Matrices DPLR](#make-hippo-matrices-dplr)
        * [DPLR-LegT](#dplr-legt)
            * [DPLR-LegT](#dplr-legt)
            * [DPLR-LMU](#dplr-lmu)
        * [DPLR-LagT](#dplr-lagt)
        * [DPLR-LegS](#dplr-legs)
        * [DPLR Applied To Fourier Basis](#dplr-applied-to-fourier-basis)
            * [DPLR-FRU](#nplr-fru)
            * [DPLR-FouT](#nplr-fout)
            * [DPLR-FouD](#nplr-foud)
    * [Utilities For Gu HiPPO Operator](#utilities-for-gu-hippo-operator)
    * [Gu's HiPPO LegT Operator](#gus-hippo-legt-operator)
    * [Gu's Scale invariant HiPPO LegS Operator](#gus-scale-invariant-hippo-legs-operator)
    * [Implementation Of General HiPPO Operator](#implementation-of-general-hippo-operator)
    * [Output](#output)
---


## Load Packages

In [1]:
import os
import sys

module_path = os.path.abspath(os.path.join("../../../"))
if module_path not in sys.path:
    sys.path.append(module_path)

In [2]:
## import packages
import jax
import jax.numpy as jnp
from opt_einsum import contract
import numpy as np
import torch

from src.models.hippo.transition import TransMatrix
from src.models.hippo.gu_transition import GuLowRankMatrix

print(jax.devices())
print(f"The Device: {jax.lib.xla_bridge.get_backend().platform}")

[StreamExecutorGpuDevice(id=0, process_index=0)]
The Device: gpu


In [3]:
print(f"MPS enabled: {torch.backends.mps.is_available()}")

MPS enabled: False


In [5]:
torch.set_printoptions(linewidth=150)
np.set_printoptions(linewidth=150)
jnp.set_printoptions(linewidth=150)

In [6]:
N = 8

In [7]:
seed = 1701
key = jax.random.PRNGKey(seed)

In [8]:
num_copies = 5
rng, key2, key3, key4, key5 = jax.random.split(key, num=num_copies)

## Make HiPPO Matrices DPLR

In [24]:
class LowRankMatrix:
    def __init__(
        self,
        N,
        rank,
        measure="legs",
        lambda_n=1,
        fourier_type="fru",
        alpha=0,
        beta=1,
        DPLR=True,
        dtype=jnp.float32,
    ):
        self.N = N
        self.measure = measure
        self.rank = rank
        _trans_matrix = TransMatrix(N, measure, lambda_n, fourier_type, alpha, beta)

        Lambda = None
        B = None
        V = None
        A, B, P, S = self.make_NPLR(trans_matrix=_trans_matrix, dtype=dtype)
        if DPLR:
            Lambda, P, B, V = self.make_DPLR(B=B, P=P, S=S)
            self.Lambda = (Lambda.copy()).astype(dtype)  # real eigenvalues
            self.V = (V.copy()).astype(dtype)  # imaginary (complex) eigenvalues

        self.A = (A.copy()).astype(dtype)  # HiPPO A Matrix (N x N)
        self.B = (B.copy()).astype(dtype)  # HiPPO B Matrix (N x 1)
        self.P = (P.copy()).astype(dtype)  # HiPPO rank correction matrix (N x rank)
        self.S = (S.copy()).astype(
            dtype
        )  # HiPPO normal (skew-symmetric) matrix (N x N)

    def check_skew(self, S):
        """Check if a matrix is skew symmetric

        We require AP to be nearly skew-symmetric. To be clear, AP IS NOT skew-symmetric.
        However, it is skew-symmetric up to a small error. This function checks that error is within an acceptable tolerance.

        refer to:
        - https://www.cuemath.com/algebra/skew-symmetric-matrix/
        - https://en.wikipedia.org/wiki/Skew-symmetric_matrix

        """
        _S = S + S.transpose(
            -1, -2
        )  # ensure matrices are skew symmetric by assuming S is skew symmetric, adding two skew symmetric matrices results in a skew symmetric matrix
        if (
            err := jnp.sum((_S - _S[0, 0] * jnp.eye(self.N)) ** 2) / self.N
        ) > 1e-5:  # if not torch.allclose(_A - _A[0,0]*torch.eye(N), torch.zeros(N, N), atol=1e-5):
            print("WARNING: HiPPO matrix not skew symmetric", err)
            # print(
            #     f"Transposed matrix:\n{_S.transpose(-1, -2)}\n\nUnchanged matrix:\n{-_S}"
            # )  # the transpose of a skew symmetric matrix is equal to the negative of the matrix

        return _S

    def fix_zeroed_eigvals(self, Lambda, V, S):
        # Only keep half of each conjugate pair
        imaginary_eigvals = Lambda.imag
        idx = jnp.argsort(imaginary_eigvals)
        Lambda_sorted = Lambda[idx]
        V_sorted = V[:, idx]

        # There is an edge case when eigenvalues can be 0, which requires some machinery to handle
        # We use a huge hack here: Assume only one pair is 0, and that it is the first row/column of A (only happens in Fourier case)
        V = V_sorted[:, : self.N // 2]
        Lambda = Lambda_sorted[: self.N // 2]
        assert (
            jnp.abs(Lambda[-2]) > 1e-4
        ), "Only 1 zero eigenvalue allowed in diagonal part of A"
        if jnp.abs(Lambda[-1]) < 1e-4:
            # x = x.at[idx].set(y)
            V = V.at[:, -1].set(0.0)  # V[:, -1] = 0.0
            V = V.at[0, -1].set(2**-0.5)  # V[0, -1] = 2**-0.5
            V = V.at[1, -1].set(2**-0.5 * 1j)  # V[1, -1] = 2**-0.5 * 1j

        _AP = V @ jnp.diag(Lambda) @ V.conj().transpose(-1, -2)

        if (err := jnp.sum((2 * _AP.real - S) ** 2) / self.N) > 1e-5:
            print(
                "Warning: Diagonalization of A matrix not numerically precise - error",
                err,
            )

        return Lambda, V

    def make_NPLR(self, trans_matrix, dtype=jnp.float32):
        A = trans_matrix.A_matrix
        B = trans_matrix.B_matrix

        P = self.rank_correction(dtype=dtype)  # (r N)

        S = A + jnp.sum(
            jnp.expand_dims(P, -2) * jnp.expand_dims(P, -1), axis=-3
        )  # rank correct if rank > 1, summation happens in outer most dimension
        # S is nearly skew-symmetric

        return A, B, P, S

    def make_DPLR(self, B, P, S):
        """Diagonalize NPLR representation"""

        _S = self.check_skew(S=S)

        # Check skew symmetry
        S_diag = jnp.diagonal(S)
        Lambda_real = jnp.mean(S_diag, -1, keepdims=True) * jnp.ones_like(
            S_diag
        )  # S itself is not skew-symmetric. It is skew-symmetric by: S + c * I. Extract the value c, c = mean(S_diag)

        # Diagonalize S to V \Lambda V^*
        Lambda_imaginary, V = jnp.linalg.eigh(S * -1j)
        Lambda = Lambda_real + 1j * Lambda_imaginary

        Lambda, V = self.fix_zeroed_eigvals(Lambda=Lambda, V=V, S=_S)
        B = B[:, 0]
        V_inv = V.conj().transpose(-1, -2)
        print(f"Lambda:\n{Lambda}")
        print(f"V_inv:\n{V_inv}")
        print(f"P:\n{P}")
        print(f"B:\n{B}")
        print(f"V_inv  shape:\n{V_inv.shape}")
        print(f"P shape:\n{P.shape}")
        print(f"B shape:\n{B.shape}")

        B = contract("ij, j -> i", V_inv, B)
        # B = contract("ij, j -> i", V.conj().transpose(-1, -2), B.to(V))  # V^* B

        P = contract("ij, ...j -> ...i", V_inv, P)
        # P = contract("ij, ...j -> ...i", V.conj().transpose(-1, -2), P.to(V))  # V^* P

        print(f"B after einsum:\n{B}")
        print(f"P after einsum:\n{P}")

        return Lambda, P, B, V

    def rank_correction(self, dtype=jnp.float32):
        """Return low-rank matrix L such that A + L is normal"""

        if self.measure == "legs":
            assert self.rank >= 1
            P = jnp.expand_dims(
                jnp.sqrt(0.5 + jnp.arange(self.N, dtype=dtype)), 0
            )  # (1 N)

        elif self.measure == "legt":
            assert self.rank >= 2
            P = jnp.sqrt(1 + 2 * jnp.arange(self.N, dtype=dtype))  # (N)
            P0 = P.clone()
            P0 = P0.at[0::2].set(0.0)  # P0[0::2] = 0.0
            P1 = P.clone()
            P1 = P1.at[1::2].set(0.0)  # P1[1::2] = 0.0
            P = jnp.stack([P0, P1], axis=0)  # (2 N)
            P = P * (
                2 ** (-0.5)
            )  # Halve the rank correct just like the original matrix was halved

        elif self.measure == "lagt":
            assert self.rank >= 1
            P = 0.5**0.5 * jnp.ones((1, self.N), dtype=dtype)

        elif self.measure in ["fourier", "fout"]:
            P = jnp.zeros(self.N)
            P = P.at[0::2].set(2**0.5)  # P[0::2] = 2**0.5
            P = P.at[0].set(1)  # P[0] = 1
            P = jnp.expand_dims(P, 0)

        elif self.measure == "fourier_decay":
            P = jnp.zeros(self.N)
            P = P.at[0::2].set(2**0.5)  # P[0::2] = 2**0.5
            P = P.at[0].set(1)  # P[0] = 1
            P = jnp.expand_dims(P, 0)
            P = P / 2**0.5

        elif self.measure == "fourier2":
            P = jnp.zeros(self.N)
            P = P.at[0::2].set(2**0.5)  # P[0::2] = 2**0.5
            P = P.at[0].set(1)  # P[0] = 1
            P = 2**0.5 * jnp.expand_dims(P, 0)

        elif self.measure in ["fourier_diag", "foud", "legsd"]:
            P = jnp.zeros((1, self.N), dtype=dtype)

        else:
            raise NotImplementedError

        d = jnp.size(P, axis=0)
        if self.rank > d:
            P = jnp.concatenate(
                [P, jnp.zeros((self.rank - d, self.N), dtype=dtype)], axis=0
            )  # (rank N)

        return P

- [NPLR](#nplr)
- [DPLR](#dplr)

---
# NPLR
---

## NPLR-LegT

### NPLR-LegT

[Navigate Back To Low Rank Matrix Class](#make-hippo-matrices-dplr)

In [25]:
def test_NPLR_LegT():
    the_measure = "legt"
    rank = 2
    nplr_legt = LowRankMatrix(
        N=8, rank=rank, measure=the_measure, lambda_n=1.0, DPLR=False
    )
    gu_nplr_legt = GuLowRankMatrix(
        N=8, rank=rank, measure=the_measure, lambda_n=1.0, DPLR=False
    )
    A, B, P, S = nplr_legt.A, nplr_legt.B, nplr_legt.P, nplr_legt.S
    gu_A, gu_B, gu_P, gu_S = (
        jnp.asarray(gu_nplr_legt.A, dtype=jnp.float32),
        jnp.asarray(gu_nplr_legt.B, dtype=jnp.float32),
        jnp.asarray(gu_nplr_legt.P, dtype=jnp.float32),
        jnp.asarray(gu_nplr_legt.S, dtype=jnp.float32),
    )
    print("NPLR LEGT")
    print(f"\nA:\n{A}\n")
    print(f"gu_A:\n{gu_A}\n")
    assert jnp.allclose(A, gu_A, rtol=1e-04, atol=1e-06)

    print(f"B:\n{B}\n")
    print(f"gu_B:\n{gu_B}\n")
    assert jnp.allclose(B, gu_B, rtol=1e-04, atol=1e-06)

    print(f"P:\n{P}\n")
    print(f"gu_P:\n{gu_P}\n")
    assert jnp.allclose(P, gu_P, rtol=1e-04, atol=1e-06)

    print(f"S:\n{S}\n")
    print(f"gu_S:\n{gu_S}\n")
    assert jnp.allclose(S, gu_S, rtol=1e-04, atol=1e-06)

In [26]:
test_NPLR_LegT()

  A = torch.from_numpy(np_A)  # (N, N)


NPLR LEGT

A:
[[ -1.          1.7320508  -2.2360678   2.6457512  -3.          3.3166246  -3.6055512   3.8729832]
 [ -1.7320508  -3.          3.872983   -4.5825753   5.196152   -5.744562    6.244998   -6.708204 ]
 [ -2.2360678  -3.872983   -4.999999    5.916079   -6.7082033   7.4161973  -8.062257    8.660253 ]
 [ -2.6457512  -4.5825753  -5.916079   -6.9999995   7.937254   -8.774963    9.5393915 -10.24695  ]
 [ -3.         -5.196152   -6.7082033  -7.937254   -9.          9.949874  -10.816654   11.61895  ]
 [ -3.3166246  -5.744562   -7.4161973  -8.774963   -9.949874  -10.999999   11.958261  -12.845232 ]
 [ -3.6055512  -6.244998   -8.062257   -9.5393915 -10.816654  -11.958261  -13.         13.964239 ]
 [ -3.8729832  -6.708204   -8.660253  -10.24695   -11.61895   -12.845232  -13.964239  -14.999999 ]]

gu_A:
[[ -1.          1.7320508  -2.2360678   2.6457512  -3.          3.3166246  -3.6055512   3.8729832]
 [ -1.7320508  -3.          3.872983   -4.5825753   5.196152   -5.744562    6.244998   

### NPLR-LMU

[Navigate Back To Low Rank Matrix Class](#make-hippo-matrices-dplr)

In [27]:
def test_NPLR_LMU():
    the_measure = "legt"
    rank = 2
    nplr_lmu = LowRankMatrix(
        N=8, rank=rank, measure=the_measure, lambda_n=2.0, DPLR=False
    )
    gu_nplr_lmu = GuLowRankMatrix(
        N=8, rank=rank, measure=the_measure, lambda_n=2.0, DPLR=False
    )  # change lambda so resulting matrix is in the form of LMU
    print("NPLR LMU")
    A, B, P, S = nplr_lmu.A, nplr_lmu.B, nplr_lmu.P, nplr_lmu.S
    gu_A, gu_B, gu_P, gu_S = (
        jnp.asarray(gu_nplr_lmu.A, dtype=jnp.float32),
        jnp.asarray(gu_nplr_lmu.B, dtype=jnp.float32),
        jnp.asarray(gu_nplr_lmu.P, dtype=jnp.float32),
        jnp.asarray(gu_nplr_lmu.S, dtype=jnp.float32),
    )
    print(f"\nA:\n{A}\n")
    print(f"gu_A:\n{gu_A}\n")
    assert jnp.allclose(A, gu_A, rtol=1e-04, atol=1e-06)

    print(f"B:\n{B}\n")
    print(f"gu_B:\n{gu_B}\n")
    assert jnp.allclose(B, gu_B, rtol=1e-04, atol=1e-06)

    print(f"P:\n{P}\n")
    print(f"gu_P:\n{gu_P}\n")
    assert jnp.allclose(P, gu_P, rtol=1e-04, atol=1e-06)

    print(f"S:\n{S}\n")
    print(f"gu_S:\n{gu_S}\n")
    assert jnp.allclose(S, gu_S, rtol=1e-04, atol=1e-06)

In [28]:
test_NPLR_LMU()

NPLR LMU

A:
[[ -1.  -1.  -1.  -1.  -1.  -1.  -1.  -1.]
 [  3.  -3.  -3.  -3.  -3.  -3.  -3.  -3.]
 [ -5.   5.  -5.  -5.  -5.  -5.  -5.  -5.]
 [  7.  -7.   7.  -7.  -7.  -7.  -7.  -7.]
 [ -9.   9.  -9.   9.  -9.  -9.  -9.  -9.]
 [ 11. -11.  11. -11.  11. -11. -11. -11.]
 [-13.  13. -13.  13. -13.  13. -13. -13.]
 [ 15. -15.  15. -15.  15. -15.  15. -15.]]

gu_A:
[[ -1.  -1.  -1.  -1.  -1.  -1.  -1.  -1.]
 [  3.  -3.  -3.  -3.  -3.  -3.  -3.  -3.]
 [ -5.   5.  -5.  -5.  -5.  -5.  -5.  -5.]
 [  7.  -7.   7.  -7.  -7.  -7.  -7.  -7.]
 [ -9.   9.  -9.   9.  -9.  -9.  -9.  -9.]
 [ 11. -11.  11. -11.  11. -11. -11. -11.]
 [-13.  13. -13.  13. -13.  13. -13. -13.]
 [ 15. -15.  15. -15.  15. -15.  15. -15.]]

B:
[[  1.]
 [ -3.]
 [  5.]
 [ -7.]
 [  9.]
 [-11.]
 [ 13.]
 [-15.]]

gu_B:
[[  1.]
 [ -3.]
 [  5.]
 [ -7.]
 [  9.]
 [-11.]
 [ 13.]
 [-15.]]

P:
[[0.         1.2247448  0.         1.8708286  0.         2.3452077  0.         2.7386127 ]
 [0.70710677 0.         1.5811386  0.         2.121320

## NPLR-LagT

[Navigate Back To Low Rank Matrix Class](#make-hippo-matrices-dplr)

In [29]:
def test_NPLR_LagT():
    the_measure = "lagt"
    rank = 1
    nplr_lagt = LowRankMatrix(
        N=8,
        rank=rank,
        measure=the_measure,
        alpha=0.0,  # change resulting tilt through alpha and beta
        beta=1.0,
        DPLR=False,
    )  # change resulting tilt through alpha and beta
    gu_nplr_lagt = GuLowRankMatrix(
        N=8,
        rank=rank,
        measure=the_measure,
        alpha=0.0,
        beta=1.0,
        DPLR=False,
    )
    print("NPLR LAGT")
    A, B, P, S = nplr_lagt.A, nplr_lagt.B, nplr_lagt.P, nplr_lagt.S
    gu_A, gu_B, gu_P, gu_S = (
        jnp.asarray(gu_nplr_lagt.A, dtype=jnp.float32),
        jnp.asarray(gu_nplr_lagt.B, dtype=jnp.float32),
        jnp.asarray(gu_nplr_lagt.P, dtype=jnp.float32),
        jnp.asarray(gu_nplr_lagt.S, dtype=jnp.float32),
    )
    print(f"\nA:\n{A}\n")
    print(f"gu_A:\n{gu_A}\n")
    assert jnp.allclose(A, gu_A, rtol=1e-04, atol=1e-06)

    print(f"B:\n{B}\n")
    print(f"gu_B:\n{gu_B}\n")
    assert jnp.allclose(B, gu_B, rtol=1e-04, atol=1e-06)

    print(f"P:\n{P}\n")
    print(f"gu_P:\n{gu_P}\n")
    assert jnp.allclose(P, gu_P, rtol=1e-04, atol=1e-06)

    print(f"S:\n{S}\n")
    print(f"gu_S:\n{gu_S}\n")
    assert jnp.allclose(S, gu_S, rtol=1e-04, atol=1e-06)

In [30]:
test_NPLR_LagT()

NPLR LAGT

A:
[[-1.         -0.         -0.         -0.         -0.         -0.         -0.         -0.        ]
 [-1.         -1.         -0.         -0.         -0.         -0.         -0.         -0.        ]
 [-1.         -1.         -1.         -0.         -0.         -0.         -0.         -0.        ]
 [-1.         -1.         -1.         -1.         -0.         -0.         -0.         -0.        ]
 [-1.         -1.         -1.         -1.         -1.         -0.         -0.         -0.        ]
 [-1.         -1.         -1.         -1.         -1.         -1.         -0.         -0.        ]
 [-1.         -1.         -1.         -1.         -1.         -1.         -1.         -0.        ]
 [-0.99999976 -0.99999976 -0.99999976 -0.99999976 -0.99999976 -0.99999976 -0.99999976 -1.        ]]

gu_A:
[[-1.         -0.         -0.         -0.         -0.         -0.         -0.         -0.        ]
 [-1.         -1.         -0.         -0.         -0.         -0.         -0.         -

## NPLR-LegS

[Navigate Back To Low Rank Matrix Class](#make-hippo-matrices-dplr)

In [31]:
def test_NPLR_LegS():
    the_measure = "legs"
    rank = 1
    nplr_legs = LowRankMatrix(N=8, rank=rank, measure=the_measure, DPLR=False)
    gu_nplr_legs = GuLowRankMatrix(N=8, rank=rank, measure=the_measure, DPLR=False)
    print("NPLR LEGS")
    A, B, P, S = nplr_legs.A, nplr_legs.B, nplr_legs.P, nplr_legs.S
    gu_A, gu_B, gu_P, gu_S = (
        jnp.asarray(gu_nplr_legs.A, dtype=jnp.float32),
        jnp.asarray(gu_nplr_legs.B, dtype=jnp.float32),
        jnp.asarray(gu_nplr_legs.P, dtype=jnp.float32),
        jnp.asarray(gu_nplr_legs.S, dtype=jnp.float32),
    )
    print(f"\nA:\n{A}\n")
    print(f"gu_A:\n{gu_A}\n")
    assert jnp.allclose(A, gu_A, rtol=1e-04, atol=1e-06)

    print(f"B:\n{B}\n")
    print(f"gu_B:\n{gu_B}\n")
    assert jnp.allclose(B, gu_B, rtol=1e-04, atol=1e-06)

    print(f"P:\n{P}\n")
    print(f"gu_P:\n{gu_P}\n")
    assert jnp.allclose(P, gu_P, rtol=1e-04, atol=1e-06)

    print(f"S:\n{S}\n")
    print(f"gu_S:\n{gu_S}\n")
    assert jnp.allclose(S, gu_S, rtol=1e-04, atol=1e-06)

In [32]:
test_NPLR_LegS()

NPLR LEGS

A:
[[ -1.         -0.         -0.         -0.         -0.         -0.         -0.         -0.       ]
 [ -1.7320508  -2.         -0.         -0.         -0.         -0.         -0.         -0.       ]
 [ -2.2360678  -3.872983   -3.         -0.         -0.         -0.         -0.         -0.       ]
 [ -2.6457512  -4.5825753  -5.916079   -4.         -0.         -0.         -0.         -0.       ]
 [ -3.         -5.196152   -6.7082033  -7.937254   -5.         -0.         -0.         -0.       ]
 [ -3.3166246  -5.744562   -7.4161973  -8.774963   -9.949874   -6.         -0.         -0.       ]
 [ -3.6055512  -6.244998   -8.062257   -9.5393915 -10.816654  -11.958261   -7.         -0.       ]
 [ -3.8729832  -6.708204   -8.660253  -10.24695   -11.61895   -12.845232  -13.964239   -8.       ]]

gu_A:
[[ -1.          0.          0.          0.          0.          0.          0.          0.       ]
 [ -1.7320508  -1.9999999   0.          0.          0.          0.          0.         

## NPLR Applied To Fourier Basis

### NPLR-FRU

[Navigate Back To Low Rank Matrix Class](#make-hippo-matrices-dplr)

In [33]:
def test_NPLR_FRU():
    the_measure = "fourier"
    fourier_type = "fru"
    rank = 1
    nplr_fru = LowRankMatrix(
        N=8, rank=rank, measure=the_measure, fourier_type=fourier_type, DPLR=False
    )
    gu_nplr_fru = GuLowRankMatrix(
        N=8, rank=rank, measure=the_measure, fourier_type=fourier_type, DPLR=False
    )
    print("NPLR FRU")
    A, B, P, S = nplr_fru.A, nplr_fru.B, nplr_fru.P, nplr_fru.S
    gu_A, gu_B, gu_P, gu_S = (
        jnp.asarray(gu_nplr_fru.A, dtype=jnp.float32),
        jnp.asarray(gu_nplr_fru.B, dtype=jnp.float32),
        jnp.asarray(gu_nplr_fru.P, dtype=jnp.float32),
        jnp.asarray(gu_nplr_fru.S, dtype=jnp.float32),
    )
    print(f"\nA:\n{A}\n")
    print(f"gu_A:\n{gu_A}\n")
    assert jnp.allclose(A, gu_A, rtol=1e-04, atol=1e-06)

    print(f"B:\n{B}\n")
    print(f"gu_B:\n{gu_B}\n")
    assert jnp.allclose(B, gu_B, rtol=1e-04, atol=1e-06)

    print(f"P:\n{P}\n")
    print(f"gu_P:\n{gu_P}\n")
    assert jnp.allclose(P, gu_P, rtol=1e-04, atol=1e-06)

    print(f"S:\n{S}\n")
    print(f"gu_S:\n{gu_S}\n")
    assert jnp.allclose(S, gu_S, rtol=1e-04, atol=1e-06)

In [34]:
test_NPLR_FRU()

NPLR FRU

A:
[[-1.        -0.        -1.4142135  0.        -1.4142135  0.        -1.4142135  0.       ]
 [ 0.         0.         0.         0.         0.         0.         0.         0.       ]
 [-1.4142135  0.        -2.        -3.1415927 -2.         0.        -2.         0.       ]
 [ 0.         0.         3.1415927  0.         0.         0.         0.         0.       ]
 [-1.4142135  0.        -2.         0.        -2.        -6.2831855 -2.         0.       ]
 [ 0.         0.         0.         0.         6.2831855  0.         0.         0.       ]
 [-1.4142135  0.        -2.         0.        -2.         0.        -2.        -9.424778 ]
 [ 0.         0.         0.         0.         0.         0.         9.424778   0.       ]]

gu_A:
[[-1.         0.        -1.4142135  0.        -1.4142135  0.        -1.4142135  0.       ]
 [ 0.         0.         0.         0.         0.         0.         0.         0.       ]
 [-1.4142135  0.        -1.9999999 -3.1415927 -1.9999999  0.        -

### NPLR-FouT

[Navigate Back To Low Rank Matrix Class](#make-hippo-matrices-dplr)

In [35]:
def test_NPLR_FouT():
    the_measure = "fourier"
    fourier_type = "fout"
    rank = 1
    nplr_fout = LowRankMatrix(
        N=8, rank=rank, measure=the_measure, fourier_type=fourier_type, DPLR=False
    )
    gu_nplr_fout = GuLowRankMatrix(
        N=8, rank=rank, measure=the_measure, fourier_type=fourier_type, DPLR=False
    )
    print("NPLR FOUT")
    A, B, P, S = nplr_fout.A, nplr_fout.B, nplr_fout.P, nplr_fout.S
    gu_A, gu_B, gu_P, gu_S = (
        jnp.asarray(gu_nplr_fout.A, dtype=jnp.float32),
        jnp.asarray(gu_nplr_fout.B, dtype=jnp.float32),
        jnp.asarray(gu_nplr_fout.P, dtype=jnp.float32),
        jnp.asarray(gu_nplr_fout.S, dtype=jnp.float32),
    )
    print(f"\nA:\n{A}\n")
    print(f"gu_A:\n{gu_A}\n")
    assert jnp.allclose(A, gu_A, rtol=1e-04, atol=1e-06)

    print(f"B:\n{B}\n")
    print(f"gu_B:\n{gu_B}\n")
    assert jnp.allclose(B, gu_B, rtol=1e-04, atol=1e-06)

    print(f"P:\n{P}\n")
    print(f"gu_P:\n{gu_P}\n")
    assert jnp.allclose(P, gu_P, rtol=1e-04, atol=1e-06)

    print(f"S:\n{S}\n")
    print(f"gu_S:\n{gu_S}\n")
    assert jnp.allclose(S, gu_S, rtol=1e-04, atol=1e-06)

In [36]:
test_NPLR_FouT()

NPLR FOUT

A:
[[ -2.         -0.         -2.828427    0.         -2.828427    0.         -2.828427    0.       ]
 [  0.          0.          0.          0.          0.          0.          0.          0.       ]
 [ -2.828427    0.         -4.         -6.2831855  -4.          0.         -4.          0.       ]
 [  0.          0.          6.2831855   0.          0.          0.          0.          0.       ]
 [ -2.828427    0.         -4.          0.         -4.        -12.566371   -4.          0.       ]
 [  0.          0.          0.          0.         12.566371    0.          0.          0.       ]
 [ -2.828427    0.         -4.          0.         -4.          0.         -4.        -18.849556 ]
 [  0.          0.          0.          0.          0.          0.         18.849556    0.       ]]

gu_A:
[[ -2.          0.         -2.828427    0.         -2.828427    0.         -2.828427    0.       ]
 [  0.          0.          0.          0.          0.          0.          0.         

### NPLR-FouD

[Navigate Back To Low Rank Matrix Class](#make-hippo-matrices-dplr)

In [37]:
def test_NPLR_FouD():
    the_measure = "fourier"
    fourier_type = "foud"
    rank = 1
    nplr_foud = LowRankMatrix(
        N=8, rank=rank, measure=the_measure, fourier_type=fourier_type, DPLR=False
    )
    gu_nplr_foud = GuLowRankMatrix(
        N=8, rank=rank, measure=the_measure, fourier_type=fourier_type, DPLR=False
    )
    print("NPLR FOUD")
    A, B, P, S = nplr_foud.A, nplr_foud.B, nplr_foud.P, nplr_foud.S
    gu_A, gu_B, gu_P, gu_S = (
        jnp.asarray(gu_nplr_foud.A, dtype=jnp.float32),
        jnp.asarray(gu_nplr_foud.B, dtype=jnp.float32),
        jnp.asarray(gu_nplr_foud.P, dtype=jnp.float32),
        jnp.asarray(gu_nplr_foud.S, dtype=jnp.float32),
    )
    print(f"\nA:\n{A}\n")
    print(f"gu_A:\n{gu_A}\n")
    assert jnp.allclose(A, gu_A, rtol=1e-04, atol=1e-06)

    print(f"B:\n{B}\n")
    print(f"gu_B:\n{gu_B}\n")
    assert jnp.allclose(B, gu_B, rtol=1e-04, atol=1e-06)

    print(f"P:\n{P}\n")
    print(f"gu_P:\n{gu_P}\n")
    assert jnp.allclose(P, gu_P, rtol=1e-04, atol=1e-06)

    print(f"S:\n{S}\n")
    print(f"gu_S:\n{gu_S}\n")
    assert jnp.allclose(S, gu_S, rtol=1e-04, atol=1e-06)

In [38]:
test_NPLR_FouD()

NPLR FOUD

A:
[[-0.5        -0.         -0.70710677  0.         -0.70710677  0.         -0.70710677  0.        ]
 [ 0.          0.          0.          0.          0.          0.          0.          0.        ]
 [-0.70710677  0.         -1.         -3.1415927  -1.          0.         -1.          0.        ]
 [ 0.          0.          3.1415927   0.          0.          0.          0.          0.        ]
 [-0.70710677  0.         -1.          0.         -1.         -6.2831855  -1.          0.        ]
 [ 0.          0.          0.          0.          6.2831855   0.          0.          0.        ]
 [-0.70710677  0.         -1.          0.         -1.          0.         -1.         -9.424778  ]
 [ 0.          0.          0.          0.          0.          0.          9.424778    0.        ]]

gu_A:
[[-0.5         0.         -0.70710677  0.         -0.70710677  0.         -0.70710677  0.        ]
 [ 0.          0.          0.          0.          0.          0.          0.          

---
# DPLR
---

## DPLR-LegT

### DPLR-LegT

[Navigate Back To Low Rank Matrix Class](#make-hippo-matrices-dplr)

In [39]:
def test_DPLR_LegT():
    the_measure = "legt"
    rank = 2
    DPLR_bool = True
    gu_dplr_legt = GuLowRankMatrix(
        N=8, rank=rank, measure=the_measure, lambda_n=1.0, DPLR=DPLR_bool
    )
    dplr_legt = LowRankMatrix(
        N=8, rank=rank, measure=the_measure, lambda_n=1.0, DPLR=DPLR_bool
    )
    print("DPLR LEGT")
    Lambda, P, B, V = dplr_legt.Lambda, dplr_legt.P, dplr_legt.B, dplr_legt.V
    gu_Lambda, gu_P, gu_B, gu_V = (
        jnp.asarray(gu_dplr_legt.Lambda, dtype=jnp.float32),
        jnp.asarray(gu_dplr_legt.P, dtype=jnp.float32),
        jnp.asarray(gu_dplr_legt.B, dtype=jnp.float32),
        jnp.asarray(gu_dplr_legt.V, dtype=jnp.float32),
    )
    print(f"\nLambda:\n{Lambda}\n")
    print(f"gu_Lambda:\n{gu_Lambda}\n")
    assert jnp.allclose(Lambda, gu_Lambda, rtol=1e-04, atol=1e-06)

    print(f"P:\n{P}\n")
    print(f"gu_P:\n{gu_P}\n")
    assert jnp.allclose(P, gu_P, rtol=1e-04, atol=1e-06)

    print(f"B:\n{B}\n")
    print(f"gu_B:\n{gu_B}\n")
    assert jnp.allclose(B, gu_B, rtol=1e-04, atol=1e-06)

    print(f"V:\n{V}\n")
    print(f"gu_V:\n{gu_V}\n")
    assert jnp.allclose(V, gu_V, rtol=1e-04, atol=1e-06)

In [40]:
# r = 4
# g1 = np.random.randint(10, size=(4, r))
# print(g1)
# g2 = np.random.randint(10, size=(r, 2))
# print(g2)
# g3 = np.random.randint(10, size=(r))
# print(g3)

In [41]:
# g_12 = jnp.einsum("ij, j -> i", g1, g2)
# print(g_12)

In [42]:
# g_13 = jnp.einsum("ij, j -> i", g1, g3)
# print(g_13)

In [43]:
test_DPLR_LegT()

imaginary eigvals: tensor([-30.7341, -12.7962,  -8.2184,  -2.6705,   2.6705,   8.2184,  12.7962,  30.7341])
idx of imaginary eigvals: torch.return_types.sort(
values=tensor([-30.7341, -12.7962,  -8.2184,  -2.6705,   2.6705,   8.2184,  12.7962,  30.7341]),
indices=tensor([0, 1, 2, 3, 4, 5, 6, 7]))
Lambda:
tensor([-4.0000-30.7341j, -4.0000-12.7962j, -4.0000-8.2184j, -4.0000-2.6705j])
V_inv:
tensor([[ 1.1935e-01+0.0000e+00j,  2.2051e-01-1.1108e-02j,  2.6306e-01+8.1315e-02j,  2.9111e-01+1.3390e-01j,  2.2782e-01+3.2152e-01j,
          7.1471e-02+3.6970e-01j, -1.6505e-01+4.4722e-01j, -4.4603e-01+1.8783e-01j],
        [-9.4520e-02-0.0000e+00j,  1.2050e-04+1.7041e-01j, -3.0141e-01-6.7072e-02j,  3.4156e-02-9.0232e-03j, -2.9464e-01-2.3023e-01j,
          5.0705e-01-2.3987e-01j, -6.8414e-02+5.1170e-01j, -2.2463e-01-3.0658e-01j],
        [-2.2495e-01-0.0000e+00j, -9.6499e-02+2.3030e-05j, -4.2658e-01-2.5132e-01j,  4.1259e-01-4.6532e-01j,  2.0975e-01+4.0052e-01j,
         -2.0796e-01+5.6561e-03j, -2

  return _convert_element_type(operand, new_dtype, weak_type=False)
  self.B = (B.copy()).astype(dtype)  # HiPPO B Matrix (N x 1)
  self.P = (P.copy()).astype(dtype)  # HiPPO rank correction matrix (N x rank)
  return self.numpy().astype(dtype, copy=False)


AssertionError: 

### DPLR-LMU

[Navigate Back To Low Rank Matrix Class](#make-hippo-matrices-dplr)

In [None]:
def test_DPLR_LMU():
    the_measure = "legt"
    rank = 2
    DPLR_bool = True
    dplr_lmu = LowRankMatrix(
        N=8, rank=rank, measure=the_measure, lambda_n=2.0, DPLR=DPLR_bool
    )
    gu_dplr_lmu = GuLowRankMatrix(
        N=8, rank=rank, measure=the_measure, lambda_n=2.0, DPLR=DPLR_bool
    )  # change lambda so resulting matrix is in the form of LMU
    print("DPLR LMU")
    Lambda, P, B, V = dplr_lmu.Lambda, dplr_lmu.P, dplr_lmu.B, dplr_lmu.V
    gu_Lambda, gu_P, gu_B, gu_V = (
        jnp.asarray(gu_dplr_lmu.Lambda, dtype=jnp.float32),
        jnp.asarray(gu_dplr_lmu.P, dtype=jnp.float32),
        jnp.asarray(gu_dplr_lmu.B, dtype=jnp.float32),
        jnp.asarray(gu_dplr_lmu.V, dtype=jnp.float32),
    )
    print(f"\nLambda:\n{Lambda}\n")
    print(f"gu_Lambda:\n{gu_Lambda}\n")
    assert jnp.allclose(Lambda, gu_Lambda, rtol=1e-04, atol=1e-06)

    print(f"P:\n{P}\n")
    print(f"gu_P:\n{gu_P}\n")
    assert jnp.allclose(P, gu_P, rtol=1e-04, atol=1e-06)

    print(f"B:\n{B}\n")
    print(f"gu_B:\n{gu_B}\n")
    assert jnp.allclose(B, gu_B, rtol=1e-04, atol=1e-06)

    print(f"V:\n{V}\n")
    print(f"gu_V:\n{gu_V}\n")
    assert jnp.allclose(V, gu_V, rtol=1e-04, atol=1e-06)

In [None]:
test_DPLR_LMU()

## DPLR-LagT

[Navigate Back To Low Rank Matrix Class](#make-hippo-matrices-dplr)

In [None]:
def test_DPLR_LagT():
    the_measure = "lagt"
    rank = 1
    DPLR_bool = True
    dplr_lagt = LowRankMatrix(
        N=8,
        rank=rank,
        measure=the_measure,
        alpha=0.0,  # change resulting tilt through alpha and beta
        beta=1.0,
        DPLR=DPLR_bool,
    )  # change resulting tilt through alpha and beta
    gu_dplr_lagt = GuLowRankMatrix(
        N=8,
        rank=rank,
        measure=the_measure,
        alpha=0.0,
        beta=1.0,
        DPLR=DPLR_bool,
    )
    print("DPLR LAGT")
    Lambda, P, B, V = dplr_lagt.Lambda, dplr_lagt.P, dplr_lagt.B, dplr_lagt.V
    gu_Lambda, gu_P, gu_B, gu_V = (
        jnp.asarray(gu_dplr_lagt.Lambda, dtype=jnp.float32),
        jnp.asarray(gu_dplr_lagt.P, dtype=jnp.float32),
        jnp.asarray(gu_dplr_lagt.B, dtype=jnp.float32),
        jnp.asarray(gu_dplr_lagt.V, dtype=jnp.float32),
    )
    print(f"\nLambda:\n{Lambda}\n")
    print(f"gu_Lambda:\n{gu_Lambda}\n")
    assert jnp.allclose(Lambda, gu_Lambda, rtol=1e-04, atol=1e-06)

    print(f"P:\n{P}\n")
    print(f"gu_P:\n{gu_P}\n")
    assert jnp.allclose(P, gu_P, rtol=1e-04, atol=1e-06)

    print(f"B:\n{B}\n")
    print(f"gu_B:\n{gu_B}\n")
    assert jnp.allclose(B, gu_B, rtol=1e-04, atol=1e-06)

    print(f"V:\n{V}\n")
    print(f"gu_V:\n{gu_V}\n")
    assert jnp.allclose(V, gu_V, rtol=1e-04, atol=1e-06)

In [None]:
test_DPLR_LagT()

## DPLR-LegS

[Navigate Back To Low Rank Matrix Class](#make-hippo-matrices-dplr)

In [None]:
def test_DPLR_LegS():
    the_measure = "legs"
    rank = 1
    DPLR_bool = True
    dplr_legs = LowRankMatrix(N=8, rank=rank, measure=the_measure, DPLR=DPLR_bool)
    gu_dplr_legs = GuLowRankMatrix(N=8, rank=rank, measure=the_measure, DPLR=DPLR_bool)
    print("DPLR LEGS")
    Lambda, P, B, V = dplr_legs.Lambda, dplr_legs.P, dplr_legs.B, dplr_legs.V
    gu_Lambda, gu_P, gu_B, gu_V = (
        jnp.asarray(gu_dplr_legs.Lambda, dtype=jnp.float32),
        jnp.asarray(gu_dplr_legs.P, dtype=jnp.float32),
        jnp.asarray(gu_dplr_legs.B, dtype=jnp.float32),
        jnp.asarray(gu_dplr_legs.V, dtype=jnp.float32),
    )
    print(f"\nLambda:\n{Lambda}\n")
    print(f"gu_Lambda:\n{gu_Lambda}\n")
    assert jnp.allclose(Lambda, gu_Lambda, rtol=1e-04, atol=1e-06)

    print(f"P:\n{P}\n")
    print(f"gu_P:\n{gu_P}\n")
    assert jnp.allclose(P, gu_P, rtol=1e-04, atol=1e-06)

    print(f"B:\n{B}\n")
    print(f"gu_B:\n{gu_B}\n")
    assert jnp.allclose(B, gu_B, rtol=1e-04, atol=1e-06)

    print(f"V:\n{V}\n")
    print(f"gu_V:\n{gu_V}\n")
    assert jnp.allclose(V, gu_V, rtol=1e-04, atol=1e-06)

In [None]:
test_DPLR_LegS()

## DPLR Applied To Fourier Basis

### DPLR-FRU

[Navigate Back To Low Rank Matrix Class](#make-hippo-matrices-dplr)

In [None]:
def test_DPLR_FRU():
    the_measure = "fourier"
    fourier_type = "fru"
    rank = 1
    DPLR_bool = True
    dplr_fru = LowRankMatrix(
        N=8, rank=rank, measure=the_measure, fourier_type=fourier_type, DPLR=DPLR_bool
    )
    gu_dplr_fru = GuLowRankMatrix(
        N=8, rank=rank, measure=the_measure, fourier_type=fourier_type, DPLR=DPLR_bool
    )
    print("DPLR FRU")
    Lambda, P, B, V = dplr_fru.Lambda, dplr_fru.P, dplr_fru.B, dplr_fru.V
    gu_Lambda, gu_P, gu_B, gu_V = (
        jnp.asarray(gu_dplr_fru.Lambda, dtype=jnp.float32),
        jnp.asarray(gu_dplr_fru.P, dtype=jnp.float32),
        jnp.asarray(gu_dplr_fru.B, dtype=jnp.float32),
        jnp.asarray(gu_dplr_fru.V, dtype=jnp.float32),
    )
    print(f"\nLambda:\n{Lambda}\n")
    print(f"gu_Lambda:\n{gu_Lambda}\n")
    assert jnp.allclose(Lambda, gu_Lambda, rtol=1e-04, atol=1e-06)

    print(f"P:\n{P}\n")
    print(f"gu_P:\n{gu_P}\n")
    assert jnp.allclose(P, gu_P, rtol=1e-04, atol=1e-06)

    print(f"B:\n{B}\n")
    print(f"gu_B:\n{gu_B}\n")
    assert jnp.allclose(B, gu_B, rtol=1e-04, atol=1e-06)

    print(f"V:\n{V}\n")
    print(f"gu_V:\n{gu_V}\n")
    assert jnp.allclose(V, gu_V, rtol=1e-04, atol=1e-06)

In [None]:
test_DPLR_FRU()

### DPLR-FouT

[Navigate Back To Low Rank Matrix Class](#make-hippo-matrices-dplr)

In [None]:
def test_DPLR_FouT():
    the_measure = "fourier"
    fourier_type = "fout"
    rank = 1
    DPLR_bool = True
    dplr_fout = LowRankMatrix(
        N=8, rank=rank, measure=the_measure, fourier_type=fourier_type, DPLR=DPLR_bool
    )
    gu_dplr_fout = GuLowRankMatrix(
        N=8, rank=rank, measure=the_measure, fourier_type=fourier_type, DPLR=DPLR_bool
    )
    print("DPLR FOUT")
    Lambda, P, B, V = dplr_fout.Lambda, dplr_fout.P, dplr_fout.B, dplr_fout.V
    gu_Lambda, gu_P, gu_B, gu_V = (
        jnp.asarray(gu_dplr_fout.Lambda, dtype=jnp.float32),
        jnp.asarray(gu_dplr_fout.P, dtype=jnp.float32),
        jnp.asarray(gu_dplr_fout.B, dtype=jnp.float32),
        jnp.asarray(gu_dplr_fout.V, dtype=jnp.float32),
    )
    print(f"\nLambda:\n{Lambda}\n")
    print(f"gu_Lambda:\n{gu_Lambda}\n")
    assert jnp.allclose(Lambda, gu_Lambda, rtol=1e-04, atol=1e-06)

    print(f"P:\n{P}\n")
    print(f"gu_P:\n{gu_P}\n")
    assert jnp.allclose(P, gu_P, rtol=1e-04, atol=1e-06)

    print(f"B:\n{B}\n")
    print(f"gu_B:\n{gu_B}\n")
    assert jnp.allclose(B, gu_B, rtol=1e-04, atol=1e-06)

    print(f"V:\n{V}\n")
    print(f"gu_V:\n{gu_V}\n")
    assert jnp.allclose(V, gu_V, rtol=1e-04, atol=1e-06)

In [None]:
test_DPLR_FouT()

### DPLR-FouD

[Navigate Back To Low Rank Matrix Class](#make-hippo-matrices-dplr)

In [None]:
def test_DPLR_FouD():
    the_measure = "fourier"
    fourier_type = "foud"
    rank = 1
    DPLR_bool = True
    dplr_foud = LowRankMatrix(
        N=8, rank=rank, measure=the_measure, fourier_type=fourier_type, DPLR=DPLR_bool
    )
    gu_dplr_foud = GuLowRankMatrix(
        N=8, rank=rank, measure=the_measure, fourier_type=fourier_type, DPLR=DPLR_bool
    )
    print("DPLR FOUD")
    Lambda, P, B, V = dplr_foud.Lambda, dplr_foud.P, dplr_foud.B, dplr_foud.V
    gu_Lambda, gu_P, gu_B, gu_V = (
        jnp.asarray(gu_dplr_foud.Lambda, dtype=jnp.float32),
        jnp.asarray(gu_dplr_foud.P, dtype=jnp.float32),
        jnp.asarray(gu_dplr_foud.B, dtype=jnp.float32),
        jnp.asarray(gu_dplr_foud.V, dtype=jnp.float32),
    )
    print(f"\nLambda:\n{Lambda}\n")
    print(f"gu_Lambda:\n{gu_Lambda}\n")
    assert jnp.allclose(Lambda, gu_Lambda, rtol=1e-04, atol=1e-06)

    print(f"P:\n{P}\n")
    print(f"gu_P:\n{gu_P}\n")
    assert jnp.allclose(P, gu_P, rtol=1e-04, atol=1e-06)

    print(f"B:\n{B}\n")
    print(f"gu_B:\n{gu_B}\n")
    assert jnp.allclose(B, gu_B, rtol=1e-04, atol=1e-06)

    print(f"V:\n{V}\n")
    print(f"gu_V:\n{gu_V}\n")
    assert jnp.allclose(V, gu_V, rtol=1e-04, atol=1e-06)

In [None]:
test_DPLR_FouD()