In [1]:
import e3nn_jax as e3nn
from e3nn_jax._src.s2grid import s2_grid, _quadrature_weights_soft, to_s2grid
import jax
from jax import numpy as jnp
from matplotlib import pyplot as plt
import numpy as np
import typing
import plotly
import plotly.graph_objects as go

from sampling import sample_on_s2grid

## sample from s2grid

In [2]:
res_alpha = 51
res_beta = 30
quadrature = "gausslegendre"
key = jax.random.PRNGKey(0)



In [3]:
y_orig, alpha_orig = s2_grid(res_beta, res_alpha, quadrature=quadrature)
y, alpha = np.meshgrid(y_orig, alpha_orig, indexing='ij')
x = jnp.cos(alpha) * jnp.sqrt(1 - y**2)
z = jnp.sin(alpha) * jnp.sqrt(1 - y**2)

# c = e3nn.IrrepsArray("0e+1o+2e", jnp.array([0, 0, 100, 0, 0, 0, -100.0, 0, 0]))  # ring!
c = e3nn.IrrepsArray("0e+1o+2e+3o", jnp.array([
    0,
    0, 10, 0, 
    0, 0, 0.0, 0, 0,
    0, 0, 0, 0, 0, 0, 0
]))
f = e3nn.to_s2grid(c, (res_beta, res_alpha), quadrature=quadrature)


In [4]:
Z = e3nn.from_s2grid(jnp.exp(f), 0, p_val=1, p_arg=1, quadrature=quadrature).array[0]  # integral of exp(f) on S2?
p = jnp.exp(f) / Z

In [6]:
if quadrature == "soft":
        qw = _quadrature_weights_soft(res_beta // 2) * res_beta**2  # [b]
elif quadrature == "gausslegendre":
    _, qw = np.polynomial.legendre.leggauss(res_beta)
    qw /= 2

In [7]:
keys = jax.random.split(key, num=500)
sampled_y_i, sampled_alpha_i = jax.vmap(lambda k: sample_on_s2grid(k, p, y_orig, alpha_orig, qw))(keys)
sampled_y = y_orig[sampled_y_i]
sampled_alpha = alpha_orig[sampled_alpha_i]

sampled_x = jnp.cos(sampled_alpha) * jnp.sqrt(1 - sampled_y**2)
sampled_z = jnp.sin(sampled_alpha) * jnp.sqrt(1 - sampled_y**2)

In [8]:
fig = go.Figure(data=[go.Surface(x=x, y=y, z=z, surfacecolor=f)])
fig.show()

In [9]:
fig = go.Figure(data=[go.Surface(x=x*0.8, y=y*0.8, z=z*0.8, surfacecolor=p), go.Scatter3d(x=sampled_x, y=sampled_y, z=sampled_z, mode='markers', marker_size=1)])
fig.show()