In [9]:
import numpy as np
import scipy
from python.ADMPForce import read_mpid_inputs, setup_ewald_parameters
from scipy.stats import special_ortho_group
from python.utils import convert_cart2harm
import mpidplugin

mScales = np.array([0.0, 0.0, 0.0, 1.0])
pScales = np.array([0.0, 0.0, 0.0, 1.0])
dScales = np.array([0.0, 0.0, 0.0, 1.0])
rc = 8 # in Angstrom
ethresh = 1e-4

pdb = 'tests/samples/waterdimer_aligned.pdb'
xml = 'tests/samples/mpidwater.xml'
# get a random geometry for testing

positions, box, list_elems, params = read_mpid_inputs(pdb, xml)

scipy.random.seed(1000)
R1 = special_ortho_group.rvs(3)
R2 = special_ortho_group.rvs(3)

positions[0:3] = positions[0:3].dot(R1)
positions[3:6] = positions[3:6].dot(R2)
positions[3:] += np.array([3.0, 0.0, 0.0])


kappa, K1, K2, K3 = setup_ewald_parameters(rc, ethresh, box)
kappa = 0.328532611
print(kappa, K1, K2, K3)

lmax = 2
Q_lh = params['multipoles_lh']
axis_types = params['axis_types']
axis_indices = params['axis_indices']


0.328532611 31 31 31


In [63]:
from jax import jit, grad, value_and_grad, partial

In [33]:
from python.pme import gen_pme_reciprocal

pme_reciprocal_energy = gen_pme_reciprocal(axis_types, axis_indices)

pme_reci = jit(pme_reciprocal_energy, static_argnums=(4,5,6,7))

In [34]:
ene0 = pme_reci(positions, box,  Q_lh, kappa, lmax, K1, K2, K3)
print(ene0)
findiff = np.empty((6, 3))
delta = np.zeros((6,3))
for i in range(6):
    for j in range(3):
        delta[i][j] = 1
        findiff[i][j] = pme_reci(positions +delta, box,  Q_lh, kappa, lmax, K1, K2, K3)
        delta[i][j] = 0
print(findiff)

3.8504010162402884
[[23.93929547 13.56566346  4.28706653]
 [ 4.92445529  9.44444811 14.24634398]
 [ 4.50575493  9.68616694 13.49301647]
 [26.24020777 20.50990233 18.39138227]
 [ 7.39934245  4.85842477  4.89947731]
 [ 3.27954222  7.88590133 10.72833253]]


In [35]:
print("Testing behaviour of the same function, under different jax wrappers")
print("====================================================================")
ene1, f = value_and_grad(pme_reciprocal_energy)(positions, box,  Q_lh, kappa, lmax, K1, K2, K3)
print("value_and_grad:      ", ene1)
ene2 = pme_reciprocal_energy(positions, box,  Q_lh, kappa, lmax, K1, K2, K3)
print("only value:          ", ene2)
ene3, f = jit(value_and_grad(pme_reciprocal_energy), static_argnums=(4,5,6,7))(positions, box,  Q_lh, kappa, lmax, K1, K2, K3)
print("jit, value_and_grad: ", ene3)
ene4 = jit(pme_reciprocal_energy, static_argnums=(4,5,6,7))(positions, box,  Q_lh, kappa, lmax, K1, K2, K3)
print("jit, only value:     ", ene4)

Testing behaviour of the same function, under different jax wrappers
value_and_grad:       3.8504010162466096
only value:           3.850401016246607
jit, value_and_grad:  3.8504010162464084
jit, only value:      3.8504010162402884


In [76]:
def f(x):
    tmp = []
    for i in range(x[3].astype(int)):
        xx = x[i]
        tmp.append(xx**2)
    return jnp.sum(jnp.array(tmp))

x = jnp.array([1.,2.,3.,4.])

jitf = jit(f)

print(jitf(x))

TracerIntegerConversionError: The __index__() method was called on the JAX Tracer object Traced<ShapedArray(int64[])>with<DynamicJaxprTrace(level=0/1)>
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerIntegerConversionError