In [None]:
import safep
import alchemlyb
from glob import glob
import numpy as np
import matplotlib.pyplot as plt
import os
from alchemlyb.parsing import namd
from IPython.display import display, Markdown
from pathlib import Path
from dataclasses import dataclass
import scipy as sp
from alchemlyb.estimators import BAR
import warnings
warnings.simplefilter(action='ignore', category=FutureWarning)

In [None]:
import pandas as pd
@dataclass
class FepRun:
    u_nk:           pd.DataFrame
    perWindow:      pd.DataFrame
    cumulative:     pd.DataFrame
    forward:        pd.DataFrame
    forward_error:  pd.DataFrame
    backward:       pd.DataFrame
    backward_error: pd.DataFrame
    per_lambda_convergence: pd.DataFrame
    color: str

# User parameters

In [None]:
dataroot = Path('.')
replica_pattern='Replica?'
replicas = dataroot.glob(replica_pattern)
filename_pattern='*.fepout'

temperature = 303.15
RT = 0.00198720650096 * temperature
detectEQ = True #Flag for automated equilibrium detection

In [None]:
colors = ['blue', 'red', 'green', 'purple', 'orange', 'violet', 'cyan']
itcolors = iter(colors)

# Extract key features from the MBAR fitting and get ΔG
Note: alchemlyb operates in units of kT by default. We multiply by RT to convert to units of kcal/mol.

# Read and plot number of samples after detecting EQ

In [None]:
fepruns = {}
for replica in replicas:
    print(f"Reading {replica}")
    unkpath = replica.joinpath('decorrelated.csv')
    u_nk = None
    if unkpath.is_file():
        print(f"Found existing dataframe. Reading.")
        u_nk = safep.read_UNK(unkpath)
    else:
        print(f"Didn't find existing dataframe at {unkpath}. Checking for raw fepout files.")
        fepoutFiles = list(replica.glob(filename_pattern))
        totalSize = 0
        for file in fepoutFiles:
            totalSize += os.path.getsize(file)
        print(f"Will process {len(fepoutFiles)} fepout files.\nTotal size:{np.round(totalSize/10**9, 2)}GB")

        if len(list(fepoutFiles))>0:
            print("Reading fepout files")
            fig, ax = plt.subplots()

            u_nk = namd.extract_u_nk(fepoutFiles, temperature)
            u_nk = u_nk.sort_index(axis=0, level=1).sort_index(axis=1)
            safep.plot_samples(ax, u_nk, color='blue', label='Raw Data')

            if detectEQ:
                print("Detecting equilibrium")
                u_nk = safep.detect_equilibrium_u_nk(u_nk)
                safep.plot_samples(ax, u_nk, color='orange', label='Equilibrium-Detected')

            plt.savefig(f"./{str(replica)}_FEP_number_of_samples.pdf")
            plt.show()
            safep.save_UNK(u_nk, unkpath)
        else:
            print(f"WARNING: no fepout files found for {replica}. Skipping.")
    
    if u_nk is not None:
        fepruns[str(replica)] = FepRun(u_nk, None, None, None, None, None, None, None, next(itcolors))
        

In [None]:
for key, feprun in fepruns.items():
    u_nk = feprun.u_nk
    feprun.perWindow, feprun.cumulative = safep.do_estimation(u_nk) #Run the BAR estimator on the fep data
    feprun.forward, feprun.forward_error, feprun.backward, feprun.backward_error = safep.do_convergence(u_nk) #Used later in the convergence plot'
    feprun.per_lambda_convergence = safep.do_per_lambda_convergence(u_nk)



# Plot data

In [None]:
toprint = ""
dGs = []
errors = []
for key, feprun in fepruns.items():
    cumulative = feprun.cumulative
    dG = np.round(cumulative.BAR.f.iloc[-1]*RT, 1)
    error = np.round(cumulative.BAR.errors.iloc[-1]*RT, 1)
    dGs.append(dG)
    errors.append(error)

    changeAndError = f'{key}: \u0394G = {dG}\u00B1{error} kcal/mol'
    toprint += '<font size=5>{}</font><br/>'.format(changeAndError)

toprint += '<font size=5>{}</font><br/>'.format('__________________')
mean = np.average(dGs)

#If there are only a few replicas, the MBAR estimated error will be more reliable, albeit underestimated
if len(dGs)<3:
    sterr = np.sqrt(np.sum(np.square(errors)))
else:
    sterr = np.round(np.std(dGs),1)
toprint += '<font size=5>{}</font><br/>'.format(f'mean: {mean}')
toprint += '<font size=5>{}</font><br/>'.format(f'sterr: {sterr}')
Markdown(toprint)

In [None]:
def do_agg_data(dataax, plotax):
    agg_data = []
    lines = dataax.lines
    for line in lines:
        agg_data.append(line.get_ydata())
    flat = np.array(agg_data).flatten()
    kernel = sp.stats.gaussian_kde(flat)
    pdfX = np.linspace(-1, 1, 1000)
    pdfY = kernel(pdfX)
    std = np.std(flat)
    mean = np.average(flat)
    temp = pd.Series(pdfY, index=pdfX)
    mode = temp.idxmax()

    textstr = r"$\rm mode=$"+f"{np.round(mode,2)}"+"\n"+fr"$\mu$={np.round(mean,2)}"+"\n"+fr"$\sigma$={np.round(std,2)}"
    props = dict(boxstyle='square', facecolor='white', alpha=1)
    plotax.text(0.175, 0.95, textstr, transform=plotax.transAxes, fontsize=14,
            verticalalignment='top', bbox=props)

    return plotax

In [None]:
fig = None
for key, feprun in fepruns.items():
    if fig is None:
        fig, axes = safep.plot_general(feprun.cumulative, None, feprun.perWindow, None, RT, hysttype='lines', label=key, color=feprun.color)
        axes[1].legend()
    else:
        fig, axes = safep.plot_general(feprun.cumulative, None, feprun.perWindow, None, RT, fig=fig, axes=axes, hysttype='lines', label=key, color=feprun.color)
    #fig.suptitle(changeAndError)

# hack to get aggregate data:
axes[3] = do_agg_data(axes[2], axes[3])

axes[0].set_title(str(mean)+r'$\pm$'+str(sterr)+' kcal/mol')
axes[0].legend()
plt.savefig(dataroot.joinpath('FEP_general_figures.pdf'))

# Plot the estimated total change in free energy as a function of simulation time; contiguous subsets starting at t=0 ("Forward") and t=end ("Reverse")

In [None]:
fig, convAx = plt.subplots(1,1)

for key, feprun in fepruns.items():
    convAx = safep.convergence_plot(convAx, 
                                    feprun.forward*RT, 
                                    feprun.forward_error*RT, 
                                    feprun.backward*RT,
                                    feprun.backward_error*RT,
                                    fwd_color=feprun.color,
                                    bwd_color=feprun.color,
                                    errorbars=False
                                    )
    convAx.get_legend().remove()

forward_line, = convAx.plot([],[],linestyle='-', color='black', label='Forward Time Sampling')
backward_line, = convAx.plot([],[],linestyle='--', color='black', label='Backward Time Sampling')
convAx.legend(handles=[forward_line, backward_line])
ymin = np.min(dGs)-1
ymax = np.max(dGs)+1
convAx.set_ylim((ymin,ymax))
plt.savefig(dataroot.joinpath('FEP_convergence.pdf'))

In [None]:
import scipy as sp
fig, (Dax, pdfAx) = plt.subplots(1,2, gridspec_kw={'width_ratios': [2, 1]}, sharey='row',  figsize=(10,5))

for key, feprun in fepruns.items():
    deltas = feprun.per_lambda_convergence.BAR.df*RT
    lambdas = feprun.per_lambda_convergence.index
    Dax.errorbar(lambdas, deltas, color=feprun.color, label=key)
    Dax.set_xlabel(r"$\lambda$")
    Dax.set_ylabel(r"$\delta_\mathrm{50\%}$ (kcal/mol)")

    kernel = sp.stats.gaussian_kde(deltas)
    pdfX = np.linspace(-0.3, 0.3, 1000)
    pdfY = kernel(pdfX)
    pdfAx.plot(pdfY, pdfX, label='KDE', color=feprun.color)

Dax.legend()
pdfAx.set_xlabel("KDE")
plt.savefig(dataroot.joinpath('FEP_perLambda_convergence.pdf'))