In [78]:
import e3nn_jax as e3nn
import jax
import jax.numpy as jnp
import plotly.graph_objects as go

In [79]:
def get_change_of_basis_matrix(l: int, j: int) -> jnp.ndarray:
    rtp = e3nn.reduced_tensor_product_basis(
        "ij", i="1o", j=e3nn.s2_irreps(l)[-1], keep_ir=e3nn.Irrep(j, p=(-1) ** (l + 1))
    )
    return rtp


l = 2
print(
    get_change_of_basis_matrix(l, l - 1).shape,
    get_change_of_basis_matrix(l, l).shape,
    get_change_of_basis_matrix(l, l + 1).shape
)

(3, 5, 3) (3, 5, 5) (3, 5, 7)


In [80]:
def to_vector_coeffs(coeffs: e3nn.IrrepsArray, l: int) -> e3nn.IrrepsArray:
    j = coeffs.irreps.lmax
    rtp = get_change_of_basis_matrix(l, j)
    vector_coeffs = jnp.einsum("ijk,k->ij", rtp.array, coeffs.array)
    # TODO: Check parity.
    vector_coeffs = e3nn.IrrepsArray(e3nn.s2_irreps(l)[-1], vector_coeffs) # [3, 2l + 1]
    return vector_coeffs


def to_vector_signal(coeffs: e3nn.IrrepsArray, l: int, res_beta: int = 90, res_alpha: int = 89, quadrature="soft") -> e3nn.SphericalSignal:
    vector_coeffs = to_vector_coeffs(coeffs, l)
    # TODO: Check parity.
    vector_sig = e3nn.to_s2grid(
        vector_coeffs, res_beta=res_beta, res_alpha=res_alpha, quadrature=quadrature, p_val=1, p_arg=-1
    )
    return vector_sig

def vector_spherical_harmonics(l: int, j: int, mj: int) -> e3nn.SphericalSignal:
    if j not in [l - 1, l, l + 1]:
        raise ValueError(f"Invalid j={j} for l={l}.")

    if mj not in range(-j, j + 1):
        raise ValueError(f"Invalid mj={mj} for j={j}.")

    coeffs = e3nn.IrrepsArray(
        e3nn.s2_irreps(j)[-1],
        jnp.asarray([1.0 if i == mj else 0.0 for i in range(-j, j + 1)]),
    )
    return to_vector_signal(coeffs, l)

In [81]:
def plot_vector_signal(sig: e3nn.SphericalSignal, scale_vec: float = 0.1, title: str = None):
    grid = sig.grid_vectors.transpose((2, 0, 1)).reshape((3, -1))
    values = sig.grid_values.reshape((3, -1))

    fig = go.Figure()
    fig.add_trace(
        go.Cone(
            x=grid[0, :],
            y=grid[1, :],
            z=grid[2, :],
            u=scale_vec * values[0, :],
            v=scale_vec * values[1, :],
            w=scale_vec * values[2, :],
            colorscale="Viridis",
            sizemode="absolute",
            sizeref=5,
            showscale=True,
            hoverinfo="skip",
        )
    )
    if title is not None:
        fig.update_layout(title=title)
    return fig

In [82]:
l, j, mj = 1, 1, 1
plot_vector_signal(vector_spherical_harmonics(l, j, mj), title=f"l={l}, j={j}, mj={mj}")

## Reconstruction

Check that we can emulate from_s2grid() and to_s2grid() with VSH.

In [83]:
def wrap_fn_for_vector_signal(fn):
    """vmaps a fn over res_beta and res_alpha axes."""
    fn = jax.vmap(fn, in_axes=-1, out_axes=-1)
    fn = jax.vmap(fn, in_axes=-1, out_axes=-1)
    return fn


def get_vsh_coeffs_at_mj(sig, l_out, j_out, mj_out):
    """Returns the component of Y_{l_out, j_out, mj_out} in the signal sig."""
    vsh_signal = vector_spherical_harmonics(l_out, j_out, mj_out)
    dot_product = sig.replace_values(
        wrap_fn_for_vector_signal(jnp.dot)(sig.grid_values, vsh_signal.grid_values)
    )
    return dot_product.integrate().array.item() / (4 * jnp.pi)


def get_vsh_coeffs(sig, l_out, j_out):
    """Returns the components of Y_{l_out, j_out, mj_out} in the signal sig for all mj_out in [-j_out, ..., j_out]."""
    computed_coeffs = jnp.asarray([get_vsh_coeffs_at_mj(sig, l_out, j_out, mj_out) for mj_out in range(-j_out, j_out + 1)])
    computed_coeffs = e3nn.IrrepsArray(e3nn.s2_irreps(j_out)[-1], computed_coeffs)
    return computed_coeffs

In [84]:
# Reconstruction example
l, j = 4, 5
coeffs1 = e3nn.normal(e3nn.s2_irreps(j)[-1], key=jax.random.PRNGKey(1))
sig1 = to_vector_signal(coeffs1, l)
reconstructed_coeffs1 = get_vsh_coeffs(sig1, l, j)
reconstructed_coeffs1, coeffs1

(1x5o
 [ 0.07545024 -1.0032071  -1.1431501  -0.14357094  0.59042287 -0.26625177
  -0.03499872  0.14507414  0.23655684  1.5627992  -1.1358407 ],
 1x5o
 [ 0.07545026 -1.0032071  -1.1431501  -0.14357094  0.5904228  -0.26625177
  -0.03499871  0.14507416  0.23655683  1.5627991  -1.1358407 ])

## Cross-Product

In [85]:
l, j = 1, 1
coeffs1 = e3nn.normal(e3nn.s2_irreps(j)[-1], key=jax.random.PRNGKey(0))
sig1 = to_vector_signal(coeffs1, l)
plot_vector_signal(sig1, title="Signal 1")

In [86]:
l, j = 1, 1
coeffs2 = e3nn.normal(e3nn.s2_irreps(j)[-1], key=jax.random.PRNGKey(1))
sig2 = to_vector_signal(coeffs2, l)
plot_vector_signal(sig2, title="Signal 2")

In [87]:
cross_sig = sig1.replace_values(
    wrap_fn_for_vector_signal(jnp.cross)(
        sig1.grid_values, sig2.grid_values
    )
)
plot_vector_signal(cross_sig, title="Cross Product")

In [88]:
for l in range(1, 5):
    for j in [l - 1, l, l + 1]:
        cross_sig_coeffs = get_vsh_coeffs(cross_sig, l, j)
        print(cross_sig_coeffs.irreps, cross_sig_coeffs.array.round(2))

1x0e [0.]
1x1o [ 0. -0. -0.]
1x2e [-0.  0.  0.  0.  0.]
1x1o [-0.26 -1.53 -0.77]
1x2e [-0.  0. -0. -0.  0.]
1x3o [-0.  0.  0.  0. -0.  0.  0.]
1x2e [ 0. -0. -0.  0. -0.]
1x3o [ 0. -0. -0. -0.  0. -0.  0.]
1x4e [-0. -0. -0. -0.  0.  0. -0. -0.  0.]
1x3o [-0.  0. -0. -0. -0. -0.  0.]
1x4e [ 0. -0.  0.  0.  0. -0.  0.  0.  0.]
1x5o [-0.  0. -0.  0.  0. -0. -0.  0. -0. -0.  0.]


## Vector Gaunt Tensor Product

In [None]:
# Dot product of vector fields
def vector_gaunt_tensor_product(c:)