Skip to content

Commit

Permalink
Fix memory leak at loss func (#8929)
Browse files Browse the repository at this point in the history
Signed-off-by: Jaemin Choi <jaeminc@nvidia.com>
Co-authored-by: Jaemin Choi <jaeminc@nvidia.com>
  • Loading branch information
2 people authored and web-flow committed Apr 16, 2024
1 parent e9d8266 commit 17038fa
Showing 1 changed file with 11 additions and 10 deletions.
21 changes: 11 additions & 10 deletions nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -367,6 +367,9 @@ def __init__(self, cfg: DictConfig, trainer: Trainer):
self.log_train_loss = bool(int(os.getenv("NEMO_LOG_TRAIN_LOSS", 1)))
self.log_memory_usage = bool(int(os.getenv("NEMO_LOG_MEMORY_USAGE", 0)))
self.loss_broadcast_src_rank = None
self.return_output_tensors = cfg.data.get('return_output_tensors', False)
self.validation_drop_last = cfg.data.get('validation_drop_last', True)
self.sample_weight = cfg.data.get('sample_weight', 'token')
self.validation_param_sync_overlap = self.cfg.get('validation_param_sync_overlap', False)

self.inference_params = None
Expand Down Expand Up @@ -621,7 +624,7 @@ def fwd_bwd_step(self, dataloader_iter, forward_only, first_val_step=None):

# only the last stages of the pipeline return losses
if losses_reduced_per_micro_batch:
if (not forward_only) or self.cfg.data.get('validation_drop_last', True):
if (not forward_only) or self.validation_drop_last:
# average loss across micro batches
loss_tensors_list = [loss_reduced['avg'] for loss_reduced in losses_reduced_per_micro_batch]
loss_tensor = torch.concat(loss_tensors_list)
Expand Down Expand Up @@ -1136,10 +1139,9 @@ def fwd_output_and_loss_func(dataloader_iter, model, checkpoint_activations_all_
def loss_func(output_tensor):
# Loss for a micro-batch (ub)
loss_for_ub = self.loss_func(batch['loss_mask'], batch['num_valid_tokens_in_ub'], output_tensor)
cp_size = self.cfg.get('context_parallel_size', 1)
if self.cfg.data.get(
"return_output_tensors", False
): # TODO: need a better way to check if loss_func is returning more stuff than just loss... (@adithyare)
cp_size = parallel_state.get_context_parallel_world_size()
if self.return_output_tensors:
# TODO: need a better way to check if loss_func is returning more stuff than just loss... (@adithyare)
loss_for_ub, q_hs, d_hs, pos_cs, neg_cs, diff_cs = loss_for_ub
reduced_loss = average_losses_across_data_parallel_group([loss_for_ub])
pos_cs = average_losses_across_data_parallel_group([pos_cs])
Expand All @@ -1156,15 +1158,14 @@ def loss_func(output_tensor):
'diff_cs': diff_cs,
},
)
elif validation_step and not self.cfg.data.get('validation_drop_last', True):
sample_weight = self.cfg.data.get('sample_weight', 'token')
elif validation_step and not self.validation_drop_last:
num_valid_tokens_in_ub = batch['num_valid_tokens_in_ub']
if loss_for_ub.isnan():
assert batch['loss_mask'].count_nonzero() == 0, 'Got NaN loss with non-empty input'
loss_sum_for_ub = torch.zeros_like(loss_for_ub)
num_valid_tokens_in_ub = 0
else:
if sample_weight == 'constant':
if self.sample_weight == 'constant':
num_valid_tokens_in_ub = 1
loss_sum_for_ub = num_valid_tokens_in_ub * loss_for_ub

Expand Down Expand Up @@ -1296,7 +1297,7 @@ def validation_step(self, dataloader_iter, dataloader_idx=0):
def on_validation_epoch_end(self):
if parallel_state.is_pipeline_last_stage():
# only the last pipeline parallel stages return loss with their batch size
if self.cfg.data.get('validation_drop_last', True):
if self.validation_drop_last:
averaged_loss = torch.stack(self.validation_step_outputs).mean()
else:
# Compute the avg loss by total_loss across all samples / total number of samples
Expand Down Expand Up @@ -1534,7 +1535,7 @@ def setup_validation_data(self, cfg):
)

drop_last = True
if not self.cfg.data.get('validation_drop_last', True):
if not self.validation_drop_last:
logging.info(f'Drop last in validation dataset is set to False')
drop_last = False
pad_samples_to_global_batch_size = False
Expand Down

0 comments on commit 17038fa

Please sign in to comment.