Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix memory leak at loss func #8929

Merged
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
21 changes: 11 additions & 10 deletions nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -367,6 +367,9 @@ def __init__(self, cfg: DictConfig, trainer: Trainer):
self.log_train_loss = bool(int(os.getenv("NEMO_LOG_TRAIN_LOSS", 1)))
self.log_memory_usage = bool(int(os.getenv("NEMO_LOG_MEMORY_USAGE", 0)))
self.loss_broadcast_src_rank = None
self.return_output_tensors = cfg.data.get('return_output_tensors', False)
self.validation_drop_last = cfg.data.get('validation_drop_last', True)
self.sample_weight = cfg.data.get('sample_weight', 'token')
self.validation_param_sync_overlap = self.cfg.get('validation_param_sync_overlap', False)

self.inference_params = None
Expand Down Expand Up @@ -621,7 +624,7 @@ def fwd_bwd_step(self, dataloader_iter, forward_only, first_val_step=None):

# only the last stages of the pipeline return losses
if losses_reduced_per_micro_batch:
if (not forward_only) or self.cfg.data.get('validation_drop_last', True):
if (not forward_only) or self.validation_drop_last:
# average loss across micro batches
loss_tensors_list = [loss_reduced['avg'] for loss_reduced in losses_reduced_per_micro_batch]
loss_tensor = torch.concat(loss_tensors_list)
Expand Down Expand Up @@ -1136,10 +1139,9 @@ def fwd_output_and_loss_func(dataloader_iter, model, checkpoint_activations_all_
def loss_func(output_tensor):
# Loss for a micro-batch (ub)
loss_for_ub = self.loss_func(batch['loss_mask'], batch['num_valid_tokens_in_ub'], output_tensor)
cp_size = self.cfg.get('context_parallel_size', 1)
if self.cfg.data.get(
"return_output_tensors", False
): # TODO: need a better way to check if loss_func is returning more stuff than just loss... (@adithyare)
cp_size = parallel_state.get_context_parallel_world_size()
if self.return_output_tensors:
# TODO: need a better way to check if loss_func is returning more stuff than just loss... (@adithyare)
loss_for_ub, q_hs, d_hs, pos_cs, neg_cs, diff_cs = loss_for_ub
reduced_loss = average_losses_across_data_parallel_group([loss_for_ub])
pos_cs = average_losses_across_data_parallel_group([pos_cs])
Expand All @@ -1156,15 +1158,14 @@ def loss_func(output_tensor):
'diff_cs': diff_cs,
},
)
elif validation_step and not self.cfg.data.get('validation_drop_last', True):
sample_weight = self.cfg.data.get('sample_weight', 'token')
elif validation_step and not self.validation_drop_last:
num_valid_tokens_in_ub = batch['num_valid_tokens_in_ub']
if loss_for_ub.isnan():
assert batch['loss_mask'].count_nonzero() == 0, 'Got NaN loss with non-empty input'
loss_sum_for_ub = torch.zeros_like(loss_for_ub)
num_valid_tokens_in_ub = 0
else:
if sample_weight == 'constant':
if self.sample_weight == 'constant':
num_valid_tokens_in_ub = 1
loss_sum_for_ub = num_valid_tokens_in_ub * loss_for_ub

Expand Down Expand Up @@ -1296,7 +1297,7 @@ def validation_step(self, dataloader_iter, dataloader_idx=0):
def on_validation_epoch_end(self):
if parallel_state.is_pipeline_last_stage():
# only the last pipeline parallel stages return loss with their batch size
if self.cfg.data.get('validation_drop_last', True):
if self.validation_drop_last:
averaged_loss = torch.stack(self.validation_step_outputs).mean()
else:
# Compute the avg loss by total_loss across all samples / total number of samples
Expand Down Expand Up @@ -1534,7 +1535,7 @@ def setup_validation_data(self, cfg):
)

drop_last = True
if not self.cfg.data.get('validation_drop_last', True):
if not self.validation_drop_last:
logging.info(f'Drop last in validation dataset is set to False')
drop_last = False
pad_samples_to_global_batch_size = False
Expand Down