# Normal Plus Low Rank (NPLR) HiPPO Matrices
---

## Table of Contents
* [Loading In Necessary Packages](#load-packages)
    * [Instantiate The HiPPO Matrix](#instantiate-the-hippo-matrix)
        * [Translated Legendre (LegT)](#translated-legendre-legt)
            * [LegT](#legt)
            * [LMU](#lmu)
        * [Translated Laguerre (LagT)](#translated-laguerre-lagt)
        * [Scaled Legendre (LegS)](#scaled-legendre-legs)
        * [Fourier Basis](#fourier-basis)
            * [Fourier Recurrent Unit (FRU)](#fourier-recurrent-unit-fru)
            * [Truncated Fourier (FouT)](#truncated-fourier-fout)
            * [Fourier With Decay (FourD)](#fourier-with-decay-fourd)
    * [Make HiPPO Matrices NPLR](#make-hippo-matrices-nplr)
        * [NPLR-LegT](#nplr-legt)
            * [NPLR-LegT](#nplr-legt)
            * [NPLR-LMU](#nplr-lmu)
        * [NPLR-LagT](#nplr-lagt)
        * [NPLR-LegS](#nplr-legs)
        * [NPLR Applied To Fourier Basis](#nplr-applied-to-fourier-basis)
            * [NPLR-FRU](#nplr-fru)
            * [NPLR-FouT](#nplr-fout)
            * [NPLR-FouD](#nplr-foud)
    * [Utilities For Gu HiPPO Operator](#utilities-for-gu-hippo-operator)
    * [Gu's HiPPO LegT Operator](#gus-hippo-legt-operator)
    * [Gu's Scale invariant HiPPO LegS Operator](#gus-scale-invariant-hippo-legs-operator)
    * [Implementation Of General HiPPO Operator](#implementation-of-general-hippo-operator)
    * [Output](#output)
---


## Load Packages

In [1]:
import os
import sys
module_path = os.path.abspath(os.path.join('../../../'))
if module_path not in sys.path:
    sys.path.append(module_path)

In [2]:
## import packages
import jax
import jax.numpy as jnp

from flax import linen as jnn

from jax.nn.initializers import lecun_normal, uniform
from jax.numpy.linalg import eig, inv, matrix_power
from jax.scipy.signal import convolve

import requests

from scipy import linalg as la
from scipy import signal
from scipy import special as ss

import math

print(jax.devices())
print(f"The Device: {jax.lib.xla_bridge.get_backend().platform}")

[StreamExecutorGpuDevice(id=0, process_index=0)]
The Device: gpu


In [3]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data as data
import numpy as np
import matplotlib.pyplot as plt
print(f"MPS enabled: {torch.backends.mps.is_available()}")


MPS enabled: False


In [4]:
from src.models.hippo.gu_transition import GuTransMatrix, GuLowRankMatrix

In [5]:
N = 8


In [6]:
seed = 1701
key = jax.random.PRNGKey(seed)


In [7]:
num_copies = 5
rng, key2, key3, key4, key5 = jax.random.split(key, num=num_copies)


## Instantiate The HiPPO Matrix

In [8]:
class TransMatrix:
    def __init__(
        self, N, measure="legs", lambda_n=1, fourier_type="fru", alpha=0, beta=1
    ):
        """
        Instantiates the HiPPO matrix of a given order using a particular measure.
        Args:
            N (int): Order of coefficients to describe the orthogonal polynomial that is the HiPPO projection.
            v (str): choose between this repo's implementation or hazy research's implementation.
            measure (str):
                choose between
                    - HiPPO w/ Translated Legendre (LegT) - legt
                    - HiPPO w/ Translated Laguerre (LagT) - lagt
                    - HiPPO w/ Scaled Legendre (LegS) - legs
                    - HiPPO w/ Fourier basis - fourier
                        - FRU: Fourier Recurrent Unit
                        - FouT: Translated Fourier
            lambda_n (int): The amount of tilt applied to the HiPPO-LegS basis, determines between LegS and LMU.
            fourier_type (str): chooses between the following:
                - FRU: Fourier Recurrent Unit - fru
                - FouT: Translated Fourier - fout
                - FourD: Fourier Decay - fourd
            alpha (float): The order of the Laguerre basis.
            beta (float): The scale of the Laguerre basis.

        Returns:
            A (jnp.ndarray): The HiPPO matrix multiplied by -1.
            B (jnp.ndarray): The other corresponding state space matrix.

        """
        A = None
        B = None
        if measure == "legt":
            A, B = self.build_LegT(N=N, lambda_n=lambda_n)

        elif measure == "lagt":
            A, B = self.build_LagT(alpha=alpha, beta=beta, N=N)

        elif measure == "legs":
            A, B = self.build_LegS(N=N)

        elif measure == "fourier":
            A, B = self.build_Fourier(N=N, fourier_type=fourier_type)

        elif measure == "random":
            A = jnp.random.randn(N, N) / N
            B = jnp.random.randn(N, 1)

        elif measure == "diagonal":
            A = -jnp.diag(jnp.exp(jnp.random.randn(N)))
            B = jnp.random.randn(N, 1)

        else:
            raise ValueError("Invalid HiPPO type")

        self.A_matrix = (A.copy()).astype(np.float32)
        self.B_matrix = (B.copy()).astype(np.float32)

    # Translated Legendre (LegT) - vectorized
    @staticmethod
    def build_LegT(N, lambda_n=1):
        """
        The, vectorized implementation of the, measure derived from the translated Legendre basis.

        Args:
            N (int): Order of coefficients to describe the orthogonal polynomial that is the HiPPO projection.
            legt_type (str): Choice between the two different tilts of basis.
                - legt: translated Legendre - 'legt'
                - lmu: Legendre Memory Unit - 'lmu'

        Returns:
            A (jnp.ndarray): The A HiPPO matrix.
            B (jnp.ndarray): The B HiPPO matrix.

        """
        q = jnp.arange(N, dtype=jnp.float32)
        k, n = jnp.meshgrid(q, q)
        case = jnp.power(-1.0, (n - k))
        A = None
        B = None

        if lambda_n == 1:
            A_base = -jnp.sqrt(2 * n + 1) * jnp.sqrt(2 * k + 1)
            pre_D = jnp.sqrt(jnp.diag(2 * q + 1))
            B = D = jnp.diag(pre_D)[:, None]
            A = jnp.where(
                k <= n, A_base, A_base * case
            )  # if n >= k, then case_2 * A_base is used, otherwise A_base

        elif lambda_n == 2:  # (jnp.sqrt(2*n+1) * jnp.power(-1, n)):
            A_base = -(2 * n + 1)
            B = jnp.diag((2 * q + 1) * jnp.power(-1, n))[:, None]
            A = jnp.where(
                k <= n, A_base * case, A_base
            )  # if n >= k, then case_2 * A_base is used, otherwise A_base

        return A, B

    # Translated Laguerre (LagT) - non-vectorized
    @staticmethod
    def build_LagT(alpha, beta, N):
        """
        The, vectorized implementation of the, measure derived from the translated Laguerre basis.

        Args:
            alpha (float): The order of the Laguerre basis.
            beta (float): The scale of the Laguerre basis.
            N (int): Order of coefficients to describe the orthogonal polynomial that is the HiPPO projection.

        Returns:
            A (jnp.ndarray): The A HiPPO matrix.
            B (jnp.ndarray): The B HiPPO matrix.

        """
        L = jnp.exp(
            0.5
            * (ss.gammaln(jnp.arange(N) + alpha + 1) - ss.gammaln(jnp.arange(N) + 1))
        )
        inv_L = 1.0 / L[:, None]
        pre_A = (jnp.eye(N) * ((1 + beta) / 2)) + jnp.tril(jnp.ones((N, N)), -1)
        pre_B = ss.binom(alpha + jnp.arange(N), jnp.arange(N))[:, None]

        A = -inv_L * pre_A * L[None, :]
        B = (
            jnp.exp(-0.5 * ss.gammaln(1 - alpha))
            * jnp.power(beta, (1 - alpha) / 2)
            * inv_L
            * pre_B
        )

        return A, B

    # Scaled Legendre (LegS) vectorized
    @staticmethod
    def build_LegS(N):
        """
        The, vectorized implementation of the, measure derived from the Scaled Legendre basis.

        Args:
            N (int): Order of coefficients to describe the orthogonal polynomial that is the HiPPO projection.

        Returns:
            A (jnp.ndarray): The A HiPPO matrix.
            B (jnp.ndarray): The B HiPPO matrix.

        """
        q = jnp.arange(N, dtype=jnp.float32)
        k, n = jnp.meshgrid(q, q)
        pre_D = jnp.sqrt(jnp.diag(2 * q + 1))
        B = D = jnp.diag(pre_D)[:, None]

        A_base = jnp.sqrt(2 * n + 1) * jnp.sqrt(2 * k + 1)
        case_2 = (n + 1) / (2 * n + 1)

        A = jnp.where(n > k, 
                      A_base, 
                      jnp.where(n == k, 
                                n+1, 
                                0.0
                                )
                      )
        
        return -A.astype(np.float32), B.astype(np.float32)

    # Fourier Basis OPs and functions - vectorized
    @staticmethod
    def build_Fourier(N, fourier_type="fru"):
        """
        Vectorized measure implementations derived from fourier basis.

        Args:
            N (int): Order of coefficients to describe the orthogonal polynomial that is the HiPPO projection.
            fourier_type (str): The type of Fourier measure.
                - FRU: Fourier Recurrent Unit - fru
                - FouT: truncated Fourier - fout
                - fouD: decayed Fourier - foud

        Returns:
            A (jnp.ndarray): The A HiPPO matrix.
            B (jnp.ndarray): The B HiPPO matrix.

        """
        A = jnp.diag(
            jnp.stack([jnp.zeros(N // 2), jnp.zeros(N // 2)], axis=-1).reshape(-1)[1:],
            1,
        )
        B = jnp.zeros(A.shape[1], dtype=jnp.float32)

        B = B.at[0::2].set(jnp.sqrt(2))
        B = B.at[0].set(1)

        q = jnp.arange(A.shape[1], dtype=jnp.float32)
        k, n = jnp.meshgrid(q, q)

        n_odd = n % 2 == 0
        k_odd = k % 2 == 0

        case_1 = (n == k) & (n == 0)
        case_2_3 = ((k == 0) & (n_odd)) | ((n == 0) & (k_odd))
        case_4 = (n_odd) & (k_odd)
        case_5 = (n - k == 1) & (k_odd)
        case_6 = (k - n == 1) & (n_odd)

        if fourier_type == "fru":  # Fourier Recurrent Unit (FRU) - vectorized
            A = jnp.where(
                case_1,
                -1.0,
                jnp.where(
                    case_2_3,
                    -jnp.sqrt(2),
                    jnp.where(
                        case_4,
                        -2,
                        jnp.where(
                            case_5,
                            jnp.pi * (n // 2),
                            jnp.where(case_6, -jnp.pi * (k // 2), 0.0),
                        ),
                    ),
                ),
            )

        elif fourier_type == "fout":  # truncated Fourier (FouT) - vectorized
            A = jnp.where(
                case_1,
                -1.0,
                jnp.where(
                    case_2_3,
                    -jnp.sqrt(2),
                    jnp.where(
                        case_4,
                        -2,
                        jnp.where(
                            case_5,
                            jnp.pi * (n // 2),
                            jnp.where(case_6, -jnp.pi * (k // 2), 0.0),
                        ),
                    ),
                ),
            )

            A = 2 * A
            B = 2 * B

        elif fourier_type == "foud":
            A = jnp.where(
                case_1,
                -1.0,
                jnp.where(
                    case_2_3,
                    -jnp.sqrt(2),
                    jnp.where(
                        case_4,
                        -2,
                        jnp.where(
                            case_5,
                            2 * jnp.pi * (n // 2),
                            jnp.where(case_6, 2 * -jnp.pi * (k // 2), 0.0),
                        ),
                    ),
                ),
            )

            A = 0.5 * A
            B = 0.5 * B

        B = B[:, None]

        return A.astype(jnp.float32), B.astype(jnp.float32)


## Translated Legendre (LegT)

### LegT

In [9]:
def test_LegT():
    legt_matrices = TransMatrix(N=8, measure="legt", lambda_n=1.0)
    A, B = legt_matrices.A_matrix, legt_matrices.B_matrix
    gu_legt_matrices = GuTransMatrix(N=8, measure="legt", lambda_n=1.0)
    gu_A, gu_B = gu_legt_matrices.A_matrix, gu_legt_matrices.B_matrix
    print(f"A:\n", A)
    print(f"Gu's A:\n", gu_A)
    print(f"B:\n", B)
    print(f"Gu's B:\n", gu_B)
    assert jnp.allclose(A, gu_A)
    assert jnp.allclose(B, gu_B)
    

In [10]:
test_LegT()

  Q = jnp.arange(N, dtype=jnp.float64)


A:
 [[ -1.          1.7320508  -2.2360678   2.6457512  -3.          3.3166246
   -3.6055512   3.8729832]
 [ -1.7320508  -3.          3.872983   -4.5825753   5.196152   -5.744562
    6.244998   -6.708204 ]
 [ -2.2360678  -3.872983   -4.999999    5.916079   -6.7082033   7.4161973
   -8.062257    8.660253 ]
 [ -2.6457512  -4.5825753  -5.916079   -6.9999995   7.937254   -8.774963
    9.5393915 -10.24695  ]
 [ -3.         -5.196152   -6.7082033  -7.937254   -9.          9.949874
  -10.816654   11.61895  ]
 [ -3.3166246  -5.744562   -7.4161973  -8.774963   -9.949874  -10.999999
   11.958261  -12.845232 ]
 [ -3.6055512  -6.244998   -8.062257   -9.5393915 -10.816654  -11.958261
  -13.         13.964239 ]
 [ -3.8729832  -6.708204   -8.660253  -10.24695   -11.61895   -12.845232
  -13.964239  -14.999999 ]]
Gu's A:
 [[ -1.          1.7320508  -2.2360678   2.6457512  -3.          3.3166246
   -3.6055512   3.8729832]
 [ -1.7320508  -3.          3.872983   -4.5825753   5.196152   -5.744562
    6.2449

### LMU

In [11]:
def test_LMU():
    lmu_matrices = TransMatrix(
        N=8, measure="legt", lambda_n=2.0
    )  # change lambda so resulting matrix is in the form of LMU
    A, B = lmu_matrices.A_matrix, lmu_matrices.B_matrix
    gu_lmu_matrices = GuTransMatrix(
        N=8, measure="legt", lambda_n=2.0
    )  # change lambda so resulting matrix is in the form of LMU
    gu_A, gu_B = gu_lmu_matrices.A_matrix, gu_lmu_matrices.B_matrix
    print(f"A:\n", A)
    print(f"Gu's A:\n", gu_A)
    print(f"B:\n", B)
    print(f"Gu's B:\n", gu_B)
    assert jnp.allclose(A, gu_A)
    assert jnp.allclose(B, gu_B)
    

In [12]:
test_LMU()

  Q = jnp.arange(N, dtype=jnp.float64)


A:
 [[ -1.  -1.  -1.  -1.  -1.  -1.  -1.  -1.]
 [  3.  -3.  -3.  -3.  -3.  -3.  -3.  -3.]
 [ -5.   5.  -5.  -5.  -5.  -5.  -5.  -5.]
 [  7.  -7.   7.  -7.  -7.  -7.  -7.  -7.]
 [ -9.   9.  -9.   9.  -9.  -9.  -9.  -9.]
 [ 11. -11.  11. -11.  11. -11. -11. -11.]
 [-13.  13. -13.  13. -13.  13. -13. -13.]
 [ 15. -15.  15. -15.  15. -15.  15. -15.]]
Gu's A:
 [[ -1.  -1.  -1.  -1.  -1.  -1.  -1.  -1.]
 [  3.  -3.  -3.  -3.  -3.  -3.  -3.  -3.]
 [ -5.   5.  -5.  -5.  -5.  -5.  -5.  -5.]
 [  7.  -7.   7.  -7.  -7.  -7.  -7.  -7.]
 [ -9.   9.  -9.   9.  -9.  -9.  -9.  -9.]
 [ 11. -11.  11. -11.  11. -11. -11. -11.]
 [-13.  13. -13.  13. -13.  13. -13. -13.]
 [ 15. -15.  15. -15.  15. -15.  15. -15.]]
B:
 [[  1.]
 [ -3.]
 [  5.]
 [ -7.]
 [  9.]
 [-11.]
 [ 13.]
 [-15.]]
Gu's B:
 [[  1.]
 [ -3.]
 [  5.]
 [ -7.]
 [  9.]
 [-11.]
 [ 13.]
 [-15.]]


## Translated Laguerre (LagT)

In [13]:
def test_LagT():
    lagt_matrices = TransMatrix(
        N=8,
        measure="lagt",
        alpha=0.0,  # change resulting tilt through alpha and beta
        beta=1.0,
    )  # change resulting tilt through alpha and beta
    A, B = lagt_matrices.A_matrix, lagt_matrices.B_matrix
    gu_lagt_matrices = GuTransMatrix(
        N=8,
        measure="lagt",
        alpha=0.0,  # change resulting tilt through alpha and beta
        beta=1.0,
    )  # change resulting tilt through alpha and beta
    gu_A, gu_B = gu_lagt_matrices.A_matrix, gu_lagt_matrices.B_matrix
    print(f"A:\n", A)
    print(f"Gu's A:\n", gu_A)
    print(f"B:\n", B)
    print(f"Gu's B:\n", gu_B)
    assert jnp.allclose(A, gu_A)
    assert jnp.allclose(B, gu_B)

In [14]:
test_LagT()

A:
 [[-1.         -0.         -0.         -0.         -0.         -0.
  -0.         -0.        ]
 [-1.         -1.         -0.         -0.         -0.         -0.
  -0.         -0.        ]
 [-1.         -1.         -1.         -0.         -0.         -0.
  -0.         -0.        ]
 [-1.         -1.         -1.         -1.         -0.         -0.
  -0.         -0.        ]
 [-1.         -1.         -1.         -1.         -1.         -0.
  -0.         -0.        ]
 [-1.         -1.         -1.         -1.         -1.         -1.
  -0.         -0.        ]
 [-1.         -1.         -1.         -1.         -1.         -1.
  -1.         -0.        ]
 [-0.99999976 -0.99999976 -0.99999976 -0.99999976 -0.99999976 -0.99999976
  -0.99999976 -1.        ]]
Gu's A:
 [[-1.         -0.         -0.         -0.         -0.         -0.
  -0.         -0.        ]
 [-1.         -1.         -0.         -0.         -0.         -0.
  -0.         -0.        ]
 [-1.         -1.         -1.         -0.       

## Scaled Legendre (LegS)

In [15]:
def test_LegS():
    legs_matrices = TransMatrix(N=8, measure="legs")
    A, B = legs_matrices.A_matrix, legs_matrices.B_matrix
    gu_legs_matrices = GuTransMatrix(N=8, measure="legs")
    gu_A, gu_B = gu_legs_matrices.A_matrix, gu_legs_matrices.B_matrix
    print(f"A:\n", A)
    print(f"Gu's A:\n", gu_A)
    print(f"B:\n", B)
    print(f"Gu's B:\n", gu_B)
    assert jnp.allclose(A, gu_A)
    assert jnp.allclose(B, gu_B)

In [16]:
test_LegS()

  q = jnp.arange(


A:
 [[ -1.         -0.         -0.         -0.         -0.         -0.
   -0.         -0.       ]
 [ -1.7320508  -2.         -0.         -0.         -0.         -0.
   -0.         -0.       ]
 [ -2.2360678  -3.872983   -3.         -0.         -0.         -0.
   -0.         -0.       ]
 [ -2.6457512  -4.5825753  -5.916079   -4.         -0.         -0.
   -0.         -0.       ]
 [ -3.         -5.196152   -6.7082033  -7.937254   -5.         -0.
   -0.         -0.       ]
 [ -3.3166246  -5.744562   -7.4161973  -8.774963   -9.949874   -6.
   -0.         -0.       ]
 [ -3.6055512  -6.244998   -8.062257   -9.5393915 -10.816654  -11.958261
   -7.         -0.       ]
 [ -3.8729832  -6.708204   -8.660253  -10.24695   -11.61895   -12.845232
  -13.964239   -8.       ]]
Gu's A:
 [[ -1.          0.          0.          0.          0.          0.
    0.          0.       ]
 [ -1.7320508  -1.9999999   0.          0.          0.          0.
    0.          0.       ]
 [ -2.2360678  -3.872983   -3.    

## Fourier Basis

### Fourier Recurrent Unit (FRU)

In [17]:
def test_FRU():
    fru_matrices = TransMatrix(N=8, measure="fourier", fourier_type="fru")
    A, B = fru_matrices.A_matrix, fru_matrices.B_matrix
    gu_fru_matrices = GuTransMatrix(N=8, measure="fourier", fourier_type="fru")
    gu_A, gu_B = gu_fru_matrices.A_matrix, gu_fru_matrices.B_matrix
    print(f"A:\n", A)
    print(f"Gu's A:\n", gu_A)
    print(f"B:\n", B)
    print(f"Gu's B:\n", gu_B)
    assert jnp.allclose(A, gu_A)
    assert jnp.allclose(B, gu_B)

In [18]:
test_FRU()

A:
 [[-1.        -0.        -1.4142135  0.        -1.4142135  0.
  -1.4142135  0.       ]
 [ 0.         0.         0.         0.         0.         0.
   0.         0.       ]
 [-1.4142135  0.        -2.        -3.1415927 -2.         0.
  -2.         0.       ]
 [ 0.         0.         3.1415927  0.         0.         0.
   0.         0.       ]
 [-1.4142135  0.        -2.         0.        -2.        -6.2831855
  -2.         0.       ]
 [ 0.         0.         0.         0.         6.2831855  0.
   0.         0.       ]
 [-1.4142135  0.        -2.         0.        -2.         0.
  -2.        -9.424778 ]
 [ 0.         0.         0.         0.         0.         0.
   9.424778   0.       ]]
Gu's A:
 [[-1.         0.        -1.4142135  0.        -1.4142135  0.
  -1.4142135  0.       ]
 [ 0.         0.         0.         0.         0.         0.
   0.         0.       ]
 [-1.4142135  0.        -1.9999999 -3.1415927 -1.9999999  0.
  -1.9999999  0.       ]
 [ 0.         0.         3.141592

### Truncated Fourier (FouT)

In [19]:
def test_FouT():
    fout_matrices = TransMatrix(N=8, measure="fourier", fourier_type="fout")
    A, B = fout_matrices.A_matrix, fout_matrices.B_matrix
    gu_fout_matrices = GuTransMatrix(N=8, measure="fourier", fourier_type="fout")
    gu_A, gu_B = gu_fout_matrices.A_matrix, gu_fout_matrices.B_matrix
    print(f"A:\n", A)
    print(f"Gu's A:\n", gu_A)
    print(f"B:\n", B)
    print(f"Gu's B:\n", gu_B)
    assert jnp.allclose(A, gu_A)
    assert jnp.allclose(B, gu_B)

In [20]:
test_FouT()

A:
 [[ -2.         -0.         -2.828427    0.         -2.828427    0.
   -2.828427    0.       ]
 [  0.          0.          0.          0.          0.          0.
    0.          0.       ]
 [ -2.828427    0.         -4.         -6.2831855  -4.          0.
   -4.          0.       ]
 [  0.          0.          6.2831855   0.          0.          0.
    0.          0.       ]
 [ -2.828427    0.         -4.          0.         -4.        -12.566371
   -4.          0.       ]
 [  0.          0.          0.          0.         12.566371    0.
    0.          0.       ]
 [ -2.828427    0.         -4.          0.         -4.          0.
   -4.        -18.849556 ]
 [  0.          0.          0.          0.          0.          0.
   18.849556    0.       ]]
Gu's A:
 [[ -2.          0.         -2.828427    0.         -2.828427    0.
   -2.828427    0.       ]
 [  0.          0.          0.          0.          0.          0.
    0.          0.       ]
 [ -2.828427    0.         -3.9999998  -

### Fourier With Decay (FouD)

In [21]:
def test_FourD():
    fourd_matrices = TransMatrix(N=8, measure="fourier", fourier_type="foud")
    A, B = fourd_matrices.A_matrix, fourd_matrices.B_matrix
    gu_fourd_matrices = GuTransMatrix(N=8, measure="fourier", fourier_type="foud")
    gu_A, gu_B = gu_fourd_matrices.A_matrix, gu_fourd_matrices.B_matrix
    print(f"A:\n", A)
    print(f"Gu's A:\n", gu_A)
    assert jnp.allclose(A, gu_A)
    
    print(f"B:\n", B)
    print(f"Gu's B:\n", gu_B)
    assert jnp.allclose(B, gu_B)

In [22]:
test_FourD()

A:
 [[-0.5        -0.         -0.70710677  0.         -0.70710677  0.
  -0.70710677  0.        ]
 [ 0.          0.          0.          0.          0.          0.
   0.          0.        ]
 [-0.70710677  0.         -1.         -3.1415927  -1.          0.
  -1.          0.        ]
 [ 0.          0.          3.1415927   0.          0.          0.
   0.          0.        ]
 [-0.70710677  0.         -1.          0.         -1.         -6.2831855
  -1.          0.        ]
 [ 0.          0.          0.          0.          6.2831855   0.
   0.          0.        ]
 [-0.70710677  0.         -1.          0.         -1.          0.
  -1.         -9.424778  ]
 [ 0.          0.          0.          0.          0.          0.
   9.424778    0.        ]]
Gu's A:
 [[-0.5         0.         -0.70710677  0.         -0.70710677  0.
  -0.70710677  0.        ]
 [ 0.          0.          0.          0.          0.          0.
   0.          0.        ]
 [-0.70710677  0.         -0.99999994 -3.1415927 

## Make HiPPO Matrices NPLR

In [23]:
class LowRankMatrix:
    def __init__(
        self,
        N,
        rank,
        measure="legs",
        lambda_n=1,
        fourier_type="fru",
        alpha=0,
        beta=1,
        DPLR=True,
        dtype=jnp.float32,
    ):
        self.N = N
        self.measure = measure
        self.rank = rank
        _trans_matrix = TransMatrix(N, measure, lambda_n, fourier_type, alpha, beta)

        Lambda = None
        B = None
        V = None
        A, B, P, S = self.make_NPLR(trans_matrix=_trans_matrix, dtype=dtype)
        if DPLR:
            Lambda, P, B, V = self.make_DPLR(A=A, B=B, P=P, S=S)
            self.Lambda = (Lambda.copy()).astype(dtype)  # real eigenvalues
            self.V = (V.copy()).astype(dtype)  # imaginary (complex) eigenvalues

        self.A = (A.copy()).astype(dtype)  # HiPPO A Matrix (N x N)
        self.B = (B.copy()).astype(dtype)  # HiPPO B Matrix (N x 1)
        self.P = (P.copy()).astype(dtype)  # HiPPO rank correction matrix (N x rank)
        self.S = (S.copy()).astype(
            dtype
        )  # HiPPO normal (skew-symmetric) matrix (N x N)

    def make_NPLR(self, trans_matrix, dtype=jnp.float32):
        A = trans_matrix.A_matrix
        B = trans_matrix.B_matrix

        P = self.rank_correction(dtype=dtype)  # (r N)

        S = A + jnp.sum(
            jnp.expand_dims(P, -2) * jnp.expand_dims(P, -1), axis=-3
        )  # rank correct if rank > 1, summation happens in outer most dimension
        # S is nearly skew-symmetric

        return A, B, P, S

    def make_DPLR(self, A, B, P, S):
        """Diagonalize NPLR representation"""

        if not self.check_skew(S=S):
            raise ValueError("Matrix is not skew symmetric")

        # Check skew symmetry
        S_diag = jnp.diagonal(S)
        Lambda_real = jnp.mean(S_diag, -1, keepdim=True) * jnp.ones_like(
            S_diag
        )  # S itself is not skew-symmetric. It is skew-symmetric by: S + c * I. Extract the value c, c = mean(S_diag)

        # Diagonalize S to V \Lambda V^*
        Lambda_imaginary, V = jnp.linalg.eigh(S * -1j)
        Lambda = Lambda_real + 1j * Lambda_imaginary

        Lambda, V = self.fix_zeroed_eigvals(Lambda=Lambda, V=V)

        P = V.conj().transpose(-1, -2) @ P
        B = V.conj().transpose(-1, -2) @ B

        return Lambda, P, B, V

    def check_skew(self, S):
        """Check if a matrix is skew symmetric

        refer to:
        - https://www.cuemath.com/algebra/skew-symmetric-matrix/
        - https://en.wikipedia.org/wiki/Skew-symmetric_matrix

        """
        skew_S = S + S.transpose(
            -1, -2
        )  # ensure matrices are skew symmetric by assuming S is skew symmetric, adding two skew symmetric matrices results in a skew symmetric matrix
        skew_bool = False
        if (
            S.transpose(-1, -2) == -S
        ).all():  # the transpose of a skew symmetric matrix is equal to the negative of the matrix
            skew_bool = True

        print(f"Transposed matrix: {S.transpose(-1, -2)}\n\nUnchanged matrix: {-S}")
        return skew_bool

    def fix_zeroed_eigvals(self, Lambda, V):

        # Only keep half of each conjugate pair
        _, idx = jnp.sort(Lambda.imag)
        Lambda_sorted = Lambda.at[idx]
        V_sorted = V.at[:, idx]

        # There is an edge case when eigenvalues can be 0, which requires some machinery to handle
        # We use a huge hack here: Assume only one pair is 0, and that it is the first row/column of A (only happens in Fourier case)
        V = V_sorted.at[:, : self.N // 2]
        Lambda = Lambda_sorted.at[: self.N // 2]
        if Lambda[-1].abs() < 1e-4:
            # x = x.at[idx].set(y)
            V = V.at[:, -1].set(0.0)  # V[:, -1] = 0.0
            V = V.at[0, -1].set(2**-0.5)  # V[0, -1] = 2**-0.5
            V = V.at[1, -1].set(2**-0.5 * 1j)  # V[1, -1] = 2**-0.5 * 1j
        else:
            raise ValueError("Only 1 zero eigenvalue allowed in diagonal part of A")

        _AP = V @ jnp.diag_embed(Lambda) @ V.conj().transpose(-1, -2)
        if (err := jnp.sum((2 * _AP.real - AP) ** 2) / self.N) > 1e-5:
            print(
                "Warning: Diagonalization of A matrix not numerically precise - error",
                err,
            )

        return Lambda, V

    def rank_correction(self, dtype=jnp.float32):
        """Return low-rank matrix L such that A + L is normal"""

        if self.measure == "legs":
            assert self.rank >= 1
            P = jnp.expand_dims(
                jnp.sqrt(0.5 + jnp.arange(self.N, dtype=dtype)), 0
            )  # (1 N)

        elif self.measure == "legt":
            assert self.rank >= 2
            P = jnp.sqrt(1 + 2 * jnp.arange(self.N, dtype=dtype))  # (N)
            P0 = P.clone()
            P0 = P0.at[0::2].set(0.0)  # P0[0::2] = 0.0
            P1 = P.clone()
            P1 = P1.at[1::2].set(0.0)  # P1[1::2] = 0.0
            P = jnp.stack([P0, P1], axis=0)  # (2 N)
            P = P * (
                2 ** (-0.5)
            )  # Halve the rank correct just like the original matrix was halved

        elif self.measure == "lagt":
            assert self.rank >= 1
            P = 0.5**0.5 * jnp.ones((1, self.N), dtype=dtype)

        elif self.measure in ["fourier", "fout"]:
            P = jnp.zeros(self.N)
            P = P.at[0::2].set(2**0.5)  # P[0::2] = 2**0.5
            P = P.at[0].set(1)  # P[0] = 1
            P = jnp.expand_dims(P, 0)

        elif self.measure == "fourier_decay":
            P = jnp.zeros(self.N)
            P = P.at[0::2].set(2**0.5)  # P[0::2] = 2**0.5
            P = P.at[0].set(1)  # P[0] = 1
            P = jnp.expand_dims(P, 0)
            P = P / 2**0.5

        elif self.measure == "fourier2":
            P = jnp.zeros(self.N)
            P = P.at[0::2].set(2**0.5)  # P[0::2] = 2**0.5
            P = P.at[0].set(1)  # P[0] = 1
            P = 2**0.5 * jnp.expand_dims(P, 0)

        elif self.measure in ["fourier_diag", "foud", "legsd"]:
            P = jnp.zeros((1, self.N), dtype=dtype)

        else:
            raise NotImplementedError

        d = jnp.size(P, axis=0)
        if self.rank > d:
            P = jnp.concatenate(
                [P, jnp.zeros((self.rank - d, self.N), dtype=dtype)], axis=0
            )  # (rank N)

        return P


## NPLR-LegT

### NPLR-LegT

In [24]:
def test_NPLR_LegT():
    the_measure = "legt"
    rank = 2
    nplr_legt = LowRankMatrix(
        N=8, rank=rank, measure=the_measure, lambda_n=1.0, DPLR=False
    )
    gu_nplr_legt = GuLowRankMatrix(
        N=8, rank=rank, measure=the_measure, lambda_n=1.0, DPLR=False
    )
    A, B, P, S = nplr_legt.A, nplr_legt.B, nplr_legt.P, nplr_legt.S
    gu_A, gu_B, gu_P, gu_S = (
        jnp.asarray(gu_nplr_legt.A, dtype=jnp.float32),
        jnp.asarray(gu_nplr_legt.B, dtype=jnp.float32),
        jnp.asarray(gu_nplr_legt.P, dtype=jnp.float32),
        jnp.asarray(gu_nplr_legt.S, dtype=jnp.float32),
    )
    print(f"\nA:\n{A}\n")
    print(f"gu_A:\n{gu_A}\n")
    assert jnp.allclose(A, gu_A, rtol=1e-04, atol=1e-06)

    print(f"B:\n{B}\n")
    print(f"gu_B:\n{gu_B}\n")
    assert jnp.allclose(B, gu_B, rtol=1e-04, atol=1e-06)

    print(f"P:\n{P}\n")
    print(f"gu_P:\n{gu_P}\n")
    assert jnp.allclose(P, gu_P, rtol=1e-04, atol=1e-06)

    print(f"S:\n{S}\n")
    print(f"gu_S:\n{gu_S}\n")
    assert jnp.allclose(S, gu_S, rtol=1e-04, atol=1e-06)

In [25]:
test_NPLR_LegT()

  A = torch.from_numpy(np_A)  # (N, N)



A:
[[ -1.          1.7320508  -2.2360678   2.6457512  -3.          3.3166246
   -3.6055512   3.8729832]
 [ -1.7320508  -3.          3.872983   -4.5825753   5.196152   -5.744562
    6.244998   -6.708204 ]
 [ -2.2360678  -3.872983   -4.999999    5.916079   -6.7082033   7.4161973
   -8.062257    8.660253 ]
 [ -2.6457512  -4.5825753  -5.916079   -6.9999995   7.937254   -8.774963
    9.5393915 -10.24695  ]
 [ -3.         -5.196152   -6.7082033  -7.937254   -9.          9.949874
  -10.816654   11.61895  ]
 [ -3.3166246  -5.744562   -7.4161973  -8.774963   -9.949874  -10.999999
   11.958261  -12.845232 ]
 [ -3.6055512  -6.244998   -8.062257   -9.5393915 -10.816654  -11.958261
  -13.         13.964239 ]
 [ -3.8729832  -6.708204   -8.660253  -10.24695   -11.61895   -12.845232
  -13.964239  -14.999999 ]]

gu_A:
[[ -1.          1.7320508  -2.2360678   2.6457512  -3.          3.3166246
   -3.6055512   3.8729832]
 [ -1.7320508  -3.          3.872983   -4.5825753   5.196152   -5.744562
    6.244998

### NPLR-LMU

In [26]:
def test_NPLR_LMU():
    the_measure = "legt"
    rank = 2
    nplr_lmu = LowRankMatrix(
        N=8, rank=rank, measure=the_measure, lambda_n=2.0, DPLR=False
    )  
    gu_nplr_lmu = GuLowRankMatrix(
        N=8, rank=rank, measure=the_measure, lambda_n=2.0, DPLR=False
    )  # change lambda so resulting matrix is in the form of LMU
    print("NPLR LMU")
    A, B, P, S = nplr_lmu.A, nplr_lmu.B, nplr_lmu.P, nplr_lmu.S
    gu_A, gu_B, gu_P, gu_S = (
        jnp.asarray(gu_nplr_lmu.A, dtype=jnp.float32),
        jnp.asarray(gu_nplr_lmu.B, dtype=jnp.float32),
        jnp.asarray(gu_nplr_lmu.P, dtype=jnp.float32),
        jnp.asarray(gu_nplr_lmu.S, dtype=jnp.float32),
    )
    print(f"\nA:\n{A}\n")
    print(f"gu_A:\n{gu_A}\n")
    assert jnp.allclose(A, gu_A, rtol=1e-04, atol=1e-06)

    print(f"B:\n{B}\n")
    print(f"gu_B:\n{gu_B}\n")
    assert jnp.allclose(B, gu_B, rtol=1e-04, atol=1e-06)

    print(f"P:\n{P}\n")
    print(f"gu_P:\n{gu_P}\n")
    assert jnp.allclose(P, gu_P, rtol=1e-04, atol=1e-06)

    print(f"S:\n{S}\n")
    print(f"gu_S:\n{gu_S}\n")
    assert jnp.allclose(S, gu_S, rtol=1e-04, atol=1e-06)
    

In [27]:
test_NPLR_LMU()

NPLR LMU

A:
[[ -1.  -1.  -1.  -1.  -1.  -1.  -1.  -1.]
 [  3.  -3.  -3.  -3.  -3.  -3.  -3.  -3.]
 [ -5.   5.  -5.  -5.  -5.  -5.  -5.  -5.]
 [  7.  -7.   7.  -7.  -7.  -7.  -7.  -7.]
 [ -9.   9.  -9.   9.  -9.  -9.  -9.  -9.]
 [ 11. -11.  11. -11.  11. -11. -11. -11.]
 [-13.  13. -13.  13. -13.  13. -13. -13.]
 [ 15. -15.  15. -15.  15. -15.  15. -15.]]

gu_A:
[[ -1.  -1.  -1.  -1.  -1.  -1.  -1.  -1.]
 [  3.  -3.  -3.  -3.  -3.  -3.  -3.  -3.]
 [ -5.   5.  -5.  -5.  -5.  -5.  -5.  -5.]
 [  7.  -7.   7.  -7.  -7.  -7.  -7.  -7.]
 [ -9.   9.  -9.   9.  -9.  -9.  -9.  -9.]
 [ 11. -11.  11. -11.  11. -11. -11. -11.]
 [-13.  13. -13.  13. -13.  13. -13. -13.]
 [ 15. -15.  15. -15.  15. -15.  15. -15.]]

B:
[[  1.]
 [ -3.]
 [  5.]
 [ -7.]
 [  9.]
 [-11.]
 [ 13.]
 [-15.]]

gu_B:
[[  1.]
 [ -3.]
 [  5.]
 [ -7.]
 [  9.]
 [-11.]
 [ 13.]
 [-15.]]

P:
[[0.         1.2247448  0.         1.8708286  0.         2.3452077
  0.         2.7386127 ]
 [0.70710677 0.         1.5811386  0.         2.12132

  Q = jnp.arange(N, dtype=jnp.float64)


## NPLR-LagT

In [28]:
def test_NPLR_LagT():
    the_measure = "lagt"
    rank = 1
    nplr_lagt = LowRankMatrix(
        N=8,
        rank=rank,
        measure=the_measure,
        alpha=0.0,  # change resulting tilt through alpha and beta
        beta=1.0,
        DPLR=False,
    )  # change resulting tilt through alpha and beta
    gu_nplr_lagt = GuLowRankMatrix(
        N=8,
        rank=rank,
        measure=the_measure,
        alpha=0.0,  
        beta=1.0,
        DPLR=False,
    )  
    print("NPLR LAGT")
    A, B, P, S = nplr_lagt.A, nplr_lagt.B, nplr_lagt.P, nplr_lagt.S
    gu_A, gu_B, gu_P, gu_S = (
        jnp.asarray(gu_nplr_lagt.A, dtype=jnp.float32),
        jnp.asarray(gu_nplr_lagt.B, dtype=jnp.float32),
        jnp.asarray(gu_nplr_lagt.P, dtype=jnp.float32),
        jnp.asarray(gu_nplr_lagt.S, dtype=jnp.float32),
    )
    print(f"\nA:\n{A}\n")
    print(f"gu_A:\n{gu_A}\n")
    assert jnp.allclose(A, gu_A, rtol=1e-04, atol=1e-06)

    print(f"B:\n{B}\n")
    print(f"gu_B:\n{gu_B}\n")
    assert jnp.allclose(B, gu_B, rtol=1e-04, atol=1e-06)

    print(f"P:\n{P}\n")
    print(f"gu_P:\n{gu_P}\n")
    assert jnp.allclose(P, gu_P, rtol=1e-04, atol=1e-06)

    print(f"S:\n{S}\n")
    print(f"gu_S:\n{gu_S}\n")
    assert jnp.allclose(S, gu_S, rtol=1e-04, atol=1e-06)

In [29]:
test_NPLR_LagT()

NPLR LAGT

A:
[[-1.         -0.         -0.         -0.         -0.         -0.
  -0.         -0.        ]
 [-1.         -1.         -0.         -0.         -0.         -0.
  -0.         -0.        ]
 [-1.         -1.         -1.         -0.         -0.         -0.
  -0.         -0.        ]
 [-1.         -1.         -1.         -1.         -0.         -0.
  -0.         -0.        ]
 [-1.         -1.         -1.         -1.         -1.         -0.
  -0.         -0.        ]
 [-1.         -1.         -1.         -1.         -1.         -1.
  -0.         -0.        ]
 [-1.         -1.         -1.         -1.         -1.         -1.
  -1.         -0.        ]
 [-0.99999976 -0.99999976 -0.99999976 -0.99999976 -0.99999976 -0.99999976
  -0.99999976 -1.        ]]

gu_A:
[[-1.         -0.         -0.         -0.         -0.         -0.
  -0.         -0.        ]
 [-1.         -1.         -0.         -0.         -0.         -0.
  -0.         -0.        ]
 [-1.         -1.         -1.         -0

## NPLR-LegS

In [30]:
def test_NPLR_LegS():
    the_measure = "legs"
    rank = 1
    nplr_legs = LowRankMatrix(N=8, rank=rank, measure=the_measure, DPLR=False)
    gu_nplr_legs = GuLowRankMatrix(N=8, rank=rank, measure=the_measure, DPLR=False)
    print("NPLR LEGS")
    A, B, P, S = nplr_legs.A, nplr_legs.B, nplr_legs.P, nplr_legs.S
    gu_A, gu_B, gu_P, gu_S = (
        jnp.asarray(gu_nplr_legs.A, dtype=jnp.float32),
        jnp.asarray(gu_nplr_legs.B, dtype=jnp.float32),
        jnp.asarray(gu_nplr_legs.P, dtype=jnp.float32),
        jnp.asarray(gu_nplr_legs.S, dtype=jnp.float32),
    )
    print(f"\nA:\n{A}\n")
    print(f"gu_A:\n{gu_A}\n")
    assert jnp.allclose(A, gu_A, rtol=1e-04, atol=1e-06)

    print(f"B:\n{B}\n")
    print(f"gu_B:\n{gu_B}\n")
    assert jnp.allclose(B, gu_B, rtol=1e-04, atol=1e-06)

    print(f"P:\n{P}\n")
    print(f"gu_P:\n{gu_P}\n")
    assert jnp.allclose(P, gu_P, rtol=1e-04, atol=1e-06)

    print(f"S:\n{S}\n")
    print(f"gu_S:\n{gu_S}\n")
    assert jnp.allclose(S, gu_S, rtol=1e-04, atol=1e-06)
    

In [31]:
test_NPLR_LegS()

NPLR LEGS

A:
[[ -1.         -0.         -0.         -0.         -0.         -0.
   -0.         -0.       ]
 [ -1.7320508  -2.         -0.         -0.         -0.         -0.
   -0.         -0.       ]
 [ -2.2360678  -3.872983   -3.         -0.         -0.         -0.
   -0.         -0.       ]
 [ -2.6457512  -4.5825753  -5.916079   -4.         -0.         -0.
   -0.         -0.       ]
 [ -3.         -5.196152   -6.7082033  -7.937254   -5.         -0.
   -0.         -0.       ]
 [ -3.3166246  -5.744562   -7.4161973  -8.774963   -9.949874   -6.
   -0.         -0.       ]
 [ -3.6055512  -6.244998   -8.062257   -9.5393915 -10.816654  -11.958261
   -7.         -0.       ]
 [ -3.8729832  -6.708204   -8.660253  -10.24695   -11.61895   -12.845232
  -13.964239   -8.       ]]

gu_A:
[[ -1.          0.          0.          0.          0.          0.
    0.          0.       ]
 [ -1.7320508  -1.9999999   0.          0.          0.          0.
    0.          0.       ]
 [ -2.2360678  -3.872983  

  q = jnp.arange(


## NPLR Applied To Fourier Basis

### NPLR-FRU

In [32]:
def test_NPLR_FRU():
    the_measure = "fourier"
    fourier_type = "fru"
    rank = 1
    nplr_fru = LowRankMatrix(
        N=8, rank=rank, measure=the_measure, fourier_type=fourier_type, DPLR=False
    )
    gu_nplr_fru = GuLowRankMatrix(
        N=8, rank=rank, measure=the_measure, fourier_type=fourier_type, DPLR=False
    )
    print("NPLR FRU")
    A, B, P, S = nplr_fru.A, nplr_fru.B, nplr_fru.P, nplr_fru.S
    gu_A, gu_B, gu_P, gu_S = (
        jnp.asarray(gu_nplr_fru.A, dtype=jnp.float32),
        jnp.asarray(gu_nplr_fru.B, dtype=jnp.float32),
        jnp.asarray(gu_nplr_fru.P, dtype=jnp.float32),
        jnp.asarray(gu_nplr_fru.S, dtype=jnp.float32),
    )
    print(f"\nA:\n{A}\n")
    print(f"gu_A:\n{gu_A}\n")
    assert jnp.allclose(A, gu_A, rtol=1e-04, atol=1e-06)

    print(f"B:\n{B}\n")
    print(f"gu_B:\n{gu_B}\n")
    assert jnp.allclose(B, gu_B, rtol=1e-04, atol=1e-06)

    print(f"P:\n{P}\n")
    print(f"gu_P:\n{gu_P}\n")
    assert jnp.allclose(P, gu_P, rtol=1e-04, atol=1e-06)

    print(f"S:\n{S}\n")
    print(f"gu_S:\n{gu_S}\n")
    assert jnp.allclose(S, gu_S, rtol=1e-04, atol=1e-06)

In [33]:
test_NPLR_FRU()

NPLR FRU

A:
[[-1.        -0.        -1.4142135  0.        -1.4142135  0.
  -1.4142135  0.       ]
 [ 0.         0.         0.         0.         0.         0.
   0.         0.       ]
 [-1.4142135  0.        -2.        -3.1415927 -2.         0.
  -2.         0.       ]
 [ 0.         0.         3.1415927  0.         0.         0.
   0.         0.       ]
 [-1.4142135  0.        -2.         0.        -2.        -6.2831855
  -2.         0.       ]
 [ 0.         0.         0.         0.         6.2831855  0.
   0.         0.       ]
 [-1.4142135  0.        -2.         0.        -2.         0.
  -2.        -9.424778 ]
 [ 0.         0.         0.         0.         0.         0.
   9.424778   0.       ]]

gu_A:
[[-1.         0.        -1.4142135  0.        -1.4142135  0.
  -1.4142135  0.       ]
 [ 0.         0.         0.         0.         0.         0.
   0.         0.       ]
 [-1.4142135  0.        -1.9999999 -3.1415927 -1.9999999  0.
  -1.9999999  0.       ]
 [ 0.         0.         3

### NPLR-FouT

In [34]:
def test_NPLR_FouT():
    the_measure = "fourier"
    fourier_type = "fout"
    rank = 1
    nplr_fout = LowRankMatrix(
        N=8, rank=rank, measure=the_measure, fourier_type=fourier_type, DPLR=False
    )
    gu_nplr_fout = GuLowRankMatrix(
        N=8, rank=rank, measure=the_measure, fourier_type=fourier_type, DPLR=False
    )
    print("NPLR FOUT")
    A, B, P, S = nplr_fout.A, nplr_fout.B, nplr_fout.P, nplr_fout.S
    gu_A, gu_B, gu_P, gu_S = (
        jnp.asarray(gu_nplr_fout.A, dtype=jnp.float32),
        jnp.asarray(gu_nplr_fout.B, dtype=jnp.float32),
        jnp.asarray(gu_nplr_fout.P, dtype=jnp.float32),
        jnp.asarray(gu_nplr_fout.S, dtype=jnp.float32),
    )
    print(f"\nA:\n{A}\n")
    print(f"gu_A:\n{gu_A}\n")
    assert jnp.allclose(A, gu_A, rtol=1e-04, atol=1e-06)

    print(f"B:\n{B}\n")
    print(f"gu_B:\n{gu_B}\n")
    assert jnp.allclose(B, gu_B, rtol=1e-04, atol=1e-06)

    print(f"P:\n{P}\n")
    print(f"gu_P:\n{gu_P}\n")
    assert jnp.allclose(P, gu_P, rtol=1e-04, atol=1e-06)

    print(f"S:\n{S}\n")
    print(f"gu_S:\n{gu_S}\n")
    assert jnp.allclose(S, gu_S, rtol=1e-04, atol=1e-06)

In [35]:
test_NPLR_FouT()

NPLR FOUT

A:
[[ -2.         -0.         -2.828427    0.         -2.828427    0.
   -2.828427    0.       ]
 [  0.          0.          0.          0.          0.          0.
    0.          0.       ]
 [ -2.828427    0.         -4.         -6.2831855  -4.          0.
   -4.          0.       ]
 [  0.          0.          6.2831855   0.          0.          0.
    0.          0.       ]
 [ -2.828427    0.         -4.          0.         -4.        -12.566371
   -4.          0.       ]
 [  0.          0.          0.          0.         12.566371    0.
    0.          0.       ]
 [ -2.828427    0.         -4.          0.         -4.          0.
   -4.        -18.849556 ]
 [  0.          0.          0.          0.          0.          0.
   18.849556    0.       ]]

gu_A:
[[ -2.          0.         -2.828427    0.         -2.828427    0.
   -2.828427    0.       ]
 [  0.          0.          0.          0.          0.          0.
    0.          0.       ]
 [ -2.828427    0.         -3.99

### NPLR-FouD

In [36]:
def test_NPLR_FouD():
    the_measure = "fourier"
    fourier_type = "foud"
    rank = 1
    nplr_foud = LowRankMatrix(
        N=8, rank=rank, measure=the_measure, fourier_type=fourier_type, DPLR=False
    )
    gu_nplr_foud = GuLowRankMatrix(
        N=8, rank=rank, measure=the_measure, fourier_type=fourier_type, DPLR=False
    )
    print("NPLR FOUD")
    A, B, P, S = nplr_foud.A, nplr_foud.B, nplr_foud.P, nplr_foud.S
    gu_A, gu_B, gu_P, gu_S = (
        jnp.asarray(gu_nplr_foud.A, dtype=jnp.float32),
        jnp.asarray(gu_nplr_foud.B, dtype=jnp.float32),
        jnp.asarray(gu_nplr_foud.P, dtype=jnp.float32),
        jnp.asarray(gu_nplr_foud.S, dtype=jnp.float32),
    )
    print(f"\nA:\n{A}\n")
    print(f"gu_A:\n{gu_A}\n")
    assert jnp.allclose(A, gu_A, rtol=1e-04, atol=1e-06)

    print(f"B:\n{B}\n")
    print(f"gu_B:\n{gu_B}\n")
    assert jnp.allclose(B, gu_B, rtol=1e-04, atol=1e-06)

    print(f"P:\n{P}\n")
    print(f"gu_P:\n{gu_P}\n")
    assert jnp.allclose(P, gu_P, rtol=1e-04, atol=1e-06)

    print(f"S:\n{S}\n")
    print(f"gu_S:\n{gu_S}\n")
    assert jnp.allclose(S, gu_S, rtol=1e-04, atol=1e-06)

In [37]:
test_NPLR_FouD()

NPLR FOUD

A:
[[-0.5        -0.         -0.70710677  0.         -0.70710677  0.
  -0.70710677  0.        ]
 [ 0.          0.          0.          0.          0.          0.
   0.          0.        ]
 [-0.70710677  0.         -1.         -3.1415927  -1.          0.
  -1.          0.        ]
 [ 0.          0.          3.1415927   0.          0.          0.
   0.          0.        ]
 [-0.70710677  0.         -1.          0.         -1.         -6.2831855
  -1.          0.        ]
 [ 0.          0.          0.          0.          6.2831855   0.
   0.          0.        ]
 [-0.70710677  0.         -1.          0.         -1.          0.
  -1.         -9.424778  ]
 [ 0.          0.          0.          0.          0.          0.
   9.424778    0.        ]]

gu_A:
[[-0.5         0.         -0.70710677  0.         -0.70710677  0.
  -0.70710677  0.        ]
 [ 0.          0.          0.          0.          0.          0.
   0.          0.        ]
 [-0.70710677  0.         -0.99999994 -3.