In [14]:
import pickle
import sys

import ase
import ase.data
import ase.io
import ase.visualize
import jax
import jax.numpy as jnp
import jraph
import ml_collections
import yaml

sys.path.append('..')

In [2]:
import datatypes
import input_pipeline
import qm9
from models import ATOMIC_NUMBERS, RADII, create_model

In [3]:
path = "/home/ameyad/spherical-harmonic-net/workdirs/v3/mace/interactions=3/l=2/channels=32"
path = "/home/ameyad/spherical-harmonic-net/workdirs/v3/e3schnet/interactions=4/l=3/channels=32"

with open(path + "/checkpoints/params.pkl", 'rb') as f:
    params = pickle.load(f)
with open(path + "/config.yml", "rt") as config_file:
    config = yaml.unsafe_load(config_file)

assert config is not None
config = ml_collections.ConfigDict(config)

In [4]:
# import profile_nn_jax
# import logging
# profile_nn_jax.enable()
# logging.getLogger().setLevel(logging.INFO)  # Important to see the messages!

In [5]:
model = create_model(config, run_in_evaluation_mode=True)
apply_fn = jax.jit(model.apply)

def apply(frag, seed):
    frags = jraph.pad_with_graphs(frag, 32, 1024, 2)
    preds = apply_fn(params, seed, frags)
    pred = jraph.unpad_with_graphs(preds)
    return pred

In [9]:
for z in ATOMIC_NUMBERS:
    print(ase.data.chemical_symbols[z])

H
C
N
O
F


In [10]:
molecules = []

for seed in range(16):

    molecule = ase.Atoms(
        positions=jnp.array([[0, 0, 0.0]]),
        numbers=jnp.array([6]),
    )

    rng = jax.random.PRNGKey(seed)
    for step in range(31):
        k, rng = jax.random.split(rng)
        frag = input_pipeline.ase_atoms_to_jraph_graph(molecule, ATOMIC_NUMBERS, config.cutoff)
        pred = apply(frag, k)

        stop = pred.globals.stop.squeeze(0)

        if stop:
            break

        focus = pred.globals.focus_indices.squeeze(0)
        pos_focus = frag.nodes.positions[focus]
        pos_rel = pred.globals.position_vectors.squeeze(0)
        specie = jnp.array(ATOMIC_NUMBERS[pred.globals.target_species.squeeze(0).item()])

        position = pos_focus + pos_rel

        molecule = ase.Atoms(
            positions=jnp.concatenate([molecule.positions, position[None, :]], axis=0),
            numbers=jnp.concatenate([molecule.numbers, specie[None]], axis=0),
        )

    molecules.append(molecule)

In [12]:
for seed, molecule in enumerate(molecules):
    ase.io.write(f"gen/molecule_{seed}.xyz", molecule)

In [20]:
ase.visualize.view(molecules[1], viewer="x3d")