In [1]:
%load_ext autoreload
%autoreload 2

In [21]:
import jax
import jax.numpy as jnp

from moldex.descriptors import bond_matrix, batched_bond_matrix
from moldex.mdtraj import bond_indices_from_traj

# we use mdtraj here to load the trajectory, everything else is ok
# note that pytraj is stuck at python 3.7, while JAX has dropped
# support for python 3.7, so they are not compatible anymore
import mdtraj as md

In [22]:
traj = md.load('data/cla.nc', top='data/cla.prmtop')

In [23]:
# utility function that extracts the list of bonds from a mdtraj trajectory
indices = get_bond_indices_from_traj(traj)

print(indices.shape)

(81, 2)


In [24]:
# get the coordinates
coords = jnp.array(traj.xyz)

print(coords.shape)

(240, 73, 3)


In [25]:
# compute the bond matrix for a single frame
bond_matrix(coords[0], indices)

Array([0.10900015, 0.10899981, 0.10900021, 0.10900003, 0.10899995,
       0.10900068, 0.10899989, 0.1090004 , 0.10900027, 0.1090014 ,
       0.10900011, 0.10900082, 0.10899982, 0.10899989, 0.10900044,
       0.11039988, 0.1104002 , 0.1090006 , 0.10899998, 0.10900036,
       0.10900018, 0.10900051, 0.10900033, 0.10899986, 0.10899995,
       0.10900011, 0.14886034, 0.12289649, 0.13735062, 0.15188728,
       0.12393261, 0.15120994, 0.14361952, 0.14191373, 0.14137028,
       0.14781298, 0.15006383, 0.14149144, 0.13367821, 0.1587133 ,
       0.14613053, 0.14780726, 0.14026225, 0.14801192, 0.14149864,
       0.13324559, 0.13862422, 0.13501924, 0.14573999, 0.14895985,
       0.14047064, 0.15829708, 0.14376213, 0.14000341, 0.14093228,
       0.1472915 , 0.15638068, 0.15358779, 0.10744502, 0.15916637,
       0.15639038, 0.10956812, 0.15405755, 0.1349374 , 0.13963142,
       0.13848332, 0.14168221, 0.10690728, 0.13859981, 0.14117765,
       0.10654826, 0.13879672, 0.14105293, 0.11009594, 0.13812

In [26]:
# compute the bond matrix for a trajectory
batched_bond_matrix(coords, indices)

Array([[0.10900015, 0.10899981, 0.10900021, ..., 0.21076956, 0.1990509 ,
        0.21957746],
       [0.10899916, 0.10900006, 0.109001  , ..., 0.20631197, 0.20854744,
        0.21769354],
       [0.10900056, 0.1090004 , 0.1089996 , ..., 0.21728213, 0.19878533,
        0.21428753],
       ...,
       [0.1090008 , 0.10899953, 0.10899998, ..., 0.21849117, 0.20739482,
        0.21196984],
       [0.10900059, 0.10900071, 0.10899992, ..., 0.20792529, 0.20570794,
        0.21816477],
       [0.10900042, 0.10899971, 0.10899958, ..., 0.22068137, 0.1991147 ,
        0.20863731]], dtype=float32)

In [27]:
# derivatives work
try:
    jac = jax.jacrev(bond_matrix)(coords[0], indices)
    print('NaN in jac:', jnp.any(jnp.isnan(jac)))
    print('jac.shape:', jac.shape)
except Exception:
    print('not reverse-mode differentiable')

NaN in jac: False
jac.shape: (81, 73, 3)


In [28]:
# derivatives work
try:
    jac = jax.jacfwd(bond_matrix)(coords[0], indices)
    print('NaN in jac:', jnp.any(jnp.isnan(jac)))
    print('jac.shape:', jac.shape)
except Exception:
    print('not forward-mode differentiable')

NaN in jac: False
jac.shape: (81, 73, 3)
