In [None]:
import e3nn_jax as e3nn
import haiku as hk
import jax
import jax.numpy as jnp
import optax
import plotly.graph_objects as go
import plotly.subplots

In [None]:
import sys
sys.path.append("..")

import loss as loss_py

In [None]:
num_radii = 1
res_beta = 180
res_alpha = 359

In [None]:
def log_coeffs_to_probability_distribution(log_coeffs: e3nn.IrrepsArray) -> e3nn.SphericalSignal:
    """Converts irreps defining the logits to a probability distribution.""" 
    num_channels = log_coeffs.shape[0]
    assert log_coeffs.shape == (num_channels, num_radii, log_coeffs.irreps.dim), log_coeffs.shape
    
    log_dist = e3nn.to_s2grid(log_coeffs, res_beta, res_alpha, quadrature="soft", p_val=1, p_arg=-1)
    assert log_dist.shape == (num_channels, num_radii, res_beta, res_alpha)

    log_dist_max = jnp.max(log_dist.grid_values, axis=(-4, -3, -2, -1), keepdims=True)
    log_dist_max = jax.lax.stop_gradient(log_dist_max)
    log_dist = log_dist.apply(
        lambda x: x - log_dist_max
    )

    dist = log_dist.apply(jnp.exp)
    dist = dist / dist.integrate().array.sum()
    dist.grid_values = dist.grid_values.sum(axis=-4)
    return dist

In [None]:
def plot_coeffs(coeffs: e3nn.IrrepsArray) -> go.Figure:
    sig = log_coeffs_to_probability_distribution(coeffs)
    # num_radii is 1.
    assert sig.grid_values.shape[0] == 1
    sig.grid_values = sig.grid_values[0]
    return plot_signal(sig)
 

def plot_signal(sig: e3nn.SphericalSignal) -> go.Figure:
    fig = plotly.subplots.make_subplots(rows=1, cols=2, specs=[[{'type': 'surface'}, {'type': 'surface'}]])
    trace_1 = go.Surface(sig.plotly_surface(scale_radius_by_amplitude=False))
    fig.add_trace(trace_1, row=1, col=1)

    trace_2 = go.Surface(sig.plotly_surface(scale_radius_by_amplitude=True))
    fig.add_trace(trace_2, row=1, col=2)

    return fig

In [None]:
true_radius_weights = jnp.asarray([1.])
log_true_angular_coeffs =  e3nn.IrrepsArray("4e", jnp.array([[[1.0, 0.0, 0.0, 0.0, 1.0, 1.0, 0.0, 0.0, 0.0]]]))

log_true_coeffs = true_radius_weights[None, :, None] * log_true_angular_coeffs
true_dist = log_coeffs_to_probability_distribution(log_true_coeffs)

plot_coeffs(log_true_coeffs)

In [None]:
def kl_divergence_on_spheres(
    true_dist: e3nn.SphericalSignal,
    log_predicted_dist: e3nn.SphericalSignal,
) -> jnp.ndarray:
    """Compute the KL divergence between two distributions on the spheres."""
    assert true_dist.grid_values.shape == (num_radii, res_beta, res_alpha)
    assert log_predicted_dist.grid_values.shape == (num_radii, res_beta, res_alpha)

    # Now, compute the unnormalized predicted distribution over all spheres.
    # Subtract the maximum value for numerical stability.
    log_predicted_dist_max = jnp.max(log_predicted_dist.grid_values)
    log_predicted_dist_max = jax.lax.stop_gradient(log_predicted_dist_max)
    log_predicted_dist = log_predicted_dist.apply(
        lambda x: x - log_predicted_dist_max
    )

    # Compute the cross-entropy including a normalizing factor to account for the fact that the predicted distribution is not normalized.
    cross_entropy = -(true_dist * log_predicted_dist).integrate().array.sum()
    normalizing_factor = jnp.log(
        log_predicted_dist.apply(jnp.exp).integrate().array.sum()
    )

    # Compute the self-entropy of the true distribution.
    self_entropy = (
        -(true_dist * true_dist.apply(loss_py.safe_log)).integrate().array.sum()
    )

    # This should be non-negative, upto numerical precision.
    return cross_entropy + normalizing_factor - self_entropy


In [None]:
def alignment_loss(predicted_coeffs: e3nn.IrrepsArray) -> e3nn.IrrepsArray:
    predicted_dist = log_coeffs_to_probability_distribution(predicted_coeffs)
    log_predicted_dist = predicted_dist.apply(jnp.log)
    loss = kl_divergence_on_spheres(true_dist, log_predicted_dist)
    return e3nn.IrrepsArray("0e", jnp.asarray([loss]))

In [None]:
# Optimize predicted coefficients.
rng = jax.random.PRNGKey(0)
lmax = 1
num_channels = 10
irreps = e3nn.s2_irreps(lmax)
init_coeffs = e3nn.normal(irreps, rng, leading_shape=(num_channels, num_radii))
# init_coeffs = init_coeffs.mul_to_axis(num_radii)
# init_coeffs = init_coeffs.mul_to_axis(num_channels)
plot_coeffs(init_coeffs)

In [None]:
# Optimize coefficients to minimize the KL divergence.
coeffs = init_coeffs
tx = optax.adam(1e-2)
opt_state = tx.init(coeffs)

@jax.jit
def train_step(coeffs, opt_state):
    grad = e3nn.grad(alignment_loss)(
        coeffs
    )
    loss = alignment_loss(coeffs)
    updates, opt_state = tx.update(grad, opt_state)
    coeffs = optax.apply_updates(coeffs, updates)
    return coeffs, opt_state, loss

for step in range(5000):
    coeffs, opt_state, loss = train_step(coeffs, opt_state)
    loss = loss.array.item()

    if step % 100 == 0:
        print("Loss at step {step} is {loss}".format(step=step, loss=loss))

In [None]:
print(coeffs, "vs", log_true_angular_coeffs)
plot_coeffs(coeffs)

In [None]:
plot_coeffs(coeffs[:1])

In [None]:
plot_coeffs(coeffs[1:2])

In [None]:
plot_coeffs(coeffs[2:3])

# Linearity of Projection

In [None]:
T = 100
sig1_irreps = e3nn.IrrepsArray("1o", T * jnp.array([1.0, 0.0, 0.0]))
sig1 = log_coeffs_to_probability_distribution(sig1_irreps)
plot_signal(sig1)

In [None]:
sig2_irreps = e3nn.IrrepsArray("1o", T * jnp.array([0.0, 0.0, 1.0]))
sig2 = log_coeffs_to_probability_distribution(sig2_irreps)
plot_signal(sig2)

In [None]:
plot_signal((sig1 + sig2) / 2)

In [None]:
sig12_combined = log_coeffs_to_probability_distribution((sig1_irreps + sig2_irreps) / 2)
plot_signal(sig12_combined)