In [None]:
import numpy as np
import os
from glob import glob
from natsort import natsorted
import re
from tqdm.auto import tqdm

from ipywidgets import interact, widgets

import matplotlib
from matplotlib import pyplot as plt
from matplotlib.animation import FuncAnimation
import functools

matplotlib.rcParams.update({'figure.autolayout': True, 'toolbar':'None'})
%matplotlib widget

# Functions

## load data

In [None]:
# data is image centers with shape: (images,cvs,iterations)
# data[:,:,0] is the image centers in the initial path
# read from initial image colvar .conf files, and from NAMD log file after each iteration
def extract_data(workdir, outname="data.npy"):
    n_imgs = len(glob(f"{workdir}/output/*"))  
    data = None
    for im in tqdm(range(n_imgs),leave=False):
        with open(f"{workdir}/setup/image_"+str(im)+"_colvars.conf", "r") as f:
            im_data = np.array(re.findall("centers (\d+.?\d+)", f.read()),dtype=float)
            n_cvs = len(im_data)
        with open(f"{workdir}/output/"+str(im)+"/job00."+str(im)+".log", "r") as file:
            n_reads = 0
            for line in file:
                if "Updating bias_swarms_cv" in line:
                    cv = int(re.search("\d+",line)[0])
                    if n_reads%n_cvs == 0:
                        if n_reads > 1:
                            im_data = np.vstack((im_data,it_data))
                        it_data = np.zeros(n_cvs)

                    next_line = next(file)
                    center = float(re.search(r"\{\s+(.*)\s+\}", next_line).group(1))
                    it_data[cv-1] = center
                    n_reads += 1
            im_data = np.vstack((im_data,it_data))
        if data is None:
            data = im_data
        else:
            data = np.dstack((data,im_data))
    data = data.T
    if outname is not None:
        np.save(outname, data, allow_pickle=False)
    return data

In [None]:
# traj_data is image centers with shape: (images,cvs,iterations,frames)
# read from NAMD colvars trajectory file
def extract_traj_data(workdir, outname="traj_data.npy", stride=1):
    n_imgs = len(glob(f"{workdir}/output/*"))  
    traj_data = None
    for im in tqdm(range(n_imgs),position=0,leave=False):
        cv_traj_files = natsorted(glob(f"{workdir}/output/"+str(im)+"/*drift*colvars.traj"))
        n_iter = len(cv_traj_files)
        iter_data = None
        for it in tqdm(range(n_iter),position=1,leave=False):
            iter_traj = np.loadtxt(cv_traj_files[it])[::stride,1:]
            if iter_data is None:
                iter_data = iter_traj
            else:
                iter_data = np.dstack((iter_data,iter_traj))
        if traj_data is None:
            traj_data = np.asarray([iter_data])
        else:
            traj_data = np.vstack((traj_data,[iter_data]))
    traj_data = traj_data.T
    traj_data = np.swapaxes(traj_data,0,3)
    traj_data = np.swapaxes(traj_data,2,3)
    if outname is not None:
        np.save(outname, traj_data, allow_pickle=False)
    return traj_data

In [None]:
def get_data(string_dir=None,data_name="data.npy",traj_data_name="traj_data.npy",stride=1):
    if os.path.exists(f"{string_dir}/{data_name}"):
        data = np.load(f"{string_dir}/{data_name}")
    else:
        data = extract_data(f"{string_dir}",outname=data_name)
    if os.path.exists(f"{string_dir}/{data_name}"):
        traj_data = np.load(f"{string_dir}/{traj_data_name}")
    else:
        traj_data = extract_data(f"{string_dir}",outname=traj_data_name,stride=stride)
    return data,traj_data

## plotting

In [None]:
def render_interactive(data, traj_data=None, title="", ani_outname=None,ani_interval=20,ani_cv=1):
    '''
    Produces interactive figure of image coordinates for cv/iteration slider values
    If only data provided:
        shows string evolution over time
    if traj_data also provided:
        shows per frame, and mean cv values vs image center

    if ani_outname:
        generate animated gif
        ani_cv: cv for gif (1-indexed)
        ani_interval: milliseconds between frames
    '''
    s_fade = 20 #steps between faded lines
    f_fade = 0.80 #fade factor
    matplotlib.rcParams['figure.figsize'] = (15, 5)
    matplotlib.rcParams.update({'font.size': 20})
    
    fig,ax = plt.subplots(1,1)
    fig.canvas.header_visible = False
    x_vals = range(data.shape[0])
            
    def update_data(i,cv):
        ax.clear()
        ax.set_title(f"{title}")
        ax.set_xlim(0,data.shape[0]-1)
        ax.set_ylim(3,24)
        ax.set_xlabel("Image")
        ax.set_ylabel("Distance Å")
        fade_lines = range(0,i,s_fade)
        for k,f in enumerate(fade_lines):
            y_vals = data[:,cv-1,f]
            a = f_fade**(len(fade_lines)-k)-(f_fade**(len(fade_lines)-k)-f_fade**(len(fade_lines)-k+1))*((i-1)%s_fade/s_fade)
            if a > 0.01:
                ax.plot(x_vals,y_vals,color='#1f77b4',alpha=a)
        y_vals = data[:,cv-1,i]
        ax.plot(x_vals,y_vals,color='#1f77b4',alpha=1)
        ax.text(0.95,0.95,"CV = "+str(cv)+" It = "+str(i),transform=ax.transAxes, size=16, ha='right', va='top')
    
    def update_traj_data(i,cv):
        ax.clear()
        ax.set_title(f"{title}\nZref (solid), Mean of trajectory (dashed)")
        ax.set_xlim(0,data.shape[0]-1)
        ax.set_ylim(3,24)
        ax.set_xlabel("Image")
        ax.set_ylabel("Distance Å")        
        for im in range(data.shape[0]):
            stride = 10 if traj_data.shape[3] > 100 else 1  
            traj_y = traj_data[im,cv-1,i,::stride]
            traj_x = np.linspace(im-0.5,im-0.5+1,len(traj_y))
            ax.plot(traj_x,traj_y)

            x = [im-0.5,im-0.5+1]
            y = [data[im,cv-1,i],data[im,cv-1,i]]
            ax.plot(x,y,color="black")

            mean_y = np.mean(traj_data[im,cv-1,i,:])
            y_2 = [mean_y, mean_y]
            ax.plot(x,y_2,color="black",ls='dashed')
        ax.text(0.95,0.95,"CV = "+str(cv)+" It = "+str(i),transform=ax.transAxes, size=16, ha='right', va='top')
            
    if traj_data is not None:
        update = update_traj_data
        max_i = traj_data.shape[2]-1
    else:
        update = update_data
        max_i = data.shape[2]-1
                   
    i_slider = widgets.IntSlider(min=0,max=max_i,step=1,value=0,layout=widgets.Layout(width='500px'))
    cv_slider = widgets.IntSlider(min=1,max=data.shape[1],step=1,value=1,layout=widgets.Layout(width='500px'))
    interact(update,i=i_slider,cv=cv_slider)

    if ani_outname is not None:
        ani = FuncAnimation(fig, functools.partial(update,cv=ani_cv), frames=tqdm(range(max_i),leave=False), interval=ani_interval)
        ani.save(ani_outname,writer="imagemagick",dpi=100,fps=60)

In [None]:
def extract_traj_data(workdir, outname="traj_data.npy", stride=1):
    n_imgs = len(glob(f"{workdir}/output/*"))  
    traj_data = None
    for im in tqdm(range(n_imgs),position=0,leave=False):
        cv_traj_files = natsorted(glob(f"{workdir}/output/"+str(im)+"/*drift*colvars.traj"))    
        n_iter = len(cv_traj_files)
        iter_data = None
        for it in tqdm(range(n_iter),position=1,leave=False):
            iter_traj = np.loadtxt(cv_traj_files[it])[::stride,1:]
            if iter_data is None:
                iter_data = iter_traj
            else:
                iter_data = np.dstack((iter_data,iter_traj))
        if traj_data is None:
            traj_data = np.asarray([iter_data])
        else:
            traj_data = np.vstack((traj_data,[iter_data]))
    traj_data = traj_data.T
    traj_data = np.swapaxes(traj_data,0,3)
    traj_data = np.swapaxes(traj_data,2,3)
    if outname is not None:
        np.save(outname, traj_data, allow_pickle=False)
    return traj_data

In [None]:
def pmf(pmf_dir=None,
        colvar_traj_glob="*drift*.colvars.traj",
        ref_centers_glob="*.restart.conf",
        use_last_it=False):
    #if duplicate it takes latest job
    n_imgs = len(glob(f"{pmf_dir}/*"))
    iter_num = None
    colvar_traj_files = []
    for im in range(n_imgs):
        im_traj_file = natsorted(glob(f"{pmf_dir}/{str(im)}/{colvar_traj_glob}"))
        if not use_last_it:
            assert len(im_traj_file) == 1, f"ERROR: matched multiple colvars.traj files: {pmf_dir}/{str(im)}/{colvar_traj_glob}"
        im_traj_file = im_traj_file[-1]
        iter_match = re.search(r'iter(\d+)', im_traj_file)
        if iter_num is None and iter_match is not None:
            iter_num = iter_match.group(1)
        colvar_traj_files.append(im_traj_file)
    if iter_num is not None:
        assert len(set(re.search(r'iter(\d+)', s).group(1) for s in colvar_traj_files)) == 1, "mismatched iter nums"
    
    Z = np.asarray([np.loadtxt(i) for i in colvar_traj_files])[:,:,1:]
    print("Z.shape: "+str(Z.shape))
    
    Zref = []
    for im in range(n_imgs):
        ref_centers_file = natsorted(glob(f"{pmf_dir}/{str(im)}/{ref_centers_glob}"))
        if not use_last_it:
            assert len(ref_centers_file) == 1, f"ERROR: matched multiple ref_centers_files: {pmf_dir}/{str(im)}/{ref_centers_glob}"
        if iter_num is not None:
            for f in ref_centers_file[::-1]:
                ref_iter_num = re.search(r'iter(\d+)', f).group(1)
                if ref_iter_num == iter_num:
                    ref_centers_file = f
                    break
        else:
            ref_centers_file = ref_centers_file[-1]
        
        with open(ref_centers_file, "r") as f:
            Zref.append(np.array(re.findall("centers (\d+.?\d+)", f.read()),dtype=float))
    Zref = np.asarray(Zref)
    DeltaZ = np.swapaxes((np.swapaxes(Z,0,1)-Zref),0,1)
    MeanDeltaZ = np.mean(DeltaZ, axis=1)
    
    K = 0.5
    F = K*MeanDeltaZ
    Wn = [-F[i]*(Zref[i+1]-Zref[i]) for i in range(len(Zref)-1)]
    WnT = [-0.5*(F[i]+F[i+1])*(Zref[i+1]-Zref[i]) for i in range(len(Zref)-1)]
    
    PMF = np.asarray([np.sum(WnT[0:i]) for i in range(32)])   
    return PMF, Wn, WnT

# Plots

In [None]:
data, traj_data = get_data("./src_barrier_string000/",data_name="data.npy",traj_data_name="traj_data_s100.npy",stride=100)

In [None]:
fig_data = render_interactive(data)
# fig_data = render_interactive(data, ani_cv=28, ani_outname="cv_evolution.gif")

In [None]:
fig_traj_data = render_interactive(data,traj_data)
#fig_traj_data = render_interactive(data,traj_data, ani_cv=28, ani_outname="cv_traj_evolution.gif")

In [None]:
# Reminder data[it=0] is initial Zref, traj_data[it=0] is traj with initial Zref
converge_data = np.zeros((traj_data.shape[0],traj_data.shape[2]))
max_cvdist_data = np.zeros((traj_data.shape[0],traj_data.shape[2]))
for im in range(traj_data.shape[0]):
    for it in range(traj_data.shape[2]):
        base = data[im,:,it]
        mean = np.mean(traj_data[im,:,it],axis=1)
        d_cv = np.linalg.norm(base-mean)
        converge_data[im,it] = d_cv
        max_cvdist_data[im,it] = np.max(base-mean)

In [None]:
# per image convergence
x_vals = range(data.shape[0])
fig,ax = plt.subplots(1,1)
ax2 = ax.twinx()
fig.canvas.header_visible = False
matplotlib.rcParams['figure.figsize'] = (15,5)
matplotlib.rcParams.update({'font.size': 20})

def update(i):
    ax.clear()
    ax.set_xlim(0,len(x_vals))
    ax.set_ylim(0,10)    
    ax.set_xlabel("Image")
    ax.set_ylabel("Distance in CV space\n from previous Zref", color='#1f77b4')
    ax.plot(x_vals,converge_data[:,i])
    ax2.clear()
    ax2.set_ylim(0,10)
    ax2.yaxis.set_label_position('right')
    ax2.set_ylabel("Max distance of any CV\n from previous Zref", color='orange')
    ax2.plot(x_vals,max_cvdist_data[:,i],color='orange')
    stride_text = ax.text(0.95,0.95,"It = "+str(i),transform=ax.transAxes, size=16, ha='right', va='top')   
    
fig_conv = interact(update,i=widgets.IntSlider(min=0,max=traj_data.shape[2]-1,step=1,value=0,layout=widgets.Layout(width='500px')))

# ani = FuncAnimation(fig, update, frames=tqdm(range(traj_data.shape[2]-1),leave=False), interval=20)
# ani.save("per_image_convergence.gif",writer="imagemagick",dpi=100,fps=60)

In [None]:
# convergence over iterations
fig = plt.figure()
plt.plot(np.mean(converge_data, axis=0))
plt.show()