In [1]:
#default_exp model.CCS

In [2]:
#hide
%reload_ext autoreload
%autoreload 2

In [3]:
#export
import torch
import pandas as pd
import numpy as np

from tqdm import tqdm

from alphadeep.model.featurize import \
    parse_aa_indices, \
    get_batch_mod_feature

from alphadeep._settings import model_const

import alphadeep.model.base as model_base

from alphadeep.model.RT import evaluate_linear_regression, evaluate_linear_regression_plot


mod_feature_size = len(model_const['mod_elements'])
charge_factor = model_const['charge_factor']

In [4]:
#export
class EncDecModelCCS(torch.nn.Module):
    def __init__(self, 
        mod_feature_size,
        dropout=0.2
    ):
        super().__init__()
        self.aa_embedding_size = 27
        
        self.dropout = torch.nn.Dropout(dropout)
        
        hidden = 256
        self.encoder = model_base.SeqEncoder(
            self.aa_embedding_size+mod_feature_size+1, 
            hidden,
            dropout=0,
            rnn_layer=2
        )

        self.decoder = model_base.LinearDecoder(
            hidden+1,
            1
        )

    def forward(self, 
        aa_indices, 
        mod_x,
        charges,
    ):
        aa_x = torch.nn.functional.one_hot(
            aa_indices, self.aa_embedding_size
        ).float()

        x = torch.cat((
            aa_x, mod_x, charges.unsqueeze(1).repeat(1, aa_x.size(1), 1)
        ), 2)
        x = self.encoder(x)
        x = self.dropout(x)
        x = torch.cat((x, charges),1)

        return self.decoder(x).squeeze(1)

In [5]:
#export

class AlphaCCSModel(model_base.ModelImplBase):
    def __init__(self, dropout=0.2, lr=0.001):
        super().__init__()
        self.build(
            EncDecModelCCS, lr=lr, 
            dropout=dropout, 
            mod_feature_size=mod_feature_size
        )
        self.loss_func = torch.nn.L1Loss()
        self.charge_factor = charge_factor

    def _prepare_predict_data_df(self,
        precursor_df:pd.DataFrame,
    ):
        precursor_df['predict_CCS'] = 0
        self.predict_df = precursor_df

    def _get_features_from_batch_df(self, 
        batch_df: pd.DataFrame, 
        nAA
    ):
        aa_indices = torch.LongTensor(
            parse_aa_indices(
                batch_df['sequence'].values.astype('U')
            )
        )

        mod_x_batch = get_batch_mod_feature(batch_df, nAA)
        mod_x = torch.Tensor(mod_x_batch)

        charges = torch.Tensor(
            batch_df['charge'].values
        ).unsqueeze(1)*self.charge_factor

        return aa_indices, mod_x, charges

    def _get_targets_from_batch_df(self, 
        batch_df: pd.DataFrame, 
        nAA
    ) -> torch.Tensor:
        return torch.Tensor(batch_df['CCS'].values)

    def _set_batch_predict_data(self, 
        batch_df: pd.DataFrame, 
        predicts,
    ):
        self.predict_df.loc[batch_df.index,'predict_CCS'] = predicts

In [6]:
model = AlphaCCSModel()
model.device = torch.device('cpu')
model.model.to(model.device)
model.model(torch.LongTensor([[1,2,3,4,5,6]]), torch.tensor([[[0.0]*8]*6]), torch.tensor([[1]]))

tensor([-0.1450], grad_fn=<SqueezeBackward1>)

In [7]:
model.model

EncDecModelCCS(
  (dropout): Dropout(p=0.2, inplace=False)
  (encoder): SeqEncoder(
    (dropout): Dropout(p=0, inplace=False)
    (input_cnn): SeqCNN(
      (cnn_short): Conv1d(36, 36, kernel_size=(3,), stride=(1,), padding=(1,))
      (cnn_medium): Conv1d(36, 36, kernel_size=(5,), stride=(1,), padding=(2,))
      (cnn_long): Conv1d(36, 36, kernel_size=(7,), stride=(1,), padding=(3,))
    )
    (hidden_nn): SeqLSTM(
      (rnn): LSTM(144, 128, num_layers=2, batch_first=True, bidirectional=True)
    )
    (attn_sum): SeqAttentionSum(
      (attn): Sequential(
        (0): Linear(in_features=256, out_features=1, bias=False)
        (1): Softmax(dim=1)
      )
    )
  )
  (decoder): LinearDecoder(
    (nn): Sequential(
      (0): Linear(in_features=257, out_features=64, bias=True)
      (1): PReLU(num_parameters=1)
      (2): Linear(in_features=64, out_features=1, bias=True)
    )
  )
)

In [13]:
repeat = 10
precursor_df = pd.DataFrame({
    'sequence': ['AGHCEWQMKYR']*repeat,
    'mods': ['Acetyl@Protein N-term;Carbamidomethyl@C;Oxidation@M']*repeat,
    'mod_sites': ['0;4;8']*repeat,
    'nAA': [11]*repeat,
    'charge': [2]*repeat,
    'CCS': [1]*repeat
})
precursor_df

Unnamed: 0,sequence,mods,mod_sites,nAA,charge,CCS
0,AGHCEWQMKYR,Acetyl@Protein N-term;Carbamidomethyl@C;Oxidat...,0;4;8,11,2,1
1,AGHCEWQMKYR,Acetyl@Protein N-term;Carbamidomethyl@C;Oxidat...,0;4;8,11,2,1
2,AGHCEWQMKYR,Acetyl@Protein N-term;Carbamidomethyl@C;Oxidat...,0;4;8,11,2,1
3,AGHCEWQMKYR,Acetyl@Protein N-term;Carbamidomethyl@C;Oxidat...,0;4;8,11,2,1
4,AGHCEWQMKYR,Acetyl@Protein N-term;Carbamidomethyl@C;Oxidat...,0;4;8,11,2,1
5,AGHCEWQMKYR,Acetyl@Protein N-term;Carbamidomethyl@C;Oxidat...,0;4;8,11,2,1
6,AGHCEWQMKYR,Acetyl@Protein N-term;Carbamidomethyl@C;Oxidat...,0;4;8,11,2,1
7,AGHCEWQMKYR,Acetyl@Protein N-term;Carbamidomethyl@C;Oxidat...,0;4;8,11,2,1
8,AGHCEWQMKYR,Acetyl@Protein N-term;Carbamidomethyl@C;Oxidat...,0;4;8,11,2,1
9,AGHCEWQMKYR,Acetyl@Protein N-term;Carbamidomethyl@C;Oxidat...,0;4;8,11,2,1


In [14]:
model.train(precursor_df, epoch=5)

Epoch=1, nAA=11, Batch=1, Loss=0.6656: 100%|██████████| 1/1 [00:00<00:00, 27.90it/s]
Epoch=2, nAA=11, Batch=1, Loss=0.4224: 100%|██████████| 1/1 [00:00<00:00, 29.31it/s]
Epoch=3, nAA=11, Batch=1, Loss=0.0382: 100%|██████████| 1/1 [00:00<00:00, 31.34it/s]
Epoch=4, nAA=11, Batch=1, Loss=0.4837: 100%|██████████| 1/1 [00:00<00:00, 28.95it/s]
Epoch=5, nAA=11, Batch=1, Loss=0.5005: 100%|██████████| 1/1 [00:00<00:00, 31.77it/s]


In [15]:
model.predict(precursor_df)

Unnamed: 0,sequence,mods,mod_sites,nAA,charge,CCS,predict_CCS
0,AGHCEWQMKYR,Acetyl@Protein N-term;Carbamidomethyl@C;Oxidat...,0;4;8,11,2,1,1.280028
1,AGHCEWQMKYR,Acetyl@Protein N-term;Carbamidomethyl@C;Oxidat...,0;4;8,11,2,1,1.280028
2,AGHCEWQMKYR,Acetyl@Protein N-term;Carbamidomethyl@C;Oxidat...,0;4;8,11,2,1,1.280028
3,AGHCEWQMKYR,Acetyl@Protein N-term;Carbamidomethyl@C;Oxidat...,0;4;8,11,2,1,1.280028
4,AGHCEWQMKYR,Acetyl@Protein N-term;Carbamidomethyl@C;Oxidat...,0;4;8,11,2,1,1.280028
5,AGHCEWQMKYR,Acetyl@Protein N-term;Carbamidomethyl@C;Oxidat...,0;4;8,11,2,1,1.280028
6,AGHCEWQMKYR,Acetyl@Protein N-term;Carbamidomethyl@C;Oxidat...,0;4;8,11,2,1,1.280028
7,AGHCEWQMKYR,Acetyl@Protein N-term;Carbamidomethyl@C;Oxidat...,0;4;8,11,2,1,1.280028
8,AGHCEWQMKYR,Acetyl@Protein N-term;Carbamidomethyl@C;Oxidat...,0;4;8,11,2,1,1.280028
9,AGHCEWQMKYR,Acetyl@Protein N-term;Carbamidomethyl@C;Oxidat...,0;4;8,11,2,1,1.280028
