In [None]:
# installation
!pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu
!pip3 install tokenizers -U
!pip3 install transformers -U

In [None]:
# go to maxtext/MaxText for library import

current_dir = %pwd
working_dir = current_dir.replace("scratch_code", "") 
%cd $working_dir

# one layer mixtral model

In [None]:
import pyconfig
from transformers.models.mixtral.configuration_mixtral import MixtralConfig

pyconfig.initialize(
    [None, "configs/base.yml"],
    base_emb_dim=4096,
    base_num_query_heads=32,
    base_num_kv_heads=8,
    base_mlp_dim=14336,
    base_num_decoder_layers=1,  # 1 layer for simplicity
    head_dim=128,
    mlp_activations=["silu","linear"],
    vocab_size=32000,
    enable_dropout=False,
    logits_via_embedding=False,
    normalization_layer_epsilon=1.0e-5,
    num_experts=8,
    num_experts_per_tok=2,
    rope_max_timescale=1_000_000,
    decoder_block="mistral",
    run_name="moe_test",
    enable_checkpointing=False,
    dtype="bfloat16",
    weight_dtype="bfloat16",
    megablox=True,  # or False
    max_target_length=4,
    max_prefill_predict_length=3,
    per_device_batch_size=1,
    capacity_factor=-1,
    scan_layers=False,
)
config_maxtext = pyconfig.config

config_hf = MixtralConfig(
    vocab_size=config_maxtext.vocab_size,
    hidden_size=config_maxtext.emb_dim,
    intermediate_size=config_maxtext.mlp_dim,
    num_hidden_layers=config_maxtext.num_decoder_layers, 
    num_attention_heads=config_maxtext.base_num_query_heads,
    num_key_value_heads=config_maxtext.num_kv_heads,
    rms_norm_eps=config_maxtext.normalization_layer_epsilon,
    rope_theta=config_maxtext.rope_max_timescale,
    attention_dropout=0.0,
    num_experts_per_tok=config_maxtext.num_experts_per_tok,
    num_local_experts=config_maxtext.num_experts,
    tie_word_embeddings=config_maxtext.logits_via_embedding,
    output_router_logits=False,
    router_aux_loss_coef=0.001,
    router_jitter_noise=0.0,
    torch_dtype="bfloat16",
)

In [None]:
from transformers import AutoModelForCausalLM, set_seed
import jax
import jax.numpy as jnp
from layers.models import Transformer
import max_utils
from jax.sharding import Mesh

# ensure the same model initialization
set_seed(0)

model_hf = AutoModelForCausalLM.from_config(config_hf)

devices_array = max_utils.create_device_mesh(config_maxtext)
mesh = Mesh(devices_array, config_maxtext.mesh_axes)
prng_key = jax.random.PRNGKey(1234)
model_maxtext = Transformer(config=config_maxtext, mesh=mesh, quant=None)

In [None]:
import numpy as np

input_np = {
    'inputs': np.random.randint(0, config_maxtext.vocab_size, size=(int(config_maxtext.per_device_batch_size), config_maxtext.max_target_length)),
    'inputs_position': np.tile(np.arange(config_maxtext.max_target_length), (int(config_maxtext.per_device_batch_size), 1)),
}

In [None]:
state_maxtext = model_maxtext.init({'params': prng_key, 'dropout': prng_key, 'aqt': prng_key},
                            jnp.array(input_np['inputs']),
                            jnp.array(input_np['inputs_position']),
                            enable_dropout=config_maxtext.enable_dropout,
                            )

In [None]:
import torch 
from flax import linen as nn

state_map = {
    "['params']['decoder']['decoder_norm']['scale'].value": ("model.norm.weight", lambda x: x), 
    "['params']['decoder']['layers_0']['MoeBlock_0']['gate']['kernel'].value": ("model.layers.0.block_sparse_moe.gate.weight", lambda x: x.T),
    "['params']['decoder']['layers_0']['MoeBlock_0']['wi_0'].value": ("model.layers.0.block_sparse_moe.experts.<exp_idx>.w1.weight", lambda *x: torch.stack(*x, dim=0).transpose(1,2)),
    "['params']['decoder']['layers_0']['MoeBlock_0']['wi_1'].value": ("model.layers.0.block_sparse_moe.experts.<exp_idx>.w3.weight", lambda *x: torch.stack(*x, dim=0).transpose(1,2)),
    "['params']['decoder']['layers_0']['MoeBlock_0']['wo'].value": ("model.layers.0.block_sparse_moe.experts.<exp_idx>.w2.weight", lambda *x: torch.stack(*x, dim=0).transpose(1,2)),
    "['params']['decoder']['layers_0']['post_self_attention_layer_norm']['scale'].value": ("model.layers.0.post_attention_layernorm.weight", lambda x: x),
    "['params']['decoder']['layers_0']['pre_self_attention_layer_norm']['scale'].value": ("model.layers.0.input_layernorm.weight", lambda x:x),
    "['params']['decoder']['layers_0']['self_attention']['key']['kernel'].value": ("model.layers.0.self_attn.k_proj.weight", lambda x:x.T.reshape(config_hf.hidden_size, config_hf.num_key_value_heads, config_maxtext.head_dim)),
    "['params']['decoder']['layers_0']['self_attention']['out']['kernel'].value": ("model.layers.0.self_attn.o_proj.weight", lambda x:x.T.reshape(config_hf.num_attention_heads, config_maxtext.head_dim, config_hf.hidden_size)),
    "['params']['decoder']['layers_0']['self_attention']['query']['kernel'].value": ("model.layers.0.self_attn.q_proj.weight", lambda x:x.T.reshape(config_hf.hidden_size, config_hf.num_attention_heads, config_maxtext.head_dim) / np.sqrt(config_maxtext.head_dim)),
    "['params']['decoder']['layers_0']['self_attention']['value']['kernel'].value": ("model.layers.0.self_attn.v_proj.weight", lambda x:x.T.reshape(config_hf.hidden_size, config_hf.num_key_value_heads, config_maxtext.head_dim)),
    "['params']['decoder']['logits_dense']['kernel'].value": ("lm_head.weight", lambda x:x.T),
    "['params']['token_embedder']['embedding'].value": ("model.embed_tokens.weight", lambda x:x),
    }

state_hf = model_hf.state_dict()
def map_fn(key_path, value):
    key_path_str = jax.tree_util.keystr(key_path)
    torch_key, transform_fn = state_map[key_path_str]
    if "<exp_idx>" in torch_key:
        torch_tensors = [state_hf[torch_key.replace("<exp_idx>", str(i))] for i in range(config_hf.num_local_experts)]
    else:
        torch_tensors = state_hf[torch_key]
    
    torch_tensors = transform_fn(torch_tensors)

    assert value.shape == torch_tensors.shape, f"{key_path_str}, {value.shape}, {torch_tensors.shape}"
    new_value = jnp.array(torch_tensors.to(torch.float32).numpy(), dtype=value.dtype)
    if isinstance(value, nn.LogicallyPartitioned):
        new_value = value.replace_boxed(new_value)
    return new_value

loaded_state_maxtext = jax.tree_util.tree_map_with_path(map_fn, state_maxtext)

In [None]:
logits_hf = model_hf(torch.from_numpy(input_np['inputs'])).logits.detach()

logits_maxtext = model_maxtext.apply(
    loaded_state_maxtext,
    input_np['inputs'],
    input_np['inputs_position'],
    enable_dropout=False,
    )

In [None]:
# currently, pass the following tests in both "megablox=True" & "megablox=False capacity_factor=-1"

np.testing.assert_allclose(np.array(logits_maxtext), logits_hf.numpy(), rtol=1e-1, atol=1e-1)