In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import jax
import jax.numpy as jnp

from moldex.descriptors import dihe_matrix, batched_dihe_matrix
from moldex.mdtraj import dihe_indices_from_top

# we use mdtraj here to load the trajectory
import mdtraj as md

In [3]:
traj = md.load('data/cla.nc', top='data/cla.prmtop')

In [5]:
# utility function that extracts the list of dihedrals from a mdtraj trajectory
indices = dihe_indices_from_top(traj.top)

print(indices.shape)

No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)


(231, 4)


In [6]:
# get the coordinates
coords = jnp.array(traj.xyz)

print(coords.shape)

(240, 73, 3)


In [7]:
 # compute the dihedral matrix for a single frame
dihe_matrix(coords[0], indices)

Array([-3.1000271e+00,  4.9931936e-02, -3.5542760e-02,  3.1088681e+00,
       -5.1245540e-01,  2.6751435e+00, -2.3126724e+00,  1.3517982e+00,
        4.8054699e-03, -3.1287675e+00,  3.1406932e+00,  2.2443105e-02,
       -2.0993857e-02,  3.1324239e+00, -2.4256693e-02,  3.1222885e+00,
       -3.1308036e+00,  1.3546157e-02, -7.3527605e-03, -3.1159475e+00,
       -2.6732388e+00,  4.8372468e-01, -7.2764283e-01,  2.6245043e+00,
        2.2925136e-03, -3.1349676e+00, -5.1495638e-03,  3.1330192e+00,
       -3.1137078e+00,  6.2137842e-02, -4.3589056e-02,  3.1218448e+00,
       -3.7106520e-03, -3.1316640e+00, -3.1780437e-02,  3.1314421e+00,
        1.2867751e-02, -3.1157668e+00, -2.0133192e-02,  3.1278589e+00,
        2.3542924e-02, -3.1325624e+00, -6.5250153e-04,  3.1346073e+00,
       -3.1053214e+00,  2.1344382e-02, -3.1376095e+00,  2.7029073e-01,
       -4.1364488e-01,  3.1157587e+00, -3.0848322e+00,  4.4649044e-01,
       -1.7812593e-01,  3.1309438e+00, -1.7040938e-02,  1.5709065e-02,
      

In [8]:
# compute the dihedral matrix for a trajectory
batched_dihe_matrix(coords, indices)

Array([[-3.1000271e+00,  4.9931936e-02, -3.5542760e-02, ...,
         3.1353030e+00,  2.0490567e-01, -2.5971714e-01],
       [ 3.1396198e+00,  1.6139045e-04,  1.4311945e-02, ...,
        -3.6459407e-01,  3.1013265e+00,  1.8986583e-01],
       [-3.0829978e+00,  2.1339286e-02, -5.0531890e-02, ...,
        -1.8132296e-01,  3.1244881e+00,  2.3630585e-01],
       ...,
       [-3.1210308e+00,  4.8143055e-02, -6.3534386e-02, ...,
         1.6997379e-01, -3.4190464e-01,  3.1237035e+00],
       [-3.1250246e+00,  3.4486275e-02, -3.4065552e-02, ...,
         3.1036062e+00,  2.2845075e-01, -3.9643139e-01],
       [-3.1001079e+00,  2.5410874e-02, -1.8174116e-02, ...,
         2.2556624e-01, -3.6422220e-01,  3.1413410e+00]], dtype=float32)

In [9]:
# derivatives work
try:
    jac = jax.jacrev(dihe_matrix)(coords[0], indices)
    print('NaN in jac:', jnp.any(jnp.isnan(jac)))
    print('jac.shape:', jac.shape)
except Exception:
    print('not reverse-mode differentiable')

NaN in jac: False
jac.shape: (231, 73, 3)


In [10]:
# derivatives work
try:
    jac = jax.jacfwd(dihe_matrix)(coords[0], indices)
    print('NaN in jac:', jnp.any(jnp.isnan(jac)))
    print('jac.shape:', jac.shape)
except Exception:
    print('not forward-mode differentiable')

NaN in jac: False
jac.shape: (231, 73, 3)
