Skip to content

Support NNX checkpoint conversion to HF#2946

Merged
copybara-service[bot] merged 1 commit intomainfrom
hengtaoguo-conversion
Jan 16, 2026
Merged

Support NNX checkpoint conversion to HF#2946
copybara-service[bot] merged 1 commit intomainfrom
hengtaoguo-conversion

Conversation

@hengtaoguo
Copy link
Copy Markdown
Collaborator

@hengtaoguo hengtaoguo commented Jan 14, 2026

Description

This PR further supports NNX-SFT and NNX-RL checkpoints conversion to HuggingFace safetensors format. For the same model, their structural difference can be seen in here:

  • Linen: {'params': {'params': {'decoder': {...}, 'token_embedder': ... {WEIGHT_ARRAY}}}}. This is the conventional MaxText structure that nests weights directly within a double params wrapper.
  • NNX-SFT: {'decoder': {...}, 'token_embedder': ... {'value': WEIGHT_ARRAY}}. This format removes the top-level params wrappers but adds a value key to contain the weights at every leaf node.
  • NNX-RL: {'base': {'decoder': {...}, 'token_embedder': ... {'value': WEIGHT_ARRAY}}}. Similar to the SFT structure, this format wraps the entire model under a top-level base key while maintaining the leaf-level value convention.

Changes in this PR:

  • Replace maxengine with ocp.Checkpointer to load weights directly from the source, removing the dependency on MaxText model initialization.
  • Automatically identify the checkpoint type and transforms weight names to follow the standard Linen convention like params-decoder-decoder_norm-scale.
  • Remove unused NNX RNG flags during the to_huggingface conversion process.

Fixes: b/474525257, b/475223303

Tests

Tested conversion with all three Linen/NNX-SFT/NNX-RL checkpoints types to HuggingFace (v6e-8). Offline inference with HuggingFace scripts.

  • Linen (Qwen3-4B)
# Command
JAX_PLATFORMS=cpu python -m MaxText.utils.ckpt_conversion.to_huggingface src/MaxText/configs/base.yml model_name=qwen3-4b load_parameters_path=gs://hengtaoguo-maxtext-logs/checkpoints/qwen3-4b/scanned/0/items scan_layers=true base_output_directory=/home/hengtaoguo_google_com/projects/hf_safetensor/qwen3-4b hf_access_token=xxx weight_dtype=bfloat16 hardware=cpu skip_jax_distributed_system=True

# Conversion Logs
I0115 07:06:34.949318 139917773788992 utils.py:501]    Saved model.safetensors.index.json to /home/hengtaoguo_google_com/projects/hf_safetensor/qwen3-4b/model.safetensors.index.json
I0115 07:06:34.949672 139917773788992 utils.py:644] ✅ Model and tokenizer (if provided) successfully processed for /home/hengtaoguo_google_com/projects/hf_safetensor/qwen3-4b
I0115 07:06:34.949819 139917773788992 to_huggingface.py:210] ✅ MaxText model successfully saved in HuggingFace format at /home/hengtaoguo_google_com/projects/hf_safetensor/qwen3-4b
I0115 07:06:34.949920 139917773788992 to_huggingface.py:211] Elapse for save: 0.83 min
I0115 07:06:34.949992 139917773788992 to_huggingface.py:212] Overall Elapse: 2.48 min

# Inference with converted safetensors
Prompt: Artificial Intelligence is
--------------------------------------------------------------------------------
Generated: Artificial Intelligence is a part of modern technology. It is an area of computer science that deals with creating machines that can perform tasks that normally require human intelligence, such as learning, reasoning, problem-solving, perception, and language understanding. AI has been around for decades, but it has become more prominent in recent years due to advancements in computing power, data storage, and machine learning algorithms. AI is used in various fields, including healthcare, finance, transportation, and entertainment. In healthcare, AI is used to analyze medical
  • NNX-SFT (Qwen3-4B)
# Command
JAX_PLATFORMS=cpu python -m MaxText.utils.ckpt_conversion.to_huggingface src/MaxText/configs/base.yml model_name=qwen3-4b load_parameters_path=gs://horacehylin-ml-exp/qwen/qwen3-4b/2026-01-09-10-19-03/checkpoints/1/model_params/ scan_layers=true base_output_directory=/home/hengtaoguo_google_com/projects/hf_safetensor/qwen3-4b hf_access_token=xxx weight_dtype=bfloat16 hardware=cpu skip_jax_distributed_system=True

# Conversion Logs
I0115 07:20:55.308486 140633909978944 utils.py:501]    Saved model.safetensors.index.json to /home/hengtaoguo_google_com/projects/hf_safetensor/qwen3-4b/model.safetensors.index.json
I0115 07:20:55.308830 140633909978944 utils.py:644] ✅ Model and tokenizer (if provided) successfully processed for /home/hengtaoguo_google_com/projects/hf_safetensor/qwen3-4b
I0115 07:20:55.308993 140633909978944 to_huggingface.py:210] ✅ MaxText model successfully saved in HuggingFace format at /home/hengtaoguo_google_com/projects/hf_safetensor/qwen3-4b
I0115 07:20:55.309100 140633909978944 to_huggingface.py:211] Elapse for save: 3.12 min
I0115 07:20:55.309181 140633909978944 to_huggingface.py:212] Overall Elapse: 4.17 min

# Inference with converted safetensors
Prompt: Artificial Intelligence is
--------------------------------------------------------------------------------
Generated: Artificial Intelligence is a hot topic. It has been used in many areas of our life. It's very important for us to understand its functions.  The following is a question that I want to ask.  I have a list of 1000 items, and I want to use AI to process them.  What are the main steps I should take to do this?  Please give me a detailed answer.  I have no prior experience in AI processing, so I need to be very clear and
  • NNX-RL (Llama3.1-8B)
# Command
JAX_PLATFORMS=cpu python -m MaxText.utils.ckpt_conversion.to_huggingface src/MaxText/configs/base.yml model_name=llama3.1-8b load_parameters_path=gs://agagik-us/distillation/exp_xpk_v764_2/distill_llama/checkpoints/1986/model_params/ scan_layers=true base_output_directory=/home/hengtaoguo_google_com/projects/hf_safetensor/llama31-8b-2 hf_access_token=xxx weight_dtype=bfloat16 hardware=cpu skip_jax_distributed_system=True base_num_query_heads=16 head_dim=256 base_num_kv_heads=4

# Conversion Logs
I0115 07:31:21.982063 140431949666112 utils.py:501]    Saved model.safetensors.index.json to /home/hengtaoguo_google_com/projects/hf_safetensor/llama31-8b-2/model.safetensors.index.json
I0115 07:31:21.982394 140431949666112 utils.py:644] ✅ Model and tokenizer (if provided) successfully processed for /home/hengtaoguo_google_com/projects/hf_safetensor/llama31-8b-2
I0115 07:31:21.982536 140431949666112 to_huggingface.py:210] ✅ MaxText model successfully saved in HuggingFace format at /home/hengtaoguo_google_com/projects/hf_safetensor/llama31-8b-2
I0115 07:31:21.982625 140431949666112 to_huggingface.py:211] Elapse for save: 1.66 min
I0115 07:31:21.982692 140431949666112 to_huggingface.py:212] Overall Elapse: 4.66 min

# Inference with converted safetensors
Prompt: Artificial Intelligence is
--------------------------------------------------------------------------------
Generated: Artificial Intelligence is revolutionizing the world of music production and sound design. With the advancements in technology, AI-powered tools are becoming increasingly sophisticated, offering a wide range of applications in various industries. In this article, we will explore the benefits of using AI in music production and sound design and how it is transforming the industry.
AI in Music Production and Sound Design
AI-powered tools are being used in various aspects of music production and sound design. These tools are designed to assist in the creative process, providing assistance in tasks

Checklist

Before submitting this PR, please make sure (put X in square brackets):

  • I have performed a self-review of my code. For an optional AI review, add the gemini-review label.
  • I have necessary comments in my code, particularly in hard-to-understand areas.
  • I have run end-to-end tests tests and provided workload links above if applicable.
  • I have made or will make corresponding changes to the doc if needed, including adding new documentation pages to the relevant Table of Contents (toctree directive) as explained in our documentation.

@codecov
Copy link
Copy Markdown

codecov Bot commented Jan 14, 2026

Codecov Report

❌ Patch coverage is 0% with 56 lines in your changes missing coverage. Please review.

Files with missing lines Patch % Lines
src/MaxText/utils/ckpt_conversion/utils/utils.py 0.00% 53 Missing ⚠️
...rc/MaxText/utils/ckpt_conversion/to_huggingface.py 0.00% 3 Missing ⚠️

📢 Thoughts on this report? Let us know!

@hengtaoguo hengtaoguo force-pushed the hengtaoguo-conversion branch 3 times, most recently from 5812b57 to 1837cbd Compare January 15, 2026 07:01
@hengtaoguo hengtaoguo marked this pull request as ready for review January 15, 2026 07:28
@hengtaoguo hengtaoguo force-pushed the hengtaoguo-conversion branch from 59b2503 to 1fe8fd3 Compare January 15, 2026 08:01
Copy link
Copy Markdown
Collaborator

@shuningjin shuningjin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for the great work!

Copy link
Copy Markdown
Collaborator

@gagika gagika left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks!

@copybara-service copybara-service Bot merged commit 4d99240 into main Jan 16, 2026
31 of 32 checks passed
@copybara-service copybara-service Bot deleted the hengtaoguo-conversion branch January 16, 2026 02:31
@ChingTsai
Copy link
Copy Markdown
Collaborator

Thanks so much!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants