In [None]:
import jax


In [None]:
def make_HiPPO(N, HiPPO_type="legs"):
    mat = []
    for k in range(1, N + 1):
        for n in range(1, N + 1):
            if HiPPO_type == "legt":
                mat = build_LegT(n, k)
                # TODO
            elif HiPPO_type == "lagt":
                mat = build_LagT(alpha, beta, N)
                # TODO
            elif HiPPO_type == "legs":
                mat = build_LegS(n, k)
                # TODO
            elif HiPPO_type == "fourier":
                mat = build_FRU(N, theta=0.5)
                # TODO
            elif HiPPO_type == "random":
                A = np.random.randn(N, N) / N
                B = np.random.randn(N, 1)
                #TODO
            elif HiPPO_type == "diagonal":
                A = -np.diag(np.exp(np.random.randn(N)))
                B = np.random.randn(N, 1)
                #TODO
            else:
                raise ValueError("Invalid HiPPO type")
    
    return -np.array(mat)

## Translated Legendre (LegT)

In [None]:
# Translated Legendre (LegT) - vectorized
def build_LegT(N, lambda_n=1):
    q = jnp.arange(N, dtype=jnp.float64)
    n, k = jnp.meshgrid(q, q)
    case = jnp.power(-1.0, (n-k))
    A = None
    B = None
    
    if lambda_n == 1:
        A_base = (-jnp.sqrt(2*n+1)) * jnp.sqrt(2*k+1)
        pre_D = jnp.sqrt(jnp.diag(2*q+1))
        B = D = jnp.diag(pre_D)[:, None]
        A = jnp.where(n <= k, A_base * case, A_base) # if n >= k, then case_2 * A_base is used, otherwise A_base
        
    elif lambda_n == (np.sqrt(2*n+1) * np.power(-1, n)):
        A_base = 2*n+1
        B = jnp.diag(2*q+1) * jnp.power(-1, n)
        A = jnp.where(n >= k, A_base, A_base * case) # if n >= k, then case_2 * A_base is used, otherwise A_base

    return A, B

In [None]:
# Translated Legendre (LegT) - non-vectorized
def build_LegT(N):
    Q = np.arange(N, dtype=np.float64)
    R = (2*Q + 1)[:, None] # / theta
    n, k = np.meshgrid(Q, Q)
    A = np.where(n < k, -1, (-1.)**(n-k+1)) * R
    B = (-1.)**Q[:, None] * R
    return A, B

## Translated Laguerre (LagT)

In [None]:
# Translated Laguerre (LagT) - non-vectorized
def build_LagT(alpha, beta, N):
    big_lambda = jnp.exp(.5 * (ss.gammaln(jnp.arange(N)+alpha+1) - ss.gammaln(jnp.arange(N)+1)))
    inverse_lambda = 1./big_lambda[:, None]
    pre_A = (jnp.eye(N) * ((1 + beta) / 2)) + jnp.tril(np.ones((N, N)), -1)
    pre_B = ss.binom(alpha + jnp.arange(N), jnp.arange(N))[:, None]
    
    A = -inverse_lambda * pre_A * big_lambda[None, :]
    B =  jnp.exp(-.5 * ss.gammaln(1-alpha)) * jnp.power(beta, (1-alpha)/2) * inverse_lambda * pre_B 
    return A, B

## Scaled Legendre (LegS)

In [None]:
# Scaled Legendre (LegS) vectorized 
def build_LegS(N):
    q = jnp.arange(N, dtype=jnp.float64)
    n, k = jnp.meshgrid(q, q)
    pre_D = jnp.sqrt(jnp.diag(2*q+1))
    B = D = jnp.diag(pre_D)[:, None]
    
    A_base = (-jnp.sqrt(2*n+1)) * jnp.sqrt(2*k+1)
    case_2 = (n+1)/(2*n+1) 
    
    A = jnp.where(n > k, A_base, 0.0) # if n > k, then A_base is used, otherwise 0
    A = jnp.where(n == k, (A_base * case_2), A) # if n == k, then A_base is used, otherwise A
    
    return A, B

In [None]:
# Scaled Legendre (LegS), non-vectorized
def build_LegS(N):
    q = jnp.arange(N, dtype=jnp.float64) # q represents the values 1, 2, ..., N each column has
    n, k = jnp.meshgrid(q, q)
    M = -(jnp.where(n >= k, 2*q+1, 0) - jnp.diag(q)) # represents the state matrix M 
    D = jnp.sqrt(jnp.diag(2*q+1)) # represents the diagonal matrix D $D := \text{diag}[(2n+1)^{\frac{1}{2}}]^{N-1}_{n=0}$
    A = D @ M @ jnp.linalg.inv(D)
    B = jnp.diag(D)[:, None]
    
    return A, B

## Truncated Fourier (FouT)

In [None]:
# truncated Fourier (FouT) - vectorized
def build_FouT_V(N):
    A = jnp.diag(jnp.stack([jnp.zeros(N//2), jnp.zeros(N//2)], axis=-1).reshape(-1))
    B = jnp.zeros(A.shape[1], dtype=jnp.float64)
    q = jnp.arange((N//2)*2, dtype=jnp.float64)
    n, k = jnp.meshgrid(q, q)
    n_odd = n % 2 == 0
    k_odd = k % 2 == 0
    
    case_1 = (n==k) & (n==0)
    case_2_3 = ((k==0) & (n_odd)) | ((n==0) & (k_odd))
    case_4 = (n_odd) & (k_odd)
    case_5 = (n-k==1) & (k_odd)
    case_6 = (k-n==1) & (n_odd)
    
    A = jnp.where(case_1, -1.0, 
                  jnp.where(case_2_3, -jnp.sqrt(2),
                            jnp.where(case_4, -2, 
                                      jnp.where(case_5, -jnp.pi * (n//2), 
                                                jnp.where(case_6, jnp.pi * (k//2), 0.0)))))
    
    B = B.at[::2].set(jnp.sqrt(2))
    B = B.at[0].set(1)
    #A = 2 * A
    #B = 2 * B
    
    B = B[:, None]
        
    return A, B

In [None]:
# truncated Fourier (FouT) - non-vectorized
def build_FouT(N):
    freqs = jnp.arange(N//2)
    d = jnp.stack([jnp.zeros(N//2), freqs], axis=-1).reshape(-1)[1:]
    A = jnp.pi*(-jnp.diag(d, 1) + jnp.diag(d, -1))
    
    B = jnp.zeros(A.shape[1])
    B = B.at[0::2].set(jnp.sqrt(2))
    B = B.at[0].set(1)
    
    A = A - B[:, None] * B[None, :]
    B = B[:, None]
    
    return A, B