diff --git a/examples/nlp/language_modeling/conf/megatron_hiddens_base_config.yaml b/examples/nlp/language_modeling/conf/megatron_hiddens_base_config.yaml index 560b966a9c08..08d4c5dc6cb0 100644 --- a/examples/nlp/language_modeling/conf/megatron_hiddens_base_config.yaml +++ b/examples/nlp/language_modeling/conf/megatron_hiddens_base_config.yaml @@ -1,5 +1,9 @@ # this file main purpose is documentation, and it should not be used directly enc_output_name: z # name of key in hidden transforms output to pass to decoder (default: hiddens). e.g., z for VAE/MIM. +enc_inference_output_name: z_mean # if provided, this value will be used when self.training is False (eg inference) + +hidden_aggregation_method: mean # which collapse method you want to use for the hiddens loss over the model latent dimension. +token_aggregation_method: mean # which collapse method you want to use for the hiddens loss over the model latent dimension. tokens_loss_weight: 1.0 # weight of tokens loss (if not specified defaults to 1.0) # the lists below are useful for adding multiple transforms and losses according to order # if order is not important, you can use a single dictionary in the list with multiple keys diff --git a/nemo/collections/nlp/models/language_modeling/megatron_lm_encoder_decoder_model.py b/nemo/collections/nlp/models/language_modeling/megatron_lm_encoder_decoder_model.py index fd0381502a53..8b9d218385c1 100644 --- a/nemo/collections/nlp/models/language_modeling/megatron_lm_encoder_decoder_model.py +++ b/nemo/collections/nlp/models/language_modeling/megatron_lm_encoder_decoder_model.py @@ -134,6 +134,9 @@ def __init__(self, cfg: DictConfig, trainer: Trainer): self.enc_dec_model.model_type = ModelType.encoder_and_decoder + # Get the token aggregation method from the hiddens module for the loss. + self.token_aggregation_method = self.cfg.get('hiddens', {}).get("token_aggregation_method", None) + def setup_optimizer_param_groups(self): """ModelPT override. Optimizer will get self._optimizer_param_groups""" self._optimizer_param_groups = get_params_for_weight_decay_optimization([self.enc_dec_model]) @@ -598,10 +601,13 @@ def loss_func(output_tensor): loss_dict = output_tensor output_tensor = loss_dict.pop("output") # compute reconstruction (tokens) only loss from per-token reconstruction loss - tokens_loss = self.loss_func(loss_mask, output_tensor) + tokens_loss = self.loss_func( + loss_mask, output_tensor, token_aggregation_method=self.token_aggregation_method + ) loss_dict["tokens_loss"] = tokens_loss tokens_loss_weight = loss_dict.get("tokens_loss_weight", 1.0) - # compute total loss + # compute total loss. Note that we want the `loss` variable to point to the pre-reduced form so that reduction can happen + # by the parallel optimizer. loss = loss_dict["loss"] = loss_dict["hiddens_loss"] + tokens_loss_weight * tokens_loss # average losses across data parallel group loss_dict = { @@ -609,7 +615,9 @@ def loss_func(output_tensor): } else: # compute reconstruction (tokens) only loss from per-token reconstruction loss - loss = self.loss_func(loss_mask, output_tensor) + loss = self.loss_func( + loss_mask, output_tensor, token_aggregation_method=self.token_aggregation_method + ) # average losses across data parallel group reduced_loss = average_losses_across_data_parallel_group([loss]) loss_dict = {'loss': reduced_loss} @@ -684,12 +692,17 @@ def fwd_output_only_func(dataloader_iter, model): # map batch and shared args into forward args args = self._build_forward_args_from_kwargs(args_name=arg_names, args=batch, **kwargs) - output = model(*args).contiguous() + output = model(*args) + if torch.is_tensor(output): + output = output.contiguous() + else: + # support hiddens module where returned output is a dict + output = {k: v.contiguous() for k, v in output.items()} def id_func(output_tensor): if isinstance(output_tensor, dict): # handle loss of hidden transformations ("output" is the default output) - output_tensor = output_tensor["output"] + output_tensor = output_tensor[output_name] return output_tensor, {output_name: output_tensor} @@ -779,14 +792,27 @@ def on_validation_epoch_end(self): def on_test_epoch_end(self): return self._test_validation_epoch_end(step_outputs=self.test_step_outputs, prefix="test",) - def loss_func(self, loss_mask, tokens_loss): + def loss_func(self, loss_mask, tokens_loss, token_aggregation_method: Optional[str] = None): """ This function takes as input per-token loss and masks non-required values. """ - losses = tokens_loss.view(-1).float() - loss_mask = loss_mask.view(-1).float() + losses = tokens_loss.float() # Batch x Sequence + loss_mask = loss_mask.float() # Batch x Sequence # TODO: add nemo version here - loss = torch.sum(losses * loss_mask) / loss_mask.sum() # sequence level nll + if token_aggregation_method is None: + loss = torch.sum(losses * loss_mask) / loss_mask.sum() # sequence level nll + elif token_aggregation_method == "mean": + # This variant will put equal weight on every element of the batch rather than increased weight on + # sequences that have longer token lengths. + sample_loss = torch.sum(losses * loss_mask, dim=1) / loss_mask.sum(dim=1).clamp(min=1) + loss = sample_loss.mean() + elif token_aggregation_method == "sum": + sample_loss = torch.sum(losses * loss_mask, dim=1) + loss = sample_loss.mean() + else: + raise ValueError( + f"token_aggregation_method={token_aggregation_method}, expect one of None, 'sum', 'mean'. " + ) return loss def process_micro_batch(self, micro_batch): @@ -1068,7 +1094,7 @@ def dummy(): arg_names.append('enc_input') forward_step_func = self._get_forward_output_only_func( - arg_names=arg_names, output_name="hiddens", output_enc_hidden_only=True + arg_names=arg_names, output_name="enc_output", output_enc_hidden_only=True ) fwd_bwd_func = get_forward_backward_func() @@ -1089,7 +1115,7 @@ def dummy(): ) if output_tensor: - output_tensor = output_tensor[0]['hiddens'] + output_tensor = output_tensor[0]['enc_output'] else: output_tensor = torch.zeros(tensor_shape, dtype=self.autocast_dtype).cuda() @@ -1229,6 +1255,10 @@ def dummy(): if enc_output_attn_mask is None: enc_output_attn_mask = enc_mask + # we read here those variables to be used by beam search only + batch_size, hidden_steps, hidden_size = enc_output.size() + src_length = enc_output_attn_mask.shape[1] + for i in range(num_tokens_to_generate): # No microbatches in decoding. Just the global batch. decoder_seq_length = predicted_tokens_dec.size(1) @@ -1238,7 +1268,7 @@ def dummy(): batch_for_pipeline = [enc_output, enc_output_attn_mask, predicted_tokens_dec, dec_mask, batch_data] arg_names = ['enc_output', 'enc_output_attn_mask', 'dec_input_ids', 'dec_attn_mask', 'batch_data'] - forward_step_func = self._get_forward_output_only_func(arg_names=arg_names, output_name="logits") + forward_step_func = self._get_forward_output_only_func(arg_names=arg_names, output_name="token_logits") fwd_bwd_func = get_forward_backward_func() output_tensor = fwd_bwd_func( @@ -1253,7 +1283,7 @@ def dummy(): ) # get output tensor if parallel_state.is_pipeline_last_stage(): - output_tensor = output_tensor[0]['logits'] + output_tensor = output_tensor[0]['token_logits'] output_tensor = tensor_parallel.gather_from_tensor_model_parallel_region(output_tensor) # make sure it won't sample outside the vocab_size range output_tensor[:, :, tokenizer.vocab_size :] = -float('Inf') @@ -1274,9 +1304,8 @@ def dummy(): log_probs, token_ids = log_probs.view(-1), token_ids.view(-1) scores = log_probs.unsqueeze(1).clone() - batch_size, src_length, hidden_size = enc_output.size() enc_output_attn_mask = enc_output_attn_mask.repeat(1, beam_size).view(-1, src_length) - enc_output = enc_output.repeat(1, beam_size, 1).view(-1, src_length, hidden_size) + enc_output = enc_output.repeat(1, beam_size, 1).view(-1, hidden_steps, hidden_size) # resize tensors that collect predicted tokens and logits per iteration to # match shape of tensors augmented with the beam size diff --git a/nemo/collections/nlp/modules/common/megatron/hiddens/megatron_hidden_loss.py b/nemo/collections/nlp/modules/common/megatron/hiddens/megatron_hidden_loss.py index f10c34d3fad3..da60584cb5d0 100644 --- a/nemo/collections/nlp/modules/common/megatron/hiddens/megatron_hidden_loss.py +++ b/nemo/collections/nlp/modules/common/megatron/hiddens/megatron_hidden_loss.py @@ -25,10 +25,12 @@ class MegatronBaseHiddenLoss(torch.nn.Module): Returned dict includes a loss value and additional outputs. """ - def __init__(self, loss_weight=1.0, name=""): + def __init__(self, loss_weight=1.0, name="", hidden_aggregation_method: str = "mean"): super().__init__() self.name = name self.loss_weight = float(loss_weight) + # allows to determine how to aggregate hidden loss over hidden dimension + self.hidden_aggregation_method = hidden_aggregation_method def __str__(self): return super().__str__() + f"(name={self.name})" @@ -78,10 +80,17 @@ def loss(self, inputs, batch_data=None): hiddens_mask = inputs["hiddens_mask"].to(loss) loss = loss * hiddens_mask # sequence level loss [B x S] -> batch level loss [B] - loss = loss.sum(dim=1) / hiddens_mask.sum(dim=1).clamp(min=1.0) + if self.hidden_aggregation_method == "mean": + loss = loss.sum(dim=1) / hiddens_mask.sum(dim=1).clamp(min=1.0) + elif self.hidden_aggregation_method == "sum": + loss = loss.sum(dim=1) + else: + raise ValueError( + f"hidden_aggregation_method={self.hidden_aggregation_method} not recognized, support 'mean' or 'sum'" + ) # compute batch level weighted loss (scalar) - weighted_loss = loss.sum() * self.loss_weight + weighted_loss = loss.mean() * self.loss_weight # store updated losses loss_dict["loss"] = loss @@ -98,13 +107,8 @@ class MegatronAMIMHiddenLoss(MegatronBaseHiddenLoss): A-MIM - asymmetric MIM (without sampling) """ - def __init__(self, loss_weight=1.0, hidden_aggregation_method="sum", name="mim"): - super().__init__( - name=name, loss_weight=loss_weight, - ) - - # allows to determine how to aggregate hidden loss over hidden dimension - self.hidden_aggregation_method = hidden_aggregation_method + def __init__(self, loss_weight=1.0, hidden_aggregation_method="mean", name="mim"): + super().__init__(name=name, loss_weight=loss_weight, hidden_aggregation_method=hidden_aggregation_method) def _input_names(self): """Add here all required inputs""" @@ -121,12 +125,12 @@ def _loss(self, inputs, batch_data=None): """ z = inputs["z"] # get posterior - log_prob_q_z_given_x = inputs["z_log_prob"] + aggregator = getattr(torch, self.hidden_aggregation_method) + log_prob_q_z_given_x = aggregator(inputs["z_log_prob"], dim=-1) # [B x S x H] -> [B x S] # compute log prob of anchor a unit Normal distribution log_prob_P_z = -0.5 * (math.log(2 * math.pi) + z.pow(2)) # aggregate over hidden dimension, default is sum - log_prob_P_z = getattr(log_prob_P_z, self.hidden_aggregation_method)(dim=-1) - + log_prob_P_z = aggregator(log_prob_P_z, dim=-1) # A-MIM loss = log_p_x_given_z - 0.5 * (log_prob_P_z + log_prob_q_z_given_x) # here we return only the hidden loss part loss = -0.5 * (log_prob_P_z + log_prob_q_z_given_x) @@ -145,11 +149,8 @@ class MegatronVAEHiddenLoss(MegatronBaseHiddenLoss): Implements VAE loss with a unit Normal anchor. """ - def __init__(self, loss_weight=1.0, min_kl_value=None, name="vae"): - super().__init__( - name=name, loss_weight=loss_weight, - ) - + def __init__(self, loss_weight=1.0, min_kl_value=None, hidden_aggregation_method="mean", name="vae"): + super().__init__(name=name, loss_weight=loss_weight, hidden_aggregation_method=hidden_aggregation_method) # minimum value for KL divergence if min_kl_value is None: self.min_kl_value = min_kl_value @@ -171,9 +172,14 @@ def _loss(self, inputs, batch_data=None): """ z = inputs["z"] # get posterior - log_prob_q_z_given_x = inputs["z_log_prob"] + aggregator = getattr(torch, self.hidden_aggregation_method) + log_prob_q_z_given_x = aggregator(inputs["z_log_prob"], dim=-1) # [B x S x H] -> [B x S] + # compute log prob of anchor a unit Normal distribution - log_prob_p_z = -0.5 * (math.log(2 * math.pi) + z.pow(2)).sum(dim=-1) + log_prob_p_z = -0.5 * (math.log(2 * math.pi) + z.pow(2)) + + # aggregate over hidden dimension, default is sum + log_prob_P_z = aggregator(log_prob_P_z, dim=-1) # VAE loss = log_p_x_given_z - KL(q(z|x) || p(z)) kl_div = log_prob_q_z_given_x - log_prob_p_z diff --git a/nemo/collections/nlp/modules/common/megatron/hiddens/megatron_hidden_transform.py b/nemo/collections/nlp/modules/common/megatron/hiddens/megatron_hidden_transform.py index a604819e4df7..e780b7fb0778 100644 --- a/nemo/collections/nlp/modules/common/megatron/hiddens/megatron_hidden_transform.py +++ b/nemo/collections/nlp/modules/common/megatron/hiddens/megatron_hidden_transform.py @@ -168,12 +168,10 @@ def _transform(self, inputs, batch_data=None): if z_log_prob is None: # compute log probability of z under a diagonal Gaussian distribution z_log_prob = -0.5 * (math.log(2 * math.pi) + z_logvar + (z - z_mean).pow(2) / z_logvar.exp()) - # sum over the last dimension (hidden_size) - z_log_prob = z_log_prob.sum(dim=-1) return { "z": z, # [S x B x H] "z_mean": z_mean, # [S x B x H] "z_logvar": z_logvar, # [S x B x H] - "z_log_prob": z_log_prob, # [S x B] + "z_log_prob": z_log_prob, # [S x B x H] } diff --git a/nemo/collections/nlp/modules/common/megatron/hiddens/megatron_hiddens.py b/nemo/collections/nlp/modules/common/megatron/hiddens/megatron_hiddens.py index 1a2e48ef7fc1..de6515bda725 100644 --- a/nemo/collections/nlp/modules/common/megatron/hiddens/megatron_hiddens.py +++ b/nemo/collections/nlp/modules/common/megatron/hiddens/megatron_hiddens.py @@ -21,7 +21,7 @@ import functools import itertools -from typing import List +from typing import List, Optional import torch from omegaconf.dictconfig import DictConfig @@ -116,6 +116,8 @@ def get_hiddens_module(cfg=None, model_parallel_cfg: ModelParallelConfig = None) if cfg is None: return None + hidden_aggregation_method = cfg.get("hidden_aggregation_method", "mean") + logging.info(f"NOTE: Adding hiddens transforms and losses") # build all hidden transforms. We support a list or a dictionary of transforms (list enforces order) @@ -155,6 +157,8 @@ def get_hiddens_module(cfg=None, model_parallel_cfg: ModelParallelConfig = None) for cur_list_cfg in loss_cfg: for name, cur_cfg in cur_list_cfg.items(): cls_kwargs = OmegaConf.to_container(cur_cfg) + this_hidden_aggregation_method = cls_kwargs.get("hidden_aggregation_method", hidden_aggregation_method) + cls_kwargs["hidden_aggregation_method"] = this_hidden_aggregation_method if not "cls_name" in cls_kwargs: raise KeyError(f"Missing 'cls_name' in hidden loss {name}") @@ -173,11 +177,15 @@ def get_hiddens_module(cfg=None, model_parallel_cfg: ModelParallelConfig = None) logging.info(f"Added loss {name} with cfg={cur_cfg}") enc_output_name = cfg.get("enc_output_name", "hiddens") + enc_inference_output_name = cfg.get("enc_inference_output_name", None) + tokens_loss_weight = cfg.get("tokens_loss_weight", 1.0) return MegatronHiddensModule( hidden_transforms=hidden_transforms, hidden_loss_transforms=hidden_loss_transforms, enc_output_name=enc_output_name, + enc_inference_output_name=enc_inference_output_name, + tokens_loss_weight=tokens_loss_weight, ) @@ -192,6 +200,7 @@ def __init__( hidden_transforms: List[MegatronBaseHiddenLoss] = [], hidden_loss_transforms: List[MegatronBaseHiddenTransform] = [], enc_output_name: str = "hiddens", # name (key) of the encoder output + enc_inference_output_name: Optional[str] = None, # if provided, use different key when self.training is False tokens_loss_weight: float = 1.0, # weight of the tokens loss loss_prefix: str = "hiddens_", # if not None or "", add this prefix to all loss names ): @@ -199,6 +208,9 @@ def __init__( self.hidden_transforms = hidden_transforms self.hidden_loss_transforms = hidden_loss_transforms self.enc_output_name = enc_output_name + self.enc_inference_output_name = ( + enc_output_name if enc_inference_output_name is None else enc_inference_output_name + ) self.tokens_loss_weight = tokens_loss_weight self.loss_prefix = loss_prefix @@ -276,9 +288,11 @@ def apply_hidden_transforms(self, inputs, batch_data=None): # make sure to collect all outputs from hidden transforms outputs.update(hidden_transform.transform(outputs, batch_data=batch_data)) - # update final encoder output - outputs["enc_output"] = outputs[self.enc_output_name] - + # update final encoder output. Split into output_name/inference output name to support z vs z_mean for example with VAE style hiddens + if self.training: + outputs["enc_output"] = outputs[self.enc_output_name] + else: + outputs["enc_output"] = outputs[self.enc_inference_output_name] return outputs def apply_loss_transforms(self, outputs, batch_data=None): diff --git a/nemo/collections/nlp/modules/common/megatron/megatron_encoder_decoder.py b/nemo/collections/nlp/modules/common/megatron/megatron_encoder_decoder.py index c4192dacb45a..dc3f1a3ee6b9 100644 --- a/nemo/collections/nlp/modules/common/megatron/megatron_encoder_decoder.py +++ b/nemo/collections/nlp/modules/common/megatron/megatron_encoder_decoder.py @@ -217,7 +217,7 @@ def forward( dec_input=dec_input, dec_attn_mask=dec_attn_mask, enc_output=enc_output["enc_output"] # enc_output is a dict if we used hidden transformations - if self.hiddens_module is not None + if self.hiddens_module is not None and not torch.is_tensor(enc_output) else enc_output, # Adjust encoder attention mask if encoder is a perceiver. enc_attn_mask=self.get_hiddens_mask(enc_attn_mask), diff --git a/nemo/collections/nlp/modules/common/megatron/token_level_encoder_decoder.py b/nemo/collections/nlp/modules/common/megatron/token_level_encoder_decoder.py index b7b377940eb4..00b12fedf849 100644 --- a/nemo/collections/nlp/modules/common/megatron/token_level_encoder_decoder.py +++ b/nemo/collections/nlp/modules/common/megatron/token_level_encoder_decoder.py @@ -669,6 +669,9 @@ def forward( # [s, b, h] -> [b, s, h] token_logits = token_logits.transpose(0, 1).contiguous() if self.hiddens_cfg is not None: + # we support enc_output being a tensor even if hiddens_module is used + if torch.is_tensor(enc_output): + enc_output = {"enc_output": enc_output} # return all hiddens and token logits hiddens_dict = enc_output hiddens_dict["token_logits"] = token_logits