In [None]:
import jax
import jraph
import ase
import sys
import nglview as nv
sys.path.append('../')

In [None]:
%load_ext autoreload

In [None]:
%autoreload 2
import analyses.analysis as analysis
from symphony.data import input_pipeline
from symphony import models
from symphony import datatypes

In [None]:
workdir = "/Users/ameyad/Documents/spherical-harmonic-net/potato_workdirs/qm9_8SEP_position_denoiser/position_updater/interactions=3/l=5/position_channels=2/channels=32/"
model, params, config = analysis.load_model_at_step(workdir, step="best", run_in_evaluation_mode=True)

In [None]:
init_molecule, init_molecule_name = analysis.construct_molecule("2")
init_fragment = input_pipeline.ase_atoms_to_jraph_graph(
    init_molecule, models.ATOMIC_NUMBERS, config.nn_cutoff
)

def add_noise_to_positions(rng, fragment: jraph.GraphsTuple, noise_std: float = 0.05):
    nodes = fragment.nodes
    noise = noise_std * jax.random.normal(rng, nodes.positions.shape)
    return fragment._replace(nodes=nodes._replace(positions=nodes.positions + noise))

rngs = jax.random.split(jax.random.PRNGKey(0), 10)
noisy_fragments = jax.vmap(add_noise_to_positions, in_axes=(0, None))(rngs, init_fragment)

In [None]:
def fragments_to_ase_atoms(fragments: datatypes.Fragments):
    num_fragments = fragments.n_node.shape[0]
    return [
        ase.Atoms(
        symbols=models.get_atomic_numbers(fragments.nodes.species[index]),
        positions=fragments.nodes.positions[index],
    ) for index in range(num_fragments)]

In [None]:
print(jax.tree_map(lambda x: x.shape, noisy_fragments))

In [None]:
v = nv.show_asetraj(fragments_to_ase_atoms(noisy_fragments), gui=True)
v.add_representation("ball+stick")
v

In [None]:
@jax.jit
def denoise_positions(fragment: jraph.GraphsTuple):
    position_updates = model.apply(params, None, fragment)
    return fragment._replace(nodes=fragment.nodes._replace(positions=fragment.nodes.positions + position_updates))


denoised_fragments = jax.vmap(denoise_positions)(noisy_fragments)



In [None]:
v = nv.show_asetraj(fragments_to_ase_atoms(denoised_fragments), gui=True)
v.add_representation("ball+stick")
v