In [1]:
# convert MoE model to DeepSpeed PP checkpoint

In [3]:
import torch
import json
import safetensors

In [4]:
ckpt_home = '/hai1/shufxi/Mixtral-8x7B-v0.1'

In [5]:
with open(f'{ckpt_home}/model.safetensors.index.json') as f:
    st_index = json.load(f)

In [11]:
procssed_keys = set()
def get_tensor_by_name(name):
    procssed_keys.add(name)
    st_file = f'{ckpt_home}/{st_index["weight_map"][name]}'
    with safetensors.safe_open(st_file, framework='pt', device='cpu') as f:
        return f.get_tensor(name)

In [12]:
# input embedding
state_dict = {}
state_dict['embed_tokens.weight'] = get_tensor_by_name('model.embed_tokens.weight')

torch.save(state_dict, f'{ckpt_home}/layer_00-model_states.pt')

In [14]:
from tqdm import tqdm

In [20]:
# each layers

for idx in tqdm(range(32)):
    state_dict = {}
    state_dict['self_attn.q_proj.weight'] = get_tensor_by_name(f'model.layers.{idx}.self_attn.q_proj.weight')
    state_dict['self_attn.k_proj.weight'] = get_tensor_by_name(f'model.layers.{idx}.self_attn.k_proj.weight')
    state_dict['self_attn.v_proj.weight'] = get_tensor_by_name(f'model.layers.{idx}.self_attn.v_proj.weight')
    state_dict['self_attn.o_proj.weight'] = get_tensor_by_name(f'model.layers.{idx}.self_attn.o_proj.weight')

    for eid in range(8):
        state_dict[f'block_sparse_moe.experts.{eid}.w1.weight'] = get_tensor_by_name(f'model.layers.{idx}.block_sparse_moe.experts.{eid}.w1.weight')
        state_dict[f'block_sparse_moe.experts.{eid}.w2.weight'] = get_tensor_by_name(f'model.layers.{idx}.block_sparse_moe.experts.{eid}.w2.weight')
        state_dict[f'block_sparse_moe.experts.{eid}.w3.weight'] = get_tensor_by_name(f'model.layers.{idx}.block_sparse_moe.experts.{eid}.w3.weight')

    state_dict['block_sparse_moe.gate.weight'] = get_tensor_by_name(f'model.layers.{idx}.block_sparse_moe.gate.weight')

    state_dict['input_layernorm.weight'] = get_tensor_by_name(f'model.layers.{idx}.input_layernorm.weight')
    state_dict['post_attention_layernorm.weight'] = get_tensor_by_name(f'model.layers.{idx}.post_attention_layernorm.weight')

    torch.save(state_dict, f'{ckpt_home}/layer_{idx+1:02d}-model_states.pt')

100%|██████████| 32/32 [13:23<00:00, 25.10s/it]


In [16]:
# final norm
state_dict = {}
state_dict['norm.weight'] = get_tensor_by_name('model.norm.weight')
torch.save(state_dict, f'{ckpt_home}/layer_33-model_states.pt')

In [17]:
# lm head
state_dict = {}
state_dict['lm_head.weight'] = get_tensor_by_name('lm_head.weight')
torch.save(state_dict, f'{ckpt_home}/layer_34-model_states.pt')

In [21]:
# check if all tensors are processed
for k in st_index['weight_map']:
    if k not in procssed_keys:
        print(f'Warning: {k} not processed')