In [1]:
from accelerate import init_empty_weights, load_checkpoint_and_dispatch, load_checkpoint_in_model, dispatch_model
from sfm.data.sci_data.NlmTokenizer import NlmTokenizer
from transformers import AutoTokenizer, AutoModelForCausalLM

import torch


[2024-07-08 07:54:59,172] [INFO] [real_accelerator.py:191:get_accelerator] Setting ds_accelerator to cuda (auto detect)


In [32]:
import json
import os
import shutil
from tqdm import tqdm


def download_and_convert_ckpt(mixtral_blob_path, nlm_blob_path, local_path):
        import safetensors.torch as st

        # skip if local path is not empty
        if os.path.exists(local_path) and len(os.listdir(local_path)) > 0:
            print(f"Local path {local_path} is not empty, skip downloading and converting ckpt")
            return

        os.makedirs(local_path, exist_ok=True)
        bar = tqdm(total=35)

        metadata = {"format": "pt"}
        tensor_index = {"metadata": {"total_size": 0}, "weight_map": {}}

        # input emb
        bar.set_description("input emb")
        ckpt_old = torch.load(
            os.path.join(nlm_blob_path, "layer_00-model_states.pt"), map_location="cpu"
        )
        ckpt_new_name = "model_00.safetensors"
        emb_weight = ckpt_old["embed_tokens.weight"]
        ckpt_new = {"model.embed_tokens.weight": emb_weight}

        tensor_index["metadata"]["total_size"] += emb_weight.numel()
        tensor_index["weight_map"]["model.embed_tokens.weight"] = ckpt_new_name
        st.save_file(ckpt_new, os.path.join(local_path, ckpt_new_name), metadata=metadata)
        bar.update(1)

        # layer 1 to 32
        for i in range(0, 32):
            bar.set_description(f"layer {i+1}")
            ckpt_old = torch.load(
                os.path.join(nlm_blob_path, f"layer_{i+1:02d}-model_states.pt"),
                map_location="cpu",
            )
            ckpt_new_name = f"model_{i+1:02d}.safetensors"
            ckpt_new = {}

            # Attn QKVO proj
            ckpt_new[f"model.layers.{i}.self_attn.q_proj.weight"] = ckpt_old[
                "self_attn.q_proj.weight"
            ]
            ckpt_new[f"model.layers.{i}.self_attn.k_proj.weight"] = ckpt_old[
                "self_attn.k_proj.weight"
            ]
            ckpt_new[f"model.layers.{i}.self_attn.v_proj.weight"] = ckpt_old[
                "self_attn.v_proj.weight"
            ]
            ckpt_new[f"model.layers.{i}.self_attn.o_proj.weight"] = ckpt_old[
                "self_attn.o_proj.weight"
            ]

            # MoE
            for j in range(8):
                ckpt_new[
                    f"model.layers.{i}.block_sparse_moe.experts.{j}.w1.weight"
                ] = ckpt_old[f"block_sparse_moe.experts.{j}.w1.weight"]
                ckpt_new[
                    f"model.layers.{i}.block_sparse_moe.experts.{j}.w2.weight"
                ] = ckpt_old[f"block_sparse_moe.experts.{j}.w2.weight"]
                ckpt_new[
                    f"model.layers.{i}.block_sparse_moe.experts.{j}.w3.weight"
                ] = ckpt_old[f"block_sparse_moe.experts.{j}.w3.weight"]
            ckpt_new[f"model.layers.{i}.block_sparse_moe.gate.weight"] = ckpt_old[
                "block_sparse_moe.gate.weight"
            ]

            # LN
            ckpt_new[f"model.layers.{i}.input_layernorm.weight"] = ckpt_old[
                "input_layernorm.weight"
            ]
            ckpt_new[f"model.layers.{i}.post_attention_layernorm.weight"] = ckpt_old[
                "post_attention_layernorm.weight"
            ]

            for k, v in ckpt_new.items():
                tensor_index["metadata"]["total_size"] += v.numel()
                tensor_index["weight_map"][k] = ckpt_new_name

            st.save_file(
                ckpt_new, os.path.join(local_path, ckpt_new_name), metadata=metadata
            )
            bar.update(1)

        # Final norm
        bar.set_description("final norm")
        ckpt_old = torch.load(
            os.path.join(nlm_blob_path, "layer_33-model_states.pt"), map_location="cpu"
        )
        ckpt_new_name = "model_33.safetensors"
        emb_weight = ckpt_old["norm.weight"]
        ckpt_new = {"model.norm.weight": emb_weight}

        tensor_index["metadata"]["total_size"] += emb_weight.numel()
        tensor_index["weight_map"]["model.norm.weight"] = ckpt_new_name
        st.save_file(ckpt_new, os.path.join(local_path, ckpt_new_name), metadata=metadata)
        bar.update(1)

        # LM head
        bar.set_description("LM head")
        ckpt_old = torch.load(
            os.path.join(nlm_blob_path, "layer_34-model_states.pt"), map_location="cpu"
        )
        ckpt_new_name = "model_34.safetensors"
        emb_weight = ckpt_old["lm_head.weight"]
        ckpt_new = {"lm_head.weight": emb_weight}

        tensor_index["metadata"]["total_size"] += emb_weight.numel()
        tensor_index["weight_map"]["lm_head.weight"] = ckpt_new_name
        st.save_file(ckpt_new, os.path.join(local_path, ckpt_new_name), metadata=metadata)
        bar.update(1)

        with open(os.path.join(local_path, "model.safetensors.index.json"), "w") as f:
            json.dump(tensor_index, f, indent=2)

        print(f"Maped {tensor_index['metadata']['total_size']} tensors")

        # Other config files
        config = json.load(open(os.path.join(mixtral_blob_path, "config.json")))
        config["vocab_size"] = 33982
        with open(os.path.join(local_path, "config.json"), "w") as f:
            json.dump(config, f, indent=2)

        for file in [
            "generation_config.json",
            "special_tokens_map.json",
            "tokenizer.json",
            "tokenizer.model",
            "tokenizer_config.json",
        ]:
            shutil.copyfile(
                os.path.join(mixtral_blob_path, file), os.path.join(local_path, file)
            )

        # show file list in local_path
        print("Files in local_path:")
        for root, dirs, files in os.walk(local_path):
            for file in files:
                print(os.path.relpath(os.path.join(root, file), local_path))
        print("Done")
        bar.close()

download_and_convert_ckpt(
    '/nlm/Mixtral-8x7B-v0.1/',
    "/nlm/shufxi/nlm/8x7b/inst/uspto50k/global_step312/",
    '/dev/shm/nlm'
)

LM head: 100%|██████████| 35/35 [30:31<00:00, 25.81s/it]   

Maped 46719029248 tensors


LM head: 100%|██████████| 35/35 [30:31<00:00, 52.34s/it]

Files in local_path:
tokenizer_config.json
tokenizer.model
tokenizer.json
special_tokens_map.json
generation_config.json
config.json
model.safetensors.index.json
model_34.safetensors
model_33.safetensors
model_32.safetensors
model_31.safetensors
model_30.safetensors
model_29.safetensors
model_28.safetensors
model_27.safetensors
model_26.safetensors
model_25.safetensors
model_24.safetensors
model_23.safetensors
model_22.safetensors
model_21.safetensors
model_20.safetensors
model_19.safetensors
model_18.safetensors
model_17.safetensors
model_16.safetensors
model_15.safetensors
model_14.safetensors
model_13.safetensors
model_12.safetensors
model_11.safetensors
model_10.safetensors
model_09.safetensors
model_08.safetensors
model_07.safetensors
model_06.safetensors
model_05.safetensors
model_04.safetensors
model_03.safetensors
model_02.safetensors
model_01.safetensors
model_00.safetensors
Done





In [2]:
def create_device_map():
    rank_start = 0

    n_layers = 32
    layer_per_rank = n_layers // 4
    device_map = {}
    device_map["model.embed_tokens.weight"] = rank_start
    for i in range(n_layers):
        device_idx = rank_start + i // layer_per_rank
        device_map[f"model.layers.{i}"] = device_idx

    device_map["model.norm.weight"] = rank_start + (n_layers-1) // layer_per_rank
    device_map["lm_head.weight"] = rank_start + (n_layers-1) // layer_per_rank

    return device_map

In [20]:
tokenizer = NlmTokenizer.from_pretrained('/nlm/Mixtral-8x7B-v0.1/')

The tokenizer class you load from this checkpoint is not the same type as the class this function is called from. It may result in unexpected tokenization. 
The tokenizer class you load from this checkpoint is 'LlamaTokenizer'. 
The class this function is called from is 'NlmTokenizer'.


In [33]:
local_path = '/dev/shm/nlm'
with init_empty_weights():
    model = AutoModelForCausalLM.from_pretrained(
        local_path,
        torch_dtype=torch.bfloat16,
    )
model = load_checkpoint_and_dispatch(
    model,
    local_path,
    device_map=create_device_map(),
    no_split_module_classes=['MixtralDecoderLayer'],
    dtype=torch.bfloat16,
    offload_folder=None,
    offload_state_dict=True
)

model._hf_hook.skip_keys = ['past_key_values']
model.eval()


Loading checkpoint shards:   0%|          | 0/35 [00:00<?, ?it/s]



  0%|          | 0/1 [00:00<?, ?w/s]

  0%|          | 0/31 [00:00<?, ?w/s]

  0%|          | 0/31 [00:00<?, ?w/s]

  0%|          | 0/31 [00:00<?, ?w/s]

  0%|          | 0/31 [00:00<?, ?w/s]

  0%|          | 0/31 [00:00<?, ?w/s]

  0%|          | 0/31 [00:00<?, ?w/s]

  0%|          | 0/31 [00:00<?, ?w/s]

  0%|          | 0/31 [00:00<?, ?w/s]

  0%|          | 0/31 [00:00<?, ?w/s]

  0%|          | 0/31 [00:00<?, ?w/s]

  0%|          | 0/31 [00:00<?, ?w/s]

  0%|          | 0/31 [00:00<?, ?w/s]

  0%|          | 0/31 [00:00<?, ?w/s]

  0%|          | 0/31 [00:00<?, ?w/s]

  0%|          | 0/31 [00:00<?, ?w/s]

  0%|          | 0/31 [00:00<?, ?w/s]

  0%|          | 0/31 [00:00<?, ?w/s]

  0%|          | 0/31 [00:00<?, ?w/s]

  0%|          | 0/31 [00:00<?, ?w/s]

  0%|          | 0/31 [00:00<?, ?w/s]

  0%|          | 0/31 [00:00<?, ?w/s]

  0%|          | 0/31 [00:00<?, ?w/s]

  0%|          | 0/31 [00:00<?, ?w/s]

  0%|          | 0/31 [00:00<?, ?w/s]

  0%|          | 0/31 [00:00<?, ?w/s]

  0%|          | 0/31 [00:00<?, ?w/s]

  0%|          | 0/31 [00:00<?, ?w/s]

  0%|          | 0/31 [00:00<?, ?w/s]

  0%|          | 0/31 [00:00<?, ?w/s]

  0%|          | 0/31 [00:00<?, ?w/s]

  0%|          | 0/31 [00:00<?, ?w/s]

  0%|          | 0/31 [00:00<?, ?w/s]

  0%|          | 0/1 [00:00<?, ?w/s]

  0%|          | 0/1 [00:00<?, ?w/s]

MixtralForCausalLM(
  (model): MixtralModel(
    (embed_tokens): Embedding(33982, 4096)
    (layers): ModuleList(
      (0-31): 32 x MixtralDecoderLayer(
        (self_attn): MixtralSdpaAttention(
          (q_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (k_proj): Linear(in_features=4096, out_features=1024, bias=False)
          (v_proj): Linear(in_features=4096, out_features=1024, bias=False)
          (o_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (rotary_emb): MixtralRotaryEmbedding()
        )
        (block_sparse_moe): MixtralSparseMoeBlock(
          (gate): Linear(in_features=4096, out_features=8, bias=False)
          (experts): ModuleList(
            (0-7): 8 x MixtralBlockSparseTop2MLP(
              (w1): Linear(in_features=4096, out_features=14336, bias=False)
              (w2): Linear(in_features=14336, out_features=4096, bias=False)
              (w3): Linear(in_features=4096, out_features=14336, bias=False)
    

In [5]:
model

MixtralForCausalLM(
  (model): MixtralModel(
    (embed_tokens): Embedding(33982, 4096)
    (layers): ModuleList(
      (0-31): 32 x MixtralDecoderLayer(
        (self_attn): MixtralFlashAttention2(
          (q_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (k_proj): Linear(in_features=4096, out_features=1024, bias=False)
          (v_proj): Linear(in_features=4096, out_features=1024, bias=False)
          (o_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (rotary_emb): MixtralRotaryEmbedding()
        )
        (block_sparse_moe): MixtralSparseMoeBlock(
          (gate): Linear(in_features=4096, out_features=8, bias=False)
          (experts): ModuleList(
            (0-7): 8 x MixtralBlockSparseTop2MLP(
              (w1): Linear(in_features=4096, out_features=14336, bias=False)
              (w2): Linear(in_features=14336, out_features=4096, bias=False)
              (w3): Linear(in_features=4096, out_features=14336, bias=False)
  

In [34]:
def compute_ppl(seq):
    input_ids = tokenizer(seq, return_tensors="pt").input_ids.cuda()
    labels = input_ids.clone()

    with torch.no_grad():
        outputs = model(input_ids, labels=labels, return_dict=True) # shift inside
    loss = outputs.loss
    return loss.item(), torch.exp(loss).item()


def make_seq(p, r, use_tpl=True):
    if use_tpl:
        return f"Instruction: {p}\n\n\nResponse: {r}"
    return f"{p}\n\n\n{r}"

compute_ppl(make_seq(
    "Provided the product below, propose some possible reactants that could have been used in the reaction. <product>CC(C)(C)OC(=O)n1c2ccc(C(C)=O)cc2cc1</product>",
    "<reactants>CC(C)(C)OC(=O)n1c2ccccc2cc1.C(C)(=O)Cl</reactants>",
    #"<reactants>CC(C)(C)OC(=O)OC(=O)OC(C)(C)C.[nH]1c2ccc(C(C)=O)cc2cc1</reactants>"
    )
)

(2.076591968536377, 7.977236270904541)

In [16]:
compute_ppl(make_seq(
    "Please suggest possible reactants for the given product. <product>Nc1cccc(OC)c1C#N</product>",
    "<reactants>[N+](c1cccc(OC)c1C#N)([O-])=O</reactants>",
    True
    )
)

(4.902944564819336, 134.685791015625)

In [17]:
compute_ppl(make_seq(
    "Please suggest possible reactants for the given product. <product>Nc1cccc(OC)c1C#N</product>",
    "<reactants>[N+](c1cccc(OC)c1C#N)([O-])=O</reactants>",
    False
    )
)

(1.878980040550232, 6.546823978424072)

In [18]:
input_ids = tokenizer("Please suggest possible reactants for the given product. <product>Nc1cccc(OC)c1C#N</product>", return_tensors="pt").input_ids.cuda()
pred = model.generate(
    input_ids,
    num_beams=4,
    max_new_tokens=300,
)

output = tokenizer.decode(pred[0], skip_special_tokens=True)

The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.


In [19]:
output

'Please suggest possible reactants for the given product. <product> <m>N <m>c <m>1 <m>c <m>c <m>c <m>c <m>( <m>O <m>C <m>) <m>c <m>1 <m>C <m># <m>N </product> <reactants> <m>N <m>c <m>1 <m>c <m>c <m>c <m>c <m>( <m>O <m>) <m>c <m>1 <m>C <m># <m>N <m>. <m>C <m>I </reactants>'

In [22]:
prompt = "Instruction: Please suggest possible reactants for the given product. <product>Nc1cccc(OC)c1C#N</product>\n\n\nResponse:"
target = "<reactants>[N+](c1cccc(OC)c1C#N)([O-])=O</reactants>"

In [23]:
tokenizer.tokenize(prompt), tokenizer.tokenize(target)

(['▁Inst',
  'ruction',
  ':',
  '▁Please',
  '▁suggest',
  '▁possible',
  '▁react',
  'ants',
  '▁for',
  '▁the',
  '▁given',
  '▁product',
  '.',
  '<product>',
  '<m>N',
  '<m>c',
  '<m>1',
  '<m>c',
  '<m>c',
  '<m>c',
  '<m>c',
  '<m>(',
  '<m>O',
  '<m>C',
  '<m>)',
  '<m>c',
  '<m>1',
  '<m>C',
  '<m>#',
  '<m>N',
  '</product>',
  '▁Response',
  ':'],
 ['<reactants>',
  '<m>[N+]',
  '<m>(',
  '<m>c',
  '<m>1',
  '<m>c',
  '<m>c',
  '<m>c',
  '<m>c',
  '<m>(',
  '<m>O',
  '<m>C',
  '<m>)',
  '<m>c',
  '<m>1',
  '<m>C',
  '<m>#',
  '<m>N',
  '<m>)',
  '<m>(',
  '<m>[O-]',
  '<m>)',
  '<m>=',
  '<m>O',
  '</reactants>'])

In [24]:
prompt_tokens = tokenizer.encode(prompt, add_special_tokens=False)
target_tokens = tokenizer.encode(target, add_special_tokens=False)

tokens = (
                [tokenizer.bos_token_id]
                + prompt_tokens
                + target_tokens
                + [tokenizer.eos_token_id]
            )
labels = tokens[:]
labels[: len(prompt_tokens) + 1] = [-100] * (len(prompt_tokens) + 1)


In [25]:
tokens

[1,
 3133,
 3112,
 28747,
 5919,
 3397,
 2572,
 13035,
 1549,
 354,
 272,
 2078,
 2093,
 28723,
 32017,
 32135,
 32127,
 32131,
 32127,
 32127,
 32127,
 32127,
 32129,
 32132,
 32128,
 32130,
 32127,
 32131,
 32128,
 32152,
 32135,
 32018,
 12107,
 28747,
 32019,
 32153,
 32129,
 32127,
 32131,
 32127,
 32127,
 32127,
 32127,
 32129,
 32132,
 32128,
 32130,
 32127,
 32131,
 32128,
 32152,
 32135,
 32130,
 32129,
 32154,
 32130,
 32133,
 32132,
 32020,
 2]

In [26]:
labels

[-100,
 -100,
 -100,
 -100,
 -100,
 -100,
 -100,
 -100,
 -100,
 -100,
 -100,
 -100,
 -100,
 -100,
 -100,
 -100,
 -100,
 -100,
 -100,
 -100,
 -100,
 -100,
 -100,
 -100,
 -100,
 -100,
 -100,
 -100,
 -100,
 -100,
 -100,
 -100,
 -100,
 -100,
 32019,
 32153,
 32129,
 32127,
 32131,
 32127,
 32127,
 32127,
 32127,
 32129,
 32132,
 32128,
 32130,
 32127,
 32131,
 32128,
 32152,
 32135,
 32130,
 32129,
 32154,
 32130,
 32133,
 32132,
 32020,
 2]

In [27]:
tokenizer(prompt, return_tensors="pt").input_ids

tensor([[    1,  3133,  3112, 28747,  5919,  3397,  2572, 13035,  1549,   354,
           272,  2078,  2093, 28723, 32017, 32135, 32127, 32131, 32127, 32127,
         32127, 32127, 32129, 32132, 32128, 32130, 32127, 32131, 32128, 32152,
         32135, 32018, 12107, 28747]])

In [29]:
with torch.no_grad():
    ret = model(torch.tensor([tokens]), labels=torch.tensor([labels]), return_dict=True) # shift inside

In [31]:
ret.loss

tensor(0.0589)

In [35]:
with torch.no_grad():
    ret = model(torch.tensor([tokens]), labels=torch.tensor([labels]), return_dict=True) # shift inside
print(ret.loss)

tensor(0.0231)


: 