# Basic FEP Processing

## Imports and constants

In [None]:
import safep
#import alchemlyb
from glob import glob
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from alchemlyb.parsing import namd
from IPython.display import display, Markdown
from scipy.stats import gaussian_kde

from scipy.constants import R
from scipy.constants import calorie
kcal = calorie*1000

import warnings
warnings.simplefilter(action='ignore', category=FutureWarning)

## Plotting function definitions (hidden by default)

In [None]:
# Plotting function definitions

def plot_general(cumulative, cumulativeYlim, perWindow, perWindowYlim, RT, width=8, height=4, PDFtype='KDE', fontsize=12):
    fig, axes = plt.subplots(3,2, sharex='col', sharey='row', gridspec_kw={'width_ratios': [2, 1]})
    ((cumAx, del1),( eachAx, del2), (hystAx, pdfAx)) = axes

    fig.delaxes(del1)
    fig.delaxes(del2)

    # Cumulative change in kcal/mol
    cumAx.errorbar(cumulative.index, cumulative.BAR.f*RT, yerr=cumulative.BAR.errors, marker=None, linewidth=1)
    cumAx.set(ylabel=r'Cumulative $\mathrm{\Delta} G_{\lambda}$'+'\n(kcal/mol)', ylim=cumulativeYlim)

    # Per-window change in kcal/mol
    eachAx.errorbar(perWindow.index, perWindow.BAR.df*RT, yerr=perWindow.BAR.ddf, marker=None, linewidth=1)
    eachAx.plot(perWindow.index, perWindow.EXP.dG_f*RT, marker=None, linewidth=1, alpha=0.5)
    eachAx.errorbar(perWindow.index, -perWindow.EXP.dG_b*RT, marker=None, linewidth=1, alpha=0.5)
    eachAx.set(ylabel=r'$\mathrm{\Delta} G_\lambda$'+'\n'+r'$\left(kcal/mol\right)$', ylim=perWindowYlim)

    #Hysteresis Plots
    diff = perWindow.EXP['difference']
    hystAx.vlines(perWindow.index, np.zeros(len(perWindow)), diff, label="fwd - bwd", linewidth=2)
    hystAx.set(ylabel=r'$\delta_\lambda$ (kcal/mol)', ylim=(-1,1))
    hystAx.set_xlabel(xlabel=r'$\lambda$', fontsize=fontsize)
    
    if PDFtype=='KDE':
        kernel = gaussian_kde(diff)
        pdfX = np.linspace(-1, 1, 1000)
        pdfY = kernel(pdfX)
        pdfAx.plot(pdfY, pdfX, label='KDE')
    elif PDFtype=='Histogram':
        pdfY, pdfX = np.histogram(diff, density=True)
        pdfX = pdfX[:-1]+(pdfX[1]-pdfX[0])/2
        pdfAx.plot(pdfY, pdfX,  label="Estimated Distribution")
    else:
        raise(f"Error: PDFtype {PDFtype} not recognized")
    
    pdfAx.set_xlabel(PDFtype, fontsize=fontsize)

    std = np.std(diff)
    mean = np.average(diff)
    temp = pd.Series(pdfY, index=pdfX)
    mode = temp.idxmax()
    
    textstr = r"$\rm mode=$"+f"{np.round(mode,2)}"+"\n"+fr"$\mu$={np.round(mean,2)}"+"\n"+fr"$\sigma$={np.round(std,2)}"
    props = dict(boxstyle='square', facecolor='white', alpha=0.5)
    pdfAx.text(0.15, 0.95, textstr, transform=pdfAx.transAxes, fontsize=14,
            verticalalignment='top', bbox=props)

    fig.set_figwidth(width)
    fig.set_figheight(height*3)
    fig.tight_layout()
    
    for ax in [cumAx,eachAx,hystAx,pdfAx]:
        ax.set_ylabel(ax.get_ylabel(), fontsize=fontsize)

    return fig, [cumAx,eachAx,hystAx,pdfAx] 

def convergence_plot(theax, fs, ferr, bs, berr, fwdColor='#0072B2', bwdColor='#D55E00', lgndF=None, lgndB=None, fontsize=12):
    '''
    Convergence plot. Does the convergence calculation and plotting.
    Arguments: u_nk, tau (an error tuning factor), units (kT or kcal/mol), RT
    Returns: a pyplot
    '''
    if not lgndF:
        lgndF=fwdColor
        lgndB=bwdColor
        
        
    lower = fs[-1]-ferr[-1]
    upper = fs[-1]+ferr[-1]
    theax.fill_between([0,1],[lower, lower], [upper, upper], color=bwdColor, alpha=0.25)
    theax.errorbar(np.arange(len(fs))/len(fs)+0.1, fs, yerr=ferr, marker='o', linewidth=1, color=fwdColor, markerfacecolor='white', markeredgewidth=1, markeredgecolor=fwdColor, ms=5)
    theax.errorbar(np.arange(len(bs))/len(fs)+0.1, bs, yerr=berr, marker='o', linewidth=1, color=bwdColor, markerfacecolor='white', markeredgewidth=1, markeredgecolor=bwdColor, ms=5, linestyle='--')

    theax.xaxis.set_ticks([0, 0.2, 0.4, 0.6, 0.8, 1])
    
    finalMean = fs[-1]
    theax.axhline(y= finalMean, linestyle='-.', color='gray')
    theax.set_ylim((finalMean-0.75, finalMean+0.75))
    
    theax.plot(0, finalMean, linewidth=1, color=lgndF, label='Forward Time Sampling')
    theax.plot(0, finalMean, linewidth=1, color=lgndB, linestyle='--', label='Backward Time Sampling')
    theax.set_xlabel('Fraction of Simulation Time', fontsize=fontsize)
    theax.set_ylabel(r'Total $\mathrm{\Delta} G$ (kcal/mol)', fontsize=fontsize)
    theax.legend()
    return theax

## User parameters

In [None]:
path='/path/to/data'
filename='*.fepout'

temperature = 303.15 #in Kelvin
RT = (R/kcal) * temperature

decorrelate = True #Flag for decorrelation of samples
detectEQ = True #Flag for automated equilibrium detection

fepoutFiles = glob(path+filename)

In [None]:
u_nk = safep.read_and_process(fepoutFiles, temperature, decorrelate, detectEQ)

## Do the BAR fitting and get ΔG
**Note: alchemlyb operates in units of kT by default. We multiply by RT to convert to units of kcal/mol.**

In [None]:
perWindow, cumulative = safep.do_estimation(u_nk) #Run the BAR estimator on the fep data
forward, forward_error, backward, backward_error = safep.do_convergence(u_nk) #Used later in the convergence plot'
per_lambda_convergence = safep.do_per_lambda_convergence(u_nk)

In [None]:
dG = np.round(cumulative.BAR.f.iloc[-1]*RT, 1)
error = np.round(cumulative.BAR.errors.iloc[-1]*RT, 1)

changeAndError = f'\u0394G = {dG}\u00B1{error} kcal/mol'
Markdown('<font size=5>{}</font><br/>'.format(changeAndError))

## Generate Plots 
- General plots:
Cumulative and per-window delta G.
- Convergence plots: 
Plot the estimated total change in free energy as a function of simulation time; contiguous subsets starting at t=0 ("Forward") and t=end ("Reverse")
- Per-window Convergence:
Difference between BAR estimates using first and last halves of the samples (with respect to simulation time)

In [None]:
fig, axes = plot_general(cumulative, None, perWindow, None, RT)
fig.suptitle(changeAndError)
plt.savefig(f'{path}FEP_general_figures.pdf')

In [None]:
fig, convAx = plt.subplots(1,1)
convAx = convergence_plot(convAx, forward*RT, forward_error*RT, backward*RT, backward_error*RT)
plt.savefig(f'{path}FEP_convergence.pdf')

In [None]:
fig, ax = plt.subplots()
ax.errorbar(per_lambda_convergence.index, per_lambda_convergence.BAR.df*RT)
ax.set_xlabel(r"$\lambda$")
ax.set_ylabel(r"$D_{last-first}$ (kcal/mol)")
plt.savefig(f"{path}FEP_perLambda_convergence.pdf")