Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,8 @@ def _fuse(self):

inter_shape, hidden_size = self.w2_list[0].shape[0], self.w2_list[0].shape[1]
w2 = torch._utils._flatten_dense_tensors(self.w2_list).view(len(self.w2_list), inter_shape, hidden_size)
if self.fused_gate_up:
w2 = w2.transpose(1, 2).contiguous()
if not self.quantized_weight and self.quant_method is not None:
self.w1 = self.quant_method.quantize(w1)
self.w2 = self.quant_method.quantize(w2)
Expand Down Expand Up @@ -178,26 +180,53 @@ def _fuse_weight_scale(self):
def load_hf_weights(self, weights):
if self.e_score_correction_bias_name in weights:
self.e_score_correction_bias = self._cuda(weights[self.e_score_correction_bias_name])
for i_experts in range(self.n_routed_experts):
w1_weight = f"{self.weight_prefix}.{i_experts}.{self.w1_weight_name}.weight"
w2_weight = f"{self.weight_prefix}.{i_experts}.{self.w2_weight_name}.weight"
w3_weight = f"{self.weight_prefix}.{i_experts}.{self.w3_weight_name}.weight"
self.fused_gate_up = self.w3_weight_name is None # gate_up: [E,H,2I] down: [E,I,H]
key_gateup_3d = f"{self.weight_prefix}.{self.w1_weight_name}" # ...experts.gate_up_proj
key_down_3d = f"{self.weight_prefix}.{self.w2_weight_name}"

if w1_weight in weights:
self.experts_gate_projs[i_experts] = weights[w1_weight][
self.split_inter_size * self.tp_rank_ : self.split_inter_size * (self.tp_rank_ + 1), :
]
if w3_weight in weights:
self.experts_up_projs[i_experts] = weights[w3_weight][
self.split_inter_size * self.tp_rank_ : self.split_inter_size * (self.tp_rank_ + 1), :
]
if self.fused_gate_up and (key_gateup_3d in weights) and (key_down_3d in weights):
gate_up_3d = weights[key_gateup_3d]
down_3d = weights[key_down_3d]
assert gate_up_3d.dim() == 3 and down_3d.dim() == 3

if w2_weight in weights:
self.w2_list[i_experts] = weights[w2_weight][
:, self.split_inter_size * self.tp_rank_ : self.split_inter_size * (self.tp_rank_ + 1)
]
E_ckpt, H_, twoE = gate_up_3d.shape
assert E_ckpt == self.n_routed_experts, f"experts mismatch: ckpt {E_ckpt} vs cfg {self.n_routed_experts}"
Eint_total = twoE // 2
start, end = self.tp_rank_ * self.split_inter_size, (self.tp_rank_ + 1) * self.split_inter_size
assert end <= Eint_total, "TP split exceeds total expert-intermediate size"

for i in range(self.n_routed_experts):
gu2d = gate_up_3d[i]
gate2d = gu2d[:, :Eint_total][:, start:end].t().contiguous()
up2d = gu2d[:, Eint_total:][:, start:end].t().contiguous()
self.experts_gate_projs[i] = gate2d
self.experts_up_projs[i] = up2d

self.w2_list[i] = down_3d[i][start:end, :].contiguous()
else:
for i_experts in range(self.n_routed_experts):
w1_weight = f"{self.weight_prefix}.{i_experts}.{self.w1_weight_name}.weight"
w2_weight = f"{self.weight_prefix}.{i_experts}.{self.w2_weight_name}.weight"
w3_weight = f"{self.weight_prefix}.{i_experts}.{self.w3_weight_name}.weight"

if w1_weight in weights:
self.experts_gate_projs[i_experts] = weights[w1_weight][
self.split_inter_size * self.tp_rank_ : self.split_inter_size * (self.tp_rank_ + 1), :
]
if w3_weight in weights:
self.experts_up_projs[i_experts] = weights[w3_weight][
self.split_inter_size * self.tp_rank_ : self.split_inter_size * (self.tp_rank_ + 1), :
]

if w2_weight in weights:
self.w2_list[i_experts] = weights[w2_weight][
:, self.split_inter_size * self.tp_rank_ : self.split_inter_size * (self.tp_rank_ + 1)
]
if self.quant_method is not None:
self._load_weight_scale(weights)
if self.fused_gate_up:
raise ValueError("qwen3_vl_moe not support quant now")
Comment on lines +226 to +227
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This is a good safeguard to prevent running quantization on a model variant that does not support it yet. However, instead of raising a ValueError, it would be more informative to log a warning and skip quantization for this case, allowing the model to load and run in a non-quantized mode. This would provide more flexibility.

else:
self._load_weight_scale(weights)
self._fuse()

def _load_weight_scale(self, weights: Dict[str, torch.Tensor]) -> None:
Expand Down
1 change: 1 addition & 0 deletions lightllm/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from lightllm.models.internvl.model import InternVLInternlm2TpPartModel
from lightllm.models.qwen2_vl.model import Qwen2VLTpPartModel
from lightllm.models.qwen2_reward.model import Qwen2RewardTpPartModel
from lightllm.models.qwen3_vl.model import Qwen3VLTpPartModel, Qwen3VLMOETpPartModel
from lightllm.models.gemma3.model import Gemma3TpPartModel
from lightllm.models.tarsier2.model import (
Tarsier2Qwen2TpPartModel,
Expand Down
6 changes: 5 additions & 1 deletion lightllm/models/llama/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,8 @@ def _init_custom(self):
if rope_scaling is None:
self._init_to_get_rotary()
return
if "mrope_section" in rope_scaling:
self.mrope_section = rope_scaling["mrope_section"]

if "rope_type" in rope_scaling:
scaling_type = rope_scaling["rope_type"]
Expand All @@ -128,6 +130,8 @@ def _init_custom(self):
self._init_to_get_llama3_rotary()
elif scaling_type == "mrope":
self._init_to_get_mrope_rotary()
elif scaling_type == "default":
self._init_to_get_rotary()
else:
raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
return
Expand Down Expand Up @@ -204,7 +208,7 @@ def _init_to_get_rotary(self, default_base=10000):
/ rope_scaling_factor
)
freqs = torch.outer(t, inv_freq)

self.freqs = freqs.cuda()
self._cos_cached = torch.cos(freqs).to(self.data_type).cuda()
self._sin_cached = torch.sin(freqs).to(self.data_type).cuda()
return
Expand Down
7 changes: 6 additions & 1 deletion lightllm/models/qwen2_5_vl/qwen2_5_visual.py
Original file line number Diff line number Diff line change
Expand Up @@ -383,7 +383,12 @@ def encode(self, images: List[ImageItem]):
uuids.append(img.uuid)
image_data = read_shm(get_shm_name_data(img.uuid))
image_data = Image.open(BytesIO(image_data))
image_data = resize_image(image_data)
image_data = resize_image(
image_file=image_data,
factor=self.processor.patch_size * self.processor.merge_size,
min_pixels=self.processor.min_pixels,
max_pixels=self.processor.max_pixels,
)
pixel_values, image_grid_thw = self.processor.preprocess(image_data)
img_tensors.append(pixel_values)
img_grids.append(image_grid_thw)
Expand Down
3 changes: 2 additions & 1 deletion lightllm/models/qwen2_vl/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,9 @@ def init_audioitem_extral_params(

def get_image_token_length(self, img: ImageItem):
width, height = img.image_w, img.image_h
factor = self.patch_size * self.merge_size
resized_height, resized_width = smart_resize(
height=height, width=width, min_pixels=self.min_pixel, max_pixels=self.max_pixel
height=height, width=width, factor=factor, min_pixels=self.min_pixel, max_pixels=self.max_pixel
)
grid_h, grid_w = resized_height // self.patch_size, resized_width // self.patch_size
token_num = (grid_h * grid_w) // (self.merge_size ** 2)
Expand Down
7 changes: 6 additions & 1 deletion lightllm/models/qwen2_vl/qwen2_visual.py
Original file line number Diff line number Diff line change
Expand Up @@ -311,7 +311,12 @@ def encode(self, images: List[ImageItem]):
uuids.append(img.uuid)
image_data = read_shm(get_shm_name_data(img.uuid))
image_data = Image.open(BytesIO(image_data))
image_data = resize_image(image_data)
image_data = resize_image(
image_file=image_data,
factor=self.processor.patch_size * self.processor.merge_size,
min_pixels=self.processor.min_pixels,
max_pixels=self.processor.max_pixels,
)
pixel_values, image_grid_thw = self.processor.preprocess(image_data)
img_tensors.append(pixel_values)
img_grids.append(image_grid_thw)
Expand Down
31 changes: 21 additions & 10 deletions lightllm/models/qwen2_vl/vision_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,34 +35,36 @@ def smart_resize(
height: int, width: int, factor: int = IMAGE_FACTOR, min_pixels: int = MIN_PIXELS, max_pixels: int = MAX_PIXELS
) -> tuple[int, int]:

if max(height, width) / min(height, width) > MAX_RATIO:
if max(height, width) / min(height, width) > 200:
raise ValueError(
f"absolute aspect ratio must be smaller than {MAX_RATIO}, got {max(height, width) / min(height, width)}"
f"absolute aspect ratio must be smaller than 200, got {max(height, width) / min(height, width)}"
)
h_bar = max(factor, round(height / factor) * factor)
w_bar = max(factor, round(width / factor) * factor)
h_bar = round(height / factor) * factor
w_bar = round(width / factor) * factor
Comment on lines +42 to +43
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The calculation for h_bar and w_bar could result in zero if height or width are very small compared to factor (e.g., if height / factor < 0.5). The previous implementation prevented this by using max(factor, ...). This logic has been moved into one branch of the conditional but is missing from the initial calculation. This could lead to errors or incorrect resizing for small images. It's safer to restore the max guard for the initial calculation.

Suggested change
h_bar = round(height / factor) * factor
w_bar = round(width / factor) * factor
h_bar = max(factor, round(height / factor) * factor)
w_bar = max(factor, round(width / factor) * factor)

if h_bar * w_bar > max_pixels:
beta = math.sqrt((height * width) / max_pixels)
h_bar = math.floor(height / beta / factor) * factor
w_bar = math.floor(width / beta / factor) * factor
h_bar = max(factor, math.floor(height / beta / factor) * factor)
w_bar = max(factor, math.floor(width / beta / factor) * factor)
elif h_bar * w_bar < min_pixels:
beta = math.sqrt(min_pixels / (height * width))
h_bar = math.ceil(height * beta / factor) * factor
w_bar = math.ceil(width * beta / factor) * factor
return h_bar, w_bar


def resize_image(image_file: Image.Image, size_factor: int = IMAGE_FACTOR) -> tuple[Image.Image, int, int]:
def resize_image(
image_file: Image.Image, factor: int = IMAGE_FACTOR, min_pixels: int = MIN_PIXELS, max_pixels: int = MAX_PIXELS
) -> tuple[Image.Image, int, int]:

image = image_file.convert("RGB")
width, height = image.size

resized_height, resized_width = smart_resize(
height,
width,
factor=size_factor,
min_pixels=MIN_PIXELS,
max_pixels=MAX_PIXELS,
factor=factor,
min_pixels=min_pixels,
max_pixels=max_pixels,
)
image = image.resize((resized_width, resized_height))

Expand All @@ -72,6 +74,7 @@ def resize_image(image_file: Image.Image, size_factor: int = IMAGE_FACTOR) -> tu
class Qwen2VLImageProcessor(BaseImageProcessor):
def __init__(
self,
size: dict = None,
do_resize: bool = True,
resample: PILImageResampling = PILImageResampling.BICUBIC,
do_rescale: bool = True,
Expand All @@ -88,6 +91,7 @@ def __init__(
**kwargs,
) -> None:
super().__init__(**kwargs)
self.size = size
self.do_resize = do_resize
self.resample = resample
self.do_rescale = do_rescale
Expand All @@ -102,6 +106,13 @@ def __init__(
self.temporal_patch_size = temporal_patch_size
self.merge_size = merge_size
self.data_format = ChannelDimension.FIRST
if isinstance(self.size, dict):
shortest = self.size.get("shortest_edge", None)
longest = self.size.get("longest_edge", None)
if shortest is not None:
self.min_pixels = shortest
if longest is not None:
self.max_pixels = longest

def preprocess(self, image) -> Tuple[torch.Tensor, torch.Tensor]:
if self.do_convert_rgb:
Expand Down
51 changes: 51 additions & 0 deletions lightllm/models/qwen3_vl/infer_struct.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
import torch
import numpy as np
from lightllm.models.llama.infer_struct import LlamaInferStateInfo
from lightllm.common.basemodel.infer_struct import InferStateInfo


class Qwen3VLInferStateInfo(LlamaInferStateInfo):
def __init__(self):
super().__init__()
self.deepstack_features = []
self.img_first_token_locs = []
self.img_last_token_locs = []

def apply_interleaved_mrope(self, freqs, mrope_section):
"""Apply interleaved MRoPE to 3D rotary embeddings.
Reorganizes frequency layout from chunked [TTT...HHH...WWW] to
interleaved [THTHWHTHW...TT], preserving frequency continuity.
args:
x: (3, bs, seq_len, head_dim // 2)
mrope_section: (3,)
returns:
x_t: (bs, seq_len, head_dim // 2)
Comment on lines +18 to +22
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The docstring for apply_interleaved_mrope is inaccurate. The argument is freqs with a shape of (3, seq_len, head_dim // 2), but the docstring refers to x with a bs (batch size) dimension, which is not present. The return value shape is also (seq_len, head_dim // 2). Please update the docstring for clarity.

Suggested change
args:
x: (3, bs, seq_len, head_dim // 2)
mrope_section: (3,)
returns:
x_t: (bs, seq_len, head_dim // 2)
args:
freqs: (3, seq_len, head_dim // 2)
mrope_section: (3,)
returns:
freqs_t: (seq_len, head_dim // 2)

"""
freqs_t = freqs[0] # just overwrite the first dimension T
for dim, offset in enumerate((1, 2), start=1): # H, W
length = mrope_section[dim] * 3
idx = slice(offset, length, 3)
freqs_t[..., idx] = freqs[dim, ..., idx]
return freqs_t

def init_some_extra_state(self, model, input_ids: torch.Tensor):
InferStateInfo.init_some_extra_state(self, model, input_ids)
pos = self.position_ids[None, :].expand(3, -1)
cos_T = torch.index_select(model._cos_cached, 0, pos[0]) # [L, d/2]
cos_H = torch.index_select(model._cos_cached, 0, pos[1])
cos_W = torch.index_select(model._cos_cached, 0, pos[2])
sin_T = torch.index_select(model._sin_cached, 0, pos[0])
sin_H = torch.index_select(model._sin_cached, 0, pos[1])
sin_W = torch.index_select(model._sin_cached, 0, pos[2])
cos_half = self.apply_interleaved_mrope(
torch.stack([cos_T, cos_H, cos_W], dim=0), model.mrope_section
) # [L, d/2]
sin_half = self.apply_interleaved_mrope(
torch.stack([sin_T, sin_H, sin_W], dim=0), model.mrope_section
) # [L, d/2]

self.position_cos = torch.cat([cos_half, cos_half], dim=-1).contiguous() # [L, d]
self.position_sin = torch.cat([sin_half, sin_half], dim=-1).contiguous()
if self.is_prefill:
pos = None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The line pos = None appears to be leftover debugging code, as the pos variable is not used after this assignment. It should be removed to improve code clarity.

return
8 changes: 8 additions & 0 deletions lightllm/models/qwen3_vl/layer_infer/pre_layer_infer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
from lightllm.models.qwen_vl.layer_infer.pre_layer_infer import LlamaMultimodalPreLayerInfer


class Qwen3VLMultimodalPreLayerInfer(LlamaMultimodalPreLayerInfer):
def __init__(self, network_config, mode):
super().__init__(network_config, mode)
self.use_deepstack = True
return
64 changes: 64 additions & 0 deletions lightllm/models/qwen3_vl/layer_infer/transformer_layer_infer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
import torch
import torch.functional as F
import torch.distributed as dist
import numpy as np
from functools import partial
from typing import Tuple
from lightllm.common.basemodel.infer_struct import InferStateInfo
from lightllm.models.qwen2_vl.triton_kernel.mrope import mrope_triton
from lightllm.models.qwen3.layer_infer.transformer_layer_infer import Qwen3TransformerLayerInfer
from lightllm.models.qwen3.layer_weights.transformer_layer_weight import Qwen3TransformerLayerWeight
from lightllm.models.llama.layer_infer.transformer_layer_infer import LlamaTransformerLayerInfer
from lightllm.models.llama.infer_struct import LlamaInferStateInfo
from lightllm.models.qwen3_vl.infer_struct import Qwen3VLInferStateInfo
from lightllm.models.llama.triton_kernel.rmsnorm import rmsnorm_forward
from lightllm.models.llama.triton_kernel.rotary_emb import rotary_emb_fwd
from lightllm.models.llama.triton_kernel.silu_and_mul import silu_and_mul_fwd
from lightllm.distributed import all_reduce
from lightllm.utils.dist_utils import get_global_world_size


class Qwen3VLTransformerLayerInfer(Qwen3TransformerLayerInfer):
def __init__(self, layer_num, network_config, mode=[]):
super().__init__(layer_num, network_config, mode)
self.mrope_section = network_config["rope_scaling"]["mrope_section"]
axis_map = []
for i, n in enumerate(self.mrope_section * 2):
axis_map += [i % 3] * n
self.axis_map = torch.tensor(axis_map, dtype=torch.int32, device="cuda")
Comment on lines +25 to +28
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The instance variable self.axis_map is initialized here but is not used anywhere in the class. This appears to be dead code and should be removed.


def context_forward(self, input_embdings, infer_state: Qwen3VLInferStateInfo, layer_weight):
input1 = self._att_norm(input_embdings, infer_state, layer_weight)
q, cache_kv = self._get_qkv(input1, infer_state, layer_weight)
input1 = None
self._post_cache_kv(cache_kv, infer_state, layer_weight)
o = self._context_attention_kernel(q, cache_kv, infer_state, layer_weight)
q = None
o = self._get_o(o, infer_state, layer_weight)
if self.tp_world_size_ > 1:
all_reduce(o, op=dist.ReduceOp.SUM, group=infer_state.dist_group, async_op=False)
input_embdings.add_(o.view(-1, self.embed_dim_))
o = None

input1 = self._ffn_norm(input_embdings, infer_state, layer_weight)
ffn_out = self._ffn(input1, infer_state, layer_weight)
input1 = None
if self.tp_world_size_ > 1:
all_reduce(ffn_out, op=dist.ReduceOp.SUM, group=infer_state.dist_group, async_op=False)
input_embdings.add_(ffn_out.view(-1, self.embed_dim_))
if infer_state.deepstack_features:
for i in range(len(infer_state.img_first_token_locs)):
start = infer_state.img_first_token_locs[i]
end = infer_state.img_last_token_locs[i]
deepstack_features = infer_state.deepstack_features[i]
if end <= input_embdings.shape[0] and self.layer_num_ in range(len(deepstack_features)):
deepstack_features_cur_layer = deepstack_features[self.layer_num_].to(
device=input_embdings.device, non_blocking=True
)
input_embdings[
start:end,
].add_(deepstack_features_cur_layer)
infer_state.img_first_token_locs = []
infer_state.img_last_token_locs = []
infer_state.deepstack_features = []
return input_embdings
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
import torch
import numpy as np
from lightllm.common.basemodel import PreAndPostLayerWeight


class Qwen3VLPreAndPostLayerWeight(PreAndPostLayerWeight):
def __init__(self, data_type, network_config, mode):
super().__init__(data_type, network_config, mode)
return

def load_hf_weights(self, weights):
vob_size = self.network_config_["vocab_size"]
split_indexes = np.linspace(0, vob_size, self.tp_world_size_ + 1, dtype=np.int64)
split_start = split_indexes[self.tp_rank_]
split_end = split_indexes[self.tp_rank_ + 1]
if "model.language_model.embed_tokens.weight" in weights:
self.wte_weight_ = self._cuda(weights["model.language_model.embed_tokens.weight"][split_start:split_end, :])
tie_word_embeddings = self.network_config_.get("tie_word_embeddings", False)
if tie_word_embeddings:
self.lm_head_weight_ = self.wte_weight_
if "lm_head.weight" in weights:
self.lm_head_weight_ = self._cuda(weights["lm_head.weight"][split_start:split_end, :])
if "model.language_model.norm.weight" in weights:
self.final_norm_weight_ = self._cuda(weights["model.language_model.norm.weight"])

return

def verify_load(self):
errors = "weights load not ok"
weights = [
self.wte_weight_,
self.lm_head_weight_,
self.final_norm_weight_,
]
for i in range(len(weights)):
assert weights[i] is not None, "index:" + str(i) + " " + errors
return
Loading