## Add Resampling to FultonMarket

Randolph to become only 5 ns simulations

Is the skip parameter necessary?

Remove the reduced potentials stuff from Fulton Market Analysis?


In [40]:
from openmm import *
from openmm.app import *
from openmmtools import states, mcmc, multistate, cache
from openmmtools.states import SamplerState, ThermodynamicState
from openmmtools.multistate import ParallelTemperingSampler, ReplicaExchangeSampler, MultiStateReporter
from openmmtools.utils import get_fastest_platform
from openmmtools.utils.utils import TrackedQuantity
import tempfile
import os, sys
sys.path.append('../MotorRow')
import numpy as np
np.seterr(divide='ignore', invalid='ignore')
import netCDF4 as nc
from typing import List
from datetime import datetime
import mdtraj as md
from FultonMarketUtils import *
from analysis.FultonMarketAnalysis import *
from copy import deepcopy
import faulthandler
faulthandler.enable()

#from Randolph import Randolph

#### Unrestrained PT

In [41]:
class FultonMarket_PT():
    """
    Simple Parallel Tempering Class
    Using one structure and one system as input, run parallel tempering
    Temperatures are set as a geometric distribution between tmin and tmax
    """
    def __init__(self, input_pdb: str, input_system: str, input_state: str=None):
        """
        Initialize a Fulton Market obj. 

        Parameters:
        -----------
            input_pdb (str):
                String path to pdb to run simulation. 

            input_system (str):
                String path to OpenMM system (.xml extension) file that contains parameters for simulation. 

            input_state (str):
                String path to OpenMM state (.xml extension) file that contains state for reference. 


        Returns:
        --------
            FultonMarket obj.
        """
        printf('Welcome to FultonMarket.')
        
        # Unpack .pdb
        self.input_pdb = input_pdb
        self.pdb = PDBFile(input_pdb)
        self.init_positions = self.pdb.getPositions(asNumpy=True)
        printf(f'Found input_pdb: {input_pdb}')


        # Unpack .xml
        self.system =XmlSerializer.deserialize(open(input_system, 'r').read())
        self.init_box_vectors = self.system.getDefaultPeriodicBoxVectors()
        printf(f'Found input_system: {input_system}')


        # Build state
        if input_state != None:
            dummy_integrator = LangevinIntegrator(300, 0.01, 2)
            sim = Simulation(self.pdb.topology, self.system, dummy_integrator)
            sim.loadState(input_state)
            self.context = sim.context
            printf(f'Found input_state: {input_state}')
        
    def set_experimental_params(self, total_sim_time: unit.Quantity, iteration_length: unit.Quantity,
                                dt:unit.Quantity=2.0*unit.femtosecond, T_min:unit.Quantity=300*unit.kelvin,
                                T_max:unit.Quantity=367.447*unit.kelvin, n_replicates: int=12,
                                sim_length=50*unit.nanosecond,
                                output_dir: str=os.path.join(os.getcwd(), 'FultonMarketPT_output/')):
        """
        Set the parameters that will be used for simulations with Randolph
        Parameters:
        -----------
            total_sim_time (unit.Quantity): Aggregate simulation time from all replicates in nanoseconds.
            iteration_length (unit.Quantity): Simulation time between swapping replicates in nanoseconds. 
            dt (unit.Quantity): Timestep for simulation. Default is 2.0 femtoseconds.
            T_min (unit.Quantity): Minimum temperature in Kelvin. This state will serve as the reference state. Default is 300 K.
            T_max (unit.Quantity): Maximum temperature in Kelvin. Default is 367.447 K.
            n_replicates (int): Number of replicates, meaning number of states between T_min and T_max.
                States are automatically built at with a geometeric distribution towards T_min. Default is 12.
            sim_length (unit.Quantity): Time (in nanoseconds) for seperate Randolph runs. Default 50 ns
            output_dir (str): String path to output directory to store files. Default is 'FultonMarketPT_output' in the current working directory. 
        """
        
        # Store variables
        self.total_sim_time = total_sim_time
        self.iteration_length = iteration_length
        self.sim_length = sim_length
        self.temperatures = geometric_distribution(T_min, T_max, n_replicates)
        
        # Default initial positions and box vectors
        self.init_positions = [self.init_positions for i in range(n_replicates)]
        self.init_box_vectors = [self.init_box_vectors for i in range(n_replicates)]
        
        ref_state = states.ThermodynamicState(system=self.system, temperature=self.temperatures[0], pressure=1.0*unit.bar)
        self.output_ncdf = os.path.join(output_dir, 'output.ncdf')
        checkpoint_ncdf = os.path.join(output_dir, 'output_checkpoint.ncdf')
        self.save_dir = os.path.join(output_dir, 'saved_variables')
        if not os.path.exists(self.save_dir):
            os.mkdir(self.save_dir)
        
        # Assert that no empty save directories have been made
        assert all([len(os.listdir(os.path.join(self.save_dir, dir))) >= 5 for dir in os.listdir(self.save_dir)]), "You may have an empty save directory, please remove empty or incomplete save directories before continuing :)"
        
        # Configure experiment parameters
        self.sim_no = len(os.listdir(self.save_dir))
        self.total_n_sims = round(self.total_sim_time / self.sim_length)
        
        printf(f'Found total simulation time of {self.total_sim_time}')
        printf(f'Calculated total_n_sims to be: {self.total_n_sims}')
        printf(f'Found iteration length of {self.iteration_length}')
        printf(f'Found timestep of {dt}')
        printf(f'Found number of replicates {n_replicates}')
        printf(f'Found output_dir: {output_dir}')
        printf(f'Found Temperature Schedule: {[np.round(T._value, 1) for T in self.temperatures]} Kelvin')
        printf(f'Found n_sims_completed to be: {self.sim_no}')
                
        self.Randolph_Params = dict(sim_no=self.sim_no, # which Randolph run
                                    sim_time=self.sim_length, #Length of the Randolph run
                                    system=self.system, #Openmm system
                                    ref_state=ref_state, #Reference Thermodynamic State
                                    temperatures=self.temperatures, #List of temperatures
                                    init_positions=self.init_positions, #Initial Positions
                                    init_box_vectors=self.init_box_vectors, #Initial Box vectors
                                    output_dir=output_dir, #Directory for Output
                                    output_ncdf=self.output_ncdf, # NetCDF filename
                                    checkpoint_ncdf=checkpoint_ncdf, #Checkpoint NetCDF filename
                                    iter_length=iteration_length, #Time between swaps
                                    dt=dt) #Step time
    
    
    def run_sim_loop(self, init_overlap_thresh:float=0.5, term_overlap_thresh:float=0.35,):
        """
        Loop through Randolph runs
        
        Parameters:
            init_overlap_thresh (float):
                Acceptance rate threshold during first 50 ns simulation to cause restart. Default is 0.50. 

            term_overlap_thresh (float):
                Terminal acceptance rate. If the minimum acceptance rate ever falls below this threshold simulation with restart. Default is 0.35.
        """
        
        printf(f'Found initial acceptance rate threshold: {init_overlap_thresh}')
        printf(f'Found terminal acceptance rate threshold: {term_overlap_thresh}')
        printf(f'Beginning Randolph Runs...')
        
        while self.sim_no < self.total_n_sims:
            #Adjust Parameters
            if self.sim_no > 0:
                self._load_initial_args() #sets positions, velocities, box_vecs, temperatures, and spring_constants
                self.Randolph_Params['temperatures'] = self.temperatures
                self.Randolph_Params['init_velocities'] = self.init_velocities
                self.Randolph_Params['init_positions'] = self.init_positions
                self.Randolph_Params['init_box_vectors'] = self.init_box_vectors
                
            elif hasattr(self, 'context'):
                self.Randolph_Params['context'] = self.context
            
            simulation = Randolph(**self.Randolph_Params)
            
            # Run simulation
            simulation.main(init_overlap_thresh=init_overlap_thresh, term_overlap_thresh=term_overlap_thresh)

            # Save simulation
            self.temperatures = simulation.save_simulation(self.save_dir)
            
            # Delete output.ncdf files if not last simulation 
            if not self.sim_no+1 == self.total_n_sims:
                os.remove(self.output_ncdf)
                os.remove(checkpoint_ncdf)

            # Update counter
            self.sim_no += 1
    
    
    def _load_initial_args(self):
        # Get last directory
        load_no = self.sim_no - 1
        load_dir = os.path.join(self.save_dir, str(load_no))
        
        # Load args (not in correct shapes)
        temperatures = np.load(os.path.join(load_dir, 'temperatures.npy'))
        self.temperatures = [t*unit.kelvin for t in self.temperatures]
        
        try:
            init_positions = np.load(os.path.join(load_dir, 'positions.npy'))[-1] 
            init_box_vectors = np.load(os.path.join(load_dir, 'box_vectors.npy'))[-1] 
            init_velocities = np.load(os.path.join(load_dir, 'velocities.npy')) 
            state_inds = np.load(os.path.join(load_dir, 'states.npy'))[-1]
        except:
            try:
                init_positions = np.load(os.path.join(load_dir, 'positions.npy'))[-1]
                init_box_vectors = np.load(os.path.join(load_dir, 'box_vectors.npy'))[-1]
                init_velocities = None
                state_inds = np.load(os.path.join(load_dir, 'states.npy'))[-1]
            except:
                init_velocities, init_positions, init_box_vectors, state_inds = self._recover_arguments()

        # Reshape
        reshaped_init_positions = np.empty((init_positions.shape))
        reshaped_init_box_vectors = np.empty((init_box_vectors.shape))
        for state in range(len(temperatures)):
            rep_ind = np.where(state_inds == state)[0]
            reshaped_init_box_vectors[state] = init_box_vectors[rep_ind] 
            reshaped_init_positions[state] = init_positions[rep_ind] 
        if init_velocities is not None:
            reshaped_init_velocities = np.empty((init_velocities.shape))
            for state in range(len(temperatures)):
                rep_ind = np.where(state_inds == state)[0]
                reshaped_init_velocities[state] = init_velocities[rep_ind] 
        # Convert to quantities    
        self.init_positions = TrackedQuantity(unit.Quantity(value=np.ma.masked_array(data=reshaped_init_positions, mask=False, fill_value=1e+20), unit=unit.nanometer))
        self.init_box_vectors = TrackedQuantity(unit.Quantity(value=np.ma.masked_array(data=reshaped_init_box_vectors, mask=False, fill_value=1e+20), unit=unit.nanometer))
        if init_velocities is not None:
            self.init_velocities = TrackedQuantity(unit.Quantity(value=np.ma.masked_array(data=reshaped_init_velocities, mask=False, fill_value=1e+20), unit=(unit.nanometer / unit.picosecond)))
        else:
            self.init_velocities = None
            self.context = None
            
    
    def _recover_arguments(self):
        ncfile = nc.Dataset(self.output_ncdf, 'r')
        
        # Read
        velocities = ncfile.variables['velocities'][-1].data
        positions = ncfile.variables['positions'][-1].data
        box_vectors = ncfile.variables['box_vectors'][-1].data
        state_inds = ncfile.variables['states'][-1].data
        
        ncfile.close()
        
        return velocities, positions, box_vectors, state_inds
        

#### PT with Restraints

In [42]:
class FultonMarket_PTwRE(FultonMarket_PT):
    """
    Parallel Tempering with Restraints
    Should be functionally similar to Parallel Tempering, with the following additions:
        Definition of spring constant for strength of restraints
        Definition of spring centers for these restraints
        Restraints in all replicates are to the same spring centers
        
    """
    def __init__(self, input_pdb:str, input_system:str, input_state:str=None):
        super().__init__(self, *args, **kwargs)
        printf("Finished Initializing Fulton Market PT with Restraints")
    
    
    def set_experimental_params(self, total_sim_time: float, iteration_length: float,
                                restrained_atoms_dsl:str, K=83.68*spring_constant_unit,
                                dt: float=2.0, T_min: float=300, T_max: float=367.447,
                                n_replicates: int=12, sim_length=50,
                                output_dir: str=os.path.join(os.getcwd(), 'FultonMarket_output/')):
        """
        Set the parameters that will be used for simulations with Randolph
        Parameters:
        -----------
            total_sim_time (unit.Quantity): Aggregate simulation time from all replicates in nanoseconds.
            iteration_length (unit.Quantity): Simulation time between swapping replicates in nanoseconds.
            restrained_atoms_dsl (string): A selection string for mdtraj to select atoms to be restrained.
            K (unit.Quantity): The force constant for the restrained atoms.  Default 20 cal/(mol*Ang**2)
            dt (unit.Quantity): Timestep for simulation. Default is 2.0 femtoseconds.
            T_min (unit.Quantity): Minimum temperature in Kelvin. This state will serve as the reference state. Default is 300 K.
            T_max (unit.Quantity): Maximum temperature in Kelvin. Default is 367.447 K.
            n_replicates (int): Number of replicates, meaning number of states between T_min and T_max.
                States are automatically built at with a geometeric distribution towards T_min. Default is 12.
            sim_length (unit.Quantity): Time (in nanoseconds) for seperate Randolph runs. Default 50 ns
            output_dir (str): String path to output directory to store files. Default is 'FultonMarketPT_output' in the current working directory. 
        """
        #Do the same parameter setting as in regular PT
        super().set_experimental_params(self, total_sim_time, iteration_length, dt=dt,
                                        T_min=T_min, T_max=T_max, n_replicates=n_replicates,
                                        sim_length=sim_length, output_dir=output_dir)
        
        #Set the special cases for restraints
        if type(K) == unit.Quantity:
            self.spring_constants = [K for i in range(n_replicates)]
        elif type(K) == list:
            assert False not in [type(k) == unit.Quantity for k in K]
            self.spring_constants = K
        
        assert len(self.spring_constants) == len(self.temperatures), "Mismatch in number of Spring Constants and Temperatures"
        printf('Running with restraints')
        # Make spring centers already selected against the dsl
        traj = md.load(self.input_pdb)
        inds = traj.top.select(restrained_atoms_dsl)
        spring_centers = traj.xyz[0, inds]
        
        #N_replicate array of spring_centers
        self.spring_centers = np.repeat(spring_centers[np.newaxis, :, :], n_replicates, axis=0)
        assert len(self.temperatures) == self.spring_centers.shape[0]
        
        #N_replicate array of atom indices to be restrained
        restrained_atom_indices = np.repeat(inds[np.newaxis, :], n_replicates, axis=0)
        assert len(self.temperatures) == restrained_atom_indices.shape[0]
        printf('Restraining All States to the Initial Positions of provided selection string')
        
        
        self.Randolph_Params['spring_constants'] = self.spring_constants
        self.Randolph_Params['restrained_atom_indices'] = restrained_atom_indices
        self.Randolph_Params['mdtraj_topology'] = md.Topology.from_openmm(self.pdb.topology)
        self.Randolph_Params['spring_centers'] = self.spring_centers
        
        
    def run_sim_loop(self, init_overlap_thresh:float=0.5, term_overlap_thresh:float=0.35,):
        """
        """
        printf(f'Found initial acceptance rate threshold: {init_overlap_thresh}')
        printf(f'Found terminal acceptance rate threshold: {term_overlap_thresh}')
        printf(f'Beginning Randolph Runs...')
        while self.sim_no < self.total_n_sims:
            
            if self.sim_no > 0:
                self._load_initial_args() #sets positions, velocities, box_vecs, temperatures, and spring_constants
                self.Randolph_Params['init_velocities'] = self.init_velocities
                self.Randolph_Params['init_positions'] = self.init_positions
                self.Randolph_Params['init_box_vectors'] = self.init_box_vectors
                self.Randolph_Params['temperatures'] = self.temperatures
                self.Randolph_Params['spring_constants'] = self.spring_constants

            elif hasattr(self, 'context'):
                self.Randolph_Params['context'] = self.context
            
            simulation = Randolph(**self.Randolph_Params)
            
            # Run simulation
            simulation.main(init_overlap_thresh=init_overlap_thresh, term_overlap_thresh=term_overlap_thresh)

            # Save simulation
            self.temperatures, self.spring_constants, self.spring_centers = simulation.save_simulation(self.save_dir)
            
            # Delete output.ncdf files if not last simulation 
            if not self.sim_no+1 == self.total_n_sims:
                os.remove(self.output_ncdf)
                os.remove(checkpoint_ncdf)

            # Update counter
            self.sim_no += 1
    
    
    def _load_initial_args(self):
        super()._load_initial_args(self)
        spring_constants = np.load(os.path.join(load_dir, 'spring_constants.npy'))
        self.spring_constants = [s*spring_constant_unit for s in spring_constants]
        self.spring_centers = np.load(os.path.join(load_dir, 'spring_centers.npy'))

#### Umbrella Sampling

In [43]:
class FultonMarket_UmbSam(FultonMarket_PT):
    """
    Umbrella Sampling Replica Exchange
    This is a form of PTwRE where the temperature is held constant and the spring centers are changed
    This class also uses a previously run "trailblazing" do set the initial positions
        Replicates will not necessarilly all have the same numbers of atoms
    Randolph Runs are only 5 nanoseconds, with resampling in between each run to obtain
        new intial positions for the next Randolph run
    """
    def __init__(self, input_pdb:str, input_system:str, input_state:str=None):
        super().__init__(self, input_pdb, input_system, input_state=input_state)
        printf("Finished Initializing Fulton Market Umbrella Sampling")
        
    def set_experimental_params(self, total_sim_time: float, iteration_length: float,
                                restrained_atoms_dsl, spring_centers2_pdb:str,
                                init_positions_dcd, K=83.68*spring_constant_unit,
                                dt: float=2.0, T_min: float=300, T_max: float=367.447,
                                n_replicates: int=12, sim_length=50,
                                output_dir: str=os.path.join(os.getcwd(), 'FultonMarket_output/')):
        """
        Set the parameters that will be used for simulations with Randolph
        In this form of the simulation, the temperature of all replicates will be T_max
        Parameters:
        -----------
            total_sim_time (unit.Quantity): Aggregate simulation time from all replicates in nanoseconds.
            iteration_length (unit.Quantity): Simulation time between swapping replicates in nanoseconds.
            restrained_atoms_dsl (string): A two list of selection strings for mdtraj to select atoms to be restrained.
            spring_centers2_pdb (string): A second pdb file, the other end state of Umbrella Sampling
            init_positions_dcd (string or list of strings): Dcd with initial positions, determines if trailblazing
                was done unilateraly or bilateraly by the quantity (string for uni, and 2-list of string for bi)
            K (unit.Quantity): The force constant for the restrained atoms.  Default 20 cal/(mol*Ang**2)
            dt (unit.Quantity): Timestep for simulation. Default is 2.0 femtoseconds.
            T_min (unit.Quantity): Minimum temperature in Kelvin. This state will serve as the reference state. Default is 300 K.
            T_max (unit.Quantity): Maximum temperature in Kelvin. Default is 367.447 K.
            n_replicates (int): Number of replicates, meaning number of states between T_min and T_max.
                States are automatically built at with a geometeric distribution towards T_min. Default is 12.
            sim_length (unit.Quantity): Time (in nanoseconds) for seperate Randolph runs. Default 50 ns
            output_dir (str): String path to output directory to store files. Default is 'FultonMarketPT_output' in the current working directory. 
        """
        
        #Do the same parameter setting as in regular PT
        super().set_experimental_params(self, total_sim_time, iteration_length, dt=dt,
                                        T_min=T_min, T_max=T_max, n_replicates=n_replicates,
                                        sim_length=sim_length, output_dir=output_dir)
        
        #In this Umbrella Sampling mode - the temps should all be the max
        self.temperatures = [T_max * unit.kelvin for temp in geometric_distribution(T_min, T_max, n_replicates)]
        
        #Set the special cases for restraints
        if type(K) == unit.Quantity:
            self.spring_constants = [K for i in range(n_replicates)]
        elif type(K) == list:
            self.spring_constants = K
        
        assert len(self.spring_constants) == len(self.temperatures), "Mismatch in number of Spring Constants and Temperatures"
        printf('Running with restraints')
        
        # Make spring centers already selected against the dsl
        if type(restrained_atoms_dsl) == list:
            assert len(restrained_atoms_dsl) == 2
            self.spring_centers, inds1, inds2 = make_interpolated_positions_array_from_selections(self.input_pdb, restrained_atoms_dsl[0],
                                                                                    spring_centers2_pdb, n_replicates,
                                                                                    selection_2=restrained_atoms_dsl[1])
            printf('Found Second Spring Centers and Made the Shifting Center Schedule using unique selection strings.')
            #In this case the restrained atoms are represented by different indices
            num_per_leg = n_replicates // 2
            if n_replicates %2 == 1:
                restrained_atom_indices = np.concatenate((np.repeat(inds1[np.newaxis, :], num_per_leg, axis=0),
                                                          [inds1],
                                                          np.repeat(inds2[np.newaxis, :], num_per_leg, axis=0)))
            else:
                restrained_atom_indices = np.concatenate((np.repeat(inds1[np.newaxis, :], num_per_leg, axis=0),
                                                          np.repeat(inds2[np.newaxis, :], num_per_leg, axis=0)))
            
            
            
            
        elif type(restrained_atoms_dsl) == str:
            self.spring_centers = make_interpolated_positions_array_from_selections(self.input_pdb, restrained_atoms_dsl,
                                                                                    spring_centers2_pdb, n_replicates)
            printf('Found Second Spring Centers and Made the Shifting Center Schedule.')
            #In the case one selection string is used, the indices are the same for all replicates
            restrained_atom_indices = np.repeat(inds1[np.newaxis, :], n_replicates, axis=0)
        
        #Assign some of these values
        self.Randolph_Params['spring_constants'] = self.spring_constants
        self.Randolph_Params['restrained_atom_indices'] = restrained_atom_indices
        self.Randolph_Params['mdtraj_topology'] = md.Topology.from_openmm(self.pdb.topology)
        self.Randolph_Params['spring_centers'] = self.spring_centers
        
        
        # In the Umbrella Sampling case, initial positions should also be taken from those obtained by
        #   Unilateral or Bilateral trailblazing
        if type(init_positions_dcd) == str:
            printf(f"Assigning Initial Positions from Unilateral Trailblazing DCD {init_positions_dcd}")
            assert os.path.exists(init_positions_dcd), init_positions_dcd
            init_traj = md.load(init_positions_dcd, top=self.input_pdb)
            assert n_replicates == init_traj.n_frames
            #self.init_positions = TrackedQuantity(unit.Quantity(value=np.ma.masked_array(data=init_traj.xyz, mask=False, fill_value=1e+20), unit=unit.nanometer))
            #self.init_box_vectors = TrackedQuantity(unit.Quantity(value=np.ma.masked_array(data=init_traj.unitcell_vectors, mask=False, fill_value=1e+20), unit=unit.nanometer))
            self.Randolph_Params['init_positions'] = np.array([init_traj.openmm_positions(i) for i in range(n_replicates)])
            self.Randolph_Params['init_box_vectors'] = np.array([init_traj.openmm_boxes(i) for i in range(n_replicates)])
            
        
        elif type(init_positions_dcd) == list:
            printf(f"Assigning Initial Positions from Bilateral Trailblazing DCDs {init_positions_dcd}")
            assert False not in [os.path.exists(dcd_fn) for dcd_fn in init_positions_dcd], init_positions_dcd
            
            init_traj1 = md.load(init_positions_dcd[0], top=self.input_pdb)
            init_traj2 = md.load(init_positions_dcd[1], top=spring_centers2_pdb)
            assert init_traj1.n_frames == init_traj2.n_frames
            top1, top2 = init_traj1.top, init_traj2.top
            
            num_per_leg = n_replicates // 2
            #In the bilateral case, feed the initial positions for the second half backward (reversed)
            #Also in the bilateral case, there may be two different topologies for these states
            if n_replicates %2 == 1:
                assert num_per_leg + 1 == init_traj1.n_frames
                init_positions = np.concatenate((np.array([init_traj1.openmm_positions(i) for i in range(num_per_leg+1)]),
                                                 np.array([init_traj2.openmm_positions(i) for i in range(num_per_leg)][::-1])))
                init_box_vecs = np.concatenate((np.array([init_traj1.openmm_boxes(i) for i in range(num_per_leg+1)]),
                                                np.array([init_traj2.openmm_boxes(i) for i in range(num_per_leg)][::-1])))
                topologies = [top1 for i in range(num_per_leg+1)] + [top2 for i in range(num_per_leg)]
            else:
                assert num_per_leg == init_traj1.n_frames
                init_positions = np.concatenate((np.array([init_traj1.openmm_positions(i) for i in range(num_per_leg)]),
                                                 np.array([init_traj2.openmm_positions(i) for i in range(num_per_leg)][::-1])))
                init_box_vecs = np.concatenate((np.array([init_traj1.openmm_boxes(i) for i in range(num_per_leg)]),
                                                np.array([init_traj2.openmm_boxes(i) for i in range(num_per_leg)][::-1])))
                topologies = [top1 for i in range(num_per_leg)] + [top2 for i in range(num_per_leg)]
            
            self.Randolph_Params['init_positions'] = init_positions
            self.Randolph_Params['init_box_vectors'] = init_box_vecs
            self.Randolph_Params['mdtraj_topology'] = topologies
        
    
    def run_sim_loop(self, init_overlap_thresh:float=0.5, term_overlap_thresh:float=0.35,):
        """
        """
        printf(f'Found initial acceptance rate threshold: {init_overlap_thresh}')
        printf(f'Found terminal acceptance rate threshold: {term_overlap_thresh}')
        printf(f'Beginning Randolph Runs...')
        while self.sim_no < self.total_n_sims:
            
            if self.sim_no > 0:
                self._load_initial_args() #sets positions, velocities, box_vecs, temperatures, and spring_constants
                self.Randolph_Params['init_velocities'] = self.init_velocities
                self.Randolph_Params['init_positions'] = self.init_positions
                self.Randolph_Params['init_box_vectors'] = self.init_box_vectors
                self.Randolph_Params['temperatures'] = self.temperatures
                self.Randolph_Params['spring_constants'] = self.spring_constants

            elif hasattr(self, 'context'):
                self.Randolph_Params['context'] = self.context
            
            simulation = Randolph(**self.Randolph_Params)
            
            # Run simulation
            simulation.main(init_overlap_thresh=init_overlap_thresh, term_overlap_thresh=term_overlap_thresh)

            # Save simulation
            self.temperatures, self.spring_constants, self.spring_centers = simulation.save_simulation(self.save_dir)
            
            # Delete output.ncdf files if not last simulation 
            if not self.sim_no+1 == self.total_n_sims:
                os.remove(self.output_ncdf)
                os.remove(checkpoint_ncdf)

            # Update counter
            self.sim_no += 1
            
    def _load_initial_args(self):
        super()._load_initial_args(self)
        spring_constants = np.load(os.path.join(load_dir, 'spring_constants.npy'))
        self.spring_constants = [s*spring_constant_unit for s in spring_constants]
        self.spring_centers = np.load(os.path.join(load_dir, 'spring_centers.npy'))

#### Previous Fulton Market

In [44]:
class FultonMarket():
    """
    Replica exchange
    """

    def __init__(self, input_pdb: str, input_system: str, input_state: str=None):
        """
        Initialize a Fulton Market obj. 

        Parameters:
        -----------
            input_pdb (str):
                String path to pdb to run simulation. 

            input_system (str):
                String path to OpenMM system (.xml extension) file that contains parameters for simulation. 

            input_state (str):
                String path to OpenMM state (.xml extension) file that contains state for reference. 


        Returns:
        --------
            FultonMarket obj.
        """
        print(datetime.now().strftime("%m/%d/%Y %H:%M:%S") + '//' + 'Welcome to FultonMarket.', flush=True)
        
        # Unpack .pdb
        self.input_pdb = input_pdb
        self.pdb = PDBFile(input_pdb)
        self.init_positions = self.pdb.getPositions(asNumpy=True)
        print(datetime.now().strftime("%m/%d/%Y %H:%M:%S") + '//' + 'Found input_pdb:', input_pdb, flush=True)


        # Unpack .xml
        self.system =XmlSerializer.deserialize(open(input_system, 'r').read())
        self.init_box_vectors = self.system.getDefaultPeriodicBoxVectors()
        print(datetime.now().strftime("%m/%d/%Y %H:%M:%S") + '//' + 'Found input_system:', input_system, flush=True)


        # Build state
        if input_state != None:
            integrator = LangevinIntegrator(300, 0.01, 2)
            sim = Simulation(self.pdb.topology, self.system, integrator)
            sim.loadState(input_state)
            self.context = sim.context
            print(datetime.now().strftime("%m/%d/%Y %H:%M:%S") + '//' + 'Found input_state:', input_state, flush=True)

    def run(self, 
            total_sim_time: float, 
            iteration_length: float, 
            dt: float=2.0,
            T_min: float=300, 
            T_max: float=367.447, 
            n_replicates: int=12, 
            sim_length=50,
            init_overlap_thresh: float=0.5, 
            term_overlap_thresh: float=0.35,
            output_dir: str=os.path.join(os.getcwd(), 'FultonMarket_output/'),
            restrained_atoms_dsl:str=None,
            spring_centers2_pdb:str=None,
            init_positions_dcd:str=None,
            K=83.68):
        """
        Run parallel tempering replica exchange.
        if restrained_atoms_dsl is provided, then parallel tempering will also be restrained to the initial positions with spring constants K
        if spring_centers2_pdb is also provided, then the spring centers will shift between the initial positions and spring_centers2 at high temperature

        Parameters:
        -----------
            total_sim_time (float):
                Aggregate simulation time from all replicates in nanoseconds.

            iteration_length (float):
                Specify the amount of time between swapping replicates in nanoseconds. 

            dt (float):
                Timestep for simulation. Default is 2.0 femtoseconds.

            T_min (float):
                Minimum temperature in Kelvin. This state will serve as the reference state. Default is 300 K.

            T_max (float):
                Maximum temperature in Kelvin. Default is 360 K.

            n_replicates (int):
                Number of replicates, meaning number of states between T_min and T_max. States are automatically built at with a geometeric distribution towards T_min. Default is 12.

            init_overlap_thresh (float):
                Acceptance rate threshold during first 50 ns simulation to cause restart. Default is 0.50. 

            term_overlap_thresh (float):
                Terminal acceptance rate. If the minimum acceptance rate every falls below this threshold simulation with restart. Default is 0.35.

            output_dir (str):
                String path to output directory to store files. Default is 'FultonMarket_output' in the current working directory.

            restrained_atoms_dsl (str):
                MDTraj selection string, selected atoms will be restrained

            K_max (Float in units (joule)/(angstrom*angstrom*mole)):
                Strongest restraint (low temperature state), restraints weaken as temperature rises
        """

        # Store variables
        self.total_sim_time = total_sim_time
        self.restrained_atoms_dsl = restrained_atoms_dsl
        self.init_positions_dcd = init_positions_dcd
        self.temperatures = [temp*unit.kelvin for temp in geometric_distribution(T_min, T_max, n_replicates)]

        # Default initial positions and box vectors
        self.init_positions = [self.init_positions for i in range(n_replicates)]
        self.init_box_vectors = [self.init_box_vectors for i in range(n_replicates)]
        
        
        if restrained_atoms_dsl is not None: #Leave the top 20% of states unrestrained
            self.spring_constants = [K * spring_constant_unit for i in range(n_replicates)]
            assert len(self.spring_constants) == len(self.temperatures)
            print(datetime.now().strftime("%m/%d/%Y %H:%M:%S") + '//' + 'Running with restraints', flush=True)
        else:
            self.spring_constants = None
        
        if spring_centers2_pdb is not None:
            # Unpack .pdb
            self.spring_centers = make_interpolated_positions_array(self.input_pdb, spring_centers2_pdb, n_replicates) #All atoms, not just protein
            print(datetime.now().strftime("%m/%d/%Y %H:%M:%S") + '//' + 'Found Second Spring Centers and Made the Shifting Center Schedule', flush=True)


            # Get the inital positions and box vectors
            if self.init_positions_dcd is not None:
                assert os.path.exists(self.init_positions_dcd), self.init_positions_dcd
                init_traj = md.load(self.init_positions_dcd, top=self.input_pdb)
                n_replicates = init_traj.n_frames
                self.init_positions = TrackedQuantity(unit.Quantity(value=np.ma.masked_array(data=init_traj.xyz, mask=False, fill_value=1e+20), unit=unit.nanometer))
                self.init_box_vectors = TrackedQuantity(unit.Quantity(value=np.ma.masked_array(data=init_traj.unitcell_vectors, mask=False, fill_value=1e+20), unit=unit.nanometer))
            
            
            # Additionally, in this Umbrella Sampling mode - the temps should all be the max
            self.temperatures = [T_max * unit.kelvin for temp in geometric_distribution(T_min, T_max, n_replicates)]
            
        elif self.spring_constants is not None:
            self.spring_centers = self.init_positions.copy()
            print(datetime.now().strftime("%m/%d/%Y %H:%M:%S") + '//' + 'Restraining All States to the Initial Positions', flush=True)
        else:
            self.spring_centers = None

        ref_state = states.ThermodynamicState(system=self.system, temperature=self.temperatures[0], pressure=1.0*unit.bar)
        self.output_ncdf = os.path.join(output_dir, 'output.ncdf')
        checkpoint_ncdf = os.path.join(output_dir, 'output_checkpoint.ncdf')
        self.save_dir = os.path.join(output_dir, 'saved_variables')
        if not os.path.exists(self.save_dir):
            os.mkdir(self.save_dir)

        print(datetime.now().strftime("%m/%d/%Y %H:%M:%S") + '//' + 'Found total simulation time of', total_sim_time, 'nanoseconds', flush=True)
        print(datetime.now().strftime("%m/%d/%Y %H:%M:%S") + '//' + 'Found iteration length of', iteration_length, 'nanoseconds', flush=True)
        print(datetime.now().strftime("%m/%d/%Y %H:%M:%S") + '//' + 'Found timestep of', dt, 'femtoseconds', flush=True)
        print(datetime.now().strftime("%m/%d/%Y %H:%M:%S") + '//' + 'Found number of replicates', n_replicates, flush=True)
        print(datetime.now().strftime("%m/%d/%Y %H:%M:%S") + '//' + 'Found initial acceptance rate threshold', init_overlap_thresh, flush=True)
        print(datetime.now().strftime("%m/%d/%Y %H:%M:%S") + '//' + 'Found terminal acceptance rate threshold', term_overlap_thresh, flush=True)
        print(datetime.now().strftime("%m/%d/%Y %H:%M:%S") + '//' + 'Found output_dir', output_dir, flush=True)
        print(datetime.now().strftime("%m/%d/%Y %H:%M:%S") + '//' + 'Found Temperature Schedule', [np.round(T._value, 1) for T in self.temperatures], 'Kelvin', flush=True)
        if self.restrained_atoms_dsl is not None:
            print(datetime.now().strftime("%m/%d/%Y %H:%M:%S") + '//' + 'Found Restraint Schedule', [np.round(T._value, 1) for T in self.spring_constants], spring_constant_unit, flush=True)
            print(datetime.now().strftime("%m/%d/%Y %H:%M:%S") + '//' + 'Found Spring Center Schedule', np.round([rmsd(self.spring_centers[i], self.spring_centers[0]) for i in range(n_replicates)], 2), 'nm', flush=True)
            

        # Loop through short 50 ns simulations to allow for .ncdf truncation
        self._configure_experiment_parameters(sim_length=sim_length)
        while self.sim_no < self.total_n_sims:
            
            params = dict(sim_no=self.sim_no,
                          sim_time=self.sim_time,
                          system=self.system,
                          ref_state=ref_state,
                          temperatures=self.temperatures,
                          spring_constants=self.spring_constants,
                          init_positions=self.init_positions,
                          init_box_vectors=self.init_box_vectors,
                          output_dir=output_dir,
                          output_ncdf=self.output_ncdf,
                          checkpoint_ncdf=checkpoint_ncdf,
                          iter_length=iteration_length,
                          dt=dt)
             
            # Initialize Randolph
            if self.sim_no > 0:
                self._load_initial_args() #sets positions, velocities, box_vecs, temperatures, and spring_constants
                params['init_velocities'] = self.init_velocities
                params['init_positions'] = self.init_positions
                params['init_box_vectors'] = self.init_box_vectors
                params['temperatures'] = self.temperatures
                if restrained_atoms_dsl is not None:
                    params['spring_constants'] = self.spring_constants

            elif hasattr(self, 'context'):
                params['context'] = self.context

            if self.restrained_atoms_dsl is not None:
                params['restrained_atoms_dsl'] = self.restrained_atoms_dsl
                params['mdtraj_topology'] = md.Topology.from_openmm(self.pdb.topology)
                params['spring_centers'] = self.spring_centers

                if self.init_box_vectors is not None and self.init_positions is not None:
                    params['init_positions'] = self.init_positions
                    params['init_box_vectors'] = self.init_box_vectors
            
            simulation = Randolph(**params)
            
            # Run simulation
            simulation.main(init_overlap_thresh=init_overlap_thresh, term_overlap_thresh=term_overlap_thresh)

            # Save simulation
            if self.spring_centers is None and restrained_atoms_dsl is None:
                self.temperatures = simulation.save_simulation(self.save_dir)
            else:
                self.temperatures, self.spring_constants, self.spring_centers = simulation.save_simulation(self.save_dir)
            
            # Delete output.ncdf files if not last simulation 
            if not self.sim_no+1 == self.total_n_sims:
                os.remove(self.output_ncdf)
                os.remove(checkpoint_ncdf)

            # Update counter
            self.sim_no += 1
    

    
    def _load_initial_args(self):
        # Get last directory
        load_no = self.sim_no - 1
        load_dir = os.path.join(self.save_dir, str(load_no))
        
        # Load args (not in correct shapes
        self.temperatures = np.load(os.path.join(load_dir, 'temperatures.npy'))
        self.temperatures = [t*unit.kelvin for t in self.temperatures]
        if self.spring_constants is not None:
            self.spring_constants = np.load(os.path.join(load_dir, 'spring_constants.npy'))
            self.spring_constants = [s*spring_constant_unit for s in self.spring_constants]
            self.spring_centers = np.load(os.path.join(load_dir, 'spring_centers.npy'))
        
        
        if self.init_positions_dcd is not None:
            self._resample_init_positions()
        else:
            try:
                init_positions = np.load(os.path.join(load_dir, 'positions.npy'))[-1] 
                init_box_vectors = np.load(os.path.join(load_dir, 'box_vectors.npy'))[-1] 
                init_velocities = np.load(os.path.join(load_dir, 'velocities.npy')) 
                state_inds = np.load(os.path.join(load_dir, 'states.npy'))[-1]
            except:
                try:
                    init_positions = np.load(os.path.join(load_dir, 'positions.npy'))[-1]
                    init_box_vectors = np.load(os.path.join(load_dir, 'box_vectors.npy'))[-1]
                    init_velocities = None
                    state_inds = np.load(os.path.join(load_dir, 'states.npy'))[-1]
                except:
                    init_velocities, init_positions, init_box_vectors, state_inds = self._recover_arguments()
        
            # Reshape 
            reshaped_init_positions = np.empty((init_positions.shape))
            reshaped_init_box_vectors = np.empty((init_box_vectors.shape))
            for state in range(len(self.temperatures)):
                rep_ind = np.where(state_inds == state)[0]
                reshaped_init_box_vectors[state] = init_box_vectors[rep_ind] 
                reshaped_init_positions[state] = init_positions[rep_ind] 
            if init_velocities is not None:
                reshaped_init_velocities = np.empty((init_velocities.shape))
                for state in range(len(self.temperatures)):
                    rep_ind = np.where(state_inds == state)[0]
                    reshaped_init_velocities[state] = init_velocities[rep_ind] 
            # Convert to quantities    
            self.init_positions = TrackedQuantity(unit.Quantity(value=np.ma.masked_array(data=reshaped_init_positions, mask=False, fill_value=1e+20), unit=unit.nanometer))
            self.init_box_vectors = TrackedQuantity(unit.Quantity(value=np.ma.masked_array(data=reshaped_init_box_vectors, mask=False, fill_value=1e+20), unit=unit.nanometer))
            if init_velocities is not None:
                self.init_velocities = TrackedQuantity(unit.Quantity(value=np.ma.masked_array(data=reshaped_init_velocities, mask=False, fill_value=1e+20), unit=(unit.nanometer / unit.picosecond)))
            else:
                self.init_velocities = None
                self.context = None

    def _configure_experiment_parameters(self, sim_length=50):
        # Assert that no empty save directories have been made
        assert all([len(os.listdir(os.path.join(self.save_dir, dir))) >= 5 for dir in os.listdir(self.save_dir)]), "You may have an empty save directory, please remove empty or incomplete save directories before continuing :)"
        
        # Configure experiment parameters
        self.sim_no = len(os.listdir(self.save_dir))
        print(datetime.now().strftime("%m/%d/%Y %H:%M:%S") + '//' + 'Found n_sims_completed to be', self.sim_no, flush=True)
        self.sim_time = sim_length # ns
        self.total_n_sims = np.ceil(self.total_sim_time / self.sim_time)
        print(datetime.now().strftime("%m/%d/%Y %H:%M:%S") + '//' + 'Calculated total_n_sims to be', self.total_n_sims, flush=True)
    
    def _recover_arguments(self):
        ncfile = nc.Dataset(self.output_ncdf, 'r')
        
        # Read
        velocities = ncfile.variables['velocities'][-1].data
        positions = ncfile.variables['positions'][-1].data
        box_vectors = ncfile.variables['box_vectors'][-1].data
        state_inds = ncfile.variables['states'][-1].data
        
        ncfile.close()
        
        return velocities, positions, box_vectors, state_inds
    
    def _resample_init_positions(self):
        """
        """
        #Load an analyzer
        input_dir = os.path.abspath(os.path.join(self.save_dir, '..'))
        #The analyzer will handle the loading of energies and any backfilling
        analyzer = FultonMarketAnalysis(input_dir, self.input_pdb)
        #Set the new intial positions and box vecs by resampling
        new_init_positions = []
        new_init_box_vecs = []
        for i in range(self.simulation.n_replicates):
            analyzer.importance_resampling(n_samples=1, equilibration_method='None', specify_state=i)
            #sets analyzer.resampled_inds and analyzer.weights
            traj = analyzer.write_resampled_traj('temp.pdb', 'temp.dcd', return_traj=True)
            #clean up from that line
            os.remove('temp.pdb')
            os.remove('temp.dcd')
            #Add positions and box vectors to the list
            new_init_positions.append(traj.openmm_positions(0))
            new_init_box_vecs.append(traj.openmm_boxes(0))
        
        self.init_positions = new_init_positions
        self.init_box_vecs = new_init_box_vecs
        self.init_velocities = None
                    

### Randolph_Params

#### All Runs
sim_no - Which Randolph run (int)\
sim_time - Length of the Randolph run (unit.Quantity in units of nanoseconds)\
system - Openmm system\
ref_state - Reference Thermodynamic State\
temperatures - List of temperatures (list of unit.Kelvin)\
init_positions - Initial Positions\
init_box_vectors - Initial Box vectors\
init_velocities - Initial Velocities\
output_dir - Directory for Output\
output_ncdf - NetCDF filename\
checkpoint_ncdf - Checkpoint NetCDF filename\
iter_length - Time between swaps\
dt - Step time\
context - Openmm Context to load from if necessary

#### PTwRE Runs
spring_constants\
restrained_atoms_dsl\
mdtraj_topology\
spring_centers

In [None]:
class Randolph():
    """
    """
    
    def __init__(self, sim_no: int, sim_time: unit.Quantity, system: openmm.System,
                 ref_state: ThermodynamicState, temperatures: np.array,
                 init_positions: np.array, init_box_vectors: np.array, 
                 output_dir: str, output_ncdf: str, checkpoint_ncdf: str,
                 iter_length: unit.Quanity, dt: unit.Quantity,
                 init_velocities=None, sampler_states=None,
                 context=None, 
                 spring_constants=None,
                 restrained_atom_indices=None,
                 spring_centers=None):
        """
        Sim_no - Integer reference to which run this is
        sim_time - nanosecond length of the randolph run (aggregate)
        system - Openmm System describing 
        """
        # Assign attributes
        self.sim_no = sim_no
        self.sim_time = sim_time
        self.system = system
        self.output_dir = output_dir
        self.output_ncdf = output_ncdf
        self.checkpoint_ncdf = checkpoint_ncdf
        self.temperatures = temperatures.copy()
        self.ref_state = ref_state
        self.n_replicates = len(self.temperatures)
        self.init_positions = init_positions
        self.init_box_vectors = init_box_vectors
        self.init_velocities = init_velocities
        self.iter_length = iter_length
        self.dt = dt
        self.context = context
        
        #Restraints if necessary
        self.restrained_atom_indices = restrained_atom_indices
        self.spring_constants = spring_constants
        self.mdtraj_topology = mdtraj_topology
        self.spring_centers = spring_centers
        
        # Configure simulation parameters
        self._configure_simulation_parameters()
        
        # Build simulation
        self._build_simulation()

    
    def main(self, init_overlap_thresh: float, term_overlap_thresh: float):
        """
        """
        
        # Assign attributes
        self.init_overlap_thresh = init_overlap_thresh
        self.term_overlap_thresh = term_overlap_thresh

        #Minimize if this is the start of the first Randolph run
        if self.sim_no == 0:
            print(datetime.now().strftime("%m/%d/%Y %H:%M:%S") + '//' + 'Minimizing...', flush=True)
            self.simulation.minimize() 
            print(datetime.now().strftime("%m/%d/%Y %H:%M:%S") + '//' + 'Minimizing finished.', flush=True)
        
        # Continue until self.n_cycles reached
        self.current_cycle = 0
        while self.current_cycle <= self.n_cycles:
            self._run_cycle()
            
    
    def save_simulation(self, save_dir):
        """
        Save the important information from a simulation and then truncate the output.ncdf file to preserve disk space.
        """
        # Determine save no. 
        save_no_dir = os.path.join(save_dir, str(self.sim_no))
        if not os.path.exists(save_no_dir):
            os.mkdir(save_no_dir)

        # Truncate output.ncdf
        ncdf_copy = os.path.join(self.output_dir, 'output_copy.ncdf')
        pos, velos, box_vectors, states, energies, temperatures = truncate_ncdf(self.output_ncdf, ncdf_copy, self.reporter, False)
        np.save(os.path.join(save_no_dir, 'positions.npy'), pos.data)
        np.save(os.path.join(save_no_dir, 'velocities.npy'), velos.data)
        np.save(os.path.join(save_no_dir, 'box_vectors.npy'), box_vectors.data)
        np.save(os.path.join(save_no_dir, 'states.npy'), states.data)
        np.save(os.path.join(save_no_dir, 'energies.npy'), energies.data)
        np.save(os.path.join(save_no_dir, 'temperatures.npy'), temperatures)
        
        if self.restrained_atom_indices is not None:
            spring_constants = np.array([np.round(t._value,2) for t in self.spring_constants])
            np.save(os.path.join(save_no_dir, 'spring_constants.npy'), spring_constants)
            np.save(os.path.join(save_no_dir, 'spring_centers.npy'), self.spring_centers)

        # Truncate output_checkpoint.ncdf
        checkpoint_copy = os.path.join(self.output_dir, 'output_checkpoint_copy.ncdf')
        truncate_ncdf(self.checkpoint_ncdf, checkpoint_copy, self.reporter, True)

        # Write over previous .ncdf files
        os.system(f'mv {ncdf_copy} {self.output_ncdf}')
        os.system(f'mv {checkpoint_copy} {self.checkpoint_ncdf}')

        # Close reporter object
        try:
            self.reporter.close()
        except:
            pass    
        
        if self.restrained_atom_indices is not None:    
            return [t*unit.kelvin for t in temperatures], [t*spring_constant_unit for t in self.spring_constants], self.spring_centers
        else:
            return [t*unit.kelvin for t in temperatures]
        

    
    def _configure_simulation_parameters(self):
        """
        Configure simulation times to meet aggregate simulation time. 
        """            

        # Read number replicates if different than argument
        self.n_replicates = len(self.temperatures)
        
        # Configure times/steps
        sim_time_per_rep = self.sim_time / self.n_replicates
        steps_per_rep = round(sim_time_per_rep / self.dt)
        self.n_steps_per_iter = round(self.iter_length / self.dt)
        self.n_iters = np.ceil(steps_per_rep / self.n_steps_per_iter)
        self.n_cycles = np.ceil(self.n_iters / 5)
        self.n_iters_per_cycle = np.ceil(self.n_iters / self.n_cycles)
        
        printf(f'Calculated simulation per replicate to be: {np.round(sim_time_per_rep, 6)}')
        printf(f'Calculated steps per replicate to be {steps_per_rep} steps')
        printf(f'Calculated steps per iteration to be {self.n_steps_per_iter} steps')
        printf(f'Calculated number of iterations to be {self.n_iters} iterations')
        printf(f'Calculated number of cycles to be {self.n_cycles} cycles') 
        printf(f'Calculated number of iters per cycle to be {self.n_iters_per_cycle} iterations')

        # Configure replicates
        temps = [np.round(t._value,1) for t in self.temperatures]
        printf(f'Calculated temperature of {self.n_replicates} replicates to be {temps} Kelvin')
        
        if self.restrained_atom_indices is not None:
            springs = [np.round(t._value,1) for t in self.spring_constants]
            printf(f'Calculated spring_constants of {self.n_replicates} replicates to be {springs} Joule / (mol*Ang^2)')
            rmsds = np.round([rmsd(self.spring_centers[i], self.spring_centers[0]) for i in range(self.n_replicates)], 2)
            printf(f'Calculated spring_centers of {rmsds} nm')


    
    def _build_simulation(self):
        """
        """
        # Set up integrator
        move = mcmc.LangevinDynamicsMove(timestep=self.dt, collision_rate=1.0 / unit.picosecond, n_steps=self.n_steps_per_iter, reassign_velocities=False)
        
        # Set up simulation
        if self.restrained_atom_indices is None:
            self.simulation = ParallelTemperingSampler(mcmc_moves=move, number_of_iterations=self.n_iters)
        else:
            self.simulation = ReplicaExchangeSampler(mcmc_moves=move, number_of_iterations=self.n_iters) #This is the case for PTwR and US
        self.simulation._global_citation_silence = True

        # Remove existing .ncdf files
        if os.path.exists(self.output_ncdf):
            printf(f'Removing {self.output_ncdf}')
            os.remove(self.output_ncdf)
        
        # Setup reporter
        atom_inds = tuple([i for i in range(self.system.getNumParticles())])
        self.reporter = MultiStateReporter(self.output_ncdf, checkpoint_interval=10, analysis_particle_indices=atom_inds)
        
        # Initialize sampler states if starting from scratch, otherwise they should be determinine in interpolation or passed through from Fulton Market
        if self.init_velocities is not None:
            printf('Setting initial positions with the "Velocity" method')
            self.sampler_states = [SamplerState(positions=self.init_positions[i], box_vectors=self.init_box_vectors[i], velocities=self.init_velocities[i]) for i in range(self.n_replicates)]
            
        elif self.context is not None:
            printf('Setting initial positions with the "Context" method')
            self.sampler_states = SamplerState(positions=self.init_positions, box_vectors=self.init_box_vectors).from_context(self.context)
            
        else:
            printf('Setting initial positions with the "No Context" method')
            
            if self.sim_no > 0 or len(np.array(self.init_box_vectors).shape) == 3:
                printf(f'Setting initial positions individual to each state')
                self.sampler_states = [SamplerState(positions=self.init_positions[i], box_vectors=self.init_box_vectors[i]) for i in range(self.n_replicates)]
        
            
        if self.restrained_atom_indices is None:
            self.simulation.create(thermodynamic_state=self.ref_state, sampler_states=self.sampler_states,
                                   storage=self.reporter, temperatures=self.temperatures, n_temperatures=self.n_replicates)
        else:
            printf(f'Creating {len(self.temperatures)} Thermodynamic States')
            thermodynamic_states = [ThermodynamicState(system=self.system, temperature=T) for T in self.temperatures]
            printf('Done Creating Thermodynamic States')
            printf(f'Assigning {len(self.spring_constants)} Restraints')
            assert len(self.temperatures) == len(self.spring_constants)
            #In this case, iterate over the n_replicate spring_centers and assign different ones to each thermo_state
            for i, (thermo_state, spring_cons, spring_center) in enumerate(zip(thermodynamic_states, self.spring_constants, self.spring_centers)):
                restrain_atoms_by_dsl(thermo_state, self.mdtraj_topology, self.restrained_atoms_dsl, spring_cons, spring_center)
            
            self.simulation.create(thermodynamic_states=thermodynamic_states, sampler_states=self.sampler_states, storage=self.reporter)
        

    
    def _run_cycle(self):
        """
        Run one cycle
        """

        # Take steps
        print(datetime.now().strftime("%m/%d/%Y %H:%M:%S") + '//' + 'CYCLE', self.current_cycle, 'advancing', self.n_iters_per_cycle, 'iterations', flush=True) 
        if self.simulation.is_completed:
            self.simulation.extend(self.n_iters_per_cycle)
        else:
            self.simulation.run(self.n_iters_per_cycle)

        # Eval acceptance rates
        if self.sim_no == 0:
            insert_inds = self._eval_acc_rates(self.init_overlap_thresh)
        else:
            insert_inds = self._eval_acc_rates(self.term_overlap_thresh)

        # Interpolate, if necessary
        if len(insert_inds) > 0:
            self._interpolate_states(insert_inds)
            self.reporter.close()
            self.current_cycle = 0
            self._configure_simulation_parameters()
            self._build_simulation()
        else:
            self.current_cycle += 1
            
            
    def _eval_acc_rates(self, acceptance_rate_thresh: float=0.40):
        "Evaluate acceptance rates"        
        
        # Get temperatures
        temperatures = [s.temperature._value for s in self.reporter.read_thermodynamic_states()[0]]
        
        # Get mixing statistics
        accepted, proposed = self.reporter.read_mixing_statistics()
        accepted = accepted.data
        proposed = proposed.data
        acc_rates = np.mean(accepted[1:] / proposed[1:], axis=0)
        acc_rates = np.nan_to_num(acc_rates) # Adjust for cases with 0 proposed swaps
    
        # Iterate through mixing statistics to flag acceptance rates that are too low
        insert_inds = [] # List of indices to apply new state. Ex: (a "1" means a new state between "0" and the previous "1" indiced state)
        for state in range(len(acc_rates)-1):
            print(datetime.now().strftime("%m/%d/%Y %H:%M:%S") + '//' + 'Mixing between', np.round(temperatures[state], 2), 'and', np.round(temperatures[state+1], 2), ':', acc_rates[state, state+1], flush=True) 
            rate = acc_rates[state, state+1]
            if rate < acceptance_rate_thresh:
                insert_inds.append(state+1)
    
        return np.array(insert_inds)
        
        
    def _interpolate_states(self, insert_inds: np.array):
    
        # Add new states
        prev_temps = [s.temperature._value for s in self.reporter.read_thermodynamic_states()[0]]
        new_temps = [temp for temp in prev_temps]
        for displacement, ind in enumerate(insert_inds):
            temp_below = prev_temps[ind-1]
            temp_above = prev_temps[ind]
            print(datetime.now().strftime("%m/%d/%Y %H:%M:%S") + '//' + 'Inserting state at', np.mean((temp_below, temp_above)), flush=True) 
            new_temps.insert(ind + displacement, np.mean((temp_below, temp_above)))
        self.temperatures = [temp*unit.kelvin for temp in new_temps]

        # Add new restraints if in PTwRE
        if self.restrained_atoms_dsl is not None :
            prev_spring_cons = [s._value for s in self.spring_constants]
            new_spring_cons = [cons for cons in prev_spring_cons]
            for displacement, ind in enumerate(insert_inds):
                cons_below = prev_spring_cons[ind-1]
                cons_above = prev_spring_cons[ind]
                print(datetime.now().strftime("%m/%d/%Y %H:%M:%S") + '//' + 'Inserting state with Spring Constant', np.mean((cons_below, cons_above)), flush=True) 
                new_spring_cons.insert(ind + displacement, np.mean((cons_below, cons_above)))
            self.spring_constants = [cons * spring_constant_unit for cons in new_spring_cons]
            assert len(self.spring_constants) == len(self.temperatures)
            
            prev_spring_centers = self.spring_centers
            new_spring_centers = self.spring_centers
            for displacement, ind in enumerate(insert_inds):
                center_below = prev_spring_centers[ind - 1]
                center_above = prev_spring_centers[ind]
                print(datetime.now().strftime("%m/%d/%Y %H:%M:%S") + '//' + 'Inserting state with new Spring Center', flush=True)
                new_center = 0.5*(center_above + center_below)
                new_spring_centers = np.insert(new_spring_centers, ind + displacement, new_center, axis=0)
            self.spring_centers = new_spring_centers
            assert self.spring_centers.shape[0] == len(self.temperatures)
        
        self.n_replicates = len(self.temperatures)
        
        # Add pos, box_vecs, velos for new temperatures
        self.init_positions = np.insert(self.init_positions, insert_inds, [self.init_positions[ind-1] for ind in insert_inds], axis=0)
        self.init_box_vectors = np.insert(self.init_box_vectors, insert_inds, [self.init_box_vectors[ind-1] for ind in insert_inds], axis=0)
        if self.init_velocities is not None:
            self.init_velocities = np.insert(self.init_velocities, insert_inds, [self.init_velocities[ind-1] for ind in insert_inds], axis=0)


        # Convert to quantities    
        if self.sim_no > 0:
            self.init_positions = TrackedQuantity(unit.Quantity(value=np.ma.masked_array(data=self.init_positions, mask=False, fill_value=1e+20), unit=unit.nanometer))
            self.init_box_vectors = TrackedQuantity(unit.Quantity(value=np.ma.masked_array(data=np.array(self.init_box_vectors).reshape(self.n_replicates, 3, 3), mask=False, fill_value=1e+20), unit=unit.nanometer))
            if self.init_velocities is not None:
                self.init_velocities = TrackedQuantity(unit.Quantity(value=np.ma.masked_array(data=self.init_velocities, mask=False, fill_value=1e+20), unit=(unit.nanometer / unit.picosecond)))

