In [19]:
import e3nn_jax as e3nn
import jax
import optax
import jax.numpy as jnp

In [20]:
def powerspectrum(x: e3nn.IrrepsArray) -> e3nn.IrrepsArray:
    """
    Computes the power spectrum given an array of irreps.

    Parameters:
        x (e3nn.IrrepsArray): Input array of irreducible representations.

    Returns:
        e3nn.IrrepsArray: The power spectrum of the input array.
    """
    rtp = e3nn.reduced_symmetric_tensor_product_basis(x.irreps, 2, keep_ir=['0o', '0e'])
    return e3nn.IrrepsArray(rtp.irreps, jnp.einsum("i,j,ijz->z", x.array, x.array, rtp.array)).array


def bispectrum(x: e3nn.IrrepsArray) -> e3nn.IrrepsArray:
    """
    Computes the bispectrum given an array of irreps.

    Parameters:
        x (e3nn.IrrepsArray): Input array of irreps.

    Returns:
        e3nn.IrrepsArray: The bispectrum of the input array.
    """
    rtp = e3nn.reduced_symmetric_tensor_product_basis(x.irreps, 3, keep_ir=['0o', '0e'])
    return e3nn.IrrepsArray(rtp.irreps, jnp.einsum("i,j,k,ijkz->z", x.array, x.array, x.array, rtp.array)).array


def trispectrum(x: e3nn.IrrepsArray) -> e3nn.IrrepsArray:
    """
    Computes the trispectrum given an array of irreps.

    Parameters:
        x (e3nn.IrrepsArray): Input array of irreps.

    Returns:
        e3nn.IrrepsArray: The trispectrum of the input array.
    """
    rtp = e3nn.reduced_symmetric_tensor_product_basis(x.irreps, 4, keep_ir=['0o', '0e'])
    return e3nn.IrrepsArray(rtp.irreps, jnp.einsum("i,j,k,l,ijklz->z", x.array, x.array, x.array, x.array, rtp.array)).array


def with_peaks_at(vectors, lmax):
    """
    Compute a spherical harmonics expansion given Dirac delta functions defined on the sphere.

    Parameters:
        vectors (jnp.ndarray): An array of vectors. Each row represents a vector.
        lmax (int): The maximum degree of the spherical harmonics expansion.

    Returns:
        e3nn.IrrepsArray: An array representing the weighted sum of the spherical harmonics expansion.
    """
    values = jnp.linalg.norm(vectors, axis=1)

    mask = values != 0
    vectors = jnp.where(mask[:, None], vectors, 0)
    values = jnp.where(mask, values, 0)
 
    coeff_list = [e3nn.spherical_harmonics(i, e3nn.IrrepsArray("1o", vectors), normalize=True).array for i in range(lmax + 1)]
    coeff = jnp.concatenate(coeff_list, axis=1)
    
    A = jnp.einsum(
        "ai,bi->ab",
        coeff,
        coeff
    )
    solution = jnp.array(jnp.linalg.lstsq(A, values)[0])  
    
    assert jnp.max(jnp.abs(values - A @ solution)) < 1e-5 * jnp.max(jnp.abs(values))

    sh_expansion = solution @ coeff
    
    irreps = e3nn.s2_irreps(lmax)
    
    return e3nn.IrrepsArray(irreps, sh_expansion)

In [21]:
true_geometry = jnp.array([
    [1, 0, 0],
    [-0.5, jnp.sqrt(3)/2, 0],
    [-0.5, -jnp.sqrt(3)/2, 0]
])

In [22]:
test_signal = with_peaks_at(true_geometry, lmax=4)
bispectrum(test_signal)

Array([ 2.0354760e-03,  5.7810184e-17,  1.9708409e-03,  5.8298265e-03,
        1.4873402e-03,  2.0438986e-17,  1.1122715e-16,  1.6390467e-16,
       -6.8000465e-04,  1.1851409e-03,  4.2073145e-03, -8.4749062e-04,
        1.3982885e-03,  3.8863564e-04,  0.0000000e+00], dtype=float32)