In [None]:
%load_ext autoreload

In [None]:
%autoreload 2
import e3nn_jax as e3nn
import jax
import jax.numpy as jnp
import plotly.graph_objects as go

import itertools
import sys

sys.path.append(".")
from vector_spherical_harmonics import (
    VSHCoeffs,
    cross_product,
    dot_product,
)

jnp.set_printoptions(precision=3)
jnp.set_printoptions(suppress=True)

We store the coefficients for each Y(l, j, mj) in a dictionary.

In [None]:
VSHCoeffs.get_vsh_irreps(3, parity=1), VSHCoeffs.get_vsh_irreps(3, parity=-1)

This creates the Y(j, l, mj) by reducing the 1 x l representation into a (l - 1), l and (l + 1) representation.

In [None]:
coeffs_dict1 = VSHCoeffs(parity=-1)
coeffs_dict1[(1, 1)] = e3nn.IrrepsArray("1e", jnp.asarray([1.0, 0.0, 0.0]))
print(coeffs_dict1.to_vector_coeffs())

coeffs_dict1 = VSHCoeffs.normal(jmax=3, parity=-1, key=jax.random.PRNGKey(0))
print(coeffs_dict1.to_vector_coeffs())

In [None]:
def plot_vector_signal(
    sig: e3nn.SphericalSignal, scale_vec: float = 0.1, title: str = None
):
    """Plots a vector spherical signal."""
    grid = sig.grid_vectors.transpose((2, 0, 1)).reshape((3, -1))
    values = sig.grid_values.reshape((3, -1))

    fig = go.Figure()
    fig.add_trace(
        go.Cone(
            x=grid[0, :],
            y=grid[1, :],
            z=grid[2, :],
            u=scale_vec * values[0, :],
            v=scale_vec * values[1, :],
            w=scale_vec * values[2, :],
            colorscale="Viridis",
            sizemode="absolute",
            sizeref=5,
            showscale=True,
            hoverinfo="skip",
        )
    )
    if title is not None:
        fig.update_layout(title=title)
    return fig

In [None]:
j, l, mj = 2, 1, 0
plot_vector_signal(
    VSHCoeffs.vector_spherical_harmonics(j, l, mj).to_vector_signal(
        res_beta=40, res_alpha=39, quadrature="soft"
    ),
    title=f"VSH j={j}, l={l}, mj={mj}",
)

In [None]:
def get_change_of_basis_matrix(jmax: int, parity: int) -> jnp.ndarray:
    """Returns the change of basis matrix."""
    if parity not in [1, -1]:
        raise ValueError(f"Invalid parity {parity}.")

    return e3nn.reduced_tensor_product_basis(
        "ij",
        i=e3nn.Irrep(1, parity),
        j=e3nn.s2_irreps(jmax - 1),
    )

rtp = get_change_of_basis_matrix(3, parity=-1)
print(rtp.shape)
vsh_coeffs = e3nn.normal(rtp.irreps, key=jax.random.PRNGKey(0))
vsh_coeffs

In [None]:
xyz_coeffs = jnp.einsum("ijk,k->ij", rtp.array, vsh_coeffs.array)
xyz_coeffs

In [None]:
vsh_coeffs_new = e3nn.IrrepsArray(rtp.irreps, jnp.einsum("ijk,ij->k", rtp.array, xyz_coeffs))
assert jnp.allclose(vsh_coeffs.array, vsh_coeffs_new.array)
vsh_coeffs_new

In [None]:
e3nn.Irreps(e3nn.Irrep(0, 1))

In [None]:
e3nn.IrrepsArray(e3nn.Irreps("0e"), jnp.)

In [None]:
e3nn.tensor_product("1o", e3nn.s2_irreps(2, p_val=1)).filter(lmax=0)[0].ir.p

In [None]:
VSHCoeffs.get_vsh_irreps(jmax=1, parity=1).filter(lmax=0)

In [None]:
e3nn.normal("1x0e+2x1o+1x1e+1x2e+1x2o+1x3o").rechunk()

## Reconstruction

Check that we can emulate from_s2grid() and to_s2grid() with VSH.

In [None]:
# Reconstruction example
lmax = 3
coeffs_dict1 = VSHCoeffs.normal(jmax=lmax, parity=-1, key=jax.random.PRNGKey(0))
sig1 = coeffs_dict1.to_vector_signal(res_beta=40, res_alpha=39, quadrature="soft")
reconstructed_coeffs_dict1 = VSHCoeffs.from_vector_signal(sig1, lmax, parity=-1)
print(reconstructed_coeffs_dict1)
print(coeffs_dict1)
for j, l in reconstructed_coeffs_dict1.keys():
    assert reconstructed_coeffs_dict1[(j, l)].irreps == coeffs_dict1[(j, l)].irreps
    assert jnp.allclose(
        reconstructed_coeffs_dict1[(j, l)].array, coeffs_dict1[(j, l)].array
    )

## Cross Product

In [None]:
def cross_gaunt_tensor_product(
    coeffs_dict1: VSHCoeffs, coeffs_dict2: VSHCoeffs, output_lmax: int
) -> VSHCoeffs:
    sig1 = coeffs_dict1.to_vector_signal(res_beta=40, res_alpha=39, quadrature="soft")
    sig2 = coeffs_dict2.to_vector_signal(res_beta=40, res_alpha=39, quadrature="soft")
    cross_sig = cross_product(sig1, sig2)
    return VSHCoeffs.from_vector_signal(
        cross_sig, jmax=output_lmax, parity=coeffs_dict1.parity * coeffs_dict2.parity
    )

In [None]:
# Visualizing the cross product of two VSH
coeffs_dict1 = VSHCoeffs(parity=-1)
coeffs_dict1[(1, 2)] = e3nn.IrrepsArray("1o", jnp.asarray([1.0, 0.0, 0.0]))
plot_vector_signal(
    coeffs_dict1.to_vector_signal(res_beta=40, res_alpha=39, quadrature="soft"),
    title="Signal 1",
)

In [None]:
coeffs_dict2 = VSHCoeffs(parity=-1)
coeffs_dict2[(1, 2)] = e3nn.IrrepsArray("1o", jnp.asarray([0.0, 1.0, 0.0]))
plot_vector_signal(
    coeffs_dict2.to_vector_signal(res_beta=40, res_alpha=39, quadrature="soft"),
    title="Signal 2",
)

In [None]:
cross_product_dict = cross_gaunt_tensor_product(
    coeffs_dict1, coeffs_dict2, output_lmax=2
)
print(cross_product_dict)
plot_vector_signal(
    cross_product_dict.to_vector_signal(res_beta=40, res_alpha=39, quadrature="soft"),
    title="Cross Product",
)

In [None]:
cross_product_dict

In [None]:
flipped_cross_product = cross_gaunt_tensor_product(
    coeffs_dict2, coeffs_dict1, output_lmax=2
)
plot_vector_signal(
    flipped_cross_product.to_vector_signal(
        res_beta=40, res_alpha=39, quadrature="soft"
    ),
    title="Flipped Cross Product",
)

In [None]:
flipped_cross_product

In [None]:
# PVSH x VSH
coeffs_dict1 = VSHCoeffs(parity=1)
coeffs_dict1[(2, 2)] = e3nn.normal("2e", jax.random.PRNGKey(0))

coeffs_dict2 = VSHCoeffs(parity=-1)
coeffs_dict2[(2, 3)] = e3nn.normal("2e", jax.random.PRNGKey(1))

(cross_gaunt_tensor_product(coeffs_dict1, coeffs_dict2, output_lmax=4).filter("4e"))

In [None]:
# VSH x PVSH
coeffs_dict1 = VSHCoeffs(parity=-1)
coeffs_dict1[(2, 2)] = e3nn.normal("2o", jax.random.PRNGKey(0))

coeffs_dict2 = VSHCoeffs(parity=-1)
coeffs_dict2[(2, 2)] = e3nn.normal("2o", jax.random.PRNGKey(1))

(cross_gaunt_tensor_product(coeffs_dict1, coeffs_dict2, output_lmax=4).filter("4e"))

# Dot Product

In [None]:
def get_ssh_coeffs(sig: e3nn.SphericalSignal, lmax: int) -> e3nn.IrrepsArray:
    """Returns the components of the scalar spherical harmonics for each l and m in [-l, ..., l]."""
    return e3nn.from_s2grid(sig, irreps=e3nn.s2_irreps(lmax))


# Dot product of vector fields
def dot_gaunt_tensor_product(
    coeffs_dict1: VSHCoeffs, coeffs_dict2: VSHCoeffs, output_lmax: int
) -> e3nn.IrrepsArray:
    sig1 = coeffs_dict1.to_vector_signal(res_beta=40, res_alpha=39, quadrature="soft")
    sig2 = coeffs_dict2.to_vector_signal(res_beta=40, res_alpha=39, quadrature="soft")
    dot_sig = dot_product(sig1, sig2)
    return get_ssh_coeffs(dot_sig, lmax=output_lmax)

In [None]:
c1 = e3nn.IrrepsArray("1o", jnp.asarray([1.0, 0.0, 0.0]))
c2 = e3nn.IrrepsArray("1o", jnp.asarray([1.0, 1.0, 1.0]))

coeffs_dict1 = VSHCoeffs(parity=-1)
coeffs_dict1[(1, 2)] = c1

coeffs_dict2 = VSHCoeffs(parity=-1)
coeffs_dict2[(1, 2)] = c2

(
    dot_gaunt_tensor_product(coeffs_dict1, coeffs_dict2, output_lmax=4),
    dot_gaunt_tensor_product(coeffs_dict2, coeffs_dict1, output_lmax=4),
    e3nn.cross(c1, c2),
)

# Full Vector Gaunt Tensor Product

- Go from irreps1 -> VSHCoeffs1 using Linear layer,
similarly for irreps2 -> VSHCoeffs2.

- Do the cross product of the VSHCoeffss to get VSHCoeffs3.

- Finally, go from VSHCoeffs3 -> irreps3 using Linear layer.

In [None]:
import haiku as hk


@hk.without_apply_rng
@hk.transform
def full_gaunt_tensor_product(
    coeffs1: e3nn.IrrepsArray,
    coeffs2: e3nn.IrrepsArray,
    parity1: int,
    parity2: int,
    output_lmax: int,
) -> e3nn.IrrepsArray:
    coeffs1 = e3nn.haiku.Linear(
        VSHCoeffs.get_vsh_irreps(coeffs1.irreps.lmax, parity=parity1), force_irreps_out=True
    )(coeffs1)
    coeffs2 = e3nn.haiku.Linear(
        VSHCoeffs.get_vsh_irreps(coeffs2.irreps.lmax, parity=parity2), force_irreps_out=True
    )(coeffs2)

    coeffs_dict1 = VSHCoeffs.from_irreps_array(coeffs1)
    coeffs_dict2 = VSHCoeffs.from_irreps_array(coeffs2)
    cross_product_dict = cross_gaunt_tensor_product(
        coeffs_dict1, coeffs_dict2, output_lmax=output_lmax
    )

    cross_product_coeffs = cross_product_dict.to_irreps_array()
    cross_product_coeffs = e3nn.haiku.Linear(
        VSHCoeffs.get_vsh_irreps(output_lmax, parity=parity1 * parity2)
    )(cross_product_coeffs)
    return cross_product_coeffs


coeffs1 = e3nn.normal(e3nn.s2_irreps(3), jax.random.PRNGKey(0))
coeffs2 = e3nn.normal(e3nn.s2_irreps(3), jax.random.PRNGKey(1))

params = full_gaunt_tensor_product.init(
    jax.random.PRNGKey(0),
    coeffs1,
    coeffs2,
    parity1=-1,
    parity2=-1,
    output_lmax=4,
)
params

In [None]:
jax.jit(
    full_gaunt_tensor_product.apply,
    static_argnames=["output_lmax", "parity1", "parity2"],
)(
    params,
    coeffs1,
    coeffs2,
    parity1=-1,
    parity2=-1,
    output_lmax=4,
)

# Some optimizations

In [None]:
e3nn.tensor_product("1o", e3nn.s2_irreps(3), regroup_output=False)

In [None]:
matrix = e3nn.reduced_tensor_product_basis("ij", i="1o", j="4e")

In [None]:
matrix.array.shape, matrix

In [None]:
(~ jnp.isclose(matrix.array, 0)).sum() / matrix.array.size