# HiPPO Matrices
---

## Table of Contents
* [Loading In Necessary Packages](#load-packages)
* [Instantiate The HiPPO Matrix](#instantiate-the-hippo-matrix)
    * [Translated Legendre (LegT)](#translated-legendre-legt)
        * [LegT](#legt)
        * [LMU](#lmu)
    * [Translated Laguerre (LagT)](#translated-laguerre-lagt)
    * [Scaled Legendre (LegS)](#scaled-legendre-legs)
    * [Fourier Basis](#fourier-basis)
        * [Fourier Recurrent Unit (FRU)](#fourier-recurrent-unit-fru)
        * [Truncated Fourier (FouT)](#truncated-fourier-fout)
        * [Fourier With Decay (FourD)](#fourier-with-decay-fourd)
* [Gu's Linear Time Invariant (LTI) HiPPO Operator](#gus-hippo-legt-operator)
* [Gu's Scale invariant (LSI) HiPPO Operator](#gus-scale-invariant-hippo-legs-operator)
* [Implementation Of General HiPPO Operator](#implementation-of-general-hippo-operator)
* [Test Generalized Bilinear Transform and Zero Order Hold Matrices](#test-generalized-bilinear-transform-and-zero-order-hold-matrices)
    * [Testing Forward Euler on GBT matrices](#testing-forward-euler-transform-for-lti-and-lsi)
    * [Testing Backward Euler on GBT matrices](#testing-backward-euler-transform-for-lti-and-lsi-on-legs-matrices)
    * [Testing Bidirectional on GBT matrices](#testing-lti-and-lsi-operators-with-bidirectional-transform)
    * [Testing ZOH on GBT matrices](#testing-zoh-transform-for-lti-and-lsi-on-legs-matrices)
* [Testing HiPPO Operators](#test-hippo-operators)
    * [Testing Forward Euler on HiPPO Operators](#testing-lti-and-lsi-operators-with-forward-euler-transform)
    * [Testing Backward Euler on HiPPO Operators](#testing-lti-and-lsi-operators-with-backward-euler-transform)
    * [Testing Bidirectional on HiPPO Operators](#testing-lti-and-lsi-operators-with-bidirectional-transform)
    * [Testing ZOH on HiPPO Operators](#testing-lti-and-lsi-operators-with-zoh-transform)
---


## Load Packages

In [1]:
import os
import sys

module_path = os.path.abspath(os.path.join("../../../"))
print(f"module_path: {module_path}")
if module_path not in sys.path:
    print(f"Adding {module_path} to sys.path")
    sys.path.append(module_path)

module_path: /Users/beegass/Documents/Coding/s4mer


In [2]:
# os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "False"
os.environ["TF_FORCE_UNIFIED_MEMORY"] = "1"

In [3]:
## import packages
import jax
import jax.numpy as jnp
from jax._src.scipy.special import _gen_associated_legendre
import einops
from scipy import special as ss
import numpy as np
import torch
import sys

In [4]:
sys.setrecursionlimit(10**6)

In [5]:
print(jax.devices())
print(f"The Device: {jax.lib.xla_bridge.get_backend().platform}")

[CpuDevice(id=0)]
The Device: cpu


In [6]:
print(f"MPS enabled: {torch.backends.mps.is_available()}")

MPS enabled: True


In [7]:
torch.set_printoptions(linewidth=150)
np.set_printoptions(linewidth=150)
jnp.set_printoptions(linewidth=150)

In [8]:
seed = 1701
key = jax.random.PRNGKey(seed)

In [9]:
np.random.seed(seed)

In [10]:
num_copies = 10
subkeys = jax.random.split(key, num=num_copies)
key = subkeys[0]

In [11]:
# def eval_legendre(x, coeffs, tensor=False):
#     """
#     Evaluate Legendre series at points x.

#     Parameters
#     ----------
#     x : array_like
#         Points at which to evaluate the series.
#     coeffs : array_like
#         Coefficients of the Legendre series. The series is assumed to be of
#         the form ``c[0]*P_0(x) + c[1]*P_1(x) + ... + c[N]*P_N(x)``.
#     tensor : bool, optional
#         If True, return a tensor. If False, return an array. Default is False.

#     Returns
#     -------
#     ndarray
#         The value of the Legendre series at the points `x`.

#     References
#     ----------
#     .. [1] Wikipedia, "Legendre polynomials",
#            https://en.wikipedia.org/wiki/Legendre_polynomials

#     """
#     x = jnp.asarray(x)
#     if tensor:
#         result = jnp.zeros(x.shape + coeffs.shape[0:1], dtype=coeffs.dtype)
#     else:
#         result = jnp.zeros(x.shape, dtype=coeffs.dtype)

#     def body(i, val):
#         # Recurrence relation for Legendre polynomials
#         P0 = jnp.ones_like(x)
#         P1 = x
#         P = jax.lax.cond(i > 1, (2 * i - 1) * x * P1 - (i - 1) * P0 / i, P1)

#         # Add current term to result
#         val = val + coeffs[i] * P

#         return i + 1, val

#     _, result = jax.lax.scan(body, 0, result)

#     return result

In [12]:
# def eval_legendre(x, coeffs, tensor=None):
#     """
#     Evaluate Legendre polynomials at points x using the provided coefficients.

#     Args:
#         x: Input points at which to evaluate the Legendre polynomials.
#         coeffs: Coefficients of the Legendre polynomials.
#         tensor: Optional tensor for use with JAX's JIT compilation.

#     Returns:
#         The value of the Legendre polynomials evaluated at the points x.
#     """
#     if tensor is None:
#         tensor = jnp.zeros_like(x)

#     def body(i, val):
#         true_fun = lambda i: (jnp.sqrt((2 * i - 1) / i), None)
#         false_fun = lambda _: 0.0
#         P1 = jax.lax.cond(i > 1, true_fun, false_fun)
#         for j in range(2, jnp.floor(i) + 1):
#             P0, P1 = P1, P
#             P = ((2 * j - 1) * x * P1 - (j - 1) * P0) / j
#         val = val + coeffs[i] * P
#         return i + 1, val

#     _, result = jax.lax.scan(body, 0, tensor)
#     return result

In [13]:
# def eval_legendre(x, coeffs):
#     n = coeffs.shape[0]
#     tensor = jnp.stack([jnp.ones_like(x), x], axis=-1)
#     for i in range(2, n):
#         P0, P1 = tensor[..., i-1], tensor[..., i-2]
#         true_fun = lambda i, val: (jnp.sqrt((2 * i - 1) / i), val)
#         false_fun = lambda i, val: (0.0, val)
#         P1 = jax.lax.cond(i > 1, (i, P1), true_fun, false_fun)[0]
#         tensor = jnp.concatenate([tensor, ((x * P1 - ((i - 1) / i) * P0)[..., None])], axis=-1)
#     val = jnp.zeros_like(x)
#     body_fun = lambda val, p: (val + coeffs[p[0]] * p[1], (p[0]+1, p[2], p[3]))
#     final_val = jax.lax.scan(body_fun, val, (jnp.arange(n), tensor[..., -1], x, P1))[0]
#     return final_val

In [14]:
# def eval_legendre(x, coeffs):
#     # Initialize P0 and P1
#     P0 = jnp.ones_like(x)
#     P1 = x

#     # Initialize tensor for storing intermediate results
#     tensor = P0[..., None]

#     # Loop through recursion relation and compute Legendre polynomials
#     def body(i, val):
#         true_fun = lambda i, P1: (jnp.sqrt((2 * i - 1) / i) * P1, val + coeffs[i] * P1)
#         false_fun = lambda i, P1: (P1, val)
#         P1, val = jax.lax.cond(i > 1, (i, P1), true_fun, false_fun)
#         tensor = jnp.concatenate([tensor, ((x * P1 - ((i - 1) / i) * P0)[..., None])], axis=-1)
#         P0 = P1
#         return i + 1, val

#     _, result = jax.lax.scan(body, 0, jnp.zeros_like(x))
#     return result

In [15]:
# def eval_legendre(x, coeffs):
#     def body(i_val, x_val):
#         i, val, P1, P0 = i_val
#         true_fun = lambda i, P1: (jnp.sqrt((2 * i - 1) / i) * P1, val + coeffs[i] * P1)
#         false_fun = lambda i, P1: (P1, val)
#         P1, val = jax.lax.cond(i > 1, (i, P1), true_fun, false_fun)
#         tensor = jnp.concatenate([x_val, ((x * P1 - ((i - 1) / i) * P0)[..., None])], axis=-1)
#         P0 = P1
#         return (i + 1, val, P1, P0), tensor

#     _, result = jax.lax.scan(body, (0, jnp.zeros_like(x), jnp.zeros_like(x), jnp.zeros_like(x)), jnp.zeros_like(x))
#     return result

In [16]:
# @jax.jit
# def eval_legendre(x, coeffs):
#     def body(carry, i_val):
#         i, val, P1, P0 = carry
#         x_val = i_val[0]

#         true_fun = lambda i_P1: (jnp.sqrt((2 * i - 1) / i) * i_P1, val + coeffs[i] * i_P1)
#         false_fun = lambda i_P1: (i_P1, val)

#         P1, val = jax.lax.cond(i > 1, P1, true_fun, false_fun)

#         tensor = jnp.concatenate([i_val[1], ((x_val * P1 - ((i - 1) / i) * P0)[..., None])], axis=-1)
#         P0 = P1

#         return (i + 1, val, P1, P0), tensor

#     _, result = jax.lax.scan(body, (0, jnp.zeros_like(x), jnp.zeros_like(x), jnp.zeros_like(x)), (x,))
#     return result

In [17]:
# def eval_legendre(x, coeffs):
#     def body(carry, i_val):
#         i, val, P1, P0 = carry
#         x_val = i_val[0]
#         true_fun = lambda i_P1: (jnp.sqrt((2 * i - 1) / i) * i_P1, val + coeffs[i] * i_P1)
#         false_fun = lambda i_P1: (i_P1, val)
#         P1, new_val = jax.lax.cond(i > 0, P1, true_fun, P0, false_fun)
#         return (i + 1, new_val, P1, P0), P1

#     _, result = jax.lax.scan(body, (jnp.array([0]), jnp.zeros_like(x), jnp.zeros_like(x), jnp.zeros_like(x)), x[:, None])
#     return result.ravel()

In [18]:
# def eval_legendre(x, coeffs):
#     def body(carry, i_val):
#         i, P1, P0, val = carry
#         true_fun = lambda i_P1: (jnp.sqrt((2 * i - 1) / i) * i_P1, val + coeffs[i] * i_P1)
#         false_fun = lambda i_P1: (i_P1, val)
#         P1, new_val = jax.lax.cond(i_val[0] > 0, P1, true_fun, P0, false_fun)
#         return (i + 1, new_val, P1, P0), P1

#     _, result = jax.lax.scan(body, (jnp.array([0]), jnp.zeros_like(x), jnp.zeros_like(x), jnp.zeros_like(x)), x[:, None])
#     return result.ravel()

In [19]:
# def eval_legendre(x, coeffs):
#     """Evaluate Legendre polynomials with coefficients `coeffs` at points `x`."""
#     coeffs = jnp.flip(coeffs)

#     def body(carry, i_val):
#         i, val, P1, P0 = carry
#         true_fun = lambda i_P1: (jnp.sqrt((2 * i - 1) / i) * i_P1, val + coeffs[i] * i_P1)
#         false_fun = lambda i_P1: (i_P1, val)
#         P1, new_val = jax.lax.cond(bool(i > 0), P1, lambda i_P1: (jnp.sqrt((2 * i[0] - 1) / i[0]) * i_P1, val + coeffs[i[0]] * i_P1), P0, lambda i_P1: (i_P1, val))
#         return (i + 1, new_val, P1, P0), P1

#     _, result = jax.lax.scan(body, (jnp.array([0]), jnp.zeros_like(x), jnp.zeros_like(x), jnp.zeros_like(x)), x[:, None])
#     return result.ravel()

In [20]:
# def eval_legendre(x, coeffs):
#     def body(carry, i_val):
#         i, val, P1, P0 = carry
#         P2 = jnp.where(jnp.less(i, 2), x, ((2 * i - 1) * x * P1 - (i - 1) * P0) / i)
#         P1, new_val = jax.lax.cond(i > 0, (P1, val), lambda args: (jnp.sqrt((2 * i - 1) / i) * args[0], args[1] + coeffs[i] * args[0]), (P2, val), lambda args: (args[0], args[1]))
#         return (i + 1, new_val, P1, P2), P1

#     _, result = jax.lax.scan(body, (jnp.array([0]), jnp.zeros_like(x), jnp.zeros_like(x), jnp.zeros_like(x)), x[:, None])
#     return result.ravel()

In [21]:
# def eval_legendre(x, coeffs):
#     def body(carry, i_val):
#         i, val, P1, P0 = carry
#         P2 = jnp.where(jnp.less(i, 2), x, ((2 * i - 1) * x * P1 - (i - 1) * P0) / i)
#         P1 = jax.lax.cond(i > 0, (P1, val), lambda args: (jnp.sqrt((2 * i - 1) / i) * args[0], args[1] + coeffs[i] * args[0]), (P2, val), lambda args: (args[0], args[1]))[0]
#         new_val = P1
#         return (i + 1, new_val, P1, P2), P1
#     _, result = jax.lax.scan(body, (jnp.array([0]), jnp.zeros_like(x), jnp.zeros_like(x), jnp.zeros_like(x)), x[:, None])
#     return result.ravel()

In [22]:
# def eval_legendre(x, coeffs):
#     """Evaluate Legendre polynomials with coefficients `coeffs` at `x`."""
#     def body(carry, i_val):
#         i, val, P1, P0 = carry
#         P2 = jnp.where(jnp.less(i, 2), x, ((2 * i - 1) * x * P1 - (i - 1) * P0) / i)
#         P1 = jax.lax.cond(jnp.squeeze(i > 0), (P1, val), lambda args: (jnp.sqrt((2 * i - 1) / i) * args[0], args[1] + coeffs[i] * args[0]), (P2, val), lambda args: (args[0], args[1]))[0]
#         new_val = P1
#         return (i + 1, new_val, P1, P2), P1

#     _, result = jax.lax.scan(body, (jnp.array([0]), jnp.zeros_like(x), jnp.zeros_like(x), jnp.zeros_like(x)), x[:, None])
#     return result.ravel()

In [23]:
# @jax.jit
# def eval_legendre(x, coeffs):
#     """
#     Evaluate the Legendre polynomials with coefficients `coeffs` at the values in `x`.

#     Args:
#     x: a 1D array of input values
#     coeffs: a 1D array of polynomial coefficients

#     Returns:
#     a 1D array of the same length as `x`, representing the polynomial values
#     """
#     p0 = jnp.ones_like(x)
#     p1 = x
#     if coeffs.size == 0:
#         return p0
#     elif coeffs.size == 1:
#         return coeffs[0] * p1
#     else:
#         for n in range(1, coeffs.size):
#             pn = ((2 * n + 1) * x * p1 - n * p0) / (n + 1)
#             p0 = p1
#             p1 = pn
#         return jnp.dot(jnp.vander(x, coeffs.size), coeffs)

In [24]:
# @jax.jit
# def eval_legendre(x, coeffs):
#     """
#     Evaluate the Legendre polynomials with coefficients `coeffs` at the values in `x`.

#     Args:
#     x: a 1D array of input values
#     coeffs: a 1D array of polynomial coefficients

#     Returns:
#     a 1D array of the same length as `x`, representing the polynomial values
#     """
#     n = coeffs.size
#     p = jnp.zeros_like(x)
#     for i in range(n):
#         # Compute the ith Legendre polynomial using Rodrigues' formula
#         dPdx = jax.grad(lambda x: jnp.power(x*x - 1, i))(x)
#         p += coeffs[i] * dPdx / (2**i * jnp.math.factorial(i))
#     return p

In [25]:
# def eval_legendre(x, coeffs):
#     n = coeffs.size
#     p = jnp.zeros_like(x)
#     for i in range(n):
#         # Define a scalar function that takes x as input and returns the ith Legendre polynomial
#         def legendre_scalar(xx):
#             return jnp.power(xx*xx - 1, i)
#         # Compute the gradient of the scalar function at x
#         dPdx = jax.grad(legendre_scalar)(x)
#         p += coeffs[i] * dPdx / (2**i * jnp.math.factorial(i))
#     return p

In [26]:
# def eval_legendre(x, coeffs, degree=None):
#     # If degree is not specified, use the number of coefficients - 1
#     degree = coeffs.size - 1 if degree is None else degree
#     # Broadcast the inputs to the same shape
#     x = jnp.broadcast_to(x, (degree+1, x.size))
#     coeffs = jnp.broadcast_to(coeffs, (degree+1, coeffs.size))
#     # Define the scalar function for the ith Legendre polynomial using Rodrigues' formula
#     def legendre_scalar(i, xx):
#         return jnp.power(xx*xx - 1, i)
#     # Vectorize the scalar function and compute the gradients for all inputs at once
#     dPdx = jax.vmap(jax.grad(legendre_scalar), in_axes=(0, None))(jnp.arange(degree+1), x)
#     # Compute the polynomial as a linear combination of the gradients
#     p = jnp.sum(coeffs * dPdx, axis=0) / 2**jnp.arange(degree+1) / jnp.array([jnp.math.factorial(i) for i in range(degree+1)])
#     return p

In [27]:
# def eval_legendre(x, coeffs, degree=None):
#     if degree is None:
#         degree = len(coeffs) - 1
#     xx = jnp.asarray(x)
#     p = jnp.zeros_like(xx)
#     for i in range(degree+1):
#         # Define a scalar function of a scalar variable
#         def legendre_scalar(x_i):
#             if i == 0:
#                 return jnp.ones_like(x_i)
#             elif i == 1:
#                 return x_i
#             else:
#                 return ((2*i-1)*x_i*legendre_scalar(x_i) - (i-1)*legendre_scalar(x_i-1)) / i
#         # Vectorize the scalar function and compute the gradients for all inputs at once
#         dPdx = jax.vmap(jax.grad(legendre_scalar), in_axes=(0, None))(jnp.arange(degree+1).astype(float), xx)
#         # Compute the polynomial as a linear combination of the gradients
#         p += coeffs[i] * dPdx[i] / (2**i * jnp.math.factorial(i))
#     return p

In [28]:
# def eval_legendre(x, coeffs, degree=None):
#     if degree is None:
#         degree = len(coeffs) - 1
#     xx = jnp.array(x).reshape((-1, 1))
#     p = jnp.zeros_like(xx)
#     for i in range(degree+1):
#         def legendre_scalar(x_i):
#             if i == 0:
#                 return jnp.ones_like(x_i)
#             elif i == 1:
#                 return x_i
#             else:
#                 return ((2*i-1)*x_i*legendre_scalar(x_i) - (i-1)*legendre_scalar(x_i-1)) / i
#         # Vectorize the scalar function and compute the gradients for all inputs at once
#         dPdx = jax.vmap(jax.grad(legendre_scalar))(xx)
#         # Compute the polynomial as a linear combination of the gradients
#         p += coeffs[i] * dPdx[i] / (2**i * jnp.math.factorial(i))
#     return p

In [29]:
# def legendre_scalar(x):
#     if x < 2:
#         return x
#     p_i_minus_2 = 1
#     p_i_minus_1 = x
#     for i in range(2, x+1):
#         p_i = ((2*i-1)*x*p_i_minus_1 - (i-1)*p_i_minus_2) / i
#         p_i_minus_2 = p_i_minus_1
#         p_i_minus_1 = p_i
#     return p_i

# def eval_legendre(x, coeffs, degree=None):
#     if degree is None:
#         degree = len(coeffs) - 1
#     xx = jnp.array(x)
#     p = jnp.zeros_like(xx)
#     for i in range(degree+1):
#         x_i = xx if i == 0 else ((2 * xx * x_i) - x_im1)
#         x_im1 = xx if i == 1 else x_i
#         # Compute the scalar function and its gradient
#         scalar_val = legendre_scalar(x_i).item()
#         dPdx_scalar = jax.grad(legendre_scalar)(x_i).item()
#         # Compute the polynomial as a linear combination of the gradients
#         p += coeffs[i] * dPdx_scalar / (2**i * jnp.math.factorial(i))
#     return p

In [30]:
# def eval_legendre(x, coeffs, degree=None):
#     """
#     Evaluate a Legendre polynomial of a given degree at a given set of points.

#     Parameters:
#         x (float or ndarray): Points at which to evaluate the polynomial, shape (n_points,).
#         coeffs (ndarray): Coefficients of the polynomial, shape (degree+1,).
#         degree (int): Degree of the polynomial to evaluate (optional).
#             If not provided, the degree is assumed to be len(coeffs) - 1.

#     Returns:
#         ndarray: Values of the Legendre polynomial at the given points, shape (n_points,).
#     """
#     x = jnp.atleast_1d(x)
#     degree = degree or len(coeffs) - 1

#     # Define a function to compute the scalar value of the Legendre polynomial at a single point
#     def legendre_scalar(x_i):
#         p_i_minus_2 = 1
#         p_i_minus_1 = x_i
#         p_i = 0.5 * (3 * x_i**2 - 1)  # default value if degree is less than 2
#         for i in range(2, degree + 1):
#             p_i = ((2*i-1)*x_i*p_i_minus_1 - (i-1)*p_i_minus_2) / i
#             p_i_minus_2 = p_i_minus_1
#             p_i_minus_1 = p_i
#         return p_i.item()

#     # Evaluate the polynomial at each point in x
#     p = jnp.zeros_like(x)
#     for i in range(degree + 1):
#         x_i = x if i == 0 else (2 * x * x_i - x_im1)
#         x_im1 = x if i == 1 else x_i
#         # Compute the scalar function and its gradient
#         scalar_val = legendre_scalar(x_i)
#         dPdx_scalar = jax.grad(legendre_scalar)(x_i).item()
#         # Compute the polynomial as a linear combination of the gradients
#         p += coeffs[i] * dPdx_scalar / (2**i * jnp.math.factorial(i))

#     return p

In [31]:
# @jax.jit
# def legendre_scalar(x_i):
#     p_i_minus_2 = 1.0
#     p_i_minus_1 = x_i
#     p_i = x_i
#     for _ in range(degree - 1):
#         p_i = ((2 * _ + 1) * x_i * p_i_minus_1 - _ * p_i_minus_2) / (_ + 1)
#         p_i_minus_2 = p_i_minus_1
#         p_i_minus_1 = p_i
#     return p_i

# @jax.jit
# def eval_legendre(x, coeffs, degree):
#     x = jnp.asarray(x)
#     coeffs = jnp.asarray(coeffs)
#     p = jnp.zeros_like(x)
#     for i in range(degree):
#         x_i = x ** i
#         # Compute the scalar function and its gradient
#         scalar_val = legendre_scalar(x_i)
#         dPdx_scalar = jax.grad(legendre_scalar)(x_i)
#         # Compute the polynomial as a linear combination of the gradients
#         p += coeffs[i] * dPdx_scalar / (2**i * jnp.math.factorial(i))
#     return p.sum(axis=-1)

In [32]:
# def eval_legendre(x, coeffs, degree=None):
#     x = jnp.asarray(x)
#     coeffs = jnp.asarray(coeffs)
#     degree = degree or len(coeffs) - 1
#     p = jnp.zeros_like(x)
#     for i in range(degree+1):
#         x_i = x ** i
#         # Compute the scalar function and its gradient
#         p += coeffs[i] * x_i * (i % 2 + 1)
#     return p

In [33]:
# def eval_legendre(x, coeffs, degree):
#     coeffs = jnp.asarray(coeffs)
#     p = jnp.zeros_like(x)
#     for i in range(degree+1):
#         x_i = x ** i
#         # Compute the scalar function and its gradient
#         p_i, _ = jax.vmap(jax.value_and_grad(legendre))(i, x)
#         p += coeffs[i] * p_i * x_i
#     return p

# def legendre(i, x):
#     if i == 0:
#         return jnp.ones_like(x)
#     elif i == 1:
#         return x
#     else:
#         return ((2 * i - 1) * x * legendre(i - 1, x) -
#                 (i - 1) * legendre(i - 2, x)) / i

In [34]:
# @jax.jit
# def eval_legendre(x, coeffs, degree):
#     coeffs = jnp.asarray(coeffs)
#     p = jnp.zeros_like(x)
#     for i in range(degree):
#         x_i = x ** i
#         # Compute the scalar function and its gradient
#         p += coeffs[i] * x_i * (i + 0.5) / jnp.sqrt(1 - x ** 2)
#     return p

In [35]:
# def eval_legendre(x, coeffs, degree):
#     """
#     Evaluate the Legendre series with given coefficients and degree at the points `x`.

#     Args:
#     x: array_like, shape (n_samples,), the input values to evaluate the Legendre series.
#     coeffs: array_like, shape (degree+1,), the coefficients of the Legendre series.
#     degree: int, the degree of the Legendre series.

#     Returns:
#     An array of shape (n_samples,) with the evaluation of the Legendre series at the points `x`.
#     """
#     p = jnp.zeros((degree+1, len(x)))
#     p = jnp.expand_dims(p[0], axis=0) + coeffs[1] * jnp.expand_dims(p[1], axis=0)
#     for i in range(2, degree+1):
#         p_next = ((2*i-1)/i)*coeffs[i]*p[-1] - ((i-1)/i)*p[-2]
#         p = jnp.concatenate([p, jnp.expand_dims(p_next, axis=0)])
#     return p[-1]

In [36]:
# @jax.jit
# def eval_legendre(x, coeffs, degree):
#     n = degree
#     p0 = jnp.ones_like(x)
#     p1 = x
#     if n == 0:
#         return coeffs[0] * p0
#     elif n == 1:
#         return coeffs[:2] @ jnp.stack((p0, p1))
#     else:
#         coeffs = coeffs[:n+1]
#         pn = 0
#         pn_minus_1 = p1
#         for k in range(1, n):
#             pn = ((2 * k + 1) * x * pn_minus_1 - k * p0) / (k + 1)
#             p0 = pn_minus_1
#             pn_minus_1 = pn
#         pn = ((2 * n + 1) * x * pn_minus_1 - n * p0) / (n + 1)
#         return jnp.dot(coeffs, jnp.vstack([p0, pn_minus_1, pn]))

In [37]:
# def eval_legendre(x, coeffs, degree):
#     """Evaluate Legendre series with coefficients `coeffs` up to order `degree` at locations `x`."""
#     def legendre_poly(n):
#         if n == 0:
#             return jnp.ones_like(x)
#         elif n == 1:
#             return x
#         else:
#             return ((2 * n - 1) * x * legendre_poly(n - 1) - (n - 1) * legendre_poly(n - 2)) / n
#     legendre_vec = jax.vmap(legendre_poly, (None,), 0)
#     return jnp.sum(coeffs[:degree+1] * legendre_vec(jnp.arange(degree+1)), axis=0)

In [38]:
# def eval_legendre(p, x):
#     n = p.shape[0]
#     x = x[..., None]
#     ap = _gen_associated_legendre(p, p, x)
#     bp = jax.lax.select(jax.numpy.mod(p, 2) == 0, 1.0, x)
#     cp = jax.numpy.sqrt(jax.lax.select(x < -0.5, -1.0, 1.0))
#     return ap[..., n-1] * cp * bp

In [39]:
# def eval_legendre(p, x):
#     p = jnp.asarray(p)
#     n = p.shape[0]
#     x = x[..., None]
#     ap = _gen_associated_legendre(p, p, x)
#     bp = jnp.where(jnp.mod(p, 2) == 0, 1.0, x)
#     cp = jnp.sqrt(jnp.where(x < -0.5, -1.0, 1.0))
#     return ap[..., n-1] * cp * bp

In [40]:
# def eval_legendre(n, x):
#     """
#     Evaluate Legendre polynomial of degree n at x using the SciPy implementation.

#     Args:
#     - n: integer degree of the polynomial
#     - x: float point(s) at which to evaluate the polynomial

#     Returns:
#     - y: float the value(s) of the Legendre polynomial
#     """
#     # ensure x is an array
#     x = jnp.array(x)

#     # set l_max equal to n
#     l_max = n

#     # evaluate the associated Legendre functions
#     is_normalized = False
#     Plm = _gen_associated_legendre(l_max, x, is_normalized)

#     # extract the column corresponding to degree n
#     Pn = Plm[n, :, jnp.newaxis]

#     # scale the result by the normalization factor and remove the first
#     # coefficient, which is zero
#     Nn = jnp.sqrt((2 * n + 1) / 2)
#     Pn = Nn * Pn[1:]

#     return Pn

In [41]:
def eval_legendre_x(n, x):
    """
    Evaluate Legendre polynomial of degree n at x using the SciPy implementation.

    Args:
    - n: integer degree of the polynomial
    - x: float point(s) at which to evaluate the polynomial

    Returns:
    - y: float the value(s) of the Legendre polynomial
    """
    # ensure x is an array
    x = jnp.array(x)

    # set l_max equal to n
    l_max = n

    # evaluate the associated Legendre functions
    is_normalized = False
    Plm = _gen_associated_legendre(l_max, x, is_normalized)

    # extract the column corresponding to degree n
    Pn = Plm[n, :]

    # scale the result by the normalization factor and remove the first
    # coefficient, which is zero
    Nn = jnp.sqrt((2 * n + 1) / 2)
    Pn = Nn * Pn[1:]

    return Pn

In [42]:
# def eval_legendre(n, x):
#     """
#     Evaluate the Legendre polynomials up to order n at the points x.

#     Parameters
#     ----------
#     n : int
#         The highest order of the Legendre polynomial to evaluate.
#     x : numpy.ndarray, shape (m,)
#         The points at which to evaluate the Legendre polynomials.

#     Returns
#     -------
#     y : numpy.ndarray, shape (n+1, m)
#         The values of the Legendre polynomials up to order n at the points x.
#     """
#     # Allocate space for output array
#     y = np.zeros((n+1, len(x)))

#     # Evaluate 0-th order Legendre polynomial
#     y[0, :] = 1.0

#     # Evaluate 1-st order Legendre polynomial
#     if n >= 1:
#         y[1, :] = x

#     # Evaluate higher order Legendre polynomials using recurrence relation
#     for i in range(2, n+1):
#         y[i, :] = ((2*i - 1) * x * y[i-1, :] - (i - 1) * y[i-2, :]) / i

#     return y

In [43]:
# def eval_legendre(n, x):
#     if n == 0:
#         return np.ones_like(x)
#     elif n == 1:
#         return np.array([np.ones_like(x), x])
#     else:
#         # recurrence relation
#         L_n_minus_2 = np.ones_like(x)
#         L_n_minus_1 = x
#         L_n = 0.0
#         for i in range(2, n+1):
#             L_n = ((2*i-1)/i) * x * L_n_minus_1 - ((i-1)/i) * L_n_minus_2
#             L_n_minus_2 = L_n_minus_1
#             L_n_minus_1 = L_n
#         return L_n

In [44]:
def eval_legendre_x(n, x):
    if n == 0:
        return jnp.ones_like(x)
    elif n == 1:
        return jnp.vstack((jnp.ones_like(x), x))
    else:
        L = jnp.zeros((n + 1, len(x)))
        L = L.at[0].set(jnp.ones_like(x))
        L = L.at[1].set(x)
        for i in range(2, n + 1):
            L = L.at[i].set(((2 * i - 1) * x * L[i - 1] - (i - 1) * L[i - 2]) / i)
        return L

In [45]:
def hypergeometric(a, b, c, z, eps=1e-10):
    res = term = jnp.ones_like(z)
    n = 0
    while jnp.any(jnp.abs(term) > eps):
        term *= (a + n) * (b + n) * z / ((c + n) * (n + 1))
        res += term
        n += 1
    return res

In [46]:
def legendre_polynomial(n):
    def inner(x):
        return hypergeometric(-n, n + 1, 1, (1 - x) / 2)

    return jax.vmap(inner)

In [47]:
# def eval_legendre(n, x):
#     n = jnp.asarray(n)
#     x = jnp.asarray(x)
#     result = jnp.zeros((n.size, x.size))
#     result = result.at[0].set(jnp.ones_like(x))
#     if n.size > 1:
#         result = result.at[1].set(x)
#         for i in range(2, n.max() + 1):
#             result = result.at[i].set(((2 * i - 1) * x * result[i - 1] - (i - 1) * result[i - 2]) / i)
#     return result[n]

In [48]:
def eval_legendre_x(n, x):
    n = np.asarray(n)
    x = np.asarray(x)
    if np.isscalar(n):
        n = np.array([n])
    result = np.zeros((n.size, x.size))
    for i, ni in enumerate(n):
        if ni == 0:
            result[i] = np.ones_like(x)
        elif ni == 1:
            result[i] = np.vstack((np.ones_like(x), x))
        else:
            L = np.zeros((ni + 1, len(x)))
            L[0] = np.ones_like(x)
            L[1] = x
            for j in range(2, ni + 1):
                L[j] = ((2 * j - 1) * x * L[j - 1] - (j - 1) * L[j - 2]) / j
            result[i] = L[-1]
    return result

In [49]:
# def eval_legendre(n, x):
#     n = jnp.asarray(n)
#     x = jnp.asarray(x)
#     if jnp.isscalar(n):
#         n = jnp.array([n])
#     result = jnp.zeros((n.size, x.size))
#     for i, ni in enumerate(n):
#         if ni == 0:
#             result = result.at[i].set(jnp.ones_like(x))
#         elif ni == 1:
#             result = result.at[i].set(jnp.vstack((jnp.ones_like(x), x)))
#         else:
#             L = jnp.zeros((ni + 1, len(x)))
#             L = jax.ops.index_update(L, 0, jnp.ones_like(x))
#             L = jax.ops.index_update(L, 1, x)
#             for j in range(2, ni + 1):
#                 L = jax.ops.index_update(L, j, ((2 * j - 1) * x * L[j - 1] - (j - 1) * L[j - 2]) / j)
#             result = jax.ops.index_update(result, i, L[-1])
#     return result

In [50]:
# def eval_legendre(n, x):
#     result = jnp.zeros((len(n), len(x)))
#     print(f"result {result}")
#     print(f"result.shape {result.shape}")

#     for i, ni in enumerate(n):
#         ni_val = int(ni)
#         print(f"ni_val {ni_val}")
#         if ni_val == 0:
#             result = result.at[i].set(jnp.ones_like(x))
#             print(f"when ni_val == 0, result {result}")
#             print(f"when ni_val == 0, result.shape {result.shape}")
#         elif ni_val == 1:
#             ones = jnp.ones_like(x)[jnp.newaxis, :]
#             print(f"when ni_val == 0, ones {ones}")
#             print(f"when ni_val == 0, ones.shape {ones.shape}")
#             xs = x.reshape((1, -1))
#             print(f"when ni_val == 0, xs {xs}")
#             print(f"when ni_val == 0, xs.shape {xs.shape}")
#             result = result.at[i].set(jnp.vstack([ones, xs]))
#             print(f"when ni_val == 1, result {result}")
#             print(f"when ni_val == 1, result.shape {result.shape}")
#         else:
#             L = jnp.zeros((ni_val + 1, len(x)))
#             L = L.at[0].set(jnp.ones_like(x))
#             L = L.at[1].set(x)
#             print(f"else, before for-loop, L {L}")
#             print(f"else, before for-loop, L shape {L.shape}")
#             for j in range(1, ni_val):
#                 L = L.at[j + 1].set(((2*j + 1) * x * L[j] - j * L[j-1]) / (j+1))
#                 print(f"else, inside for-loop, L {L}")
#                 print(f"else, inside for-loop, L shape {L.shape}")
#             result = result.at[i].set(L[ni_val])
#             print(f"else, result {result}")
#             print(f"else, result.shape {result.shape}")
#     return result

In [51]:
# def eval_legendre(n, x):
#     result = jnp.zeros((len(n), len(x)))
#     print(f"result {result}")
#     print(f"result.shape {result.shape}")

#     for i, ni in enumerate(n):
#         ni_val = int(ni)
#         print(f"ni_val {ni_val}")
#         if ni_val == 0:
#             result = result.at[i].set(jnp.ones_like(x))
#             print(f"when ni_val == 0, result {result}")
#             print(f"when ni_val == 0, result.shape {result.shape}")
#         elif ni_val == 1:
#             ones = jnp.ones_like(x)[jnp.newaxis, :]
#             print(f"when ni_val == 1, ones {ones}")
#             print(f"when ni_val == 1, ones.shape {ones.shape}")
#             xs = x.reshape((1, -1))
#             print(f"when ni_val == 1, xs {xs}")
#             print(f"when ni_val == 1, xs.shape {xs.shape}")
#             result = result.at[i].set(jnp.vstack([ones, xs]))
#             print(f"when ni_val == 1, result {result}")
#             print(f"when ni_val == 1, result.shape {result.shape}")
#         else:
#             L = jnp.zeros((ni_val + 1, len(x)))
#             L = L.at[0].set(jnp.ones_like(x))
#             L = L.at[1].set(x)
#             print(f"else, before for-loop, L {L}")
#             print(f"else, before for-loop, L shape {L.shape}")
#             for j in range(1, ni_val):
#                 L = L.at[j + 1].set(((2 * j + 1) * x * L[j] - j * L[j - 1]) / (j + 1))
#                 print(f"else, inside for-loop, L {L}")
#                 print(f"else, inside for-loop, L shape {L.shape}")
#             result = result.at[i].set(jnp.expand_dims(L[ni_val], axis=0))
#             print(f"else, result {result}")
#             print(f"else, result.shape {result.shape}")
#     return result

In [52]:
# def eval_legendre(n, x):
#     L0 = jnp.ones_like(x)
#     L1 = x
#     L = jnp.where(n==0, L0, L1)
#     L = jnp.where(n==1, jnp.vstack((L0, L1)), L)

#     for i in range(2, n+1):
#         L = jnp.where(i==n, ((2*i-1)*x*L - (i-1)*L[:-2]) / i, L)

#     return L

In [53]:
# def test_eval_legendre():
#     # Define test inputs
#     coeffs = jnp.array([1, 2, 3], dtype=jnp.float64)
#     x = jnp.linspace(-1, 1, 11, dtype=jnp.float64)

#     # Compute outputs using JAX
#     jax_result = eval_legendre(x, coeffs)

#     # Compute expected outputs using SciPy
#     scipy_result = ss.eval_legendre(x, coeffs)

#     # Compare results
#     assert jnp.allclose(jax_result, scipy_result)

In [54]:
# def test_eval_legendre():
#     coeffs = jnp.array([1, 2, 3], dtype=jnp.float64)
#     x = jnp.linspace(-1, 1, 11, dtype=jnp.float64)

#     # Compute outputs using JAX
#     jax_result = eval_legendre(x, coeffs)

#     # Compute expected outputs using SciPy
#     scipy_result = ss.eval_legendre(x, coeffs)

#     # Compare results
#     assert jnp.allclose(jax_result, scipy_result)

In [55]:
# # Test the function
# def test_eval_legendre():
#     coeffs = jnp.array([1, 0, -1], dtype=jnp.float64)
#     x = jnp.linspace(-1, 1, 11, dtype=jnp.float64)

#     # Compute outputs using JAX
#     jax_result = eval_legendre(x, coeffs)

#     # Compute expected outputs using SciPy
#     scipy_result = ss.eval_legendre(x, coeffs)

#     # Compare results
#     assert jnp.allclose(jax_result, scipy_result)

In [56]:
# def test_eval_legendre():
#     # Generate input values
#     x = jnp.linspace(-1, 1, 11, dtype=jnp.float32)

#     # Generate random Legendre polynomial coefficients
#     coeffs = jax.random.normal(subkeys[1], shape=(len(x),))

#     # Compute outputs using JAX
#     jax_result = eval_legendre(x, coeffs)
#     # Compute expected outputs using SciPy
#     scipy_result = ss.eval_legendre(x, coeffs)
#     # Compare JAX and SciPy results
#     assert jnp.allclose(jax_result, scipy_result)

In [57]:
# def test_eval_legendre():

#     # Generate random number of Legendre polynomial coefficients and input values
#     num_coeffs = jax.random.randint(subkeys[1], shape=(), minval=1, maxval=10)
#     num_vals = jax.random.randint(subkeys[2], shape=(), minval=1, maxval=20)

#     # Generate random Legendre polynomial coefficients
#     coeffs = jax.random.normal(subkeys[3], shape=(num_coeffs,))

#     # Generate input values
#     x = jnp.linspace(-1, 1, num_vals, dtype=jnp.float64)

#     # Compute JAX outputs
#     jax_result = eval_legendre(x, coeffs)

#     # Compute expected outputs using SciPy
#     scipy_result = ss.eval_legendre(jnp.arange(num_coeffs), x) @ coeffs

#     print(jax_result.shape, scipy_result.shape)

#     # Compare JAX and SciPy results
#     assert jnp.allclose(jax_result, scipy_result)

In [58]:
# def test_eval_legendre():
#     key = jax.random.PRNGKey(0)

#     # Generate random number of Legendre polynomial coefficients and input values
#     num_coeffs = jax.random.randint(key, shape=(), minval=1, maxval=10)
#     num_vals = jax.random.randint(key, shape=(), minval=1, maxval=20)

#     # Generate random Legendre polynomial coefficients
#     coeffs = jax.random.normal(key, shape=(num_coeffs,))

#     # Generate input values
#     x = jnp.linspace(-1, 1, num_vals, dtype=jnp.float64)

#     # Compute expected outputs using NumPy/SciPy
#     expected_result = ss.eval_legendre(jnp.arange(num_coeffs), x) @ coeffs

#     # Compute actual outputs using your function
#     actual_result = eval_legendre(x, coeffs)

#     # Compare expected and actual results using allclose with a tolerance of 1e-12
#     assert jnp.allclose(actual_result, expected_result, rtol=1e-12)

In [59]:
# def test_eval_legendre():

#     # Generate random maximum degree of the polynomial and number of input values
#     degree = jax.random.randint(subkeys[1], shape=(), minval=1, maxval=10)
#     num_vals = jax.random.randint(subkeys[2], shape=(), minval=1, maxval=20)

#     # Generate random Legendre polynomial coefficients
#     coeffs = jax.random.normal(subkeys[3], shape=(degree+1,))

#     # Generate input values
#     x = jnp.linspace(-1, 1, num_vals, dtype=jnp.float64)

#     # Compute expected outputs using NumPy/SciPy
#     expected_result = ss.eval_legendre(jnp.arange(degree+1), x) @ coeffs

#     # Compute actual outputs using your function
#     actual_result = eval_legendre(x, coeffs)

#     # Compare actual and expected results
#     assert jnp.allclose(actual_result, expected_result)

In [60]:
# def test_eval_legendre():

#     # Generate random maximum degree of the polynomial and number of input values
#     degree = jax.random.randint(subkeys[1], shape=(), minval=1, maxval=10)
#     num_vals = jax.random.randint(subkeys[2], shape=(), minval=1, maxval=20)

#     # Generate random Legendre polynomial coefficients
#     coeffs = jax.random.normal(subkeys[3], shape=(degree+1,))

#     # Generate input values
#     x = jnp.linspace(-1, 1, num_vals, dtype=jnp.float64)

#     # Compute expected outputs using NumPy/SciPy
#     expected_result = jnp.zeros_like(x)
#     for i in range(x.size):
#         p = ss.eval_legendre(jnp.arange(coeffs.size), x[i])
#         expected_result[i] = jnp.dot(p, coeffs)

#     # Compute actual outputs using your function
#     actual_result = eval_legendre(x, coeffs)

#     # Compare actual and expected results
#     assert jnp.allclose(actual_result, expected_result)

In [61]:
# def test_eval_legendre():

#     # Generate random maximum degree of the polynomial and number of input values
#     degree = jax.random.randint(subkeys[1], shape=(), minval=1, maxval=10)
#     num_vals = jax.random.randint(subkeys[2], shape=(), minval=1, maxval=20)

#     # Generate random Legendre polynomial coefficients
#     coeffs = jax.random.normal(subkeys[3], shape=(degree+1,))

#     # Generate input values
#     x = jnp.linspace(-1, 1, num_vals, dtype=jnp.float64)

#     # Compute expected outputs using NumPy/SciPy
#     expected_result = jnp.zeros(num_vals)
#     for i in range(x.size):
#         p = ss.eval_legendre(jnp.arange(coeffs.size), x[i])
#         expected_result = expected_result.at[i].set(jnp.dot(p, coeffs))

#     # Compute actual outputs using your function
#     actual_result = eval_legendre(x, coeffs)

#     # Compare actual and expected results
#     assert jnp.allclose(actual_result, expected_result)

In [62]:
# def test_eval_legendre():

#     # Generate random maximum degree of the polynomial and number of input values
#     degree = jax.random.randint(subkeys[1], shape=(), minval=1, maxval=10)
#     num_vals = jax.random.randint(subkeys[2], shape=(), minval=1, maxval=20)

#     # Generate random Legendre polynomial coefficients
#     coeffs = jax.random.normal(subkeys[3], shape=(degree+1,))

#     # Generate input values
#     x = jnp.linspace(-1, 1, num_vals, dtype=jnp.float64)

#     # Compute expected outputs using NumPy/SciPy
#     expected_result = ss.eval_legendre(jnp.arange(degree+1), x) @ coeffs

#     # Compute actual outputs using your function
#     actual_result = eval_legendre(x, coeffs)

#     # Compare actual and expected results
#     assert jnp.allclose(actual_result, expected_result)

In [63]:
# def test_eval_legendre():
#     x = jnp.array([-0.5, 0.0, 0.5, 1.0])
#     degree = 3
#     coeffs = jnp.array([1.0, 0.0, -0.5, 0.0])

#     expected_result = ss.eval_legendre(jnp.arange(degree+1), x) @ coeffs

#     actual_result = eval_legendre(x, coeffs, degree)

#     assert jnp.allclose(actual_result, expected_result)

In [64]:
# def test_eval_legendre():
#     degree = 3
#     x = jnp.linspace(-1, 1, 100)
#     coeffs = jnp.array([1.0, 0.0, -0.5, 0.0])
#     expected_result = ss.eval_legendre(jnp.arange(degree+1), x).T @ coeffs
#     actual_result = eval_legendre(x, coeffs, degree)
#     assert jnp.allclose(actual_result, expected_result)

In [65]:
# def test_eval_legendre():
#     x = jnp.linspace(-1, 1, 100)
#     degree = 3
#     coeffs = jnp.array([1.0, 0.0, -0.5, 0.0])
#     expected_result = ss.eval_legendre(jnp.arange(degree+1), x) @ coeffs.T
#     actual_result = eval_legendre(x, coeffs, degree)
#     assert jnp.allclose(actual_result, expected_result)

In [66]:
# def test_eval_legendre():
#     degree = 3
#     x = jnp.linspace(-1, 1, 100)
#     coeffs = jnp.array([1.0, 0.0, -0.5, 0.0])
#     expected_result = ss.eval_legendre(jnp.arange(degree+1), x) @ coeffs
#     actual_result = eval_legendre(x, coeffs, degree)
#     assert jnp.allclose(actual_result, expected_result)

In [67]:
# def test_eval_legendre():
#     degree = 3
#     x = jnp.linspace(-1, 1, 100)
#     coeffs = jnp.array([1.0, 0.0, -0.5, 0.0])
#     expected_result = ss.eval_legendre(jnp.arange(degree+1), x).T @ coeffs[:degree+1]
#     actual_result = eval_legendre(x, coeffs, degree)
#     assert jnp.allclose(actual_result, expected_result)
#     print("Test passed")

In [68]:
# def test_eval_legendre():
#     x = jnp.linspace(-1, 1, 100)
#     degree = 3
#     coeffs = jnp.array([1.0, 0.0, -0.5, 0.0])
#     expected_result = ss.eval_legendre(jnp.arange(degree+1), x)
#     actual_result = eval_legendre(x, coeffs, degree)
#     assert jnp.allclose(actual_result, expected_result)

In [69]:
# def test_eval_legendre():
#     # Test case
#     p = np.array([0, 1, 2, 3])
#     x = np.array([-1, -0.5, 0, 0.5, 1])
#     actual_output = eval_legendre(p, x)
#     expected_output = ss.eval_legendre(p, x)
#     assert np.allclose(expected_output, actual_output)

In [70]:
# def test_eval_legendre():
#     # Test case
#     x = np.linspace(-1, 1, 10)
#     for n in range(5):
#         y_pred = eval_legendre(n, x)
#         y = ss.eval_legendre(n, x)
#         print(f"n = {n},\ny_pred = {y_pred},\ny = {y}")
#         assert np.allclose(y_pred, y)

In [71]:
# def test_eval_legendre():
#     # Test case
#     x = np.linspace(-1, 1, 10)
#     for n in range(5):
#         y_pred = eval_legendre(n, x)
#         y = ss.eval_legendre(n, x)
#         print(f"n = {n},\ny_pred = {y_pred},\ny = {y}")
#         assert np.allclose(y_pred, y)

In [72]:
# def test_eval_legendre():
#     x = np.linspace(-1, 1, 10)
#     for n in range(5):
#         y = jnp.array(ss.eval_legendre(n, x))
#         y_pred = eval_legendre(n, jnp.array(x))[n]
#         print(f"n = {n}")
#         print(f"y_pred = {y_pred}")
#         print(f"y_pred shape = {y_pred.shape}")
#         print(f"y = {y}")
#         print(f"y shape = {y.shape}")
#         assert jnp.allclose(y_pred, y)

In [73]:
# def test_eval_legendre():
#     x = jax.random.uniform(subkeys[1], shape=(10,))
#     n = jax.random.randint(subkeys[2], minval=0, maxval=10, shape=(10,))
#     y = ss.eval_legendre(n, x)
#     y_pred = eval_legendre(n, x)

#     print(f"y_pred = {y_pred}")
#     print(f"y_pred shape = {y_pred.shape}")
#     print(f"y = {y}")
#     print(f"y shape = {y.shape}")

#     assert jnp.allclose(y_pred, y, rtol=1e-4, atol=1e-7)

In [74]:
# def test_eval_legendre():
#     subkeys = jax.random.split(jax.random.PRNGKey(0), 3)
#     x = jax.random.uniform(subkeys[1], shape=(10,))
#     n = jax.random.randint(subkeys[2], minval=0, maxval=10, shape=(10,))
#     y = ss.eval_legendre(n, x)
#     y_pred = eval_legendre(n, x)
#     print(f"y_pred = {y_pred}")
#     print(f"y_pred shape = {y_pred.shape}")
#     print(f"y = {y}")
#     print(f"y shape = {y.shape}")
#     assert jnp.allclose(y, y_pred)

In [75]:
def eval_legendre_real(n, x):
    jax.debug.print("n: {x1}", x1=n)
    jax.debug.print("x: {x2}", x2=x)

    def cond_one(x):
        return jnp.ones_like(x)

    def cond_two(x):
        return jnp.vstack((jnp.ones_like(x), x))

    def cond_three(n, x):
        L = jnp.zeros((n + 1, len(x)))
        L = L.at[0].set(jnp.ones_like(x))
        L = L.at[1].set(x)
        for i in range(2, n + 1):
            L = L.at[i].set(((2 * i - 1) * x * L[i - 1] - (i - 1) * L[i - 2]) / i)
        return L

    return jax.lax.cond(
        n == 0,
        cond_one,
        jax.lax.cond(n == 1, cond_two, cond_three, operand=(n, x)),
        operand=x,
    )
    # val = None
    # if n == 0:
    #     val = jnp.ones_like(x)
    # elif n == 1:
    #     val = jnp.vstack((jnp.ones_like(x), x))
    # else:
    #     L = jnp.zeros((n + 1, len(x)))
    #     L = L.at[0].set(jnp.ones_like(x))
    #     L = L.at[1].set(x)
    #     for i in range(2, n + 1):
    #         L = L.at[i].set(((2 * i - 1) * x * L[i - 1] - (i - 1) * L[i - 2]) / i)
    #     val = L
    # return val

In [76]:
def test_eval_legendre():
    n = np.array([0, 1, 2, 3])
    x = np.linspace(-1, 1, n.shape[0])

    print(f"x = {x}")
    print(f"x shape = {x.shape}")
    print(f"n = {n}")
    print(f"n shape = {n.shape}")

    y = ss.eval_legendre(n, x)

    print(f"y = {y}")
    print(f"y shape = {y.shape}")

    y_pred = jax.vmap(eval_legendre_real, in_axes=(0, None))(n, x)

    print(f"y_pred = {y_pred}")
    print(f"y_pred shape = {y_pred.shape}")

    assert np.allclose(y, y_pred)

In [77]:
# def test_eval_legendre():
#     x = np.linspace(-1, 1, 10)
#     for n in range(5):
#         y = ss.eval_legendre(n, x)
#         y_pred = eval_legendre(n, x)
#         print(f"n = {n}")
#         print(f"y_pred = {y_pred}")
#         print(f"y_pred shape = {y_pred.shape}")
#         print(f"y = {y}")
#         print(f"y shape = {y.shape}")
#         assert np.allclose(y_pred, y.reshape((-1, y_pred.shape[1])))

In [78]:
test_eval_legendre()

x = [-1.         -0.33333333  0.33333333  1.        ]
x shape = (4,)
n = [0 1 2 3]
n shape = (4,)
y = [ 1.         -0.33333333 -0.33333333  1.        ]
y shape = (4,)
n: 0
n: 1
n: 2
n: 3
x: [-1.         -0.33333333  0.33333333  1.        ]


TypeError: lax.cond: true_fun and false_fun arguments should be callable.