## Load Packages

In [1]:
import os
import sys

module_path = os.path.abspath(os.path.join("../../../"))
print(f"module_path: {module_path}")
if module_path not in sys.path:
    print(f"Adding {module_path} to sys.path")
    sys.path.append(module_path)

module_path: /home/beegass/Documents/Coding/HiPPO-Jax


In [2]:
# os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "False"
os.environ["TF_FORCE_UNIFIED_MEMORY"] = "1"

In [3]:
## import packages
import jax
import jax.numpy as jnp
import einops
from scipy import special as ss
import numpy as np
import torch
import sys
from functools import partial

In [4]:
from src.utils.util import normalize

  jax.tree_util.register_keypaths(data_clz, keypaths)
  jax.tree_util.register_keypaths(data_clz, keypaths)


In [5]:
sys.setrecursionlimit(10**6)

In [6]:
print(jax.devices())
print(f"The Device: {jax.lib.xla_bridge.get_backend().platform}")

[StreamExecutorGpuDevice(id=0, process_index=0, slice_index=0)]
The Device: gpu


In [7]:
print(f"MPS enabled: {torch.backends.mps.is_available()}")

MPS enabled: False


In [8]:
torch.set_printoptions(linewidth=150)
np.set_printoptions(linewidth=150)
jnp.set_printoptions(linewidth=150)

In [9]:
seed = 1701
key = jax.random.PRNGKey(seed)

In [10]:
np.random.seed(seed)

In [11]:
num_copies = 10
subkeys = jax.random.split(key, num=num_copies)
key = subkeys[0]

In [12]:
# @jax.jit
def legendre_recurrence_old(n, x, max_n):
    """
    Computes the Legendre polynomial of degree n at point x using the recurrence relation.

    Args:
        n: int, the degree of the Legendre polynomial.
        x: float, the point at which to evaluate the polynomial.
        max_n: int, the maximum degree of n in the batch.

    Returns:
        The value of the Legendre polynomial of degree n at point x.
    """
    # Initialize the array to store the Legendre polynomials for all degrees from 0 to max_n
    p = jnp.zeros((max_n + 1,) + x.shape)
    p = p.at[0].set(1.0)  # Set the 0th degree Legendre polynomial
    p = p.at[1].set(x)  # Set the 1st degree Legendre polynomial

    # Compute the Legendre polynomials for degrees 2 to max_n using the recurrence relation
    def body_fun(i, p):
        p_i = ((2 * i - 1) * x * p[i - 1] - (i - 1) * p[i - 2]) / i
        return p.at[i].set(p_i)

    p = jax.lax.fori_loop(2, max_n + 1, body_fun, p)

    return p[n]

In [13]:
# @jax.jit
def eval_legendre_old(n, x, out=None):
    """
    Evaluates the Legendre polynomials of degrees specified in the input array n at the points specified in the input array x.

    Args:
        n: array-like, the degrees of the Legendre polynomials.
        x: array-like, the points at which to evaluate the polynomials.
        out: optional, an output array to store the results.

    Returns:
        An array containing the Legendre polynomial values of the specified degrees at the specified points.
    """
    jax.debug.print("\nStarting eval_legendre_old")

    n = jnp.asarray(n)
    x = jnp.asarray(x)
    n_max = n.max()

    if n.ndim == 1 and x.ndim == 1:
        p = jax.vmap(
            lambda ni: jax.vmap(lambda xi: legendre_recurrence_old(ni, xi, n_max))(x)
        )(n)
        p = jnp.diagonal(
            p
        )  # Get the diagonal elements to match the scipy.special.eval_legendre output
    else:
        p = jax.vmap(
            lambda ni: jax.vmap(lambda xi: legendre_recurrence_old(ni, xi, n_max))(x)
        )(n)

    if out is not None:
        out = jnp.asarray(out)
        out = jnp.copy(p, out=out)
        return out
    else:
        return jnp.squeeze(p)

In [14]:
def clenshaw_legendre(n, x, n_max):
    """
    Evaluate the Legendre polynomial of degree n at point x using Clenshaw's algorithm with JAX's scan function.

    This function makes use of a stable method for evaluating high-degree polynomials, known as Clenshaw's algorithm.
    The method is based on a recurrence relation that is less prone to numerical instability.

    Parameters:
        n (jnp.ndarray): The degree of the Legendre polynomial to be evaluated. Must be a non-negative integer.
        x (jnp.ndarray): The point(s) at which the Legendre polynomial is to be evaluated. Can be a single
                        point (float) or an array of points.

    Returns:
        jnp.ndarray: The value(s) of the Legendre polynomial of degree n, evaluated at point(s) x.
    """

    # def body_fun(carry, coeff_i):
    #     i, (p_im1, p_im2) = carry
    #     p_i = (
    #         2.0 * x * p_im1 - p_im2 + coeff_i
    #     )  # p_i = 2 * x * p_{k-1} - p_{k-2} + c_{N-k+1}
    #     return (i + 1, (p_im1, p_i)), p_i
    
    # p_init = jnp.zeros((2,) + x.shape)
    # p_init = p_init.at[0].set(0)
    # p_init = p_init.at[1].set(0)

    # (_, (p_nim1, p_ni)), p_n = jax.lax.scan(
    #     f=body_fun, init=(1, (p_init[0], p_init[1])), xs=n, reverse=True
    # )

    # return n[0] * 0.5 - p_nim1 + x * p_ni  # \frac{c_{0}}{2} - p_{N-1} + x * p_{N}
    
    # def body_fun(carry, coeff_i):
    #     i, (p_im1, p_im2) = carry
    #     # p_i = (
    #     #     (2.0 * i + 1)/(i + 1) * x * p_im1 - i /(i+1) * p_im2 + coeff_i
    #     # )
    #     p_i = (
    #         ((2*i + 1) * x * p_im2 - i * p_im1) / (i + 1) + coeff_i
    #     )
    #     return (i + 1, (p_im1, p_i)), p_i
    
    # p_init = jnp.zeros((2,) + x.shape)
    # p_init = p_init.at[0].set(0)
    # p_init = p_init.at[1].set(0)

    # (_, (p_nim1, p_ni)), p_n = jax.lax.scan(
    #     f=body_fun, init=(1, (p_init[0], p_init[1])), xs=n, reverse=True
    # )

    # return x * p_ni - p_nim1 + 0.5 * n[0]
    
    # def body_fun(carry, _):
    #     i, (b_i, b_im1) = carry
    #     b_ip1 = (2 * i + 1) * x * b_i / (i + 1) - i * b_im1 / (i + 1)
    #     return (i + 1, (b_ip1, b_i)), None

    # b_init = jnp.ones((2,) + x.shape)
    # b_init = b_init.at[0].set(1)
    # b_init = b_init.at[1].set(x)

    # (_, (b_n, b_nm1)), _ = jax.lax.scan(
    #     f=body_fun, init=(1, (b_init[0], b_init[1])), xs=None, length=(n_max-1)
    # )

    # return jax.lax.cond(n_max == 0, lambda _: b_n, lambda _: b_nm1, None)
    
    
    def body_fun(carry, i):
        b2, b1 = carry
        b = (2.0 * i + 1.0) / (i + 1.0) * x * b1 - i / (i + 1.0) * b2
        return (b1, b), None

    (p_nim1, p_ni), _ = jax.lax.scan(body_fun, (jnp.zeros_like(x), jnp.ones_like(x)), jnp.arange(n_max))

    return p_ni

```python
def clenshaw_legendre(n, x, n_max):
    """
    Evaluate the Legendre polynomial of degree n at point x using Clenshaw's algorithm with JAX's scan function.

    This function makes use of a stable method for evaluating high-degree polynomials, known as Clenshaw's algorithm.
    The method is based on a recurrence relation that is less prone to numerical instability.

    Parameters:
        n (jnp.ndarray): The degree of the Legendre polynomial to be evaluated. Must be a non-negative integer.
        x (jnp.ndarray): The point(s) at which the Legendre polynomial is to be evaluated. Can be a single
                        point (float) or an array of points.

    Returns:
        jnp.ndarray: The value(s) of the Legendre polynomial of degree n, evaluated at point(s) x.
    """
    def body_fun(carry, coeff_i):
        i, (p_im1, p_im2) = carry
        p_i = (
            ((2*i + 1) * x * p_im2 - i * p_im1) / (i + 1) + coeff_i
        )
        return (i + 1, (p_im1, p_i)), p_i
    
    p_init = jnp.zeros((2,) + x.shape)
    p_init = p_init.at[0].set(0)
    p_init = p_init.at[1].set(0)

    (_, (p_nim1, p_ni)), p_n = jax.lax.scan(
        f=body_fun, init=(1, (p_init[0], p_init[1])), xs=n, reverse=True
    )

    return x * p_ni - p_nim1 + 0.5 * n[0]

def legendre_recurrence(n, x, n_max):
    """
    Compute the Legendre polynomials up to degree n_max at a given point or array of points x.

    The function employs the recurrence relation for Legendre polynomials. The Legendre polynomials
    are orthogonal on the interval [-1,1] and are used in a wide array of scientific and mathematical applications.
    This function returns a series of Legendre polynomials evaluated at the point(s) x, up to the degree n_max.

    Parameters:
        n_max (int): The highest degree of Legendre polynomial to compute. Must be a non-negative integer.
        x (jnp.ndarray): The point(s) at which the Legendre polynomials are to be evaluated. Can be a single
                        point (float) or an array of points.

    Returns:
        jnp.ndarray: A sequence of Legendre polynomial values of shape (n_max+1,) + x.shape, evaluated at point(s) x.
                    The i-th entry of the output array corresponds to the Legendre polynomial of degree i.

    Notes:
        The first two Legendre polynomials are initialized as P_0(x) = 1 and P_1(x) = x. The subsequent polynomials
        are computed using the recurrence relation:
        P_{n+1}(x) = ((2n + 1) * x * P_n(x) - n * P_{n-1}(x)) / (n + 1).
    """

    p_init = jnp.zeros((2,) + x.shape)
    p_init = p_init.at[0].set(1.0)  # Set the 0th degree Legendre polynomial
    p_init = p_init.at[1].set(x)  # Set the 1st degree Legendre polynomial

    def body_fun(carry, _):
        i, (p_im1, p_i) = carry
        p_ip1 = ((2 * i + 1) * x * p_i - i * p_im1) / (i + 1)

        return ((i + 1).astype(int), (p_i, p_ip1)), p_ip1

    (_, (_, _)), p_n = jax.lax.scan(
        f=body_fun, init=(1, (p_init[0], p_init[1])), xs=(None), length=(n_max - 1)
    )
    p_n = jnp.concatenate((p_init, p_n), axis=0)

    return p_n[n]

def eval_legendre(n, x, stable=False):
    """
    Evaluate Legendre polynomials of specified degrees at provided point(s).

    This function makes use of a vectorized version of the Legendre polynomial recurrence relation to
    compute the necessary polynomials up to the maximum degree found in 'n'. It then selects and returns
    the values of the polynomials at the degrees specified in 'n' and evaluated at the points in 'x'.

    Parameters:
        n (jnp.ndarray): An array of integer degrees for which the Legendre polynomials are to be evaluated.
                        Each element must be a non-negative integer and the array can be of any shape.
        x (jnp.ndarray): The point(s) at which the Legendre polynomials are to be evaluated. Can be a single
                        point (float) or an array of points. The shape must be broadcastable to the shape of 'n'.

    Returns:
        jnp.ndarray: An array of Legendre polynomial values. The output has the same shape as 'n' and 'x' after broadcasting.
                    The i-th entry corresponds to the Legendre polynomial of degree 'n[i]' evaluated at point 'x[i]'.

    Notes:
        This function makes use of the vectorized map (vmap) functionality in JAX to efficiently compute and select
        the necessary Legendre polynomial values.
    """

    jax.debug.print("\nStarting eval_legendre")

    n = jnp.asarray(n)
    x = jnp.asarray(x)
    n_max = n.max()

    if n.ndim == 1 and x.ndim == 1:
        p = jax.vmap(
            lambda ni: jax.vmap(lambda xi: legendre_recurrence(ni, xi, n_max))(x)
        )(n)
        p = jnp.diagonal(
            p
        )  # Get the diagonal elements to match the scipy.special.eval_legendre output
    else:
        if stable:
            jax.debug.print("Using clenshaw")
            p = jax.vmap(
                lambda ni: jax.vmap(lambda xi: clenshaw_legendre(ni, xi, n_max))(x)
            )(n)
            # p = jax.vmap(clenshaw_legendre, in_axes=(None, 0, None))(n, x, n_max)
        else:
            jax.debug.print("Using recurrence")
            p = jax.vmap(
                lambda ni: jax.vmap(lambda xi: legendre_recurrence(ni, xi, n_max))(x)
            )(n)

    return jnp.squeeze(p)

def test_eval_legendre():
    n = 64
    n = np.arange(n)[:, None]
    
    x = np.arange(0.0, 1, 0.1)
    x = np.linspace(-1, 1, 3000)
    x = x[None, :]
    
    print(f"n: {n}")
    print(f"x: {x}")

    for stable in [False, True]:
        p_jax = eval_legendre(n, x, stable=stable)
        p_scipy = ss.eval_legendre(n, x)
        
        print(f"p_scipy = {p_scipy}")
        print(f"p_scipy shape = {p_scipy.shape}\n")
        
        print(f"p_jax-{stable} = {p_jax}")
        print(f"p_jax-{stable} shape = {p_jax.shape}\n")

        assert np.allclose(p_jax, p_scipy, rtol=1e-4, atol=1e-4), f"Mismatch in Legendre polynomial values for stable={stable}"
        print(f"Test for stable: {stable} - PASSED")

print("Testing eval_legendre...")
test_eval_legendre()
print("All tests passed!")
```

In [15]:
def legendre_recurrence(n, x, n_max):
    """
    Compute the Legendre polynomials up to degree n_max at a given point or array of points x.

    The function employs the recurrence relation for Legendre polynomials. The Legendre polynomials
    are orthogonal on the interval [-1,1] and are used in a wide array of scientific and mathematical applications.
    This function returns a series of Legendre polynomials evaluated at the point(s) x, up to the degree n_max.

    Parameters:
        n_max (int): The highest degree of Legendre polynomial to compute. Must be a non-negative integer.
        x (jnp.ndarray): The point(s) at which the Legendre polynomials are to be evaluated. Can be a single
                        point (float) or an array of points.

    Returns:
        jnp.ndarray: A sequence of Legendre polynomial values of shape (n_max+1,) + x.shape, evaluated at point(s) x.
                    The i-th entry of the output array corresponds to the Legendre polynomial of degree i.

    Notes:
        The first two Legendre polynomials are initialized as P_0(x) = 1 and P_1(x) = x. The subsequent polynomials
        are computed using the recurrence relation:
        P_{n+1}(x) = ((2n + 1) * x * P_n(x) - n * P_{n-1}(x)) / (n + 1).
    """

    p_init = jnp.zeros((2,) + x.shape)
    p_init = p_init.at[0].set(1.0)  # Set the 0th degree Legendre polynomial
    p_init = p_init.at[1].set(x)  # Set the 1st degree Legendre polynomial

    def body_fun(carry, _):
        i, (p_im1, p_i) = carry
        p_ip1 = ((2 * i + 1) * x * p_i - i * p_im1) / (i + 1)

        return ((i + 1).astype(int), (p_i, p_ip1)), p_ip1

    (_, (_, _)), p_n = jax.lax.scan(
        f=body_fun, init=(1, (p_init[0], p_init[1])), xs=(None), length=(n_max - 1)
    )
    p_n = jnp.concatenate((p_init, p_n), axis=0)

    return p_n[n]

In [16]:
def eval_legendre(n, x, stable=False):
    """
    Evaluate Legendre polynomials of specified degrees at provided point(s).

    This function makes use of a vectorized version of the Legendre polynomial recurrence relation to
    compute the necessary polynomials up to the maximum degree found in 'n'. It then selects and returns
    the values of the polynomials at the degrees specified in 'n' and evaluated at the points in 'x'.

    Parameters:
        n (jnp.ndarray): An array of integer degrees for which the Legendre polynomials are to be evaluated.
                        Each element must be a non-negative integer and the array can be of any shape.
        x (jnp.ndarray): The point(s) at which the Legendre polynomials are to be evaluated. Can be a single
                        point (float) or an array of points. The shape must be broadcastable to the shape of 'n'.

    Returns:
        jnp.ndarray: An array of Legendre polynomial values. The output has the same shape as 'n' and 'x' after broadcasting.
                    The i-th entry corresponds to the Legendre polynomial of degree 'n[i]' evaluated at point 'x[i]'.

    Notes:
        This function makes use of the vectorized map (vmap) functionality in JAX to efficiently compute and select
        the necessary Legendre polynomial values.
    """

    jax.debug.print("\nStarting eval_legendre")

    n = jnp.asarray(n)
    x = jnp.asarray(x)
    n_max = n.max()

    if n.ndim == 1 and x.ndim == 1:
        p = jax.vmap(
            lambda ni: jax.vmap(lambda xi: legendre_recurrence(ni, xi, n_max))(x)
        )(n)
        p = jnp.diagonal(
            p
        )  # Get the diagonal elements to match the scipy.special.eval_legendre output
    else:
        if stable:
            jax.debug.print("Using clenshaw")
            p = jax.vmap(
                lambda ni: jax.vmap(lambda xi: clenshaw_legendre(ni, xi, n_max))(x)
            )(n)
        else:
            jax.debug.print("Using recurrence")
            p = jax.vmap(
                lambda ni: jax.vmap(lambda xi: legendre_recurrence(ni, xi, n_max))(x)
            )(n)

    return jnp.squeeze(p)

In [17]:
def test_eval_legendre():
    n = 64
    n = np.arange(n)[:, None]
    
    x = np.arange(0.0, 2, 0.000001)
    x = normalize(x=x)
    # x = np.linspace(-1, 1, 3000)
    x = x[None, :]
    
    # print(f"n: {n}")
    # print(f"x: {x}")

    for stable in [False, True]:
        p_jax = eval_legendre(n, x, stable=stable)
        p_scipy = ss.eval_legendre(n, x)
        
        print(f"p_scipy = {p_scipy}")
        print(f"p_scipy shape = {p_scipy.shape}\n")
        
        print(f"p_jax-{stable} = {p_jax}")
        print(f"p_jax-{stable} shape = {p_jax.shape}\n")

        assert np.allclose(p_jax, p_scipy, rtol=1e-4, atol=1e-4), f"Mismatch in Legendre polynomial values for stable={stable}"
        print(f"Test for stable: {stable} - PASSED")

print("Testing eval_legendre...")
test_eval_legendre()
print("All tests passed!")

Testing eval_legendre...

Starting eval_legendre
Using recurrence
p_scipy = [[ 1.          1.          1.         ...  1.          1.          1.        ]
 [-1.         -0.99999899 -0.99999797 ...  0.99999797  0.99999893  1.        ]
 [ 1.          0.99999696  0.99999392 ...  0.99999392  0.99999678  1.        ]
 ...
 [-1.         -0.99808481 -0.99617145 ...  0.99617145  0.99797221  1.        ]
 [ 1.          0.99802204  0.99604604 ...  0.99604604  0.99790576  1.        ]
 [-1.         -0.99795827 -0.99591863 ...  0.99591863  0.99783824  1.        ]]
p_scipy shape = (64, 2000000)

p_jax-False = [[ 1.          1.          1.         ...  1.          1.          1.        ]
 [-1.         -0.999999   -0.999998   ...  0.999998    0.9999989   1.        ]
 [ 1.          0.9999969   0.9999939  ...  0.9999939   0.9999968   1.        ]
 ...
 [-0.999996   -0.99810416 -0.99620205 ...  0.99620205  0.9979836   0.999996  ]
 [ 0.99999595  0.99804157  0.9960774  ...  0.9960774   0.9979177   0.99999595]

AssertionError: Mismatch in Legendre polynomial values for stable=True