In [1]:
from typing import *

import chex
import jax
import jax.numpy as jnp
import ase
import numpy as np
import jraph
import plotly.graph_objects as go
import plotly.subplots

import sys

sys.path.append("..")

from symphony.data import input_pipeline_tf
import configs.qm9.nequip as nequip
from symphony import models, datatypes

In [2]:
%load_ext autoreload

In [5]:
%autoreload 2
import analyses.analysis as analysis

In [None]:
config = nequip.get_config()
config.root_dir = root_dir.get_root

In [None]:
rng = jax.random.PRNGKey(config.rng_seed)
rng, dataset_rng = jax.random.split(rng)
datasets = input_pipeline_tf.get_datasets(dataset_rng, config)

config.shuffle_datasets = False
fragments = next(datasets["train"].take(1).as_numpy_iterator())

ValueError: No files found in None.

In [None]:
import tensorflow as tf
element_spec = tf.TensorSpec(shape=(None, 3), dtype=tf.float32)
def fragment_yielder():
    for index, fragment in enumerate(jraph.unbatch(fragments)):
        yield index * tf.ones_like(fragment.nodes.positions)

ds = tf.data.Dataset.from_generator(
    fragment_yielder, output_signature=element_spec
)
for x in ds:
    print(x)

In [None]:
unbatched = jraph.unbatch(fragments)
print(unbatched[0].nodes.positions, unbatched[1].nodes.positions, unbatched[2].nodes.positions)

In [None]:
workdir = "../potato_workdirs/tetris/nequip/interactions=2/l=4/channels=64/"
step = "best"
focus_and_atom_type_inverse_temperature = 1
position_inverse_temperature = 10

model, params, config = analysis.load_model_at_step(
    workdir, step, run_in_evaluation_mode=True
)
apply_fn = jax.jit(model.apply)

preds = jax.jit(model.apply)(
    params,
    rng,
    fragments,
    focus_and_atom_type_inverse_temperature,
    position_inverse_temperature,
)
preds = jax.tree_map(np.asarray, preds)
preds

In [None]:
def visualize_predictions(fragment: datatypes.Fragments, pred: Optional[datatypes.Predictions] = None) -> go.Figure:
    atomic_numbers = list(
        int(num) for num in models.get_atomic_numbers(fragment.nodes.species)
    )

    fig = plotly.subplots.make_subplots(
        rows=1,
        cols=3,
        specs=[[{"type": "scene"}, {"type": "scene"}, {"type": "xy"}]],
        subplot_titles=("Input Fragment", "Output Predictions", "Output Predictions"),
    )
    common_traces = []
    common_traces.append(
        go.Scatter3d(
            x=fragment.nodes.positions[:, 0],
            y=fragment.nodes.positions[:, 1],
            z=fragment.nodes.positions[:, 2],
            mode="markers",
            marker=dict(
                size=[analysis.ATOMIC_SIZES[num] for num in atomic_numbers],
                color=[analysis.ATOMIC_COLORS[num] for num in atomic_numbers],
            ),
            hovertext=[
                f"Element: {ase.data.chemical_symbols[num]}" for num in atomic_numbers
            ],
            opacity=1.0,
            name="Molecule Atoms",
            legendrank=1,
        )
    )

    # Add bonds.
    for i, j in zip(fragment.senders, fragment.receivers):
        common_traces.append(
            go.Scatter3d(
                x=fragment.nodes.positions[[i, j], 0],
                y=fragment.nodes.positions[[i, j], 1],
                z=fragment.nodes.positions[[i, j], 2],
                line=dict(color="black"),
                mode="lines",
                showlegend=False,
            )
        )

    # Obtain traces for prediction.
    prediction_molecule_traces, focus_and_atom_type_traces = analysis.get_prediction_plotly_traces(pred, fragment)

    # Add traces to figure.
    for trace in common_traces:
        fig.add_trace(trace, row=1, col=1)
        trace.showlegend = False
        fig.add_trace(trace, row=1, col=2)

    for trace in prediction_molecule_traces:
        fig.add_trace(trace, row=1, col=2)

    for trace in focus_and_atom_type_traces:
        fig.add_trace(trace, row=1, col=3)

    # Update layout.
    axis = dict(
        showbackground=False,
        showticklabels=False,
        showgrid=False,
        zeroline=False,
        title="",
        nticks=3,
    )
    fig.update_layout(
        width=800,
        height=500,
        scene1=dict(
            xaxis=dict(**axis),
            yaxis=dict(**axis),
            zaxis=dict(**axis),
            aspectmode="data",
        ),
        scene2=dict(
            xaxis=dict(**axis),
            yaxis=dict(**axis),
            zaxis=dict(**axis),
            aspectmode="data",
        ),
        paper_bgcolor="rgba(255,255,255,1)",
        plot_bgcolor="rgba(255,255,255,1)",
        legend=dict(
            yanchor="top",
            y=0.99,
            xanchor="right",
            x=0.1,
        ),
    )

    # Sync cameras.
    fig_widget = go.FigureWidget(fig)
    def cam_change_1(layout, camera):
        fig_widget.layout.scene2.camera = camera
    def cam_change_2(layout, camera):
        if fig_widget.layout.scene1.camera != camera:
            fig_widget.layout.scene1.camera = camera

    fig_widget.layout.scene1.on_change(cam_change_1, 'camera')
    fig_widget.layout.scene2.on_change(cam_change_2, 'camera')

    return fig_widget

In [None]:
index = 3
fragment = jraph.unbatch(fragments)[index]
pred = jraph.unbatch(preds)[index]

fragment = fragment._replace(
    globals=jax.tree_map(lambda x: np.squeeze(x, axis=0), fragment.globals)
)
pred = pred._replace(
    globals=jax.tree_map(lambda x: np.squeeze(x, axis=0), pred.globals)
)
corrected_focus_indices = pred.globals.focus_indices - preds.n_node[:index].sum()
pred = pred._replace(
    globals=pred.globals._replace(focus_indices=corrected_focus_indices)
)

fig1 = analysis.visualize_predictions(pred, fragment)

In [None]:
index = 2
fragment = jraph.unbatch(fragments)[index]
pred = jraph.unbatch(preds)[index]

fragment = fragment._replace(
    globals=jax.tree_map(lambda x: np.squeeze(x, axis=0), fragment.globals)
)
pred = pred._replace(
    globals=jax.tree_map(lambda x: np.squeeze(x, axis=0), pred.globals)
)
corrected_focus_indices = pred.globals.focus_indices - preds.n_node[:index].sum()
pred = pred._replace(
    globals=pred.globals._replace(focus_indices=corrected_focus_indices)
)

fig2 = analysis.visualize_predictions(pred, fragment)

In [None]:
figure1 = go.Figure(data=[go.Scatter(x=[1, 2, 3], y=[1, 3, 2])])
figure2 = go.Figure(data=[go.Scatter(x=[1, 2, 3], y=[2, 1, 3])])
slider_steps = [
    {'args': [[{'data': figure1.data}]], 'label': 'Figure 1', 'method': 'update'},
    {'args': [[{'data': figure2.data}]], 'label': 'Figure 2', 'method': 'update'}
]
slider = {'active': 0, 'currentvalue': {'prefix': 'Figure: '}, 'steps': slider_steps}

layout = go.Layout(sliders=[slider])
fig = go.Figure(layout=layout)
fig