In [105]:
import e3nn_jax as e3nn
import jax
import jax.numpy as jnp
import plotly.graph_objects as go

jnp.set_printoptions(precision=3)
jnp.set_printoptions(suppress=True)


In [106]:
def get_change_of_basis_matrix(lmax: int) -> jnp.ndarray:
    rtp = e3nn.reduced_tensor_product_basis(
        "ij", i="1o", j=e3nn.s2_irreps(lmax)[1:]
    )
    new_irreps = e3nn.Irreps([(1, ir) for mul, ir in rtp.irreps for _ in range(mul)])
    rtp = rtp.rechunk(new_irreps)
    return rtp


def get_vsh_irreps(lmax: int) -> e3nn.Irreps:
    return get_change_of_basis_matrix(lmax).irreps

In [107]:
get_change_of_basis_matrix(lmax=3)

1x0e+1x1o+1x1e+1x2e+1x2e+1x2o+1x3o+1x3e+1x4e
[[[ 0.577  0.     0.    ...  0.     0.     0.   ]
  [ 0.     0.     0.    ...  0.     0.     0.   ]
  [ 0.     0.     0.    ...  0.     0.     0.   ]
  ...
  [ 0.     0.     0.    ...  0.     0.     0.   ]
  [ 0.     0.     0.    ...  0.     0.     0.   ]
  [ 0.     0.     0.    ...  0.     0.     0.   ]]

 [[ 0.     0.     0.    ...  0.     0.     0.   ]
  [ 0.577  0.     0.    ...  0.     0.     0.   ]
  [ 0.     0.     0.    ...  0.     0.     0.   ]
  ...
  [ 0.     0.     0.    ...  0.     0.     0.   ]
  [ 0.     0.     0.    ...  0.655  0.     0.   ]
  [ 0.     0.     0.    ...  0.     0.5    0.   ]]

 [[ 0.     0.     0.    ...  0.     0.     0.   ]
  [ 0.     0.     0.    ...  0.     0.     0.   ]
  [ 0.577  0.     0.    ...  0.     0.     0.   ]
  ...
  [ 0.     0.     0.    ...  0.518  0.     0.   ]
  [ 0.     0.     0.    ...  0.     0.612  0.   ]
  [ 0.     0.     0.    ... -0.134  0.     0.707]]]

In [108]:
def to_vector_coeffs(coeffs: e3nn.IrrepsArray) -> e3nn.IrrepsArray:
    l = coeffs.irreps.lmax + 1
    rtp = get_change_of_basis_matrix(l)
    coeffs = coeffs.extend_with_zeros(rtp.irreps)
    vector_coeffs = jnp.einsum("ijk,k->ij", rtp.array, coeffs.array)
    vector_coeffs = e3nn.IrrepsArray(e3nn.s2_irreps(l)[1:], vector_coeffs) # [3, (lmax + 1) ** 2]
    return vector_coeffs


def to_vector_signal(coeffs: e3nn.IrrepsArray, res_beta: int = 90, res_alpha: int = 89, quadrature="soft") -> e3nn.SphericalSignal:
    vector_coeffs = to_vector_coeffs(coeffs)
    vector_sig = e3nn.to_s2grid(
        vector_coeffs, res_beta=res_beta, res_alpha=res_alpha, quadrature=quadrature, p_val=1, p_arg=-1
    )
    return vector_sig


def vector_spherical_harmonics(l: int, j: int, mj: int) -> e3nn.SphericalSignal:
    if j not in [l - 1, l, l + 1]:
        raise ValueError(f"Invalid j={j} for l={l}.")

    if mj not in range(-j, j + 1):
        raise ValueError(f"Invalid mj={mj} for j={j}.")

    coeffs = e3nn.IrrepsArray(
        e3nn.Irrep(j, (-1) ** (l + 1)),
        jnp.asarray([1.0 if i == mj else 0.0 for i in range(-j, j + 1)]),
    )
    return to_vector_signal(coeffs)

In [109]:
l, j, mj = 2, 1, 0
rtp = e3nn.reduced_tensor_product_basis(
    "ij", i="1o", j=e3nn.s2_irreps(l)[1:],
)
new_irreps = e3nn.Irreps([(1, ir) for mul, ir in rtp.irreps for _ in range(mul)])
rtp = rtp.rechunk(new_irreps)
# print(rtp.shape)
coeffs = e3nn.IrrepsArray(e3nn.Irrep(j, (-1) ** (l + 1)), jnp.asarray([1.0, 0.0, 0.0]))
coeffs = coeffs.extend_with_zeros(rtp.irreps)
# print(coeffs)
# print(rtp.shape)
# for irrep, chunk in zip(rtp.irreps, rtp.chunks):
#     print(irrep, chunk.shape)
    
vector_coeffs = jnp.einsum("ijk,k->ij", rtp.array, coeffs.array)
vector_coeffs = e3nn.IrrepsArray(e3nn.s2_irreps(l)[1:], vector_coeffs) # [3, (2l + 1)]
for irrep, chunk in zip(vector_coeffs.irreps, vector_coeffs.chunks):
    print(irrep, chunk.squeeze())

1x1o [[0. 0. 0.]
 [0. 0. 0.]
 [0. 0. 0.]]
1x2e [[ 0.     0.    -0.316  0.    -0.548]
 [ 0.     0.548  0.     0.     0.   ]
 [ 0.548  0.     0.     0.     0.   ]]


In [110]:
l, j, mj = 1, 1, 0
vector_spherical_harmonics(l, j, mj)

SphericalSignal(shape=(3, 90, 89), res_beta=90, res_alpha=89, quadrature=soft, p_val=1, p_arg=-1)
[[[-0.021 -0.021 -0.021 ... -0.021 -0.021 -0.021]
  [-0.064 -0.064 -0.063 ... -0.063 -0.063 -0.064]
  [-0.107 -0.106 -0.106 ... -0.104 -0.106 -0.106]
  ...
  [-0.107 -0.106 -0.106 ... -0.104 -0.106 -0.106]
  [-0.064 -0.064 -0.063 ... -0.063 -0.063 -0.064]
  [-0.021 -0.021 -0.021 ... -0.021 -0.021 -0.021]]

 [[ 0.     0.     0.    ...  0.     0.     0.   ]
  [ 0.     0.     0.    ...  0.     0.     0.   ]
  [ 0.     0.     0.    ...  0.     0.     0.   ]
  ...
  [ 0.     0.     0.    ...  0.     0.     0.   ]
  [ 0.     0.     0.    ...  0.     0.     0.   ]
  [ 0.     0.     0.    ...  0.     0.     0.   ]]

 [[ 0.     0.002  0.003 ... -0.004 -0.003 -0.002]
  [ 0.     0.005  0.009 ... -0.013 -0.009 -0.005]
  [ 0.     0.008  0.015 ... -0.022 -0.015 -0.008]
  ...
  [ 0.     0.008  0.015 ... -0.022 -0.015 -0.008]
  [ 0.     0.005  0.009 ... -0.013 -0.009 -0.005]
  [ 0.     0.002  0.003 ... -0

In [111]:
def plot_vector_signal(sig: e3nn.SphericalSignal, scale_vec: float = 0.1, title: str = None):
    grid = sig.grid_vectors.transpose((2, 0, 1)).reshape((3, -1))
    values = sig.grid_values.reshape((3, -1))

    fig = go.Figure()
    fig.add_trace(
        go.Cone(
            x=grid[0, :],
            y=grid[1, :],
            z=grid[2, :],
            u=scale_vec * values[0, :],
            v=scale_vec * values[1, :],
            w=scale_vec * values[2, :],
            colorscale="Viridis",
            sizemode="absolute",
            sizeref=5,
            showscale=True,
            hoverinfo="skip",
        )
    )
    if title is not None:
        fig.update_layout(title=title)
    return fig

In [112]:
l, j, mj = 2, 1, 0
plot_vector_signal(vector_spherical_harmonics(l, j, mj), title=f"l={l}, j={j}, mj={mj}")

## Reconstruction

Check that we can emulate from_s2grid() and to_s2grid() with VSH.

In [113]:
def wrap_fn_for_vector_signal(fn):
    """vmaps a fn over res_beta and res_alpha axes."""
    fn = jax.vmap(fn, in_axes=-1, out_axes=-1)
    fn = jax.vmap(fn, in_axes=-1, out_axes=-1)
    return fn


def get_vsh_coeffs_at_mj(sig, l_out, j_out, mj_out):
    """Returns the component of Y_{l_out, j_out, mj_out} in the signal sig."""
    vsh_signal = vector_spherical_harmonics(l_out, j_out, mj_out)
    dot_product = sig.replace_values(
        wrap_fn_for_vector_signal(jnp.dot)(sig.grid_values, vsh_signal.grid_values)
    )
    return dot_product.integrate().array.item() / (4 * jnp.pi)


def get_vsh_coeffs_at_j(sig, l_out, j_out):
    """Returns the components of Y_{l_out, j_out, mj_out} in the signal sig for all mj_out in [-j_out, ..., j_out]."""
    computed_coeffs = jnp.asarray([get_vsh_coeffs_at_mj(sig, l_out, j_out, mj_out) for mj_out in range(-j_out, j_out + 1)])
    computed_coeffs = e3nn.IrrepsArray(e3nn.Irrep(j_out, (-1) ** (l_out + 1)), computed_coeffs)
    return computed_coeffs


def get_vsh_coeffs(sig, lmax: int):
    """Returns the components of Y_{l_out, j_out, mj_out} in the signal sig for all mj_out in [-j_out, ..., j_out] and j_out in [-l_out, ..., l_out] and l_out upto lmax."""
    result = []
    for l_out in range(1, lmax + 1):
        for j_out in [l_out - 1, l_out, l_out + 1]:
            result.append(get_vsh_coeffs_at_j(sig, l_out, j_out))
    return e3nn.concatenate(result).regroup()

In [122]:
# Reconstruction example
lmax = 3
coeffs1 = e3nn.normal(get_vsh_irreps(lmax), key=jax.random.PRNGKey(1))
sig1 = to_vector_signal(coeffs1)
reconstructed_coeffs1 = get_vsh_coeffs(sig1, lmax)
print(jnp.isclose(reconstructed_coeffs1.array, coeffs1.array, atol=1e-3))
print(jnp.isclose(e3nn.sum(reconstructed_coeffs1.regroup()).array, e3nn.sum(coeffs1.regroup()).array, atol=1e-3))
# reconstructed_coeffs1, coeffs1
e3nn.sum(reconstructed_coeffs1.regroup()).array, e3nn.sum(coeffs1.regroup()).array

[ True  True  True  True  True  True  True  True  True  True  True  True
 False False False False False  True  True  True  True  True  True  True
  True  True  True  True  True  True  True  True  True  True  True  True
  True  True  True  True  True  True  True  True  True]
[ True  True  True  True  True  True  True False False False False False
  True  True  True  True  True  True  True  True  True  True  True  True
  True  True  True  True  True  True  True  True  True  True  True  True
  True  True  True  True]


(Array([ 0.448, -0.108,  0.36 , -1.132,  1.144,  0.894, -2.207,  2.938,
         1.051,  1.661, -1.253,  0.622, -1.441,  0.732,  0.111, -3.316,
         0.507,  0.854, -1.496, -0.322,  0.552, -0.974,  0.781,  0.421,
         0.473,  0.125,  0.007, -2.845, -0.421, -1.949,  1.853,  0.542,
         1.282,  0.75 , -0.602, -1.307,  0.859,  0.798,  0.702, -0.679],      dtype=float32),
 Array([ 0.448, -0.108,  0.36 , -1.132,  1.144,  0.894, -2.207,  0.396,
        -0.975,  0.429,  0.151,  0.562, -1.441,  0.732,  0.111, -3.316,
         0.507,  0.854, -1.496, -0.322,  0.552, -0.974,  0.781,  0.421,
         0.473,  0.125,  0.007, -2.845, -0.421, -1.949,  1.853,  0.542,
         1.282,  0.75 , -0.602, -1.307,  0.859,  0.798,  0.702, -0.679],      dtype=float32))

## Cross-Product

In [115]:
l, j = 1, 1
coeffs1 = e3nn.normal(e3nn.s2_irreps(j)[-1], key=jax.random.PRNGKey(0))
sig1 = to_vector_signal(coeffs1)
plot_vector_signal(sig1, title="Signal 1")

In [116]:
l, j = 1, 1
coeffs2 = e3nn.normal(e3nn.s2_irreps(j)[-1], key=jax.random.PRNGKey(1))
sig2 = to_vector_signal(coeffs2)
plot_vector_signal(sig2, title="Signal 2")

In [117]:
cross_sig = sig1.replace_values(
    wrap_fn_for_vector_signal(jnp.cross)(
        sig1.grid_values, sig2.grid_values
    )
)
plot_vector_signal(cross_sig, title="Cross Product")

In [118]:
cross_sig_coeffs = get_vsh_coeffs(cross_sig, lmax=5)
print(cross_sig_coeffs.irreps, cross_sig_coeffs.array.round(2))

1x0e+1x1o+1x1e+2x2e+1x2o+2x3o+1x3e+2x4e+1x4o+1x5o+1x5e+1x6e [ 0.   -0.26 -1.53 -0.77  0.    0.   -0.    0.   -0.   -0.    0.   -0.
  0.   -0.   -0.    0.   -0.    0.   -0.   -0.    0.    0.    0.    0.
  0.    0.   -0.    0.    0.    0.    0.    0.    0.   -0.    0.    0.
 -0.   -0.   -0.    0.   -0.    0.   -0.   -0.   -0.    0.   -0.    0.
  0.   -0.   -0.   -0.   -0.   -0.    0.   -0.    0.    0.   -0.   -0.
 -0.   -0.    0.    0.    0.   -0.   -0.   -0.    0.   -0.    0.   -0.
  0.   -0.    0.    0.    0.   -0.   -0.   -0.   -0.    0.   -0.   -0.
  0.   -0.    0.    0.   -0.   -0.   -0.    0.    0.    0.   -0.   -0.
  0.   -0.    0.   -0.   -0.    0.    0.    0.    0.  ]


## Vector Gaunt Tensor Product

In [119]:
# Cross product of vector fields
def vector_gaunt_tensor_product(coeffs1: e3nn.IrrepsArray, coeffs2: e3nn.IrrepsArray, l: int) -> e3nn.IrrepsArray:
    sig1 = to_vector_signal(coeffs1, l)
    sig2 = to_vector_signal(coeffs2, l)
    cross_sig = sig1.replace_values(
        wrap_fn_for_vector_signal(jnp.cross)(
            sig1.grid_values, sig2.grid_values
        )
    )
    return get_vsh_coeffs(cross_sig, lmax=2 * l)

j = 1
coeffs1 = e3nn.normal(e3nn.s2_irreps(j)[-1], key=jax.random.PRNGKey(0))
coeffs2 = e3nn.normal(e3nn.s2_irreps(j)[-1], key=jax.random.PRNGKey(1))
vector_gaunt_tensor_product(coeffs1, coeffs2, 1).array.round(2)

AssertionError: res_beta needs to be even for soft quadrature weights to be computed properly