In [1]:
%load_ext autoreload
%autoreload 2

In [3]:
import jax
import jax.numpy as jnp

from moldex.descriptors import angle_matrix, batched_angle_matrix
from moldex.mdtraj import angle_indices_from_traj

# we use mdtraj here to load the trajectory, everything else is ok
# note that pytraj is stuck at python 3.7, while JAX has dropped
# support for python 3.7, so they are not compatible anymore
import mdtraj as md

In [5]:
traj = md.load('data/cla.nc', top='data/cla.prmtop')

In [7]:
 # utility function that extracts the list of bonds from a mdtraj trajectory
indices = angle_indices_from_traj(traj)

print(indices.shape)

(151, 3)


In [8]:
# get the coordinates
coords = jnp.array(traj.xyz)

print(coords.shape)

(240, 73, 3)


In [9]:
# compute the angle matrix for a single frame
angle_matrix(coords[0], indices)

Array([1.9818683, 1.967356 , 1.8203231, 1.806127 , 1.9542804, 1.9381222,
       1.9757938, 1.7614589, 1.9640675, 2.035383 , 1.8506817, 2.0349817,
       1.6063879, 1.8441117, 2.0546854, 2.0039263, 1.9081216, 1.8203055,
       1.8438263, 1.9641558, 1.9241865, 1.8520817, 1.7264885, 1.9949795,
       1.8909944, 1.9656956, 1.9616767, 1.6485453, 1.9598796, 1.8866559,
       2.0817668, 1.8696939, 2.031957 , 2.150674 , 2.0949423, 2.0463932,
       2.040798 , 1.8333824, 2.0170655, 1.9610023, 1.8450758, 1.9095137,
       1.891455 , 1.819781 , 1.883445 , 1.9576367, 1.8832703, 1.7804285,
       2.0004005, 1.8650947, 1.8121316, 1.8146857, 2.0245125, 1.9442816,
       1.9955633, 2.0816944, 2.2960792, 1.8961813, 1.9374988, 1.9938146,
       2.1371596, 2.2647326, 1.881166 , 1.8091801, 1.9053597, 1.8192413,
       1.9913548, 1.9569604, 2.5455775, 2.2123084, 1.8734305, 2.1973152,
       1.8577409, 2.2366302, 1.8758085, 2.1847842, 2.2986271, 2.3340795,
       2.104153 , 2.0130725, 2.210964 , 1.8345889, 

In [10]:
# compute the angle matrix for a trajectory
batched_angle_matrix(coords, indices)

Array([[1.9818683, 1.967356 , 1.8203231, ..., 1.5850295, 2.9185662,
        1.4839582],
       [1.9811524, 1.9161453, 2.0085797, ..., 1.5680245, 2.9959164,
        1.5105282],
       [2.0956373, 1.7886938, 1.9702884, ..., 1.5768598, 2.9594774,
        1.5177646],
       ...,
       [1.7077527, 1.9410682, 2.1428626, ..., 1.5278413, 2.9514728,
        1.5618576],
       [1.7797318, 1.93668  , 1.8562516, ..., 1.5510342, 2.8746982,
        1.5047835],
       [1.9449124, 1.8768766, 2.0697618, ..., 1.600144 , 3.0523016,
        1.5409843]], dtype=float32)

In [11]:
# derivatives work
try:
    jac = jax.jacrev(angle_matrix)(coords[0], indices)
    print('NaN in jac:', jnp.any(jnp.isnan(jac)))
    print('jac.shape:', jac.shape)
except Exception:
    print('not reverse-mode differentiable')

NaN in jac: False
jac.shape: (151, 73, 3)


In [12]:
# derivatives work
try:
    jac = jax.jacfwd(angle_matrix)(coords[0], indices)
    print('NaN in jac:', jnp.any(jnp.isnan(jac)))
    print('jac.shape:', jac.shape)
except Exception:
    print('not forward-mode differentiable')

NaN in jac: False
jac.shape: (151, 73, 3)
