In [8]:
import numpy as np
import scipy
from python.ADMPForce import read_mpid_inputs, setup_ewald_parameters
from scipy.stats import special_ortho_group
from python.utils import convert_cart2harm
import mpidplugin

mScales = np.array([0.0, 0.0, 0.0, 1.0])
pScales = np.array([0.0, 0.0, 0.0, 1.0])
dScales = np.array([0.0, 0.0, 0.0, 1.0])
rc = 8 # in Angstrom
ethresh = 1e-4

pdb = 'tests/samples/waterbox_31ang.pdb'
xml = 'tests/samples/mpidwater.xml'
# get a random geometry for testing

positions, box, list_elems, params = read_mpid_inputs(pdb, xml)

scipy.random.seed(1000)
R1 = special_ortho_group.rvs(3)
R2 = special_ortho_group.rvs(3)

positions[0:3] = positions[0:3].dot(R1)
positions[3:6] = positions[3:6].dot(R2)
positions[3:] += np.array([3.0, 0.0, 0.0])


kappa, K1, K2, K3 = setup_ewald_parameters(rc, ethresh, box)
kappa = 0.328532611
print(kappa, K1, K2, K3)

lmax = 2
Q_lh = params['multipoles_lh']
axis_types = params['axis_types']
axis_indices = params['axis_indices']


0.328532611 49 49 49


In [15]:
from jax import jit, grad, value_and_grad, partial
from python.pme import gen_pme_reciprocal

pme_reciprocal_energy = value_and_grad(gen_pme_reciprocal(axis_types, axis_indices))

pme_reci = jit(pme_reciprocal_energy, static_argnums=(4,5,6,7))

In [16]:
ene0, f0 = pme_reci(positions, box,  Q_lh, kappa, lmax, K1, K2, K3)
print(ene0, f0)

%timeit ene0, f0 = pme_reci(positions, box,  Q_lh, kappa, lmax, K1, K2, K3)

169.62799040544323
1.17 s ± 178 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [17]:
from python.pme import pme_self

pme_self_eandf = jit(value_and_grad(pme_self), static_argnums=2)

eneself, fself = pme_self_eandf(Q_lh, kappa)

print(eneself, fself)
%timeit eneself, fself = pme_self_eandf(Q_lh, kappa)

-434474.77274520777 [[ 5.46670600e+02  8.77286610e+00 -0.00000000e+00 ... -0.00000000e+00
  -1.77069278e-02 -0.00000000e+00]
 [-2.73335300e+02 -0.00000000e+00 -0.00000000e+00 ... -0.00000000e+00
  -0.00000000e+00 -0.00000000e+00]
 [-2.73335300e+02 -0.00000000e+00 -0.00000000e+00 ... -0.00000000e+00
  -0.00000000e+00 -0.00000000e+00]
 ...
 [ 5.46670600e+02  8.77286610e+00 -0.00000000e+00 ... -0.00000000e+00
  -1.77069278e-02 -0.00000000e+00]
 [-2.73335300e+02 -0.00000000e+00 -0.00000000e+00 ... -0.00000000e+00
  -0.00000000e+00 -0.00000000e+00]
 [-2.73335300e+02 -0.00000000e+00 -0.00000000e+00 ... -0.00000000e+00
  -0.00000000e+00 -0.00000000e+00]]
63.5 µs ± 4.79 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)


In [None]:
N_a = 2988
findiff = np.empty(N_a, 3))
delta = np.zeros((N_a,3))
for i in range(N_a):
    for j in range(3):
        delta[i][j] = 1
        findiff[i][j] = pme_reci(positions +delta, box,  Q_lh, kappa, lmax, K1, K2, K3)
        delta[i][j] = 0
print(findiff)

In [5]:
print("Testing behaviour of the same function, under different jax wrappers")
print("====================================================================")
ene1, f = value_and_grad(pme_reciprocal_energy)(positions, box,  Q_lh, kappa, lmax, K1, K2, K3)
print("value_and_grad:      ", ene1)
ene2 = pme_reciprocal_energy(positions, box,  Q_lh, kappa, lmax, K1, K2, K3)
print("only value:          ", ene2)
ene3, f = jit(value_and_grad(pme_reciprocal_energy), static_argnums=(4,5,6,7))(positions, box,  Q_lh, kappa, lmax, K1, K2, K3)
print("jit, value_and_grad: ", ene3)
ene4 = jit(pme_reciprocal_energy, static_argnums=(4,5,6,7))(positions, box,  Q_lh, kappa, lmax, K1, K2, K3)
print("jit, only value:     ", ene4)

Testing behaviour of the same function, under different jax wrappers
value_and_grad:       3.8504010162466096
only value:           3.850401016246607
jit, value_and_grad:  3.8504010162464084
jit, only value:      3.8504010162402884


In [7]:
import jax.numpy as jnp

def f(x):
    tmp = []
    for i in range(x[3].astype(int)):
        xx = x[i]
        tmp.append(xx**2)
    return jnp.sum(jnp.array(tmp))

x = jnp.array([1.,2.,3.,4.])

jitf = jit(f)

print(jitf(x))

TracerIntegerConversionError: The __index__() method was called on the JAX Tracer object Traced<ShapedArray(int64[])>with<DynamicJaxprTrace(level=0/1)>
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerIntegerConversionError