In [1]:
from typing import *
import e3nn_jax as e3nn
import jax
import jax.numpy as jnp
import plotly.graph_objects as go
import plotly.subplots

  jax.tree_util.register_keypaths(data_clz, keypaths)
  jax.tree_util.register_keypaths(data_clz, keypaths)


In [2]:
def plot_signal(signal: e3nn.SphericalSignal, cmin: Optional[float] = None, cmax: Optional[float] = None) -> go.Figure:
    """Plots a SphericalSignal on a sphere."""
    if cmin is None:
        cmin = float(signal.grid_values.min()) - 0.1
    if cmax is None:
        cmax = float(signal.grid_values.max()) + 0.1
    fig = go.Figure([
        go.Surface(
            signal.plotly_surface(scale_radius_by_amplitude=True),
            cmin=cmin,
            cmax=cmax,
    )])
    fig.update_layout(
        scene_camera=dict(
            up=dict(x=0, y=0, z=1),
            center=dict(x=0, y=0, z=0),
            eye=dict(x=0., y=0., z=2.5)
        ),
        # paper_bgcolor="rgba(0,0,0,0)",
        # plot_bgcolor="rgba(0,0,0,0)",
        scene=dict(xaxis=dict(visible=False),
                   yaxis=dict(visible=False),
                   zaxis=dict(visible=False)),
    )
    fig.update_traces(showscale=False)
    return fig


In [3]:
theta_offset = -jnp.pi / 3
theta_diff = jnp.pi / 3
f_coeffs = 20. * e3nn.sum(e3nn.s2_dirac(jnp.asarray([[jnp.sin(theta_offset + jnp.pi / 3), -jnp.cos(theta_offset + jnp.pi / 3), 0.], [jnp.sin(theta_offset + jnp.pi / 3 + theta_diff), -jnp.cos(theta_offset + jnp.pi / 3 + theta_diff), 0.]]), lmax=8, p_arg=-1, p_val=1), axis=0)
f_signal = e3nn.to_s2grid(f_coeffs, res_beta=100, res_alpha=99, quadrature="gausslegendre")
# f_signal = f_signal.apply(jnp.exp)

In [4]:
fig = plot_signal(f_signal)
fig.write_image(f"spherical_harmonic_logits.pdf")

In [5]:
def get_spherical_harmonic_projections(coeffs: e3nn.IrrepsArray) -> Iterable[e3nn.SphericalSignal]:
    """Returns a list of SphericalSignals, one for each irreducible representation in coeffs."""
    coeffs = coeffs.regroup()
    for irrep, array in zip(coeffs.irreps, coeffs.chunks):
        array = array.reshape(-1)
        signal_coeffs = e3nn.IrrepsArray(irrep, array)
        yield irrep, e3nn.to_s2grid(signal_coeffs, res_beta=100, res_alpha=359, quadrature="gausslegendre", p_val=1, p_arg=-1)
        
for irrep, signal in get_spherical_harmonic_projections(f_coeffs):
    fig = plot_signal(signal)
    print(irrep)
    fig.write_image(f"spherical_harmonic_projection_l={irrep.ir.l}.pdf")

1x0e
1x1o
1x2e
1x3o
1x4e
1x5o
1x6e
1x7o
1x8e


In [6]:
def spherical_harmonics_as_signals(l: int) -> Iterable[e3nn.SphericalSignal]:
    """Yields the spherical harmonics of degree l as a sequence of e3nn.SphericalSignal objects for each m such that -l <= m <= l."""
    res = (50, 49)
    for m in range(-l, l + 1):
        coeffs = e3nn.IrrepsArray(e3nn.s2_irreps(l)[-1], jnp.asarray([1. if md == m else 0. for md in range(-l, l + 1)]))
        yield e3nn.to_s2grid(coeffs, *res, quadrature="soft", p_val=1, p_arg=-1)

def plot_spherical_harmonics(l: int) -> None:
    """Plots the spherical harmonics of degree l on a single row of subplots with one column for each m such that -l <= m <= l."""
    fig = plotly.subplots.make_subplots(rows=1, cols=2*l + 1, specs=[[{'type': 'surface'} for _ in range(2*l + 1)]], subplot_titles=[f"m = {m}" for m in range(-l, l + 1)])
    camera = dict(
        up=dict(x=0, y=0, z=1),
        center=dict(x=0, y=0, z=0),
        eye=dict(x=2.25, y=2.25, z=2.25)
    )

    for index, sig in enumerate(spherical_harmonics_as_signals(l), start=1):
        fig.add_trace(go.Surface(sig.plotly_surface(scale_radius_by_amplitude=True), cmax=2, cmin=-2, showscale=(index == 2*l + 1)), row=1, col=index)
        fig.layout[f"scene{index}"].camera = camera

    fig.update_layout(title=f"Spherical Harmonics of Degree l = {l}", title_x=0.5)
    fig.show()

plot_spherical_harmonics(l=2)

In [7]:
fig = plot_signal(f_signal)
fig