## Assign parameters to molecules in the dataset

In [1]:
from openff.toolkit import Molecule, ForceField
import tqdm
import smee.converters
from pydantic import Field

In [2]:
# uncomment if reloading data
import datasets

dataset = datasets.Dataset.load_from_disk("test-smee-data")

In [3]:
# reformat dataset lists to torch tensors
dataset.set_format('torch', columns=['energy', 'coords','forces'], output_all_columns=True)

In [4]:
# this is what a single entry looks like
dataset[0]

{'coords': tensor([-0.5399,  2.1674, -3.7936,  ...,  0.1233,  3.3214,  2.1324]),
 'energy': tensor([-322860.4375, -322860.5000, -322860.6562, -322860.7812, -322860.8125,
         -322860.7500, -322860.6250, -322860.5000, -322860.6250, -322860.7500,
         -322860.8125, -322860.7812, -322860.6562, -322860.5000, -322860.4375,
         -322860.5000, -322860.6562, -322860.7812, -322860.8125, -322860.5000,
         -322860.4375, -322860.5000, -322860.6250, -322860.7500]),
 'forces': tensor([ 0.0021,  0.0088, -0.0046,  ..., -0.0018,  0.0060, -0.0051]),
 'smiles': '[H:14][C:7]([H:15])([H:16])[C:6]1=[C:8]([C:3](=[N:4][O:5]1)[O:2][C:1]([H:11])([H:12])[H:13])[C:9]([H:17])([H:18])[O:10][H:19]'}

Below we specify a starting force field.
Normally we would initialize parameters using the Modified Seminario method,
[example here](https://github.com/openforcefield/sage-2.2.1/blob/main/03_generate-initial-ff/create-msm-ff.py),
but here we just start from Sage 2.2.1.

[Josh's repo](https://github.com/jthorton/SPICE-SMEE/blob/main/fit-v1/training/001-expand_torsions.py)
has examples on expanding torsions too.

In [5]:
import torch
import smee

@smee.potentials.potential_energy_fn("TwoMinima", "twominima")
def compute_twominima_energy(
    system: smee.TensorSystem,
    potential: smee.TensorPotential,
    conformer: torch.Tensor,
) -> torch.Tensor:
    is_batched = conformer.ndim == 3

    if not is_batched:
        conformer = torch.unsqueeze(conformer, 0)

    parameters = potential.parameters
    attributes = potential.attributes

    k1 = parameters[:, 0]
    k2 = parameters[:, 1]
    periodicity = parameters[:, 2]
    phase = parameters[:, 3]

    central_atom = conformer[:, 0, :]
    bonded_atoms = conformer[:, 1:, :]
    normal_vector = torch.cross(bonded_atoms[:, 1] - bonded_atoms[:, 0], 
                                bonded_atoms[:, 2] - bonded_atoms[:, 0])
    normal_vector = normal_vector / torch.norm(normal_vector, dim=-1, keepdim=True)

    oop_vector = central_atom - bonded_atoms[:, 0]
    theta = torch.acos(torch.sum(oop_vector * normal_vector, dim=-1) / torch.norm(oop_vector, dim=-1))

    energy_1 = k1 * (1 + torch.cos(periodicity * theta - phase))
    energy_2 = k2 * (1 + torch.cos(2 * periodicity * theta + phase))
    
    energy = (energy_1 - energy_2).sum(-1)

    if not is_batched:
        energy = torch.squeeze(energy, 0)

    return energy


In [6]:
from smee.converters.openff.valence import convert_valence_handlers
from smirnoff_plugins.collections.valence import SMIRNOFFTwoMinimaCollection
import openff

def convert_twominima(
    handlers: list[SMIRNOFFTwoMinimaCollection],
    topologies: list[openff.toolkit.Topology],
) -> tuple[smee.TensorPotential, list[smee.ValenceParameterMap]]:
    return convert_valence_handlers(handlers, topologies)

In [7]:
from openff.toolkit import ForceField
import openff.interchange as interchange

twominima_force_field = ForceField("two_minima.offxml", load_plugins = True)

molecule = openff.toolkit.Molecule.from_smiles("c1n(CCO)c(C(F)(F)(F))cc1CNCCl")
molecule.generate_conformers(n_conformers=1)

conformer = torch.tensor(molecule.conformers[0].m_as(openff.units.unit.angstrom))

twominima_interchange = interchange.Interchange.from_smirnoff(
    twominima_force_field, molecule.to_topology()
)

twominima_interchange

Interchange with 8 collections, non-periodic topology with 28 atoms.

In [8]:
handler_types = {
        handler_type
        for interchange in [twominima_interchange]
        for handler_type in interchange.collections
    }

In [9]:
handler_types

{'Angles',
 'Bonds',
 'Constraints',
 'Electrostatics',
 'ImproperTorsions',
 'ProperTorsions',
 'TwoMinima',
 'vdW'}

In [10]:
from openff.toolkit import ForceField, Molecule, Topology

topology = Topology.from_molecules([molecule])

molecule_force_list = twominima_force_field.label_molecules(topology)

for mol_idx, mol_forces in enumerate(molecule_force_list):
    print(f"Forces for molecule {mol_idx}")
    for force_tag, force_dict in mol_forces.items():
        print(f"\n{force_tag}:")
        for atom_indices, parameter in force_dict.items():
            atomstr = ""
            for idx in atom_indices:
                atomstr += f"{idx:>3}"
            print(
                f"atoms: {atomstr}  parameter_id: {parameter.id}  smirks {parameter.smirks}"
            )



Forces for molecule 0

Constraints:

Bonds:
atoms:   0  1  parameter_id: b8  smirks [#6X3:1]-[#7X3:2]
atoms:   0 11  parameter_id: b6  smirks [#6X3:1]=[#6X3:2]
atoms:   0 16  parameter_id: b85  smirks [#6X3:1]-[#1:2]
atoms:   1  2  parameter_id: b7  smirks [#6:1]-[#7:2]
atoms:   1  5  parameter_id: b8  smirks [#6X3:1]-[#7X3:2]
atoms:   2  3  parameter_id: b1  smirks [#6X4:1]-[#6X4:2]
atoms:   2 17  parameter_id: b84  smirks [#6X4:1]-[#1:2]
atoms:   2 18  parameter_id: b84  smirks [#6X4:1]-[#1:2]
atoms:   3  4  parameter_id: b14  smirks [#6:1]-[#8:2]
atoms:   3 19  parameter_id: b84  smirks [#6X4:1]-[#1:2]
atoms:   3 20  parameter_id: b84  smirks [#6X4:1]-[#1:2]
atoms:   4 21  parameter_id: b88  smirks [#8:1]-[#1:2]
atoms:   5  6  parameter_id: b2  smirks [#6X4:1]-[#6X3:2]
atoms:   5 10  parameter_id: b6  smirks [#6X3:1]=[#6X3:2]
atoms:   6  7  parameter_id: b69  smirks [#6X4:1]-[#9:2]
atoms:   6  8  parameter_id: b69  smirks [#6X4:1]-[#9:2]
atoms:   6  9  parameter_id: b69  smirks [#6X

In [15]:
twominima_tensor_ff, [twominima_topology] = smee.converters.convert_interchange(twominima_interchange)


KeyError: 'TwoMinima'

In [None]:
all_smiles = []
interchanges = []
for entry in tqdm.tqdm(dataset):
    mol = Molecule.from_mapped_smiles(
        entry["smiles"],
        allow_undefined_stereo=True
    )
    all_smiles.append(entry["smiles"])
    interchange = twominima_force_field.create_interchange(mol.to_topology())
    interchanges.append(interchange)
    
smee_force_field, smee_topologies = smee.converters.convert_interchange(interchanges)
topologies = dict(zip(all_smiles, smee_topologies))

## Fit

Now we can set up and run the fit.

In [None]:
import descent.train
import descent.targets.energy

import math
import pathlib
import tensorboardX
import more_itertools


In [None]:
# specify which parameters to train
# and some details about them
# they're scaled so they're roughly on the same order of magnitude

parameters = {
    "Bonds": descent.train.ParameterConfig(
        cols=["k", "length"],
        scales={"k": 1e-2, "length": 1.0}, # normalize so roughly equal
        limits={"k":[0.0, None], "length": [0.0, None]}
        # the include/exclude types are Interchange PotentialKey.id's -- typically SMIRKS
        # include=[], <-- bonds to train. Not specifying trains all
        # exclude=[], <-- bonds NOT to train
    ),
    "Angles": descent.train.ParameterConfig(
        cols=["k", "angle"],
        scales={"k": 1e-2, "angle": 1.0},
        limits={"k": [0.0, None], "angle": [0.0, math.pi]}
    ),
    "ProperTorsions": descent.train.ParameterConfig(
        # fit ks
        cols=["k"],
        scales={"k": 1.0},
    ),
    "TwoMinima": descent.train.ParameterConfig(
        cols=["k"],
        scales={"k": 1.0},
        
    )

}

In [None]:
trainable = descent.train.Trainable(
    force_field=smee_force_field,
    parameters=parameters,
    attributes={}
)

In [None]:
print(twominima_force_field.registered_parameter_handlers)

In [None]:
# optional below if you want cool tensorboard logging
def write_metrics(
        epoch: int,
        loss: torch.Tensor,
        loss_energy: torch.Tensor,
        loss_forces: torch.Tensor,
        writer: tensorboardX.SummaryWriter
):
    print(f"epoch={epoch} loss={loss.detach().item():.6f}", flush=True)

    writer.add_scalar("loss", loss.detach().item(), epoch)
    writer.add_scalar("loss_energy", loss_energy.detach().item(), epoch)
    writer.add_scalar("loss_forces", loss_forces.detach().item(), epoch)

    writer.add_scalar("rmse_energy", math.sqrt(loss_energy.detach().item()), epoch)
    writer.add_scalar("rmse_forces", math.sqrt(loss_forces.detach().item()), epoch)
    writer.flush()

Specify some hyperparameters, n_epochs is intentionally very low to guarantee fast execution.

In [None]:
N_EPOCHS = 10
LEARNING_RATE = 0.01
BATCH_SIZE = 500

In [None]:
# make directory to save files in
directory = pathlib.Path("my-smee-fit")
directory.mkdir(exist_ok=True, parents=True)


Run fit below.

In [None]:
# load tensorboard extension so we can view in notebook
%load_ext tensorboard

In [None]:
trainable_parameters = trainable.to_values()
device = trainable_parameters.device.type

torch.autograd.set_detect_anomaly(True)

with tensorboardX.SummaryWriter(str(directory)) as writer:
    optimizer = torch.optim.Adam([trainable_parameters], lr=LEARNING_RATE, amsgrad=True)
    dataset_indices = list(range(len(dataset)))

    for i in range(N_EPOCHS):
        ff = trainable.to_force_field(trainable_parameters)
        total_loss = torch.zeros(size=(1,), device=device)
        energy_loss = torch.zeros(size=(1,), device=device)
        force_loss = torch.zeros(size=(1,), device=device)
        grad = None
    
        for batch_ids in tqdm.tqdm(
            more_itertools.batched(dataset_indices, BATCH_SIZE),
            desc='Calculating energies',
            ncols=80, total=math.ceil(len(dataset) / BATCH_SIZE)
        ):
            batch = dataset.select(indices=batch_ids)
            true_batch_size = len(dataset)
            batch_configs = sum([len(d["energy"]) for d in batch])

            e_ref, e_pred, f_ref, f_pred = descent.targets.energy.predict(
                batch, ff, topologies, "mean"
            )   
            # L2 loss
            batch_loss_energy = ((e_pred - e_ref) ** 2).sum() / true_batch_size
            batch_loss_force = ((f_pred - f_ref) ** 2).sum() / true_batch_size

            # Equal sum of L2 loss on energies and forces
            batch_loss = batch_loss_energy + batch_loss_force

            (batch_grad, ) = torch.autograd.grad(batch_loss, trainable_parameters, create_graph=True)
            batch_grad = batch_grad.detach()
            if grad is None:
                grad = batch_grad
            else:
                grad += batch_grad
            
            # keep sum of squares to report MSE at the end
            total_loss += batch_loss.detach()
            energy_loss += batch_loss_energy.detach()
            force_loss += batch_loss_force.detach()
        
        trainable_parameters.grad = grad
        
        write_metrics(
            epoch=i, loss=total_loss, loss_energy=energy_loss,
            loss_forces=force_loss, writer=writer
        )

        optimizer.step()
        optimizer.zero_grad()

        if i % 10 == 0:
            torch.save(
                trainable.to_force_field(trainable_parameters),
                directory / f"force-field-epoch-{i}.pt"
            )

    torch.save(
        trainable.to_force_field(trainable_parameters),
        directory / "final-force-field.pt"
    )
    

Metrics can be viewed in tensorboard below.

`tensorboard --logdir my-smee-fit` can also be run on command line instead of in the notebook.

In [None]:
%tensorboard --logdir my-smee-fit

## Convert back to OFFXML

In [None]:
from collections import defaultdict

for potential in smee_force_field.potentials:
    handler_name = potential.parameter_keys[0].associated_handler

    parameter_attrs = potential.parameter_cols
    parameter_units = potential.parameter_units

    if handler_name in ["Bonds", "Angles"]:
        handler = twominima_force_field.get_parameter_handler(handler_name)
        for i, opt_parameters in enumerate(potential.parameters):
            smirks = potential.parameter_keys[i].id
            ff_parameter = handler[smirks]
            opt_parameters = opt_parameters.detach().cpu().numpy()
            for j, (p, unit) in enumerate(zip(parameter_attrs, parameter_units)):
                setattr(ff_parameter, p, opt_parameters[j] * unit)

    elif handler_name in ["ProperTorsions"]:
        handler = twominima_force_field.get_parameter_handler(handler_name)
        k_index = parameter_attrs.index('k')
        p_index = parameter_attrs.index('periodicity')
        # we need to collect the k values into a list across the entries
        collection_data = defaultdict(dict)
        for i, opt_parameters in enumerate(potential.parameters):
            smirks = potential.parameter_keys[i].id
            ff_parameter = handler[smirks]
            opt_parameters = opt_parameters.detach().cpu().numpy()
            # find k and the periodicity
            k = opt_parameters[k_index] * parameter_units[k_index]
            p = int(opt_parameters[p_index])
            collection_data[smirks][p] = k
        # now update the force field
        for smirks, k_s in collection_data.items():
            ff_parameter = handler[smirks]
            k_mapped_to_p = [k_s[p] for p in ff_parameter.periodicity]
            ff_parameter.k = k_mapped_to_p

    elif handler_name in ["ImproperTorsions"]:
        k_index = parameter_attrs.index('k')
        handler = twominima_force_field.get_parameter_handler(handler_name)
        # we only fit the v2 terms for improper torsions so convert to list and set
        for i, opt_parameters in enumerate(potential.parameters):
            smirks = potential.parameter_keys[i].id
            ff_parameter = handler[smirks]
            opt_parameters = opt_parameters.detach().cpu().numpy()
            ff_parameter.k = [opt_parameters[k_index] * parameter_units[k_index]]

    elif handler_name == "TwoMinima":
        handler = twominima_force_field.get_parameter_handler(handler_name)
        param_indices = {
            name: i for i, name in enumerate(parameter_attrs)
        }

        for i, opt_parameters in enumerate(potential.parameters):
            smirks = potential.parameter_keys[i].id
            ff_parameter = handler[smirks]
            opt_parameters = opt_parameters.detach().cpu().numpy()

            for param_name, unit in zip(parameter_attrs, parameter_units):
                value = opt_parameters[param_indices[param_name]] * unit
                setattr(ff_parameter, param_name, value)


twominima_force_field.to_file("two_minima_final-force-field.offxml")