In [None]:
import safep
import alchemlyb
from glob import glob
import numpy as np
import matplotlib.pyplot as plt
import os
from alchemlyb.parsing import namd
from IPython.display import display, Markdown

from alchemlyb.estimators import BAR
import warnings
warnings.simplefilter(action='ignore', category=FutureWarning)

from matplotlib.lines import Line2D
from math import ceil

# User parameters

In [None]:
path='./Sample_Data/'
filename='*.fepout'

temperature = 303.15
RT = 0.00198720650096 * temperature
decorrelate = True #Flag for decorrelation of samples
detectEQ = True #Flag for automated equilibrium detection

fepoutFiles = glob(path + '/' + filename)
totalSize = 0
for file in fepoutFiles:
    totalSize += os.path.getsize(file)
print(f"Will process {len(fepoutFiles)} fepout files.\nTotal size:{np.round(totalSize/10**9, 2)}GB")

# Extract key features from the MBAR fitting and get ΔG
Note: alchemlyb operates in units of kT by default. We multiply by RT to convert to units of kcal/mol.

In [None]:
fig, ax = plt.subplots()
u_nk = namd.extract_u_nk(fepoutFiles, temperature)
safep.plot_samples(ax, u_nk, color='blue', label='Raw Data')

if detectEQ:
    print("Detecting equilibrium")
    u_nk = safep.detect_equilibrium_u_nk(u_nk)
    safep.plot_samples(ax, u_nk, color='orange', label='Equilibrium-Detected')

    
plt.savefig(f"{path}FEP_number_of_samples.pdf")

In [None]:
perWindow, cumulative = safep.do_estimation(u_nk) #Run the BAR estimator on the fep data
forward, forward_error, backward, backward_error = safep.do_convergence(u_nk) #Used later in the convergence plot'
per_lambda_convergence = safep.do_per_lambda_convergence(u_nk)

In [None]:
dG = np.round(cumulative.BAR.f.iloc[-1]*RT, 1)
error = np.round(cumulative.BAR.errors.iloc[-1]*RT, 1)

changeAndError = f'\u0394G = {dG}\u00B1{error} kcal/mol'
Markdown('<font size=5>{}</font><br/>'.format(changeAndError))

In [None]:
fig, axes = safep.plot_general(cumulative, None, perWindow, None, RT)
fig.suptitle(changeAndError)
plt.savefig(f'{path}FEP_general_figures.pdf')

# Plot the estimated total change in free energy as a function of simulation time; contiguous subsets starting at t=0 ("Forward") and t=end ("Reverse")

In [None]:
fig, convAx = plt.subplots(1,1)
convAx = safep.convergence_plot(convAx, forward*RT, forward_error*RT, backward*RT, backward_error*RT)
plt.savefig(f'{path}FEP_convergence.pdf')

In [None]:
fig, ax = plt.subplots()
ax.errorbar(per_lambda_convergence.index, per_lambda_convergence.BAR.df*RT)
ax.set_xlabel(r"$\lambda$")
ax.set_ylabel(r"$D_{last-first}$ (kcal/mol)")
plt.savefig(f"{path}FEP_perLambda_convergence.pdf")

# Comparison energy distribution plots

In [None]:
from scipy import stats

def get_kde(data, bins):
    kernel = stats.gaussian_kde(data)
    return kernel(bins)

def normalize(data):
    return data/np.sum(data)

def plot(ax, fwd, bwd, bar, vlines=True, colors=plt.rcParams['axes.prop_cycle'].by_key()['color']):
    #ax = sns.kdeplot(fwd,ax=ax)
    #ax = sns.kdeplot(-1 * bwd,ax=ax)
    rev = -1*bwd
    lower = np.min([fwd.min(), rev.min(), 0])
    upper = np.max([fwd.max(), rev.max(), 0])
    bins = np.linspace(lower, upper, 200)

    fYs = normalize(get_kde(fwd, bins))
    bYs = normalize(get_kde(rev, bins))

    
    ax.plot(bins, fYs)
    ax.plot(bins, bYs)

    overlap = np.min([fYs, bYs], axis=0)
    ax.fill_between(bins, overlap, color='gray', label="overlap")

    f = EXP(fwd)
    b = -EXP(bwd)
    if vlines:
        ax.axvline(f,linestyle='-',color=colors[0])
        ax.axvline(b,linestyle='-',color=colors[1])
        ax.axvline(bar,linestyle='-',color=colors[2])

def EXP(data):
    return - RT * np.log(np.mean(np.exp(- data / RT)))

In [None]:
# Reformat dataframe to access dE series

states = u_nk.columns.values.tolist()
groups = u_nk.groupby(level=u_nk.index.names[1:])

# Choose windows for which to plot overlap

overlap_plots = [0, len(states)-2]     # First and last windows
# overlap_plots = range(len(states)-1)   # All windows
# overlap_plots = [3,12]                 #Custom set of windows

# Number of columns of graph array
cols = 2
figsize = (16,7)

for k in overlap_plots:
    assert k <= len(states) - 1, ('Invalid window number: ' + str(k) +
                                  ' - valid numbers are 0 to ' + str(len(states) - 1))

nplots = len(overlap_plots)
rows = ceil(nplots / cols)

fig, axes = plt.subplots(rows,cols, figsize=(5*cols,5*rows), sharex='col', sharey=True)




fig.legend([Line2D([0],[0],color='tab:blue'),Line2D([0],[0],color='tab:orange'),Line2D([0],[0],color='tab:green')],
           ['forward EXP','backward EXP', 'BAR'])

for (k, ax) in zip(overlap_plots, axes):
    uk = groups.get_group(states[k])
    w_f = uk.iloc[:, k+1] - uk.iloc[:, k]
    uk1 = groups.get_group(states[k+1])
    w_r = uk1.iloc[:, k] - uk1.iloc[:, k+1]
    
    plot(ax, w_f, w_r, perWindow['BAR']['df'].iloc[k], vlines=False)
    
    ax.set_ylabel("P(dE)")
    ax.set_title(str(states[k])+' ↔ '+str(states[k+1]))
    ax.set_xlabel("dE")

fig.tight_layout()

fig.savefig("energy_distributions.pdf")