In [16]:
from argparse import Namespace
import json
import random

from ase import visualize
import ase.visualize.ngl
import ipywidgets
import matplotlib.pyplot as plt
import numpy as np
import pandas
from scipy.spatial.transform import Rotation as R

import milad
from milad.play import asetools
from milad import invariants
from milad import reconstruct


import qm9_utils

In [2]:
# Seed RNGs
random.seed(1234)
np.random.seed(1234)

# Colours used for plotting
cmap=('#845ec2',  '#FF9D47', '#ff9671')

# Parameters passed to reconstruct.get_best_rms
rmsd_settings = dict(
    max_retries=100,
    threshold=1e-3,
    use_hungarian=False,
)

In [3]:
from schnetpack import datasets

qm9data = datasets.QM9('data/qm9.db', download=True)

In [4]:
no_species = pandas.read_pickle('structure_recovery_iterative_no_species.pickle')
with_species = pandas.read_pickle('structure_recovery_iterative_with_species.pickle')
with_species_two_fingerprints = pandas.read_pickle('species_recovery_from_decoded_positions.pickle')

In [81]:
def create_atoms_widget(atoms: ase.Atoms, label: str):
    atoms_display = visualize.ngl.NGLDisplay(atoms)
    return ipywidgets.VBox([ipywidgets.Label(label), atoms_display.view])


def create_reconstructed_atoms_widget(idx: int, dataset, label: str=''):
    best = qm9_utils.get_best_reconstruction(idx, dataset)
    result = best['Result'].iloc[0]
    return create_atoms_widget(asetools.milad2ase(result.value), f'{label} (RMSD {result.rmsd:.2e})')

def get_visualisation(
    idx, no_species, with_speices, with_species_two_fingerprints, qm9data):

    # The original
    orig_atoms = qm9data.get_atoms(idx=idx)
    asetools.prepare_molecule(orig_atoms)
    orig = create_atoms_widget(orig_atoms, 'Original')
    
    no_species_view = create_reconstructed_atoms_widget(
        idx, no_species, label='Positions only')
    with_species_view = create_reconstructed_atoms_widget(
        idx, with_species, label='Positions and species')
    with_species_two_fingerprints_view = create_reconstructed_atoms_widget(
        idx, with_species_two_fingerprints, label='Positions and speceis (two fingerprints)')
    
   
    return ipywidgets.GridBox(
        [orig, no_species_view, with_species_view, with_species_two_fingerprints_view],
        layout=ipywidgets.Layout(
            grid_template_columns='40% 40%',
            grid_template_rows='40% 40%'
        ),
    )

In [85]:
get_visualisation(64, no_species, with_species, with_species_two_fingerprints, qm9data)

GridBox(children=(VBox(children=(Label(value='Original'), NGLWidget())), VBox(children=(Label(value='Positions…