In [1]:
import os
import queue
import shutil

import jax
import jax.numpy as jnp
import numpy as np
import matplotlib.pyplot as plt
import jraph
import ase

import sys
sys.path.append('../')

In [2]:
%load_ext autoreload

In [3]:
%autoreload 2

from analyses import analysis
from symphony.data import input_pipeline_tf, input_pipeline
from symphony import datatypes
from symphony.models import utils

  jax.tree_util.register_keypaths(data_clz, keypaths)
  jax.tree_util.register_keypaths(data_clz, keypaths)
  jax.tree_util.register_keypaths(data_clz, keypaths)
  jax.tree_util.register_keypaths(data_clz, keypaths)


In [4]:
piece = 3
workdir = f"/Users/ameyad/Documents/spherical-harmonic-net/potato_workdirs/platonic_solids_by_piece_multifocus/e3schnet_and_nequip/piece={piece}/interactions=3/l=5/position_channels=2/channels=64/"
model, params, config = analysis.load_model_at_step(workdir, step="best", run_in_evaluation_mode=True)
apply_fn = jax.jit(model.apply)

In [5]:
def append_predictions(fragments: datatypes.Fragments, preds: datatypes.Predictions, merge_cutoff: float) -> datatypes.Fragments:
    """Appends the predictions to the fragments."""
    # Bring back to CPU.
    fragments = jax.tree_map(np.asarray, fragments)
    preds = jax.tree_map(np.asarray, preds)
    valids = jraph.get_graph_padding_mask(fragments)

    # Process each fragment.
    valid_count = 0
    for valid, fragment, pred in zip(valids, jraph.unbatch(fragments), jraph.unbatch(preds)):
        if not valid:
            continue
        valid_count += 1
        # print("Valid count:", valid_count)
        yield append_predictions_to_fragment(fragment, pred, merge_cutoff)


def append_predictions_to_fragment(fragment: datatypes.Fragments, pred: datatypes.Predictions, merge_cutoff: float) -> datatypes.Fragments:
    """Appends the predictions to a single fragment."""
    focus_mask = pred.nodes.focus_mask
    target_relative_positions = pred.nodes.position_vectors[focus_mask]
    extra_positions = target_relative_positions + fragment.nodes.positions[focus_mask]
    extra_species = pred.nodes.target_species[focus_mask]
    print("focus probs", pred.nodes.focus_probs)
    print("focus_mask:", focus_mask)
    stop = pred.globals.stop

    # Filter out positions too close to each other.
    new_positions = fragment.nodes.positions
    new_species = fragment.nodes.species
    for extra_position, extra_specie in zip(extra_positions, extra_species):
        if np.min(np.linalg.norm(new_positions - extra_position, axis=1)) < merge_cutoff:
            continue

        new_positions = np.concatenate([new_positions, [extra_position]], axis=0)
        new_species = np.concatenate([new_species, [extra_specie]], axis=0)

    atomic_numbers = np.asarray([1, 6, 7, 8, 9])
    return fragment, input_pipeline.ase_atoms_to_jraph_graph(
        atoms=ase.Atoms(
            numbers=atomic_numbers[new_species],
            positions=new_positions
        ),
        atomic_numbers=atomic_numbers,
        nn_cutoff=config.nn_cutoff
    ), stop

In [6]:
def create_initial_fragment():
    return input_pipeline.ase_atoms_to_jraph_graph(
        atoms=ase.Atoms(
            positions=np.zeros((1, 3)),
            numbers=np.asarray([1]),
        ),
        atomic_numbers=np.asarray([1, 6, 7, 8, 9]),
        nn_cutoff=config.nn_cutoff
    )

In [7]:
num_fragments = 10
completed_fragments = []
fragment_pool = queue.SimpleQueue()
for _ in range(2 * num_fragments):
    fragment_pool.put(create_initial_fragment())

rng = jax.random.PRNGKey(0)
n_graph_for_padding = 10

def make_queue_iterator(q):
    """Makes a non-blocking iterator from a queue."""
    while q.qsize() > 0:
        yield q.get(block=False)


while len(completed_fragments) < num_fragments and fragment_pool.qsize() > 0:
    n_graph_for_padding = 1 + min(n_graph_for_padding - 1, fragment_pool.qsize())
    fragments = next(jraph.dynamically_batch(make_queue_iterator(fragment_pool),
                                            n_node=n_graph_for_padding * 20,
                                            n_edge=n_graph_for_padding * 100,
                                            n_graph=n_graph_for_padding))


    step_rng, rng = jax.random.split(rng)
    preds = apply_fn(params, step_rng, fragments, focus_and_atom_type_inverse_temperature=1.0, position_inverse_temperature=10.0)
    for fragment, new_fragment, stop in append_predictions(fragments, preds, merge_cutoff=0.5):
        if stop or new_fragment.nodes.species.shape[0] > 20:
            completed_fragments.append(new_fragment)
        else:
            print(fragment.nodes.species.shape[0], new_fragment.nodes.species.shape[0], stop)
            fragment_pool.put(new_fragment)

print(completed_fragments)

focus probs [0.99999213]
focus_mask: [ True]
1 2 [False]
focus probs [0.99999213]
focus_mask: [ True]
1 2 [False]
focus probs [0.99999213]
focus_mask: [ True]
1 2 [False]
focus probs [0.99999213]
focus_mask: [ True]
1 2 [False]
focus probs [0.99999213]
focus_mask: [ True]
1 2 [False]
focus probs [0.99999213]
focus_mask: [ True]
1 2 [False]
focus probs [0.99999213]
focus_mask: [ True]
1 2 [False]
focus probs [0.99999213]
focus_mask: [ True]
1 2 [False]
focus probs [0.99999213]
focus_mask: [ True]
1 2 [False]
focus probs [0.99999213]
focus_mask: [ True]
1 2 [False]
focus probs [0.99999213]
focus_mask: [ True]
1 2 [False]
focus probs [0.99999213]
focus_mask: [ True]
1 2 [False]
focus probs [0.99999213]
focus_mask: [ True]
1 2 [False]
focus probs [0.99999213]
focus_mask: [ True]
1 2 [False]
focus probs [0.99999213]
focus_mask: [ True]
1 2 [False]
focus probs [0.99999213]
focus_mask: [ True]
1 2 [False]
focus probs [0.99999213]
focus_mask: [ True]
1 2 [False]
focus probs [0.99999213]
focus_

In [8]:
output_dir = f"ps_debug/piece={piece}"
if os.path.exists(output_dir):
    shutil.rmtree(output_dir)
os.makedirs(output_dir, exist_ok=True)
for index, completed_fragment in enumerate(completed_fragments):
    completed_fragment_ase = ase.Atoms(
        symbols=utils.get_atomic_numbers(completed_fragment.nodes.species),
        positions=completed_fragment.nodes.positions
    )
    completed_fragment_ase.write(f"{output_dir}/fragment_{index}.xyz", append=True)