Skip to content

[Feature Request] The MatMulNBits matmul_nbits_quantizer does not support 3D weight tensors. #25362

Open
@DakeQQ

Description

@DakeQQ

Describe the feature request

The Q8 quantize_dynamic() tool supports and works well for this purpose. Therefore, it would be beneficial if the Q4 quantization using matmul_nbits_quantizer() could achieve similar support. Is there a plan to support quantization of 3D weights using matmul_nbits_quantizer(), or are there technical challenges that make this impossible?

# ONNX Runtime 1.22.1
2025-07-11 11:30:23,150 onnxruntime.quantization.matmul_nbits_quantizer [INFO] - MatMul weight is not 2D. Skip to quantize

Describe scenario use case

In a typical large language model (LLM) attention process, several view()/reshape() and transpose() operations are performed in each decoder layer. For example:

# Original 

q = self.q_proj(hidden_states_norm).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
k = self.k_proj(hidden_states_norm).view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
v = self.v_proj(hidden_states_norm).view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
# ...

attn_output = attn_output.transpose(1, 2).contiguous().reshape(bsz, q_len, self.hidden_size)
o = self.o_proj(attn_output)

To simplify this, we can pre-rearrange the weight data during the initialization stage:

# pre-rearrange

q_weight = layer.self_attn.q_proj.weight.data
layer.q_weight_reshaped = torch.nn.Parameter(q_weight.view(self.num_heads, self.head_dim, hidden_size).transpose(1, 2), requires_grad=False)

k_weight = layer.self_attn.k_proj.weight.data
layer.k_weight_reshaped = torch.nn.Parameter(k_weight.view(self.num_key_value_heads, 1, self.head_dim, hidden_size).transpose(2, 3), requires_grad=False)

v_weight = layer.self_attn.v_proj.weight.data
layer.v_weight_reshaped = torch.nn.Parameter(v_weight.view(self.num_key_value_heads, 1, self.head_dim, hidden_size).transpose(2, 3), requires_grad=False)

o_weight = layer.self_attn.o_proj.weight.data
layer.o_weight_reshaped = torch.nn.Parameter(o_weight.view(hidden_size, self.num_heads, self.head_dim).permute(1, 2, 0).contiguous(), requires_grad=False)

By pre-rearranging the weights this way, we can eliminate 4 transpose() and 4 reshape() operations in each decoder layer, which leads to a significant reduction in the number of operators—especially beneficial in large models with billions of parameters.

This approach works well with Q8 quantization using quantize_dynamic(), but currently fails with Q4 quantization using matmul_nbits_quantizer() due to the lack of support for 3D weights.

# Simplified

q = torch.matmul(hidden_states_norm, layer.q_weight_reshaped)
k = torch.matmul(hidden_states_norm, layer.k_weight_reshaped)
v = torch.matmul(hidden_states_norm, layer.v_weight_reshaped)
# ...
attn_output = torch.matmul(q, k) + attention_mask
# ...
attn_output = torch.matmul(attn_output, v)
o = torch.matmul(attn_output, layer.o_weight_reshaped).sum(dim=0, keepdim=True)

Metadata

Metadata

Assignees

No one assigned

    Labels

    feature requestrequest for unsupported feature or enhancementquantizationissues related to quantization

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions