In [10]:
from ase import Atoms
from ase.visualize import view
import jax
import jax.numpy as jnp
from matplotlib import pyplot as plt
import numpy as np
import os
import pandas as pd
import pickle
import re
import sys
import tensorflow as tf
import tqdm
from typing import List, Sequence
sys.path.append("../")

from symphony.data import input_pipeline_tf

In [3]:
workdirs = {
    "nn_edm_1": "/radish/qm9_fragments_fixed_mad/nn_edm/max_targets_1/",
    "nn_edm_4": "/radish/qm9_fragments_fixed_mad/nn_edm/max_targets_4/",
    "radius_1":"/radish/qm9_fragments_fixed_mad/radius/max_targets_1",
    "radius_4":"/radish/qm9_fragments_fixed_mad/radius/max_targets_4",
}

In [11]:
atomic_numbers = jnp.array([1, 6, 7, 8, 9])
numbers_to_symbols = {1: 'H', 6: 'C', 7: 'N', 8: 'O', 9: 'F'}
elements = list(numbers_to_symbols.values())

# covalent bond radii, in angstroms
element_radii = [0.32, 0.75, 0.71, 0.63, 0.64]

def get_numbers(species: jnp.ndarray):
    numbers = []
    for i in species:
        numbers.append(atomic_numbers[i])
    return jnp.array(numbers)

In [4]:
def get_dataset(method, seed=0):
    # Set the seed for reproducibility.
    tf.random.set_seed(seed)

    # Root directory of the dataset.
    filenames = sorted(os.listdir(workdirs[method]))
    filenames = [
        os.path.join(workdirs[method], f)
        for f in filenames
        if f.startswith("fragments_")
    ]
    if len(filenames) == 0:
        raise ValueError(f"No files found in {workdirs[method]}.")

    # Partition the filenames into train, val, and test.
    def filter_by_molecule_number(
        filenames: Sequence[str], start: int, end: int
    ) -> List[str]:
        def filter_file(filename: str, start: int, end: int) -> bool:
            filename = os.path.basename(filename)
            _, file_start, file_end = [int(val) for val in re.findall(r"\d+", filename)]
            return start <= file_start and file_end <= end

        return [f for f in filenames if filter_file(f, start, end)]

    # Number of molecules for training can be smaller than the chunk size.
    all_files = filter_by_molecule_number(filenames, 0, 135000),

    element_spec = tf.data.Dataset.load(filenames[0]).element_spec
    dataset = tf.data.Dataset.from_tensor_slices(all_files)
    dataset = dataset.interleave(
        lambda x: tf.data.Dataset.load(x, element_spec=element_spec),
        num_parallel_calls=tf.data.AUTOTUNE,
        deterministic=True,
    )

    # Shuffle the dataset.
    dataset = dataset.shuffle(1000, seed=seed)

    # Convert to jraph.GraphsTuple.
    dataset = dataset.map(
        input_pipeline_tf._convert_to_graphstuple,
        num_parallel_calls=tf.data.AUTOTUNE,
        deterministic=True,
    )

    return dataset

In [78]:
dataset = get_dataset("nn_edm_1")
dataset_iter = dataset.as_numpy_iterator();

[libprotobuf ERROR external/com_google_protobuf/src/google/protobuf/text_format.cc:337] Error parsing text-format tensorflow.data.experimental.DistributedSnapshotMetadata: 1:1: Invalid control characters encountered in text.
[libprotobuf ERROR external/com_google_protobuf/src/google/protobuf/text_format.cc:337] Error parsing text-format tensorflow.data.experimental.DistributedSnapshotMetadata: 1:3: Expected identifier, got: 12778021976030521269


In [83]:
graph = next(dataset_iter);

In [84]:
positions_list = graph.nodes.positions
species_list = graph.nodes.species
focus_and_target_species_probs = graph.nodes.focus_and_target_species_probs
num_nodes = graph.n_node.item()
focus_probs = focus_and_target_species_probs.sum(axis=-1)
focus_position = positions_list[0]
target_positions = graph.globals.target_positions[0]
target_species = graph.globals.target_species.item()

In [85]:
np.linalg.norm(target_positions, axis=-1)  # C-C: 1.54 angstroms, C-H: 1.09 angstroms

array([1.0862628], dtype=float32)

In [90]:
len(set(zip(graph.senders, graph.receivers)))

86

In [87]:
graph.n_node

array([10], dtype=int32)

In [76]:
mol_atoms = Atoms(positions=positions_list, numbers=get_numbers(species_list))
v = view(mol_atoms, viewer='ngl')

# add true focus highlight
v.view.shape.add_sphere(
    focus_position,
    [0, 1, 0],
    element_radii[species_list[0]] * 0.6,
    f"{elements[species_list[0]]}: true focus (probability {focus_probs.tolist()[0]:.3f})",
)
v.view.update_representation(component=1, opacity=0.4)

for target_position in target_positions:
    # add the target atom
    v.view.shape.add_sphere(
        (target_position+focus_position).tolist(),
        [0, 1, 1],
        element_radii[target_species] * 0.5,
        f"target atom: {elements[target_species]}",
    )

# for s, r in zip(graph.senders, graph.receivers):
#     pos_s = positions_list[s]
#     pos_r = positions_list[r]
#     v.view.shape.add_arrow(
#         pos_s.tolist(),
#         pos_r.tolist(),
#         [1, 0.85, 0],
#         0.1,
#         f'distance: {jnp.linalg.norm(pos_s - pos_r):.3f} A'
#     )

v

HBox(children=(NGLWidget(), VBox(children=(Dropdown(description='Show', options=('All', 'O', 'H', 'C'), value=…