## Symphony on QM9

Here, you can play around with a pre-trained Symphony model on the QM9 dataset.

In [None]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "2"

In [None]:
# Imports
from typing import List
import jax
import jax.numpy as jnp
import jraph
import pickle
import numpy as np
import plotly.graph_objects as go

import sys
sys.path.append("../")

from symphony import models
from symphony.data import input_pipeline
from analyses import visualizer
import configs.qm9.nequip as tutorial_config

In [None]:
config = tutorial_config.get_config()

Load some fragments from QM9. We have extracted some fragments from QM9 already, with the following code:


In [None]:
saved_fragments_path = "./data/qm9_fragments_list.pkl"

if os.path.exists(saved_fragments_path):
    with open(saved_fragments_path, "rb") as f:
        unbatched_fragments = pickle.load(f)

else:
    config.root_dir = "./data/qm9"
    rng = jax.random.PRNGKey(0)
    ds = input_pipeline.get_datasets(rng, config)
    fragments = next(ds["train"])

    def unbatch(fragments: jraph.GraphsTuple) -> List[jraph.GraphsTuple]:
        """Unbatch a batch of fragments."""
        # Remove padding from the fragments to visualize them.
        fragments = jraph.unpad_with_graphs(fragments)

        # Unbatch the fragments, and remove the batch dimension from the globals.
        unbatched_fragments = jraph.unbatch(fragments)

        return [
            fragment._replace(
                globals=jax.tree_util.tree_map(
                    lambda x: np.squeeze(x, axis=0), fragment.globals
                )
            )
            for fragment in unbatched_fragments
        ]

    # We all fragments for the first molecule.
    unbatched_fragments = unbatch(fragments)
    first_stop = [fragment.globals.stop for fragment in unbatched_fragments].index(True)
    unbatched_fragments = unbatched_fragments[:first_stop + 1]

    with open(saved_fragments_path, "wb") as f:
        pickle.dump(unbatched_fragments, f)

We start off with a fragment with a single atom.

In [None]:
visualizer.visualize_fragment(unbatched_fragments[0])

... which grows into a larger fragment:

In [None]:
visualizer.visualize_fragment(unbatched_fragments[1])

In [None]:
visualizer.visualize_fragment(unbatched_fragments[5])

and end up with a molecule!

In [None]:
visualizer.visualize_fragment(unbatched_fragments[-1])

Add step-by-step predictions.

In [None]:
model = models.create_model(config, run_in_evaluation_mode=True)