In [21]:
import e3nn_jax as e3nn
import jax
import jax.numpy as jnp
import plotly.graph_objects as go

import itertools
import sys

sys.path.append(".")
from vector_spherical_harmonics import *

jnp.set_printoptions(precision=3)
jnp.set_printoptions(suppress=True)

We store the coefficients for each Y(l, j, mj) in a dictionary.

In [22]:
get_vsh_irreps(3, parity=1), get_vsh_irreps(3, parity=-1)

(1x0o+1x1e+1x1o+1x1e+1x2o+1x2e+1x2o+1x3e+1x3o+1x3e,
 1x0e+1x1o+1x1e+1x1o+1x2e+1x2o+1x2e+1x3o+1x3e+1x3o)

This creates the Y(j, l, mj) by reducing the 1 x l representation into a (l - 1), l and (l + 1) representation.

In [23]:
get_change_of_basis_matrices(jmax=2, parity=1)

{(0,
  1): 1x0o
 [[[0.577]
   [0.   ]
   [0.   ]]
 
  [[0.   ]
   [0.577]
   [0.   ]]
 
  [[0.   ]
   [0.   ]
   [0.577]]],
 (1,
  0): 1x1e
 [[[1. 0. 0.]]
 
  [[0. 1. 0.]]
 
  [[0. 0. 1.]]],
 (1,
  1): 1x1o
 [[[ 0.     0.     0.   ]
   [ 0.     0.     0.707]
   [ 0.    -0.707  0.   ]]
 
  [[ 0.     0.    -0.707]
   [ 0.     0.     0.   ]
   [ 0.707  0.     0.   ]]
 
  [[ 0.     0.707  0.   ]
   [-0.707  0.     0.   ]
   [ 0.     0.     0.   ]]],
 (1,
  2): 1x1e
 [[[ 0.     0.     0.548]
   [ 0.     0.548  0.   ]
   [-0.316  0.     0.   ]
   [ 0.     0.     0.   ]
   [-0.548  0.     0.   ]]
 
  [[ 0.     0.     0.   ]
   [ 0.548  0.     0.   ]
   [ 0.     0.632  0.   ]
   [ 0.     0.     0.548]
   [ 0.     0.     0.   ]]
 
  [[ 0.548  0.     0.   ]
   [ 0.     0.     0.   ]
   [ 0.     0.    -0.316]
   [ 0.     0.548  0.   ]
   [ 0.     0.     0.548]]],
 (2,
  1): 1x2o
 [[[ 0.     0.    -0.408  0.    -0.707]
   [ 0.     0.707  0.     0.     0.   ]
   [ 0.707  0.     0.     0.     0.   ]

In [25]:
coeffs_dict1 = VSHCoeffs(parity=-1)
coeffs_dict1[(1, 1)] = e3nn.IrrepsArray("1e", jnp.asarray([1.0, 0.0, 0.0]))
print(coeffs_dict1.to_vector_coeffs())

coeffs_dict1 = VSHCoeffs.normal(jmax=3, parity=-1, key=jax.random.PRNGKey(0))
print(coeffs_dict1.to_vector_coeffs())

1x1o
[[ 0.     0.     0.   ]
 [ 0.     0.     0.707]
 [ 0.    -0.707  0.   ]]
1x1o+1x0e+1x1o+1x2e+1x1o+1x2e+1x3o+1x2e+1x3o+1x4e
[[-0.119 -0.    -0.     0.139  0.     1.204  1.434 -0.322  0.109  0.179
   0.     0.31  -0.374 -0.613  0.147  0.383 -0.13   0.585  0.615 -0.338
  -0.092  0.293 -1.204  0.25  -0.014  0.322 -0.055 -0.684  0.055  0.524
  -0.37   1.032  0.451 -0.642 -0.582  0.572  0.248 -0.396 -0.421 -0.84
   0.557 -0.606 -0.912 -0.22  -0.176 -0.27  -0.466 -0.129]
 [-0.    -0.119 -0.     0.509 -1.204  0.     0.564  0.    -0.31   0.126
  -0.322  0.    -0.613  0.878 -0.251 -1.446 -0.338  0.    -0.383 -0.261
   0.     0.045 -0.407 -1.445  0.371 -0.075  0.    -0.37   0.605 -0.139
  -0.164  0.026  0.066 -0.687 -0.234  0.     0.529 -0.736  0.477  0.
   0.091  0.499  0.348 -1.42  -0.632  0.596 -0.594  0.   ]
 [-0.    -0.    -0.119 -0.531 -1.434 -0.564  0.    -0.31   0.     0.186
   0.109 -0.322  0.147 -0.251 -0.503  0.338  0.83   0.664  0.13   0.383
   0.055 -0.322 -0.014 -0.227 -1.156  

In [26]:
def plot_vector_signal(
    sig: e3nn.SphericalSignal, scale_vec: float = 0.1, title: str = None
):
    """Plots a vector spherical signal."""
    grid = sig.grid_vectors.transpose((2, 0, 1)).reshape((3, -1))
    values = sig.grid_values.reshape((3, -1))

    fig = go.Figure()
    fig.add_trace(
        go.Cone(
            x=grid[0, :],
            y=grid[1, :],
            z=grid[2, :],
            u=scale_vec * values[0, :],
            v=scale_vec * values[1, :],
            w=scale_vec * values[2, :],
            colorscale="Viridis",
            sizemode="absolute",
            sizeref=5,
            showscale=True,
            hoverinfo="skip",
        )
    )
    if title is not None:
        fig.update_layout(title=title)
    return fig

In [27]:
j, l, mj = 2, 1, 0
plot_vector_signal(
    vector_spherical_harmonics(j, l, mj), title=f"VSH j={j}, l={l}, mj={mj}"
)

## Reconstruction

Check that we can emulate from_s2grid() and to_s2grid() with VSH.

In [28]:


def get_ssh_coeffs(sig: e3nn.SphericalSignal, lmax: int) -> e3nn.IrrepsArray:
    """Returns the components of the scalar spherical harmonics for each l and m in [-l, ..., l]."""
    return e3nn.from_s2grid(sig, irreps=e3nn.s2_irreps(lmax))

In [29]:
# Reconstruction example
lmax = 3
coeffs_dict1 = VSHCoeffs.normal(jmax=lmax, parity=-1, key=jax.random.PRNGKey(0))
sig1 = coeffs_dict1.to_vector_signal()
reconstructed_coeffs_dict1 = get_vsh_coeffs(sig1, lmax, parity=-1)
print(reconstructed_coeffs_dict1)
print(coeffs_dict1)
for j, l in reconstructed_coeffs_dict1.keys():
    assert reconstructed_coeffs_dict1[(j, l)].irreps == coeffs_dict1[(j, l)].irreps
    assert jnp.allclose(
        reconstructed_coeffs_dict1[(j, l)].array, coeffs_dict1[(j, l)].array
    )

VSHCoeffs(parity=-1)
 (0, 1): 1x0e [-0.206]
 (1, 0): 1x1o [ 0.139  0.509 -0.531]
 (1, 1): 1x1e [ 0.798 -2.029  1.702]
 (1, 2): 1x1o [-0.567  0.2   -0.588]
 (2, 1): 1x2e [ 0.208 -0.868  1.075 -0.354 -0.091]
 (2, 2): 1x2o [-0.319 -0.938  0.152  0.828  1.77 ]
 (2, 3): 1x2e [ 0.092 -0.66  -2.207  0.601 -0.154]
 (3, 2): 1x3o [ 1.246 -0.641  0.829 -0.18  -0.225  0.044  1.025]
 (3, 3): 1x3e [ 0.55  -1.275  1.834  0.417  0.809  1.19  -0.077]
 (3, 4): 1x3o [ 0.207  0.864  0.539 -2.131 -0.979  1.032 -1.347]
VSHCoeffs(parity=-1)
 (0, 1): 1x0e [-0.206]
 (1, 0): 1x1o [ 0.139  0.509 -0.531]
 (1, 1): 1x1e [ 0.798 -2.029  1.702]
 (1, 2): 1x1o [-0.567  0.2   -0.588]
 (2, 1): 1x2e [ 0.208 -0.868  1.075 -0.354 -0.091]
 (2, 2): 1x2o [-0.319 -0.938  0.152  0.828  1.77 ]
 (2, 3): 1x2e [ 0.092 -0.66  -2.207  0.601 -0.154]
 (3, 2): 1x3o [ 1.246 -0.641  0.829 -0.18  -0.225  0.044  1.025]
 (3, 3): 1x3e [ 0.55  -1.275  1.834  0.417  0.809  1.19  -0.077]
 (3, 4): 1x3o [ 0.207  0.864  0.539 -2.131 -0.979  1.032 -1

## Cross Product

In [30]:
def cross_gaunt_tensor_product(
    coeffs_dict1: VSHCoeffs, coeffs_dict2: VSHCoeffs, output_lmax: int
) -> VSHCoeffs:
    sig1 = coeffs_dict1.to_vector_signal()
    sig2 = coeffs_dict2.to_vector_signal()
    cross_sig = sig1.replace_values(
        wrap_fn_for_vector_signal(jnp.cross)(sig1.grid_values, sig2.grid_values)
    )
    return get_vsh_coeffs(
        cross_sig, lmax=output_lmax, parity=coeffs_dict1.parity * coeffs_dict2.parity
    )

In [31]:
# Visualizing the cross product of two VSH
coeffs_dict1 = VSHCoeffs(parity=-1)
coeffs_dict1[(1, 2)] = e3nn.IrrepsArray("1o", jnp.asarray([1.0, 0.0, 0.0]))
plot_vector_signal(coeffs_dict1.to_vector_signal(), title="Signal 1")

In [32]:
coeffs_dict2 = VSHCoeffs(parity=-1)
coeffs_dict2[(1, 2)] = e3nn.IrrepsArray("1o", jnp.asarray([0.0, 1.0, 0.0]))
plot_vector_signal(coeffs_dict2.to_vector_signal(), title="Signal 2")

In [33]:
cross_product_dict = cross_gaunt_tensor_product(
    coeffs_dict1, coeffs_dict2, output_lmax=2
)
print(cross_product_dict)
plot_vector_signal(cross_product_dict.to_vector_signal(), title="Cross Product")

VSHCoeffs(parity=1)
 (0, 1): 1x0o [-0.]
 (1, 0): 1x1e [ 0.  -0.  -0.5]
 (1, 1): 1x1o [ 0. -0.  0.]
 (1, 2): 1x1e [ 0.    -0.     0.707]
 (2, 1): 1x2o [ 0.  0.  0.  0. -0.]
 (2, 2): 1x2e [ 0.  0.  0. -0.  0.]
 (2, 3): 1x2o [ 0.  0.  0. -0.  0.]


In [34]:
cross_product_dict

VSHCoeffs(parity=1)
 (0, 1): 1x0o [-0.]
 (1, 0): 1x1e [ 0.  -0.  -0.5]
 (1, 1): 1x1o [ 0. -0.  0.]
 (1, 2): 1x1e [ 0.    -0.     0.707]
 (2, 1): 1x2o [ 0.  0.  0.  0. -0.]
 (2, 2): 1x2e [ 0.  0.  0. -0.  0.]
 (2, 3): 1x2o [ 0.  0.  0. -0.  0.]

In [35]:
flipped_cross_product = cross_gaunt_tensor_product(
    coeffs_dict2, coeffs_dict1, output_lmax=2
)
plot_vector_signal(
    flipped_cross_product.to_vector_signal(), title="Flipped Cross Product"
)

In [36]:
flipped_cross_product

VSHCoeffs(parity=1)
 (0, 1): 1x0o [-0.]
 (1, 0): 1x1e [-0.  -0.   0.5]
 (1, 1): 1x1o [-0.  0.  0.]
 (1, 2): 1x1e [ 0.    -0.    -0.707]
 (2, 1): 1x2o [ 0.  0. -0. -0.  0.]
 (2, 2): 1x2e [ 0. -0.  0.  0.  0.]
 (2, 3): 1x2o [-0.  0. -0.  0.  0.]

In [37]:
# PVSH x VSH
coeffs_dict1 = VSHCoeffs(parity=1)
coeffs_dict1[(2, 2)] = e3nn.normal("2e", jax.random.PRNGKey(0))

coeffs_dict2 = VSHCoeffs(parity=-1)
coeffs_dict2[(2, 3)] = e3nn.normal("2e", jax.random.PRNGKey(1))

(cross_gaunt_tensor_product(coeffs_dict1, coeffs_dict2, output_lmax=4).filter("4e"))

VSHCoeffs(parity=-1)
 (4, 3): 1x4e [ 0.055  0.118 -0.302 -0.018 -0.225  0.016 -0.134  0.252 -0.029]
 (4, 5): 1x4e [-0.177 -0.378  0.966  0.057  0.719 -0.05   0.427 -0.806  0.091]

In [38]:
# VSH x VSH
coeffs_dict1 = VSHCoeffs(parity=-1)
coeffs_dict1[(2, 3)] = e3nn.normal("2e", jax.random.PRNGKey(0))

coeffs_dict2 = VSHCoeffs(parity=-1)
coeffs_dict2[(2, 3)] = e3nn.normal("2e", jax.random.PRNGKey(1))

(cross_gaunt_tensor_product(coeffs_dict1, coeffs_dict2, output_lmax=4).filter("4e"))

VSHCoeffs(parity=1)
 (4, 4): 1x4e [ 0.  0.  0. -0. -0.  0. -0. -0.  0.]

## Check YuQing's selection rules

YuQing:
1. |l_i - 1|   <=  j_i  <= l_i + 1 for all i (usual selection rule for l_i times 1)
2. |l_1 - l_2| <=  l_3  <= l_1 + l_2 (usual selection rule for l_1 times l_2)
3. |j_1 - j_2| <=  j_3  <= j_1 + j_2 (usual selection rule for j_1 times j_2)
4. l_1 + l_2 + l_3 is even
5. There is no choice of a, b, c where l_a = j_a and (l_b, j_b) = (l_c, j_c)

First one is guaranteed by our construction.

In [39]:
coeffs_dict1 = VSHCoeffs(parity=1)
coeffs_dict1[(2, 2)] = e3nn.normal("2e", jax.random.PRNGKey(0))

coeffs_dict2 = VSHCoeffs(parity=1)
coeffs_dict2[(2, 2)] = e3nn.normal("2e", jax.random.PRNGKey(1))

cross_product_dict = cross_gaunt_tensor_product(
    coeffs_dict1, coeffs_dict2, output_lmax=4
)
cross_product_dict

VSHCoeffs(parity=1)
 (0, 1): 1x0o [0.]
 (1, 0): 1x1e [-0.142 -0.056 -0.22 ]
 (1, 1): 1x1o [-0.  0. -0.]
 (1, 2): 1x1e [-0.201 -0.08  -0.312]
 (2, 1): 1x2o [ 0.  0.  0.  0. -0.]
 (2, 2): 1x2e [ 0.  0.  0.  0. -0.]
 (2, 3): 1x2o [ 0.  0.  0.  0. -0.]
 (3, 2): 1x3e [ 0.459 -0.011  0.292 -0.417  0.294  0.132 -0.447]
 (3, 3): 1x3o [ 0. -0. -0.  0.  0. -0. -0.]
 (3, 4): 1x3e [ 0.53  -0.013  0.337 -0.481  0.34   0.153 -0.516]
 (4, 3): 1x4o [ 0.  0. -0.  0.  0. -0.  0. -0.  0.]
 (4, 4): 1x4e [ 0. -0.  0.  0.  0.  0. -0. -0.  0.]
 (4, 5): 1x4o [ 0.  0.  0.  0. -0.  0.  0.  0. -0.]

In [40]:
for (j3, l3), coeffs in cross_product_dict.items():
    if jnp.allclose(cross_product_dict[(j3, l3)].array, 0, atol=1e-3):
        continue

    print(f"Checking (j3, l3) = ({j3}, {l3})")
    for j2, l2 in coeffs_dict2.keys():
        for j1, l1 in coeffs_dict1.keys():
            # Check first conditions
            assert l1 - 1 <= j1 <= l1 + 1
            assert l2 - 1 <= j2 <= l2 + 1
            assert l3 - 1 <= j3 <= l3 + 1

            # Check second condition
            assert abs(l1 - l2) <= l3 <= l1 + l2, (l1, l2, l3)

            # Check third condition
            assert abs(j1 - j2) <= j3 <= j1 + j2, (j1, j2, j3)

            # Check fourth condition
            assert (l1 + l2 + l3) % 2 == 0, (l1, l2, l3)

            # Check fifth condition
            ls = [l1, l2, l3]
            js = [j1, j2, j3]
            for a, b, c in itertools.permutations(range(3)):
                assert not ((ls[a], ls[b], js[b]) == (js[a], ls[c], js[c]))

Checking (j3, l3) = (1, 0)
Checking (j3, l3) = (1, 2)
Checking (j3, l3) = (3, 2)
Checking (j3, l3) = (3, 4)


# Dot Product

In [41]:
# Dot product of vector fields
def dot_gaunt_tensor_product(
    coeffs_dict1: VSHCoeffs, coeffs_dict2: VSHCoeffs, output_lmax: int
) -> e3nn.IrrepsArray:
    sig1 = coeffs_dict1.to_vector_signal()
    sig2 = coeffs_dict2.to_vector_signal()
    dot_sig = sig1.replace_values(
        wrap_fn_for_vector_signal(jnp.dot)(sig1.grid_values, sig2.grid_values)
    )
    return get_ssh_coeffs(dot_sig, lmax=output_lmax)

In [42]:
coeffs_dict1 = VSHCoeffs(parity=-1)
coeffs_dict1[(1, 2)] = e3nn.IrrepsArray("1o", jnp.asarray([1.0, 0.0, 0.0]))

coeffs_dict2 = VSHCoeffs(parity=-1)
coeffs_dict2[(1, 2)] = e3nn.IrrepsArray("1o", jnp.asarray([0.0, 1.0, 1.0]))

(
    dot_gaunt_tensor_product(coeffs_dict1, coeffs_dict2, output_lmax=4),
    dot_gaunt_tensor_product(coeffs_dict2, coeffs_dict1, output_lmax=4),
)

(1x0e+1x1o+1x2e+1x3o+1x4e
 [ 0.     0.     0.     0.     0.387  0.387  0.     0.     0.    -0.
  -0.     0.     0.     0.     0.     0.    -0.     0.     0.     0.
   0.     0.     0.     0.     0.   ],
 1x0e+1x1o+1x2e+1x3o+1x4e
 [ 0.     0.     0.     0.     0.387  0.387  0.     0.     0.    -0.
  -0.     0.     0.     0.     0.     0.    -0.     0.     0.     0.
   0.     0.     0.     0.     0.   ])

# Full Vector Gaunt Tensor Product

- Go from irreps1 -> VSHCoeffs1 using Linear layer,
similarly for irreps2 -> VSHCoeffs2.

- Do the cross product of the VSHCoeffss to get VSHCoeffs3.

- Finally, go from VSHCoeffs3 -> irreps3 using Linear layer.

In [43]:
import haiku as hk

@hk.without_apply_rng
@hk.transform
def full_gaunt_tensor_product(
    coeffs1: e3nn.IrrepsArray,
    coeffs2: e3nn.IrrepsArray,
    parity1: int,
    parity2: int,
    output_lmax: int,
) -> e3nn.IrrepsArray:
    coeffs1 = e3nn.haiku.Linear(
        get_vsh_irreps(coeffs1.irreps.lmax, parity=parity1), force_irreps_out=True
    )(coeffs1)
    coeffs2 = e3nn.haiku.Linear(
        get_vsh_irreps(coeffs2.irreps.lmax, parity=parity2), force_irreps_out=True
    )(coeffs2)

    coeffs_dict1 = VSHCoeffs.from_irreps_array(coeffs1)
    coeffs_dict2 = VSHCoeffs.from_irreps_array(coeffs2)
    cross_product_dict = cross_gaunt_tensor_product(
        coeffs_dict1, coeffs_dict2, output_lmax=output_lmax
    )

    cross_product_coeffs = cross_product_dict.to_irreps_array()
    cross_product_coeffs = e3nn.haiku.Linear(
        get_vsh_irreps(output_lmax, parity=parity1 * parity2)
    )(cross_product_coeffs)
    return cross_product_coeffs


coeffs1 = e3nn.normal(e3nn.s2_irreps(3), jax.random.PRNGKey(0))
coeffs2 = e3nn.normal(e3nn.s2_irreps(3), jax.random.PRNGKey(1))

params = full_gaunt_tensor_product.init(
    jax.random.PRNGKey(0),
    coeffs1,
    coeffs2,
    parity1=-1,
    parity2=-1,
    output_lmax=4,
)
params

{'linear': {'w[0,0] 1x0e,1x0e': Array([[-1.252]], dtype=float32),
  'w[1,1] 1x1o,1x1o': Array([[-0.587]], dtype=float32),
  'w[1,3] 1x1o,1x1o': Array([[0.486]], dtype=float32),
  'w[2,4] 1x2e,1x2e': Array([[0.217]], dtype=float32),
  'w[2,6] 1x2e,1x2e': Array([[-0.649]], dtype=float32),
  'w[3,7] 1x3o,1x3o': Array([[1.34]], dtype=float32),
  'w[3,9] 1x3o,1x3o': Array([[1.038]], dtype=float32)},
 'linear_1': {'w[0,0] 1x0e,1x0e': Array([[-0.754]], dtype=float32),
  'w[1,1] 1x1o,1x1o': Array([[-0.605]], dtype=float32),
  'w[1,3] 1x1o,1x1o': Array([[-2.35]], dtype=float32),
  'w[2,4] 1x2e,1x2e': Array([[-2.088]], dtype=float32),
  'w[2,6] 1x2e,1x2e': Array([[1.633]], dtype=float32),
  'w[3,7] 1x3o,1x3o': Array([[2.049]], dtype=float32),
  'w[3,9] 1x3o,1x3o': Array([[-2.519]], dtype=float32)},
 'linear_2': {'w[0,0] 1x0o,1x0o': Array([[0.817]], dtype=float32),
  'w[1,2] 1x1o,1x1o': Array([[0.34]], dtype=float32),
  'w[2,1] 2x1e,1x1e': Array([[1.711],
         [0.812]], dtype=float32),
  'w[2

In [44]:
jax.jit(
    full_gaunt_tensor_product.apply,
    static_argnames=["output_lmax", "parity1", "parity2"],
)(
    params,
    coeffs1,
    coeffs2,
    parity1=-1,
    parity2=-1,
    output_lmax=4,
)

1x0o+1x1e+1x1o+1x1e+1x2o+1x2e+1x2o+1x3e+1x3o+1x3e+1x4o+1x4e+1x4o
[ -0.      0.52    9.672 -11.638  -1.048   2.593   2.809   0.042   3.431
  -3.498   4.807   2.227  -5.238   4.971  -0.887  -0.389   0.988  -1.442
  -1.003  -0.483  -6.412  -6.856   3.041  -2.452   1.257  -1.146  -1.367
  -1.122  -6.539   4.073  -5.817  -6.398  -2.123   0.425  -2.39   -0.913
   1.696   0.652  -0.786  -0.666   0.782  -2.988  -3.692  -1.31    0.754
  -1.246   6.326   4.59   -6.358  -0.49   -0.943   3.148  -3.826   3.64
   1.645  -6.652   3.781   5.317   1.501  -6.108   6.904  -2.084   5.649
   9.381  -7.089  -5.032   3.902   0.96   -1.924  -4.059   3.824  -2.302
   0.619]