Skip to content

Commit

Permalink
Merge pull request #165 from MannLabs/change-device-charge-model
Browse files Browse the repository at this point in the history
Allow change of device for charge models through initialization
  • Loading branch information
mo-sameh committed May 13, 2024
2 parents b1903d4 + fb7e39c commit 9006870
Showing 1 changed file with 22 additions and 4 deletions.
26 changes: 22 additions & 4 deletions peptdeep/model/charge.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,12 +86,21 @@ class ChargeModelForModAASeq(
):
"""
ModelInterface for charge prediction for modified peptides
Parameters
----------
min_charge : int, optional
Minimum charge to predict, by default 1
max_charge : int, optional
Maximum charge to predict, by default 6
device : str, optional
Device to use for training and prediction, by default "gpu"
"""
def __init__(self, min_charge:int=1, max_charge:int=6):
def __init__(self, min_charge:int=1, max_charge:int=6, device:str="gpu"):
super().__init__(
num_target_values=max_charge-min_charge+1,
model_class=Model_for_Generic_ModAASeq_BinaryClassification_Transformer,
nlayers=4, hidden_dim=128, dropout=0.1
nlayers=4, hidden_dim=128, dropout=0.1, device=device
)

self.target_column_to_predict = "charge_probs"
Expand All @@ -109,12 +118,21 @@ class ChargeModelForAASeq(
):
"""
ModelInterface for charge prediction for amino acid sequence
Parameters
----------
min_charge : int, optional
Minimum charge to predict, by default 1
max_charge : int, optional
Maximum charge to predict, by default 6
device : str, optional
Device to use for training and prediction, by default "gpu"
"""
def __init__(self, min_charge:int=1, max_charge:int=6):
def __init__(self, min_charge:int=1, max_charge:int=6,device:str="gpu"):
super().__init__(
num_target_values=max_charge-min_charge+1,
model_class=Model_for_Generic_AASeq_BinaryClassification_Transformer,
nlayers=4, hidden_dim=128, dropout=0.1
nlayers=4, hidden_dim=128, dropout=0.1, device=device
)

self.target_column_to_predict = "charge_probs"
Expand Down

0 comments on commit 9006870

Please sign in to comment.