In [10]:
import os
import re
import sys
import shutil
import cloudpickle
from tqdm import tqdm
from concurrent.futures import ProcessPoolExecutor

import datamol as dm

import pytraj as pt
from pytraj.cluster import kmeans
import parmed as pmd

import MDAnalysis as mda
from MDAnalysis.analysis import rms, align, rms, gnm, pca
from MDAnalysis.analysis.base import (AnalysisBase,
                                      AnalysisFromFunction,
                                      analysis_class)
import pandas as pd
import numpy as np

import seaborn as sns
import matplotlib.pyplot as plt
import plotly.graph_objects as go
import plotly.express as px

from src.Calc_MMPBSA import Wrapper_MMPBSA

In [11]:
def find_matching_directories(dir_list):
    file_paths = {}
    for root, dirs, files in os.walk("."):
        for dir in dirs:
            # Check if directory matches any pattern in dir_list
            for pattern in dir_list:
                if (pattern == dir) or (re.match(pattern, dir)):
                    dirpath = os.path.join(root, dir)
                    file_paths[dir] = {"PRMTOP": "", "DCD": ""}
                    
                    for file in os.listdir(dirpath):
                        if file.endswith("m.prmtop"):
                            filepath = os.path.join(root, dir, file)
                            file_paths[dir]["PRMTOP"] = filepath
                        elif file.endswith("0.dcd"):
                            filepath = os.path.join(root, dir, file)
                            file_paths[dir]["DCD"] = filepath
                    break  # Once matched, no need to check other patterns
    
    return file_paths

# Example usage
dir_list = ["idx_\d+_Rank_\d+_.*"]
file_paths = find_matching_directories(dir_list)

In [12]:
class Pytraj_Analysis():
    def __init__(self, md_dir, traj_path, top_path, overwrite=False):
        self.md_dir = md_dir
        self.traj_path = traj_path
        self.top_path = top_path
        self.overwrite = overwrite

        # Set the slice factor
        SLICE = 20

        # Define output paths
        self.xtc_path = os.path.join(md_dir, "Step3_Md_Rep0_noWAT.xtc")
        self.dcd_path = os.path.join(md_dir, "Step3_Md_Rep0_noWAT.dcd")
        self.nowat_top_path = os.path.join(md_dir, "system_noWAT.prmtop")
        self.pdb_path = os.path.join(md_dir, "Minimized_noWAT.pdb")
        self.cluster_path = os.path.join(md_dir, "Clusters.pdb")

        # Handle trajectory loading and processing
        if traj_path.endswith(".dcd"):
            if not overwrite and os.path.exists(self.xtc_path) and os.path.exists(self.nowat_top_path):
                # Skip processing, load from existing XTC
                self.traj_noWAT = pt.iterload(self.xtc_path, self.nowat_top_path)
            else:
                # Process DCD (heavy computation)
                traj_WAT = pt.iterload(traj_path, self.top_path)
                traj_WAT.autoimage()
                traj_WAT.superpose(mask="@CA", ref=0)
                traj_noWAT = traj_WAT["!(:HOH,NA,CL)"]
                self.traj_WAT = traj_WAT
                self.traj_noWAT = traj_noWAT[::SLICE]
                # Write XTC
                pt.write_traj(self.xtc_path, self.traj_noWAT, overwrite=True)
                # Write DCD
                pt.write_traj(self.dcd_path, self.traj_noWAT, overwrite=True)
                # Write TOP
                pt.save(self.nowat_top_path, self.traj_noWAT.top, overwrite=True)
        elif traj_path.endswith(".xtc"):
            # Load directly from XTC
            self.traj_noWAT = pt.iterload(traj_path, self.top_path)

        # Set clustering trajectory
        self.traj_Cluster = self.traj_noWAT

        # Convert to PDB if needed (heavy if trajectory is large)
        if overwrite or not os.path.exists(self.pdb_path):
            pt.write_traj(self.pdb_path, self.traj_noWAT, frame_indices=[0], overwrite=True)

        # Cluster if needed (heavy computation)
        if overwrite or not os.path.exists(self.cluster_path):
            cluster_opts = {"MASK": "!@H=", "NUM": 10}
            self.cluster_traj(cluster_opts)

    def cluster_traj(self, cluster_opts):
        try:
            mask = cluster_opts.get("MASK", "!@H=")
            n_clust = cluster_opts.get("NUM", 10)
            
            cluster_data = kmeans(self.traj_Cluster, mask=mask, n_clusters=n_clust)
            centroids = list(cluster_data.centroids)
            cluster_traj = self.traj_Cluster[centroids]
    
            # FIXME: Probably sub-optimal
            cluster_df = pd.DataFrame({
                "Cluster_ID": list(range(len(cluster_data.centroids))),
                "Fraction": cluster_data.fraction,
            })
            
            pt.write_traj(
                self.cluster_path,
                cluster_traj,
                options="model",
                overwrite=True
            )
        except Exception as e:
            print(f"Cluster FAIL for {self.md_dir}\nError: {e}")
    
    def PCA(self):
        traj_PCA = self.traj_noWAT
        data_PCA = pt.pca(traj_PCA, mask='!@H=', n_vecs=2)
        data_PCA = pt.pca(traj_PCA, mask='@CA,@N', n_vecs=2)
        
        PCA = data_PCA[0]
        return PCA, traj_PCA
    
    def plot_PCA(self, title="PCA", filepath=None):
        PCA, traj_PCA = self.PCA()
        
        fig, ax = plt.subplots(figsize=(10, 5))
        scatter = ax.scatter(PCA[0], PCA[1], marker='o', c=range(traj_PCA.n_frames), alpha=0.5)
        
        ax.set_title(title)
        ax.set_xlabel("PC1")
        ax.set_ylabel("PC2")
        ax.axhline(y=0, color='gray', linestyle='--')
        ax.axvline(x=0, color='gray', linestyle='--')
        ax.grid(False)
        
        cbar = plt.colorbar(scatter, ax=ax)
        cbar.set_label("Frame")
        
        if filepath is not None:
            fig.savefig(filepath)
        
        plt.show()

class MDA_Analysis():
    def __init__(self, md_dir):
        traj_harmon_path = os.path.join(md_dir, "Step3_Md_Rep0_noWAT.xtc")
        assert os.path.exists(traj_harmon_path)
        
        prmtop_harmon_path = os.path.join(md_dir, "system_noWAT.prmtop")
        assert os.path.exists(prmtop_harmon_path)
        
        self.md_dir = md_dir
        self.universe = mda.Universe(prmtop_harmon_path, traj_harmon_path)
    
    def calc_rmsf(self):
        self.universe.trajectory[0]
        # Precompute RMSF
        average = align.AverageStructure(self.universe, self.universe, select='protein and name CA', ref_frame=0).run()
        ref = average.results.universe
        aligner = align.AlignTraj(self.universe, ref, select='protein and name CA', in_memory=True).run()
        
        c_alphas = self.universe.select_atoms('protein and name CA')
        R = rms.RMSF(c_alphas).run()
        rmsf = R.results.rmsf
        return rmsf, c_alphas
    
    def plot_rmsf(self):
        rmsf, c_alphas = self.calc_rmsf()
        
        # Plot the chart
        fig, ax = plt.subplots(figsize=(10, 5))
        ax.plot(c_alphas.resids, rmsf, label=self.md_dir)
        ax.legend()
        ax.set_xlabel("Residue")
        ax.set_ylabel("RMSF (Å)")
        
        plt.show()
    
    def calc_rmsd(self):
        # Restart trajectory
        self.universe.trajectory[0]
        rmsd = rms.RMSD(self.universe, select="all", groupselections=["protein and backbone", "resname UNK"]).run().results.rmsd
        return rmsd
    
    def plot_rmsd(self):
        rmsd = self.calc_rmsd()
        
        # Plot the chart
        fig, ax = plt.subplots(figsize=(10, 5))
        ax.plot(rmsd[:, 0], rmsd[:, -2:])
        ax.set_xlabel("Frame")
        ax.set_ylabel("RMSD (Å)")
        
        plt.show()
    
    def gaussian_elastic(self, close=False):
        self.universe.trajectory[0]
        if close == False:
            nma = gnm.GNMAnalysis(self.universe, select='protein and name CA', cutoff=7.0).run()
            nma_res = nma.results
            return nma_res
        else:
            nma = gnm.closeContactGNMAnalysis(self.universe, select='protein and name CA',
                                              cutoff=7.0, weights="size").run()
            nma_res = nma.results
            return nma_res
    
    def radgyr_run(self):
        self.universe.trajectory[0]
        def radgyr(atomgroup, masses, total_mass=None):
            # coordinates change for each frame
            coordinates = atomgroup.positions
            center_of_mass = atomgroup.center_of_mass()
        
            # get squared distance from center
            ri_sq = (coordinates-center_of_mass)**2
            # sum the unweighted positions
            sq = np.sum(ri_sq, axis=1)
            sq_x = np.sum(ri_sq[:,[1,2]], axis=1) # sum over y and z
            sq_y = np.sum(ri_sq[:,[0,2]], axis=1) # sum over x and z
            sq_z = np.sum(ri_sq[:,[0,1]], axis=1) # sum over x and y
        
            # make into array
            sq_rs = np.array([sq, sq_x, sq_y, sq_z])
        
            # weight positions
            rog_sq = np.sum(masses*sq_rs, axis=1)/total_mass
            # square root and return
            return np.sqrt(rog_sq)
        
        protein_sel = self.universe.select_atoms("protein")
        
        rog = AnalysisFromFunction(radgyr, self.universe.trajectory, protein_sel, protein_sel.masses, total_mass=np.sum(protein_sel.masses)).run()
        return rog.results["timeseries"]
    
    def radgyr_run_plot(self):
        rog = self.radgyr_run()
        
        labels = ['all', 'x-axis', 'y-axis', 'z-axis']
        for col, label in zip(rog.T, labels):
            plt.plot(col, label=label)
            
        plt.legend()
        plt.ylabel('Radius of gyration (Å)')
        plt.xlabel('Frame')
        plt.show()
    
    def PCA_2(self, components=2):
        # Reload the universe
        u = self.universe
        
        # Define the backbone
        selection_str = "backbone"
        selection_md = u.select_atoms(selection_str)
        
        # Run the PCA
        pc = pca.PCA(u, select=selection_str,
                     align=True, mean=None,
                     n_components=None).run()
        
        # Obtain the cumulative variance
        cumulative_variance_df = pd.DataFrame(pc.results.cumulated_variance[:3], columns=["Cumulative Variance"])
        cumulative_variance_df.index = ["PC1", "PC2", "PC3"]
        cumulative_df = cumulative_variance_df.to_dict()["Cumulative Variance"]
        
        # Obtain the transformed frames
        transformed = pc.transform(selection_md, n_components=3)
        transformed_df = pd.DataFrame(transformed, columns=["PC1", "PC2", "PC3"])
        transformed_df["Frames"] = range(u.trajectory.n_frames)
        
        return cumulative_df, transformed_df
    
    def plot_PCA_3D(self, title="PairGrid PCA", filepath=None):
        
        _, transformed_df = self.PCA_2()
        
        g = sns.PairGrid(transformed_df, hue="Frames", palette=sns.color_palette("viridis", self.universe.trajectory.n_frames))
        g.map(plt.scatter, marker=".")
        
        plt.subplots_adjust(top=0.9)
        g.fig.suptitle(title)
    
        if filepath is not None:
            g.fig.savefig(filepath)
        
        plt.show()
    
class MD_Analyzer(Pytraj_Analysis, MDA_Analysis):
    def __init__(self, md_dir, traj_path, top_path, overwrite):
        
        Pytraj_Analysis.__init__(self, md_dir, traj_path, top_path, overwrite)
        MDA_Analysis.__init__(self, md_dir)

In [13]:
# analyzer_classes = []
# analyzer_names = []
# for dirname, files in file_paths.items():
#     try:
#         analyzer = MD_Analyzer(dirname, files["DCD"], files["PRMTOP"])
#         analyzer_classes.append(analyzer)
#         analyzer_names.append(dirname)
#     except Exception as e:
#         print(f"Error processing {dirname}: {e}")

# # Add PDB and XTC
# for dir, files in file_paths.items():
#     try:
#         dcd_file = files["DCD"]
#         prmtop_file = files["PRMTOP"]
        
#         xtc_file = dcd_file.replace(".dcd", "_noWAT.xtc")
#         mod_prmtop = prmtop_file.replace("system.prmtop", "system_noWAT.prmtop")
#         files["PDB"] = os.path.join(dir, "Minimized_noWAT.pdb")
#         files["XTC"] = xtc_file
#         files["PRMTOP"] = mod_prmtop
#     except Exception as e:
#         print(f"Error updating files for {dir}: {e}")

In [14]:
analyzer_classes = []
analyzer_names = []
for dirname, files in tqdm(file_paths.items(), desc="Creating analyzers"):
   try:
       analyzer = MD_Analyzer(dirname, files["DCD"], files["PRMTOP"], True)
       analyzer_classes.append(analyzer)
       analyzer_names.append(dirname)
   except Exception as e:
       print(f"Error processing {dirname}: {e}")

analyzer_dict = {}
for cls, name in zip(analyzer_classes, analyzer_names):
    analyzer_dict[name] = cls

# Add PDB and XTC
for mydir, files in tqdm(file_paths.items(), desc="Updating file paths"):
   try:
       dcd_file = files["DCD"]
       prmtop_file = files["PRMTOP"]
       
       xtc_file = dcd_file.replace(".dcd", "_noWAT.xtc")
       dcd_file_noWAT =  dcd_file.replace(".dcd", "_noWAT.dcd")
       mod_prmtop = prmtop_file.replace("system.prmtop", "system_noWAT.prmtop")
       files["PDB"] = os.path.join(mydir, "Minimized_noWAT.pdb")
       files["XTC"] = xtc_file
       files["DCD_noWAT"] = dcd_file_noWAT
       files["PRMTOP_noWAT"] = mod_prmtop
       files["PRMTOP_WAT"] = prmtop_file
       files["CLUSTER"] = os.path.join(mydir, "Clusters.pdb")
       
   except Exception as e:
       print(f"Error updating files for {mydir}: {e}")

 ctime or size or n_atoms did not match
Creating analyzers:  10%|█         | 4/40 [11:09<1:23:26, 139.06s/it]

Cluster FAIL for idx_3_Rank_3_Chrysin
Error: index 50 is out of bounds for axis 0 with size 50


Creating analyzers:  22%|██▎       | 9/40 [24:04<1:04:47, 125.40s/it]

Cluster FAIL for idx_4_Rank_8_Luteolin
Error: index 84 is out of bounds for axis 0 with size 84


Creating analyzers:  28%|██▊       | 11/40 [24:51<34:56, 72.30s/it]  

Cluster FAIL for idx_0_Rank_1_Apigenin
Error: index 50 is out of bounds for axis 0 with size 50


Creating analyzers:  48%|████▊     | 19/40 [39:55<36:20, 103.83s/it]

Cluster FAIL for idx_4_Rank_2_Luteolin
Error: index 85 is out of bounds for axis 0 with size 85


Creating analyzers:  98%|█████████▊| 39/40 [1:13:35<01:11, 71.46s/it] 

Cluster FAIL for idx_0_Rank_8_Apigenin
Error: index 87 is out of bounds for axis 0 with size 87


Creating analyzers: 100%|██████████| 40/40 [1:14:50<00:00, 112.26s/it]
Updating file paths: 100%|██████████| 40/40 [00:00<00:00, 207126.12it/s]


In [17]:
analyzer_dict["idx_4_Rank_2_Luteolin"].traj_noWAT

pytraj.Trajectory, 85 frames: 
Size: 0.006353 (GB)
<Topology: 3344 atoms, 211 residues, 2 mols, PBC with box type = cubic>
           

In [None]:
def plot_RMSF(analyzers, labels):
    fig = go.Figure()
    for analyzer, label in zip(analyzers, labels):
        RMSF, c_alphas = analyzer.calc_rmsf()
        fig.add_trace(go.Scatter(
            x=c_alphas.resids,
            y=RMSF[:300],
            mode='lines',
            name=label
        ))
        
        fig.update_layout(
            title='RMSF Plot',
            xaxis_title='Residue',
            yaxis_title='RMSF (Å)',
            legend_title='MD Directory',
            height=600
        )
    fig.show()

plot_RMSF(analyzer_classes, analyzer_names)

In [None]:
def plot_RMSD(analyzers, labels):
    fig = go.Figure()
    df_RMSD_lst = []
    
    for analyzer, label in zip(analyzers, labels):
        RMSD = analyzer.calc_rmsd()
        fig.add_trace(go.Scatter(
            x=RMSD[:, 0],
            y=RMSD[:, 4],
            mode='lines',
            name=label
        ))
        
        # Save a DF
        df_RMSD = pd.DataFrame({
            "Complex": label,
            "Frame": RMSD[:, 0],
            "RMSD": RMSD[:, 4]
        })
        
        # Calculate mean RMSD for this complex
        mean_rmsd = np.mean(RMSD[:, 4])
        
        # Add mean as a column
        df_RMSD["Mean_RMSD"] = mean_rmsd
        
        df_RMSD_lst.append(df_RMSD)
    
    fig.update_layout(
        title='RMSD Plot',
        xaxis_title='Frame',
        yaxis_title='RMSD (Å)',
        legend_title='MD Directory',
        height=800,
    )
    fig.show()
    
    # Concatenate all dataframes
    df_combined = pd.concat(df_RMSD_lst)
    
    # Create summary dataframe with mean values
    df_summary = df_combined.groupby('Complex')['RMSD'].agg(['mean', 'std', 'min', 'max']).reset_index()
    df_summary = df_summary.sort_values('mean')
    
    print("Summary of RMSD values by complex (sorted by mean):")
    print(df_summary)
    
    return df_combined, df_summary

# Modified function call to unpack both return values
df_RMSD, df_summary = plot_RMSD(analyzer_classes, analyzer_names)

In [None]:
def plot_Radius(analyzers, labels):
    fig = go.Figure()
    for analyzer, label in zip(analyzers, labels):
        radgyr_array = analyzer.radgyr_run()
        frames = list(range(0, len(radgyr_array)))
        fig.add_trace(go.Scatter(
            x=frames,
            y=radgyr_array[:, 0],
            mode='lines',
            name=label
        ))
    
    fig.update_layout(
        title='Radius of Gyration',
        xaxis_title='Frame',
        yaxis_title='RMSD (Å)',
        legend_title='MD Directory',
        height=600
    )
    fig.show()

plot_Radius(analyzer_classes, analyzer_names)

In [None]:
def plot_Gaussian(analyzers, labels):
    fig = go.Figure()
    for analyzer, label in zip(analyzers, labels):
        gaussian = analyzer.gaussian_elastic(close=False)
        eigenvalues = gaussian["eigenvalues"]
        time = gaussian["times"]
        
        fig.add_trace(go.Scatter(x=time, y=eigenvalues, mode='lines', name=label))
        
        fig.update_layout(
            title='Gaussian Network Model',
            xaxis_title='Time',
            yaxis_title='Eigenvalue',
            legend_title='MD Directory',
            height=600
        )

    fig.show()

plot_Gaussian(analyzer_classes, analyzer_names)

In [None]:
def plot_PCA(analyzers, titles):

    n_plots = len(analyzers)
    n_cols = 2
    n_rows = int(np.ceil(n_plots / n_cols))

    fig_width = 6 * n_cols  # 6 inches per column
    fig_height = 5 * n_rows  # 5 inches per row

    fig, axs = plt.subplots(n_rows, n_cols, figsize=(fig_width, fig_height))
    for ax, analyzer, title in zip(axs.flatten(), analyzers, titles):
        PCA, traj_PCA = analyzer.PCA()
        scatter = ax.scatter(PCA[0], PCA[1], marker='o', c=range(traj_PCA.n_frames), alpha=0.5)
        
        ax.set_title(title)
        ax.set_xlabel("PC1")
        ax.set_ylabel("PC2")
        ax.axhline(y=0, color='gray', linestyle='--')
        ax.axvline(x=0, color='gray', linestyle='--')
        ax.grid(False)
        
        cbar = plt.colorbar(scatter, ax=ax)
        cbar.set_label("Frame")
    plt.tight_layout()
    plt.show()

plot_PCA(analyzer_classes, analyzer_names)

## View

In [None]:
def find_neight(traj):
    from itertools import chain
    atom_ndx = pt.search_neighbors(traj, mask=":UNK<:5.5")
    
    atom_ndx_2d = []
    for sub in atom_ndx:
        sub_list = list(sub)
        atom_ndx_2d.append(sub_list)
    
    atom_ndx_flatten = list(set(list(chain(*atom_ndx_2d))))

    all_resids = []
    for atom in atom_ndx_flatten:
        resid = traj.top.atom(atom).resid
        all_resids.append(resid)
    
    resids_unique = list(set(all_resids))
    resids_unique_str = " or ".join(map(str, resids_unique))

    return resids_unique_str
        
def view_lig(dirname):
    import nglview as nv
    traj = pt.iterload(file_paths[dirname]["XTC"], file_paths[dirname]["PRMTOP_noWAT"])

    N_list = find_neight(traj)
    
    view = nv.show_pytraj(traj)
    
    view.clear_representations()
    view.representations = [
        {
            "type":"cartoon",
            "params": {"sele":"protein", "color":"residueindex"}
        },
        {
            "type":"licorice",
            "params": {"sele":"(ligand) and not (_H)"}
        },
        {
            "type":"licorice",
            "params": {"sele":f"({N_list}) and not (_H)"}
        },
        # {
        #     "type":"surface",
        #     "params": {"sele":"protein and not ligand", "color":"blue", "wireframe":True, "opacity":0.6, "isolevel":3.}
        # }
    ]
    
    view.center("ligand")
    return view

In [None]:
view_lig("idx_0_Rank_12_Apigenin")

## Download

In [None]:
import zipfile
zip_filename = 'to_download.zip'

try:
    os.remove(zip_filename)
except:
    pass

with zipfile.ZipFile(zip_filename, 'w') as zipf:
    for run_dir, files in file_paths.items():
        # Add file to the zip file, with the appropriate directory structure
        arcname = os.path.join(run_dir, os.path.basename(run_dir))
        
        prmtop = files["PRMTOP_noWAT"]
        xtc = files["XTC"]
        pdb = files["PDB"]
        cluster = files["CLUSTER"]
        
        if prmtop:
            arcname = os.path.join(run_dir, os.path.basename(prmtop))
            zipf.write(prmtop, arcname)
        
        if xtc:
            arcname = os.path.join(run_dir, os.path.basename(xtc))
            zipf.write(xtc, arcname)

        if pdb:
            arcname = os.path.join(run_dir, os.path.basename(pdb))
            zipf.write(pdb, arcname)

        if cluster:
            arcname = os.path.join(run_dir, os.path.basename(cluster))
            zipf.write(cluster, arcname)

## MMPBA

In [None]:
from itertools import islice

In [None]:
file_paths_WAT = file_paths
for key, value in file_paths_WAT.items():
    prmtop = value["PRMTOP"]
    new_prmtop = prmtop.replace("_noWAT", "")
    value["PRMTOP"] = new_prmtop

In [None]:
some_keys = list(islice(file_paths.keys(), 2))
file_paths_sliced = {key: file_paths[key] for key in some_keys}

In [None]:
# mmpbsa = Wrapper_MMPBSA(file_paths_sliced)
# mmpbsa_df = mmpbsa()

## Prolif

In [None]:
file_paths

In [None]:
import MDAnalysis as mda
import prolif as plf

# load topology and trajectory
u = mda.Universe(file_paths["idx_6_Rank_10_Eriodictyol"]["PRMTOP_WAT"], file_paths["idx_6_Rank_10_Eriodictyol"]["DCD"])

# create selections for the ligand and protein
ligand_selection = u.select_atoms("resname UNK")
protein_selection = u.select_atoms(
    "protein and byres around 20.0 group ligand", ligand=ligand_selection
)
ligand_selection, protein_selection

In [None]:
# use default interactions
fp = plf.Fingerprint(count=True)

# run on a slice of the trajectory frames: from begining to end with a step of 10
fp.run(u.trajectory[::10], ligand_selection, protein_selection)

In [None]:
fp.plot_barcode(interactive=True, figsize=(8, 8))

In [None]:
# frame = 10
# # seek specific frame
# u.trajectory[frame]
# ligand_mol = plf.Molecule.from_mda(ligand_selection)
# protein_mol = plf.Molecule.from_mda(protein_selection)
# # display
# view = fp.plot_3d(ligand_mol, protein_mol, frame=frame, display_all=False)
# view