In [None]:
from ase.atoms import Atoms
from ase.io import write
from ase.visualize import view
import e3nn_jax as e3nn
import jax
import jax.numpy as jnp
import jraph
import plotly.express as px
import plotly.graph_objects as go
import sys
import tensorflow as tf

sys.path.append('..')
sys.path.append('../analyses')
import analysis
import datatypes
import input_pipeline_tf
import models
import train

In [None]:
atomic_numbers = jnp.array([1, 6, 7, 8, 9])
numbers_to_symbols = {1: 'H', 6: 'C', 7: 'N', 8: 'O', 9: 'F'}
elements = list(numbers_to_symbols.values())

# covalent bond radii, in angstroms
element_radii = [0.32, 0.75, 0.71, 0.63, 0.64]

def get_numbers(species: jnp.ndarray):
    numbers = []
    for i in species:
        numbers.append(atomic_numbers[i])
    return jnp.array(numbers)

In [None]:
tf.config.experimental.set_visible_devices([], "GPU")

config, best_state_train, best_state_eval, metrics_for_best_state = analysis.load_from_workdir('/home/ameyad/spherical-harmonic-net/workdirs/v5/e3schnet/interactions=4/l=5/channels=32/', load_pickled_params=False)
rng = jax.random.PRNGKey(config.rng_seed)
rng, dataset_rng = jax.random.split(rng)
datasets = input_pipeline_tf.get_datasets(dataset_rng, config)

cutoff = 5.0
epsilon = 1e-4

frag_num = -1

In [None]:
metrics_for_best_state

In [None]:
example_graph = next(datasets["test"].as_numpy_iterator())
frag = datatypes.Fragments.from_graphstuple(example_graph)
frag = jax.tree_map(jnp.asarray, frag)
frag_num += 1

frag_unpadded = jraph.unpad_with_graphs(frag)
molecules = jraph.unbatch(frag_unpadded)

In [None]:
mol_num = 2
mol = molecules[mol_num]
species_list = mol.nodes.species.tolist()
positions_list = mol.nodes.positions.tolist()
target_species = mol.globals.target_species.tolist()[0]
target_position = mol.globals.target_positions.tolist()[0]
target_position = [target_position[i]+positions_list[0][i] for i in range(3)]
true_focus_index = 0

preds = train.get_predictions(best_state_train, mol, rng)

### train mode

In [None]:
preds

In [None]:
focus_probs = jax.nn.softmax(jnp.concatenate([preds.nodes.focus_logits, jnp.array([0])]))
true_focus = len(focus_probs)-1 if mol.globals.stop[0] else 0
bar_colors = ['#ef553b' if i == true_focus else '#636efa' for i in range(len(focus_probs))]
focus_fig = go.Figure([go.Bar(x=jnp.arange(focus_probs.shape[0]-1), y=focus_probs[:-1], marker_color=bar_colors, showlegend=False),
                        go.Bar(x=[len(focus_probs)], y=[focus_probs[-1]], showlegend=False, text='STOP', marker_color='#ef63f9')])
focus_fig.update_xaxes(title='Atom index')
focus_fig.update_yaxes(title='Predicted Probability')
focus_fig.show()

In [None]:
species_probs = jax.nn.softmax(preds.globals.target_species_logits.squeeze(axis=0))

bar_colors = ['#ef553b' if i == target_species else '#636efa' for i in range(len(models.ATOMIC_NUMBERS))]
species_fig = go.Figure([go.Bar(x=list(map(lambda z: numbers_to_symbols[z], models.ATOMIC_NUMBERS)),
                                y=species_probs,
                                marker_color=bar_colors)])
species_fig.update_xaxes(title='Predicted Type')
species_fig.update_yaxes(title='Predicted Probability')
species_fig.show()

In [None]:
# radial distribution

rel_target_position = mol.globals.target_positions[0].tolist()
target_radius = jnp.linalg.norm(mol.globals.target_positions[0]).item()

RADII = models.RADII
radial_probs = preds.globals.position_probs.integrate().array.squeeze(
    axis=(0, -1)
)

radius_fig = go.Figure([go.Bar(x=RADII, y=radial_probs)])
radius_fig.add_vline(target_radius, line_color='red')
radius_fig.update_xaxes(title='Radius (A)')  # give axis label (do this for the other graphs too )
radius_fig.show()


In [None]:
# angular distributions per radius

P = jax.tree_util.tree_map(lambda x: x.squeeze(0), preds.globals.position_probs.resample(30, 51, 5))
prob_surfaces = jax.vmap(lambda x: x.plotly_surface())(P)

sphere_bound = e3nn.SphericalSignal(jnp.ones((30, 51)), 'soft')
s_surface = sphere_bound.plotly_surface()

surface_list = [go.Surface(
            jax.tree_map(lambda x: x[i] * RADII[i], prob_surfaces),
            opacity=1,
            colorscale=[[0, "rgba(13,8,135,0)"], [0.14, "rgba(84,2,163,0.14)"], [0.29, "rgba(139,10,165,0.29)"], [0.43, "rgba(185,50,137,0.43)"], [0.57, "rgba(219,92,104,0.57)"], [0.71, "rgba(244,136,73,0.71)"], [0.86, "rgba(254,188,43,0.86)"], [1, "rgba(240,239,33,255)"]],
            cmin=0,
            cmax=RADII[i].item() / 2,
            showscale=False)
        for i in range(64)]


In [None]:

# all the radii
fig = go.Figure(surface_list + [
            go.Scatter3d(x=[0, rel_target_position[0]], y=[0, rel_target_position[1]], z=[0, rel_target_position[2]], mode="lines"),
            go.Scatter3d(x=[rel_target_position[0]], y=[rel_target_position[1]], z=[rel_target_position[2]], marker={"size": 2}),
        ])
# the below line is for a global color scale on [0, 1], in contrast to the default [0, max]
# fig.update_layout(coloraxis={'colorscale':[[0, "rgba(13,8,135,0)"], [0.14, "rgba(84,2,163,0.14)"], [0.29, "rgba(139,10,165,0.29)"], [0.43, "rgba(185,50,137,0.43)"], [0.57, "rgba(219,92,104,0.57)"], [0.71, "rgba(244,136,73,0.71)"], [0.86, "rgba(254,188,43,0.86)"], [1, "rgba(240,239,33,255)"]]})
fig.update_layout(template='simple_white')
fig.show()

In [None]:
mol.nodes.species

In [None]:
mol.globals.stop

In [None]:
e3nn.norm(preds.position_coeffs[0, 37, :])

### evaluate mode

In [None]:
preds = train.get_predictions(best_state_eval, mol, rng)

focus_index = preds.focus_indices.tolist()[0]
# add stop probability
focus_probs = jax.nn.softmax(jnp.concatenate([preds.focus_logits, jnp.array([0])]))

pred_species = preds.target_species.tolist()[0]
pred_position = preds.position_vectors.tolist()[0]
pred_position = [pred_position[i]+positions_list[focus_index][i] for i in range(3)]

In [None]:
preds.target_species_logits

In [None]:
mol_atoms = Atoms(positions=positions_list, numbers=get_numbers(species_list))
v = view(mol_atoms, viewer='ngl')

num_nodes = mol.n_node[0].tolist()

for i in range(num_nodes):
    focus_prob = focus_probs.tolist()[i]
    species = species_list[i]

    # add focus probability highlights for each atom
    v.view.shape.add_sphere(
        positions_list[i],
        [1, 0.85, 0],
        element_radii[species] * 0.6,
        f"atom {i} ({elements[species]}): focus probability {focus_prob:.3f}",
    )
    v.view.update_representation(component=i+1, opacity=focus_prob)

# add true focus highlight
v.view.shape.add_sphere(
    positions_list[0],
    [0, 1, 0],
    element_radii[species_list[0]] * 0.6,
    f"atom {i} ({elements[species_list[0]]}): true focus (probability {focus_probs.tolist()[0]:.3f})",
)
v.view.update_representation(component=num_nodes+1, opacity=0.4)

# add the next atom we're adding to this molecule, predicted specie + highlight
v.view.shape.add_sphere(
    pred_position,
    [1, 0, 1],
    element_radii[pred_species] * 0.5,
    f"predicted atom: {elements[pred_species]}",
)

# add the target atom
v.view.shape.add_sphere(
    target_position,
    [0, 1, 1],
    element_radii[target_species] * 0.5,
    f"target atom: {elements[target_species]}",
)

# add an arrow from selected focus
pred_focus_position = positions_list[focus_index]
v.view.shape.add_arrow(
    pred_focus_position,
    pred_position,
    [1, 0.85, 0],
    0.1,
    f'distance: {jnp.sqrt(jnp.sum((jnp.array(pred_position)-jnp.array(pred_focus_position))**2)):.3f} A'
)

v

In [None]:
v.view.download_image(f'frag{frag_num}_mol{mol_num}.png')

In [None]:
# write losses to file

for i in range(len(molecules)):
    mol = molecules[i]
    preds = train.get_predictions(best_state_train, mol, rng)
    mol_loss = train.generation_loss(preds, mol, config.loss_kwargs.radius_rbf_variance)
    with open(f'interactions=2_l=3_channels=32_frag={frag_num}_loss.txt', 'a') as f:
        f.write(f'molecule {i}:\n')
        f.write(f'total loss = {mol_loss[0].tolist()[0]}\n')
        f.write(f'focus loss = {mol_loss[1][0].tolist()[0]}\n')
        f.write(f'species loss = {mol_loss[1][1].tolist()[0]}\n')
        f.write(f'position loss = {mol_loss[1][2].tolist()[0]}\n')
        f.write('\n')