In [None]:
import jax
import jax.numpy as jnp
import numpy as np
import e3nn_jax as e3nn
import plotly.graph_objects as go
import pandas as pd
import optax
import chex
import seaborn as sns
import matplotlib.pyplot as plt
import sys
import functools
sys.path.append('../..')

In [None]:
%load_ext autoreload

In [None]:
%autoreload 2

from symphony import loss
from symphony import models
import helpers

In [None]:
res_alphas = [9, 179, 359, 719]
res_betas = [10, 180, 360]

In [None]:
# Create a random signal on the sphere, by sampling random points on the sphere.
N_points = 5
lmax = 5
rng = jax.random.PRNGKey(0)
random_points = jax.random.normal(rng, (N_points, 3))
random_points = random_points / jnp.linalg.norm(random_points, axis=-1, keepdims=True)
random_signal = 20 * e3nn.s2_dirac(random_points, lmax=lmax, p_val=1, p_arg=-1)

In [None]:
random_dist = helpers.average_target_distributions(random_signal, 159, 80)
best_mean_dist, best_std_dist = helpers.rmse_of_samples(random_dist, random_points, rng)
print(best_mean_dist, best_std_dist)

In [None]:
random_dist_copy = e3nn.SphericalSignal(grid_values=random_dist.grid_values[0], quadrature=random_dist.quadrature)
fig = go.Figure([go.Surface(random_dist_copy.plotly_surface(scale_radius_by_amplitude=False, radius=0.8, normalize_radius_by_max_amplitude=True)),
                 go.Scatter3d(x=random_points[:, 0], y=random_points[:, 1], z=random_points[:, 2], mode='markers')])

# Do not show the axis
fig.update_layout(scene = dict(
                    xaxis = dict(showticklabels=False, visible=False),
                    yaxis = dict(showticklabels=False, visible=False),
                    zaxis = dict(showticklabels=False, visible=False)))
fig.show()

In [None]:
# Try to learn random signal via gradient descent on the KL divergence.
def optimize_coeffs(true_signal, lmax, position_channels, res_alpha, res_beta, num_training_steps):

    # Compute the target distribution
    true_dist = helpers.average_target_distributions(true_signal, res_alpha, res_beta)

    rng = jax.random.PRNGKey(0)
    irreps = e3nn.s2_irreps(lmax, p_val=1, p_arg=-1)
    coeffs = e3nn.normal(irreps, rng, (position_channels, 1))

    tx = optax.adam(1e-3)
    opt_state = tx.init(coeffs)

    def loss_fn(coeffs):
        log_predicted_dist = models.log_coeffs_to_logits(coeffs, res_beta=res_beta, res_alpha=res_alpha, num_radii=1)
        return loss.kl_divergence_on_spheres(true_dist, log_predicted_dist)

    @jax.jit
    def train_step(coeffs, opt_state):
        loss_value, grads = jax.value_and_grad(loss_fn)(coeffs)
        grad_norms = jnp.linalg.norm(grads.array)
        updates, opt_state = tx.update(grads, opt_state, coeffs)
        coeffs = optax.apply_updates(coeffs, updates)
        return coeffs, opt_state, loss_value, grad_norms


    training_dict = {}
    for step in range(num_training_steps):
        coeffs, opt_state, loss_value, grad_norms = train_step(coeffs, opt_state)
        if step % 5000 == 0:
            print(f"Step {step}, Loss: {loss_value}")
        
        if step % 10 == 0:
            step_rng = jax.random.fold_in(rng, step)
            dist = helpers.coeffs_to_distribution(coeffs, res_alpha, res_beta)
            mean_dist, std_dist = helpers.rmse_of_samples(dist, random_points, step_rng, num_samples=1000)
            
            training_dict[step] = {
                "coeffs": coeffs.array,
                "loss_value": float(loss_value),
                "grad_norms": float(grad_norms),
                "mean_dist": float(mean_dist),
                "std_dist": float(std_dist),
            }

    return training_dict


In [None]:
# We vary the number of position channels and lmax.
# See how well we can learn the signal.
results_df = pd.DataFrame(columns=["res_alpha", "res_beta", "lmax", "position_channels", "loss", "coeffs", "grad_norms", "mean_dist", "std_dist"])
for res_alpha in res_alphas:
    for res_beta in res_betas:
        for lmax in [5]:
            for position_channels in [2]:
                training_dict = optimize_coeffs(random_signal, lmax, position_channels, res_alpha, res_beta, num_training_steps=10000) 
                
                first_step = list(training_dict.keys())[0]
                last_step = list(training_dict.keys())[-1]
                assert first_step == 0

                results_df = results_df.append({
                    "res_alpha": res_alpha,
                    "res_beta": res_beta,
                    "lmax": lmax,
                    "position_channels": position_channels,
                    "steps": [step for step in training_dict.keys()],
                    "loss": [training_dict[step]["loss_value"] for step in training_dict.keys()],
                    "grad_norms": [training_dict[step]["grad_norms"] for step in training_dict.keys()],
                    "coeffs": training_dict[last_step]["coeffs"],
                    "mean_dist": [training_dict[step]["mean_dist"] for step in training_dict.keys()],
                    "std_dist": [training_dict[step]["std_dist"] for step in training_dict.keys()],
                }, ignore_index=True)
        

In [None]:
results_df

In [None]:
def get_color_value(res_alpha, res_beta):
    return np.cbrt((res_alpha * res_beta) / (np.max(res_alphas) * np.max(res_betas)))

In [None]:
# Make a line plot of loss vs. number of training steps, for each choice of res_alpha and res_beta.
fig, axs = plt.subplots(ncols=3, figsize=(12, 5))
sns.set_style("darkgrid")
for res_alpha in res_alphas:
    for res_beta in res_betas:
        allowed_res_alphas_and_betas = [(9, 10), (9, 180), (179, 180), (359, 180), (719, 360)]
        if (res_alpha, res_beta) not in allowed_res_alphas_and_betas:
            continue
        color_value = get_color_value(res_alpha, res_beta)
        color = plt.cm.viridis(color_value)
        subset_df = results_df[(results_df["res_alpha"] == res_alpha) & (results_df["res_beta"] == res_beta)]
        steps = subset_df["steps"].iloc[0]
        loss = jnp.asarray(subset_df["loss"].iloc[0])
        grad_norms = jnp.asarray(subset_df["grad_norms"].iloc[0])
        mean_dist = jnp.asarray(subset_df["mean_dist"].iloc[0])

        axs[0].plot(steps, loss,
                    label=fr"$(r_\theta, r_\phi) = ({res_beta},{res_alpha})$",
                    color=color)

        axs[1].plot(steps, grad_norms,
                    label=fr"$(r_\theta, r_\phi) = ({res_beta},{res_alpha})$",
                    color=color)

        axs[2].plot(steps, mean_dist,
                label=fr"$(r_\theta, r_\phi) = ({res_beta},{res_alpha})$",
                color=color)
    
        # Also plot the std of the rmse, to get a sense of the variance.
        # std_dist = jnp.asarray(subset_df["std_dist"].iloc[0])
        # ax.fill_between(steps, mean_dist - std_dist, mean_dist + std_dist, alpha=0.2, color=color)


def extract_resolution_from_label(args):
    label, handle = args
    label = label.replace("$", "")
    label = label.replace("(", "")
    label = label.replace(")", "")
    label = label.split("=")[1].split(",")
    res_alpha = int(label[0])
    res_beta = int(label[1])
    return res_alpha * res_beta

fig.suptitle("Learning a Random Distribution on the Sphere", fontsize=16)
axs[0].set_yscale("log")
axs[0].set_title("KL Divergence Loss")
axs[0].set_ylabel("KL Divergence Loss")
axs[0].set_xlabel("Number of Training Steps")

axs[1].set_yscale("log")
axs[1].set_title("Gradient Norms")
axs[1].set_ylabel("Gradient Norms")
axs[1].set_xlabel("Number of Training Steps")

axs[2].set_title("Mean Distance to Closest Target Point")
axs[2].set_ylabel("Mean Distance to Closest Target Point")
axs[2].set_xlabel("Number of Training Steps")

# Sort legend by resolution
handles, labels = axs[0].get_legend_handles_labels()
labels, handles = zip(*sorted(zip(labels, handles), key=extract_resolution_from_label))
plt.legend(handles, labels, bbox_to_anchor=(1.05, 0.5), loc="center left", borderaxespad=0., title=r"Resolution ($r_\theta, r_\phi$)")
plt.savefig("pdfs/resolution_random_points.pdf", dpi=500, bbox_inches="tight")
plt.show()

In [None]:
lmax = 5
position_channels = 1

coeffs = results_df[results_df["lmax"] == lmax]["coeffs"].values[position_channels - 1]
coeffs = e3nn.IrrepsArray(e3nn.s2_irreps(lmax, p_val=1, p_arg=-1), coeffs)
predicted_dist = models.log_coeffs_to_logits(coeffs, res_beta=90, res_alpha=179, num_radii=1)
predicted_dist = models.position_logits_to_position_distribution(predicted_dist)
samples_rng, rng = jax.random.split(rng)
samples = helpers.sample_from_dist(predicted_dist, samples_rng, num_samples=1000)

predicted_dist.grid_values = predicted_dist.grid_values[0]
fig = go.Figure([go.Surface(predicted_dist.plotly_surface(scale_radius_by_amplitude=False, radius=0.8, normalize_radius_by_max_amplitude=True)),
                 go.Scatter3d(x=samples[:, 0], y=samples[:, 1], z=samples[:, 2], mode='markers'),
                 go.Scatter3d(x=random_points[:, 0], y=random_points[:, 1], z=random_points[:, 2], mode='markers')])
fig.update_layout(scene = dict(
                    xaxis = dict(showticklabels=False, visible=False),
                    yaxis = dict(showticklabels=False, visible=False),
                    zaxis = dict(showticklabels=False, visible=False)))
fig.show()