In [36]:
#default_exp model.ccs

In [37]:
#hide
%reload_ext autoreload
%autoreload 2

In [38]:
#export
import torch
import pandas as pd
import numpy as np

from tqdm import tqdm

from alphabase.peptide.fragment import update_precursor_mz

from alphabase.peptide.mobility import (
    ccs_to_mobility_for_df,
    mobility_to_ccs_for_df
)

from alphadeep.model.featurize import (
    parse_aa_indices, 
    get_batch_mod_feature
)

from alphadeep.settings import model_const

import alphadeep.model.base as model_base

from alphadeep.model.rt import (
    evaluate_linear_regression, 
    evaluate_linear_regression_plot
)

charge_factor = model_const['charge_factor']

In [39]:
#export
class ModelCCSTransformer(torch.nn.Module):
    def __init__(self,
        dropout = 0.1,
        nlayers = 4,
        hidden = 128,
        **kwargs,
    ):
        super().__init__()

        self.dropout = torch.nn.Dropout(dropout)

        self.input_nn = model_base.AATransformerEncoding(hidden-2)

        self.hidden_nn = model_base.HiddenTransformer(
            hidden, nlayers=nlayers, dropout=dropout
        )

        self.output_nn = torch.nn.Sequential(
            model_base.SeqAttentionSum(hidden+2),
            torch.nn.PReLU(),
            self.dropout,
            torch.nn.Linear(hidden+2, 1),
        )

    def forward(self, 
        aa_indices, 
        mod_x,
        charges:torch.Tensor,
    ):
        x = self.dropout(self.input_nn(
            aa_indices, mod_x
        ))
        charges = charges.unsqueeze(1).repeat(1,x.size(1),2)
        x = torch.cat((x, charges),2)

        x = torch.cat((self.hidden_nn(x),charges),2)

        return self.output_nn(x).squeeze(1)

In [40]:
#export
class ModelCCSLSTM(torch.nn.Module):
    def __init__(self,
        dropout=0.1,
        *kwargs,
    ):
        super().__init__()
        
        self.dropout = torch.nn.Dropout(dropout)
        
        hidden = 256

        self.ccs_encoder = (
            model_base.Input_AA_CNN_LSTM_cat_Charge_Encoder(
                hidden
            )
        )

        self.ccs_decoder = model_base.LinearDecoder(
            hidden+1, 1
        )

    def forward(self, 
        aa_indices, 
        mod_x,
        charges,
    ):
        x = self.ccs_encoder(aa_indices, mod_x, charges)
        x = self.dropout(x)
        x = torch.cat((x, charges),1)
        return self.ccs_decoder(x).squeeze(1)

In [41]:
#export

def ccs_to_mobility_pred_df(
    precursor_df:pd.DataFrame
)->pd.DataFrame:
    """ Add 'mobility_pred' into precursor_df inplace """
    precursor_df[
        'mobility_pred'
    ] = ccs_to_mobility_for_df(
        precursor_df, 'ccs_pred'
    )
    return precursor_df

def mobility_to_ccs_df_(
    precursor_df:pd.DataFrame
)->pd.DataFrame:
    """ Add 'ccs' into precursor_df inplace """
    precursor_df[
        'ccs'
    ] = mobility_to_ccs_for_df(
        precursor_df, 'mobility'
    )
    return precursor_df

In [42]:
#export

class AlphaCCSModel(model_base.ModelImplBase):
    def __init__(self, 
        dropout=0.1, lr=0.001,
        model_class:torch.nn.Module=ModelCCSLSTM,
        **kwargs,
    ):
        super().__init__()
        self.build(
            model_class,
            dropout=dropout, 
            **kwargs
        )
        self.loss_func = torch.nn.L1Loss()
        self.charge_factor = charge_factor

    def _prepare_predict_data_df(self,
        precursor_df:pd.DataFrame,
    ):
        precursor_df['ccs_pred'] = 0.
        self.predict_df = precursor_df

    def _get_features_from_batch_df(self, 
        batch_df: pd.DataFrame, 
        nAA, **kwargs,
    ):
        aa_indices = torch.LongTensor(
            parse_aa_indices(
                batch_df['sequence'].values.astype('U')
            )
        )

        mod_x_batch = get_batch_mod_feature(batch_df, nAA)
        mod_x = torch.Tensor(mod_x_batch)

        charges = torch.Tensor(
            batch_df['charge'].values
        ).unsqueeze(1)*self.charge_factor

        return aa_indices, mod_x, charges

    def _get_targets_from_batch_df(self, 
        batch_df: pd.DataFrame, 
        **kwargs,
    ) -> torch.Tensor:
        return torch.Tensor(batch_df['ccs'].values)

    def _set_batch_predict_data(self, 
        batch_df: pd.DataFrame, 
        predicts,
    ):
        predicts[predicts<0] = 0.0
        if self._predict_in_order:
            self.predict_df.loc[:,'ccs_pred'].values[
                batch_df.index.values[0]:batch_df.index.values[-1]+1
            ] = predicts
        else:
            self.predict_df.loc[
                batch_df.index,'ccs_pred'
            ] = predicts

    def ccs_to_mobility_pred(self,
        precursor_df:pd.DataFrame
    )->pd.DataFrame:
        return ccs_to_mobility_pred_df(precursor_df)

In [43]:
#hide
from alphadeep.model.base import aa_one_hot
model = AlphaCCSModel()
model.device = torch.device('cpu')
model.model.to(model.device)
mod_hidden = len(model_const['mod_elements'])
model.model(
    torch.LongTensor([[1,2,3,4,5,6]]), 
    torch.tensor([[[0.0]*mod_hidden]*6]), 
    torch.tensor([[1]])
)

tensor([-0.2777], grad_fn=<SqueezeBackward1>)

In [44]:
model.get_parameter_num()

796742

In [45]:
model.model

ModelCCSTransformer(
  (dropout): Dropout(p=0.1, inplace=False)
  (input_nn): AATransformerEncoding(
    (mod_nn): InputModNetFixFirstK(
      (nn): Linear(in_features=103, out_features=2, bias=False)
    )
    (aa_emb): Embedding(27, 118, padding_idx=0)
    (pos_encoder): PositionalEncoding()
  )
  (hidden_nn): HiddenTransformer(
    (transormer): SeqTransformer(
      (transformer_encoder): TransformerEncoder(
        (layers): ModuleList(
          (0): TransformerEncoderLayer(
            (self_attn): MultiheadAttention(
              (out_proj): _LinearWithBias(in_features=128, out_features=128, bias=True)
            )
            (linear1): Linear(in_features=128, out_features=512, bias=True)
            (dropout): Dropout(p=0.1, inplace=False)
            (linear2): Linear(in_features=512, out_features=128, bias=True)
            (norm1): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
            (norm2): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
            

In [46]:
repeat = 10
precursor_df = pd.DataFrame({
    'sequence': ['AGHCEWQMKYR']*repeat,
    'mods': ['Acetyl@Protein N-term;Carbamidomethyl@C;Oxidation@M']*repeat,
    'mod_sites': ['0;4;8']*repeat,
    'nAA': [11]*repeat,
    'charge': [2]*repeat,
    'ccs': [1]*repeat
})
precursor_df

Unnamed: 0,sequence,mods,mod_sites,nAA,charge,ccs
0,AGHCEWQMKYR,Acetyl@Protein N-term;Carbamidomethyl@C;Oxidat...,0;4;8,11,2,1
1,AGHCEWQMKYR,Acetyl@Protein N-term;Carbamidomethyl@C;Oxidat...,0;4;8,11,2,1
2,AGHCEWQMKYR,Acetyl@Protein N-term;Carbamidomethyl@C;Oxidat...,0;4;8,11,2,1
3,AGHCEWQMKYR,Acetyl@Protein N-term;Carbamidomethyl@C;Oxidat...,0;4;8,11,2,1
4,AGHCEWQMKYR,Acetyl@Protein N-term;Carbamidomethyl@C;Oxidat...,0;4;8,11,2,1
5,AGHCEWQMKYR,Acetyl@Protein N-term;Carbamidomethyl@C;Oxidat...,0;4;8,11,2,1
6,AGHCEWQMKYR,Acetyl@Protein N-term;Carbamidomethyl@C;Oxidat...,0;4;8,11,2,1
7,AGHCEWQMKYR,Acetyl@Protein N-term;Carbamidomethyl@C;Oxidat...,0;4;8,11,2,1
8,AGHCEWQMKYR,Acetyl@Protein N-term;Carbamidomethyl@C;Oxidat...,0;4;8,11,2,1
9,AGHCEWQMKYR,Acetyl@Protein N-term;Carbamidomethyl@C;Oxidat...,0;4;8,11,2,1


In [47]:
model.train(precursor_df, epoch=5)

In [48]:
model.predict(precursor_df)

Unnamed: 0,sequence,mods,mod_sites,nAA,charge,ccs,ccs_pred
0,AGHCEWQMKYR,Acetyl@Protein N-term;Carbamidomethyl@C;Oxidat...,0;4;8,11,2,1,1.445119
1,AGHCEWQMKYR,Acetyl@Protein N-term;Carbamidomethyl@C;Oxidat...,0;4;8,11,2,1,1.445119
2,AGHCEWQMKYR,Acetyl@Protein N-term;Carbamidomethyl@C;Oxidat...,0;4;8,11,2,1,1.445119
3,AGHCEWQMKYR,Acetyl@Protein N-term;Carbamidomethyl@C;Oxidat...,0;4;8,11,2,1,1.445119
4,AGHCEWQMKYR,Acetyl@Protein N-term;Carbamidomethyl@C;Oxidat...,0;4;8,11,2,1,1.445119
5,AGHCEWQMKYR,Acetyl@Protein N-term;Carbamidomethyl@C;Oxidat...,0;4;8,11,2,1,1.445119
6,AGHCEWQMKYR,Acetyl@Protein N-term;Carbamidomethyl@C;Oxidat...,0;4;8,11,2,1,1.445119
7,AGHCEWQMKYR,Acetyl@Protein N-term;Carbamidomethyl@C;Oxidat...,0;4;8,11,2,1,1.445119
8,AGHCEWQMKYR,Acetyl@Protein N-term;Carbamidomethyl@C;Oxidat...,0;4;8,11,2,1,1.445119
9,AGHCEWQMKYR,Acetyl@Protein N-term;Carbamidomethyl@C;Oxidat...,0;4;8,11,2,1,1.445119


In [49]:
model.ccs_to_mobility_pred(precursor_df)

Unnamed: 0,sequence,mods,mod_sites,nAA,charge,ccs,ccs_pred,precursor_mz,mobility_pred
0,AGHCEWQMKYR,Acetyl@Protein N-term;Carbamidomethyl@C;Oxidat...,0;4;8,11,2,1,1.445119,762.329553,0.003576
1,AGHCEWQMKYR,Acetyl@Protein N-term;Carbamidomethyl@C;Oxidat...,0;4;8,11,2,1,1.445119,762.329553,0.003576
2,AGHCEWQMKYR,Acetyl@Protein N-term;Carbamidomethyl@C;Oxidat...,0;4;8,11,2,1,1.445119,762.329553,0.003576
3,AGHCEWQMKYR,Acetyl@Protein N-term;Carbamidomethyl@C;Oxidat...,0;4;8,11,2,1,1.445119,762.329553,0.003576
4,AGHCEWQMKYR,Acetyl@Protein N-term;Carbamidomethyl@C;Oxidat...,0;4;8,11,2,1,1.445119,762.329553,0.003576
5,AGHCEWQMKYR,Acetyl@Protein N-term;Carbamidomethyl@C;Oxidat...,0;4;8,11,2,1,1.445119,762.329553,0.003576
6,AGHCEWQMKYR,Acetyl@Protein N-term;Carbamidomethyl@C;Oxidat...,0;4;8,11,2,1,1.445119,762.329553,0.003576
7,AGHCEWQMKYR,Acetyl@Protein N-term;Carbamidomethyl@C;Oxidat...,0;4;8,11,2,1,1.445119,762.329553,0.003576
8,AGHCEWQMKYR,Acetyl@Protein N-term;Carbamidomethyl@C;Oxidat...,0;4;8,11,2,1,1.445119,762.329553,0.003576
9,AGHCEWQMKYR,Acetyl@Protein N-term;Carbamidomethyl@C;Oxidat...,0;4;8,11,2,1,1.445119,762.329553,0.003576
