In [13]:
import os
import pickle
import sys

import ase
import ase.data
import ase.io
import ase.visualize
import jax
import jax.numpy as jnp
import jraph
import ml_collections
import numpy as np
import tqdm
import yaml

from typing import List

os.environ["CUDA_VISIBLE_DEVICES"] = "4"

sys.path.append('..')

In [6]:
import datatypes
import input_pipeline
import qm9
from analyses import analysis
from models import ATOMIC_NUMBERS, RADII, create_model

  from .autonotebook import tqdm as notebook_tqdm
  jax.tree_util.register_keypaths(data_clz, keypaths)


In [7]:
# path = "/home/ameyad/spherical-harmonic-net/workdirs/v3/mace/interactions=3/l=2/channels=32"
# path = "/home/ameyad/spherical-harmonic-net/workdirs/v3/e3schnet/interactions=4/l=3/channels=32"
# path = "/home/ameyad/spherical-harmonic-net/workdirs/v3/mace/interactions=4/l=3/channels=32"
path = "/Users/songk/atomicarchitects/spherical_harmonic_net/workdirs/v3/mace/interactions=4/l=5/channels=32"

with open(path + "/checkpoints/params.pkl", 'rb') as f:
    params = pickle.load(f)
with open(path + "/config.yml", "rt") as config_file:
    config = yaml.unsafe_load(config_file)

assert config is not None
config = ml_collections.ConfigDict(config)

In [8]:
config.target_position_predictor.res_alpha = 359
config.target_position_predictor.res_beta = 180

config

activation: softplus
avg_num_neighbors: 15.0
checkpoint_every_steps: 1000
eval_every_steps: 1000
focus_predictor:
  latent_size: 128
  num_layers: 2
learning_rate: 0.001
log_every_steps: 1000
loss_kwargs:
  radius_rbf_variance: 0.001
max_ell: 5
max_n_edges: 1024
max_n_graphs: 64
max_n_nodes: 512
model: MACE
nn_cutoff: 5.0
nn_tolerance: 0.5
num_basis_fns: 8
num_channels: 32
num_eval_steps: 100
num_eval_steps_at_end_of_training: 5000
num_interactions: 4
num_species: 5
num_train_steps: 20000
optimizer: adam
r_max: 5
rng_seed: 0
root_dir: /Users/songk/atomicarchitects/spherical_harmonic_net/qm9_data_tf/data_tf2
target_position_predictor:
  res_alpha: 359
  res_beta: 180
target_species_predictor:
  latent_size: 128
  num_layers: 2
test_molecules: !!python/tuple
- 53568
- 133920
train_molecules: !!python/tuple
- 0
- 47616
val_molecules: !!python/tuple
- 47616
- 53568

In [9]:
model = create_model(config, run_in_evaluation_mode=True)
apply_fn = jax.jit(model.apply)

def apply(frag, seed, beta):
    frags = jraph.pad_with_graphs(frag, 32, 1024, 2)
    preds = apply_fn(params, seed, frags, beta)
    pred = jraph.unpad_with_graphs(preds)
    return pred

In [10]:
def append_pred_to_ase_atoms(molecule: ase.Atoms, pred: datatypes.Predictions) -> ase.Atoms:
    focus = pred.globals.focus_indices.squeeze(0)
    pos_focus = molecule.positions[focus]
    pos_rel = pred.globals.position_vectors.squeeze(0)

    new_specie = jnp.array(
        ATOMIC_NUMBERS[pred.globals.target_species.squeeze(0).item()]
    )
    new_position = pos_focus + pos_rel

    return ase.Atoms(
        positions=jnp.concatenate([molecule.positions, new_position[None, :]], axis=0),
        numbers=jnp.concatenate([molecule.numbers, new_specie[None]], axis=0),
    )


In [11]:
name = path.split("/")[-5:]
name = "_".join(name)

beta = 1

name = f"{name}_beta={beta}"

print(name)

v3_mace_interactions=4_l=5_channels=32_beta=1


In [12]:
molecules = []

if not os.path.exists(f"gen/{name}"):
    os.mkdir(f"gen/{name}")

for seed in tqdm.tqdm(range(64)):
    molecule = ase.Atoms(
        positions=jnp.array([[0, 0, 0.0]]),
        numbers=jnp.array([6]),
    )

    rng = jax.random.PRNGKey(seed)
    for step in range(31):
        k, rng = jax.random.split(rng)
        frag = input_pipeline.ase_atoms_to_jraph_graph(
            molecule, ATOMIC_NUMBERS, config.nn_cutoff
        )
        pred = apply(frag, k, beta)

        stop = pred.globals.stop.squeeze(0).item()
        if stop:
            molecules.append(molecule)

            ase.io.write(f"gen/{name}/molecule_{seed}.xyz", molecule)

            print(f"Generated molecule {seed} of {len(molecule)} atoms")
            break

        molecule = append_pred_to_ase_atoms(molecule, pred)


  2%|▏         | 1/64 [02:34<2:42:35, 154.86s/it]

Generated molecule 0 of 15 atoms


  3%|▎         | 2/64 [02:56<1:19:00, 76.45s/it] 

Generated molecule 1 of 14 atoms


  5%|▍         | 3/64 [03:10<48:40, 47.88s/it]  

Generated molecule 2 of 9 atoms


  6%|▋         | 4/64 [04:00<48:36, 48.60s/it]

Generated molecule 3 of 30 atoms


  8%|▊         | 5/64 [04:38<44:15, 45.00s/it]

Generated molecule 4 of 24 atoms


  9%|▉         | 6/64 [04:52<33:22, 34.53s/it]

Generated molecule 5 of 9 atoms


 11%|█         | 7/64 [05:11<27:56, 29.41s/it]

Generated molecule 6 of 12 atoms


 12%|█▎        | 8/64 [05:31<24:34, 26.33s/it]

Generated molecule 7 of 12 atoms


 14%|█▍        | 9/64 [05:49<21:39, 23.62s/it]

Generated molecule 8 of 11 atoms


 16%|█▌        | 10/64 [06:28<25:36, 28.46s/it]

Generated molecule 9 of 25 atoms


 19%|█▉        | 12/64 [07:52<29:57, 34.58s/it]

Generated molecule 11 of 21 atoms


 20%|██        | 13/64 [08:06<24:05, 28.35s/it]

Generated molecule 12 of 9 atoms


 22%|██▏       | 14/64 [08:32<23:06, 27.74s/it]

Generated molecule 13 of 17 atoms


 27%|██▋       | 17/64 [10:27<25:14, 32.22s/it]

Generated molecule 16 of 12 atoms


 30%|██▉       | 19/64 [11:21<20:45, 27.67s/it]

Generated molecule 18 of 4 atoms


 31%|███▏      | 20/64 [11:27<15:34, 21.23s/it]

Generated molecule 19 of 4 atoms


 33%|███▎      | 21/64 [11:45<14:38, 20.44s/it]

Generated molecule 20 of 12 atoms


 34%|███▍      | 22/64 [12:08<14:52, 21.25s/it]

Generated molecule 21 of 15 atoms


 36%|███▌      | 23/64 [12:25<13:38, 19.97s/it]

Generated molecule 22 of 11 atoms


 39%|███▉      | 25/64 [13:30<16:12, 24.94s/it]

Generated molecule 24 of 11 atoms


 41%|████      | 26/64 [13:43<13:24, 21.16s/it]

Generated molecule 25 of 8 atoms


 42%|████▏     | 27/64 [14:03<12:51, 20.85s/it]

Generated molecule 26 of 13 atoms


 44%|████▍     | 28/64 [14:20<11:53, 19.83s/it]

Generated molecule 27 of 11 atoms


 45%|████▌     | 29/64 [14:42<11:54, 20.42s/it]

Generated molecule 28 of 14 atoms


 48%|████▊     | 31/64 [15:39<12:22, 22.49s/it]

Generated molecule 30 of 4 atoms


 52%|█████▏    | 33/64 [16:59<15:47, 30.57s/it]

Generated molecule 32 of 20 atoms


 53%|█████▎    | 34/64 [17:27<14:53, 29.78s/it]

Generated molecule 33 of 18 atoms


 55%|█████▍    | 35/64 [17:40<12:05, 25.03s/it]

Generated molecule 34 of 9 atoms


 56%|█████▋    | 36/64 [18:22<14:01, 30.06s/it]

Generated molecule 35 of 27 atoms


 58%|█████▊    | 37/64 [18:58<14:17, 31.77s/it]

Generated molecule 36 of 23 atoms


 59%|█████▉    | 38/64 [19:22<12:46, 29.48s/it]

Generated molecule 37 of 15 atoms


 61%|██████    | 39/64 [19:38<10:34, 25.36s/it]

Generated molecule 38 of 10 atoms


 62%|██████▎   | 40/64 [19:55<09:10, 22.94s/it]

Generated molecule 39 of 11 atoms


 64%|██████▍   | 41/64 [20:19<08:50, 23.05s/it]

Generated molecule 40 of 15 atoms


 66%|██████▌   | 42/64 [20:48<09:10, 25.03s/it]

Generated molecule 41 of 19 atoms


 67%|██████▋   | 43/64 [21:15<08:55, 25.51s/it]

Generated molecule 42 of 17 atoms


 69%|██████▉   | 44/64 [21:26<07:03, 21.15s/it]

Generated molecule 43 of 7 atoms


 70%|███████   | 45/64 [21:45<06:28, 20.44s/it]

Generated molecule 44 of 12 atoms


 72%|███████▏  | 46/64 [22:25<07:54, 26.36s/it]

Generated molecule 45 of 26 atoms


 75%|███████▌  | 48/64 [23:25<07:06, 26.67s/it]

Generated molecule 47 of 8 atoms


 77%|███████▋  | 49/64 [24:43<10:29, 41.94s/it]

Generated molecule 48 of 21 atoms


 80%|███████▉  | 51/64 [37:37<41:12, 190.22s/it]

Generated molecule 50 of 13 atoms


 83%|████████▎ | 53/64 [41:27<26:41, 145.59s/it]

Generated molecule 52 of 29 atoms


 84%|████████▍ | 54/64 [41:56<18:23, 110.39s/it]

Generated molecule 53 of 18 atoms


 86%|████████▌ | 55/64 [42:18<12:36, 84.10s/it] 

Generated molecule 54 of 14 atoms


 88%|████████▊ | 56/64 [42:45<08:55, 66.97s/it]

Generated molecule 55 of 17 atoms


 89%|████████▉ | 57/64 [43:10<06:20, 54.39s/it]

Generated molecule 56 of 16 atoms


 92%|█████████▏| 59/64 [44:36<04:00, 48.05s/it]

Generated molecule 58 of 24 atoms


 94%|█████████▍| 60/64 [44:57<02:38, 39.72s/it]

Generated molecule 59 of 13 atoms


 95%|█████████▌| 61/64 [45:25<01:49, 36.36s/it]

Generated molecule 60 of 18 atoms


 97%|█████████▋| 62/64 [45:58<01:10, 35.37s/it]

Generated molecule 61 of 21 atoms


 98%|█████████▊| 63/64 [46:18<00:30, 30.82s/it]

Generated molecule 62 of 13 atoms


100%|██████████| 64/64 [46:47<00:00, 43.87s/it]

Generated molecule 63 of 18 atoms





In [14]:
def ase_to_mol_dict(molecules: List[ase.Atoms], save=True, model_path=None, file_name=None):
    '''from G-SchNet: https://github.com/atomistic-machine-learning/G-SchNet'''

    generated = (
        {}
    )
    for mol in molecules:
        l = mol.get_atomic_numbers().shape[0]
        if l not in generated:
            generated[l] = {
                "_positions": np.array([mol.get_positions()]),
                "_atomic_numbers": np.array([mol.get_atomic_numbers()]),
            }
        else:
            generated[l]["_positions"] = np.append(
                generated[l]["_positions"],
                np.array([mol.get_positions()]),
                0,
            )
            generated[l]["_atomic_numbers"] = np.append(
                generated[l]["_atomic_numbers"],
                np.array([mol.get_atomic_numbers()]),
                0,
            )

    return generated

In [15]:
name

'v3_mace_interactions=4_l=5_channels=32_beta=1'

In [16]:
with open(f'gen/{name}/molecules.pkl', 'wb') as f:
    pickle.dump(molecules, f)

In [17]:
mol_dict = ase_to_mol_dict(molecules)

In [20]:
with open("../workdirs/generated/generated_molecules.mol_dict", "wb") as f:
    pickle.dump(mol_dict, f)