Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 16 additions & 1 deletion examples/speculative_decoding/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,8 +212,23 @@ def train():
"Either data.data_path or data.offline_data_path must be set in the config."
)
if training_args.cp_size > 1 or training_args.dp_shard_size > 1:
# Auto-compute dp_replicate_size so that
# dp_replicate_size * dp_shard_size * cp_size == world_size.
# Note: torch.cuda.device_count() returns per-node GPU count, not world_size.
# WORLD_SIZE (set by torchrun/accelerate) gives the correct multi-node total.
world_size = int(os.environ.get("WORLD_SIZE", torch.cuda.device_count()))
parallel_size = training_args.dp_shard_size * training_args.cp_size
if world_size % parallel_size != 0:
raise ValueError(
f"world_size ({world_size}) must be divisible by "
f"dp_shard_size ({training_args.dp_shard_size}) * cp_size ({training_args.cp_size}) "
f"= {parallel_size}"
)
dp_replicate_size = world_size // parallel_size
training_args.parallelism_config = ParallelismConfig(
cp_size=training_args.cp_size, dp_shard_size=training_args.dp_shard_size
cp_size=training_args.cp_size,
dp_shard_size=training_args.dp_shard_size,
dp_replicate_size=dp_replicate_size,
Comment thread
coderabbitai[bot] marked this conversation as resolved.
)
if training_args.cp_size > 1:
patch_ring_attention_for_ttt()
Expand Down
Loading