In [3]:
from ase.atoms import Atoms
import ase.io
from ase.visualize import view
import e3nn_jax as e3nn
import ipywidgets as widgets
import jax
import jax.numpy as jnp
import jraph
import numpy as np
import plotly.express as px
import plotly.graph_objects as go

import sys

sys.path.append("../analyses")
sys.path.append("..")
import analysis
import datatypes
import input_pipeline_tf
import train

In [4]:
def to_spherical_coordinates(coords):
    """Returns the spherical coordinates of a point in 3D space."""
    x, y, z = coords
    return jnp.arccos(z / jnp.sqrt(x ** 2 + y ** 2 + z ** 2)), jnp.arctan2(y, x)

def p_unnormalized(theta: float, phi: float) -> float:
    """Returns the value of the function at the given spherical coordinates."""
    return jnp.sin(theta) + jnp.cos(phi)

In [5]:
resolution = (100, 359)
p_signal = e3nn.SphericalSignal.zeros(*resolution, quadrature="soft")
thetas, phis = jax.vmap(jax.vmap(to_spherical_coordinates))(p_signal.grid_vectors)
p_signal.grid_values = jax.vmap(jax.vmap(p_unnormalized))(thetas, phis)

In [16]:
go.Figure([
    go.Surface(
        p_signal.plotly_surface(scale_radius_by_amplitude=True),
        cmin=-1,
        cmax=2)
])

In [21]:
coeffs = e3nn.from_s2grid(p_signal, '0e + 1o + 2e + 3o + 4e')
signal_sh = e3nn.to_s2grid(coeffs, *resolution, quadrature="soft", p_val=1, p_arg=-1)

go.Figure([
    go.Surface(
        signal_sh.plotly_surface(scale_radius_by_amplitude=True),
        cmin=-1,
        cmax=2
    )
])