In [None]:
import torch
from pathlib import Path
import numpy as np
import sys
mpath = Path('e:/big_model/mpt-7b-chat')

In [None]:
import transformers
model = transformers.AutoModelForCausalLM.from_pretrained(
  mpath,
  trust_remote_code=True,
  low_cpu_mem_usage=True
)

In [None]:
model

In [None]:
ftype_str = ["f32", "f16"]
ftype = 0
fname_out = Path("e:/big_model/") / (f"ggml-{mpath.name}-" + ftype_str[ftype] + ".bin")
print(fname_out)

In [None]:
print(model.config)

In [None]:
import struct
assert(model.config.attn_config['alibi'])
assert(model.config.no_bias)
assert(not model.config.attn_config['prefix_lm'])
assert(model.config.norm_type == "low_precision_layernorm")
assert(not model.config.attn_config['qk_ln'])
assert(model.config.expansion_ratio == 4)

In [None]:
fout = open(fname_out, "wb")
fout.write(struct.pack("I", 0x67676d64)) # magic: ggmd in hex
fout.write(struct.pack("I", 0)) # v1_no_vocab
fout.write(struct.pack("I", model.config.vocab_size))
fout.write(struct.pack("I", model.config.max_seq_len))
fout.write(struct.pack("I", model.config.n_layers))
fout.write(struct.pack("I", model.config.n_heads))
fout.write(struct.pack("I", model.config.d_model))
fout.write(struct.pack("f", model.config.attn_config['alibi_bias_max']))
clip_qkv = model.config.attn_config['clip_qkv']
fout.write(struct.pack("f",  clip_qkv if clip_qkv is not None else 0))
fout.write(struct.pack("I", ftype))
list_vars = model.state_dict()
# for name in list_vars.keys():
#     print(name, list_vars[name].shape, list_vars[name].dtype)
for name in list_vars.keys():
    data = list_vars[name].squeeze().numpy()
    print("Processing variable: " + name + " with shape: ", data.shape)

    n_dims = len(data.shape);

    # ftype == 0 -> float32, ftype == 1 -> float16
    ftype_cur = 0;
    if ftype != 0:
        # Keep token embeddings in fp32
        if name[-7:] == ".weight" and n_dims == 2 and ".wte" not in name:
            print("  Converting to float16")
            data = data.astype(np.float16)
            ftype_cur = 1
        else:
            print("  Converting to float32")
            data = data.astype(np.float32)
            ftype_cur = 0
    else:
        if data.dtype != np.float32:
            print("  Converting to float32")
            data = data.astype(np.float32)
            ftype_cur = 0

    # header
    str = name.encode('utf-8')
    fout.write(struct.pack("iii", n_dims, len(str), ftype_cur))
    for i in range(n_dims):
        fout.write(struct.pack("i", data.shape[n_dims - 1 - i]))
    fout.write(str);

    # data
    data.tofile(fout)
fout.close()
print(fname_out)