In [1]:
%load_ext autoreload
%autoreload 2

In [12]:
import jax
import jax.numpy as jnp
from jax import vmap, jacrev, jacfwd

from moldex.descriptors import inverse_distance_matrix

In [23]:
# simulate the coordinates of a molecule of 100 atoms
coords = jax.random.normal(jax.random.PRNGKey(2023), shape=(100, 3))

Compute the descriptor

In [10]:
inverse_distance_matrix(coords)

Array([[0.        , 0.86971694, 0.41963887, ..., 0.2899901 , 1.0181781 ,
        0.6011578 ],
       [0.86971694, 0.        , 0.36460266, ..., 0.30790445, 0.5616225 ,
        0.48058152],
       [0.41963887, 0.36460266, 0.        , ..., 0.5877673 , 0.37714306,
        0.45621592],
       ...,
       [0.2899901 , 0.30790445, 0.5877673 , ..., 0.        , 0.27161607,
        0.34341046],
       [1.0181781 , 0.5616225 , 0.37714306, ..., 0.27161607, 0.        ,
        0.91935885],
       [0.6011578 , 0.48058152, 0.45621592, ..., 0.34341046, 0.91935885,
        0.        ]], dtype=float32)

In [11]:
%timeit inverse_distance_matrix(coords)

17.6 µs ± 342 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)


You can compute the jacobian of the transformation. Constructing it calling forward is faster

In [13]:
%timeit jacfwd(inverse_distance_matrix)(coords)

20.1 ms ± 812 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [14]:
%timeit jacrev(inverse_distance_matrix)(coords)

859 ms ± 30.8 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [16]:
diff = jacrev(inverse_distance_matrix)(coords) - jacfwd(inverse_distance_matrix)(coords)
jnp.all(diff == 0)

Array(True, dtype=bool)

You can vectorize over a trajectory

In [30]:
batched_inverse_distance_matrix = vmap(inverse_distance_matrix)

traj = jnp.array([coords for _ in range(500)])
traj.shape

(500, 100, 3)

In [34]:
matrices = batched_inverse_distance_matrix(traj)
matrices.shape

(500, 100, 100)

In [35]:
%timeit batched_inverse_distance_matrix(traj)

12 ms ± 450 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
