Skip to content

feat: per-task validation metrics in GRPO/Distillation, optional max_val_samples#2499

Closed
bzantium wants to merge 2 commits into
NVIDIA-NeMo:mainfrom
bzantium:feat/per-task-validation-metrics
Closed

feat: per-task validation metrics in GRPO/Distillation, optional max_val_samples#2499
bzantium wants to merge 2 commits into
NVIDIA-NeMo:mainfrom
bzantium:feat/per-task-validation-metrics

Conversation

@bzantium
Copy link
Copy Markdown

@bzantium bzantium commented May 15, 2026

What does this PR do ?

Closes #2497

Two related improvements to the validate() function in GRPO and Distillation, bundled into one PR because they touch the same function in the same way.

1. Per-task validation metrics

When data.validation is configured as a list of multiple datasets the multi-validation path correctly loads them all and dispatches per-task to the right environment during rollout, but the validation aggregator collapses everything into a single sample-weighted accuracy and avg_length. Per-task progress (e.g. gsm8k vs math500) is silently lost.

task_name is already on every sample (DatumSpec.task_name, preserved through rl_collate_fn into val_batch[\"task_name\"]); validate() simply did not read it. This PR teaches both validate() functions to track rewards per task during the loop and emit:

  • accuracy_<task> for each task seen
  • num_samples_<task> for each task seen

The aggregated accuracy key is preserved unchanged so existing dashboards continue to work. Single-task runs and legacy datasets without task_name are unaffected (the per-task block is skipped).

The driver-log summary also gains a per-task block:

📊 Validation Results:
    • Accuracy: 0.5957
    • Average response length: 512.8 tokens
    • Samples processed: 1819
    • Per-task accuracy:
        - data-math500: 0.4320 (n=500)
        - gsm8k: 0.6580 (n=1319)

2. max_val_samples becomes optional, Distillation truncation matches GRPO

Bundled because the patches sit a few lines apart in the same validate() block.

  • grpo.GRPOConfig.max_val_samples was already typed as int | None (grpo.py:150) for NeMo-Gym compatibility but the main path crashed on None. Distillation's TypedDict required int outright. Both now accept None/absent and fall back to the full val_dataloader.
  • Distillation's truncation switches from ceiling division to floor division, matching GRPO. Behaviour change is bounded to runs whose max_val_samples is not divisible by val_batch_size; all shipped recipes under examples/configs/recipes/llm/ use values that divide cleanly so none are affected.

Files touched

File Change
nemo_rl/algorithms/grpo.py Per-task tracking in validate(). Optional max_val_samples branch. Per-task block in summary print.
nemo_rl/algorithms/distillation.py Same per-task tracking. TypedDict widened to NotRequired[int]. Truncation switched to floor division.
tests/unit/algorithms/test_grpo.py test_validate_emits_per_task_accuracy_keys, test_validate_iterates_full_dataloader_when_max_val_samples_is_none.
tests/unit/algorithms/test_distillation.py Same plus test_validate_floor_divides_max_val_samples_by_val_batch_size to guard the GRPO/Distillation parity.

Out of scope

  • DPO already emits per-dataset metrics via its dict[str, StatefulDataLoader] architecture (see nemo_rl/algorithms/dpo.py:332-377, prefix validation-{dataset_name}). No change needed.
  • NeMo-Gym path in GRPO is left untouched. It clobbers val_batch_size to len(val_dataset) at setup time, which is specific to its single-batch eval and not something the main paths should adopt.
  • Exemplar YAMLs (examples/configs/grpo_math_1B.yaml, examples/configs/distillation_math.yaml) keep their explicit max_val_samples so the recommended default stays documented.

Backwards compatibility

  • Aggregated validation/accuracy key unchanged in value.
  • Single-task validation: per_task_rewards ends up with one key, you get one extra validation/accuracy_<the-only-task> metric. Harmless.
  • Datasets without task_name: per-task block is skipped, behaviour matches the old code.

Issues

Closes #2497.

Usage

A recipe with multiple validation datasets now reports per-task metrics:

data:
  validation:
    - dataset_name: gsm8k
      split: test
    - dataset_name: ResponseDataset
      data_path: data/math500.parquet
distillation:
  val_batch_size: 8
  # max_val_samples omitted -> evaluate the entire val dataset

wandb will plot validation/accuracy, validation/accuracy_gsm8k, and validation/accuracy_ResponseDataset separately. The existing aggregated validation/accuracy panel keeps working unchanged.

Before your PR is "Ready for review"

Pre checks:

  • Make sure you read and followed Contributor guidelines
  • Did you write any new necessary tests?
  • Did you run the unit tests and functional tests locally? Visit our Testing Guide for how to run tests
  • Did you add or update any necessary documentation? Visit our Document Development Guide for how to write, build and test the docs.

Additional Information

  • Followed the config-conventions skill: optional max_val_samples field expressed via NotRequired, no hidden defaults, exemplar YAMLs keep their explicit recommended values.
  • Two commits, one per logical change, to make the per-task and the max_val_samples patches reviewable independently.
  • Open question for reviewers: should task-name slugification be added (e.g. data-math500data_math500)? Hyphens work in wandb but render slightly oddly in some downstream stores. Current patch keeps the name verbatim; happy to slugify if preferred.
  • Branched off origin/main in a separate worktree to keep the change isolated from fix: pass trust_remote_code=True to remaining AutoConfig.from_pretrained sites #2496 (fix/autoconfig-trust-remote-code); the two PRs touch disjoint files.

Two related changes to the validation truncation logic.

1. Make max_val_samples optional. When the field is absent or set to None
   in the recipe, validate() now iterates the entire val_dataloader.
   * GRPO already typed it as `int | None  # None for NeMo-Gym
     compatibility` but the main validation path crashed when reading
     None. Patch the read site so the main path matches the type.
   * Distillation widens the TypedDict from `int` to `NotRequired[int]`
     and applies the same read-site change.
   The exemplar YAMLs (examples/configs/grpo_math_1B.yaml and
   examples/configs/distillation_math.yaml) keep their explicit values
   so the recommended default is still documented.

2. Unify Distillation truncation with GRPO. GRPO uses floor division
   (max_val_samples // val_batch_size); Distillation used ceiling
   division ((max_val_samples + val_batch_size - 1) // val_batch_size).
   With the new None-handling branch already in place, switch
   Distillation to floor division so the two algorithms behave
   identically when the field is set.

Behaviour impact for existing recipes: only Distillation runs whose
max_val_samples is not divisible by val_batch_size see fewer samples
evaluated by one partial batch. Recipes in examples/configs/recipes/llm
all use values that divide cleanly (256/8, 512/8 etc.), so no recipe
under examples/ is affected. Recipes that previously set an integer
that divides cleanly remain identical; recipes that previously omitted
the field could not run at all and now do.

Tests:
* tests/unit/algorithms/test_grpo.py adds
  test_validate_iterates_full_dataloader_when_max_val_samples_is_none
* tests/unit/algorithms/test_distillation.py adds the same plus
  test_validate_floor_divides_max_val_samples_by_val_batch_size to
  guard the GRPO/Distillation parity.

Signed-off-by: Minho Ryu <ryumin93@gmail.com>
@bzantium bzantium requested review from a team as code owners May 15, 2026 05:12
@copy-pr-bot
Copy link
Copy Markdown

copy-pr-bot Bot commented May 15, 2026

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

Multi-validation (data.validation as a list of datasets) currently runs
correctly but the validation aggregator collapses everything into a
single sample-weighted accuracy. Per-task progress (e.g. gsm8k vs
math500) is silently lost.

task_name is already on every sample (DatumSpec.task_name preserved
through rl_collate_fn into val_batch["task_name"]); validate() simply
did not read it.

This commit teaches both validate() functions to track rewards per
task during the loop, then emit accuracy_<task> and num_samples_<task>
keys alongside the existing aggregated accuracy. logger.log_metrics
plots each as its own metric automatically.

The aggregated accuracy key is preserved unchanged for dashboard
backwards compatibility. Datasets without task_name are skipped, so
single-task and legacy recipes behave identically.

DPO already does per-dataset metrics via its dict-of-dataloaders
architecture (see dpo.validate at nemo_rl/algorithms/dpo.py:332-377),
so it is not touched here.

Tests:
* test_grpo.py adds test_validate_emits_per_task_accuracy_keys.
* test_distillation.py adds the same plus a check that the
  aggregated accuracy key matches the sample-weighted mean across
  tasks.

Signed-off-by: Minho Ryu <ryumin93@gmail.com>
@bzantium
Copy link
Copy Markdown
Author

FYI for reviewers: filed #2500 as a follow-up RFC for the larger architectural alignment with DPO's dict-of-dataloaders pattern. This PR (Option A) is the non-breaking quick fix; #2500 is the discussion thread for whether to do the proper refactor afterwards. The two are independent — this PR can land on its own merits.

@bzantium bzantium closed this May 15, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Per-task validation metrics in GRPO/Distillation, and make max_val_samples optional

2 participants