# Custom Parameter Handlers

This example shows how custom parameter handlers can be used with `smee`, showcasing how the Lennard-Jones (LJ) potential can be swapped with a double exponential (DEXP) potential.

**Note:** this behaviour is currently experimental and not as fully tested as the built-in parameter handlers.

In order to use a custom potential in `smee`, we must register the function that will be used to evaluate the potential energy:

In [1]:
import torch

import smee
import smee.potentials.nonbonded


@smee.potentials.potential_energy_fn("vdW", "dexp")
def compute_dexp_energy(
    conformer: torch.Tensor,
    parameters: torch.Tensor,
    exclusions: torch.Tensor,
    exclusion_scales: torch.Tensor,
    attributes: torch.Tensor,
) -> torch.Tensor:
    """Evaluates the potential energy [kcal / mol] of the vdW interactions using the
    double-exponential potential.

    Args:
        conformer: The conformer [Å] to evaluate the potential at.
        parameters: A tensor containing the epsilon [kcal / mol] and r_min [Å] values
            of each particle, with ``shape=(n_particles, 2)``.
        exclusions: A tensor containing pairs of atom indices whose interaction should
            be scaled by ``exclusion_scales`` with ``shape=(n_exclusions, 2)``.
        exclusion_scales: A tensor containing the scale factor for each exclusion pair
            with ``shape=(n_exclusions, 1)``.
        attributes: A tensor containing the global alpha and beta values, and the
            scale_12, scale_13, scale_14, and scale_15 values with ``shape=(6,)``.

    Returns:
        The evaluated potential energy [kcal / mol].
    """

    is_batched = conformer.ndim == 3

    if not is_batched:
        conformer = torch.unsqueeze(conformer, 0)

    # pair_idxs will contain the particle idxs of each interacting pair
    # while pair_scales will contain any scaling factor (e.g. 1-2, 1-3)
    pair_idxs, rij_sqr, pair_scales = smee.potentials.nonbonded.compute_pairwise(
        conformer, exclusions, exclusion_scales
    )

    epsilon, r_min = smee.potentials.nonbonded.lorentz_berthelot(
        parameters[pair_idxs[:, 0], 0],
        parameters[pair_idxs[:, 1], 0],
        parameters[pair_idxs[:, 0], 1],
        parameters[pair_idxs[:, 1], 1],
    )
    alpha, beta = attributes[-2], attributes[-1]

    x = torch.sqrt(rij_sqr) / r_min

    energy_repulsion = beta / (alpha - beta) * torch.exp(alpha * (1.0 - x))
    energy_attraction = alpha / (alpha - beta) * torch.exp(beta * (1.0 - x))

    energy = (pair_scales * epsilon * (energy_repulsion - energy_attraction)).sum(-1)

    if not is_batched:
        energy = torch.squeeze(energy, 0)

    return energy



There are then two mains routes that can be taken to make use of this function 

1. we can manually convert an interchange object that uses LJ by default to use DEXP, and then simply call `smee.compute_energy`.
2. we can register a function using a `smee.ff.parameter_converter` decorator that converts ``SMIRNOFFDoubleExponentialCollection`` into tensor form, such that any interchange objects containing such a parameter collection can automatically be converted by `smee.convert_interchange`.

## Converting an existing LJ potential

We will begin with the first option. First we will define our molecule of interest and assign standard parameters:

In [2]:
import openff.interchange
import openff.toolkit
import openff.units

molecule = openff.toolkit.Molecule.from_smiles("CCCC")
molecule.generate_conformers(n_conformers=1)

conformer = torch.tensor(molecule.conformers[0].m_as(openff.units.unit.angstrom))

interchange = openff.interchange.Interchange.from_smirnoff(
    openff.toolkit.ForceField("openff-2.0.0.offxml"), molecule.to_topology()
)

We will then map it into tensor form

In [3]:
force_field, [topology] = smee.convert_interchange(interchange)

The vdW potential handler can be accessed via the `force_field` object:

In [4]:
vdw_potential = force_field.potentials_by_type["vdW"]

print("ENERGY FN=", vdw_potential.fn)
print("PARAMETER COLUMNS=", vdw_potential.parameter_cols)
print("ATTRIBUTE COLUMNS=", vdw_potential.attribute_cols)

ENERGY FN= 4*epsilon*((sigma/r)**12-(sigma/r)**6)
PARAMETER COLUMNS= ('epsilon', 'sigma')
ATTRIBUTE COLUMNS= ('scale_12', 'scale_13', 'scale_14', 'scale_15')


We set the potential energy function to be `dexp` as we declared above:

In [5]:
vdw_potential.fn = "dexp"

convert the LJ `sigma` parameters into `r_min`:

In [6]:
parameter_columns = [*vdw_potential.parameter_cols]
sigma_col_idx = vdw_potential.parameter_cols.index("sigma")

sigma = vdw_potential.parameters[:, sigma_col_idx]
r_min = 2 ** (1 / 6) * sigma

vdw_potential.parameters[:, sigma_col_idx] = r_min

parameter_columns[sigma_col_idx] = "r_min"
vdw_potential.parameter_cols = tuple(parameter_columns)

and add alpha and beta attributes:

In [7]:
vdw_potential.attribute_cols = (*vdw_potential.attribute_cols, "alpha", "beta")
vdw_potential.attributes = torch.cat(
    [vdw_potential.attributes, torch.tensor([16.5, 5.0])]
)

The energy can then be computed by the normal means:

In [8]:
energy = smee.compute_energy(topology.parameters, conformer, force_field)
print(f"Energy = {energy.item():.3f} kcal / mol")

Energy = 25.896 kcal / mol


## Defining a custom parameter converter

The more long term solution to supporting custom parameter handlers is by defining a converter for the custom interchange collection.

In general this should be as simple as calling out to the built-in convert, and simply specifying which parameters and handler attributes to expect, and what units they should be converted to.

In [9]:
from smirnoff_plugins.collections.nonbonded import SMIRNOFFDoubleExponentialCollection

import smee.ff.nonbonded

KCAL_PER_MOL = openff.units.unit.kilocalories / openff.units.unit.mole
ANGSTROM = openff.units.unit.angstrom

UNITLESS = openff.units.unit.dimensionless


@smee.ff.parameter_converter(
    "DoubleExponential",
    {
        "epsilon": KCAL_PER_MOL,
        "r_min": ANGSTROM,
        "alpha": UNITLESS,
        "beta": UNITLESS,
        "scale_12": UNITLESS,
        "scale_13": UNITLESS,
        "scale_14": UNITLESS,
        "scale_15": UNITLESS,
    },
)
def convert_dexp(
    handlers: list[SMIRNOFFDoubleExponentialCollection],
    topologies: list[openff.toolkit.Topology],
    v_site_maps: list[smee.ff.VSiteMap | None],
) -> tuple[smee.ff.TensorPotential, list[smee.ff.NonbondedParameterMap]]:
    potential, parameter_maps = smee.ff.nonbonded.convert_nonbonded_handlers(
        handlers,
        "DoubleExponential",
        topologies,
        v_site_maps,
        ("epsilon", "r_min"),
        ("alpha", "beta"),
    )
    potential.fn = "dexp"

    return potential, parameter_maps

You might notice that this function expects a list of double exponential collections rather than just a single one. This is because `smee` can convert multiple interchange objects at once and in doing so aggregate any found parameters into a single force field.

We must also -reregister the potential energy function that we defined above but this time setting the type to `'DoubleExponential'` as this is the name of the custom parameter handler:

In [10]:
smee.potentials.potential_energy_fn("DoubleExponential", "dexp")(compute_dexp_energy);

We can then define a new interchange containing our custom parameter handler:

In [11]:
dexp_force_field = openff.toolkit.ForceField(load_plugins=True)
dexp_force_field.get_parameter_handler("Electrostatics")

charge_handler = dexp_force_field.get_parameter_handler("LibraryCharges")
charge_handler.add_parameter({"smirks": "[*:1]", "charge1": 0.0 * openff.units.unit.e})

dexp_handler = dexp_force_field.get_parameter_handler("DoubleExponential")
dexp_handler.alpha = 16.5 * UNITLESS
dexp_handler.beta = 5.0 * UNITLESS
dexp_handler.add_parameter(
    {"smirks": "[#6:1]", "epsilon": 0.1 * KCAL_PER_MOL, "r_min": 3.0 * ANGSTROM}
)
dexp_handler.add_parameter(
    {"smirks": "[#1:1]", "epsilon": 0.01 * KCAL_PER_MOL, "r_min": 1.0 * ANGSTROM}
)

dexp_interchange = openff.interchange.Interchange.from_smirnoff(
    dexp_force_field, molecule.to_topology()
)
dexp_interchange

Interchange with 3 collections, non-periodic topology with 14 atoms.

The interchange containing the custom DEXP handler can then be converted into tensor form:

In [12]:
dexp_tensor_ff, [dexp_topology] = smee.convert_interchange(dexp_interchange)

The potential energy can then be computed as normal:

In [13]:
energy = smee.compute_energy(dexp_topology.parameters, conformer, dexp_tensor_ff)
print(f"Energy = {energy.item():.3f} kcal / mol")

Energy = -0.054 kcal / mol
