Skip to content

fix: guard against zero label tokens causing NaN loss in VLM training#1985

Merged
HuiyingLi merged 1 commit into
NVIDIA-NeMo:mainfrom
khazic:fix/vlm-nan-sdpa-empty-supervision
Apr 23, 2026
Merged

fix: guard against zero label tokens causing NaN loss in VLM training#1985
HuiyingLi merged 1 commit into
NVIDIA-NeMo:mainfrom
khazic:fix/vlm-nan-sdpa-empty-supervision

Conversation

@khazic
Copy link
Copy Markdown
Contributor

@khazic khazic commented Apr 22, 2026

What does this PR do ?

Add a defensive guard against division by zero in `MaskedCrossEntropy` when `num_label_tokens=0`.

Changelog

  • `masked_ce.py`: return `loss * 0.0` instead of dividing by zero when `num_label_tokens=0`
  • `finetune.py`: guard PP reporting loss normalization against `num_label_tokens=0`
  • `test_masked_ce.py`: add regression test for the empty-supervision case

Before your PR is "Ready for review"

Pre checks:

  • Make sure you read and followed Contributor guidelines
  • Did you write any new necessary tests?
  • Did you add or update any necessary documentation?

Additional Information

Related to #1883.

I attempted to reproduce the SDPA NaN described in #1883 using the exact repro command on transformers 5.5.0 / PyTorch 2.11.0 (CUDA 13.1) / 8xH100, but could not reproduce it. The issue was filed against `transformers==5.5.0.dev0`, and I believe the underlying SDPA masking bug has since been fixed in the stable 5.5.0 release.

While investigating, I noticed that `MaskedCrossEntropy` has no guard for `num_label_tokens=0`, which produces `NaN` via division by zero. In the multi-GPU training path, `num_label_tokens` is all-reduced across DP ranks, so hitting zero in practice would require every sample across every rank to have no valid labels simultaneously -- extremely unlikely. However, two paths are genuinely exposed:

  1. Validation loop: `num_label_tokens` is computed per-batch with no all-reduce. If a single validation batch has all labels set to `-100` (e.g. due to label-building failure, as already noted in the `collate_fns.py` comment: "may produce nan loss"), `MaskedCrossEntropy` produces NaN, which propagates into the reported validation loss.
  2. Single-GPU training: the all-reduce is a no-op (world size = 1), so a single bad sample is enough to trigger it.

This PR adds a minimal defensive guard so that an empty-supervision batch contributes zero loss instead of NaN, keeping training and validation metrics clean.

When all labels in a batch are -100 (empty supervision), num_label_tokens
is 0, causing division by zero and NaN loss that corrupts training.

- masked_ce.py: return 0.0 instead of dividing by zero
- finetune.py: guard PP reporting_loss normalization against zero
- test_masked_ce.py: add regression test for the empty-supervision case

Signed-off-by: khazic <khazzz1c@gmail.com>
@copy-pr-bot
Copy link
Copy Markdown

copy-pr-bot Bot commented Apr 22, 2026

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

@khazic
Copy link
Copy Markdown
Contributor Author

khazic commented Apr 22, 2026

Additional reproduction note

While the trigger condition is unlikely with the default model_max_length=131072, we confirmed the bug fires in practice under high-resolution multi-image inputs combined with a shorter max_length setting.

In that scenario, the image tokens occupy most of the sequence budget and right-truncation completely removes the assistant response. The resulting sample has all labels set to -100, so num_label_tokens=0 in the validation loop (which has no all-reduce guard). MaskedCrossEntropy then divides by zero, producing loss: tensor(nan), and the NaN propagates into the reported validation loss via total_loss += local_loss.item() * num_label_tokens (nan * 0 = nan in Python float arithmetic).

Multi-GPU training is not affected in practice because num_label_tokens is all-reduced across DP ranks there — a single bad sample cannot zero out the global count. The exposed paths are the validation loop (per-batch count, no all-reduce) and single-GPU training (all-reduce is a no-op).

@HuiyingLi
Copy link
Copy Markdown
Contributor

/ok to test 4011752

1 similar comment
@HuiyingLi
Copy link
Copy Markdown
Contributor

/ok to test 4011752

Copy link
Copy Markdown
Contributor

@athitten athitten left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, TYSM @khazic !

@HuiyingLi HuiyingLi enabled auto-merge (squash) April 22, 2026 18:48
@svcnvidia-nemo-ci svcnvidia-nemo-ci added the waiting-on-customer Waiting on the original author to respond label Apr 22, 2026
@HuiyingLi HuiyingLi disabled auto-merge April 23, 2026 00:05
@HuiyingLi HuiyingLi merged commit 5876c8a into NVIDIA-NeMo:main Apr 23, 2026
58 of 59 checks passed
@HuiyingLi HuiyingLi mentioned this pull request Apr 23, 2026
4 tasks
HuiyingLi added a commit that referenced this pull request Apr 23, 2026
Adds two regression tests for _run_train_optim_step with pp_enabled=True,
covering the num_label_tokens=0 guard added in #1985 (finetune.py:1142)
and the standard num_label_tokens>0 division branch. Neither branch had
prior coverage since no existing test exercised _run_train_optim_step
with pipeline parallelism enabled.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Signed-off-by: HuiyingLi <willwin.lee@gmail.com>
@svcnvidia-nemo-ci svcnvidia-nemo-ci removed the waiting-on-customer Waiting on the original author to respond label Apr 23, 2026
HuiyingLi added a commit that referenced this pull request Apr 23, 2026
test: cover PP reporting loss guard for zero label tokens

Adds two regression tests for _run_train_optim_step with pp_enabled=True,
covering the num_label_tokens=0 guard added in #1985 (finetune.py:1142)
and the standard num_label_tokens>0 division branch. Neither branch had
prior coverage since no existing test exercised _run_train_optim_step
with pipeline parallelism enabled.

Signed-off-by: HuiyingLi <willwin.lee@gmail.com>
Co-authored-by: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
linnanwang pushed a commit that referenced this pull request Apr 24, 2026
…#1985)

When all labels in a batch are -100 (empty supervision), num_label_tokens
is 0, causing division by zero and NaN loss that corrupts training.

- masked_ce.py: return 0.0 instead of dividing by zero
- finetune.py: guard PP reporting_loss normalization against zero
- test_masked_ce.py: add regression test for the empty-supervision case

Signed-off-by: khazic <khazzz1c@gmail.com>
linnanwang pushed a commit that referenced this pull request Apr 24, 2026
test: cover PP reporting loss guard for zero label tokens

Adds two regression tests for _run_train_optim_step with pp_enabled=True,
covering the num_label_tokens=0 guard added in #1985 (finetune.py:1142)
and the standard num_label_tokens>0 division branch. Neither branch had
prior coverage since no existing test exercised _run_train_optim_step
with pipeline parallelism enabled.

Signed-off-by: HuiyingLi <willwin.lee@gmail.com>
Co-authored-by: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants