diff --git a/nemo/collections/nlp/models/language_modeling/megatron_lm_encoder_decoder_model.py b/nemo/collections/nlp/models/language_modeling/megatron_lm_encoder_decoder_model.py index fd0381502a53..2c2d8ad93ecf 100644 --- a/nemo/collections/nlp/models/language_modeling/megatron_lm_encoder_decoder_model.py +++ b/nemo/collections/nlp/models/language_modeling/megatron_lm_encoder_decoder_model.py @@ -1068,7 +1068,7 @@ def dummy(): arg_names.append('enc_input') forward_step_func = self._get_forward_output_only_func( - arg_names=arg_names, output_name="hiddens", output_enc_hidden_only=True + arg_names=arg_names, output_name="enc_output", output_enc_hidden_only=True ) fwd_bwd_func = get_forward_backward_func() @@ -1089,7 +1089,7 @@ def dummy(): ) if output_tensor: - output_tensor = output_tensor[0]['hiddens'] + output_tensor = output_tensor[0]['enc_output'] else: output_tensor = torch.zeros(tensor_shape, dtype=self.autocast_dtype).cuda() diff --git a/nemo/collections/nlp/modules/common/megatron/hiddens/megatron_hidden_loss.py b/nemo/collections/nlp/modules/common/megatron/hiddens/megatron_hidden_loss.py index f10c34d3fad3..afe847f2521b 100644 --- a/nemo/collections/nlp/modules/common/megatron/hiddens/megatron_hidden_loss.py +++ b/nemo/collections/nlp/modules/common/megatron/hiddens/megatron_hidden_loss.py @@ -81,7 +81,7 @@ def loss(self, inputs, batch_data=None): loss = loss.sum(dim=1) / hiddens_mask.sum(dim=1).clamp(min=1.0) # compute batch level weighted loss (scalar) - weighted_loss = loss.sum() * self.loss_weight + weighted_loss = loss.mean() * self.loss_weight # store updated losses loss_dict["loss"] = loss diff --git a/nemo/collections/nlp/modules/common/megatron/hiddens/megatron_hiddens.py b/nemo/collections/nlp/modules/common/megatron/hiddens/megatron_hiddens.py index 1a2e48ef7fc1..82f73dd12ea7 100644 --- a/nemo/collections/nlp/modules/common/megatron/hiddens/megatron_hiddens.py +++ b/nemo/collections/nlp/modules/common/megatron/hiddens/megatron_hiddens.py @@ -21,7 +21,7 @@ import functools import itertools -from typing import List +from typing import List, Optional import torch from omegaconf.dictconfig import DictConfig @@ -192,6 +192,7 @@ def __init__( hidden_transforms: List[MegatronBaseHiddenLoss] = [], hidden_loss_transforms: List[MegatronBaseHiddenTransform] = [], enc_output_name: str = "hiddens", # name (key) of the encoder output + enc_inference_output_name: Optional[str] = None, # if provided, use different key when self.training is False tokens_loss_weight: float = 1.0, # weight of the tokens loss loss_prefix: str = "hiddens_", # if not None or "", add this prefix to all loss names ): @@ -199,6 +200,9 @@ def __init__( self.hidden_transforms = hidden_transforms self.hidden_loss_transforms = hidden_loss_transforms self.enc_output_name = enc_output_name + self.enc_inference_output_name = ( + enc_output_name if enc_inference_output_name is None else enc_inference_output_name + ) self.tokens_loss_weight = tokens_loss_weight self.loss_prefix = loss_prefix @@ -276,9 +280,11 @@ def apply_hidden_transforms(self, inputs, batch_data=None): # make sure to collect all outputs from hidden transforms outputs.update(hidden_transform.transform(outputs, batch_data=batch_data)) - # update final encoder output - outputs["enc_output"] = outputs[self.enc_output_name] - + # update final encoder output. Split into output_name/inference output name to support z vs z_mean for example with VAE style hiddens + if self.training: + outputs["enc_output"] = outputs[self.enc_output_name] + else: + outputs["enc_output"] = outputs[self.enc_inference_output_name] return outputs def apply_loss_transforms(self, outputs, batch_data=None):