In [1]:
%load_ext snakeviz
%config InlineBackend.figure_format = 'svg'
%config InlineBackend.figure_format = 'png'

In [2]:
import numpy as np
import scipy
from python.ADMPForce import ADMPGenerator
from scipy.stats import special_ortho_group
from python.utils import convert_cart2harm
import mpidplugin

mScales = np.array([0.0, 0.0, 0.0, 1.0])
pScales = np.array([0.0, 0.0, 0.0, 1.0])
dScales = np.array([0.0, 0.0, 0.0, 1.0])
rc = 8 # in Angstrom
ethresh = 1e-4


pdb = 'tests/samples/waterdimer_aligned.pdb'
xml = 'tests/samples/mpidwater.xml'
generator = ADMPGenerator(pdb, xml, rc, ethresh, mScales, pScales, dScales, )
# get a random geometry for testing
scipy.random.seed(1000)
R1 = special_ortho_group.rvs(3)
R2 = special_ortho_group.rvs(3)

positions = generator.positions
positions[0:3] = positions[0:3].dot(R1)
positions[3:6] = positions[3:6].dot(R2)
positions[3:] += np.array([3.0, 0.0, 0.0])


force = generator.create_force()
force.update()
force.kappa = 0.328532611

multipoles_lc = np.concatenate((np.expand_dims(force.mpid_params['charges'], axis=1), force.mpid_params['dipoles'], force.mpid_params['quadrupoles']), axis=1)
Q_lh = convert_cart2harm(multipoles_lc, lmax=2)
axis_types = force.mpid_params['axis_types']
axis_indices = force.mpid_params['axis_indices']


In [3]:
print(positions)

[[ 0.          0.          0.        ]
 [-0.72004306 -0.01869945  0.63058173]
 [ 0.68557522  0.50825211  0.43371812]
 [ 3.          0.          0.        ]
 [ 3.02161682 -0.52203431 -0.80215827]
 [ 2.08174622 -0.01347945  0.27032073]]


In [30]:
from python.pme import pme_reciprocal_energy_from_lh, pme_reciprocal_force_alt
%timeit ene0 = pme_reciprocal_energy_from_lh(positions, force.box,  Q_lh, force.kappa, force.lmax, force.K1, force.K2, force.K3, axis_types, axis_indices)

In [29]:
%timeit f = pme_reciprocal_force_alt(positions, force.box,  Q_lh, force.kappa, force.lmax, force.K1, force.K2, force.K3, axis_types, axis_indices)

157 ms ± 1.59 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [27]:
delt = 1e-5
findiff = np.empty((6,3))
for i in range(6):
    for j in range(3):
        delta = np.zeros((6,3))
        delta[i,j] = delt
        findiff[i, j] = (pme_reciprocal_energy_from_lh(positions+delta, force.box,  Q_lh, force.kappa, force.lmax, force.K1, force.K2, force.K3, axis_types, axis_indices) - ene0)/delt

In [28]:
print(f/findiff)

[[0.99996795 1.00004606 1.00001679]
 [1.00001857 0.99997733 0.9999921 ]
 [1.00001007 0.99997431 0.99998791]
 [0.99996379 0.99992058 1.00068487]
 [1.0000633  1.00002242 1.00002565]
 [1.00001415 0.99970273 0.99997981]]


In [5]:
%timeit ene0 = force.calc_reci_space_energy()
%timeit force0 = force.calc_reci_space_force()
ene0 = force.calc_reci_space_energy()
force0 = force.calc_reci_space_force()
print(ene0)
print(force0)

5.79 ms ± 167 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
9 ms ± 394 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)
3.8504010162402884
[[  5.58156365  -2.58004665 -10.82644139]
 [ -2.49424205   1.30562368   5.67417203]
 [ -3.48460866   1.52578869   5.09100699]
 [  5.03836956   1.55864688  -0.65421062]
 [ -0.98473687  -2.01505676  -1.87536137]
 [ -3.65628997   0.20508466   2.59076499]]


In [6]:
from python.pme import pme_reciprocal, pme_reciprocal_force, dielectric
import numpy as np
print("====================Testing jax performance on CPU====================")
print("Calculating energy and force for: {N_a} atoms, {grid_dim}*{grid_dim}*{grid_dim} grid points".format(N_a  =4000 // 6 * 6, grid_dim = 30))
%timeit pme_reciprocal(np.repeat(force.positions, 4000 //6, axis = 0), force.box, np.repeat(force.Q, 4000 //6, axis = 0), force.kappa, force.lmax, 30, 30, 30)

Calculating energy and force for: 3996 atoms, 30*30*30 grid points


KeyboardInterrupt: 

In [None]:
import numpy as np

from tqdm import trange
def pme_reciprocal_numpy(positions, box, Q, kappa, lmax, K1, K2, K3):
    N = np.array([K1,K2,K3])
    padder = np.arange(-3, 3)
    shifts = np.array(np.meshgrid(padder, padder, padder)).T.reshape((1, 216, 3))
    
    def get_recip_vectors(N, box):
        """
        Computes reciprocal lattice vectors of the grid
        
        Input:
            N:
                (3,)-shaped array
            box:
                3 x 3 matrix, box parallelepiped vectors arranged in TODO rows or columns?
                
        Output: 
            Nj_Aji_star:
                3 x 3 matrix, the first index denotes reciprocal lattice vector, the second index is the component xyz.
                (lattice vectors arranged in rows)
        """
        Nj_Aji_star = (N.reshape((1, 3)) * np.linalg.inv(box)).T
        return Nj_Aji_star

    def u_reference(R_a, Nj_Aji_star):
        """
        Each atom is meshed to PME_ORDER**3 points on the m-meshgrid. This function computes the xyz-index of the reference point, which is the point on the meshgrid just above atomic coordinates, and the corresponding values of xyz fractional displacements from real coordinate to the reference point. 
        
        Inputs:
            R_a:
                N_a * 3 matrix containing positions of sites
            Nj_Aji_star:
                3 x 3 matrix, the first index denotes reciprocal lattice vector, the second index is the component xyz.
                (lattice vectors arranged in rows)
                
        Outputs:
            m_u0: 
                N_a * 3 matrix, positions of the reference points of R_a on the m-meshgrid
            u0: 
                N_a * 3 matrix, (R_a - R_m)*a_star values
        """
        
        R_in_m_basis =  np.einsum("ij,kj->ki", Nj_Aji_star, R_a)
        
        m_u0 = np.ceil(R_in_m_basis).astype(int)
        
        u0 = (m_u0 - R_in_m_basis) + 6/2
        return m_u0, u0

    def bspline6(u):
        """
        Computes the cardinal B-spline function
        """
        return np.piecewise(u, 
                            [np.logical_and(u>=0, u<1.), 
                            np.logical_and(u>=1, u<2.), 
                            np.logical_and(u>=2, u<3.), 
                            np.logical_and(u>=3, u<4.), 
                            np.logical_and(u>=4, u<5.), 
                            np.logical_and(u>=5, u<6.)],
                            [lambda u: u**5/120,
                            lambda u: u**5/120 - (u - 1)**5/20,
                            lambda u: u**5/120 + (u - 2)**5/8 - (u - 1)**5/20,
                            lambda u: u**5/120 - (u - 3)**5/6 + (u - 2)**5/8 - (u - 1)**5/20,
                            lambda u: u**5/24 - u**4 + 19*u**3/2 - 89*u**2/2 + 409*u/4 - 1829/20,
                            lambda u: -u**5/120 + u**4/4 - 3*u**3 + 18*u**2 - 54*u + 324/5] )

    def bspline6prime(u):
        """
        Computes first derivative of the cardinal B-spline function
        """
        return np.piecewise(u, 
                            [np.logical_and(u>=0., u<1.), 
                            np.logical_and(u>=1., u<2.), 
                            np.logical_and(u>=2., u<3.), 
                            np.logical_and(u>=3., u<4.), 
                            np.logical_and(u>=4., u<5.), 
                            np.logical_and(u>=5., u<6.)],
                            [lambda u: u**4/24,
                            lambda u: u**4/24 - (u - 1)**4/4,
                            lambda u: u**4/24 + 5*(u - 2)**4/8 - (u - 1)**4/4,
                            lambda u: -5*u**4/12 + 6*u**3 - 63*u**2/2 + 71*u - 231/4,
                            lambda u: 5*u**4/24 - 4*u**3 + 57*u**2/2 - 89*u + 409/4,
                            lambda u: -u**4/24 + u**3 - 9*u**2 + 36*u - 54] )

    def bspline6prime2(u):
        """
        Computes second derivate of the cardinal B-spline function
        """
        return np.piecewise(u, 
                            [np.logical_and(u>=0., u<1.), 
                            np.logical_and(u>=1., u<2.), 
                            np.logical_and(u>=2., u<3.), 
                            np.logical_and(u>=3., u<4.), 
                            np.logical_and(u>=4., u<5.), 
                            np.logical_and(u>=5., u<6.)],
                            [lambda u: u**3/6,
                            lambda u: u**3/6 - (u - 1)**3,
                            lambda u: 5*u**3/3 - 12*u**2 + 27*u - 19,
                            lambda u: -5*u**3/3 + 18*u**2 - 63*u + 71,
                            lambda u: 5*u**3/6 - 12*u**2 + 57*u - 89,
                            lambda u: -u**3/6 + 3*u**2 - 18*u + 36,] )


    def theta_eval(u):
        """
        Evaluates the value of theta given 3D u values at ... points 
        
        Input:
            u:
                ... x 3 matrix

        Output:
            theta:
                ... matrix
        """
        theta = np.prod(bspline6(u), axis = -1)
        return theta

    def thetaprime_eval(u, Nj_Aji_star, theta):
        """
        First derivative of theta with respect to x,y,z directions
        
        Input:
            u
            Nj_Aji_star:
                reciprocal lattice vectors
            theta
        
        Output:
            N_a * 3 matrix
        """

        M_u = bspline6(u)
        Mprime_u = bspline6prime(u)
        div = np.array([
            Mprime_u[:, 0] * M_u[:, 1] * M_u[:, 2],
            Mprime_u[:, 1] * M_u[:, 2] * M_u[:, 0],
            Mprime_u[:, 2] * M_u[:, 0] * M_u[:, 1],
        ]).T
        
        # Notice that u = m_u0 - R_in_m_basis + 6/2
        # therefore the Jacobian du_j/dx_i = - Nj_Aji_star
        return np.einsum("ij,kj->ki", -Nj_Aji_star, div)

    def theta2prime_eval(u, Nj_Aji_star):
        """
        compute the 3 x 3 second derivatives of theta with respect to xyz
        
        Input:
            u
            Nj_Aji_star
        
        Output:
            N_A * 3 * 3
        """
        M_u = bspline6(u)
        Mprime_u = bspline6prime(u)
        M2prime_u = bspline6prime2(u)

        div_00 = M2prime_u[:, 0] * M_u[:, 1] * M_u[:, 2]
        div_11 = M2prime_u[:, 1] * M_u[:, 0] * M_u[:, 2]
        div_22 = M2prime_u[:, 2] * M_u[:, 0] * M_u[:, 1]
        
        div_01 = Mprime_u[:, 0] * Mprime_u[:, 1] * M_u[:, 2]
        div_02 = Mprime_u[:, 0] * Mprime_u[:, 2] * M_u[:, 1]
        div_12 = Mprime_u[:, 1] * Mprime_u[:, 2] * M_u[:, 0]

        div_10 = div_01
        div_20 = div_02
        div_21 = div_12
        
        div = np.array([
            [div_00, div_01, div_02],
            [div_10, div_11, div_12],
            [div_20, div_21, div_22],
        ]).swapaxes(0, 2)
        
        # Notice that u = m_u0 - R_in_m_basis + 6/2
        # therefore the Jacobian du_j/dx_i = - Nj_Aji_star
        return np.einsum("im,jn,kmn->kij", -Nj_Aji_star, -Nj_Aji_star, div)

    def sph_harmonics_GO(u0, Nj_Aji_star, lmax):
        '''
        Find out the value of spherical harmonics GRADIENT OPERATORS, assume the order is:
        00, 10, 11c, 11s, 20, 21c, 21s, 22c, 22s, ...
        Currently supports lmax <= 2

        Inputs:
            u0: 
                a N_a * 3 matrix containing all positions
            Nj_Aji_star:
                reciprocal lattice vectors in the m-grid
            lmax: 
                the maximum l value

        Output: 
            harmonics: 
                a Na * (6**3) * (l+1)^2 matrix, STGO operated on theta,
                evaluated at 6*6*6 integer points about reference points m_u0 
        '''
        
        n_harm = (lmax + 1)**2

        N_a = u0.shape[0]
        u = (u0[:, np.newaxis, :] + shifts).reshape((N_a*216, 3)) 

        theta = theta_eval(u)
                
        if lmax == 0:
            return theta.reshape(N_a, 216, n_harm)
        
        # dipole
        thetaprime = thetaprime_eval(u, Nj_Aji_star, theta)
        harmonics_1 = np.stack(
            [theta,
            thetaprime[:, 2],
            thetaprime[:, 0],
            thetaprime[:, 1]],
            axis = -1
        )
        
        if lmax == 1:
            return harmonics_1.reshape(N_a, 216, n_harm)

        # quadrapole
        theta2prime = theta2prime_eval(u, Nj_Aji_star)
        rt3 = np.sqrt(3)
        harmonics_2 = np.hstack(
            [harmonics_1,
            np.stack([(3*theta2prime[:,2,2] - np.trace(theta2prime, axis1=1, axis2=2)) / 2,
            rt3 * theta2prime[:, 0, 2],
            rt3 * theta2prime[:, 1, 2],
            rt3/2 * (theta2prime[:, 0, 0] - theta2prime[:, 1, 1]),
            rt3 * theta2prime[:, 0, 1]], axis = 1)]
        )
        
        if lmax == 2:
            return harmonics_2.reshape(N_a, 216, n_harm)
        else:
            raise NotImplementedError('l > 2 (beyond quadrupole) not supported')
        
    def Q_m_peratom(Q, sph_harms):
        """
        Computes <R_t|Q>. See eq. (49) of https://doi.org/10.1021/ct5007983
        
        Inputs:
            Q: 
                N_a * (l+1)**2 matrix containing global frame multipole moments up to lmax,
            sph_harms:
                N_a, 216, (l+1)**2
        
        Output:
            Q_m_pera:
                N_a * 216 matrix, values of theta evaluated on a 6 * 6 block about the atoms
        """
        
        N_a = sph_harms.shape[0]
        
        if lmax > 2:
            raise NotImplementedError('l > 2 (beyond quadrupole) not supported')

        Q_dbf = Q[:, 0]
        if lmax >= 1:
            Q_dbf = np.hstack([Q_dbf[:,np.newaxis], Q[:,1:4]])
        if lmax >= 2:
            Q_dbf = np.hstack([Q_dbf, Q[:,4:9]/3])
        
        Q_m_pera = np.sum( Q_dbf[:,np.newaxis,:]* sph_harms, axis=2)

        assert Q_m_pera.shape == (N_a, 216)
        return Q_m_pera
    
    def Q_mesh_on_m(Q_mesh_pera, m_u0, N):
        """
        spreads the particle mesh onto the grid
        
        Input:
            Q_mesh_pera, m_u0, N
            
        Output:
            Q_mesh: 
                Nx * Ny * Nz matrix
        """

        indices_arr = np.mod(m_u0[:,np.newaxis,:]+shifts, N[np.newaxis, np.newaxis, :])
        
        Q_mesh = np.zeros((N[0], N[1], N[2]))
        
        def acc_mesh(ai, Q_mesh):
            Q_mesh[indices_arr[ai, :, 0], indices_arr[ai, :, 1], indices_arr[ai, :, 2]] += Q_mesh_pera[ai, :]
            return Q_mesh

        for ai in trange(indices_arr.shape[0]):
            Q_mesh = acc_mesh(ai, Q_mesh)
        
        return Q_mesh

    def setup_kpts_integer(N):
        """
        Outputs:
            kpts_int:
                n_k * 3 matrix, n_k = N[0] * N[1] * N[2]
        """
        N_half = N.reshape(3)

        kx, ky, kz = [np.roll(np.arange(- (N_half[i] - 1) // 2, (N_half[i] + 1) // 2 ), - (N_half[i] - 1) // 2) for i in range(3)]

        kpts_int = np.array([ki.flatten() for ki in np.meshgrid(kx, ky, kz)]).T

        kpts_int = kpts_int @ np.array([[0, 1, 0], [1, 0, 0], [0, 0, 1]])

        return kpts_int 

    def setup_kpts(box, kpts_int):
        '''
        This function sets up the k-points used for reciprocal space calculations
        
        Input:
            box:
                3 * 3, three axis arranged in rows
            kpts_int:
                n_k * 3 matrix

        Output:
            kpts:
                4 * K, K=K1*K2*K3, contains kx, ky, kz, k^2 for each kpoint
        '''
        # in this array, a*, b*, c* (without 2*pi) are arranged in column
        box_inv = np.linalg.inv(box)

        # K * 3, coordinate in reciprocal space
        kpts = 2 * np.pi * kpts_int.dot(box_inv)

        ksr = np.sum(kpts**2, axis=1)

        # 4 * K
        kpts = np.hstack((kpts, ksr[:, np.newaxis])).T

        return kpts

    def E_recip_on_grid(Q_mesh, box, N, kappa):
        """
        Computes the reciprocal part energy
        """
        
        N = N.reshape(1,1,3)
        kpts_int = setup_kpts_integer(N)
        kpts = setup_kpts(box, kpts_int)

        m = np.linspace(-2,2,5).reshape(5, 1, 1)
        # theta_k : array of shape n_k
        theta_k = np.prod(
            np.sum(
                bspline6(m + 6/2) * np.cos(2*np.pi*m*kpts_int[np.newaxis] / N),
                axis = 0
            ),
            axis = 1
        )

        S_k = np.fft.fftn(Q_mesh)

        S_k = S_k.flatten()

        E_k = 2*np.pi/kpts[3,1:]/np.linalg.det(box) * np.exp( - kpts[3, 1:] /4 /kappa**2) * np.abs(S_k[1:]/theta_k[1:])**2
        return np.sum(E_k)

    Nj_Aji_star = get_recip_vectors(N, box)
    m_u0, u0    = u_reference(positions, Nj_Aji_star)
    sph_harms   = sph_harmonics_GO(u0, Nj_Aji_star, lmax)
    Q_mesh_pera = Q_m_peratom(Q, sph_harms)
    Q_mesh      = Q_mesh_on_m(Q_mesh_pera, m_u0, N)
    E_recip     = E_recip_on_grid(Q_mesh, box, N, kappa)
    
    # Outputs energy in OPENMM units    
    return E_recip*dielectric


In [None]:
%timeit pme_reciprocal_numpy(np.repeat(force.positions, 4000 //6, axis = 0), force.box, np.repeat(force.Q, 4000 //6, axis = 0), force.kappa, force.lmax, 30, 30, 30)

100%|██████████| 3996/3996 [00:00<00:00, 136178.48it/s]
100%|██████████| 3996/3996 [00:00<00:00, 126197.67it/s]
100%|██████████| 3996/3996 [00:00<00:00, 136171.84it/s]
100%|██████████| 3996/3996 [00:00<00:00, 145644.16it/s]
100%|██████████| 3996/3996 [00:00<00:00, 129151.97it/s]
100%|██████████| 3996/3996 [00:00<00:00, 122523.20it/s]
100%|██████████| 3996/3996 [00:00<00:00, 141363.15it/s]
100%|██████████| 3996/3996 [00:00<00:00, 129074.39it/s]

2.13 s ± 142 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)



