# Gumbel on the Sphere

Implementing the Gumbel-Softmax trick for distributions defined on the sphere!

In [None]:
from typing import Tuple, Optional, Callable, Union, List

import functools

import chex
import e3nn_jax as e3nn
import jax
import jax.numpy as jnp
import plotly.graph_objects as go
import optax

In [None]:
def debug_print(*args):
    print(*args)


def debug_print(*args):
    pass

In [None]:
def softmax_on_sphere(logits: e3nn.SphericalSignal) -> e3nn.SphericalSignal:
    """Applies softmax on the sphere."""
    dist = logits.apply(jnp.exp)
    return dist / dist.integrate()


def coeffs_to_distribution(
    logits_coeffs: e3nn.IrrepsArray, grid_resolution: Tuple[int, int]
) -> e3nn.SphericalSignal:
    """Converts coefficients to a valid probability distribution."""
    unnormalized_logits = e3nn.to_s2grid(
        logits_coeffs, *grid_resolution, quadrature="gausslegendre", p_val=1, p_arg=-1
    )
    unnormalized_logits.grid_values -= unnormalized_logits.grid_values.max()
    return softmax_on_sphere(unnormalized_logits)


def coeffs_to_logits(
    logits_coeffs: e3nn.IrrepsArray, grid_resolution: Tuple[int, int]
) -> e3nn.SphericalSignal:
    """Converts coefficients to the logits of a valid probability distribution."""
    return coeffs_to_distribution(logits_coeffs, grid_resolution).apply(jnp.log)


#@functools.partial(jax.jit, static_argnames=["num_samples", "hard_sampling", "grid_resolution"])
def gumbel_softmax_on_sphere(
    rng: chex.PRNGKey,
    logits_coeffs: e3nn.IrrepsArray,
    gumbel_temperature: float,
    num_samples: int,
    hard_sampling: bool,
    grid_resolution: Tuple[int, int],
) -> e3nn.SphericalSignal:
    """Samples indices and 3D vectors from logits of a distribution defined on a sphere in a differentiable manner using the Gumbel-Softmax trick."""
    logits = coeffs_to_logits(logits_coeffs, grid_resolution)

    def single_sample_from_logits(gumbel_rng: chex.PRNGKey) -> jnp.ndarray:
        """Returns a single sample from the logits."""
        gumbel_rng, sample_rng = jax.random.split(gumbel_rng)
        gumbels = jax.random.gumbel(gumbel_rng, logits.shape)
        gumbels = e3nn.SphericalSignal(gumbels, logits.quadrature)
        debug_print("shapes:", logits.grid_values.shape, gumbels.shape)
        noisy_logits = (logits + gumbels) / gumbel_temperature
        debug_print("noisy", noisy_logits)
        soft_samples = softmax_on_sphere(noisy_logits)
        if not hard_sampling:
            return soft_samples

        debug_print("soft", soft_samples)
        # We return the hard samples, but we want to propagate the gradient of the soft samples.
        # TODO: Replace with actual argmax not sample.
        argmax_indices_2D = soft_samples.sample(sample_rng)
        # Maybe return the sampled vectors on the sphere as well?
        sampled_vectors = soft_samples.grid_vectors[argmax_indices_2D]
        debug_print("2d argmax", argmax_indices_2D)
        hard_samples = jnp.zeros_like(soft_samples.grid_values).at[argmax_indices_2D].set(1.)
        hard_samples = e3nn.SphericalSignal(hard_samples, logits.quadrature)
        hard_samples = hard_samples / hard_samples.integrate()
        debug_print("hard", hard_samples)
        hard_samples = jax.lax.stop_gradient(hard_samples - soft_samples) + soft_samples
        return hard_samples

    gumbel_rngs = jax.random.split(rng, num=num_samples)
    return jax.vmap(single_sample_from_logits)(gumbel_rngs)

In [None]:
logits_coeffs = e3nn.IrrepsArray("1o", jnp.asarray([10.0, 0.0, 0.0]))
grid_resolution = (50, 39)

go.Figure([go.Surface(coeffs_to_distribution(logits_coeffs, grid_resolution).plotly_surface())])

In [None]:
rng = jax.random.PRNGKey(0)
gumbel_temperature = 0.5

samples = gumbel_softmax_on_sphere(
    rng,
    logits_coeffs,
    gumbel_temperature,
    num_samples=1000,
    hard_sampling=True,
    grid_resolution=grid_resolution,
)

empirical_distribution = samples.grid_values.mean(axis=0)
empirical_distribution = e3nn.SphericalSignal(empirical_distribution, samples.quadrature)
go.Figure([go.Surface(empirical_distribution.plotly_surface())])

# Optimization

We can optimize coefficients to make our samples minimize some loss.
For example, if we want to minimize some function on the sphere!

In [None]:
def cost(value: jnp.ndarray, position: jnp.ndarray) -> float:
    """The function we want to minimize over the sphere."""
    x, y, z = position
    return value * (x ** 2 - y ** 2 + 3 * z)


def loss(coeffs: e3nn.IrrepsArray) -> float:
    """The loss function we want to minimize."""
    samples = gumbel_softmax_on_sphere(
        rng,
        coeffs,
        gumbel_temperature,
        num_samples=1000,
        hard_sampling=True,
        grid_resolution=grid_resolution,
    )
    mean_cost = jax.vmap(jax.vmap(cost))(samples.grid_values.transpose((1, 2, 0)), samples.grid_vectors).mean()
    return e3nn.IrrepsArray("0e", jnp.asarray([mean_cost]))

In [None]:
# Optimize loss using Adam.
loss_fn = jax.jit(loss)
grad_fn = jax.jit(e3nn.grad(loss))

tx = optax.adam(1e-2)
opt_state = tx.init(logits_coeffs)
for step in range(1000):
    loss_value, grads = loss_fn(logits_coeffs), grad_fn(logits_coeffs)
    updates, opt_state = tx.update(grads, opt_state)
    logits_coeffs = optax.apply_updates(logits_coeffs, updates)

    if step % 100 == 0:
        print(f"Step {step}, loss {loss_value}")

In [None]:
go.Figure([go.Surface(coeffs_to_distribution(logits_coeffs, grid_resolution).plotly_surface())])