In [None]:
from aux_functions import *

import jax
from jax.scipy.special import sph_harm
import pandas as pd
import os
from mpl_toolkits.mplot3d import Axes3D
from matplotlib import cm, colors
import matplotlib.pyplot as plt
from functools import partial

from scipy.special import sph_harm

In [None]:
blender_folder = "blender-data"
out_folder = os.path.join("blender-data", "outputs")

mean_inputs = pd.read_csv(
    os.path.join(blender_folder, "mean_inputs.csv"),
    names=["x", "y", "z"]
).to_numpy()

more_mean_inputs = pd.read_csv(
    os.path.join(blender_folder, "input_locations.csv"),
    names=["x", "y", "z"]
).to_numpy()

std_inputs = pd.read_csv(
    os.path.join(blender_folder, "std_inputs.csv"),
    names=["x", "y", "z"]
).to_numpy()

mean_inputs.shape, more_mean_inputs.shape

In [None]:
sph = car_to_sph(mean_inputs)

### Some spherical harmonics

Taken from https://en.wikipedia.org/wiki/Table_of_spherical_harmonics

In [None]:
@jax.jit
def Y11(car):
    return jnp.sqrt(3 / (4 * jnp.pi)) * car[0]

@jax.jit
def Y32(car):
    return .25 * jnp.sqrt(105 / jnp.pi) * ((car[0]**2 - car[1]**2) * car[2])

@jax.jit
def Y42(car):
    return 3 / 8 * jnp.sqrt(5 / jnp.pi) * (
        (car[0]**2 - car[1]**2) * (7 * car[2]**2 - 1)
    )

In [None]:
@jax.jit
def _single_proj_to_sphere(v, tb):
    return tb @ tb.T @ v

@jax.jit
def _star(v, car):
    return jnp.cross(car, v)

@jax.jit
def d_Y11(car, tangent_basis):
    fun = Y11
    # gradient in R3
    vs = jax.vmap(jax.jacfwd(fun, argnums=0))(car)
    # projection to sphere
    vs = jax.vmap(_single_proj_to_sphere)(vs, tangent_basis)
    return vs

@jax.jit
def sd_Y11(car, tangent_basis):
    fun = Y11
    # gradient in R3
    vs = jax.vmap(jax.jacfwd(fun, argnums=0))(car)
    # projection to sphere
    vs = jax.vmap(_single_proj_to_sphere)(vs, tangent_basis)
    vs = jax.vmap(_star)(vs, car)
    return vs

In [None]:
@jax.jit
def d_Y32(car, tangent_basis):
    fun = Y32
    # gradient in R3
    vs = jax.vmap(jax.jacfwd(fun, argnums=0))(car)
    # projection to sphere
    vs = jax.vmap(_single_proj_to_sphere)(vs, tangent_basis)
    return vs

@jax.jit
def sd_Y32(car, tangent_basis):
    fun = Y32
    # gradient in R3
    vs = jax.vmap(jax.jacfwd(fun, argnums=0))(car)
    # projection to sphere
    vs = jax.vmap(_single_proj_to_sphere)(vs, tangent_basis)
    vs = jax.vmap(_star)(vs, car)
    return vs

In [None]:
@jax.jit
def d_Y42(car, tangent_basis):
    fun = Y42
    # gradient in R3
    vs = jax.vmap(jax.jacfwd(fun, argnums=0))(car)
    # projection to sphere
    vs = jax.vmap(_single_proj_to_sphere)(vs, tangent_basis)
    return vs

@jax.jit
def sd_Y42(car, tangent_basis):
    fun = Y42
    # gradient in R3
    vs = jax.vmap(jax.jacfwd(fun, argnums=0))(car)
    # projection to sphere
    vs = jax.vmap(_single_proj_to_sphere)(vs, tangent_basis)
    vs = jax.vmap(_star)(vs, car)
    return vs

In [None]:
@jax.jit
def Y73(car):
    x, y, z = car
    return 3 / 32 * jnp.sqrt(385 / jnp.pi) * (
        (x**4 - 6 * x**2 * y**2 + y**4) * (1 - z**2)**2 * (13 * z**3 - 3 * z)
    )

@jax.jit
def d_Y73(car, tangent_basis):
    fun = Y73
    # gradient in R3
    vs = jax.vmap(jax.jacfwd(fun, argnums=0))(car)
    # projection to sphere
    vs = jax.vmap(_single_proj_to_sphere)(vs, tangent_basis)
    return vs

@jax.jit
def sd_Y73(car, tangent_basis):
    fun = Y73
    # gradient in R3
    vs = jax.vmap(jax.jacfwd(fun, argnums=0))(car)
    # projection to sphere
    vs = jax.vmap(_single_proj_to_sphere)(vs, tangent_basis)
    vs = jax.vmap(_star)(vs, car)
    return vs

In [None]:
# jnp.sqrt(2) * (-1)**m * sph_harm(jnp.array([m]), jnp.array([ell]), phi, theta).real

### Plotting to check

In [None]:
def plot_harmonic(fun, d_fun, sd_fun, v_scale=1.):
    nx, ny = 100j, 100j
    u, v = np.mgrid[(-np.pi/2):(np.pi/2):ny, 0:2*np.pi:nx]
    sph = np.stack([u.flatten(), v.flatten()]).T
    tangent_basis = sphere_tangent_basis(sph)
    car = sph_to_car(sph)
    x, y, z = car.T
    x, y, z = x.reshape(u.shape), y.reshape(u.shape), z.reshape(u.shape)
    vals = jax.vmap(fun)(car).reshape(u.shape)
    
    fig = plt.figure(figsize=(8, 24))
    ax1 = fig.add_subplot(311, projection='3d')
    ax2 = fig.add_subplot(312, projection='3d')
    ax3 = fig.add_subplot(313, projection='3d')
    
    norm = colors.Normalize(vmin = np.min(vals), vmax = np.max(vals), clip = False)
    
    ax1.plot_surface(
        x, y, z, rstride=1, cstride=1, cmap=cm.coolwarm,
        linewidth=0, antialiased=False,
        facecolors=cm.coolwarm(norm(vals))
    )
    ax2.plot_surface(
        x, y, z, rstride=1, cstride=1, cmap=cm.coolwarm,
        linewidth=0, antialiased=False,
        facecolors=cm.coolwarm(norm(vals))
    )
    ax3.plot_surface(
        x, y, z, rstride=1, cstride=1, cmap=cm.coolwarm,
        linewidth=0, antialiased=False,
        facecolors=cm.coolwarm(norm(vals))
    )
    
    # fewer points for vectors
    
    nx, ny = 25j, 25j
    u, v = np.mgrid[(-np.pi/2):(np.pi/2):ny, 0:2*np.pi:nx]
    sph = np.stack([u.flatten(), v.flatten()]).T
    tangent_basis = sphere_tangent_basis(sph)
    car = sph_to_car(sph)
    x, y, z = car.T * 1.01
    x, y, z = x.reshape(u.shape), y.reshape(u.shape), z.reshape(u.shape)
    mask = (y.flatten() < 0)  # drop some stuff "behind"
    vf_d = d_fun(car, tangent_basis)[mask] * v_scale
    vf_sd = sd_fun(car, tangent_basis)[mask] * v_scale
    
    ax2.quiver(x.flatten()[mask], y.flatten()[mask], z.flatten()[mask], vf_d[:, 0], vf_d[:, 1], vf_d[:, 2], color="k")
    ax3.quiver(x.flatten()[mask], y.flatten()[mask], z.flatten()[mask], vf_sd[:, 0], vf_sd[:, 1], vf_sd[:, 2], color="k")
    return fig

In [None]:
fig = plot_harmonic(Y11, d_Y11, sd_Y11, .5)

In [None]:
fig = plot_harmonic(Y32, d_Y32, sd_Y32, .2)

In [None]:
fig = plot_harmonic(Y42, d_Y42, sd_Y42, .1)

In [None]:
fig = plot_harmonic(Y73, d_Y73, sd_Y73, .1)

### Blender data generation

In [None]:
def blender_eigenfunctions(Y, d_Y, sd_Y, ell, m, inputs=mean_inputs):
    # function
    f = jax.vmap(Y)(std_inputs)
    # vector fields
    sph = car_to_sph(inputs)
    tangent_basis = sphere_tangent_basis(sph)
    d = d_Y(inputs, tangent_basis)
    sd = sd_Y(inputs, tangent_basis)
    # save
    out_folder = os.path.join("blender-data", "outputs")
    np.savetxt(os.path.join(out_folder, f"eigenfield_{ell=}_{m=}_curl_free__mean.csv"), np.hstack([inputs, d / np.sqrt(ell * (ell + 1))]), delimiter=",")
    np.savetxt(os.path.join(out_folder, f"eigenfield_{ell=}_{m=}_curl_free__std.csv"), f, delimiter=",")
    np.savetxt(os.path.join(out_folder, f"eigenfield_{ell=}_{m=}_div_free__mean.csv"), np.hstack([inputs, sd / np.sqrt(ell * (ell + 1))]), delimiter=",")
    np.savetxt(os.path.join(out_folder, f"eigenfield_{ell=}_{m=}_div_free__std.csv"), f, delimiter=",")

In [None]:
blender_eigenfunctions(Y11, d_Y11, sd_Y11, ell=1, m=1)

In [None]:
blender_eigenfunctions(Y32, d_Y32, sd_Y32, ell=3, m=2)

In [None]:
blender_eigenfunctions(Y42, d_Y42, sd_Y42, ell=4, m=2)

In [None]:
blender_eigenfunctions(Y73, d_Y73, sd_Y73, ell=7, m=3, inputs=more_mean_inputs)