## Symphony on QM9

Here, you can play around with a pre-trained Symphony model on the QM9 dataset.

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
# Imports
import jax
import jax.numpy as jnp
import jraph
import nequip_jax
import haiku as hk
import numpy as np
import plotly.graph_objects as go

import sys
sys.path.append('../')

from symphony import models
from symphony.data import input_pipeline
from analyses import visualizer 
import configs.qm9.nequip as tutorial_config


Define the message-passing network.

In [None]:
config = tutorial_config.get_config()
config.root_dir = './tutorial/qm9'

model = models.create_model(config, run_in_evaluation_mode=True)

Load some fragments from QM9.

In [None]:
rng = jax.random.PRNGKey(0)
ds = input_pipeline.get_datasets(rng, config)
fragments = next(ds['train'])

def unbatch(fragments: jraph.GraphsTuple) -> List[jraph.GraphsTuple]:
    """Unbatch a batch of fragments."""
    # Remove padding from the fragments to visualize them.
    fragments = jraph.unpad_with_graphs(fragments)

    # Unbatch the fragments, and remove the batch dimension from the globals.
    unbatched_fragments = jraph.unbatch(fragments)
    return [
        fragment._replace(globals=jax.tree_util.tree_map(lambda x: np.squeeze(x, axis=0), fragment.globals))
        for fragment in unbatched_fragments
    ]

unbatched_fragments = unbatch(fragments)

In [None]:
fragments.globals.target_positions.shape

In [None]:
visualizer.visualize_fragment(unbatched_fragments[0])

In [None]:
visualizer.visualize_fragment(unbatched_fragments[1])

In [None]:
visualizer.visualize_fragment(unbatched_fragments[10])

In [None]:
visualizer.visualize_fragment(unbatched_fragments[11])

In [None]:
visualizer.visualize_fragment(unbatched_fragments[8])