In [None]:
import jax
import jax.numpy as jnp
import e3nn_jax as e3nn
import plotly.graph_objects as go
import pandas as pd
import optax
import chex
import seaborn as sns
import matplotlib.pyplot as plt
import sys
import functools
sys.path.append('../..')

In [None]:
%load_ext autoreload

In [None]:
%autoreload 2

from symphony import loss
from symphony import models
import helpers

In [None]:
# Create a random signal on the sphere, by sampling random points on the sphere.
N_points = 5
lmax = 5
rng = jax.random.PRNGKey(0)
random_points = jax.random.normal(rng, (N_points, 3))
random_points = random_points / jnp.linalg.norm(random_points, axis=-1, keepdims=True)
random_signal = 20 * e3nn.s2_dirac(random_points, lmax=lmax, p_val=1, p_arg=-1)

In [None]:
random_dist = helpers.average_target_distributions(random_signal, 159, 80)
random_dist_copy = e3nn.SphericalSignal(grid_values=random_dist.grid_values[0], quadrature=random_dist.quadrature)
fig = go.Figure([go.Surface(random_dist_copy.plotly_surface(scale_radius_by_amplitude=False, radius=0.8, normalize_radius_by_max_amplitude=True)),
                 go.Scatter3d(x=random_points[:, 0], y=random_points[:, 1], z=random_points[:, 2], mode='markers')])

# Do not show the axis
fig.update_layout(scene = dict(
                    xaxis = dict(showticklabels=False, visible=False),
                    yaxis = dict(showticklabels=False, visible=False),
                    zaxis = dict(showticklabels=False, visible=False)))
fig.show()

In [None]:
# Try to learn random signal via gradient descent on the KL divergence.
def optimize_coeffs(true_signal, lmax, position_channels, res_alpha, res_beta, use_simm_et_al, regularize_coeffs, num_training_steps):

    # Compute the target distribution
    true_dist = helpers.average_target_distributions(true_signal, res_alpha=res_alpha, res_beta=res_beta)

    rng = jax.random.PRNGKey(0)
    irreps = e3nn.s2_irreps(lmax, p_val=1, p_arg=-1)
    coeffs = e3nn.normal(irreps, rng, (position_channels, 1))

    tx = optax.adam(1e-3)
    opt_state = tx.init(coeffs)

    def loss_fn(coeffs):
        log_predicted_dist = models.log_coeffs_to_logits(coeffs, res_beta=res_beta, res_alpha=res_alpha, num_radii=1)
        if use_simm_et_al:
            log_predicted_dist.grid_values = log_predicted_dist.grid_values ** 2
            if regularize_coeffs:
                log_predicted_dist.grid_values = log_predicted_dist.grid_values / e3nn.norm(coeffs, per_irrep=True, squared=True).array.sum()

        return loss.kl_divergence_on_spheres(true_dist, log_predicted_dist)

    @jax.jit
    def train_step(coeffs, opt_state):
        loss_value, grads = jax.value_and_grad(loss_fn)(coeffs)
        grad_norms = jnp.linalg.norm(grads.array)
        updates, opt_state = tx.update(grads, opt_state, coeffs)
        coeffs = optax.apply_updates(coeffs, updates)
        return coeffs, opt_state, loss_value, grad_norms


    training_dict = {}
    for step in range(num_training_steps):
        coeffs, opt_state, loss_value, grad_norms = train_step(coeffs, opt_state)
        if step % 5000 == 0 or step == num_training_steps - 1:
            print(f"Step {step}, Loss: {loss_value}")
    
        if step % 10 == 0:
            # step_rng = jax.random.fold_in(rng, step)
            # dist = helpers.coeffs_to_distribution(coeffs, res_alpha, res_beta)
            # mean_dist, std_dist = helpers.rmse_of_samples(dist, random_points, step_rng, num_samples=1000)
            training_dict[step] = {
                "coeffs": coeffs.array,
                "loss_value": float(loss_value.item()),
                "grad_norms": float(grad_norms.item()),
                # "mean_dist": float(mean_dist),
                # "std_dist": float(std_dist),
            }

    return training_dict


In [None]:
# We vary the number of position channels and lmax.
# See how well we can learn the signal.
results_df = pd.DataFrame(columns=["res_alpha", "res_beta", "lmax", "position_channels", "loss", "loss_diff", "coeffs", "grad_norms", "use_simm_et_al", "regularize_coeffs"])
for res_alpha in [179]:
    for res_beta in [90]:
        for lmax in range(1, 6):
            for position_channels in range(1, 10):
                for use_simm_et_al in [True, False]:
                    for regularize_coeffs in [True, False]:
                        if regularize_coeffs and not use_simm_et_al:
                            continue

                        training_dict = optimize_coeffs(random_signal, lmax, position_channels, res_alpha, res_beta, use_simm_et_al, regularize_coeffs, num_training_steps=10000) 
                
                        first_step = list(training_dict.keys())[0]
                        last_step = list(training_dict.keys())[-1]
                        second_last_step = list(training_dict.keys())[-2]
                        loss_diff = (training_dict[last_step]["loss_value"] - training_dict[second_last_step]["loss_value"]) / (second_last_step - last_step)
                        assert first_step == 0
                        
                        results_df = results_df.append({
                            "res_alpha": res_alpha,
                            "res_beta": res_beta,
                            "lmax": lmax,
                            "position_channels": position_channels,
                            "loss": training_dict[last_step]["loss_value"],
                            "loss_diff": loss_diff,
                            "grad_norms": training_dict[last_step]["grad_norms"],
                            "coeffs": training_dict[last_step]["coeffs"],
                            "use_simm_et_al": use_simm_et_al,
                            "regularize_coeffs": regularize_coeffs,
                        }, ignore_index=True)


In [None]:
results_df

In [None]:
fig, ax = plt.subplots(ncols=1, figsize=(5, 5), sharey=True, sharex=True)
sns.set_theme(style="darkgrid")

# Make 1 plot, one for results where:
# * use_simm_et_al = True, regularize_coeffs = True
# * use_simm_et_al = True, regularize_coeffs = False
sns.barplot(data=results_df[results_df["use_simm_et_al"] == False], x="lmax", y="loss", hue="position_channels", palette="Blues", ax=ax)


fig.suptitle("KL Divergence on Random Signal")
ax.set_yscale("log")
ax.set_ylim([1e-4, 1e1])

# Remove legends
ax.get_legend().remove()

# Set x-axis label as "Max L"
ax.set_xlabel("Max L")
# ax.set_title("Symphony Parametrization")

# Set y-axis label as "KL Divergence"
ax.set_ylabel("KL Divergence")
# Place legend outside the figure
plt.legend(bbox_to_anchor=(1.05, 0.5), loc="center left", borderaxespad=0., title="Position Channels")
plt.savefig("pdfs/kl_divergence_random_signal.pdf", dpi=500, bbox_inches='tight')
plt.show()

In [None]:
fig, axs = plt.subplots(ncols=2, figsize=(10, 5), sharey=True, sharex=True)
sns.set_theme(style="darkgrid")


# Make 2 plots, one for results where:
# * use_simm_et_al = True, regularize_coeffs = True
# * use_simm_et_al = True, regularize_coeffs = False
sns.barplot(data=results_df[(results_df["use_simm_et_al"] == True) & (results_df["regularize_coeffs"] == True)], x="lmax", y="loss", hue="position_channels", palette="Blues", ax=axs[0])
sns.barplot(data=results_df[(results_df["use_simm_et_al"] == True) & (results_df["regularize_coeffs"] == False)], x="lmax", y="loss", hue="position_channels", palette="Blues", ax=axs[1])


fig.suptitle("KL Divergence on Random Signal")
# Remove legends
axs[0].get_legend().remove()
axs[0].set_yscale("log")
axs[0].set_ylim([1e-4, 1e1])

# Set x-axis label as "Max L"
axs[0].set_xlabel("Max L")
axs[1].set_xlabel("Max L")
axs[0].set_title("Simm et al. (2021) Parametrization")
axs[1].set_title("Simm et al. (2021) Parametrization without Regularization")

# Set y-axis label as "KL Divergence"
axs[0].set_ylabel("KL Divergence")
axs[1].set_ylabel("KL Divergence")

# Place legend outside the figure
plt.legend(bbox_to_anchor=(1.05, 0.5), loc="center left", borderaxespad=0., title="Position Channels")
plt.savefig("pdfs/kl_divergence_random_signal_simm.pdf", dpi=500, bbox_inches='tight')
plt.show()

In [None]:
lmax = 5
position_channels = 1
use_simm_et_al = True
regularize_coeffs = True

results_df_subset = results_df[(results_df["lmax"] == lmax) & (results_df["use_simm_et_al"] == use_simm_et_al) & (results_df["regularize_coeffs"] == regularize_coeffs) & (results_df["position_channels"] == position_channels)]
print(results_df_subset)

coeffs = results_df_subset["coeffs"].values.item()
coeffs = e3nn.IrrepsArray(e3nn.s2_irreps(lmax, p_val=1, p_arg=-1), coeffs)

predicted_dist = models.log_coeffs_to_logits(coeffs, res_beta=90, res_alpha=179, num_radii=1)
if use_simm_et_al:
    predicted_dist.grid_values = predicted_dist.grid_values ** 2
    if regularize_coeffs:
        predicted_dist.grid_values = predicted_dist.grid_values / e3nn.norm(coeffs, per_irrep=True, squared=True).array.sum()

predicted_dist = models.position_logits_to_position_distribution(predicted_dist)
samples_rng, rng = jax.random.split(rng)
samples = helpers.sample_from_dist(predicted_dist, samples_rng, num_samples=10)

predicted_dist.grid_values = predicted_dist.grid_values[0]
fig = go.Figure([go.Surface(predicted_dist.plotly_surface(scale_radius_by_amplitude=False, radius=0.8, normalize_radius_by_max_amplitude=True)),
                 go.Scatter3d(x=samples[:, 0], y=samples[:, 1], z=samples[:, 2], mode='markers'),
                 go.Scatter3d(x=random_points[:, 0], y=random_points[:, 1], z=random_points[:, 2], mode='markers')])
fig.update_layout(scene = dict(
                    xaxis = dict(showticklabels=False, visible=False),
                    yaxis = dict(showticklabels=False, visible=False),
                    zaxis = dict(showticklabels=False, visible=False)))
fig.show()