Skip to content

Commit

Permalink
Merge pull request #162 from MannLabs/training-callback-handler
Browse files Browse the repository at this point in the history
Feat: Add train callbacks
  • Loading branch information
GeorgWa authored May 3, 2024
2 parents 4c06758 + 4392159 commit b1903d4
Showing 1 changed file with 66 additions and 5 deletions.
71 changes: 66 additions & 5 deletions peptdeep/model/model_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,45 @@ def get_cosine_schedule_with_warmup( self,
)
return LambdaLR(optimizer, lr_lambda, last_epoch)

class CallbackHandler:
"""
A CallbackHandler class that can be used to add callbacks to the training process for both
epoch-level and batch-level events. To have more control over the training process, you can
create a subclass of this class and override the methods you need.
"""
def epoch_callback(self, epoch:int, epoch_loss:float) -> bool:
"""
This method will be called at the end of each epoch. The callback can also be used to
stop the training by returning False. If the return value is None, or True, the training
will continue.
Parameters
----------
epoch : int
The current epoch number.
epoch_loss : float
The loss value of the current epoch.
Returns
-------
continue_training : bool
If False, the training will stop.
"""
continue_training = True
return continue_training
def batch_callback(self, batch:int, batch_loss:float):
"""
This method will be called at the end of each batch.
Parameters
----------
batch : int
The current batch number.
batch_loss : float
The loss value of the current batch.
"""
pass

def append_nAA_column_if_missing(precursor_df):
"""
Expand Down Expand Up @@ -195,6 +234,7 @@ def __init__(self,
self.fixed_sequence_len = fixed_sequence_len
self.min_pred_value = min_pred_value
self.lr_scheduler_class = WarmupLR_Scheduler
self.callback_handler = CallbackHandler()

@property
def fixed_sequence_len(self)->int:
Expand Down Expand Up @@ -273,6 +313,16 @@ def set_lr_scheduler_class(self, lr_scheduler_class:LR_SchedulerInterface) -> No
)
else:
self.lr_scheduler_class = lr_scheduler_class
def set_callback_handler(self, callback_handler:CallbackHandler) -> None:
"""
Set the callback handler. It has to be a subclass of CallbackHandler.
"""
if isinstance(callback_handler, CallbackHandler):
self.callback_handler = callback_handler
else:
raise ValueError(
"The callback handler passed must be a subclass of model_interface.CallbackHandler"
)
def set_device(self,
device_type:str = 'gpu',
device_ids:list = []
Expand Down Expand Up @@ -392,12 +442,17 @@ def train_with_warmup(self,
batch_size, verbose_each_epoch,
**kwargs
)

lr_scheduler.step(epoch=epoch, loss=np.mean(batch_cost))
if verbose: print(
f'[Training] Epoch={epoch+1}, lr={lr_scheduler.get_last_lr()[0]}, loss={np.mean(batch_cost)}'
)

continue_training = self.callback_handler.epoch_callback(
epoch=epoch, epoch_loss=np.mean(batch_cost)
)
if not continue_training:
print(f"Training stopped at epoch {epoch}")
break
torch.cuda.empty_cache()

def train(self,
Expand Down Expand Up @@ -447,7 +502,13 @@ def train(self,
**kwargs
)
if verbose: print(f'[Training] Epoch={epoch+1}, Mean Loss={np.mean(batch_cost)}')


continue_training = self.callback_handler.epoch_callback(
epoch=epoch, epoch_loss=np.mean(batch_cost)
)
if not continue_training:
print(f"Training stopped at epoch {epoch}")
break
torch.cuda.empty_cache()

def predict(self,
Expand Down Expand Up @@ -724,7 +785,7 @@ def _train_one_epoch_by_padding_zeros(self,
batch_cost.append(
self._train_one_batch(targets, features)
)

self.callback_handler.batch_callback(i//batch_size, batch_cost[-1])
if verbose_each_epoch:
batch_tqdm.set_description(
f'Epoch={epoch+1}, batch={len(batch_cost)}, loss={batch_cost[-1]:.4f}'
Expand Down Expand Up @@ -764,7 +825,7 @@ def _train_one_epoch(self,
batch_cost.append(
self._train_one_batch(targets, features)
)

self.callback_handler.batch_callback(i//batch_size, batch_cost[-1])
if verbose_each_epoch:
batch_tqdm.set_description(
f'Epoch={epoch+1}, nAA={nAA}, batch={len(batch_cost)}, loss={batch_cost[-1]:.4f}'
Expand Down

0 comments on commit b1903d4

Please sign in to comment.