### Convert Megatron checkpoint to Huggingface model
1. Convert the Megatron checkpoint to Huggingface Model.
2. Compare the model weights between original huggingface model and the converted one.

In [1]:
import os
import sys
import torch
import argparse

MEGATRON_ROOT = "/cpfs/29ccba8f16c61395/data/user/liushan/projects/Megatron-LM-master/"
sys.path.insert(0, MEGATRON_ROOT)

# import unicorn
sys.path.append(os.path.join(MEGATRON_ROOT, "tools", "unicorn"))
import unicorn

  from .autonotebook import tqdm as notebook_tqdm


[2023-10-18 06:02:24,765] [INFO] [real_accelerator.py:110:get_accelerator] Setting ds_accelerator to cuda (auto detect)


#### Convert Megatron checkpoint to Huggingface model
- You can also use shell in `tools/unicorn/examples/llama/convert_megatron_to_hf.sh`.

In [2]:
# We use the converted megatron checkpoint by `tests/unicorn/convert_hf_to_megatron_and_resume.ipynb`,

def parse_args():
    parser = argparse.ArgumentParser()
    parser = unicorn.add_checkpointing_args(parser)
    parser = unicorn.add_transformers_checkpoint_args(parser)
    parser = unicorn.add_megatron_checkpoint_args(parser)
    args = parser.parse_args()
    return args

megatron_ckpt = os.path.join(MEGATRON_ROOT, "models", "llama-megatron")

sys.argv = ['script.py',
            '--convert_checkpoint_from_megatron_to_transformers',
            '--megatron-path', MEGATRON_ROOT,
            '--load-path', os.path.join(MEGATRON_ROOT, "models", "llama-megatron", "release"),
            '--save-path', os.path.join(MEGATRON_ROOT, "models", "llama-hf") ,
            '--model-name', 'llama2-13b',
            '--template-name', 'llama',
            '--print-checkpoint-structure',
            '--target_params_dtype', 'fp16']

args = parse_args()

In [3]:
unicorn.convert_checkpoint_from_megatron_to_transformers(args)

=> Loading Megatron-LM checkpoint from: /cpfs/29ccba8f16c61395/data/user/liushan/projects/Megatron-LM-master/models/llama-megatron/release/mp_rank_00_000/model_optim_rng.pt
=> vocab_size: 32000
=> Saving <class 'transformers.models.llama.configuration_llama.LlamaConfig'> to /cpfs/29ccba8f16c61395/data/user/liushan/projects/Megatron-LM-master/models/llama-hf ...
=> converting ...
=> converting word embeddings ...
=> Converting transformer layers ...
=> converting pipeline parallel rank 0 ...
	=> processing layers.0.input_norm.weight ...
	=> processing layers.0.self_attention.query_key_value.weight ...
	=> processing layers.0.self_attention.dense.weight ...
	=> processing layers.0.post_attention_norm.weight ...
	=> processing layers.0.mlp.dense_h_to_4h.weight ...
	=> processing layers.0.mlp.dense_4h_to_h.weight ...
	=> processing layers.1.input_norm.weight ...
	=> processing layers.1.self_attention.query_key_value.weight ...
	=> processing layers.1.self_attention.dense.weight ...
	=> pro

In [4]:
target_path = os.path.join(MEGATRON_ROOT, "models", "llama-hf")
!ls {target_path}

config.json			  pytorch_model-00003-of-00003.bin
pytorch_model-00001-of-00003.bin  pytorch_model.bin.index.json
pytorch_model-00002-of-00003.bin


#### Compare model weights.
* You can check the shell example `tools/unicorn/examples/llama/compare_hf_model.sh`.
* The results will show the key with different content, if the value's dtype, shape or elements count.
    * **\*.rotary_emb.inv_freq** can ignore, pay attention to other keys.
    * Empty results is okay.

In [10]:
sys.path.append(os.path.join(MEGATRON_ROOT, "tests", "unicorn"))
from compare_hf_model import compare_state_dicts, load_hf_model

src_state_dict = load_hf_model(os.path.join(MEGATRON_ROOT, "models", "Llama-2-13b-hf"))
dst_state_dict = load_hf_model(os.path.join(MEGATRON_ROOT, "models", "llama-hf"))
_ = compare_state_dicts(src_state_dict, dst_state_dict)

=> Comparing SRC[/cpfs/29ccba8f16c61395/data/user/liushan/projects/Megatron-LM-master/models/Llama-2-13b-hf] to DST[/cpfs/29ccba8f16c61395/data/user/liushan/projects/Megatron-LM-master/models/llama-hf] ...
=> Only these keys in src_model: [/cpfs/29ccba8f16c61395/data/user/liushan/projects/Megatron-LM-master/models/Llama-2-13b-hf]: == {'model.layers.9.self_attn.rotary_emb.inv_freq', 'model.layers.33.self_attn.rotary_emb.inv_freq', 'model.layers.24.self_attn.rotary_emb.inv_freq', 'model.layers.36.self_attn.rotary_emb.inv_freq', 'model.layers.38.self_attn.rotary_emb.inv_freq', 'model.layers.6.self_attn.rotary_emb.inv_freq', 'model.layers.3.self_attn.rotary_emb.inv_freq', 'model.layers.1.self_attn.rotary_emb.inv_freq', 'model.layers.26.self_attn.rotary_emb.inv_freq', 'model.layers.14.self_attn.rotary_emb.inv_freq', 'model.layers.5.self_attn.rotary_emb.inv_freq', 'model.layers.22.self_attn.rotary_emb.inv_freq', 'model.layers.29.self_attn.rotary_emb.inv_freq', 'model.layers.2.self_attn.rotar