In [None]:
# Model directory and sampling hyperparameters.
#workdir = "../potato_workdirs/platonic_solids_by_piece_07JUL/nequip/interactions=2/l=4/position_channels=5/channels=64/piece=3"
workdir = "../potato_workdirs/platonic_solids_without_extra_irreps_with_noise/nequip/interactions=3/l=4/position_channels=5/channels=64/piece=0"
step = "20000"
focus_and_atom_type_inverse_temperature = 1.0
position_inverse_temperature = 1.0

In [None]:
"""Visualize the fragments and corresponding predictions."""

import chex
import jax
import e3nn_jax as e3nn
import jax.numpy as jnp
import ase
import numpy as np
import jraph
import plotly.graph_objects as go
import plotly.subplots
import matplotlib.pyplot as plt
import seaborn as sns
import sys

sys.path.append("..")

from symphony.data import input_pipeline_tf
from symphony import models
from symphony import datatypes
from analyses import analysis

In [None]:
all_params = {}
_, all_params[1], _ = analysis.load_model_at_step(
    workdir, 1, run_in_evaluation_mode=True
)

for _step in range(2000, 100000, 2000):
    try:
        _, all_params[_step], _ = analysis.load_model_at_step(
            workdir, _step, run_in_evaluation_mode=True
        )
    except FileNotFoundError:
        continue


In [None]:
steps = list(all_params.keys())
num_params = len(all_params[1])
param_norms_by_step = [np.zeros(len(steps)) for param_index in range(num_params)]

for step_index, _step in enumerate(steps):
    param_norms = jax.tree_map(lambda param, init_param: jnp.linalg.norm(param - init_param) / jnp.linalg.norm(init_param), all_params[_step], all_params[1])
    for param_index, param in enumerate(all_params[_step]):
        param_norms_by_step[param_index][step_index] = jax.tree_leaves(param_norms)[param_index]

for param_index, param in enumerate(all_params[1]):
    plt.plot(steps, param_norms_by_step[param_index], label=param)

# plt.legend()
plt.xlabel("Step")
plt.ylabel("Relative parameter norm")
plt.show();

In [None]:
name = analysis.name_from_workdir(workdir)
model, params, config = analysis.load_model_at_step(
    workdir, step, run_in_evaluation_mode=True
)

# Load the dataset.
# We disable shuffling to visualize step-by-step.
config.shuffle_datasets = False
rng = jax.random.PRNGKey(config.rng_seed)
rng, dataset_rng = jax.random.split(rng)
datasets = input_pipeline_tf.get_datasets(dataset_rng, config)

# Load the fragments and compute predictions.
fragments = next(datasets["train"].take(1).as_numpy_iterator())
fragment = jraph.unbatch(jraph.unpad_with_graphs(fragments))[2]

# Random initialization of the parameters.
# random_params = model.init(jax.random.PRNGKey(0), fragments)
# random_params

In [None]:
fragment.nodes.positions, fragment.nodes.positions[0] + fragment.globals.target_positions

In [None]:
vmapped_apply = jax.jit(jax.vmap(model.apply, in_axes=(None, None, 0, None, None)))

In [None]:
if fragment.nodes.positions.shape[0] == 2:
    directions = jnp.asarray([[0.0, 0.0, 0.0], [0.0, 0.0, 1.0]])
elif fragment.nodes.positions.shape[0] == 3:
    directions = jnp.asarray([[0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [2.0, 1.0, 1.0]]) / jnp.sqrt(6)
else:
    raise ValueError("Invalid shape of fragment.nodes.positions.")

num_displacements = 100
displacements = jnp.linspace(-0.1, 0.1, num_displacements)
batched_displacements = jnp.einsum("i,jk->ijk", displacements, directions)
displaced_fragments = jax.tree_map(
    lambda x: jnp.repeat(x[None], num_displacements, axis=0), fragment
)
displaced_fragments = displaced_fragments._replace(
    nodes=displaced_fragments.nodes._replace(
        positions=displaced_fragments.nodes.positions + batched_displacements
    )
)

In [None]:
# config

In [None]:
displaced_preds = vmapped_apply(
    params,
    rng,
    displaced_fragments,
    focus_and_atom_type_inverse_temperature,
    position_inverse_temperature,
)

In [None]:
focus_and_target_species_probs = displaced_preds.nodes.focus_and_target_species_probs[:, :, 0]

sns.set_style("darkgrid")
sns.lineplot(x=displacements, y=focus_and_target_species_probs[:, 0], label="Atom 0")
sns.lineplot(x=displacements, y=focus_and_target_species_probs[:, 1], label="Atom 1")
sns.lineplot(x=displacements, y=focus_and_target_species_probs[:, 2], label="Atom 2")
plt.legend()
plt.xlabel("Displacements")
plt.ylabel("Focus Probabilities")
plt.show();

In [None]:
focus_and_target_species_logits = displaced_preds.nodes.focus_and_target_species_logits[:, :, 0]

sns.set_style("darkgrid")
sns.lineplot(x=displacements, y=focus_and_target_species_logits[:, 0], label="Atom 0")
sns.lineplot(x=displacements, y=focus_and_target_species_logits[:, 1], label="Atom 1")
sns.lineplot(x=displacements, y=focus_and_target_species_logits[:, 2], label="Atom 2")
plt.legend()
plt.xlabel("Displacements")
plt.ylabel("Focus Logits")
plt.show();

In [None]:
focus_indices = displaced_preds.globals.focus_indices.squeeze(axis=-1)

sns.set_style("darkgrid")
sns.scatterplot(x=displacements, y=focus_indices, label="Focus Index")
plt.yticks([0, 1, 2])
plt.xlabel("Displacements")
plt.ylabel("Focus Index")
plt.legend()
plt.show();

In [None]:
jnp.linalg.norm(displaced_preds.globals.position_vectors[0])

In [None]:
position_probs = displaced_preds.globals.position_probs
all_radii = displaced_preds.globals.radial_bins[0].squeeze(axis=0)
position_probs /= position_probs.integrate().array.squeeze().sum(axis=-1)[:, None, None]
grid_vectors = position_probs.grid_vectors
scaled_grid_vectors = jnp.einsum("abi,r,zcrab->ziab", grid_vectors, all_radii, position_probs.grid_values)
scaled_signal = e3nn.SphericalSignal(scaled_grid_vectors, position_probs.quadrature)
print(scaled_signal.shape)
expectation_position = scaled_signal.integrate().array.squeeze(axis=-1)
expectation_position += displaced_fragments.nodes.positions[jnp.arange(num_displacements), displaced_preds.globals.focus_indices[:, 0]] 

print(displaced_fragments.nodes.positions[0])
sns.lineplot(x=displacements, y=expectation_position[:, 0], label="x")
sns.lineplot(x=displacements, y=expectation_position[:, 1], label="y")
sns.lineplot(x=displacements, y=expectation_position[:, 2], label="z")
plt.legend()
plt.xlabel("Displacements")
plt.ylabel("Expectation of Target Position")
plt.show();

In [None]:
position_probs = displaced_preds.globals.position_probs
radial_probs = position_probs.integrate().array.squeeze()
radial_probs /= radial_probs.sum(axis=-1)[:, None]
print(radial_probs.shape, radial_probs[0])
radial_argmax = radial_probs.argmax(axis=1)
radii = all_radii[radial_argmax]

sns.set_style("darkgrid")
# sns.scatterplot(x=displacements, y=radii, label="Argmax Radius")
plt.imshow(radial_probs.T, cmap='viridis')
plt.xticks(np.arange(len(displacements))[::20], displacements[::20])
plt.xlabel("Displacements")
plt.yticks(np.arange(len(all_radii))[::5], all_radii[::5])
plt.ylabel("Radius")
# plt.legend()
plt.show();

In [None]:
position_vectors = displaced_preds.globals.position_vectors[:, 0, :]
position_vectors += displaced_fragments.nodes.positions[jnp.arange(num_displacements), displaced_preds.globals.focus_indices[:, 0]] 

sns.set_style("darkgrid")
sns.lineplot(x=displacements, y=position_vectors[:, 0], label="x")
sns.lineplot(x=displacements, y=position_vectors[:, 1], label="y")
sns.lineplot(x=displacements, y=position_vectors[:, 2], label="z")
plt.legend()
plt.xlabel("Displacements")
plt.ylabel("Sampled Target Position")
plt.show();

In [None]:
node_embeddings = displaced_preds.nodes.auxiliary_node_embeddings
valid_displacements = jnp.ones_like(displacements, dtype=bool)
for mul_irrep, slice in zip(node_embeddings.irreps, node_embeddings.irreps.slices()):
    for i in range(mul_irrep.dim):
        sns.set_style("darkgrid")
        # sns.lineplot(x=displacements, y=node_embeddings.array[:, 0, i], label="Atom 0")
        # sns.lineplot(x=displacements, y=node_embeddings.array[:, 1, i], label="Atom 1")
        feature = node_embeddings.array[valid_displacements, 2, slice][:, i]
        rescaled_feature = (feature - feature.min()) / (feature.max() - feature.min())
        sns.lineplot(x=displacements[valid_displacements], y=feature)
        # plt.legend()

        plt.xlabel("Displacements")
        plt.ylabel("Node Embedding")
        plt.title(f"mul_irrep {mul_irrep}")
    plt.show();

In [None]:
scaled_signal_x = scaled_signal[0][0]
scaled_signal_y = scaled_signal[0][1]
scaled_signal_z = scaled_signal[0][2]

go.Figure([go.Surface(scaled_signal_x.plotly_surface(translation=jnp.asarray([-2, 0, 0])), cmin=-1., cmax=1.),
           go.Surface(scaled_signal_y.plotly_surface(), cmin=-1., cmax=1.),
           go.Surface(scaled_signal_z.plotly_surface(translation=jnp.asarray([2, 0, 0])), cmin=-1., cmax=1.)])