In [2]:
import jax
import jax.numpy as jnp
import jraph
import sys
import ase
import e3nn_jax as e3nn

sys.path.append("..")

In [2]:
%load_ext autoreload

In [3]:
%autoreload 2
import analyses.analysis as analysis
import input_pipeline_tf
import input_pipeline
import models
import train

In [4]:
workdir = '/Users/ameyad/Documents/spherical-harmonic-net/potato_workdirs/debug/nequip/num_molecules=1/scale_position_logits=True/interactions=4/l=5/channels=32'

In [18]:
name = analysis.name_from_workdir(workdir)
model, params, config = analysis.load_model_at_step(
    workdir, "5000", run_in_evaluation_mode=True
)
apply_fn = jax.jit(model.apply)
print(config)

activation: softplus
avg_num_neighbors: 300.0
compute_padding_dynamically: false
dataset: qm9
eval_every_steps: 5000
even_activation: swish
focus_predictor:
  latent_size: 128
  num_layers: 3
learning_rate: 0.001
learning_rate_schedule: constant
learning_rate_schedule_kwargs:
  decay_steps: 50000
  init_value: 0.001
  peak_value: 0.002
  warmup_steps: 2000
log_every_steps: 1000
loss_kwargs:
  radius_rbf_variance: 0.05
  scale_position_logits_by_inverse_temperature: true
  target_position_inverse_temperature: 50.0
max_ell: 5
max_n_edges: 2880
max_n_graphs: 32
max_n_nodes: 960
mlp_activation: swish
mlp_n_layers: 2
model: NequIP
nn_cutoff: 5.0
nn_tolerance: 0.5
num_basis_fns: 8
num_channels: 32
num_eval_steps: 100
num_eval_steps_at_end_of_training: 5000
num_interactions: 4
num_train_steps: 30000
odd_activation: tanh
optimizer: adam
r_max: 5
rng_seed: 0
root_dir: /Users/ameyad/Documents/qm9_data_tf/data_tf2
target_position_predictor:
  res_alpha: 359
  res_beta: 180
target_species_predicto

In [19]:
# Load the dataset.
datasets = input_pipeline_tf.get_datasets(None, config, shuffle=False)
for step, graphs in enumerate(datasets["train"].as_numpy_iterator()):
    graphs = jax.tree_map(jnp.asarray, graphs)
    for graph in jraph.unbatch(graphs):
        if len(graph.nodes.species) == 4 and jnp.allclose(graph.nodes.species, jnp.asarray([1, 0, 0, 0])):
            fragment = graph
            break
    break

In [20]:
fragment

GraphsTuple(nodes=FragmentsNodes(positions=Array([[-1.2700e-02,  1.0858e+00,  8.0000e-03],
       [-5.2380e-01,  1.4379e+00,  9.0640e-01],
       [ 1.0117e+00,  1.4638e+00,  3.0000e-04],
       [ 2.2000e-03, -6.0000e-03,  2.0000e-03]], dtype=float32), species=Array([1, 0, 0, 0], dtype=int32), focus_probability=Array([1., 0., 0., 0.], dtype=float32)), edges=Array([[1.],
       [1.],
       [1.],
       [1.],
       [1.],
       [1.],
       [1.],
       [1.],
       [1.],
       [1.],
       [1.],
       [1.]], dtype=float32), receivers=Array([0, 0, 0, 3, 3, 3, 2, 2, 2, 1, 1, 1], dtype=int32), senders=Array([3, 2, 1, 0, 2, 1, 0, 3, 1, 0, 3, 2], dtype=int32), globals=FragmentsGlobals(stop=Array([False], dtype=bool), target_positions=Array([[-0.5281,  0.3617, -0.8846]], dtype=float32), target_species=Array([0], dtype=int32), target_species_probability=Array([[1., 0., 0., 0., 0.]], dtype=float32)), n_node=Array([4], dtype=int32), n_edge=Array([12], dtype=int32))

In [21]:
preds = apply_fn(params, jax.random.PRNGKey(0), fragment, 1.0)

In [22]:
total_loss, (
    loss_focus,
    loss_atom_type,
    loss_position,
) = train.generation_loss(preds, fragment, **config.loss_kwargs)

In [23]:
total_loss

Array([3.7652333], dtype=float32)

In [49]:
num_graphs = 1
num_radii = 2
target_positions = jnp.asarray([[1., 1., 1.]])

# RADII = jnp.asarray([1., 2.])

# true_radius_weights = jax.vmap(
#     lambda target_position: jax.vmap(
#         lambda radius: jnp.exp(
#             -((radius - jnp.linalg.norm(target_position)) ** 2)
#             / (2 * 1)
#         )
#     )(RADII)
# )(target_positions)
true_radius_weights = jnp.asarray([[1., 0.]])

position_coeffs = e3nn.IrrepsArray("1o", jnp.asarray([[[1., 1., 1.], [0., 0., 0.]]]))
position_logits = e3nn.to_s2grid(
    position_coeffs,
    res_beta=50,
    res_alpha=39,
    quadrature="gausslegendre",
    normalization="integral",
    p_val=1,
    p_arg=-1,
)
print(position_logits)
position_dist = position_logits.apply(jnp.exp)
integrals = position_dist.integrate().array
print(integrals)
position_dist /= jnp.where(integrals == 0, 1, integrals)
position_logits = position_dist.apply(safe_log)
print(jnp.isnan(position_logits.grid_values).sum())

norms = jnp.linalg.norm(target_positions, axis=-1, keepdims=True)
target_positions_unit_vectors = target_positions / jnp.where(
    norms == 0, 1, norms
)
target_positions_unit_vectors = e3nn.IrrepsArray(
    "1o", target_positions_unit_vectors
)
res_beta, res_alpha, quadrature = (
    position_logits.res_beta,
    position_logits.res_alpha,
    position_logits.quadrature,
)
log_true_angular_dist = e3nn.to_s2grid(
    target_positions_unit_vectors,
    res_beta,
    res_alpha,
    quadrature=quadrature,
    p_val=1,
    p_arg=-1,
)
assert log_true_angular_dist.grid_values.shape == (
    num_graphs,
    res_beta,
    res_alpha,
), log_true_angular_dist.grid_values.shape
#print(log_true_angular_dist, position_logits)

log_true_angular_dist_max = jnp.max(
    log_true_angular_dist.grid_values, axis=(-2, -1), keepdims=True
)
log_true_angular_dist = log_true_angular_dist.apply(lambda x: x - log_true_angular_dist_max)
true_angular_dist = log_true_angular_dist.apply(
    lambda x: jnp.exp(x - log_true_angular_dist_max)
)
true_angular_dist = true_angular_dist / true_angular_dist.integrate()
assert true_angular_dist.grid_values.shape == (num_graphs, res_beta, res_alpha)
print(true_angular_dist.grid_values, position_dist.grid_values)

# Integrate the true angular distribution with the predicted logits.
cross_entropy_at_radius = (
    (true_angular_dist[:, None, :, :] * position_logits)
    .integrate()
    .array.squeeze(axis=-1)
)
assert cross_entropy_at_radius.shape == (num_graphs, num_radii)


radius_normalizing_factors = position_logits.apply(jnp.exp).integrate()
radius_normalizing_factors = radius_normalizing_factors.array.squeeze(axis=-1)
assert radius_normalizing_factors.shape == (
    num_graphs,
    num_radii,
)


def safe_log(x):
    return jnp.log(jnp.where(x == 0, 1.0, x))

lower_bounds = (
    -(true_angular_dist * true_angular_dist.apply(safe_log)).integrate().array.squeeze(axis=-1)
)
lower_bounds += (
    -(true_radius_weights * safe_log(true_radius_weights)).sum(axis=-1)
)

loss_position = jax.vmap(
    lambda qr, fr, Zr, lb: -jnp.sum(qr * fr) + jnp.log(jnp.sum(Zr)) - lb
)(
    true_radius_weights,
    cross_entropy_at_radius,
    radius_normalizing_factors,
    lower_bounds,
)
loss_position

SphericalSignal(shape=(1, 2, 50, 39), res_beta=50, res_alpha=39, quadrature=gausslegendre, p_val=1, p_arg=-1)
[[[[-1.6476383 -1.6354802 -1.6257724 ... -1.6953983 -1.6779903
    -1.6619316]
   [-1.5327659 -1.5049033 -1.4826561 ... -1.6422174 -1.6023237
    -1.5655222]
   [-1.4113327 -1.3677797 -1.3330044 ... -1.5824196 -1.5200607
    -1.462535 ]
   ...
   [ 2.0020342  2.045587   2.0803626 ...  1.8309472  1.8933063
     1.9508318]
   [ 1.910662   1.9385245  1.9607718 ...  1.8012105  1.8411041
     1.8779057]
   [ 1.8125366  1.8246946  1.8344026 ...  1.7647766  1.7821845
     1.7982433]]

  [[ 0.         0.         0.        ...  0.         0.
     0.       ]
   [ 0.         0.         0.        ...  0.         0.
     0.       ]
   [ 0.         0.         0.        ...  0.         0.
     0.       ]
   ...
   [ 0.         0.         0.        ...  0.         0.
     0.       ]
   [ 0.         0.         0.        ...  0.         0.
     0.       ]
   [ 0.         0.         0.        ...

AssertionError: 

In [84]:
def kl_on_sphere(true_radial, log_true_angular_coeffs, log_predicted_coeffs):
    log_true_angular_dist = e3nn.to_s2grid(
        log_true_angular_coeffs,
        res_beta,
        res_alpha,
        quadrature=quadrature,
        p_val=1,
        p_arg=-1,
    )
    log_true_angular_dist_max = jnp.max(
        log_true_angular_dist.grid_values, axis=(-2, -1), keepdims=True
    )
    log_true_angular_dist = log_true_angular_dist.apply(lambda x: x - log_true_angular_dist_max)
    true_angular_dist = log_true_angular_dist.apply(
        lambda x: jnp.exp(x - log_true_angular_dist_max)
    )
    true_angular_dist = true_angular_dist / true_angular_dist.integrate()

    true_dist = true_radial * true_angular_dist[None, :, :]
    self_entropy = -(true_dist * true_dist.apply(safe_log)).integrate().array.sum()

    print(e3nn.from_s2grid(true_dist.apply(safe_log), "1o + 2e"), log_predicted_coeffs)
    log_predicted_dist = e3nn.to_s2grid(
        log_predicted_coeffs,
        res_beta,
        res_alpha,
        quadrature=quadrature,
        p_val=1,
        p_arg=-1,
    )
    log_predicted_dist_max = jnp.max(log_predicted_dist.grid_values)
    log_predicted_dist = log_predicted_dist.apply(lambda x: x - log_predicted_dist_max)
    cross_entropy = -(true_dist * log_predicted_dist).integrate().array.sum()
    normalizing_factor = jnp.log(log_predicted_dist.apply(jnp.exp).integrate().array.sum())

    return cross_entropy + normalizing_factor - self_entropy

kl_on_sphere(jnp.asarray([0.9, 0.1]), e3nn.IrrepsArray("1o", jnp.asarray([2., 1., 1.])), e3nn.IrrepsArray("1o", jnp.asarray([[1., 5., 1.], [1., 1., 1.]])))



1x1o+1x2e
[[ 1.9999998e+00  9.9999988e-01  1.0000001e+00 -1.8529313e-08
   0.0000000e+00  1.0430813e-07 -2.9802322e-08  6.1211871e-08]
 [ 1.9999998e+00  1.0000010e+00  1.0000002e+00  9.9738529e-09
   0.0000000e+00 -4.8428774e-07 -8.9406967e-08  9.0113069e-08]] 1x1o
[[1. 5. 1.]
 [1. 1. 1.]]


Array(2.8226643, dtype=float32)

In [15]:
coeffs = e3nn.IrrepsArray("1o", jnp.asarray([1., 1., 1.]))
sig = e3nn.to_s2grid(coeffs, 50, 69, quadrature="soft", p_val=1, p_arg=-1)
go.Surface(sig.plotly_surface(scale_radius_by_amplitude=True))

import plotly.graph_objects as go

go.Figure([go.Surface(sig.plotly_surface(scale_radius_by_amplitude=True))])

ValueError: Mime type rendering requires nbformat>=4.2.0 but it is not installed