In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import jax
import jax.numpy as jnp

from moldex.descriptors import angle_matrix, batched_angle_matrix
from moldex.mdtraj import angle_indices_from_top

# we use mdtraj here to load the trajectory
import mdtraj as md

In [3]:
traj = md.load('data/cla.nc', top='data/cla.prmtop')

In [4]:
 # utility function that extracts the list of bonds from a mdtraj trajectory
indices, angle_names = angle_indices_from_top(traj.top)

print(indices.shape)

No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)


(151, 3)


In [5]:
# note that you can use `angle_names` to identify the angles
# e.g., for the first angle
print(angle_names[0])

['CLA1@MG', 'CLA1@NA', 'CLA1@C1A']


In [6]:
# get the coordinates
coords = jnp.array(traj.xyz)

print(coords.shape)

(240, 73, 3)


In [7]:
# compute the angle matrix for a single frame
angle_matrix(coords[0], indices)

Array([2.22651  , 2.1156006, 2.2046494, 2.2061594, 2.134764 , 2.2044322,
       2.2986271, 2.104153 , 2.1025178, 2.244131 , 2.3340795, 1.9569604,
       1.8091801, 1.9938146, 1.9640675, 2.229071 , 2.1065707, 2.1284547,
       2.2383604, 2.121804 , 2.2687442, 2.2003345, 2.1870918, 2.197832 ,
       2.210595 , 2.1847842, 2.2366302, 1.9326735, 1.9470304, 1.8172264,
       2.0481417, 1.9335165, 1.7357104, 1.9870574, 1.9403062, 2.1132276,
       1.9576367, 1.7804285, 1.8121316, 2.0245125, 1.9442816, 1.9405506,
       1.9979782, 2.0034063, 1.595839 , 1.9158843, 1.8923274, 2.2604284,
       1.8701202, 2.0960517, 1.8727707, 2.2744286, 1.9610023, 1.9095137,
       1.891455 , 2.1914768, 2.040798 , 1.8723179, 2.3164625, 2.1357198,
       2.150674 , 2.0949423, 2.9185662, 1.5823983, 1.8908424, 1.863518 ,
       2.3170362, 1.8979928, 2.251046 , 1.8345889, 2.2364557, 1.9598796,
       2.0817668, 1.8696939, 2.0130725, 1.9949795, 1.9656956, 1.9358145,
       2.1340866, 2.210964 , 1.8203055, 1.9641558, 

In [8]:
# compute the angle matrix for a trajectory
batched_angle_matrix(coords, indices)

Array([[2.22651  , 2.1156006, 2.2046494, ..., 1.9818683, 1.967356 ,
        1.806127 ],
       [2.222229 , 2.1551707, 2.186135 , ..., 1.9811524, 1.9161453,
        1.835352 ],
       [2.2192702, 2.1346827, 2.1523967, ..., 2.0956373, 1.7886938,
        1.9214677],
       ...,
       [2.2591925, 2.1234214, 2.1567874, ..., 1.7077527, 1.9410682,
        1.9472398],
       [2.323841 , 2.1166172, 2.2478771, ..., 1.7797318, 1.93668  ,
        2.0286064],
       [2.1776063, 2.1574183, 2.2021227, ..., 1.9449124, 1.8768766,
        1.8192736]], dtype=float32)

In [9]:
# derivatives work
try:
    jac = jax.jacrev(angle_matrix)(coords[0], indices)
    print('NaN in jac:', jnp.any(jnp.isnan(jac)))
    print('jac.shape:', jac.shape)
except Exception:
    print('not reverse-mode differentiable')

NaN in jac: False
jac.shape: (151, 73, 3)


In [10]:
# derivatives work
try:
    jac = jax.jacfwd(angle_matrix)(coords[0], indices)
    print('NaN in jac:', jnp.any(jnp.isnan(jac)))
    print('jac.shape:', jac.shape)
except Exception:
    print('not forward-mode differentiable')

NaN in jac: False
jac.shape: (151, 73, 3)
