## 4. Symphony on QM9

Here, you can play around with a pre-trained Symphony model on the QM9 dataset.

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "2"

In [None]:
# Imports
from typing import List
import jax
import jax.numpy as jnp
import jraph
import pickle
import numpy as np
import plotly.graph_objects as go

import sys
sys.path.append("../")

import tutorial.visualizer as visualizer
import tutorial.tutorial_utils as tutorial_utils

In [None]:
model, params, config = tutorial_utils.load_model_at_step(
    workdir="./workdir",
    step="best",
    run_in_evaluation_mode=True,
)

Load some fragments from QM9. We have extracted some fragments from QM9 already, with the following code:


In [None]:
saved_fragments_path = "./data/qm9_fragments_list.pkl"

if os.path.exists(saved_fragments_path):
    with open(saved_fragments_path, "rb") as f:
        molecule_fragments = pickle.load(f)

else:
    from symphony.data import qm9
    from symphony.data import fragments

    # Load the QM9 dataset.    
    molecules = qm9.load_qm9("./data/qm9", use_edm_splits=True, check_molecule_sanity=False)
    
    # We pick the first molecule in the dataset.
    molecule = molecules[0]
    molecule_graph = tutorial_utils.ase_atoms_to_jraph_graph(
        atoms=molecules,
        atomic_numbers=np.asarray([1, 6, 7, 8, 9]),
        nn_cutoff=config.nn_cutoff
    )
    molecule_fragments = fragments.generate_fragments(
        jax.random.PRNGKey(0),
        molecule_graph,
        n_species=5,
        nn_tolerance=config.nn_tolerance,
        mode="nn",
    )
    molecule_fragments = list(molecule_fragments)

    with open(saved_fragments_path, "wb") as f:
        pickle.dump(molecule_fragments, f)

We start off with a fragment with a single atom.

In [None]:
visualizer.visualize_fragment(molecule_fragments[0])

... which grows into a larger fragment:

In [None]:
visualizer.visualize_fragment(molecule_fragments[1])

In [None]:
visualizer.visualize_fragment(molecule_fragments[5])

and end up with a molecule!

In [None]:
visualizer.visualize_fragment(molecule_fragments[-1])

We can query Symphony for the next atom in the sequence.
We could do this recursively to generate whole molecules, but here we'll just generate one atom starting from each fragment.
Note that the model has seen many different sequences of atoms, so its predictions may not match the sequence of fragments here!

In [None]:
apply_rng = jax.random.PRNGKey(0)
preds = jax.jit(model.apply)(
    params,
    apply_rng,
    jraph.batch(molecule_fragments),
    focus_and_atom_type_inverse_temperature=1.0,
    position_inverse_temperature=1.0
)

fixed_preds = []
for index, (fragment, pred) in enumerate(
    zip(molecule_fragments, jraph.unbatch(preds))
):
    # Remove batch dimension.
    # Also, correct the focus indices.
    pred = pred._replace(
        globals=jax.tree_util.tree_map(lambda x: np.squeeze(x, axis=0), pred.globals)
    )
    corrected_focus_indices = (
        pred.globals.focus_indices - preds.n_node[:index].sum()
    )
    pred = pred._replace(
        globals=pred.globals._replace(focus_indices=corrected_focus_indices)
    )

    fixed_preds.append(pred)

In [None]:
visualizer.visualize_predictions(fixed_preds[0], molecule_fragments[0], showlegend=True)

In [None]:
visualizer.visualize_predictions(fixed_preds[1], molecule_fragments[1], showlegend=True)

In [None]:
visualizer.visualize_predictions(fixed_preds[5], molecule_fragments[5], showlegend=True)

In [None]:
visualizer.visualize_predictions(fixed_preds[-1], molecule_fragments[-1], showlegend=True)