In [None]:
#! ensure moe-infinity is installed by 'pip install git+https://github.com/TorchMoE/MoE-Infinity'

In [1]:
import torch
import os
from transformers import AutoTokenizer, SwitchTransformersForConditionalGeneration, TextStreamer
from moe_infinity import MoE

import safetensors.torch as st
import json
import shutil

from tqdm.notebook import tqdm

Do not detect pre-installed ops, use JIT mode


In [None]:
# It's much faster to use local file than blobfuse
# we also need to convert ckpt format

def download_and_convert_ckpt(mixtral_blob_path, nlm_blob_path, local_path):
    os.makedirs(local_path, exist_ok=True)
    bar = tqdm(total=35)

    tensor_index = {
        "metadata": {
            "total_size": 0
        },
        "weight_map": {}
    }

    # input emb
    bar.set_description("input emb")
    ckpt_old = torch.load(os.path.join(nlm_blob_path, "layer_00-model_states.pt"), map_location='cpu')
    ckpt_new_name = "model_00.safetensors"
    emb_weight = ckpt_old["embed_tokens.weight"]
    ckpt_new = {
        "model.embed_tokens.weight": emb_weight
    }

    tensor_index["metadata"]["total_size"] += emb_weight.numel()
    tensor_index["weight_map"]["model.embed_tokens.weight"] = ckpt_new_name
    st.save_file(ckpt_new, os.path.join(local_path, ckpt_new_name))
    bar.update(1)

    # layer 1 to 32
    for i in range(0, 32):
        bar.set_description(f"layer {i+1}")
        ckpt_old = torch.load(os.path.join(nlm_blob_path, f"layer_{i+1:02d}-model_states.pt"), map_location='cpu')
        ckpt_new_name = f"model_{i+1:02d}.safetensors"
        ckpt_new = {}

        # Attn QKVO proj
        ckpt_new[f"model.layers.{i}.self_attn.q_proj.weight"] = ckpt_old["self_attn.q_proj.weight"]
        ckpt_new[f"model.layers.{i}.self_attn.k_proj.weight"] = ckpt_old["self_attn.k_proj.weight"]
        ckpt_new[f"model.layers.{i}.self_attn.v_proj.weight"] = ckpt_old["self_attn.v_proj.weight"]
        ckpt_new[f"model.layers.{i}.self_attn.o_proj.weight"] = ckpt_old["self_attn.o_proj.weight"]

        # MoE
        for j in range(8):
            ckpt_new[f"model.layers.{i}.block_sparse_moe.experts.{j}.w1.weight"] = ckpt_old[f"block_sparse_moe.experts.{j}.w1.weight"]
            ckpt_new[f"model.layers.{i}.block_sparse_moe.experts.{j}.w2.weight"] = ckpt_old[f"block_sparse_moe.experts.{j}.w2.weight"]
            ckpt_new[f"model.layers.{i}.block_sparse_moe.experts.{j}.w3.weight"] = ckpt_old[f"block_sparse_moe.experts.{j}.w3.weight"]
        ckpt_new[f"model.layers.{i}.block_sparse_moe.gate.weight"] = ckpt_old["block_sparse_moe.gate.weight"]

        # LN
        ckpt_new[f"model.layers.{i}.input_layernorm.weight"] = ckpt_old["input_layernorm.weight"]
        ckpt_new[f"model.layers.{i}.post_attention_layernorm.weight"] = ckpt_old["post_attention_layernorm.weight"]

        for k, v in ckpt_new.items():
            tensor_index["metadata"]["total_size"] += v.numel()
            tensor_index["weight_map"][k] = ckpt_new_name

        st.save_file(ckpt_new, os.path.join(local_path, ckpt_new_name))
        bar.update(1)

    # Final norm
    bar.set_description("final norm")
    ckpt_old = torch.load(os.path.join(nlm_blob_path, "layer_33-model_states.pt"), map_location='cpu')
    ckpt_new_name = "model_33.safetensors"
    emb_weight = ckpt_old["norm.weight"]
    ckpt_new = {
        "model.norm.weight": emb_weight
    }

    tensor_index["metadata"]["total_size"] += emb_weight.numel()
    tensor_index["weight_map"]["model.norm.weight"] = ckpt_new_name
    st.save_file(ckpt_new, os.path.join(local_path, ckpt_new_name))
    bar.update(1)

    # LM head
    bar.set_description("LM head")
    ckpt_old = torch.load(os.path.join(nlm_blob_path, "layer_34-model_states.pt"), map_location='cpu')
    ckpt_new_name = "model_34.safetensors"
    emb_weight = ckpt_old["lm_head.weight"]
    ckpt_new = {
        "lm_head.weight": emb_weight
    }

    tensor_index["metadata"]["total_size"] += emb_weight.numel()
    tensor_index["weight_map"]["lm_head.weight"] = ckpt_new_name
    st.save_file(ckpt_new, os.path.join(local_path, ckpt_new_name))
    bar.update(1)

    with open(os.path.join(local_path, "model.safetensors.index.json"), "w") as f:
        json.dump(tensor_index, f, indent=2)

    print(f"Maped {tensor_index['metadata']['total_size']} tensors")

    # Other config files
    config = json.load(open(os.path.join(mixtral_blob_path, "config.json")))
    config["vocab_size"] = 33982
    with open(os.path.join(local_path, "config.json"), "w") as f:
        json.dump(config, f, indent=2)

    for file in ["generation_config.json", "special_tokens_map.json", "tokenizer.json", "tokenizer.model", "tokenizer_config.json"]:
        shutil.copyfile(os.path.join(mixtral_blob_path, file), os.path.join(local_path, file))

    # show file list in local_path
    print("Files in local_path:")
    for root, dirs, files in os.walk(local_path):
        for file in files:
            print(os.path.relpath(os.path.join(root, file), local_path))
    print("Done")
    bar.close()

In [None]:
download_and_convert_ckpt(
    "/hai1/shufxi/Mixtral-8x7B-v0.1",
    "/nlm/shufxi/nlm/8x7b/stageB/global_step54999",
    "/tmp/nlm"
)

In [2]:
checkpoint = '/tmp/nlm'
config = {
    "offload_path": "/tmp/moe-infinity",
    "device_memory_ratio": 0.75,
}

model = MoE(checkpoint, config)

Using /home/shufxi/.cache/torch_extensions/py311_cu121 as PyTorch extensions root...
Emitting ninja build file /home/shufxi/.cache/torch_extensions/py311_cu121/prefetch/build.ninja...
Building extension module prefetch...
Allowing ninja to set a default number of workers... (overridable by setting the environment variable MAX_JOBS=N)


ninja: no work to do.
Time to load prefetch op: 2.3728654384613037 seconds
SPDLOG_LEVEL : (null)
2024-05-10 08:06:09.635 INFO Create ArcherAioThread for thread: , 0
2024-05-10 08:06:09.635 INFO Loading index file from , /tmp/moe-infinity/archer_index
2024-05-10 08:06:09.636 INFO Index file size , 995
2024-05-10 08:06:09.636 INFO Device count , 1
2024-05-10 08:06:09.636 INFO Enabled peer access for all devices
Loading model from offload_path ...


Loading extension module prefetch...
Model create:   0%|          | 0/994 [00:00<?, ?it/s]MixtralBLockSparseTop2MLP is deprecated by MixtralBlockSparseTop2MLP and will be removed in v4.40.
Model create:  91%|█████████ | 905/994 [00:00<00:00, 2330.79it/s]

MixtralConfig {
  "_name_or_path": "/tmp/nlm",
  "architectures": [
    "MixtralForCausalLM"
  ],
  "attention_dropout": 0.0,
  "bos_token_id": 1,
  "eos_token_id": 2,
  "hidden_act": "silu",
  "hidden_size": 4096,
  "initializer_range": 0.02,
  "intermediate_size": 14336,
  "max_position_embeddings": 32768,
  "model_type": "mixtral",
  "num_attention_heads": 32,
  "num_experts_per_tok": 2,
  "num_hidden_layers": 32,
  "num_key_value_heads": 8,
  "num_local_experts": 8,
  "output_router_logits": false,
  "rms_norm_eps": 1e-05,
  "rope_theta": 1000000.0,
  "router_aux_loss_coef": 0.02,
  "router_jitter_noise": 0.0,
  "sliding_window": null,
  "tie_word_embeddings": false,
  "torch_dtype": "bfloat16",
  "transformers_version": "4.40.1",
  "use_cache": true,
  "vocab_size": 33982
}



In [3]:
import numpy as np

In [4]:
data = np.load('/nlm/shufxi/data/SFM.Mixtral.v0/valid.npy')
data.shape

(4204, 8192)

In [5]:
data = torch.from_numpy(data.astype('int64')).cuda()

In [9]:
loss_sum = 0
for i in tqdm(range(0, data.shape[0])):
    with torch.no_grad():
        input_ids = data[i].unsqueeze(0)
        labels = input_ids.clone()
        outputs = model(input_ids, labels=labels, return_dict=True)
        loss = outputs.loss
        loss_sum += loss.item()

print(loss_sum / data.shape[0])

  0%|          | 0/4204 [00:00<?, ?it/s]

In [8]:
0.0034876863634325456 * data.shape[0] / 10

1.4662233471870423

In [None]:
from sfm.data.sci_data.NlmTokenizer import NlmTokenizer
tokenizer = NlmTokenizer.from_pretrained(checkpoint)

In [None]:
input_text = "An apple a day"
input_ids = tokenizer(input_text, return_tensors="pt").input_ids.cuda()

with torch.no_grad():
    outputs = model.generate(
            input_ids,
            max_new_tokens=20,
            do_sample=False,
            pad_token_id=tokenizer.pad_token_id,
        )
output_text = tokenizer.decode(outputs[0])

print(output_text)

In [None]:
input_text = "<mol>C1=CC=CC=C1</mol> <mol>"
input_ids = tokenizer(input_text, return_tensors="pt").input_ids.cuda()

with torch.no_grad():
    outputs = model.generate(
            input_ids,
            max_new_tokens=20,
            do_sample=False,
            pad_token_id=tokenizer.pad_token_id,
        )
output_text = tokenizer.decode(outputs[0])

print(output_text)

In [None]:
input_text = "<protein>MKQHKAMIVALIVICITAVVAALVTRKDLCEVHIRTGQTEVAVF</protein> <protein>"
input_ids = tokenizer(input_text, return_tensors="pt").input_ids.cuda()

with torch.no_grad():
    outputs = model.generate(
            input_ids,
            max_new_tokens=20,
            do_sample=False,
            pad_token_id=tokenizer.pad_token_id,
        )
output_text = tokenizer.decode(outputs[0])

print(output_text)