In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import e3nn_jax as e3nn
import jax
jax.config.update("jax_enable_x64", True)
import jax.numpy as jnp

import sys
sys.path.append('../..')

from src.tensor_products import functional
from src.tensor_products import gaunt_tensor_product_utils as gtp_utils

In [None]:
x1 = e3nn.normal(e3nn.s2_irreps(5), jax.random.PRNGKey(0))
x2 = e3nn.normal(e3nn.s2_irreps(5), jax.random.PRNGKey(1))

In [None]:
l1 = 1
lmax = 3

y_grid_lmax = gtp_utils.compute_y_grid(lmax=lmax, res_theta=10, res_phi=10)
y_grid_l1 = gtp_utils.compute_y_grid(lmax=l1, res_theta=10, res_phi=10)


jnp.allclose(
    y_grid_l1, y_grid_lmax[
        :(l1 + 1) ** 2,
        2 * (lmax - l1): 2 * (lmax + l1) + 1, 
        2 * (lmax - l1): 2 * (lmax + l1) + 1
    ]
)

In [None]:
# Test equivariance
gaunt_tensor_product_fourier_2D_fn = lambda x1, x2: functional.gaunt_tensor_product_fourier_2D(
    x1, x2,
    res_theta=6, res_phi=6, convolution_type="direct"
)
gaunt_tensor_product_fourier_2D_fn = jax.jit(gaunt_tensor_product_fourier_2D_fn)

tp_original = gaunt_tensor_product_fourier_2D_fn(
    x1, x2
)

R = e3nn.rand_matrix(jax.random.PRNGKey(3))
tp_rotated = gaunt_tensor_product_fourier_2D_fn(
    x1.transform_by_matrix(R), x2.transform_by_matrix(R)
)

tp_original.transform_by_matrix(R), tp_rotated

In [None]:
tp_s2grid = functional.gaunt_tensor_product_s2grid(
    x1, x2, 
    res_beta=100, res_alpha=99, quadrature="gausslegendre", p_val1=1, p_val2=1, s2grid_fft=False
)
tp_s2grid

In [None]:
y1_grid = gtp_utils.compute_y_grid(2, res_theta=100, res_phi=200)
z_grid = gtp_utils.compute_z_grid(2, res_theta=100, res_phi=200)
x1_uv = jnp.einsum("a,auv->uv", x1.array, y1_grid)
x1_restored = jnp.einsum("uv,auv->a", x1_uv.conj(), z_grid)
x1.array, x1_restored.real, jnp.isclose(x1.array, x1_restored.real, atol=5e-3)

In [None]:
tp_fourier_2D_direct = functional.gaunt_tensor_product_fourier_2D(
    x1, x2,
    res_theta=100, res_phi=99, convolution_type="direct"
)
tp_fourier_2D_fft = functional.gaunt_tensor_product_fourier_2D(
    x1, x2,
    res_theta=100, res_phi=99, convolution_type="fft"
)
jnp.isclose(tp_fourier_2D_direct.array, tp_fourier_2D_fft.array)

In [None]:
# Test equivariance
tp_original = functional.gaunt_tensor_product_s2grid(
    x1, x2,
    res_beta=100, res_alpha=99, quadrature="gausslegendre", p_val1=1, p_val2=1, s2grid_fft=False

)
R = e3nn.rand_matrix(jax.random.PRNGKey(0))
tp_rotated = functional.gaunt_tensor_product_s2grid(
    x1.transform_by_matrix(R), x2.transform_by_matrix(R),
    res_beta=100, res_alpha=99, quadrature="gausslegendre", p_val1=1, p_val2=1, s2grid_fft=False
)
tp_original.transform_by_matrix(R), tp_rotated