# HiPPO Matrices
---

## Load Packages

In [121]:
## import packages
import jax
import jax.numpy as jnp

from jax.nn.initializers import lecun_normal, uniform
from jax.numpy.linalg import eig, inv, matrix_power
from jax.scipy.signal import convolve

import os
import requests

from scipy import linalg as la
from scipy import signal
from scipy import special as ss

## setup JAX to use TPUs if available
try:
    url = 'http:' + os.environ['TPU_NAME'].split(':')[1] + ':8475/requestversion/tpu_driver_nightly'
    resp = requests.post(url)
    jax.config.FLAGS.jax_xla_backend = 'tpu_driver'
    jax.config.FLAGS.jax_backend_target = os.environ['TPU_NAME']
except:
    pass

jax.devices()

[CpuDevice(id=0)]

In [122]:
N=5

## Instantiate The HiPPO Matrix

## Translated Legendre (LegT)

In [123]:
# Translated Legendre (LegT) - vectorized
def build_LegT_V(N, lambda_n=1):
    q = jnp.arange(N, dtype=jnp.float64)
    k, n = jnp.meshgrid(q, q)
    case = jnp.power(-1.0, (n-k))
    A = None
    B = None
    
    if lambda_n == 1:
        A_base = -jnp.sqrt(2*n+1) * jnp.sqrt(2*k+1)
        pre_D = jnp.sqrt(jnp.diag(2*q+1))
        B = D = jnp.diag(pre_D)[:, None]
        A = jnp.where(k <= n, A_base, A_base * case) # if n >= k, then case_2 * A_base is used, otherwise A_base
        
    elif lambda_n == 2: #(jnp.sqrt(2*n+1) * jnp.power(-1, n)):
        A_base = -(2*n+1)
        B = jnp.diag((2*q+1) * jnp.power(-1, n))[:, None]
        A = jnp.where(k <= n, A_base * case, A_base) # if n >= k, then case_2 * A_base is used, otherwise A_base

    return A, B

In [124]:
# Translated Legendre (LegT) - non-vectorized
def build_LegT(N, legt_type="legt"):
    Q = jnp.arange(N, dtype=jnp.float64)
    pre_R = (2*Q + 1)
    k, n = jnp.meshgrid(Q, Q)
        
    if legt_type == "legt":
        R = jnp.sqrt(pre_R)
        A = R[:, None] * jnp.where(n < k, (-1.)**(n-k), 1) * R[None, :]
        B = R[:, None]
        A = -A

        # Halve again for timescale correctness
        # A, B = A/2, B/2
        #A *= 0.5
        #B *= 0.5
        
    elif legt_type == "lmu":
        R = pre_R[:, None]
        A = jnp.where(n < k, -1, (-1.)**(n-k+1)) * R
        B = (-1.)**Q[:, None] * R
        
    return A, B

In [125]:
nv_LegT_A, nv_LegT_B = build_LegT(N=N, legt_type="lmu")
LegT_A, LegT_B = build_LegT_V(N=N, lambda_n=2)
print(f"nv:\n", nv_LegT_A)
print(f"v:\n", LegT_A)
# print(f"nv:\n", nv_LegT_B)
# print(f"v:\n", LegT_B)
print(f"A Comparison:\n ", jnp.allclose(nv_LegT_A, LegT_A))
print(f"B Comparison:\n ", jnp.allclose(nv_LegT_B, LegT_B))

nv:
 [[-1. -1. -1. -1. -1.]
 [ 3. -3. -3. -3. -3.]
 [-5.  5. -5. -5. -5.]
 [ 7. -7.  7. -7. -7.]
 [-9.  9. -9.  9. -9.]]
v:
 [[-1. -1. -1. -1. -1.]
 [ 3. -3. -3. -3. -3.]
 [-5.  5. -5. -5. -5.]
 [ 7. -7.  7. -7. -7.]
 [-9.  9. -9.  9. -9.]]
A Comparison:
  True
B Comparison:
  True


  lax_internal._check_user_dtype_supported(dtype, "arange")


## Translated Laguerre (LagT)

In [126]:
# Translated Laguerre (LagT) - non-vectorized
def build_LagT(alpha, beta, N):
    A = -jnp.eye(N) * (1 + beta) / 2 - jnp.tril(jnp.ones((N, N)), -1)
    B = ss.binom(alpha + jnp.arange(N), jnp.arange(N))[:, None]

    L = jnp.exp(.5 * (ss.gammaln(jnp.arange(N)+alpha+1) - ss.gammaln(jnp.arange(N)+1)))
    A = (1./L[:, None]) * A * L[None, :]
    B = (1./L[:, None]) * B * jnp.exp(-.5 * ss.gammaln(1-alpha)) * beta**((1-alpha)/2)
    
    return A, B

In [127]:
# Translated Laguerre (LagT) - non-vectorized
def build_LagT_V(alpha, beta, N):
    L = jnp.exp(.5 * (ss.gammaln(jnp.arange(N)+alpha+1) - ss.gammaln(jnp.arange(N)+1)))
    inv_L = 1./L[:, None]
    pre_A = (jnp.eye(N) * ((1 + beta) / 2)) + jnp.tril(jnp.ones((N, N)), -1)
    pre_B = ss.binom(alpha + jnp.arange(N), jnp.arange(N))[:, None]
    
    A = -inv_L * pre_A * L[None, :]
    B =  jnp.exp(-.5 * ss.gammaln(1-alpha)) * jnp.power(beta, (1-alpha)/2) * inv_L * pre_B 
    
    return A, B

In [128]:
nv_LagT_A, nv_LagT_B = build_LagT(alpha=0, beta=1, N=N)
LagT_A, LagT_B = build_LagT_V(alpha=0, beta=1, N=N)
print(f"nv:\n{nv_LagT_A}")
print(f"v:\n{LagT_A}")
# print(f"nv:\n{nv_LagT_B}")
# print(f"v:\n{LagT_B}")
print(f"A Comparison:\n ", jnp.allclose(nv_LagT_A, LagT_A))
print(f"B Comparison:\n ", jnp.allclose(nv_LagT_B, LagT_B))

nv:
[[-1. -0. -0. -0. -0.]
 [-1. -1. -0. -0. -0.]
 [-1. -1. -1. -0. -0.]
 [-1. -1. -1. -1. -0.]
 [-1. -1. -1. -1. -1.]]
v:
[[-1. -0. -0. -0. -0.]
 [-1. -1. -0. -0. -0.]
 [-1. -1. -1. -0. -0.]
 [-1. -1. -1. -1. -0.]
 [-1. -1. -1. -1. -1.]]
A Comparison:
  True
B Comparison:
  True


## Scaled Legendre (LegS)

In [129]:
#Scaled Legendre (LegS) vectorized
def build_LegS_V(N):
    q = jnp.arange(N, dtype=jnp.float64)
    k, n = jnp.meshgrid(q, q)
    pre_D = jnp.sqrt(jnp.diag(2*q+1))
    B = D = jnp.diag(pre_D)[:, None]
    
    A_base = (-jnp.sqrt(2*n+1)) * jnp.sqrt(2*k+1)
    case_2 = (n+1)/(2*n+1)
    
    A = jnp.where(n > k, A_base, 0.0) # if n > k, then A_base is used, otherwise 0
    A = jnp.where(n == k, (A_base * case_2), A) # if n == k, then A_base is used, otherwise A
    
    return A, B

In [130]:
#Scaled Legendre (LegS), non-vectorized
def build_LegS(N):
    q = jnp.arange(N, dtype=jnp.float64)  # q represents the values 1, 2, ..., N each column has
    k, n = jnp.meshgrid(q, q)
    r = 2 * q + 1
    M = -(jnp.where(n >= k, r, 0) - jnp.diag(q)) # represents the state matrix M 
    D = jnp.sqrt(jnp.diag(2 * q + 1)) # represents the diagonal matrix D $D := \text{diag}[(2n+1)^{\frac{1}{2}}]^{N-1}_{n=0}$
    A = D @ M @ jnp.linalg.inv(D)
    B = jnp.diag(D)[:, None]
    B = B.copy() # Otherwise "UserWarning: given NumPY array is not writeable..." after torch.as_tensor(B)
    
    return A, B

In [131]:
nv_LegS_A, nv_LegS_B = build_LegS(N=N)
LegS_A, LegS_B = build_LegS_V(N=N)
print(f"nv:\n{nv_LegS_A}")
print(f"v:\n{LegS_A}")
# print(f"nv:\n{nv_LegS_B}")
# print(f"v:\n{LegS_B}")
print(f"A Comparison:\n ", jnp.allclose(nv_LegS_A, LegS_A))
print(f"B Comparison:\n ", jnp.allclose(nv_LegS_B, LegS_B))

nv:
[[-1.         0.         0.         0.         0.       ]
 [-1.7320508 -1.9999999  0.         0.         0.       ]
 [-2.236068  -3.8729835 -3.         0.         0.       ]
 [-2.6457512 -4.582576  -5.9160795 -4.         0.       ]
 [-3.        -5.196152  -6.708204  -7.9372544 -5.       ]]
v:
[[-1.         0.         0.         0.         0.       ]
 [-1.7320508 -2.         0.         0.         0.       ]
 [-2.236068  -3.8729832 -3.         0.         0.       ]
 [-2.6457512 -4.5825753 -5.9160795 -4.         0.       ]
 [-3.        -5.196152  -6.7082043 -7.937254  -5.       ]]
A Comparison:
  True
B Comparison:
  True


## Fourier Basis

In [132]:
# Fourier Basis OPs and functions - vectorized
def build_Fourier_V(N, fourier_type='FRU'):    
    q = jnp.arange((N//2)*2, dtype=jnp.float64)
    k, n = jnp.meshgrid(q, q)
    
    n_odd = n % 2 == 0
    k_odd = k % 2 == 0
    
    case_1 = (n==k) & (n==0)
    case_2_3 = ((k==0) & (n_odd)) | ((n==0) & (k_odd))
    case_4 = (n_odd) & (k_odd)
    case_5 = (n-k==1) & (k_odd)
    case_6 = (k-n==1) & (n_odd)
    
    A = None
    B = None
    
    if fourier_type == "FRU": # Fourier Recurrent Unit (FRU) - vectorized
        A = jnp.diag(jnp.stack([jnp.zeros(N//2), jnp.zeros(N//2)], axis=-1).reshape(-1))
        B = jnp.zeros(A.shape[1], dtype=jnp.float64)
        q = jnp.arange((N//2)*2, dtype=jnp.float64)
        
        A = jnp.where(case_1, -1.0, 
                    jnp.where(case_2_3, -jnp.sqrt(2),
                                jnp.where(case_4, -2, 
                                        jnp.where(case_5, jnp.pi * (n//2), 
                                                    jnp.where(case_6, -jnp.pi * (k//2), 0.0)))))
        
        B = B.at[::2].set(jnp.sqrt(2))
        B = B.at[0].set(1)
        
    elif fourier_type == "FouT": # truncated Fourier (FouT) - vectorized
        A = jnp.diag(jnp.stack([jnp.zeros(N//2), jnp.zeros(N//2)], axis=-1).reshape(-1))
        B = jnp.zeros(A.shape[1], dtype=jnp.float64)
        k, n = jnp.meshgrid(q, q)
        n_odd = n % 2 == 0
        k_odd = k % 2 == 0
        
        A = jnp.where(case_1, -1.0, 
                    jnp.where(case_2_3, -jnp.sqrt(2),
                                jnp.where(case_4, -2, 
                                        jnp.where(case_5, jnp.pi * (n//2), 
                                                    jnp.where(case_6, -jnp.pi * (k//2), 0.0)))))
        
        B = B.at[::2].set(jnp.sqrt(2))
        B = B.at[0].set(1)
        
        A = 2 * A
        B = 2 * B
        
    elif fourier_type == "fourier_decay":
        A = jnp.diag(jnp.stack([jnp.zeros(N//2), jnp.zeros(N//2)], axis=-1).reshape(-1))
        B = jnp.zeros(A.shape[1], dtype=jnp.float64)
        
        A = jnp.where(case_1, -1.0, 
                    jnp.where(case_2_3, -jnp.sqrt(2),
                                jnp.where(case_4, -2, 
                                        jnp.where(case_5, 2 * jnp.pi * (n//2), 
                                                    jnp.where(case_6, 2 * -jnp.pi * (k//2), 0.0)))))
        
        B = B.at[::2].set(jnp.sqrt(2))
        B = B.at[0].set(1)
        
        A = 0.5 * A
        B = 0.5 * B
        
    
    
    B = B[:, None]
        
    return A, B

In [133]:
def build_Fourier(N, fourier_type='FRU'):
    freqs = jnp.arange(N//2)
    
    if fourier_type == "FRU": # Fourier Recurrent Unit (FRU) - non-vectorized
        d = jnp.stack([jnp.zeros(N//2), freqs], axis=-1).reshape(-1)[1:]
        A = jnp.pi*(-jnp.diag(d, 1) + jnp.diag(d, -1))
        
        B = jnp.zeros(A.shape[1])
        B = B.at[0::2].set(jnp.sqrt(2))
        B = B.at[0].set(1)
        
        A = A - B[:, None] * B[None, :]
        B = B[:, None]

    elif fourier_type == "FouT": # truncated Fourier (FouT) - non-vectorized
        freqs *= 2
        d = jnp.stack([jnp.zeros(N//2), freqs], axis=-1).reshape(-1)[1:]
        A = jnp.pi*(-jnp.diag(d, 1) + jnp.diag(d, -1))
        
        B = jnp.zeros(A.shape[1])
        B = B.at[0::2].set(jnp.sqrt(2))
        B = B.at[0].set(1)

        # Subtract off rank correction - this corresponds to the other endpoint u(t-1) in this case
        A = A - B[:, None] * B[None, :] * 2
        B = B[:, None] * 2
        
    elif fourier_type == "fourier_decay":
        d = jnp.stack([jnp.zeros(N//2), freqs], axis=-1).reshape(-1)[1:]
        A = jnp.pi*(-jnp.diag(d, 1) + jnp.diag(d, -1))
        
        B = jnp.zeros(A.shape[1])
        B = B.at[0::2].set(jnp.sqrt(2))
        B = B.at[0].set(1)

        # Subtract off rank correction - this corresponds to the other endpoint u(t-1) in this case
        A = A - 0.5 * B[:, None] * B[None, :]
        B = 0.5 * B[:, None]
        
            
    return A, B

In [134]:
nv_Fourier_A, nv_Fourier_B = build_Fourier(N=N, fourier_type='FRU')
Fourier_A, Fourier_B = build_Fourier_V(N=N, fourier_type='FRU')
print(f"nv:\n{nv_Fourier_A}")
print(f"v:\n{Fourier_A}")
# print(f"nv:\n{nv_Fourier_B}")
# print(f"v:\n{Fourier_B}")
print(f"A Comparison:\n {jnp.allclose(nv_Fourier_A, Fourier_A)}")
print(f"B Comparison:\n {jnp.allclose(nv_Fourier_B, Fourier_B)}")

nv:
[[-1.         0.        -1.4142135  0.       ]
 [ 0.         0.         0.         0.       ]
 [-1.4142135  0.        -1.9999999 -3.1415927]
 [ 0.         0.         3.1415927  0.       ]]
v:
[[-1.        -0.        -1.4142135  0.       ]
 [ 0.         0.         0.         0.       ]
 [-1.4142135  0.        -2.        -3.1415927]
 [ 0.         0.         3.1415927  0.       ]]
A Comparison:
 True
B Comparison:
 True


  lax_internal._check_user_dtype_supported(dtype, "zeros")


In [135]:
def make_HiPPO(N, v='nv', HiPPO_type="legs", lambda_n=1, fourier_type="FRU", alpha=0, beta=1):
    A = None
    B = None
    for k in range(1, N + 1):
        for n in range(1, N + 1):
            if HiPPO_type == "legt":
                if v == 'nv':
                    A, B = build_LegT(N=N, lambda_n=lambda_n)
                else:
                 A, B = build_LegT_V(N=N, lambda_n=lambda_n) 
                
            elif HiPPO_type == "lagt":
                if v == 'nv':
                    A, B = build_LagT(alpha=alpha, beta=beta, N=N)
                else:
                    A, B = build_LagT_V(alpha=alpha, beta=beta, N=N)
                
            elif HiPPO_type == "legs":
                if v == 'nv':
                    A, B = build_LegS(N=N)
                else:
                    A, B = build_LegS_V(N=N)
                
            elif HiPPO_type == "fourier":
                if v == 'nv':
                    A, B = build_Fourier(N=N, fourier_type=fourier_type)
                else:
                    A, B = build_Fourier_V(N=N, fourier_type=fourier_type)
                
            elif HiPPO_type == "random":
                A = jnp.random.randn(N, N) / N
                B = jnp.random.randn(N, 1)
                
            elif HiPPO_type == "diagonal":
                A = -jnp.diag(jnp.exp(jnp.random.randn(N)))
                B = jnp.random.randn(N, 1)
                
            else:
                raise ValueError("Invalid HiPPO type")
    
    return -jnp.array(A), B

In [136]:
nv_A, nv_B = make_HiPPO(N=N, v='nv', HiPPO_type="legs", lambda_n=1, fourier_type="FRU", alpha=0, beta=1)
v_A, v_B = make_HiPPO(N=N, v='v', HiPPO_type="legs", lambda_n=1, fourier_type="FRU", alpha=0, beta=1)
#print(f"nv:\n", nv_A)
#print(f"v:\n", v_A)
# print(f"nv:\n", nv_B)
print(f"v:\n", v_B)
print(f"A Comparison:\n ", jnp.allclose(nv_A, v_A))
print(f"B Comparison:\n ", jnp.allclose(nv_B, v_B))

max_length=1024
print(jnp.empty((max_length, N, N), dtype=v_A.dtype).shape)

v:
 [[1.       ]
 [1.7320508]
 [2.236068 ]
 [2.6457512]
 [3.       ]]
A Comparison:
  True
B Comparison:
  True
(1024, 5, 5)
