In [1]:
import os
import pickle
import sys

import ase
import ase.data
import ase.io
import ase.visualize
import jax
import jax.numpy as jnp
import jraph
import ml_collections
import tqdm
import yaml

sys.path.append('..')

In [2]:
import datatypes
import input_pipeline
import qm9
from models import ATOMIC_NUMBERS, RADII, create_model

In [3]:
# path = "/home/ameyad/spherical-harmonic-net/workdirs/v3/mace/interactions=3/l=2/channels=32"
path = "/home/ameyad/spherical-harmonic-net/workdirs/v3/e3schnet/interactions=4/l=3/channels=32"
path = "/home/ameyad/spherical-harmonic-net/workdirs/v3/mace/interactions=4/l=3/channels=32"

with open(path + "/checkpoints/params.pkl", 'rb') as f:
    params = pickle.load(f)
with open(path + "/config.yml", "rt") as config_file:
    config = yaml.unsafe_load(config_file)

assert config is not None
config = ml_collections.ConfigDict(config)

FileNotFoundError: [Errno 2] No such file or directory: '/home/ameyad/spherical-harmonic-net/workdirs/v3/mace/interactions=4/l=3/channels=32/checkpoints/params.pkl'

In [4]:
name = path.split("/")[-5:]
name = "_".join(name)
print(name)

v3_mace_interactions=4_l=3_channels=32


In [13]:
import ase

m = ase.Atoms(
    numbers=[1, 1, 8, 1],
    positions=[[0, 0, 0], [0, 0, 1], [0, 0, 3], [0, 0, 2]],
)
m.get_chemical_formula(mode='reduce'), m.get_chemical_formula(mode='hill'),

('H2OH', 'H3O')

In [14]:
import numpy as np
qm9_dataset = qm9.load_qm9("qm9_data")
from collections import Counter

[m.get_chemical_formula(mode='reduce') for m in qm9_dataset]

Counter([m.get_chemical_formula(mode='hill') for m in qm9_dataset])

Counter({'CH4': 1,
         'H3N': 1,
         'H2O': 1,
         'C2H2': 1,
         'CHN': 1,
         'CH2O': 1,
         'C2H6': 1,
         'CH4O': 1,
         'C3H4': 1,
         'C2H3N': 1,
         'C2H4O': 2,
         'CH3NO': 1,
         'C3H8': 1,
         'C2H6O': 2,
         'C3H6': 1,
         'C3H6O': 5,
         'C2H5NO': 2,
         'CH4N2O': 1,
         'C4H10': 2,
         'C3H8O': 3,
         'C4H2': 1,
         'C3HN': 1,
         'C2N2': 1,
         'C3H2O': 1,
         'C2HNO': 1,
         'C2H2O2': 1,
         'C4H6': 2,
         'C3H5N': 1,
         'C2H4N2': 1,
         'C3H4O': 1,
         'C2H3NO': 1,
         'C2H4O2': 2,
         'C2H6O2': 1,
         'C4H8': 2,
         'C3H7N': 1,
         'C3H7NO': 8,
         'C4H5N': 3,
         'C3H4N2': 3,
         'C4H4O': 5,
         'C3H3NO': 6,
         'C5H12': 3,
         'C4H10O': 7,
         'C2H3N3': 5,
         'C3H4O2': 4,
         'C2H4N2O': 2,
         'C2H3NO2': 3,
         'C5H8': 6,
         'C4H7N':

In [24]:
model = create_model(config, run_in_evaluation_mode=True)
apply_fn = jax.jit(model.apply)

def apply(frag, seed, beta):
    frags = jraph.pad_with_graphs(frag, 32, 1024, 2)
    preds = apply_fn(params, seed, frags, beta)
    pred = jraph.unpad_with_graphs(preds)
    return pred

In [25]:
def append_pred_to_ase_atoms(molecule: ase.Atoms, pred: datatypes.Predictions) -> ase.Atoms:
    focus = pred.globals.focus_indices.squeeze(0)
    pos_focus = molecule.positions[focus]
    pos_rel = pred.globals.position_vectors.squeeze(0)

    new_specie = jnp.array(
        ATOMIC_NUMBERS[pred.globals.target_species.squeeze(0).item()]
    )
    new_position = pos_focus + pos_rel

    return ase.Atoms(
        positions=jnp.concatenate([molecule.positions, new_position[None, :]], axis=0),
        numbers=jnp.concatenate([molecule.numbers, new_specie[None]], axis=0),
    )


In [26]:
molecules = []

for seed in tqdm.tqdm(range(64)):
    molecule = ase.Atoms(
        positions=jnp.array([[0, 0, 0.0]]),
        numbers=jnp.array([6]),
    )

    rng = jax.random.PRNGKey(seed)
    for step in range(31):
        k, rng = jax.random.split(rng)
        frag = input_pipeline.ase_atoms_to_jraph_graph(
            molecule, ATOMIC_NUMBERS, config.nn_cutoff
        )
        pred = apply(frag, k)

        stop = pred.globals.stop.squeeze(0).item()
        if stop:
            break

        molecule = append_pred_to_ase_atoms(molecule, pred)

    if molecule.numbers.shape[0] < 32:
        molecules.append(molecule)


100%|██████████| 64/64 [00:52<00:00,  1.22it/s]


In [27]:
if not os.path.exists(f"gen/{name}"):
    os.mkdir(f"gen/{name}")

for seed, molecule in enumerate(molecules):
    ase.io.write(f"gen/{name}/molecule_{seed}.xyz", molecule)