Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions src/maxtext/checkpoint_conversion/to_huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,10 @@ def _get_model_mappings(
):
"""Retrieves parameter, shape, and hook function mappings for the model.

This handles both `atomic` keys and `composite_mt_key` architectures.
A `composite_mt_key` occurs when multiple MaxText keys must be fused back into a
single HF parameter (e.g., fusing MT `wi_0` and `wi_1` back into HF `gate_up_proj`).

Args:
model_name: The name of the model (e.g., "gemma2-2b").
scan_layers: Boolean indicating if the model was trained with scanned layers.
Expand Down
22 changes: 19 additions & 3 deletions src/maxtext/checkpoint_conversion/to_maxtext.py
Original file line number Diff line number Diff line change
Expand Up @@ -374,7 +374,10 @@ def _build_multi_axis_stacked_tensor(
layer_tensors_for_expert = []
# Inner loop iterates through layers for the current expert
for hf_key_single in layer_keys_for_expert:
hf_tensor_numpy = tensor_getter_fn(hf_key_single)
if isinstance(hf_key_single, (list, tuple)):
hf_tensor_numpy = tuple(tensor_getter_fn(k) for k in hf_key_single)
else:
hf_tensor_numpy = tensor_getter_fn(hf_key_single)
processed_hf_tensor = apply_hook_fns(hf_tensor_numpy, mt_slice_shape, hook_fns)
layer_tensors_for_expert.append(processed_hf_tensor)
all_expert_tensors.append(np.stack(layer_tensors_for_expert, axis=0))
Expand Down Expand Up @@ -419,7 +422,10 @@ def _build_single_axis_stacked_tensor(
mt_slice_shape = tuple(mt_slice_shape_list)

for hf_key_single in hf_source_keys:
hf_tensor_numpy = tensor_getter_fn(hf_key_single)
if isinstance(hf_key_single, (list, tuple)):
hf_tensor_numpy = tuple(tensor_getter_fn(k) for k in hf_key_single)
else:
hf_tensor_numpy = tensor_getter_fn(hf_key_single)
processed_hf_tensor = apply_hook_fns(hf_tensor_numpy, mt_slice_shape, hook_fns)
tensors_to_stack.append(processed_hf_tensor)

Expand All @@ -429,6 +435,12 @@ def _build_single_axis_stacked_tensor(

def _get_hf_loading_function(hf_source_keys_or_key, tensor_getter, hook_fn, mt_target_shape_or_shapes, config):
"""Determine the loading function for HF keys.

This function natively supports `composite_hf_key` mapping (where multiple HF keys
combine into a single MaxText parameter, like Qwen3.5's qkv and z -> in_proj_qkvz).
If the input is a tuple of strings, they are fetched as a tuple of arrays and passed
together into the model hook.

HF keys can take four forms:
Case 1: Unscanned (single string)
Case 2: Scanned (list of strings)
Expand All @@ -439,6 +451,9 @@ def _get_hf_loading_function(hf_source_keys_or_key, tensor_getter, hook_fn, mt_t
if not isinstance(hf_source_keys_or_key, list):
# Case 1: Single hf key (str)
def _loader(getter, key, shape, hook):
if isinstance(key, (list, tuple)):
tensors = tuple(getter(k) for k in key)
return apply_hook_fns(tensors, shape, hook)
return apply_hook_fns(getter(key), shape, hook)

load_fn = partial(
Expand Down Expand Up @@ -479,7 +494,8 @@ def _get_maxtext_indices_and_shapes(mt_param_key_or_keys, maxtext_abstract_dict)
The index is the parameter's order in `maxtext_abstract_dict.keys()`.
This function handles two forms of MaxText keys:
- `atomic_mt_key`: A single string representing one MaxText parameter that map to HF parameter(s).
- `composite_mt_key`: A tuple of strings for multiple MaxText parameters that map to HF parameter(s).
- `composite_mt_key`: A tuple of strings representing multiple MaxText parameters derived from
a single/bundled HF parameter source (e.g., HF gate_up_proj splitting into MT wi_0 and wi_1).
"""
is_composite_mt_key = isinstance(mt_param_key_or_keys, tuple)
# atomic_mt_key
Expand Down
252 changes: 252 additions & 0 deletions src/maxtext/checkpoint_conversion/utils/hf_model_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -1200,6 +1200,256 @@ def __init__(self, **kwargs):
qwen3_next_80b_a3b_config = transformers.Qwen3NextConfig(**qwen3_next_80b_a3b_dict)


qwen3_5_397b_a17b_dict = {
"architectures": ["Qwen3_5MoeForConditionalGeneration"],
"image_token_id": 248056,
"model_type": "qwen3_5_moe",
"text_config": {
"attention_bias": False,
"attention_dropout": 0.0,
"attn_output_gate": True,
"dtype": "bfloat16",
"eos_token_id": 248044,
"full_attention_interval": 4,
"head_dim": 256,
"hidden_act": "silu",
"hidden_size": 4096,
"initializer_range": 0.02,
"layer_types": [
"linear_attention",
"linear_attention",
"linear_attention",
"full_attention",
"linear_attention",
"linear_attention",
"linear_attention",
"full_attention",
"linear_attention",
"linear_attention",
"linear_attention",
"full_attention",
"linear_attention",
"linear_attention",
"linear_attention",
"full_attention",
"linear_attention",
"linear_attention",
"linear_attention",
"full_attention",
"linear_attention",
"linear_attention",
"linear_attention",
"full_attention",
"linear_attention",
"linear_attention",
"linear_attention",
"full_attention",
"linear_attention",
"linear_attention",
"linear_attention",
"full_attention",
"linear_attention",
"linear_attention",
"linear_attention",
"full_attention",
"linear_attention",
"linear_attention",
"linear_attention",
"full_attention",
"linear_attention",
"linear_attention",
"linear_attention",
"full_attention",
"linear_attention",
"linear_attention",
"linear_attention",
"full_attention",
"linear_attention",
"linear_attention",
"linear_attention",
"full_attention",
"linear_attention",
"linear_attention",
"linear_attention",
"full_attention",
"linear_attention",
"linear_attention",
"linear_attention",
"full_attention",
],
"linear_conv_kernel_dim": 4,
"linear_key_head_dim": 128,
"linear_num_key_heads": 16,
"linear_num_value_heads": 64,
"linear_value_head_dim": 128,
"max_position_embeddings": 262144,
"mlp_only_layers": [],
"model_type": "qwen3_5_moe_text",
"moe_intermediate_size": 1024,
"mtp_num_hidden_layers": 1,
"mtp_use_dedicated_embeddings": False,
"num_attention_heads": 32,
"num_experts": 512,
"num_experts_per_tok": 10,
"num_hidden_layers": 60,
"num_key_value_heads": 2,
"rms_norm_eps": 1e-06,
"router_aux_loss_coef": 0.001,
"shared_expert_intermediate_size": 1024,
"use_cache": True,
"vocab_size": 248320,
"mamba_ssm_dtype": "float32",
"rope_parameters": {
"mrope_interleaved": True,
"mrope_section": [11, 11, 10],
"rope_type": "default",
"rope_theta": 10000000,
"partial_rotary_factor": 0.25,
},
},
"tie_word_embeddings": False,
"transformers_version": "4.57.0.dev0",
"video_token_id": 248057,
"vision_config": {
"deepstack_visual_indexes": [],
"depth": 27,
"hidden_act": "gelu_pytorch_tanh",
"hidden_size": 1152,
"in_channels": 3,
"initializer_range": 0.02,
"intermediate_size": 4304,
"model_type": "qwen3_5_moe",
"num_heads": 16,
"num_position_embeddings": 2304,
"out_hidden_size": 4096,
"patch_size": 16,
"spatial_merge_size": 2,
"temporal_patch_size": 2,
},
"vision_end_token_id": 248054,
"vision_start_token_id": 248053,
}


qwen3_5_35b_a3b_dict = {
"architectures": ["Qwen3_5MoeForConditionalGeneration"],
"image_token_id": 248056,
"model_type": "qwen3_5_moe",
"text_config": {
"attention_bias": False,
"attention_dropout": 0.0,
"attn_output_gate": True,
"dtype": "bfloat16",
"eos_token_id": 248044,
"full_attention_interval": 4,
"head_dim": 256,
"hidden_act": "silu",
"hidden_size": 2048,
"initializer_range": 0.02,
"layer_types": [
"linear_attention",
"linear_attention",
"linear_attention",
"full_attention",
"linear_attention",
"linear_attention",
"linear_attention",
"full_attention",
"linear_attention",
"linear_attention",
"linear_attention",
"full_attention",
"linear_attention",
"linear_attention",
"linear_attention",
"full_attention",
"linear_attention",
"linear_attention",
"linear_attention",
"full_attention",
"linear_attention",
"linear_attention",
"linear_attention",
"full_attention",
"linear_attention",
"linear_attention",
"linear_attention",
"full_attention",
"linear_attention",
"linear_attention",
"linear_attention",
"full_attention",
"linear_attention",
"linear_attention",
"linear_attention",
"full_attention",
"linear_attention",
"linear_attention",
"linear_attention",
"full_attention",
],
"linear_conv_kernel_dim": 4,
"linear_key_head_dim": 128,
"linear_num_key_heads": 16,
"linear_num_value_heads": 32,
"linear_value_head_dim": 128,
"max_position_embeddings": 262144,
"mlp_only_layers": [],
"model_type": "qwen3_5_moe_text",
"moe_intermediate_size": 512,
"mtp_num_hidden_layers": 1,
"mtp_use_dedicated_embeddings": False,
"num_attention_heads": 16,
"num_experts": 256,
"num_experts_per_tok": 8,
"num_hidden_layers": 40,
"num_key_value_heads": 2,
"rms_norm_eps": 1e-06,
"router_aux_loss_coef": 0.001,
"shared_expert_intermediate_size": 512,
"use_cache": True,
"vocab_size": 248320,
"mamba_ssm_dtype": "float32",
"rope_parameters": {
"mrope_interleaved": True,
"mrope_section": [11, 11, 10],
"rope_type": "default",
"rope_theta": 10000000,
"partial_rotary_factor": 0.25,
},
},
"tie_word_embeddings": False,
"transformers_version": "4.57.0.dev0",
"video_token_id": 248057,
"vision_config": {
"deepstack_visual_indexes": [],
"depth": 27,
"hidden_act": "gelu_pytorch_tanh",
"hidden_size": 1152,
"in_channels": 3,
"initializer_range": 0.02,
"intermediate_size": 4304,
"model_type": "qwen3_5_moe",
"num_heads": 16,
"num_position_embeddings": 2304,
"out_hidden_size": 2048,
"patch_size": 16,
"spatial_merge_size": 2,
"temporal_patch_size": 2,
},
"vision_end_token_id": 248054,
"vision_start_token_id": 248053,
}

try:
# Will execute successfully if Transformers is updated with Qwen3.5 support
qwen3_5_35b_a3b_config = transformers.Qwen3_5MoeConfig(**qwen3_5_35b_a3b_dict)
qwen3_5_397b_a17b_config = transformers.Qwen3_5MoeConfig(**qwen3_5_397b_a17b_dict)
except AttributeError:
qwen3_5_35b_a3b_config = PTConfig(**qwen3_5_35b_a3b_dict) # pytype: disable=wrong-arg-types
qwen3_5_397b_a17b_config = PTConfig(**qwen3_5_397b_a17b_dict) # pytype: disable=wrong-arg-types


# from https://huggingface.co/mistralai/Mixtral-8x7B-Instruct-v0.1/blob/main/config.json
mixtral_8x7b_dict = {
"architectures": ["MixtralForCausalLM"],
Expand Down Expand Up @@ -1339,6 +1589,8 @@ def __init__(self, **kwargs):
"gpt-oss-120b": gpt_oss_120b_config,
"qwen3-omni-30b-a3b": qwen3_omni_30b_a3b_config,
"qwen3-next-80b-a3b": qwen3_next_80b_a3b_config,
"qwen3.5-397b-a17b": qwen3_5_397b_a17b_config,
"qwen3.5-35b-a3b": qwen3_5_35b_a3b_config,
"mixtral-8x7b": mixtral_8x7b_config,
"mixtral-8x22b": mixtral_8x22b_config,
"olmo3-7b": olmo3_7b_config,
Expand Down
Loading
Loading