diff --git a/recbole/trainer/trainer.py b/recbole/trainer/trainer.py index fb0b8d7..7fc5781 100644 --- a/recbole/trainer/trainer.py +++ b/recbole/trainer/trainer.py @@ -218,7 +218,7 @@ class Trainer(AbstractTrainer): train_loss_output += "train loss: %.4f" % losses return train_loss_output + ']' - def fit(self, train_data, valid_data=None, verbose=True, saved=True): + def fit(self, train_data, valid_data=None, verbose=True, saved=True, callback_fn=None): r"""Train the model based on the train data and the valid data. Args: @@ -227,6 +227,8 @@ class Trainer(AbstractTrainer): If it's None, the early_stopping is invalid. verbose (bool, optional): whether to write training and evaluation information to logger, default: True saved (bool, optional): whether to save the model parameters, default: True + callback_fn (callable): Optional callback function executed at end of epoch. Includes + (epoch_idx, valid_score) input arguments Returns: (float, dict): best valid score and best valid result. If valid_data is None, it returns (-1, None) @@ -273,6 +275,9 @@ class Trainer(AbstractTrainer): self.logger.info(update_output) self.best_valid_result = valid_result + if callback_fn: + callback_fn(epoch_idx, valid_score) + if stop_flag: stop_output = 'Finished training, best eval result in epoch %d' % \ (epoch_idx - self.cur_step * self.eval_step) @@ -647,12 +652,12 @@ class MKRTrainer(Trainer): interaction = interaction.to(self.device) self.optimizer.zero_grad() loss_rs = self.model.calculate_rs_loss(interaction) - + self._check_nan(loss_rs) loss_rs.backward() self.optimizer.step() rs_total_loss += loss_rs - + # train kg if epoch_idx % self.kge_interval == 0: print('Train KG') @@ -661,7 +666,7 @@ class MKRTrainer(Trainer): interaction = interaction.to(self.device) self.optimizer.zero_grad() loss_kge = self.model.calculate_kg_loss(interaction) - + self._check_nan(loss_kge) loss_kge.backward() self.optimizer.step()