In [None]:
import jax 
import jax.numpy as jnp 
from jaxtyping import Float, Key, Array 
import numpy as np 
import pandas as pd 
import plotly.express as px 
from plotly import graph_objects as go
from plotly.subplots import make_subplots
from scipy.special import sph_harm 

In [None]:
def rotate(sph):
    """
    Apply a roll rotation of pi / 2 to the points in spherical coordinates (colatitude, longitude) in [0, pi] x [0, 2pi].
    
    Parameters:
    colatitude (float or jnp.ndarray): Colatitude values in the range [0, pi].
    longitude (float or jnp.ndarray): Longitude values in the range [0, 2pi].

    Returns:
    jnp.ndarray: Array of transformed colatitude and longitude values.
    """
    colatitude, longitude = sph[..., 0], sph[..., 1]
    
    # Convert back to spherical coordinates
    new_colatitude = jnp.arccos(jnp.sin(colatitude) * jnp.sin(longitude))
    new_longitude = jnp.arctan2(-jnp.cos(colatitude), jnp.sin(colatitude) * jnp.cos(longitude))
    
    return jnp.stack([new_colatitude, new_longitude], axis=-1)



def reversed_spherical_harmonic(sph, m: int, n: int):
    colat, lon = sph[..., 0], sph[..., 1]
    return jnp.asarray(sph_harm(m, n, np.asarray(colat), np.asarray(lon)).real)


def target_f__reversed_spherical_harmonic(sph: Float) -> Float:
    return reversed_spherical_harmonic(sph, m=1, n=2) + reversed_spherical_harmonic(rotate(sph), m=1, n=1)


@jax.jit
def car_to_sph(car):
    x, y, z = car[..., 0], car[..., 1], car[..., 2]
    colat = jnp.arccos(z)
    lon = jnp.arctan2(y, x)
    return jnp.stack([colat, lon], axis=-1)


def add_noise(f: Float[Array, " N"], noise_std: float = 0.01, *, key: Key) -> Float:
    return f + jax.random.normal(key=key, shape=f.shape) * noise_std

In [None]:
key = jax.random.key(0)


x = jnp.asarray(pd.read_csv('../std_inputs.csv', header=None, names=['x', 'y', 'z']).values)
f = target_f__reversed_spherical_harmonic(car_to_sph(x))
y = add_noise(f, noise_std=0.01, key=key)

In [None]:
fig = go.Figure() 
fig.add_trace(
    go.Scatter3d(
        x=x[:, 0], 
        y=x[:, 1], 
        z=x[:, 2], 
        mode='markers', 
        marker=dict(
            color=-y, 
            size=3,
            colorscale='magma',
        ),
    ), 
)

fig.update_layout(
    scene=dict(
        xaxis=dict(visible=False),
        yaxis=dict(visible=False),
        zaxis=dict(visible=False),
    ),
    width=1300,
    height=600,
    showlegend=False,
)
fig.write_image("synthetic_target_function.pdf")
fig.show()

In [None]:
# save the data as csv using the names of the variables
data = [
    x, 
    -y, 
]

names = [
    'synthetic_y-inputs',
    'synthetic_y-outputs', 
]


for datum, name in zip(data, names):
    pd.DataFrame(datum).to_csv(f"{name}.csv", header=False, index=False)