[quantization] Normalize attention_mask in qwen3_vl text model#642
Conversation
- Convert 2D/4D masks to additive causal form - Preserve padding semantics and support cache-aware shapes - Return tuple when return_dict=False to fix Dynamo export TICO-DCO-1.0-Signed-off-by: seongwoo <mhs4670go@naver.com>
|
I updated the implementation accordingly to normalize the mask.
Also adjusted the export path to return tuples ( Could you review this PR instead of #630? |
Qunatization Example Script 👍$ python tico/quantization/wrapq/examples/qwen/quantize_text_model.py
┌───────────── Quantization Error Summary ─────────────
│ Mean |diff|: 0.132374
│ PEIR : 9.924287 %
└──────────────────────────────────────────────────────
┌────────────────────────────────────────────┐
4.1┤ │
│ ••••• │
2.5┤ •••••• │
│ •••••••• │
│ •••••••• │
1.0┤ ••••••••• │
│ •••••••••• │
-0.5┤ ••••••••• │
│ •••••••••• │
-2.1┤ •••••••• │
│ •••••••• │
│ ••••• │
-3.6┤ •••• │
│ • │
-5.1┤ │
└┬──────────┬──────────┬─────────┬──────────┬┘
-5.1 -2.8 -0.5 1.8 4.1
Circle model saved as 'qwen3vl_text_model.q.circle' |
Unit Tests 👍$ coverage run -m pytest test/quantization/wrapq/wrappers/qwen_vl/test_quant_text_model.py -v
======================================================================= test session starts ========================================================================
platform linux -- Python 3.10.12, pytest-8.4.0, pluggy-1.6.0 -- /home/d.savchenkov/myenv/bin/python3
cachedir: .pytest_cache
rootdir: /home/d.savchenkov/TICO
configfile: pyproject.toml
plugins: anyio-4.12.0, mock-3.15.1, xdist-3.7.0, cov-6.2.1
collected 21 items
test/quantization/wrapq/wrappers/qwen_vl/test_quant_text_model.py::TestQuantQwen3VLTextModel::test_attention_mask_2d_prefill PASSED [ 4%]
test/quantization/wrapq/wrappers/qwen_vl/test_quant_text_model.py::TestQuantQwen3VLTextModel::test_attention_mask_4d_additive_passthrough PASSED [ 9%]
test/quantization/wrapq/wrappers/qwen_vl/test_quant_text_model.py::TestQuantQwen3VLTextModel::test_attention_mask_bool_prefill PASSED [ 14%]
test/quantization/wrapq/wrappers/qwen_vl/test_quant_text_model.py::TestQuantQwen3VLTextModel::test_attention_mask_decode_with_cache PASSED [ 19%]
test/quantization/wrapq/wrappers/qwen_vl/test_quant_text_model.py::TestQuantQwen3VLTextModel::test_deepstack_injection PASSED [ 23%]
test/quantization/wrapq/wrappers/qwen_vl/test_quant_text_model.py::TestQuantQwen3VLTextModel::test_different_batch_sizes PASSED [ 28%]
test/quantization/wrapq/wrappers/qwen_vl/test_quant_text_model.py::TestQuantQwen3VLTextModel::test_different_sequence_lengths PASSED [ 33%]
test/quantization/wrapq/wrappers/qwen_vl/test_quant_text_model.py::TestQuantQwen3VLTextModel::test_embedding_layer_quantization PASSED [ 38%]
test/quantization/wrapq/wrappers/qwen_vl/test_quant_text_model.py::TestQuantQwen3VLTextModel::test_forward_diff PASSED [ 42%]
test/quantization/wrapq/wrappers/qwen_vl/test_quant_text_model.py::TestQuantQwen3VLTextModel::test_inputs_embeds_path PASSED [ 47%]
test/quantization/wrapq/wrappers/qwen_vl/test_quant_text_model.py::TestQuantQwen3VLTextModel::test_layers_wrapped PASSED [ 52%]
test/quantization/wrapq/wrappers/qwen_vl/test_quant_text_model.py::TestQuantQwen3VLTextModel::test_mode_transitions PASSED [ 57%]
test/quantization/wrapq/wrappers/qwen_vl/test_quant_text_model.py::TestQuantQwen3VLTextModel::test_no_cache_mode PASSED [ 61%]
test/quantization/wrapq/wrappers/qwen_vl/test_quant_text_model.py::TestQuantQwen3VLTextModel::test_norm_wrapped PASSED [ 66%]
test/quantization/wrapq/wrappers/qwen_vl/test_quant_text_model.py::TestQuantQwen3VLTextModel::test_normalize_attention_mask_shapes PASSED [ 71%]
test/quantization/wrapq/wrappers/qwen_vl/test_quant_text_model.py::TestQuantQwen3VLTextModel::test_observer_count PASSED [ 76%]
test/quantization/wrapq/wrappers/qwen_vl/test_quant_text_model.py::TestQuantQwen3VLTextModel::test_output_shape PASSED [ 80%]
test/quantization/wrapq/wrappers/qwen_vl/test_quant_text_model.py::TestQuantQwen3VLTextModel::test_per_module_override PASSED [ 85%]
test/quantization/wrapq/wrappers/qwen_vl/test_quant_text_model.py::TestQuantQwen3VLTextModel::test_registration_in_registry PASSED [ 90%]
test/quantization/wrapq/wrappers/qwen_vl/test_quant_text_model.py::TestQuantQwen3VLTextModel::test_return_dict_false_with_cache PASSED [ 95%]
test/quantization/wrapq/wrappers/qwen_vl/test_quant_text_model.py::TestQuantQwen3VLTextModel::test_rotary_emb_not_wrapped PASSED [100%]
================================================================== 21 passed, 2 warnings in 9.67s ==================================================================Test Coverage$ coverage report -m
Name Stmts Miss Cover Missing
----------------------------------------------------------------------------------------------------
...
tico/quantization/wrapq/wrappers/qwen_vl/quant_text_attn.py 136 3 98% 201-203
tico/quantization/wrapq/wrappers/qwen_vl/quant_text_decoder_layer.py 42 0 100%
tico/quantization/wrapq/wrappers/qwen_vl/quant_text_mlp.py 43 0 100%
tico/quantization/wrapq/wrappers/qwen_vl/quant_text_model.py 177 24 86% 111-112, 305, 312-319, 322, 330, 351, 362-386, 427, 458-459, 462-463, 520
...
------------------------------------------------------------------------------------------------------------
TOTAL 11789 7432 37%Test coverage has degraded a bit. |
Trace Results 👍python tico/quantization/wrapq/examples/qwen/trace_qwen.py \
--model ~/models/qwen3-vl-2b \
--no-trace-unquantized \
--no-trace-quantized \
| tee trace_qwen.txt
--------------------------------------------------------------------------------
MODULE NAME DIFFERENCE
--------------------------------------------------------------------------------
model.language_model.embed_tokens {'mean': 0.0, 'min': 0.0, 'max': 0.0, 'stddev': 0.0, 'peir': 0.0}
model.visual.patch_embed.proj {'mean': 4.663015715777874e-07, 'min': 2.9802322387695312e-08, 'max': 3.337860107421875e-06, 'stddev': 5.199021302360052e-07, 'peir': 1.105467087481138e-06}
...
model.language_model.layers.0.self_attn.v_proj {'mean': 1.054959852808679e-07, 'min': 0.0, 'max': 7.171183824539185e-07, 'stddev': 1.0697030461415125e-07, 'peir': 7.13171060995813e-07}
model.language_model.layers.0.self_attn.o_proj {'mean': 1.4910252588151707e-08, 'min': 0.0, 'max': 6.891787052154541e-08, 'stddev': 1.1672808497564802e-08, 'peir': 4.7245867289065846e-07}
...
lm_head {'mean': 1.4132734804661595e-07, 'min': 0.0, 'max': 8.344650268554688e-07, 'stddev': 1.1461531101986111e-07, 'peir': 6.546005702132053e-07}PEIR at |
|
@dvsav Please approve the PR if it looks good:) |
There was a problem hiding this comment.
LGTM
Will appreciate if you respond to the following comments though:
https://github.com/Samsung/TICO/pull/642/changes#r3109802326
https://github.com/Samsung/TICO/pull/642/changes#r3109837986
| raise ValueError( | ||
| "2D attention_mask batch size does not match inputs_embeds batch size. " | ||
| f"Got mask batch={attention_mask.shape[0]}, input batch={batch_size}." | ||
| ) |
There was a problem hiding this comment.
Not covered with unit tests.
| if mask_len == q_len and past_seen_tokens > 0: | ||
| past_prefix = torch.ones( | ||
| batch_size, | ||
| past_seen_tokens, | ||
| device=attention_mask.device, | ||
| dtype=attention_mask.dtype, | ||
| ) | ||
| attention_mask = torch.cat((past_prefix, attention_mask), dim=-1) | ||
| mask_len = attention_mask.shape[1] |
There was a problem hiding this comment.
Not covered with unit tests.
| if mask_len != kv_len: | ||
| raise ValueError( | ||
| "2D attention_mask length does not match the expected KV length. " | ||
| f"Got mask length={mask_len}, expected kv_len={kv_len}." | ||
| ) |
There was a problem hiding this comment.
Not covered with unit tests.
| elif torch.is_floating_point(attention_mask): | ||
| padding_keep = attention_mask != 0 |
There was a problem hiding this comment.
Not covered with unit tests.
| if attention_mask.ndim == 4: | ||
| if attention_mask.shape[-2] != q_len or attention_mask.shape[-1] != kv_len: | ||
| raise ValueError( | ||
| "4D attention_mask shape does not match the expected query/KV lengths. " | ||
| f"Got shape={tuple(attention_mask.shape)}, expected (*, *, {q_len}, {kv_len})." | ||
| ) |
There was a problem hiding this comment.
Not covered with unit tests.
| if attention_mask.dtype == torch.bool: | ||
| additive_mask = torch.zeros_like( | ||
| attention_mask, | ||
| device=input_embeds.device, | ||
| dtype=input_embeds.dtype, | ||
| ) | ||
| additive_mask = additive_mask.masked_fill( | ||
| ~attention_mask.to(device=input_embeds.device), | ||
| float("-120"), | ||
| ) | ||
| return additive_mask | ||
|
|
||
| bool_mask = attention_mask.to(torch.long) != 0 | ||
| additive_mask = torch.zeros_like( | ||
| bool_mask, | ||
| device=input_embeds.device, | ||
| dtype=input_embeds.dtype, | ||
| ) | ||
| additive_mask = additive_mask.masked_fill( | ||
| ~bool_mask.to(device=input_embeds.device), | ||
| float("-120"), | ||
| ) | ||
| return additive_mask | ||
|
|
||
| raise ValueError( | ||
| "Unsupported attention_mask rank. " | ||
| f"Expected None, 2D, or 4D mask, but got ndim={attention_mask.ndim}." | ||
| ) |
There was a problem hiding this comment.
Not covered with unit tests.
| if not return_dict: | ||
| if use_cache: | ||
| return hidden_states, past_key_values | ||
| return (hidden_states,) |
There was a problem hiding this comment.
Not covered with unit tests.
| float("-120"), | ||
| ) | ||
|
|
||
| return causal_mask + padding_mask |
There was a problem hiding this comment.
This addition leads to a bit nonuniform additive mask where -240. elements may appear at positions where causal_mask overlaps with padding_mask:
(Pdb) pp (causal_mask + padding_mask)[0]
tensor( [[[ 0., -120., -120., -120., -120., -120., -120., -120., -120., -120., -120., -120., -240., -240., -240., -240.],
[ 0., 0., -120., -120., -120., -120., -120., -120., -120., -120., -120., -120., -240., -240., -240., -240.],
[ 0., 0., 0., -120., -120., -120., -120., -120., -120., -120., -120., -120., -240., -240., -240., -240.],
[ 0., 0., 0., 0., -120., -120., -120., -120., -120., -120., -120., -120., -240., -240., -240., -240.],
[ 0., 0., 0., 0., 0., -120., -120., -120., -120., -120., -120., -120., -240., -240., -240., -240.],
[ 0., 0., 0., 0., 0., 0., -120., -120., -120., -120., -120., -120., -240., -240., -240., -240.],
[ 0., 0., 0., 0., 0., 0., 0., -120., -120., -120., -120., -120., -240., -240., -240., -240.],
[ 0., 0., 0., 0., 0., 0., 0., 0., -120., -120., -120., -120., -240., -240., -240., -240.],
[ 0., 0., 0., 0., 0., 0., 0., 0., 0., -120., -120., -120., -240., -240., -240., -240.],
[ 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., -120., -120., -240., -240., -240., -240.],
[ 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., -120., -240., -240., -240., -240.],
[ 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., -240., -240., -240., -240.],
[ 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., -120., -240., -240., -240.],
[ 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., -120., -120., -240., -240.],
[ 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., -120., -120., -120., -240.],
[ 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., -120., -120., -120., -120.]]])Not sure if this affects anything in terms of quantization of that mask...
Anyway, how about clamping this mixture of -120 and -240?
return torch.clamp(causal_mask + padding_mask, min=-120., max=0.)There was a problem hiding this comment.
Good point! I've applied this.
| ) | ||
| additive_mask = additive_mask.masked_fill( | ||
| ~attention_mask.to(device=input_embeds.device), | ||
| float("-120"), |
There was a problem hiding this comment.
This magic -120.0 number occurs a lot in the code now.
Maybe it'd be reasonable to define a class variable for it.
There was a problem hiding this comment.
Yeah. Actually, the number is kind of heuristic. It would be helpful to have ablation study with various mask values.
There was a problem hiding this comment.
I mean, I'd suggest something like this:
class QuantQwen3VLTextModel(QuantModuleBase):
LARGE_NEGATIVE_NUMBER_FOR_MASKING: float = -120.0
...
additive_mask = additive_mask.masked_fill(
~attention_mask.to(device=input_embeds.device),
self.LARGE_NEGATIVE_NUMBER_FOR_MASKING,|
Ah I miseed that comments sorry. I'll review it soon |
|
@dvsav PTAL:) |
Related: #630
TICO-DCO-1.0-Signed-off-by: seongwoo mhs4670go@naver.com