Bug Description
_get_capped_partitions (introduced in #1823) can produce empty micro-batch partitions, causing a ValueError: torch.cat(): expected a non-empty list of Tensors crash in get_batch.
In get_data_iterator, each DP rank independently computes num_microbatches via get_minimum_num_micro_batch_size, then all ranks synchronize to the maximum via dist.all_reduce(..., op=ReduceOp.MAX). This is required because all DP ranks must execute the same number of forward passes for NCCL collectives.
However, when the fallback to _get_capped_partitions is triggered on a rank that needed fewer micro-batches, the first-fit algorithm packs all samples into the first few bins and leaves trailing bins empty ([]). When DataIterator.get_next later serves an empty partition, get_batch calls torch.cat([]) which raises ValueError.
The original get_seqlen_balanced_partitions is immune to this because it guarantees every partition is non-empty via the precondition len(seqlen_list) >= k_partitions and an explicit assertion.
Steps to Reproduce
This is more likely to occur when:
- Using custom rollout functions that produce highly variable sequence lengths
- Using
--partial-rollout which accumulates samples across rounds with diverse lengths
- Using
--balance-data with --use-dynamic-batch-size
- The rollout has a mix of completed (long) and failed/aborted (short) samples across DP ranks
Expected Behavior
Should not crash.
Actual Behavior
Crashed.
Environment
- slime version: 286750a (with a few unrelated mods)
- Python version: 3.12.13
- PyTorch version: 2.9.1+cu129
Logs
ray.exceptions.RayTaskError(ValueError): ray::MegatronTrainRayActor.train() (pid=514256, ip=node-0, actor_id=203ebc9274a8f40229bed6e402000000, repr=<slime.backends.megatron_utils.actor.MegatronTrainRayActor object at 0x752e33fd7bc0>)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/playground/slime/slime/backends/megatron_utils/actor.py", line 376, in train
return self.train_actor(rollout_id, rollout_data)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/playground/slime/slime/backends/megatron_utils/actor.py", line 448, in train_actor
self.compute_log_prob(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/playground/slime/slime/backends/megatron_utils/actor.py", line 354, in compute_log_prob
return forward_only(
^^^^^^^^^^^^^
File "/playground/slime/.venv/lib/python3.12/site-packages/torch/utils/_contextlib.py", line 120, in decorate_context
return func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "/playground/slime/slime/backends/megatron_utils/model.py", line 265, in forward_only
forward_data_store += forward_backward_func(
^^^^^^^^^^^^^^^^^^^^^^
File "/playground/slime/Megatron-LM/megatron/core/pipeline_parallel/schedules.py", line 652, in forward_backward_no_pipelining
output_tensor, num_tokens = forward_step(
^^^^^^^^^^^^^
File "/playground/slime/Megatron-LM/megatron/core/pipeline_parallel/schedules.py", line 417, in forward_step
output_tensor, loss_func = forward_step_func(data_iterator, model)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/playground/slime/slime/backends/megatron_utils/model.py", line 208, in forward_step
batch = get_batch(
^^^^^^^^^^
File "/playground/slime/slime/backends/megatron_utils/data.py", line 104, in get_batch
tokens = torch.cat(tokens)
^^^^^^^^^^^^^^^^^
ValueError: torch.cat(): expected a non-empty list of Tensors
Additional Context
No response
Pre-submission Checklist
Bug Description
_get_capped_partitions(introduced in #1823) can produce empty micro-batch partitions, causing aValueError: torch.cat(): expected a non-empty list of Tensorscrash inget_batch.In
get_data_iterator, each DP rank independently computesnum_microbatchesviaget_minimum_num_micro_batch_size, then all ranks synchronize to the maximum viadist.all_reduce(..., op=ReduceOp.MAX). This is required because all DP ranks must execute the same number of forward passes for NCCL collectives.However, when the fallback to
_get_capped_partitionsis triggered on a rank that needed fewer micro-batches, the first-fit algorithm packs all samples into the first few bins and leaves trailing bins empty ([]). WhenDataIterator.get_nextlater serves an empty partition,get_batchcallstorch.cat([])which raisesValueError.The original
get_seqlen_balanced_partitionsis immune to this because it guarantees every partition is non-empty via the preconditionlen(seqlen_list) >= k_partitionsand an explicit assertion.Steps to Reproduce
This is more likely to occur when:
--partial-rolloutwhich accumulates samples across rounds with diverse lengths--balance-datawith--use-dynamic-batch-sizeExpected Behavior
Should not crash.
Actual Behavior
Crashed.
Environment
Logs
Additional Context
No response
Pre-submission Checklist