Skip to content

Integrate mHC with DeepSeek custom model#3115

Merged
copybara-service[bot] merged 1 commit intomainfrom
mhc_integration
Feb 14, 2026
Merged

Integrate mHC with DeepSeek custom model#3115
copybara-service[bot] merged 1 commit intomainfrom
mhc_integration

Conversation

@RissyRan
Copy link
Collaborator

@RissyRan RissyRan commented Feb 9, 2026

Description

  • Update default mhc expansion rate to 1, which is the same as disable the feature.
  • Update the shape of activations in decoders.py when feature is enabled.
  • Enable the loss tracking in MoE layers when using mHC.
  • Update the precision of weights to activation before matmul, which aligns with existing pattern in MaxText

General pre-norm when mHC feature is disabled:

Input (x) ───────────────────────────┐
  │                                  │
  ▼                                  │ (Residual Connection)
[ Pre-Norm ]                         │
  │                                  │
  ▼                                  │
[ Attention / MLP ]                  │
  │                                  │
  ▼                                  │
[ Layer Output ]                     │
  │                                  │
  └───────────► ( + ) ◄──────────────┘
                 │
                 ▼
           Final Output

When mHC feature is enabled:


Input (x) ─────────────────────────-─────────┐
          │                                  │
          │    Pre mapping                   |  
          │    pre-norm                      |  
          ▼                           residual mapping     
  [ Attention / MLP ]                        │      
          │                                  |
          ▼                                  │   
  [ Layer Output ]                           │          
          │   post mapping                   │             
          |                                  │         
         layer_output                      res_output           
          └──────────► ( + ) ◄────────--─────┘          
                        │                             
                        ▼                             
                  Final Output                 

Tests

  • Update unit tests
  • End-to-end sanity check test - link
  • Check MoE related load balance is captured in TB - link
  • Check mHC end-to-end (500 steps with seed datasets) - comparison. Please note, the paper is comparing mHC vs. HC. Here is comparing mHC vs. baseline.
    • mHC - loss: 6.701 (expansion_rate=4), slightly lower on this toy model with real dataset
    • Without mHC - loss: 6.796 (expansion_rate=1)
# cmd to run

python3 -m MaxText.train maxtext/configs/base.yml base_output_directory=${BASE_OUTPUT_PATH} run_name=${RUN_NAME} per_device_batch_size=8 enable_checkpointing=false model_name=deepseek-custom ici_fsdp_parallelism=4 steps=500 max_target_length=4096 async_checkpointing=false dtype=bfloat16 weight_dtype=float32 scan_layers=True dataset_type=synthetic attention=dot_product train_split=train dataset_type=hf hf_path='HuggingFaceFW/fineweb-edu' hf_name=default enable_tensorboard=true tokenizer_type=huggingface tokenizer_path=deepseek-ai/DeepSeek-V3.2 data_shuffle_seed=1234
  • DeepSeek v2 sanity tests (expect no impact for existing models)
python3 -m MaxText.train maxtext/configs/base.yml base_output_directory=${BASE_OUTPUT_PATH} run_name=${RUN_NAME} per_device_batch_size=8 enable_checkpointing=false model_name=deepseek2-16b ici_fsdp_parallelism=4 steps=20 max_target_length=4096 async_checkpointing=false tokenizer_path=src/MaxText/assets/tokenizer.mistral-v1 dtype=bfloat16 weight_dtype=float32 scan_layers=True dataset_type=synthetic attention=flash

# before change

I0210 00:18:46.050221 139944566173248 metric_logger.py:181] completed step: 19, seconds: 4.363, TFLOP/s/device: 123.215, Tokens/s/device: 7510.083, total_weights: 131072, loss: 8.135

# after change

I0210 00:09:42.323213 139931725155904 metric_logger.py:181] completed step: 19, seconds: 4.365, TFLOP/s/device: 123.166, Tokens/s/device: 7507.116, total_weights: 131072, loss: 8.135

Checklist

Before submitting this PR, please make sure (put X in square brackets):

  • I have performed a self-review of my code. For an optional AI review, add the gemini-review label.
  • I have necessary comments in my code, particularly in hard-to-understand areas.
  • I have run end-to-end tests tests and provided workload links above if applicable.
  • I have made or will make corresponding changes to the doc if needed, including adding new documentation pages to the relevant Table of Contents (toctree directive) as explained in our documentation.

@codecov
Copy link

codecov bot commented Feb 9, 2026

Codecov Report

❌ Patch coverage is 94.11765% with 4 lines in your changes missing coverage. Please review.

Files with missing lines Patch % Lines
src/MaxText/train.py 60.00% 2 Missing and 2 partials ⚠️

📢 Thoughts on this report? Let us know!

@RissyRan RissyRan force-pushed the mhc_integration branch 5 times, most recently from 8a7e2a0 to 5b938c0 Compare February 10, 2026 00:57
@RissyRan RissyRan changed the title [WIP] Integrate MHC with DeepSeek custom model [WIP] Integrate mHC with DeepSeek custom model Feb 10, 2026
@RissyRan RissyRan force-pushed the mhc_integration branch 4 times, most recently from 2fdb6e1 to 222286f Compare February 11, 2026 20:57
Copy link
Collaborator

@shuningjin shuningjin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the integration and testing!

Update mHC a little bit to align with normalization (flexible with pre-norm and post-norm). Please note mHC also has a norm across last k * dim dimension inside.

I’m not sure I fully follow this part. My understanding was that HC is designed to replace residual connection (either pre-norm and post-norm form), rather than co-exist. HC subsumes both forms and overcomes their limitation, ref from HC paper Page2.


[Update] I need another look at mhc paper, this

@RissyRan RissyRan force-pushed the mhc_integration branch 2 times, most recently from d3c1185 to 22890d5 Compare February 12, 2026 18:19
@RissyRan
Copy link
Collaborator Author

Thanks for the integration and testing!

Update mHC a little bit to align with normalization (flexible with pre-norm and post-norm). Please note mHC also has a norm across last k * dim dimension inside.

I’m not sure I fully follow this part. My understanding was that HC is designed to replace residual connection (either pre-norm and post-norm form), rather than co-exist. HC subsumes both forms and overcomes their limitation, ref.

[Update] I need another look at mhc paper, this

I think that's right. HC has learnable weights and scale to make flexible as pre/post-norm. This makes it much easier! Let me update it back to original implementation.

@RissyRan RissyRan force-pushed the mhc_integration branch 4 times, most recently from 8e20dbd to cc0e744 Compare February 12, 2026 19:52
@RissyRan
Copy link
Collaborator Author

RissyRan commented Feb 13, 2026

Thanks for the integration and testing!

Update mHC a little bit to align with normalization (flexible with pre-norm and post-norm). Please note mHC also has a norm across last k * dim dimension inside.

I’m not sure I fully follow this part. My understanding was that HC is designed to replace residual connection (either pre-norm and post-norm form), rather than co-exist. HC subsumes both forms and overcomes their limitation, ref.
[Update] I need another look at mhc paper, this

I think that's right. HC has learnable weights and scale to make flexible as pre/post-norm. This makes it much easier! Let me update it back to original implementation.

Discussed offline. There are misalignment between DS mHC and original HC implementation (HC has pre-norm, weights shape are different to mHC, and HC does not utilize mapping for residuals) . We would like to wait for the clarification when official implementation is available. Also added a comment in the codebase.

  Note: As an official reference implementation is currently unavailable, the
  integration of pre-norm and post-norm within the mHC architecture remains
  to further verification.

Copy link
Collaborator

@shuningjin shuningjin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you! LGTM

@shuningjin
Copy link
Collaborator

shuningjin commented Feb 13, 2026

Thanks for the update, LGTM!

I have high confidence that current implementation aligns with both papers.

Copy link
Collaborator

@gagika gagika left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks

@copybara-service copybara-service bot merged commit 3c56dd3 into main Feb 14, 2026
47 of 48 checks passed
@copybara-service copybara-service bot deleted the mhc_integration branch February 14, 2026 00:02
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants