In [1]:
#! ensure moe-infinity is installed by 'pip install git+https://github.com/TorchMoE/MoE-Infinity'

In [1]:
import torch
import os
from transformers import AutoTokenizer, SwitchTransformersForConditionalGeneration, TextStreamer
from moe_infinity import MoE

import safetensors.torch as st
import json
import shutil

from tqdm.notebook import tqdm

Do not detect pre-installed ops, use JIT mode


In [4]:
# It's much faster to use local file than blobfuse
# we also need to convert ckpt format

def download_and_convert_ckpt(mixtral_blob_path, nlm_blob_path, local_path):
    os.makedirs(local_path, exist_ok=True)
    bar = tqdm(total=35)

    tensor_index = {
        "metadata": {
            "total_size": 0
        },
        "weight_map": {}
    }

    # input emb
    bar.set_description("input emb")
    ckpt_old = torch.load(os.path.join(nlm_blob_path, "layer_00-model_states.pt"), map_location='cpu')
    ckpt_new_name = "model_00.safetensors"
    emb_weight = ckpt_old["embed_tokens.weight"]
    ckpt_new = {
        "model.embed_tokens.weight": emb_weight
    }

    tensor_index["metadata"]["total_size"] += emb_weight.numel()
    tensor_index["weight_map"]["model.embed_tokens.weight"] = ckpt_new_name
    st.save_file(ckpt_new, os.path.join(local_path, ckpt_new_name))
    bar.update(1)

    # layer 1 to 32
    for i in range(0, 32):
        bar.set_description(f"layer {i+1}")
        ckpt_old = torch.load(os.path.join(nlm_blob_path, f"layer_{i+1:02d}-model_states.pt"), map_location='cpu')
        ckpt_new_name = f"model_{i+1:02d}.safetensors"
        ckpt_new = {}

        # Attn QKVO proj
        ckpt_new[f"model.layers.{i}.self_attn.q_proj.weight"] = ckpt_old["self_attn.q_proj.weight"]
        ckpt_new[f"model.layers.{i}.self_attn.k_proj.weight"] = ckpt_old["self_attn.k_proj.weight"]
        ckpt_new[f"model.layers.{i}.self_attn.v_proj.weight"] = ckpt_old["self_attn.v_proj.weight"]
        ckpt_new[f"model.layers.{i}.self_attn.o_proj.weight"] = ckpt_old["self_attn.o_proj.weight"]

        # MoE
        for j in range(8):
            ckpt_new[f"model.layers.{i}.block_sparse_moe.experts.{j}.w1.weight"] = ckpt_old[f"block_sparse_moe.experts.{j}.w1.weight"]
            ckpt_new[f"model.layers.{i}.block_sparse_moe.experts.{j}.w2.weight"] = ckpt_old[f"block_sparse_moe.experts.{j}.w2.weight"]
            ckpt_new[f"model.layers.{i}.block_sparse_moe.experts.{j}.w3.weight"] = ckpt_old[f"block_sparse_moe.experts.{j}.w3.weight"]
        ckpt_new[f"model.layers.{i}.block_sparse_moe.gate.weight"] = ckpt_old["block_sparse_moe.gate.weight"]

        # LN
        ckpt_new[f"model.layers.{i}.input_layernorm.weight"] = ckpt_old["input_layernorm.weight"]
        ckpt_new[f"model.layers.{i}.post_attention_layernorm.weight"] = ckpt_old["post_attention_layernorm.weight"]

        for k, v in ckpt_new.items():
            tensor_index["metadata"]["total_size"] += v.numel()
            tensor_index["weight_map"][k] = ckpt_new_name

        st.save_file(ckpt_new, os.path.join(local_path, ckpt_new_name))
        bar.update(1)

    # Final norm
    bar.set_description("final norm")
    ckpt_old = torch.load(os.path.join(nlm_blob_path, "layer_33-model_states.pt"), map_location='cpu')
    ckpt_new_name = "model_33.safetensors"
    emb_weight = ckpt_old["norm.weight"]
    ckpt_new = {
        "model.norm.weight": emb_weight
    }

    tensor_index["metadata"]["total_size"] += emb_weight.numel()
    tensor_index["weight_map"]["model.norm.weight"] = ckpt_new_name
    st.save_file(ckpt_new, os.path.join(local_path, ckpt_new_name))
    bar.update(1)

    # LM head
    bar.set_description("LM head")
    ckpt_old = torch.load(os.path.join(nlm_blob_path, "layer_34-model_states.pt"), map_location='cpu')
    ckpt_new_name = "model_34.safetensors"
    emb_weight = ckpt_old["lm_head.weight"]
    ckpt_new = {
        "lm_head.weight": emb_weight
    }

    tensor_index["metadata"]["total_size"] += emb_weight.numel()
    tensor_index["weight_map"]["lm_head.weight"] = ckpt_new_name
    st.save_file(ckpt_new, os.path.join(local_path, ckpt_new_name))
    bar.update(1)

    with open(os.path.join(local_path, "model.safetensors.index.json"), "w") as f:
        json.dump(tensor_index, f, indent=2)

    print(f"Maped {tensor_index['metadata']['total_size']} tensors")

    # Other config files
    config = json.load(open(os.path.join(mixtral_blob_path, "config.json")))
    config["vocab_size"] = 33982
    with open(os.path.join(local_path, "config.json"), "w") as f:
        json.dump(config, f, indent=2)

    for file in ["generation_config.json", "special_tokens_map.json", "tokenizer.json", "tokenizer.model", "tokenizer_config.json"]:
        shutil.copyfile(os.path.join(mixtral_blob_path, file), os.path.join(local_path, file))

    # show file list in local_path
    print("Files in local_path:")
    for root, dirs, files in os.walk(local_path):
        for file in files:
            print(os.path.relpath(os.path.join(root, file), local_path))
    print("Done")
    bar.close()

In [10]:
download_and_convert_ckpt(
    "/hai1/shufxi/Mixtral-8x7B-v0.1",
    "/nlm/shufxi/nlm/8x7b/stageA/global_step3999",
    "/tmp/nlm"
)

  0%|          | 0/35 [00:00<?, ?it/s]

Maped 46719029248 tensors
Files in local_path:
model_32.safetensors
model_20.safetensors
model.safetensors.index.json
model_34.safetensors
model_29.safetensors
model_00.safetensors
model_11.safetensors
model_06.safetensors
generation_config.json
model_13.safetensors
model_09.safetensors
model_14.safetensors
tokenizer.json
special_tokens_map.json
model_07.safetensors
model_10.safetensors
tokenizer_config.json
config.json
model_16.safetensors
model_03.safetensors
model_30.safetensors
model_22.safetensors
model_08.safetensors
model_23.safetensors
model_12.safetensors
model_24.safetensors
model_01.safetensors
model_21.safetensors
model_17.safetensors
model_19.safetensors
model_18.safetensors
model_04.safetensors
model_15.safetensors
tokenizer.model
model_33.safetensors
model_05.safetensors
model_28.safetensors
model_02.safetensors
model_31.safetensors
model_26.safetensors
model_27.safetensors
model_25.safetensors
Done


In [2]:
checkpoint = '/tmp/nlm'
config = {
    "offload_path": "/tmp/moe-infinity",
    "device_memory_ratio": 0.75,
}

model = MoE(checkpoint, config)

Using /home/shufxi/.cache/torch_extensions/py310_cu121 as PyTorch extensions root...
Emitting ninja build file /home/shufxi/.cache/torch_extensions/py310_cu121/prefetch/build.ninja...
Building extension module prefetch...
Allowing ninja to set a default number of workers... (overridable by setting the environment variable MAX_JOBS=N)


ninja: no work to do.
Time to load prefetch op: 2.3624579906463623 seconds
SPDLOG_LEVEL : (null)
2024-04-26 09:10:43.151 INFO Create ArcherAioThread for thread: , 0
2024-04-26 09:10:43.151 INFO Loading index file from , /tmp/moe-infinity/archer_index
2024-04-26 09:10:43.152 INFO Index file size , 995
2024-04-26 09:10:43.152 INFO Device count , 1
2024-04-26 09:10:43.152 INFO Enabled peer access for all devices
Loading model from offload_path ...


Loading extension module prefetch...
Model create:   1%|          | 6/994 [00:00<00:20, 47.61it/s]MixtralBLockSparseTop2MLP is deprecated by MixtralBlockSparseTop2MLP and will be removed in v4.40.
Model create:  91%|█████████ | 905/994 [00:03<00:00, 251.00it/s]

MixtralConfig {
  "_name_or_path": "/tmp/nlm",
  "architectures": [
    "MixtralForCausalLM"
  ],
  "attention_dropout": 0.0,
  "bos_token_id": 1,
  "eos_token_id": 2,
  "hidden_act": "silu",
  "hidden_size": 4096,
  "initializer_range": 0.02,
  "intermediate_size": 14336,
  "max_position_embeddings": 32768,
  "model_type": "mixtral",
  "num_attention_heads": 32,
  "num_experts_per_tok": 2,
  "num_hidden_layers": 32,
  "num_key_value_heads": 8,
  "num_local_experts": 8,
  "output_router_logits": false,
  "rms_norm_eps": 1e-05,
  "rope_theta": 1000000.0,
  "router_aux_loss_coef": 0.02,
  "sliding_window": null,
  "tie_word_embeddings": false,
  "torch_dtype": "bfloat16",
  "transformers_version": "4.39.3",
  "use_cache": true,
  "vocab_size": 33982
}



Model create:  94%|█████████▎| 930/994 [00:19<00:00, 251.00it/s]

In [4]:
from sfm.data.sci_data.NlmTokenizer import NlmTokenizer
tokenizer = NlmTokenizer.from_pretrained(checkpoint)

[2024-04-26 09:20:43,565] [INFO] [real_accelerator.py:191:get_accelerator] Setting ds_accelerator to cuda (auto detect)


The tokenizer class you load from this checkpoint is not the same type as the class this function is called from. It may result in unexpected tokenization. 
The tokenizer class you load from this checkpoint is 'LlamaTokenizer'. 
The class this function is called from is 'NlmTokenizer'.


[[32m2024-04-26 09:20:48.061[0m][[36mINFO[0m]: Tokenizer has 33982 tokens


In [5]:
input_text = "An apple a day"
input_ids = tokenizer(input_text, return_tensors="pt").input_ids.cuda()

with torch.no_grad():
    outputs = model.generate(
            input_ids,
            max_new_tokens=20,
            do_sample=False,
            pad_token_id=tokenizer.pad_token_id,
        )
output_text = tokenizer.decode(outputs[0])

print(output_text)



<s>An apple a day keeps the doctor away.

We’ve all heard this saying, but is it true?


In [9]:
input_text = "<mol>C1=CC=CC=C1</mol> <mol>"
input_ids = tokenizer(input_text, return_tensors="pt").input_ids.cuda()

with torch.no_grad():
    outputs = model.generate(
            input_ids,
            max_new_tokens=20,
            do_sample=False,
            pad_token_id=tokenizer.pad_token_id,
        )
output_text = tokenizer.decode(outputs[0])

print(output_text)

<s> <mol> <m>C <m>1 <m>= <m>C <m>C <m>= <m>C <m>C <m>= <m>C <m>1 </mol> <mol> <m>C <m>C <m>C <m>C <m>C <m>C a great place to visit.










In [11]:
input_text = "<protein>MKQHKAMIVALIVICITAVVAALVTRKDLCEVHIRTGQTEVAVF</protein> <protein>"
input_ids = tokenizer(input_text, return_tensors="pt").input_ids.cuda()

with torch.no_grad():
    outputs = model.generate(
            input_ids,
            max_new_tokens=20,
            do_sample=False,
            pad_token_id=tokenizer.pad_token_id,
        )
output_text = tokenizer.decode(outputs[0])

print(output_text)

<s> <protein> <a>M <a>K <a>Q <a>H <a>K <a>A <a>M <a>I <a>V <a>A <a>L <a>I <a>V <a>I <a>C <a>I <a>T <a>A <a>V <a>V <a>A <a>A <a>L <a>V <a>T <a>R <a>K <a>D <a>L <a>C <a>E <a>V <a>H <a>I <a>R <a>T <a>G <a>Q <a>T <a>E <a>V <a>A <a>V <a>F </protein> <protein> <a>M <a>L <a>L <a>L <a>L <a>L <a>L <a>V <a>V <a>I <a>L <a>A <a>L <a>A <a>L <a>A <a>L <a>A <a>L <a>A
