Description
Describe the feature request
The Q8 quantize_dynamic()
tool supports and works well for this purpose. Therefore, it would be beneficial if the Q4 quantization using matmul_nbits_quantizer()
could achieve similar support. Is there a plan to support quantization of 3D weights using matmul_nbits_quantizer()
, or are there technical challenges that make this impossible?
# ONNX Runtime 1.22.1
2025-07-11 11:30:23,150 onnxruntime.quantization.matmul_nbits_quantizer [INFO] - MatMul weight is not 2D. Skip to quantize
Describe scenario use case
In a typical large language model (LLM) attention process, several view()
/reshape()
and transpose()
operations are performed in each decoder layer. For example:
# Original
q = self.q_proj(hidden_states_norm).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
k = self.k_proj(hidden_states_norm).view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
v = self.v_proj(hidden_states_norm).view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
# ...
attn_output = attn_output.transpose(1, 2).contiguous().reshape(bsz, q_len, self.hidden_size)
o = self.o_proj(attn_output)
To simplify this, we can pre-rearrange the weight data during the initialization stage:
# pre-rearrange
q_weight = layer.self_attn.q_proj.weight.data
layer.q_weight_reshaped = torch.nn.Parameter(q_weight.view(self.num_heads, self.head_dim, hidden_size).transpose(1, 2), requires_grad=False)
k_weight = layer.self_attn.k_proj.weight.data
layer.k_weight_reshaped = torch.nn.Parameter(k_weight.view(self.num_key_value_heads, 1, self.head_dim, hidden_size).transpose(2, 3), requires_grad=False)
v_weight = layer.self_attn.v_proj.weight.data
layer.v_weight_reshaped = torch.nn.Parameter(v_weight.view(self.num_key_value_heads, 1, self.head_dim, hidden_size).transpose(2, 3), requires_grad=False)
o_weight = layer.self_attn.o_proj.weight.data
layer.o_weight_reshaped = torch.nn.Parameter(o_weight.view(hidden_size, self.num_heads, self.head_dim).permute(1, 2, 0).contiguous(), requires_grad=False)
By pre-rearranging the weights this way, we can eliminate 4 transpose()
and 4 reshape()
operations in each decoder layer, which leads to a significant reduction in the number of operators—especially beneficial in large models with billions of parameters.
This approach works well with Q8 quantization using quantize_dynamic()
, but currently fails with Q4 quantization using matmul_nbits_quantizer()
due to the lack of support for 3D weights.
# Simplified
q = torch.matmul(hidden_states_norm, layer.q_weight_reshaped)
k = torch.matmul(hidden_states_norm, layer.k_weight_reshaped)
v = torch.matmul(hidden_states_norm, layer.v_weight_reshaped)
# ...
attn_output = torch.matmul(q, k) + attention_mask
# ...
attn_output = torch.matmul(attn_output, v)
o = torch.matmul(attn_output, layer.o_weight_reshaped).sum(dim=0, keepdim=True)