In [None]:
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

In [2]:
%pip install -U -q 'google-genai>=1.0.0'

Note: you may need to restart the kernel to use updated packages.


In [3]:
from google import genai
from IPython.display import Markdown

In [None]:
GOOGLE_API_KEY=""

client = genai.Client(api_key=GOOGLE_API_KEY)

In [None]:
MODEL_ID = "gemini-2.0-pro"
target_model = "Gemma3"

In [None]:
param_file = client.files.upload(file="context/param_mapping.py")
shape_file = client.files.upload(file="context/hf_shape.py")

print(f"Uploaded file '{param_file.name}' as: {param_file.uri}")
print(f"Uploaded file '{shape_file.name}' as: {shape_file.uri}")

Uploaded file 'files/96zbvn0jl8v9' as: https://generativelanguage.googleapis.com/v1beta/files/96zbvn0jl8v9
Uploaded file 'files/i81ci5tcyjwa' as: https://generativelanguage.googleapis.com/v1beta/files/i81ci5tcyjwa
Uploaded file 'files/poo15cv7l54e' as: https://generativelanguage.googleapis.com/v1beta/files/poo15cv7l54e
Uploaded file 'files/84qrwp42q92e' as: https://generativelanguage.googleapis.com/v1beta/files/84qrwp42q92e


In [10]:

prompt = f"""
  You are a code assist to help me find the checkpoint conversion from maxtext to huggingface. 
  The checkpoint does not fuse QKV vectors. 
  The transformer configs should be completely aligned with given model config for {target_model}
  You need to generate the following code functions of {target_model} Model:
    {target_model}_MAXTEXT_TO_HF_PARAM_MAPPING(); 
    {target_model}_MAXTEXT_TO_HF_PARAM_HOOK_FN();
    {target_model}_HF_WEIGHTS_TO_SHAPE_MAPPING();
"""

response = client.models.generate_content(
    model=MODEL_ID,
    contents=[prompt,  param_file, shape_file]
    )

Markdown(response.text)

```python
"""
 Copyright 2025 Google LLC

 Licensed under the Apache License, Version 2.0 (the "License");
 you may not use this file except in compliance with the License.
 You may obtain a copy of the License at

      https://www.apache.org/licenses/LICENSE-2.0

 Unless required by applicable law or agreed to in writing, software
 distributed under the License is distributed on an "AS IS" BASIS,
 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 See the License for the specific language governing permissions and
 limitations under the License.
 """

import numpy as np
import jax
import jax.numpy as jnp


def Gemma3_MAXTEXT_TO_HF_PARAM_MAPPING(config, scan_layers=False):
  """Returns mapping between MaxText and HuggingFace Gemma3 weight paths.

  Args:
      config (dict): Model configuration dictionary containing at least 'num_hidden_layers'.
      scan_layers (bool, optional): Whether the MaxText model uses layer scanning optimization.
          When True, decoder layers are stacked into a single tensor [dim1, #layers, dim2].
          Defaults to False.

  Returns:
      dict: A mapping where:
          - Keys are MaxText parameter paths
          - Values are either:
              - Single strings (HF parameter path) for unscanned parameters
              - Lists of strings (HF parameter paths) for stacked layers when scan_layers=True
  """

  nlayers = config["num_hidden_layers"]
  mapping = {
      "params-token_embedder-embedding": "model.embed_tokens.weight",
      "params-decoder-decoder_norm-scale": "model.norm.weight",
  }
  if scan_layers:
    mapping = {
        **mapping,
        "params-decoder-layers-attention-key-kernel": [
            f"model.layers.{i}.self_attn.k_proj.weight" for i in range(nlayers)
        ],
        "params-decoder-layers-attention-value-kernel": [
            f"model.layers.{i}.self_attn.v_proj.weight" for i in range(nlayers)
        ],
        "params-decoder-layers-attention-query-kernel": [
            f"model.layers.{i}.self_attn.q_proj.weight" for i in range(nlayers)
        ],
        "params-decoder-layers-attention-out-kernel": [
            f"model.layers.{i}.self_attn.o_proj.weight" for i in range(nlayers)
        ],
        "params-decoder-layers-mlp-wi_0-kernel": [
            f"model.layers.{i}.mlp.gate_proj.weight" for i in range(nlayers)
        ],
        "params-decoder-layers-mlp-wi_1-kernel": [
            f"model.layers.{i}.mlp.up_proj.weight" for i in range(nlayers)
        ],
        "params-decoder-layers-mlp-wo-kernel": [
            f"model.layers.{i}.mlp.down_proj.weight" for i in range(nlayers)
        ],
        "params-decoder-layers-rms_norm-scale": [
            f"model.layers.{i}.input_layernorm.weight" for i in range(nlayers)
        ],
        "params-decoder-layers-ffn_rms_norm-scale": [
            f"model.layers.{i}.post_attention_layernorm.weight" for i in range(nlayers)
        ],
    }
  else:
    for layer_idx in range(nlayers):
      layer_mapping = {
          f"params-decoder-layers_{layer_idx}-attention-key-kernel": f"model.layers.{layer_idx}.self_attn.k_proj.weight",
          f"params-decoder-layers_{layer_idx}-attention-value-kernel": f"model.layers.{layer_idx}.self_attn.v_proj.weight",
          f"params-decoder-layers_{layer_idx}-attention-query-kernel": f"model.layers.{layer_idx}.self_attn.q_proj.weight",
          f"params-decoder-layers_{layer_idx}-attention-out-kernel": f"model.layers.{layer_idx}.self_attn.o_proj.weight",
          f"params-decoder-layers_{layer_idx}-mlp-wi_0-kernel": f"model.layers.{layer_idx}.mlp.gate_proj.weight",
          f"params-decoder-layers_{layer_idx}-mlp-wi_1-kernel": f"model.layers.{layer_idx}.mlp.up_proj.weight",
          f"params-decoder-layers_{layer_idx}-mlp-wo-kernel": f"model.layers.{layer_idx}.mlp.down_proj.weight",
          f"params-decoder-layers_{layer_idx}-rms_norm-scale": f"model.layers.{layer_idx}.input_layernorm.weight",
          f"params-decoder-layers_{layer_idx}-ffn_rms_norm-scale": f"model.layers.{layer_idx}.post_attention_layernorm.weight",
      }
      mapping = {**mapping, **layer_mapping}
  return mapping


def Gemma3_MAXTEXT_TO_HF_PARAM_HOOK_FN(config, scan_layers=False, saving_to_hf=False):
  """Creates parameter transformation functions for converting between MaxText and
  HuggingFace formats.

  This function generates a mapping of transformation functions that handle the necessary
  conversions between MaxText and HuggingFace parameter formats, including operations like
  padding, reshaping, and scaling.

  Args:
      config (dict): Model configuration dictionary that must contain:
          - num_hidden_layers (int): Number of layers in the model
          - head_dim (int): Dimension of attention heads
          - hidden_size (int): Model's hidden dimension size

      scan_layers (bool, optional): Controls the output format for layer parameters:
          - True: Returns transformation functions for batched layer parameters
          - False: Returns transformation functions for individual layer parameters
          Defaults to False.

      saving_to_hf (bool, optional): Determines the direction of transformation:
          - True: MaxText → HuggingFace conversion
          - False: HuggingFace → MaxText conversion
          Defaults to False.

  Returns:
      dict: Parameter transformation mapping where:
          - Keys: MaxText parameter names (str)
          - Values: Either:
              - callable: Single transformation function
              - list[callable]: List of transformation functions to be applied in sequence

  Transformation Details:
      The function handles several types of parameter transformations:
      1. Embedding layer padding:
          - HF shape: [vocab_size, d_model]
          - MaxText shape: [padded_vocab_size, d_model] (padded for performance)
      2. Layer normalization scaling:
          - Adds/subtracts 1.0 depending on direction
      3. Attention query scaling:
          - Scales by sqrt(head_dim) or its inverse

      4. Kernel reshaping:
          - Handles dimension transposition and reshaping between formats
  """
  nlayers = config["num_hidden_layers"]

  def pad_hf_embedding_layer(input_tensor, target_shape):
    """Pads the HF embedding layer to match the MaxText embedding layer's shape.

    Note:
        HF embedding weights shape =  [vocab_size,d_model]
        MaxText embedding weights shape = [padded_vocab_size,d_model]
        MaxText pad Gemma3 embedding to padded_vocab_size for better performance.
    """
    # TODO(wenxindongwork), Perhaps, this dtype should be the activation dtype
    normalizer = np.dtype("float32").type(config["hidden_size"] ** 0.5)

    def to_hf():
      target_tensor = input_tensor[: target_shape[0], : target_shape[1]]
      # target_tensor = target_tensor / normalizer  # no scale factor for embedding
      target_tensor = target_tensor.astype(input_tensor.dtype)
      return target_tensor

    def from_hf():
      target_tensor = np.zeros(target_shape, dtype=input_tensor.dtype)
      target_tensor[: input_tensor.shape[0], : input_tensor.shape[1]] = input_tensor
      # target_tensor = target_tensor * normalizer # no scale factor for embedding
      target_tensor = target_tensor.astype(input_tensor.dtype)
      return target_tensor

    if saving_to_hf:
      return to_hf()
    else:
      return from_hf()

  def reshape_kernel(input_tensor, target_shape):
    def to_hf():
      flipped_target_shape = np.flip(np.array(target_shape))
      return input_tensor.reshape(flipped_target_shape).T

    def from_hf():
      return input_tensor.T.reshape(target_shape)

    if saving_to_hf:
      return to_hf()
    else:
      return from_hf()

  def scale_rmsnorm_layer(input_tensor, target_shape):
    def to_hf():
      return (input_tensor - 1.0).reshape(target_shape)

    def from_hf():
      return (input_tensor + 1.0).reshape(target_shape)

    if saving_to_hf:
      return to_hf()
    else:
      return from_hf()

  def scale_query_layer(input_tensor, target_shape):
    def to_hf():
      depth_scale = np.dtype("float32").type(np.sqrt(config["head_dim"]))
      return (input_tensor * depth_scale).astype(input_tensor.dtype)

    def from_hf():
      depth_scale = np.dtype("float32").type(1 / np.sqrt(config["head_dim"]))
      return (input_tensor * depth_scale).astype(input_tensor.dtype)

    if saving_to_hf:
      return to_hf()
    else:
      return from_hf()

  mapping = {
      "params-token_embedder-embedding": pad_hf_embedding_layer,
      "params-decoder-decoder_norm-scale": scale_rmsnorm_layer,
  }
  if scan_layers:
    mapping = {
        **mapping,
        "params-decoder-layers-attention-query-kernel": [
            reshape_kernel,
            scale_query_layer,
        ],
        "params-decoder-layers-attention-key-kernel": reshape_kernel,
        "params-decoder-layers-attention-value-kernel": reshape_kernel,
        "params-decoder-layers-mlp-wo-kernel": reshape_kernel,
        "params-decoder-layers-mlp-wi_1-kernel": reshape_kernel,
        "params-decoder-layers-mlp-wi_0-kernel": reshape_kernel,
        "params-decoder-layers-attention-out-kernel": reshape_kernel,
        "params-decoder-layers-rms_norm-scale": scale_rmsnorm_layer,
        "params-decoder-layers-ffn_rms_norm-scale": scale_rmsnorm_layer,
    }
  else:
    for layer_idx in range(nlayers):
      mapping = {
          **mapping,
          f"params-decoder-layers_{layer_idx}-attention-query-kernel": [
              reshape_kernel,
              scale_query_layer,
          ],
          f"params-decoder-layers_{layer_idx}-attention-key-kernel": reshape_kernel,
          f"params-decoder-layers_{layer_idx}-attention-value-kernel": reshape_kernel,
          f"params-decoder-layers_{layer_idx}-mlp-wo-kernel": reshape_kernel,
          f"params-decoder-layers_{layer_idx}-mlp-wi_1-kernel": reshape_kernel,
          f"params-decoder-layers_{layer_idx}-mlp-wi_0-kernel": reshape_kernel,
          f"params-decoder-layers_{layer_idx}-attention-out-kernel": reshape_kernel,
          f"params-decoder-layers_{layer_idx}-rms_norm-scale": scale_rmsnorm_layer,
          f"params-decoder-layers_{layer_idx}-ffn_rms_norm-scale": scale_rmsnorm_layer,
      }
  return mapping


def Gemma3_HF_WEIGHTS_TO_SHAPE_MAPPING(config):
  """Returns mapping between HuggingFace weights path and weights shape.

  Args:
      config (dict): Model configuration dictionary, defined in `model_configs.py`

  Returns:
      dict: A mapping where:
          - Keys are HuggingFace model parameter paths
          - Values are parameter shape as a List
  """

  mapping = {
      "model.embed_tokens.weight": [config["vocab_size"], config["hidden_size"]],
      "model.norm.weight": [config["hidden_size"]],
  }
  for layer_idx in range(config["num_hidden_layers"]):
    layer_mapping = {
        f"model.layers.{layer_idx}.input_layernorm.weight": [config["hidden_size"]],
        f"model.layers.{layer_idx}.post_attention_layernorm.weight": [config["hidden_size"]],
        f"model.layers.{layer_idx}.self_attn.q_proj.weight": [
            config["num_attention_heads"] * config["head_dim"],
            config["hidden_size"],
        ],
        f"model.layers.{layer_idx}.self_attn.k_proj.weight": [
            config["num_key_value_heads"] * config["head_dim"],
            config["hidden_size"],
        ],
        f"model.layers.{layer_idx}.self_attn.v_proj.weight": [
            config["num_key_value_heads"] * config["head_dim"],
            config["hidden_size"],
        ],
        f"model.layers.{layer_idx}.self_attn.o_proj.weight": [
            config["hidden_size"],
            config["num_attention_heads"] * config["head_dim"],
        ],
        f"model.layers.{layer_idx}.mlp.gate_proj.weight": [
            config["intermediate_size"],
            config["hidden_size"],
        ],
        f"model.layers.{layer_idx}.mlp.up_proj.weight": [
            config["intermediate_size"],
            config["hidden_size"],
        ],
        f"model.layers.{layer_idx}.mlp.down_proj.weight": [
            config["hidden_size"],
            config["intermediate_size"],
        ],
    }
    mapping = {**mapping, **layer_mapping}
  return mapping

```