# Training Against QM Energies and Gradients

This notebook aims to show how the [`descent`](https://github.com/SimonBoothroyd/descent) framework in combination with
[`smirnoffee`](https://github.com/SimonBoothroyd/smirnoffee) can be used to train a set of SMIRNOFF force field bond and
angle force constant parameters against the QM computed energies and associated gradients of a small molecule in
multiple conformers.

For the sake of clarity all warning will be disabled:

In [1]:
import warnings
warnings.filterwarnings('ignore')
import logging
logging.getLogger("openff.toolkit").setLevel(logging.ERROR)

### Retrieving the QM training set

For this example we will be training against QM energies which have been computed by and stored within the
[QCArchive](https://qcarchive.molssi.org/), which are easily retrieved using the [OpenFF QCSubmit](https://github.com/openforcefield/openff-qcsubmit)
package.

We begin by importing the records associated with the `OpenFF Optimization Set 1` optimization data set:

In [2]:
from qcportal import FractalClient

from openff.qcsubmit.results import OptimizationResultCollection

result_collection = OptimizationResultCollection.from_server(
    client=FractalClient(),
    datasets="OpenFF Optimization Set 1"
)



which we will then filter to retain a small molecule which will be fast to train on as a demonstration:

In [3]:
from openff.qcsubmit.results.filters import ConformerRMSDFilter, SMILESFilter

result_collection = result_collection.filter(
    SMILESFilter(smiles_to_include=["CC(=O)NCC1=NC=CN1C"]),
    # Only retain conformers with an RMSD greater than 0.5 Å.
    ConformerRMSDFilter(max_conformers=10, rmsd_tolerance=0.5)
)

print(f"N Molecules:  {result_collection.n_molecules}")
print(f"N Conformers: {result_collection.n_results}")

N Molecules:  1
N Conformers: 3


You should see that our filtered collection contains the 6 results, which corresponds to 6 minimized conformers (and
their associated energy computed using the OpenFF default B3LYP-D3BJ spec) for the molecule we filtered for above.

### Defining the objective ( / loss) function

For this example we will be training our force field parameters against:

* the relative energies between each conformer with the first conformer of the molecule
* the deviations between the QM and MM gradients projected along the redundant internal coordinates (RIC) of
  the molecule.

The construction of such a loss function is made trivial using the built-in ``EnergyObjective`` class which can be
created directly from the collection of optimization results we retrieved above:

We first load in the initial force field parameters ($\theta$) using the [OpenFF Toolkit](https://github.com/openforcefield/openff-toolkit):

In [4]:
from openff.toolkit.typing.engines.smirnoff import ForceField
initial_force_field = ForceField("openff_unconstrained-1.0.0.offxml")

which we can then use to construct our contribution objects:

In [5]:
from descent import metrics, transforms
from descent.objectives.energy import EnergyObjective

objective_contributions = EnergyObjective.from_optimization_results(
    result_collection,
    initial_force_field,
    # State that we want to include energies and gradients when computing
    # the contribution to the loss function.
    include_energies=True,
    include_gradients=True,
    # Specify that we want use energies relative to the first conformer
    # when evaluating the loss function
    energy_transforms=transforms.relative(index=0),
    # Use the built-in MSE metric when comparing the MM and QM relative
    # energies.
    energy_metric=metrics.mse(),
    # For this example with will use the QM and MM gradients directly when
    # computing the loss function.
    gradient_transforms=transforms.identity(),
    # Use the built-in MSE metric when comparing the MM and QM gradients
    gradient_metric=metrics.mse(),
    # State that we want to project the gradients along the RICs
    gradient_coordinate_system="ric"
)

The returned `objective_contributions` will contain one objective object per unique molecule in the
`result_collection`:

In [6]:
len(objective_contributions)

1

as we filtered our initial result collection to only contain a single molecule, so too do we only have a single
contribution.

### Specifying the parameters to train

For this example will will train all of the bond and force constants that were assigned to our molecule
of interest:

In [7]:
parameter_delta_ids = sorted(
    {
        (handler_type, potential_key, attribute)
        for contribution in objective_contributions
        for handler_type, potential_key, attribute in contribution.parameter_ids
        if handler_type in ["Bonds", "Angles"] and attribute == "k"
    },
    key=lambda x: x[0],
    reverse=True
)

where here we have made use of the ``parameter_ids`` property that contains the unique identifiers of
the parameters that were assigned to the molecule referenced by the object.

In [8]:
parameter_delta_ids[:2]

[('Bonds',
  PotentialKey(id='[#6X3:1](=[#8X1+0])-[#7X3:2]', mult=None, associated_handler='Bonds'),
  'k'),
 ('Bonds',
  PotentialKey(id='[#6X3:1]=[#6X3:2]', mult=None, associated_handler='Bonds'),
  'k')]

These ids are comprised of the type of SMIRNOFF parameter handler that the parameter originated from,
a key containing the id (in this case the SMIRKS pattern) associated with the parameter and the specific
attribute of the parameter (e.g. the force constant ``k``).

These keys will allow us to map our tensor of delta values:

In [9]:
import torch

parameter_delta = torch.zeros(len(parameter_delta_ids), requires_grad=True)

easily back to more meaningful force field parameters.

### Training the force field parameters

We are finally ready to begin training our force field parameters, or more precisely, the delta value that
we should perturb the force field parameters by to reach better agreement with the training data.

Here we will use the 'boilerplate `pytorch` optimization loop':

In [10]:
lr = 0.01
n_epochs = 200

optimizer = torch.optim.Adam([parameter_delta], lr=lr)

for epoch in range(n_epochs):

    loss = torch.zeros(1)

    for objective_contribution in objective_contributions:
        loss += objective_contribution(parameter_delta, parameter_delta_ids)

    loss.backward()

    optimizer.step()
    optimizer.zero_grad()

    if epoch % 20 == 0:
        print(f"Epoch {epoch}: loss={loss.item()}")

Epoch 0: loss=541.970458984375
Epoch 20: loss=316.5989990234375
Epoch 40: loss=302.3933410644531
Epoch 60: loss=298.25592041015625
Epoch 80: loss=296.67816162109375
Epoch 100: loss=296.056396484375
Epoch 120: loss=295.83831787109375
Epoch 140: loss=295.7755126953125
Epoch 160: loss=295.7544250488281
Epoch 180: loss=295.7417297363281


where the only code of note is the loop over our objective contributions which get added to the total
loss function.

We can save our trained parameters back to a SMIRNOFF `.offxml` file for future use:

In [11]:
from descent.utilities.smirnoff import perturb_force_field

final_force_field = perturb_force_field(
    initial_force_field, parameter_delta, parameter_delta_ids
)
final_force_field.to_file("final.offxml")

or print out the initial and final values.

In [12]:
for parameter_handler, potential_key, attribute in parameter_delta_ids:

    initial_value = getattr(
        initial_force_field[parameter_handler].parameters[potential_key.id], attribute
    )
    final_value = getattr(
        final_force_field[parameter_handler].parameters[potential_key.id], attribute
    )

    print(
        f"{parameter_handler} SMIRKS={potential_key.id} ATTR={attribute}  INITIAL={initial_value}  FINAL={final_value}"
    )

Bonds SMIRKS=[#6X3:1](=[#8X1+0])-[#7X3:2] ATTR=k  INITIAL=1053.970761594 kcal/(A**2 mol)  FINAL=1053.4928565952225 kcal/(A**2 mol)
Bonds SMIRKS=[#6X3:1]=[#6X3:2] ATTR=k  INITIAL=857.1115548611 kcal/(A**2 mol)  FINAL=856.6519475195506 kcal/(A**2 mol)
Bonds SMIRKS=[#6X4:1]-[#1:2] ATTR=k  INITIAL=758.0931772913 kcal/(A**2 mol)  FINAL=757.615259442764 kcal/(A**2 mol)
Bonds SMIRKS=[#6X3:1]-[#7X2:2] ATTR=k  INITIAL=837.2647972972 kcal/(A**2 mol)  FINAL=837.5914970507644 kcal/(A**2 mol)
Bonds SMIRKS=[#6X4:1]-[#6X3:2]=[#8X1+0] ATTR=k  INITIAL=612.0537081219 kcal/(A**2 mol)  FINAL=612.5316650040704 kcal/(A**2 mol)
Bonds SMIRKS=[#7:1]-[#1:2] ATTR=k  INITIAL=997.7547006218 kcal/(A**2 mol)  FINAL=997.2765224160728 kcal/(A**2 mol)
Bonds SMIRKS=[#6X3:1]-[#1:2] ATTR=k  INITIAL=808.1394472833 kcal/(A**2 mol)  FINAL=807.6614951022607 kcal/(A**2 mol)
Bonds SMIRKS=[#6X4:1]-[#6X3:2] ATTR=k  INITIAL=612.5097961064 kcal/(A**2 mol)  FINAL=612.0317996492527 kcal/(A**2 mol)
Bonds SMIRKS=[#6X3:1]=[#7X2,#7X3+1:2