## 4. Creating Symphony

$$
\def\gR{\mathcal{R}}
\def\gS{\mathcal{S}}
\def\r{\vec{\mathbf{r}}}
\def\R{{\mathbf{R}}}
\def\T{{\mathbf{T}}}
\def\pT{p_\Theta}
\def\fT{f_\Theta}
\def\p{p}
$$

We have all the pieces we need to build our autoregressive generative model for molecules, Symphony.

Given the current fragment $\gS$, we need to predict the focus atom $f$, the next atom's type $Z$ and it's position $\r$ relative to the focus atom.

For this, Symphony employs two $E(3)$-equivariant neural networks:
- The first network needs to predict the focus node index in $S$ and the next atom type $Z$. Both of these quantities should be invariant to the permutation of the atoms in $\gS$, so we use an $E(3)$ equivariant message-passing network, specifically [NequIP](https://www.nature.com/articles/s41467-022-29939-5). Note that these quantities are invariant to the rotation and translation of the molecule, and hence, we only need to extract the scalar features from each node to predict them. Essentially, this is a glorified classification problem, and we don't need to do anything fancy here. 
- The more *interesting part is predicting the relative position of the next atom. Here, we can use the 'multi-channel spherical harmonics' idea that we explored in the second notebook. In particular, the second network computes an embedding for the focus node via message-passing and then predicts the coefficients $c^l_{\text{ch}}(r)$ for each discretized distance $r$. This gives rise to the following distribution over 3D:
$$
p(r, \theta, \phi; \gS) = \frac{\exp{f(r, \theta, \phi)}}{\int_{\mathbb{R}^3}\exp{f(r, \theta', \phi')} \ dV'} \\
f(r, \theta, \phi; \gS) = \log \sum_{\text{ch}} \exp\left(\sum_{l = 0}^{L} (c^l_{\text{ch}}(r))^T Y^l(\theta, \phi)\right)
$$
which satisifies:
$$

\int_{\mathbb{R}^3} p(r, \theta, \phi; \gS) \ dV &= 1 \\
\p(r, \theta, \phi; \gS) &\geq 0
\end{aligned}
$$
where $dV = r^2 \sin \theta \ dr \ d\theta \ d\phi$ is the volume element in spherical coordinates.

Further, because we use $E(3)$-equivariant neural networks, the distribution $p(r, \theta, \phi; \gS)$ is also equivariant to the rotation $\R$ and translation $\T$ of the fragment $\gS$:
$$
p(\R \cdot (r, \theta, \phi) + \T; \R \gS + \T) = p(r, \theta, \phi; \gS) \\
$$

## Symphony on QM9

Here, you can play around with a pre-trained Symphony model on the QM9 dataset.

In [None]:
# Setup.
!pip install git+https://github.com/atomicarchitects/symphony@tutorial
!curl -L -H 'Cache-Control: no-cache' https://github.com/atomicarchitects/symphony/archive/refs/heads/tutorial.zip -o symphony-tutorial.zip
!unzip symphony-tutorial.zip --yes
!mv symphony-tutorial-tutorial/tutorial tutorial
!rm -r symphony-tutorial-tutorial symphony-tutorial.zip
!cd tutorial


In [None]:
# Imports.
from typing import List
import pickle

import jax
import jax.numpy as jnp
import jraph
import numpy as np
import plotly.graph_objects as go

import sys
sys.path.append("../")

import tutorial.visualizer as visualizer
import tutorial.tutorial_utils as tutorial_utils

In [None]:
model, params, config = tutorial_utils.load_model_at_step(
    workdir="./workdir",
    step="best",
    run_in_evaluation_mode=True,
)

Load some fragments from QM9. We have extracted some fragments from QM9 already, with the following code:


In [None]:
saved_fragments_path = "./data/qm9_fragments_list.pkl"

if os.path.exists(saved_fragments_path):
    with open(saved_fragments_path, "rb") as f:
        molecule_fragments = pickle.load(f)

else:
    from symphony.data import qm9
    from symphony.data import fragments

    # Load the QM9 dataset.    
    molecules = qm9.load_qm9("./data/qm9", use_edm_splits=True, check_molecule_sanity=False)
    
    # We pick the first molecule in the dataset.
    molecule = molecules[0]
    molecule_graph = tutorial_utils.ase_atoms_to_jraph_graph(
        atoms=molecule,
        atomic_numbers=np.asarray([1, 6, 7, 8, 9]),
        nn_cutoff=config.nn_cutoff
    )
    molecule_fragments = fragments.generate_fragments(
        jax.random.PRNGKey(0),
        molecule_graph,
        n_species=5,
        nn_tolerance=config.nn_tolerance,
        mode="nn",
    )
    molecule_fragments = list(molecule_fragments)

    with open(saved_fragments_path, "wb") as f:
        pickle.dump(molecule_fragments, f)

We start off with a fragment with a single atom.

In [None]:
visualizer.visualize_fragment(molecule_fragments[0])

... which grows into a larger fragment:

In [None]:
visualizer.visualize_fragment(molecule_fragments[1])

In [None]:
visualizer.visualize_fragment(molecule_fragments[5])

and end up with a molecule!

In [None]:
visualizer.visualize_fragment(molecule_fragments[-1])

We can query Symphony for the next atom in the sequence.
We could do this recursively to generate whole molecules, but here we'll just generate one atom starting from each fragment.
Note that the model has seen many different sequences of atoms, so its predictions may not match the sequence of fragments here!

In [None]:
apply_rng = jax.random.PRNGKey(0)
preds = jax.jit(model.apply)(
    params,
    apply_rng,
    jraph.batch(molecule_fragments),
    focus_and_atom_type_inverse_temperature=1.0,
    position_inverse_temperature=1.0
)

fixed_preds = []
for index, (fragment, pred) in enumerate(
    zip(molecule_fragments, jraph.unbatch(preds))
):
    # Remove batch dimension.
    # Also, correct the focus indices.
    pred = pred._replace(
        globals=jax.tree_util.tree_map(lambda x: np.squeeze(x, axis=0), pred.globals)
    )
    corrected_focus_indices = (
        pred.globals.focus_indices - preds.n_node[:index].sum()
    )
    pred = pred._replace(
        globals=pred.globals._replace(focus_indices=corrected_focus_indices)
    )

    fixed_preds.append(pred)

Note how Symphony respects all symmetries of the input fragment $\gS$.
If there is only one atom in the fragment, the model predicts a rotationally symmetric distribution over the next atom's position.

In [None]:
visualizer.visualize_predictions(fixed_preds[0], molecule_fragments[0], showlegend=True)

In [None]:
visualizer.visualize_predictions(fixed_preds[1], molecule_fragments[1], showlegend=True)

In [None]:
visualizer.visualize_predictions(fixed_preds[5], molecule_fragments[5], showlegend=True)

In [None]:
visualizer.visualize_predictions(fixed_preds[-1], molecule_fragments[-1], showlegend=True)