# HiPPO Matrices
---

## Table of Contents
* [Loading In Necessary Packages](#load-packages)
* [Instantiate The HiPPO Matrix](#instantiate-the-hippo-matrix)
    * [Translated Legendre (LegT)](#translated-legendre-legt)
        * [LegT](#legt)
        * [LMU](#lmu)
    * [Translated Laguerre (LagT)](#translated-laguerre-lagt)
    * [Scaled Legendre (LegS)](#scaled-legendre-legs)
    * [Fourier Basis](#fourier-basis)
        * [Fourier Recurrent Unit (FRU)](#fourier-recurrent-unit-fru)
        * [Truncated Fourier (FouT)](#truncated-fourier-fout)
        * [Fourier With Decay (FourD)](#fourier-with-decay-fourd)
* [Gu's Linear Time Invariant (LTI) HiPPO Operator](#gus-hippo-legt-operator)
* [Gu's Scale invariant (LSI) HiPPO Operator](#gus-scale-invariant-hippo-legs-operator)
* [Implementation Of General HiPPO Operator](#implementation-of-general-hippo-operator)
* [Test Generalized Bilinear Transform and Zero Order Hold Matrices](#test-generalized-bilinear-transform-and-zero-order-hold-matrices)
    * [Testing Forward Euler on GBT matrices](#testing-forward-euler-transform-for-lti-and-lsi)
    * [Testing Backward Euler on GBT matrices](#testing-backward-euler-transform-for-lti-and-lsi-on-legs-matrices)
    * [Testing Bidirectional on GBT matrices](#testing-lti-and-lsi-operators-with-bidirectional-transform)
    * [Testing ZOH on GBT matrices](#testing-zoh-transform-for-lti-and-lsi-on-legs-matrices)
* [Testing HiPPO Operators](#test-hippo-operators)
    * [Testing Forward Euler on HiPPO Operators](#testing-lti-and-lsi-operators-with-forward-euler-transform)
    * [Testing Backward Euler on HiPPO Operators](#testing-lti-and-lsi-operators-with-backward-euler-transform)
    * [Testing Bidirectional on HiPPO Operators](#testing-lti-and-lsi-operators-with-bidirectional-transform)
    * [Testing ZOH on HiPPO Operators](#testing-lti-and-lsi-operators-with-zoh-transform)
---


## Load Packages

In [1]:
import os
import sys

module_path = os.path.abspath(os.path.join("../../../"))
if module_path not in sys.path:
    sys.path.append(module_path)

In [2]:
## import packages
import math
from typing import Any, Callable, List, Optional, Tuple, Union

import jax
import jax.numpy as jnp
import requests
from flax import linen as jnn
from jax.nn.initializers import lecun_normal, uniform
from jax.numpy.linalg import eig, inv, matrix_power
from jax.scipy.signal import convolve
from jaxtyping import Array, Float, Float16, Float32, Float64, UInt
from scipy import linalg as la
from scipy import signal
from scipy import special as ss

from src.data.process import moving_window, rolling_window

# import modules
from src.models.hippo.gu_transition import GuTransMatrix
from src.models.hippo.unroll import (
    basis,
    measure,
    variable_unroll_matrix,
    variable_unroll_matrix_sequential,
)

print(jax.devices())
print(f"The Device: {jax.lib.xla_bridge.get_backend().platform}")

2023-01-07 23:02:25.550775: W external/org_tensorflow/tensorflow/compiler/xla/service/platform_util.cc:198] unable to create StreamExecutor for CUDA:0: failed initializing StreamExecutor for CUDA device ordinal 0: INTERNAL: failed call to cuDevicePrimaryCtxRetain: CUDA_ERROR_OUT_OF_MEMORY: out of memory; total memory reported: 10492772352
No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)


[CpuDevice(id=0)]
The Device: cpu


In [3]:
from functools import partial
from typing import Any

import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data as data
from einops import rearrange, reduce, repeat

print(f"MPS enabled: {torch.backends.mps.is_available()}")

MPS enabled: False


In [4]:
torch.set_printoptions(linewidth=150)
np.set_printoptions(linewidth=150)
jnp.set_printoptions(linewidth=150)

In [5]:
seed = 1701
key = jax.random.PRNGKey(seed)

In [6]:
num_copies = 5
subkeys = jax.random.split(key, num=num_copies)
key = subkeys[0]

## Instantiate The HiPPO Matrix

In [7]:
class TransMatrix:
    def __init__(
        self,
        N: int,
        measure: str = "legs",
        lambda_n: float = 1.0,
        alpha: float = 0.0,
        beta: float = 1.0,
        dtype: Any = jnp.float32,
    ):
        """
        Instantiates the HiPPO matrix of a given order using a particular measure.
        Args:
            N (int): Order of coefficients to describe the orthogonal polynomial that is the HiPPO projection.
            v (str): choose between this repo's implementation or hazy research's implementation.
            measure (str):
                choose between
                    - HiPPO w/ Translated Legendre (LegT) - legt
                    - HiPPO w/ Translated Laguerre (LagT) - lagt
                    - HiPPO w/ Scaled Legendre (LegS) - legs
                    - HiPPO w/ Fourier basis
                        - FRU: Fourier Recurrent Unit - fru
                        - FouT: Translated Fourier - fout
                        - FourD: Fourier Decay - fourd
            lambda_n (int): The amount of tilt applied to the HiPPO-LegS basis, determines between LegS and LMU.
            alpha (float): The order of the Laguerre basis.
            beta (float): The scale of the Laguerre basis.

        Returns:
            A (jnp.ndarray): The HiPPO matrix multiplied by -1.
            B (jnp.ndarray): The other corresponding state space matrix.

        """
        A = None
        B = None
        if measure in ["legt", "lmu"]:
            if measure == "legt":
                assert lambda_n == 1.0
            elif measure == "lmu":
                assert lambda_n == 2.0
            else:
                raise ValueError("Invalid lambda_n for HiPPO type 'legt' or 'lmu")

            A, B = self.build_LegT(N=N, lambda_n=lambda_n, dtype=dtype)

        elif measure == "lagt":
            A, B = self.build_LagT(alpha=alpha, beta=beta, N=N, dtype=dtype)

        elif measure == "legs":
            A, B = self.build_LegS(N=N, dtype=dtype)

        elif measure in ["fout", "fru", "foud"]:
            A, B = self.build_Fourier(N=N, fourier_type=measure, dtype=dtype)

        elif measure == "random":
            A = jnp.random.randn(N, N) / N
            B = jnp.random.randn(N, 1)

        elif measure == "diagonal":
            A = -jnp.diag(jnp.exp(jnp.random.randn(N)))
            B = jnp.random.randn(N, 1)

        else:
            raise ValueError("Invalid HiPPO type")

        self.A = (A.copy()).astype(dtype)
        self.B = (B.copy()).astype(dtype)

    # Translated Legendre (LegT) - vectorized
    @staticmethod
    def build_LegT(
        N: int, lambda_n: int = 1, dtype=jnp.float32
    ) -> Tuple[Float[Array, "N N"], Float[Array, "N 1"]]:
        """
        The, vectorized implementation of the, measure derived from the translated Legendre basis.

        Args:
            N (int): Order of coefficients to describe the orthogonal polynomial that is the HiPPO projection.
            legt_type (str): Choice between the two different tilts of basis.
                - legt: translated Legendre - 'legt'
                - lmu: Legendre Memory Unit - 'lmu'

        Returns:
            A (jnp.ndarray): The A HiPPO matrix.
            B (jnp.ndarray): The B HiPPO matrix.

        """
        q = jnp.arange(N, dtype=dtype)
        k, n = jnp.meshgrid(q, q)
        case = jnp.power(-1.0, (n - k))
        A = None
        B = None

        if lambda_n == 1:
            A_base = jnp.sqrt(2 * n + 1) * jnp.sqrt(2 * k + 1)
            pre_D = jnp.sqrt(jnp.diag(2 * q + 1))
            B = D = jnp.diag(pre_D)[:, None]
            A = jnp.where(
                k <= n, A_base, A_base * case
            )  # if n >= k, then case_2 * A_base is used, otherwise A_base

        elif lambda_n == 2:  # (jnp.sqrt(2*n+1) * jnp.power(-1, n)):
            A_base = 2 * n + 1
            B = jnp.diag((2 * q + 1) * jnp.power(-1, n))[:, None]
            A = jnp.where(
                k <= n, A_base * case, A_base
            )  # if n >= k, then case_2 * A_base is used, otherwise A_base

        return -A.astype(dtype), B.astype(dtype)

    # Translated Laguerre (LagT) - non-vectorized
    @staticmethod
    def build_LagT(
        alpha: float, beta: float, N: int, dtype=jnp.float32
    ) -> Tuple[Float[Array, "N N"], Float[Array, "N 1"]]:
        """
        The, vectorized implementation of the, measure derived from the translated Laguerre basis.

        Args:
            alpha (float): The order of the Laguerre basis.
            beta (float): The scale of the Laguerre basis.
            N (int): Order of coefficients to describe the orthogonal polynomial that is the HiPPO projection.

        Returns:
            A (jnp.ndarray): The A HiPPO matrix.
            B (jnp.ndarray): The B HiPPO matrix.

        """
        L = jnp.exp(
            0.5
            * (ss.gammaln(jnp.arange(N) + alpha + 1) - ss.gammaln(jnp.arange(N) + 1))
        )
        inv_L = 1.0 / L[:, None]
        pre_A = (jnp.eye(N) * ((1 + beta) / 2)) + jnp.tril(jnp.ones((N, N)), -1)
        pre_B = ss.binom(alpha + jnp.arange(N), jnp.arange(N))[:, None]

        A = -inv_L * pre_A * L[None, :]
        B = (
            jnp.exp(-0.5 * ss.gammaln(1 - alpha))
            * jnp.power(beta, (1 - alpha) / 2)
            * inv_L
            * pre_B
        )

        return A.astype(dtype), B.astype(dtype)

    # Scaled Legendre (LegS) vectorized
    @staticmethod
    def build_LegS(
        N: int, dtype=jnp.float32
    ) -> Tuple[Float[Array, "N N"], Float[Array, "N 1"]]:
        """
        The, vectorized implementation of the, measure derived from the Scaled Legendre basis.

        Args:
            N (int): Order of coefficients to describe the orthogonal polynomial that is the HiPPO projection.

        Returns:
            A (jnp.ndarray): The A HiPPO matrix.
            B (jnp.ndarray): The B HiPPO matrix.

        """
        q = jnp.arange(N, dtype=dtype)
        k, n = jnp.meshgrid(q, q)
        pre_D = jnp.sqrt(jnp.diag(2 * q + 1))
        B = D = jnp.diag(pre_D)[:, None]

        A_base = jnp.sqrt(2 * n + 1) * jnp.sqrt(2 * k + 1)

        A = jnp.where(n > k, A_base, jnp.where(n == k, n + 1, 0.0))

        return -A.astype(dtype), B.astype(dtype)

    # Fourier Basis OPs and functions - vectorized
    @staticmethod
    def build_Fourier(
        N: int, fourier_type: str = "fru", dtype=jnp.float32
    ) -> Tuple[Float[Array, "N N"], Float[Array, "N 1"]]:
        """
        Vectorized measure implementations derived from fourier basis.

        Args:
            N (int): Order of coefficients to describe the orthogonal polynomial that is the HiPPO projection.
            fourier_type (str): The type of Fourier measure.
                - FRU: Fourier Recurrent Unit - fru
                - FouT: truncated Fourier - fout
                - fouD: decayed Fourier - foud

        Returns:
            A (jnp.ndarray): The A HiPPO matrix.
            B (jnp.ndarray): The B HiPPO matrix.

        """
        A = jnp.diag(
            jnp.stack([jnp.zeros(N // 2), jnp.zeros(N // 2)], axis=-1).reshape(-1)[1:],
            1,
        )
        B = jnp.zeros(A.shape[1], dtype=dtype)

        B = B.at[0::2].set(jnp.sqrt(2))
        B = B.at[0].set(1)

        q = jnp.arange(A.shape[1], dtype=dtype)
        k, n = jnp.meshgrid(q, q)

        n_odd = n % 2 == 0
        k_odd = k % 2 == 0

        case_1 = (n == k) & (n == 0)
        case_2_3 = ((k == 0) & (n_odd)) | ((n == 0) & (k_odd))
        case_4 = (n_odd) & (k_odd)
        case_5 = (n - k == 1) & (k_odd)
        case_6 = (k - n == 1) & (n_odd)

        if fourier_type == "fru":  # Fourier Recurrent Unit (FRU) - vectorized
            A = jnp.where(
                case_1,
                -1.0,
                jnp.where(
                    case_2_3,
                    -jnp.sqrt(2),
                    jnp.where(
                        case_4,
                        -2,
                        jnp.where(
                            case_5,
                            jnp.pi * (n // 2),
                            jnp.where(case_6, -jnp.pi * (k // 2), 0.0),
                        ),
                    ),
                ),
            )

        elif fourier_type == "fout":  # truncated Fourier (FouT) - vectorized
            A = jnp.where(
                case_1,
                -1.0,
                jnp.where(
                    case_2_3,
                    -jnp.sqrt(2),
                    jnp.where(
                        case_4,
                        -2,
                        jnp.where(
                            case_5,
                            jnp.pi * (n // 2),
                            jnp.where(case_6, -jnp.pi * (k // 2), 0.0),
                        ),
                    ),
                ),
            )

            A = 2 * A
            B = 2 * B

        elif fourier_type == "foud":
            A = jnp.where(
                case_1,
                -1.0,
                jnp.where(
                    case_2_3,
                    -jnp.sqrt(2),
                    jnp.where(
                        case_4,
                        -2,
                        jnp.where(
                            case_5,
                            2 * jnp.pi * (n // 2),
                            jnp.where(case_6, 2 * -jnp.pi * (k // 2), 0.0),
                        ),
                    ),
                ),
            )

            A = 0.5 * A
            B = 0.5 * B

        B = B[:, None]

        return A.astype(dtype), B.astype(dtype)

## Navigation To Table Of Contents
---
* [Table Of Contents](#table-of-contents)
* [Loading In Necessary Packages](#load-packages)
* [Instantiate The HiPPO Matrix](#instantiate-the-hippo-matrix)
* [Gu's Linear Time Invariant (LTI) HiPPO Operator](#gus-hippo-legt-operator)
* [Gu's Scale invariant (LSI) HiPPO Operator](#gus-scale-invariant-hippo-legs-operator)
* [Implementation Of General HiPPO Operator](#implementation-of-general-hippo-operator)
* [Test Generalized Bilinear Transform and Zero Order Hold Matrices](#test-generalized-bilinear-transform-and-zero-order-hold-matrices)
* [Testing HiPPO Operators](#test-hippo-operators)
---

In [8]:
num_of_coef = 8

## Translated Legendre (LegT)

### LegT

In [9]:
def test_LegT(N):
    legt_matrices = TransMatrix(N=N, measure="legt", lambda_n=1.0)
    A, B = legt_matrices.A, legt_matrices.B
    gu_legt_matrices = GuTransMatrix(N=N, measure="legt", lambda_n=1.0)
    gu_A, gu_B = gu_legt_matrices.A, gu_legt_matrices.B
    print(f"A:\n", A)
    print(f"Gu's A:\n", gu_A)
    print(f"B:\n", B)
    print(f"Gu's B:\n", gu_B)
    assert jnp.allclose(A, gu_A)
    assert jnp.allclose(B, gu_B)

In [10]:
test_LegT(N=num_of_coef)

A:
 [[ -1.          1.7320508  -2.236068    2.6457512  -3.          3.3166249  -3.6055512   3.8729835]
 [ -1.7320508  -3.          3.8729832  -4.5825753   5.196152   -5.7445626   6.244998   -6.708204 ]
 [ -2.236068   -3.8729832  -5.          5.9160795  -6.7082043   7.4161987  -8.062258    8.6602545]
 [ -2.6457512  -4.5825753  -5.9160795  -6.9999995   7.937254   -8.774964    9.5393915 -10.246951 ]
 [ -3.         -5.196152   -6.7082043  -7.937254   -9.          9.949875  -10.816654   11.61895  ]
 [ -3.3166249  -5.7445626  -7.4161987  -8.774964   -9.949875  -11.000001   11.958261  -12.845233 ]
 [ -3.6055512  -6.244998   -8.062258   -9.5393915 -10.816654  -11.958261  -13.         13.96424  ]
 [ -3.8729835  -6.708204   -8.6602545 -10.246951  -11.61895   -12.845233  -13.96424   -15.000001 ]]
Gu's A:
 [[ -1.          1.7320508  -2.236068    2.6457512  -3.          3.3166249  -3.6055512   3.8729835]
 [ -1.7320508  -3.          3.8729832  -4.5825753   5.196152   -5.7445626   6.244998   -6.70820

### LMU

In [11]:
def test_LMU(N):
    lmu_matrices = TransMatrix(
        N=N, measure="lmu", lambda_n=2.0
    )  # change lambda so resulting matrix is in the form of LMU
    A, B = lmu_matrices.A, lmu_matrices.B
    gu_lmu_matrices = GuTransMatrix(
        N=N, measure="lmu", lambda_n=2.0
    )  # change lambda so resulting matrix is in the form of LMU
    gu_A, gu_B = gu_lmu_matrices.A, gu_lmu_matrices.B
    print(f"A:\n", A)
    print(f"Gu's A:\n", gu_A)
    print(f"B:\n", B)
    print(f"Gu's B:\n", gu_B)
    assert jnp.allclose(A, gu_A)
    assert jnp.allclose(B, gu_B)

In [12]:
test_LMU(N=num_of_coef)

A:
 [[ -1.  -1.  -1.  -1.  -1.  -1.  -1.  -1.]
 [  3.  -3.  -3.  -3.  -3.  -3.  -3.  -3.]
 [ -5.   5.  -5.  -5.  -5.  -5.  -5.  -5.]
 [  7.  -7.   7.  -7.  -7.  -7.  -7.  -7.]
 [ -9.   9.  -9.   9.  -9.  -9.  -9.  -9.]
 [ 11. -11.  11. -11.  11. -11. -11. -11.]
 [-13.  13. -13.  13. -13.  13. -13. -13.]
 [ 15. -15.  15. -15.  15. -15.  15. -15.]]
Gu's A:
 [[ -1.  -1.  -1.  -1.  -1.  -1.  -1.  -1.]
 [  3.  -3.  -3.  -3.  -3.  -3.  -3.  -3.]
 [ -5.   5.  -5.  -5.  -5.  -5.  -5.  -5.]
 [  7.  -7.   7.  -7.  -7.  -7.  -7.  -7.]
 [ -9.   9.  -9.   9.  -9.  -9.  -9.  -9.]
 [ 11. -11.  11. -11.  11. -11. -11. -11.]
 [-13.  13. -13.  13. -13.  13. -13. -13.]
 [ 15. -15.  15. -15.  15. -15.  15. -15.]]
B:
 [[  1.]
 [ -3.]
 [  5.]
 [ -7.]
 [  9.]
 [-11.]
 [ 13.]
 [-15.]]
Gu's B:
 [[  1.]
 [ -3.]
 [  5.]
 [ -7.]
 [  9.]
 [-11.]
 [ 13.]
 [-15.]]


## Translated Laguerre (LagT)

In [13]:
def test_LagT(N):
    lagt_matrices = TransMatrix(
        N=N,
        measure="lagt",
        alpha=0.0,  # change resulting tilt through alpha and beta
        beta=1.0,
    )  # change resulting tilt through alpha and beta
    A, B = lagt_matrices.A, lagt_matrices.B
    gu_lagt_matrices = GuTransMatrix(
        N=N,
        measure="lagt",
        alpha=0.0,  # change resulting tilt through alpha and beta
        beta=1.0,
    )  # change resulting tilt through alpha and beta
    gu_A, gu_B = gu_lagt_matrices.A, gu_lagt_matrices.B
    print(f"A:\n", A)
    print(f"Gu's A:\n", gu_A)
    print(f"B:\n", B)
    print(f"Gu's B:\n", gu_B)
    assert jnp.allclose(A, gu_A)
    assert jnp.allclose(B, gu_B)

In [14]:
test_LagT(N=num_of_coef)

A:
 [[-1.         -0.         -0.         -0.         -0.         -0.         -0.         -0.        ]
 [-1.         -1.         -0.         -0.         -0.         -0.         -0.         -0.        ]
 [-1.         -1.         -1.         -0.         -0.         -0.         -0.         -0.        ]
 [-1.         -1.         -1.         -1.         -0.         -0.         -0.         -0.        ]
 [-1.         -1.         -1.         -1.         -1.         -0.         -0.         -0.        ]
 [-1.         -1.         -1.         -1.         -1.         -1.         -0.         -0.        ]
 [-1.         -1.         -1.         -1.         -1.         -1.         -1.         -0.        ]
 [-0.99999976 -0.99999976 -0.99999976 -0.99999976 -0.99999976 -0.99999976 -0.99999976 -1.        ]]
Gu's A:
 [[-1.         -0.         -0.         -0.         -0.         -0.         -0.         -0.        ]
 [-1.         -1.         -0.         -0.         -0.         -0.         -0.         -0.      

## Scaled Legendre (LegS)

In [15]:
def test_LegS(N):
    legs_matrices = TransMatrix(N=N, measure="legs")
    A, B = legs_matrices.A, legs_matrices.B
    gu_legs_matrices = GuTransMatrix(N=N, measure="legs")
    gu_A, gu_B = gu_legs_matrices.A, gu_legs_matrices.B
    print(f"A:\n", A)
    print(f"Gu's A:\n", gu_A)
    print(f"B:\n", B)
    print(f"Gu's B:\n", gu_B)
    assert jnp.allclose(A, gu_A)
    assert jnp.allclose(B, gu_B)

In [16]:
test_LegS(N=num_of_coef)

A:
 [[ -1.         -0.         -0.         -0.         -0.         -0.         -0.         -0.       ]
 [ -1.7320508  -2.         -0.         -0.         -0.         -0.         -0.         -0.       ]
 [ -2.236068   -3.8729832  -3.         -0.         -0.         -0.         -0.         -0.       ]
 [ -2.6457512  -4.5825753  -5.9160795  -4.         -0.         -0.         -0.         -0.       ]
 [ -3.         -5.196152   -6.7082043  -7.937254   -5.         -0.         -0.         -0.       ]
 [ -3.3166249  -5.7445626  -7.4161987  -8.774964   -9.949875   -6.         -0.         -0.       ]
 [ -3.6055512  -6.244998   -8.062258   -9.5393915 -10.816654  -11.958261   -7.         -0.       ]
 [ -3.8729835  -6.708204   -8.6602545 -10.246951  -11.61895   -12.845233  -13.96424    -8.       ]]
Gu's A:
 [[ -1.          0.          0.          0.          0.          0.          0.          0.       ]
 [ -1.7320508  -1.9999999   0.          0.          0.          0.          0.          0.     

## Fourier Basis

### Fourier Recurrent Unit (FRU)

In [17]:
def test_FRU(N):
    fru_matrices = TransMatrix(N=N, measure="fru")
    A, B = fru_matrices.A, fru_matrices.B
    gu_fru_matrices = GuTransMatrix(N=N, measure="fru")
    gu_A, gu_B = gu_fru_matrices.A, gu_fru_matrices.B
    print(f"A:\n", A)
    print(f"Gu's A:\n", gu_A)
    print(f"B:\n", B)
    print(f"Gu's B:\n", gu_B)
    assert jnp.allclose(A, gu_A)
    assert jnp.allclose(B, gu_B)

In [18]:
test_FRU(N=num_of_coef)

A:
 [[-1.        -0.        -1.4142135  0.        -1.4142135  0.        -1.4142135  0.       ]
 [ 0.         0.         0.         0.         0.         0.         0.         0.       ]
 [-1.4142135  0.        -2.        -3.1415927 -2.         0.        -2.         0.       ]
 [ 0.         0.         3.1415927  0.         0.         0.         0.         0.       ]
 [-1.4142135  0.        -2.         0.        -2.        -6.2831855 -2.         0.       ]
 [ 0.         0.         0.         0.         6.2831855  0.         0.         0.       ]
 [-1.4142135  0.        -2.         0.        -2.         0.        -2.        -9.424778 ]
 [ 0.         0.         0.         0.         0.         0.         9.424778   0.       ]]
Gu's A:
 [[-1.         0.        -1.4142135  0.        -1.4142135  0.        -1.4142135  0.       ]
 [ 0.         0.         0.         0.         0.         0.         0.         0.       ]
 [-1.4142135  0.        -1.9999999 -3.1415927 -1.9999999  0.        -1.99999

### Truncated Fourier (FouT)

In [19]:
def test_FouT(N):
    fout_matrices = TransMatrix(N=N, measure="fout")
    A, B = fout_matrices.A, fout_matrices.B
    gu_fout_matrices = GuTransMatrix(N=N, measure="fout")
    gu_A, gu_B = gu_fout_matrices.A, gu_fout_matrices.B
    print(f"A:\n", A)
    print(f"Gu's A:\n", gu_A)
    print(f"B:\n", B)
    print(f"Gu's B:\n", gu_B)
    assert jnp.allclose(A, gu_A)
    assert jnp.allclose(B, gu_B)

In [20]:
test_FouT(N=num_of_coef)

A:
 [[ -2.         -0.         -2.828427    0.         -2.828427    0.         -2.828427    0.       ]
 [  0.          0.          0.          0.          0.          0.          0.          0.       ]
 [ -2.828427    0.         -4.         -6.2831855  -4.          0.         -4.          0.       ]
 [  0.          0.          6.2831855   0.          0.          0.          0.          0.       ]
 [ -2.828427    0.         -4.          0.         -4.        -12.566371   -4.          0.       ]
 [  0.          0.          0.          0.         12.566371    0.          0.          0.       ]
 [ -2.828427    0.         -4.          0.         -4.          0.         -4.        -18.849556 ]
 [  0.          0.          0.          0.          0.          0.         18.849556    0.       ]]
Gu's A:
 [[ -2.          0.         -2.828427    0.         -2.828427    0.         -2.828427    0.       ]
 [  0.          0.          0.          0.          0.          0.          0.          0.     

### Fourier With Decay (FourD)

In [21]:
def test_FouD(N):
    the_measure = "foud"
    foud_matrices = TransMatrix(N=N, measure="foud")
    A, B = foud_matrices.A, foud_matrices.B
    gu_foud_matrices = GuTransMatrix(N=N, measure="foud")
    gu_A, gu_B = gu_foud_matrices.A, gu_foud_matrices.B
    print(f"A:\n", A)
    print(f"Gu's A:\n", gu_A)
    print(f"B:\n", B)
    print(f"Gu's B:\n", gu_B)
    assert jnp.allclose(A, gu_A)
    assert jnp.allclose(B, gu_B)

In [22]:
test_FouD(N=num_of_coef)

A:
 [[-0.5        -0.         -0.70710677  0.         -0.70710677  0.         -0.70710677  0.        ]
 [ 0.          0.          0.          0.          0.          0.          0.          0.        ]
 [-0.70710677  0.         -1.         -3.1415927  -1.          0.         -1.          0.        ]
 [ 0.          0.          3.1415927   0.          0.          0.          0.          0.        ]
 [-0.70710677  0.         -1.          0.         -1.         -6.2831855  -1.          0.        ]
 [ 0.          0.          0.          0.          6.2831855   0.          0.          0.        ]
 [-0.70710677  0.         -1.          0.         -1.          0.         -1.         -9.424778  ]
 [ 0.          0.          0.          0.          0.          0.          9.424778    0.        ]]
Gu's A:
 [[-0.5         0.         -0.70710677  0.         -0.70710677  0.         -0.70710677  0.        ]
 [ 0.          0.          0.          0.          0.          0.          0.          0.      