In [1]:
import jax
import jax.numpy as jnp
import e3nn_jax as e3nn
import plotly.graph_objects as go

from freedom.tensor_products.vector_spherical_harmonics import VSHCoeffs

In [2]:
vsh_coeffs_1 = VSHCoeffs.normal(jmax=3, parity=-1, key=jax.random.PRNGKey(0))
vsh_coeffs_2 = VSHCoeffs.normal(jmax=3, parity=-1, key=jax.random.PRNGKey(1))
vsh_coeffs_1, vsh_coeffs_2

(VSH Coefficients
  1x0e: [[-1.4134908]]
  2x1o: [[ 1.2222526  -1.1022004   0.49160266]
  [ 0.08187129  0.81290984  0.21408351]]
  1x1e: [[-1.70266    -0.96234554  1.8774409 ]]
  2x2e: [[-0.03871189  0.11073402  1.0138445   1.509118   -1.9148264 ]
  [-1.4499495  -0.59010714 -0.6845358  -0.35398695  0.3026237 ]]
  1x2o: [[-0.61889184 -0.11229677  1.4682506  -0.00166103  0.2519778 ]]
  1x3o: [[-0.17617968  0.06397271  1.4722321   1.0775082  -0.9310467  -1.4745097
    1.5775032 ]]
  1x3e: [[ 0.1934331   0.7253706   0.6590064   1.7567376   0.05851391 -0.8722071
    0.15762478]]
  1x4e: [[-0.83682615  0.19154705 -0.5562989  -1.3756475  -0.78326577  0.6363154
    0.2544184  -0.90707594  0.1912001 ]],
 VSH Coefficients
  1x0e: [[2.292836]]
  2x1o: [[ 0.34930295  0.51804686 -1.5153265 ]
  [-1.6210418  -0.83535033  1.8372124 ]]
  1x1e: [[ 1.8336425  -1.6708398   0.15234356]]
  2x2e: [[-0.3966134  -1.9082907   1.1256224  -0.11723752 -0.82958674]
  [-0.90526086 -0.20008144  1.5660174  -0.6564076 

In [3]:
grid_kwargs = dict(res_beta=100, res_alpha=99, quadrature="soft")

In [4]:
sig_1 = vsh_coeffs_1.to_vector_signal(**grid_kwargs)
sig_2 = vsh_coeffs_2.to_vector_signal(**grid_kwargs)

In [5]:
def plot_vector_signal(
    sig: e3nn.SphericalSignal, scale_vec: float = 0.001, title: str = None
):
    """Plots a vector spherical signal."""
    grid = sig.grid_vectors.transpose((2, 0, 1)).reshape((3, -1))
    values = sig.grid_values.reshape((3, -1))

    fig = go.Figure()
    fig.add_trace(
        go.Cone(
            x=grid[0, :],
            y=grid[1, :],
            z=grid[2, :],
            u=scale_vec * values[0, :],
            v=scale_vec * values[1, :],
            w=scale_vec * values[2, :],
            colorscale="Viridis",
            sizemode="absolute",
            sizeref=5,
            showscale=True,
            hoverinfo="skip",
        )
    )
    if title is not None:
        fig.update_layout(title=title)
    return fig

In [6]:
plot_vector_signal(sig_1, title="Signal 1")

In [7]:
plot_vector_signal(sig_2, title="Signal 2")

In [8]:
cross_coeffs = vsh_coeffs_1.reduce_pointwise_cross_product(vsh_coeffs_2, **grid_kwargs)
cross_coeffs

PVSH Coefficients
 1x0o: [[11.5233]]
 1x1o: [[-3.282505   2.1873672  3.5948331]]
 2x1e: [[-2.0042558  2.4800336 13.664726 ]
 [ 6.4210157  3.865445   3.3846214]]
 1x2e: [[ -1.7245905  -2.2699032 -11.30559    -7.3121905   2.0841808]]
 2x2o: [[-4.975041   -1.8101524   4.0745497  -1.123744    1.9915977 ]
 [-4.93644    -0.07588302 -0.26268974 -1.2078125   3.131565  ]]
 1x3o: [[ 2.1096225   0.7251361  -9.505152   -0.02353078  0.590553    2.4604669
   2.5813022 ]]
 2x3e: [[-4.1096396  6.2000246  4.5171714  4.810617  -0.4946092 -1.7111585
   5.9859715]
 [-2.4308183 -1.897345   5.5419908 -3.496013  -5.955441   1.690079
   4.7399926]]
 1x4e: [[ 1.1237953  -1.2570901   1.1700044   3.9101608   0.15236849 -6.349345
   0.83392704 -1.9390092  -1.9559363 ]]
 2x4o: [[ -0.15837082  -0.15826863   4.378097     1.4067827    5.984352
   -1.1065944  -10.233195    -1.6938962   -0.42502373]
 [ -3.6378138   -0.91345054   1.8374575    2.1341555   -0.6339636
   -1.1281043   -4.7645607   -2.3132722    1.8804944 ]]

In [9]:
cross_sig = cross_coeffs.to_vector_signal(**grid_kwargs)
plot_vector_signal(cross_sig, title="Cross Product")