From 83226f41c20e6f0fea00740d5968b88243161544 Mon Sep 17 00:00:00 2001 From: Ye Yu Date: Mon, 20 Apr 2026 12:28:03 -0700 Subject: [PATCH 1/2] fix: auto-compute dp_replicate_size from world_size in ParallelismConfig When dp_shard_size < world_size (e.g., dp_shard_size=4 on 8 GPUs), ParallelismConfig raises "total_size does not match num_processes" because dp_replicate_size defaults to 1. Auto-compute dp_replicate_size = world_size // (dp_shard_size * cp_size) so that intra-node FSDP2 sharding + inter-node data-parallel replication works without requiring users to manually set dp_replicate_size. Co-Authored-By: Claude Opus 4.6 Signed-off-by: Ye Yu --- examples/speculative_decoding/main.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/examples/speculative_decoding/main.py b/examples/speculative_decoding/main.py index efc4ba82bd..530c68574a 100644 --- a/examples/speculative_decoding/main.py +++ b/examples/speculative_decoding/main.py @@ -212,8 +212,15 @@ def train(): "Either data.data_path or data.offline_data_path must be set in the config." ) if training_args.cp_size > 1 or training_args.dp_shard_size > 1: + # Auto-compute dp_replicate_size so that + # dp_replicate_size * dp_shard_size * cp_size == world_size. + world_size = int(os.environ.get("WORLD_SIZE", torch.cuda.device_count())) + parallel_size = training_args.dp_shard_size * training_args.cp_size + dp_replicate_size = world_size // parallel_size training_args.parallelism_config = ParallelismConfig( - cp_size=training_args.cp_size, dp_shard_size=training_args.dp_shard_size + cp_size=training_args.cp_size, + dp_shard_size=training_args.dp_shard_size, + dp_replicate_size=dp_replicate_size, ) if training_args.cp_size > 1: patch_ring_attention_for_ttt() From d9bb6c402f4a9effbb372c0d2a27c2a40a8f07ea Mon Sep 17 00:00:00 2001 From: Ye Yu Date: Mon, 20 Apr 2026 12:50:37 -0700 Subject: [PATCH 2/2] fix: add divisibility guard and clarify WORLD_SIZE fallback Address review feedback: - Add ValueError if world_size is not divisible by dp_shard_size * cp_size - Comment that torch.cuda.device_count() is per-node, not world_size Co-Authored-By: Claude Opus 4.6 Signed-off-by: Ye Yu --- examples/speculative_decoding/main.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/examples/speculative_decoding/main.py b/examples/speculative_decoding/main.py index 530c68574a..31c73d0427 100644 --- a/examples/speculative_decoding/main.py +++ b/examples/speculative_decoding/main.py @@ -214,8 +214,16 @@ def train(): if training_args.cp_size > 1 or training_args.dp_shard_size > 1: # Auto-compute dp_replicate_size so that # dp_replicate_size * dp_shard_size * cp_size == world_size. + # Note: torch.cuda.device_count() returns per-node GPU count, not world_size. + # WORLD_SIZE (set by torchrun/accelerate) gives the correct multi-node total. world_size = int(os.environ.get("WORLD_SIZE", torch.cuda.device_count())) parallel_size = training_args.dp_shard_size * training_args.cp_size + if world_size % parallel_size != 0: + raise ValueError( + f"world_size ({world_size}) must be divisible by " + f"dp_shard_size ({training_args.dp_shard_size}) * cp_size ({training_args.cp_size}) " + f"= {parallel_size}" + ) dp_replicate_size = world_size // parallel_size training_args.parallelism_config = ParallelismConfig( cp_size=training_args.cp_size,