diff --git a/src/abstractive.py b/src/abstractive.py index 5bfdd0f..4b18aaf 100644 --- a/src/abstractive.py +++ b/src/abstractive.py @@ -90,7 +90,7 @@ class AbstractiveSummarizer(pl.LightningModule): def __init__(self, hparams): super(AbstractiveSummarizer, self).__init__() - self.hparams = hparams + self.save_hyperparameters(hparams) if len(self.hparams.dataset) <= 1: self.hparams.dataset = self.hparams.dataset[0] diff --git a/src/extractive.py b/src/extractive.py index 33c647a..98267c9 100644 --- a/src/extractive.py +++ b/src/extractive.py @@ -106,7 +106,7 @@ def __init__(self, hparams, embedding_model_config=None, classifier_obj=None): hparams.tokenizer_no_use_fast = getattr(hparams, "tokenizer_no_use_fast", False) hparams.data_type = getattr(hparams, "data_type", "none") - self.hparams = hparams + self.save_hyperparameters(hparams) self.forward_modify_inputs_callback = None if not embedding_model_config: