In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import jax
import jax.numpy as jnp

from moldex.descriptors import re_matrix, batched_re_matrix
from moldex.mdtraj import bond_indices_from_traj

# we use mdtraj here to load the trajectory
import mdtraj as md

In [3]:
traj = md.load('data/cla.nc', top='data/cla.prmtop')

In [4]:
indices = bond_indices_from_traj(traj)

No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)


In [5]:
coords = jnp.array(traj.xyz)

In [6]:
# compute the RE matrix w.r.t. the first frame, for the second frame
re_matrix(coords[1], coords[0], indices)

Array([1.0000091 , 0.99999774, 0.9999927 , 1.0000039 , 0.9999924 ,
       1.0000058 , 1.0000026 , 1.0000029 , 1.0000043 , 1.0000099 ,
       1.0000068 , 1.0000111 , 0.99999726, 1.000002  , 1.0000017 ,
       1.0000004 , 1.0000029 , 1.0000075 , 1.0000037 , 0.9999928 ,
       1.0000075 , 1.0000033 , 1.0000046 , 1.0000033 , 1.0000033 ,
       0.9999939 , 1.0224061 , 0.98417264, 0.99282277, 1.0199834 ,
       1.0300363 , 0.9641044 , 0.98808366, 0.9625573 , 0.99943346,
       0.99649835, 1.0334166 , 1.0041109 , 1.0177194 , 1.0078245 ,
       0.968264  , 0.9707567 , 1.011831  , 1.0014951 , 0.9391475 ,
       0.9955072 , 0.9553912 , 1.0105664 , 0.97211826, 0.9718572 ,
       0.99293023, 1.0117513 , 0.9490691 , 1.0237103 , 1.039378  ,
       0.92029387, 1.0354193 , 1.000646  , 0.999973  , 0.994055  ,
       1.0081117 , 1.0404235 , 0.9996572 , 0.95415306, 0.99744385,
       0.9907497 , 0.9957021 , 1.0351065 , 0.9358092 , 1.0027026 ,
       1.0043268 , 1.0150027 , 0.9875117 , 0.9946092 , 0.97621

In [7]:
# compute the RE matrix for a trajectory
batched_re_matrix(coords, coords[0], indices)

Array([[1.        , 1.        , 1.        , ..., 1.        , 1.        ,
        1.        ],
       [1.0000091 , 0.99999774, 0.9999927 , ..., 1.0216061 , 0.9544634 ,
        1.008654  ],
       [0.9999963 , 0.9999946 , 1.0000056 , ..., 0.9700271 , 1.001336  ,
        1.0246861 ],
       ...,
       [0.9999941 , 1.0000026 , 1.000002  , ..., 0.96465945, 0.95976794,
        1.0358901 ],
       [0.99999607, 0.9999918 , 1.0000027 , ..., 1.0136793 , 0.9676384 ,
        1.0064753 ],
       [0.99999756, 1.000001  , 1.0000057 , ..., 0.95508546, 0.9996796 ,
        1.0524362 ]], dtype=float32)

In [8]:
# derivatives work
try:
    jac = jax.jacrev(re_matrix)(coords[1], coords[0], indices)
    print('NaN in jac:', jnp.any(jnp.isnan(jac)))
    print('jac.shape:', jac.shape)
except Exception:
    print('not reverse-mode differentiable')

NaN in jac: False
jac.shape: (81, 73, 3)


In [9]:
# derivatives work
try:
    jac = jax.jacfwd(re_matrix)(coords[1], coords[0], indices)
    print('NaN in jac:', jnp.any(jnp.isnan(jac)))
    print('jac.shape:', jac.shape)
except Exception:
    print('not forward-mode differentiable')

NaN in jac: False
jac.shape: (81, 73, 3)
