In [1]:
%load_ext autoreload
%autoreload 2

import os
os.environ['CUDA_VISIBLE_DEVICES'] = str(0)

In [2]:
import jax
import jax.numpy as jnp
from jax import vmap, jacrev, jacfwd

from tqdm import tqdm

from moldex.descriptors import coulomb_matrix

In [3]:
# simulate the coordinates and nuclear charges of a molecule of 100 atoms

def make_molecule(key, n_samples=100):
    coords = jax.random.normal(key, shape=(n_samples, 3))

    atnums = jnp.zeros(shape=(n_samples,))
    atnums = atnums.at[:int(n_samples/50)].add(8)
    atnums = atnums.at[int(n_samples/50):int(n_samples/5)].add(6)
    atnums = atnums.at[int(n_samples/5):].add(1)
    
    return coords, atnums

In [4]:
coords, atnums = make_molecule(jax.random.PRNGKey(2023))

Compute the descriptor

In [5]:
coulomb_matrix(coords, atnums)

Array([[73.51671   , 55.66189   , 20.142666  , ...,  2.3199205 ,
         8.145426  ,  4.8092623 ],
       [55.66189   , 73.51671   , 17.500927  , ...,  2.4632356 ,
         4.4929805 ,  3.8446524 ],
       [20.142666  , 17.500927  , 36.85811   , ...,  3.5266035 ,
         2.2628584 ,  2.7372956 ],
       ...,
       [ 2.3199205 ,  2.4632356 ,  3.5266035 , ...,  0.5       ,
         0.27161607,  0.34341043],
       [ 8.145426  ,  4.4929805 ,  2.2628584 , ...,  0.27161607,
         0.5       ,  0.91935885],
       [ 4.8092623 ,  3.8446524 ,  2.7372956 , ...,  0.34341043,
         0.91935885,  0.5       ]], dtype=float32)

In [6]:
%timeit coulomb_matrix(coords, atnums)

963 µs ± 42.5 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


You can compute the jacobian of the transformation. Constructing it calling forward is faster

In [7]:
%timeit jacfwd(coulomb_matrix)(coords, atnums)

7.61 ms ± 1.78 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [8]:
%timeit jacrev(coulomb_matrix)(coords, atnums)

15.2 ms ± 4.51 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [9]:
diff = jacrev(coulomb_matrix)(coords, atnums) - jacfwd(coulomb_matrix)(coords, atnums)
jnp.all(diff == 0)

Array(True, dtype=bool)

You can vectorize over a trajectory

In [10]:
batched_coulomb_matrix = vmap(coulomb_matrix)

In [11]:
traj_coords = jnp.array(
    [make_molecule(jax.random.PRNGKey(2023))[0] for _ in range(500)]
)

traj_charges = jnp.array(
    [make_molecule(jax.random.PRNGKey(2023))[1] for _ in range(500)]
)

traj_coords.shape, traj_charges.shape

((500, 100, 3), (500, 100))

In [12]:
cm_matrices = batched_coulomb_matrix(traj_coords, traj_charges)
cm_matrices.shape

(500, 100, 100)

In [13]:
%timeit batched_coulomb_matrix(traj_coords, traj_charges)

1.63 ms ± 25.8 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


The hessian of the transformation can be computed only farward as the backward suffers from the NaN propagation problem

In [14]:
func = jacfwd(jacfwd(coulomb_matrix))

func(coords[0:1], atnums[0:1])

Array([[[[[[0., 0., 0.]],

          [[0., 0., 0.]],

          [[0., 0., 0.]]]]]], dtype=float32)

In [15]:
func = jacrev(jacrev(coulomb_matrix))

func(coords[0:1], atnums[0:1])

Array([[[[[[nan, nan, nan]],

          [[nan, nan, nan]],

          [[nan, nan, nan]]]]]], dtype=float32)