In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import jax
import jax.numpy as jnp

from moldex.descriptors import dihe_matrix, batched_dihe_matrix
from moldex.mdtraj import dihe_indices_from_top

# we use mdtraj here to load the trajectory
import mdtraj as md

In [3]:
traj = md.load('data/cla.nc', top='data/cla.prmtop')

In [4]:
# utility function that extracts the list of dihedrals from a mdtraj trajectory
indices, dihe_names = dihe_indices_from_top(traj.top)

print(indices.shape)

No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)


(231, 4)


In [5]:
# note that you can use `dihe_names` to identify the dihedrals
# e.g., for the first dihedral
print(dihe_names[0])

['CLA1@NB', 'CLA1@MG', 'CLA1@NA', 'CLA1@C1A']


In [6]:
# get the coordinates
coords = jnp.array(traj.xyz)

print(coords.shape)

(240, 73, 3)


In [7]:
 # compute the dihedral matrix for a single frame
dihe_matrix(coords[0], indices)

Array([-2.94677329e+00,  2.32801154e-01, -1.68454364e-01,  2.98628235e+00,
       -1.21233821e+00,  1.96723640e+00, -1.76939392e+00,  1.52196443e+00,
        2.27958765e-02, -3.08081508e+00,  3.13707352e+00,  1.12295039e-01,
       -1.54367149e-01,  3.07374716e+00, -1.72036007e-01,  3.00419593e+00,
       -3.09003448e+00,  6.47022799e-02, -3.34740095e-02, -3.02530098e+00,
       -1.96270645e+00,  1.19203019e+00, -1.35092747e+00,  1.90747941e+00,
        1.63732469e-02, -3.09430647e+00, -3.65233980e-02,  3.08083272e+00,
       -3.01524305e+00,  2.76115268e-01, -2.15713903e-01,  3.04269290e+00,
       -2.78411862e-02, -3.06721377e+00, -2.25434989e-01,  3.06849670e+00,
        9.06992704e-02, -2.96101522e+00, -1.49505764e-01,  3.03920770e+00,
        1.68850675e-01, -3.07630706e+00, -4.78879502e-03,  3.09037018e+00,
       -2.88469148e+00,  1.53335959e-01, -3.11616349e+00,  1.05633664e+00,
       -1.22810984e+00,  2.95238924e+00, -2.78823924e+00,  1.25945961e+00,
       -8.62994671e-01,  

In [8]:
# compute the dihedral matrix for a trajectory
batched_dihe_matrix(coords, indices)

Array([[-2.9467733e+00,  2.3280115e-01, -1.6845436e-01, ...,
         3.0993659e+00,  9.4922340e-01, -1.0601422e+00],
       [ 3.1320305e+00,  7.8233518e-04,  6.8332851e-02, ...,
        -1.2063485e+00,  2.8716404e+00,  9.2244256e-01],
       [-2.8778992e+00,  9.7910859e-02, -2.3373741e-01, ...,
        -8.8215351e-01,  3.0287113e+00,  1.0110229e+00],
       ...,
       [-3.0477479e+00,  2.1704023e-01, -2.9153711e-01, ...,
         8.8143826e-01, -1.1924365e+00,  3.0157919e+00],
       [-3.0620704e+00,  1.6442676e-01, -1.6026269e-01, ...,
         2.8853321e+00,  1.0130458e+00, -1.2372390e+00],
       [-2.9556723e+00,  1.1466689e-01, -8.7674893e-02, ...,
         9.9890798e-01, -1.2011826e+00,  3.1398885e+00]], dtype=float32)

In [9]:
# derivatives work
try:
    jac = jax.jacrev(dihe_matrix)(coords[0], indices)
    print('NaN in jac:', jnp.any(jnp.isnan(jac)))
    print('jac.shape:', jac.shape)
except Exception:
    print('not reverse-mode differentiable')

NaN in jac: False
jac.shape: (231, 73, 3)


In [10]:
# derivatives work
try:
    jac = jax.jacfwd(dihe_matrix)(coords[0], indices)
    print('NaN in jac:', jnp.any(jnp.isnan(jac)))
    print('jac.shape:', jac.shape)
except Exception:
    print('not forward-mode differentiable')

NaN in jac: False
jac.shape: (231, 73, 3)
