In [1]:
import os
import pickle
import sys

import ase
import ase.data
import ase.io
import ase.visualize
import jax
import jax.numpy as jnp
import jraph
import ml_collections
import tqdm
import yaml
import plotly.graph_objects as go
import e3nn_jax as e3nn
import numpy as np
import matplotlib.pyplot as plt

os.environ["CUDA_VISIBLE_DEVICES"] = "7"

sys.path.append("..")


In [2]:
import datatypes
import input_pipeline
import qm9
from models import ATOMIC_NUMBERS, RADII, create_model

In [3]:
path = "/home/ameyad/spherical-harmonic-net/workdirs/v5/mace/interactions=4/l=5/channels=32/"
path = "/home/ameyad/spherical-harmonic-net/workdirs/v5/e3schnet/interactions=4/l=5/channels=32/"

with open(path + "/checkpoints/params_220000.pkl", "rb") as f:
    params = pickle.load(f)
with open(path + "/config.yml", "rt") as config_file:
    config = yaml.unsafe_load(config_file)

assert config is not None
config = ml_collections.ConfigDict(config)


In [4]:
model = create_model(config, run_in_evaluation_mode=True)
apply_fn = jax.jit(model.apply)


def apply(frag, seed, beta):
    frags = jraph.pad_with_graphs(frag, 32, 1024, 2)
    preds = apply_fn(params, seed, frags, beta)
    pred = jraph.unpad_with_graphs(preds)
    return pred


In [5]:
dataset = qm9.load_qm9("qm9_data")


In [34]:
def vizualize_prediction(molecule: ase.Atoms, target: int) -> go.Figure:
    target = target % len(molecule)

    molecule_ablation = ase.Atoms(
        positions=np.concatenate(
            [molecule.positions[:target], molecule.positions[target + 1 :]]
        ),
        numbers=np.concatenate(
            [molecule.numbers[:target], molecule.numbers[target + 1 :]]
        ),
    )

    frag = input_pipeline.ase_atoms_to_jraph_graph(
        molecule_ablation,
        ATOMIC_NUMBERS,
        config.nn_cutoff,
    )

    k = jax.random.PRNGKey(0)
    pred = apply(frag, k, 1.0)

    data = []

    target_probs = molecule.positions
    z = molecule.numbers

    ATOMIC_COLORS = {
        1: "rgb(150, 150, 150)",  # H
        6: "rgb(50, 50, 50)",  # C
        7: "rgb(0, 100, 255)",  # N
        8: "rgb(255, 0, 0)",  # O
        9: "rgb(255, 0, 255)",  # F
    }
    ATOMIC_SIZE = {
        1: 10,  # H
        6: 30,  # C
        7: 30,  # N
        8: 30,  # O
        9: 30,  # F
    }

    data.append(
        go.Scatter3d(
            x=target_probs[:, 0],
            y=target_probs[:, 1],
            z=target_probs[:, 2],
            mode="markers",
            marker=dict(
                size=[ATOMIC_SIZE[i] for i in z],
                color=[ATOMIC_COLORS[i] for i in z],
            ),
            hovertext=[ase.data.chemical_symbols[i] for i in z],
            opacity=1.0,
            showlegend=False,
        )
    )

    focus = pred.globals.focus_indices[0]
    focus_pos = frag.nodes.positions[focus]
    focus_sp = frag.nodes.species[focus].item()

    data.append(
        go.Scatter3d(
            x=[target_probs[target, 0], focus_pos[0]],
            y=[target_probs[target, 1], focus_pos[1]],
            z=[target_probs[target, 2], focus_pos[2]],
            mode="markers",
            marker=dict(
                size=[
                    1.2 * ATOMIC_SIZE[z[target]],
                    1.2 * ATOMIC_SIZE[ATOMIC_NUMBERS[focus_sp]],
                ],
                color=["yellow", "green"],
            ),
            opacity=0.1,
            name="Target",
        )
    )

    target_sp = pred.globals.target_species.item()
    target_color = ATOMIC_COLORS[ATOMIC_NUMBERS[target_sp]]
    target_logit = pred.globals.position_logits.resample(50, 99, 6)
    target_probs = target_logit.apply(
        lambda x: jnp.exp(x - target_logit.grid_values.max())
    )

    cmin = 0  # p.grid_values.min().item()
    cmax = target_probs.grid_values.max().item()
    for i in range(len(RADII)):
        p = target_probs[0, i]

        if p.grid_values.max() < cmax / 100.0:
            continue

        data.append(
            go.Surface(
                **p.plotly_surface(radius=RADII[i], translation=focus_pos),
                colorscale=[
                    [0, f"rgba({target_color[4:-1]}, 0.0)"],
                    [1, f"rgba({target_color[4:-1]}, 1.0)"],
                ],
                showscale=False,
                cmin=cmin,
                cmax=cmax,
                name=f"Prediction: {ase.data.chemical_symbols[ATOMIC_NUMBERS[target_sp]]}",
            )
        )

    axis = dict(
        showbackground=False,
        showticklabels=False,
        showgrid=False,
        zeroline=False,
        title="",
        nticks=3,
    )

    layout = go.Layout(
        width=1200,
        height=800,
        scene=dict(
            xaxis=dict(**axis),
            yaxis=dict(**axis),
            zaxis=dict(**axis),
            aspectmode="data",
            camera=dict(
                up=dict(x=0, y=1, z=0),
                center=dict(x=0, y=0, z=0),
                eye=dict(x=0, y=0, z=5),
                projection=dict(type="orthographic"),
            ),
        ),
        paper_bgcolor="rgba(0,0,0,0)",
        plot_bgcolor="rgba(0,0,0,0)",
        margin=dict(l=0, r=0, t=0, b=0),
    )

    return go.Figure(data=data, layout=layout), frag, pred

In [32]:
mol = dataset[-1]

target = 3
fig, frag, pred = vizualize_prediction(mol, target=target)
fig


plot sphere 34
plot sphere 35
plot sphere 36
plot sphere 37
plot sphere 38
plot sphere 39
plot sphere 40
plot sphere 41
plot sphere 42


In [35]:
mol = dataset[-1]

for target in tqdm.tqdm([0, 1, 2, 3, 4, 15]):
    fig, _, _ = vizualize_prediction(mol, target=target)
    fig.write_html(
        f"v5_220k/v5_{config.model}_{mol.get_chemical_formula()}_{target}.html"
    )

100%|██████████| 6/6 [00:03<00:00,  1.87it/s]


In [None]:
len(mol)
