# Mutipolar polarizable force field with fluctuating charges

In this demo, we show how to implement a **multipolar polarizable potential with fluctuating charges** with DMFF API.

In conventional models, atomic charges are pre-defined and remain unchanged during the simulation. Here, we want to implement a model that considers atomic charges as *conformer-dependent*, so that the charges can vary during a molecular dynamics simulation. This will give better description of the system's behavior.

## System preparation
Load the coordinates, box and compute neighbor list. Note that conventionally in multipolar polarizable models, the length unit is **angstrom**.

In [None]:
import jax
import jax.numpy as jnp
import openmm.app as app
import openmm.unit as unit
from dmff import Hamiltonian, NeighborList

app.Topology.loadBondDefinitions("residues.xml")
pdb = app.PDBFile("water_dimer.pdb")
rc = 4 # cutoff, in angstrom
positions = jnp.array(pdb.getPositions(asNumpy=True).value_in_unit(unit.angstrom))
box = jnp.array(
    [vec.value_in_unit(unit.angstrom) for vec in pdb.topology.getPeriodicBoxVectors()]
)
nbList = NeighborList(box, rc=rc)
nbList.allocate(positions)
pairs = nbList.pairs

## Genearate auto-differentiable multipolar polarizable (ADMP) forces

First, we will use the `dmff` to create a multipolar polarizable potential with **fixed** atomic charges.

Here, we have two types of force: 

- Dispersion force
- Multipolar polarizable PME force.

We will focus on the PME force.

In [None]:
H = Hamiltonian('forcefield.xml')
disp_pot, pme_pot = H.createPotential(pdb.topology, nonbondedCutoff=rc*unit.angstrom, step_pol=5)
disp_generator, pme_generator = H.getGenerators()
print(pme_generator)
print(pme_pot)
pme_generator.params

The function `pme_pot` takes the following actions:

- Expand **force field parameters** (oxygen and hydrogen charges, polarizabilites, etc.) pre-defined in `forcefield.xml` to each atom, which we called **atomic parameters**
- Calls the real PME kernel function to evaluate energy.

The force field parameters are stored in `pme_generator.params`. And the expansion is implemented with the *broadcast* feature of `jax.numpy.ndarray`.

In [None]:
params = pme_generator.params["Q_local"]
params[pme_generator.map_atomtype]

## Implement fluctuating charges

Since this expansion process is done internally within `pme_pot`, it is **not flexible** enough for us to specify atom-specific charges, i.e. **fluctuating charges**. 

As a result, we must re-write `pme_pot` to enable modifying the atomic charges after force field parameter expansion. 

Benifiting from the flexible APIs in DMFF, we will reuse most of the functions and variables in the `pme_generator`, only modify charges in the input parameters, i.e. the `Q_local` argument in `pme_generator.pme_force.get_energy` function.

In [None]:
from dmff.utils import jit_condition
from dmff.admp.pme import trim_val_0
from dmff.admp.spatial import v_pbc_shift


@jit_condition(static_argnums=())
def compute_leading_terms(positions, box):
    n_atoms = len(positions)
    c0 = jnp.zeros(n_atoms)
    c6_list = jnp.zeros(n_atoms)
    box_inv = jnp.linalg.inv(box)
    O = positions[::3]
    H1 = positions[1::3]
    H2 = positions[2::3]
    ROH1 = H1 - O
    ROH2 = H2 - O
    ROH1 = v_pbc_shift(ROH1, box, box_inv)
    ROH2 = v_pbc_shift(ROH2, box, box_inv)
    dROH1 = jnp.linalg.norm(ROH1, axis=1)
    dROH2 = jnp.linalg.norm(ROH2, axis=1)
    costh = jnp.sum(ROH1 * ROH2, axis=1) / (dROH1 * dROH2)
    angle = jnp.arccos(costh) * 180 / jnp.pi
    dipole = -0.016858755 + 0.002287251 * angle + 0.239667591 * dROH1 + (-0.070483437) * dROH2
    charge_H = dipole / dROH1
    charge_O = charge_H * (-2)
    C6_H = (-2.36066199 + (-0.007049238) * angle + 1.949429648 * dROH1+ 2.097120784 * dROH2) * 0.529**6 * 2625.5
    C6_O = (-8.641301261 + 0.093247893 * angle + 11.90395358 * (dROH1+ dROH2)) * 0.529**6 * 2625.5
    C6_H = trim_val_0(C6_H)
    c0 = c0.at[::3].set(charge_O)
    c0 = c0.at[1::3].set(charge_H)
    c0 = c0.at[2::3].set(charge_H)
    c6_list = c6_list.at[::3].set(jnp.sqrt(C6_O))
    c6_list = c6_list.at[1::3].set(jnp.sqrt(C6_H))
    c6_list = c6_list.at[2::3].set(jnp.sqrt(C6_H))
    return c0, c6_list


def generate_calculator(pme_generator):
    def admp_calculator(positions, box, pairs):
        c0, c6_list = compute_leading_terms(positions,box) # compute fluctuated charges
        Q_local = pme_generator.params["Q_local"][pme_generator.map_atomtype]
        Q_local = Q_local.at[:,0].set(c0)  # change fixed charge into fluctuated one
        pol = pme_generator.params["pol"][pme_generator.map_atomtype]
        tholes = pme_generator.params["tholes"][pme_generator.map_atomtype]
        mScales = pme_generator.params["mScales"]
        pScales = pme_generator.params["pScales"]
        dScales = pme_generator.params["dScales"]
        E_pme = pme_generator.pme_force.get_energy(
            positions, 
            box, 
            pairs, 
            Q_local, 
            pol, 
            tholes, 
            mScales, 
            pScales, 
            dScales
        )
        return E_pme 
    return jax.jit(admp_calculator)


**Finally, compute the energy and force!**

In [None]:
potential_fn = generate_calculator(pme_generator)
ene = potential_fn(positions, box, pairs)
print(ene)

In [None]:
force_fn = jax.grad(potential_fn, argnums=(0))
force = -force_fn(positions, box, pairs)
print(force)

The running speed of the first pass is slow because JAX is trying to track the data flow and compile the code. Once the code is compiled, it is run much faster, until the shapes of the input parameters change, trigerring a recompilation.  

In [None]:
print(-force_fn(positions, box, pairs))