In [2]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [3]:
import jax
import jax.numpy as jnp

from moldex.descriptors import dihe_matrix, batched_dihe_matrix
from moldex.mdtraj import dihe_indices_from_traj

# we use mdtraj here to load the trajectory, everything else is ok
# note that pytraj is stuck at python 3.7, while JAX has dropped
# support for python 3.7, so they are not compatible anymore
import mdtraj as md

In [5]:
traj = md.load('data/cla.nc', top='data/cla.prmtop')

In [7]:
# utility function that extracts the list of dihedrals from a mdtraj trajectory
indices = dihe_indices_from_traj(traj)

print(indices.shape)

(231, 4)


In [8]:
# get the coordinates
coords = jnp.array(traj.xyz)

print(coords.shape)

(240, 73, 3)


In [9]:
 # compute the dihedral matrix for a single frame
dihe_matrix(coords[0], indices)

Array([ 3.13530302e+00,  2.04905674e-01, -2.59717137e-01, -1.95855908e-02,
        3.09812403e+00,  2.57083356e-01, -2.87540507e+00, -4.13644850e-01,
        2.80197477e+00,  2.55245596e-01, -2.87479377e+00, -4.03175890e-01,
        2.76029396e+00, -3.12025881e+00,  2.43039820e-02,  3.10976934e+00,
       -7.84820139e-01,  1.64547339e-01,  1.49905071e-01,  3.06416965e+00,
       -6.74969494e-01, -5.07807195e-01,  1.16015688e-01,  3.09477687e+00,
        3.06947589e+00, -6.14747293e-02,  9.47540477e-02, -3.03391361e+00,
       -3.10919499e+00,  3.44398357e-02,  2.30579361e-01, -2.90440392e+00,
       -3.11103195e-01,  2.84032393e+00,  3.12583852e+00, -1.43637869e-03,
        2.19805239e-04, -3.12706995e+00,  1.04725465e-01, -3.03054357e+00,
        3.10047746e+00, -3.43141928e-02,  1.81184605e-01, -2.94513774e+00,
       -3.36321175e-01,  2.83658123e+00,  3.09322715e+00,  1.82595849e-01,
       -5.11180580e-01,  1.96055174e-01, -2.78063536e-01,  3.12669611e+00,
       -1.14519723e-01,  

In [10]:
# compute the dihedral matrix for a trajectory
batched_dihe_matrix(coords, indices)

Array([[ 3.1353030e+00,  2.0490567e-01, -2.5971714e-01, ...,
         4.9931936e-02, -3.1287675e+00,  2.6751435e+00],
       [-3.6459407e-01,  3.1013265e+00,  1.8986583e-01, ...,
         1.6139045e-04,  3.1293099e+00,  2.6755793e+00],
       [-1.8132296e-01,  3.1244881e+00,  2.3630585e-01, ...,
         2.1339286e-02,  3.1349351e+00,  2.3983095e+00],
       ...,
       [ 1.6997379e-01, -3.4190464e-01,  3.1237035e+00, ...,
         4.8143055e-02, -3.1328082e+00,  2.5498095e+00],
       [ 3.1036062e+00,  2.2845075e-01, -3.9643139e-01, ...,
         3.4486275e-02, -3.1333880e+00,  2.6467195e+00],
       [ 2.2556624e-01, -3.6422220e-01,  3.1413410e+00, ...,
         2.5410874e-02,  3.1339469e+00,  1.4440511e+00]], dtype=float32)

In [11]:
# derivatives work
try:
    jac = jax.jacrev(dihe_matrix)(coords[0], indices)
    print('NaN in jac:', jnp.any(jnp.isnan(jac)))
    print('jac.shape:', jac.shape)
except Exception:
    print('not reverse-mode differentiable')

NaN in jac: False
jac.shape: (231, 73, 3)


In [12]:
# derivatives work
try:
    jac = jax.jacfwd(dihe_matrix)(coords[0], indices)
    print('NaN in jac:', jnp.any(jnp.isnan(jac)))
    print('jac.shape:', jac.shape)
except Exception:
    print('not forward-mode differentiable')

NaN in jac: False
jac.shape: (231, 73, 3)
