In [103]:
import functools
import jax
import jax.numpy as jnp
import jraph
import optax
import sys
import ml_collections
import e3nn_jax as e3nn
import plotly.graph_objects as go
from flax.training import train_state

sys.path.append("..")

from jax import config
config.update("jax_debug_nans", True)

In [None]:
%load_ext autoreload

In [None]:
%autoreload 2
import analyses.analysis as analysis
from symphony.data import input_pipeline_tf, input_pipeline
from symphony import models, train, loss

In [None]:
workdir = '/home/songk/symphony/workdirs/e3schnet_and_nequip/interactions=3/l=5/position_channels=5/channels=64/piece=0'

In [104]:
name = analysis.name_from_workdir(workdir)
model, params, config_platonic = analysis.load_model_at_step(
    workdir, "1", run_in_evaluation_mode=True
)
config_platonic = ml_collections.FrozenConfigDict(config_platonic)
apply_fn = jax.jit(model.apply)

In [None]:
# Load the dataset.
datasets = input_pipeline_tf.get_datasets(jax.random.PRNGKey(0), config)
for step, graphs in enumerate(datasets["train"].as_numpy_iterator()):
    graphs = jax.tree_map(jnp.asarray, graphs)
    for graph in jraph.unbatch(graphs):
        if jnp.sum(graph.globals.target_position_mask) > 1:
            fragment = graph
            break
    break

In [None]:
fragment

In [None]:
preds = apply_fn(params, jax.random.PRNGKey(0), graphs, 1.0)

In [None]:
num_radii = 20
radial_bins = jnp.linspace(0.5, 1.5, num_radii)
def target_position_to_joint_distribution(
    target_positions,  # (max_targets_per_graph, 3)
    res_beta,
    res_alpha,
    quadrature,
    lmax=5,
    radius_rbf_variance = 1e-3,
    target_position_inverse_temperature = 1.0,
    ):
    true_radial_weights = jax.vmap(
        lambda pos: loss.target_position_to_radius_weights(
            pos, radius_rbf_variance, radial_bins
        )
    )(target_positions)
    log_true_angular_coeffs = jax.vmap(
        lambda pos: loss.target_position_to_log_angular_coeffs(
            pos, target_position_inverse_temperature, lmax
        )
    )(target_positions)

    assert true_radial_weights.shape[-1] == num_radii, true_radial_weights.shape
    assert log_true_angular_coeffs.shape[-1] == log_true_angular_coeffs.irreps.dim, log_true_angular_coeffs.shape
    assert true_radial_weights.shape[0] == log_true_angular_coeffs.shape[0]

    compute_joint_distribution_fn = functools.partial(
        models.compute_grid_of_joint_distribution,
        res_beta=res_beta,
        res_alpha=res_alpha,
        quadrature=quadrature,
    )
    joint_dist = jax.vmap(compute_joint_distribution_fn)(
        true_radial_weights, log_true_angular_coeffs
    )  # (max_targets_per_graph, num_radii, res_beta, res_alpha)
    return joint_dist

In [None]:
res_beta, res_alpha, quadrature = (90, 179, 'gausslegendre')

target_positions = graphs.globals.target_positions
target_position_mask = graphs.globals.target_position_mask
num_graphs = target_positions.shape[0]
compute_joint_distribution_fn = functools.partial(
    target_position_to_joint_distribution,
    res_beta=res_beta,
    res_alpha=res_alpha,
    quadrature=quadrature,
)
true_dist = jax.vmap(compute_joint_distribution_fn)(
    target_positions
)  # (num_graphs, max_targets_per_graph, num_radii, res_beta, res_alpha)
# true_dist /= true_dist.integrate().array.sum()
num_target_positions = jnp.sum(target_position_mask, axis=1).reshape(-1, 1, 1, 1)
num_target_positions = jnp.maximum(num_target_positions, 1)  # in case there are zeros (though there shouldn't be)
dist_sum = jnp.sum(true_dist.grid_values * target_position_mask.reshape(num_graphs, -1, 1, 1, 1), axis=1)
dist_mean = dist_sum / num_target_positions
mean_true_dist = e3nn.SphericalSignal(
    grid_values=dist_mean,
    quadrature=true_dist.quadrature
)

In [74]:
def loss(preds, dist):
    position_logits = preds.globals.position_logits
    log_predicted_dist = position_logits

    assert log_predicted_dist.grid_values.shape == (
        num_graphs,
        num_radii,
        res_beta,
        res_alpha,
    )
    loss_position = jax.vmap(loss.kl_divergence_on_spheres)(
        dist, log_predicted_dist
    )
    assert loss_position.shape == (num_graphs,)
    return loss_position

copy from train.py

In [109]:
loss_kwargs = config_platonic.loss_kwargs

In [115]:
# @jax.jit
def train_step(
    graphs,
    state,
    rng,
    noise_std: float,
):
    """Performs one update step over the current batch of graphs."""

    loss_rng, rng = jax.random.split(rng)
    def loss_fn(params, graphs,) -> float:
        curr_state = state.replace(params=params)
        preds = train.get_predictions(curr_state, graphs, rng=loss_rng)
        total_loss, (
            focus_and_atom_type_loss,
            position_loss,
        ) = loss.generation_loss(preds=preds, graphs=graphs, **loss_kwargs)
        mask = jraph.get_graph_padding_mask(graphs)
        mean_loss = jnp.sum(jnp.where(mask, total_loss, 0.0)) / jnp.sum(mask)
        return mean_loss, (
            total_loss,
            focus_and_atom_type_loss,
            position_loss,
            mask,
        )

    # # Add noise to positions, if required.
    # if add_noise_to_positions:
    noise_rng, rng = jax.random.split(rng)
    position_noise = (
        jax.random.normal(noise_rng, graphs.nodes.positions.shape) * noise_std
    )
    # else:
    # position_noise = jnp.zeros_like(graphs.nodes.positions)

    noisy_positions = graphs.nodes.positions + position_noise
    graphs = graphs._replace(nodes=graphs.nodes._replace(positions=noisy_positions))

    # Compute gradients.
    grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
    (
        _,
        (total_loss, focus_and_atom_type_loss, position_loss, mask),
    ), grads = grad_fn(state.params, graphs)

    # Average gradients across devices.
    # grads = jax.lax.pmean(grads, axis_name="device")
    state = state.apply_gradients(grads=grads)

    return state, total_loss, focus_and_atom_type_loss, position_loss

In [117]:
all_losses_by_hparams = {}
all_steps_by_hparams = {}
all_coeffs_by_hparams = {}

num_steps = 10000
report_every = num_steps // 50

rng = jax.random.PRNGKey(0)

for learning_rate in [1e-2]:
# for learning_rate in [1e1, 1e0, 1e-1, 1e-2, 1e-3, 1e-4]:
    
    init_rng, rng = jax.random.split(rng)
    model = models.create_model(config_platonic, run_in_evaluation_mode=False)
    params = model.init(init_rng, graphs)
    tx = optax.adam(learning_rate)
    state = train_state.TrainState.create(
        apply_fn=model.apply, params=params, tx=tx
    )

    for step in range(num_steps):
        step_rng, rng = jax.random.split(rng)
        state, total_loss, focus_and_atom_type_loss, position_loss = train_step(
            graphs,
            state,
            step_rng,
            0.0,
        )
        if step % report_every == 0 or step == num_steps - 1:
            print(f"step={step}: mean position loss={position_loss.mean()}")

    # all_losses_by_hparams[(use_mean_dist, learning_rate)] = all_losses
    # all_steps_by_hparams[(use_mean_dist, learning_rate)] = all_steps
    # all_coeffs_by_hparams[(use_mean_dist, learning_rate)] = all_coeffs

step=0: mean position loss=14.62774658203125


KeyboardInterrupt: 