Skip to content

fix: batch_size must be multiple of num_generations, pad dataset#244

Merged
abrichr merged 1 commit into
mainfrom
fix/trl-batch-size-divisibility
Mar 29, 2026
Merged

fix: batch_size must be multiple of num_generations, pad dataset#244
abrichr merged 1 commit into
mainfrom
fix/trl-batch-size-divisibility

Conversation

@abrichr
Copy link
Copy Markdown
Member

@abrichr abrichr commented Mar 29, 2026

Summary

Supersedes #240. TRL requires generation_batch_size % num_generations == 0. The previous fix (batch_size=1) violated this with num_generations=4.

Fix:

  1. Set per_device_train_batch_size = num_generations (minimum valid value)
  2. Pad dataset by repeating tasks if len(dataset) < batch_size

With client's config (1 task, num_generations=4):

  • Dataset padded from 1 → 4 rows (same task repeated)
  • batch_size=4, generation_batch_size=4
  • 4 % 4 == 0 ✓
  • Each step: 4 prompts × 4 rollouts = 16 total rollouts

Repeating tasks is fine for RL — same task with many rollouts = more learning signal per step.

Test plan

  • 27 TRL tests pass
  • Client re-test with num_rollouts_per_step=4

🤖 Generated with Claude Code

…eeded

TRL requires generation_batch_size % num_generations == 0. With
batch_size=1 and num_generations=4, TRL rejects it. Fix:

1. Set per_device_train_batch_size = num_generations (minimum valid)
2. Pad dataset by repeating tasks if len(dataset) < batch_size

With 1 task and num_generations=4: dataset padded to 4 rows,
batch_size=4, generation_batch_size=4, 4 % 4 == 0 ✓

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
@abrichr abrichr merged commit d6e1b5b into main Mar 29, 2026
1 check passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant