In [58]:
import copy
import openmm as mm
from openmm import app, unit
from openmm.app import PDBFile

In [59]:
with open('system_with_ligand_only.xml', 'r') as f:
    system_1 = mm.XmlSerializer.deserialize(f.read())
pdb = PDBFile('ligand_in_solvent.pdb')
system_1.addForce(mm.MonteCarloBarostat(1*unit.atmosphere, 300*unit.kelvin))
system_1.getForces()

[<openmm.openmm.HarmonicBondForce; proxy of <Swig Object of type 'OpenMM::HarmonicBondForce *' at 0x71041083def0> >,
 <openmm.openmm.PeriodicTorsionForce; proxy of <Swig Object of type 'OpenMM::PeriodicTorsionForce *' at 0x71041083de60> >,
 <openmm.openmm.NonbondedForce; proxy of <Swig Object of type 'OpenMM::NonbondedForce *' at 0x71041083dce0> >,
 <openmm.openmm.CMMotionRemover; proxy of <Swig Object of type 'OpenMM::CMMotionRemover *' at 0x71041083dd70> >,
 <openmm.openmm.HarmonicAngleForce; proxy of <Swig Object of type 'OpenMM::HarmonicAngleForce *' at 0x71041083dc20> >,
 <openmm.openmm.MonteCarloBarostat; proxy of <Swig Object of type 'OpenMM::MonteCarloBarostat *' at 0x71041083db60> >]

In [60]:
ligand = list(pdb.topology.residues())[0]
rest_atoms = [atom.index for atom in ligand.atoms()]
len(rest_atoms)

151

In [61]:
# Create REST system
rest_system = mm.System()
# Create dict of vanilla system forces (for easy retrieval of force objects later)
system_forces = {type(force).__name__ : force for force in system_1.getForces()}

# Add particles
for particle_idx in range(system_1.getNumParticles()):
    particle_mass = system_1.getParticleMass(particle_idx)
    rest_system.addParticle(particle_mass)

# Copy barostat
if "MonteCarloBarostat" in system_forces:
    barostat = copy.deepcopy(system_forces["MonteCarloBarostat"])
    rest_system.addForce(barostat)

# Copy box vectors
box_vectors = system_1.getDefaultPeriodicBoxVectors()
rest_system.setDefaultPeriodicBoxVectors(*box_vectors)

# Copy constraints
for constraint_idx in range(system_1.getNumConstraints()):
    atom1, atom2, length = system_1.getConstraintParameters(constraint_idx)
    rest_system.addConstraint(atom1, atom2, length)

In [62]:
# Define the custom expression
bond_expression = "rest_scale * (K / 2) * (r - length)^2;"
bond_expression += "rest_scale = is_rest * lambda_rest_bonds * lambda_rest_bonds " \
                   "+ is_inter * lambda_rest_bonds " \
                   "+ is_nonrest;"

# Create custom force
rest_bond_force = mm.CustomBondForce(bond_expression)
rest_system.addForce(rest_bond_force)

# Add global parameters
rest_bond_force.addGlobalParameter("lambda_rest_bonds", 1.0)

# Add per-bond parameters for rest scaling
rest_bond_force.addPerBondParameter("is_rest")
rest_bond_force.addPerBondParameter("is_inter")
rest_bond_force.addPerBondParameter("is_nonrest")

# Add per-bond parameters for defining bond energy
rest_bond_force.addPerBondParameter('length')  # equilibrium bond length
rest_bond_force.addPerBondParameter('K')  # force constant

4

In [63]:
def get_rest_identifier(atoms, rest_atoms):
    """
    For a given atom or set of atoms, get the rest_id which is a list of binary ints that defines which
    (mutually exclusive) set the atom(s) belong to.

    If there is a single atom, the sets are: is_rest, is_nonrest
    If there is a set of atoms, the sets are: is_rest, is_inter, is_nonrest

    Example: if there is a single atom that is in the nonrest set, the rest_id is [0, 1]

    Arguments
    ---------
    atoms : set or int
        a set of hybrid atom indices or single atom
    rest_atoms : set or list
        a list (or list-like) of atoms whose interactions will be scaled by REST
    Returns
    -------
    rest_id : list
        list of binaries indicating which set the atom(s) belong to
    """

    if isinstance(atoms, int):
        rest_id = [0, 1] # Set the default rest_id to non-REST
        if atoms in rest_atoms:
            rest_id = [1, 0]
        return rest_id

    elif isinstance(atoms, set):
        rest_id = [0, 0, 1] # Set the default rest_id to non-REST
        if atoms.intersection(rest_atoms) != set(): # At least one of the atoms is REST
            if atoms.issubset(rest_atoms): # All atoms are REST
                rest_id = [1, 0, 0]
            else: # At least one (but not all) of the atoms is are REST
                rest_id = [0, 1, 0]
        return rest_id

    else:
        raise Exception(f"atoms is of type {type(atoms)}, but only `int` and `set` are allowable")

In [64]:
# Get vanilla system bond force
bond_force = system_forces['HarmonicBondForce']

# Set periodicity
if bond_force.usesPeriodicBoundaryConditions():
    rest_bond_force.setUsesPeriodicBoundaryConditions(True)

# Add bonds to rest_system
for term_idx in range(bond_force.getNumBonds()):
    # Get the bond parameters and rest id
    p1, p2, r0, k = bond_force.getBondParameters(term_idx)
    idx_set = set([p1, p2])
    rest_id = get_rest_identifier(idx_set, rest_atoms)

    # Add the bond
    bond_term = (p1, p2, rest_id + [r0, k])
    rest_bond_force.addBond(*bond_term)

In [65]:
# Define the custom expression
angle_expression = "rest_scale * (K / 2) * (theta - theta0)^2;"
angle_expression += "rest_scale = is_rest * lambda_rest_angles * lambda_rest_angles " \
                    "+ is_inter * lambda_rest_angles " \
                    "+ is_nonrest;"

# Create custom force
rest_angle_force = mm.CustomAngleForce(angle_expression)
rest_system.addForce(rest_angle_force)

# Add global parameters
rest_angle_force.addGlobalParameter("lambda_rest_angles", 1.0)

# Add per-angle parameters for rest scaling
rest_angle_force.addPerAngleParameter("is_rest")
rest_angle_force.addPerAngleParameter("is_inter")
rest_angle_force.addPerAngleParameter("is_nonrest")

# Add per-angle parameters for defining angle energy
rest_angle_force.addPerAngleParameter('theta0')  # equilibrium angle
rest_angle_force.addPerAngleParameter('K')  # force constant

# Get vanilla system angle force
angle_force = system_forces['HarmonicAngleForce']

# Set periodicity
if angle_force.usesPeriodicBoundaryConditions():
    rest_angle_force.setUsesPeriodicBoundaryConditions(True)

# Add angles to rest_system
for term_idx in range(angle_force.getNumAngles()):
    # Get the angle parameters and rest id
    p1, p2, p3, theta0, k = angle_force.getAngleParameters(term_idx)
    idx_set = set([p1, p2, p3])
    rest_id = get_rest_identifier(idx_set, rest_atoms)

    # Add the angle
    angle_term = (p1, p2, p3, rest_id + [theta0, k])
    rest_angle_force.addAngle(*angle_term)

In [66]:
# Define the custom expression
torsion_expression = "rest_scale * U;"
torsion_expression += "rest_scale = is_rest * lambda_rest_torsions * lambda_rest_torsions " \
                      "+ is_inter * lambda_rest_torsions " \
                      "+ is_nonrest;"
torsion_expression += "U = (K * (1 + cos(periodicity * theta - phase)));"

# Create custom force
rest_torsion_force = mm.CustomTorsionForce(torsion_expression)
rest_system.addForce(rest_torsion_force)

# Add global parameters
rest_torsion_force.addGlobalParameter("lambda_rest_torsions", 1.0)

# Add per-torsion parameters for rest scaling
rest_torsion_force.addPerTorsionParameter("is_rest")
rest_torsion_force.addPerTorsionParameter("is_inter")
rest_torsion_force.addPerTorsionParameter("is_nonrest")

# Add per-torsion parameters for defining torsion energy
rest_torsion_force.addPerTorsionParameter('periodicity')
rest_torsion_force.addPerTorsionParameter('phase') # phase offset
rest_torsion_force.addPerTorsionParameter('K') # force constant

# Get vanilla system torsion force
torsion_force = system_forces['PeriodicTorsionForce']

# Set periodicity
if torsion_force.usesPeriodicBoundaryConditions():
    rest_torsion_force.setUsesPeriodicBoundaryConditions(True)

# Add torsions to rest_system
for torsion_idx in range(torsion_force.getNumTorsions()):
    # Get the torsion parameters and rest id
    p1, p2, p3, p4, periodicity, phase, K = torsion_force.getTorsionParameters(torsion_idx)
    idx_set = set([p1, p2, p3, p4])
    rest_id = get_rest_identifier(idx_set, rest_atoms)

    # Add torsion
    torsion_term = (p1, p2, p3, p4, rest_id + [periodicity, phase, K])
    rest_torsion_force.addTorsion(*torsion_term)

In [67]:
# Create nonbonded force
rest_nonbonded_force = mm.NonbondedForce()
rest_system.addForce(rest_nonbonded_force)

# Get vanilla system nonbonded force
nonbonded_force = system_forces['NonbondedForce']

# Set the nonbonded method and related parameters
nonbonded_method = nonbonded_force.getNonbondedMethod()
rest_nonbonded_force.setNonbondedMethod(nonbonded_method)
if nonbonded_method != mm.NonbondedForce.NoCutoff:
    epsilon_solvent = nonbonded_force.getReactionFieldDielectric()
    cutoff = nonbonded_force.getCutoffDistance()
    rest_nonbonded_force.setReactionFieldDielectric(epsilon_solvent)
    rest_nonbonded_force.setCutoffDistance(cutoff)
if nonbonded_method in [mm.NonbondedForce.PME, mm.NonbondedForce.Ewald]:
    [alpha_ewald, nx, ny, nz] = nonbonded_force.getPMEParameters()
    delta = nonbonded_force.getEwaldErrorTolerance()
    rest_nonbonded_force.setPMEParameters(alpha_ewald, nx, ny, nz)
    rest_nonbonded_force.setEwaldErrorTolerance(delta)

# Copy switching function from vanilla system
switch_bool = nonbonded_force.getUseSwitchingFunction()
rest_nonbonded_force.setUseSwitchingFunction(switch_bool)
if switch_bool:
    switching_distance = nonbonded_force.getSwitchingDistance()
    rest_nonbonded_force.setSwitchingDistance(switching_distance)

# Copy dispersion correction
dispersion_bool = nonbonded_force.getUseDispersionCorrection()
rest_nonbonded_force.setUseDispersionCorrection(dispersion_bool)

# Add global parameters
rest_nonbonded_force.addGlobalParameter('lambda_rest_electrostatics', 0.)
rest_nonbonded_force.addGlobalParameter('lambda_rest_sterics', 0.)

1

In [68]:
# Add nonbondeds to rest_system
for particle_idx in range(nonbonded_force.getNumParticles()):
    # Get the nonbonded parameters and rest id
    q, sigma, epsilon = nonbonded_force.getParticleParameters(particle_idx)
    rest_id = get_rest_identifier(particle_idx, rest_atoms)

    # Add particles and offsets
    if rest_id == [0, 1]: # nonrest
        rest_nonbonded_force.addParticle(q, sigma, epsilon)

    else: # rest
        rest_nonbonded_force.addParticle(q, sigma, epsilon)
        rest_nonbonded_force.addParticleParameterOffset('lambda_rest_electrostatics', particle_idx, q, 0.0*sigma, epsilon*0.0)
        rest_nonbonded_force.addParticleParameterOffset('lambda_rest_sterics', particle_idx, q*0.0, 0.0*sigma, epsilon)

# Handle exceptions
for exception_idx in range(nonbonded_force.getNumExceptions()):
    # Get exception parameters and rest id
    p1, p2, chargeProd, sigma, epsilon = nonbonded_force.getExceptionParameters(exception_idx)
    idx_set = set([p1, p2])
    rest_id = get_rest_identifier(idx_set, rest_atoms)

    # Add exceptions and offsets
    exc_idx = rest_nonbonded_force.addException(p1, p2, chargeProd, sigma, epsilon)
    if rest_id == [0, 0, 1]: # nonrest
        pass

    elif rest_id == [1, 0, 0]: # rest
        rest_nonbonded_force.addExceptionParameterOffset('lambda_rest_sterics', exc_idx, chargeProd, 0.0*sigma, epsilon)

    elif rest_id == [0, 1, 0]: # inter
        rest_nonbonded_force.addExceptionParameterOffset('lambda_rest_electrostatics', exc_idx, chargeProd, 0.0*sigma, epsilon)


In [69]:
import math
import logging
import numpy as np
from openmmtools.constants import kB
from openmmtools import cache, mcmc, multistate
from openmmtools.multistate import ReplicaExchangeSampler
from openmmtools.states import GlobalParameterState, SamplerState, ThermodynamicState, CompoundThermodynamicState

In [70]:
class RESTState(GlobalParameterState):
    lambda_rest_bonds = GlobalParameterState.GlobalParameter('lambda_rest_bonds', standard_value=1.0)
    lambda_rest_angles = GlobalParameterState.GlobalParameter('lambda_rest_angles', standard_value=1.0)
    lambda_rest_torsions = GlobalParameterState.GlobalParameter('lambda_rest_torsions', standard_value=1.0)
    lambda_rest_electrostatics = GlobalParameterState.GlobalParameter('lambda_rest_electrostatics', standard_value=0.0)
    lambda_rest_sterics = GlobalParameterState.GlobalParameter('lambda_rest_sterics', standard_value=0.0)

    def set_rest_parameters(self, beta_m, beta_0):
        """Set all defined lambda parameters to the given value.

        The undefined parameters (i.e. those being set to None) remain undefined.

        Parameters
        ----------
        new_value : float
            The new value for all defined parameters.
        """
        lambda_functions = {'lambda_rest_bonds': lambda beta_m, beta_0 : np.sqrt(beta_m / beta_0),
                 'lambda_rest_angles' : lambda beta_m, beta_0 : np.sqrt(beta_m / beta_0),
                 'lambda_rest_torsions' : lambda beta_m, beta_0 : np.sqrt(beta_m / beta_0),
                 'lambda_rest_electrostatics' : lambda beta_m, beta_0 : np.sqrt(beta_m / beta_0) - 1,
                 'lambda_rest_sterics' : lambda beta_m, beta_0 : beta_m / beta_0 - 1
                 }

        for parameter_name in self._parameters:
            if self._parameters[parameter_name] is not None:
                new_value = lambda_functions[parameter_name](beta_m, beta_0)
                setattr(self, parameter_name, new_value)

In [71]:
# Set temperatures for each thermodynamic state
n_replicas = 40  # Number of temperature replicas
T_min = 300 * unit.kelvin  # Minimum temperature (i.e., temperature of desired distribution)
T_max = 500 * unit.kelvin  # Maximum temperature
temperatures = [T_min + (T_max - T_min) * (math.exp(float(i) / float(n_replicas-1)) - 1.0) / (math.e - 1.0)
                for i in range(n_replicas)]
temperatures

[Quantity(value=300.0, unit=kelvin),
 Quantity(value=303.02308784039593, unit=kelvin),
 Quantity(value=306.1246930852075, unit=kelvin),
 Quantity(value=309.3068550343565, unit=kelvin),
 Quantity(value=312.571665953654, unit=kelvin),
 Quantity(value=315.92127245046197, unit=kelvin),
 Quantity(value=319.3578768850833, unit=kelvin),
 Quantity(value=322.88373881881006, unit=kelvin),
 Quantity(value=326.50117649958094, unit=kelvin),
 Quantity(value=330.21256838622503, unit=kelvin),
 Quantity(value=334.02035471229453, unit=kelvin),
 Quantity(value=337.9270390905137, unit=kelvin),
 Quantity(value=341.9351901588998, unit=kelvin),
 Quantity(value=346.0474432696381, unit=kelvin),
 Quantity(value=350.2665022218214, unit=kelvin),
 Quantity(value=354.59514103919264, unit=kelvin),
 Quantity(value=359.036205794061, unit=kelvin),
 Quantity(value=363.5926164785894, unit=kelvin),
 Quantity(value=368.26736892468364, unit=kelvin),
 Quantity(value=373.06353677374716, unit=kelvin),
 Quantity(value=377.98427

In [72]:
# Create reference thermodynamic state
rest_state = RESTState.from_system(rest_system)
thermostate = ThermodynamicState(rest_system, temperature=T_min)
compound_thermodynamic_state = CompoundThermodynamicState(thermostate, composable_states=[rest_state])

In [73]:
# Create thermodynamic states
sampler_state =  SamplerState(pdb.positions, box_vectors=rest_system.getDefaultPeriodicBoxVectors())
beta_0 = 1/(kB*T_min)
thermodynamic_state_list = []
sampler_state_list = []
for temperature in temperatures:
    # Create a thermodynamic state with REST interactions scaled to the given temperature
    beta_m = 1/(kB*temperature)
    compound_thermodynamic_state_copy = copy.deepcopy(compound_thermodynamic_state)
    compound_thermodynamic_state_copy.set_rest_parameters(beta_m, beta_0)
    thermodynamic_state_list.append(compound_thermodynamic_state_copy)

    # Generate a sampler_state with minimized positions
    context, integrator = cache.global_context_cache.get_context(compound_thermodynamic_state_copy)
    sampler_state.apply_to_context(context, ignore_velocities=True)
    mm.LocalEnergyMinimizer.minimize(context)
    sampler_state.update_from_context(context)
    sampler_state_list.append(copy.deepcopy(sampler_state))

In [74]:
# Set up sampler
_logger = logging.getLogger()
_logger.setLevel(logging.DEBUG)
move = mcmc.LangevinDynamicsMove(timestep=4*unit.femtoseconds, n_steps=10000)
simulation = ReplicaExchangeSampler(mcmc_moves=move, number_of_iterations=400)

# Run repex
reporter_file = "test_run.nc"
reporter = multistate.MultiStateReporter(reporter_file, checkpoint_interval=100)
simulation.create(thermodynamic_states=thermodynamic_state_list,
                  sampler_states=sampler_state_list,
                  storage=reporter)



Please cite the following:

        Friedrichs MS, Eastman P, Vaidyanathan V, Houston M, LeGrand S, Beberg AL, Ensign DL, Bruns CM, and Pande VS. Accelerating molecular dynamic simulations on graphics processing unit. J. Comput. Chem. 30:864, 2009. DOI: 10.1002/jcc.21209
        Eastman P and Pande VS. OpenMM: A hardware-independent framework for molecular simulations. Comput. Sci. Eng. 12:34, 2010. DOI: 10.1109/MCSE.2010.27
        Eastman P and Pande VS. Efficient nonbonded interactions for molecular dynamics on a graphics processing unit. J. Comput. Chem. 31:1268, 2010. DOI: 10.1002/jcc.21413
        Eastman P and Pande VS. Constant constraint matrix approximation: A robust, parallelizable constraint method for molecular simulations. J. Chem. Theor. Comput. 6:434, 2010. DOI: 10.1021/ct900463w
        Chodera JD and Shirts MR. Replica exchange and expanded ensemble simulations as Gibbs multistate: Simple improvements for enhanced mixing. J. Chem. Phys., 135:194110, 2011. DOI:10.1063/

In [None]:
simulation.run()


******* JAX 64-bit mode is now on! *******
*     JAX is now set to 64-bit mode!     *
*   This MAY cause problems with other   *
*      uses of JAX in the same code.     *
******************************************

DEBUG:2025-09-14 10:03:10,688:jax._src.dispatch:198: Finished tracing + transforming _reduce_min for pjit in 0.001800537 sec
DEBUG:2025-09-14 10:03:10,691:jax._src.dispatch:198: Finished tracing + transforming subtract for pjit in 0.000864506 sec
DEBUG:2025-09-14 10:03:10,692:jax._src.dispatch:198: Finished tracing + transforming subtract for pjit in 0.000325918 sec
DEBUG:2025-09-14 10:03:10,694:jax._src.dispatch:198: Finished tracing + transforming not_equal for pjit in 0.000584126 sec
DEBUG:2025-09-14 10:03:10,695:jax._src.dispatch:198: Finished tracing + transforming _broadcast_arrays for pjit in 0.000298500 sec
DEBUG:2025-09-14 10:03:10,696:jax._src.dispatch:198: Finished tracing + transforming _where for pjit in 0.001558304 sec
DEBUG:2025-09-14 10:03:10,697:jax._src.d

In [None]:
from openmmtools.multistate import MultiStateReporter
import mdtraj as md
import numpy as np


reporter_path = "test_run.nc"
pdb_path = 'ligand_in_solvent.pdb'


reporter = MultiStateReporter(reporter_path, open_mode='r')
pdb = md.load(pdb_path)
topology = pdb.topology

# Read the total number of saved iterations
last_iteration = reporter.read_last_iteration()
    
# Create an empty list to store the positions from each iteration
positions = []
    
# Loop through all iterations to extract the positions
# We add 1 to the last_iteration because the iterations are zero-indexed
print(f"Reading {last_iteration + 1} iterations from {reporter_path}...")
for i in range(last_iteration + 1):
    # Read the sampler states for the current iteration
    sampler_states = reporter.read_sampler_states(iteration=i)
        
    # Check if sampler_states is not None and not empty
    if sampler_states is None or len(sampler_states) == 0:
        print(f"Warning: No sampler states found for iteration {i}. Skipping.")
        continue

    # Extract the positions from the first sampler state.
    # This assumes a single-replica simulation or that you only need
    # the trajectory from the first replica.
    print(f"Reading positions from iteration {i}...")
    pos = sampler_states[0].positions
        
    # Append the positions (in nanometers) to our list
    positions.append(pos)

reporter.close()    

traj = md.Trajectory(np.array(positions), topology)
output_dcd_path = "rest_traj.dcd"
traj.save_dcd(output_dcd_path)



Reading 101 iterations from test_run.nc...
Reading positions from iteration 0...
Reading positions from iteration 100...
Conversion complete! The DCD file has been created at: rest_traj.dcd
