Support specifying tokamax gmm tile sizes in MaxText#3779
Conversation
da5a5c3 to
7d49b8d
Compare
Codecov Report❌ Patch coverage is
📢 Thoughts on this report? Let us know! |
gobbleturk
left a comment
There was a problem hiding this comment.
Lets try to find a more robust way to differentiate wi and wo - my idea is to monkey patch twice
gmm_up = monkey_patch_up(PallasMosaicTpuRaggedDot)
gmm_down = monkey_patch_down(PallasMosaicTpuRaggedDot) - see also comment
…_sizes and use Pallas subclasses This change addresses review comments on PR #3779: 1. Replaces the brittle global monkey-patch with clean Pallas subclasses (PallasMosaicTpuRaggedDotWI and PallasMosaicTpuRaggedDotWO) in moe.py and deepseek_batchsplit_fp8.py. 2. Implements custom __post_init__ in these subclasses to ensure JAX backward passes (VJP) use the correct subclass instead of reverting to the base class. 3. Renames the configuration flag `tokamax_gmm_autotune` to `tokamax_gmm_custom_tile_sizes` (defaulting to true) across configs, types, and layers. 4. Updates unit tests in moe_test.py to verify the new subclass-based tiling overrides. TAG=agy CONV=2c6843af-dcf7-403b-b67e-2fedd5f81b95
…stom tile sizes with Pallas subclasses This change addresses review comments on PR #3779: 1. Replaces the brittle global monkey-patch with clean Pallas subclasses (PallasMosaicTpuRaggedDotWI and PallasMosaicTpuRaggedDotWO) in moe.py and deepseek_batchsplit_fp8.py. 2. Implements custom __post_init__ in these subclasses to ensure JAX backward passes (VJP) use the correct subclass instead of reverting to the base class. 3. Removes the configuration flag `tokamax_gmm_autotune` entirely (unconditionally enabling custom tile sizes when Tokamax GMM is active). 4. Updates unit tests in moe_test.py to verify the new subclass-based tiling overrides. TAG=agy CONV=2c6843af-dcf7-403b-b67e-2fedd5f81b95
…stom tile sizes with Pallas subclasses This change addresses review comments on PR #3779: 1. Replaces the brittle global monkey-patch with clean Pallas subclasses (PallasMosaicTpuRaggedDotWI and PallasMosaicTpuRaggedDotWO) in moe.py and deepseek_batchsplit_fp8.py. 2. Implements custom __post_init__ in these subclasses to ensure JAX backward passes (VJP) use the correct subclass instead of reverting to the base class. 3. Removes the configuration flag `tokamax_gmm_autotune` entirely (unconditionally enabling custom tile sizes when Tokamax GMM is active). 4. Updates unit tests in moe_test.py to verify the new subclass-based tiling overrides. CONV=2c6843af-dcf7-403b-b67e-2fedd5f81b95
| def compute(x, w0, w1, wo, group_sizes, weights, *, config, mesh): | ||
| """Processes routed tokens through the MLP.""" | ||
|
|
||
| class PallasMosaicTpuRaggedDotWI(PallasMosaicTpuRaggedDot): |
There was a problem hiding this comment.
can we not re-use the classes defined in moe.py? Is quantization different here?
cc @shuningjin
…stom GMM tile sizes with Pallas subclass This change addresses review comments on PR #3779: 1. Replaces the brittle global monkey-patch and separate WI/WO subclasses with a single, reusable `PallasMosaicTpuRaggedDotCustom` class defined in `moe.py`. 2. Implements custom __post_init__ in this subclass to ensure JAX backward passes (VJP) preserve tile configurations correctly. 3. Removes the configuration flag `tokamax_gmm_autotune` entirely (unconditionally enabling custom tile sizes when Tokamax GMM is active). 4. Updates unit tests in `moe_test.py` to verify the new subclass-based tiling overrides. CONV=2c6843af-dcf7-403b-b67e-2fedd5f81b95
gobbleturk
left a comment
There was a problem hiding this comment.
Looks good, lets have the tokamax team take a look as well
gobbleturk
left a comment
There was a problem hiding this comment.
Note we currently don't support passing in input_buffer_count - perhaps this can be a separate add if we find significant lift. I think this defaults to 2 (with or without this change?)
…stom GMM tile sizes with Pallas subclass This change addresses review comments on PR #3779: 1. Replaces the brittle global monkey-patch and separate WI/WO subclasses with a single, reusable `PallasMosaicTpuRaggedDotCustom` class defined in `moe.py`. 2. Implements custom __post_init__ in this subclass to ensure JAX backward passes (VJP) preserve tile configurations correctly. 3. Removes the configuration flag `tokamax_gmm_autotune` entirely (unconditionally enabling custom tile sizes when Tokamax GMM is active). 4. Updates unit tests in `moe_test.py` to verify the new subclass-based tiling overrides. CONV=2c6843af-dcf7-403b-b67e-2fedd5f81b95
…stom GMM tile sizes with Pallas subclass This change addresses review comments on PR #3779: 1. Replaces the brittle global monkey-patch and separate WI/WO subclasses with a single, reusable `PallasMosaicTpuRaggedDotCustom` class defined in `moe.py`. 2. Implements custom __post_init__ in this subclass to ensure JAX backward passes (VJP) preserve tile configurations correctly. 3. Removes the configuration flag `tokamax_gmm_autotune` entirely (unconditionally enabling custom tile sizes when Tokamax GMM is active). 4. Updates unit tests in `moe_test.py` to verify the new subclass-based tiling overrides. CONV=2c6843af-dcf7-403b-b67e-2fedd5f81b95
…stom GMM tile sizes with Pallas subclass This change addresses review comments on PR #3779: 1. Replaces the brittle global monkey-patch and separate WI/WO subclasses with a single, reusable `PallasMosaicTpuRaggedDotCustom` class defined in `moe.py`. 2. Implements custom __post_init__ in this subclass to ensure JAX backward passes (VJP) preserve tile configurations correctly. 3. Removes the configuration flag `tokamax_gmm_autotune` entirely (unconditionally enabling custom tile sizes when Tokamax GMM is active). 4. Updates unit tests in `moe_test.py` to verify the new subclass-based tiling overrides. CONV=2c6843af-dcf7-403b-b67e-2fedd5f81b95
…stom GMM tile sizes with Pallas subclass This change addresses review comments on PR #3779: 1. Replaces the brittle global monkey-patch and separate WI/WO subclasses with a single, reusable `PallasMosaicTpuRaggedDotCustom` class defined in `moe.py`. 2. Implements custom __post_init__ in this subclass to ensure JAX backward passes (VJP) preserve tile configurations correctly. 3. Removes the configuration flag `tokamax_gmm_autotune` entirely (unconditionally enabling custom tile sizes when Tokamax GMM is active). 4. Updates unit tests in `moe_test.py` to verify the new subclass-based tiling overrides. CONV=2c6843af-dcf7-403b-b67e-2fedd5f81b95
368e244 to
504dcb0
Compare
…stom GMM tile sizes with Pallas subclass This change addresses review comments on PR #3779: 1. Replaces the brittle global monkey-patch and separate WI/WO subclasses with a single, reusable `PallasMosaicTpuRaggedDotCustom` class defined in `moe.py`. 2. Implements custom __post_init__ in this subclass to ensure JAX backward passes (VJP) preserve tile configurations correctly. 3. Removes the configuration flag `tokamax_gmm_autotune` entirely (unconditionally enabling custom tile sizes when Tokamax GMM is active). 4. Updates unit tests in `moe_test.py` to verify the new subclass-based tiling overrides. CONV=2c6843af-dcf7-403b-b67e-2fedd5f81b95
…stom GMM tile sizes with Pallas subclass This change addresses review comments on PR #3779: 1. Replaces the brittle global monkey-patch and separate WI/WO subclasses with a single, reusable `PallasMosaicTpuRaggedDotCustom` class defined in `moe.py`. 2. Implements custom __post_init__ in this subclass to ensure JAX backward passes (VJP) preserve tile configurations correctly. 3. Removes the configuration flag `tokamax_gmm_autotune` entirely (unconditionally enabling custom tile sizes when Tokamax GMM is active). 4. Updates unit tests in `moe_test.py` to verify the new subclass-based tiling overrides. CONV=2c6843af-dcf7-403b-b67e-2fedd5f81b95
Description
This PR enables cleanly overriding the GMM tile sizes for both the forward and backward passes of Tokamax (
ragged_dot) in MaxText using an elegant global heuristics monkey-patching approach.Proposed Changes
src/maxtext/layers/moe.py): Instead of introducing complex customvjpfunction wrappers or polluting model layers with low-level compilation code, this PR dynamically monkey-patchesPallasMosaicTpuRaggedDot._get_heuristics_configglobally during RoutedMoE initialization.lhs/rhs) and the JAX dimension numbers. This allows it to detect the active pass (Forward, Backward DLHS, or Backward DRHS) and layer type (Weight Inputwior Weight Outputwo) on the fly, applying your exact manual tile configurations:wi_tile_fwd,wi_tile_dlhs,wi_tile_drhsconfigs.wo_tile_fwd,wo_tile_dlhs,wo_tile_drhsconfigs.FIXES: b/506157856
Tests
Verified compilation and trace collection via a 5-step synthetic pretraining run on a
v6e-4VM using thedeepseek3-671bMoE model configuration with your manual tiling overrides:JAX_PLATFORMS=tpu,cpu PYTHONPATH=src python3 -m maxtext.trainers.pre_train.train \ run_name=darisoy-moe-train \ base_output_directory=gs://maxtext-experiments-multipod \ model_name=deepseek3-671b \ dataset_type=synthetic \ steps=5 \ enable_checkpointing=False \ enable_goodput_recording=False \ attention_type=mla \ sparse_matmul=True \ megablox=True \ per_device_batch_size=2 \ base_num_decoder_layers=1 \ first_num_dense_layers=0 \ shared_experts=1 \ base_mlp_dim=2048 \ profiler=xplane \ use_tokamax_gmm=True \ override_model_config=True \ max_target_length=4096 \ ici_expert_parallelism=4 \ use_ring_of_experts=True \ tokamax_gmm_autotune=True \ weight_dtype=bfloat16 \ opt_type=sgd \ wi_tile_fwd_batch_seq=128 \ wi_tile_fwd_embed_dim=128 \ wi_tile_fwd_mlp_dim=128 \ wi_tile_dlhs_batch_seq=256 \ wi_tile_dlhs_embed_dim=256 \ wi_tile_dlhs_mlp_dim=256 \ wi_tile_drhs_batch_seq=512 \ wi_tile_drhs_embed_dim=512 \ wi_tile_drhs_mlp_dim=512 \ wo_tile_fwd_batch_seq: 128 \ wo_tile_fwd_embed_dim: 128 \ wo_tile_fwd_mlp_dim: 128 \ wo_tile_dlhs_batch_seq: 256 \ wo_tile_dlhs_embed_dim: 256 \ wo_tile_dlhs_mlp_dim: 256 \ wo_tile_drhs_batch_seq: 512 \ wo_tile_drhs_embed_dim: 512 \ wo_tile_drhs_mlp_dim: 512See details in http://b/506157856
Checklist
Before submitting this PR, please make sure (put X in square brackets):
gemini-reviewlabel.