In [2]:
import e3nn_jax as e3nn
import jax
import jax.numpy as jnp

import sys
sys.path.append('../')
from symphony import models

In [12]:
def kl_divergence_on_sphere(
    log_true_angular_coeffs: e3nn.IrrepsArray,
    log_predicted_dist: e3nn.SphericalSignal,
) -> jnp.ndarray:
    """Compute the KL divergence between two distributions on a sphere."""
    # Convert coefficients to a distribution on the sphere.
    log_true_angular_dist = e3nn.to_s2grid(
        log_true_angular_coeffs,
        log_predicted_dist.res_beta,
        log_predicted_dist.res_alpha,
        quadrature=log_predicted_dist.quadrature,
        p_val=1,
        p_arg=-1,
    )

    # Subtract the maximum value for numerical stability.
    log_true_angular_dist_max = jnp.max(
        log_true_angular_dist.grid_values, axis=(-2, -1), keepdims=True
    )
    log_true_angular_dist_max = jax.lax.stop_gradient(log_true_angular_dist_max)
    log_true_angular_dist = log_true_angular_dist.apply(
        lambda x: x - log_true_angular_dist_max
    )

    # Convert to a probability distribution, by taking the exponential and normalizing.
    true_angular_dist = log_true_angular_dist.apply(jnp.exp)
    true_angular_dist = true_angular_dist / true_angular_dist.integrate()
    true_dist = true_angular_dist

    # Check that shapes are correct.
    assert true_dist.grid_values.shape == (
        log_predicted_dist.res_beta,
        log_predicted_dist.res_alpha,
    ), true_dist.grid_values.shape

    # Now, compute the unnormalized predicted distribution over all spheres.
    # Subtract the maximum value for numerical stability.
    log_predicted_dist_max = jnp.max(log_predicted_dist.grid_values)
    log_predicted_dist_max = jax.lax.stop_gradient(log_predicted_dist_max)
    log_predicted_dist = log_predicted_dist.apply(
        lambda x: x - log_predicted_dist_max
    )

    # Compute the cross-entropy including a normalizing factor to account for the fact that the predicted distribution is not normalized.
    cross_entropy = -(true_dist * log_predicted_dist).integrate().array.sum()
    normalizing_factor = jnp.log(
        log_predicted_dist.apply(jnp.exp).integrate().array.sum()
    )

    # Compute the self-entropy of the true distribution.
    self_entropy = (
        -(true_dist * true_dist.apply(models.safe_log)).integrate().array.sum()
    )

    # This should be non-negative, upto numerical precision.
    return cross_entropy + normalizing_factor - self_entropy



In [None]:
position_logits = 
log_predicted_angular_dist = position_logits
log_predicted_angular_dist.grid_values = jax.vmap(
    lambda probs, radial_weights: (probs * radial_weights[:, None, None]).sum(
        axis=0
    )
)(position_probs.grid_values, true_radius_weights)
log_predicted_angular_dist.grid_values = jnp.log(log_predicted_angular_dist.grid_values)



In [23]:
log_true_angular_coeffs = e3nn.IrrepsArray("1o", jnp.asarray([1.0, 0.0, 0.0]))
log_predicted_angular_coeffs = e3nn.IrrepsArray("1o", jnp.asarray([0.0, 1.0, 0.0]))
log_predicted_dist = e3nn.to_s2grid(log_predicted_angular_coeffs, res_beta=100, res_alpha=99, quadrature="soft", p_val=1, p_arg=-1)
kl_divergence_on_sphere(log_true_angular_coeffs, log_predicted_dist)

Array(0.84398484, dtype=float32)