In [12]:
import e3nn_jax as e3nn
from e3nn_jax._src.s2grid import s2_grid, _quadrature_weights_soft
import jax
from jax import numpy as jnp
from matplotlib import pyplot as plt
import numpy as np
import typing

## sample from s2grid

In [37]:
res_alpha = 51
res_beta = 30
quadrature = "soft"
# quadrature = "gausslegendre"
key = jax.random.PRNGKey(0)

In [None]:
def sample_s2grid_uniform(res_beta: int, res_alpha: int, num_samples: int, *, quadrature: str, key: jax.random.PRNGKey):
    """
    Take samples from the uniform distribution on the S2 grid.
    Args:
        res_beta (int)
        res_alpha (int)
        num_samples (int)
        quadrature (str)
        key (jax.random.PRNGKey)
    Returns:
        sampled_z (jnp.numpy.ndarray)
        sampled_alpha (jnp.numpy.ndarray)
    """
    zs, alphas = s2_grid(res_beta, res_alpha, quadrature=quadrature)
    # zs has shape res_beta, alphas has shape res_alpha
    if quadrature == "soft":
        qw = _quadrature_weights_soft(res_beta // 2) * res_beta**2  # [b]
    elif quadrature == "gausslegendre":
        _, qw = np.polynomial.legendre.leggauss(res_beta)
        qw /= 2
    sampled_z = jax.random.choice(key, zs, shape=(num_samples,), p=qw)
    sampled_alpha = jax.random.choice(key, alphas, shape=(num_samples,))

    return sampled_z, sampled_alpha

In [33]:
def sample_s2grid(x: jnp.ndarray, num_samples: int, *, quadrature: str, key: jax.random.PRNGKey):
    r"""
    Take samples from a signal on the S2 grid.
    Args:
        x (`jax.numpy.ndarray`): signal on the sphere of shape ``(..., res_beta, res_alpha)``
        num_samples (int): the number of samples to take from x
        quadrature (str): "soft" or "gausslegendre"
        key (jax.random.PRNGKey)
    Returns:
        x_samples (`jax.numpy.ndarray`): samples taken from x, shape ``(..., num_samples)``
    """
    res_beta, res_alpha = x.shape[-2:]
    if quadrature == "soft":
        qw = _quadrature_weights_soft(res_beta // 2) * res_beta**2  # [b]
    elif quadrature == "gausslegendre":
        _, qw = np.polynomial.legendre.leggauss(res_beta)
        qw /= 2
    sampled_beta_i = jax.random.choice(key, jnp.arange(res_beta), shape=(num_samples,), p=qw)
    sampled_alpha_i = jax.random.choice(key, jnp.arange(res_alpha), shape=(num_samples,))

    return x[..., sampled_beta_i, sampled_alpha_i]

In [None]:
# Your function needs to take as argument a probability distribution sampled on the grid on the sphere

def sample_s2grid(proba, num_samples: int, *, quadrature: str, key: jax.random.PRNGKey):
    '''
    I'm guessing proba is a (z, a) 2D array
    '''
    # integral_s2 proba(x) dx = 1
    # integral_z integral_alpha proba(z, alpha) dz dalpha = 1

    zs, alphas = s2_grid(res_beta, res_alpha, quadrature=quadrature)
    if quadrature == "soft":
        qw = _quadrature_weights_soft(res_beta // 2) * res_beta**2  # [b]
    elif quadrature == "gausslegendre":
        _, qw = np.polynomial.legendre.leggauss(res_beta)
        qw /= 2
    # p_z(z) = integral_alpha proba(z, alpha) dalpha   # implemented with a sum and qw
    p_z = jnp.sum(proba, axis=-1) / res_alpha * qw
    sampled_z = jax.random.choice(key, zs, shape=(num_samples,), p=p_z)
    sampled_alpha = jax.random.choice(key, alphas, shape=(num_samples,), p=proba[sampled_z])

    p_cuml = jnp.cumsum(p_arr)  # p_arr is the p argument of jax.random.choice
    r = p_cuml[-1] * (1 - jax.random.uniform(key, shape=(num_samples,), dtype=p_cuml.dtype))
    ind = jnp.searchsorted(p_cuml, r)
    # and thanks to the p_cuml[-1] * you don’t need to have p normalized
    
    # return sampled_z, sampled_alpha

## sanity check

In [38]:
signal = jax.random.uniform(key, (10, res_beta, res_alpha))
x_samples = sample_s2grid(signal, 5, quadrature=quadrature, key=key)

In [39]:
signal.shape

(10, 30, 51)

In [40]:
x_samples.shape

(10, 5)