Skip to content

Commit

Permalink
fix ptl issue #18803
Browse files Browse the repository at this point in the history
Signed-off-by: stevehuang52 <heh@nvidia.com>
  • Loading branch information
stevehuang52 committed Dec 27, 2023
1 parent f97c901 commit fcc0f9f
Showing 1 changed file with 4 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -1169,7 +1169,7 @@ def inference_epoch_end(self, outputs, mode, data_cfg):
if metric_name == 'bleu':
metric_result = torch.Tensor(
[sacrebleu.corpus_bleu(deduplicated_outputs['preds'], [labels]).score]
)
).to(self.device)
else:
for pred, label in zip(deduplicated_outputs['preds'], labels):
_ = metric_fn(pred, label)
Expand Down Expand Up @@ -1212,6 +1212,9 @@ def inference_epoch_end(self, outputs, mode, data_cfg):
# Logging of the averaged metrics:
averaged_loss = sum(averaged_loss) / len(averaged_loss)
averaged_metric = sum(averaged_metric) / len(averaged_metric) if len(averaged_metric) > 0 else None
averaged_loss = averaged_loss.to(self.device)
if averaged_metric is not None:
averaged_metric = averaged_metric.to(self.device)

# Handle case where metrics can be nan or inf. This can break checkpoint save/load.
if averaged_metric is not None and (torch.isinf(averaged_metric) or torch.isnan(averaged_metric)):
Expand Down

0 comments on commit fcc0f9f

Please sign in to comment.