In [52]:
import jax
import jraph
import numpy as np
import jax.numpy as jnp
import ase
from ase.db import connect
import chex
import sys
import tqdm
import plotly.graph_objects as go
sys.path.append('..')
sys.path.append('../analyses')
import analysis
import datatypes
import generate_molecules
import input_pipeline
import models
import train

In [2]:
model_path = '/Users/songk/atomicarchitects/spherical_harmonic_net/workdirs/v6/nequip/interactions=4/l=5/channels=32'

# analysis.get_results_as_dataframe(
#     ["nequip-l2"],
#     ["total_loss", "focus_loss", "atom_type_loss", "position_loss"],
#     model_path,
# )

model, params, config = analysis.load_model_at_step(
    model_path, -1, run_in_evaluation_mode=True
)
apply_fn = jax.jit(model.apply)

In [64]:
init_molecule, init_molecule_name = analysis.construct_molecule('../analyses/molecules/downloaded/CH3.xyz')
rng = jax.random.PRNGKey(28)
molecule = init_molecule.copy()
nan_found = False

In [65]:
def get_predictions(
    fragment: jraph.GraphsTuple, rng: chex.PRNGKey
) -> datatypes.Predictions:
    fragments = jraph.pad_with_graphs(fragment, n_node=80, n_edge=4096, n_graph=2)
    preds = apply_fn(params, rng, fragments, 10)

    # Remove the batch dimension.
    pred = jraph.unpad_with_graphs(preds)
    pred = pred._replace(
        globals=jax.tree_map(lambda x: np.squeeze(x, axis=0), pred.globals)
    )
    return pred


def append_predictions(
        molecule: ase.Atoms, pred: datatypes.Predictions
    ) -> ase.Atoms:
        focus = pred.globals.focus_indices
        pos_focus = molecule.positions[focus]
        pos_rel = pred.globals.position_vectors

        new_species = jnp.array(
            models.ATOMIC_NUMBERS[pred.globals.target_species.item()]
        )
        new_position = pos_focus + pos_rel

        return ase.Atoms(
            positions=jnp.concatenate(
                [molecule.positions, new_position[None, :]], axis=0
            ),
            numbers=jnp.concatenate([molecule.numbers, new_species[None]], axis=0),
        )

In [70]:
# Add atoms step-by-step.
step_rng, rng = jax.random.split(rng)
fragment = input_pipeline.ase_atoms_to_jraph_graph(
    molecule, models.ATOMIC_NUMBERS, config.nn_cutoff
)
# Run the model on the current molecule.
pred = get_predictions(fragment, step_rng)
# Append the new atom to the molecule.
molecule = append_predictions(molecule, pred)

In [71]:
molecule

Atoms(symbols='CH3NCH', pbc=False)

In [82]:
focus_probs

Array([1.9059124e-19, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00,
       2.2385720e-02, 9.7761095e-01, 3.2381479e-06], dtype=float32)

In [80]:
focus_probs = jax.nn.softmax(jnp.concatenate([pred.nodes.focus_logits, jnp.array([0])]))
# true_focus = len(focus_probs)-1 if mol.globals.stop[0] else 0
bar_colors = ['#636efa' for i in range(len(focus_probs))]
focus_fig = go.Figure([go.Bar(x=jnp.arange(focus_probs.shape[0]-1), y=focus_probs[:-1], marker_color=bar_colors, showlegend=False),
                        go.Bar(x=[len(focus_probs)], y=[focus_probs[-1]], showlegend=False, text='STOP', marker_color='#ef63f9')])
focus_fig.update_xaxes(title='Atom index')
focus_fig.update_yaxes(title='Predicted Probability')
focus_fig.show()

In [15]:
pred.globals.position_probs.integrate().array.sum()

Array(0.00128827, dtype=float32)

In [84]:
pred.globals.target_species_probs

Array([9.9994481e-01, 8.3941208e-19, 3.5305285e-13, 5.5147695e-05,
       0.0000000e+00], dtype=float32)

In [61]:
RADII = models.RADII
radii_probs = pred.globals.position_probs.integrate().array.squeeze(
    axis=-1
)
radii_probs = radii_probs / radii_probs.sum()

radius_fig = go.Figure([go.Bar(x=RADII, y=radii_probs)])
radius_fig.update_xaxes(title='Radius (A)')  # give axis label (do this for the other graphs too )
radius_fig.show()

In [62]:
max_logit = jnp.max(
    pred.globals.position_logits.grid_values, axis=(-3, -2, -1), keepdims=True
)

position_probs = pred.globals.position_logits.apply(
    lambda logit: jnp.exp(10 * (logit - max_logit))
)

In [85]:
angular_probs = jax.tree_util.tree_map(
    lambda x: x[22], position_probs
)

In [86]:
go.Figure([go.Surface(angular_probs.plotly_surface())])

In [22]:
def visualize_atom_removals(
    workdir: str,
    beta: float,
    step: int,
    molecule: ase.Atoms,
    seed: int,
):
    """Generates visualizations of the predictions when removing each atom from a molecule."""
    name = analysis.name_from_workdir(workdir)
    model, params, config = analysis.load_model_at_step(
        workdir, step, run_in_evaluation_mode=True
    )

    # Remove the target atoms from the molecule.
    molecules_with_target_removed = []
    fragments = []
    for target in range(len(molecule)):
        molecule_with_target_removed = ase.Atoms(
            positions=np.concatenate(
                [molecule.positions[:target], molecule.positions[target + 1 :]]
            ),
            numbers=np.concatenate(
                [molecule.numbers[:target], molecule.numbers[target + 1 :]]
            ),
        )
        fragment = input_pipeline.ase_atoms_to_jraph_graph(
            molecule_with_target_removed,
            analysis.ATOMIC_NUMBERS,
            config.nn_cutoff,
        )

        molecules_with_target_removed.append(molecule_with_target_removed)
        fragments.append(fragment)

    # We don't actually need a PRNG key, since we're not sampling.
    print("Computing predictions...")

    rng = jax.random.PRNGKey(seed)
    preds = jax.jit(model.apply)(params, rng, jraph.batch(fragments), beta)
    preds = jax.tree_map(np.asarray, preds)
    preds = jraph.unbatch(preds)
    print("Predictions computed.")

    # Loop over all possible targets.
    print("Visualizing predictions...")
    figs = []
    preds_list = []
    for target in tqdm.tqdm(range(len(molecule)), desc="Targets"):
        # We have to remove the batch dimension.
        # Also, correct the focus indices due to batching.
        pred = preds[target]._replace(
            globals=jax.tree_map(lambda x: np.squeeze(x, axis=0), preds[target].globals)
        )
        corrected_focus_indices = pred.globals.focus_indices - sum(
            p.n_node.item() for i, p in enumerate(preds) if i < target
        )
        pred = pred._replace(
            globals=pred.globals._replace(focus_indices=corrected_focus_indices)
        )

        # Visualize predictions for this target.
        fig = analysis.visualize_predictions(
            pred, molecules_with_target_removed[target], molecule, target
        )

        figs.append(fig)
        preds_list.append(pred)

    # Combine all figures into one.
    fig_all = analysis.combine_visualizations(figs)

    # Add title.
    model_name = analysis.get_title_for_name(name)
    fig_all.update_layout(
        title=f"{model_name}: Predictions for {str(molecule.symbols)}",
        title_x=0.5,
    )

    return molecules_with_target_removed, preds_list

In [21]:
def append_predictions(
        molecule: ase.Atoms, pred: datatypes.Predictions
    ) -> ase.Atoms:
    focus = pred.globals.focus_indices
    pos_focus = molecule.positions[focus]
    pos_rel = pred.globals.position_vectors

    new_species = jnp.array(
        models.ATOMIC_NUMBERS[pred.globals.target_species.item()]
    )
    new_position = pos_focus + pos_rel

    return ase.Atoms(
        positions=jnp.concatenate(
            [molecule.positions, new_position[None, :]], axis=0
        ),
        numbers=jnp.concatenate([molecule.numbers, new_species[None]], axis=0),
    )

In [14]:
molecules = []
with connect('../qm9_data/qm9-all.db') as conn:
    for i, row in enumerate(conn.select()):
        if i > 53570: break
        if i >= 53568: molecules.append(row.toatoms())

In [23]:
model_path = '/Users/songk/atomicarchitects/spherical_harmonic_net/workdirs/v6/nequip/interactions=4/l=5/channels=32'

molecules_with_target_removed, preds = visualize_atom_removals(
    workdir=model_path,
    beta=1.0,
    step=-1,
    molecule=molecules[0],
    seed=0)

Computing predictions...
Predictions computed.
Visualizing predictions...


Targets: 100%|██████████| 19/19 [00:02<00:00,  6.93it/s]


Atoms(symbols='C6OCH10C', pbc=False)

In [25]:
import check_valence