In [1]:
from ase.atoms import Atoms
from ase.io import write
from ase.visualize import view
import e3nn_jax as e3nn
import jax
import jax.numpy as jnp
import jraph
import plotly.graph_objects as go
import sys

sys.path.append('..')
import analysis
import datatypes
import models
import train

No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)


In [3]:
config, best_state_train, best_state_eval, metrics_for_best_state, datasets = analysis.load_from_workdir('../workdirs/mace/interactions=2/l=3/channels=32', load_pickled_params=False)

Instructions for updating:
Lambda fuctions will be no more assumed to be used in the statement where they are used, or at least in the same block. https://github.com/tensorflow/tensorflow/issues/56089


Instructions for updating:
Lambda fuctions will be no more assumed to be used in the statement where they are used, or at least in the same block. https://github.com/tensorflow/tensorflow/issues/56089


In [4]:
metrics_for_best_state

{'val': {'total_loss': array(1.5165603, dtype=float32),
  'focus_loss': array(0.4437537, dtype=float32),
  'atom_type_loss': array(0.70975226, dtype=float32),
  'position_loss': array(0.36305448, dtype=float32)},
 'test': {'total_loss': array(1.2162731, dtype=float32),
  'focus_loss': array(0.4060237, dtype=float32),
  'atom_type_loss': array(0.63641983, dtype=float32),
  'position_loss': array(0.17382975, dtype=float32)}}

In [5]:
cutoff = 5.0
rng = jax.random.PRNGKey(0)
epsilon = 1e-4

In [6]:
example_graph = next(datasets["test"].as_numpy_iterator())
frag = datatypes.Fragment.from_graphstuple(example_graph)
frag = jax.tree_map(jnp.asarray, frag)

frag_unpadded = jraph.unpad_with_graphs(frag)
molecules = jraph.unbatch(frag_unpadded)

In [7]:
mol = molecules[0]
species_list = mol.nodes.species.tolist()
positions_list = mol.nodes.positions.tolist()

In [8]:
preds = train.get_predictions(best_state_eval, mol, rng)
focus_index = preds.focus_indices.tolist()[0]
true_focus_index = 0
# add stop probability
focus_probs = jax.nn.softmax(jnp.concatenate([preds.focus_logits, jnp.array([0])]))

pred_species = preds.target_species.tolist()[0]
pred_position = preds.position_vectors.tolist()[0]

In [None]:
preds.

In [35]:
atomic_numbers = jnp.array([1, 6, 7, 8, 9])
numbers_to_symbols = {1: 'H', 6: 'C', 7: 'N', 8: 'O', 9: 'F'}
elements = list(numbers_to_symbols.values())

# covalent bond radii, in angstroms
element_radii = [0.32, 0.75, 0.71, 0.63, 0.64]

def get_numbers(species: jnp.ndarray):
    numbers = []
    for i in species:
        numbers.append(atomic_numbers[i])
    return jnp.array(numbers)

In [47]:
mol_atoms = Atoms(positions=positions_list, numbers=get_numbers(species_list))
v = view(mol_atoms, viewer='ngl')

num_nodes = mol.n_node[0].tolist()

for i in range(num_nodes):
    focus_prob = focus_probs.tolist()[i]
    species = species_list[i]

    # add focus probability highlights for each atom
    v.view.shape.add_sphere(
        positions_list[i],
        [1, 0.85, 0],
        element_radii[species] * 0.6,
        f"atom {i} ({elements[species]}): focus probability {focus_prob:.3f}",
    )
    v.view.update_representation(component=i+1, opacity=focus_prob)

# add true focus highlight
v.view.shape.add_sphere(
    positions_list[0],
    [0, 1, 0],
    element_radii[species_list[0]] * 0.6,
    f"atom {i} ({elements[species_list[0]]}): true focus (probability {focus_probs.tolist()[0]:.3f})",
)
v.view.update_representation(component=num_nodes+1, opacity=0.4)

# add the next atom we're adding to this molecule, predicted specie + highlight
v.view.shape.add_sphere(
    pred_position,
    [1, 0, 1],
    element_radii[pred_species] * 0.5,
    f"predicted atom: {elements[pred_species]}",
)

# add an arrow from true focus
pred_focus_position = positions_list[focus_index]
v.view.shape.add_arrow(
    pred_focus_position,
    pred_position,
    [1, 0.85, 0],
    0.1,
    f'distance: {jnp.sqrt(jnp.sum((jnp.array(pred_position)-jnp.array(pred_focus_position))**2)):.3f} A'
)

v

HBox(children=(NGLWidget(), VBox(children=(Dropdown(description='Show', options=('All', 'C', 'H'), value='All'…

In [None]:
img = view.render_image()
with open("img.png", "wb") as f:
    f.write(img.value)