diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe_weight_tp.py b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe_weight_tp.py index 3e61178f3..e862d25ae 100644 --- a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe_weight_tp.py +++ b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe_weight_tp.py @@ -134,6 +134,8 @@ def _fuse(self): inter_shape, hidden_size = self.w2_list[0].shape[0], self.w2_list[0].shape[1] w2 = torch._utils._flatten_dense_tensors(self.w2_list).view(len(self.w2_list), inter_shape, hidden_size) + if self.fused_gate_up: + w2 = w2.transpose(1, 2).contiguous() if not self.quantized_weight and self.quant_method is not None: self.w1 = self.quant_method.quantize(w1) self.w2 = self.quant_method.quantize(w2) @@ -178,26 +180,53 @@ def _fuse_weight_scale(self): def load_hf_weights(self, weights): if self.e_score_correction_bias_name in weights: self.e_score_correction_bias = self._cuda(weights[self.e_score_correction_bias_name]) - for i_experts in range(self.n_routed_experts): - w1_weight = f"{self.weight_prefix}.{i_experts}.{self.w1_weight_name}.weight" - w2_weight = f"{self.weight_prefix}.{i_experts}.{self.w2_weight_name}.weight" - w3_weight = f"{self.weight_prefix}.{i_experts}.{self.w3_weight_name}.weight" + self.fused_gate_up = self.w3_weight_name is None # gate_up: [E,H,2I] down: [E,I,H] + key_gateup_3d = f"{self.weight_prefix}.{self.w1_weight_name}" # ...experts.gate_up_proj + key_down_3d = f"{self.weight_prefix}.{self.w2_weight_name}" - if w1_weight in weights: - self.experts_gate_projs[i_experts] = weights[w1_weight][ - self.split_inter_size * self.tp_rank_ : self.split_inter_size * (self.tp_rank_ + 1), : - ] - if w3_weight in weights: - self.experts_up_projs[i_experts] = weights[w3_weight][ - self.split_inter_size * self.tp_rank_ : self.split_inter_size * (self.tp_rank_ + 1), : - ] + if self.fused_gate_up and (key_gateup_3d in weights) and (key_down_3d in weights): + gate_up_3d = weights[key_gateup_3d] + down_3d = weights[key_down_3d] + assert gate_up_3d.dim() == 3 and down_3d.dim() == 3 - if w2_weight in weights: - self.w2_list[i_experts] = weights[w2_weight][ - :, self.split_inter_size * self.tp_rank_ : self.split_inter_size * (self.tp_rank_ + 1) - ] + E_ckpt, H_, twoE = gate_up_3d.shape + assert E_ckpt == self.n_routed_experts, f"experts mismatch: ckpt {E_ckpt} vs cfg {self.n_routed_experts}" + Eint_total = twoE // 2 + start, end = self.tp_rank_ * self.split_inter_size, (self.tp_rank_ + 1) * self.split_inter_size + assert end <= Eint_total, "TP split exceeds total expert-intermediate size" + + for i in range(self.n_routed_experts): + gu2d = gate_up_3d[i] + gate2d = gu2d[:, :Eint_total][:, start:end].t().contiguous() + up2d = gu2d[:, Eint_total:][:, start:end].t().contiguous() + self.experts_gate_projs[i] = gate2d + self.experts_up_projs[i] = up2d + + self.w2_list[i] = down_3d[i][start:end, :].contiguous() + else: + for i_experts in range(self.n_routed_experts): + w1_weight = f"{self.weight_prefix}.{i_experts}.{self.w1_weight_name}.weight" + w2_weight = f"{self.weight_prefix}.{i_experts}.{self.w2_weight_name}.weight" + w3_weight = f"{self.weight_prefix}.{i_experts}.{self.w3_weight_name}.weight" + + if w1_weight in weights: + self.experts_gate_projs[i_experts] = weights[w1_weight][ + self.split_inter_size * self.tp_rank_ : self.split_inter_size * (self.tp_rank_ + 1), : + ] + if w3_weight in weights: + self.experts_up_projs[i_experts] = weights[w3_weight][ + self.split_inter_size * self.tp_rank_ : self.split_inter_size * (self.tp_rank_ + 1), : + ] + + if w2_weight in weights: + self.w2_list[i_experts] = weights[w2_weight][ + :, self.split_inter_size * self.tp_rank_ : self.split_inter_size * (self.tp_rank_ + 1) + ] if self.quant_method is not None: - self._load_weight_scale(weights) + if self.fused_gate_up: + raise ValueError("qwen3_vl_moe not support quant now") + else: + self._load_weight_scale(weights) self._fuse() def _load_weight_scale(self, weights: Dict[str, torch.Tensor]) -> None: diff --git a/lightllm/models/__init__.py b/lightllm/models/__init__.py index d49d8fa75..3eab0f84f 100644 --- a/lightllm/models/__init__.py +++ b/lightllm/models/__init__.py @@ -29,6 +29,7 @@ from lightllm.models.internvl.model import InternVLInternlm2TpPartModel from lightllm.models.qwen2_vl.model import Qwen2VLTpPartModel from lightllm.models.qwen2_reward.model import Qwen2RewardTpPartModel +from lightllm.models.qwen3_vl.model import Qwen3VLTpPartModel, Qwen3VLMOETpPartModel from lightllm.models.gemma3.model import Gemma3TpPartModel from lightllm.models.tarsier2.model import ( Tarsier2Qwen2TpPartModel, diff --git a/lightllm/models/llama/model.py b/lightllm/models/llama/model.py index abc258e8b..d7be45229 100644 --- a/lightllm/models/llama/model.py +++ b/lightllm/models/llama/model.py @@ -111,6 +111,8 @@ def _init_custom(self): if rope_scaling is None: self._init_to_get_rotary() return + if "mrope_section" in rope_scaling: + self.mrope_section = rope_scaling["mrope_section"] if "rope_type" in rope_scaling: scaling_type = rope_scaling["rope_type"] @@ -128,6 +130,8 @@ def _init_custom(self): self._init_to_get_llama3_rotary() elif scaling_type == "mrope": self._init_to_get_mrope_rotary() + elif scaling_type == "default": + self._init_to_get_rotary() else: raise ValueError(f"Unknown RoPE scaling type {scaling_type}") return @@ -204,7 +208,7 @@ def _init_to_get_rotary(self, default_base=10000): / rope_scaling_factor ) freqs = torch.outer(t, inv_freq) - + self.freqs = freqs.cuda() self._cos_cached = torch.cos(freqs).to(self.data_type).cuda() self._sin_cached = torch.sin(freqs).to(self.data_type).cuda() return diff --git a/lightllm/models/qwen2_5_vl/qwen2_5_visual.py b/lightllm/models/qwen2_5_vl/qwen2_5_visual.py index 01e3a7268..c973716ff 100644 --- a/lightllm/models/qwen2_5_vl/qwen2_5_visual.py +++ b/lightllm/models/qwen2_5_vl/qwen2_5_visual.py @@ -383,7 +383,12 @@ def encode(self, images: List[ImageItem]): uuids.append(img.uuid) image_data = read_shm(get_shm_name_data(img.uuid)) image_data = Image.open(BytesIO(image_data)) - image_data = resize_image(image_data) + image_data = resize_image( + image_file=image_data, + factor=self.processor.patch_size * self.processor.merge_size, + min_pixels=self.processor.min_pixels, + max_pixels=self.processor.max_pixels, + ) pixel_values, image_grid_thw = self.processor.preprocess(image_data) img_tensors.append(pixel_values) img_grids.append(image_grid_thw) diff --git a/lightllm/models/qwen2_vl/model.py b/lightllm/models/qwen2_vl/model.py index 6d179e6f9..d98a0c130 100644 --- a/lightllm/models/qwen2_vl/model.py +++ b/lightllm/models/qwen2_vl/model.py @@ -51,8 +51,9 @@ def init_audioitem_extral_params( def get_image_token_length(self, img: ImageItem): width, height = img.image_w, img.image_h + factor = self.patch_size * self.merge_size resized_height, resized_width = smart_resize( - height=height, width=width, min_pixels=self.min_pixel, max_pixels=self.max_pixel + height=height, width=width, factor=factor, min_pixels=self.min_pixel, max_pixels=self.max_pixel ) grid_h, grid_w = resized_height // self.patch_size, resized_width // self.patch_size token_num = (grid_h * grid_w) // (self.merge_size ** 2) diff --git a/lightllm/models/qwen2_vl/qwen2_visual.py b/lightllm/models/qwen2_vl/qwen2_visual.py index 68e161737..af30bdbe0 100644 --- a/lightllm/models/qwen2_vl/qwen2_visual.py +++ b/lightllm/models/qwen2_vl/qwen2_visual.py @@ -311,7 +311,12 @@ def encode(self, images: List[ImageItem]): uuids.append(img.uuid) image_data = read_shm(get_shm_name_data(img.uuid)) image_data = Image.open(BytesIO(image_data)) - image_data = resize_image(image_data) + image_data = resize_image( + image_file=image_data, + factor=self.processor.patch_size * self.processor.merge_size, + min_pixels=self.processor.min_pixels, + max_pixels=self.processor.max_pixels, + ) pixel_values, image_grid_thw = self.processor.preprocess(image_data) img_tensors.append(pixel_values) img_grids.append(image_grid_thw) diff --git a/lightllm/models/qwen2_vl/vision_process.py b/lightllm/models/qwen2_vl/vision_process.py index 692f3aac3..75ea36f6d 100644 --- a/lightllm/models/qwen2_vl/vision_process.py +++ b/lightllm/models/qwen2_vl/vision_process.py @@ -35,16 +35,16 @@ def smart_resize( height: int, width: int, factor: int = IMAGE_FACTOR, min_pixels: int = MIN_PIXELS, max_pixels: int = MAX_PIXELS ) -> tuple[int, int]: - if max(height, width) / min(height, width) > MAX_RATIO: + if max(height, width) / min(height, width) > 200: raise ValueError( - f"absolute aspect ratio must be smaller than {MAX_RATIO}, got {max(height, width) / min(height, width)}" + f"absolute aspect ratio must be smaller than 200, got {max(height, width) / min(height, width)}" ) - h_bar = max(factor, round(height / factor) * factor) - w_bar = max(factor, round(width / factor) * factor) + h_bar = round(height / factor) * factor + w_bar = round(width / factor) * factor if h_bar * w_bar > max_pixels: beta = math.sqrt((height * width) / max_pixels) - h_bar = math.floor(height / beta / factor) * factor - w_bar = math.floor(width / beta / factor) * factor + h_bar = max(factor, math.floor(height / beta / factor) * factor) + w_bar = max(factor, math.floor(width / beta / factor) * factor) elif h_bar * w_bar < min_pixels: beta = math.sqrt(min_pixels / (height * width)) h_bar = math.ceil(height * beta / factor) * factor @@ -52,7 +52,9 @@ def smart_resize( return h_bar, w_bar -def resize_image(image_file: Image.Image, size_factor: int = IMAGE_FACTOR) -> tuple[Image.Image, int, int]: +def resize_image( + image_file: Image.Image, factor: int = IMAGE_FACTOR, min_pixels: int = MIN_PIXELS, max_pixels: int = MAX_PIXELS +) -> tuple[Image.Image, int, int]: image = image_file.convert("RGB") width, height = image.size @@ -60,9 +62,9 @@ def resize_image(image_file: Image.Image, size_factor: int = IMAGE_FACTOR) -> tu resized_height, resized_width = smart_resize( height, width, - factor=size_factor, - min_pixels=MIN_PIXELS, - max_pixels=MAX_PIXELS, + factor=factor, + min_pixels=min_pixels, + max_pixels=max_pixels, ) image = image.resize((resized_width, resized_height)) @@ -72,6 +74,7 @@ def resize_image(image_file: Image.Image, size_factor: int = IMAGE_FACTOR) -> tu class Qwen2VLImageProcessor(BaseImageProcessor): def __init__( self, + size: dict = None, do_resize: bool = True, resample: PILImageResampling = PILImageResampling.BICUBIC, do_rescale: bool = True, @@ -88,6 +91,7 @@ def __init__( **kwargs, ) -> None: super().__init__(**kwargs) + self.size = size self.do_resize = do_resize self.resample = resample self.do_rescale = do_rescale @@ -102,6 +106,13 @@ def __init__( self.temporal_patch_size = temporal_patch_size self.merge_size = merge_size self.data_format = ChannelDimension.FIRST + if isinstance(self.size, dict): + shortest = self.size.get("shortest_edge", None) + longest = self.size.get("longest_edge", None) + if shortest is not None: + self.min_pixels = shortest + if longest is not None: + self.max_pixels = longest def preprocess(self, image) -> Tuple[torch.Tensor, torch.Tensor]: if self.do_convert_rgb: diff --git a/lightllm/models/qwen3_vl/infer_struct.py b/lightllm/models/qwen3_vl/infer_struct.py new file mode 100644 index 000000000..6c7ac96f9 --- /dev/null +++ b/lightllm/models/qwen3_vl/infer_struct.py @@ -0,0 +1,51 @@ +import torch +import numpy as np +from lightllm.models.llama.infer_struct import LlamaInferStateInfo +from lightllm.common.basemodel.infer_struct import InferStateInfo + + +class Qwen3VLInferStateInfo(LlamaInferStateInfo): + def __init__(self): + super().__init__() + self.deepstack_features = [] + self.img_first_token_locs = [] + self.img_last_token_locs = [] + + def apply_interleaved_mrope(self, freqs, mrope_section): + """Apply interleaved MRoPE to 3D rotary embeddings. + Reorganizes frequency layout from chunked [TTT...HHH...WWW] to + interleaved [THTHWHTHW...TT], preserving frequency continuity. + args: + x: (3, bs, seq_len, head_dim // 2) + mrope_section: (3,) + returns: + x_t: (bs, seq_len, head_dim // 2) + """ + freqs_t = freqs[0] # just overwrite the first dimension T + for dim, offset in enumerate((1, 2), start=1): # H, W + length = mrope_section[dim] * 3 + idx = slice(offset, length, 3) + freqs_t[..., idx] = freqs[dim, ..., idx] + return freqs_t + + def init_some_extra_state(self, model, input_ids: torch.Tensor): + InferStateInfo.init_some_extra_state(self, model, input_ids) + pos = self.position_ids[None, :].expand(3, -1) + cos_T = torch.index_select(model._cos_cached, 0, pos[0]) # [L, d/2] + cos_H = torch.index_select(model._cos_cached, 0, pos[1]) + cos_W = torch.index_select(model._cos_cached, 0, pos[2]) + sin_T = torch.index_select(model._sin_cached, 0, pos[0]) + sin_H = torch.index_select(model._sin_cached, 0, pos[1]) + sin_W = torch.index_select(model._sin_cached, 0, pos[2]) + cos_half = self.apply_interleaved_mrope( + torch.stack([cos_T, cos_H, cos_W], dim=0), model.mrope_section + ) # [L, d/2] + sin_half = self.apply_interleaved_mrope( + torch.stack([sin_T, sin_H, sin_W], dim=0), model.mrope_section + ) # [L, d/2] + + self.position_cos = torch.cat([cos_half, cos_half], dim=-1).contiguous() # [L, d] + self.position_sin = torch.cat([sin_half, sin_half], dim=-1).contiguous() + if self.is_prefill: + pos = None + return diff --git a/lightllm/models/qwen3_vl/layer_infer/pre_layer_infer.py b/lightllm/models/qwen3_vl/layer_infer/pre_layer_infer.py new file mode 100644 index 000000000..f82c4e55b --- /dev/null +++ b/lightllm/models/qwen3_vl/layer_infer/pre_layer_infer.py @@ -0,0 +1,8 @@ +from lightllm.models.qwen_vl.layer_infer.pre_layer_infer import LlamaMultimodalPreLayerInfer + + +class Qwen3VLMultimodalPreLayerInfer(LlamaMultimodalPreLayerInfer): + def __init__(self, network_config, mode): + super().__init__(network_config, mode) + self.use_deepstack = True + return diff --git a/lightllm/models/qwen3_vl/layer_infer/transformer_layer_infer.py b/lightllm/models/qwen3_vl/layer_infer/transformer_layer_infer.py new file mode 100644 index 000000000..a50f817a6 --- /dev/null +++ b/lightllm/models/qwen3_vl/layer_infer/transformer_layer_infer.py @@ -0,0 +1,64 @@ +import torch +import torch.functional as F +import torch.distributed as dist +import numpy as np +from functools import partial +from typing import Tuple +from lightllm.common.basemodel.infer_struct import InferStateInfo +from lightllm.models.qwen2_vl.triton_kernel.mrope import mrope_triton +from lightllm.models.qwen3.layer_infer.transformer_layer_infer import Qwen3TransformerLayerInfer +from lightllm.models.qwen3.layer_weights.transformer_layer_weight import Qwen3TransformerLayerWeight +from lightllm.models.llama.layer_infer.transformer_layer_infer import LlamaTransformerLayerInfer +from lightllm.models.llama.infer_struct import LlamaInferStateInfo +from lightllm.models.qwen3_vl.infer_struct import Qwen3VLInferStateInfo +from lightllm.models.llama.triton_kernel.rmsnorm import rmsnorm_forward +from lightllm.models.llama.triton_kernel.rotary_emb import rotary_emb_fwd +from lightllm.models.llama.triton_kernel.silu_and_mul import silu_and_mul_fwd +from lightllm.distributed import all_reduce +from lightllm.utils.dist_utils import get_global_world_size + + +class Qwen3VLTransformerLayerInfer(Qwen3TransformerLayerInfer): + def __init__(self, layer_num, network_config, mode=[]): + super().__init__(layer_num, network_config, mode) + self.mrope_section = network_config["rope_scaling"]["mrope_section"] + axis_map = [] + for i, n in enumerate(self.mrope_section * 2): + axis_map += [i % 3] * n + self.axis_map = torch.tensor(axis_map, dtype=torch.int32, device="cuda") + + def context_forward(self, input_embdings, infer_state: Qwen3VLInferStateInfo, layer_weight): + input1 = self._att_norm(input_embdings, infer_state, layer_weight) + q, cache_kv = self._get_qkv(input1, infer_state, layer_weight) + input1 = None + self._post_cache_kv(cache_kv, infer_state, layer_weight) + o = self._context_attention_kernel(q, cache_kv, infer_state, layer_weight) + q = None + o = self._get_o(o, infer_state, layer_weight) + if self.tp_world_size_ > 1: + all_reduce(o, op=dist.ReduceOp.SUM, group=infer_state.dist_group, async_op=False) + input_embdings.add_(o.view(-1, self.embed_dim_)) + o = None + + input1 = self._ffn_norm(input_embdings, infer_state, layer_weight) + ffn_out = self._ffn(input1, infer_state, layer_weight) + input1 = None + if self.tp_world_size_ > 1: + all_reduce(ffn_out, op=dist.ReduceOp.SUM, group=infer_state.dist_group, async_op=False) + input_embdings.add_(ffn_out.view(-1, self.embed_dim_)) + if infer_state.deepstack_features: + for i in range(len(infer_state.img_first_token_locs)): + start = infer_state.img_first_token_locs[i] + end = infer_state.img_last_token_locs[i] + deepstack_features = infer_state.deepstack_features[i] + if end <= input_embdings.shape[0] and self.layer_num_ in range(len(deepstack_features)): + deepstack_features_cur_layer = deepstack_features[self.layer_num_].to( + device=input_embdings.device, non_blocking=True + ) + input_embdings[ + start:end, + ].add_(deepstack_features_cur_layer) + infer_state.img_first_token_locs = [] + infer_state.img_last_token_locs = [] + infer_state.deepstack_features = [] + return input_embdings diff --git a/lightllm/models/qwen3_vl/layer_weights/pre_and_post_layer_weight.py b/lightllm/models/qwen3_vl/layer_weights/pre_and_post_layer_weight.py new file mode 100644 index 000000000..98aae3f37 --- /dev/null +++ b/lightllm/models/qwen3_vl/layer_weights/pre_and_post_layer_weight.py @@ -0,0 +1,37 @@ +import torch +import numpy as np +from lightllm.common.basemodel import PreAndPostLayerWeight + + +class Qwen3VLPreAndPostLayerWeight(PreAndPostLayerWeight): + def __init__(self, data_type, network_config, mode): + super().__init__(data_type, network_config, mode) + return + + def load_hf_weights(self, weights): + vob_size = self.network_config_["vocab_size"] + split_indexes = np.linspace(0, vob_size, self.tp_world_size_ + 1, dtype=np.int64) + split_start = split_indexes[self.tp_rank_] + split_end = split_indexes[self.tp_rank_ + 1] + if "model.language_model.embed_tokens.weight" in weights: + self.wte_weight_ = self._cuda(weights["model.language_model.embed_tokens.weight"][split_start:split_end, :]) + tie_word_embeddings = self.network_config_.get("tie_word_embeddings", False) + if tie_word_embeddings: + self.lm_head_weight_ = self.wte_weight_ + if "lm_head.weight" in weights: + self.lm_head_weight_ = self._cuda(weights["lm_head.weight"][split_start:split_end, :]) + if "model.language_model.norm.weight" in weights: + self.final_norm_weight_ = self._cuda(weights["model.language_model.norm.weight"]) + + return + + def verify_load(self): + errors = "weights load not ok" + weights = [ + self.wte_weight_, + self.lm_head_weight_, + self.final_norm_weight_, + ] + for i in range(len(weights)): + assert weights[i] is not None, "index:" + str(i) + " " + errors + return diff --git a/lightllm/models/qwen3_vl/layer_weights/transformers_layer_weight.py b/lightllm/models/qwen3_vl/layer_weights/transformers_layer_weight.py new file mode 100644 index 000000000..f51025527 --- /dev/null +++ b/lightllm/models/qwen3_vl/layer_weights/transformers_layer_weight.py @@ -0,0 +1,44 @@ +import os +import torch +import math +import numpy as np +from lightllm.models.qwen3.layer_weights.transformer_layer_weight import Qwen3TransformerLayerWeight +from lightllm.models.qwen3_moe.layer_weights.transformer_layer_weight import Qwen3MOETransformerLayerWeight +from lightllm.common.basemodel.layer_weights.meta_weights import ( + ROWMMWeight, + MultiROWMMWeight, + COLMMWeight, + NormWeight, + FusedMoeWeightTP, + FusedMoeWeightEP, + ROWBMMWeight, +) + + +class Qwen3VLTransformerLayerWeight(Qwen3TransformerLayerWeight): # 后面看要不要改 + def __init__(self, layer_num, data_type, network_config, mode=[], quant_cfg=None): + super().__init__(layer_num, data_type, network_config, mode, quant_cfg) + + def _init_weight_names(self): + super()._init_weight_names() + self._q_weight_name = f"model.language_model.layers.{self.layer_num_}.self_attn.q_proj.weight" + self._q_norm_name = f"model.language_model.layers.{self.layer_num_}.self_attn.q_norm.weight" + self._q_bias_name = None + self._k_weight_name = f"model.language_model.layers.{self.layer_num_}.self_attn.k_proj.weight" + self._k_norm_name = f"model.language_model.layers.{self.layer_num_}.self_attn.k_norm.weight" + self._k_bias_name = None + self._v_weight_name = f"model.language_model.layers.{self.layer_num_}.self_attn.v_proj.weight" + self._v_bias_name = None + self._o_weight_name = f"model.language_model.layers.{self.layer_num_}.self_attn.o_proj.weight" + self._o_bias_name = None + self._att_norm_weight_name = f"model.language_model.layers.{self.layer_num_}.input_layernorm.weight" + self._att_norm_bias_name = None + self._ffn_norm_weight_name = f"model.language_model.layers.{self.layer_num_}.post_attention_layernorm.weight" + self._ffn_norm_bias_name = None + + self._gate_weight_name = f"model.language_model.layers.{self.layer_num_}.mlp.gate_proj.weight" + self._gate_bias_name = None + self._up_weight_name = f"model.language_model.layers.{self.layer_num_}.mlp.up_proj.weight" + self._up_bias_name = None + self._down_weight_name = f"model.language_model.layers.{self.layer_num_}.mlp.down_proj.weight" + self._down_bias_name = None diff --git a/lightllm/models/qwen3_vl/model.py b/lightllm/models/qwen3_vl/model.py new file mode 100644 index 000000000..6fed68488 --- /dev/null +++ b/lightllm/models/qwen3_vl/model.py @@ -0,0 +1,155 @@ +import json +import numpy as np +import unicodedata +from lightllm.common.basemodel.multimodal_tokenizer import BaseMultiModalTokenizer +from lightllm.models.qwen.model import QWenTpPartModel +from lightllm.models.qwen_vl.layer_infer.pre_layer_infer import LlamaMultimodalPreLayerInfer +from lightllm.server.multimodal_params import AudioItem, MultimodalParams, ImageItem +from transformers.feature_extraction_utils import BatchFeature +from transformers.image_utils import ImageInput +from transformers.processing_utils import ProcessorMixin +from lightllm.server.core.objs import SamplingParams +from transformers.tokenization_utils_base import PaddingStrategy, PreTokenizedInput, TextInput, TruncationStrategy +from typing import List, Optional, Union +from transformers.utils import TensorType, logging +from lightllm.models.qwen2_vl.flashattention_infer_struct import Qwen2VLFlashAttentionStateInfo +from lightllm.common.build_utils import repair_config +from lightllm.models.registry import ModelRegistry +from lightllm.models.qwen3_vl.infer_struct import Qwen3VLInferStateInfo +from lightllm.models.qwen3_vl.layer_infer.pre_layer_infer import Qwen3VLMultimodalPreLayerInfer +from lightllm.models.qwen3_vl.layer_infer.transformer_layer_infer import Qwen3VLTransformerLayerInfer +from lightllm.models.qwen3_vl.layer_weights.pre_and_post_layer_weight import Qwen3VLPreAndPostLayerWeight +from lightllm.models.qwen3_vl.layer_weights.transformers_layer_weight import Qwen3VLTransformerLayerWeight +from lightllm.models.qwen3_vl_moe.layer_weights.transformers_layer_weight import Qwen3VLMOETransformerLayerWeight +from lightllm.models.qwen3_vl_moe.layer_infer.transformer_layer_infer import Qwen3VLMOETransformerLayerInfer + +import torch +from PIL import Image +from lightllm.models.qwen2_vl.vision_process import smart_resize +from lightllm.utils.envs_utils import enable_env_vars, get_env_start_args +from lightllm.models.qwen3.model import Qwen3TpPartModel +from lightllm.models.qwen3_moe.model import Qwen3MOEModel +import os + + +class QWen3VLTokenizer(BaseMultiModalTokenizer): + def __init__(self, tokenizer=None, image_processor=None, **kwargs): + super().__init__(tokenizer) + self.image_processor = image_processor + self.min_pixel = self.image_processor.size["shortest_edge"] + self.max_pixel = self.image_processor.size["longest_edge"] + self.patch_size = self.image_processor.patch_size + self.merge_size = self.image_processor.merge_size + self.image_start_id = kwargs["model_cfg"]["vision_start_token_id"] + self.image_end_id = kwargs["model_cfg"]["vision_end_token_id"] + self.image_token_id = kwargs["model_cfg"]["image_token_id"] + + def init_imageitem_extral_params( + self, img: ImageItem, multi_params: MultimodalParams, sampling_params: SamplingParams + ): + return + + def init_audioitem_extral_params( + self, audio: AudioItem, multi_params: MultimodalParams, sampling_params: SamplingParams + ): + raise NotImplementedError + + def get_image_token_length(self, img: ImageItem): + width, height = img.image_w, img.image_h + factor = self.patch_size * self.merge_size + resized_height, resized_width = smart_resize( + height=height, width=width, factor=factor, min_pixels=self.min_pixel, max_pixels=self.max_pixel + ) + grid_h, grid_w = resized_height // self.patch_size, resized_width // self.patch_size + token_num = (grid_h * grid_w) // (self.merge_size ** 2) + return token_num + + def get_audio_token_length(self, audio: AudioItem): + raise NotImplementedError + + def encode(self, prompt, multimodal_params: MultimodalParams = None, **kwargs): + + origin_ids = self.tokenizer.encode(prompt) + + # -> + origin_ids = [token for token in origin_ids if token != self.image_token_id] + # --> id,id+1...id+num + input_ids = [] + image_id = 0 + start_idx = 0 + while True: + try: + start_idx = origin_ids.index(self.image_start_id, start_idx) + if start_idx + 1 >= len(origin_ids): + break + if origin_ids[start_idx + 1] == self.image_end_id: + input_ids.extend(origin_ids[: start_idx + 1]) + token_id = multimodal_params.images[image_id].token_id + token_num = multimodal_params.images[image_id].token_num + input_ids.extend(range(token_id, token_id + token_num)) + input_ids.append(self.image_end_id) + origin_ids = origin_ids[start_idx + 2 :] + start_idx = 0 + image_id += 1 + else: + raise ValueError("image token error") + except ValueError: + break + input_ids.extend(origin_ids[start_idx:]) + return input_ids + + +@ModelRegistry(["qwen3_vl"], is_multimodal=True) +class Qwen3VLTpPartModel(Qwen3TpPartModel): + + pre_layer_infer_class = Qwen3VLMultimodalPreLayerInfer + transformer_layer_infer_class = Qwen3VLTransformerLayerInfer + + pre_and_post_weight_class = Qwen3VLPreAndPostLayerWeight + transformer_weight_class = Qwen3VLTransformerLayerWeight + + infer_state_class = Qwen3VLInferStateInfo + + def __init__(self, kvargs): + super().__init__(kvargs) + return + + def _init_config(self): + with open(os.path.join(self.weight_dir_, "config.json"), "r") as json_file: + all_config = json.load(json_file) + self.config = all_config["text_config"] + # rename keys + repair_config(self.config, same_names=["num_attention_heads", "n_head"]) + repair_config(self.config, same_names=["hidden_size", "n_embd", "n_embed"]) + repair_config(self.config, same_names=["num_hidden_layers", "n_layer"]) + if self.finetune_config: + self.config["vocab_size"] = self.finetune_config.vocab_size + return + + +@ModelRegistry(["qwen3_vl_moe"], is_multimodal=True) +class Qwen3VLMOETpPartModel(Qwen3MOEModel): + + pre_layer_infer_class = Qwen3VLMultimodalPreLayerInfer + transformer_layer_infer_class = Qwen3VLMOETransformerLayerInfer + + pre_and_post_weight_class = Qwen3VLPreAndPostLayerWeight + transformer_weight_class = Qwen3VLMOETransformerLayerWeight + + infer_state_class = Qwen3VLInferStateInfo + + def __init__(self, kvargs): + super().__init__(kvargs) + return + + def _init_config(self): + with open(os.path.join(self.weight_dir_, "config.json"), "r") as json_file: + all_config = json.load(json_file) + self.config = all_config["text_config"] + # rename keys + repair_config(self.config, same_names=["num_attention_heads", "n_head"]) + repair_config(self.config, same_names=["hidden_size", "n_embd", "n_embed"]) + repair_config(self.config, same_names=["num_hidden_layers", "n_layer"]) + if self.finetune_config: + self.config["vocab_size"] = self.finetune_config.vocab_size + return diff --git a/lightllm/models/qwen3_vl/qwen3_visual.py b/lightllm/models/qwen3_vl/qwen3_visual.py new file mode 100644 index 000000000..02b7fa73e --- /dev/null +++ b/lightllm/models/qwen3_vl/qwen3_visual.py @@ -0,0 +1,397 @@ +# coding=utf-8 +# Copyright 2025 The Qwen Team and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import json +from PIL import Image +from io import BytesIO +from typing import List +from safetensors import safe_open +import torch +import torch.nn as nn +import torch.nn.functional as F +from transformers.activations import ACT2FN + +from lightllm.server.multimodal_params import ImageItem +from lightllm.server.embed_cache.utils import read_shm, get_shm_name_data +from lightllm.models.qwen2_vl.vision_process import resize_image, Qwen2VLImageProcessor +from lightllm.models.qwen2_vl.qwen2_visual import VisionRotaryEmbedding, VisionFlashAttention + + +class Qwen3VLVisionMLP(nn.Module): + def __init__(self, hidden_size: int, intermediate_size: int, hidden_act: str): + super().__init__() + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.linear_fc1 = nn.Linear(self.hidden_size, self.intermediate_size, bias=True) + self.linear_fc2 = nn.Linear(self.intermediate_size, self.hidden_size, bias=True) + self.act_fn = ACT2FN[hidden_act] + + def forward(self, hidden_state): + return self.linear_fc2(self.act_fn(self.linear_fc1(hidden_state))) + + +class Qwen3VLPatchEmbed(nn.Module): + def __init__( + self, + patch_size: int = 14, + temporal_patch_size: int = 2, + in_channels: int = 3, + embed_dim: int = 1152, + ) -> None: + super().__init__() + self.patch_size = patch_size + self.temporal_patch_size = temporal_patch_size + self.in_channels = in_channels + self.embed_dim = embed_dim + + kernel_size = [temporal_patch_size, patch_size, patch_size] + self.proj = nn.Conv3d(in_channels, embed_dim, kernel_size=kernel_size, stride=kernel_size, bias=True) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + target_dtype = self.proj.weight.dtype + if hidden_states.dtype != target_dtype: + hidden_states = hidden_states.to(target_dtype) + hidden_states = hidden_states.view( + -1, self.in_channels, self.temporal_patch_size, self.patch_size, self.patch_size + ) + hidden_states = self.proj(hidden_states).view(-1, self.embed_dim) + return hidden_states + + +class Qwen3VLVisionPatchMerger(nn.Module): + def __init__(self, hidden_size, out_hidden_size, spatial_merge_size, use_postshuffle_norm=False) -> None: + super().__init__() + self.hidden_size = hidden_size * (spatial_merge_size ** 2) + self.use_postshuffle_norm = use_postshuffle_norm + self.norm = nn.LayerNorm(self.hidden_size if use_postshuffle_norm else hidden_size, eps=1e-6) + self.linear_fc1 = nn.Linear(self.hidden_size, self.hidden_size) + self.act_fn = nn.GELU() + self.linear_fc2 = nn.Linear(self.hidden_size, out_hidden_size) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.norm(x.view(-1, self.hidden_size) if self.use_postshuffle_norm else x).view(-1, self.hidden_size) + x = self.linear_fc2(self.act_fn(self.linear_fc1(x))) + return x + + +class Qwen3VLVisionBlock(nn.Module): + def __init__(self, hidden_size, intermediate_size, num_heads, hidden_act) -> None: + super().__init__() + self.norm1 = nn.LayerNorm(hidden_size, eps=1e-6) + self.norm2 = nn.LayerNorm(hidden_size, eps=1e-6) + + self.attn = VisionFlashAttention(hidden_size, num_heads=num_heads) + self.mlp = Qwen3VLVisionMLP(hidden_size, intermediate_size, hidden_act) + + def forward(self, hidden_states, cu_seqlens, max_seqlen, rotary_cos, rotary_sin) -> torch.Tensor: + hidden_states = hidden_states + self.attn( + self.norm1(hidden_states), + cu_seqlens=cu_seqlens, + max_seqlen=max_seqlen, + rotary_cos=rotary_cos, + rotary_sin=rotary_sin, + ) + hidden_states = hidden_states + self.mlp(self.norm2(hidden_states)) + return hidden_states + + +class Qwen3VisionTransformerPretrainedModel(nn.Module): + def __init__( + self, + kvargs, + depth=27, + out_hidden_size=4096, + hidden_size=1152, + hidden_act="gelu_pytorch_tanh", + intermediate_size=4304, + deepstack_visual_indexes=[8, 16, 24], + num_heads=16, + in_channels=3, + patch_size=16, + spatial_merge_size=2, + temporal_patch_size=2, + num_position_embeddings=2304, + **kwargs, + ): + super().__init__() + self.data_type = kvargs.get("data_type", "bfloat16") + + self.depth = depth + self.out_hidden_size = out_hidden_size + self.hidden_size = hidden_size + self.hidden_act = hidden_act + self.num_heads = num_heads + self.in_channels = in_channels + self.patch_size = patch_size + self.intermediate_size = intermediate_size + self.spatial_merge_size = spatial_merge_size + self.temporal_patch_size = temporal_patch_size + self.num_position_embeddings = num_position_embeddings + self.spatial_merge_unit = self.spatial_merge_size * self.spatial_merge_size + + self.patch_embed = Qwen3VLPatchEmbed( + patch_size=self.patch_size, + temporal_patch_size=self.temporal_patch_size, + in_channels=self.in_channels, + embed_dim=self.hidden_size, + ) + + self.pos_embed = nn.Embedding(self.num_position_embeddings, self.hidden_size) + self.num_grid_per_side = int(self.num_position_embeddings ** 0.5) + + head_dim = self.hidden_size // self.num_heads + self.rotary_pos_emb = VisionRotaryEmbedding(head_dim // 2).cuda() + + self.blocks = nn.ModuleList( + [ + Qwen3VLVisionBlock(self.hidden_size, self.intermediate_size, self.num_heads, self.hidden_act) + for _ in range(self.depth) + ] + ) + self.merger = Qwen3VLVisionPatchMerger( + hidden_size=self.hidden_size, + out_hidden_size=self.out_hidden_size, + spatial_merge_size=self.spatial_merge_size, + use_postshuffle_norm=False, + ) + self.deepstack_visual_indexes = deepstack_visual_indexes + self.deepstack_merger_list = nn.ModuleList( + [ + Qwen3VLVisionPatchMerger( + hidden_size=self.hidden_size, + out_hidden_size=self.out_hidden_size, + spatial_merge_size=self.spatial_merge_size, + use_postshuffle_norm=True, + ) + for _ in range(len(self.deepstack_visual_indexes)) + ] + ) + self._init_datatype() + + def _init_datatype(self): + if isinstance(self.data_type, torch.dtype): + return + if self.data_type in ["fp16", "float16"]: + self.data_type = torch.float16 + elif self.data_type in ["bf16", "bfloat16"]: + self.data_type = torch.bfloat16 + elif self.data_type in ["fp32", "float32"]: + self.data_type = torch.float32 + else: + raise ValueError(f"Unsupport datatype {self.data_type}!") + return + + def load_model(self, weight_dir): + + processor_config_path = os.path.join(weight_dir, "preprocessor_config.json") + with open(processor_config_path, "r") as f: + processor_config_dict = json.load(f) + self.processor = Qwen2VLImageProcessor(**processor_config_dict) + + bin_weight_files = [file_ for file_ in os.listdir(weight_dir) if file_.endswith(".bin")] + if bin_weight_files: + weight_dict = {} + for file_ in bin_weight_files: + f = torch.load(os.path.join(weight_dir, file_), "cpu") + for k, v in f.items(): + if "model.visual" in k: + weight_dict[k[len("model.visual.") :]] = v + else: + hf_weight_files = [file_ for file_ in os.listdir(weight_dir) if file_.endswith(".safetensors")] + weight_dict = {} + for file_ in hf_weight_files: + f = safe_open(os.path.join(weight_dir, file_), "pt", "cpu") + for k in f.keys(): + if "model.visual" in k: + weight_dict[k[len("model.visual.") :]] = f.get_tensor(k) + + self.load_state_dict(weight_dict) + + def rot_pos_emb(self, grid_thw: torch.Tensor) -> torch.Tensor: + merge_size = self.spatial_merge_size + + total_tokens = int(torch.prod(grid_thw, dim=1).sum().item()) + pos_ids = torch.empty((total_tokens, 2), dtype=torch.long) + + offset = 0 + for num_frames, height, width in grid_thw: + merged_h, merged_w = height // merge_size, width // merge_size + + block_rows = torch.arange(merged_h) # block row indices + block_cols = torch.arange(merged_w) # block col indices + intra_row = torch.arange(merge_size) # intra-block row offsets + intra_col = torch.arange(merge_size) # intra-block col offsets + + # Compute full-resolution positions + row_idx = block_rows[:, None, None, None] * merge_size + intra_row[None, None, :, None] + col_idx = block_cols[None, :, None, None] * merge_size + intra_col[None, None, None, :] + + row_idx = row_idx.expand(merged_h, merged_w, merge_size, merge_size).reshape(-1) + col_idx = col_idx.expand(merged_h, merged_w, merge_size, merge_size).reshape(-1) + + coords = torch.stack((row_idx, col_idx), dim=-1) + + if num_frames > 1: + coords = coords.repeat(num_frames, 1) + + num_tokens = coords.shape[0] + pos_ids[offset : offset + num_tokens] = coords + offset += num_tokens + + max_hw = int(grid_thw[:, 1:].max().item()) + cos_full, sin_full = self.rotary_pos_emb(max_hw) # (max_hw, dim // 2) + cos = cos_full[pos_ids].flatten(1) + sin = sin_full[pos_ids].flatten(1) + return cos, sin + + def fast_pos_embed_interpolate(self, grid_thw): + grid_ts, grid_hs, grid_ws = grid_thw[:, 0], grid_thw[:, 1], grid_thw[:, 2] + + idx_list = [[] for _ in range(4)] + weight_list = [[] for _ in range(4)] + + for t, h, w in zip(grid_ts, grid_hs, grid_ws): + h_idxs = torch.linspace(0, self.num_grid_per_side - 1, h) + w_idxs = torch.linspace(0, self.num_grid_per_side - 1, w) + + h_idxs_floor = h_idxs.int() + w_idxs_floor = w_idxs.int() + h_idxs_ceil = (h_idxs.int() + 1).clip(max=self.num_grid_per_side - 1) + w_idxs_ceil = (w_idxs.int() + 1).clip(max=self.num_grid_per_side - 1) + + dh = h_idxs - h_idxs_floor + dw = w_idxs - w_idxs_floor + + base_h = h_idxs_floor * self.num_grid_per_side + base_h_ceil = h_idxs_ceil * self.num_grid_per_side + + indices = [ + (base_h[None].T + w_idxs_floor[None]).flatten(), + (base_h[None].T + w_idxs_ceil[None]).flatten(), + (base_h_ceil[None].T + w_idxs_floor[None]).flatten(), + (base_h_ceil[None].T + w_idxs_ceil[None]).flatten(), + ] + + weights = [ + ((1 - dh)[None].T * (1 - dw)[None]).flatten(), + ((1 - dh)[None].T * dw[None]).flatten(), + (dh[None].T * (1 - dw)[None]).flatten(), + (dh[None].T * dw[None]).flatten(), + ] + + for i in range(4): + idx_list[i].extend(indices[i].tolist()) + weight_list[i].extend(weights[i].tolist()) + + idx_tensor = torch.tensor(idx_list, dtype=torch.long, device=self.pos_embed.weight.device) + weight_tensor = torch.tensor( + weight_list, dtype=self.pos_embed.weight.dtype, device=self.pos_embed.weight.device + ) + pos_embeds = self.pos_embed(idx_tensor) * weight_tensor[:, :, None] + patch_pos_embeds = pos_embeds[0] + pos_embeds[1] + pos_embeds[2] + pos_embeds[3] + + patch_pos_embeds = patch_pos_embeds.split([h * w for h, w in zip(grid_hs, grid_ws)]) + + patch_pos_embeds_permute = [] + merge_size = self.spatial_merge_size + for pos_embed, t, h, w in zip(patch_pos_embeds, grid_ts, grid_hs, grid_ws): + pos_embed = pos_embed.repeat(t, 1) + pos_embed = ( + pos_embed.view(t, h // merge_size, merge_size, w // merge_size, merge_size, -1) + .permute(0, 1, 3, 2, 4, 5) + .flatten(0, 4) + ) + patch_pos_embeds_permute.append(pos_embed) + patch_pos_embeds = torch.cat(patch_pos_embeds_permute) + return patch_pos_embeds + + def forward(self, hidden_states: torch.Tensor, grid_thw: torch.Tensor, **kwargs) -> torch.Tensor: + hidden_states = self.patch_embed(hidden_states) + + pos_embeds = self.fast_pos_embed_interpolate(grid_thw) + hidden_states = hidden_states + pos_embeds + + rotary_cos, rotary_sin = self.rot_pos_emb(grid_thw) + rotary_cos = rotary_cos.to("cuda", non_blocking=True) + rotary_sin = rotary_sin.to("cuda", non_blocking=True) + + cu_seqlens = torch.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]).cumsum( + dim=0, + dtype=torch.int32, + ) + cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0).to("cuda", non_blocking=True) + max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() + + deepstack_feature_lists = [] + for layer_num, blk in enumerate(self.blocks): + hidden_states = blk( + hidden_states, + cu_seqlens=cu_seqlens, + max_seqlen=max_seqlen, + rotary_cos=rotary_cos, + rotary_sin=rotary_sin, + ) + if layer_num in self.deepstack_visual_indexes: + deepstack_feature = self.deepstack_merger_list[self.deepstack_visual_indexes.index(layer_num)]( + hidden_states + ) + deepstack_feature_lists.append(deepstack_feature) + + hidden_states = self.merger(hidden_states) + + return hidden_states, deepstack_feature_lists + + def encode(self, images: List[ImageItem]): + img_tensors = [] + valid_ids = [] + valid_id = 0 + img_grids = [] + uuids = [] + for i, img in enumerate(images): + if isinstance(img, ImageItem): + uuids.append(img.uuid) + image_data = read_shm(get_shm_name_data(img.uuid)) + image_data = Image.open(BytesIO(image_data)) + image_data = resize_image( + image_file=image_data, + factor=self.processor.patch_size * self.processor.merge_size, + min_pixels=self.processor.min_pixels, + max_pixels=self.processor.max_pixels, + ) + pixel_values, image_grid_thw = self.processor.preprocess(image_data) + img_tensors.append(pixel_values) + img_grids.append(image_grid_thw) + else: + raise Exception("Unsupport input types: {} for {}".format(type(img), img)) + + # must devide merge_length + cur_num = img_tensors[-1].shape[0] // (self.spatial_merge_size ** 2) + + valid_ids.append([valid_id, valid_id + cur_num]) + valid_id += cur_num + + if len(img_tensors) <= 0: + return None + + imgs = torch.cat(img_tensors, dim=0) + grid_thw = torch.cat(img_grids, dim=0) + + pixel_values = imgs.to("cuda", dtype=self.data_type, non_blocking=True) + image_grid_thw = grid_thw.to("cuda", non_blocking=True) + + all_img_embeds, deepstack_feature_lists = self.forward(pixel_values, grid_thw=image_grid_thw) + + return all_img_embeds, uuids, valid_ids, deepstack_feature_lists diff --git a/lightllm/models/qwen3_vl_moe/layer_infer/transformer_layer_infer.py b/lightllm/models/qwen3_vl_moe/layer_infer/transformer_layer_infer.py new file mode 100644 index 000000000..3cf4b0d2d --- /dev/null +++ b/lightllm/models/qwen3_vl_moe/layer_infer/transformer_layer_infer.py @@ -0,0 +1,96 @@ +import torch +import torch.functional as F +import torch.distributed as dist +import numpy as np +from functools import partial +from typing import Tuple +from lightllm.common.basemodel.infer_struct import InferStateInfo +from lightllm.models.qwen2_vl.triton_kernel.mrope import mrope_triton +from lightllm.models.qwen3_moe.layer_infer.transformer_layer_infer import Qwen3MOETransformerLayerInfer +from lightllm.models.qwen3_moe.layer_weights.transformer_layer_weight import Qwen3MOETransformerLayerWeight +from lightllm.models.llama.layer_infer.transformer_layer_infer import LlamaTransformerLayerInfer +from lightllm.models.llama.infer_struct import LlamaInferStateInfo +from lightllm.models.qwen3_vl.infer_struct import Qwen3VLInferStateInfo +from lightllm.models.llama.triton_kernel.rmsnorm import rmsnorm_forward +from lightllm.models.llama.triton_kernel.rotary_emb import rotary_emb_fwd +from lightllm.models.llama.triton_kernel.silu_and_mul import silu_and_mul_fwd +from lightllm.distributed import all_reduce +from lightllm.utils.dist_utils import get_global_world_size + + +class Qwen3VLMOETransformerLayerInfer(Qwen3MOETransformerLayerInfer): + def __init__(self, layer_num, network_config, mode=[]): + super().__init__(layer_num, network_config, mode) + self.mrope_section = network_config["rope_scaling"]["mrope_section"] + axis_map = [] + for i, n in enumerate(self.mrope_section * 2): + axis_map += [i % 3] * n + self.axis_map = torch.tensor(axis_map, dtype=torch.int32, device="cuda") + + def _get_qkv( + self, + input: torch.Tensor, + infer_state: Qwen3VLInferStateInfo, + layer_weight: Qwen3MOETransformerLayerWeight, + ) -> Tuple[torch.Tensor, torch.Tensor]: + input = input.view(-1, self.embed_dim_) + q = layer_weight.q_proj.mm(input) + cache_kv = self._pre_cache_kv(infer_state=infer_state, layer_weight=layer_weight) + cache_kv = layer_weight.kv_proj.mm( + input, out=cache_kv.view(-1, (self.tp_k_head_num_ + self.tp_v_head_num_) * self.head_dim_) + ).view(-1, (self.tp_k_head_num_ + self.tp_v_head_num_), self.head_dim_) + rmsnorm_forward( + q.view(-1, self.head_dim_), + weight=layer_weight.q_norm_weight_.weight, + eps=self.eps_, + out=q.view(-1, self.head_dim_), + ) + + cache_kv[:, : self.tp_k_head_num_, :] = rmsnorm_forward( + cache_kv[:, : self.tp_k_head_num_, :].reshape(-1, cache_kv.shape[-1]), + weight=layer_weight.k_norm_weight_.weight, + eps=self.eps_, + ).view(-1, self.tp_k_head_num_, cache_kv.shape[-1]) + rotary_emb_fwd( + q.view(-1, self.tp_q_head_num_, self.head_dim_), + cache_kv[:, : self.tp_k_head_num_, :], + infer_state.position_cos, + infer_state.position_sin, + ) + return q, cache_kv + + def context_forward(self, input_embdings, infer_state: Qwen3VLInferStateInfo, layer_weight): + input1 = self._att_norm(input_embdings, infer_state, layer_weight) + q, cache_kv = self._get_qkv(input1, infer_state, layer_weight) + input1 = None + self._post_cache_kv(cache_kv, infer_state, layer_weight) + o = self._context_attention_kernel(q, cache_kv, infer_state, layer_weight) + q = None + o = self._get_o(o, infer_state, layer_weight) + if self.tp_world_size_ > 1: + all_reduce(o, op=dist.ReduceOp.SUM, group=infer_state.dist_group, async_op=False) + input_embdings.add_(o.view(-1, self.embed_dim_)) + o = None + + input1 = self._ffn_norm(input_embdings, infer_state, layer_weight) + ffn_out = self._ffn(input1, infer_state, layer_weight) + input1 = None + if self.tp_world_size_ > 1: + all_reduce(ffn_out, op=dist.ReduceOp.SUM, group=infer_state.dist_group, async_op=False) + input_embdings.add_(ffn_out.view(-1, self.embed_dim_)) + if infer_state.deepstack_features: + for i in range(len(infer_state.img_first_token_locs)): + start = infer_state.img_first_token_locs[i] + end = infer_state.img_last_token_locs[i] + deepstack_features = infer_state.deepstack_features[i] + if end <= input_embdings.shape[0] and self.layer_num_ in range(len(deepstack_features)): + deepstack_features_cur_layer = deepstack_features[self.layer_num_].to( + device=input_embdings.device, non_blocking=True + ) + input_embdings[ + start:end, + ].add_(deepstack_features_cur_layer) + infer_state.img_first_token_locs = [] + infer_state.img_last_token_locs = [] + infer_state.deepstack_features = [] + return input_embdings diff --git a/lightllm/models/qwen3_vl_moe/layer_weights/pre_and_post_layer_weight.py b/lightllm/models/qwen3_vl_moe/layer_weights/pre_and_post_layer_weight.py new file mode 100644 index 000000000..0a7f82a93 --- /dev/null +++ b/lightllm/models/qwen3_vl_moe/layer_weights/pre_and_post_layer_weight.py @@ -0,0 +1,37 @@ +import torch +import numpy as np +from lightllm.common.basemodel import PreAndPostLayerWeight + + +class Qwen3VLMOEPreAndPostLayerWeight(PreAndPostLayerWeight): + def __init__(self, data_type, network_config, mode): + super().__init__(data_type, network_config, mode) + return + + def load_hf_weights(self, weights): + vob_size = self.network_config_["vocab_size"] + split_indexes = np.linspace(0, vob_size, self.tp_world_size_ + 1, dtype=np.int64) + split_start = split_indexes[self.tp_rank_] + split_end = split_indexes[self.tp_rank_ + 1] + if "model.language_model.embed_tokens.weight" in weights: + self.wte_weight_ = self._cuda(weights["model.language_model.embed_tokens.weight"][split_start:split_end, :]) + tie_word_embeddings = self.network_config_.get("tie_word_embeddings", False) + if tie_word_embeddings: + self.lm_head_weight_ = self.wte_weight_ + if "lm_head.weight" in weights: + self.lm_head_weight_ = self._cuda(weights["lm_head.weight"][split_start:split_end, :]) + if "model.language_model.norm.weight" in weights: + self.final_norm_weight_ = self._cuda(weights["model.language_model.norm.weight"]) + + return + + def verify_load(self): + errors = "weights load not ok" + weights = [ + self.wte_weight_, + self.lm_head_weight_, + self.final_norm_weight_, + ] + for i in range(len(weights)): + assert weights[i] is not None, "index:" + str(i) + " " + errors + return diff --git a/lightllm/models/qwen3_vl_moe/layer_weights/transformers_layer_weight.py b/lightllm/models/qwen3_vl_moe/layer_weights/transformers_layer_weight.py new file mode 100644 index 000000000..3a3af7e5a --- /dev/null +++ b/lightllm/models/qwen3_vl_moe/layer_weights/transformers_layer_weight.py @@ -0,0 +1,81 @@ +import os +import torch +import math +import numpy as np +from lightllm.models.qwen3.layer_weights.transformer_layer_weight import Qwen3TransformerLayerWeight +from lightllm.models.qwen3_moe.layer_weights.transformer_layer_weight import Qwen3MOETransformerLayerWeight +from lightllm.common.basemodel.layer_weights.meta_weights import ( + ROWMMWeight, + MultiROWMMWeight, + COLMMWeight, + NormWeight, + FusedMoeWeightTP, + FusedMoeWeightEP, + ROWBMMWeight, +) + + +class Qwen3VLMOETransformerLayerWeight(Qwen3MOETransformerLayerWeight): + def __init__(self, layer_num, data_type, network_config, mode=[], quant_cfg=None): + super().__init__(layer_num, data_type, network_config, mode, quant_cfg) + + def _init_weight_names(self): + super()._init_weight_names() + self._q_weight_name = f"model.language_model.layers.{self.layer_num_}.self_attn.q_proj.weight" + self._q_norm_name = f"model.language_model.layers.{self.layer_num_}.self_attn.q_norm.weight" + self._q_bias_name = None + self._k_weight_name = f"model.language_model.layers.{self.layer_num_}.self_attn.k_proj.weight" + self._k_norm_name = f"model.language_model.layers.{self.layer_num_}.self_attn.k_norm.weight" + self._k_bias_name = None + self._v_weight_name = f"model.language_model.layers.{self.layer_num_}.self_attn.v_proj.weight" + self._v_bias_name = None + self._o_weight_name = f"model.language_model.layers.{self.layer_num_}.self_attn.o_proj.weight" + self._o_bias_name = None + self._att_norm_weight_name = f"model.language_model.layers.{self.layer_num_}.input_layernorm.weight" + self._att_norm_bias_name = None + self._ffn_norm_weight_name = f"model.language_model.layers.{self.layer_num_}.post_attention_layernorm.weight" + self._ffn_norm_bias_name = None + + def _init_moe(self): + moe_intermediate_size = self.network_config_["moe_intermediate_size"] + self.moe_gate = ROWMMWeight( + weight_name=f"model.language_model.layers.{self.layer_num_}.mlp.gate.weight", + data_type=self.data_type_, + layer_num=self.layer_num_, + name="moe_gate", + tp_rank=0, + tp_world_size=1, + ) + moe_mode = os.getenv("MOE_MODE", "TP") + assert moe_mode in ["EP", "TP"] + + if moe_mode == "TP": + self.experts = FusedMoeWeightTP( + gate_proj_name="gate_up_proj", + down_proj_name="down_proj", + up_proj_name=None, + e_score_correction_bias_name="", + weight_prefix=f"model.language_model.layers.{self.layer_num_}.mlp.experts", + n_routed_experts=self.n_routed_experts, + split_inter_size=moe_intermediate_size // self.tp_world_size_, + data_type=self.data_type_, + network_config=self.network_config_, + layer_num=self.layer_num_, + quant_cfg=self.quant_cfg, + num_fused_shared_experts=0, + ) + elif moe_mode == "EP": + self.experts = FusedMoeWeightEP( + gate_proj_name="gate_up_proj", + down_proj_name="down_proj", + up_proj_name=None, + e_score_correction_bias_name="", + weight_prefix=f"model.language_model.layers.{self.layer_num_}.mlp.experts", + n_routed_experts=self.n_routed_experts, + data_type=self.data_type_, + network_config=self.network_config_, + layer_num=self.layer_num_, + quant_cfg=self.quant_cfg, + ) + else: + raise ValueError(f"Unsupported moe mode: {moe_mode}") diff --git a/lightllm/models/qwen_vl/layer_infer/pre_layer_infer.py b/lightllm/models/qwen_vl/layer_infer/pre_layer_infer.py index b5b31a413..fa8e7b414 100644 --- a/lightllm/models/qwen_vl/layer_infer/pre_layer_infer.py +++ b/lightllm/models/qwen_vl/layer_infer/pre_layer_infer.py @@ -6,7 +6,13 @@ from lightllm.models.llama.layer_infer.pre_layer_infer import LlamaPreLayerInfer from lightllm.utils.infer_utils import mark_cost_time -from lightllm.server.embed_cache.utils import bytes2tensor, read_shm, get_shm_name_embed +from lightllm.server.embed_cache.utils import ( + bytes2tensor, + read_shm, + get_shm_name_embed, + get_shm_name_deepstack, + bytes2list, +) from lightllm.common.basemodel.triton_kernel.multimodal_emb import multimodal_emb from lightllm.distributed.communication_op import all_reduce @@ -29,6 +35,7 @@ class LlamaMultimodalPreLayerInfer(LlamaPreLayerInfer): def __init__(self, network_config, mode): super().__init__(network_config, mode) + self.use_deepstack = False return def context_forward(self, input_ids, infer_state: LlamaInferStateInfo, layer_weight: LlamaPreAndPostLayerWeight): @@ -50,9 +57,18 @@ def context_forward(self, input_ids, infer_state: LlamaInferStateInfo, layer_wei # skip the same image if img["token_id"] in img_start_token_ids or img["_prefill_"] is False: continue + pos = (input_ids == img["token_id"]).nonzero(as_tuple=True) + if pos[0].numel() == 0: + continue # pull the img_embeds by uid from shm data = read_shm(get_shm_name_embed(img["uuid"])) img_weight.append(bytes2tensor(data).cuda().reshape(img["token_num"], -1)) + if self.use_deepstack: + deepstack_features = read_shm(get_shm_name_deepstack(img["uuid"])) + infer_state.deepstack_features.append(bytes2list(deepstack_features)) + img_insert_locs = int(pos[0][0]) + infer_state.img_first_token_locs.append(img_insert_locs) + infer_state.img_last_token_locs.append(img_insert_locs + img["token_num"]) img_start_token_ids.append(img["token_id"]) img_token_lens.append(img["token_num"]) img_start_locs.append(img_start_loc) diff --git a/lightllm/server/embed_cache/utils.py b/lightllm/server/embed_cache/utils.py index 6df031293..08d49c17d 100644 --- a/lightllm/server/embed_cache/utils.py +++ b/lightllm/server/embed_cache/utils.py @@ -1,6 +1,7 @@ import torch import numpy as np from io import BytesIO +from typing import List, Optional import multiprocessing.shared_memory as shm @@ -20,11 +21,52 @@ def tensor2bytes(t: torch.Tensor): return buf.read() +def list2bytes(tensors: List[torch.Tensor]) -> bytes: + # 逐个张量做 detach().cpu() 和复制 + safe_list = [] + for t in tensors: + if t is None: + safe_list.append(None) + continue + t = t.detach().cpu() + if not t.is_contiguous(): + t = t.contiguous() + dest = torch.empty_like(t) + dest.copy_(t) + safe_list.append(dest) + buf = BytesIO() + torch.save(safe_list, buf, _use_new_zipfile_serialization=False, pickle_protocol=4) + buf.seek(0) + return buf.read() + + def bytes2tensor(b): # return torch.from_numpy(np.frombuffer(b, dtype=np.float16)).cuda() return torch.load(BytesIO(b), weights_only=False) +def bytes2list(b: bytes, device: Optional[torch.device] = None, non_blocking: bool = False) -> List[torch.Tensor]: + obj = torch.load(BytesIO(b), map_location="cpu", weights_only=False) + + if isinstance(obj, tuple): + obj = list(obj) + if not isinstance(obj, list): + raise TypeError(f"Loaded object is {type(obj)}, expected list or tuple.") + + if device is None: + return obj + + out: List[torch.Tensor] = [] + for x in obj: + if x is None: + out.append(None) + elif isinstance(x, torch.Tensor): + out.append(x.to(device, non_blocking=non_blocking)) + else: + raise TypeError(f"List element is {type(x)}, expected Tensor or None.") + return out + + def create_shm(name, data): try: data_size = len(data) @@ -53,3 +95,7 @@ def get_shm_name_data(uid): def get_shm_name_embed(uid): return str(uid) + "-embed" + + +def get_shm_name_deepstack(uid): + return str(uid) + "-deepstack" diff --git a/lightllm/server/tokenizer.py b/lightllm/server/tokenizer.py index 1f10aa5ec..e0b2bd425 100644 --- a/lightllm/server/tokenizer.py +++ b/lightllm/server/tokenizer.py @@ -28,6 +28,7 @@ from ..models.llava.model import LlavaTokenizer from ..models.qwen_vl.model import QWenVLTokenizer from ..models.qwen2_vl.model import QWen2VLTokenizer +from ..models.qwen3_vl.model import QWen3VLTokenizer from ..models.internvl.model import InternvlTokenizer from ..models.gemma3.model import Gemma3Tokenizer @@ -92,6 +93,13 @@ def get_tokenizer( tokenizer = QWen2VLTokenizer( tokenizer=tokenizer, image_processor=processor.image_processor, model_cfg=model_cfg ) + elif model_type in ["qwen3_vl", "qwen3_vl_moe"] and "vision_config" in model_cfg: + from transformers import AutoProcessor + + processor = AutoProcessor.from_pretrained(tokenizer_name) + tokenizer = QWen3VLTokenizer( + tokenizer=tokenizer, image_processor=processor.image_processor, model_cfg=model_cfg + ) elif model_type == "internvl_chat": tokenizer = InternvlTokenizer(tokenizer, model_cfg, weight_dir=tokenizer_name) elif model_type == "gemma3": diff --git a/lightllm/server/visualserver/model_infer/model_rpc.py b/lightllm/server/visualserver/model_infer/model_rpc.py index a25065e42..5815e7b81 100644 --- a/lightllm/server/visualserver/model_infer/model_rpc.py +++ b/lightllm/server/visualserver/model_infer/model_rpc.py @@ -16,8 +16,17 @@ from lightllm.server.multimodal_params import MultimodalParams, ImageItem from lightllm.models.qwen2_vl.qwen2_visual import Qwen2VisionTransformerPretrainedModel from lightllm.models.qwen2_5_vl.qwen2_5_visual import Qwen2_5_VisionTransformerPretrainedModel +from lightllm.models.qwen3_vl.qwen3_visual import Qwen3VisionTransformerPretrainedModel from lightllm.models.tarsier2.tarsier2_visual import TarsierVisionTransformerPretrainedModel -from lightllm.server.embed_cache.utils import tensor2bytes, read_shm, create_shm, get_shm_name_data, get_shm_name_embed +from lightllm.server.embed_cache.utils import ( + tensor2bytes, + read_shm, + create_shm, + get_shm_name_data, + get_shm_name_embed, + get_shm_name_deepstack, + list2bytes, +) from lightllm.utils.infer_utils import set_random_seed from lightllm.utils.infer_utils import calculate_time, mark_start, mark_end from lightllm.utils.dist_utils import init_vision_distributed_env @@ -63,6 +72,10 @@ def exposed_init_model(self, kvargs): self.model = ( Qwen2_5_VisionTransformerPretrainedModel(kvargs, **model_cfg["vision_config"]).eval().bfloat16() ) + elif self.model_type in ["qwen3_vl", "qwen3_vl_moe"]: + self.model = ( + Qwen3VisionTransformerPretrainedModel(kvargs, **model_cfg["vision_config"]).eval().bfloat16() + ) elif model_cfg["architectures"][0] == "TarsierForConditionalGeneration": self.model = TarsierVisionTransformerPretrainedModel(**model_cfg).eval().bfloat16() elif self.model_type == "llava": @@ -96,7 +109,8 @@ def forward(self, images: List[ImageItem]): # @calculate_time(show=False, min_cost_ms=300) def exposed_encode(self, images: List[ImageItem]): images = obtain(images) - all_img_embeds, uuids, valid_ids = self.forward(images) + all_img_embeds, uuids, valid_ids, *deepstack_features = self.forward(images) + deepstack_feature_lists = deepstack_features[0] if deepstack_features else None all_img_embeds = all_img_embeds.to(torch.device("cpu")) if self.tp_rank_id == 0: @@ -109,6 +123,10 @@ def exposed_encode(self, images: List[ImageItem]): start, end = valid_ids[i] cur_embed_bytes = tensor2bytes(all_img_embeds[start:end]) create_shm(get_shm_name_embed(uid), cur_embed_bytes) + if deepstack_feature_lists is not None: + per_image_deepstack = [feat[start:end] for feat in deepstack_feature_lists] + deepstack_features_bytes = list2bytes(per_image_deepstack) + create_shm(get_shm_name_deepstack(uid), deepstack_features_bytes) ids_to_set.append(uid) if ids_to_set: self.cache_client.root.set_items_embed(ids_to_set) diff --git a/lightllm/utils/config_utils.py b/lightllm/utils/config_utils.py index afaf71a25..91543f14c 100644 --- a/lightllm/utils/config_utils.py +++ b/lightllm/utils/config_utils.py @@ -78,6 +78,9 @@ def get_vocab_size(model_path: str): if "llm_config" in config_json: vocab_size = int(config_json["llm_config"]["vocab_size"]) return vocab_size + elif "text_config" in config_json: + vocab_size = int(config_json["text_config"]["vocab_size"]) + return vocab_size vocab_size = config_json["vocab_size"] if not isinstance(vocab_size, int): vocab_size = int(vocab_size)