In [1]:
%reload_ext autoreload
%autoreload 2

# Model for multi-target values

In [2]:
from peptdeep.model.model_shop import (
    Model_for_Generic_AASeq_BinaryClassification_LSTM,
    Model_for_Generic_AASeq_BinaryClassification_Transformer,
    ModelInterface_for_Generic_AASeq_BinaryClassification,
    Model_for_Generic_AASeq_Regression_LSTM,
    Model_for_Generic_AASeq_Regression_Transformer,
    ModelInterface_for_Generic_AASeq_Regression,
)
import torch
import numpy as np

class ModelInterface_MultiTarget(ModelInterface_for_Generic_AASeq_BinaryClassification):
    def __init__(self):
        super().__init__(
            model_class=Model_for_Generic_AASeq_BinaryClassification_Transformer,
            output_dim=2, # two target values
        )
        self.num_target_values = 2
        self.target_column_to_train = 'target_column'
        self.target_column_to_predict = 'pred_column'

    def _get_targets_from_batch_df(self, batch_df, **kwargs):
        return self._as_tensor(
            np.stack(batch_df[self.target_column_to_train].values), 
            dtype=torch.float32
        )

    def _prepare_predict_data_df(self, precursor_df, **kwargs):
        precursor_df[self.target_column_to_predict] = [
            [0]*self.num_target_values
        ]*len(precursor_df)
        self.predict_df = precursor_df

    def _set_batch_predict_data(self, batch_df, predict_values, **kwargs):
        predict_values[predict_values<self._min_pred_value] = self._min_pred_value
        if self._predict_in_order:
            self.predict_df.loc[:,self.target_column_to_predict].values[
                batch_df.index.values[0]:batch_df.index.values[-1]+1
            ] = list(predict_values)
        else:
            self.predict_df.loc[
                batch_df.index,self.target_column_to_predict
            ] = list(predict_values)

model = ModelInterface_MultiTarget()

In [3]:
import pandas as pd

df = pd.DataFrame({
    'sequence': ['ABCDE','FGHIJK','LMNOPQ','RSTUVWXYZ'],
    'target_column': [[1,0],[0,1],[1,1],[0,0]],
})
df

Unnamed: 0,sequence,target_column
0,ABCDE,"[1, 0]"
1,FGHIJK,"[0, 1]"
2,LMNOPQ,"[1, 1]"
3,RSTUVWXYZ,"[0, 0]"


In [4]:
model.train(df)
model.predict(df)

Unnamed: 0,sequence,target_column,nAA,pred_column
0,ABCDE,"[1, 0]",5,"[0.93552226, 0.07380291]"
1,FGHIJK,"[0, 1]",6,"[0.08214627, 0.89221424]"
2,LMNOPQ,"[1, 1]",6,"[0.94375694, 0.8868231]"
3,RSTUVWXYZ,"[0, 0]",9,"[0.05007053, 0.05487113]"


In [5]:
from peptdeep.model.generic_property_prediction import (
    Model_for_Generic_AASeq_BinaryClassification_Transformer,
    ModelInterface_for_Generic_AASeq_MultiTargetClassification
)
model = ModelInterface_for_Generic_AASeq_MultiTargetClassification(
    model_class=Model_for_Generic_AASeq_BinaryClassification_Transformer,
    num_target_values=2,
)
model.target_column_to_train

'target_probs'

In [6]:
import pandas as pd

df = pd.DataFrame({
    'sequence': ['ABCDE','FGHIJK','LMNOPQ','RSTUVWXYZ'],
    'target_probs': [[1,0],[0,1],[1,1],[0,0]],
})
model.train(df)
model.predict(df)

Unnamed: 0,sequence,target_probs,nAA,target_probs_pred
0,ABCDE,"[1, 0]",5,"[0.9555132, 0.04477089]"
1,FGHIJK,"[0, 1]",6,"[0.07803793, 0.88632977]"
2,LMNOPQ,"[1, 1]",6,"[0.9190902, 0.89732695]"
3,RSTUVWXYZ,"[0, 0]",9,"[0.04630012, 0.052344453]"


In [7]:
from peptdeep.model.generic_property_prediction import (
    Model_for_Generic_ModAASeq_BinaryClassification_Transformer,
    ModelInterface_for_Generic_ModAASeq_MultiTargetClassification
)
model = ModelInterface_for_Generic_ModAASeq_MultiTargetClassification(
    model_class=Model_for_Generic_ModAASeq_BinaryClassification_Transformer,
    num_target_values=2,
)
model.target_column_to_train

'target_probs'

In [8]:
import pandas as pd

df = pd.DataFrame({
    'sequence': ['ABCDE','FGHIJK','LMNOPQ','RSTUVWXYZ'],
    'mods': "",
    'mod_sites': "",
    'target_probs': [[1,0],[0,1],[1,1],[0,0]],
})
model.train(df)
model.predict(df)

Unnamed: 0,sequence,mods,mod_sites,target_probs,nAA,target_probs_pred
0,ABCDE,,,"[1, 0]",5,"[0.9591324, 0.060887054]"
1,FGHIJK,,,"[0, 1]",6,"[0.11639012, 0.88370585]"
2,LMNOPQ,,,"[1, 1]",6,"[0.91906965, 0.89074755]"
3,RSTUVWXYZ,,,"[0, 0]",9,"[0.03260221, 0.045892052]"
