In [None]:
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
import jraph

import sys
sys.path.append('../')

In [None]:
%load_ext autoreload

In [None]:
%autoreload 2
from analyses import analysis
from symphony import models
from symphony.data import input_pipeline_tf

In [None]:
workdir = '/Users/ameyad/Documents/spherical-harmonic-net/potato_workdirs/platonic_solids_by_piece_radial_diffconditioning_smallerembed_noradial/e3schnet_and_mace/interactions=1/l=2/lfocus=2/position_channels=2/channels=32/apply_gate=True/square_logits=True/piece=0'
step = '1000'

In [None]:
name = analysis.name_from_workdir(workdir)
model, params, config = analysis.load_model_at_step(
    workdir, step, run_in_evaluation_mode=False
)

# Load the dataset.
# We disable shuffling to visualize step-by-step.
config.shuffle_datasets = False
rng = jax.random.PRNGKey(config.rng_seed)
rng, dataset_rng = jax.random.split(rng)
datasets = input_pipeline_tf.get_datasets(dataset_rng, config)

# Load the fragments and compute predictions.
fragments = next(datasets["train"].take(1).as_numpy_iterator())
fragments = jraph.unpad_with_graphs(fragments)
apply_rng, rng = jax.random.split(rng)
preds = jax.jit(model.apply)(
    params,
    apply_rng,
    fragments,
    focus_and_atom_type_inverse_temperature=1.0,
    position_inverse_temperature=1.0,
)

In [None]:
preds.nodes.embeddings_for_positions[1].filter(drop="0e")

In [None]:
config.target_position_predictor

In [None]:
radii = jnp.linspace(
    config.target_position_predictor.radius_predictor_config.min_radius,
    config.target_position_predictor.radius_predictor_config.max_radius,
    config.target_position_predictor.radius_predictor_config.num_radii,
)
print(radii)
sns.set_style("darkgrid")
for index, (fragment, pred) in enumerate(zip(jraph.unbatch(fragments), jraph.unbatch(preds))):
    counts, edges = pred.globals.radii_pdf
    counts = counts.squeeze(axis=0)
    edges = edges.squeeze(axis=0)
    plt.stairs(counts, edges, fill=True)
    if pred.globals.radii is not None:
        radii = pred.globals.radii[0]
        plt.axvline(radii, color='r', linestyle='--', linewidth=2, label='Predicted radius')
    true_radii = jnp.linalg.norm(
        fragment.globals.target_positions,
        axis=-1,
    )
    plt.axvline(true_radii, color='g', linestyle='--', label='True radius')
    plt.xlabel('r')
    plt.ylabel('p(r)')
    plt.title(f'Radial PDF for Fragment {index}')
    plt.legend()
    plt.show()