In [None]:
import safep
import alchemlyb
from glob import glob
import numpy as np
import matplotlib.pyplot as plt
import os
from alchemlyb.parsing import namd
from IPython.display import display, Markdown

from alchemlyb.estimators import BAR
import warnings
warnings.simplefilter(action='ignore', category=FutureWarning)

# User parameters

In [None]:
path='.'
filename='*.fepout'

temperature = 303.15
RT = 0.00198720650096 * temperature
decorrelate = True #Flag for decorrelation of samples
detectEQ = True #Flag for automated equilibrium detection

fepoutFiles = glob(path + '/' + filename)
totalSize = 0
for file in fepoutFiles:
    totalSize += os.path.getsize(file)
print(f"Will process {len(fepoutFiles)} fepout files.\nTotal size:{np.round(totalSize/10**9, 2)}GB")

# Extract key features from the MBAR fitting and get ΔG
Note: alchemlyb operates in units of kT by default. We multiply by RT to convert to units of kcal/mol.

In [None]:
fig, ax = plt.subplots()
u_nk = namd.extract_u_nk(fepoutFiles, temperature)
safep.plot_samples(ax, u_nk, color='blue', label='Raw Data')

if detectEQ:
    print("Detecting equilibrium")
    u_nk = safep.detect_equilibrium_u_nk(u_nk)
    safep.plot_samples(ax, u_nk, color='orange', label='Equilibrium-Detected')
if decorrelate:
    print("Decorrelating")
    u_nk = safep.decorrelate_u_nk(u_nk)
    safep.plot_samples(ax, u_nk, color='green', label='Decorrelated')
    
plt.savefig(f"{path}FEP_number_of_samples.pdf")

In [None]:
perWindow, cumulative = safep.do_estimation(u_nk) #Run the BAR estimator on the fep data
forward, forward_error, backward, backward_error = safep.do_convergence(u_nk) #Used later in the convergence plot'
per_lambda_convergence = safep.do_per_lambda_convergence(u_nk)

In [None]:
dG = np.round(cumulative.BAR.f.iloc[-1]*RT, 1)
error = np.round(cumulative.BAR.errors.iloc[-1]*RT, 1)

changeAndError = f'\u0394G = {dG}\u00B1{error} kcal/mol'
Markdown('<font size=5>{}</font><br/>'.format(changeAndError))

In [None]:
fig, axes = safep.plot_general(cumulative, None, perWindow, None, RT)
fig.suptitle(changeAndError)
plt.savefig(f'{path}FEP_general_figures.pdf')

# Plot the estimated total change in free energy as a function of simulation time; contiguous subsets starting at t=0 ("Forward") and t=end ("Reverse")

In [None]:
fig, convAx = plt.subplots(1,1)
convAx = safep.convergence_plot(convAx, forward*RT, forward_error*RT, backward*RT, backward_error*RT)
plt.savefig(f'{path}FEP_convergence.pdf')

In [None]:
fig, ax = plt.subplots()
ax.errorbar(per_lambda_convergence.index, per_lambda_convergence.BAR.df*RT)
ax.set_xlabel(r"$\lambda$")
ax.set_ylabel(r"$D_{last-first}$ (kcal/mol)")
plt.savefig(f"{path}FEP_perLambda_convergence.pdf")

# Comparison energy distribution plots

In [None]:
import seaborn as sns

def plot(ax, fwd, bwd, bar, vlines=True):
    sns.kdeplot(fwd,ax=ax)
    sns.kdeplot(-1 * bwd,ax=ax)
    f = EXP(fwd)
    b = -EXP(bwd)
    if vlines:
        ax.axvline(f,linestyle='-',color=colors[0])
        ax.axvline(b,linestyle='-',color=colors[1])
        ax.axvline(bar,linestyle='-',color=colors[2])

def EXP(data):
    return - RT * np.log(np.mean(np.exp(- data / RT)))

In [None]:
# Reformat dataframe to access dE series

states = u_nk.columns.values.tolist()
groups = u_nk.groupby(level=u_nk.index.names[1:])

In [None]:
# Choose windows for which to plot overlap

overlap_plots = [0, len(states)-2]     # First and last windows
# overlap_plots = range(len(states)-1)   # All windows
# overlap_plots = [3,12]                 #Custom set of windows

# Number of columns of gaph array
cols = 2
figsize = (16,7)

In [None]:
from matplotlib.lines import Line2D
from math import ceil

for k in overlap_plots:
    assert k <= len(states) - 1, ('Invalid window number: ' + str(k) +
                                  ' - valid numbers are 0 to ' + str(len(states) - 1))

nplots = len(overlap_plots)
rows = ceil(nplots / cols)
# Pad the end of the rectangular layout with blank spaces
layout_items = [str(i) for i in overlap_plots] + ['.'] * (rows * cols - nplots)
layout=[[layout_items[r * cols + c] for c in range(cols)] for r in range(rows)]

fig=plt.figure(figsize=figsize)
gs=fig.subplot_mosaic(layout, gridspec_kw={'hspace' : 0.25})

prop_cycle=plt.rcParams['axes.prop_cycle']
colors=prop_cycle.by_key()['color']

fig.legend([Line2D([0],[0],color='tab:blue'),Line2D([0],[0],color='tab:orange'),Line2D([0],[0],color='tab:green')],
           ['forward EXP','backward EXP', 'BAR'])

for k in overlap_plots:
    uk = groups.get_group(states[k])
    w_f = uk.iloc[:, k+1] - uk.iloc[:, k]
    uk1 = groups.get_group(states[k+1])
    w_r = uk1.iloc[:, k] - uk1.iloc[:, k+1]
    
    plot(gs[str(k)], w_f, w_r, perWindow['BAR']['df'].iloc[k])
    gs[str(k)].set(ylabel=None)
    gs[str(k)].set_title(str(states[k])+' ↔ '+str(states[k+1]))
    # gs[str(k)].set_title(str(states[k])+' ↔ '+str(states[k+1]), x=0.25, y=0.8)