Skip to content

[quantization] Normalize attention_mask in qwen3_vl text model#642

Merged
mhs4670go merged 2 commits into
Samsung:mainfrom
mhs4670go:re2
Apr 21, 2026
Merged

[quantization] Normalize attention_mask in qwen3_vl text model#642
mhs4670go merged 2 commits into
Samsung:mainfrom
mhs4670go:re2

Conversation

@mhs4670go
Copy link
Copy Markdown
Contributor

  • Convert 2D/4D masks to additive causal form
  • Preserve padding semantics and support cache-aware shapes
  • Return tuple when return_dict=False to fix Dynamo export

Related: #630
TICO-DCO-1.0-Signed-off-by: seongwoo mhs4670go@naver.com

- Convert 2D/4D masks to additive causal form
- Preserve padding semantics and support cache-aware shapes
- Return tuple when return_dict=False to fix Dynamo export

TICO-DCO-1.0-Signed-off-by: seongwoo <mhs4670go@naver.com>
@mhs4670go mhs4670go requested a review from a team April 19, 2026 11:08
@mhs4670go
Copy link
Copy Markdown
Contributor Author

@dvsav

I updated the implementation accordingly to normalize the mask.

  • Added _normalize_attention_mask(...) to convert all inputs (None / 2D / 4D) into a 4D additive causal mask.
  • Preserves padding semantics for 2D masks and combines them with the causal mask.
  • Supports cache-aware shapes (q_len x (past_len + q_len)) during decode.

Also adjusted the export path to return tuples (return_dict=False) to avoid Dynamo issues, while keeping eager behavior unchanged.

Could you review this PR instead of #630?

@dvsav
Copy link
Copy Markdown
Contributor

dvsav commented Apr 20, 2026

Qunatization Example Script 👍

$ python tico/quantization/wrapq/examples/qwen/quantize_text_model.py
┌───────────── Quantization Error Summary ─────────────
│ Mean |diff|: 0.132374
│ PEIR       : 9.924287 %
└──────────────────────────────────────────────────────
    ┌────────────────────────────────────────────┐
 4.1┤                                            │
    │                                     •••••  │
 2.5┤                                  ••••••    │
    │                              ••••••••      │
    │                           ••••••••         │
 1.0┤                        •••••••••           │
    │                    ••••••••••              │
-0.5┤                  •••••••••                 │
    │              ••••••••••                    │
-2.1┤            ••••••••                        │
    │         ••••••••                           │
    │        •••••                               │
-3.6┤     ••••                                   │
    │  •                                         │
-5.1┤                                            │
    └┬──────────┬──────────┬─────────┬──────────┬┘
   -5.1       -2.8       -0.5       1.8       4.1 

Circle model saved as 'qwen3vl_text_model.q.circle'

@dvsav
Copy link
Copy Markdown
Contributor

dvsav commented Apr 20, 2026

Unit Tests 👍

$ coverage run -m pytest test/quantization/wrapq/wrappers/qwen_vl/test_quant_text_model.py -v
======================================================================= test session starts ========================================================================
platform linux -- Python 3.10.12, pytest-8.4.0, pluggy-1.6.0 -- /home/d.savchenkov/myenv/bin/python3
cachedir: .pytest_cache
rootdir: /home/d.savchenkov/TICO
configfile: pyproject.toml
plugins: anyio-4.12.0, mock-3.15.1, xdist-3.7.0, cov-6.2.1
collected 21 items                                                                                                                                                 

test/quantization/wrapq/wrappers/qwen_vl/test_quant_text_model.py::TestQuantQwen3VLTextModel::test_attention_mask_2d_prefill PASSED                          [  4%]
test/quantization/wrapq/wrappers/qwen_vl/test_quant_text_model.py::TestQuantQwen3VLTextModel::test_attention_mask_4d_additive_passthrough PASSED             [  9%]
test/quantization/wrapq/wrappers/qwen_vl/test_quant_text_model.py::TestQuantQwen3VLTextModel::test_attention_mask_bool_prefill PASSED                        [ 14%]
test/quantization/wrapq/wrappers/qwen_vl/test_quant_text_model.py::TestQuantQwen3VLTextModel::test_attention_mask_decode_with_cache PASSED                   [ 19%]
test/quantization/wrapq/wrappers/qwen_vl/test_quant_text_model.py::TestQuantQwen3VLTextModel::test_deepstack_injection PASSED                                [ 23%]
test/quantization/wrapq/wrappers/qwen_vl/test_quant_text_model.py::TestQuantQwen3VLTextModel::test_different_batch_sizes PASSED                              [ 28%]
test/quantization/wrapq/wrappers/qwen_vl/test_quant_text_model.py::TestQuantQwen3VLTextModel::test_different_sequence_lengths PASSED                         [ 33%]
test/quantization/wrapq/wrappers/qwen_vl/test_quant_text_model.py::TestQuantQwen3VLTextModel::test_embedding_layer_quantization PASSED                       [ 38%]
test/quantization/wrapq/wrappers/qwen_vl/test_quant_text_model.py::TestQuantQwen3VLTextModel::test_forward_diff PASSED                                       [ 42%]
test/quantization/wrapq/wrappers/qwen_vl/test_quant_text_model.py::TestQuantQwen3VLTextModel::test_inputs_embeds_path PASSED                                 [ 47%]
test/quantization/wrapq/wrappers/qwen_vl/test_quant_text_model.py::TestQuantQwen3VLTextModel::test_layers_wrapped PASSED                                     [ 52%]
test/quantization/wrapq/wrappers/qwen_vl/test_quant_text_model.py::TestQuantQwen3VLTextModel::test_mode_transitions PASSED                                   [ 57%]
test/quantization/wrapq/wrappers/qwen_vl/test_quant_text_model.py::TestQuantQwen3VLTextModel::test_no_cache_mode PASSED                                      [ 61%]
test/quantization/wrapq/wrappers/qwen_vl/test_quant_text_model.py::TestQuantQwen3VLTextModel::test_norm_wrapped PASSED                                       [ 66%]
test/quantization/wrapq/wrappers/qwen_vl/test_quant_text_model.py::TestQuantQwen3VLTextModel::test_normalize_attention_mask_shapes PASSED                    [ 71%]
test/quantization/wrapq/wrappers/qwen_vl/test_quant_text_model.py::TestQuantQwen3VLTextModel::test_observer_count PASSED                                     [ 76%]
test/quantization/wrapq/wrappers/qwen_vl/test_quant_text_model.py::TestQuantQwen3VLTextModel::test_output_shape PASSED                                       [ 80%]
test/quantization/wrapq/wrappers/qwen_vl/test_quant_text_model.py::TestQuantQwen3VLTextModel::test_per_module_override PASSED                                [ 85%]
test/quantization/wrapq/wrappers/qwen_vl/test_quant_text_model.py::TestQuantQwen3VLTextModel::test_registration_in_registry PASSED                           [ 90%]
test/quantization/wrapq/wrappers/qwen_vl/test_quant_text_model.py::TestQuantQwen3VLTextModel::test_return_dict_false_with_cache PASSED                       [ 95%]
test/quantization/wrapq/wrappers/qwen_vl/test_quant_text_model.py::TestQuantQwen3VLTextModel::test_rotary_emb_not_wrapped PASSED                             [100%]

================================================================== 21 passed, 2 warnings in 9.67s ==================================================================

Test Coverage

$ coverage report -m
Name                                                                   Stmts   Miss  Cover   Missing
----------------------------------------------------------------------------------------------------
...
tico/quantization/wrapq/wrappers/qwen_vl/quant_text_attn.py                      136      3    98%   201-203
tico/quantization/wrapq/wrappers/qwen_vl/quant_text_decoder_layer.py              42      0   100%
tico/quantization/wrapq/wrappers/qwen_vl/quant_text_mlp.py                        43      0   100%
tico/quantization/wrapq/wrappers/qwen_vl/quant_text_model.py                     177     24    86%   111-112, 305, 312-319, 322, 330, 351, 362-386, 427, 458-459, 462-463, 520
...
------------------------------------------------------------------------------------------------------------
TOTAL                                                                          11789   7432    37%

Test coverage has degraded a bit.

@dvsav
Copy link
Copy Markdown
Contributor

dvsav commented Apr 20, 2026

Trace Results 👍

python tico/quantization/wrapq/examples/qwen/trace_qwen.py \
    --model ~/models/qwen3-vl-2b \
    --no-trace-unquantized \
    --no-trace-quantized \
    | tee trace_qwen.txt

--------------------------------------------------------------------------------
MODULE NAME                                            DIFFERENCE
--------------------------------------------------------------------------------
model.language_model.embed_tokens                      {'mean': 0.0, 'min': 0.0, 'max': 0.0, 'stddev': 0.0, 'peir': 0.0}
model.visual.patch_embed.proj                          {'mean': 4.663015715777874e-07, 'min': 2.9802322387695312e-08, 'max': 3.337860107421875e-06, 'stddev': 5.199021302360052e-07, 'peir': 1.105467087481138e-06}
...
model.language_model.layers.0.self_attn.v_proj         {'mean': 1.054959852808679e-07, 'min': 0.0, 'max': 7.171183824539185e-07, 'stddev': 1.0697030461415125e-07, 'peir': 7.13171060995813e-07}
model.language_model.layers.0.self_attn.o_proj         {'mean': 1.4910252588151707e-08, 'min': 0.0, 'max': 6.891787052154541e-08, 'stddev': 1.1672808497564802e-08, 'peir': 4.7245867289065846e-07}
...
lm_head                                                {'mean': 1.4132734804661595e-07, 'min': 0.0, 'max': 8.344650268554688e-07, 'stddev': 1.1461531101986111e-07, 'peir': 6.546005702132053e-07}

PEIR at model.language_model.layers.0.self_attn.o_proj is OK now.

@mhs4670go
Copy link
Copy Markdown
Contributor Author

@dvsav Please approve the PR if it looks good:)

Copy link
Copy Markdown
Contributor

@dvsav dvsav left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Comment on lines +305 to +308
raise ValueError(
"2D attention_mask batch size does not match inputs_embeds batch size. "
f"Got mask batch={attention_mask.shape[0]}, input batch={batch_size}."
)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not covered with unit tests.

Comment on lines +311 to +319
if mask_len == q_len and past_seen_tokens > 0:
past_prefix = torch.ones(
batch_size,
past_seen_tokens,
device=attention_mask.device,
dtype=attention_mask.dtype,
)
attention_mask = torch.cat((past_prefix, attention_mask), dim=-1)
mask_len = attention_mask.shape[1]
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not covered with unit tests.

Comment on lines +321 to +325
if mask_len != kv_len:
raise ValueError(
"2D attention_mask length does not match the expected KV length. "
f"Got mask length={mask_len}, expected kv_len={kv_len}."
)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not covered with unit tests.

Comment on lines +329 to +330
elif torch.is_floating_point(attention_mask):
padding_keep = attention_mask != 0
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not covered with unit tests.

Comment on lines +349 to +354
if attention_mask.ndim == 4:
if attention_mask.shape[-2] != q_len or attention_mask.shape[-1] != kv_len:
raise ValueError(
"4D attention_mask shape does not match the expected query/KV lengths. "
f"Got shape={tuple(attention_mask.shape)}, expected (*, *, {q_len}, {kv_len})."
)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not covered with unit tests.

Comment on lines +362 to +389
if attention_mask.dtype == torch.bool:
additive_mask = torch.zeros_like(
attention_mask,
device=input_embeds.device,
dtype=input_embeds.dtype,
)
additive_mask = additive_mask.masked_fill(
~attention_mask.to(device=input_embeds.device),
float("-120"),
)
return additive_mask

bool_mask = attention_mask.to(torch.long) != 0
additive_mask = torch.zeros_like(
bool_mask,
device=input_embeds.device,
dtype=input_embeds.dtype,
)
additive_mask = additive_mask.masked_fill(
~bool_mask.to(device=input_embeds.device),
float("-120"),
)
return additive_mask

raise ValueError(
"Unsupported attention_mask rank. "
f"Expected None, 2D, or 4D mask, but got ndim={attention_mask.ndim}."
)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not covered with unit tests.

if not return_dict:
if use_cache:
return hidden_states, past_key_values
return (hidden_states,)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not covered with unit tests.

float("-120"),
)

return causal_mask + padding_mask
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This addition leads to a bit nonuniform additive mask where -240. elements may appear at positions where causal_mask overlaps with padding_mask:

(Pdb) pp (causal_mask + padding_mask)[0]
tensor( [[[   0., -120., -120., -120., -120., -120., -120., -120., -120., -120., -120., -120., -240., -240., -240., -240.],
          [   0.,    0., -120., -120., -120., -120., -120., -120., -120., -120., -120., -120., -240., -240., -240., -240.],
          [   0.,    0.,    0., -120., -120., -120., -120., -120., -120., -120., -120., -120., -240., -240., -240., -240.],
          [   0.,    0.,    0.,    0., -120., -120., -120., -120., -120., -120., -120., -120., -240., -240., -240., -240.],
          [   0.,    0.,    0.,    0.,    0., -120., -120., -120., -120., -120., -120., -120., -240., -240., -240., -240.],
          [   0.,    0.,    0.,    0.,    0.,    0., -120., -120., -120., -120., -120., -120., -240., -240., -240., -240.],
          [   0.,    0.,    0.,    0.,    0.,    0.,    0., -120., -120., -120., -120., -120., -240., -240., -240., -240.],
          [   0.,    0.,    0.,    0.,    0.,    0.,    0.,    0., -120., -120., -120., -120., -240., -240., -240., -240.],
          [   0.,    0.,    0.,    0.,    0.,    0.,    0.,    0.,    0., -120., -120., -120., -240., -240., -240., -240.],
          [   0.,    0.,    0.,    0.,    0.,    0.,    0.,    0.,    0.,    0., -120., -120., -240., -240., -240., -240.],
          [   0.,    0.,    0.,    0.,    0.,    0.,    0.,    0.,    0.,    0.,    0., -120., -240., -240., -240., -240.],
          [   0.,    0.,    0.,    0.,    0.,    0.,    0.,    0.,    0.,    0.,    0.,    0., -240., -240., -240., -240.],
          [   0.,    0.,    0.,    0.,    0.,    0.,    0.,    0.,    0.,    0.,    0.,    0., -120., -240., -240., -240.],
          [   0.,    0.,    0.,    0.,    0.,    0.,    0.,    0.,    0.,    0.,    0.,    0., -120., -120., -240., -240.],
          [   0.,    0.,    0.,    0.,    0.,    0.,    0.,    0.,    0.,    0.,    0.,    0., -120., -120., -120., -240.],
          [   0.,    0.,    0.,    0.,    0.,    0.,    0.,    0.,    0.,    0.,    0.,    0., -120., -120., -120., -120.]]])

Not sure if this affects anything in terms of quantization of that mask...
Anyway, how about clamping this mixture of -120 and -240?

return torch.clamp(causal_mask + padding_mask, min=-120., max=0.)

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point! I've applied this.

)
additive_mask = additive_mask.masked_fill(
~attention_mask.to(device=input_embeds.device),
float("-120"),
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This magic -120.0 number occurs a lot in the code now.
Maybe it'd be reasonable to define a class variable for it.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah. Actually, the number is kind of heuristic. It would be helpful to have ablation study with various mask values.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I mean, I'd suggest something like this:

class QuantQwen3VLTextModel(QuantModuleBase):
    LARGE_NEGATIVE_NUMBER_FOR_MASKING: float = -120.0
...
                additive_mask = additive_mask.masked_fill(
                    ~attention_mask.to(device=input_embeds.device),
                    self.LARGE_NEGATIVE_NUMBER_FOR_MASKING,

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's consider that in #643.

@mhs4670go
Copy link
Copy Markdown
Contributor Author

Ah I miseed that comments sorry. I'll review it soon

@mhs4670go
Copy link
Copy Markdown
Contributor Author

@dvsav PTAL:)

@mhs4670go mhs4670go merged commit dee0aee into Samsung:main Apr 21, 2026
7 checks passed
@mhs4670go mhs4670go deleted the re2 branch April 22, 2026 00:27
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants