diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe_weight_tp.py b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe_weight_tp.py
index 3e61178f3..e862d25ae 100644
--- a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe_weight_tp.py
+++ b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe_weight_tp.py
@@ -134,6 +134,8 @@ def _fuse(self):
inter_shape, hidden_size = self.w2_list[0].shape[0], self.w2_list[0].shape[1]
w2 = torch._utils._flatten_dense_tensors(self.w2_list).view(len(self.w2_list), inter_shape, hidden_size)
+ if self.fused_gate_up:
+ w2 = w2.transpose(1, 2).contiguous()
if not self.quantized_weight and self.quant_method is not None:
self.w1 = self.quant_method.quantize(w1)
self.w2 = self.quant_method.quantize(w2)
@@ -178,26 +180,53 @@ def _fuse_weight_scale(self):
def load_hf_weights(self, weights):
if self.e_score_correction_bias_name in weights:
self.e_score_correction_bias = self._cuda(weights[self.e_score_correction_bias_name])
- for i_experts in range(self.n_routed_experts):
- w1_weight = f"{self.weight_prefix}.{i_experts}.{self.w1_weight_name}.weight"
- w2_weight = f"{self.weight_prefix}.{i_experts}.{self.w2_weight_name}.weight"
- w3_weight = f"{self.weight_prefix}.{i_experts}.{self.w3_weight_name}.weight"
+ self.fused_gate_up = self.w3_weight_name is None # gate_up: [E,H,2I] down: [E,I,H]
+ key_gateup_3d = f"{self.weight_prefix}.{self.w1_weight_name}" # ...experts.gate_up_proj
+ key_down_3d = f"{self.weight_prefix}.{self.w2_weight_name}"
- if w1_weight in weights:
- self.experts_gate_projs[i_experts] = weights[w1_weight][
- self.split_inter_size * self.tp_rank_ : self.split_inter_size * (self.tp_rank_ + 1), :
- ]
- if w3_weight in weights:
- self.experts_up_projs[i_experts] = weights[w3_weight][
- self.split_inter_size * self.tp_rank_ : self.split_inter_size * (self.tp_rank_ + 1), :
- ]
+ if self.fused_gate_up and (key_gateup_3d in weights) and (key_down_3d in weights):
+ gate_up_3d = weights[key_gateup_3d]
+ down_3d = weights[key_down_3d]
+ assert gate_up_3d.dim() == 3 and down_3d.dim() == 3
- if w2_weight in weights:
- self.w2_list[i_experts] = weights[w2_weight][
- :, self.split_inter_size * self.tp_rank_ : self.split_inter_size * (self.tp_rank_ + 1)
- ]
+ E_ckpt, H_, twoE = gate_up_3d.shape
+ assert E_ckpt == self.n_routed_experts, f"experts mismatch: ckpt {E_ckpt} vs cfg {self.n_routed_experts}"
+ Eint_total = twoE // 2
+ start, end = self.tp_rank_ * self.split_inter_size, (self.tp_rank_ + 1) * self.split_inter_size
+ assert end <= Eint_total, "TP split exceeds total expert-intermediate size"
+
+ for i in range(self.n_routed_experts):
+ gu2d = gate_up_3d[i]
+ gate2d = gu2d[:, :Eint_total][:, start:end].t().contiguous()
+ up2d = gu2d[:, Eint_total:][:, start:end].t().contiguous()
+ self.experts_gate_projs[i] = gate2d
+ self.experts_up_projs[i] = up2d
+
+ self.w2_list[i] = down_3d[i][start:end, :].contiguous()
+ else:
+ for i_experts in range(self.n_routed_experts):
+ w1_weight = f"{self.weight_prefix}.{i_experts}.{self.w1_weight_name}.weight"
+ w2_weight = f"{self.weight_prefix}.{i_experts}.{self.w2_weight_name}.weight"
+ w3_weight = f"{self.weight_prefix}.{i_experts}.{self.w3_weight_name}.weight"
+
+ if w1_weight in weights:
+ self.experts_gate_projs[i_experts] = weights[w1_weight][
+ self.split_inter_size * self.tp_rank_ : self.split_inter_size * (self.tp_rank_ + 1), :
+ ]
+ if w3_weight in weights:
+ self.experts_up_projs[i_experts] = weights[w3_weight][
+ self.split_inter_size * self.tp_rank_ : self.split_inter_size * (self.tp_rank_ + 1), :
+ ]
+
+ if w2_weight in weights:
+ self.w2_list[i_experts] = weights[w2_weight][
+ :, self.split_inter_size * self.tp_rank_ : self.split_inter_size * (self.tp_rank_ + 1)
+ ]
if self.quant_method is not None:
- self._load_weight_scale(weights)
+ if self.fused_gate_up:
+ raise ValueError("qwen3_vl_moe not support quant now")
+ else:
+ self._load_weight_scale(weights)
self._fuse()
def _load_weight_scale(self, weights: Dict[str, torch.Tensor]) -> None:
diff --git a/lightllm/models/__init__.py b/lightllm/models/__init__.py
index d49d8fa75..3eab0f84f 100644
--- a/lightllm/models/__init__.py
+++ b/lightllm/models/__init__.py
@@ -29,6 +29,7 @@
from lightllm.models.internvl.model import InternVLInternlm2TpPartModel
from lightllm.models.qwen2_vl.model import Qwen2VLTpPartModel
from lightllm.models.qwen2_reward.model import Qwen2RewardTpPartModel
+from lightllm.models.qwen3_vl.model import Qwen3VLTpPartModel, Qwen3VLMOETpPartModel
from lightllm.models.gemma3.model import Gemma3TpPartModel
from lightllm.models.tarsier2.model import (
Tarsier2Qwen2TpPartModel,
diff --git a/lightllm/models/llama/model.py b/lightllm/models/llama/model.py
index abc258e8b..d7be45229 100644
--- a/lightllm/models/llama/model.py
+++ b/lightllm/models/llama/model.py
@@ -111,6 +111,8 @@ def _init_custom(self):
if rope_scaling is None:
self._init_to_get_rotary()
return
+ if "mrope_section" in rope_scaling:
+ self.mrope_section = rope_scaling["mrope_section"]
if "rope_type" in rope_scaling:
scaling_type = rope_scaling["rope_type"]
@@ -128,6 +130,8 @@ def _init_custom(self):
self._init_to_get_llama3_rotary()
elif scaling_type == "mrope":
self._init_to_get_mrope_rotary()
+ elif scaling_type == "default":
+ self._init_to_get_rotary()
else:
raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
return
@@ -204,7 +208,7 @@ def _init_to_get_rotary(self, default_base=10000):
/ rope_scaling_factor
)
freqs = torch.outer(t, inv_freq)
-
+ self.freqs = freqs.cuda()
self._cos_cached = torch.cos(freqs).to(self.data_type).cuda()
self._sin_cached = torch.sin(freqs).to(self.data_type).cuda()
return
diff --git a/lightllm/models/qwen2_5_vl/qwen2_5_visual.py b/lightllm/models/qwen2_5_vl/qwen2_5_visual.py
index 01e3a7268..c973716ff 100644
--- a/lightllm/models/qwen2_5_vl/qwen2_5_visual.py
+++ b/lightllm/models/qwen2_5_vl/qwen2_5_visual.py
@@ -383,7 +383,12 @@ def encode(self, images: List[ImageItem]):
uuids.append(img.uuid)
image_data = read_shm(get_shm_name_data(img.uuid))
image_data = Image.open(BytesIO(image_data))
- image_data = resize_image(image_data)
+ image_data = resize_image(
+ image_file=image_data,
+ factor=self.processor.patch_size * self.processor.merge_size,
+ min_pixels=self.processor.min_pixels,
+ max_pixels=self.processor.max_pixels,
+ )
pixel_values, image_grid_thw = self.processor.preprocess(image_data)
img_tensors.append(pixel_values)
img_grids.append(image_grid_thw)
diff --git a/lightllm/models/qwen2_vl/model.py b/lightllm/models/qwen2_vl/model.py
index 6d179e6f9..d98a0c130 100644
--- a/lightllm/models/qwen2_vl/model.py
+++ b/lightllm/models/qwen2_vl/model.py
@@ -51,8 +51,9 @@ def init_audioitem_extral_params(
def get_image_token_length(self, img: ImageItem):
width, height = img.image_w, img.image_h
+ factor = self.patch_size * self.merge_size
resized_height, resized_width = smart_resize(
- height=height, width=width, min_pixels=self.min_pixel, max_pixels=self.max_pixel
+ height=height, width=width, factor=factor, min_pixels=self.min_pixel, max_pixels=self.max_pixel
)
grid_h, grid_w = resized_height // self.patch_size, resized_width // self.patch_size
token_num = (grid_h * grid_w) // (self.merge_size ** 2)
diff --git a/lightllm/models/qwen2_vl/qwen2_visual.py b/lightllm/models/qwen2_vl/qwen2_visual.py
index 68e161737..af30bdbe0 100644
--- a/lightllm/models/qwen2_vl/qwen2_visual.py
+++ b/lightllm/models/qwen2_vl/qwen2_visual.py
@@ -311,7 +311,12 @@ def encode(self, images: List[ImageItem]):
uuids.append(img.uuid)
image_data = read_shm(get_shm_name_data(img.uuid))
image_data = Image.open(BytesIO(image_data))
- image_data = resize_image(image_data)
+ image_data = resize_image(
+ image_file=image_data,
+ factor=self.processor.patch_size * self.processor.merge_size,
+ min_pixels=self.processor.min_pixels,
+ max_pixels=self.processor.max_pixels,
+ )
pixel_values, image_grid_thw = self.processor.preprocess(image_data)
img_tensors.append(pixel_values)
img_grids.append(image_grid_thw)
diff --git a/lightllm/models/qwen2_vl/vision_process.py b/lightllm/models/qwen2_vl/vision_process.py
index 692f3aac3..75ea36f6d 100644
--- a/lightllm/models/qwen2_vl/vision_process.py
+++ b/lightllm/models/qwen2_vl/vision_process.py
@@ -35,16 +35,16 @@ def smart_resize(
height: int, width: int, factor: int = IMAGE_FACTOR, min_pixels: int = MIN_PIXELS, max_pixels: int = MAX_PIXELS
) -> tuple[int, int]:
- if max(height, width) / min(height, width) > MAX_RATIO:
+ if max(height, width) / min(height, width) > 200:
raise ValueError(
- f"absolute aspect ratio must be smaller than {MAX_RATIO}, got {max(height, width) / min(height, width)}"
+ f"absolute aspect ratio must be smaller than 200, got {max(height, width) / min(height, width)}"
)
- h_bar = max(factor, round(height / factor) * factor)
- w_bar = max(factor, round(width / factor) * factor)
+ h_bar = round(height / factor) * factor
+ w_bar = round(width / factor) * factor
if h_bar * w_bar > max_pixels:
beta = math.sqrt((height * width) / max_pixels)
- h_bar = math.floor(height / beta / factor) * factor
- w_bar = math.floor(width / beta / factor) * factor
+ h_bar = max(factor, math.floor(height / beta / factor) * factor)
+ w_bar = max(factor, math.floor(width / beta / factor) * factor)
elif h_bar * w_bar < min_pixels:
beta = math.sqrt(min_pixels / (height * width))
h_bar = math.ceil(height * beta / factor) * factor
@@ -52,7 +52,9 @@ def smart_resize(
return h_bar, w_bar
-def resize_image(image_file: Image.Image, size_factor: int = IMAGE_FACTOR) -> tuple[Image.Image, int, int]:
+def resize_image(
+ image_file: Image.Image, factor: int = IMAGE_FACTOR, min_pixels: int = MIN_PIXELS, max_pixels: int = MAX_PIXELS
+) -> tuple[Image.Image, int, int]:
image = image_file.convert("RGB")
width, height = image.size
@@ -60,9 +62,9 @@ def resize_image(image_file: Image.Image, size_factor: int = IMAGE_FACTOR) -> tu
resized_height, resized_width = smart_resize(
height,
width,
- factor=size_factor,
- min_pixels=MIN_PIXELS,
- max_pixels=MAX_PIXELS,
+ factor=factor,
+ min_pixels=min_pixels,
+ max_pixels=max_pixels,
)
image = image.resize((resized_width, resized_height))
@@ -72,6 +74,7 @@ def resize_image(image_file: Image.Image, size_factor: int = IMAGE_FACTOR) -> tu
class Qwen2VLImageProcessor(BaseImageProcessor):
def __init__(
self,
+ size: dict = None,
do_resize: bool = True,
resample: PILImageResampling = PILImageResampling.BICUBIC,
do_rescale: bool = True,
@@ -88,6 +91,7 @@ def __init__(
**kwargs,
) -> None:
super().__init__(**kwargs)
+ self.size = size
self.do_resize = do_resize
self.resample = resample
self.do_rescale = do_rescale
@@ -102,6 +106,13 @@ def __init__(
self.temporal_patch_size = temporal_patch_size
self.merge_size = merge_size
self.data_format = ChannelDimension.FIRST
+ if isinstance(self.size, dict):
+ shortest = self.size.get("shortest_edge", None)
+ longest = self.size.get("longest_edge", None)
+ if shortest is not None:
+ self.min_pixels = shortest
+ if longest is not None:
+ self.max_pixels = longest
def preprocess(self, image) -> Tuple[torch.Tensor, torch.Tensor]:
if self.do_convert_rgb:
diff --git a/lightllm/models/qwen3_vl/infer_struct.py b/lightllm/models/qwen3_vl/infer_struct.py
new file mode 100644
index 000000000..6c7ac96f9
--- /dev/null
+++ b/lightllm/models/qwen3_vl/infer_struct.py
@@ -0,0 +1,51 @@
+import torch
+import numpy as np
+from lightllm.models.llama.infer_struct import LlamaInferStateInfo
+from lightllm.common.basemodel.infer_struct import InferStateInfo
+
+
+class Qwen3VLInferStateInfo(LlamaInferStateInfo):
+ def __init__(self):
+ super().__init__()
+ self.deepstack_features = []
+ self.img_first_token_locs = []
+ self.img_last_token_locs = []
+
+ def apply_interleaved_mrope(self, freqs, mrope_section):
+ """Apply interleaved MRoPE to 3D rotary embeddings.
+ Reorganizes frequency layout from chunked [TTT...HHH...WWW] to
+ interleaved [THTHWHTHW...TT], preserving frequency continuity.
+ args:
+ x: (3, bs, seq_len, head_dim // 2)
+ mrope_section: (3,)
+ returns:
+ x_t: (bs, seq_len, head_dim // 2)
+ """
+ freqs_t = freqs[0] # just overwrite the first dimension T
+ for dim, offset in enumerate((1, 2), start=1): # H, W
+ length = mrope_section[dim] * 3
+ idx = slice(offset, length, 3)
+ freqs_t[..., idx] = freqs[dim, ..., idx]
+ return freqs_t
+
+ def init_some_extra_state(self, model, input_ids: torch.Tensor):
+ InferStateInfo.init_some_extra_state(self, model, input_ids)
+ pos = self.position_ids[None, :].expand(3, -1)
+ cos_T = torch.index_select(model._cos_cached, 0, pos[0]) # [L, d/2]
+ cos_H = torch.index_select(model._cos_cached, 0, pos[1])
+ cos_W = torch.index_select(model._cos_cached, 0, pos[2])
+ sin_T = torch.index_select(model._sin_cached, 0, pos[0])
+ sin_H = torch.index_select(model._sin_cached, 0, pos[1])
+ sin_W = torch.index_select(model._sin_cached, 0, pos[2])
+ cos_half = self.apply_interleaved_mrope(
+ torch.stack([cos_T, cos_H, cos_W], dim=0), model.mrope_section
+ ) # [L, d/2]
+ sin_half = self.apply_interleaved_mrope(
+ torch.stack([sin_T, sin_H, sin_W], dim=0), model.mrope_section
+ ) # [L, d/2]
+
+ self.position_cos = torch.cat([cos_half, cos_half], dim=-1).contiguous() # [L, d]
+ self.position_sin = torch.cat([sin_half, sin_half], dim=-1).contiguous()
+ if self.is_prefill:
+ pos = None
+ return
diff --git a/lightllm/models/qwen3_vl/layer_infer/pre_layer_infer.py b/lightllm/models/qwen3_vl/layer_infer/pre_layer_infer.py
new file mode 100644
index 000000000..f82c4e55b
--- /dev/null
+++ b/lightllm/models/qwen3_vl/layer_infer/pre_layer_infer.py
@@ -0,0 +1,8 @@
+from lightllm.models.qwen_vl.layer_infer.pre_layer_infer import LlamaMultimodalPreLayerInfer
+
+
+class Qwen3VLMultimodalPreLayerInfer(LlamaMultimodalPreLayerInfer):
+ def __init__(self, network_config, mode):
+ super().__init__(network_config, mode)
+ self.use_deepstack = True
+ return
diff --git a/lightllm/models/qwen3_vl/layer_infer/transformer_layer_infer.py b/lightllm/models/qwen3_vl/layer_infer/transformer_layer_infer.py
new file mode 100644
index 000000000..a50f817a6
--- /dev/null
+++ b/lightllm/models/qwen3_vl/layer_infer/transformer_layer_infer.py
@@ -0,0 +1,64 @@
+import torch
+import torch.functional as F
+import torch.distributed as dist
+import numpy as np
+from functools import partial
+from typing import Tuple
+from lightllm.common.basemodel.infer_struct import InferStateInfo
+from lightllm.models.qwen2_vl.triton_kernel.mrope import mrope_triton
+from lightllm.models.qwen3.layer_infer.transformer_layer_infer import Qwen3TransformerLayerInfer
+from lightllm.models.qwen3.layer_weights.transformer_layer_weight import Qwen3TransformerLayerWeight
+from lightllm.models.llama.layer_infer.transformer_layer_infer import LlamaTransformerLayerInfer
+from lightllm.models.llama.infer_struct import LlamaInferStateInfo
+from lightllm.models.qwen3_vl.infer_struct import Qwen3VLInferStateInfo
+from lightllm.models.llama.triton_kernel.rmsnorm import rmsnorm_forward
+from lightllm.models.llama.triton_kernel.rotary_emb import rotary_emb_fwd
+from lightllm.models.llama.triton_kernel.silu_and_mul import silu_and_mul_fwd
+from lightllm.distributed import all_reduce
+from lightllm.utils.dist_utils import get_global_world_size
+
+
+class Qwen3VLTransformerLayerInfer(Qwen3TransformerLayerInfer):
+ def __init__(self, layer_num, network_config, mode=[]):
+ super().__init__(layer_num, network_config, mode)
+ self.mrope_section = network_config["rope_scaling"]["mrope_section"]
+ axis_map = []
+ for i, n in enumerate(self.mrope_section * 2):
+ axis_map += [i % 3] * n
+ self.axis_map = torch.tensor(axis_map, dtype=torch.int32, device="cuda")
+
+ def context_forward(self, input_embdings, infer_state: Qwen3VLInferStateInfo, layer_weight):
+ input1 = self._att_norm(input_embdings, infer_state, layer_weight)
+ q, cache_kv = self._get_qkv(input1, infer_state, layer_weight)
+ input1 = None
+ self._post_cache_kv(cache_kv, infer_state, layer_weight)
+ o = self._context_attention_kernel(q, cache_kv, infer_state, layer_weight)
+ q = None
+ o = self._get_o(o, infer_state, layer_weight)
+ if self.tp_world_size_ > 1:
+ all_reduce(o, op=dist.ReduceOp.SUM, group=infer_state.dist_group, async_op=False)
+ input_embdings.add_(o.view(-1, self.embed_dim_))
+ o = None
+
+ input1 = self._ffn_norm(input_embdings, infer_state, layer_weight)
+ ffn_out = self._ffn(input1, infer_state, layer_weight)
+ input1 = None
+ if self.tp_world_size_ > 1:
+ all_reduce(ffn_out, op=dist.ReduceOp.SUM, group=infer_state.dist_group, async_op=False)
+ input_embdings.add_(ffn_out.view(-1, self.embed_dim_))
+ if infer_state.deepstack_features:
+ for i in range(len(infer_state.img_first_token_locs)):
+ start = infer_state.img_first_token_locs[i]
+ end = infer_state.img_last_token_locs[i]
+ deepstack_features = infer_state.deepstack_features[i]
+ if end <= input_embdings.shape[0] and self.layer_num_ in range(len(deepstack_features)):
+ deepstack_features_cur_layer = deepstack_features[self.layer_num_].to(
+ device=input_embdings.device, non_blocking=True
+ )
+ input_embdings[
+ start:end,
+ ].add_(deepstack_features_cur_layer)
+ infer_state.img_first_token_locs = []
+ infer_state.img_last_token_locs = []
+ infer_state.deepstack_features = []
+ return input_embdings
diff --git a/lightllm/models/qwen3_vl/layer_weights/pre_and_post_layer_weight.py b/lightllm/models/qwen3_vl/layer_weights/pre_and_post_layer_weight.py
new file mode 100644
index 000000000..98aae3f37
--- /dev/null
+++ b/lightllm/models/qwen3_vl/layer_weights/pre_and_post_layer_weight.py
@@ -0,0 +1,37 @@
+import torch
+import numpy as np
+from lightllm.common.basemodel import PreAndPostLayerWeight
+
+
+class Qwen3VLPreAndPostLayerWeight(PreAndPostLayerWeight):
+ def __init__(self, data_type, network_config, mode):
+ super().__init__(data_type, network_config, mode)
+ return
+
+ def load_hf_weights(self, weights):
+ vob_size = self.network_config_["vocab_size"]
+ split_indexes = np.linspace(0, vob_size, self.tp_world_size_ + 1, dtype=np.int64)
+ split_start = split_indexes[self.tp_rank_]
+ split_end = split_indexes[self.tp_rank_ + 1]
+ if "model.language_model.embed_tokens.weight" in weights:
+ self.wte_weight_ = self._cuda(weights["model.language_model.embed_tokens.weight"][split_start:split_end, :])
+ tie_word_embeddings = self.network_config_.get("tie_word_embeddings", False)
+ if tie_word_embeddings:
+ self.lm_head_weight_ = self.wte_weight_
+ if "lm_head.weight" in weights:
+ self.lm_head_weight_ = self._cuda(weights["lm_head.weight"][split_start:split_end, :])
+ if "model.language_model.norm.weight" in weights:
+ self.final_norm_weight_ = self._cuda(weights["model.language_model.norm.weight"])
+
+ return
+
+ def verify_load(self):
+ errors = "weights load not ok"
+ weights = [
+ self.wte_weight_,
+ self.lm_head_weight_,
+ self.final_norm_weight_,
+ ]
+ for i in range(len(weights)):
+ assert weights[i] is not None, "index:" + str(i) + " " + errors
+ return
diff --git a/lightllm/models/qwen3_vl/layer_weights/transformers_layer_weight.py b/lightllm/models/qwen3_vl/layer_weights/transformers_layer_weight.py
new file mode 100644
index 000000000..f51025527
--- /dev/null
+++ b/lightllm/models/qwen3_vl/layer_weights/transformers_layer_weight.py
@@ -0,0 +1,44 @@
+import os
+import torch
+import math
+import numpy as np
+from lightllm.models.qwen3.layer_weights.transformer_layer_weight import Qwen3TransformerLayerWeight
+from lightllm.models.qwen3_moe.layer_weights.transformer_layer_weight import Qwen3MOETransformerLayerWeight
+from lightllm.common.basemodel.layer_weights.meta_weights import (
+ ROWMMWeight,
+ MultiROWMMWeight,
+ COLMMWeight,
+ NormWeight,
+ FusedMoeWeightTP,
+ FusedMoeWeightEP,
+ ROWBMMWeight,
+)
+
+
+class Qwen3VLTransformerLayerWeight(Qwen3TransformerLayerWeight): # 后面看要不要改
+ def __init__(self, layer_num, data_type, network_config, mode=[], quant_cfg=None):
+ super().__init__(layer_num, data_type, network_config, mode, quant_cfg)
+
+ def _init_weight_names(self):
+ super()._init_weight_names()
+ self._q_weight_name = f"model.language_model.layers.{self.layer_num_}.self_attn.q_proj.weight"
+ self._q_norm_name = f"model.language_model.layers.{self.layer_num_}.self_attn.q_norm.weight"
+ self._q_bias_name = None
+ self._k_weight_name = f"model.language_model.layers.{self.layer_num_}.self_attn.k_proj.weight"
+ self._k_norm_name = f"model.language_model.layers.{self.layer_num_}.self_attn.k_norm.weight"
+ self._k_bias_name = None
+ self._v_weight_name = f"model.language_model.layers.{self.layer_num_}.self_attn.v_proj.weight"
+ self._v_bias_name = None
+ self._o_weight_name = f"model.language_model.layers.{self.layer_num_}.self_attn.o_proj.weight"
+ self._o_bias_name = None
+ self._att_norm_weight_name = f"model.language_model.layers.{self.layer_num_}.input_layernorm.weight"
+ self._att_norm_bias_name = None
+ self._ffn_norm_weight_name = f"model.language_model.layers.{self.layer_num_}.post_attention_layernorm.weight"
+ self._ffn_norm_bias_name = None
+
+ self._gate_weight_name = f"model.language_model.layers.{self.layer_num_}.mlp.gate_proj.weight"
+ self._gate_bias_name = None
+ self._up_weight_name = f"model.language_model.layers.{self.layer_num_}.mlp.up_proj.weight"
+ self._up_bias_name = None
+ self._down_weight_name = f"model.language_model.layers.{self.layer_num_}.mlp.down_proj.weight"
+ self._down_bias_name = None
diff --git a/lightllm/models/qwen3_vl/model.py b/lightllm/models/qwen3_vl/model.py
new file mode 100644
index 000000000..6fed68488
--- /dev/null
+++ b/lightllm/models/qwen3_vl/model.py
@@ -0,0 +1,155 @@
+import json
+import numpy as np
+import unicodedata
+from lightllm.common.basemodel.multimodal_tokenizer import BaseMultiModalTokenizer
+from lightllm.models.qwen.model import QWenTpPartModel
+from lightllm.models.qwen_vl.layer_infer.pre_layer_infer import LlamaMultimodalPreLayerInfer
+from lightllm.server.multimodal_params import AudioItem, MultimodalParams, ImageItem
+from transformers.feature_extraction_utils import BatchFeature
+from transformers.image_utils import ImageInput
+from transformers.processing_utils import ProcessorMixin
+from lightllm.server.core.objs import SamplingParams
+from transformers.tokenization_utils_base import PaddingStrategy, PreTokenizedInput, TextInput, TruncationStrategy
+from typing import List, Optional, Union
+from transformers.utils import TensorType, logging
+from lightllm.models.qwen2_vl.flashattention_infer_struct import Qwen2VLFlashAttentionStateInfo
+from lightllm.common.build_utils import repair_config
+from lightllm.models.registry import ModelRegistry
+from lightllm.models.qwen3_vl.infer_struct import Qwen3VLInferStateInfo
+from lightllm.models.qwen3_vl.layer_infer.pre_layer_infer import Qwen3VLMultimodalPreLayerInfer
+from lightllm.models.qwen3_vl.layer_infer.transformer_layer_infer import Qwen3VLTransformerLayerInfer
+from lightllm.models.qwen3_vl.layer_weights.pre_and_post_layer_weight import Qwen3VLPreAndPostLayerWeight
+from lightllm.models.qwen3_vl.layer_weights.transformers_layer_weight import Qwen3VLTransformerLayerWeight
+from lightllm.models.qwen3_vl_moe.layer_weights.transformers_layer_weight import Qwen3VLMOETransformerLayerWeight
+from lightllm.models.qwen3_vl_moe.layer_infer.transformer_layer_infer import Qwen3VLMOETransformerLayerInfer
+
+import torch
+from PIL import Image
+from lightllm.models.qwen2_vl.vision_process import smart_resize
+from lightllm.utils.envs_utils import enable_env_vars, get_env_start_args
+from lightllm.models.qwen3.model import Qwen3TpPartModel
+from lightllm.models.qwen3_moe.model import Qwen3MOEModel
+import os
+
+
+class QWen3VLTokenizer(BaseMultiModalTokenizer):
+ def __init__(self, tokenizer=None, image_processor=None, **kwargs):
+ super().__init__(tokenizer)
+ self.image_processor = image_processor
+ self.min_pixel = self.image_processor.size["shortest_edge"]
+ self.max_pixel = self.image_processor.size["longest_edge"]
+ self.patch_size = self.image_processor.patch_size
+ self.merge_size = self.image_processor.merge_size
+ self.image_start_id = kwargs["model_cfg"]["vision_start_token_id"]
+ self.image_end_id = kwargs["model_cfg"]["vision_end_token_id"]
+ self.image_token_id = kwargs["model_cfg"]["image_token_id"]
+
+ def init_imageitem_extral_params(
+ self, img: ImageItem, multi_params: MultimodalParams, sampling_params: SamplingParams
+ ):
+ return
+
+ def init_audioitem_extral_params(
+ self, audio: AudioItem, multi_params: MultimodalParams, sampling_params: SamplingParams
+ ):
+ raise NotImplementedError
+
+ def get_image_token_length(self, img: ImageItem):
+ width, height = img.image_w, img.image_h
+ factor = self.patch_size * self.merge_size
+ resized_height, resized_width = smart_resize(
+ height=height, width=width, factor=factor, min_pixels=self.min_pixel, max_pixels=self.max_pixel
+ )
+ grid_h, grid_w = resized_height // self.patch_size, resized_width // self.patch_size
+ token_num = (grid_h * grid_w) // (self.merge_size ** 2)
+ return token_num
+
+ def get_audio_token_length(self, audio: AudioItem):
+ raise NotImplementedError
+
+ def encode(self, prompt, multimodal_params: MultimodalParams = None, **kwargs):
+
+ origin_ids = self.tokenizer.encode(prompt)
+
+ #
->
+ origin_ids = [token for token in origin_ids if token != self.image_token_id]
+ #
-->
id,id+1...id+num
+ input_ids = []
+ image_id = 0
+ start_idx = 0
+ while True:
+ try:
+ start_idx = origin_ids.index(self.image_start_id, start_idx)
+ if start_idx + 1 >= len(origin_ids):
+ break
+ if origin_ids[start_idx + 1] == self.image_end_id:
+ input_ids.extend(origin_ids[: start_idx + 1])
+ token_id = multimodal_params.images[image_id].token_id
+ token_num = multimodal_params.images[image_id].token_num
+ input_ids.extend(range(token_id, token_id + token_num))
+ input_ids.append(self.image_end_id)
+ origin_ids = origin_ids[start_idx + 2 :]
+ start_idx = 0
+ image_id += 1
+ else:
+ raise ValueError("image token error")
+ except ValueError:
+ break
+ input_ids.extend(origin_ids[start_idx:])
+ return input_ids
+
+
+@ModelRegistry(["qwen3_vl"], is_multimodal=True)
+class Qwen3VLTpPartModel(Qwen3TpPartModel):
+
+ pre_layer_infer_class = Qwen3VLMultimodalPreLayerInfer
+ transformer_layer_infer_class = Qwen3VLTransformerLayerInfer
+
+ pre_and_post_weight_class = Qwen3VLPreAndPostLayerWeight
+ transformer_weight_class = Qwen3VLTransformerLayerWeight
+
+ infer_state_class = Qwen3VLInferStateInfo
+
+ def __init__(self, kvargs):
+ super().__init__(kvargs)
+ return
+
+ def _init_config(self):
+ with open(os.path.join(self.weight_dir_, "config.json"), "r") as json_file:
+ all_config = json.load(json_file)
+ self.config = all_config["text_config"]
+ # rename keys
+ repair_config(self.config, same_names=["num_attention_heads", "n_head"])
+ repair_config(self.config, same_names=["hidden_size", "n_embd", "n_embed"])
+ repair_config(self.config, same_names=["num_hidden_layers", "n_layer"])
+ if self.finetune_config:
+ self.config["vocab_size"] = self.finetune_config.vocab_size
+ return
+
+
+@ModelRegistry(["qwen3_vl_moe"], is_multimodal=True)
+class Qwen3VLMOETpPartModel(Qwen3MOEModel):
+
+ pre_layer_infer_class = Qwen3VLMultimodalPreLayerInfer
+ transformer_layer_infer_class = Qwen3VLMOETransformerLayerInfer
+
+ pre_and_post_weight_class = Qwen3VLPreAndPostLayerWeight
+ transformer_weight_class = Qwen3VLMOETransformerLayerWeight
+
+ infer_state_class = Qwen3VLInferStateInfo
+
+ def __init__(self, kvargs):
+ super().__init__(kvargs)
+ return
+
+ def _init_config(self):
+ with open(os.path.join(self.weight_dir_, "config.json"), "r") as json_file:
+ all_config = json.load(json_file)
+ self.config = all_config["text_config"]
+ # rename keys
+ repair_config(self.config, same_names=["num_attention_heads", "n_head"])
+ repair_config(self.config, same_names=["hidden_size", "n_embd", "n_embed"])
+ repair_config(self.config, same_names=["num_hidden_layers", "n_layer"])
+ if self.finetune_config:
+ self.config["vocab_size"] = self.finetune_config.vocab_size
+ return
diff --git a/lightllm/models/qwen3_vl/qwen3_visual.py b/lightllm/models/qwen3_vl/qwen3_visual.py
new file mode 100644
index 000000000..02b7fa73e
--- /dev/null
+++ b/lightllm/models/qwen3_vl/qwen3_visual.py
@@ -0,0 +1,397 @@
+# coding=utf-8
+# Copyright 2025 The Qwen Team and The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+import json
+from PIL import Image
+from io import BytesIO
+from typing import List
+from safetensors import safe_open
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from transformers.activations import ACT2FN
+
+from lightllm.server.multimodal_params import ImageItem
+from lightllm.server.embed_cache.utils import read_shm, get_shm_name_data
+from lightllm.models.qwen2_vl.vision_process import resize_image, Qwen2VLImageProcessor
+from lightllm.models.qwen2_vl.qwen2_visual import VisionRotaryEmbedding, VisionFlashAttention
+
+
+class Qwen3VLVisionMLP(nn.Module):
+ def __init__(self, hidden_size: int, intermediate_size: int, hidden_act: str):
+ super().__init__()
+ self.hidden_size = hidden_size
+ self.intermediate_size = intermediate_size
+ self.linear_fc1 = nn.Linear(self.hidden_size, self.intermediate_size, bias=True)
+ self.linear_fc2 = nn.Linear(self.intermediate_size, self.hidden_size, bias=True)
+ self.act_fn = ACT2FN[hidden_act]
+
+ def forward(self, hidden_state):
+ return self.linear_fc2(self.act_fn(self.linear_fc1(hidden_state)))
+
+
+class Qwen3VLPatchEmbed(nn.Module):
+ def __init__(
+ self,
+ patch_size: int = 14,
+ temporal_patch_size: int = 2,
+ in_channels: int = 3,
+ embed_dim: int = 1152,
+ ) -> None:
+ super().__init__()
+ self.patch_size = patch_size
+ self.temporal_patch_size = temporal_patch_size
+ self.in_channels = in_channels
+ self.embed_dim = embed_dim
+
+ kernel_size = [temporal_patch_size, patch_size, patch_size]
+ self.proj = nn.Conv3d(in_channels, embed_dim, kernel_size=kernel_size, stride=kernel_size, bias=True)
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ target_dtype = self.proj.weight.dtype
+ if hidden_states.dtype != target_dtype:
+ hidden_states = hidden_states.to(target_dtype)
+ hidden_states = hidden_states.view(
+ -1, self.in_channels, self.temporal_patch_size, self.patch_size, self.patch_size
+ )
+ hidden_states = self.proj(hidden_states).view(-1, self.embed_dim)
+ return hidden_states
+
+
+class Qwen3VLVisionPatchMerger(nn.Module):
+ def __init__(self, hidden_size, out_hidden_size, spatial_merge_size, use_postshuffle_norm=False) -> None:
+ super().__init__()
+ self.hidden_size = hidden_size * (spatial_merge_size ** 2)
+ self.use_postshuffle_norm = use_postshuffle_norm
+ self.norm = nn.LayerNorm(self.hidden_size if use_postshuffle_norm else hidden_size, eps=1e-6)
+ self.linear_fc1 = nn.Linear(self.hidden_size, self.hidden_size)
+ self.act_fn = nn.GELU()
+ self.linear_fc2 = nn.Linear(self.hidden_size, out_hidden_size)
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ x = self.norm(x.view(-1, self.hidden_size) if self.use_postshuffle_norm else x).view(-1, self.hidden_size)
+ x = self.linear_fc2(self.act_fn(self.linear_fc1(x)))
+ return x
+
+
+class Qwen3VLVisionBlock(nn.Module):
+ def __init__(self, hidden_size, intermediate_size, num_heads, hidden_act) -> None:
+ super().__init__()
+ self.norm1 = nn.LayerNorm(hidden_size, eps=1e-6)
+ self.norm2 = nn.LayerNorm(hidden_size, eps=1e-6)
+
+ self.attn = VisionFlashAttention(hidden_size, num_heads=num_heads)
+ self.mlp = Qwen3VLVisionMLP(hidden_size, intermediate_size, hidden_act)
+
+ def forward(self, hidden_states, cu_seqlens, max_seqlen, rotary_cos, rotary_sin) -> torch.Tensor:
+ hidden_states = hidden_states + self.attn(
+ self.norm1(hidden_states),
+ cu_seqlens=cu_seqlens,
+ max_seqlen=max_seqlen,
+ rotary_cos=rotary_cos,
+ rotary_sin=rotary_sin,
+ )
+ hidden_states = hidden_states + self.mlp(self.norm2(hidden_states))
+ return hidden_states
+
+
+class Qwen3VisionTransformerPretrainedModel(nn.Module):
+ def __init__(
+ self,
+ kvargs,
+ depth=27,
+ out_hidden_size=4096,
+ hidden_size=1152,
+ hidden_act="gelu_pytorch_tanh",
+ intermediate_size=4304,
+ deepstack_visual_indexes=[8, 16, 24],
+ num_heads=16,
+ in_channels=3,
+ patch_size=16,
+ spatial_merge_size=2,
+ temporal_patch_size=2,
+ num_position_embeddings=2304,
+ **kwargs,
+ ):
+ super().__init__()
+ self.data_type = kvargs.get("data_type", "bfloat16")
+
+ self.depth = depth
+ self.out_hidden_size = out_hidden_size
+ self.hidden_size = hidden_size
+ self.hidden_act = hidden_act
+ self.num_heads = num_heads
+ self.in_channels = in_channels
+ self.patch_size = patch_size
+ self.intermediate_size = intermediate_size
+ self.spatial_merge_size = spatial_merge_size
+ self.temporal_patch_size = temporal_patch_size
+ self.num_position_embeddings = num_position_embeddings
+ self.spatial_merge_unit = self.spatial_merge_size * self.spatial_merge_size
+
+ self.patch_embed = Qwen3VLPatchEmbed(
+ patch_size=self.patch_size,
+ temporal_patch_size=self.temporal_patch_size,
+ in_channels=self.in_channels,
+ embed_dim=self.hidden_size,
+ )
+
+ self.pos_embed = nn.Embedding(self.num_position_embeddings, self.hidden_size)
+ self.num_grid_per_side = int(self.num_position_embeddings ** 0.5)
+
+ head_dim = self.hidden_size // self.num_heads
+ self.rotary_pos_emb = VisionRotaryEmbedding(head_dim // 2).cuda()
+
+ self.blocks = nn.ModuleList(
+ [
+ Qwen3VLVisionBlock(self.hidden_size, self.intermediate_size, self.num_heads, self.hidden_act)
+ for _ in range(self.depth)
+ ]
+ )
+ self.merger = Qwen3VLVisionPatchMerger(
+ hidden_size=self.hidden_size,
+ out_hidden_size=self.out_hidden_size,
+ spatial_merge_size=self.spatial_merge_size,
+ use_postshuffle_norm=False,
+ )
+ self.deepstack_visual_indexes = deepstack_visual_indexes
+ self.deepstack_merger_list = nn.ModuleList(
+ [
+ Qwen3VLVisionPatchMerger(
+ hidden_size=self.hidden_size,
+ out_hidden_size=self.out_hidden_size,
+ spatial_merge_size=self.spatial_merge_size,
+ use_postshuffle_norm=True,
+ )
+ for _ in range(len(self.deepstack_visual_indexes))
+ ]
+ )
+ self._init_datatype()
+
+ def _init_datatype(self):
+ if isinstance(self.data_type, torch.dtype):
+ return
+ if self.data_type in ["fp16", "float16"]:
+ self.data_type = torch.float16
+ elif self.data_type in ["bf16", "bfloat16"]:
+ self.data_type = torch.bfloat16
+ elif self.data_type in ["fp32", "float32"]:
+ self.data_type = torch.float32
+ else:
+ raise ValueError(f"Unsupport datatype {self.data_type}!")
+ return
+
+ def load_model(self, weight_dir):
+
+ processor_config_path = os.path.join(weight_dir, "preprocessor_config.json")
+ with open(processor_config_path, "r") as f:
+ processor_config_dict = json.load(f)
+ self.processor = Qwen2VLImageProcessor(**processor_config_dict)
+
+ bin_weight_files = [file_ for file_ in os.listdir(weight_dir) if file_.endswith(".bin")]
+ if bin_weight_files:
+ weight_dict = {}
+ for file_ in bin_weight_files:
+ f = torch.load(os.path.join(weight_dir, file_), "cpu")
+ for k, v in f.items():
+ if "model.visual" in k:
+ weight_dict[k[len("model.visual.") :]] = v
+ else:
+ hf_weight_files = [file_ for file_ in os.listdir(weight_dir) if file_.endswith(".safetensors")]
+ weight_dict = {}
+ for file_ in hf_weight_files:
+ f = safe_open(os.path.join(weight_dir, file_), "pt", "cpu")
+ for k in f.keys():
+ if "model.visual" in k:
+ weight_dict[k[len("model.visual.") :]] = f.get_tensor(k)
+
+ self.load_state_dict(weight_dict)
+
+ def rot_pos_emb(self, grid_thw: torch.Tensor) -> torch.Tensor:
+ merge_size = self.spatial_merge_size
+
+ total_tokens = int(torch.prod(grid_thw, dim=1).sum().item())
+ pos_ids = torch.empty((total_tokens, 2), dtype=torch.long)
+
+ offset = 0
+ for num_frames, height, width in grid_thw:
+ merged_h, merged_w = height // merge_size, width // merge_size
+
+ block_rows = torch.arange(merged_h) # block row indices
+ block_cols = torch.arange(merged_w) # block col indices
+ intra_row = torch.arange(merge_size) # intra-block row offsets
+ intra_col = torch.arange(merge_size) # intra-block col offsets
+
+ # Compute full-resolution positions
+ row_idx = block_rows[:, None, None, None] * merge_size + intra_row[None, None, :, None]
+ col_idx = block_cols[None, :, None, None] * merge_size + intra_col[None, None, None, :]
+
+ row_idx = row_idx.expand(merged_h, merged_w, merge_size, merge_size).reshape(-1)
+ col_idx = col_idx.expand(merged_h, merged_w, merge_size, merge_size).reshape(-1)
+
+ coords = torch.stack((row_idx, col_idx), dim=-1)
+
+ if num_frames > 1:
+ coords = coords.repeat(num_frames, 1)
+
+ num_tokens = coords.shape[0]
+ pos_ids[offset : offset + num_tokens] = coords
+ offset += num_tokens
+
+ max_hw = int(grid_thw[:, 1:].max().item())
+ cos_full, sin_full = self.rotary_pos_emb(max_hw) # (max_hw, dim // 2)
+ cos = cos_full[pos_ids].flatten(1)
+ sin = sin_full[pos_ids].flatten(1)
+ return cos, sin
+
+ def fast_pos_embed_interpolate(self, grid_thw):
+ grid_ts, grid_hs, grid_ws = grid_thw[:, 0], grid_thw[:, 1], grid_thw[:, 2]
+
+ idx_list = [[] for _ in range(4)]
+ weight_list = [[] for _ in range(4)]
+
+ for t, h, w in zip(grid_ts, grid_hs, grid_ws):
+ h_idxs = torch.linspace(0, self.num_grid_per_side - 1, h)
+ w_idxs = torch.linspace(0, self.num_grid_per_side - 1, w)
+
+ h_idxs_floor = h_idxs.int()
+ w_idxs_floor = w_idxs.int()
+ h_idxs_ceil = (h_idxs.int() + 1).clip(max=self.num_grid_per_side - 1)
+ w_idxs_ceil = (w_idxs.int() + 1).clip(max=self.num_grid_per_side - 1)
+
+ dh = h_idxs - h_idxs_floor
+ dw = w_idxs - w_idxs_floor
+
+ base_h = h_idxs_floor * self.num_grid_per_side
+ base_h_ceil = h_idxs_ceil * self.num_grid_per_side
+
+ indices = [
+ (base_h[None].T + w_idxs_floor[None]).flatten(),
+ (base_h[None].T + w_idxs_ceil[None]).flatten(),
+ (base_h_ceil[None].T + w_idxs_floor[None]).flatten(),
+ (base_h_ceil[None].T + w_idxs_ceil[None]).flatten(),
+ ]
+
+ weights = [
+ ((1 - dh)[None].T * (1 - dw)[None]).flatten(),
+ ((1 - dh)[None].T * dw[None]).flatten(),
+ (dh[None].T * (1 - dw)[None]).flatten(),
+ (dh[None].T * dw[None]).flatten(),
+ ]
+
+ for i in range(4):
+ idx_list[i].extend(indices[i].tolist())
+ weight_list[i].extend(weights[i].tolist())
+
+ idx_tensor = torch.tensor(idx_list, dtype=torch.long, device=self.pos_embed.weight.device)
+ weight_tensor = torch.tensor(
+ weight_list, dtype=self.pos_embed.weight.dtype, device=self.pos_embed.weight.device
+ )
+ pos_embeds = self.pos_embed(idx_tensor) * weight_tensor[:, :, None]
+ patch_pos_embeds = pos_embeds[0] + pos_embeds[1] + pos_embeds[2] + pos_embeds[3]
+
+ patch_pos_embeds = patch_pos_embeds.split([h * w for h, w in zip(grid_hs, grid_ws)])
+
+ patch_pos_embeds_permute = []
+ merge_size = self.spatial_merge_size
+ for pos_embed, t, h, w in zip(patch_pos_embeds, grid_ts, grid_hs, grid_ws):
+ pos_embed = pos_embed.repeat(t, 1)
+ pos_embed = (
+ pos_embed.view(t, h // merge_size, merge_size, w // merge_size, merge_size, -1)
+ .permute(0, 1, 3, 2, 4, 5)
+ .flatten(0, 4)
+ )
+ patch_pos_embeds_permute.append(pos_embed)
+ patch_pos_embeds = torch.cat(patch_pos_embeds_permute)
+ return patch_pos_embeds
+
+ def forward(self, hidden_states: torch.Tensor, grid_thw: torch.Tensor, **kwargs) -> torch.Tensor:
+ hidden_states = self.patch_embed(hidden_states)
+
+ pos_embeds = self.fast_pos_embed_interpolate(grid_thw)
+ hidden_states = hidden_states + pos_embeds
+
+ rotary_cos, rotary_sin = self.rot_pos_emb(grid_thw)
+ rotary_cos = rotary_cos.to("cuda", non_blocking=True)
+ rotary_sin = rotary_sin.to("cuda", non_blocking=True)
+
+ cu_seqlens = torch.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]).cumsum(
+ dim=0,
+ dtype=torch.int32,
+ )
+ cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0).to("cuda", non_blocking=True)
+ max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
+
+ deepstack_feature_lists = []
+ for layer_num, blk in enumerate(self.blocks):
+ hidden_states = blk(
+ hidden_states,
+ cu_seqlens=cu_seqlens,
+ max_seqlen=max_seqlen,
+ rotary_cos=rotary_cos,
+ rotary_sin=rotary_sin,
+ )
+ if layer_num in self.deepstack_visual_indexes:
+ deepstack_feature = self.deepstack_merger_list[self.deepstack_visual_indexes.index(layer_num)](
+ hidden_states
+ )
+ deepstack_feature_lists.append(deepstack_feature)
+
+ hidden_states = self.merger(hidden_states)
+
+ return hidden_states, deepstack_feature_lists
+
+ def encode(self, images: List[ImageItem]):
+ img_tensors = []
+ valid_ids = []
+ valid_id = 0
+ img_grids = []
+ uuids = []
+ for i, img in enumerate(images):
+ if isinstance(img, ImageItem):
+ uuids.append(img.uuid)
+ image_data = read_shm(get_shm_name_data(img.uuid))
+ image_data = Image.open(BytesIO(image_data))
+ image_data = resize_image(
+ image_file=image_data,
+ factor=self.processor.patch_size * self.processor.merge_size,
+ min_pixels=self.processor.min_pixels,
+ max_pixels=self.processor.max_pixels,
+ )
+ pixel_values, image_grid_thw = self.processor.preprocess(image_data)
+ img_tensors.append(pixel_values)
+ img_grids.append(image_grid_thw)
+ else:
+ raise Exception("Unsupport input types: {} for {}".format(type(img), img))
+
+ # must devide merge_length
+ cur_num = img_tensors[-1].shape[0] // (self.spatial_merge_size ** 2)
+
+ valid_ids.append([valid_id, valid_id + cur_num])
+ valid_id += cur_num
+
+ if len(img_tensors) <= 0:
+ return None
+
+ imgs = torch.cat(img_tensors, dim=0)
+ grid_thw = torch.cat(img_grids, dim=0)
+
+ pixel_values = imgs.to("cuda", dtype=self.data_type, non_blocking=True)
+ image_grid_thw = grid_thw.to("cuda", non_blocking=True)
+
+ all_img_embeds, deepstack_feature_lists = self.forward(pixel_values, grid_thw=image_grid_thw)
+
+ return all_img_embeds, uuids, valid_ids, deepstack_feature_lists
diff --git a/lightllm/models/qwen3_vl_moe/layer_infer/transformer_layer_infer.py b/lightllm/models/qwen3_vl_moe/layer_infer/transformer_layer_infer.py
new file mode 100644
index 000000000..3cf4b0d2d
--- /dev/null
+++ b/lightllm/models/qwen3_vl_moe/layer_infer/transformer_layer_infer.py
@@ -0,0 +1,96 @@
+import torch
+import torch.functional as F
+import torch.distributed as dist
+import numpy as np
+from functools import partial
+from typing import Tuple
+from lightllm.common.basemodel.infer_struct import InferStateInfo
+from lightllm.models.qwen2_vl.triton_kernel.mrope import mrope_triton
+from lightllm.models.qwen3_moe.layer_infer.transformer_layer_infer import Qwen3MOETransformerLayerInfer
+from lightllm.models.qwen3_moe.layer_weights.transformer_layer_weight import Qwen3MOETransformerLayerWeight
+from lightllm.models.llama.layer_infer.transformer_layer_infer import LlamaTransformerLayerInfer
+from lightllm.models.llama.infer_struct import LlamaInferStateInfo
+from lightllm.models.qwen3_vl.infer_struct import Qwen3VLInferStateInfo
+from lightllm.models.llama.triton_kernel.rmsnorm import rmsnorm_forward
+from lightllm.models.llama.triton_kernel.rotary_emb import rotary_emb_fwd
+from lightllm.models.llama.triton_kernel.silu_and_mul import silu_and_mul_fwd
+from lightllm.distributed import all_reduce
+from lightllm.utils.dist_utils import get_global_world_size
+
+
+class Qwen3VLMOETransformerLayerInfer(Qwen3MOETransformerLayerInfer):
+ def __init__(self, layer_num, network_config, mode=[]):
+ super().__init__(layer_num, network_config, mode)
+ self.mrope_section = network_config["rope_scaling"]["mrope_section"]
+ axis_map = []
+ for i, n in enumerate(self.mrope_section * 2):
+ axis_map += [i % 3] * n
+ self.axis_map = torch.tensor(axis_map, dtype=torch.int32, device="cuda")
+
+ def _get_qkv(
+ self,
+ input: torch.Tensor,
+ infer_state: Qwen3VLInferStateInfo,
+ layer_weight: Qwen3MOETransformerLayerWeight,
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ input = input.view(-1, self.embed_dim_)
+ q = layer_weight.q_proj.mm(input)
+ cache_kv = self._pre_cache_kv(infer_state=infer_state, layer_weight=layer_weight)
+ cache_kv = layer_weight.kv_proj.mm(
+ input, out=cache_kv.view(-1, (self.tp_k_head_num_ + self.tp_v_head_num_) * self.head_dim_)
+ ).view(-1, (self.tp_k_head_num_ + self.tp_v_head_num_), self.head_dim_)
+ rmsnorm_forward(
+ q.view(-1, self.head_dim_),
+ weight=layer_weight.q_norm_weight_.weight,
+ eps=self.eps_,
+ out=q.view(-1, self.head_dim_),
+ )
+
+ cache_kv[:, : self.tp_k_head_num_, :] = rmsnorm_forward(
+ cache_kv[:, : self.tp_k_head_num_, :].reshape(-1, cache_kv.shape[-1]),
+ weight=layer_weight.k_norm_weight_.weight,
+ eps=self.eps_,
+ ).view(-1, self.tp_k_head_num_, cache_kv.shape[-1])
+ rotary_emb_fwd(
+ q.view(-1, self.tp_q_head_num_, self.head_dim_),
+ cache_kv[:, : self.tp_k_head_num_, :],
+ infer_state.position_cos,
+ infer_state.position_sin,
+ )
+ return q, cache_kv
+
+ def context_forward(self, input_embdings, infer_state: Qwen3VLInferStateInfo, layer_weight):
+ input1 = self._att_norm(input_embdings, infer_state, layer_weight)
+ q, cache_kv = self._get_qkv(input1, infer_state, layer_weight)
+ input1 = None
+ self._post_cache_kv(cache_kv, infer_state, layer_weight)
+ o = self._context_attention_kernel(q, cache_kv, infer_state, layer_weight)
+ q = None
+ o = self._get_o(o, infer_state, layer_weight)
+ if self.tp_world_size_ > 1:
+ all_reduce(o, op=dist.ReduceOp.SUM, group=infer_state.dist_group, async_op=False)
+ input_embdings.add_(o.view(-1, self.embed_dim_))
+ o = None
+
+ input1 = self._ffn_norm(input_embdings, infer_state, layer_weight)
+ ffn_out = self._ffn(input1, infer_state, layer_weight)
+ input1 = None
+ if self.tp_world_size_ > 1:
+ all_reduce(ffn_out, op=dist.ReduceOp.SUM, group=infer_state.dist_group, async_op=False)
+ input_embdings.add_(ffn_out.view(-1, self.embed_dim_))
+ if infer_state.deepstack_features:
+ for i in range(len(infer_state.img_first_token_locs)):
+ start = infer_state.img_first_token_locs[i]
+ end = infer_state.img_last_token_locs[i]
+ deepstack_features = infer_state.deepstack_features[i]
+ if end <= input_embdings.shape[0] and self.layer_num_ in range(len(deepstack_features)):
+ deepstack_features_cur_layer = deepstack_features[self.layer_num_].to(
+ device=input_embdings.device, non_blocking=True
+ )
+ input_embdings[
+ start:end,
+ ].add_(deepstack_features_cur_layer)
+ infer_state.img_first_token_locs = []
+ infer_state.img_last_token_locs = []
+ infer_state.deepstack_features = []
+ return input_embdings
diff --git a/lightllm/models/qwen3_vl_moe/layer_weights/pre_and_post_layer_weight.py b/lightllm/models/qwen3_vl_moe/layer_weights/pre_and_post_layer_weight.py
new file mode 100644
index 000000000..0a7f82a93
--- /dev/null
+++ b/lightllm/models/qwen3_vl_moe/layer_weights/pre_and_post_layer_weight.py
@@ -0,0 +1,37 @@
+import torch
+import numpy as np
+from lightllm.common.basemodel import PreAndPostLayerWeight
+
+
+class Qwen3VLMOEPreAndPostLayerWeight(PreAndPostLayerWeight):
+ def __init__(self, data_type, network_config, mode):
+ super().__init__(data_type, network_config, mode)
+ return
+
+ def load_hf_weights(self, weights):
+ vob_size = self.network_config_["vocab_size"]
+ split_indexes = np.linspace(0, vob_size, self.tp_world_size_ + 1, dtype=np.int64)
+ split_start = split_indexes[self.tp_rank_]
+ split_end = split_indexes[self.tp_rank_ + 1]
+ if "model.language_model.embed_tokens.weight" in weights:
+ self.wte_weight_ = self._cuda(weights["model.language_model.embed_tokens.weight"][split_start:split_end, :])
+ tie_word_embeddings = self.network_config_.get("tie_word_embeddings", False)
+ if tie_word_embeddings:
+ self.lm_head_weight_ = self.wte_weight_
+ if "lm_head.weight" in weights:
+ self.lm_head_weight_ = self._cuda(weights["lm_head.weight"][split_start:split_end, :])
+ if "model.language_model.norm.weight" in weights:
+ self.final_norm_weight_ = self._cuda(weights["model.language_model.norm.weight"])
+
+ return
+
+ def verify_load(self):
+ errors = "weights load not ok"
+ weights = [
+ self.wte_weight_,
+ self.lm_head_weight_,
+ self.final_norm_weight_,
+ ]
+ for i in range(len(weights)):
+ assert weights[i] is not None, "index:" + str(i) + " " + errors
+ return
diff --git a/lightllm/models/qwen3_vl_moe/layer_weights/transformers_layer_weight.py b/lightllm/models/qwen3_vl_moe/layer_weights/transformers_layer_weight.py
new file mode 100644
index 000000000..3a3af7e5a
--- /dev/null
+++ b/lightllm/models/qwen3_vl_moe/layer_weights/transformers_layer_weight.py
@@ -0,0 +1,81 @@
+import os
+import torch
+import math
+import numpy as np
+from lightllm.models.qwen3.layer_weights.transformer_layer_weight import Qwen3TransformerLayerWeight
+from lightllm.models.qwen3_moe.layer_weights.transformer_layer_weight import Qwen3MOETransformerLayerWeight
+from lightllm.common.basemodel.layer_weights.meta_weights import (
+ ROWMMWeight,
+ MultiROWMMWeight,
+ COLMMWeight,
+ NormWeight,
+ FusedMoeWeightTP,
+ FusedMoeWeightEP,
+ ROWBMMWeight,
+)
+
+
+class Qwen3VLMOETransformerLayerWeight(Qwen3MOETransformerLayerWeight):
+ def __init__(self, layer_num, data_type, network_config, mode=[], quant_cfg=None):
+ super().__init__(layer_num, data_type, network_config, mode, quant_cfg)
+
+ def _init_weight_names(self):
+ super()._init_weight_names()
+ self._q_weight_name = f"model.language_model.layers.{self.layer_num_}.self_attn.q_proj.weight"
+ self._q_norm_name = f"model.language_model.layers.{self.layer_num_}.self_attn.q_norm.weight"
+ self._q_bias_name = None
+ self._k_weight_name = f"model.language_model.layers.{self.layer_num_}.self_attn.k_proj.weight"
+ self._k_norm_name = f"model.language_model.layers.{self.layer_num_}.self_attn.k_norm.weight"
+ self._k_bias_name = None
+ self._v_weight_name = f"model.language_model.layers.{self.layer_num_}.self_attn.v_proj.weight"
+ self._v_bias_name = None
+ self._o_weight_name = f"model.language_model.layers.{self.layer_num_}.self_attn.o_proj.weight"
+ self._o_bias_name = None
+ self._att_norm_weight_name = f"model.language_model.layers.{self.layer_num_}.input_layernorm.weight"
+ self._att_norm_bias_name = None
+ self._ffn_norm_weight_name = f"model.language_model.layers.{self.layer_num_}.post_attention_layernorm.weight"
+ self._ffn_norm_bias_name = None
+
+ def _init_moe(self):
+ moe_intermediate_size = self.network_config_["moe_intermediate_size"]
+ self.moe_gate = ROWMMWeight(
+ weight_name=f"model.language_model.layers.{self.layer_num_}.mlp.gate.weight",
+ data_type=self.data_type_,
+ layer_num=self.layer_num_,
+ name="moe_gate",
+ tp_rank=0,
+ tp_world_size=1,
+ )
+ moe_mode = os.getenv("MOE_MODE", "TP")
+ assert moe_mode in ["EP", "TP"]
+
+ if moe_mode == "TP":
+ self.experts = FusedMoeWeightTP(
+ gate_proj_name="gate_up_proj",
+ down_proj_name="down_proj",
+ up_proj_name=None,
+ e_score_correction_bias_name="",
+ weight_prefix=f"model.language_model.layers.{self.layer_num_}.mlp.experts",
+ n_routed_experts=self.n_routed_experts,
+ split_inter_size=moe_intermediate_size // self.tp_world_size_,
+ data_type=self.data_type_,
+ network_config=self.network_config_,
+ layer_num=self.layer_num_,
+ quant_cfg=self.quant_cfg,
+ num_fused_shared_experts=0,
+ )
+ elif moe_mode == "EP":
+ self.experts = FusedMoeWeightEP(
+ gate_proj_name="gate_up_proj",
+ down_proj_name="down_proj",
+ up_proj_name=None,
+ e_score_correction_bias_name="",
+ weight_prefix=f"model.language_model.layers.{self.layer_num_}.mlp.experts",
+ n_routed_experts=self.n_routed_experts,
+ data_type=self.data_type_,
+ network_config=self.network_config_,
+ layer_num=self.layer_num_,
+ quant_cfg=self.quant_cfg,
+ )
+ else:
+ raise ValueError(f"Unsupported moe mode: {moe_mode}")
diff --git a/lightllm/models/qwen_vl/layer_infer/pre_layer_infer.py b/lightllm/models/qwen_vl/layer_infer/pre_layer_infer.py
index b5b31a413..fa8e7b414 100644
--- a/lightllm/models/qwen_vl/layer_infer/pre_layer_infer.py
+++ b/lightllm/models/qwen_vl/layer_infer/pre_layer_infer.py
@@ -6,7 +6,13 @@
from lightllm.models.llama.layer_infer.pre_layer_infer import LlamaPreLayerInfer
from lightllm.utils.infer_utils import mark_cost_time
-from lightllm.server.embed_cache.utils import bytes2tensor, read_shm, get_shm_name_embed
+from lightllm.server.embed_cache.utils import (
+ bytes2tensor,
+ read_shm,
+ get_shm_name_embed,
+ get_shm_name_deepstack,
+ bytes2list,
+)
from lightllm.common.basemodel.triton_kernel.multimodal_emb import multimodal_emb
from lightllm.distributed.communication_op import all_reduce
@@ -29,6 +35,7 @@
class LlamaMultimodalPreLayerInfer(LlamaPreLayerInfer):
def __init__(self, network_config, mode):
super().__init__(network_config, mode)
+ self.use_deepstack = False
return
def context_forward(self, input_ids, infer_state: LlamaInferStateInfo, layer_weight: LlamaPreAndPostLayerWeight):
@@ -50,9 +57,18 @@ def context_forward(self, input_ids, infer_state: LlamaInferStateInfo, layer_wei
# skip the same image
if img["token_id"] in img_start_token_ids or img["_prefill_"] is False:
continue
+ pos = (input_ids == img["token_id"]).nonzero(as_tuple=True)
+ if pos[0].numel() == 0:
+ continue
# pull the img_embeds by uid from shm
data = read_shm(get_shm_name_embed(img["uuid"]))
img_weight.append(bytes2tensor(data).cuda().reshape(img["token_num"], -1))
+ if self.use_deepstack:
+ deepstack_features = read_shm(get_shm_name_deepstack(img["uuid"]))
+ infer_state.deepstack_features.append(bytes2list(deepstack_features))
+ img_insert_locs = int(pos[0][0])
+ infer_state.img_first_token_locs.append(img_insert_locs)
+ infer_state.img_last_token_locs.append(img_insert_locs + img["token_num"])
img_start_token_ids.append(img["token_id"])
img_token_lens.append(img["token_num"])
img_start_locs.append(img_start_loc)
diff --git a/lightllm/server/embed_cache/utils.py b/lightllm/server/embed_cache/utils.py
index 6df031293..08d49c17d 100644
--- a/lightllm/server/embed_cache/utils.py
+++ b/lightllm/server/embed_cache/utils.py
@@ -1,6 +1,7 @@
import torch
import numpy as np
from io import BytesIO
+from typing import List, Optional
import multiprocessing.shared_memory as shm
@@ -20,11 +21,52 @@ def tensor2bytes(t: torch.Tensor):
return buf.read()
+def list2bytes(tensors: List[torch.Tensor]) -> bytes:
+ # 逐个张量做 detach().cpu() 和复制
+ safe_list = []
+ for t in tensors:
+ if t is None:
+ safe_list.append(None)
+ continue
+ t = t.detach().cpu()
+ if not t.is_contiguous():
+ t = t.contiguous()
+ dest = torch.empty_like(t)
+ dest.copy_(t)
+ safe_list.append(dest)
+ buf = BytesIO()
+ torch.save(safe_list, buf, _use_new_zipfile_serialization=False, pickle_protocol=4)
+ buf.seek(0)
+ return buf.read()
+
+
def bytes2tensor(b):
# return torch.from_numpy(np.frombuffer(b, dtype=np.float16)).cuda()
return torch.load(BytesIO(b), weights_only=False)
+def bytes2list(b: bytes, device: Optional[torch.device] = None, non_blocking: bool = False) -> List[torch.Tensor]:
+ obj = torch.load(BytesIO(b), map_location="cpu", weights_only=False)
+
+ if isinstance(obj, tuple):
+ obj = list(obj)
+ if not isinstance(obj, list):
+ raise TypeError(f"Loaded object is {type(obj)}, expected list or tuple.")
+
+ if device is None:
+ return obj
+
+ out: List[torch.Tensor] = []
+ for x in obj:
+ if x is None:
+ out.append(None)
+ elif isinstance(x, torch.Tensor):
+ out.append(x.to(device, non_blocking=non_blocking))
+ else:
+ raise TypeError(f"List element is {type(x)}, expected Tensor or None.")
+ return out
+
+
def create_shm(name, data):
try:
data_size = len(data)
@@ -53,3 +95,7 @@ def get_shm_name_data(uid):
def get_shm_name_embed(uid):
return str(uid) + "-embed"
+
+
+def get_shm_name_deepstack(uid):
+ return str(uid) + "-deepstack"
diff --git a/lightllm/server/tokenizer.py b/lightllm/server/tokenizer.py
index 1f10aa5ec..e0b2bd425 100644
--- a/lightllm/server/tokenizer.py
+++ b/lightllm/server/tokenizer.py
@@ -28,6 +28,7 @@
from ..models.llava.model import LlavaTokenizer
from ..models.qwen_vl.model import QWenVLTokenizer
from ..models.qwen2_vl.model import QWen2VLTokenizer
+from ..models.qwen3_vl.model import QWen3VLTokenizer
from ..models.internvl.model import InternvlTokenizer
from ..models.gemma3.model import Gemma3Tokenizer
@@ -92,6 +93,13 @@ def get_tokenizer(
tokenizer = QWen2VLTokenizer(
tokenizer=tokenizer, image_processor=processor.image_processor, model_cfg=model_cfg
)
+ elif model_type in ["qwen3_vl", "qwen3_vl_moe"] and "vision_config" in model_cfg:
+ from transformers import AutoProcessor
+
+ processor = AutoProcessor.from_pretrained(tokenizer_name)
+ tokenizer = QWen3VLTokenizer(
+ tokenizer=tokenizer, image_processor=processor.image_processor, model_cfg=model_cfg
+ )
elif model_type == "internvl_chat":
tokenizer = InternvlTokenizer(tokenizer, model_cfg, weight_dir=tokenizer_name)
elif model_type == "gemma3":
diff --git a/lightllm/server/visualserver/model_infer/model_rpc.py b/lightllm/server/visualserver/model_infer/model_rpc.py
index a25065e42..5815e7b81 100644
--- a/lightllm/server/visualserver/model_infer/model_rpc.py
+++ b/lightllm/server/visualserver/model_infer/model_rpc.py
@@ -16,8 +16,17 @@
from lightllm.server.multimodal_params import MultimodalParams, ImageItem
from lightllm.models.qwen2_vl.qwen2_visual import Qwen2VisionTransformerPretrainedModel
from lightllm.models.qwen2_5_vl.qwen2_5_visual import Qwen2_5_VisionTransformerPretrainedModel
+from lightllm.models.qwen3_vl.qwen3_visual import Qwen3VisionTransformerPretrainedModel
from lightllm.models.tarsier2.tarsier2_visual import TarsierVisionTransformerPretrainedModel
-from lightllm.server.embed_cache.utils import tensor2bytes, read_shm, create_shm, get_shm_name_data, get_shm_name_embed
+from lightllm.server.embed_cache.utils import (
+ tensor2bytes,
+ read_shm,
+ create_shm,
+ get_shm_name_data,
+ get_shm_name_embed,
+ get_shm_name_deepstack,
+ list2bytes,
+)
from lightllm.utils.infer_utils import set_random_seed
from lightllm.utils.infer_utils import calculate_time, mark_start, mark_end
from lightllm.utils.dist_utils import init_vision_distributed_env
@@ -63,6 +72,10 @@ def exposed_init_model(self, kvargs):
self.model = (
Qwen2_5_VisionTransformerPretrainedModel(kvargs, **model_cfg["vision_config"]).eval().bfloat16()
)
+ elif self.model_type in ["qwen3_vl", "qwen3_vl_moe"]:
+ self.model = (
+ Qwen3VisionTransformerPretrainedModel(kvargs, **model_cfg["vision_config"]).eval().bfloat16()
+ )
elif model_cfg["architectures"][0] == "TarsierForConditionalGeneration":
self.model = TarsierVisionTransformerPretrainedModel(**model_cfg).eval().bfloat16()
elif self.model_type == "llava":
@@ -96,7 +109,8 @@ def forward(self, images: List[ImageItem]):
# @calculate_time(show=False, min_cost_ms=300)
def exposed_encode(self, images: List[ImageItem]):
images = obtain(images)
- all_img_embeds, uuids, valid_ids = self.forward(images)
+ all_img_embeds, uuids, valid_ids, *deepstack_features = self.forward(images)
+ deepstack_feature_lists = deepstack_features[0] if deepstack_features else None
all_img_embeds = all_img_embeds.to(torch.device("cpu"))
if self.tp_rank_id == 0:
@@ -109,6 +123,10 @@ def exposed_encode(self, images: List[ImageItem]):
start, end = valid_ids[i]
cur_embed_bytes = tensor2bytes(all_img_embeds[start:end])
create_shm(get_shm_name_embed(uid), cur_embed_bytes)
+ if deepstack_feature_lists is not None:
+ per_image_deepstack = [feat[start:end] for feat in deepstack_feature_lists]
+ deepstack_features_bytes = list2bytes(per_image_deepstack)
+ create_shm(get_shm_name_deepstack(uid), deepstack_features_bytes)
ids_to_set.append(uid)
if ids_to_set:
self.cache_client.root.set_items_embed(ids_to_set)
diff --git a/lightllm/utils/config_utils.py b/lightllm/utils/config_utils.py
index afaf71a25..91543f14c 100644
--- a/lightllm/utils/config_utils.py
+++ b/lightllm/utils/config_utils.py
@@ -78,6 +78,9 @@ def get_vocab_size(model_path: str):
if "llm_config" in config_json:
vocab_size = int(config_json["llm_config"]["vocab_size"])
return vocab_size
+ elif "text_config" in config_json:
+ vocab_size = int(config_json["text_config"]["vocab_size"])
+ return vocab_size
vocab_size = config_json["vocab_size"]
if not isinstance(vocab_size, int):
vocab_size = int(vocab_size)