In [2]:
%load_ext autoreload
%autoreload 2

In [3]:
import e3nn_jax as e3nn
import jax
jax.config.update("jax_enable_x64", True)
import jax.numpy as jnp

import sys
sys.path.append('../..')

from src.tensor_products import functional
from src.tensor_products import gaunt_tensor_product_utils as gtp_utils

In [4]:
# x1 = e3nn.IrrepsArray(
#     e3nn.s2_irreps(2),
#     jnp.concatenate([jax.random.normal(jax.random.PRNGKey(0), (4,)), jnp.zeros(5)]))

# x2 = e3nn.IrrepsArray(
#     e3nn.s2_irreps(2),
#     jnp.concatenate([jax.random.normal(jax.random.PRNGKey(1), (4,)), jnp.zeros(5)]))

x1 = e3nn.normal(e3nn.s2_irreps(1), jax.random.PRNGKey(0))
x2 = e3nn.normal(e3nn.s2_irreps(1), jax.random.PRNGKey(1))

In [5]:
tp_s2grid = functional.gaunt_tensor_product_s2grid(
    x1, x2, 
    res_beta=100, res_alpha=99, quadrature="gausslegendre", p_val1=1, p_val2=1, s2grid_fft=False
)
tp_s2grid

1x0e+1x1o+1x2e
[ 2.69984163  0.46553642  0.18777085 -1.25098149  1.78401518 -0.0545106
 -0.45766027 -2.50128504  2.75731263]

In [7]:
y1_grid = gtp_utils.compute_y_grid(1, res_theta=100, res_phi=200)
z_grid = gtp_utils.compute_z_grid(1, res_theta=100, res_phi=200)
x1_uv = jnp.einsum("a,auv->uv", x1.array, y1_grid)
x1_restored = jnp.einsum("uv,auv->a", x1_uv.conj(), z_grid)
x1.array, x1_restored.real, jnp.isclose(x1.array, x1_restored.real, atol=5e-3)

(Array([ 0.08086788, -0.38624702, -0.37565558,  1.66897423], dtype=float64),
 Array([ 0.08065686, -0.38624506, -0.37495903,  1.66896573], dtype=float64),
 Array([ True,  True,  True,  True], dtype=bool))

In [8]:
tp_fourier_2D_direct = functional.gaunt_tensor_product_fourier_2D(
    x1, x2,
    res_theta=100, res_phi=99, convolution_type="direct"
)
tp_fourier_2D_fft = functional.gaunt_tensor_product_fourier_2D(
    x1, x2,
    res_theta=100, res_phi=99, convolution_type="fft"
)
jnp.isclose(tp_fourier_2D_direct.array, tp_fourier_2D_fft.array)



Array([ True,  True,  True,  True,  True,  True,  True,  True,  True],      dtype=bool)

In [17]:
# Test equivariance
tp_original = functional.gaunt_tensor_product_s2grid(
    x1, x2,
    res_beta=100, res_alpha=99, quadrature="gausslegendre", p_val1=1, p_val2=1, s2grid_fft=False

)
R = e3nn.rand_matrix(jax.random.PRNGKey(0))
tp_rotated = functional.gaunt_tensor_product_s2grid(
    x1.transform_by_matrix(R), x2.transform_by_matrix(R),
    res_beta=100, res_alpha=99, quadrature="gausslegendre", p_val1=1, p_val2=1, s2grid_fft=False
)
tp_original.transform_by_matrix(R), tp_rotated

(1x0e+1x1o+1x2e
 [ 2.69984163 -1.34457227 -0.01454009 -0.09407826 -2.58911482  2.65111986
  -0.76403058 -0.37703174 -1.67226989],
 1x0e+1x1o+1x2e
 [ 2.69984163 -1.34457227 -0.01454009 -0.09407826 -2.58911482  2.65111986
  -0.76403058 -0.37703174 -1.67226989])

In [19]:
# Test equivariance
tp_original = functional.gaunt_tensor_product_fourier_2D(
    x1, x2,
    res_theta=100, res_phi=99, convolution_type="direct"
)
R = e3nn.rand_matrix(jax.random.PRNGKey(1))
tp_rotated = functional.gaunt_tensor_product_fourier_2D(
    x1.transform_by_matrix(R), x2.transform_by_matrix(R),
    res_theta=100, res_phi=99, convolution_type="direct"
)
tp_original.transform_by_matrix(R), tp_rotated

(1x0e+1x1o+1x2e
 [ 4.79991172  2.19338789  0.70070788 -0.63398606 -6.64275998  0.1874885
  -2.32374008 -1.61593553  1.62083508],
 1x0e+1x1o+1x2e
 [ 4.8060221   2.19399914  0.69835547 -0.634226   -6.72951817  0.34716344
  -2.55758929 -1.80203014  1.90055239])