In [1]:
#SBATCH -p gpu --exclusive --time 24:00:00
import sys
sys.path.append(".")
DEVICE = "cpu"
#DEVICE = "cuda:0"

# Learn the All-Atom Force Field Parameters from the Trajectory

We try to learn force field parameters that reproduce the trajectory. Therefore, short chunks of the trajectory are simulated. The simulation engine serves as a propagator $P(x,\Delta t,\theta)$ with the force field parameters $\theta.$ For consecutive snapshots of the trajectory $x(t_i)$ and $x(t_i + \Delta t),$ we minimize the loss function

$$ 
F(\theta) = \frac{1}{N} \sum_{i} \| P(x(t_i),\Delta t, \theta) - x(t_i + \Delta t) \|^{2},
$$

using a stochastic optimizer. $N$ is the number of snapshots in the trajectory (minus 1).

In [2]:
import datetime
import copy
import numpy as np
import torch
from tqdm import tqdm

from waterbox import WaterBox

Set seeds to ensure reproducibility.

In [3]:
torch.manual_seed(1)
np.random.seed(1)

In [4]:
batch_size = 1 
n_epochs = 4
n_iter = 10  ## steps between frames
waterbox = WaterBox(batch_size, device=DEVICE)

## Load Trajectory

We load the trajectory, remove equilibration, split it into pairs, and permute the pairs randomly.

In [5]:
traj = np.load("xyz_vel.npy")
permutation = torch.randperm(len(traj)-1)

xyz = torch.tensor(traj[:,0,...], device=waterbox.device)
xyz = torch.stack([xyz[:-1], xyz[1:]], axis=1)[permutation]
vel = torch.tensor(traj[:,1,...], device=waterbox.device)
vel = torch.stack([vel[:-1], vel[1:]], axis=1)[permutation]

boxtensor = torch.tensor(waterbox.mol.box, device=waterbox.device)

## Default Force Field Parameters

In [6]:
defaults = {
    "bond_length": waterbox.ff.get_bond("OT", "HT")[1],
    "bond_k": waterbox.ff.get_bond("OT", "HT")[0],
    "angle": waterbox.ff.get_angle("HT", "OT", "HT")[1],
    "angle_k": waterbox.ff.get_angle("HT", "OT", "HT")[0],
    "charges": np.array([
        waterbox.ff.get_charge(at) 
        for at in ["OT", "CLA"] # ["OT", "HT", "CLA", "SOD"]
    ]),
    "sigma": np.array([
        waterbox.ff.get_LJ(at)[0] 
        for at in ["CLA", "HT", "OT", "SOD"]
    ]),
    "epsilon": np.array([
        waterbox.ff.get_LJ(at)[1] 
        for at in ["CLA", "HT", "OT", "SOD"]
    ]),
}

## Propagator

The propagator runs a few iterations of Langevin dynamics with modified force field parameters.
It is defined as a torch module with the force field parameters as the parameters.

In [7]:
from integrator import Integrator 

class WaterBoxPropagator(torch.nn.Module):
    """Propagator. By default, the parameters are 
    initialized with the true force field parameters.
    The initial parameters can be modified by 
    passing non-defaults to the constructor arguments.
    """
    def __init__(
        self, 
        waterbox, 
        timestep=1.0,
        langevin_gamma=0.,
        temperature=None,
        bond_length=defaults["bond_length"],
        bond_k=defaults["bond_k"],
        angle=defaults["angle"],
        angle_k=defaults["angle_k"],
        charges=defaults["charges"],
        sigma=defaults["sigma"],
        epsilon=defaults["epsilon"]
    ):
        super(WaterBoxPropagator, self).__init__()
        self.temperature = temperature 
        self.device = waterbox.device
        self.timestep = timestep
        self.waterbox = waterbox
        self.langevin_gamma = langevin_gamma
        self.bond_length = torch.nn.Parameter(
            torch.tensor(
                bond_length, 
                dtype=waterbox.dtype, 
                device=waterbox.device
            )
        )
        self.bond_k = torch.nn.Parameter(
            torch.tensor(
                bond_k, 
                dtype=waterbox.dtype, 
                device=waterbox.device
            )
        )
        self.angle = torch.nn.Parameter(
            torch.tensor(
                angle, 
                dtype=waterbox.dtype, 
                device=waterbox.device
            )
        )
        self.angle_k = torch.nn.Parameter(
            torch.tensor(
                angle_k, 
                dtype=waterbox.dtype, 
                device=waterbox.device
            )
        )
        self.charges = torch.nn.Parameter(
            torch.tensor(
                charges, 
                dtype=waterbox.dtype, 
                device=waterbox.device
            )
        )
        self.sigma = torch.nn.Parameter(
            torch.tensor(
                sigma, 
                dtype=waterbox.dtype, 
                device=waterbox.device
            )
        )
        self.epsilon = torch.nn.Parameter(
            torch.tensor(
                epsilon, 
                dtype=waterbox.dtype, 
                device=waterbox.device
            )
        )
 
    def forward(self, pos, vel, niter):
        #self._set_integrator_temperature(self.temperature)
        waterbox = copy.deepcopy(self.waterbox)
        #waterbox = WaterBox(batch_size, device="cpu")
        self._apply_ff_parameters(waterbox)
        integrator = Integrator(
            waterbox.system, 
            waterbox.forces, 
            timestep=self.timestep,
            device=waterbox.device,
            gamma=self.langevin_gamma, 
            T=self.temperature
        )
        waterbox.system.pos[:] = pos
        waterbox.system.vel[:] = vel
        integrator.step(niter=niter)
        return waterbox.system.pos, waterbox.system.vel
        
    def _apply_ff_parameters(self, waterbox):
        p = waterbox.forces.par
        p.bond_params[:] = self._make_bond_params(self.bond_length, self.bond_k)
        p.angle_params[:] = self._make_angle_params(self.angle, self.angle_k)
        p.charges[:] = self._make_charges(self.charges)
        p.A[:], p.B[:] = self._make_lj(self.sigma, self.epsilon)
    
    @staticmethod
    def _make_bond_params(length, k):
        params1 = torch.stack([k, length])
        params2 = torch.tensor([0.000, 1.5139], dtype=params1.dtype, device=params1.device)
        three_bonds = torch.stack([params1, params1, params2])
        return three_bonds.repeat((291//3,1))

    @staticmethod
    def _make_angle_params(angle, k):
        params = torch.stack([k, angle])
        return params.repeat((97,1))

    @staticmethod
    def _make_charges(charges): #q_oxygen, q_hydrogen, q_anion, q_cation):
        q_oxygen = charges[0]
        q_hydrogen = -charges[0]*0.5
        q_anion = charges[1]
        q_cation = -charges[1]
        params1 = torch.stack([q_oxygen, q_hydrogen, q_hydrogen]).repeat((97,))
        params2 = torch.stack([q_anion, q_cation], 0)
        return torch.cat([params1, params2])

    @staticmethod
    def _make_lj(sigma, epsilon):
        # Lorentz - Berthelot combination rule
        sigma_table = 0.5 * (sigma + sigma[:, None])
        eps_table = torch.sqrt(epsilon * epsilon[:, None])
        sigma_table_6 = sigma_table ** 6
        sigma_table_12 = sigma_table_6 * sigma_table_6
        A = eps_table * 4 * sigma_table_12
        B = eps_table * 4 * sigma_table_6
        return A, B        

In [8]:
class ParameterLogger:
    """Write parameters to a npz file during optimization."""
    def __init__(self, filename=None, defaults=defaults, flush_interval=10):
        default_filename = (
            datetime.datetime.now()
            .strftime("learn_%Y-%m-%d_%Hh%Mm%Ss.npz")
        )
        self.filename = default_filename if filename is None else filename
        self.defaults = defaults
        self.data = {
            key: [] for key in self.defaults
        }
        self.data["epoch"] = []
        self.data["it"] = []
        self.data["loss"] = []
        self.flush_interval = flush_interval
        self.i = 0
    
    def __call__(self, epoch, it, loss, propagator):
        self.i += 1
        for key in defaults:
            assert hasattr(propagator, key)
        for key in self.defaults:
            self.data[key].append(getattr(propagator, key).clone().detach().cpu().numpy())
        self.data["epoch"].append(epoch)
        self.data["it"].append(it)
        self.data["loss"].append(loss.item())
        if self.i % self.flush_interval == 0:
            self.flush()
    
    def flush(self):
        np.savez(self.filename, **self.data)
        

## Initialize Propagator with Modified Parameters

Modify charges.

In [9]:
modified = {key: defaults[key] for key in defaults}
modified["charges"] *= 0.01

Create the propagator with the modified parameters and the optimizer.

In [10]:
propagator = WaterBoxPropagator(waterbox, **modified)
optim = torch.optim.Adam([propagator.charges], lr=1e-3)
scheduler = torch.optim.lr_scheduler.StepLR(optim, step_size=2, gamma=0.1)

Create the Logger.

In [11]:
logger = ParameterLogger()

Define the loss function as the mean squared distance between the propagated positions from the trajectory and the modified parameters.

In [12]:
from torchmd.forces import wrap_dist

def periodic_mse_loss(pos1, pos2, box):
    return (wrap_dist(pos1-pos2, box)**2).mean()

# Test:
periodic_mse_loss(xyz[0,0], xyz[0,1], boxtensor)

tensor(0.0088, dtype=torch.float64)

In [None]:
for epoch in range(n_epochs):
    print(f"Epoch {epoch}/{n_epochs}")
    for ibatch in tqdm(range(len(xyz) // batch_size - 1)):
        optim.zero_grad()
        start_xyz = xyz[batch_size*ibatch:batch_size*(ibatch+1), 0, ...].clone()
        end_xyz = xyz[batch_size*ibatch:batch_size*(ibatch+1), 1, ...].clone()
        start_vel = vel[batch_size*ibatch:batch_size*(ibatch+1), 0, ...].clone()
        new_xyz, new_vel = propagator(start_xyz, start_vel, niter=n_iter)
        loss = periodic_mse_loss(
            new_xyz, 
            end_xyz, 
            boxtensor
        )
        loss.backward()
        logger(epoch, ibatch, loss, propagator)
        optim.step()
    scheduler.step()

  0%|          | 0/188 [00:00<?, ?it/s]

Epoch 0/4


  2%|▏         | 3/188 [00:02<02:56,  1.05it/s]