Skip to content

Conversation

@NuojCheng
Copy link
Collaborator

@NuojCheng NuojCheng commented Dec 22, 2025

Most changes in this PR are in bfa20a4

Description

Add sharding debug feature printing weight sharding details . For example, when using deepseek-test with tp=2 ep=2 and set debug_sharding=True, you get

tensor: 2
expert: 2
params/decoder/decoder_norm/scale............................................... float32[512] PartitionSpec('tensor',)
params/decoder/dense_layers/mlp/wi_0/kernel..................................... float32[512,3,18432] PartitionSpec('expert', None, 'tensor')
params/decoder/dense_layers/mlp/wi_1/kernel..................................... float32[512,3,18432] PartitionSpec('expert', None, 'tensor')
params/decoder/dense_layers/mlp/wo/kernel....................................... float32[18432,3,512] PartitionSpec('tensor', None, 'expert')
params/decoder/dense_layers/post_self_attention_layer_norm/scale................ float32[512,3] PartitionSpec('tensor', None)
params/decoder/dense_layers/pre_self_attention_layer_norm/scale................. float32[512,3] PartitionSpec('tensor', None)
params/decoder/dense_layers/self_attention/kv_norm/scale........................ float32[512,3] PartitionSpec('tensor', None)
params/decoder/dense_layers/self_attention/out/kernel........................... float32[128,3,128,512] PartitionSpec('tensor', None, None, 'expert')
params/decoder/dense_layers/self_attention/q_norm/scale......................... float32[1536,3] PartitionSpec('tensor', None)
params/decoder/dense_layers/self_attention/wkv_a/kernel......................... float32[512,3,576] PartitionSpec('expert', None, None)
params/decoder/dense_layers/self_attention/wkv_b/kernel......................... float32[512,3,128,256] PartitionSpec('expert', None, 'tensor', None)
params/decoder/dense_layers/self_attention/wq_a/kernel.......................... float32[512,3,1536] PartitionSpec('expert', None, None)
params/decoder/dense_layers/self_attention/wq_b/kernel.......................... float32[1536,3,128,192] PartitionSpec('expert', None, 'tensor', None)
params/decoder/logits_dense/kernel.............................................. float32[512,129280] PartitionSpec('expert', 'tensor')
params/decoder/moe_layers/DeepSeekMoeBlock_0/MoeBlock_0/gate/bias............... float32[256,58] PartitionSpec(None, None)
params/decoder/moe_layers/DeepSeekMoeBlock_0/MoeBlock_0/gate/kernel............. float32[512,58,256] PartitionSpec('expert', None, None)
params/decoder/moe_layers/DeepSeekMoeBlock_0/MoeBlock_0/wi_0.................... float32[256,58,512,512] PartitionSpec('expert', None, None, 'tensor')
params/decoder/moe_layers/DeepSeekMoeBlock_0/MoeBlock_0/wi_1.................... float32[256,58,512,512] PartitionSpec('expert', None, None, 'tensor')
params/decoder/moe_layers/DeepSeekMoeBlock_0/MoeBlock_0/wo...................... float32[256,58,512,512] PartitionSpec('expert', None, 'tensor', None)
params/decoder/moe_layers/DeepSeekMoeBlock_0/shared_experts/wi_0/kernel......... float32[512,58,512] PartitionSpec('expert', None, 'tensor')
params/decoder/moe_layers/DeepSeekMoeBlock_0/shared_experts/wi_1/kernel......... float32[512,58,512] PartitionSpec('expert', None, 'tensor')
params/decoder/moe_layers/DeepSeekMoeBlock_0/shared_experts/wo/kernel........... float32[512,58,512] PartitionSpec('tensor', None, 'expert')
params/decoder/moe_layers/post_self_attention_layer_norm/scale.................. float32[512,58] PartitionSpec('tensor', None)
params/decoder/moe_layers/pre_self_attention_layer_norm/scale................... float32[512,58] PartitionSpec('tensor', None)
params/decoder/moe_layers/self_attention/kv_norm/scale.......................... float32[512,58] PartitionSpec('tensor', None)
params/decoder/moe_layers/self_attention/out/kernel............................. float32[128,58,128,512] PartitionSpec('tensor', None, None, 'expert')
params/decoder/moe_layers/self_attention/q_norm/scale........................... float32[1536,58] PartitionSpec('tensor', None)
params/decoder/moe_layers/self_attention/wkv_a/kernel........................... float32[512,58,576] PartitionSpec('expert', None, None)
params/decoder/moe_layers/self_attention/wkv_b/kernel........................... float32[512,58,128,256] PartitionSpec('expert', None, 'tensor', None)
params/decoder/moe_layers/self_attention/wq_a/kernel............................ float32[512,58,1536] PartitionSpec('expert', None, None)
params/decoder/moe_layers/self_attention/wq_b/kernel............................ float32[1536,58,128,192] PartitionSpec('expert', None, 'tensor', None)
params/token_embedder/embedding................................................. float32[129280,512] PartitionSpec('tensor', 'expert')

To easily check sharding shapes in large scale, we also support sharding debug using train_compile.py (CPU only!), using following command

python -m MaxText.train_compile MaxText/configs/base.yml compile_topology=v5p-1024 compile_topology_num_slices=1 model_name=deepseek3-671b per_device_batch_size=1  ici_tensor_parallelism=8 ici_expert_parallelism=8  log_config=false debug_sharding=true

And the output is

fsdp: 8
tensor: 8
expert: 8
params/decoder/decoder_norm/scale............................................... float32[7168] PartitionSpec('tensor',)
params/decoder/dense_layers/mlp/wi_0/kernel..................................... float32[7168,3,18432] PartitionSpec(('fsdp', 'expert'), None, 'tensor')
params/decoder/dense_layers/mlp/wi_1/kernel..................................... float32[7168,3,18432] PartitionSpec(('fsdp', 'expert'), None, 'tensor')
params/decoder/dense_layers/mlp/wo/kernel....................................... float32[18432,3,7168] PartitionSpec('tensor', None, ('fsdp', 'expert'))
params/decoder/dense_layers/post_self_attention_layer_norm/scale................ float32[7168,3] PartitionSpec('tensor', None)
params/decoder/dense_layers/pre_self_attention_layer_norm/scale................. float32[7168,3] PartitionSpec('tensor', None)
params/decoder/dense_layers/self_attention/kv_norm/scale........................ float32[512,3] PartitionSpec('tensor', None)
params/decoder/dense_layers/self_attention/out/kernel........................... float32[128,3,128,7168] PartitionSpec('tensor', None, None, ('fsdp', 'expert'))
params/decoder/dense_layers/self_attention/q_norm/scale......................... float32[1536,3] PartitionSpec('tensor', None)
params/decoder/dense_layers/self_attention/wkv_a/kernel......................... float32[7168,3,576] PartitionSpec(('fsdp', 'expert'), None, None)
params/decoder/dense_layers/self_attention/wkv_b/kernel......................... float32[512,3,128,256] PartitionSpec(('fsdp', 'expert'), None, 'tensor', None)
params/decoder/dense_layers/self_attention/wq_a/kernel.......................... float32[7168,3,1536] PartitionSpec(('fsdp', 'expert'), None, None)
params/decoder/dense_layers/self_attention/wq_b/kernel.......................... float32[1536,3,128,192] PartitionSpec(('fsdp', 'expert'), None, 'tensor', None)
params/decoder/logits_dense/kernel.............................................. float32[7168,129280] PartitionSpec(('fsdp', 'expert'), 'tensor')
params/decoder/moe_layers/DeepSeekMoeBlock_0/MoeBlock_0/gate/bias............... float32[256,58] PartitionSpec(None, None)
params/decoder/moe_layers/DeepSeekMoeBlock_0/MoeBlock_0/gate/kernel............. float32[7168,58,256] PartitionSpec(('fsdp', 'expert'), None, None)
params/decoder/moe_layers/DeepSeekMoeBlock_0/MoeBlock_0/wi_0.................... float32[256,58,7168,2048] PartitionSpec('expert', None, 'fsdp', 'tensor')
params/decoder/moe_layers/DeepSeekMoeBlock_0/MoeBlock_0/wi_1.................... float32[256,58,7168,2048] PartitionSpec('expert', None, 'fsdp', 'tensor')
params/decoder/moe_layers/DeepSeekMoeBlock_0/MoeBlock_0/wo...................... float32[256,58,2048,7168] PartitionSpec('expert', None, 'tensor', 'fsdp')
params/decoder/moe_layers/DeepSeekMoeBlock_0/shared_experts/wi_0/kernel......... float32[7168,58,2048] PartitionSpec(('fsdp', 'expert'), None, 'tensor')
params/decoder/moe_layers/DeepSeekMoeBlock_0/shared_experts/wi_1/kernel......... float32[7168,58,2048] PartitionSpec(('fsdp', 'expert'), None, 'tensor')
params/decoder/moe_layers/DeepSeekMoeBlock_0/shared_experts/wo/kernel........... float32[2048,58,7168] PartitionSpec('tensor', None, ('fsdp', 'expert'))
params/decoder/moe_layers/post_self_attention_layer_norm/scale.................. float32[7168,58] PartitionSpec('tensor', None)
params/decoder/moe_layers/pre_self_attention_layer_norm/scale................... float32[7168,58] PartitionSpec('tensor', None)
params/decoder/moe_layers/self_attention/kv_norm/scale.......................... float32[512,58] PartitionSpec('tensor', None)
params/decoder/moe_layers/self_attention/out/kernel............................. float32[128,58,128,7168] PartitionSpec('tensor', None, None, ('fsdp', 'expert'))
params/decoder/moe_layers/self_attention/q_norm/scale........................... float32[1536,58] PartitionSpec('tensor', None)
params/decoder/moe_layers/self_attention/wkv_a/kernel........................... float32[7168,58,576] PartitionSpec(('fsdp', 'expert'), None, None)
params/decoder/moe_layers/self_attention/wkv_b/kernel........................... float32[512,58,128,256] PartitionSpec(('fsdp', 'expert'), None, 'tensor', None)
params/decoder/moe_layers/self_attention/wq_a/kernel............................ float32[7168,58,1536] PartitionSpec(('fsdp', 'expert'), None, None)
params/decoder/moe_layers/self_attention/wq_b/kernel............................ float32[1536,58,128,192] PartitionSpec(('fsdp', 'expert'), None, 'tensor', None)
params/token_embedder/embedding................................................. float32[129280,7168] PartitionSpec('tensor', ('fsdp', 'expert'))

Tests

Please describe how you tested this change, and include any instructions and/or
commands to reproduce.

Checklist

Before submitting this PR, please make sure (put X in square brackets):

  • I have performed a self-review of my code. For an optional AI review, add the gemini-review label.
  • I have necessary comments in my code, particularly in hard-to-understand areas.
  • I have run end-to-end tests tests and provided workload links above if applicable.
  • I have made or will make corresponding changes to the doc if needed, including adding new documentation pages to the relevant Table of Contents (toctree directive) as explained in our documentation.

@NuojCheng NuojCheng added stale Automatically applied to stale PRs. draft Draft PR and removed stale Automatically applied to stale PRs. labels Dec 22, 2025
@NuojCheng NuojCheng force-pushed the chengnuojin-sharding-debug branch 2 times, most recently from 47c7852 to fa814f7 Compare December 22, 2025 18:14
@NuojCheng NuojCheng marked this pull request as ready for review December 22, 2025 18:16
@NuojCheng NuojCheng force-pushed the chengnuojin-sharding-debug branch 2 times, most recently from 16f1731 to e5229d4 Compare December 22, 2025 19:26
@NuojCheng NuojCheng force-pushed the chengnuojin-sharding-debug branch 2 times, most recently from 4755650 to 6c94fcd Compare December 22, 2025 19:47
@NuojCheng NuojCheng added pull ready and removed draft Draft PR labels Dec 22, 2025
@NuojCheng NuojCheng force-pushed the chengnuojin-sharding-debug branch from 6c94fcd to 0627ad6 Compare December 22, 2025 20:39
copybara-service bot pushed a commit that referenced this pull request Dec 22, 2025
--
6c94fcd by NuojCheng <chengnuojin@google.com>:

add sharding debug feature

COPYBARA_INTEGRATE_REVIEW=#2866 from AI-Hypercomputer:chengnuojin-sharding-debug 6c94fcd
PiperOrigin-RevId: 847853078
@NuojCheng NuojCheng force-pushed the chengnuojin-sharding-debug branch from 0627ad6 to 79afcb8 Compare December 22, 2025 20:45
@NuojCheng NuojCheng force-pushed the chengnuojin-sharding-debug branch from 79afcb8 to 36029ce Compare December 22, 2025 20:46
@copybara-service copybara-service bot merged commit b2b7d8f into main Dec 22, 2025
22 checks passed
@copybara-service copybara-service bot deleted the chengnuojin-sharding-debug branch December 22, 2025 21:36
@Shuang-cnt Shuang-cnt mentioned this pull request Jan 21, 2026
4 tasks
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants