In [1]:
from glob import glob
import mdtraj as md
import numpy as np
from pathlib import Path

import proteka
from proteka.metrics import Featurizer

In [2]:
cln_folded_ref = md.load_pdb("./example_dataset_files/cln_folded.pdb")

In [None]:
# define the apriori-known timestep of the dcd files 5 ns between frames
ts = 5
traj_pattern = "./example_dataset_files/cln_amber_300K_mini/*.dcd"
dcd_path_list = glob(traj_pattern)
# sort the listed dcd files
dcd_path_list = sorted(dcd_path_list)
print(f"Found {len(dcd_path_list)} trajectories at pattern {traj_pattern}")

In [4]:
traj_dict = {}
for dcd_file in dcd_path_list:
    dcd_path = Path(dcd_file)
    traj = md.load_dcd(dcd_path,top=dcd_path.with_suffix(".pdb"))
    dic_name = dcd_path.stem
    traj_dict[dic_name] = traj

In [5]:
traj_slices = {}
frame_count = 0
for name,traj in traj_dict.items():
    curr_frames = traj.n_frames
    curr_slice = slice(frame_count,frame_count+curr_frames)
    traj_slices[name] = curr_slice
    frame_count += curr_frames

In [6]:
single_traj = md.join([traj for m_traj in traj_dict.items()])
ens = proteka.Ensemble.from_mdtraj_trj("cln-amber-300K",single_traj,trajectory_slices=traj_slices)

In [None]:
feat = Featurizer()
feat.add_fraction_native_contacts(ens, cln_folded_ref, 
                                rep_atoms=["CA"], 
                                lam=1.5, 
                                beta=10, 
                                atom_selection="all and not element H")
feat.add_rmsd(ens, cln_folded_ref, atom_selection="name CA")



In [8]:
# we will make some plots using matplotlib, but this should be installed separately as its not part of the requirements.
import matplotlib.pyplot as plt

In [None]:
fig,axs = plt.subplots(ncols=1,nrows=ens.n_trjs,figsize=(10,8),sharex=True)
i = 0
for name, traj_slice in ens.trajectories.items():
    rmsd_trace = ens.get_quantity("rmsd")[traj_slice]
    q_trace = ens.get_quantity("fraction_native_contacts")[traj_slice]
    ax = axs.flatten()[i]
    xs = np.arange(len(rmsd_trace))*ts
    ax.plot(xs,rmsd_trace,
        "salmon",
        label="RMSD",
        alpha=0.5,
    )
    ax.set_title(f"{name}", fontsize=20)
    if i == ens.n_trjs-1:
        ax.set_xlabel(f"Simulation time (ns)", fontsize=16)
    ax.set_ylabel("RMSD (nm)", fontsize=16)                                                                                                                                    
    ax.set_ylim(0, 2.5)
    ax.legend(loc="upper right")
    ax.grid()
    ax.tick_params(axis="both", which="major", labelsize=12)

    ax2 = ax.twinx()
    ax2.plot(xs, q_trace,
        "mediumslateblue",
        label="Q",
        alpha=0.5,
    )
    ax2.set_ylabel("Q", fontsize=16)

    ax2.set_ylim(0, 1)
    ax2.legend(loc="lower right")
    i += 1

plt.show()