diff --git a/src/maxtext/configs/base.yml b/src/maxtext/configs/base.yml index 2340fd72aa..4671d76d4c 100644 --- a/src/maxtext/configs/base.yml +++ b/src/maxtext/configs/base.yml @@ -488,6 +488,7 @@ logical_axis_rules: [ ['activation_stage', 'stage'], ['activation_exp', ['expert']], ['decode_batch', ['data', 'fsdp', 'fsdp_transpose', 'expert']], + ['decode_batch_moe', ['data', 'fsdp', 'fsdp_transpose']], ['decode_length', ['sequence']], ['mlp', ['fsdp_transpose', 'tensor', 'tensor_sequence', 'autoregressive']], ['mlp_moe', ['fsdp_transpose', 'tensor', 'tensor_sequence', 'autoregressive']], diff --git a/src/maxtext/configs/inference/inference.yml b/src/maxtext/configs/inference/inference.yml index 55407b3edc..1206675599 100644 --- a/src/maxtext/configs/inference/inference.yml +++ b/src/maxtext/configs/inference/inference.yml @@ -24,6 +24,7 @@ logical_axis_rules: [ ['activation_stage', 'stage'], ['activation_exp', ['expert', 'context_autoregressive']], ['decode_batch', ['data', 'fsdp', 'fsdp_transpose', 'expert', 'context_autoregressive']], + ['decode_batch_moe', ['data', 'fsdp', 'fsdp_transpose', 'context_autoregressive']], ['decode_length', []], ['mlp', ['fsdp_transpose', 'tensor', 'tensor_sequence', 'autoregressive']], ['mlp_moe', ['fsdp_transpose', 'tensor', 'tensor_sequence', 'autoregressive']], diff --git a/src/maxtext/configs/inference/vllm.yml b/src/maxtext/configs/inference/vllm.yml index ffa984296d..d939663a69 100644 --- a/src/maxtext/configs/inference/vllm.yml +++ b/src/maxtext/configs/inference/vllm.yml @@ -54,6 +54,7 @@ logical_axis_rules: [ ['activation_norm_length_moe', []], ['activation_exp', ['expert', 'attn_dp_expert']], ['decode_batch', ['expert', 'attn_dp_expert']], + ['decode_batch_moe', []], ['decode_length', []], ['mlp', ['model', 'attn_dp']], ['mlp_moe', ['model', 'attn_dp']], diff --git a/src/maxtext/configs/post_train/rl_mt_jt.yml b/src/maxtext/configs/post_train/rl_mt_jt.yml index 4383b1c4ac..34829fbc19 100644 --- a/src/maxtext/configs/post_train/rl_mt_jt.yml +++ b/src/maxtext/configs/post_train/rl_mt_jt.yml @@ -39,6 +39,7 @@ logical_axis_rules: [ ['activation_stage', 'stage'], ['activation_exp', ['expert', 'context_autoregressive']], ['decode_batch', ['data', 'fsdp', 'fsdp_transpose', 'expert', 'context_autoregressive']], + ['decode_batch_moe', ['data', 'fsdp', 'fsdp_transpose', 'expert', 'context_autoregressive']], ['decode_length', []], ['mlp', ['fsdp_transpose', 'tensor', 'tensor_sequence', 'autoregressive']], ['vocab', ['tensor', 'tensor_transpose', 'tensor_sequence', 'autoregressive','context_autoregressive']], diff --git a/src/maxtext/layers/moe.py b/src/maxtext/layers/moe.py index 89325d14f8..3dd1322f7b 100644 --- a/src/maxtext/layers/moe.py +++ b/src/maxtext/layers/moe.py @@ -1027,7 +1027,13 @@ def gmm( output = output[: orig_inputs_shape[0]] return output - batch_logical_axis = "activation_batch" + # The batch is sharded by expert, except during inference decoding (where batch size == 1). + # In the decoding case, the expert axis is instead replicated along the tensor's batch dimension. + is_batch_sharded_by_expert = inputs.shape[0] > 1 + if is_batch_sharded_by_expert: + batch_logical_axis = "activation_batch" + else: + batch_logical_axis = "decode_batch_moe" if self.get_tensor_transpose_parallelism_size() > 1: input_partition_pspec = self._logical_to_mesh_axes( @@ -1142,47 +1148,59 @@ def wrapper(x, logits, pre_bias_logits, w0, w1, wo, w0_bias, w1_bias, wo_bias, r ) if num_expert_parallelism > 1: + batch_axis = self._expert_parallelism_name if is_batch_sharded_by_expert else "data" # get group sizes for all shards local_expert_size = self.config.num_experts // num_expert_parallelism reshaped_group_sizes = jnp.sum(group_sizes.reshape(-1, local_expert_size), axis=1) global_group_sizes = group_sizes - all_shards_group_sizes = jax.lax.all_gather(reshaped_group_sizes, axis_name=self._expert_parallelism_name) - input_offsets, send_sizes, output_offsets, recv_sizes = RoutedMoE.get_all_to_all_params( - all_shards_group_sizes, - expert_shard_id, - num_expert_parallelism, - ) - - # TODO(ranran): For better performance, we could update output buffer to a smaller - # size to replace self.get_expert_parallelism_size() for efficiency, - # Or we could apply capacity_factor for excessive experts. - # Note: Reducing buffer increase the risk of token dropping under unbalanced distribution. - - # In the worst case, all of the global input data is assigned to each expert in the current shard. - # This would result in num_expert_shards * input_size * experts_per_shard assignments. However, if - # experts_per_shard > num_experts_per_tok we cannot assign more than num_experts_per_tok to all of the inputs. - max_local_experts_per_tok = min(local_expert_size, self.config.num_experts_per_tok) - buffer_size = int(num_expert_parallelism * batch_size * sequence_length * max_local_experts_per_tok) - output_shape = jax.lax.empty((buffer_size, self.config.emb_dim), dtype=x.dtype) - - x = jax.lax.ragged_all_to_all( - x, - output_shape, - input_offsets, - send_sizes, - output_offsets, - recv_sizes, - axis_name=self._expert_parallelism_name, - ) - global_group_sizes = jax.lax.all_gather(group_sizes, axis_name=self._expert_parallelism_name) - x, local_sorted_indices, group_sizes, selected_experts = RoutedMoE.local_permute( - x, - global_group_sizes, - local_expert_size, - shard_index=expert_shard_id, - use_custom_sort_vjp=self.config.use_custom_sort_vjp, - ) + if is_batch_sharded_by_expert: + all_shards_group_sizes = jax.lax.all_gather(reshaped_group_sizes, axis_name=batch_axis) + input_offsets, send_sizes, output_offsets, recv_sizes = RoutedMoE.get_all_to_all_params( + all_shards_group_sizes, + expert_shard_id, + num_expert_parallelism, + ) + + # TODO(ranran): For better performance, we could update output buffer to a smaller + # size to replace self.get_expert_parallelism_size() for efficiency, + # Or we could apply capacity_factor for excessive experts. + # Note: Reducing buffer increase the risk of token dropping under unbalanced distribution. + + # In the worst case, all of the global input data is assigned to each expert in the current shard. + # This would result in num_expert_shards * input_size * experts_per_shard assignments. However, if + # experts_per_shard > num_experts_per_tok we cannot assign more than num_experts_per_tok to all of the inputs. + max_local_experts_per_tok = min(local_expert_size, self.config.num_experts_per_tok) + buffer_size = int(num_expert_parallelism * batch_size * sequence_length * max_local_experts_per_tok) + output_shape = jax.lax.empty((buffer_size, self.config.emb_dim), dtype=x.dtype) + + x = jax.lax.ragged_all_to_all( + x, + output_shape, + input_offsets, + send_sizes, + output_offsets, + recv_sizes, + axis_name=self._expert_parallelism_name, + ) + global_group_sizes = jax.lax.all_gather(group_sizes, axis_name=self._expert_parallelism_name) + x, local_sorted_indices, group_sizes, selected_experts = RoutedMoE.local_permute( + x, + global_group_sizes, + local_expert_size, + shard_index=expert_shard_id, + use_custom_sort_vjp=self.config.use_custom_sort_vjp, + ) + else: + x, local_sorted_indices, group_sizes, selected_experts = RoutedMoE.local_permute( + x, + global_group_sizes[None, :], + local_expert_size, + shard_index=expert_shard_id, + is_offset=True, + global_sorted_experts=selected_experts, + use_custom_sort_vjp=self.config.use_custom_sort_vjp, + ) if self.config.mlp_bias: w0_bias, w1_bias, wo_bias = self.transform_bias(selected_experts, w0_bias, w1_bias, wo_bias) @@ -1325,26 +1343,47 @@ def get_active_sharding_axes(pspec_dim_axes, tensor_dim_index): dtype=intermediate_output.dtype, ) - # locally unpermute back to the original order - local_output = _sort_activations( - intermediate_output, - jnp.argsort(local_sorted_indices), # pylint: disable=undefined-variable - self.config.use_custom_sort_vjp, - ) - input_offsets, send_sizes, output_offsets, recv_sizes = RoutedMoE.get_all_to_all_params( - jnp.transpose(all_shards_group_sizes), # pylint: disable=undefined-variable - expert_shard_id, - num_expert_parallelism, - ) - intermediate_output = jax.lax.ragged_all_to_all( - local_output, - output_shape, - input_offsets, - send_sizes, - output_offsets, - recv_sizes, - axis_name=self._expert_parallelism_name, - ) + if is_batch_sharded_by_expert: + # locally unpermute back to the original order + local_output = _sort_activations( + intermediate_output, + jnp.argsort(local_sorted_indices), # pylint: disable=undefined-variable + self.config.use_custom_sort_vjp, + ) + input_offsets, send_sizes, output_offsets, recv_sizes = RoutedMoE.get_all_to_all_params( + jnp.transpose(all_shards_group_sizes), # pylint: disable=undefined-variable + expert_shard_id, + num_expert_parallelism, + ) + intermediate_output = jax.lax.ragged_all_to_all( + local_output, + output_shape, + input_offsets, + send_sizes, + output_offsets, + recv_sizes, + axis_name=self._expert_parallelism_name, + ) + else: + # If bach is replicated across EP shards then each shard should send + # 0..local_shard_size data to the other shards and receive the + # local_shard data from all of the other shards using + # ragged_all_to_all. + input_offsets, send_sizes, output_offsets, recv_sizes = RoutedMoE.get_all_to_all_params( + reshaped_group_sizes, # pylint: disable=undefined-variable + expert_shard_id, + num_expert_parallelism, + is_batch_sharded=False, + ) + intermediate_output = jax.lax.ragged_all_to_all( + intermediate_output, + output_shape, + input_offsets, + send_sizes, + output_offsets, + recv_sizes, + axis_name=self._expert_parallelism_name, + ) output = self.unpermute( intermediate_output, diff --git a/tests/integration/decode_tests.py b/tests/integration/decode_tests.py index f36ecf9efd..9dbe32db53 100644 --- a/tests/integration/decode_tests.py +++ b/tests/integration/decode_tests.py @@ -96,6 +96,34 @@ class DecodeTests(unittest.TestCase): "prompt=I love to", "skip_jax_distributed_system=True", ], + "deepseek32": [ # tests decode for deepseek3.2-671b full EP + None, + get_test_config_path(), + "base_output_directory=gs://runner-maxtext-logs", + "run_name=decode", + "model_name=deepseek3.2-671b", + "override_model_config=True", + "base_num_decoder_layers=2", + "first_num_dense_layers=1", + "num_experts=16", + "base_mlp_dim=128", + "base_emb_dim=128", + "base_moe_mlp_dim=128", + "tokenizer_type=huggingface", + f"hf_access_token={os.environ.get('HF_TOKEN', '')}", + "tokenizer_path=deepseek-ai/DeepSeek-V3.2-Exp", + "scan_layers=False", + "attention=dot_product", + "weight_dtype=bfloat16", + "per_device_batch_size=1", + "max_prefill_predict_length=8", + "max_target_length=16", + "ici_fsdp_parallelism=1", + "ici_tensor_parallelism=1", + "ici_expert_parallelism=-1", + "mla_naive_kvcache=false", + "prompt=I love to", + ], } SAMPLING_STRATEGY_CONFIG = { "greedy": [ @@ -173,6 +201,11 @@ def test_decode_topk_sampling(self): expected_output = "Input `I love to` -> ` travel and I love to write" assert expected_output in captured_out + @pytest.mark.tpu_only + @pytest.mark.scheduled_only + def test_tpu_deepseek32(self): + decode_main(DecodeTests.CONFIGS["deepseek32"]) + def run_decoding(config): f = io.StringIO()