In [4]:
## Imports
import numpy as np
import jax
import matplotlib
from matplotlib import animation
import matplotlib.pyplot as plt
import matplotlib.lines as mlines
import os
from sklearn import decomposition
import scipy.interpolate
import trajectory

matplotlib.rc("animation", html="jshtml")
PLT_STYLE_CONTEXT = ['science', 'ieee', 'grid']


In [2]:
def visualize_trajectory(molecule: str, num_structures: int = 100, resample: bool = False):
    """Visualizes a molecular dynamics trajectory."""

    # Load trajectory
    positions, _, nuclear_charges = trajectory.load_trajectory(molecule, num_structures + 1, resample)

    # Project down to two dimensions.
    num_atoms = positions.shape[1]
    positions_flat = positions.reshape(-1, 3)
    projected_positions_flat = decomposition.PCA(n_components=2).fit_transform(positions_flat)
    projected_positions = np.reshape(projected_positions_flat, (-1, num_atoms, 2))

    # Plot.
    with plt.style.context(PLT_STYLE_CONTEXT):
        fig, ax = plt.subplots()
        scatter = ax.scatter(projected_positions[0, :, 0], projected_positions[0, :, 1], c=plt.cm.tab20(nuclear_charges))
        handles = []
        for nuclear_charge in set(nuclear_charges):
            element = trajectory.charge_to_element(nuclear_charge)
            handle = mlines.Line2D([], [], color=plt.cm.tab20(nuclear_charge), marker='o', ls='', label=element)
            handles.append(handle)
        ax.legend(handles=handles, loc="upper right")
        ax.set_xlabel("x")
        ax.set_ylabel("y")
        ax.set_xlim(1.5 * v for v in ax.get_xlim())
        ax.set_ylim(1.5 * v for v in ax.get_ylim())
        ax.set_title(f"MD17: {molecule.capitalize()}")
        ax.grid(False)
        fig.tight_layout()

        def init():
            return scatter,
            
        def plot_structure(index):
            scatter.set_offsets(projected_positions[index, :])
            return scatter,

        anim = animation.FuncAnimation(fig, plot_structure, init_func=init, frames=num_structures, interval=100, blit=True)

    plt.close()
    return anim

In [3]:
for molecule in ["benzene", "ethanol", "aspirin", "toluene"]:
    anim = visualize_trajectory(molecule, resample=True)
    anim.save(f"{molecule}.gif")
    anim

[12 12 12 12 12 12  1  1  1  1  1  1] (1, 12, 1)
[12 12 16  1  1  1  1  1  1] (1, 9, 1)
[12 12 12 12 12 12 12 16 16 16 12 12 16  1  1  1  1  1  1  1  1] (1, 21, 1)
[12 12 12 12 12 12 12  1  1  1  1  1  1  1  1] (1, 15, 1)
