In [19]:
import os
import pickle
import sys

import ase
import ase.data
import ase.io
import ase.visualize
import jax
import jax.numpy as jnp
import jraph
import ml_collections
import tqdm
import yaml

sys.path.append('..')

In [20]:
import datatypes
import input_pipeline
import qm9
from models import ATOMIC_NUMBERS, RADII, create_model

In [21]:
# path = "/home/ameyad/spherical-harmonic-net/workdirs/v3/mace/interactions=3/l=2/channels=32"
path = "/home/ameyad/spherical-harmonic-net/workdirs/v3/e3schnet/interactions=4/l=3/channels=32"
path = "/home/ameyad/spherical-harmonic-net/workdirs/v3/mace/interactions=4/l=3/channels=32"

with open(path + "/checkpoints/params.pkl", 'rb') as f:
    params = pickle.load(f)
with open(path + "/config.yml", "rt") as config_file:
    config = yaml.unsafe_load(config_file)

assert config is not None
config = ml_collections.ConfigDict(config)

In [22]:
name = path.split("/")[-5:]
name = "_".join(name)
print(name)

v3_mace_interactions=4_l=3_channels=32


In [24]:
model = create_model(config, run_in_evaluation_mode=True)
apply_fn = jax.jit(model.apply)

def apply(frag, seed):
    frags = jraph.pad_with_graphs(frag, 32, 1024, 2)
    preds = apply_fn(params, seed, frags)
    pred = jraph.unpad_with_graphs(preds)
    return pred

In [25]:
def append_pred_to_ase_atoms(molecule: ase.Atoms, pred: datatypes.Predictions) -> ase.Atoms:
    focus = pred.globals.focus_indices.squeeze(0)
    pos_focus = molecule.positions[focus]
    pos_rel = pred.globals.position_vectors.squeeze(0)

    new_specie = jnp.array(
        ATOMIC_NUMBERS[pred.globals.target_species.squeeze(0).item()]
    )
    new_position = pos_focus + pos_rel

    return ase.Atoms(
        positions=jnp.concatenate([molecule.positions, new_position[None, :]], axis=0),
        numbers=jnp.concatenate([molecule.numbers, new_specie[None]], axis=0),
    )


In [26]:
molecules = []

for seed in tqdm.tqdm(range(64)):
    molecule = ase.Atoms(
        positions=jnp.array([[0, 0, 0.0]]),
        numbers=jnp.array([6]),
    )

    rng = jax.random.PRNGKey(seed)
    for step in range(31):
        k, rng = jax.random.split(rng)
        frag = input_pipeline.ase_atoms_to_jraph_graph(
            molecule, ATOMIC_NUMBERS, config.nn_cutoff
        )
        pred = apply(frag, k)

        stop = pred.globals.stop.squeeze(0).item()
        if stop:
            break

        molecule = append_pred_to_ase_atoms(molecule, pred)

    if molecule.numbers.shape[0] < 32:
        molecules.append(molecule)


100%|██████████| 64/64 [00:52<00:00,  1.22it/s]


In [27]:
if not os.path.exists(f"gen/{name}"):
    os.mkdir(f"gen/{name}")

for seed, molecule in enumerate(molecules):
    ase.io.write(f"gen/{name}/molecule_{seed}.xyz", molecule)