Skip to content

Support specifying tokamax gmm tile sizes in MaxText#3779

Open
darisoy wants to merge 1 commit into
mainfrom
darisoy-gmm-tile
Open

Support specifying tokamax gmm tile sizes in MaxText#3779
darisoy wants to merge 1 commit into
mainfrom
darisoy-gmm-tile

Conversation

@darisoy
Copy link
Copy Markdown
Collaborator

@darisoy darisoy commented Apr 29, 2026

Description

This PR enables cleanly overriding the GMM tile sizes for both the forward and backward passes of Tokamax (ragged_dot) in MaxText using an elegant global heuristics monkey-patching approach.

Proposed Changes

  • Heuristics Monkey-Patching (src/maxtext/layers/moe.py): Instead of introducing complex custom vjp function wrappers or polluting model layers with low-level compilation code, this PR dynamically monkey-patches PallasMosaicTpuRaggedDot._get_heuristics_config globally during RoutedMoE initialization.
  • Dynamic Shapes & Pass Routing: The monkey-patch dynamically inspects the shapes of the operands (lhs/rhs) and the JAX dimension numbers. This allows it to detect the active pass (Forward, Backward DLHS, or Backward DRHS) and layer type (Weight Input wi or Weight Output wo) on the fly, applying your exact manual tile configurations:
    • wi_tile_fwd, wi_tile_dlhs, wi_tile_drhs configs.
    • wo_tile_fwd, wo_tile_dlhs, wo_tile_drhs configs.

FIXES: b/506157856

Tests

Verified compilation and trace collection via a 5-step synthetic pretraining run on a v6e-4 VM using the deepseek3-671b MoE model configuration with your manual tiling overrides:

JAX_PLATFORMS=tpu,cpu PYTHONPATH=src python3 -m maxtext.trainers.pre_train.train \
    run_name=darisoy-moe-train \
    base_output_directory=gs://maxtext-experiments-multipod \
    model_name=deepseek3-671b \
    dataset_type=synthetic \
    steps=5 \
    enable_checkpointing=False \
    enable_goodput_recording=False \
    attention_type=mla \
    sparse_matmul=True \
    megablox=True \
    per_device_batch_size=2 \
    base_num_decoder_layers=1 \
    first_num_dense_layers=0 \
    shared_experts=1 \
    base_mlp_dim=2048 \
    profiler=xplane \
    use_tokamax_gmm=True \
    override_model_config=True \
    max_target_length=4096 \
    ici_expert_parallelism=4 \
    use_ring_of_experts=True \
    tokamax_gmm_autotune=True \
    weight_dtype=bfloat16 \
    opt_type=sgd \
    wi_tile_fwd_batch_seq=128 \
    wi_tile_fwd_embed_dim=128 \
    wi_tile_fwd_mlp_dim=128 \
    wi_tile_dlhs_batch_seq=256 \
    wi_tile_dlhs_embed_dim=256 \
    wi_tile_dlhs_mlp_dim=256 \
    wi_tile_drhs_batch_seq=512 \
    wi_tile_drhs_embed_dim=512 \
    wi_tile_drhs_mlp_dim=512 \
    wo_tile_fwd_batch_seq: 128 \
    wo_tile_fwd_embed_dim: 128 \
    wo_tile_fwd_mlp_dim: 128 \
    wo_tile_dlhs_batch_seq: 256 \
    wo_tile_dlhs_embed_dim: 256 \
    wo_tile_dlhs_mlp_dim: 256 \
    wo_tile_drhs_batch_seq: 512 \
    wo_tile_drhs_embed_dim: 512 \
    wo_tile_drhs_mlp_dim: 512

See details in http://b/506157856

Checklist

Before submitting this PR, please make sure (put X in square brackets):

  • I have performed a self-review of my code. For an optional AI review, add the gemini-review label.
  • I have necessary comments in my code, particularly in hard-to-understand areas.
  • I have run end-to-end tests tests and provided workload links above if applicable.
  • I have made or will make corresponding changes to the doc if needed, including adding new documentation pages to the relevant Table of Contents (toctree directive) as explained in our documentation.

@codecov
Copy link
Copy Markdown

codecov Bot commented May 5, 2026

Codecov Report

❌ Patch coverage is 82.35294% with 6 lines in your changes missing coverage. Please review.

Files with missing lines Patch % Lines
src/maxtext/layers/moe.py 87.50% 2 Missing and 2 partials ⚠️
src/maxtext/models/deepseek_batchsplit_fp8.py 0.00% 2 Missing ⚠️

📢 Thoughts on this report? Let us know!

Comment thread src/maxtext/layers/moe.py Outdated
Comment thread src/maxtext/layers/moe.py Outdated
Comment thread src/maxtext/layers/moe.py Outdated
Copy link
Copy Markdown
Collaborator

@gobbleturk gobbleturk left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Lets try to find a more robust way to differentiate wi and wo - my idea is to monkey patch twice

gmm_up = monkey_patch_up(PallasMosaicTpuRaggedDot)
gmm_down = monkey_patch_down(PallasMosaicTpuRaggedDot) - see also comment

@darisoy darisoy force-pushed the darisoy-gmm-tile branch from f0615e2 to 504dcb0 Compare June 1, 2026 21:00
darisoy pushed a commit that referenced this pull request Jun 2, 2026
…_sizes and use Pallas subclasses

This change addresses review comments on PR #3779:
1. Replaces the brittle global monkey-patch with clean Pallas subclasses (PallasMosaicTpuRaggedDotWI and PallasMosaicTpuRaggedDotWO) in moe.py and deepseek_batchsplit_fp8.py.
2. Implements custom __post_init__ in these subclasses to ensure JAX backward passes (VJP) use the correct subclass instead of reverting to the base class.
3. Renames the configuration flag `tokamax_gmm_autotune` to `tokamax_gmm_custom_tile_sizes` (defaulting to true) across configs, types, and layers.
4. Updates unit tests in moe_test.py to verify the new subclass-based tiling overrides.

TAG=agy
CONV=2c6843af-dcf7-403b-b67e-2fedd5f81b95
darisoy pushed a commit that referenced this pull request Jun 2, 2026
…stom tile sizes with Pallas subclasses

This change addresses review comments on PR #3779:
1. Replaces the brittle global monkey-patch with clean Pallas subclasses (PallasMosaicTpuRaggedDotWI and PallasMosaicTpuRaggedDotWO) in moe.py and deepseek_batchsplit_fp8.py.
2. Implements custom __post_init__ in these subclasses to ensure JAX backward passes (VJP) use the correct subclass instead of reverting to the base class.
3. Removes the configuration flag `tokamax_gmm_autotune` entirely (unconditionally enabling custom tile sizes when Tokamax GMM is active).
4. Updates unit tests in moe_test.py to verify the new subclass-based tiling overrides.

TAG=agy
CONV=2c6843af-dcf7-403b-b67e-2fedd5f81b95
@darisoy darisoy force-pushed the darisoy-gmm-tile branch from e0adc78 to c390387 Compare June 2, 2026 17:39
darisoy added a commit that referenced this pull request Jun 2, 2026
…stom tile sizes with Pallas subclasses

This change addresses review comments on PR #3779:
1. Replaces the brittle global monkey-patch with clean Pallas subclasses (PallasMosaicTpuRaggedDotWI and PallasMosaicTpuRaggedDotWO) in moe.py and deepseek_batchsplit_fp8.py.
2. Implements custom __post_init__ in these subclasses to ensure JAX backward passes (VJP) use the correct subclass instead of reverting to the base class.
3. Removes the configuration flag `tokamax_gmm_autotune` entirely (unconditionally enabling custom tile sizes when Tokamax GMM is active).
4. Updates unit tests in moe_test.py to verify the new subclass-based tiling overrides.

CONV=2c6843af-dcf7-403b-b67e-2fedd5f81b95
@darisoy darisoy force-pushed the darisoy-gmm-tile branch from c390387 to cfeea9c Compare June 2, 2026 17:46
Comment thread src/maxtext/layers/moe.py
def compute(x, w0, w1, wo, group_sizes, weights, *, config, mesh):
"""Processes routed tokens through the MLP."""

class PallasMosaicTpuRaggedDotWI(PallasMosaicTpuRaggedDot):
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we not re-use the classes defined in moe.py? Is quantization different here?

cc @shuningjin

darisoy added a commit that referenced this pull request Jun 2, 2026
…stom GMM tile sizes with Pallas subclass

This change addresses review comments on PR #3779:
1. Replaces the brittle global monkey-patch and separate WI/WO subclasses with a single, reusable `PallasMosaicTpuRaggedDotCustom` class defined in `moe.py`.
2. Implements custom __post_init__ in this subclass to ensure JAX backward passes (VJP) preserve tile configurations correctly.
3. Removes the configuration flag `tokamax_gmm_autotune` entirely (unconditionally enabling custom tile sizes when Tokamax GMM is active).
4. Updates unit tests in `moe_test.py` to verify the new subclass-based tiling overrides.

CONV=2c6843af-dcf7-403b-b67e-2fedd5f81b95
@darisoy darisoy force-pushed the darisoy-gmm-tile branch from cfeea9c to 2a36c57 Compare June 2, 2026 18:27
Copy link
Copy Markdown
Collaborator

@gobbleturk gobbleturk left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good, lets have the tokamax team take a look as well

Copy link
Copy Markdown
Collaborator

@gobbleturk gobbleturk left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note we currently don't support passing in input_buffer_count - perhaps this can be a separate add if we find significant lift. I think this defaults to 2 (with or without this change?)

Comment thread src/maxtext/layers/moe.py Outdated
darisoy added a commit that referenced this pull request Jun 2, 2026
…stom GMM tile sizes with Pallas subclass

This change addresses review comments on PR #3779:
1. Replaces the brittle global monkey-patch and separate WI/WO subclasses with a single, reusable `PallasMosaicTpuRaggedDotCustom` class defined in `moe.py`.
2. Implements custom __post_init__ in this subclass to ensure JAX backward passes (VJP) preserve tile configurations correctly.
3. Removes the configuration flag `tokamax_gmm_autotune` entirely (unconditionally enabling custom tile sizes when Tokamax GMM is active).
4. Updates unit tests in `moe_test.py` to verify the new subclass-based tiling overrides.

CONV=2c6843af-dcf7-403b-b67e-2fedd5f81b95
@darisoy darisoy force-pushed the darisoy-gmm-tile branch from 2a36c57 to d6d8229 Compare June 2, 2026 20:43
darisoy added a commit that referenced this pull request Jun 2, 2026
…stom GMM tile sizes with Pallas subclass

This change addresses review comments on PR #3779:
1. Replaces the brittle global monkey-patch and separate WI/WO subclasses with a single, reusable `PallasMosaicTpuRaggedDotCustom` class defined in `moe.py`.
2. Implements custom __post_init__ in this subclass to ensure JAX backward passes (VJP) preserve tile configurations correctly.
3. Removes the configuration flag `tokamax_gmm_autotune` entirely (unconditionally enabling custom tile sizes when Tokamax GMM is active).
4. Updates unit tests in `moe_test.py` to verify the new subclass-based tiling overrides.

CONV=2c6843af-dcf7-403b-b67e-2fedd5f81b95
@darisoy darisoy force-pushed the darisoy-gmm-tile branch from d6d8229 to a1af8b5 Compare June 2, 2026 21:22
darisoy added a commit that referenced this pull request Jun 2, 2026
…stom GMM tile sizes with Pallas subclass

This change addresses review comments on PR #3779:
1. Replaces the brittle global monkey-patch and separate WI/WO subclasses with a single, reusable `PallasMosaicTpuRaggedDotCustom` class defined in `moe.py`.
2. Implements custom __post_init__ in this subclass to ensure JAX backward passes (VJP) preserve tile configurations correctly.
3. Removes the configuration flag `tokamax_gmm_autotune` entirely (unconditionally enabling custom tile sizes when Tokamax GMM is active).
4. Updates unit tests in `moe_test.py` to verify the new subclass-based tiling overrides.

CONV=2c6843af-dcf7-403b-b67e-2fedd5f81b95
@darisoy darisoy force-pushed the darisoy-gmm-tile branch from a1af8b5 to d08174a Compare June 2, 2026 21:35
darisoy added a commit that referenced this pull request Jun 2, 2026
…stom GMM tile sizes with Pallas subclass

This change addresses review comments on PR #3779:
1. Replaces the brittle global monkey-patch and separate WI/WO subclasses with a single, reusable `PallasMosaicTpuRaggedDotCustom` class defined in `moe.py`.
2. Implements custom __post_init__ in this subclass to ensure JAX backward passes (VJP) preserve tile configurations correctly.
3. Removes the configuration flag `tokamax_gmm_autotune` entirely (unconditionally enabling custom tile sizes when Tokamax GMM is active).
4. Updates unit tests in `moe_test.py` to verify the new subclass-based tiling overrides.

CONV=2c6843af-dcf7-403b-b67e-2fedd5f81b95
@darisoy darisoy force-pushed the darisoy-gmm-tile branch 3 times, most recently from 368e244 to 504dcb0 Compare June 2, 2026 22:42
darisoy added a commit that referenced this pull request Jun 2, 2026
…stom GMM tile sizes with Pallas subclass

This change addresses review comments on PR #3779:
1. Replaces the brittle global monkey-patch and separate WI/WO subclasses with a single, reusable `PallasMosaicTpuRaggedDotCustom` class defined in `moe.py`.
2. Implements custom __post_init__ in this subclass to ensure JAX backward passes (VJP) preserve tile configurations correctly.
3. Removes the configuration flag `tokamax_gmm_autotune` entirely (unconditionally enabling custom tile sizes when Tokamax GMM is active).
4. Updates unit tests in `moe_test.py` to verify the new subclass-based tiling overrides.

CONV=2c6843af-dcf7-403b-b67e-2fedd5f81b95
…stom GMM tile sizes with Pallas subclass

This change addresses review comments on PR #3779:
1. Replaces the brittle global monkey-patch and separate WI/WO subclasses with a single, reusable `PallasMosaicTpuRaggedDotCustom` class defined in `moe.py`.
2. Implements custom __post_init__ in this subclass to ensure JAX backward passes (VJP) preserve tile configurations correctly.
3. Removes the configuration flag `tokamax_gmm_autotune` entirely (unconditionally enabling custom tile sizes when Tokamax GMM is active).
4. Updates unit tests in `moe_test.py` to verify the new subclass-based tiling overrides.

CONV=2c6843af-dcf7-403b-b67e-2fedd5f81b95
@darisoy darisoy force-pushed the darisoy-gmm-tile branch from 8aca22d to e1189e9 Compare June 2, 2026 23:08
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants