In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import jax
import jax.numpy as jnp
from jax import vmap, jacrev, jacfwd

from moldex.descriptors import inverse_distance_matrix

In [3]:
# simulate the coordinates of a molecule of 100 atoms
coords = jax.random.normal(jax.random.PRNGKey(2023), shape=(100, 3))

Compute the descriptor

In [5]:
inverse_distance_matrix(coords)

Array([[0.        , 0.869717  , 0.41963887, ..., 0.28999007, 1.0181782 ,
        0.6011578 ],
       [0.869717  , 0.        , 0.36460266, ..., 0.30790445, 0.56162256,
        0.48058155],
       [0.41963887, 0.36460266, 0.        , ..., 0.58776724, 0.37714309,
        0.45621592],
       ...,
       [0.28999007, 0.30790445, 0.58776724, ..., 0.        , 0.27161607,
        0.34341043],
       [1.0181782 , 0.56162256, 0.37714309, ..., 0.27161607, 0.        ,
        0.91935885],
       [0.6011578 , 0.48058155, 0.45621592, ..., 0.34341043, 0.91935885,
        0.        ]], dtype=float32)

In [10]:
%timeit inverse_distance_matrix(coords)

547 µs ± 3.88 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


You can compute the jacobian of the transformation. Constructing it calling forward is faster

In [11]:
%timeit jacfwd(inverse_distance_matrix)(coords)

7.05 ms ± 819 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [12]:
%timeit jacrev(inverse_distance_matrix)(coords)

14.4 ms ± 2.41 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [17]:
diff = jacrev(inverse_distance_matrix)(coords) - jacfwd(inverse_distance_matrix)(coords)
jnp.all(diff == 0)

Array(True, dtype=bool)

You can vectorize over a trajectory

In [18]:
batched_inverse_distance_matrix = vmap(inverse_distance_matrix)

traj = jnp.array([coords for _ in range(500)])
traj.shape

(500, 100, 3)

In [19]:
matrices = batched_inverse_distance_matrix(traj)
matrices.shape

(500, 100, 100)

In [20]:
%timeit batched_inverse_distance_matrix(traj)

1.49 ms ± 65.2 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
