In [1]:
import os
import pickle
import sys

import ase
import ase.data
import ase.io
import ase.visualize
import jax
import jax.numpy as jnp
import jraph
import ml_collections
import tqdm
import yaml

os.environ["CUDA_VISIBLE_DEVICES"] = "4"

sys.path.append('..')

In [2]:
import datatypes
import input_pipeline
import qm9
from models import ATOMIC_NUMBERS, RADII, create_model

In [12]:
# path = "/home/ameyad/spherical-harmonic-net/workdirs/v3/mace/interactions=3/l=2/channels=32"
path = "/home/ameyad/spherical-harmonic-net/workdirs/v3/e3schnet/interactions=4/l=3/channels=32"
path = "/home/ameyad/spherical-harmonic-net/workdirs/v3/mace/interactions=4/l=3/channels=32"
path = "/home/ameyad/spherical-harmonic-net/workdirs/v3/mace/interactions=4/l=5/channels=32"
path = "/home/ameyad/spherical-harmonic-net/workdirs/v4/mace/interactions=6/l=4/channels=32"
path = "/home/ameyad/spherical-harmonic-net/workdirs/v4/mace/interactions=2/l=4/channels=32"
path = "/home/ameyad/spherical-harmonic-net/workdirs/v4/mace/interactions=5/l=5/channels=32"
path = "/home/ameyad/spherical-harmonic-net/workdirs/extras/num_layers/mace/interactions=4/l=5/channels=32/num_layers=4"
path = "/home/ameyad/spherical-harmonic-net/workdirs/v4/mace/interactions=4/l=4/channels=32"

with open(path + "/checkpoints/params.pkl", 'rb') as f:
    params = pickle.load(f)
with open(path + "/config.yml", "rt") as config_file:
    config = yaml.unsafe_load(config_file)

assert config is not None
config = ml_collections.ConfigDict(config)

FileNotFoundError: [Errno 2] No such file or directory: '/home/ameyad/spherical-harmonic-net/workdirs/v3/mace/interactions=4/l=3/channels=32/checkpoints/params.pkl'

In [13]:
config.target_position_predictor.res_alpha = 359
config.target_position_predictor.res_beta = 180

config

activation: softplus
avg_num_neighbors: 15.0
checkpoint_every_steps: 1000
eval_every_steps: 1000
focus_predictor:
  latent_size: 128
  num_layers: 2
learning_rate: 0.001
log_every_steps: 1000
loss_kwargs:
  radius_rbf_variance: 0.001
max_ell: 4
max_n_graphs: 32
model: MACE
nn_cutoff: 5.0
nn_tolerance: 0.5
num_basis_fns: 8
num_channels: 32
num_eval_steps: 100
num_eval_steps_at_end_of_training: 5000
num_interactions: 8
num_species: 5
num_train_steps: 20000
optimizer: adam
r_max: 5
rng_seed: 0
root_dir: /home/ameyad/qm9_data_tf/data_tf2
target_position_predictor:
  res_alpha: 359
  res_beta: 180
target_species_predictor:
  latent_size: 128
  num_layers: 2
test_molecules: !!python/tuple
- 53568
- 133920
train_molecules: !!python/tuple
- 0
- 47616
val_molecules: !!python/tuple
- 47616
- 53568

In [14]:
model = create_model(config, run_in_evaluation_mode=True)
apply_fn = jax.jit(model.apply)

def apply(frag, seed, beta):
    frags = jraph.pad_with_graphs(frag, 32, 1024, 2)
    preds = apply_fn(params, seed, frags, beta)
    pred = jraph.unpad_with_graphs(preds)
    return pred

In [15]:
def append_pred_to_ase_atoms(molecule: ase.Atoms, pred: datatypes.Predictions) -> ase.Atoms:
    focus = pred.globals.focus_indices.squeeze(0)
    pos_focus = molecule.positions[focus]
    pos_rel = pred.globals.position_vectors.squeeze(0)

    new_specie = jnp.array(
        ATOMIC_NUMBERS[pred.globals.target_species.squeeze(0).item()]
    )
    new_position = pos_focus + pos_rel

    return ase.Atoms(
        positions=jnp.concatenate([molecule.positions, new_position[None, :]], axis=0),
        numbers=jnp.concatenate([molecule.numbers, new_specie[None]], axis=0),
    )


In [16]:
name = path.split("/")[-5:]
name = "_".join(name)

beta = 1

name = f"{name}_beta={beta}"

print(name)

v4_mace_interactions=8_l=4_channels=32_beta=1


In [17]:
molecules = []

if not os.path.exists(f"gen/{name}"):
    os.mkdir(f"gen/{name}")

for seed in tqdm.tqdm(range(64)):
    molecule = ase.Atoms(
        positions=jnp.array([[0, 0, 0.0]]),
        numbers=jnp.array([6]),
    )

    rng = jax.random.PRNGKey(seed)
    for step in range(31):
        k, rng = jax.random.split(rng)
        frag = input_pipeline.ase_atoms_to_jraph_graph(
            molecule, ATOMIC_NUMBERS, config.nn_cutoff
        )
        pred = apply(frag, k, beta)

        stop = pred.globals.stop.squeeze(0).item()
        if stop:
            molecules.append(molecule)

            ase.io.write(f"gen/{name}/molecule_{seed}.xyz", molecule)

            print(f"Generated molecule {seed} of {len(molecule)} atoms")
            break

        molecule = append_pred_to_ase_atoms(molecule, pred)


  5%|▍         | 3/64 [00:41<09:36,  9.45s/it]

Generated molecule 2 of 5 atoms


  6%|▋         | 4/64 [00:41<05:51,  5.85s/it]

Generated molecule 3 of 6 atoms


  9%|▉         | 6/64 [00:44<03:06,  3.21s/it]

Generated molecule 5 of 13 atoms


 12%|█▎        | 8/64 [00:46<01:59,  2.14s/it]

Generated molecule 7 of 14 atoms


 16%|█▌        | 10/64 [00:49<01:29,  1.67s/it]

Generated molecule 9 of 13 atoms


 27%|██▋       | 17/64 [01:01<01:10,  1.49s/it]

Generated molecule 16 of 13 atoms


 39%|███▉      | 25/64 [01:14<00:53,  1.37s/it]

Generated molecule 24 of 8 atoms


 41%|████      | 26/64 [01:14<00:40,  1.07s/it]

Generated molecule 25 of 7 atoms


 45%|████▌     | 29/64 [01:19<00:51,  1.47s/it]

Generated molecule 28 of 27 atoms


 48%|████▊     | 31/64 [01:22<00:46,  1.41s/it]

Generated molecule 30 of 19 atoms


 53%|█████▎    | 34/64 [01:27<00:41,  1.39s/it]

Generated molecule 33 of 16 atoms


 58%|█████▊    | 37/64 [01:31<00:34,  1.29s/it]

Generated molecule 36 of 9 atoms


 62%|██████▎   | 40/64 [01:36<00:34,  1.46s/it]

Generated molecule 39 of 21 atoms


 66%|██████▌   | 42/64 [01:39<00:31,  1.44s/it]

Generated molecule 41 of 20 atoms


 70%|███████   | 45/64 [01:43<00:27,  1.44s/it]

Generated molecule 44 of 18 atoms


 89%|████████▉ | 57/64 [02:04<00:10,  1.45s/it]

Generated molecule 56 of 10 atoms


 91%|█████████ | 58/64 [02:05<00:07,  1.25s/it]

Generated molecule 57 of 14 atoms


 92%|█████████▏| 59/64 [02:06<00:05,  1.19s/it]

Generated molecule 58 of 18 atoms


 94%|█████████▍| 60/64 [02:07<00:05,  1.32s/it]

Generated molecule 59 of 28 atoms


100%|██████████| 64/64 [02:14<00:00,  2.10s/it]

Generated molecule 63 of 21 atoms





In [9]:
name

'v4_mace_interactions=2_l=4_channels=32_beta=10'