In [None]:
from typing import Tuple, Dict, Any
import itertools

import e3nn_jax as e3nn
import jax
import jax.numpy as jnp
import plotly.graph_objects as go

jnp.set_printoptions(precision=3)
jnp.set_printoptions(suppress=True)

We store the coefficients for each Y(l, j, mj) in a dictionary.

In [None]:
CoeffsDict = Dict[Tuple[int, int], e3nn.IrrepsArray]


def vsh_iterator(lmax: int):
    """Iterates over all VSH up to some lmax."""
    for l in range(lmax + 1):
        for j in [l - 1, l, l + 1]:
            if l == 0 and j != 1:
                continue
            yield j, l


def create_zeros_coeffs_dict(lmax: int) -> CoeffsDict:
    """Creates a dictionary of all-zeros coefficients for each VSH."""
    return {
        (j, l): e3nn.zeros(e3nn.Irrep(j, (-1) ** (l + 1)))
        for j, l in vsh_iterator(lmax)
    }


def create_random_coeffs_dict(lmax: int, key: jax.random.PRNGKey) -> CoeffsDict:
    """Creates a dictionary of random coefficients for each VSH."""
    coeffs_dict = create_zeros_coeffs_dict(lmax)
    for j, l in coeffs_dict.keys():
        coeffs_dict[(j, l)] = e3nn.normal(coeffs_dict[(j, l)].irreps, key)
        key, _ = jax.random.split(key)
    return coeffs_dict


def check_coeffs_dict(coeffs_dict: CoeffsDict):
    """Checks that the coefficients dictionary is well-formed."""
    for (j, l), v in coeffs_dict.items():
        assert (
            v.irreps.num_irreps == 1
        ), f"Invalid count {v.irreps.count} for VSH {j, l}."
        mul, ir = v.irreps[0]
        assert l - 1 <= j <= l + 1, f"Invalid j={j} for VSH {j, l}."
        assert mul == 1, f"Invalid multiplicity {mul} for VSH {j, l}."
        assert ir.l == j, f"Invalid l={ir.l} for VSH {j, l}."
        assert ir.p == (-1) ** (l + 1), f"Invalid p={ir.p} for VSH {j, l}."


def get_lmax(coeffs_dict: CoeffsDict) -> int:
    """Returns the maximum l in a dictionary of coefficients."""
    return max(l for _, l in coeffs_dict.keys())


def get_vsh_irreps(lmax: int) -> e3nn.Irreps:
    """Returns the irreps for the VSH upto some lmax."""
    return e3nn.Irreps([e3nn.Irrep(j, (-1) ** (l + 1)) for j, l in vsh_iterator(lmax)])


# Conversion to standard IrrepsArrays. We assume some ordering.
def coeffs_dict_to_irreps_array(coeffs_dict: CoeffsDict):
    """Converts a dictionary of VSH coefficients to an IrrepsArray."""
    return e3nn.concatenate([v for v in coeffs_dict.values()])


def irreps_array_to_coeffs_dict(irreps_array: e3nn.IrrepsArray) -> CoeffsDict:
    """Converts an IrrepsArray to a dictionary of VSH coefficients."""
    jmax = irreps_array.irreps.lmax
    assert irreps_array.irreps == get_vsh_irreps(
        jmax - 1
    ), f"Invalid irreps {irreps_array.irreps} for VSH."

    coeffs_dict = create_zeros_coeffs_dict(jmax - 1)
    for (j, l), (ir_mul, ir), chunk in zip(
        coeffs_dict.keys(), irreps_array.irreps, irreps_array.chunks
    ):
        if ir_mul != 1:
            raise ValueError(f"Invalid multiplicity {ir_mul} for VSH. Expected 1.")
        if ir.l != j:
            raise ValueError(f"Invalid irrep {ir} for VSH. Expected {j}.")
        if ir.p != (-1) ** (l + 1):
            raise ValueError(
                f"Invalid parity {ir.p} for VSH. Expected {(-1) ** (l + 1)}."
            )

        coeffs_dict[(j, l)] = e3nn.IrrepsArray(ir, chunk)
    return coeffs_dict

In [None]:
x = e3nn.Irreps("2x0e + 2x1o + 2x2e")
x.num_irreps

In [None]:
get_vsh_irreps(3)

This creates the Y(j, l, mj) by reducing the 1 x l representation into a (l - 1), l and (l + 1) representation.

In [None]:
def get_change_of_basis_matrices(lmax: int) -> jnp.ndarray:
    """Returns the change of basis for each (j, l) pair."""
    rtps = {}
    for j, l in vsh_iterator(lmax):
        rtp = e3nn.reduced_tensor_product_basis(
            "ij",
            i="1o",
            j=e3nn.Irrep(l, (-1) ** (l)),
            keep_ir=e3nn.Irrep(j, (-1) ** (l + 1)),
        )
        rtps[(j, l)] = rtp
    return rtps

In [None]:
get_change_of_basis_matrices(lmax=2)

In [None]:
def to_vector_coeffs(coeffs_dict: CoeffsDict) -> e3nn.IrrepsArray:
    """Converts a dictionary of VSH coefficients to a 3D IrrepsArray."""
    check_coeffs_dict(coeffs_dict)

    rtps = get_change_of_basis_matrices(lmax=get_lmax(coeffs_dict))
    all_vector_coeffs = []
    for j, l in coeffs_dict.keys():
        rtp = rtps[(j, l)]
        if rtp.array.shape[-1] != coeffs_dict[(j, l)].array.shape[0]:
            raise ValueError(
                f"Invalid shape {coeffs_dict[(j, l)].shape} for coefficients with j={j}, l={l}."
            )

        vector_coeffs = jnp.einsum("ijk,k->ij", rtp.array, coeffs_dict[(j, l)].array)
        vector_coeffs = e3nn.IrrepsArray(e3nn.s2_irreps(l)[-1], vector_coeffs)
        all_vector_coeffs.append(vector_coeffs)
    return e3nn.concatenate(all_vector_coeffs)


def to_vector_signal(
    coeffs_dict: CoeffsDict, res_beta: int = 90, res_alpha: int = 89, quadrature="soft"
) -> e3nn.SphericalSignal:
    """Converts a dictionary of VSH coefficients to a vector spherical signal."""
    vector_coeffs = to_vector_coeffs(coeffs_dict)
    vector_coeffs = e3nn.sum(vector_coeffs.regroup(), axis=-1)
    vector_sig = e3nn.to_s2grid(
        vector_coeffs,
        res_beta=res_beta,
        res_alpha=res_alpha,
        quadrature=quadrature,
        p_val=1,
        p_arg=-1,
    )
    return vector_sig


def vector_spherical_harmonics(j: int, l: int, mj: int) -> e3nn.SphericalSignal:
    """Returns a vector spherical harmonic for a given (j, l, mj)."""
    if j not in [l - 1, l, l + 1]:
        raise ValueError(f"Invalid j={j} for l={l}.")

    if mj not in range(-j, j + 1):
        raise ValueError(f"Invalid mj={mj} for j={j}.")

    coeffs = e3nn.IrrepsArray(
        e3nn.Irrep(j, (-1) ** (l + 1)),
        jnp.asarray([1.0 if i == mj else 0.0 for i in range(-j, j + 1)]),
    )
    coeffs_dict = {(j, l): coeffs}
    return to_vector_signal(coeffs_dict)

In [None]:
coeffs_dict1 = {
    (1, 1): e3nn.IrrepsArray("1e", jnp.asarray([1.0, 0.0, 0.0])),
}
print(to_vector_coeffs(coeffs_dict1))

coeffs_dict1 = create_random_coeffs_dict(lmax=3, key=jax.random.PRNGKey(0))
print(to_vector_coeffs(coeffs_dict1))

In [None]:
def plot_vector_signal(
    sig: e3nn.SphericalSignal, scale_vec: float = 0.1, title: str = None
):
    """Plots a vector spherical signal."""
    grid = sig.grid_vectors.transpose((2, 0, 1)).reshape((3, -1))
    values = sig.grid_values.reshape((3, -1))

    fig = go.Figure()
    fig.add_trace(
        go.Cone(
            x=grid[0, :],
            y=grid[1, :],
            z=grid[2, :],
            u=scale_vec * values[0, :],
            v=scale_vec * values[1, :],
            w=scale_vec * values[2, :],
            colorscale="Viridis",
            sizemode="absolute",
            sizeref=5,
            showscale=True,
            hoverinfo="skip",
        )
    )
    if title is not None:
        fig.update_layout(title=title)
    return fig

In [None]:
j, l, mj = 2, 1, 0
plot_vector_signal(
    vector_spherical_harmonics(j, l, mj), title=f"VSH j={j}, l={l}, mj={mj}"
)

## Reconstruction

Check that we can emulate from_s2grid() and to_s2grid() with VSH.

In [None]:
def wrap_fn_for_vector_signal(fn):
    """vmaps a fn over res_beta and res_alpha axes."""
    fn = jax.vmap(fn, in_axes=-1, out_axes=-1)
    fn = jax.vmap(fn, in_axes=-1, out_axes=-1)
    return fn


def get_vsh_coeffs_at_mj(
    sig: e3nn.SphericalSignal, j_out: int, l_out: int, mj_out: int
) -> float:
    """Returns the component of Y_{j_out, l_out, mj_out} in the signal sig."""
    vsh_signal = vector_spherical_harmonics(j_out, l_out, mj_out)
    dot_product = sig.replace_values(
        wrap_fn_for_vector_signal(jnp.dot)(sig.grid_values, vsh_signal.grid_values)
    )
    return dot_product.integrate().array.item() / (4 * jnp.pi)


def get_vsh_coeffs_at_j(
    sig: e3nn.SphericalSignal,
    j_out: int,
    l_out: int,
) -> e3nn.IrrepsArray:
    """Returns the components of Y_{j_out, l_out, mj_out} in the signal sig for all mj_out in [-j_out, ..., j_out]."""
    computed_coeffs = jnp.asarray(
        [
            get_vsh_coeffs_at_mj(sig, j_out, l_out, mj_out)
            for mj_out in range(-j_out, j_out + 1)
        ]
    )
    computed_coeffs = e3nn.IrrepsArray(
        e3nn.Irrep(j_out, (-1) ** (l_out + 1)), computed_coeffs
    )
    return computed_coeffs


def get_vsh_coeffs(sig: e3nn.SphericalSignal, lmax: int) -> CoeffsDict:
    """Returns the components of Y_{j_out, l_out, mj_out} in the signal sig for all mj_out in [-j_out, ..., j_out] and j_out in [-l_out, ..., l_out] and l_out upto lmax."""
    if sig.shape[-3] != 3:
        raise ValueError(f"Invalid shape {sig.shape} for signal.")

    result = {}
    for j_out, l_out in vsh_iterator(lmax):
        result[j_out, l_out] = get_vsh_coeffs_at_j(sig, j_out, l_out)
    return result


def get_ssh_coeffs(sig: e3nn.SphericalSignal, lmax: int) -> e3nn.IrrepsArray:
    """Returns the components of the scalar spherical harmonics for each l and m in [-l, ..., l]."""
    return e3nn.from_s2grid(sig, irreps=e3nn.s2_irreps(lmax))

In [None]:
# Reconstruction example
lmax = 3
coeffs_dict1 = create_random_coeffs_dict(lmax, key=jax.random.PRNGKey(0))
sig1 = to_vector_signal(coeffs_dict1)
reconstructed_coeffs_dict1 = get_vsh_coeffs(sig1, lmax)
print(reconstructed_coeffs_dict1)
print(coeffs_dict1)
print(
    jax.tree_map(
        lambda x, y: jnp.isclose(x, y, atol=1e-5),
        coeffs_dict1,
        reconstructed_coeffs_dict1,
    )
)
print(
    all(
        jax.tree_leaves(
            jax.tree_map(
                lambda x, y: jnp.allclose(x, y, atol=1e-5),
                coeffs_dict1,
                reconstructed_coeffs_dict1,
            )
        )
    )
)

## Cross Product

In [None]:
# Cross product of vector fields
def cross_gaunt_tensor_product(
    coeffs_dict1: CoeffsDict, coeffs_dict2: CoeffsDict, output_lmax: int
) -> e3nn.IrrepsArray:
    sig1 = to_vector_signal(coeffs_dict1)
    sig2 = to_vector_signal(coeffs_dict2)
    cross_sig = sig1.replace_values(
        wrap_fn_for_vector_signal(jnp.cross)(sig1.grid_values, sig2.grid_values)
    )
    return get_vsh_coeffs(cross_sig, lmax=output_lmax)

In [None]:
# Visualizing the cross product of two VSH
coeffs_dict1 = {
    (1, 2): e3nn.IrrepsArray("1o", jnp.asarray([1.0, 1.0, 1.0])),
}
plot_vector_signal(to_vector_signal(coeffs_dict1), title="Signal 1")

In [None]:
coeffs_dict2 = {
    (1, 1): e3nn.IrrepsArray("1e", jnp.asarray([0.0, 1.0, 0.0])),
}
plot_vector_signal(to_vector_signal(coeffs_dict2), title="Signal 2")

In [None]:
cross_product_dict = cross_gaunt_tensor_product(
    coeffs_dict1, coeffs_dict2, output_lmax=2
)
plot_vector_signal(to_vector_signal(cross_product_dict), title="Cross Product")

In [None]:
cross_product_dict

In [None]:
flipped_cross_product = cross_gaunt_tensor_product(
    coeffs_dict2, coeffs_dict1, output_lmax=2
)
plot_vector_signal(
    to_vector_signal(flipped_cross_product), title="Flipped Cross Product"
)

In [None]:
coeffs_dict1 = {
    (1, 2): e3nn.IrrepsArray("1o", jnp.asarray([1.0, 0.0, 1.0])),
}
coeffs_dict2 = {
    (1, 1): e3nn.IrrepsArray("1e", jnp.asarray([1.0, 1.0, 1.0])),
}

(
    cross_gaunt_tensor_product(coeffs_dict1, coeffs_dict2, output_lmax=3),
    cross_gaunt_tensor_product(coeffs_dict2, coeffs_dict1, output_lmax=3),
)

## Check YuQing's selection rules

YuQing:
1. |l_i - 1|   <=  j_i  <= l_i + 1 for all i (usual selection rule for l_i times 1)
2. |l_1 - l_2| <=  l_3  <= l_1 + l_2 (usual selection rule for l_1 times l_2)
3. |j_1 - j_2| <=  j_3  <= j_1 + j_2 (usual selection rule for j_1 times j_2)
4. j_1 + j_2 + j_3 is even
5. There is no choice of a, b, c where l_a = j_a and (l_b, j_b) = (l_c, j_c)

First one is guaranteed by our construction.

In [None]:
coeffs_dict1 = {
    (3, 4): e3nn.IrrepsArray("3o", jnp.asarray([1.0, 0.0, 1.0, 1.0, 1.0, 0.0, 1.0])),
}
coeffs_dict2 = {
    (1, 1): e3nn.IrrepsArray("1e", jnp.asarray([1.0, 1.0, 1.0])),
}

cross_product_dict = cross_gaunt_tensor_product(
    coeffs_dict1, coeffs_dict2, output_lmax=4
)
cross_product_dict

In [None]:
for (j3, l3), coeffs in cross_product_dict.items():
    if jnp.allclose(cross_product_dict[(j3, l3)].array, 0, atol=1e-3):
        continue

    print(f"Checking (j3, l3) = ({j3}, {l3})")
    for j2, l2 in coeffs_dict2.keys():
        for j1, l1 in coeffs_dict1.keys():
            # Check first conditions
            assert l1 - 1 <= j1 <= l1 + 1
            assert l2 - 1 <= j2 <= l2 + 1
            assert l3 - 1 <= j3 <= l3 + 1

            # Check second condition
            assert abs(l1 - l2) <= l3 <= l1 + l2, (l1, l2, l3)

            # Check third condition
            assert abs(j1 - j2) <= j3 <= j1 + j2, (j1, j2, j3)

            # Check fourth condition
            assert (l1 + l2 + l3) % 2 == 0, (l1, l2, l3)

            # Check fifth condition
            ls = [l1, l2, l3]
            js = [j1, j2, j3]
            for a, b, c in itertools.permutations(range(3)):
                assert not ((ls[a], ls[b], js[b]) == (js[a], ls[c], js[c]))

# Dot Product

In [None]:
# Dot product of vector fields
def dot_gaunt_tensor_product(
    coeffs_dict1: CoeffsDict, coeffs_dict2: CoeffsDict, output_lmax: int
) -> e3nn.IrrepsArray:
    sig1 = to_vector_signal(coeffs_dict1)
    sig2 = to_vector_signal(coeffs_dict2)
    dot_sig = sig1.replace_values(
        wrap_fn_for_vector_signal(jnp.dot)(sig1.grid_values, sig2.grid_values)
    )
    return get_ssh_coeffs(dot_sig, lmax=output_lmax)

In [None]:
coeffs_dict1 = {
    (1, 2): e3nn.IrrepsArray("1o", jnp.asarray([1.0, 0.0, 1.0])),
}
coeffs_dict2 = {
    (1, 1): e3nn.IrrepsArray("1e", jnp.asarray([1.0, 1.0, 1.0])),
}

(
    dot_gaunt_tensor_product(coeffs_dict1, coeffs_dict2, output_lmax=4),
    dot_gaunt_tensor_product(coeffs_dict2, coeffs_dict1, output_lmax=4),
)