In [None]:
from matplotlib import pyplot as plt
import numpy as np
import os
import pandas as pd

class MSDData:
    def __init__(self, filepath):
        if filepath[-3:] == 'dat' or filepath[-3:] == 'csv':
            self.time, self.msd = self.read_from_cpptraj(filepath)
        else:
            self.time, self.msd = self.read_from_gmx(filepath)
        
    def read_from_cpptraj(self, filepath):
        f = open(filepath)
        next(f)
        time = []
        msd = []
        for line in f:
            item = line.split()
            if item:
                time.append(float(item[0]))
                msd.append(float(item[1]))
        f.close()
        return np.array(time), np.array(msd)
            
    
    def read_from_gmx(self, filepath):
        f = open(filepath)
        for line in f:
            if line[2:4] == 's0':
                next(f)
                break
        time = []
        msd = []
        for line in f:
            item = line.split()
            if item:
                time.append(float(item[0]))
                msd.append(float(item[1]))
        f.close()
        return np.array(time), np.array(msd)
    
    def output_data(self):
        return self.time, self.msd
    
class DrawScatter:
    def __init__(self, X, Y):
        self.X = X
        self.Y = Y
    
    def draw(self, ax):
        ax.scatter(self.X, self.Y)
        
    def fit1(self, begin_fit, end_fit):
        slope, intercept = np.polyfit(self.X[begin_fit:end_fit], self.Y[begin_fit:end_fit], 1)
        def fit_func(x):
            return x * slope + intercept
        R2 = self.calculate_R2(fit_func, begin_fit, end_fit)
        return slope, intercept, R2, fit_func
    
    def calculate_R2(self, func, begin_fit, end_fit):
        mean = self.Y[begin_fit:end_fit].mean()
        ss_tot = np.sum((self.Y[begin_fit:end_fit] - mean) ** 2)
        ss_res = np.sum((self.Y[begin_fit:end_fit] - func(self.X[begin_fit:end_fit])) ** 2)
        R2 = 1 - ss_res / ss_tot
        return R2
    
    def fit_interval(self, begin_fit, end_fit):
        '''
        begin_fit, end_fit: fit from begin_fit to end_fit
        '''
        if begin_fit == end_fit:
            begin_fit = 0
            end_fit = -1
        else: 
            begin_fit = int(begin_fit / (self.X[1] - self.X[0]))
            end_fit = int(end_fit / (self.X[1] - self.X[0]))
        return begin_fit, end_fit


class DiffusionDrawScatter(DrawScatter):
    def fit_least_squares(self, begin_fit, end_fit):
        begin_fit, end_fit = self.fit_interval(begin_fit, end_fit)
        slope, intercept, R2, fit_func = self.fit1(begin_fit, end_fit)
        diffusion_coefficient = slope / 6 * 10
        return {
            "begin_fit": begin_fit,
            "end_fit": end_fit,
            "slope": slope,
            "intercept": intercept,
            "R2": R2,
            "fit_func": fit_func,
            "D": diffusion_coefficient
        }

    def draw_with_fit_result(self, fit_result, text_x, text_y, ax=None, color=None):
        begin_fit = fit_result["begin_fit"]
        end_fit = fit_result["end_fit"]

        if color:
            ax.scatter(self.X, self.Y, s=3, c=color)
        else:
            ax.scatter(self.X, self.Y, s=3)
        ax.set_xlabel('Time  ps', fontdict={'fontsize': 12})
        ax.set_ylabel('MSD  $\mathrm{\AA}^2$/ps', fontdict={'fontsize': 12})

        ax.plot(self.X[begin_fit:end_fit], fit_result["fit_func"](self.X[begin_fit:end_fit]), color='red', linewidth=1)

        if fit_result["intercept"] > 0:
            text = f'y={fit_result["slope"]:.4f}x+{fit_result["intercept"]:.4f}\nR$^2$={fit_result["R2"]:.4f}\nD={fit_result["D"]:.4f} $10^-$$^5$ $cm^2$/s'
        else:
            text = f'y={fit_result["slope"]:.4f}x{fit_result["intercept"]:.4f}\nR$^2$={fit_result["R2"]:.4f}\nD={fit_result["D"]:.4f} $10^-$$^5$ $cm^2$/s'
        ax.text(text_x, text_y, text, fontdict={'fontsize': 12})

        
class BoxDrawScatter(DrawScatter):
    def __init__(self, X, Y, fit_method='mean'):
        super().__init__(X, Y)
        self.fit_method = fit_method
        
        if self.Y.ndim != 1:
            if fit_method == 'wls':
                var = np.var(self.Y, axis=1, ddof=1)
                var[var == 0] = 1e-8  
                weights = 1 / var
                self.Y = np.average(self.Y, axis=1, weights=weights)
            if fit_method == 'mean':
                self.Y = np.mean(self.Y, axis=1)

        self.X = 1 / self.X
        
    def draw_and_fit(self, text_x, text_y, ax=None,):
        
        ax.set_xlabel('1/L  $nm^{-1}$',fontdict={'fontsize':12})
        ax.set_ylabel('$D_{tr}$  $10^-$$^5$ $cm^2$/s',fontdict={'fontsize':12})
        slope, intercept, R2, fit_func = self.fit1(0, 3)
        ax.scatter(self.X, self.Y, s=25)
        ax.plot(self.X, fit_func(self.X), color='red', linewidth=1, linestyle='--')
        text = f'y={slope:.4f}x+{intercept:.4f}\nR$^2$={R2:.4f}\nD$_\infty$={intercept:.4f}  $10^-$$^5$ $cm^2$/s'
        ax.text(text_x, text_y, text,fontdict={'fontsize':12})
    
    def draw_and_fit2(self,ax,label,color=None):
        ax.set_xlabel('1/L  $nm^{-1}$',fontdict={'fontsize':12})
        ax.set_ylabel('$D_{tr}$  $10^-$$^5$ $cm^2$/s',fontdict={'fontsize':12})
        slope, intercept, R2, fit_func = self.fit1(0, 3)
        print(f'{label}: {intercept:.4f} {R2:.4f}')
        x = np.linspace(0,max(self.X),20)
        if color:
            ax.scatter(self.X, self.Y, s=15,c=color, label=label)
            ax.plot(x, fit_func(x), color=color, linewidth=1, linestyle='--')
        else:
            ax.scatter(self.X, self.Y, s=15, label=label)
            ax.plot(x, fit_func(x), linewidth=1, linestyle='--')
        #text = f'y={slope:.4f}x+{intercept:.4f}\nR$^2$={R2:.4f}\nD$_\infty$={intercept:.4f}  $10^-$$^5$ $cm^2$/s'

        
        
class MeanMSDData:
    def __init__(self, *args):
        arr = []
        for file in args:
            data = MSDData(file)
            time, msd = data.output_data()
            arr.append(msd)
        self.time = time
        arr = np.stack(arr)
        self.msd = np.mean(arr, axis=0)
        
    def output_data(self,):
        return self.time, self.msd

In [None]:
plt.ioff()
color_dic = {'1-butanol':'#925EB0', '1-propanol':'#7E99F4', 'aspirin':'#9D9EA3',
               'methyl_butyrate':'#7AB656', 't-butanol':'#E29135','apo':'#94C6CD'}
# average box length during md
box_length_dic = {'1-butanol':np.array([3.42739, 3.86156, 4.29555]), '1-propanol':np.array([3.42633, 3.86072, 4.29489]),
                  'aspirin':np.array([3.42898, 3.86265, 4.29653]), 'methyl_butyrate':np.array([3.42837, 3.86222, 4.29622]),
                  't-butanol':np.array([3.42675, 3.86135, 4.29535]),'water':np.array([2.1714, 3.7270, 5.5923]),
                  'apo':np.array([3.42284, 3.85798, 4.29252])}
solute_list = ('1-butanol', '1-propanol', 'aspirin', 'methyl_butyrate', 't-butanol','apo')

#calculate D for single trajtory and plot
plot = False
savefilepath = 'result_5-50/result_traj.csv'
if not os.path.exists(savefilepath):
    df_list = []
    for solute in solute_list:
        for box_id in [1,2,3]:
            for i in range(1,18):
                df_list.append([solute,box_id,i,0.,0.])
    df = pd.DataFrame(df_list, columns=['Solute','Box','Index','R','D'])
    df.to_csv(savefilepath, index=False)
else:
    df = pd.read_csv(savefilepath)

for solute in ('methyl_butyrate',):
    for box_size in (40,45,50):
        box_index = box_index = (box_size - 35) // 5
        if plot:
            fig, axes = plt.subplots(4, 5, figsize=(20, 25), dpi=600)
            axes = axes.ravel()
            c = 0
            for i in range(1,18):
                file = f'/home/databank/zjhan/diffusion/calculation/{solute}/box_{box_size}/md{i}/msd-1.dat'
                data = MSDData(file)
                diffusion_obj = DiffusionDrawScatter(*data.output_data())
                fit_result =diffusion_obj.fit_least_squares(5000, 50000,)
                diffusion_obj.draw_with_fit_result(fit_result, 16000, 700, ax=axes[i-1])
                df.loc[(df['Solute'] == solute) & (df['Box'] == box_index) & (df['Index'] == i), ['R', 'D']] = [fit_result['R2'], fit_result['D']]
                c += 1
            # hide subplots that don't have data
            for i in range(c, 20):
                axes[i].axis('off')
            fig.tight_layout()
            plt.title(f'MSD and D for all trajs of {solute}-box{box_index}')
            fig.savefig(f'result_5-50/{solute}-box{box_index}-all_trajs.png')
        else:
            for i in range(1,18):
                file = f'/home/databank/zjhan/diffusion/calculation/{solute}/box_{box_size}/md{i}/msd-1.dat'
                data = MSDData(file)
                diffusion_obj = DiffusionDrawScatter(*data.output_data())
                fit_result = diffusion_obj.fit_least_squares(5000, 50000,)
                df.loc[(df['Solute'] == solute) & (df['Box'] == box_index) & (df['Index'] == i), ['R', 'D']] = [fit_result['R2'], fit_result['D']]
df.to_csv(savefilepath, index=False)  



In [None]:
#infinite-box
def get_D(df, solute, number=False):
    df_solute = df[df['Solute']==solute]
    grouped = df_solute.groupby('Box')
    df_solute = []
    for name,group in grouped:
        df_solute.append(group['D'].to_numpy())
    if number:
        df_solute = [arr[:number] for arr in df_solute]
    return np.array(df_solute)

df = pd.read_csv('result_5-50/result_traj.csv')
fig = plt.figure(dpi=800)
ax = fig.add_subplot(111)

for solute in solute_list:
    df_solute = get_D(df, solute, 15)
    box_draw = BoxDrawScatter(np.array(box_length_dic[solute]), df_solute)
    box_draw.draw_and_fit2(ax=ax,color=color_dic[solute],label=solute)

plt.xlim(xmin=0)
plt.legend()
plt.savefig(f'0_diffusion_coefficient')
#box_draw.draw_and_fit(0.2,2.15, f'result/{solute}-diffusion_coefficient') #only for water