In [1]:
## Imports
import numpy as np
import jax
import matplotlib
from matplotlib import animation
import matplotlib.pyplot as plt
import matplotlib.lines as mlines
import os
from sklearn import decomposition
import scipy.interpolate
import sys
sys.path.append("..")

from action_angle_networks.simulation import md17_simulation
matplotlib.rc("animation", html="jshtml")
PLT_STYLE_CONTEXT = ['science', 'ieee', 'grid']


In [2]:
%load_ext autoreload

In [3]:
# @title Source Imports
%autoreload 2
from action_angle_networks.simulation import md17_simulation

In [4]:
def visualize_trajectory(molecule: str, num_structures: int = 100, resample: bool = False):
    """Visualizes a molecular dynamics trajectory."""

    # Load trajectory
    positions, _, nuclear_charges = md17_simulation.load_trajectory(molecule, 10000000, resample)
    # Project down to two dimensions.
    num_atoms = positions.shape[1]
    positions_flat = positions.reshape(-1, 3)
    projected_positions_flat = decomposition.PCA(n_components=2).fit_transform(positions_flat)
    projected_positions = np.reshape(projected_positions_flat, (-1, num_atoms, 2))

    # Plot.
    with plt.style.context(PLT_STYLE_CONTEXT):
        fig, ax = plt.subplots()
        scatter = ax.scatter(projected_positions[0, :, 0], projected_positions[0, :, 1], c=plt.cm.tab20(nuclear_charges))
        handles = []
        for nuclear_charge in set(nuclear_charges):
            element = md17_simulation.charge_to_element(nuclear_charge)
            handle = mlines.Line2D([], [], color=plt.cm.tab20(nuclear_charge), marker='o', ls='', label=element)
            handles.append(handle)
        ax.legend(handles=handles, loc="upper right")
        ax.set_xlabel("x")
        ax.set_ylabel("y")
        ax.set_xlim(1.5 * v for v in ax.get_xlim())
        ax.set_ylim(1.5 * v for v in ax.get_ylim())
        ax.set_title(f"MD17: {molecule.capitalize()}")
        ax.grid(False)
        fig.tight_layout()

        def init():
            return scatter,
            
        def plot_structure(index):
            scatter.set_offsets(projected_positions[index, :])
            return scatter,

        anim = animation.FuncAnimation(fig, plot_structure, init_func=init, frames=num_structures, interval=100, blit=True)

    plt.close()
    return anim

In [8]:
for molecule in ["benzene", "ethanol", "aspirin", "toluene"]:
    md17_simulation.load_trajectory(molecule, 100000000, resample=False)

indices [128702 469127 576375 ... 455811 214108 178805] 100000
indices [390138 533956 496754 ... 133353 171562 425308] 100000
indices [161596 194785  68409 ... 169192 210850  19313] 100000
indices [320642 145570 331765 ... 290689 160278  39010] 100000


In [5]:
for molecule in ["benzene", "ethanol", "aspirin", "toluene"]:
    anim = visualize_trajectory(molecule, resample=True)
    # anim.save(f"../notebook_outputs/{molecule}.gif")
    # anim

indices 100000


TypeError: cannot unpack non-iterable NoneType object