Skip to content

Commit

Permalink
Support a training key and an optionally separate key for inference i…
Browse files Browse the repository at this point in the history
…n hiddens modules

Signed-off-by: John St John <jstjohn@nvidia.com>
  • Loading branch information
jstjohn committed Feb 21, 2024
1 parent 81f56ea commit 8493c96
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 7 deletions.
Expand Up @@ -1068,7 +1068,7 @@ def dummy():
arg_names.append('enc_input')

forward_step_func = self._get_forward_output_only_func(
arg_names=arg_names, output_name="hiddens", output_enc_hidden_only=True
arg_names=arg_names, output_name="enc_output", output_enc_hidden_only=True
)

fwd_bwd_func = get_forward_backward_func()
Expand All @@ -1089,7 +1089,7 @@ def dummy():
)

if output_tensor:
output_tensor = output_tensor[0]['hiddens']
output_tensor = output_tensor[0]['enc_output']
else:
output_tensor = torch.zeros(tensor_shape, dtype=self.autocast_dtype).cuda()

Expand Down
Expand Up @@ -81,7 +81,7 @@ def loss(self, inputs, batch_data=None):
loss = loss.sum(dim=1) / hiddens_mask.sum(dim=1).clamp(min=1.0)

# compute batch level weighted loss (scalar)
weighted_loss = loss.sum() * self.loss_weight
weighted_loss = loss.mean() * self.loss_weight

# store updated losses
loss_dict["loss"] = loss
Expand Down
Expand Up @@ -21,7 +21,7 @@

import functools
import itertools
from typing import List
from typing import List, Optional

import torch
from omegaconf.dictconfig import DictConfig
Expand Down Expand Up @@ -192,13 +192,17 @@ def __init__(
hidden_transforms: List[MegatronBaseHiddenLoss] = [],
hidden_loss_transforms: List[MegatronBaseHiddenTransform] = [],
enc_output_name: str = "hiddens", # name (key) of the encoder output
enc_inference_output_name: Optional[str] = None, # if provided, use different key when self.training is False
tokens_loss_weight: float = 1.0, # weight of the tokens loss
loss_prefix: str = "hiddens_", # if not None or "", add this prefix to all loss names
):
super().__init__()
self.hidden_transforms = hidden_transforms
self.hidden_loss_transforms = hidden_loss_transforms
self.enc_output_name = enc_output_name
self.enc_inference_output_name = (
enc_output_name if enc_inference_output_name is None else enc_inference_output_name
)
self.tokens_loss_weight = tokens_loss_weight
self.loss_prefix = loss_prefix

Expand Down Expand Up @@ -276,9 +280,11 @@ def apply_hidden_transforms(self, inputs, batch_data=None):
# make sure to collect all outputs from hidden transforms
outputs.update(hidden_transform.transform(outputs, batch_data=batch_data))

# update final encoder output
outputs["enc_output"] = outputs[self.enc_output_name]

# update final encoder output. Split into output_name/inference output name to support z vs z_mean for example with VAE style hiddens
if self.training:
outputs["enc_output"] = outputs[self.enc_output_name]
else:
outputs["enc_output"] = outputs[self.enc_inference_output_name]
return outputs

def apply_loss_transforms(self, outputs, batch_data=None):
Expand Down

0 comments on commit 8493c96

Please sign in to comment.