# Boilerplate

In [None]:
import warnings
warnings.filterwarnings('ignore')

In [None]:
import os

from JupyterJoy.simtools.linker_system import GMXLinkerSystem, mindist
import JupyterJoy.pbash
from JupyterJoy.mdpbuild.mdp import MDP20183
from JupyterJoy.bqploteins import TrajPlotTime, EZFigure

import numpy as np

import mdtraj as md

import bqplot as bq

import pandas as pd

In [None]:
%%pbash 
source /store/opt/gromacs-2018.3-plumed-2.4.2/bin/GMXRC

In [None]:
nbdir = 'SET ME TO THE FOLDER YOU WANNA RUN SIMS IN'
def cd(target_dir=nbdir, mkdir=False):
    """CDs both python and the persistent Bash kernel"""
    if mkdir:
        %pbash mkdir -p $target_dir
    os.chdir(target_dir)
    %pbash cd $target_dir
    
class WorkingDirectory():
    """Context manager that cds in and out on enter/exit"""
    def __init__(self, target_dir):
        self.target_dir = target_dir
    def __enter__(self):
        cd()
        cd(self.target_dir, mkdir=True)
        return os.getcwd()
    def __exit__(self, *args):
        cd()
    
cd()
%pbash nbdir=$nbdir
%pbash pwd
%pbash ls

# Configuration

In [None]:
c22s_prod_mdp = MDP20183("""
; MDP file for the CHARMM22* force field
; Follows the CHARMM27 GROMACS implementation paper: https://dx.doi.org/10.1021/ct900549r

;    INTEGRATION - MD
integrator               = md
dt                       = 0.002 ; 2 fs step
nsteps                   = 100000000 ; 200 ns simulation
nstcomm                  = 1000 ; remove COM motion every 2 ps

;    CONSTRAINTS - MD
; Constraints on all bonds permit 2 fs time step
; LINCS is faster and more stable
; Use SHAKE if you need angle constraints
constraints              = all-bonds
constraint-algorithm     = LINCS
continuation             = no

;    OUTPUT CONTROL
; Strangely, this is where GROMACS sets all its output control
; mdrun switches don't change anything
nstxout                  = 0 ; only last ; Steps between writing coords to uncompressed output trajectory
nstvout                  = 0 ; only last ; Steps between writing velocities to uncompressed output trajectory
nstfout                  = 0 ; never ; Steps between writing forces to uncompressed output trajectory
nstlog                   = 5000 ; 100 ps ; Steps between writing energies to log file
nstenergy                = 5000 ; 100 ps ; Steps between writing energies to energy file
nstxout-compressed       = 5000 ; 100 ps ; Steps between writing coords to compressed output trajectory
compressed-x-precision   = 1000 ; Trajectory compression is lossy; this is the precision of that compression

;    CUTOFF SCHEME - verlet
; Verlet is faster, more parallelisable, more accurate, supports GPUs
; Only use group for the legacy interactions it supports
cutoff-scheme            = Verlet

;    COULOMB INTERACTIONS
coulombtype              = PME
rcoulomb                 = 1.2

;    LJ INTERACTIONS
vdwtype                  = Cut-off
rvdw                     = 1.2
vdw-modifier             = force-switch
rvdw-switch              = 1.0
dispcorr                 = no

;    TEMPERATURE COUPLING
tcoupl                   = V-rescale
tc-grps                  = Protein  non-Protein
tau-t                    = 1.0      1.0
ref-t                    = 300.00   300.00

;    PRESSURE COUPLING - production
; Parrinello-Rahman produces better pressure distribution,
; but is less stable and can oscillate if the box has to dramatically change size
pcoupl                  = Parrinello-Rahman
pcoupltype              = isotropic
tau-p                   = 12.0
compressibility         = 4.5e-5 4.5e-5
ref-p                   = 1.0 1.0

;    VELOCITY GENERATION
gen-vel                  = no
gen-temp                 = 300.00
gen-seed                 = -1 ; -1 uses a random seed
""")

In [None]:
repeat = 'NANP'
num_repeats = 3

system = GMXLinkerSystem(
    sequence=repeat * num_repeats, 
    ffpath='../../charmm22star_kcx.ff',
    name=f'{repeat}{num_repeats}',
    min_temp=300.0,
    max_temp=600.0,
    num_reps=8,
    exchange_freq=100
)

print(
    "Will simulate", 
    system.name, 
    "over the temperature ladder", 
    system.ladder, 
    "with the forcefield found at",
    system.ff_path
)

In [None]:
# To load a previously constructed system
# system = GMXLinkerSystem.read(f"{system.name}_sys/{system.name}.pickle")

# Prepare Extended Structure and Topology

## Prep Structure

In [None]:
c22s_prod_mdp.set_temperature(system.min_temp)
system.initialise()
system.optimise_rdsq_box(1.2)
system.solvate()
system.salt(conc=0.15)
system.em()
system.write_all(f"{system.name}_sys")

# NPT Equilibration

This equilibration needs to do two things. First, it needs to ensure we've got the number of water molecules right so that at 300 K the protein isn't interacting with itself. Second, we'll use it to set the box size of our NVT equilibration run.

In [None]:
stepnum = 1
deffnm = 'npt_equil'

In [None]:
with WorkingDirectory(f'{stepnum}_{deffnm}'):
    mdp = c22s_prod_mdp.copy()
    mdp.set_time(2, 1)
    mdp.pcoupl = 'berendsen'
    mdp.tau_p = 0.5
    mdp.genvel = 'yes'
    
    system.run_sim(mdp, deffnm)   


## Analysis

In [None]:
traj, edr_df = system.get_properties(deffnm, cwd=f'{stepnum}_{deffnm}')          
print('"' + '"    "'.join(edr_df) + '"')

In [None]:
mean_volume = np.mean(edr_df['Volume'][1:]) # Drop first frame to let box size equilibrate a bit
best_match = (-1, float('inf'))
for idx, volume in enumerate(edr_df['Volume']):
    if abs(volume - mean_volume) < abs(best_match[1] - mean_volume):
        best_match = (idx, volume)
print(f'The mean volume was {mean_volume:0.8} nm^3, the closest frame was at index {best_match[0]} with {best_match[1]:0.8} nm^3')

In [None]:
w = TrajPlotTime(traj, edr_df['Min. PI dist'])
w

# NVT Equilibration

First, we do one long sim at max_T and use it to generate starting structures for a ladder equilibration. Then we do a series of longish MD all along the temperature ladder with exchange to equilibrate each replica. 

In [None]:
stepnum = 2
deffnm = 'nvt_equil'

In [None]:
with WorkingDirectory(f'{stepnum}_{deffnm}'):
    min_idx = best_match[0]
    traj[min_idx].save('npt_mostavevol.pdb')
    system.load_pdb('npt_mostavevol.pdb')

In [None]:
mdp = c22s_prod_mdp.copy()
mdp.set_time(2, 100)
mdp.remove_pcouple()
mdp.set_temperature(system.max_temp)
mdp.genvel = 'yes'

with WorkingDirectory(f'{stepnum}_{deffnm}'):
    system.run_sim(mdp, deffnm) 

## Analysis

In [None]:
traj, edr_df = system.get_properties(deffnm, cwd=f'{stepnum}_{deffnm}')          
print('"' + '"    "'.join(edr_df) + '"')

In [None]:
calpha_atom_indices = traj.top.select_atom_indices('alpha')
rmsd = md.rmsd(traj, system.traj, atom_indices=calpha_atom_indices)
w = TrajPlotTime(traj, rmsd, stride=100)
w

# NPT Ladder equilibration

In [None]:
stepnum = 3
deffnm = "npt_ladder_equil"

In [None]:
mdp = c22s_prod_mdp.copy()
mdp.set_time(2, 1)
mdp.pcoupl = 'berendsen'
mdp.tau_p = 0.5
mdp.genvel = 'yes'

starting_frames = system.take_starting_strucs(traj, mdp, skiptime_ps=1000)

temp_ladder = system.ladder
    
system.prep_rest2(deffnm, f'{stepnum}_{deffnm}', mdp, starting_frames)

In [None]:
with WorkingDirectory(f'{stepnum}_{deffnm}'):
    system.call_gmx(
        cmd='mdrun_mpi', 
        stdin='',
        mpiranks=system.num_reps,
        deffnm=deffnm,
        v=True,
        multidir=(f'{t:.2f}' for t in temp_ladder),
        plumed='plumed.dat',
        hrex=True,
        replex=-1
    )

In [None]:
for t in temp_ladder:
    path = f'{stepnum}_{deffnm}/{t:.2f}'
    with WorkingDirectory(path):
        system.trajvis(f'{deffnm}.xtc')

## Analysis

In [None]:
traj_dict = {}
edr_dict = {}
for t in system.ladder:
    tstr = f'{t:.2f}'
    path = f'{stepnum}_{deffnm}/{tstr}'
    traj_dict[tstr], edr_dict[tstr] = system.get_properties(deffnm, cwd=f'{stepnum}_{deffnm}/{tstr}')

In [None]:
plot_x = 'Time'
plot_y = 'Pressure'

fig = EZFigure(label_x=plot_x, label_y=plot_y)

colourscheme = bq.colorschemes.CATEGORY20
for c, (t, df) in zip(colourscheme, edr_dict.items()):
    stride = 1
    times = df[plot_x][::stride]
    enes = df[plot_y][::stride]
    fig.lines(title=t, x=times, y=enes, colors=[c])

fig.scale_y.reverse=True
fig

# REST2 Production

In [None]:
# MDTraj doesn't save velocities, so we want to continue with the actual previous file
prev_path_fstring = f'{stepnum}_{deffnm}/{{}}/{deffnm}.gro'

stepnum = 4
deffnm = "prod"

In [None]:
mdp = c22s_prod_mdp.copy()
mdp.set_time(2, 200)
mdp.continuation = 'yes'

# MDTraj doesn't save velocities, so we want to continue with the actual previous file
temp_ladder = system.ladder
starting_frames = [prev_path_fstring.format(f'{t:.2f}') for t in temp_ladder]

system.prep_rest2(deffnm, f'{stepnum}_{deffnm}', mdp, starting_frames)

## Final checks

In [None]:
# Run the first nanosecond and then check everything's OK
with WorkingDirectory(f'{stepnum}_{deffnm}'):
    system.call_gmx(
        cmd='mdrun_mpi', 
        stdin='',
        mpiranks=system.num_reps,
        deffnm=deffnm,
        v=True,
        multidir=(f'{t:.2f}' for t in temp_ladder),
        plumed='plumed.dat',
        hrex=True,
        replex=system.exchange_freq,
        nsteps=int(1000 / mdp.dt)
    )

In [None]:
edr_dict = {}
for t in temp_ladder:
    tstr = f'{t:.2f}'
    path = f'{stepnum}_{deffnm}/{tstr}'
    with WorkingDirectory(path):
        edr_dict[tstr] = edr_to_df(f'{deffnm}.edr')

In [None]:
plot_x = 'Time'
plot_y = 'Pressure'

fig = EZFigure(label_x=plot_x, label_y=plot_y)

colourscheme = bq.colorschemes.CATEGORY20
for c, (t, df) in zip(colourscheme, edr_dict.items()):
    stride = 1
    times = df[plot_x][::stride]
    enes = df[plot_y][::stride]
    fig.lines(title=t, x=times, y=enes, colors=[c])

fig.scale_y.reverse=True
fig

In [None]:
tstr = f'{system.min_temp:.2f}'
logfile = f'{stepnum}_{deffnm}/{tstr}/{deffnm}.log'
print("Here's some key info from the log file")
print("Replica exchange interval should be a multiple of nstlist")
print("Replica exchange probabilities should be around 0.2-0.3")
print("Check the times are reasonable too")
print("\n--------------------------------------------------------\n")
with open(logfile) as f:
    print_rest = False
    for line in f:
        if any([
            line.startswith('Changing nstlist'),
            line.startswith('Intra-simulation communication'),
            line.startswith('Replica exchange interval')            
        ]):
            print(line)
            
        if any([
            line == 'Replica exchange statistics\n',
            line == '     R E A L   C Y C L E   A N D   T I M E   A C C O U N T I N G\n'
        ]):
            print_rest = True
            
        if any([
            line == 'Repl                                Empirical Transition Matrix\n'
        ]):
            print_rest = False
            
        if print_rest:
            print(line[:-1])

## Production

In [None]:
with WorkingDirectory(f'{stepnum}_{deffnm}'):
    system.call_gmx(
        cmd='mdrun_mpi', 
        stdin='',
        mpiranks=system.num_reps,
        deffnm=deffnm,
        v=True,
        multidir=(f'{t:.2f}' for t in temp_ladder),
        plumed='plumed.dat',
        hrex=True,
        replex=system.exchange_freq,
        cpi=True
    )

In [None]:
for t in temp_ladder:
    path = f'{stepnum}_{deffnm}/{t:.2f}'
    with WorkingDirectory(path):
        system.trajvis(f'{deffnm}.xtc')

## Analysis

In [None]:
traj_dict = {}
edr_dict = {}
for t in system.ladder:
    tstr = f'{t:.2f}'
    path = f'{stepnum}_{deffnm}/{tstr}'
    traj_dict[tstr], edr_dict[tstr] = system.get_properties(deffnm, cwd=f'{stepnum}_{deffnm}/{tstr}')

In [None]:
plot_x = 'Time'
plot_y = 'Potential'

fig = EZFigure(label_x=plot_x, label_y=plot_y)

colourscheme = bq.colorschemes.CATEGORY20
for c, (t, df) in zip(colourscheme, edr_dict.items()):
    stride = 1
    times = df[plot_x][::stride]
    enes = df[plot_y][::stride]
    fig.lines(title=t, x=times, y=enes, colors=[c])

fig.scale_y.reverse=True
fig

In [None]:
tstr = f'{system.min_temp:.2f}'
logfile = f'{stepnum}_{deffnm}/{tstr}/{deffnm}.log'
with open(logfile) as f:
    print_rest = False
    for line in f:
        if "<======  ###############  ==>" in line:
            print_rest = True
            
        if print_rest:
            print(line[:-1])