In [1]:
import torch
from pathlib import Path
import numpy as np
import sys
mpath = Path.home() / 'model' / 'mpt-7b-chat'

In [2]:
import transformers
model = transformers.AutoModelForCausalLM.from_pretrained(
  mpath,
  trust_remote_code=True,
  low_cpu_mem_usage=True
)

  from .autonotebook import tqdm as notebook_tqdm


Instantiating an MPTForCausalLM model from /Users/aaron/.cache/huggingface/modules/transformers_modules/mpt-7b-chat/modeling_mpt.py
You are using config.init_device='cpu', but you can also use config.init_device="meta" with Composer + FSDP for fast initialization.


Loading checkpoint shards: 100%|██████████| 2/2 [00:28<00:00, 14.42s/it]


In [3]:
model

MPTForCausalLM(
  (transformer): MPTModel(
    (wte): SharedEmbedding(50432, 4096)
    (emb_drop): Dropout(p=0, inplace=False)
    (blocks): ModuleList(
      (0-31): 32 x MPTBlock(
        (norm_1): LPLayerNorm((4096,), eps=1e-05, elementwise_affine=True)
        (attn): MultiheadAttention(
          (Wqkv): Linear(in_features=4096, out_features=12288, bias=False)
          (out_proj): Linear(in_features=4096, out_features=4096, bias=False)
        )
        (norm_2): LPLayerNorm((4096,), eps=1e-05, elementwise_affine=True)
        (ffn): MPTMLP(
          (up_proj): Linear(in_features=4096, out_features=16384, bias=False)
          (act): GELU(approximate='none')
          (down_proj): Linear(in_features=16384, out_features=4096, bias=False)
        )
        (resid_attn_dropout): Dropout(p=0, inplace=False)
        (resid_ffn_dropout): Dropout(p=0, inplace=False)
      )
    )
    (norm_f): LPLayerNorm((4096,), eps=1e-05, elementwise_affine=True)
  )
)

In [5]:
ftype_str = ["f32", "f16"]
ftype = 1
fname_out = Path.home() / 'model' / (f"ggml-{mpath.name}-vocabless-" + ftype_str[ftype] + ".bin")
print(fname_out)

/Users/aaron/model/ggml-mpt-7b-chat-vocabless-f16.bin


In [6]:
print(model.config)

MPTConfig {
  "_name_or_path": "/Users/aaron/model/mpt-7b-chat",
  "architectures": [
    "MPTForCausalLM"
  ],
  "attn_config": {
    "alibi": true,
    "alibi_bias_max": 8,
    "attn_impl": "torch",
    "attn_pdrop": 0,
    "attn_type": "multihead_attention",
    "attn_uses_sequence_id": false,
    "clip_qkv": null,
    "prefix_lm": false,
    "qk_ln": false,
    "softmax_scale": null
  },
  "auto_map": {
    "AutoConfig": "configuration_mpt.MPTConfig",
    "AutoModelForCausalLM": "modeling_mpt.MPTForCausalLM"
  },
  "d_model": 4096,
  "emb_pdrop": 0,
  "embedding_fraction": 1.0,
  "expansion_ratio": 4,
  "init_config": {
    "emb_init_std": null,
    "emb_init_uniform_lim": null,
    "fan_mode": "fan_in",
    "init_div_is_residual": true,
    "init_gain": 0,
    "init_nonlinearity": "relu",
    "init_std": 0.02,
    "name": "kaiming_normal_",
    "verbose": 0
  },
  "init_device": "cpu",
  "learned_pos_emb": true,
  "logit_scale": null,
  "max_seq_len": 2048,
  "model_type": "mpt",


In [7]:
import struct
assert(model.config.attn_config['alibi'])
assert(model.config.no_bias)
assert(not model.config.attn_config['prefix_lm'])
assert(model.config.norm_type == "low_precision_layernorm")
assert(not model.config.attn_config['qk_ln'])
assert(model.config.expansion_ratio == 4)

In [8]:
fout = open(fname_out, "wb")
fout.write(struct.pack("I", 0x67676d64)) # magic: ggmd in hex
fout.write(struct.pack("I", 0)) # v1_no_vocab
fout.write(struct.pack("I", model.config.vocab_size))
fout.write(struct.pack("I", model.config.max_seq_len))
fout.write(struct.pack("I", model.config.n_layers))
fout.write(struct.pack("I", model.config.n_heads))
fout.write(struct.pack("I", model.config.d_model))
fout.write(struct.pack("f", model.config.attn_config['alibi_bias_max']))
clip_qkv = model.config.attn_config['clip_qkv']
fout.write(struct.pack("f",  clip_qkv if clip_qkv is not None else 0))
fout.write(struct.pack("I", ftype))
list_vars = model.state_dict()
# for name in list_vars.keys():
#     print(name, list_vars[name].shape, list_vars[name].dtype)
for name in list_vars.keys():
    data = list_vars[name].squeeze().numpy()
    print("Processing variable: " + name + " with shape: ", data.shape)

    n_dims = len(data.shape);

    # ftype == 0 -> float32, ftype == 1 -> float16
    ftype_cur = 0;
    if ftype != 0:
        # Keep token embeddings in fp32
        if name[-7:] == ".weight" and n_dims == 2 and ".wte" not in name:
            print("  Converting to float16")
            data = data.astype(np.float16)
            ftype_cur = 1
        else:
            print("  Converting to float32")
            data = data.astype(np.float32)
            ftype_cur = 0
    else:
        if data.dtype != np.float32:
            print("  Converting to float32")
            data = data.astype(np.float32)
            ftype_cur = 0

    # header
    str = name.encode('utf-8')
    fout.write(struct.pack("iii", n_dims, len(str), ftype_cur))
    for i in range(n_dims):
        fout.write(struct.pack("i", data.shape[n_dims - 1 - i]))
    fout.write(str);

    # data
    data.tofile(fout)
fout.close()
print(fname_out)

Processing variable: transformer.wte.weight with shape:  (50432, 4096)
  Converting to float32
Processing variable: transformer.blocks.0.norm_1.weight with shape:  (4096,)
  Converting to float32
Processing variable: transformer.blocks.0.attn.Wqkv.weight with shape:  (12288, 4096)
  Converting to float16
Processing variable: transformer.blocks.0.attn.out_proj.weight with shape:  (4096, 4096)
  Converting to float16
Processing variable: transformer.blocks.0.norm_2.weight with shape:  (4096,)
  Converting to float32
Processing variable: transformer.blocks.0.ffn.up_proj.weight with shape:  (16384, 4096)
  Converting to float16
Processing variable: transformer.blocks.0.ffn.down_proj.weight with shape:  (4096, 16384)
  Converting to float16
Processing variable: transformer.blocks.1.norm_1.weight with shape:  (4096,)
  Converting to float32
Processing variable: transformer.blocks.1.attn.Wqkv.weight with shape:  (12288, 4096)
  Converting to float16
Processing variable: transformer.blocks.1.a