In [1]:
#default_exp model.msms

In [2]:
#hide
%reload_ext autoreload
%autoreload 2

In [3]:
#export
import torch
import pandas as pd
import numpy as np

import typing

from tqdm import tqdm

from alphabase.peptide.fragment import \
    init_fragment_by_precursor_dataframe, \
    set_sliced_fragment_dataframe, \
    get_sliced_fragment_dataframe, \
    get_charged_frag_types

from alphadeep.model.featurize import \
    parse_aa_indices, parse_instrument_indices, \
    get_batch_mod_feature

from alphadeep._settings import \
    global_settings as settings, \
    model_const

import alphadeep.model.base as model_base

class ModelMSMSpDeep3(torch.nn.Module):
    def __init__(self,
        mod_feature_size,
        num_ion_types,
        max_instrument_num,
        dropout=0.2
    ):
        super().__init__()
        BiRNN = True
        self.aa_embedding_size = 27
        hidden=model_const['hidden']
        ins_nce_embed_size=3
        hidden_rnn_layer=2

        self.max_instrument_num = max_instrument_num
        self.instrument_nce_embed = torch.nn.Linear(max_instrument_num+1,ins_nce_embed_size)
        # ins_nce_embed_size = conf.max_instrument_num+1
        # self.instrument_nce_embed = torch.nn.Identity()

        output_hidden_size = (
            hidden if BiRNN else hidden//2
        )+ins_nce_embed_size+1

        # mod_embed_size = 8
        # self.mod_embed_weights = torch.nn.Parameter(
            # torch.empty(mod_size, mod_embed_size),
            # requires_grad = True
        # )
        self.dropout = torch.nn.Dropout(dropout)

#         self.input_cnn = model_base.SeqCNN(
#             self.aa_embedding_size+mod_feature_size
#         )

        self.input_rnn = model_base.SeqLSTM(
            self.aa_embedding_size+mod_feature_size,
            hidden,
            rnn_layer=1, bidirectional=BiRNN
        )

        self.hidden = model_base.SeqLSTM(
            output_hidden_size,
            hidden, rnn_layer=hidden_rnn_layer, 
            bidirectional=BiRNN
        )

        self.output = model_base.SeqLSTM(
            output_hidden_size,
            num_ion_types,
            rnn_layer=1, bidirectional=False
        )

    def forward(self,
        aa_indices,
        mod_x,
        charges:torch.Tensor,
        NCEs:torch.Tensor,
        instrument_indices,
    ):
        aa_x = torch.nn.functional.one_hot(aa_indices, self.aa_embedding_size)
        inst_x = torch.nn.functional.one_hot(instrument_indices, self.max_instrument_num)

        ins_nce = torch.cat((inst_x, NCEs), 1)
        ins_nce = self.instrument_nce_embed(ins_nce)
        ins_nce_charge = torch.cat((ins_nce, charges), 1)
        ins_nce_charge = ins_nce_charge.unsqueeze(1).repeat(1, aa_x.size(1), 1)

        x = torch.cat((aa_x, mod_x), 2)
        x = self.input_rnn(x)
        x = self.dropout(x)

        x = torch.cat((x, ins_nce_charge), 2)
        x = self.hidden(x)
        x = self.dropout(x)

        x = torch.cat((x, ins_nce_charge), 2)

        return self.output(x)[:,3:,:]


In [4]:
import numpy as np
sequence = 'ACDEFGIK'
max_instrument_num = 4
num_ion_types=2
mod_feature_size=2
aa_indices = torch.LongTensor(parse_aa_indices([sequence]))
aa_indices = torch.nn.functional.pad(aa_indices, (1,1))
charges = torch.tensor([[2]])
instrument_indices = torch.LongTensor([1])
NCEs = torch.tensor([[0.3]])
mod_x = torch.zeros((1, aa_indices.size(1), mod_feature_size))

In [5]:
msms_model = ModelMSMSpDeep3(mod_feature_size, num_ion_types, max_instrument_num)
msms_model

ModelMSMSpDeep3(
  (instrument_nce_embed): Linear(in_features=5, out_features=3, bias=True)
  (dropout): Dropout(p=0.2, inplace=False)
  (input_rnn): SeqLSTM(
    (rnn): LSTM(29, 128, batch_first=True, bidirectional=True)
  )
  (hidden): SeqLSTM(
    (rnn): LSTM(260, 128, num_layers=2, batch_first=True, bidirectional=True)
  )
  (output): SeqLSTM(
    (rnn): LSTM(260, 2, batch_first=True)
  )
)

In [6]:
msms_model.eval()
msms_model(aa_indices, mod_x, charges, NCEs, instrument_indices)

tensor([[[ 0.2332, -0.3277],
         [ 0.2418, -0.3441],
         [ 0.2477, -0.3532],
         [ 0.2529, -0.3585],
         [ 0.2557, -0.3638],
         [ 0.2576, -0.3678],
         [ 0.2615, -0.3755],
         [ 0.2665, -0.3889],
         [ 0.2703, -0.4171]]], grad_fn=<SliceBackward>)

# pDeepModel for MS/MS prediction

In [7]:
#export
class IntenAwareLoss(torch.nn.Module):
    def __init__(self, base_weight=0.2):
        super().__init__()
        self.w = base_weight
    
    def forward(self, pred, target):
        x = pred.reshape(-1)
        y = target.reshape(-1)
        return torch.mean((y+self.w)*torch.abs(x-y))

In [8]:
#export
mod_feature_size = len(model_const['mod_elements'])
max_instrument_num = model_const['max_instrument_num']
frag_types = settings['model']['frag_types']
max_frag_charge = settings['model']['max_frag_charge']
num_ion_types = len(frag_types)*max_frag_charge
nce_factor = model_const['nce_factor']
charge_factor = model_const['charge_factor']

class pDeepModel(model_base.ModelImplBase):
    def __init__(self,
        dropout=0.2,
        lr=0.001,
        model_class:typing.Type[torch.nn.Module]=ModelMSMSpDeep3,
    ):
        super().__init__()
        self.charged_frag_types = get_charged_frag_types(
            frag_types, max_frag_charge
        )
        self.charge_factor = charge_factor
        self.NCE_factor = nce_factor
        self.build(
            model_class, 
            mod_feature_size = mod_feature_size,
            num_ion_types = len(self.charged_frag_types),
            max_instrument_num = max_instrument_num,
            dropout=dropout,
            lr=lr
        )
        self.loss_func = IntenAwareLoss()
        # self.loss_func = torch.nn.L1Loss()
        self.min_inten = 1e-4

    def train(self, 
        precursor_df: pd.DataFrame, 
        fragment_inten_df: pd.DataFrame, 
        epoch=10, 
        batch_size=1024,
        verbose=False,
        verbose_each_epoch=True,
    ):
        self.model.train()

        fragment_inten_df = fragment_inten_df[self.charged_frag_types]

        if np.all(precursor_df['NCE'].values > 1):
            precursor_df['NCE'] = precursor_df['NCE']*self.NCE_factor

        for epoch in range(epoch):
            batch_cost = []
            _grouped = list(precursor_df.sample(frac=1).groupby('nAA'))
            rnd_nAA = np.random.permutation(len(_grouped))
            if verbose_each_epoch:
                batch_tqdm = tqdm(rnd_nAA)
            else:
                batch_tqdm = rnd_nAA
            for i_group in batch_tqdm:
                nAA, df_group = _grouped[i_group]
                df_group = df_group.reset_index(drop=True)
                for i in range(0, len(df_group), batch_size):
                    batch_end = i+batch_size-1 # DataFrame.loc[start:end] inlcudes the end

                    aa_indices = torch.LongTensor(
                        parse_aa_indices(
                            df_group.loc[i:batch_end, 'sequence'].values.astype('U')
                        )
                    )

                    mod_x_batch = get_batch_mod_feature(df_group.loc[i:batch_end,:], nAA)
                    mod_x = torch.Tensor(mod_x_batch)

                    charges = torch.Tensor(
                        df_group.loc[i:batch_end, 'charge'].values
                    ).unsqueeze(1)*self.charge_factor

                    nces = torch.Tensor(df_group.loc[i:batch_end, 'NCE'].values).unsqueeze(1)

                    instrument_indices = torch.LongTensor(
                        parse_instrument_indices(df_group.loc[i:batch_end, 'instrument'])
                    )
                    intens = torch.Tensor(
                        get_sliced_fragment_dataframe(
                            fragment_inten_df, 
                            df_group.loc[
                                i:batch_end, ['frag_start_idx','frag_end_idx']
                            ].values
                        ).values
                    ).view(-1, nAA-1, len(self.charged_frag_types))
                    
                    cost = self._train_one_batch(
                        intens, 
                        aa_indices, mod_x, charges,
                        nces, instrument_indices
                    )
                    batch_cost.append(cost.item())
                if verbose_each_epoch:
                    batch_tqdm.set_description(
                        f'Epoch={epoch+1}, nAA={nAA}, Batch={len(batch_cost)}, Loss={cost.item():.4f}'
                    )
            if verbose: print(f'[MS/MS training] Epoch={epoch+1}, Mean Loss={np.mean(batch_cost)}')

    def predict(self, 
        precursor_df: pd.DataFrame,
        reference_frag_df: pd.DataFrame = None,
        batch_size: int=1024,
        verbose=False,
    )->pd.DataFrame:

        self.model.eval()
        
        predict_inten_df = init_fragment_by_precursor_dataframe(
            precursor_df, self.charged_frag_types, reference_frag_df
        )

        if np.all(precursor_df['NCE'].values > 1):
            precursor_df['NCE'] = precursor_df['NCE']*self.NCE_factor

        _grouped = precursor_df.groupby('nAA')

        if verbose:
            batch_tqdm = tqdm(_grouped)
        else:
            batch_tqdm = _grouped

        for nAA, df_group in batch_tqdm:
            df_group = df_group.reset_index(drop=True)
            for i in range(0, len(df_group), batch_size):
                batch_end = i+batch_size-1 # DataFrame.loc[start:end] inlcudes the end
                
                mod_x_batch = get_batch_mod_feature(df_group.loc[i:batch_end,:], nAA)

                aa_indices = torch.LongTensor(parse_aa_indices(
                    df_group.loc[i:batch_end, 'sequence'].values.astype('U')
                ))
                mod_x = torch.Tensor(mod_x_batch)
                charges = torch.Tensor(
                    df_group.loc[i:batch_end, 'charge'].values
                ).view(-1,1)*self.charge_factor
                
                nces = torch.Tensor(df_group.loc[i:batch_end, 'NCE'].values).view(-1,1)
                instrument_indices = torch.LongTensor(
                    parse_instrument_indices(df_group.loc[i:batch_end, 'instrument'])
                )

                predicts = self.model(
                    *[fea.to(self.device) for fea in 
                    [aa_indices, mod_x, charges, nces, instrument_indices]
                ]).cpu().detach().numpy()
                predicts[predicts>1] = 1
                predicts[predicts<self.min_inten] = 0

                set_sliced_fragment_dataframe(
                    predict_inten_df,
                    predicts.reshape((-1, len(self.charged_frag_types))),
                    df_group.loc[
                        i:batch_end, 
                        ['frag_start_idx','frag_end_idx']
                    ].values,
                    self.charged_frag_types
                )

        return predict_inten_df

In [9]:
#export
from scipy.stats import pearsonr
from scipy.stats import spearmanr
from scipy.spatial.distance import cosine

def cosine(x1, x2, eps=1e-8):
    _cos = np.dot(x1,x2)/(np.linalg.norm(x1)*np.linalg.norm(x2)+eps)
    if _cos > 1: _cos = 1
    return _cos

def spectral_angle(x1, x2, eps=1e-8):
    return 1 - 2 * np.arccos(cosine(x1,x2,eps)) / np.pi

def pearson(x1, x2):
    ret = pearsonr(x1, x2)[0]
    if np.isnan(ret):
        return 0
    else: return ret

def spearman(x1, x2):
    ret = spearmanr(x1, x2)[0]
    if np.isnan(ret):
        return 0
    else: return ret

def get_metric_fuc(metric):
    metric = metric.upper()
    if metric == 'COS':
        return cosine
    elif metric == 'PCC':
        return pearson
    elif metric == 'SPC':
        return spearman
    elif metric == 'SA':
        return spectral_angle
    else:
        raise NotImplementedError(f'Unknown metric "{metric}".')

def batch_metric(batch1, batch2, frag_start_end_list, metric_func):
    sim_list = []
    for start, end in frag_start_end_list:
        sim_list.append(
            metric_func(
                batch1[start:end].reshape(-1), 
                batch2[start:end].reshape(-1)
            )
        )
    return sim_list

def evaluate_msms(
    precursor_df: pd.DataFrame,
    predict_inten_df: pd.DataFrame,
    fragment_inten_df: pd.DataFrame,
    charged_frag_types: typing.List,
    metrics = ['PCC','COS','SA'], #+['SPC']
)->pd.DataFrame:
    ret = pd.DataFrame(
        np.zeros((len(precursor_df), len(metrics))), 
        columns=metrics
    )

    for metric in tqdm(metrics):
        ret[metric] = batch_metric(
            predict_inten_df[charged_frag_types].values,
            fragment_inten_df[charged_frag_types].values,
            precursor_df[['frag_start_idx','frag_end_idx']].values,
            get_metric_fuc(metric)
        )
    return ret

def add_cutoff_metric(
    metrics_describ, metrics_df, thres=0.9
):
    vals = []
    for col in metrics_describ.columns.values:
        vals.append(metrics_df.loc[metrics_df[col]>thres, col].count()/len(metrics_df))
    metrics_describ.loc[f'>{thres:.2f}'] = vals
    return metrics_describ

### Examples

In [10]:
fragment_inten_df = pd.DataFrame({'b':np.arange(100)/100, 'y':np.arange(100)/100})
repeat = 10
precursor_df = pd.DataFrame({
    'sequence': ['AGHCEWQMKYR']*repeat,
    'mods': ['Acetyl@Protein N-term;Carbamidomethyl@C;Oxidation@M']*repeat,
    'mod_sites': ['0;4;8']*repeat,
    'nAA': [11]*repeat,
    'NCE': [20]*repeat,
    'instrument': 'QE',
    'charge': np.arange(1,repeat+1),
    'frag_start_idx':np.arange(10, dtype=int)*10,
    'frag_end_idx':np.arange(10, dtype=int)*10+10,
})
precursor_df

Unnamed: 0,sequence,mods,mod_sites,nAA,NCE,instrument,charge,frag_start_idx,frag_end_idx
0,AGHCEWQMKYR,Acetyl@Protein N-term;Carbamidomethyl@C;Oxidat...,0;4;8,11,20,QE,1,0,10
1,AGHCEWQMKYR,Acetyl@Protein N-term;Carbamidomethyl@C;Oxidat...,0;4;8,11,20,QE,2,10,20
2,AGHCEWQMKYR,Acetyl@Protein N-term;Carbamidomethyl@C;Oxidat...,0;4;8,11,20,QE,3,20,30
3,AGHCEWQMKYR,Acetyl@Protein N-term;Carbamidomethyl@C;Oxidat...,0;4;8,11,20,QE,4,30,40
4,AGHCEWQMKYR,Acetyl@Protein N-term;Carbamidomethyl@C;Oxidat...,0;4;8,11,20,QE,5,40,50
5,AGHCEWQMKYR,Acetyl@Protein N-term;Carbamidomethyl@C;Oxidat...,0;4;8,11,20,QE,6,50,60
6,AGHCEWQMKYR,Acetyl@Protein N-term;Carbamidomethyl@C;Oxidat...,0;4;8,11,20,QE,7,60,70
7,AGHCEWQMKYR,Acetyl@Protein N-term;Carbamidomethyl@C;Oxidat...,0;4;8,11,20,QE,8,70,80
8,AGHCEWQMKYR,Acetyl@Protein N-term;Carbamidomethyl@C;Oxidat...,0;4;8,11,20,QE,9,80,90
9,AGHCEWQMKYR,Acetyl@Protein N-term;Carbamidomethyl@C;Oxidat...,0;4;8,11,20,QE,10,90,100


In [11]:
fragment_inten_df

Unnamed: 0,b,y
0,0.00,0.00
1,0.01,0.01
2,0.02,0.02
3,0.03,0.03
4,0.04,0.04
...,...,...
95,0.95,0.95
96,0.96,0.96
97,0.97,0.97
98,0.98,0.98


### Train an MSMS model

In [12]:
pdeep = pDeepModel()
pdeep.charged_frag_types=['b','y']
pdeep.build(
    ModelMSMSpDeep3, 
    mod_feature_size = mod_feature_size,
    num_ion_types = len(pdeep.charged_frag_types),
    max_instrument_num = max_instrument_num,
    dropout=0.2,
    lr=1e-3
)
pdeep.train(precursor_df, fragment_inten_df, epoch=5, batch_size=3)

Epoch=1, nAA=11, Batch=4, Loss=0.1162: 100%|██████████| 1/1 [00:00<00:00,  5.14it/s]
Epoch=2, nAA=11, Batch=4, Loss=0.2476: 100%|██████████| 1/1 [00:00<00:00,  5.77it/s]
Epoch=3, nAA=11, Batch=4, Loss=0.3574: 100%|██████████| 1/1 [00:00<00:00,  5.77it/s]
Epoch=4, nAA=11, Batch=4, Loss=0.1601: 100%|██████████| 1/1 [00:00<00:00,  7.75it/s]
Epoch=5, nAA=11, Batch=4, Loss=0.0641: 100%|██████████| 1/1 [00:00<00:00,  7.60it/s]


### Predict from the model

In [13]:
predict_inten_df = pdeep.predict(precursor_df)
predict_inten_df



Unnamed: 0,b,y
0,0.261805,0.210516
1,0.269720,0.215043
2,0.276544,0.221313
3,0.281586,0.226752
4,0.285407,0.234778
...,...,...
95,0.531534,0.409498
96,0.532575,0.415461
97,0.530297,0.417315
98,0.519031,0.412681


### Metrics

In [14]:
metrics = evaluate_msms(precursor_df, predict_inten_df, fragment_inten_df, ['b','y'])
metrics

100%|██████████| 3/3 [00:00<00:00, 568.59it/s]


Unnamed: 0,PCC,COS,SA
0,0.604238,0.873292,0.676038
1,0.550038,0.986744,0.896227
2,0.487563,0.993573,0.927782
3,0.421679,0.994451,0.932903
4,0.356834,0.994247,0.931679
5,0.296188,0.993726,0.928649
6,0.241453,0.993088,0.925106
7,0.193191,0.992411,0.921517
8,0.151215,0.991736,0.918101
9,0.114926,0.991095,0.914979


In [15]:
metrics_describ = metrics.describe()
add_cutoff_metric(metrics_describ, metrics, thres=0.9)
add_cutoff_metric(metrics_describ, metrics, thres=0.0)

Unnamed: 0,PCC,COS,SA
count,10.0,10.0,10.0
mean,0.341732,0.980436,0.897298
std,0.170659,0.037714,0.078474
min,0.114926,0.873292,0.676038
25%,0.205257,0.991256,0.91576
50%,0.326511,0.992749,0.923312
75%,0.471092,0.993688,0.928432
max,0.604238,0.994451,0.932903
>0.90,0.0,0.9,0.8
>0.00,1.0,1.0,1.0
