diff --git a/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py b/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py index 817ef0bd6442..44b484b28949 100644 --- a/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py +++ b/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py @@ -22,7 +22,6 @@ import torch from omegaconf.dictconfig import DictConfig from pytorch_lightning.accelerators import CPUAccelerator -from pytorch_lightning.plugins.precision.native_amp import NativeMixedPrecisionPlugin from pytorch_lightning.trainer.trainer import Trainer from nemo.collections.nlp.data.language_modeling.megatron.data_samplers import ( @@ -53,7 +52,6 @@ SamplingParam, TextGeneration, ) -from nemo.collections.nlp.parts.nlp_overrides import GradScaler from nemo.collections.nlp.parts.utils_funcs import get_last_rank from nemo.core.classes import Exportable from nemo.core.classes.common import PretrainedModelInfo @@ -512,37 +510,38 @@ def fwd_bwd_step(self, dataloader_iter, batch_idx, forward_only): return loss_mean + def initialize_ub_func(self): + input_shape = [ + self.cfg.get('encoder_seq_length') * self.cfg.get('micro_batch_size'), + self.cfg.get('hidden_size'), + ] + ub_cfg_file_name = self.cfg.get('ub_tp_comm_overlap_cfg', None) + ub_cfgs = None + if ub_cfg_file_name is not None: + try: + import yaml + + with open(ub_cfg_file_name, 'r') as ub_cfg_file: + ub_cfgs = yaml.safe_load(ub_cfg_file) + except (ImportError, TypeError): + logging.error(f"Fail to read ub_tp_comm_overlap config file: {ub_cfg_file_name}.") + te_module.initialize_ub( + shape=input_shape, + tp_size=self.cfg.get('tensor_model_parallel_size'), + use_fp8=self.cfg.get('fp8'), + ub_cfgs=ub_cfgs, + ) + self.initialize_ub = False + def training_step(self, dataloader_iter, batch_idx): """ We pass the dataloader iterator function to the micro-batch scheduler. The input batch to each micro-batch is fetched using the dataloader function in the micro-batch fwd function. """ - # Initialize userbuffer communicators. Initialization is done only once at the - # beginning of the first training step. + # Initialize userbuffer communicators. if self.initialize_ub: - input_shape = [ - self.cfg.get('encoder_seq_length') * self.cfg.get('micro_batch_size'), - self.cfg.get('hidden_size'), - ] - ub_cfg_file_name = self.cfg.get('ub_tp_comm_overlap_cfg', None) - ub_cfgs = None - if ub_cfg_file_name is not None: - try: - import yaml - - with open(ub_cfg_file_name, 'r') as ub_cfg_file: - ub_cfgs = yaml.safe_load(ub_cfg_file) - except (ImportError, TypeError): - print("Fail to read ub_tp_comm_overlap config file.") - - te_module.initialize_ub( - shape=input_shape, - tp_size=self.cfg.get('tensor_model_parallel_size'), - use_fp8=self.cfg.get('fp8'), - ub_cfgs=ub_cfgs, - ) - self.initialize_ub = False + self.initialize_ub_func() if self.rampup_batch_size: num_microbatch_calculator = apex.transformer.pipeline_parallel.utils._GLOBAL_NUM_MICROBATCHES_CALCULATOR @@ -873,6 +872,10 @@ def validation_step(self, dataloader_iter, batch_idx): from the dataloader to produce a list of microbatches. The list of microbatches is then piped through the pipeline using megatron-core fwd/bwd functions. """ + # Initialize userbuffer communicators. + if self.initialize_ub: + self.initialize_ub_func() + if isinstance(self.model, list): for model_module in self.model: model_module.eval()