-
Notifications
You must be signed in to change notification settings - Fork 285
Add qwen3 vl #1095
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Add qwen3 vl #1095
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||
|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -35,34 +35,36 @@ def smart_resize( | |||||||||
| height: int, width: int, factor: int = IMAGE_FACTOR, min_pixels: int = MIN_PIXELS, max_pixels: int = MAX_PIXELS | ||||||||||
| ) -> tuple[int, int]: | ||||||||||
|
|
||||||||||
| if max(height, width) / min(height, width) > MAX_RATIO: | ||||||||||
| if max(height, width) / min(height, width) > 200: | ||||||||||
| raise ValueError( | ||||||||||
| f"absolute aspect ratio must be smaller than {MAX_RATIO}, got {max(height, width) / min(height, width)}" | ||||||||||
| f"absolute aspect ratio must be smaller than 200, got {max(height, width) / min(height, width)}" | ||||||||||
| ) | ||||||||||
| h_bar = max(factor, round(height / factor) * factor) | ||||||||||
| w_bar = max(factor, round(width / factor) * factor) | ||||||||||
| h_bar = round(height / factor) * factor | ||||||||||
| w_bar = round(width / factor) * factor | ||||||||||
|
Comment on lines
+42
to
+43
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The calculation for
Suggested change
|
||||||||||
| if h_bar * w_bar > max_pixels: | ||||||||||
| beta = math.sqrt((height * width) / max_pixels) | ||||||||||
| h_bar = math.floor(height / beta / factor) * factor | ||||||||||
| w_bar = math.floor(width / beta / factor) * factor | ||||||||||
| h_bar = max(factor, math.floor(height / beta / factor) * factor) | ||||||||||
| w_bar = max(factor, math.floor(width / beta / factor) * factor) | ||||||||||
| elif h_bar * w_bar < min_pixels: | ||||||||||
| beta = math.sqrt(min_pixels / (height * width)) | ||||||||||
| h_bar = math.ceil(height * beta / factor) * factor | ||||||||||
| w_bar = math.ceil(width * beta / factor) * factor | ||||||||||
| return h_bar, w_bar | ||||||||||
|
|
||||||||||
|
|
||||||||||
| def resize_image(image_file: Image.Image, size_factor: int = IMAGE_FACTOR) -> tuple[Image.Image, int, int]: | ||||||||||
| def resize_image( | ||||||||||
| image_file: Image.Image, factor: int = IMAGE_FACTOR, min_pixels: int = MIN_PIXELS, max_pixels: int = MAX_PIXELS | ||||||||||
| ) -> tuple[Image.Image, int, int]: | ||||||||||
|
|
||||||||||
| image = image_file.convert("RGB") | ||||||||||
| width, height = image.size | ||||||||||
|
|
||||||||||
| resized_height, resized_width = smart_resize( | ||||||||||
| height, | ||||||||||
| width, | ||||||||||
| factor=size_factor, | ||||||||||
| min_pixels=MIN_PIXELS, | ||||||||||
| max_pixels=MAX_PIXELS, | ||||||||||
| factor=factor, | ||||||||||
| min_pixels=min_pixels, | ||||||||||
| max_pixels=max_pixels, | ||||||||||
| ) | ||||||||||
| image = image.resize((resized_width, resized_height)) | ||||||||||
|
|
||||||||||
|
|
@@ -72,6 +74,7 @@ def resize_image(image_file: Image.Image, size_factor: int = IMAGE_FACTOR) -> tu | |||||||||
| class Qwen2VLImageProcessor(BaseImageProcessor): | ||||||||||
| def __init__( | ||||||||||
| self, | ||||||||||
| size: dict = None, | ||||||||||
| do_resize: bool = True, | ||||||||||
| resample: PILImageResampling = PILImageResampling.BICUBIC, | ||||||||||
| do_rescale: bool = True, | ||||||||||
|
|
@@ -88,6 +91,7 @@ def __init__( | |||||||||
| **kwargs, | ||||||||||
| ) -> None: | ||||||||||
| super().__init__(**kwargs) | ||||||||||
| self.size = size | ||||||||||
| self.do_resize = do_resize | ||||||||||
| self.resample = resample | ||||||||||
| self.do_rescale = do_rescale | ||||||||||
|
|
@@ -102,6 +106,13 @@ def __init__( | |||||||||
| self.temporal_patch_size = temporal_patch_size | ||||||||||
| self.merge_size = merge_size | ||||||||||
| self.data_format = ChannelDimension.FIRST | ||||||||||
| if isinstance(self.size, dict): | ||||||||||
| shortest = self.size.get("shortest_edge", None) | ||||||||||
| longest = self.size.get("longest_edge", None) | ||||||||||
| if shortest is not None: | ||||||||||
| self.min_pixels = shortest | ||||||||||
| if longest is not None: | ||||||||||
| self.max_pixels = longest | ||||||||||
|
|
||||||||||
| def preprocess(self, image) -> Tuple[torch.Tensor, torch.Tensor]: | ||||||||||
| if self.do_convert_rgb: | ||||||||||
|
|
||||||||||
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| @@ -0,0 +1,51 @@ | ||||||||||||||||||||||
| import torch | ||||||||||||||||||||||
| import numpy as np | ||||||||||||||||||||||
| from lightllm.models.llama.infer_struct import LlamaInferStateInfo | ||||||||||||||||||||||
| from lightllm.common.basemodel.infer_struct import InferStateInfo | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
|
|
||||||||||||||||||||||
| class Qwen3VLInferStateInfo(LlamaInferStateInfo): | ||||||||||||||||||||||
| def __init__(self): | ||||||||||||||||||||||
| super().__init__() | ||||||||||||||||||||||
| self.deepstack_features = [] | ||||||||||||||||||||||
| self.img_first_token_locs = [] | ||||||||||||||||||||||
| self.img_last_token_locs = [] | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| def apply_interleaved_mrope(self, freqs, mrope_section): | ||||||||||||||||||||||
| """Apply interleaved MRoPE to 3D rotary embeddings. | ||||||||||||||||||||||
| Reorganizes frequency layout from chunked [TTT...HHH...WWW] to | ||||||||||||||||||||||
| interleaved [THTHWHTHW...TT], preserving frequency continuity. | ||||||||||||||||||||||
| args: | ||||||||||||||||||||||
| x: (3, bs, seq_len, head_dim // 2) | ||||||||||||||||||||||
| mrope_section: (3,) | ||||||||||||||||||||||
| returns: | ||||||||||||||||||||||
| x_t: (bs, seq_len, head_dim // 2) | ||||||||||||||||||||||
|
Comment on lines
+18
to
+22
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The docstring for
Suggested change
|
||||||||||||||||||||||
| """ | ||||||||||||||||||||||
| freqs_t = freqs[0] # just overwrite the first dimension T | ||||||||||||||||||||||
| for dim, offset in enumerate((1, 2), start=1): # H, W | ||||||||||||||||||||||
| length = mrope_section[dim] * 3 | ||||||||||||||||||||||
| idx = slice(offset, length, 3) | ||||||||||||||||||||||
| freqs_t[..., idx] = freqs[dim, ..., idx] | ||||||||||||||||||||||
| return freqs_t | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| def init_some_extra_state(self, model, input_ids: torch.Tensor): | ||||||||||||||||||||||
| InferStateInfo.init_some_extra_state(self, model, input_ids) | ||||||||||||||||||||||
| pos = self.position_ids[None, :].expand(3, -1) | ||||||||||||||||||||||
| cos_T = torch.index_select(model._cos_cached, 0, pos[0]) # [L, d/2] | ||||||||||||||||||||||
| cos_H = torch.index_select(model._cos_cached, 0, pos[1]) | ||||||||||||||||||||||
| cos_W = torch.index_select(model._cos_cached, 0, pos[2]) | ||||||||||||||||||||||
| sin_T = torch.index_select(model._sin_cached, 0, pos[0]) | ||||||||||||||||||||||
| sin_H = torch.index_select(model._sin_cached, 0, pos[1]) | ||||||||||||||||||||||
| sin_W = torch.index_select(model._sin_cached, 0, pos[2]) | ||||||||||||||||||||||
| cos_half = self.apply_interleaved_mrope( | ||||||||||||||||||||||
| torch.stack([cos_T, cos_H, cos_W], dim=0), model.mrope_section | ||||||||||||||||||||||
| ) # [L, d/2] | ||||||||||||||||||||||
| sin_half = self.apply_interleaved_mrope( | ||||||||||||||||||||||
| torch.stack([sin_T, sin_H, sin_W], dim=0), model.mrope_section | ||||||||||||||||||||||
| ) # [L, d/2] | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| self.position_cos = torch.cat([cos_half, cos_half], dim=-1).contiguous() # [L, d] | ||||||||||||||||||||||
| self.position_sin = torch.cat([sin_half, sin_half], dim=-1).contiguous() | ||||||||||||||||||||||
| if self.is_prefill: | ||||||||||||||||||||||
| pos = None | ||||||||||||||||||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||||||||||||||||||||||
| return | ||||||||||||||||||||||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,8 @@ | ||
| from lightllm.models.qwen_vl.layer_infer.pre_layer_infer import LlamaMultimodalPreLayerInfer | ||
|
|
||
|
|
||
| class Qwen3VLMultimodalPreLayerInfer(LlamaMultimodalPreLayerInfer): | ||
| def __init__(self, network_config, mode): | ||
| super().__init__(network_config, mode) | ||
| self.use_deepstack = True | ||
| return |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,64 @@ | ||
| import torch | ||
| import torch.functional as F | ||
| import torch.distributed as dist | ||
| import numpy as np | ||
| from functools import partial | ||
| from typing import Tuple | ||
| from lightllm.common.basemodel.infer_struct import InferStateInfo | ||
| from lightllm.models.qwen2_vl.triton_kernel.mrope import mrope_triton | ||
| from lightllm.models.qwen3.layer_infer.transformer_layer_infer import Qwen3TransformerLayerInfer | ||
| from lightllm.models.qwen3.layer_weights.transformer_layer_weight import Qwen3TransformerLayerWeight | ||
| from lightllm.models.llama.layer_infer.transformer_layer_infer import LlamaTransformerLayerInfer | ||
| from lightllm.models.llama.infer_struct import LlamaInferStateInfo | ||
| from lightllm.models.qwen3_vl.infer_struct import Qwen3VLInferStateInfo | ||
| from lightllm.models.llama.triton_kernel.rmsnorm import rmsnorm_forward | ||
| from lightllm.models.llama.triton_kernel.rotary_emb import rotary_emb_fwd | ||
| from lightllm.models.llama.triton_kernel.silu_and_mul import silu_and_mul_fwd | ||
| from lightllm.distributed import all_reduce | ||
| from lightllm.utils.dist_utils import get_global_world_size | ||
|
|
||
|
|
||
| class Qwen3VLTransformerLayerInfer(Qwen3TransformerLayerInfer): | ||
| def __init__(self, layer_num, network_config, mode=[]): | ||
| super().__init__(layer_num, network_config, mode) | ||
| self.mrope_section = network_config["rope_scaling"]["mrope_section"] | ||
| axis_map = [] | ||
| for i, n in enumerate(self.mrope_section * 2): | ||
| axis_map += [i % 3] * n | ||
| self.axis_map = torch.tensor(axis_map, dtype=torch.int32, device="cuda") | ||
|
Comment on lines
+25
to
+28
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
|
|
||
| def context_forward(self, input_embdings, infer_state: Qwen3VLInferStateInfo, layer_weight): | ||
| input1 = self._att_norm(input_embdings, infer_state, layer_weight) | ||
| q, cache_kv = self._get_qkv(input1, infer_state, layer_weight) | ||
| input1 = None | ||
| self._post_cache_kv(cache_kv, infer_state, layer_weight) | ||
| o = self._context_attention_kernel(q, cache_kv, infer_state, layer_weight) | ||
| q = None | ||
| o = self._get_o(o, infer_state, layer_weight) | ||
| if self.tp_world_size_ > 1: | ||
| all_reduce(o, op=dist.ReduceOp.SUM, group=infer_state.dist_group, async_op=False) | ||
| input_embdings.add_(o.view(-1, self.embed_dim_)) | ||
| o = None | ||
|
|
||
| input1 = self._ffn_norm(input_embdings, infer_state, layer_weight) | ||
| ffn_out = self._ffn(input1, infer_state, layer_weight) | ||
| input1 = None | ||
| if self.tp_world_size_ > 1: | ||
| all_reduce(ffn_out, op=dist.ReduceOp.SUM, group=infer_state.dist_group, async_op=False) | ||
| input_embdings.add_(ffn_out.view(-1, self.embed_dim_)) | ||
| if infer_state.deepstack_features: | ||
| for i in range(len(infer_state.img_first_token_locs)): | ||
| start = infer_state.img_first_token_locs[i] | ||
| end = infer_state.img_last_token_locs[i] | ||
| deepstack_features = infer_state.deepstack_features[i] | ||
| if end <= input_embdings.shape[0] and self.layer_num_ in range(len(deepstack_features)): | ||
| deepstack_features_cur_layer = deepstack_features[self.layer_num_].to( | ||
| device=input_embdings.device, non_blocking=True | ||
| ) | ||
| input_embdings[ | ||
| start:end, | ||
| ].add_(deepstack_features_cur_layer) | ||
| infer_state.img_first_token_locs = [] | ||
| infer_state.img_last_token_locs = [] | ||
| infer_state.deepstack_features = [] | ||
| return input_embdings | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,37 @@ | ||
| import torch | ||
| import numpy as np | ||
| from lightllm.common.basemodel import PreAndPostLayerWeight | ||
|
|
||
|
|
||
| class Qwen3VLPreAndPostLayerWeight(PreAndPostLayerWeight): | ||
| def __init__(self, data_type, network_config, mode): | ||
| super().__init__(data_type, network_config, mode) | ||
| return | ||
|
|
||
| def load_hf_weights(self, weights): | ||
| vob_size = self.network_config_["vocab_size"] | ||
| split_indexes = np.linspace(0, vob_size, self.tp_world_size_ + 1, dtype=np.int64) | ||
| split_start = split_indexes[self.tp_rank_] | ||
| split_end = split_indexes[self.tp_rank_ + 1] | ||
| if "model.language_model.embed_tokens.weight" in weights: | ||
| self.wte_weight_ = self._cuda(weights["model.language_model.embed_tokens.weight"][split_start:split_end, :]) | ||
| tie_word_embeddings = self.network_config_.get("tie_word_embeddings", False) | ||
| if tie_word_embeddings: | ||
| self.lm_head_weight_ = self.wte_weight_ | ||
| if "lm_head.weight" in weights: | ||
| self.lm_head_weight_ = self._cuda(weights["lm_head.weight"][split_start:split_end, :]) | ||
| if "model.language_model.norm.weight" in weights: | ||
| self.final_norm_weight_ = self._cuda(weights["model.language_model.norm.weight"]) | ||
|
|
||
| return | ||
|
|
||
| def verify_load(self): | ||
| errors = "weights load not ok" | ||
| weights = [ | ||
| self.wte_weight_, | ||
| self.lm_head_weight_, | ||
| self.final_norm_weight_, | ||
| ] | ||
| for i in range(len(weights)): | ||
| assert weights[i] is not None, "index:" + str(i) + " " + errors | ||
| return |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a good safeguard to prevent running quantization on a model variant that does not support it yet. However, instead of raising a
ValueError, it would be more informative to log a warning and skip quantization for this case, allowing the model to load and run in a non-quantized mode. This would provide more flexibility.