In [20]:
import e3nn
from e3nn import io, o3
import torch
import numpy as np

In [21]:
def spectra_functions(lmax):
    """
    This function generates the spectra functions for a given lmax.

    Parameters:
    lmax (int): The maximum degree of the spherical harmonics.

    Returns:
    sph: SphericalTensor object.
    powerspectrum: Function for calculating the power spectrum.
    bispectrum: Function for calculating the bispectrum.
    """
    
    sph = io.SphericalTensor(lmax, p_val=1, p_arg=-1)

    powerspectrum_main = o3.ReducedTensorProducts(
        'ij=ji', i=sph, 
        filter_ir_out=['0e', '0o'])
    powerspectrum = lambda x : powerspectrum_main(x, x)


    bispectrum_main = o3.ReducedTensorProducts(
        'ijk=jik=ikj', i=sph, 
        filter_ir_mid=list(o3.Irrep.iterator(lmax)), 
        filter_ir_out=['0e', '0o'])
    bispectrum = lambda x : bispectrum_main(x, x, x)
    
    return sph, powerspectrum, bispectrum

lmax = 4
sph, powerspectrum, bispectrum = spectra_functions(lmax)



In [22]:
test_geometry = torch.tensor([
    [1, 0, 0],
    [-0.5, np.sqrt(3)/2, 0],
    [-0.5, -np.sqrt(3)/2, 0]
], dtype=torch.float)

In [23]:
test_signal = sph.with_peaks_at(test_geometry)
bispectrum(test_signal)

tensor([ 0.0000e+00,  9.0674e-02,  4.9877e-16,  8.7794e-02,  2.5970e-01,
         6.6256e-02,  1.7634e-16,  1.4285e-15,  1.1986e-15, -3.0292e-02,
         5.2794e-02,  1.8742e-01, -3.7753e-02,  6.2289e-02,  1.7312e-02])