In [None]:
import os
import pickle
import sys

import ase
import ase.data
import ase.io
import ase.visualize
import jax
import jax.numpy as jnp
import jraph
import ml_collections
import tqdm
import yaml
import plotly.graph_objects as go
import e3nn_jax as e3nn
import numpy as np

os.environ["CUDA_VISIBLE_DEVICES"] = "7"

sys.path.append("..")

In [None]:
import datatypes
import input_pipeline
import qm9
from models import ATOMIC_NUMBERS, RADII, create_model


In [None]:
path = "/home/ameyad/spherical-harmonic-net/workdirs/v5/mace/interactions=4/l=5/channels=32/"
path = "/home/ameyad/spherical-harmonic-net/workdirs/v5/e3schnet/interactions=4/l=5/channels=32/"

with open(path + "/checkpoints/params_30000.pkl", "rb") as f:
    params = pickle.load(f)
with open(path + "/config.yml", "rt") as config_file:
    config = yaml.unsafe_load(config_file)

assert config is not None
config = ml_collections.ConfigDict(config)

In [None]:
model = create_model(config, run_in_evaluation_mode=True)
apply_fn = jax.jit(model.apply)


def apply(frag, seed, beta):
    frags = jraph.pad_with_graphs(frag, 32, 1024, 2)
    preds = apply_fn(params, seed, frags, beta)
    pred = jraph.unpad_with_graphs(preds)
    return pred

In [None]:
dataset = qm9.load_qm9("qm9_data")

In [None]:
def vizualize_prediction(molecule: ase.Atoms, target: int) -> go.Figure:
    target = target % len(molecule)

    molecule_ablation = ase.Atoms(
        positions=np.concatenate(
            [molecule.positions[:target], molecule.positions[target + 1 :]]
        ),
        numbers=np.concatenate(
            [molecule.numbers[:target], molecule.numbers[target + 1 :]]
        ),
    )

    frag = input_pipeline.ase_atoms_to_jraph_graph(
        molecule_ablation,
        ATOMIC_NUMBERS,
        config.nn_cutoff,
    )

    k = jax.random.PRNGKey(0)
    pred = apply(frag, k, 1.0)

    data = []

    p = molecule.positions
    z = molecule.numbers

    ATOMIC_COLORS = {
        1: "rgb(200, 200, 200)",  # H
        6: "rgb(50, 50, 50)",  # C
        7: "rgb(0, 100, 255)",  # N
        8: "rgb(255, 0, 0)",  # O
        9: "rgb(255, 0, 255)",  # F
    }
    ATOMIC_SIZE = {
        1: 10,  # H
        6: 30,  # C
        7: 30,  # N
        8: 30,  # O
        9: 30,  # F
    }

    data.append(
        go.Scatter3d(
            x=p[:, 0],
            y=p[:, 1],
            z=p[:, 2],
            mode="markers",
            marker=dict(
                size=[ATOMIC_SIZE[i] for i in z],
                color=[ATOMIC_COLORS[i] for i in z],
            ),
            hovertext=[ase.data.chemical_symbols[i] for i in z],
            opacity=1.0,
            showlegend=False,
        )
    )
    data.append(
        go.Scatter3d(
            x=[p[target, 0]],
            y=[p[target, 1]],
            z=[p[target, 2]],
            mode="markers",
            marker=dict(
                size=1.05 * ATOMIC_SIZE[z[target]],
                color="yellow",
            ),
            opacity=0.5,
            name="Target",
        )
    )

    focus = pred.globals.focus_indices[0]
    sp = pred.globals.target_species.item()
    # color = ATOMIC_COLORS[ATOMIC_NUMBERS[sp]]
    p = pred.globals.position_probs
    p = p.resample(50, 99, 6)
    pos = frag.nodes.positions[focus]

    cmax = p.grid_values.max().item()
    for i in range(len(RADII)):
        pr = p.grid_values[0, i]
        pr = e3nn.SphericalSignal(pr, p.quadrature)

        s = go.Surface(
            **pr.plotly_surface(radius=RADII[i], translation=pos),
            colorscale=[
                [0, f"rgba(0, 0, 0, 0.0)"],
                [1, f"rgba(0, 0, 0, 1.0)"],
            ],
            showscale=False,
            cmin=0.0,
            cmax=cmax,
            name=f"Prediction: {ase.data.chemical_symbols[ATOMIC_NUMBERS[sp]]}",
        )
        data.append(s)

    # pr = e3nn.SphericalSignal(np.ones((30, 59)), p.quadrature)
    # s = go.Surface(
    #     **pr.plotly_surface(radius=RADII[-1], translation=pos),
    #     colorscale=[[0, color], [1, color]],
    #     showscale=False,
    #     opacity=0.02,
    # )
    # data.append(s)

    axis = dict(
        showbackground=False,
        showticklabels=False,
        showgrid=False,
        zeroline=False,
        title="",
        nticks=3,
        # range=[-3, 3],
    )

    layout = go.Layout(
        width=1200,
        height=800,
        scene=dict(
            xaxis=dict(**axis),
            yaxis=dict(**axis),
            zaxis=dict(**axis),
            aspectmode="data",
            camera=dict(
                up=dict(x=0, y=1, z=0),
                center=dict(x=0, y=0, z=0),
                eye=dict(x=0, y=0, z=5),
                projection=dict(type="orthographic"),
            ),
        ),
        paper_bgcolor="rgba(0,0,0,0)",
        plot_bgcolor="rgba(0,0,0,0)",
        margin=dict(l=0, r=0, t=0, b=0),
    )

    return go.Figure(data=data, layout=layout)

In [None]:
mol = dataset[-1]

target = 3
fig = vizualize_prediction(mol, target=target)
fig

In [None]:
mol = dataset[-1]

for target in tqdm.tqdm([0, 1, 2, 3, 4, 15]):
    fig = vizualize_prediction(mol, target=target)
    fig.write_html(f"v5/v5_{config.model}_{mol.get_chemical_formula()}_{target}.html")

In [None]:
len(mol)