From 5dd1039dc860b655c2d10b5a68ecabad60062d19 Mon Sep 17 00:00:00 2001 From: NuojCheng Date: Mon, 13 Apr 2026 22:44:34 +0000 Subject: [PATCH 1/3] reorder logical rule and add embed_vocab --- src/maxtext/configs/custom_mesh_and_rule/ep-as-cp.yml | 0 1 file changed, 0 insertions(+), 0 deletions(-) create mode 100644 src/maxtext/configs/custom_mesh_and_rule/ep-as-cp.yml diff --git a/src/maxtext/configs/custom_mesh_and_rule/ep-as-cp.yml b/src/maxtext/configs/custom_mesh_and_rule/ep-as-cp.yml new file mode 100644 index 0000000000..e69de29bb2 From 38326d417e30917e4aa27287020a13466038a517 Mon Sep 17 00:00:00 2001 From: NuojCheng Date: Tue, 14 Apr 2026 00:11:12 +0000 Subject: [PATCH 2/3] add ep-as-cp custom rule --- src/maxtext/configs/base.yml | 1 + .../configs/custom_mesh_and_rule/ep-as-cp.yml | 59 + src/maxtext/configs/types.py | 10 +- src/maxtext/layers/attention_op.py | 4 +- src/maxtext/utils/train_utils.py | 11 +- tests/utils/sharding_dump.py | 7 + .../input_shardings.json | 148 ++ .../logical_shardings.json | 980 ++++++++ .../named_shardings.json | 2132 +++++++++++++++++ 9 files changed, 3341 insertions(+), 11 deletions(-) create mode 100644 tests/utils/sharding_info/deepseek2-16b/tpu7x-8/slice_1/rule_ep-as-cp_ici_fsdp_parallelism=-1_ici_expert_parallelism=2/input_shardings.json create mode 100644 tests/utils/sharding_info/deepseek2-16b/tpu7x-8/slice_1/rule_ep-as-cp_ici_fsdp_parallelism=-1_ici_expert_parallelism=2/logical_shardings.json create mode 100644 tests/utils/sharding_info/deepseek2-16b/tpu7x-8/slice_1/rule_ep-as-cp_ici_fsdp_parallelism=-1_ici_expert_parallelism=2/named_shardings.json diff --git a/src/maxtext/configs/base.yml b/src/maxtext/configs/base.yml index 0fe9899778..44d7bde339 100644 --- a/src/maxtext/configs/base.yml +++ b/src/maxtext/configs/base.yml @@ -569,6 +569,7 @@ logical_axis_rules: [ # Axes used for DCN must be earlier in this list than ICI, see (b/339009148) for details data_sharding: [['data', 'stage', 'fsdp', 'fsdp_transpose', 'sequence', 'context', 'context_autoregressive', 'tensor', 'tensor_transpose', 'tensor_sequence', 'expert', 'autoregressive']] input_data_sharding_logical_axes: ['activation_embed_and_logits_batch', 'activation_norm_length'] +context_sharding: "context" # sharding tolerance: float between 0.0 and 1.0 representing the allowed percentage of non-sharded parameters. sharding_tolerance: 0.02 diff --git a/src/maxtext/configs/custom_mesh_and_rule/ep-as-cp.yml b/src/maxtext/configs/custom_mesh_and_rule/ep-as-cp.yml index e69de29bb2..00969328ca 100644 --- a/src/maxtext/configs/custom_mesh_and_rule/ep-as-cp.yml +++ b/src/maxtext/configs/custom_mesh_and_rule/ep-as-cp.yml @@ -0,0 +1,59 @@ +mesh_axes: ['data', 'fsdp', 'expert'] +data_sharding: [['data', 'fsdp', 'expert']] +context_sharding: 'expert' +logical_axis_rules: [ + # ========================================== + # Vocabulary Embedding + # ========================================== + # Vocab Activations + ['activation_embed_and_logits_batch', ['data', 'fsdp']], + ['activation_embed_and_logits_batch_sequence', ['data', 'fsdp', 'expert']], + # Vocab Weights + ['vocab', []], + ['embed_vocab', ['fsdp', 'expert']], + # ========================================== + # Attention + # ========================================== + # Attention Activations + ['activation_heads', []], + ['activation_kv_heads', []], + ['activation_attn_length', ['expert']], + ['activation_q_length', ['expert']], + ['activation_kv_length', []], + ['activation_attn_embed', []], + ['activation_kv', []], + ['activation_kv_batch', ['data', 'fsdp']], + ['activation_kv_head_dim', []], + # Attention Weights + ['heads', []], + ['q_heads', []], + ['kv_heads', []], + ['qkv', []], + ['kv', []], + ['kv_head_dim', []], + ['q_lora', ['fsdp', 'expert']], + ["q_lora_up_proj", []], + ['kv_lora', ['fsdp', 'expert']], + ["kv_lora_up_proj", []], + # ========================================== + # Mixture of Experts (MoE) + # ========================================== + # MoE Activations + ['activation_batch_moe', ['data', 'fsdp']], + ['activation_exp', ['expert']], + # MoE Weights + ['exp', 'expert'], + ['embed_moe', ['fsdp']], + # ========================================== + # Standard MLP / Dense Layers / Model Structure + # ========================================== + # Dense Activations + ['activation_mlp', []], + ['activation_batch', ['data', 'fsdp']], + ['activation_length', ['expert']], + ['activation_norm_length', ['expert']], + ['activation_embed', []], + # General Weights + ['mlp', []], + ['embed', ['fsdp', 'expert']], + ] \ No newline at end of file diff --git a/src/maxtext/configs/types.py b/src/maxtext/configs/types.py index c734ac2f87..22765c1843 100644 --- a/src/maxtext/configs/types.py +++ b/src/maxtext/configs/types.py @@ -842,6 +842,7 @@ class LayoutAndSharding(BaseModel): logical_axis_rules: Any = Field([], description="Rules for mapping logical axes to physical mesh axes.") data_sharding: Any = Field([], description="Sharding for input data.") + context_sharding: str = Field("context", description="Physical axis name for context parallelism.") input_data_sharding_logical_axes: list[str] = Field( ["activation_embed_and_logits_batch", "activation_norm_length"], description="Logical axes for sharding input data.", @@ -2116,6 +2117,8 @@ def set_derived_and_validate_values(self) -> "MaxTextConfig": self.logical_axis_rules = custom_mesh_config["logical_axis_rules"] if "data_sharding" in custom_mesh_config: self.data_sharding = custom_mesh_config["data_sharding"] + if "context_sharding" in custom_mesh_config: + self.context_sharding = custom_mesh_config["context_sharding"] else: raise NotImplementedError(f"Custom mesh config file not found at {custom_mesh_path}") @@ -2398,10 +2401,9 @@ def calculate_global_batch_sizes(per_device_batch_size, expansion_factor, num_de self.tensors_on_device = [t for t in tensors if getattr(self, t) == "device"] self.tensors_to_offload = [t for t in tensors if getattr(self, t) == "offload"] - cp_size = self.ici_context_parallelism * self.dcn_context_parallelism - if self.expert_shard_attention_option == "context": - cp_size *= self.ici_expert_parallelism * self.dcn_expert_parallelism - self.context_parallel_size = cp_size + self.context_parallel_size = getattr(self, f"ici_{self.context_sharding}_parallelism", 1) * getattr( + self, f"dcn_{self.context_sharding}_parallelism", 1 + ) if self.pipeline_parallel_layers == -1: if self.decoder_block == DecoderBlockType.DEEPSEEK: moe_layers = self.num_decoder_layers - self.first_num_dense_layers diff --git a/src/maxtext/layers/attention_op.py b/src/maxtext/layers/attention_op.py index f634e176d7..ddc4fff9a1 100644 --- a/src/maxtext/layers/attention_op.py +++ b/src/maxtext/layers/attention_op.py @@ -1517,7 +1517,7 @@ def cudnn_flash_attention( _, _, _, head_dim = query.shape # pylint: disable=unused-variable - using_context_parallelism = self.mesh.shape["context"] > 1 + using_context_parallelism = self.mesh.shape[self.config.context_sharding] > 1 # Initialize default attention configuration sliding_window_size = None @@ -1575,7 +1575,7 @@ def cudnn_flash_attention( transpose_batch_sequence=False, window_size=sliding_window_size, context_parallel_causal_load_balanced=self.config.context_parallel_load_balance, - context_parallel_axis="context", + context_parallel_axis=self.config.context_sharding, context_parallel_strategy=self.config.context_parallel_strategy, max_segments_per_seq=max_segments_per_seq, ) diff --git a/src/maxtext/utils/train_utils.py b/src/maxtext/utils/train_utils.py index 9e0a00c8e6..8565f7ec83 100644 --- a/src/maxtext/utils/train_utils.py +++ b/src/maxtext/utils/train_utils.py @@ -230,9 +230,8 @@ def setup_train_loop(config, recorder, devices=None): data_iterator, eval_data_iterator = create_data_iterator(config, mesh) rampup_manager = create_rampup_manager(config, checkpoint_manager) data_loader = create_dataloader(config, mesh, data_iterator, recorder, rampup_manager) - context_parallel_size = mesh.shape.get("context", 1) # Check if context parallelism is being used with sequence packing - if context_parallel_size > 1 and config.packing and config.dataset_type != "synthetic": + if config.context_parallel_size > 1 and config.packing and config.dataset_type != "synthetic": raise ValueError( "Context parallelism cannot be used with sequence packing. " "Disable sequence packing (set packing=False). " @@ -241,11 +240,13 @@ def setup_train_loop(config, recorder, devices=None): # Apply reordering wrapper to data iterators if context parallelism is enabled with jax.set_mesh(mesh): - if context_parallel_size > 1 and config.context_parallel_load_balance: - data_iterator = map(maxtext_utils.get_reorder_callable(context_parallel_size, config.shard_mode), data_iterator) + if config.context_parallel_size > 1 and config.context_parallel_load_balance: + data_iterator = map( + maxtext_utils.get_reorder_callable(config.context_parallel_size, config.shard_mode), data_iterator + ) if eval_data_iterator: eval_data_iterator = map( - maxtext_utils.get_reorder_callable(context_parallel_size, config.shard_mode), + maxtext_utils.get_reorder_callable(config.context_parallel_size, config.shard_mode), eval_data_iterator, ) diff --git a/tests/utils/sharding_dump.py b/tests/utils/sharding_dump.py index d8075b6e44..cd19a49d63 100644 --- a/tests/utils/sharding_dump.py +++ b/tests/utils/sharding_dump.py @@ -48,6 +48,13 @@ "pipeline-large-moe", ("ici_fsdp_parallelism=-1", "ici_expert_parallelism=4", "use_ring_of_experts=true"), ), + ( + "deepseek2-16b", + "tpu7x-8", + 1, + "ep-as-cp", + ("ici_fsdp_parallelism=-1", "ici_expert_parallelism=2"), + ), ("qwen3-0.6b", "tpu7x-16", 1, "", ()), ("gpt-oss-20b", "tpu7x-16", 1, "", ()), ("gpt-oss-20b", "tpu7x-16", 1, "", ("ici_fsdp_parallelism=-1", "ici_expert_parallelism=2")), diff --git a/tests/utils/sharding_info/deepseek2-16b/tpu7x-8/slice_1/rule_ep-as-cp_ici_fsdp_parallelism=-1_ici_expert_parallelism=2/input_shardings.json b/tests/utils/sharding_info/deepseek2-16b/tpu7x-8/slice_1/rule_ep-as-cp_ici_fsdp_parallelism=-1_ici_expert_parallelism=2/input_shardings.json new file mode 100644 index 0000000000..0fe4f2405f --- /dev/null +++ b/tests/utils/sharding_info/deepseek2-16b/tpu7x-8/slice_1/rule_ep-as-cp_ici_fsdp_parallelism=-1_ici_expert_parallelism=2/input_shardings.json @@ -0,0 +1,148 @@ +{ + "Activation Sharding Dump": [ + { + "deepseek/inputs: bfloat16[96,2048,2048]": { + "logic_axes": "('activation_batch', 'activation_norm_length', 'activation_embed')", + "PartitionSpec": "P('fsdp', 'expert', None)" + } + }, + { + "deepseek/pre_attention_norm: bfloat16[96,2048,2048]": { + "logic_axes": "('activation_batch', 'activation_norm_length', 'activation_embed')", + "PartitionSpec": "P('fsdp', 'expert', None)" + } + }, + { + "attention_mla/inputs_q: bfloat16[96,2048,2048]": { + "logic_axes": "('activation_batch', 'activation_length', 'activation_embed')", + "PartitionSpec": "P('fsdp', 'expert', None)" + } + }, + { + "attention_mla/inputs_kv: bfloat16[96,2048,2048]": { + "logic_axes": "('activation_batch', 'activation_length', 'activation_embed')", + "PartitionSpec": "P('fsdp', 'expert', None)" + } + }, + { + "attention_mla/q_nope: bfloat16[96,2048,16,128]": { + "logic_axes": "('activation_kv_batch', 'activation_length', 'activation_kv_heads', 'activation_kv_head_dim')", + "PartitionSpec": "P('fsdp', 'expert', None, None)" + } + }, + { + "attention_mla/q_pe: bfloat16[96,2048,16,64]": { + "logic_axes": "('activation_kv_batch', 'activation_length', 'activation_kv_heads', 'activation_kv_head_dim')", + "PartitionSpec": "P('fsdp', 'expert', None, None)" + } + }, + { + "attention_mla/query: bfloat16[96,2048,16,192]": { + "logic_axes": "('activation_kv_batch', 'activation_length', 'activation_kv_heads', 'activation_kv_head_dim')", + "PartitionSpec": "P('fsdp', 'expert', None, None)" + } + }, + { + "attention_mla/key_nope: bfloat16[96,2048,16,128]": { + "logic_axes": "('activation_kv_batch', 'activation_length', 'activation_kv_heads', 'activation_kv_head_dim')", + "PartitionSpec": "P('fsdp', 'expert', None, None)" + } + }, + { + "attention_mla/key_rope: bfloat16[96,2048,16,64]": { + "logic_axes": "('activation_kv_batch', 'activation_length', 'activation_kv_heads', 'activation_kv_head_dim')", + "PartitionSpec": "P('fsdp', 'expert', None, None)" + } + }, + { + "attention_mla/key: bfloat16[96,2048,16,192]": { + "logic_axes": "('activation_kv_batch', 'activation_length', 'activation_kv_heads', 'activation_kv_head_dim')", + "PartitionSpec": "P('fsdp', 'expert', None, None)" + } + }, + { + "attention_mla/value: bfloat16[96,2048,16,128]": { + "logic_axes": "('activation_kv_batch', 'activation_length', 'activation_kv_heads', 'activation_kv_head_dim')", + "PartitionSpec": "P('fsdp', 'expert', None, None)" + } + }, + { + "attention_op/query: bfloat16[96,16,2048,192]": { + "logic_axes": "Unknown", + "PartitionSpec": "P('fsdp', None, 'expert', None)" + } + }, + { + "attention_op/key: bfloat16[96,16,2048,192]": { + "logic_axes": "Unknown", + "PartitionSpec": "P('fsdp', None, None, None)" + } + }, + { + "attention_op/value: bfloat16[96,16,2048,128]": { + "logic_axes": "Unknown", + "PartitionSpec": "P('fsdp', None, None, None)" + } + }, + { + "attention_mla/out: bfloat16[96,2048,16,128]": { + "logic_axes": "('activation_batch', 'activation_length', 'activation_heads', 'activation_kv')", + "PartitionSpec": "P('fsdp', 'expert', None, None)" + } + }, + { + "deepseek/attention_result: bfloat16[96,2048,2048]": { + "logic_axes": "('activation_batch', 'activation_norm_length', 'activation_embed')", + "PartitionSpec": "P('fsdp', 'expert', None)" + } + }, + { + "deepseek/post_attention_norm: bfloat16[96,2048,2048]": { + "logic_axes": "('activation_batch', 'activation_norm_length', 'activation_embed')", + "PartitionSpec": "P('fsdp', 'expert', None)" + } + }, + { + "linears/x: bfloat16[96,2048,10944]": { + "logic_axes": "('activation_batch', 'activation_length', 'activation_mlp')", + "PartitionSpec": "P('fsdp', 'expert', None)" + } + }, + { + "deepseek/mlp: bfloat16[96,2048,2048]": { + "logic_axes": "('activation_batch', 'activation_norm_length', 'activation_embed')", + "PartitionSpec": "P('fsdp', 'expert', None)" + } + }, + { + "deepseek/x: bfloat16[96,2048,2048]": { + "logic_axes": "('activation_batch', 'activation_norm_length', 'activation_embed')", + "PartitionSpec": "P('fsdp', 'expert', None)" + } + }, + { + "moe/inputs: bfloat16[96,2048,2048]": { + "logic_axes": "('activation_batch', 'activation_norm_length', None)", + "PartitionSpec": "P('fsdp', 'expert', None)" + } + }, + { + "moe/gate_logits: bfloat16[96,2048,64]": { + "logic_axes": "('activation_batch', 'activation_norm_length', None)", + "PartitionSpec": "P('fsdp', 'expert', None)" + } + }, + { + "linears/x: bfloat16[96,2048,2816]": { + "logic_axes": "('activation_batch', 'activation_length', 'activation_mlp')", + "PartitionSpec": "P('fsdp', 'expert', None)" + } + }, + { + "deepseek/mlp_lnx: bfloat16[96,2048,2048]": { + "logic_axes": "('activation_batch', 'activation_norm_length', 'activation_embed')", + "PartitionSpec": "P('fsdp', 'expert', None)" + } + } + ] +} \ No newline at end of file diff --git a/tests/utils/sharding_info/deepseek2-16b/tpu7x-8/slice_1/rule_ep-as-cp_ici_fsdp_parallelism=-1_ici_expert_parallelism=2/logical_shardings.json b/tests/utils/sharding_info/deepseek2-16b/tpu7x-8/slice_1/rule_ep-as-cp_ici_fsdp_parallelism=-1_ici_expert_parallelism=2/logical_shardings.json new file mode 100644 index 0000000000..8d30b919f8 --- /dev/null +++ b/tests/utils/sharding_info/deepseek2-16b/tpu7x-8/slice_1/rule_ep-as-cp_ici_fsdp_parallelism=-1_ici_expert_parallelism=2/logical_shardings.json @@ -0,0 +1,980 @@ +{ + ".step": { + "partition_spec": [], + "shape": [] + }, + ".params/['params']/['decoder']/['decoder_norm']/['scale']": { + "partition_spec": [ + "norm" + ], + "shape": [ + 2048 + ] + }, + ".params/['params']/['decoder']/['dense_layers']/['mlp']/['wi_0']/['kernel']": { + "partition_spec": [ + "embed", + "dense_layers", + "mlp" + ], + "shape": [ + 2048, + 1, + 10944 + ] + }, + ".params/['params']/['decoder']/['dense_layers']/['mlp']/['wi_1']/['kernel']": { + "partition_spec": [ + "embed", + "dense_layers", + "mlp" + ], + "shape": [ + 2048, + 1, + 10944 + ] + }, + ".params/['params']/['decoder']/['dense_layers']/['mlp']/['wo']/['kernel']": { + "partition_spec": [ + "mlp", + "dense_layers", + "embed" + ], + "shape": [ + 10944, + 1, + 2048 + ] + }, + ".params/['params']/['decoder']/['dense_layers']/['post_self_attention_layer_norm']/['scale']": { + "partition_spec": [ + "norm", + "dense_layers" + ], + "shape": [ + 2048, + 1 + ] + }, + ".params/['params']/['decoder']/['dense_layers']/['pre_self_attention_layer_norm']/['scale']": { + "partition_spec": [ + "norm", + "dense_layers" + ], + "shape": [ + 2048, + 1 + ] + }, + ".params/['params']/['decoder']/['dense_layers']/['self_attention']/['kv_norm']/['scale']": { + "partition_spec": [ + "norm", + "dense_layers" + ], + "shape": [ + 512, + 1 + ] + }, + ".params/['params']/['decoder']/['dense_layers']/['self_attention']/['out']/['kernel']": { + "partition_spec": [ + "heads", + "dense_layers", + "kv", + "embed" + ], + "shape": [ + 16, + 1, + 128, + 2048 + ] + }, + ".params/['params']/['decoder']/['dense_layers']/['self_attention']/['query']/['kernel']": { + "partition_spec": [ + "embed", + "dense_layers", + "q_heads", + "kv" + ], + "shape": [ + 2048, + 1, + 16, + 192 + ] + }, + ".params/['params']/['decoder']/['dense_layers']/['self_attention']/['wkv_a']/['kernel']": { + "partition_spec": [ + "embed", + "dense_layers", + "kv_lora_up_proj" + ], + "shape": [ + 2048, + 1, + 576 + ] + }, + ".params/['params']/['decoder']/['dense_layers']/['self_attention']/['wkv_b']/['kernel']": { + "partition_spec": [ + "kv_lora", + "dense_layers", + "kv_heads", + "kv_head_dim" + ], + "shape": [ + 512, + 1, + 16, + 256 + ] + }, + ".params/['params']/['decoder']/['logits_dense']/['kernel']": { + "partition_spec": [ + "embed_vocab", + "vocab" + ], + "shape": [ + 2048, + 102400 + ] + }, + ".params/['params']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['gate']/['kernel']": { + "partition_spec": [ + "embed_moe", + "moe_layers", + null + ], + "shape": [ + 2048, + 26, + 64 + ] + }, + ".params/['params']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['wi_0']": { + "partition_spec": [ + "exp", + "moe_layers", + "embed_moe", + "mlp_moe" + ], + "shape": [ + 64, + 26, + 2048, + 1408 + ] + }, + ".params/['params']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['wi_1']": { + "partition_spec": [ + "exp", + "moe_layers", + "embed_moe", + "mlp_moe" + ], + "shape": [ + 64, + 26, + 2048, + 1408 + ] + }, + ".params/['params']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['wo']": { + "partition_spec": [ + "exp", + "moe_layers", + "mlp_moe", + "embed_moe" + ], + "shape": [ + 64, + 26, + 1408, + 2048 + ] + }, + ".params/['params']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['shared_experts']/['wi_0']/['kernel']": { + "partition_spec": [ + "embed", + "moe_layers", + "mlp" + ], + "shape": [ + 2048, + 26, + 2816 + ] + }, + ".params/['params']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['shared_experts']/['wi_1']/['kernel']": { + "partition_spec": [ + "embed", + "moe_layers", + "mlp" + ], + "shape": [ + 2048, + 26, + 2816 + ] + }, + ".params/['params']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['shared_experts']/['wo']/['kernel']": { + "partition_spec": [ + "mlp", + "moe_layers", + "embed" + ], + "shape": [ + 2816, + 26, + 2048 + ] + }, + ".params/['params']/['decoder']/['moe_layers']/['post_self_attention_layer_norm']/['scale']": { + "partition_spec": [ + "norm", + "moe_layers" + ], + "shape": [ + 2048, + 26 + ] + }, + ".params/['params']/['decoder']/['moe_layers']/['pre_self_attention_layer_norm']/['scale']": { + "partition_spec": [ + "norm", + "moe_layers" + ], + "shape": [ + 2048, + 26 + ] + }, + ".params/['params']/['decoder']/['moe_layers']/['self_attention']/['kv_norm']/['scale']": { + "partition_spec": [ + "norm", + "moe_layers" + ], + "shape": [ + 512, + 26 + ] + }, + ".params/['params']/['decoder']/['moe_layers']/['self_attention']/['out']/['kernel']": { + "partition_spec": [ + "heads", + "moe_layers", + "kv", + "embed" + ], + "shape": [ + 16, + 26, + 128, + 2048 + ] + }, + ".params/['params']/['decoder']/['moe_layers']/['self_attention']/['query']/['kernel']": { + "partition_spec": [ + "embed", + "moe_layers", + "q_heads", + "kv" + ], + "shape": [ + 2048, + 26, + 16, + 192 + ] + }, + ".params/['params']/['decoder']/['moe_layers']/['self_attention']/['wkv_a']/['kernel']": { + "partition_spec": [ + "embed", + "moe_layers", + "kv_lora_up_proj" + ], + "shape": [ + 2048, + 26, + 576 + ] + }, + ".params/['params']/['decoder']/['moe_layers']/['self_attention']/['wkv_b']/['kernel']": { + "partition_spec": [ + "kv_lora", + "moe_layers", + "kv_heads", + "kv_head_dim" + ], + "shape": [ + 512, + 26, + 16, + 256 + ] + }, + ".params/['params']/['token_embedder']/['embedding']": { + "partition_spec": [ + "vocab", + "embed_vocab" + ], + "shape": [ + 102400, + 2048 + ] + }, + ".opt_state/[0]/.count": { + "partition_spec": [], + "shape": [] + }, + ".opt_state/[0]/.mu/['params']/['decoder']/['decoder_norm']/['scale']": { + "partition_spec": [ + "norm" + ], + "shape": [ + 2048 + ] + }, + ".opt_state/[0]/.mu/['params']/['decoder']/['dense_layers']/['mlp']/['wi_0']/['kernel']": { + "partition_spec": [ + "embed", + "dense_layers", + "mlp" + ], + "shape": [ + 2048, + 1, + 10944 + ] + }, + ".opt_state/[0]/.mu/['params']/['decoder']/['dense_layers']/['mlp']/['wi_1']/['kernel']": { + "partition_spec": [ + "embed", + "dense_layers", + "mlp" + ], + "shape": [ + 2048, + 1, + 10944 + ] + }, + ".opt_state/[0]/.mu/['params']/['decoder']/['dense_layers']/['mlp']/['wo']/['kernel']": { + "partition_spec": [ + "mlp", + "dense_layers", + "embed" + ], + "shape": [ + 10944, + 1, + 2048 + ] + }, + ".opt_state/[0]/.mu/['params']/['decoder']/['dense_layers']/['post_self_attention_layer_norm']/['scale']": { + "partition_spec": [ + "norm", + "dense_layers" + ], + "shape": [ + 2048, + 1 + ] + }, + ".opt_state/[0]/.mu/['params']/['decoder']/['dense_layers']/['pre_self_attention_layer_norm']/['scale']": { + "partition_spec": [ + "norm", + "dense_layers" + ], + "shape": [ + 2048, + 1 + ] + }, + ".opt_state/[0]/.mu/['params']/['decoder']/['dense_layers']/['self_attention']/['kv_norm']/['scale']": { + "partition_spec": [ + "norm", + "dense_layers" + ], + "shape": [ + 512, + 1 + ] + }, + ".opt_state/[0]/.mu/['params']/['decoder']/['dense_layers']/['self_attention']/['out']/['kernel']": { + "partition_spec": [ + "heads", + "dense_layers", + "kv", + "embed" + ], + "shape": [ + 16, + 1, + 128, + 2048 + ] + }, + ".opt_state/[0]/.mu/['params']/['decoder']/['dense_layers']/['self_attention']/['query']/['kernel']": { + "partition_spec": [ + "embed", + "dense_layers", + "q_heads", + "kv" + ], + "shape": [ + 2048, + 1, + 16, + 192 + ] + }, + ".opt_state/[0]/.mu/['params']/['decoder']/['dense_layers']/['self_attention']/['wkv_a']/['kernel']": { + "partition_spec": [ + "embed", + "dense_layers", + "kv_lora_up_proj" + ], + "shape": [ + 2048, + 1, + 576 + ] + }, + ".opt_state/[0]/.mu/['params']/['decoder']/['dense_layers']/['self_attention']/['wkv_b']/['kernel']": { + "partition_spec": [ + "kv_lora", + "dense_layers", + "kv_heads", + "kv_head_dim" + ], + "shape": [ + 512, + 1, + 16, + 256 + ] + }, + ".opt_state/[0]/.mu/['params']/['decoder']/['logits_dense']/['kernel']": { + "partition_spec": [ + "embed_vocab", + "vocab" + ], + "shape": [ + 2048, + 102400 + ] + }, + ".opt_state/[0]/.mu/['params']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['gate']/['kernel']": { + "partition_spec": [ + "embed_moe", + "moe_layers", + null + ], + "shape": [ + 2048, + 26, + 64 + ] + }, + ".opt_state/[0]/.mu/['params']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['wi_0']": { + "partition_spec": [ + "exp", + "moe_layers", + "embed_moe", + "mlp_moe" + ], + "shape": [ + 64, + 26, + 2048, + 1408 + ] + }, + ".opt_state/[0]/.mu/['params']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['wi_1']": { + "partition_spec": [ + "exp", + "moe_layers", + "embed_moe", + "mlp_moe" + ], + "shape": [ + 64, + 26, + 2048, + 1408 + ] + }, + ".opt_state/[0]/.mu/['params']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['wo']": { + "partition_spec": [ + "exp", + "moe_layers", + "mlp_moe", + "embed_moe" + ], + "shape": [ + 64, + 26, + 1408, + 2048 + ] + }, + ".opt_state/[0]/.mu/['params']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['shared_experts']/['wi_0']/['kernel']": { + "partition_spec": [ + "embed", + "moe_layers", + "mlp" + ], + "shape": [ + 2048, + 26, + 2816 + ] + }, + ".opt_state/[0]/.mu/['params']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['shared_experts']/['wi_1']/['kernel']": { + "partition_spec": [ + "embed", + "moe_layers", + "mlp" + ], + "shape": [ + 2048, + 26, + 2816 + ] + }, + ".opt_state/[0]/.mu/['params']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['shared_experts']/['wo']/['kernel']": { + "partition_spec": [ + "mlp", + "moe_layers", + "embed" + ], + "shape": [ + 2816, + 26, + 2048 + ] + }, + ".opt_state/[0]/.mu/['params']/['decoder']/['moe_layers']/['post_self_attention_layer_norm']/['scale']": { + "partition_spec": [ + "norm", + "moe_layers" + ], + "shape": [ + 2048, + 26 + ] + }, + ".opt_state/[0]/.mu/['params']/['decoder']/['moe_layers']/['pre_self_attention_layer_norm']/['scale']": { + "partition_spec": [ + "norm", + "moe_layers" + ], + "shape": [ + 2048, + 26 + ] + }, + ".opt_state/[0]/.mu/['params']/['decoder']/['moe_layers']/['self_attention']/['kv_norm']/['scale']": { + "partition_spec": [ + "norm", + "moe_layers" + ], + "shape": [ + 512, + 26 + ] + }, + ".opt_state/[0]/.mu/['params']/['decoder']/['moe_layers']/['self_attention']/['out']/['kernel']": { + "partition_spec": [ + "heads", + "moe_layers", + "kv", + "embed" + ], + "shape": [ + 16, + 26, + 128, + 2048 + ] + }, + ".opt_state/[0]/.mu/['params']/['decoder']/['moe_layers']/['self_attention']/['query']/['kernel']": { + "partition_spec": [ + "embed", + "moe_layers", + "q_heads", + "kv" + ], + "shape": [ + 2048, + 26, + 16, + 192 + ] + }, + ".opt_state/[0]/.mu/['params']/['decoder']/['moe_layers']/['self_attention']/['wkv_a']/['kernel']": { + "partition_spec": [ + "embed", + "moe_layers", + "kv_lora_up_proj" + ], + "shape": [ + 2048, + 26, + 576 + ] + }, + ".opt_state/[0]/.mu/['params']/['decoder']/['moe_layers']/['self_attention']/['wkv_b']/['kernel']": { + "partition_spec": [ + "kv_lora", + "moe_layers", + "kv_heads", + "kv_head_dim" + ], + "shape": [ + 512, + 26, + 16, + 256 + ] + }, + ".opt_state/[0]/.mu/['params']/['token_embedder']/['embedding']": { + "partition_spec": [ + "vocab", + "embed_vocab" + ], + "shape": [ + 102400, + 2048 + ] + }, + ".opt_state/[0]/.nu/['params']/['decoder']/['decoder_norm']/['scale']": { + "partition_spec": [ + "norm" + ], + "shape": [ + 2048 + ] + }, + ".opt_state/[0]/.nu/['params']/['decoder']/['dense_layers']/['mlp']/['wi_0']/['kernel']": { + "partition_spec": [ + "embed", + "dense_layers", + "mlp" + ], + "shape": [ + 2048, + 1, + 10944 + ] + }, + ".opt_state/[0]/.nu/['params']/['decoder']/['dense_layers']/['mlp']/['wi_1']/['kernel']": { + "partition_spec": [ + "embed", + "dense_layers", + "mlp" + ], + "shape": [ + 2048, + 1, + 10944 + ] + }, + ".opt_state/[0]/.nu/['params']/['decoder']/['dense_layers']/['mlp']/['wo']/['kernel']": { + "partition_spec": [ + "mlp", + "dense_layers", + "embed" + ], + "shape": [ + 10944, + 1, + 2048 + ] + }, + ".opt_state/[0]/.nu/['params']/['decoder']/['dense_layers']/['post_self_attention_layer_norm']/['scale']": { + "partition_spec": [ + "norm", + "dense_layers" + ], + "shape": [ + 2048, + 1 + ] + }, + ".opt_state/[0]/.nu/['params']/['decoder']/['dense_layers']/['pre_self_attention_layer_norm']/['scale']": { + "partition_spec": [ + "norm", + "dense_layers" + ], + "shape": [ + 2048, + 1 + ] + }, + ".opt_state/[0]/.nu/['params']/['decoder']/['dense_layers']/['self_attention']/['kv_norm']/['scale']": { + "partition_spec": [ + "norm", + "dense_layers" + ], + "shape": [ + 512, + 1 + ] + }, + ".opt_state/[0]/.nu/['params']/['decoder']/['dense_layers']/['self_attention']/['out']/['kernel']": { + "partition_spec": [ + "heads", + "dense_layers", + "kv", + "embed" + ], + "shape": [ + 16, + 1, + 128, + 2048 + ] + }, + ".opt_state/[0]/.nu/['params']/['decoder']/['dense_layers']/['self_attention']/['query']/['kernel']": { + "partition_spec": [ + "embed", + "dense_layers", + "q_heads", + "kv" + ], + "shape": [ + 2048, + 1, + 16, + 192 + ] + }, + ".opt_state/[0]/.nu/['params']/['decoder']/['dense_layers']/['self_attention']/['wkv_a']/['kernel']": { + "partition_spec": [ + "embed", + "dense_layers", + "kv_lora_up_proj" + ], + "shape": [ + 2048, + 1, + 576 + ] + }, + ".opt_state/[0]/.nu/['params']/['decoder']/['dense_layers']/['self_attention']/['wkv_b']/['kernel']": { + "partition_spec": [ + "kv_lora", + "dense_layers", + "kv_heads", + "kv_head_dim" + ], + "shape": [ + 512, + 1, + 16, + 256 + ] + }, + ".opt_state/[0]/.nu/['params']/['decoder']/['logits_dense']/['kernel']": { + "partition_spec": [ + "embed_vocab", + "vocab" + ], + "shape": [ + 2048, + 102400 + ] + }, + ".opt_state/[0]/.nu/['params']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['gate']/['kernel']": { + "partition_spec": [ + "embed_moe", + "moe_layers", + null + ], + "shape": [ + 2048, + 26, + 64 + ] + }, + ".opt_state/[0]/.nu/['params']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['wi_0']": { + "partition_spec": [ + "exp", + "moe_layers", + "embed_moe", + "mlp_moe" + ], + "shape": [ + 64, + 26, + 2048, + 1408 + ] + }, + ".opt_state/[0]/.nu/['params']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['wi_1']": { + "partition_spec": [ + "exp", + "moe_layers", + "embed_moe", + "mlp_moe" + ], + "shape": [ + 64, + 26, + 2048, + 1408 + ] + }, + ".opt_state/[0]/.nu/['params']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['wo']": { + "partition_spec": [ + "exp", + "moe_layers", + "mlp_moe", + "embed_moe" + ], + "shape": [ + 64, + 26, + 1408, + 2048 + ] + }, + ".opt_state/[0]/.nu/['params']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['shared_experts']/['wi_0']/['kernel']": { + "partition_spec": [ + "embed", + "moe_layers", + "mlp" + ], + "shape": [ + 2048, + 26, + 2816 + ] + }, + ".opt_state/[0]/.nu/['params']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['shared_experts']/['wi_1']/['kernel']": { + "partition_spec": [ + "embed", + "moe_layers", + "mlp" + ], + "shape": [ + 2048, + 26, + 2816 + ] + }, + ".opt_state/[0]/.nu/['params']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['shared_experts']/['wo']/['kernel']": { + "partition_spec": [ + "mlp", + "moe_layers", + "embed" + ], + "shape": [ + 2816, + 26, + 2048 + ] + }, + ".opt_state/[0]/.nu/['params']/['decoder']/['moe_layers']/['post_self_attention_layer_norm']/['scale']": { + "partition_spec": [ + "norm", + "moe_layers" + ], + "shape": [ + 2048, + 26 + ] + }, + ".opt_state/[0]/.nu/['params']/['decoder']/['moe_layers']/['pre_self_attention_layer_norm']/['scale']": { + "partition_spec": [ + "norm", + "moe_layers" + ], + "shape": [ + 2048, + 26 + ] + }, + ".opt_state/[0]/.nu/['params']/['decoder']/['moe_layers']/['self_attention']/['kv_norm']/['scale']": { + "partition_spec": [ + "norm", + "moe_layers" + ], + "shape": [ + 512, + 26 + ] + }, + ".opt_state/[0]/.nu/['params']/['decoder']/['moe_layers']/['self_attention']/['out']/['kernel']": { + "partition_spec": [ + "heads", + "moe_layers", + "kv", + "embed" + ], + "shape": [ + 16, + 26, + 128, + 2048 + ] + }, + ".opt_state/[0]/.nu/['params']/['decoder']/['moe_layers']/['self_attention']/['query']/['kernel']": { + "partition_spec": [ + "embed", + "moe_layers", + "q_heads", + "kv" + ], + "shape": [ + 2048, + 26, + 16, + 192 + ] + }, + ".opt_state/[0]/.nu/['params']/['decoder']/['moe_layers']/['self_attention']/['wkv_a']/['kernel']": { + "partition_spec": [ + "embed", + "moe_layers", + "kv_lora_up_proj" + ], + "shape": [ + 2048, + 26, + 576 + ] + }, + ".opt_state/[0]/.nu/['params']/['decoder']/['moe_layers']/['self_attention']/['wkv_b']/['kernel']": { + "partition_spec": [ + "kv_lora", + "moe_layers", + "kv_heads", + "kv_head_dim" + ], + "shape": [ + 512, + 26, + 16, + 256 + ] + }, + ".opt_state/[0]/.nu/['params']/['token_embedder']/['embedding']": { + "partition_spec": [ + "vocab", + "embed_vocab" + ], + "shape": [ + 102400, + 2048 + ] + }, + ".opt_state/[2]/.count": { + "partition_spec": [], + "shape": [] + } +} \ No newline at end of file diff --git a/tests/utils/sharding_info/deepseek2-16b/tpu7x-8/slice_1/rule_ep-as-cp_ici_fsdp_parallelism=-1_ici_expert_parallelism=2/named_shardings.json b/tests/utils/sharding_info/deepseek2-16b/tpu7x-8/slice_1/rule_ep-as-cp_ici_fsdp_parallelism=-1_ici_expert_parallelism=2/named_shardings.json new file mode 100644 index 0000000000..76e3036a34 --- /dev/null +++ b/tests/utils/sharding_info/deepseek2-16b/tpu7x-8/slice_1/rule_ep-as-cp_ici_fsdp_parallelism=-1_ici_expert_parallelism=2/named_shardings.json @@ -0,0 +1,2132 @@ +{ + ".step": { + "mesh": { + "axis_names": [ + "data", + "fsdp", + "expert" + ], + "shape": { + "data": 1, + "fsdp": 4, + "expert": 2 + } + }, + "partition_spec": [], + "shape": [] + }, + ".params/['params']/['decoder']/['decoder_norm']/['scale']": { + "mesh": { + "axis_names": [ + "data", + "fsdp", + "expert" + ], + "shape": { + "data": 1, + "fsdp": 4, + "expert": 2 + } + }, + "partition_spec": [ + null + ], + "shape": [ + 2048 + ] + }, + ".params/['params']/['decoder']/['dense_layers']/['mlp']/['wi_0']/['kernel']": { + "mesh": { + "axis_names": [ + "data", + "fsdp", + "expert" + ], + "shape": { + "data": 1, + "fsdp": 4, + "expert": 2 + } + }, + "partition_spec": [ + [ + "fsdp", + "expert" + ], + null, + null + ], + "shape": [ + 2048, + 1, + 10944 + ] + }, + ".params/['params']/['decoder']/['dense_layers']/['mlp']/['wi_1']/['kernel']": { + "mesh": { + "axis_names": [ + "data", + "fsdp", + "expert" + ], + "shape": { + "data": 1, + "fsdp": 4, + "expert": 2 + } + }, + "partition_spec": [ + [ + "fsdp", + "expert" + ], + null, + null + ], + "shape": [ + 2048, + 1, + 10944 + ] + }, + ".params/['params']/['decoder']/['dense_layers']/['mlp']/['wo']/['kernel']": { + "mesh": { + "axis_names": [ + "data", + "fsdp", + "expert" + ], + "shape": { + "data": 1, + "fsdp": 4, + "expert": 2 + } + }, + "partition_spec": [ + null, + null, + [ + "fsdp", + "expert" + ] + ], + "shape": [ + 10944, + 1, + 2048 + ] + }, + ".params/['params']/['decoder']/['dense_layers']/['post_self_attention_layer_norm']/['scale']": { + "mesh": { + "axis_names": [ + "data", + "fsdp", + "expert" + ], + "shape": { + "data": 1, + "fsdp": 4, + "expert": 2 + } + }, + "partition_spec": [ + null, + null + ], + "shape": [ + 2048, + 1 + ] + }, + ".params/['params']/['decoder']/['dense_layers']/['pre_self_attention_layer_norm']/['scale']": { + "mesh": { + "axis_names": [ + "data", + "fsdp", + "expert" + ], + "shape": { + "data": 1, + "fsdp": 4, + "expert": 2 + } + }, + "partition_spec": [ + null, + null + ], + "shape": [ + 2048, + 1 + ] + }, + ".params/['params']/['decoder']/['dense_layers']/['self_attention']/['kv_norm']/['scale']": { + "mesh": { + "axis_names": [ + "data", + "fsdp", + "expert" + ], + "shape": { + "data": 1, + "fsdp": 4, + "expert": 2 + } + }, + "partition_spec": [ + null, + null + ], + "shape": [ + 512, + 1 + ] + }, + ".params/['params']/['decoder']/['dense_layers']/['self_attention']/['out']/['kernel']": { + "mesh": { + "axis_names": [ + "data", + "fsdp", + "expert" + ], + "shape": { + "data": 1, + "fsdp": 4, + "expert": 2 + } + }, + "partition_spec": [ + null, + null, + null, + [ + "fsdp", + "expert" + ] + ], + "shape": [ + 16, + 1, + 128, + 2048 + ] + }, + ".params/['params']/['decoder']/['dense_layers']/['self_attention']/['query']/['kernel']": { + "mesh": { + "axis_names": [ + "data", + "fsdp", + "expert" + ], + "shape": { + "data": 1, + "fsdp": 4, + "expert": 2 + } + }, + "partition_spec": [ + [ + "fsdp", + "expert" + ], + null, + null, + null + ], + "shape": [ + 2048, + 1, + 16, + 192 + ] + }, + ".params/['params']/['decoder']/['dense_layers']/['self_attention']/['wkv_a']/['kernel']": { + "mesh": { + "axis_names": [ + "data", + "fsdp", + "expert" + ], + "shape": { + "data": 1, + "fsdp": 4, + "expert": 2 + } + }, + "partition_spec": [ + [ + "fsdp", + "expert" + ], + null, + null + ], + "shape": [ + 2048, + 1, + 576 + ] + }, + ".params/['params']/['decoder']/['dense_layers']/['self_attention']/['wkv_b']/['kernel']": { + "mesh": { + "axis_names": [ + "data", + "fsdp", + "expert" + ], + "shape": { + "data": 1, + "fsdp": 4, + "expert": 2 + } + }, + "partition_spec": [ + [ + "fsdp", + "expert" + ], + null, + null, + null + ], + "shape": [ + 512, + 1, + 16, + 256 + ] + }, + ".params/['params']/['decoder']/['logits_dense']/['kernel']": { + "mesh": { + "axis_names": [ + "data", + "fsdp", + "expert" + ], + "shape": { + "data": 1, + "fsdp": 4, + "expert": 2 + } + }, + "partition_spec": [ + [ + "fsdp", + "expert" + ], + null + ], + "shape": [ + 2048, + 102400 + ] + }, + ".params/['params']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['gate']/['kernel']": { + "mesh": { + "axis_names": [ + "data", + "fsdp", + "expert" + ], + "shape": { + "data": 1, + "fsdp": 4, + "expert": 2 + } + }, + "partition_spec": [ + "fsdp", + null, + null + ], + "shape": [ + 2048, + 26, + 64 + ] + }, + ".params/['params']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['wi_0']": { + "mesh": { + "axis_names": [ + "data", + "fsdp", + "expert" + ], + "shape": { + "data": 1, + "fsdp": 4, + "expert": 2 + } + }, + "partition_spec": [ + "expert", + null, + "fsdp", + null + ], + "shape": [ + 64, + 26, + 2048, + 1408 + ] + }, + ".params/['params']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['wi_1']": { + "mesh": { + "axis_names": [ + "data", + "fsdp", + "expert" + ], + "shape": { + "data": 1, + "fsdp": 4, + "expert": 2 + } + }, + "partition_spec": [ + "expert", + null, + "fsdp", + null + ], + "shape": [ + 64, + 26, + 2048, + 1408 + ] + }, + ".params/['params']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['wo']": { + "mesh": { + "axis_names": [ + "data", + "fsdp", + "expert" + ], + "shape": { + "data": 1, + "fsdp": 4, + "expert": 2 + } + }, + "partition_spec": [ + "expert", + null, + null, + "fsdp" + ], + "shape": [ + 64, + 26, + 1408, + 2048 + ] + }, + ".params/['params']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['shared_experts']/['wi_0']/['kernel']": { + "mesh": { + "axis_names": [ + "data", + "fsdp", + "expert" + ], + "shape": { + "data": 1, + "fsdp": 4, + "expert": 2 + } + }, + "partition_spec": [ + [ + "fsdp", + "expert" + ], + null, + null + ], + "shape": [ + 2048, + 26, + 2816 + ] + }, + ".params/['params']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['shared_experts']/['wi_1']/['kernel']": { + "mesh": { + "axis_names": [ + "data", + "fsdp", + "expert" + ], + "shape": { + "data": 1, + "fsdp": 4, + "expert": 2 + } + }, + "partition_spec": [ + [ + "fsdp", + "expert" + ], + null, + null + ], + "shape": [ + 2048, + 26, + 2816 + ] + }, + ".params/['params']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['shared_experts']/['wo']/['kernel']": { + "mesh": { + "axis_names": [ + "data", + "fsdp", + "expert" + ], + "shape": { + "data": 1, + "fsdp": 4, + "expert": 2 + } + }, + "partition_spec": [ + null, + null, + [ + "fsdp", + "expert" + ] + ], + "shape": [ + 2816, + 26, + 2048 + ] + }, + ".params/['params']/['decoder']/['moe_layers']/['post_self_attention_layer_norm']/['scale']": { + "mesh": { + "axis_names": [ + "data", + "fsdp", + "expert" + ], + "shape": { + "data": 1, + "fsdp": 4, + "expert": 2 + } + }, + "partition_spec": [ + null, + null + ], + "shape": [ + 2048, + 26 + ] + }, + ".params/['params']/['decoder']/['moe_layers']/['pre_self_attention_layer_norm']/['scale']": { + "mesh": { + "axis_names": [ + "data", + "fsdp", + "expert" + ], + "shape": { + "data": 1, + "fsdp": 4, + "expert": 2 + } + }, + "partition_spec": [ + null, + null + ], + "shape": [ + 2048, + 26 + ] + }, + ".params/['params']/['decoder']/['moe_layers']/['self_attention']/['kv_norm']/['scale']": { + "mesh": { + "axis_names": [ + "data", + "fsdp", + "expert" + ], + "shape": { + "data": 1, + "fsdp": 4, + "expert": 2 + } + }, + "partition_spec": [ + null, + null + ], + "shape": [ + 512, + 26 + ] + }, + ".params/['params']/['decoder']/['moe_layers']/['self_attention']/['out']/['kernel']": { + "mesh": { + "axis_names": [ + "data", + "fsdp", + "expert" + ], + "shape": { + "data": 1, + "fsdp": 4, + "expert": 2 + } + }, + "partition_spec": [ + null, + null, + null, + [ + "fsdp", + "expert" + ] + ], + "shape": [ + 16, + 26, + 128, + 2048 + ] + }, + ".params/['params']/['decoder']/['moe_layers']/['self_attention']/['query']/['kernel']": { + "mesh": { + "axis_names": [ + "data", + "fsdp", + "expert" + ], + "shape": { + "data": 1, + "fsdp": 4, + "expert": 2 + } + }, + "partition_spec": [ + [ + "fsdp", + "expert" + ], + null, + null, + null + ], + "shape": [ + 2048, + 26, + 16, + 192 + ] + }, + ".params/['params']/['decoder']/['moe_layers']/['self_attention']/['wkv_a']/['kernel']": { + "mesh": { + "axis_names": [ + "data", + "fsdp", + "expert" + ], + "shape": { + "data": 1, + "fsdp": 4, + "expert": 2 + } + }, + "partition_spec": [ + [ + "fsdp", + "expert" + ], + null, + null + ], + "shape": [ + 2048, + 26, + 576 + ] + }, + ".params/['params']/['decoder']/['moe_layers']/['self_attention']/['wkv_b']/['kernel']": { + "mesh": { + "axis_names": [ + "data", + "fsdp", + "expert" + ], + "shape": { + "data": 1, + "fsdp": 4, + "expert": 2 + } + }, + "partition_spec": [ + [ + "fsdp", + "expert" + ], + null, + null, + null + ], + "shape": [ + 512, + 26, + 16, + 256 + ] + }, + ".params/['params']/['token_embedder']/['embedding']": { + "mesh": { + "axis_names": [ + "data", + "fsdp", + "expert" + ], + "shape": { + "data": 1, + "fsdp": 4, + "expert": 2 + } + }, + "partition_spec": [ + null, + [ + "fsdp", + "expert" + ] + ], + "shape": [ + 102400, + 2048 + ] + }, + ".opt_state/[0]/.count": { + "mesh": { + "axis_names": [ + "data", + "fsdp", + "expert" + ], + "shape": { + "data": 1, + "fsdp": 4, + "expert": 2 + } + }, + "partition_spec": [], + "shape": [] + }, + ".opt_state/[0]/.mu/['params']/['decoder']/['decoder_norm']/['scale']": { + "mesh": { + "axis_names": [ + "data", + "fsdp", + "expert" + ], + "shape": { + "data": 1, + "fsdp": 4, + "expert": 2 + } + }, + "partition_spec": [ + null + ], + "shape": [ + 2048 + ] + }, + ".opt_state/[0]/.mu/['params']/['decoder']/['dense_layers']/['mlp']/['wi_0']/['kernel']": { + "mesh": { + "axis_names": [ + "data", + "fsdp", + "expert" + ], + "shape": { + "data": 1, + "fsdp": 4, + "expert": 2 + } + }, + "partition_spec": [ + [ + "fsdp", + "expert" + ], + null, + null + ], + "shape": [ + 2048, + 1, + 10944 + ] + }, + ".opt_state/[0]/.mu/['params']/['decoder']/['dense_layers']/['mlp']/['wi_1']/['kernel']": { + "mesh": { + "axis_names": [ + "data", + "fsdp", + "expert" + ], + "shape": { + "data": 1, + "fsdp": 4, + "expert": 2 + } + }, + "partition_spec": [ + [ + "fsdp", + "expert" + ], + null, + null + ], + "shape": [ + 2048, + 1, + 10944 + ] + }, + ".opt_state/[0]/.mu/['params']/['decoder']/['dense_layers']/['mlp']/['wo']/['kernel']": { + "mesh": { + "axis_names": [ + "data", + "fsdp", + "expert" + ], + "shape": { + "data": 1, + "fsdp": 4, + "expert": 2 + } + }, + "partition_spec": [ + null, + null, + [ + "fsdp", + "expert" + ] + ], + "shape": [ + 10944, + 1, + 2048 + ] + }, + ".opt_state/[0]/.mu/['params']/['decoder']/['dense_layers']/['post_self_attention_layer_norm']/['scale']": { + "mesh": { + "axis_names": [ + "data", + "fsdp", + "expert" + ], + "shape": { + "data": 1, + "fsdp": 4, + "expert": 2 + } + }, + "partition_spec": [ + null, + null + ], + "shape": [ + 2048, + 1 + ] + }, + ".opt_state/[0]/.mu/['params']/['decoder']/['dense_layers']/['pre_self_attention_layer_norm']/['scale']": { + "mesh": { + "axis_names": [ + "data", + "fsdp", + "expert" + ], + "shape": { + "data": 1, + "fsdp": 4, + "expert": 2 + } + }, + "partition_spec": [ + null, + null + ], + "shape": [ + 2048, + 1 + ] + }, + ".opt_state/[0]/.mu/['params']/['decoder']/['dense_layers']/['self_attention']/['kv_norm']/['scale']": { + "mesh": { + "axis_names": [ + "data", + "fsdp", + "expert" + ], + "shape": { + "data": 1, + "fsdp": 4, + "expert": 2 + } + }, + "partition_spec": [ + null, + null + ], + "shape": [ + 512, + 1 + ] + }, + ".opt_state/[0]/.mu/['params']/['decoder']/['dense_layers']/['self_attention']/['out']/['kernel']": { + "mesh": { + "axis_names": [ + "data", + "fsdp", + "expert" + ], + "shape": { + "data": 1, + "fsdp": 4, + "expert": 2 + } + }, + "partition_spec": [ + null, + null, + null, + [ + "fsdp", + "expert" + ] + ], + "shape": [ + 16, + 1, + 128, + 2048 + ] + }, + ".opt_state/[0]/.mu/['params']/['decoder']/['dense_layers']/['self_attention']/['query']/['kernel']": { + "mesh": { + "axis_names": [ + "data", + "fsdp", + "expert" + ], + "shape": { + "data": 1, + "fsdp": 4, + "expert": 2 + } + }, + "partition_spec": [ + [ + "fsdp", + "expert" + ], + null, + null, + null + ], + "shape": [ + 2048, + 1, + 16, + 192 + ] + }, + ".opt_state/[0]/.mu/['params']/['decoder']/['dense_layers']/['self_attention']/['wkv_a']/['kernel']": { + "mesh": { + "axis_names": [ + "data", + "fsdp", + "expert" + ], + "shape": { + "data": 1, + "fsdp": 4, + "expert": 2 + } + }, + "partition_spec": [ + [ + "fsdp", + "expert" + ], + null, + null + ], + "shape": [ + 2048, + 1, + 576 + ] + }, + ".opt_state/[0]/.mu/['params']/['decoder']/['dense_layers']/['self_attention']/['wkv_b']/['kernel']": { + "mesh": { + "axis_names": [ + "data", + "fsdp", + "expert" + ], + "shape": { + "data": 1, + "fsdp": 4, + "expert": 2 + } + }, + "partition_spec": [ + [ + "fsdp", + "expert" + ], + null, + null, + null + ], + "shape": [ + 512, + 1, + 16, + 256 + ] + }, + ".opt_state/[0]/.mu/['params']/['decoder']/['logits_dense']/['kernel']": { + "mesh": { + "axis_names": [ + "data", + "fsdp", + "expert" + ], + "shape": { + "data": 1, + "fsdp": 4, + "expert": 2 + } + }, + "partition_spec": [ + [ + "fsdp", + "expert" + ], + null + ], + "shape": [ + 2048, + 102400 + ] + }, + ".opt_state/[0]/.mu/['params']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['gate']/['kernel']": { + "mesh": { + "axis_names": [ + "data", + "fsdp", + "expert" + ], + "shape": { + "data": 1, + "fsdp": 4, + "expert": 2 + } + }, + "partition_spec": [ + "fsdp", + null, + null + ], + "shape": [ + 2048, + 26, + 64 + ] + }, + ".opt_state/[0]/.mu/['params']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['wi_0']": { + "mesh": { + "axis_names": [ + "data", + "fsdp", + "expert" + ], + "shape": { + "data": 1, + "fsdp": 4, + "expert": 2 + } + }, + "partition_spec": [ + "expert", + null, + "fsdp", + null + ], + "shape": [ + 64, + 26, + 2048, + 1408 + ] + }, + ".opt_state/[0]/.mu/['params']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['wi_1']": { + "mesh": { + "axis_names": [ + "data", + "fsdp", + "expert" + ], + "shape": { + "data": 1, + "fsdp": 4, + "expert": 2 + } + }, + "partition_spec": [ + "expert", + null, + "fsdp", + null + ], + "shape": [ + 64, + 26, + 2048, + 1408 + ] + }, + ".opt_state/[0]/.mu/['params']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['wo']": { + "mesh": { + "axis_names": [ + "data", + "fsdp", + "expert" + ], + "shape": { + "data": 1, + "fsdp": 4, + "expert": 2 + } + }, + "partition_spec": [ + "expert", + null, + null, + "fsdp" + ], + "shape": [ + 64, + 26, + 1408, + 2048 + ] + }, + ".opt_state/[0]/.mu/['params']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['shared_experts']/['wi_0']/['kernel']": { + "mesh": { + "axis_names": [ + "data", + "fsdp", + "expert" + ], + "shape": { + "data": 1, + "fsdp": 4, + "expert": 2 + } + }, + "partition_spec": [ + [ + "fsdp", + "expert" + ], + null, + null + ], + "shape": [ + 2048, + 26, + 2816 + ] + }, + ".opt_state/[0]/.mu/['params']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['shared_experts']/['wi_1']/['kernel']": { + "mesh": { + "axis_names": [ + "data", + "fsdp", + "expert" + ], + "shape": { + "data": 1, + "fsdp": 4, + "expert": 2 + } + }, + "partition_spec": [ + [ + "fsdp", + "expert" + ], + null, + null + ], + "shape": [ + 2048, + 26, + 2816 + ] + }, + ".opt_state/[0]/.mu/['params']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['shared_experts']/['wo']/['kernel']": { + "mesh": { + "axis_names": [ + "data", + "fsdp", + "expert" + ], + "shape": { + "data": 1, + "fsdp": 4, + "expert": 2 + } + }, + "partition_spec": [ + null, + null, + [ + "fsdp", + "expert" + ] + ], + "shape": [ + 2816, + 26, + 2048 + ] + }, + ".opt_state/[0]/.mu/['params']/['decoder']/['moe_layers']/['post_self_attention_layer_norm']/['scale']": { + "mesh": { + "axis_names": [ + "data", + "fsdp", + "expert" + ], + "shape": { + "data": 1, + "fsdp": 4, + "expert": 2 + } + }, + "partition_spec": [ + null, + null + ], + "shape": [ + 2048, + 26 + ] + }, + ".opt_state/[0]/.mu/['params']/['decoder']/['moe_layers']/['pre_self_attention_layer_norm']/['scale']": { + "mesh": { + "axis_names": [ + "data", + "fsdp", + "expert" + ], + "shape": { + "data": 1, + "fsdp": 4, + "expert": 2 + } + }, + "partition_spec": [ + null, + null + ], + "shape": [ + 2048, + 26 + ] + }, + ".opt_state/[0]/.mu/['params']/['decoder']/['moe_layers']/['self_attention']/['kv_norm']/['scale']": { + "mesh": { + "axis_names": [ + "data", + "fsdp", + "expert" + ], + "shape": { + "data": 1, + "fsdp": 4, + "expert": 2 + } + }, + "partition_spec": [ + null, + null + ], + "shape": [ + 512, + 26 + ] + }, + ".opt_state/[0]/.mu/['params']/['decoder']/['moe_layers']/['self_attention']/['out']/['kernel']": { + "mesh": { + "axis_names": [ + "data", + "fsdp", + "expert" + ], + "shape": { + "data": 1, + "fsdp": 4, + "expert": 2 + } + }, + "partition_spec": [ + null, + null, + null, + [ + "fsdp", + "expert" + ] + ], + "shape": [ + 16, + 26, + 128, + 2048 + ] + }, + ".opt_state/[0]/.mu/['params']/['decoder']/['moe_layers']/['self_attention']/['query']/['kernel']": { + "mesh": { + "axis_names": [ + "data", + "fsdp", + "expert" + ], + "shape": { + "data": 1, + "fsdp": 4, + "expert": 2 + } + }, + "partition_spec": [ + [ + "fsdp", + "expert" + ], + null, + null, + null + ], + "shape": [ + 2048, + 26, + 16, + 192 + ] + }, + ".opt_state/[0]/.mu/['params']/['decoder']/['moe_layers']/['self_attention']/['wkv_a']/['kernel']": { + "mesh": { + "axis_names": [ + "data", + "fsdp", + "expert" + ], + "shape": { + "data": 1, + "fsdp": 4, + "expert": 2 + } + }, + "partition_spec": [ + [ + "fsdp", + "expert" + ], + null, + null + ], + "shape": [ + 2048, + 26, + 576 + ] + }, + ".opt_state/[0]/.mu/['params']/['decoder']/['moe_layers']/['self_attention']/['wkv_b']/['kernel']": { + "mesh": { + "axis_names": [ + "data", + "fsdp", + "expert" + ], + "shape": { + "data": 1, + "fsdp": 4, + "expert": 2 + } + }, + "partition_spec": [ + [ + "fsdp", + "expert" + ], + null, + null, + null + ], + "shape": [ + 512, + 26, + 16, + 256 + ] + }, + ".opt_state/[0]/.mu/['params']/['token_embedder']/['embedding']": { + "mesh": { + "axis_names": [ + "data", + "fsdp", + "expert" + ], + "shape": { + "data": 1, + "fsdp": 4, + "expert": 2 + } + }, + "partition_spec": [ + null, + [ + "fsdp", + "expert" + ] + ], + "shape": [ + 102400, + 2048 + ] + }, + ".opt_state/[0]/.nu/['params']/['decoder']/['decoder_norm']/['scale']": { + "mesh": { + "axis_names": [ + "data", + "fsdp", + "expert" + ], + "shape": { + "data": 1, + "fsdp": 4, + "expert": 2 + } + }, + "partition_spec": [ + null + ], + "shape": [ + 2048 + ] + }, + ".opt_state/[0]/.nu/['params']/['decoder']/['dense_layers']/['mlp']/['wi_0']/['kernel']": { + "mesh": { + "axis_names": [ + "data", + "fsdp", + "expert" + ], + "shape": { + "data": 1, + "fsdp": 4, + "expert": 2 + } + }, + "partition_spec": [ + [ + "fsdp", + "expert" + ], + null, + null + ], + "shape": [ + 2048, + 1, + 10944 + ] + }, + ".opt_state/[0]/.nu/['params']/['decoder']/['dense_layers']/['mlp']/['wi_1']/['kernel']": { + "mesh": { + "axis_names": [ + "data", + "fsdp", + "expert" + ], + "shape": { + "data": 1, + "fsdp": 4, + "expert": 2 + } + }, + "partition_spec": [ + [ + "fsdp", + "expert" + ], + null, + null + ], + "shape": [ + 2048, + 1, + 10944 + ] + }, + ".opt_state/[0]/.nu/['params']/['decoder']/['dense_layers']/['mlp']/['wo']/['kernel']": { + "mesh": { + "axis_names": [ + "data", + "fsdp", + "expert" + ], + "shape": { + "data": 1, + "fsdp": 4, + "expert": 2 + } + }, + "partition_spec": [ + null, + null, + [ + "fsdp", + "expert" + ] + ], + "shape": [ + 10944, + 1, + 2048 + ] + }, + ".opt_state/[0]/.nu/['params']/['decoder']/['dense_layers']/['post_self_attention_layer_norm']/['scale']": { + "mesh": { + "axis_names": [ + "data", + "fsdp", + "expert" + ], + "shape": { + "data": 1, + "fsdp": 4, + "expert": 2 + } + }, + "partition_spec": [ + null, + null + ], + "shape": [ + 2048, + 1 + ] + }, + ".opt_state/[0]/.nu/['params']/['decoder']/['dense_layers']/['pre_self_attention_layer_norm']/['scale']": { + "mesh": { + "axis_names": [ + "data", + "fsdp", + "expert" + ], + "shape": { + "data": 1, + "fsdp": 4, + "expert": 2 + } + }, + "partition_spec": [ + null, + null + ], + "shape": [ + 2048, + 1 + ] + }, + ".opt_state/[0]/.nu/['params']/['decoder']/['dense_layers']/['self_attention']/['kv_norm']/['scale']": { + "mesh": { + "axis_names": [ + "data", + "fsdp", + "expert" + ], + "shape": { + "data": 1, + "fsdp": 4, + "expert": 2 + } + }, + "partition_spec": [ + null, + null + ], + "shape": [ + 512, + 1 + ] + }, + ".opt_state/[0]/.nu/['params']/['decoder']/['dense_layers']/['self_attention']/['out']/['kernel']": { + "mesh": { + "axis_names": [ + "data", + "fsdp", + "expert" + ], + "shape": { + "data": 1, + "fsdp": 4, + "expert": 2 + } + }, + "partition_spec": [ + null, + null, + null, + [ + "fsdp", + "expert" + ] + ], + "shape": [ + 16, + 1, + 128, + 2048 + ] + }, + ".opt_state/[0]/.nu/['params']/['decoder']/['dense_layers']/['self_attention']/['query']/['kernel']": { + "mesh": { + "axis_names": [ + "data", + "fsdp", + "expert" + ], + "shape": { + "data": 1, + "fsdp": 4, + "expert": 2 + } + }, + "partition_spec": [ + [ + "fsdp", + "expert" + ], + null, + null, + null + ], + "shape": [ + 2048, + 1, + 16, + 192 + ] + }, + ".opt_state/[0]/.nu/['params']/['decoder']/['dense_layers']/['self_attention']/['wkv_a']/['kernel']": { + "mesh": { + "axis_names": [ + "data", + "fsdp", + "expert" + ], + "shape": { + "data": 1, + "fsdp": 4, + "expert": 2 + } + }, + "partition_spec": [ + [ + "fsdp", + "expert" + ], + null, + null + ], + "shape": [ + 2048, + 1, + 576 + ] + }, + ".opt_state/[0]/.nu/['params']/['decoder']/['dense_layers']/['self_attention']/['wkv_b']/['kernel']": { + "mesh": { + "axis_names": [ + "data", + "fsdp", + "expert" + ], + "shape": { + "data": 1, + "fsdp": 4, + "expert": 2 + } + }, + "partition_spec": [ + [ + "fsdp", + "expert" + ], + null, + null, + null + ], + "shape": [ + 512, + 1, + 16, + 256 + ] + }, + ".opt_state/[0]/.nu/['params']/['decoder']/['logits_dense']/['kernel']": { + "mesh": { + "axis_names": [ + "data", + "fsdp", + "expert" + ], + "shape": { + "data": 1, + "fsdp": 4, + "expert": 2 + } + }, + "partition_spec": [ + [ + "fsdp", + "expert" + ], + null + ], + "shape": [ + 2048, + 102400 + ] + }, + ".opt_state/[0]/.nu/['params']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['gate']/['kernel']": { + "mesh": { + "axis_names": [ + "data", + "fsdp", + "expert" + ], + "shape": { + "data": 1, + "fsdp": 4, + "expert": 2 + } + }, + "partition_spec": [ + "fsdp", + null, + null + ], + "shape": [ + 2048, + 26, + 64 + ] + }, + ".opt_state/[0]/.nu/['params']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['wi_0']": { + "mesh": { + "axis_names": [ + "data", + "fsdp", + "expert" + ], + "shape": { + "data": 1, + "fsdp": 4, + "expert": 2 + } + }, + "partition_spec": [ + "expert", + null, + "fsdp", + null + ], + "shape": [ + 64, + 26, + 2048, + 1408 + ] + }, + ".opt_state/[0]/.nu/['params']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['wi_1']": { + "mesh": { + "axis_names": [ + "data", + "fsdp", + "expert" + ], + "shape": { + "data": 1, + "fsdp": 4, + "expert": 2 + } + }, + "partition_spec": [ + "expert", + null, + "fsdp", + null + ], + "shape": [ + 64, + 26, + 2048, + 1408 + ] + }, + ".opt_state/[0]/.nu/['params']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['wo']": { + "mesh": { + "axis_names": [ + "data", + "fsdp", + "expert" + ], + "shape": { + "data": 1, + "fsdp": 4, + "expert": 2 + } + }, + "partition_spec": [ + "expert", + null, + null, + "fsdp" + ], + "shape": [ + 64, + 26, + 1408, + 2048 + ] + }, + ".opt_state/[0]/.nu/['params']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['shared_experts']/['wi_0']/['kernel']": { + "mesh": { + "axis_names": [ + "data", + "fsdp", + "expert" + ], + "shape": { + "data": 1, + "fsdp": 4, + "expert": 2 + } + }, + "partition_spec": [ + [ + "fsdp", + "expert" + ], + null, + null + ], + "shape": [ + 2048, + 26, + 2816 + ] + }, + ".opt_state/[0]/.nu/['params']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['shared_experts']/['wi_1']/['kernel']": { + "mesh": { + "axis_names": [ + "data", + "fsdp", + "expert" + ], + "shape": { + "data": 1, + "fsdp": 4, + "expert": 2 + } + }, + "partition_spec": [ + [ + "fsdp", + "expert" + ], + null, + null + ], + "shape": [ + 2048, + 26, + 2816 + ] + }, + ".opt_state/[0]/.nu/['params']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['shared_experts']/['wo']/['kernel']": { + "mesh": { + "axis_names": [ + "data", + "fsdp", + "expert" + ], + "shape": { + "data": 1, + "fsdp": 4, + "expert": 2 + } + }, + "partition_spec": [ + null, + null, + [ + "fsdp", + "expert" + ] + ], + "shape": [ + 2816, + 26, + 2048 + ] + }, + ".opt_state/[0]/.nu/['params']/['decoder']/['moe_layers']/['post_self_attention_layer_norm']/['scale']": { + "mesh": { + "axis_names": [ + "data", + "fsdp", + "expert" + ], + "shape": { + "data": 1, + "fsdp": 4, + "expert": 2 + } + }, + "partition_spec": [ + null, + null + ], + "shape": [ + 2048, + 26 + ] + }, + ".opt_state/[0]/.nu/['params']/['decoder']/['moe_layers']/['pre_self_attention_layer_norm']/['scale']": { + "mesh": { + "axis_names": [ + "data", + "fsdp", + "expert" + ], + "shape": { + "data": 1, + "fsdp": 4, + "expert": 2 + } + }, + "partition_spec": [ + null, + null + ], + "shape": [ + 2048, + 26 + ] + }, + ".opt_state/[0]/.nu/['params']/['decoder']/['moe_layers']/['self_attention']/['kv_norm']/['scale']": { + "mesh": { + "axis_names": [ + "data", + "fsdp", + "expert" + ], + "shape": { + "data": 1, + "fsdp": 4, + "expert": 2 + } + }, + "partition_spec": [ + null, + null + ], + "shape": [ + 512, + 26 + ] + }, + ".opt_state/[0]/.nu/['params']/['decoder']/['moe_layers']/['self_attention']/['out']/['kernel']": { + "mesh": { + "axis_names": [ + "data", + "fsdp", + "expert" + ], + "shape": { + "data": 1, + "fsdp": 4, + "expert": 2 + } + }, + "partition_spec": [ + null, + null, + null, + [ + "fsdp", + "expert" + ] + ], + "shape": [ + 16, + 26, + 128, + 2048 + ] + }, + ".opt_state/[0]/.nu/['params']/['decoder']/['moe_layers']/['self_attention']/['query']/['kernel']": { + "mesh": { + "axis_names": [ + "data", + "fsdp", + "expert" + ], + "shape": { + "data": 1, + "fsdp": 4, + "expert": 2 + } + }, + "partition_spec": [ + [ + "fsdp", + "expert" + ], + null, + null, + null + ], + "shape": [ + 2048, + 26, + 16, + 192 + ] + }, + ".opt_state/[0]/.nu/['params']/['decoder']/['moe_layers']/['self_attention']/['wkv_a']/['kernel']": { + "mesh": { + "axis_names": [ + "data", + "fsdp", + "expert" + ], + "shape": { + "data": 1, + "fsdp": 4, + "expert": 2 + } + }, + "partition_spec": [ + [ + "fsdp", + "expert" + ], + null, + null + ], + "shape": [ + 2048, + 26, + 576 + ] + }, + ".opt_state/[0]/.nu/['params']/['decoder']/['moe_layers']/['self_attention']/['wkv_b']/['kernel']": { + "mesh": { + "axis_names": [ + "data", + "fsdp", + "expert" + ], + "shape": { + "data": 1, + "fsdp": 4, + "expert": 2 + } + }, + "partition_spec": [ + [ + "fsdp", + "expert" + ], + null, + null, + null + ], + "shape": [ + 512, + 26, + 16, + 256 + ] + }, + ".opt_state/[0]/.nu/['params']/['token_embedder']/['embedding']": { + "mesh": { + "axis_names": [ + "data", + "fsdp", + "expert" + ], + "shape": { + "data": 1, + "fsdp": 4, + "expert": 2 + } + }, + "partition_spec": [ + null, + [ + "fsdp", + "expert" + ] + ], + "shape": [ + 102400, + 2048 + ] + }, + ".opt_state/[2]/.count": { + "mesh": { + "axis_names": [ + "data", + "fsdp", + "expert" + ], + "shape": { + "data": 1, + "fsdp": 4, + "expert": 2 + } + }, + "partition_spec": [], + "shape": [] + } +} \ No newline at end of file From 73cb0957c31a04bd0a2f41145d9df3ce44bbe9cf Mon Sep 17 00:00:00 2001 From: NuojCheng Date: Wed, 15 Apr 2026 16:00:25 +0000 Subject: [PATCH 3/3] deprecate expert_shard_attention_option config --- .../core_concepts/moe_configuration.md | 5 - src/maxtext/common/common_types.py | 4 - src/maxtext/configs/base.yml | 8 +- .../configs/custom_mesh_and_rule/ep-as-cp.yml | 32 ++- src/maxtext/configs/types.py | 6 +- src/maxtext/layers/attention_mla.py | 9 +- src/maxtext/layers/attention_op.py | 3 +- .../named_shardings.json | 198 +++++++++++++++--- 8 files changed, 206 insertions(+), 59 deletions(-) diff --git a/docs/reference/core_concepts/moe_configuration.md b/docs/reference/core_concepts/moe_configuration.md index 7ce7d63110..c150022b15 100644 --- a/docs/reference/core_concepts/moe_configuration.md +++ b/docs/reference/core_concepts/moe_configuration.md @@ -96,11 +96,6 @@ Dropping: ## 2. Sharding -`expert_shard_attention_option`: Determines how the "expert" axis is interpreted when sharding attention layers. Options include: - -- `fsdp`: Treats the expert axis as a FSDP axis. -- `context`: Treats the expert axis as a context parallelism axis, useful for long context. - `use_ring_of_experts` (experimental): This feature requires expert parallelism. If enabled, it replaces the standard two All-to-All communications with All-Gather in dispatch and Reduce-Scatter in collect. By gathering inputs across all shards, it allows for local routing and Top-K calculations, followed by result aggregation via Reduce-Scatter. This approach is particularly effective for models with a large Top-K, as it gathers activations before they are replicated k times to reduce communication. `moe_fsdp_use_two_stage_all_gather`: If enabled, split the All-Gather operation for MoE weights into two separate stages when using FSDP/FSDP-transpose sharding. This is preferred when 3D All-Gather support is unavailable. diff --git a/src/maxtext/common/common_types.py b/src/maxtext/common/common_types.py index 8ab7182779..ec2e96333f 100644 --- a/src/maxtext/common/common_types.py +++ b/src/maxtext/common/common_types.py @@ -66,10 +66,6 @@ MODEL_MODE_PREFILL = "prefill" MODEL_MODE_TRAIN = "train" -# expert_shard_attention_option -EP_AS_CONTEXT = "context" -EP_AS_FSDP = "fsdp" - DECODING_ACTIVE_SEQUENCE_INDICATOR = 1 # A large negative mask value is used for masking to ensure that the diff --git a/src/maxtext/configs/base.yml b/src/maxtext/configs/base.yml index 44d7bde339..b680912ae0 100644 --- a/src/maxtext/configs/base.yml +++ b/src/maxtext/configs/base.yml @@ -237,11 +237,6 @@ merge_gating_gmm: False norm_topk_prob: false # boolean to enable the top-k probability normalization. qwen3-specific normalization of router weights. -# how the expert axis is used to shard attention weights and activations -# "fsdp" (ep acts as fsdp parallelism) -# "context" (ep acts as context parallelism, training only) -expert_shard_attention_option: "fsdp" - # when moe weight matrices are sharded on both fsdp and fsdp-transpose axes, use two separate all-gather calls moe_fsdp_use_two_stage_all_gather: false # Shard the expert dimension of the MLP weights on the FSDP axis. @@ -521,6 +516,7 @@ logical_axis_rules: [ # ========================================== # Dense Activations ['activation_mlp', ['tensor', 'tensor_transpose', 'tensor_sequence']], + # Note activation batch and length also get used in attention and vocab ['activation_batch', ['data', 'fsdp', 'fsdp_transpose', 'expert']], ['activation_length', ['sequence', 'context']], ['activation_length', ['context']], @@ -569,6 +565,8 @@ logical_axis_rules: [ # Axes used for DCN must be earlier in this list than ICI, see (b/339009148) for details data_sharding: [['data', 'stage', 'fsdp', 'fsdp_transpose', 'sequence', 'context', 'context_autoregressive', 'tensor', 'tensor_transpose', 'tensor_sequence', 'expert', 'autoregressive']] input_data_sharding_logical_axes: ['activation_embed_and_logits_batch', 'activation_norm_length'] +# Determines which physical axis plays the role of context parallelism for input data processing and load balancing +# only supports "context" or "expert" (when custom_mesh_and_rule=ep-as-cp) context_sharding: "context" # sharding tolerance: float between 0.0 and 1.0 representing the allowed percentage of non-sharded parameters. diff --git a/src/maxtext/configs/custom_mesh_and_rule/ep-as-cp.yml b/src/maxtext/configs/custom_mesh_and_rule/ep-as-cp.yml index 00969328ca..c13be7e266 100644 --- a/src/maxtext/configs/custom_mesh_and_rule/ep-as-cp.yml +++ b/src/maxtext/configs/custom_mesh_and_rule/ep-as-cp.yml @@ -1,13 +1,29 @@ -mesh_axes: ['data', 'fsdp', 'expert'] -data_sharding: [['data', 'fsdp', 'expert']] +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# This rule uses data, stage, FSDP, and expert. Expert axis acts as context parallelism in +# components except core dMoE part (between EP all2all). +mesh_axes: ['data', 'stage', 'fsdp', 'expert'] +data_sharding: [['data', 'stage', 'fsdp', 'expert']] context_sharding: 'expert' logical_axis_rules: [ # ========================================== # Vocabulary Embedding # ========================================== # Vocab Activations - ['activation_embed_and_logits_batch', ['data', 'fsdp']], - ['activation_embed_and_logits_batch_sequence', ['data', 'fsdp', 'expert']], + ['activation_embed_and_logits_batch', ['data', 'stage', 'fsdp']], + ['activation_embed_and_logits_batch_sequence', ['data', 'stage', 'fsdp', 'expert']], # Vocab Weights ['vocab', []], ['embed_vocab', ['fsdp', 'expert']], @@ -31,9 +47,9 @@ logical_axis_rules: [ ['qkv', []], ['kv', []], ['kv_head_dim', []], - ['q_lora', ['fsdp', 'expert']], + ['q_lora', ['fsdp']], ["q_lora_up_proj", []], - ['kv_lora', ['fsdp', 'expert']], + ['kv_lora', ['fsdp']], ["kv_lora_up_proj", []], # ========================================== # Mixture of Experts (MoE) @@ -53,7 +69,9 @@ logical_axis_rules: [ ['activation_length', ['expert']], ['activation_norm_length', ['expert']], ['activation_embed', []], + ['activation_stage', 'stage'], # General Weights ['mlp', []], + ['layers', 'stage'], ['embed', ['fsdp', 'expert']], - ] \ No newline at end of file + ] diff --git a/src/maxtext/configs/types.py b/src/maxtext/configs/types.py index 22765c1843..ee5b916f21 100644 --- a/src/maxtext/configs/types.py +++ b/src/maxtext/configs/types.py @@ -661,10 +661,6 @@ class MoEGeneral(BaseModel): ) use_random_routing: bool = Field(False, description="Whether to use random routing for debugging.") interleave_moe_layer_step: int = Field(1, description="Frequency of MoE layers, e.g., 2 means every 2nd layer is MoE.") - expert_shard_attention_option: Literal["fsdp", "context"] = Field( - "fsdp", - description="How the expert axis is used to shard attention weights and activations.", - ) moe_fsdp_use_two_stage_all_gather: bool = Field( False, description="Use two separate All-Gather calls for MoE weights sharded on both FSDP and FSDP-transpose.", @@ -2605,6 +2601,8 @@ def calculate_global_batch_sizes(per_device_batch_size, expansion_factor, num_de ) if self.quantization: raise ValueError("Quantization is not supported with 'explicit' sharding.") + if self.context_sharding not in ("context", "expert"): + raise ValueError(f"Assigned context_sharding f{self.context_sharding} is not supported.") if ( self.per_device_batch_size > 0 and (self.per_device_batch_size * self.max_target_length) % self.num_vocab_tiling != 0 diff --git a/src/maxtext/layers/attention_mla.py b/src/maxtext/layers/attention_mla.py index 5b2ab5c119..f0d7791a0d 100644 --- a/src/maxtext/layers/attention_mla.py +++ b/src/maxtext/layers/attention_mla.py @@ -48,7 +48,6 @@ D_KV, DType, EMBED, - EP_AS_CONTEXT, HEAD, Q_LORA_UP_PROJ, KV_BATCH, @@ -905,9 +904,6 @@ def mla_get_key_value(self, low_rank_main, key_rope, model_mode): if model_mode == MODEL_MODE_PREFILL: key_logical_name = self.prefill_key_axis_names value_logical_name = self.prefill_value_axis_names - elif model_mode == MODEL_MODE_TRAIN and self.config.expert_shard_attention_option == EP_AS_CONTEXT: - key_logical_name = self.ep_key_axis_names - value_logical_name = self.ep_value_axis_names else: key_logical_name = self.key_axis_names value_logical_name = self.value_axis_names @@ -1227,11 +1223,8 @@ def __call__( record_max_logits=use_qk_clip, ) + out = self._maybe_shard_with_logical(out, self.out_axis_names) out = jax.ad_checkpoint.checkpoint_name(out, "attention_out") - if model_mode == MODEL_MODE_TRAIN and self.config.expert_shard_attention_option == EP_AS_CONTEXT: - out = self._maybe_shard_with_logical(out, self.ep_out_axis_names) - else: - out = self._maybe_shard_with_logical(out, self.out_axis_names) out_sharding = create_sharding(self.mesh, out_logical_name) out = self.out_projection(out, out_sharding=out_sharding) diff --git a/src/maxtext/layers/attention_op.py b/src/maxtext/layers/attention_op.py index ddc4fff9a1..c770501309 100644 --- a/src/maxtext/layers/attention_op.py +++ b/src/maxtext/layers/attention_op.py @@ -55,7 +55,6 @@ DEFAULT_MASK_VALUE, DType, D_KV, - EP_AS_FSDP, HEAD, KV_LENGTH, LENGTH, @@ -1270,7 +1269,7 @@ def wrap_splash_kernel(single_head_mask): splash_kernel = wrap_splash_kernel(single_head_mask) segment_axis_names_splash_kernel = self._logical_to_mesh_axes((Q_LENGTH,)) - elif self.config.use_jax_splash and self.config.expert_shard_attention_option == EP_AS_FSDP: + elif self.config.use_jax_splash: if self.config.use_max_logit_estimate > 0: sa_config = dataclasses.replace(sa_config, max_logit_const=self.config.use_max_logit_estimate) segment_axis_names_splash_kernel = nn.logical_to_mesh_axes((Q_LENGTH,)) diff --git a/tests/utils/sharding_info/deepseek2-16b/tpu7x-8/slice_1/rule_ep-as-cp_ici_fsdp_parallelism=-1_ici_expert_parallelism=2/named_shardings.json b/tests/utils/sharding_info/deepseek2-16b/tpu7x-8/slice_1/rule_ep-as-cp_ici_fsdp_parallelism=-1_ici_expert_parallelism=2/named_shardings.json index 76e3036a34..b8bdbc2c58 100644 --- a/tests/utils/sharding_info/deepseek2-16b/tpu7x-8/slice_1/rule_ep-as-cp_ici_fsdp_parallelism=-1_ici_expert_parallelism=2/named_shardings.json +++ b/tests/utils/sharding_info/deepseek2-16b/tpu7x-8/slice_1/rule_ep-as-cp_ici_fsdp_parallelism=-1_ici_expert_parallelism=2/named_shardings.json @@ -3,11 +3,13 @@ "mesh": { "axis_names": [ "data", + "stage", "fsdp", "expert" ], "shape": { "data": 1, + "stage": 1, "fsdp": 4, "expert": 2 } @@ -19,11 +21,13 @@ "mesh": { "axis_names": [ "data", + "stage", "fsdp", "expert" ], "shape": { "data": 1, + "stage": 1, "fsdp": 4, "expert": 2 } @@ -39,11 +43,13 @@ "mesh": { "axis_names": [ "data", + "stage", "fsdp", "expert" ], "shape": { "data": 1, + "stage": 1, "fsdp": 4, "expert": 2 } @@ -66,11 +72,13 @@ "mesh": { "axis_names": [ "data", + "stage", "fsdp", "expert" ], "shape": { "data": 1, + "stage": 1, "fsdp": 4, "expert": 2 } @@ -93,11 +101,13 @@ "mesh": { "axis_names": [ "data", + "stage", "fsdp", "expert" ], "shape": { "data": 1, + "stage": 1, "fsdp": 4, "expert": 2 } @@ -120,11 +130,13 @@ "mesh": { "axis_names": [ "data", + "stage", "fsdp", "expert" ], "shape": { "data": 1, + "stage": 1, "fsdp": 4, "expert": 2 } @@ -142,11 +154,13 @@ "mesh": { "axis_names": [ "data", + "stage", "fsdp", "expert" ], "shape": { "data": 1, + "stage": 1, "fsdp": 4, "expert": 2 } @@ -164,11 +178,13 @@ "mesh": { "axis_names": [ "data", + "stage", "fsdp", "expert" ], "shape": { "data": 1, + "stage": 1, "fsdp": 4, "expert": 2 } @@ -186,11 +202,13 @@ "mesh": { "axis_names": [ "data", + "stage", "fsdp", "expert" ], "shape": { "data": 1, + "stage": 1, "fsdp": 4, "expert": 2 } @@ -215,11 +233,13 @@ "mesh": { "axis_names": [ "data", + "stage", "fsdp", "expert" ], "shape": { "data": 1, + "stage": 1, "fsdp": 4, "expert": 2 } @@ -244,11 +264,13 @@ "mesh": { "axis_names": [ "data", + "stage", "fsdp", "expert" ], "shape": { "data": 1, + "stage": 1, "fsdp": 4, "expert": 2 } @@ -271,20 +293,19 @@ "mesh": { "axis_names": [ "data", + "stage", "fsdp", "expert" ], "shape": { "data": 1, + "stage": 1, "fsdp": 4, "expert": 2 } }, "partition_spec": [ - [ - "fsdp", - "expert" - ], + "fsdp", null, null, null @@ -300,11 +321,13 @@ "mesh": { "axis_names": [ "data", + "stage", "fsdp", "expert" ], "shape": { "data": 1, + "stage": 1, "fsdp": 4, "expert": 2 } @@ -325,11 +348,13 @@ "mesh": { "axis_names": [ "data", + "stage", "fsdp", "expert" ], "shape": { "data": 1, + "stage": 1, "fsdp": 4, "expert": 2 } @@ -349,11 +374,13 @@ "mesh": { "axis_names": [ "data", + "stage", "fsdp", "expert" ], "shape": { "data": 1, + "stage": 1, "fsdp": 4, "expert": 2 } @@ -375,11 +402,13 @@ "mesh": { "axis_names": [ "data", + "stage", "fsdp", "expert" ], "shape": { "data": 1, + "stage": 1, "fsdp": 4, "expert": 2 } @@ -401,11 +430,13 @@ "mesh": { "axis_names": [ "data", + "stage", "fsdp", "expert" ], "shape": { "data": 1, + "stage": 1, "fsdp": 4, "expert": 2 } @@ -427,11 +458,13 @@ "mesh": { "axis_names": [ "data", + "stage", "fsdp", "expert" ], "shape": { "data": 1, + "stage": 1, "fsdp": 4, "expert": 2 } @@ -454,11 +487,13 @@ "mesh": { "axis_names": [ "data", + "stage", "fsdp", "expert" ], "shape": { "data": 1, + "stage": 1, "fsdp": 4, "expert": 2 } @@ -481,11 +516,13 @@ "mesh": { "axis_names": [ "data", + "stage", "fsdp", "expert" ], "shape": { "data": 1, + "stage": 1, "fsdp": 4, "expert": 2 } @@ -508,11 +545,13 @@ "mesh": { "axis_names": [ "data", + "stage", "fsdp", "expert" ], "shape": { "data": 1, + "stage": 1, "fsdp": 4, "expert": 2 } @@ -530,11 +569,13 @@ "mesh": { "axis_names": [ "data", + "stage", "fsdp", "expert" ], "shape": { "data": 1, + "stage": 1, "fsdp": 4, "expert": 2 } @@ -552,11 +593,13 @@ "mesh": { "axis_names": [ "data", + "stage", "fsdp", "expert" ], "shape": { "data": 1, + "stage": 1, "fsdp": 4, "expert": 2 } @@ -574,11 +617,13 @@ "mesh": { "axis_names": [ "data", + "stage", "fsdp", "expert" ], "shape": { "data": 1, + "stage": 1, "fsdp": 4, "expert": 2 } @@ -603,11 +648,13 @@ "mesh": { "axis_names": [ "data", + "stage", "fsdp", "expert" ], "shape": { "data": 1, + "stage": 1, "fsdp": 4, "expert": 2 } @@ -632,11 +679,13 @@ "mesh": { "axis_names": [ "data", + "stage", "fsdp", "expert" ], "shape": { "data": 1, + "stage": 1, "fsdp": 4, "expert": 2 } @@ -659,20 +708,19 @@ "mesh": { "axis_names": [ "data", + "stage", "fsdp", "expert" ], "shape": { "data": 1, + "stage": 1, "fsdp": 4, "expert": 2 } }, "partition_spec": [ - [ - "fsdp", - "expert" - ], + "fsdp", null, null, null @@ -688,11 +736,13 @@ "mesh": { "axis_names": [ "data", + "stage", "fsdp", "expert" ], "shape": { "data": 1, + "stage": 1, "fsdp": 4, "expert": 2 } @@ -713,11 +763,13 @@ "mesh": { "axis_names": [ "data", + "stage", "fsdp", "expert" ], "shape": { "data": 1, + "stage": 1, "fsdp": 4, "expert": 2 } @@ -729,11 +781,13 @@ "mesh": { "axis_names": [ "data", + "stage", "fsdp", "expert" ], "shape": { "data": 1, + "stage": 1, "fsdp": 4, "expert": 2 } @@ -749,11 +803,13 @@ "mesh": { "axis_names": [ "data", + "stage", "fsdp", "expert" ], "shape": { "data": 1, + "stage": 1, "fsdp": 4, "expert": 2 } @@ -776,11 +832,13 @@ "mesh": { "axis_names": [ "data", + "stage", "fsdp", "expert" ], "shape": { "data": 1, + "stage": 1, "fsdp": 4, "expert": 2 } @@ -803,11 +861,13 @@ "mesh": { "axis_names": [ "data", + "stage", "fsdp", "expert" ], "shape": { "data": 1, + "stage": 1, "fsdp": 4, "expert": 2 } @@ -830,11 +890,13 @@ "mesh": { "axis_names": [ "data", + "stage", "fsdp", "expert" ], "shape": { "data": 1, + "stage": 1, "fsdp": 4, "expert": 2 } @@ -852,11 +914,13 @@ "mesh": { "axis_names": [ "data", + "stage", "fsdp", "expert" ], "shape": { "data": 1, + "stage": 1, "fsdp": 4, "expert": 2 } @@ -874,11 +938,13 @@ "mesh": { "axis_names": [ "data", + "stage", "fsdp", "expert" ], "shape": { "data": 1, + "stage": 1, "fsdp": 4, "expert": 2 } @@ -896,11 +962,13 @@ "mesh": { "axis_names": [ "data", + "stage", "fsdp", "expert" ], "shape": { "data": 1, + "stage": 1, "fsdp": 4, "expert": 2 } @@ -925,11 +993,13 @@ "mesh": { "axis_names": [ "data", + "stage", "fsdp", "expert" ], "shape": { "data": 1, + "stage": 1, "fsdp": 4, "expert": 2 } @@ -954,11 +1024,13 @@ "mesh": { "axis_names": [ "data", + "stage", "fsdp", "expert" ], "shape": { "data": 1, + "stage": 1, "fsdp": 4, "expert": 2 } @@ -981,20 +1053,19 @@ "mesh": { "axis_names": [ "data", + "stage", "fsdp", "expert" ], "shape": { "data": 1, + "stage": 1, "fsdp": 4, "expert": 2 } }, "partition_spec": [ - [ - "fsdp", - "expert" - ], + "fsdp", null, null, null @@ -1010,11 +1081,13 @@ "mesh": { "axis_names": [ "data", + "stage", "fsdp", "expert" ], "shape": { "data": 1, + "stage": 1, "fsdp": 4, "expert": 2 } @@ -1035,11 +1108,13 @@ "mesh": { "axis_names": [ "data", + "stage", "fsdp", "expert" ], "shape": { "data": 1, + "stage": 1, "fsdp": 4, "expert": 2 } @@ -1059,11 +1134,13 @@ "mesh": { "axis_names": [ "data", + "stage", "fsdp", "expert" ], "shape": { "data": 1, + "stage": 1, "fsdp": 4, "expert": 2 } @@ -1085,11 +1162,13 @@ "mesh": { "axis_names": [ "data", + "stage", "fsdp", "expert" ], "shape": { "data": 1, + "stage": 1, "fsdp": 4, "expert": 2 } @@ -1111,11 +1190,13 @@ "mesh": { "axis_names": [ "data", + "stage", "fsdp", "expert" ], "shape": { "data": 1, + "stage": 1, "fsdp": 4, "expert": 2 } @@ -1137,11 +1218,13 @@ "mesh": { "axis_names": [ "data", + "stage", "fsdp", "expert" ], "shape": { "data": 1, + "stage": 1, "fsdp": 4, "expert": 2 } @@ -1164,11 +1247,13 @@ "mesh": { "axis_names": [ "data", + "stage", "fsdp", "expert" ], "shape": { "data": 1, + "stage": 1, "fsdp": 4, "expert": 2 } @@ -1191,11 +1276,13 @@ "mesh": { "axis_names": [ "data", + "stage", "fsdp", "expert" ], "shape": { "data": 1, + "stage": 1, "fsdp": 4, "expert": 2 } @@ -1218,11 +1305,13 @@ "mesh": { "axis_names": [ "data", + "stage", "fsdp", "expert" ], "shape": { "data": 1, + "stage": 1, "fsdp": 4, "expert": 2 } @@ -1240,11 +1329,13 @@ "mesh": { "axis_names": [ "data", + "stage", "fsdp", "expert" ], "shape": { "data": 1, + "stage": 1, "fsdp": 4, "expert": 2 } @@ -1262,11 +1353,13 @@ "mesh": { "axis_names": [ "data", + "stage", "fsdp", "expert" ], "shape": { "data": 1, + "stage": 1, "fsdp": 4, "expert": 2 } @@ -1284,11 +1377,13 @@ "mesh": { "axis_names": [ "data", + "stage", "fsdp", "expert" ], "shape": { "data": 1, + "stage": 1, "fsdp": 4, "expert": 2 } @@ -1313,11 +1408,13 @@ "mesh": { "axis_names": [ "data", + "stage", "fsdp", "expert" ], "shape": { "data": 1, + "stage": 1, "fsdp": 4, "expert": 2 } @@ -1342,11 +1439,13 @@ "mesh": { "axis_names": [ "data", + "stage", "fsdp", "expert" ], "shape": { "data": 1, + "stage": 1, "fsdp": 4, "expert": 2 } @@ -1369,20 +1468,19 @@ "mesh": { "axis_names": [ "data", + "stage", "fsdp", "expert" ], "shape": { "data": 1, + "stage": 1, "fsdp": 4, "expert": 2 } }, "partition_spec": [ - [ - "fsdp", - "expert" - ], + "fsdp", null, null, null @@ -1398,11 +1496,13 @@ "mesh": { "axis_names": [ "data", + "stage", "fsdp", "expert" ], "shape": { "data": 1, + "stage": 1, "fsdp": 4, "expert": 2 } @@ -1423,11 +1523,13 @@ "mesh": { "axis_names": [ "data", + "stage", "fsdp", "expert" ], "shape": { "data": 1, + "stage": 1, "fsdp": 4, "expert": 2 } @@ -1443,11 +1545,13 @@ "mesh": { "axis_names": [ "data", + "stage", "fsdp", "expert" ], "shape": { "data": 1, + "stage": 1, "fsdp": 4, "expert": 2 } @@ -1470,11 +1574,13 @@ "mesh": { "axis_names": [ "data", + "stage", "fsdp", "expert" ], "shape": { "data": 1, + "stage": 1, "fsdp": 4, "expert": 2 } @@ -1497,11 +1603,13 @@ "mesh": { "axis_names": [ "data", + "stage", "fsdp", "expert" ], "shape": { "data": 1, + "stage": 1, "fsdp": 4, "expert": 2 } @@ -1524,11 +1632,13 @@ "mesh": { "axis_names": [ "data", + "stage", "fsdp", "expert" ], "shape": { "data": 1, + "stage": 1, "fsdp": 4, "expert": 2 } @@ -1546,11 +1656,13 @@ "mesh": { "axis_names": [ "data", + "stage", "fsdp", "expert" ], "shape": { "data": 1, + "stage": 1, "fsdp": 4, "expert": 2 } @@ -1568,11 +1680,13 @@ "mesh": { "axis_names": [ "data", + "stage", "fsdp", "expert" ], "shape": { "data": 1, + "stage": 1, "fsdp": 4, "expert": 2 } @@ -1590,11 +1704,13 @@ "mesh": { "axis_names": [ "data", + "stage", "fsdp", "expert" ], "shape": { "data": 1, + "stage": 1, "fsdp": 4, "expert": 2 } @@ -1619,11 +1735,13 @@ "mesh": { "axis_names": [ "data", + "stage", "fsdp", "expert" ], "shape": { "data": 1, + "stage": 1, "fsdp": 4, "expert": 2 } @@ -1648,11 +1766,13 @@ "mesh": { "axis_names": [ "data", + "stage", "fsdp", "expert" ], "shape": { "data": 1, + "stage": 1, "fsdp": 4, "expert": 2 } @@ -1675,20 +1795,19 @@ "mesh": { "axis_names": [ "data", + "stage", "fsdp", "expert" ], "shape": { "data": 1, + "stage": 1, "fsdp": 4, "expert": 2 } }, "partition_spec": [ - [ - "fsdp", - "expert" - ], + "fsdp", null, null, null @@ -1704,11 +1823,13 @@ "mesh": { "axis_names": [ "data", + "stage", "fsdp", "expert" ], "shape": { "data": 1, + "stage": 1, "fsdp": 4, "expert": 2 } @@ -1729,11 +1850,13 @@ "mesh": { "axis_names": [ "data", + "stage", "fsdp", "expert" ], "shape": { "data": 1, + "stage": 1, "fsdp": 4, "expert": 2 } @@ -1753,11 +1876,13 @@ "mesh": { "axis_names": [ "data", + "stage", "fsdp", "expert" ], "shape": { "data": 1, + "stage": 1, "fsdp": 4, "expert": 2 } @@ -1779,11 +1904,13 @@ "mesh": { "axis_names": [ "data", + "stage", "fsdp", "expert" ], "shape": { "data": 1, + "stage": 1, "fsdp": 4, "expert": 2 } @@ -1805,11 +1932,13 @@ "mesh": { "axis_names": [ "data", + "stage", "fsdp", "expert" ], "shape": { "data": 1, + "stage": 1, "fsdp": 4, "expert": 2 } @@ -1831,11 +1960,13 @@ "mesh": { "axis_names": [ "data", + "stage", "fsdp", "expert" ], "shape": { "data": 1, + "stage": 1, "fsdp": 4, "expert": 2 } @@ -1858,11 +1989,13 @@ "mesh": { "axis_names": [ "data", + "stage", "fsdp", "expert" ], "shape": { "data": 1, + "stage": 1, "fsdp": 4, "expert": 2 } @@ -1885,11 +2018,13 @@ "mesh": { "axis_names": [ "data", + "stage", "fsdp", "expert" ], "shape": { "data": 1, + "stage": 1, "fsdp": 4, "expert": 2 } @@ -1912,11 +2047,13 @@ "mesh": { "axis_names": [ "data", + "stage", "fsdp", "expert" ], "shape": { "data": 1, + "stage": 1, "fsdp": 4, "expert": 2 } @@ -1934,11 +2071,13 @@ "mesh": { "axis_names": [ "data", + "stage", "fsdp", "expert" ], "shape": { "data": 1, + "stage": 1, "fsdp": 4, "expert": 2 } @@ -1956,11 +2095,13 @@ "mesh": { "axis_names": [ "data", + "stage", "fsdp", "expert" ], "shape": { "data": 1, + "stage": 1, "fsdp": 4, "expert": 2 } @@ -1978,11 +2119,13 @@ "mesh": { "axis_names": [ "data", + "stage", "fsdp", "expert" ], "shape": { "data": 1, + "stage": 1, "fsdp": 4, "expert": 2 } @@ -2007,11 +2150,13 @@ "mesh": { "axis_names": [ "data", + "stage", "fsdp", "expert" ], "shape": { "data": 1, + "stage": 1, "fsdp": 4, "expert": 2 } @@ -2036,11 +2181,13 @@ "mesh": { "axis_names": [ "data", + "stage", "fsdp", "expert" ], "shape": { "data": 1, + "stage": 1, "fsdp": 4, "expert": 2 } @@ -2063,20 +2210,19 @@ "mesh": { "axis_names": [ "data", + "stage", "fsdp", "expert" ], "shape": { "data": 1, + "stage": 1, "fsdp": 4, "expert": 2 } }, "partition_spec": [ - [ - "fsdp", - "expert" - ], + "fsdp", null, null, null @@ -2092,11 +2238,13 @@ "mesh": { "axis_names": [ "data", + "stage", "fsdp", "expert" ], "shape": { "data": 1, + "stage": 1, "fsdp": 4, "expert": 2 } @@ -2117,11 +2265,13 @@ "mesh": { "axis_names": [ "data", + "stage", "fsdp", "expert" ], "shape": { "data": 1, + "stage": 1, "fsdp": 4, "expert": 2 }