From d8dc8c9c066fbe5639228763da5c11d7ca23425e Mon Sep 17 00:00:00 2001 From: huvunvidia Date: Fri, 10 May 2024 14:33:06 +0000 Subject: [PATCH] Apply isort and black reformatting Signed-off-by: huvunvidia --- .../language_modeling/megatron_retro_eval_legacy.py | 10 ++++++++-- .../megatron_retro_pretraining_legacy.py | 2 +- 2 files changed, 9 insertions(+), 3 deletions(-) diff --git a/examples/nlp/language_modeling/megatron_retro_eval_legacy.py b/examples/nlp/language_modeling/megatron_retro_eval_legacy.py index 9bd07c135ca7..df88f52ed8d1 100644 --- a/examples/nlp/language_modeling/megatron_retro_eval_legacy.py +++ b/examples/nlp/language_modeling/megatron_retro_eval_legacy.py @@ -77,7 +77,10 @@ def main(cfg) -> None: save_restore_connector.model_extracted_dir = model_path model_cfg = MegatronRetrievalModel.restore_from( - model_path, trainer=trainer, return_config=True, save_restore_connector=save_restore_connector, + model_path, + trainer=trainer, + return_config=True, + save_restore_connector=save_restore_connector, ) with open_dict(model_cfg): @@ -97,7 +100,10 @@ def main(cfg) -> None: cfg.pipeline_model_parallel_split_rank = model_cfg.get('pipeline_model_parallel_split_rank', 0) model = MegatronRetrievalModel.restore_from( - model_path, trainer=trainer, save_restore_connector=save_restore_connector, override_config_path=model_cfg, + model_path, + trainer=trainer, + save_restore_connector=save_restore_connector, + override_config_path=model_cfg, ) length_params: LengthParam = { diff --git a/examples/nlp/language_modeling/megatron_retro_pretraining_legacy.py b/examples/nlp/language_modeling/megatron_retro_pretraining_legacy.py index 6386d495986b..a975f9c9f0d7 100644 --- a/examples/nlp/language_modeling/megatron_retro_pretraining_legacy.py +++ b/examples/nlp/language_modeling/megatron_retro_pretraining_legacy.py @@ -59,7 +59,7 @@ def main(cfg) -> None: scaler = None if cfg.trainer.precision in [16, '16', '16-mixed']: scaler = GradScaler( - init_scale=cfg.model.get('native_amp_init_scale', 2 ** 32), + init_scale=cfg.model.get('native_amp_init_scale', 2**32), growth_interval=cfg.model.get('native_amp_growth_interval', 1000), hysteresis=cfg.model.get('hysteresis', 2), )