In [1]:
#default_exp model.msms

In [2]:
#hide
%reload_ext autoreload
%autoreload 2

In [3]:
#export
import torch
import pandas as pd
import numpy as np

from tqdm import tqdm

from alphabase.peptide.fragment import \
    init_fragment_by_precursor_dataframe, \
    set_sliced_fragment_dataframe, \
    get_sliced_fragment_dataframe, \
    get_charged_frag_types

from alphadeep.model.featurize import \
    parse_aa_indices, parse_instrument_indices, \
    get_batch_mod_feature

from alphadeep._settings import \
    global_settings as settings, \
    const_settings

import alphadeep.model.base as model_base

class ModelMSMSpDeep3(torch.nn.Module):
    def __init__(self, 
        mod_feature_size,
        num_ion_types,
        max_instrument_num,
        dropout=0.2
    ):
        super().__init__()
        BiRNN = True
        self.aa_embedding_size = 27
        hidden=256
        ins_nce_embed_size = 3
        
        self.max_instrument_num = max_instrument_num
        self.instrument_nce_embed = torch.nn.Linear(max_instrument_num+1,ins_nce_embed_size)
        # ins_nce_embed_size = conf.max_instrument_num+1
        # self.instrument_nce_embed = torch.nn.Identity()
        
        output_hidden_size = hidden*(2 if BiRNN else 1) + ins_nce_embed_size + 1
        
        # mod_embed_size = 8
        # self.mod_embed_weights = torch.nn.Parameter(
            # torch.empty(mod_size, mod_embed_size), 
            # requires_grad = True
        # )
        self.dropout = torch.nn.Dropout(dropout)
        
        self.input = model_base.SeqLSTM(
            self.aa_embedding_size+mod_feature_size+ins_nce_embed_size+1, 
            hidden,
            rnn_layer=1, bidirectional=BiRNN
        )
        
        self.hidden = model_base.SeqLSTM(
            output_hidden_size,
            hidden,
            rnn_layer=1, bidirectional=BiRNN
        )
        
        self.output = model_base.SeqLSTM(
            output_hidden_size,
            num_ion_types,
            rnn_layer=1, bidirectional=False
        )

    def forward(self, 
        aa_indices, 
        mod_x, 
        charges:torch.Tensor,
        NCEs:torch.Tensor, 
        instrument_indices,
    ):
        aa_x = torch.nn.functional.one_hot(aa_indices, self.aa_embedding_size)
        inst_x = torch.nn.functional.one_hot(instrument_indices, self.max_instrument_num)
        
        ins_nce = torch.cat((inst_x, NCEs), 1)
        ins_nce = self.instrument_nce_embed(ins_nce)
        ins_nce_charge = torch.cat((ins_nce, charges), 1)
        ins_nce_charge = ins_nce_charge.unsqueeze(1).repeat(1, aa_x.size(1), 1)

        x = torch.cat((aa_x, mod_x, ins_nce_charge), 2)
        x = self.input(x)
        x = self.dropout(x)
        
        x = torch.cat((x, ins_nce_charge), 2)
        x = self.hidden(x)
        x = self.dropout(x)

        x = torch.cat((x, ins_nce_charge), 2)

        return self.output(x[:,1:-2,:])


In [4]:
class ModelMSMSTest(torch.nn.Module):
    def __init__(self, 
        mod_feature_size,
        num_ion_types,
        max_instrument_num,
        dropout=0.2
    ):
        super().__init__()
        embedding_hidden = 27
        instrument_hidden = 8
        
       # self.aa_embedding = model_base.aa_embedding(embedding_hidden)
        
        self.instrument_embedding = torch.nn.Embedding(
            max_instrument_num, 
            instrument_hidden
        )
        hidden = 256
        self.input = model_base.SeqLSTM(
            embedding_hidden+mod_feature_size, hidden,
            rnn_layer=1, bidirectional=False
        )
        self.hidden = model_base.SeqLSTM(hidden, hidden)

        self.output = model_base.SeqLSTM(
            # +2 for charge and NCE
            hidden*2+instrument_hidden+2, 
            num_ion_types, rnn_layer=1, bidirectional=False
        )

    def forward(self, 
        aa_indices, 
        mod_x, 
        charges:torch.Tensor,
        NCEs:torch.Tensor, 
        instrument_indices,
    ):
        #aa_x = self.aa_embedding(aa_indices)
        aa_x = torch.nn.functional.one_hot(aa_indices, 27)
        inst_x = self.instrument_embedding(instrument_indices)

        x = self.input(torch.cat((aa_x, mod_x), dim=2))
        x = self.hidden(x)

        x = torch.cat((
            x, 
            charges.unsqueeze(1).repeat(1,aa_x.size(1),1), 
            NCEs.unsqueeze(1).repeat(1,aa_x.size(1),1), 
            inst_x.unsqueeze(1).repeat(1,aa_x.size(1),1),
        ), dim=2)

        return self.output(x[:,1:-2,:])

In [5]:
import numpy as np
sequence = 'ACDEFGIK'
max_instrument_num = 4
num_ion_types=2
mod_feature_size=2
aa_indices = torch.LongTensor(parse_aa_indices([sequence]))
aa_indices = torch.nn.functional.pad(aa_indices, (1,1))
charges = torch.tensor([[2]])
instrument_indices = torch.LongTensor([1])
NCEs = torch.tensor([[0.3]])
mod_x = torch.zeros((1, aa_indices.size(1), mod_feature_size))

In [6]:
msms_model = ModelMSMSpDeep3(mod_feature_size, num_ion_types, max_instrument_num)
msms_model

ModelMSMSpDeep3(
  (instrument_nce_embed): Linear(in_features=5, out_features=3, bias=True)
  (dropout): Dropout(p=0.2, inplace=False)
  (input): SeqLSTM(
    (rnn): LSTM(33, 256, batch_first=True, bidirectional=True)
  )
  (hidden): SeqLSTM(
    (rnn): LSTM(516, 256, batch_first=True, bidirectional=True)
  )
  (output): SeqLSTM(
    (rnn): LSTM(516, 2, batch_first=True)
  )
)

In [7]:
msms_model.eval()
msms_model(aa_indices, mod_x, charges, NCEs, instrument_indices)

tensor([[[-0.3327, -0.3021],
         [-0.3553, -0.2417],
         [-0.2875, -0.2250],
         [-0.2134, -0.2110],
         [-0.1467, -0.1926],
         [-0.0854, -0.1808],
         [-0.0279, -0.1736],
         [ 0.0349, -0.1684],
         [ 0.1028, -0.1675]]], grad_fn=<TransposeBackward0>)

# pDeepModel for MS/MS prediction

In [8]:
#export
class IntenAwareLoss(torch.nn.Module):
    def __init__(self, base_weight = 0.1):
        super().__init__()
        self.w = base_weight
    
    def forward(self, pred, target):
        x = pred.reshape(-1)
        y = target.reshape(-1)
        #return torch.sum(torch.abs(x-y))/pred.size(0)
        return torch.mean((y+self.w)*torch.abs(x-y))

In [9]:
#export

mod_feature_size = len(const_settings['mod_elements'])
max_instrument_num = const_settings['max_instrument_num']
frag_types = settings['model']['frag_types']
max_frag_charge = settings['model']['max_frag_charge']
num_ion_types = len(frag_types)*max_frag_charge
nce_factor = const_settings['nce_factor']
charge_factor = const_settings['charge_factor']

class pDeepModel(model_base.ModelImplBase):
    def __init__(self,
        dropout=0.2,
        lr=0.001,
    ):
        super().__init__()
        self.charged_frag_types = get_charged_frag_types(
            frag_types, max_frag_charge
        )
        self.charge_factor = charge_factor
        self.NCE_factor = nce_factor
        self.build(
            ModelMSMSpDeep3, 
            mod_feature_size = mod_feature_size,
            num_ion_types = len(self.charged_frag_types),
            max_instrument_num = max_instrument_num,
            dropout=dropout,
            lr=lr
        )
        self.loss_func = IntenAwareLoss()
        # self.loss_func = torch.nn.L1Loss()
        self.min_inten = 1e-4

    def train(self, 
        precursor_df: pd.DataFrame, 
        fragment_inten_df: pd.DataFrame, 
        epoch=10, 
        batch_size=1024,
        verbose=True
    ):
        self.model.train()

        fragment_inten_df = fragment_inten_df[self.charged_frag_types]

        if np.all(precursor_df['NCE'].values > 1):
            precursor_df['NCE'] = precursor_df['NCE']*self.NCE_factor

        for epoch in range(epoch):
            batch_cost = []
            _grouped = list(precursor_df.sample(frac=1).groupby('nAA'))
            rnd_nAA = np.random.permutation(len(_grouped))
            batch_tqdm = tqdm(rnd_nAA)
            for i_group in batch_tqdm:
                nAA, df_group = _grouped[i_group]
                df_group = df_group.reset_index(drop=True)
                for i in range(0, len(df_group), batch_size):
                    batch_end = i+batch_size-1 # DataFrame.loc[start:end] inlcudes the end

                    aa_indices = torch.LongTensor(
                        parse_aa_indices(
                            df_group.loc[i:batch_end, 'sequence'].values.astype('U')
                        )
                    )

                    mod_x_batch = get_batch_mod_feature(df_group.loc[i:batch_end,:], nAA)
                    mod_x = torch.Tensor(mod_x_batch)

                    charges = torch.Tensor(
                        df_group.loc[i:batch_end, 'charge'].values
                    ).unsqueeze(1)*self.charge_factor

                    nces = torch.Tensor(df_group.loc[i:batch_end, 'NCE'].values).unsqueeze(1)

                    instrument_indices = torch.LongTensor(
                        parse_instrument_indices(df_group.loc[i:batch_end, 'instrument'])
                    )
                    intens = torch.Tensor(
                        get_sliced_fragment_dataframe(
                            fragment_inten_df, 
                            df_group.loc[
                                i:batch_end, ['frag_start_idx','frag_end_idx']
                            ].values
                        ).values
                    ).view(-1, nAA-1, len(self.charged_frag_types))
                    
                    cost = self._train_one_batch(
                        intens, 
                        aa_indices, mod_x, charges,
                        nces, instrument_indices
                    )
                    batch_cost.append(cost.item())
                batch_tqdm.set_description(
                    f'Epoch={epoch+1}, nAA={nAA}, Batch={len(batch_cost)}, Loss={cost.item():.4f}'
                )
            if verbose: print(f'[MS/MS training] epoch={epoch+1}, mean Loss={np.mean(batch_cost)}')

    def predict(self, 
        precursor_df: pd.DataFrame,
        batch_size: int=1024
    )->pd.DataFrame:

        self.model.eval()

        predict_inten_df = init_fragment_by_precursor_dataframe(
            precursor_df, self.charged_frag_types
        )

        if np.all(precursor_df['NCE'].values > 1):
            precursor_df['NCE'] = precursor_df['NCE']*self.NCE_factor

        _grouped = precursor_df.groupby('nAA')
        for nAA, df_group in tqdm(_grouped):
            df_group = df_group.reset_index(drop=True)
            for i in range(0, len(df_group), batch_size):
                batch_end = i+batch_size-1 # DataFrame.loc[start:end] inlcudes the end
                
                mod_x_batch = get_batch_mod_feature(df_group.loc[i:batch_end,:], nAA)

                aa_indices = torch.LongTensor(parse_aa_indices(
                    df_group.loc[i:batch_end, 'sequence'].values.astype('U')
                ))
                mod_x = torch.Tensor(mod_x_batch)
                charges = torch.Tensor(
                    df_group.loc[i:batch_end, 'charge'].values
                ).view(-1,1)*self.charge_factor
                
                nces = torch.Tensor(df_group.loc[i:batch_end, 'NCE'].values).view(-1,1)
                instrument_indices = torch.LongTensor(
                    parse_instrument_indices(df_group.loc[i:batch_end, 'instrument'])
                )

                predicts = self.model(
                    *[fea.to(self.device) for fea in 
                    [aa_indices, mod_x, charges, nces, instrument_indices]
                ]).cpu().detach().numpy()
                predicts[predicts>1] = 1
                predicts[predicts<self.min_inten] = 0

                set_sliced_fragment_dataframe(
                    predict_inten_df,
                    predicts.reshape((-1, len(self.charged_frag_types))),
                    df_group.loc[
                        i:batch_end, 
                        ['frag_start_idx','frag_end_idx']
                    ].values,
                    self.charged_frag_types
                )

        return predict_inten_df

In [10]:
#export
from scipy.stats import pearsonr
from scipy.stats import spearmanr
from scipy.spatial.distance import cosine
import typing

def cosine(x1, x2, eps=1e-8):
    _cos = np.dot(x1,x2)/(np.linalg.norm(x1)*np.linalg.norm(x2)+eps)
    if _cos > 1: _cos = 1
    return _cos

def spectral_angle(x1, x2, eps=1e-8):
    return 1 - 2 * np.arccos(cosine(x1,x2,eps)) / np.pi

def pearson(x1, x2):
    ret = pearsonr(x1, x2)[0]
    if np.isnan(ret):
        return 0
    else: return ret

def spearman(x1, x2):
    ret = spearmanr(x1, x2)[0]
    if np.isnan(ret):
        return 0
    else: return ret

def get_metric_fuc(metric):
    metric = metric.upper()
    if metric == 'COS':
        return cosine
    elif metric == 'PCC':
        return pearson
    elif metric == 'SPC':
        return spearman
    elif metric == 'SA':
        return spectral_angle
    else:
        raise NotImplementedError(f'Unknown metric "{metric}".')

def batch_metric(batch1, batch2, frag_start_end_list, metric_func):
    sim_list = []
    for start, end in frag_start_end_list:
        sim_list.append(
            metric_func(
                batch1[start:end].reshape(-1), 
                batch2[start:end].reshape(-1)
            )
        )
    return sim_list

def evaluate_msms(
    precursor_df: pd.DataFrame,
    predict_inten_df: pd.DataFrame,
    fragment_inten_df: pd.DataFrame,
    charged_frag_types: typing.List,
    metrics = ['PCC','COS','SA'], #+['SPC']
)->pd.DataFrame:
    ret = pd.DataFrame(
        np.zeros((len(precursor_df), len(metrics))), 
        columns=metrics
    )

    for metric in tqdm(metrics):
        ret[metric] = batch_metric(
            predict_inten_df[charged_frag_types].values,
            fragment_inten_df[charged_frag_types].values,
            precursor_df[['frag_start_idx','frag_end_idx']].values,
            get_metric_fuc(metric)
        )
    return ret

def add_cutoff_metric(
    metrics_describ, metrics_df, thres=0.9
):
    vals = []
    for col in metrics_describ.columns.values:
        vals.append(metrics_df.loc[metrics_df[col]>thres, col].count()/len(metrics_df))
    metrics_describ.loc[f'>{thres:.2f}'] = vals
    return metrics_describ

### Examples

In [11]:
fragment_inten_df = pd.DataFrame({'b':np.arange(100)/100, 'y':np.arange(100)/100})
repeat = 10
precursor_df = pd.DataFrame({
    'sequence': ['AGHCEWQMKYR']*repeat,
    'mods': ['Acetyl@Protein N-term;Carbamidomethyl@C;Oxidation@M']*repeat,
    'mod_sites': ['0;4;8']*repeat,
    'nAA': [11]*repeat,
    'NCE': [20]*repeat,
    'instrument': 'QE',
    'charge': np.arange(1,repeat+1),
    'frag_start_idx':np.arange(10, dtype=int)*10,
    'frag_end_idx':np.arange(10, dtype=int)*10+10,
})
precursor_df

Unnamed: 0,sequence,mods,mod_sites,nAA,NCE,instrument,charge,frag_start_idx,frag_end_idx
0,AGHCEWQMKYR,Acetyl@Protein N-term;Carbamidomethyl@C;Oxidat...,0;4;8,11,20,QE,1,0,10
1,AGHCEWQMKYR,Acetyl@Protein N-term;Carbamidomethyl@C;Oxidat...,0;4;8,11,20,QE,2,10,20
2,AGHCEWQMKYR,Acetyl@Protein N-term;Carbamidomethyl@C;Oxidat...,0;4;8,11,20,QE,3,20,30
3,AGHCEWQMKYR,Acetyl@Protein N-term;Carbamidomethyl@C;Oxidat...,0;4;8,11,20,QE,4,30,40
4,AGHCEWQMKYR,Acetyl@Protein N-term;Carbamidomethyl@C;Oxidat...,0;4;8,11,20,QE,5,40,50
5,AGHCEWQMKYR,Acetyl@Protein N-term;Carbamidomethyl@C;Oxidat...,0;4;8,11,20,QE,6,50,60
6,AGHCEWQMKYR,Acetyl@Protein N-term;Carbamidomethyl@C;Oxidat...,0;4;8,11,20,QE,7,60,70
7,AGHCEWQMKYR,Acetyl@Protein N-term;Carbamidomethyl@C;Oxidat...,0;4;8,11,20,QE,8,70,80
8,AGHCEWQMKYR,Acetyl@Protein N-term;Carbamidomethyl@C;Oxidat...,0;4;8,11,20,QE,9,80,90
9,AGHCEWQMKYR,Acetyl@Protein N-term;Carbamidomethyl@C;Oxidat...,0;4;8,11,20,QE,10,90,100


In [12]:
fragment_inten_df

Unnamed: 0,b,y
0,0.00,0.00
1,0.01,0.01
2,0.02,0.02
3,0.03,0.03
4,0.04,0.04
...,...,...
95,0.95,0.95
96,0.96,0.96
97,0.97,0.97
98,0.98,0.98


### Train an MSMS model

In [13]:
pdeep = pDeepModel()
pdeep.charged_frag_types=['b','y']
pdeep.build(
    ModelMSMSpDeep3, 
    mod_feature_size = mod_feature_size,
    num_ion_types = len(pdeep.charged_frag_types),
    max_instrument_num = max_instrument_num,
    dropout=0.2,
    lr=1e-3
)
pdeep.train(precursor_df, fragment_inten_df, epoch=5, batch_size=3)

Epoch=1, nAA=11, Batch=4, Loss=0.3291: 100%|██████████| 1/1 [00:00<00:00,  5.48it/s]
Epoch=2, nAA=11, Batch=4, Loss=0.1156: 100%|██████████| 1/1 [00:00<00:00,  6.45it/s]
  0%|          | 0/1 [00:00<?, ?it/s]

[MS/MS training] epoch=1, mean Loss=0.3151082172989845
[MS/MS training] epoch=2, mean Loss=0.23548224568367004


Epoch=3, nAA=11, Batch=4, Loss=0.1747: 100%|██████████| 1/1 [00:00<00:00,  5.99it/s]
Epoch=4, nAA=11, Batch=4, Loss=0.0923: 100%|██████████| 1/1 [00:00<00:00,  6.25it/s]
  0%|          | 0/1 [00:00<?, ?it/s]

[MS/MS training] epoch=3, mean Loss=0.2127169594168663
[MS/MS training] epoch=4, mean Loss=0.10877891257405281


Epoch=5, nAA=11, Batch=4, Loss=0.0772: 100%|██████████| 1/1 [00:00<00:00,  6.27it/s]

[MS/MS training] epoch=5, mean Loss=0.0707893492653966





### Predict from the model

In [14]:
predict_inten_df = pdeep.predict(precursor_df)
predict_inten_df

100%|██████████| 1/1 [00:00<00:00, 37.91it/s]


Unnamed: 0,b,y
0,0.034328,0.054884
1,0.051684,0.038291
2,0.040419,0.032964
3,0.026396,0.022558
4,0.028102,0.025052
...,...,...
95,0.944994,0.954228
96,0.943682,0.954656
97,0.941017,0.953547
98,0.930580,0.945205


### Metrics

In [15]:
metrics = evaluate_msms(precursor_df, predict_inten_df, fragment_inten_df, ['b','y'])
metrics

100%|██████████| 4/4 [00:00<00:00, 228.04it/s]


Unnamed: 0,PCC,COS,SPC,SA
0,0.597007,0.90109,0.48906,0.714463
1,0.630492,0.969698,0.546419,0.84288
2,0.684962,0.984044,0.609816,0.886123
3,0.756425,0.992494,0.691326,0.921951
4,0.798101,0.995868,0.830195,0.942107
5,0.749426,0.996099,0.881516,0.943748
6,0.676514,0.995651,0.70642,0.940603
7,0.636446,0.995826,0.564532,0.941816
8,0.613832,0.9964,0.507173,0.945965
9,0.59477,0.996982,0.443777,0.950528


In [16]:
metrics_describ = metrics.describe()
add_cutoff_metric(metrics_describ, metrics, thres=0.9)
add_cutoff_metric(metrics_describ, metrics, thres=0.0)

Unnamed: 0,PCC,COS,SPC,SA
count,10.0,10.0,10.0,10.0
mean,0.673797,0.982415,0.627023,0.903018
std,0.072372,0.029834,0.146878,0.07455
min,0.59477,0.90109,0.443777,0.714463
25%,0.617997,0.986156,0.516985,0.89508
50%,0.65648,0.995739,0.587174,0.94121
75%,0.73331,0.996041,0.702647,0.943338
max,0.798101,0.996982,0.881516,0.950528
>0.90,0.0,1.0,0.0,0.7
>0.00,1.0,1.0,1.0,1.0
