In [33]:
# Imports
from typing import *
import ase
import ase.db
import ase.io
import chex
import e3nn_jax as e3nn
import jax
import jax.numpy as jnp
import jraph
import logging
import numpy as np
import optax
import os
import plotly
import plotly.graph_objects as go
import tqdm

logging.getLogger().setLevel(logging.INFO)

import sys
sys.path.append("..")

import analyses.analysis as analysis
import analyses.generate_molecules as generate_molecules
import analyses.visualize_atom_removals as visualize_atom_removals
from symphony import datatypes
import symphony.data.input_pipeline as input_pipeline
import symphony.models as models
import symphony.train as train


In [None]:
with ase.db.connect('../qm9_data/qm9gen.db') as conn:
    for row in conn.select(id=10):
        mol = row.toatoms()

In [15]:
workdir = "/home/ameyad/spherical-harmonic-net/workdirs/qm9_bessel_embedding_attempt6_edm_splits/e3schnet_and_nequip/interactions=3/l=5/position_channels=2/channels=64"
outputdir = '../analyses/outputs/figures'
beta_species = 1.
beta_position = 1.
step = '4950000'
num_seeds = 1
num_seeds_per_chunk = 1
init_molecule = '../analyses/molecules/downloaded/CH3.xyz'  # file name is fine
max_num_atoms = 35
visualize = True

In [16]:
# Create initial molecule, if provided.
if type(init_molecule) == str:
    init_molecule, init_molecule_name = analysis.construct_molecule(init_molecule)
    logging.info(
        f"Initial molecule: {init_molecule.get_chemical_formula()} with numbers {init_molecule.numbers} and positions {init_molecule.positions}"
    )
else:
    init_molecule_name = str(init_molecule.symbols)

# Load model.
name = analysis.name_from_workdir(workdir)
model, params, config = analysis.load_model_at_step(
    workdir, step, run_in_evaluation_mode=True
)
config = config.unlock()
if "position_updater" in config:
    del config["position_updater"]
logging.info(config.to_dict())

# Prepare initial fragment.
init_fragment = input_pipeline.ase_atoms_to_jraph_graph(
    init_molecule, models.ATOMIC_NUMBERS, config.nn_cutoff
)
init_fragment = jraph.pad_with_graphs(
    init_fragment,
    n_node=(max_num_atoms + 1),
    n_edge=(max_num_atoms + 1) ** 2,
    n_graph=2,
)
init_fragment = jax.tree_map(jnp.asarray, init_fragment)

@jax.jit
def chunk_and_apply(
    params: optax.Params, rngs: chex.PRNGKey
) -> Tuple[datatypes.Fragments, datatypes.Predictions]:
    """Chunks the seeds and applies the model sequentially over all chunks."""

    def apply_on_chunk(
        rngs: chex.PRNGKey,
    ) -> Tuple[datatypes.Fragments, datatypes.Predictions]:
        """Applies the model on a single chunk."""
        apply_fn = lambda padded_fragment, rng: model.apply(
            params,
            rng,
            padded_fragment,
            beta_species,
            beta_position,
        )
        generate_for_one_seed_fn = lambda rng: generate_molecules.generate_for_one_seed(
            apply_fn,
            init_fragment,
            max_num_atoms,
            config.nn_cutoff,
            rng,
            return_intermediates=visualize,
        )
        return jax.vmap(generate_for_one_seed_fn)(rngs)

    rngs = rngs.reshape((num_seeds // num_seeds_per_chunk, num_seeds_per_chunk, -1))
    results = jax.lax.map(apply_on_chunk, rngs)
    return jax.tree_map(lambda arr: arr.reshape((-1, *arr.shape[2:])), results)

INFO:root:Initial molecule: CH3 with numbers [6 1 1 1] and positions [[ 0.        0.        0.      ]
 [ 0.629118  0.629118  0.629118]
 [-0.629118 -0.629118  0.629118]
 [ 0.629118 -0.629118 -0.629118]]
INFO:root:{'add_noise_to_positions': True, 'compute_padding_dynamically': False, 'dataset': 'qm9', 'eval_every_steps': 30000, 'focus_and_target_species_predictor': {'activation': 'softplus', 'compute_global_embedding': False, 'embedder_config': {'activation': 'shifted_softplus', 'cutoff': 5.0, 'max_ell': 2, 'model': 'E3SchNet', 'num_channels': 64, 'num_filters': 16, 'num_interactions': 3, 'num_radial_basis_functions': 8}, 'latent_size': 128, 'num_layers': 3}, 'fragment_logic': 'nn_edm', 'freeze_node_embedders': False, 'learning_rate': 0.0005, 'learning_rate_schedule': 'constant', 'learning_rate_schedule_kwargs': {'decay_steps': 50000, 'init_value': 0.0005, 'peak_value': 0.001, 'warmup_steps': 2000}, 'log_every_steps': 1000, 'loss_kwargs': {'ignore_position_loss_for_small_fragments': Fals

In [17]:
# Generate molecules for all seeds.
seed = 0
rng = jax.random.PRNGKey(0)
rngs = jnp.asarray([rng])

# Compute compilation time.
chunk_and_apply.lower(params, rngs).compile()

# Generate molecules (and intermediate steps, if visualizing).
final_padded_fragments, stops = chunk_and_apply(params, rngs)

# We already have the final padded fragment.
final_padded_fragments = jax.tree_map(
    lambda x: x[seed], final_padded_fragments
)

In [27]:
# Get the padded fragment and predictions for this seed.
preds_for_seed = jax.tree_map(lambda x: x[seed], stops)

frags = []
for step in range(max_num_atoms):
    if step == 0:
        padded_fragment = init_fragment
    else:
        padded_fragment = jax.tree_map(
            lambda x: x[step - 1], final_padded_fragments
        )
    pred = jax.tree_map(lambda x: x[step], preds_for_seed)

    # Save visualization of generation process.
    fragment = jraph.unpad_with_graphs(padded_fragment)
    pred = jraph.unpad_with_graphs(pred)
    fragment = fragment._replace(
        globals=jax.tree_map(
            lambda x: np.squeeze(x, axis=0), fragment.globals
        )
    )
    pred = pred._replace(
        globals=jax.tree_map(lambda x: np.squeeze(x, axis=0), pred.globals)
    )
    frags.append((fragment, pred))

In [30]:
fragment, pred = frags[3]

In [77]:
def spherical_harmonics_as_signals(l: int) -> Iterable[e3nn.SphericalSignal]:
    """Yields the spherical harmonics of degree l as a sequence of e3nn.SphericalSignal objects for each m such that -l <= m <= l."""
    res = (50, 49)
    for m in range(-l, l + 1):
        coeffs = e3nn.IrrepsArray(e3nn.s2_irreps(l)[-1], jnp.asarray([1. if md == m else 0. for md in range(-l, l + 1)]))
        yield e3nn.to_s2grid(coeffs, *res, quadrature="soft", p_val=1, p_arg=-1)

def plot_spherical_harmonics(lmax: int) -> go.Figure:
    """Plots the spherical harmonics of degree l on a single row of subplots with one column for each m such that -l <= m <= l."""
    fig = plotly.subplots.make_subplots(rows=lmax + 1, cols=2*lmax + 1, specs=[[{'type': 'surface'} for _ in range(2*lmax + 1)] for _ in range(lmax + 1)])
    for l in range(lmax + 1):
        for m, sig in enumerate(spherical_harmonics_as_signals(l), start=-l):
            index = lmax + m + 1
            fig.add_trace(go.Surface(sig.plotly_surface(scale_radius_by_amplitude=True), colorscale='plasma', cmax=2, cmin=-2, showscale=False, colorbar=dict(lenmode='fraction', len=0.5, thickness=20)), row=l + 1, col=index)

    fig.update_layout(margin=dict(l=10,r=10,b=10,t=10))
    camera = dict(
        up=dict(x=0, y=0, z=1),
        center=dict(x=0, y=0, z=0),
        eye=dict(x=1.25, y=1.25, z=1.25)
    )
    axis_props = dict(title="", showticklabels=False, showgrid=False, zeroline=False, showbackground=False)
    fig.update_scenes(camera=camera, xaxis=axis_props, yaxis=axis_props, zaxis=axis_props) # hide all the xticks
    return fig


In [44]:
def id_to_element(specie_id):
    element_dict = [1, 6, 7, 8, 9]
    return element_dict[specie_id]

In [46]:
init_mol = ase.Atoms(positions=fragment.nodes.positions, numbers=list(map(id_to_element, fragment.nodes.species)))

In [43]:
fragment.nodes.species

Array([1, 0, 0, 0, 1, 3, 1], dtype=int32)

In [89]:
def plot_decomposition(f_signal: e3nn.SphericalSignal, lmax: int):
    fig = plotly.subplots.make_subplots(rows=1, cols=2*lmax + 1, specs=[[{'type': 'surface'} for _ in range(2*lmax + 1)] for _ in range(1)])
    fig.add_trace(
        go.Surface(
            f_signal.plotly_surface(scale_radius_by_amplitude=True),
            cmin=-2,
            cmax=2
    ), row=1, col=1)

    num_non_zero_coeffs = 0
    f_coeffs = e3nn.from_s2grid(f_signal, e3nn.s2_irreps(lmax + 1), normalization="integral")
    for l in range(lmax + 1):
        c_l = f_coeffs.filter(e3nn.Irreps([(1, (l, (-1) ** (l)))]))
        print(c_l)
        for index, sig in enumerate(spherical_harmonics_as_signals(l)):
            if np.abs(c_l.array[index]) > 1e-2:
                try:
                    fig.add_trace(go.Surface(sig.plotly_surface(scale_radius_by_amplitude=True), colorscale='plasma', cmax=2, cmin=-2, showscale=False, colorbar=dict(lenmode='fraction', len=0.5, thickness=20)), row=1, col=num_non_zero_coeffs + 2)
                    num_non_zero_coeffs += 1
                except:
                    print(1, num_non_zero_coeffs + 2)
                    raise Exception
    fig.update_traces(showscale=False)
    fig.update_layout(margin=dict(l=10,r=10,b=10,t=10))
    camera=dict(
            up=dict(x=0, y=0, z=1),
            center=dict(x=0, y=0, z=0),
            eye=dict(x=4.25, y=-4.25, z=4.25)
    )
    axis_props = dict(title="", showticklabels=False, showgrid=False, zeroline=False, showbackground=False)
    fig.update_scenes(camera=camera, xaxis=axis_props, yaxis=axis_props, zaxis=axis_props) # hide all the xticks

    return fig


In [85]:
radius = jnp.sqrt(jnp.sum((fragment.nodes.positions[6] - frags[4][0].nodes.positions[-1])**2))
rad_index = 12

In [90]:
plot_decomposition(pred.globals.position_logits[rad_index], lmax=2)

1x0e [-19.213167]
1x1o [ 0.23197708  0.10139324 -0.25141567]
1x2e [ 0.4163386   0.24635902  0.44629362  1.0692436  -0.1214186 ]
1 7


Exception: 