# The Weighted Ensemble Method

## Sodium Chloride Association Kinetics, with OpenMM

The Weighted Ensemble (WE) method provides a route to estimating kinetic and thermodynamic parameters for many different types of biomolecular simulation problem. For a good introduction, see this [2017 review from Zuckerman and Chong](https://pubmed.ncbi.nlm.nih.gov/28301772/).

The aim of this notebook is to illustrate the key aspects of "steady state" type WE simulations (walkers, progress coordinates, binning, splitting and merging, recycling) with a version of the sodium chloride association kinetics example that also features in the [WESTPA](https://pubmed.ncbi.nlm.nih.gov/26392815/) tutorials, but using `WElib` instead.

There is a barrier to the formation of a Na+ - Cl- ion pair in solution as the hydration shells that surround each ion - when separated -  must be disrupted. In this Weighted Ensemble simulation, we begin with a box of water containing one sodium and one chloride ion, about 11 Angstroms (1.1 nm) apart. We measure the rate at which this forms an ion pair (defined as <= 2.6 Angstroms apart).


---------

### Part 1: Building the OpenMM system
We begin by importing the packages required to build the simulation system in OpenMM: 

In [None]:
import openmm.app as omm_app
import openmm as omm
import openmm.unit as unit

Now we create the `system`, and then a `simulation` object:

In [None]:
prmtop = omm_app.AmberPrmtopFile('nacl.parm7')
inpcrd = omm_app.AmberInpcrdFile('nacl_unbound.ncrst')
system = prmtop.createSystem(nonbondedMethod=omm_app.PME, nonbondedCutoff=10.0*unit.angstrom,
        constraints=omm_app.HBonds)

T = 300.0 * unit.kelvin  ## temperature
fricCoef = 1.0 / unit.picoseconds ## friction coefficient 
stepsize = 0.002 * unit.picoseconds ## integration step size
integrator = omm.LangevinIntegrator(T, fricCoef, stepsize)

simulation = omm_app.Simulation(prmtop.topology, system, integrator)
simulation.context.setPositions(inpcrd.positions)
if inpcrd.boxVectors is not None:
    simulation.context.setPeriodicBoxVectors(*inpcrd.boxVectors)
    
print(f'OpenMM will use the {simulation.context.getPlatform().getName()} platform')

### Part 2: Building the WE workflow
Now we import WElib and other utilities that will be useful. Many are the same as those used for the simple double well potential example, but we have OpenMM-compatible versions of the `Stepper` and `ProgressCoordinator`:

In [None]:
import mdtraj as mdt
import numpy as np
import time
from WElib import Walker, FunctionStepper, OMMSimpleDistanceProgressCoordinator, Recycler, StaticBinner, SplitMerger

Create some walkers, each begins in the initial, dissociated, state:

In [None]:
initial_state = simulation.context.getState(getPositions=True, enforcePeriodicBox=True)

n_reps = 5
walkers = [Walker(initial_state, 1.0/n_reps) for i in range(n_reps)]
for w in walkers:
    print(w)

The progress coordinate will be the distance between the sodium and chloride ion:

In [None]:
na_atom = 0 # index of the sodium atom in the system
cl_atom = 1 # index of the chloride ion in the system
progress_coordinator = OMMSimpleDistanceProgressCoordinator(a1=na_atom, a2=cl_atom)
walkers = progress_coordinator.run(walkers)
for w in walkers:
    print(w)

We use the same bin boundaries as in the WESTPA tutorials. Notice these are closer-spaced at shorter distances, as the solvation shells get "stiffer":

In [None]:
binner = StaticBinner([0, 0.26, 0.28, 0.3, 0.32, 0.34, 0.36, 0.38, 0.4, 0.45, 0.5, 
                 0.55, 0.6, 0.7, 0.8, 0.9, 1.0, 1.1, 1.2, 1.3, 1.4, 1.5])
walkers = binner.run(walkers)
for w in walkers:
    print(w)

We will recycle walkers when the Na-Cl distance falls below 0.26 nm. As the progress coordinate is something that gets smaller as we move towards the target state, this is a "retrograde" coordinate:

In [None]:
recycler = Recycler(initial_state, 0.26, retrograde=True)
walkers = recycler.run(walkers)
for w in walkers:
    print(w)
print('recycled flux = ',recycler.flux)

The SplitMerger is just the same as that used for the DWP example. We create it and run it, even though we know that at this time it will have nothing to do:

In [None]:
splitmerger = SplitMerger(n_reps)
walkers = splitmerger.run(walkers)
for w in walkers:
    print(w)

Create a function that will run an OpenMM simulation. The function needs to take the current state of the system as its first argument, and return the final state at the end of the MD. Then use this created function to initialise a FunctionStepper, as was done for the DWP example.

In [None]:
def OMMSim(state, simulation, nsteps):
    simulation.context.setPositions(state.getPositions())
    simulation.context.setPeriodicBoxVectors(*state.getPeriodicBoxVectors())
    simulation.step(nsteps)
    return simulation.context.getState(getPositions=True, enforcePeriodicBox=True)

stepper = FunctionStepper(OMMSim, simulation, 500)

Now we will apply the stepper. Note this will take some time longer to run than in the DWP example, exactly how long will depend on power of your laptop/workstation:

In [None]:
start_time = time.time()
new_walkers = stepper.run(walkers) # this is where the MD happens
end_time = time.time()
print(f'{len(walkers)} simulations completed in {end_time-start_time:6.1f} seconds')

Let's see where those MD steps have moved each walker to:

In [None]:
new_walkers = progress_coordinator.run(new_walkers)
new_walkers = binner.run(new_walkers)
new_walkers = recycler.run(new_walkers)
print('recycled flux = ', recycler.flux)
for w in new_walkers:
    print(w)

Apply the SplitMerger to the list of walkers:

In [None]:
new_walkers = splitmerger.run(new_walkers)
for w in new_walkers:
    print(w)

### Part 3: Iterating the WE workflow
OK, that's all the components in place, they have been tested individually and seem to be bahaving. Time to run a few cycles:

In [None]:
n_cycles=10
print(' cycle    n_walkers   left-most bin  right-most bin   flux')
for i in range(n_cycles):
    new_walkers = stepper.run(new_walkers)
    new_walkers = progress_coordinator.run(new_walkers)
    new_walkers = binner.run(new_walkers)
    new_walkers = recycler.run(new_walkers)
    if recycler.flux > 0.0:
        new_walkers = progress_coordinator.run(new_walkers)
        new_walkers = binner.run(new_walkers)
    new_walkers = splitmerger.run(new_walkers)
    occupied_bins = list(binner.bin_weights.keys())
    print(f' {i:3d} {len(new_walkers):10d} {min(occupied_bins):12d} {max(occupied_bins):14d} {recycler.flux:20.8f}')

The take-home message should be fairly obvious: even on a system as small as this, you can't really run WE simulations interactively - they need time and considerable compute resources. We'll come to how you can move from Jupyter Notebooks to HPC systems in a later part of the workshop. But for now, let's see how we can get some useful/interesting data from the stepper's `Recorder`.

#### Generating trajectory files
You can use MDTraj to write out the path taken so far by any of the walkers in the form of a trajectory that could be visualised. 

The `replay` method of the recorder outputs a list of the states visited by the given walker. Being OpenMM states, these have a `getPositions` method that can be used to extract atom coordinates, which can then be used to generate an MDTraj `trajectory`:

In [None]:
chosen_walker = 0
top = mdt.load_topology('nacl.parm7')
xyz = []
for s in stepper.recorder.replay(new_walkers[chosen_walker]):
    xyz.append(s.getPositions(asNumpy=True) / unit.nanometer)
t = mdt.Trajectory(xyz, top)
print(t)
t.save(f'walker_trajectory_{chosen_walker}.nc')

#### Plotting progress coordinates
You can plot the history of the progress coordinate of a walker. Create a list of fresh walkers, each initialised with one of the states visited by the chosen walker. Then pass this list through the `ProgressCoordinator` to add PC data, then you can plot this:

In [None]:
from matplotlib import pyplot as plt
%matplotlib inline
chosen_walker = -1
walker_list = [Walker(state, 1.0) for state in stepper.recorder.replay(new_walkers[chosen_walker])]
walker_list = progress_coordinator.run(walker_list)
plt.plot([w.pc for w in walker_list])
plt.xlabel('step #')
plt.ylabel('progress coordinate')

### Analysis of a longer simulation

We have provided you with the log file, `nacl.log` obtained when this simulation was run for 500 cycles (each cycle being 2ps, rather than 1ps as above).

In [None]:
# Extract data from the log file. Get:
#
# n_walkers: the number of walkers each cycle
# flux: the recycled flux, each cycle
# bin_weights: a dictionary with the cumulative weight of simulation in each bin
#
with open('nacl.log') as f:
    data = f.readlines()

n_walkers = []
flux = []
for d in data[1:-1]:
    w = d.split()
    n_walkers.append(int(w[1]))
    flux.append(float(w[4]))

n_walkers = np.array(n_walkers)
flux = np.array(flux)
bin_weights = eval(data[-1])

# normalise bin weights:
mean_weights = np.array(list(bin_weights.values()))
mean_weights /= mean_weights.sum() 

Plot key data:

In [None]:
from matplotlib import pyplot as plt
%matplotlib inline

plt.figure(figsize=(10,8))
plt.subplot(221)
plt.plot(flux)
plt.xlabel('cycle #')
plt.ylabel('flux)')
plt.subplot(222)
plt.plot(n_walkers)
plt.xlabel('cycle #')
plt.ylabel('n_walkers)')
plt.subplot(223)
plt.plot(mean_weights)
plt.xlabel('bin #')
plt.ylabel('relative weight)')
print(f'mean flux = {flux[30:].mean():6.4g}')

The erratic pattern of flux recycling, and the rapid increase and then plateauing in the number of walkers each cycle, are apparent. The majority of the simulation weight remains in the last bin (Na-Cl sepaation > 1.5 nm). To calculate the association rate from the flux, we need to decide on where the boundary between the unassociated and associated states is, and - as this is an association rate constant with units of 1/(time\*concentration) - do a volume correction.

Zooming in a bit on the weights data reveals a 'kink' in the profile that is a fair guide to where the transition state probably is (there is no neeed to be super-exact about this in a case like this). It suggests we can regard the first 10 bins as being on the associated side of the barrier, so the rest count towards the unassociated concentration:

In [None]:
plt.plot(mean_weights[:18])
plt.xlabel('bin #')
plt.ylabel('relative weight)')

Now the volume correction. The maths below calculates this for a triclinic periodic cell:

In [None]:
bv = inpcrd.boxVectors
a, b, c = [np.linalg.norm(b) for b in bv] # unit cell vector lengths
unit_vectors = [b / np.linalg.norm(b) for b in bv]
cosalpha = np.dot(unit_vectors[1], unit_vectors[2]) #
cosbeta = np.dot(unit_vectors[0], unit_vectors[2])  # unit cell angles
cosgamma = np.dot(unit_vectors[0], unit_vectors[1]) #
volume = a*b*c*(1 - cosalpha**2 - cosbeta**2 - cosgamma**2) + 2* np.sqrt(np.abs(cosalpha*cosbeta*cosgamma))
print(f'unit cell volume = {volume:6.4g} nm**3')

In [None]:
boundary_bin = 10 # boundary between what's considered "associated" and "disassociated"
w_u = mean_weights[boundary_bin:].sum() / mean_weights.sum()
print(f'unbound weight = {w_u:6.4g}')

NA = 6.022e+23
nm3_to_dm3 = 1e-24
time_step_to_seconds = 1 / 5e11 # the WE simulations are 2 ps per cycle
concentration = w_u / (volume * NA * nm3_to_dm3)
print(f'concentration of unassociated ion = {concentration:6.4g} M')

k_assoc = flux.mean() / (concentration * time_step_to_seconds)
print(f'Association rate constant = {k_assoc:6.4g} / M.second')

The result is quite close to the diffusion limit for bimolecular asssociation in water (about 7e9 /M.second, see [here](https://en.wikipedia.org/wiki/Diffusion-controlled_reaction).

### Experiments to try:

What happens to the predicted association constant if you decide to move the division between bound and unbound states to a different bin boundary?

You will find a restart file for the "bound" state of the NaCl system in this directory. Try to construct a WE workflow to predict the unbinding rate.