In [30]:
from os import path
from typing import Tuple

from openmmforcefields.generators import SMIRNOFFTemplateGenerator
import mdtraj
from openmm import app, unit, Platform, LangevinIntegrator
import torch
import pandas as pd

In [31]:
mol_name = "mobley_6456034"
component = "compound"  # also "solvated" for combined system and "solvent" for isolated water box
temperature = 298.15  # in K
freesolv_db_path = "/rxrx/data/valence/shared_psu_unit/free_energy/solvation/free_solv/database.txt"

In [32]:
def load_md_inputs(mol_name: str, component: str) -> Tuple[str, str, str]:
    """Loads MD files extracted from openFE alchemical free energy simulation.

    Args:
        endpoint_path (str): Path to directory containing endpoints.
        component (str): The system compounent to load. Either "compound",
          "solvated" for the combined system or "solvent" for the water box.

    Returns:
        Tuple: Paths to the trajectory, the topology and the forcefield DB.
    """
    # all this loading is hardcoded to the FreeSolv dataset
    endpoint_dir = ("/rxrx/data/valence/shared_psu_unit/free_energy/solvation/"
                    "experiments/mlip_compatible/trajectories/endpoints")
    endpoint_path = path.join(endpoint_dir, mol_name)
    top_path = path.join(endpoint_path, f"{component}_top.pdb")
    traj_path = path.join(endpoint_path, f"{component}_traj.dcd")
    if component == "solvent":  # solvent and solvated use same FF DB
        component = "solvated"
    db_path = path.join(endpoint_path, f"{component}_db.json")
    return traj_path, top_path, db_path


freesolv_df = pd.read_csv(freesolv_db_path, delimiter=';', skiprows=[0,1])
# here I'm assuming that the atom ordering, bond connection etc is the same for
#  the SMIRNOFFTemplateGenerator I'm using to build the topology and the ETFlow
#  topology generator

# TODO: The ETFlow SMILES MolFeaturizer would most likely need to be adjusted to
#  properly encode water molecules
smiles = freesolv_df[freesolv_df["compound_id"] == mol_name].iloc[0]["SMILES"]
smiles

' CC(C)COC(=O)C(C)C'

In [33]:
# FF constants
forcefields = ['amber/ff14SB.xml', 'amber/tip3p_standard.xml',
               'amber/tip3p_HFE_multivalent.xml', 'amber/phosaa10.xml']
small_mol_ff = "openff-2.1.1"
nonbonded_cutoff = 1. * unit.nanometers

In [34]:
class Energy:
    def get_energy(self, pos):
        raise NotImplementedError

    def compute_many(self, confs):
        return torch.tensor([self.get_energy(i) for i in confs]).float()

class openMMEnergy(Energy):
    def __init__(self, trajectory, topology, db_path, temperature, 
                 with_pbcs=False):
        kb = 0.0019872041  # in kcal / mol K
        self.kbt = kb * temperature
        forcefield = app.ForceField(*forcefields)
        template_generator = SMIRNOFFTemplateGenerator(
            cache=db_path, forcefield=small_mol_ff)
        forcefield.registerTemplateGenerator(template_generator.generator)

        mm_traj = mdtraj.load(trajectory, top=topology)
        self.init_positions = mm_traj.openmm_positions(0)

        pdb = app.PDBFile(topology)
        modeller = app.Modeller(pdb.topology, pdb.positions)
        platform = Platform.getPlatformByName("CUDA")
        integrator = LangevinIntegrator(temperature * unit.kelvin,
                                        1., 1. * unit.femtosecond)

        if with_pbcs:
            box_length = mm_traj.openmm_boxes(0)
            nonbonded_method = "PME"
            pdb.topology.setPeriodicBoxVectors(box_length)
        else:
            nonbonded_method = app.NoCutoff

        system = forcefield.createSystem(
            modeller.topology, nonbondedMethod=nonbonded_method,
            nonbondedCutoff=nonbonded_cutoff,
            rigidWater=False, removeCMMotion=False
        )
        # print("Classical FF forces: \n", system.getForces())
        self.sim = app.Simulation(modeller.topology, system, integrator,
                                  platform)
        self.sim.context.setPositions(self.init_positions)

        self.sim.minimizeEnergy()
        min_energy = self.sim.context.getState(
            getEnergy=True).getPotentialEnergy()
        print(f"Minimum energy: "
            f"{min_energy.value_in_unit(unit.kilocalorie_per_mole)} kcal/mol")
        
    def get_energy(self, positions, box_lengths=None):
        """Returns the potential energy of the system for given atomic
            postions. If boxvectors are provided, adjusts box size."""
        if box_lengths is not None:
            self.sim.context.setPeriodicBoxVectors(*box_lengths)
        self.sim.context.setPositions(positions)
        u = self.sim.context.getState(getEnergy=True).getPotentialEnergy()
        beta_u = u.value_in_unit(unit.kilocalorie_per_mole) / self.kbt
        return beta_u  # unitless exponent Boltzmann dist. of NVT ensemble


In [35]:
traj_path, top_path, db_path = load_md_inputs(mol_name, component)
openmm_molecule = openMMEnergy(traj_path, top_path, db_path, temperature)
openmm_molecule.get_energy(openmm_molecule.init_positions)

Minimum energy: 2.0936634071014577 kcal/mol


46.82150167395687