In [None]:
import jax
import jax.numpy as jnp
import jraph
import sys
import ase
import e3nn_jax as e3nn

sys.path.append("..")

In [None]:
%load_ext autoreload

In [None]:
%autoreload 2
import analyses.analysis as analysis
import input_pipeline_tf
import input_pipeline
import models
import train

In [None]:
workdir = '/Users/ameyad/Documents/spherical-harmonic-net/potato_workdirs/debug/nequip/num_molecules=1/scale_position_logits=True/interactions=4/l=5/channels=32'

In [None]:
name = analysis.name_from_workdir(workdir)
model, params, config = analysis.load_model_at_step(
    workdir, "5000", run_in_evaluation_mode=True
)
apply_fn = jax.jit(model.apply)
print(config)

In [None]:
# Load the dataset.
datasets = input_pipeline_tf.get_datasets(None, config, shuffle=False)
for step, graphs in enumerate(datasets["train"].as_numpy_iterator()):
    graphs = jax.tree_map(jnp.asarray, graphs)
    for graph in jraph.unbatch(graphs):
        if len(graph.nodes.species) == 4 and jnp.allclose(graph.nodes.species, jnp.asarray([1, 0, 0, 0])):
            fragment = graph
            break
    break

In [None]:
fragment

In [None]:
preds = apply_fn(params, jax.random.PRNGKey(0), fragment, 1.0)

In [None]:
total_loss, (
    loss_focus,
    loss_atom_type,
    loss_position,
) = train.generation_loss(preds, fragment, **config.loss_kwargs)

In [None]:
total_loss

In [None]:
num_graphs = 1
num_radii = 2
target_positions = jnp.asarray([[1., 1., 1.]])

# RADII = jnp.asarray([1., 2.])

# true_radius_weights = jax.vmap(
#     lambda target_position: jax.vmap(
#         lambda radius: jnp.exp(
#             -((radius - jnp.linalg.norm(target_position)) ** 2)
#             / (2 * 1)
#         )
#     )(RADII)
# )(target_positions)
true_radius_weights = jnp.asarray([[1., 0.]])

position_coeffs = e3nn.IrrepsArray("1o", jnp.asarray([[[1., 1., 1.], [0., 0., 0.]]]))
position_logits = e3nn.to_s2grid(
    position_coeffs,
    res_beta=50,
    res_alpha=39,
    quadrature="gausslegendre",
    normalization="integral",
    p_val=1,
    p_arg=-1,
)
print(position_logits)
position_dist = position_logits.apply(jnp.exp)
integrals = position_dist.integrate().array
print(integrals)
position_dist /= jnp.where(integrals == 0, 1, integrals)
position_logits = position_dist.apply(safe_log)
print(jnp.isnan(position_logits.grid_values).sum())

norms = jnp.linalg.norm(target_positions, axis=-1, keepdims=True)
target_positions_unit_vectors = target_positions / jnp.where(
    norms == 0, 1, norms
)
target_positions_unit_vectors = e3nn.IrrepsArray(
    "1o", target_positions_unit_vectors
)
res_beta, res_alpha, quadrature = (
    position_logits.res_beta,
    position_logits.res_alpha,
    position_logits.quadrature,
)
log_true_angular_dist = e3nn.to_s2grid(
    target_positions_unit_vectors,
    res_beta,
    res_alpha,
    quadrature=quadrature,
    p_val=1,
    p_arg=-1,
)
assert log_true_angular_dist.grid_values.shape == (
    num_graphs,
    res_beta,
    res_alpha,
), log_true_angular_dist.grid_values.shape
#print(log_true_angular_dist, position_logits)

log_true_angular_dist_max = jnp.max(
    log_true_angular_dist.grid_values, axis=(-2, -1), keepdims=True
)
log_true_angular_dist = log_true_angular_dist.apply(lambda x: x - log_true_angular_dist_max)
true_angular_dist = log_true_angular_dist.apply(
    lambda x: jnp.exp(x - log_true_angular_dist_max)
)
true_angular_dist = true_angular_dist / true_angular_dist.integrate()
assert true_angular_dist.grid_values.shape == (num_graphs, res_beta, res_alpha)
print(true_angular_dist.grid_values, position_dist.grid_values)

# Integrate the true angular distribution with the predicted logits.
cross_entropy_at_radius = (
    (true_angular_dist[:, None, :, :] * position_logits)
    .integrate()
    .array.squeeze(axis=-1)
)
assert cross_entropy_at_radius.shape == (num_graphs, num_radii)


radius_normalizing_factors = position_logits.apply(jnp.exp).integrate()
radius_normalizing_factors = radius_normalizing_factors.array.squeeze(axis=-1)
assert radius_normalizing_factors.shape == (
    num_graphs,
    num_radii,
)


def safe_log(x):
    return jnp.log(jnp.where(x == 0, 1.0, x))

lower_bounds = (
    -(true_angular_dist * true_angular_dist.apply(safe_log)).integrate().array.squeeze(axis=-1)
)
lower_bounds += (
    -(true_radius_weights * safe_log(true_radius_weights)).sum(axis=-1)
)

loss_position = jax.vmap(
    lambda qr, fr, Zr, lb: -jnp.sum(qr * fr) + jnp.log(jnp.sum(Zr)) - lb
)(
    true_radius_weights,
    cross_entropy_at_radius,
    radius_normalizing_factors,
    lower_bounds,
)
loss_position

In [None]:
def kl_on_sphere(true_radial, log_true_angular_coeffs, log_predicted_coeffs):
    log_true_angular_dist = e3nn.to_s2grid(
        log_true_angular_coeffs,
        res_beta,
        res_alpha,
        quadrature=quadrature,
        p_val=1,
        p_arg=-1,
    )
    log_true_angular_dist_max = jnp.max(
        log_true_angular_dist.grid_values, axis=(-2, -1), keepdims=True
    )
    log_true_angular_dist = log_true_angular_dist.apply(lambda x: x - log_true_angular_dist_max)
    true_angular_dist = log_true_angular_dist.apply(
        lambda x: jnp.exp(x - log_true_angular_dist_max)
    )
    true_angular_dist = true_angular_dist / true_angular_dist.integrate()

    true_dist = true_radial * true_angular_dist[None, :, :]
    self_entropy = -(true_dist * true_dist.apply(safe_log)).integrate().array.sum()

    print(e3nn.from_s2grid(true_dist.apply(safe_log), "1o + 2e"), log_predicted_coeffs)
    log_predicted_dist = e3nn.to_s2grid(
        log_predicted_coeffs,
        res_beta,
        res_alpha,
        quadrature=quadrature,
        p_val=1,
        p_arg=-1,
    )
    log_predicted_dist_max = jnp.max(log_predicted_dist.grid_values)
    log_predicted_dist = log_predicted_dist.apply(lambda x: x - log_predicted_dist_max)
    cross_entropy = -(true_dist * log_predicted_dist).integrate().array.sum()
    normalizing_factor = jnp.log(log_predicted_dist.apply(jnp.exp).integrate().array.sum())

    return cross_entropy + normalizing_factor - self_entropy

kl_on_sphere(jnp.asarray([0.9, 0.1]), e3nn.IrrepsArray("1o", jnp.asarray([2., 1., 1.])), e3nn.IrrepsArray("1o", jnp.asarray([[1., 5., 1.], [1., 1., 1.]])))



In [None]:
coeffs = e3nn.IrrepsArray("1o", jnp.asarray([1., 1., 1.]))
sig = e3nn.to_s2grid(coeffs, 50, 69, quadrature="soft", p_val=1, p_arg=-1)
go.Surface(sig.plotly_surface(scale_radius_by_amplitude=True))

import plotly.graph_objects as go

go.Figure([go.Surface(sig.plotly_surface(scale_radius_by_amplitude=True))])