In [1]:
import functools

import optax
import jax.numpy as jnp
import jax
import e3nn_jax as e3nn
import matplotlib.pyplot as plt
import seaborn as sns
import matplotlib.colors
import plotly.graph_objects as go

import sys
sys.path.append('../')

In [2]:
from symphony import loss
from symphony import models

In [3]:
target_positions = jnp.asarray([[[1., 0., 0.], [0., 1., 0.], [0., 0., 1.]]])
# target_positions = jnp.asarray([[1., 0., 0.]])
target_position_inverse_temperature = 10000.
lmax = 3
res_beta = 40
res_alpha = 39

In [4]:
log_true_angular_coeffs = jax.vmap(
    lambda pos: loss.target_position_to_log_angular_coeffs(
        pos, target_position_inverse_temperature=target_position_inverse_temperature, lmax=lmax,
    )
)(target_positions)

compute_grid_of_joint_distribution_fn = functools.partial(
        models.compute_grid_of_joint_distribution,
        res_beta=res_beta,
        res_alpha=res_alpha,
        quadrature="soft",
    )
true_angular_dist = jax.vmap(
    compute_grid_of_joint_distribution_fn,
)(jnp.ones((target_positions.shape[0], target_positions.shape[1], 1)), log_true_angular_coeffs)
true_angular_dist.grid_values = true_angular_dist.grid_values[:, 0, :, :]
mean_true_angular_dist = e3nn.SphericalSignal(
    grid_values=true_angular_dist.grid_values.mean(axis=1),
    quadrature=true_angular_dist.quadrature
)
mean_true_angular_dist


SphericalSignal(shape=(1, 40, 39), res_beta=40, res_alpha=39, quadrature=soft, p_val=1, p_arg=-1)
[[[ 0.        0.        0.       ...  0.        0.        0.      ]
  [ 0.        0.        0.       ...  0.        0.        0.      ]
  [ 0.        0.        0.       ...  0.        0.        0.      ]
  ...
  [ 0.        0.        0.       ...  0.        0.        0.      ]
  [ 0.        0.        0.       ...  0.        0.        0.      ]
  [19.714487 19.714487 19.714487 ... 19.714487 19.714487 19.714487]]]

In [5]:
def coeffs_to_dist(log_predicted_angular_coeffs):
    predicted_angular_dist = compute_grid_of_joint_distribution_fn(jnp.ones((1,)), log_predicted_angular_coeffs)
    predicted_angular_dist = predicted_angular_dist[0, :, :]
    return predicted_angular_dist


def coeffs_to_logits(log_predicted_angular_coeffs):
    predicted_angular_dist = coeffs_to_dist(log_predicted_angular_coeffs)
    predicted_angular_logits = predicted_angular_dist.apply(models.safe_log)
    return predicted_angular_logits

In [6]:
def factorized_position_loss(coeffs, dist) -> jnp.ndarray:
    """Computes the loss over position probabilities using separate losses for the radial and the angular components."""
    # Radial loss is simply the negative log-likelihood loss.
    # loss_radial = -preds.globals.radial_logits.sum(axis=-1)
    loss_radial = 0.

    predicted_angular_logits = coeffs_to_logits(coeffs)
    # The angular loss is the KL divergence between the predicted and the true angular distributions.
    res_beta, res_alpha = (
        predicted_angular_logits.res_beta,
        predicted_angular_logits.res_alpha,
    )

    # jax.debug.print("max={x}", x=true_angular_dist.grid_values.max())
    # jax.debug.print("min={x}", x=true_angular_dist.grid_values.min())
    assert predicted_angular_logits.shape == (
        res_beta,
        res_alpha,
    ), (predicted_angular_logits.shape, dist.shape)
    dist.grid_values = dist.grid_values.reshape(res_beta, res_alpha)

    loss_angular = loss.kl_divergence_on_spheres(
        dist, predicted_angular_logits
    )

    loss_position = loss_radial + loss_angular
    return loss_position


In [7]:
@jax.jit
def loss_fn(coeffs, target_dist):
    return factorized_position_loss(coeffs, target_dist).mean()

@functools.partial(jax.jit, static_argnames=("tx", "use_mean_dist",))
def step_fn(rng, coeffs, opt_state, tx, use_mean_dist):
    if use_mean_dist:
        target_dist = mean_true_angular_dist
    else:
        step_rng, rng = jax.random.split(rng)
        target_index = jax.random.choice(step_rng, a=target_positions.shape[0])
        target_dist = true_angular_dist[target_index:target_index+1]

    grads = jax.grad(loss_fn)(coeffs, target_dist)
    loss_val = loss_fn(coeffs, mean_true_angular_dist)
    updates, opt_state = tx.update(grads, opt_state)
    coeffs = optax.apply_updates(coeffs, updates)
    return rng, coeffs, opt_state, loss_val

In [8]:
all_losses_by_hparams = {}
all_steps_by_hparams = {}
all_coeffs_by_hparams = {}

for use_mean_dist in [True, False]:
    for learning_rate in [1e1, 1e0, 1e-1, 1e-2, 1e-3, 1e-4]:
        
        rng = jax.random.PRNGKey(0)
        init_coeffs = e3nn.normal(e3nn.s2_irreps(lmax=lmax), rng)
        tx = optax.adam(learning_rate)
        opt_state = tx.init(init_coeffs)

        coeffs = init_coeffs
        loss_val = float(loss_fn(coeffs, mean_true_angular_dist))

        all_coeffs = []
        all_steps = []
        all_losses = []
        for step in range(10000):
            if step % 1000 == 0:
                all_coeffs.append(coeffs)
                all_steps.append(step)
                all_losses.append(float(loss_val))

            rng, coeffs, opt_state, loss_val = step_fn(rng, coeffs, opt_state, tx, use_mean_dist)
            if step == 10000 - 1:
                print(f"step={step}: loss={loss_val}")

        all_losses_by_hparams[(use_mean_dist, learning_rate)] = all_losses
        all_steps_by_hparams[(use_mean_dist, learning_rate)] = all_steps
        all_coeffs_by_hparams[(use_mean_dist, learning_rate)] = all_coeffs

step=9999: loss=2.534858226776123
step=9999: loss=0.022061586380004883
step=9999: loss=0.007269382476806641
step=9999: loss=0.04304194450378418
step=9999: loss=0.8999682664871216
step=9999: loss=2.9829752445220947


IndexError: Array slice indices must have static start/stop/step to be used with NumPy indexing syntax. Found slice(Traced<ShapedArray(int32[])>with<DynamicJaxprTrace(level=1/0)>, Traced<ShapedArray(int32[])>with<DynamicJaxprTrace(level=1/0)>, None). To index a statically sized array at a dynamic position, try lax.dynamic_slice/dynamic_update_slice (JAX does not support dynamically sized arrays within JIT compiled functions).