Support Deepstack in qwen3-omni#3214
Conversation
49b62bc to
c794255
Compare
Codecov Report❌ Patch coverage is 📢 Thoughts on this report? Let us know! |
c794255 to
dfa048a
Compare
f19d310 to
7a3a9e2
Compare
|
🤖 Hi @aireenmei, I've received your request, and I'm working on it now! You can track my progress in the logs for more details. |
There was a problem hiding this comment.
This Pull Request effectively implements the Deepstack visual embedding injection into the Qwen3 Omni model. The integration points in the Decoder and vision_encoder look mostly correct, and the vectorized implementation of _deepstack_process using jnp.cumsum is clever and highly efficient.
🔍 General Feedback
- Efficiency Highlight: The use of boolean masks with
cumsumfor aligning visual embeddings with sequence positions is an excellent, TPU-friendly pattern that avoids slow dynamic slice operations. - A logic bug in the
scan_layerscheck could cause the model to silently drop visual embeddings ifimage_masksare not explicitly provided by the caller butdeepstack_visual_embedsare present. - A minor style improvement can be made by making
_deepstack_processa static method, aligning with typical JAX pure-function idioms.
7a3a9e2 to
c52f251
Compare
|
🤖 Hi @aireenmei, I've received your request, and I'm working on it now! You can track my progress in the logs for more details. |
There was a problem hiding this comment.
This PR successfully integrates Deepstack visual embeddings into the decoder layers for the qwen3-omni models. The implementation effectively extends the vision encoder to output deep features and seamlessly processes them within the transformer decoder architecture.
🔍 General Feedback
- The changes correctly decouple deep feature extraction from the main projection logic while adhering to existing layer configurations.
- The unit tests added are comprehensive and correctly validate the scattering of visual tokens across the sequences based on the bidirectional mask.
- A few minor inline suggestions are provided to handle edge cases relating to implicit type promotions and to prevent potential
NaNpropagation when computing masked visual embeddings.
c52f251 to
58bf14e
Compare
58bf14e to
57c958a
Compare
Co-authored-by: Eitan Porat <eporat@lightricks.com>
57c958a to
7da6a17
Compare
Description
Original work #2729
Tests
qwen3_omni_layer_test.py pass locally
The pylint error is unrelated and will be fixed by #3219
Checklist
Before submitting this PR, please make sure (put X in square brackets):
gemini-reviewlabel.