In [124]:
import e3nn_jax as e3nn
import jax
import jax.numpy as jnp
import plotly.graph_objects as go

In [125]:
l = 3
l2 = 3  # must be l - 1, l or l + 1
m2 = 1  # between -l2 and l2

def get_change_of_basis_matrix(l, l2):
    rtp = e3nn.reduced_tensor_product_basis('ij', i='1o', j=e3nn.s2_irreps(l)[-1], keep_ir=e3nn.Irrep(l2, p=(-1) ** (l + 1)))
    return rtp

get_change_of_basis_matrix(l, l - 1).shape, get_change_of_basis_matrix(l, l).shape, get_change_of_basis_matrix(l, l + 1).shape

((3, 7, 5), (3, 7, 7), (3, 7, 9))

In [126]:
# Currently one-hot.
coeffs = e3nn.IrrepsArray(
    e3nn.s2_irreps(l2)[-1],
    jnp.asarray([1.0 if i == m2 else 0.0 for i in range(-l2, l2 + 1)]))
coeffs.shape

(7,)

In [127]:
def to_vector_coeffs(coeffs, l):
    l2 = coeffs.irreps.lmax
    rtp = get_change_of_basis_matrix(l, l2)
    vector_coeffs = jnp.einsum('ijk,k->ij', rtp.array, coeffs.array)
    # TODO: Check parity.
    vector_coeffs = e3nn.IrrepsArray(e3nn.s2_irreps(l)[-1], vector_coeffs)
    return vector_coeffs


def to_vector_signal(coeffs, l):
    vector_coeffs = to_vector_coeffs(coeffs, l)
    # TODO: Check parity.
    vector_sig = e3nn.to_s2grid(
        vector_coeffs, res_beta=90, res_alpha=49, quadrature="soft",
        p_val=1, p_arg=-1)
    return vector_sig

In [128]:
def plot_vector_signal(sig: e3nn.SphericalSignal, scale_vec: float = 0.1):
    grid = sig.grid_vectors.transpose((2, 0, 1)).reshape((3, -1))
    values = sig.grid_values.reshape((3, -1))

    fig = go.Figure()
    fig.add_trace(go.Cone(x=grid[0, :],
                          y=grid[1, :],
                          z=grid[2, :],
                          u=scale_vec * values[0, :], 
                          v=scale_vec * values[1, :], 
                          w=scale_vec * values[2, :],
                          colorscale='Viridis',
                          sizemode="absolute",
                          sizeref=5,
                          showscale=True,
                          hoverinfo='skip'))
    return fig

In [129]:
plot_vector_signal(to_vector_signal(coeffs, l))

# Cross-Product

In [130]:
coeffs1 = e3nn.normal(
    e3nn.s2_irreps(l2)[-1],
    key=jax.random.PRNGKey(0))

sig1 = to_vector_signal(coeffs1, l)
plot_vector_signal(sig1)

In [131]:
coeffs2 = e3nn.normal(
    e3nn.s2_irreps(l2 + 1)[-1],
    key=jax.random.PRNGKey(1)
)

sig2 = to_vector_signal(coeffs2, l)
plot_vector_signal(sig2)

In [132]:
cross_sig = sig1.replace_values(
    jax.vmap(jax.vmap(jnp.cross, in_axes=-1, out_axes=-1), in_axes=-1, out_axes=-1)(sig1.grid_values, sig2.grid_values)
)
plot_vector_signal(cross_sig)

In [133]:
cross_sig_coeffs = e3nn.from_s2grid(cross_sig, e3nn.s2_irreps(8))
cross_sig_coeffs.array.max(), cross_sig_coeffs.array.min()

(Array(1.6106452, dtype=float32), Array(-2.7065215, dtype=float32))