# MDFP dataset generation

## TODO
- if scaling data, pickle the mins and maxs too, for easy inverse transformation later!

In [1]:
from Dataset_subclasses import DatasetMDFP

import os
import time
import pickle
import numpy as np
import torch
import mdtraj as md
from sklearn.preprocessing import MinMaxScaler

## Function: DE Shaw data folder parser

In [2]:
def get_deshaw_data_info(deshaw_folderpath):
    """
    DE Shaw pdb files are in a weird folder structure
    and file naming convention. This function walks
    through a DE Shaw data folder and generates a
    dictionary of lists holding useful file info, all
    in the order of the sorted MD simulation timesteps.
    """
    # deshaw pdb files are grouped in subfolders
    deshaw_subfolders = sorted([
        f.path for f in os.scandir(deshaw_folderpath) \
        if f.is_dir()
    ])
    n_subf = len(deshaw_subfolders)
    records = {
        'pdb_filepaths': [],
        'suffix_vals': [],
        'timestamps': []
    }
    subf_records_l = [None] * n_subf
    
    # extract info from each pdb file, by subfolder
    for j, deshaw_subf in enumerate(deshaw_subfolders):
        subf_files = os.listdir(deshaw_subf)
        n = len(subf_files)
        subf_records = {
            'pdb_filepaths': [None] * n,
            'suffix_vals': [None] * n,
            'timestamps': [None] * n
        }
        
        for i, pdb_filename in enumerate(subf_files):
            a, b = pdb_filename.split('_')
            a, val = b.split('-')
            suffix_val = val.split('.')[0]
            # int(suffix_val) is 0-2 microseconds
            t = (int(a) * 1e4 + int(suffix_val)) / 1e4
            subf_records['pdb_filepaths'][i] = f'{deshaw_subf}/{pdb_filename}'
            subf_records['suffix_vals'][i] = suffix_val
            subf_records['timestamps'][i] = t
        
        # sort subfolder info lists in timestamp order
        for k, v in subf_records.items():
            subf_records[k] = [
                x for (_, x) \
                in sorted(zip(
                    subf_records['timestamps'], 
                    subf_records[k]
                ))
            ]
        subf_records_l[j] = subf_records
    
    # create master records dict (all in sorted timestamp order)
    for k in records.keys():
        for sr in subf_records_l:
            records[k].extend(sr[k])
    return records
        


## Class: MDFP data processor

In [3]:
class DataProcessorMDFP:
    """
    Processes an ATLAS or DE Shaw dataset, passed 
    as an MDTraj `Trajectory`, into a `DatasetMDFP` (a
    subclass of `torch.utils.data.Dataset`).
    """
    
    def __init__(self,
                 traj,
                 dtype=torch.float32,
                 minmaxscale=False):
        self.traj = traj
        self.dtype = dtype
        self.minmaxscale = minmaxscale

    
    def _get_mdfp_feat_matrix(self):
        """
        Calculates a 'Molecular Dynamics FingerPrint' (MDFP) feature
        matrix for a protein trajectory, using the analysis module of the
        MDTraj package.
    
        To reduce the number of features, note that 'contact_dists' is 
        usually a large number (all pairwise distances between residues).
    
        Note if minmaxscale=True, the scaler is trained on the entire dataset,
        not just a train subset.
        
        This is based on the MDFP concept introduced in Riniker 2017.
        https://doi.org/10.1021/acs.jcim.6b00778
        https://github.com/rinikerlab/mdfptools/blob/master/examples/Example.ipynb
        """
        
        # track computation time
        time_0 = time.time()
        
        # distances
        ctr_mass = md.compute_center_of_mass(self.traj) # shape = (n_frames, 3)
        contact_dists, res_idx = md.compute_contacts( 
            self.traj, 
            contacts='all', 
            scheme='closest-heavy'
        ) # contact_dists shape = (n_frames, n_contacts)
        
        # radius of gyration, SASA
        rg = md.compute_rg(self.traj) # shape = (n_frames,)
        sasa = md.shrake_rupley(self.traj).sum(axis=1) # shape = (n_frames,)
        
        # torsion angles
        phi_idx, phi_angles = md.compute_phi(self.traj) # shape = (n_frames, n_angles)
        psi_idx, psi_angles = md.compute_psi(self.traj) # shape = (n_frames, n_angles)
        omega_idx, omega_angles = md.compute_omega(self.traj) # shape = (n_frames, n_angles)
        
        # K-S hydrogen bond energy (HBE)
        # `kabsch_sander` returns list (of len = n_frames) of matrices of shape = (n_residues, n_residues)
        ks_hbe = md.kabsch_sander(self.traj) 
        # reduce this feature into mean, median, st. dev. HBE within each frame,
        # for NONZERO HBE values only
        # this approach (for other features) is done in (Riniker 2017)
        hbe_stats = { # each list has len = n_frames
            'mean': [None] * len(ks_hbe),
            'median': [None] * len(ks_hbe),
            'stdev': [None] * len(ks_hbe)
        }
        for i, sm in enumerate(ks_hbe):
            tril_mask = np.tril_indices_from(sm, k=-1)
            v = np.array(sm[tril_mask]).squeeze()
            v_nonzero = v[v != 0.0]
            hbe_stats['mean'][i] = np.mean(v_nonzero)
            hbe_stats['median'][i] = np.median(v_nonzero)
            hbe_stats['stdev'][i] = np.std(v_nonzero)
        
        # NOT INCLUDED: thermodynamic properties
        # dip_mom = md.dipole_moments(self.traj) # requires charges
        # stat_die = md.static_dielectric(self.traj) # requires charges
        # therm_exp = md.thermal_expansion_alpha_P(self.traj) # requires temp. and pot. energies
        
        # collect feature arrays in a tuple, and give all at least 2 dim
        feat_tup = (
            ctr_mass,
            contact_dists,
            np.expand_dims(rg, axis=1), 
            np.expand_dims(sasa, axis=1),
            phi_angles,
            psi_angles,
            omega_angles,
            np.expand_dims(np.array(hbe_stats['mean']), axis=1),
            np.expand_dims(np.array(hbe_stats['median']), axis=1),
            np.expand_dims(np.array(hbe_stats['stdev']), axis=1),
        )
        # check shapes if needed
        # for f in feat_tup:
        #     print(f.shape)
        
        # stack features into matrix of shape = (n_frames, n_features)
        X = np.concatenate(feat_tup, axis=1)
        print('final mdfp matrix shape (n_frames, n_features):', 
              X.shape)
    
        if self.minmaxscale:
            mm_scaler = MinMaxScaler()
            X = mm_scaler.fit_transform(X)
            
        # print computation time
        time_elapsed = time.time() - time_0
        print(f'feature matrix generated in: {time_elapsed // 60:.0f}min, {time_elapsed % 60:.1f}sec')
    
        return X

    
    def _get_mdfp_targets(self):
        """
        Generates the targets (id, timestep,
        and stacked xyz-coordinates vector)
        from the `self.traj`.
        """
        # frame ids
        ids = torch.tensor(
            np.arange(self.traj.n_frames), 
            dtype=torch.long
        )
    
        # scaled timestep targets
        timesteps = ids.clone().detach().type(self.dtype)
        if self.minmaxscale:
            timesteps = timesteps / self.traj.n_frames
    
        # residues' center xyz coords
        residue_ctr_coords_l = [None] * len(self.traj)
        for i, frame in enumerate(self.traj):
            residue_ctr_coords = [None] * frame.n_residues
            for j, residue in enumerate(frame.top.residues):
                atom_indices = [atom.index for atom in residue.atoms]
                # note that frame.xyz[0].shape = (n_atoms, 3)
                atom_coords = frame.xyz[0][atom_indices] 
                mean_coords = np.mean(atom_coords, axis=0)
                residue_ctr_coords[j] = mean_coords
            ctr_coords_arr = np.row_stack(residue_ctr_coords)
            residue_ctr_coords_l[i] = ctr_coords_arr
    
        if self.minmaxscale:
            # stack all frames' coords row-wise and use scaler
            all_coords = np.concatenate(residue_ctr_coords_l, axis=0)
            mm_scaler = MinMaxScaler()
            all_coords = mm_scaler.fit_transform(all_coords)
            # split all-frames scaled array into a list of indiv. frame arrays
            residue_ctr_coords_l = [
                arr for arr in np.split(all_coords, self.traj.n_frames, axis=0)
            ]
    
        # unroll (n_residue, 3)-arrays into vectors and convert to list of tensors
        coords_tensors_l = [
            torch.tensor(arr, dtype=self.dtype).T.reshape(-1) \
            for arr in residue_ctr_coords_l
        ]
    
        targets_dict = {
            'id': ids,
            'timestep': timesteps,
            'coords': coords_tensors_l,
        }
        return targets_dict


    def get_mdfp_dataset(self):
        """
        Calls internal functions above to generate
        a dataset (inputs and targets dict) from 
        the `self.traj`.
        """
        # inputs
        X = self._get_mdfp_feat_matrix()
        inputs = [
            torch.tensor(x, dtype=self.dtype) \
            for x in X.tolist()
        ]
        
        # targets
        targets_dict = self._get_mdfp_targets()
        
        dataset_mdfp = DatasetMDFP(
            inputs=inputs,
            targets=targets_dict
        )
        return dataset_mdfp



## Run
- Note this example only loads a partial trajectory!

In [88]:
ATLAS_DATA_FOLDER = "YOUR_PATH_HERE"
DESHAW_DATA_FOLDER = "YOUR_PATH_HERE"
DATASET_SAVE_FOLDER = "YOUR_PATH_HERE"
PROTEIN_INFO = {
    "1bx7_A": {"dir": "1bx7_A_analysis", "kind": "atlas"},
    "1bxy_A": {"dir": "1bxy_A_analysis", "kind": "atlas"},
    "1ptq_A": {"dir": "1ptq_A_analysise", "kind": "atlas"},
    "GB3": {"dir": "GB3", "kind": "deshaw"},
    "BPTI": {"dir": "BPTI", "kind": "deshaw"},
    "Ubiquitin": {"dir": "Ubiquitin", "kind": "deshaw"}
}
MD_RUN_INDEX = 1
TORCH_DTYPE = torch.float32
MINMAXSCALE = False

### Load trajectory

Load one of:

- Atlas, from two files, an `xtc` and a `pdb`.
- DE Shaw, from a large set of `pdb` files stored in subfolders.

In [5]:
for name, info_d in PROTEIN_INFO.items():
    # 1 
    # create trajectory from data (and track time)
    time_0 = time.time()
    print(f'Processing {name}...')
    
    if info_d['kind'] == 'atlas':
        traj = md.load(f"{ATLAS_DATA_FOLDER}/{info_d['dir']}/{name}_R{MD_RUN_INDEX}.xtc", 
                           top=f"{ATLAS_DATA_FOLDER}/{info_d['dir']}/{name}.pdb")
    elif info_d['kind'] == 'deshaw':
        # DE Shaw trajectory
        # first parse DE Shaw data folder for data info
        deshaw_records = get_deshaw_data_info(f"{DESHAW_DATA_FOLDER}/{info_d['dir']}")
        # print(deshaw_records['pdb_filepaths'])
        traj = md.load(deshaw_records['pdb_filepaths'])
        
    te = time.time() - time_0
    print(f'\t{name} trajectory created in: {te // 60:.0f}min, {te % 60:.1f}sec')

    # 2 
    # generate dataset
    time_1 = time.time()
    data_proc_mdfp = DataProcessorMDFP(
        traj,
        dtype=TORCH_DTYPE,
        minmaxscale=MINMAXSCALE
    )
    dataset_mdfp = data_proc_mdfp.get_mdfp_dataset()
    te2 = time.time() - time_1
    print(f'\t{name} dataset created in: {te2 // 60:.0f}min, {te2 % 60:.1f}sec')
    print(f'\tWas data min-max scaled? {MINMAXSCALE}')

    # 3
    # pickle the reference trajectory (MDTraj.Trajectory object)
    reftraj_filename = f'{DATASET_SAVE_FOLDER}/reftraj_{name}.p'
    os.makedirs(os.path.dirname(reftraj_filename), exist_ok=True)
    with open(reftraj_filename, 'wb') as f:
        pickle.dump(traj, f, protocol=pickle.HIGHEST_PROTOCOL)

    # 4
    # pickle the dataset
    if MINMAXSCALE:
        filename = f'{DATASET_SAVE_FOLDER}/scaled/mdfp_dataset_{name}_minmaxscaled.p'
    else:
        filename = f'{DATASET_SAVE_FOLDER}/unscaled/mdfp_dataset_{name}.p'
        
    os.makedirs(os.path.dirname(filename), exist_ok=True)
    with open(filename, 'wb') as f:
        pickle.dump(dataset_mdfp, f, protocol=pickle.HIGHEST_PROTOCOL)
        
    print('\n')


Processing 1bx7_A...
1bx7_A trajectory created in: 0min, 0.1sec
final mdfp matrix shape (n_frames, n_features): (1001, 1548)
feature matrix generated in: 0min, 23.8sec
1bx7_A dataset created in: 0min, 25.5sec
Was data min-max scaled? False

Processing 1bxy_A...
1bxy_A trajectory created in: 0min, 0.0sec
final mdfp matrix shape (n_frames, n_features): (1001, 1838)
feature matrix generated in: 0min, 33.6sec
1bxy_A dataset created in: 0min, 36.0sec
Was data min-max scaled? False

Processing 1ptq_A...
1ptq_A trajectory created in: 0min, 0.0sec
final mdfp matrix shape (n_frames, n_features): (1001, 1283)
feature matrix generated in: 0min, 25.8sec
1ptq_A dataset created in: 0min, 27.6sec
Was data min-max scaled? False

Processing GB3...
GB3 trajectory created in: 1min, 55.9sec
final mdfp matrix shape (n_frames, n_features): (9985, 1604)
feature matrix generated in: 4min, 32.7sec
GB3 dataset created in: 4min, 52.5sec
Was data min-max scaled? False

Processing BPTI...
BPTI trajectory created 

In [13]:
# optional: 
torch.set_printoptions(precision=2)

# inspect one sample in the dataset
print(dataset_mdfp[9984]) # ['x'].shape

# check reloaded dataset
with open(f'{DATASET_SAVE_FOLDER}/unscaled/mdfp_dataset_{name}.p', 'rb') as f:
    unpkled_dataset = pickle.load(f)
unpkled_dataset[9984]

{'x': tensor([ 2.55e-06, -3.73e-06, -2.98e-06,  ..., -2.25e+00, -2.24e+00,
         4.73e-01]), 'target': {'id': tensor(9984), 'timestep': tensor(9984.), 'coords': tensor([ 1.00e-01,  6.86e-01,  3.76e-01,  8.75e-01,  5.65e-01,  1.10e+00,
         9.12e-01,  8.92e-01,  1.29e+00,  1.41e+00,  1.24e+00,  1.24e+00,
         8.41e-01,  9.22e-01,  4.36e-01,  3.27e-01, -1.95e-02, -4.63e-01,
        -3.54e-01, -8.16e-01, -7.16e-01, -9.59e-01, -5.04e-01, -8.72e-01,
        -6.40e-01, -1.71e-01, -4.01e-01, -3.75e-01, -9.89e-02,  1.62e-01,
        -1.82e-01,  2.05e-01,  6.44e-01,  5.48e-01,  2.64e-01,  8.71e-02,
        -3.34e-01, -5.41e-01, -7.56e-01, -2.14e-01, -1.57e-01, -4.08e-01,
        -5.30e-02,  1.27e-01,  1.34e-01,  1.87e-01,  9.85e-02, -3.46e-01,
        -4.47e-01, -3.28e-01, -9.19e-01, -7.80e-01, -1.10e+00, -9.80e-01,
        -8.71e-01, -2.97e-01, -6.76e-01, -9.17e-01, -5.86e-01, -4.36e-01,
        -8.61e-02,  2.44e-01,  4.48e-01,  8.13e-01,  4.84e-01,  7.78e-01,
         2.87e-01,  6.

{'x': tensor([ 2.55e-06, -3.73e-06, -2.98e-06,  ..., -2.25e+00, -2.24e+00,
          4.73e-01]),
 'target': {'id': tensor(9984),
  'timestep': tensor(9984.),
  'coords': tensor([ 1.00e-01,  6.86e-01,  3.76e-01,  8.75e-01,  5.65e-01,  1.10e+00,
           9.12e-01,  8.92e-01,  1.29e+00,  1.41e+00,  1.24e+00,  1.24e+00,
           8.41e-01,  9.22e-01,  4.36e-01,  3.27e-01, -1.95e-02, -4.63e-01,
          -3.54e-01, -8.16e-01, -7.16e-01, -9.59e-01, -5.04e-01, -8.72e-01,
          -6.40e-01, -1.71e-01, -4.01e-01, -3.75e-01, -9.89e-02,  1.62e-01,
          -1.82e-01,  2.05e-01,  6.44e-01,  5.48e-01,  2.64e-01,  8.71e-02,
          -3.34e-01, -5.41e-01, -7.56e-01, -2.14e-01, -1.57e-01, -4.08e-01,
          -5.30e-02,  1.27e-01,  1.34e-01,  1.87e-01,  9.85e-02, -3.46e-01,
          -4.47e-01, -3.28e-01, -9.19e-01, -7.80e-01, -1.10e+00, -9.80e-01,
          -8.71e-01, -2.97e-01, -6.76e-01, -9.17e-01, -5.86e-01, -4.36e-01,
          -8.61e-02,  2.44e-01,  4.48e-01,  8.13e-01,  4.84e-01,  7.78e-