Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/maxtext/configs/base.yml
Original file line number Diff line number Diff line change
Expand Up @@ -488,6 +488,7 @@ logical_axis_rules: [
['activation_stage', 'stage'],
['activation_exp', ['expert']],
['decode_batch', ['data', 'fsdp', 'fsdp_transpose', 'expert']],
['decode_batch_moe', ['data', 'fsdp', 'fsdp_transpose']],
['decode_length', ['sequence']],
['mlp', ['fsdp_transpose', 'tensor', 'tensor_sequence', 'autoregressive']],
['mlp_moe', ['fsdp_transpose', 'tensor', 'tensor_sequence', 'autoregressive']],
Expand Down
1 change: 1 addition & 0 deletions src/maxtext/configs/inference/inference.yml
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ logical_axis_rules: [
['activation_stage', 'stage'],
['activation_exp', ['expert', 'context_autoregressive']],
['decode_batch', ['data', 'fsdp', 'fsdp_transpose', 'expert', 'context_autoregressive']],
['decode_batch_moe', ['data', 'fsdp', 'fsdp_transpose', 'context_autoregressive']],
['decode_length', []],
['mlp', ['fsdp_transpose', 'tensor', 'tensor_sequence', 'autoregressive']],
['mlp_moe', ['fsdp_transpose', 'tensor', 'tensor_sequence', 'autoregressive']],
Expand Down
1 change: 1 addition & 0 deletions src/maxtext/configs/inference/vllm.yml
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ logical_axis_rules: [
['activation_norm_length_moe', []],
['activation_exp', ['expert', 'attn_dp_expert']],
['decode_batch', ['expert', 'attn_dp_expert']],
['decode_batch_moe', []],
['decode_length', []],
['mlp', ['model', 'attn_dp']],
['mlp_moe', ['model', 'attn_dp']],
Expand Down
1 change: 1 addition & 0 deletions src/maxtext/configs/post_train/rl_mt_jt.yml
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ logical_axis_rules: [
['activation_stage', 'stage'],
['activation_exp', ['expert', 'context_autoregressive']],
['decode_batch', ['data', 'fsdp', 'fsdp_transpose', 'expert', 'context_autoregressive']],
['decode_batch_moe', ['data', 'fsdp', 'fsdp_transpose', 'expert', 'context_autoregressive']],
['decode_length', []],
['mlp', ['fsdp_transpose', 'tensor', 'tensor_sequence', 'autoregressive']],
['vocab', ['tensor', 'tensor_transpose', 'tensor_sequence', 'autoregressive','context_autoregressive']],
Expand Down
153 changes: 96 additions & 57 deletions src/maxtext/layers/moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -1027,7 +1027,13 @@ def gmm(
output = output[: orig_inputs_shape[0]]
return output

batch_logical_axis = "activation_batch"
# The batch is sharded by expert, except during inference decoding (where batch size == 1).
# In the decoding case, the expert axis is instead replicated along the tensor's batch dimension.
is_batch_sharded_by_expert = inputs.shape[0] > 1
if is_batch_sharded_by_expert:
batch_logical_axis = "activation_batch"
else:
batch_logical_axis = "decode_batch_moe"

if self.get_tensor_transpose_parallelism_size() > 1:
input_partition_pspec = self._logical_to_mesh_axes(
Expand Down Expand Up @@ -1142,47 +1148,59 @@ def wrapper(x, logits, pre_bias_logits, w0, w1, wo, w0_bias, w1_bias, wo_bias, r
)

if num_expert_parallelism > 1:
batch_axis = self._expert_parallelism_name if is_batch_sharded_by_expert else "data"
# get group sizes for all shards
local_expert_size = self.config.num_experts // num_expert_parallelism
reshaped_group_sizes = jnp.sum(group_sizes.reshape(-1, local_expert_size), axis=1)
global_group_sizes = group_sizes

all_shards_group_sizes = jax.lax.all_gather(reshaped_group_sizes, axis_name=self._expert_parallelism_name)
input_offsets, send_sizes, output_offsets, recv_sizes = RoutedMoE.get_all_to_all_params(
all_shards_group_sizes,
expert_shard_id,
num_expert_parallelism,
)

# TODO(ranran): For better performance, we could update output buffer to a smaller
# size to replace self.get_expert_parallelism_size() for efficiency,
# Or we could apply capacity_factor for excessive experts.
# Note: Reducing buffer increase the risk of token dropping under unbalanced distribution.

# In the worst case, all of the global input data is assigned to each expert in the current shard.
# This would result in num_expert_shards * input_size * experts_per_shard assignments. However, if
# experts_per_shard > num_experts_per_tok we cannot assign more than num_experts_per_tok to all of the inputs.
max_local_experts_per_tok = min(local_expert_size, self.config.num_experts_per_tok)
buffer_size = int(num_expert_parallelism * batch_size * sequence_length * max_local_experts_per_tok)
output_shape = jax.lax.empty((buffer_size, self.config.emb_dim), dtype=x.dtype)

x = jax.lax.ragged_all_to_all(
x,
output_shape,
input_offsets,
send_sizes,
output_offsets,
recv_sizes,
axis_name=self._expert_parallelism_name,
)
global_group_sizes = jax.lax.all_gather(group_sizes, axis_name=self._expert_parallelism_name)
x, local_sorted_indices, group_sizes, selected_experts = RoutedMoE.local_permute(
x,
global_group_sizes,
local_expert_size,
shard_index=expert_shard_id,
use_custom_sort_vjp=self.config.use_custom_sort_vjp,
)
if is_batch_sharded_by_expert:
all_shards_group_sizes = jax.lax.all_gather(reshaped_group_sizes, axis_name=batch_axis)
input_offsets, send_sizes, output_offsets, recv_sizes = RoutedMoE.get_all_to_all_params(
all_shards_group_sizes,
expert_shard_id,
num_expert_parallelism,
)

# TODO(ranran): For better performance, we could update output buffer to a smaller
# size to replace self.get_expert_parallelism_size() for efficiency,
# Or we could apply capacity_factor for excessive experts.
# Note: Reducing buffer increase the risk of token dropping under unbalanced distribution.

# In the worst case, all of the global input data is assigned to each expert in the current shard.
# This would result in num_expert_shards * input_size * experts_per_shard assignments. However, if
# experts_per_shard > num_experts_per_tok we cannot assign more than num_experts_per_tok to all of the inputs.
max_local_experts_per_tok = min(local_expert_size, self.config.num_experts_per_tok)
buffer_size = int(num_expert_parallelism * batch_size * sequence_length * max_local_experts_per_tok)
output_shape = jax.lax.empty((buffer_size, self.config.emb_dim), dtype=x.dtype)

x = jax.lax.ragged_all_to_all(
x,
output_shape,
input_offsets,
send_sizes,
output_offsets,
recv_sizes,
axis_name=self._expert_parallelism_name,
)
global_group_sizes = jax.lax.all_gather(group_sizes, axis_name=self._expert_parallelism_name)
x, local_sorted_indices, group_sizes, selected_experts = RoutedMoE.local_permute(
x,
global_group_sizes,
local_expert_size,
shard_index=expert_shard_id,
use_custom_sort_vjp=self.config.use_custom_sort_vjp,
)
else:
x, local_sorted_indices, group_sizes, selected_experts = RoutedMoE.local_permute(
x,
global_group_sizes[None, :],
local_expert_size,
shard_index=expert_shard_id,
is_offset=True,
global_sorted_experts=selected_experts,
use_custom_sort_vjp=self.config.use_custom_sort_vjp,
)

if self.config.mlp_bias:
w0_bias, w1_bias, wo_bias = self.transform_bias(selected_experts, w0_bias, w1_bias, wo_bias)
Expand Down Expand Up @@ -1325,26 +1343,47 @@ def get_active_sharding_axes(pspec_dim_axes, tensor_dim_index):
dtype=intermediate_output.dtype,
)

# locally unpermute back to the original order
local_output = _sort_activations(
intermediate_output,
jnp.argsort(local_sorted_indices), # pylint: disable=undefined-variable
self.config.use_custom_sort_vjp,
)
input_offsets, send_sizes, output_offsets, recv_sizes = RoutedMoE.get_all_to_all_params(
jnp.transpose(all_shards_group_sizes), # pylint: disable=undefined-variable
expert_shard_id,
num_expert_parallelism,
)
intermediate_output = jax.lax.ragged_all_to_all(
local_output,
output_shape,
input_offsets,
send_sizes,
output_offsets,
recv_sizes,
axis_name=self._expert_parallelism_name,
)
if is_batch_sharded_by_expert:
# locally unpermute back to the original order
local_output = _sort_activations(
intermediate_output,
jnp.argsort(local_sorted_indices), # pylint: disable=undefined-variable
self.config.use_custom_sort_vjp,
)
input_offsets, send_sizes, output_offsets, recv_sizes = RoutedMoE.get_all_to_all_params(
jnp.transpose(all_shards_group_sizes), # pylint: disable=undefined-variable
expert_shard_id,
num_expert_parallelism,
)
intermediate_output = jax.lax.ragged_all_to_all(
local_output,
output_shape,
input_offsets,
send_sizes,
output_offsets,
recv_sizes,
axis_name=self._expert_parallelism_name,
)
else:
# If bach is replicated across EP shards then each shard should send
# 0..local_shard_size data to the other shards and receive the
# local_shard data from all of the other shards using
# ragged_all_to_all.
input_offsets, send_sizes, output_offsets, recv_sizes = RoutedMoE.get_all_to_all_params(
reshaped_group_sizes, # pylint: disable=undefined-variable
expert_shard_id,
num_expert_parallelism,
is_batch_sharded=False,
)
intermediate_output = jax.lax.ragged_all_to_all(
intermediate_output,
output_shape,
input_offsets,
send_sizes,
output_offsets,
recv_sizes,
axis_name=self._expert_parallelism_name,
)

output = self.unpermute(
intermediate_output,
Expand Down
33 changes: 33 additions & 0 deletions tests/integration/decode_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,34 @@ class DecodeTests(unittest.TestCase):
"prompt=I love to",
"skip_jax_distributed_system=True",
],
"deepseek32": [ # tests decode for deepseek3.2-671b full EP
None,
get_test_config_path(),
"base_output_directory=gs://runner-maxtext-logs",
"run_name=decode",
"model_name=deepseek3.2-671b",
"override_model_config=True",
"base_num_decoder_layers=2",
"first_num_dense_layers=1",
"num_experts=16",
"base_mlp_dim=128",
"base_emb_dim=128",
"base_moe_mlp_dim=128",
"tokenizer_type=huggingface",
f"hf_access_token={os.environ.get('HF_TOKEN', '')}",
"tokenizer_path=deepseek-ai/DeepSeek-V3.2-Exp",
"scan_layers=False",
"attention=dot_product",
"weight_dtype=bfloat16",
"per_device_batch_size=1",
"max_prefill_predict_length=8",
"max_target_length=16",
"ici_fsdp_parallelism=1",
"ici_tensor_parallelism=1",
"ici_expert_parallelism=-1",
"mla_naive_kvcache=false",
"prompt=I love to",
],
}
SAMPLING_STRATEGY_CONFIG = {
"greedy": [
Expand Down Expand Up @@ -173,6 +201,11 @@ def test_decode_topk_sampling(self):
expected_output = "Input `I love to` -> ` travel and I love to write"
assert expected_output in captured_out

@pytest.mark.tpu_only
@pytest.mark.scheduled_only
def test_tpu_deepseek32(self):
decode_main(DecodeTests.CONFIGS["deepseek32"])


def run_decoding(config):
f = io.StringIO()
Expand Down
Loading