Skip to content

Commit

Permalink
Support a training key and an optionally separate key for inference i…
Browse files Browse the repository at this point in the history
…n hiddens modules

Signed-off-by: John St John <jstjohn@nvidia.com>
  • Loading branch information
jstjohn committed Feb 29, 2024
1 parent 55d6059 commit aefe1d5
Show file tree
Hide file tree
Showing 7 changed files with 97 additions and 43 deletions.
@@ -1,5 +1,9 @@
# this file main purpose is documentation, and it should not be used directly
enc_output_name: z # name of key in hidden transforms output to pass to decoder (default: hiddens). e.g., z for VAE/MIM.
enc_inference_output_name: z_mean # if provided, this value will be used when self.training is False (eg inference)

hidden_aggregation_method: mean # which collapse method you want to use for the hiddens loss over the model latent dimension.
token_aggregation_method: mean # which collapse method you want to use for the hiddens loss over the model latent dimension.
tokens_loss_weight: 1.0 # weight of tokens loss (if not specified defaults to 1.0)
# the lists below are useful for adding multiple transforms and losses according to order
# if order is not important, you can use a single dictionary in the list with multiple keys
Expand Down
Expand Up @@ -134,6 +134,9 @@ def __init__(self, cfg: DictConfig, trainer: Trainer):

self.enc_dec_model.model_type = ModelType.encoder_and_decoder

# Get the token aggregation method from the hiddens module for the loss.
self.token_aggregation_method = self.cfg.get('hiddens', {}).get("token_aggregation_method", None)

def setup_optimizer_param_groups(self):
"""ModelPT override. Optimizer will get self._optimizer_param_groups"""
self._optimizer_param_groups = get_params_for_weight_decay_optimization([self.enc_dec_model])
Expand Down Expand Up @@ -598,18 +601,23 @@ def loss_func(output_tensor):
loss_dict = output_tensor
output_tensor = loss_dict.pop("output")
# compute reconstruction (tokens) only loss from per-token reconstruction loss
tokens_loss = self.loss_func(loss_mask, output_tensor)
tokens_loss = self.loss_func(
loss_mask, output_tensor, token_aggregation_method=self.token_aggregation_method
)
loss_dict["tokens_loss"] = tokens_loss
tokens_loss_weight = loss_dict.get("tokens_loss_weight", 1.0)
# compute total loss
# compute total loss. Note that we want the `loss` variable to point to the pre-reduced form so that reduction can happen
# by the parallel optimizer.
loss = loss_dict["loss"] = loss_dict["hiddens_loss"] + tokens_loss_weight * tokens_loss
# average losses across data parallel group
loss_dict = {
k: average_losses_across_data_parallel_group([v.mean()]) for k, v in loss_dict.items()
}
else:
# compute reconstruction (tokens) only loss from per-token reconstruction loss
loss = self.loss_func(loss_mask, output_tensor)
loss = self.loss_func(
loss_mask, output_tensor, token_aggregation_method=self.token_aggregation_method
)
# average losses across data parallel group
reduced_loss = average_losses_across_data_parallel_group([loss])
loss_dict = {'loss': reduced_loss}
Expand Down Expand Up @@ -684,12 +692,17 @@ def fwd_output_only_func(dataloader_iter, model):

# map batch and shared args into forward args
args = self._build_forward_args_from_kwargs(args_name=arg_names, args=batch, **kwargs)
output = model(*args).contiguous()
output = model(*args)
if torch.is_tensor(output):
output = output.contiguous()
else:
# support hiddens module where returned output is a dict
output = {k: v.contiguous() for k, v in output.items()}

def id_func(output_tensor):
if isinstance(output_tensor, dict):
# handle loss of hidden transformations ("output" is the default output)
output_tensor = output_tensor["output"]
output_tensor = output_tensor[output_name]

return output_tensor, {output_name: output_tensor}

Expand Down Expand Up @@ -779,14 +792,27 @@ def on_validation_epoch_end(self):
def on_test_epoch_end(self):
return self._test_validation_epoch_end(step_outputs=self.test_step_outputs, prefix="test",)

def loss_func(self, loss_mask, tokens_loss):
def loss_func(self, loss_mask, tokens_loss, token_aggregation_method: Optional[str] = None):
"""
This function takes as input per-token loss and masks non-required values.
"""
losses = tokens_loss.view(-1).float()
loss_mask = loss_mask.view(-1).float()
losses = tokens_loss.float() # Batch x Sequence
loss_mask = loss_mask.float() # Batch x Sequence
# TODO: add nemo version here
loss = torch.sum(losses * loss_mask) / loss_mask.sum() # sequence level nll
if token_aggregation_method is None:
loss = torch.sum(losses * loss_mask) / loss_mask.sum() # sequence level nll
elif token_aggregation_method == "mean":
# This variant will put equal weight on every element of the batch rather than increased weight on
# sequences that have longer token lengths.
sample_loss = torch.sum(losses * loss_mask, dim=1) / loss_mask.sum(dim=1).clamp(min=1)
loss = sample_loss.mean()
elif token_aggregation_method == "sum":
sample_loss = torch.sum(losses * loss_mask, dim=1)
loss = sample_loss.mean()
else:
raise ValueError(
f"token_aggregation_method={token_aggregation_method}, expect one of None, 'sum', 'mean'. "
)
return loss

def process_micro_batch(self, micro_batch):
Expand Down Expand Up @@ -1068,7 +1094,7 @@ def dummy():
arg_names.append('enc_input')

forward_step_func = self._get_forward_output_only_func(
arg_names=arg_names, output_name="hiddens", output_enc_hidden_only=True
arg_names=arg_names, output_name="enc_output", output_enc_hidden_only=True
)

fwd_bwd_func = get_forward_backward_func()
Expand All @@ -1089,7 +1115,7 @@ def dummy():
)

if output_tensor:
output_tensor = output_tensor[0]['hiddens']
output_tensor = output_tensor[0]['enc_output']
else:
output_tensor = torch.zeros(tensor_shape, dtype=self.autocast_dtype).cuda()

Expand Down Expand Up @@ -1229,6 +1255,10 @@ def dummy():
if enc_output_attn_mask is None:
enc_output_attn_mask = enc_mask

# we read here those variables to be used by beam search only
batch_size, hidden_steps, hidden_size = enc_output.size()
src_length = enc_output_attn_mask.shape[1]

for i in range(num_tokens_to_generate):
# No microbatches in decoding. Just the global batch.
decoder_seq_length = predicted_tokens_dec.size(1)
Expand All @@ -1238,7 +1268,7 @@ def dummy():
batch_for_pipeline = [enc_output, enc_output_attn_mask, predicted_tokens_dec, dec_mask, batch_data]
arg_names = ['enc_output', 'enc_output_attn_mask', 'dec_input_ids', 'dec_attn_mask', 'batch_data']

forward_step_func = self._get_forward_output_only_func(arg_names=arg_names, output_name="logits")
forward_step_func = self._get_forward_output_only_func(arg_names=arg_names, output_name="token_logits")
fwd_bwd_func = get_forward_backward_func()

output_tensor = fwd_bwd_func(
Expand All @@ -1253,7 +1283,7 @@ def dummy():
)
# get output tensor
if parallel_state.is_pipeline_last_stage():
output_tensor = output_tensor[0]['logits']
output_tensor = output_tensor[0]['token_logits']
output_tensor = tensor_parallel.gather_from_tensor_model_parallel_region(output_tensor)
# make sure it won't sample outside the vocab_size range
output_tensor[:, :, tokenizer.vocab_size :] = -float('Inf')
Expand All @@ -1274,9 +1304,8 @@ def dummy():
log_probs, token_ids = log_probs.view(-1), token_ids.view(-1)
scores = log_probs.unsqueeze(1).clone()

batch_size, src_length, hidden_size = enc_output.size()
enc_output_attn_mask = enc_output_attn_mask.repeat(1, beam_size).view(-1, src_length)
enc_output = enc_output.repeat(1, beam_size, 1).view(-1, src_length, hidden_size)
enc_output = enc_output.repeat(1, beam_size, 1).view(-1, hidden_steps, hidden_size)

# resize tensors that collect predicted tokens and logits per iteration to
# match shape of tensors augmented with the beam size
Expand Down
Expand Up @@ -25,10 +25,12 @@ class MegatronBaseHiddenLoss(torch.nn.Module):
Returned dict includes a loss value and additional outputs.
"""

def __init__(self, loss_weight=1.0, name=""):
def __init__(self, loss_weight=1.0, name="", hidden_aggregation_method: str = "mean"):
super().__init__()
self.name = name
self.loss_weight = float(loss_weight)
# allows to determine how to aggregate hidden loss over hidden dimension
self.hidden_aggregation_method = hidden_aggregation_method

def __str__(self):
return super().__str__() + f"(name={self.name})"
Expand Down Expand Up @@ -78,10 +80,17 @@ def loss(self, inputs, batch_data=None):
hiddens_mask = inputs["hiddens_mask"].to(loss)
loss = loss * hiddens_mask
# sequence level loss [B x S] -> batch level loss [B]
loss = loss.sum(dim=1) / hiddens_mask.sum(dim=1).clamp(min=1.0)
if self.hidden_aggregation_method == "mean":
loss = loss.sum(dim=1) / hiddens_mask.sum(dim=1).clamp(min=1.0)
elif self.hidden_aggregation_method == "sum":
loss = loss.sum(dim=1)
else:
raise ValueError(
f"hidden_aggregation_method={self.hidden_aggregation_method} not recognized, support 'mean' or 'sum'"
)

# compute batch level weighted loss (scalar)
weighted_loss = loss.sum() * self.loss_weight
weighted_loss = loss.mean() * self.loss_weight

# store updated losses
loss_dict["loss"] = loss
Expand All @@ -98,13 +107,8 @@ class MegatronAMIMHiddenLoss(MegatronBaseHiddenLoss):
A-MIM - asymmetric MIM (without sampling)
"""

def __init__(self, loss_weight=1.0, hidden_aggregation_method="sum", name="mim"):
super().__init__(
name=name, loss_weight=loss_weight,
)

# allows to determine how to aggregate hidden loss over hidden dimension
self.hidden_aggregation_method = hidden_aggregation_method
def __init__(self, loss_weight=1.0, hidden_aggregation_method="mean", name="mim"):
super().__init__(name=name, loss_weight=loss_weight, hidden_aggregation_method=hidden_aggregation_method)

def _input_names(self):
"""Add here all required inputs"""
Expand All @@ -121,12 +125,12 @@ def _loss(self, inputs, batch_data=None):
"""
z = inputs["z"]
# get posterior
log_prob_q_z_given_x = inputs["z_log_prob"]
aggregator = getattr(torch, self.hidden_aggregation_method)
log_prob_q_z_given_x = aggregator(inputs["z_log_prob"], dim=-1) # [B x S x H] -> [B x S]
# compute log prob of anchor a unit Normal distribution
log_prob_P_z = -0.5 * (math.log(2 * math.pi) + z.pow(2))
# aggregate over hidden dimension, default is sum
log_prob_P_z = getattr(log_prob_P_z, self.hidden_aggregation_method)(dim=-1)

log_prob_P_z = aggregator(log_prob_P_z, dim=-1)
# A-MIM loss = log_p_x_given_z - 0.5 * (log_prob_P_z + log_prob_q_z_given_x)
# here we return only the hidden loss part
loss = -0.5 * (log_prob_P_z + log_prob_q_z_given_x)
Expand All @@ -145,11 +149,8 @@ class MegatronVAEHiddenLoss(MegatronBaseHiddenLoss):
Implements VAE loss with a unit Normal anchor.
"""

def __init__(self, loss_weight=1.0, min_kl_value=None, name="vae"):
super().__init__(
name=name, loss_weight=loss_weight,
)

def __init__(self, loss_weight=1.0, min_kl_value=None, hidden_aggregation_method="mean", name="vae"):
super().__init__(name=name, loss_weight=loss_weight, hidden_aggregation_method=hidden_aggregation_method)
# minimum value for KL divergence
if min_kl_value is None:
self.min_kl_value = min_kl_value
Expand All @@ -171,9 +172,14 @@ def _loss(self, inputs, batch_data=None):
"""
z = inputs["z"]
# get posterior
log_prob_q_z_given_x = inputs["z_log_prob"]
aggregator = getattr(torch, self.hidden_aggregation_method)
log_prob_q_z_given_x = aggregator(inputs["z_log_prob"], dim=-1) # [B x S x H] -> [B x S]

# compute log prob of anchor a unit Normal distribution
log_prob_p_z = -0.5 * (math.log(2 * math.pi) + z.pow(2)).sum(dim=-1)
log_prob_p_z = -0.5 * (math.log(2 * math.pi) + z.pow(2))

# aggregate over hidden dimension, default is sum
log_prob_P_z = aggregator(log_prob_P_z, dim=-1)

Check notice

Code scanning / CodeQL

Unused local variable Note

Variable log_prob_P_z is not used.

# VAE loss = log_p_x_given_z - KL(q(z|x) || p(z))
kl_div = log_prob_q_z_given_x - log_prob_p_z
Expand Down
Expand Up @@ -168,12 +168,10 @@ def _transform(self, inputs, batch_data=None):
if z_log_prob is None:
# compute log probability of z under a diagonal Gaussian distribution
z_log_prob = -0.5 * (math.log(2 * math.pi) + z_logvar + (z - z_mean).pow(2) / z_logvar.exp())
# sum over the last dimension (hidden_size)
z_log_prob = z_log_prob.sum(dim=-1)

return {
"z": z, # [S x B x H]
"z_mean": z_mean, # [S x B x H]
"z_logvar": z_logvar, # [S x B x H]
"z_log_prob": z_log_prob, # [S x B]
"z_log_prob": z_log_prob, # [S x B x H]
}
Expand Up @@ -21,7 +21,7 @@

import functools
import itertools
from typing import List
from typing import List, Optional

import torch
from omegaconf.dictconfig import DictConfig
Expand Down Expand Up @@ -116,6 +116,8 @@ def get_hiddens_module(cfg=None, model_parallel_cfg: ModelParallelConfig = None)
if cfg is None:
return None

hidden_aggregation_method = cfg.get("hidden_aggregation_method", "mean")

logging.info(f"NOTE: Adding hiddens transforms and losses")

# build all hidden transforms. We support a list or a dictionary of transforms (list enforces order)
Expand Down Expand Up @@ -155,6 +157,8 @@ def get_hiddens_module(cfg=None, model_parallel_cfg: ModelParallelConfig = None)
for cur_list_cfg in loss_cfg:
for name, cur_cfg in cur_list_cfg.items():
cls_kwargs = OmegaConf.to_container(cur_cfg)
this_hidden_aggregation_method = cls_kwargs.get("hidden_aggregation_method", hidden_aggregation_method)
cls_kwargs["hidden_aggregation_method"] = this_hidden_aggregation_method
if not "cls_name" in cls_kwargs:
raise KeyError(f"Missing 'cls_name' in hidden loss {name}")

Expand All @@ -173,11 +177,15 @@ def get_hiddens_module(cfg=None, model_parallel_cfg: ModelParallelConfig = None)
logging.info(f"Added loss {name} with cfg={cur_cfg}")

enc_output_name = cfg.get("enc_output_name", "hiddens")
enc_inference_output_name = cfg.get("enc_inference_output_name", None)
tokens_loss_weight = cfg.get("tokens_loss_weight", 1.0)

return MegatronHiddensModule(
hidden_transforms=hidden_transforms,
hidden_loss_transforms=hidden_loss_transforms,
enc_output_name=enc_output_name,
enc_inference_output_name=enc_inference_output_name,
tokens_loss_weight=tokens_loss_weight,
)


Expand All @@ -192,13 +200,17 @@ def __init__(
hidden_transforms: List[MegatronBaseHiddenLoss] = [],
hidden_loss_transforms: List[MegatronBaseHiddenTransform] = [],
enc_output_name: str = "hiddens", # name (key) of the encoder output
enc_inference_output_name: Optional[str] = None, # if provided, use different key when self.training is False
tokens_loss_weight: float = 1.0, # weight of the tokens loss
loss_prefix: str = "hiddens_", # if not None or "", add this prefix to all loss names
):
super().__init__()
self.hidden_transforms = hidden_transforms
self.hidden_loss_transforms = hidden_loss_transforms
self.enc_output_name = enc_output_name
self.enc_inference_output_name = (
enc_output_name if enc_inference_output_name is None else enc_inference_output_name
)
self.tokens_loss_weight = tokens_loss_weight
self.loss_prefix = loss_prefix

Expand Down Expand Up @@ -276,9 +288,11 @@ def apply_hidden_transforms(self, inputs, batch_data=None):
# make sure to collect all outputs from hidden transforms
outputs.update(hidden_transform.transform(outputs, batch_data=batch_data))

# update final encoder output
outputs["enc_output"] = outputs[self.enc_output_name]

# update final encoder output. Split into output_name/inference output name to support z vs z_mean for example with VAE style hiddens
if self.training:
outputs["enc_output"] = outputs[self.enc_output_name]
else:
outputs["enc_output"] = outputs[self.enc_inference_output_name]
return outputs

def apply_loss_transforms(self, outputs, batch_data=None):
Expand Down
Expand Up @@ -217,7 +217,7 @@ def forward(
dec_input=dec_input,
dec_attn_mask=dec_attn_mask,
enc_output=enc_output["enc_output"] # enc_output is a dict if we used hidden transformations
if self.hiddens_module is not None
if self.hiddens_module is not None and not torch.is_tensor(enc_output)
else enc_output,
# Adjust encoder attention mask if encoder is a perceiver.
enc_attn_mask=self.get_hiddens_mask(enc_attn_mask),
Expand Down
Expand Up @@ -669,6 +669,9 @@ def forward(
# [s, b, h] -> [b, s, h]
token_logits = token_logits.transpose(0, 1).contiguous()
if self.hiddens_cfg is not None:
# we support enc_output being a tensor even if hiddens_module is used
if torch.is_tensor(enc_output):
enc_output = {"enc_output": enc_output}
# return all hiddens and token logits
hiddens_dict = enc_output
hiddens_dict["token_logits"] = token_logits
Expand Down

0 comments on commit aefe1d5

Please sign in to comment.