-
Notifications
You must be signed in to change notification settings - Fork 184
build bitnet from HF bf16 model #1421
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Signed-off-by: Liqun Fu <liqun.fu@microsoft.com>
Signed-off-by: Liqun Fu <liqun.fu@microsoft.com>
# Make MatMul node (output projection weight node) | ||
o_proj = 'o_proj' if hasattr(attention, 'o_proj') else 'dense' | ||
o_matmul_basename = f"/model/layers.{layer_id}/attn/o_proj/MatMul" | ||
o_weight = eval(f"attention.{o_proj}") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Avoid eval: it is a code smell
o_weight = eval(f"attention.{o_proj}") | |
o_weight = getattr(attention, o_proj) |
o_bias_exists = eval(f"attention.{o_proj}.bias") is not None | ||
if o_bias_exists: | ||
o_add_name = f"/model/layers.{layer_id}/attn/o_proj/Add" | ||
o_bias = eval(f"attention.{o_proj}.bias.detach().numpy()") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: here as well
cos_cache=cos_cache_name, sin_cache=sin_cache_name, **kwargs, | ||
) | ||
|
||
# add an extra SimplifiedLayerNorm before the output projection for attention |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This can be simplified by adding a setting in the attention_attrs
for the RMSNorm before the output projection MatMul.
onnxruntime-genai/src/python/py/models/builder.py
Lines 1619 to 1621 in 36cd2ca
# Make Q/K SimplifiedLayerNorm nodes | |
if self.attention_attrs["q_norm"] and self.attention_attrs["k_norm"]: | |
self.make_qk_norm(layer_id, attention) |
For example:
"q_norm": False, # LayerNorm after MatMul in Q path
"k_norm": False, # LayerNorm after MatMul in K path
"o_norm": False, # LayerNorm before MatMul in output path
Then we can set o_norm = True
in the BitNetModel
class constructor and insert the following logic here.
# Make SimplifiedLayerNorm node before output MatMul
if self.attention_attrs["o_norm"]:
self.make_o_norm(layer_id, attention)
Once that's done, we can remove this code to override the make_attention
method.
@@ -282,7 +282,9 @@ def from_pretrained(model_type, input_path, head_size, hidden_size, intermediate | |||
Also performs any pre-processing and post-processing to the GGUF models to ensure the | |||
weights are the same as the PyTorch models. | |||
""" | |||
if model_type == "ChatGLMModel": | |||
if model_type == "BitNetForCausalLM": | |||
model = GGUFModel(input_path, head_size, hidden_size, intermediate_size, num_attn_heads, num_kv_heads, vocab_size) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does BitNet not require any post-processing (e.g. undo_permute
, swap_norm_types
, swap_mlp_types
) to match the PyTorch model's class attributes?
super().__init__(config, io_dtype, onnx_dtype, ep, cache_dir, extra_options) | ||
self.rms_norm_eps = config.rms_norm_eps | ||
|
||
def make_mlp_proj(self, layer_id, mlp, root_input): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
BitNet uses three MatMuls (gate_proj
, up_proj
, and down_proj
) but this method only creates two of them (up_proj
and down_proj
). You would have to override the base make_mlp_proj
method. This version of make_mlp_proj
is specific to the Nemotron model because Nemotron does not have a gate projection MatMul.
|
||
act_fn_name = self.make_activation(layer_id, root_input=f"{up_name}/output_0") | ||
|
||
# add an extra SimplifiedLayerNorm after the MLP activation before the down projection MatMul |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Similar to the make_attention
method, it would be easier to add a setting in mlp_attrs
for the extra RMSNorm before the down projection MatMul.
@@ -602,7 +602,7 @@ std::shared_ptr<Model> CreateModel(OrtEnv& ort_env, const char* config_path, con | |||
std::shared_ptr<Model> CreateModel(OrtEnv& ort_env, std::unique_ptr<Config> config) { | |||
std::set<std::string> llm_types = {"chatglm", "decoder", "gemma", "gemma2", "gemma3_text", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's add bitnet
in llm_types
so that the alphabetical order is maintained.
@@ -57,6 +57,9 @@ def main(args): | |||
args.chat_template = '{system_prompt}<|im_start|>user\n{input}<|im_end|>\n<|im_start|>assistant\n' | |||
elif model_type == "gemma3_text": | |||
args.chat_template = '<start_of_turn>user\n{system_prompt}{input}<end_of_turn>\n<start_of_turn>model\n' | |||
elif model_type.startswith("bitnet"): | |||
# args.chat_template = '{system_prompt}{"role": "user", "content": "{input}"}' |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you also add the chat template in the model-chat.py
example as well?
@@ -72,6 +75,8 @@ def main(args): | |||
elif model_type.startswith("llama"): | |||
system_prompt = f"<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n{args.system_prompt}<|eot_id|>" | |||
print("Using System Prompt for LLAMA 3, if you are using LLAMA 2 please pass the argument --system_prompt '<s>[INST] <<SYS>>\\n{args.system_prompt}\\n<</SYS>>')") | |||
elif model_type.startswith("bitnet"): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you also add the system prompt in the model-chat.py
example as well?
The PR to add bfloat16 support in the model builder has been opened here. Once it is merged, you can target your PR to merge with the main branch instead. |
Signed-off-by: Liqun Fu <liqun.fu@microsoft.com>
No description provided.