In [None]:
from typing import Tuple, Dict, Any, Sequence
import itertools

import e3nn_jax as e3nn
import chex
import jax
import jax.numpy as jnp
import plotly.graph_objects as go

import sys

sys.path.append(".")
from vector_spherical_harmonics import *

jnp.set_printoptions(precision=3)
jnp.set_printoptions(suppress=True)

We store the coefficients for each Y(l, j, mj) in a dictionary.

In [None]:
get_vsh_irreps(3, parity=1), get_vsh_irreps(3, parity=-1)

This creates the Y(j, l, mj) by reducing the 1 x l representation into a (l - 1), l and (l + 1) representation.

In [None]:
get_change_of_basis_matrices(jmax=2, parity=1)

In [None]:
def vector_spherical_harmonics(
    j: int, l: int, mj: int, parity: int = -1
) -> e3nn.SphericalSignal:
    """Returns a (pseudo)-vector spherical harmonic for a given (j, l, mj)."""
    if j not in [l - 1, l, l + 1]:
        raise ValueError(f"Invalid j={j} for l={l}.")

    if mj not in range(-j, j + 1):
        raise ValueError(f"Invalid mj={mj} for j={j}.")

    coeffs = e3nn.IrrepsArray(
        get_vsh_irrep(j, l, parity),
        jnp.asarray([1.0 if i == mj else 0.0 for i in range(-j, j + 1)]),
    )
    coeffs_dict = VSHCoeffs(parity=parity)
    coeffs_dict[(j, l)] = coeffs
    return coeffs_dict.to_vector_signal()

In [None]:
coeffs_dict1 = VSHCoeffs(parity=-1)
coeffs_dict1[(1, 1)] = e3nn.IrrepsArray("1e", jnp.asarray([1.0, 0.0, 0.0]))
print(coeffs_dict1.to_vector_coeffs())

coeffs_dict1 = VSHCoeffs.normal(jmax=3, parity=-1, key=jax.random.PRNGKey(0))
print(coeffs_dict1.to_vector_coeffs())

In [None]:
def plot_vector_signal(
    sig: e3nn.SphericalSignal, scale_vec: float = 0.1, title: str = None
):
    """Plots a vector spherical signal."""
    grid = sig.grid_vectors.transpose((2, 0, 1)).reshape((3, -1))
    values = sig.grid_values.reshape((3, -1))

    fig = go.Figure()
    fig.add_trace(
        go.Cone(
            x=grid[0, :],
            y=grid[1, :],
            z=grid[2, :],
            u=scale_vec * values[0, :],
            v=scale_vec * values[1, :],
            w=scale_vec * values[2, :],
            colorscale="Viridis",
            sizemode="absolute",
            sizeref=5,
            showscale=True,
            hoverinfo="skip",
        )
    )
    if title is not None:
        fig.update_layout(title=title)
    return fig

In [None]:
j, l, mj = 2, 1, 0
plot_vector_signal(
    vector_spherical_harmonics(j, l, mj), title=f"VSH j={j}, l={l}, mj={mj}"
)

## Reconstruction

Check that we can emulate from_s2grid() and to_s2grid() with VSH.

In [None]:
def wrap_fn_for_vector_signal(fn):
    """vmaps a fn over res_beta and res_alpha axes."""
    fn = jax.vmap(fn, in_axes=-1, out_axes=-1)
    fn = jax.vmap(fn, in_axes=-1, out_axes=-1)
    return fn


def get_vsh_coeffs_at_mj(
    sig: e3nn.SphericalSignal, j_out: int, l_out: int, mj_out: int
) -> float:
    """Returns the component of Y_{j_out, l_out, mj_out} in the signal sig."""
    vsh_signal = vector_spherical_harmonics(j_out, l_out, mj_out)
    dot_product = sig.replace_values(
        wrap_fn_for_vector_signal(jnp.dot)(sig.grid_values, vsh_signal.grid_values)
    )
    return dot_product.integrate().array[0] / (4 * jnp.pi)


def get_vsh_coeffs_at_j(
    sig: e3nn.SphericalSignal,
    j_out: int,
    l_out: int,
    parity_out: int,
) -> e3nn.IrrepsArray:
    """Returns the components of Y_{j_out, l_out, mj_out} in the signal sig for all mj_out in [-j_out, ..., j_out]."""
    computed_coeffs = jnp.stack(
        [
            get_vsh_coeffs_at_mj(sig, j_out, l_out, mj_out)
            for mj_out in range(-j_out, j_out + 1)
        ]
    )
    computed_coeffs = e3nn.IrrepsArray(
        get_vsh_irrep(j_out, l_out, parity_out), computed_coeffs
    )
    return computed_coeffs


def get_vsh_coeffs(sig: e3nn.SphericalSignal, lmax: int, parity: int) -> VSHCoeffs:
    """Returns the components of Y_{j_out, l_out, mj_out} in the signal sig for all mj_out in [-j_out, ..., j_out] and j_out in [-l_out, ..., l_out] and l_out upto lmax."""
    if sig.shape[-3] != 3:
        raise ValueError(f"Invalid shape {sig.shape} for signal.")

    result = VSHCoeffs(parity=parity)
    for j_out, l_out in vsh_iterator(lmax):
        result[j_out, l_out] = get_vsh_coeffs_at_j(sig, j_out, l_out, parity)
    return result


def get_ssh_coeffs(sig: e3nn.SphericalSignal, lmax: int) -> e3nn.IrrepsArray:
    """Returns the components of the scalar spherical harmonics for each l and m in [-l, ..., l]."""
    return e3nn.from_s2grid(sig, irreps=e3nn.s2_irreps(lmax))

In [None]:
# Reconstruction example
lmax = 3
coeffs_dict1 = VSHCoeffs.normal(jmax=lmax, parity=-1, key=jax.random.PRNGKey(0))
sig1 = coeffs_dict1.to_vector_signal()
reconstructed_coeffs_dict1 = get_vsh_coeffs(sig1, lmax, parity=-1)
print(reconstructed_coeffs_dict1)
print(coeffs_dict1)
for j, l in reconstructed_coeffs_dict1.keys():
    assert reconstructed_coeffs_dict1[(j, l)].irreps == coeffs_dict1[(j, l)].irreps
    assert jnp.allclose(
        reconstructed_coeffs_dict1[(j, l)].array, coeffs_dict1[(j, l)].array
    )

## Cross Product

In [None]:
def cross_gaunt_tensor_product(
    coeffs_dict1: VSHCoeffs, coeffs_dict2: VSHCoeffs, output_lmax: int
) -> VSHCoeffs:
    sig1 = coeffs_dict1.to_vector_signal()
    sig2 = coeffs_dict2.to_vector_signal()
    cross_sig = sig1.replace_values(
        wrap_fn_for_vector_signal(jnp.cross)(sig1.grid_values, sig2.grid_values)
    )
    return get_vsh_coeffs(
        cross_sig, lmax=output_lmax, parity=coeffs_dict1.parity * coeffs_dict2.parity
    )

In [None]:
# Visualizing the cross product of two VSH
coeffs_dict1 = VSHCoeffs(parity=-1)
coeffs_dict1[(1, 2)] = e3nn.IrrepsArray("1o", jnp.asarray([1.0, 0.0, 0.0]))
plot_vector_signal(coeffs_dict1.to_vector_signal(), title="Signal 1")

In [None]:
coeffs_dict2 = VSHCoeffs(parity=-1)
coeffs_dict2[(1, 2)] = e3nn.IrrepsArray("1o", jnp.asarray([0.0, 1.0, 0.0]))
plot_vector_signal(coeffs_dict2.to_vector_signal(), title="Signal 2")

In [None]:
cross_product_dict = cross_gaunt_tensor_product(
    coeffs_dict1, coeffs_dict2, output_lmax=2
)
print(cross_product_dict)
plot_vector_signal(cross_product_dict.to_vector_signal(), title="Cross Product")

In [None]:
cross_product_dict

In [None]:
flipped_cross_product = cross_gaunt_tensor_product(
    coeffs_dict2, coeffs_dict1, output_lmax=2
)
plot_vector_signal(
    flipped_cross_product.to_vector_signal(), title="Flipped Cross Product"
)

In [None]:
flipped_cross_product

In [None]:
# PVSH x VSH
coeffs_dict1 = VSHCoeffs(parity=1)
coeffs_dict1[(2, 2)] = e3nn.normal("2e", jax.random.PRNGKey(0))

coeffs_dict2 = VSHCoeffs(parity=-1)
coeffs_dict2[(2, 3)] = e3nn.normal("2e", jax.random.PRNGKey(1))

(cross_gaunt_tensor_product(coeffs_dict1, coeffs_dict2, output_lmax=4).filter("4e"))

In [None]:
# VSH x VSH
coeffs_dict1 = VSHCoeffs(parity=-1)
coeffs_dict1[(2, 3)] = e3nn.normal("2e", jax.random.PRNGKey(0))

coeffs_dict2 = VSHCoeffs(parity=-1)
coeffs_dict2[(2, 3)] = e3nn.normal("2e", jax.random.PRNGKey(1))

(cross_gaunt_tensor_product(coeffs_dict1, coeffs_dict2, output_lmax=4).filter("4e"))

## Check YuQing's selection rules

YuQing:
1. |l_i - 1|   <=  j_i  <= l_i + 1 for all i (usual selection rule for l_i times 1)
2. |l_1 - l_2| <=  l_3  <= l_1 + l_2 (usual selection rule for l_1 times l_2)
3. |j_1 - j_2| <=  j_3  <= j_1 + j_2 (usual selection rule for j_1 times j_2)
4. l_1 + l_2 + l_3 is even
5. There is no choice of a, b, c where l_a = j_a and (l_b, j_b) = (l_c, j_c)

First one is guaranteed by our construction.

In [None]:
coeffs_dict1 = VSHCoeffs(parity=1)
coeffs_dict1[(2, 2)] = e3nn.normal("2e", jax.random.PRNGKey(0))

coeffs_dict2 = VSHCoeffs(parity=1)
coeffs_dict2[(2, 2)] = e3nn.normal("2e", jax.random.PRNGKey(1))

cross_product_dict = cross_gaunt_tensor_product(
    coeffs_dict1, coeffs_dict2, output_lmax=4
)
cross_product_dict

In [None]:
for (j3, l3), coeffs in cross_product_dict.items():
    if jnp.allclose(cross_product_dict[(j3, l3)].array, 0, atol=1e-3):
        continue

    print(f"Checking (j3, l3) = ({j3}, {l3})")
    for j2, l2 in coeffs_dict2.keys():
        for j1, l1 in coeffs_dict1.keys():
            # Check first conditions
            assert l1 - 1 <= j1 <= l1 + 1
            assert l2 - 1 <= j2 <= l2 + 1
            assert l3 - 1 <= j3 <= l3 + 1

            # Check second condition
            assert abs(l1 - l2) <= l3 <= l1 + l2, (l1, l2, l3)

            # Check third condition
            assert abs(j1 - j2) <= j3 <= j1 + j2, (j1, j2, j3)

            # Check fourth condition
            assert (l1 + l2 + l3) % 2 == 0, (l1, l2, l3)

            # Check fifth condition
            ls = [l1, l2, l3]
            js = [j1, j2, j3]
            for a, b, c in itertools.permutations(range(3)):
                assert not ((ls[a], ls[b], js[b]) == (js[a], ls[c], js[c]))

# Dot Product

In [None]:
# Dot product of vector fields
def dot_gaunt_tensor_product(
    coeffs_dict1: VSHCoeffs, coeffs_dict2: VSHCoeffs, output_lmax: int
) -> e3nn.IrrepsArray:
    sig1 = coeffs_dict1.to_vector_signal()
    sig2 = coeffs_dict2.to_vector_signal()
    dot_sig = sig1.replace_values(
        wrap_fn_for_vector_signal(jnp.dot)(sig1.grid_values, sig2.grid_values)
    )
    return get_ssh_coeffs(dot_sig, lmax=output_lmax)

In [None]:
coeffs_dict1 = VSHCoeffs(parity=-1)
coeffs_dict1[(1, 2)] = e3nn.IrrepsArray("1o", jnp.asarray([1.0, 0.0, 0.0]))

coeffs_dict2 = VSHCoeffs(parity=-1)
coeffs_dict2[(1, 2)] = e3nn.IrrepsArray("1o", jnp.asarray([0.0, 1.0, 1.0]))

(
    dot_gaunt_tensor_product(coeffs_dict1, coeffs_dict2, output_lmax=4),
    dot_gaunt_tensor_product(coeffs_dict2, coeffs_dict1, output_lmax=4),
)

# Full Vector Gaunt Tensor Product

- Go from irreps1 -> VSHCoeffs1 using Linear layer,
similarly for irreps2 -> VSHCoeffs2.

- Do the cross product of the VSHCoeffss to get VSHCoeffs3.

- Finally, go from VSHCoeffs3 -> irreps3 using Linear layer.

In [None]:
import haiku as hk


@hk.without_apply_rng
@hk.transform
def full_gaunt_tensor_product(
    coeffs1: e3nn.IrrepsArray,
    coeffs2: e3nn.IrrepsArray,
    parity1: int,
    parity2: int,
    output_lmax: int,
) -> e3nn.IrrepsArray:
    coeffs1 = e3nn.haiku.Linear(
        get_vsh_irreps(coeffs1.irreps.lmax, parity=parity1), force_irreps_out=True
    )(coeffs1)
    coeffs2 = e3nn.haiku.Linear(
        get_vsh_irreps(coeffs2.irreps.lmax, parity=parity2), force_irreps_out=True
    )(coeffs2)

    coeffs_dict1 = VSHCoeffs.from_irreps_array(coeffs1)
    coeffs_dict2 = VSHCoeffs.from_irreps_array(coeffs2)
    cross_product_dict = cross_gaunt_tensor_product(
        coeffs_dict1, coeffs_dict2, output_lmax=output_lmax
    )

    cross_product_coeffs = cross_product_dict.to_irreps_array()
    cross_product_coeffs = e3nn.haiku.Linear(
        get_vsh_irreps(output_lmax, parity=parity1 * parity2)
    )(cross_product_coeffs)
    return cross_product_coeffs


coeffs1 = e3nn.normal(e3nn.s2_irreps(3), jax.random.PRNGKey(0))
coeffs2 = e3nn.normal(e3nn.s2_irreps(3), jax.random.PRNGKey(1))

params = full_gaunt_tensor_product.init(
    jax.random.PRNGKey(0),
    coeffs1,
    coeffs2,
    parity1=-1,
    parity2=-1,
    output_lmax=4,
)
params

In [None]:
jax.jit(
    full_gaunt_tensor_product.apply,
    static_argnames=["output_lmax", "parity1", "parity2"],
)(
    params,
    coeffs1,
    coeffs2,
    parity1=-1,
    parity2=-1,
    output_lmax=4,
)