diff --git a/lightllm/common/basemodel/basemodel.py b/lightllm/common/basemodel/basemodel.py index 4516e18c3..5a4858331 100755 --- a/lightllm/common/basemodel/basemodel.py +++ b/lightllm/common/basemodel/basemodel.py @@ -265,6 +265,9 @@ def _create_inferstate(self, model_input: ModelInput, microbatch_index: int = 0) infer_state.b_ready_cache_len = torch.zeros_like(input=infer_state.b_seq_len) infer_state.multimodal_params = model_input.multimodal_params + infer_state.image_start_locs = model_input.image_start_locs + infer_state.image_token_lens = model_input.image_token_lens + infer_state.image_start_token_ids = model_input.image_start_token_ids infer_state.mem_manager = self.mem_manager infer_state.req_manager = self.req_manager diff --git a/lightllm/common/basemodel/batch_objs.py b/lightllm/common/basemodel/batch_objs.py index 9b317b423..3bf372372 100644 --- a/lightllm/common/basemodel/batch_objs.py +++ b/lightllm/common/basemodel/batch_objs.py @@ -17,6 +17,9 @@ class ModelInput: mem_indexes: torch.Tensor = None is_prefill: bool = False b_ready_cache_len: torch.Tensor = None + image_start_locs: torch.Tensor = None + image_token_lens: torch.Tensor = None + image_start_token_ids: torch.Tensor = None multimodal_params: list = field(default_factory=list) # cpu 变量 diff --git a/lightllm/models/gemma3/gemma3_visual.py b/lightllm/models/gemma3/gemma3_visual.py index b2f7a6b77..16e244c3e 100644 --- a/lightllm/models/gemma3/gemma3_visual.py +++ b/lightllm/models/gemma3/gemma3_visual.py @@ -16,7 +16,10 @@ class Gemma3VisionModel: - def __init__(self): + def __init__(self, kvargs): + self.weight_dir = kvargs["weight_dir"] + self.load_model(self.weight_dir) + self.cuda() pass def load_model(self, weight_dir): diff --git a/lightllm/models/internvl/layer_weights/pre_and_post_layer_weight.py b/lightllm/models/internvl/layer_weights/pre_and_post_layer_weight.py index f19563932..e50f8eb2d 100644 --- a/lightllm/models/internvl/layer_weights/pre_and_post_layer_weight.py +++ b/lightllm/models/internvl/layer_weights/pre_and_post_layer_weight.py @@ -3,6 +3,8 @@ from lightllm.models.llama.layer_weights.pre_and_post_layer_weight import LlamaPreAndPostLayerWeight from lightllm.models.internlm2.layer_weights.pre_and_post_layer_weight import Internlm2PreAndPostLayerWeight +from lightllm.models.vit.model import VisionTransformer +from lightllm.utils.envs_utils import get_env_start_args # add key: language_model.xxx -> xxx @@ -15,9 +17,36 @@ def rename_weight_keys(weights): weights[k[len(prefix) :]] = weights[k] +def build_visual_model(args, data_type: torch.dtype): + if args.disable_extra_process_for_multimodal: + kvargs = { + "weight_dir": args.model_dir, + "data_type": data_type, + "quant_type": args.vit_quant_type, + "quant_cfg": args.vit_quant_cfg, + "max_batch_size": args.visual_infer_batch_size, + } + return VisionTransformer(kvargs=kvargs) + return None + + +class InternVLPreAndPostLayerWeight(LlamaPreAndPostLayerWeight): + def __init__(self, data_type, network_config, mode): + super().__init__(data_type, network_config, mode) + # if we don't assign an extra process for visual model, we need initialize the image cache manager here + self.visual_model = build_visual_model(get_env_start_args(), data_type) + return + + def load_hf_weights(self, weights): + rename_weight_keys(weights) + super().load_hf_weights(weights) + + class InternVLPhi3PreAndPostLayerWeight(LlamaPreAndPostLayerWeight): def __init__(self, data_type, network_config, mode): super().__init__(data_type, network_config, mode) + # if we don't assign an extra process for visual model, we need initialize the image cache manager here + self.visual_model = build_visual_model(get_env_start_args(), data_type) return def load_hf_weights(self, weights): @@ -29,6 +58,8 @@ def load_hf_weights(self, weights): class InternVLInternlm2PreAndPostLayerWeight(Internlm2PreAndPostLayerWeight): def __init__(self, data_type, network_config, mode): super().__init__(data_type, network_config, mode) + # if we don't assign an extra process for visual model, we need initialize the image cache manager here + self.visual_model = build_visual_model(get_env_start_args(), data_type) return def load_hf_weights(self, weights): @@ -40,6 +71,8 @@ def load_hf_weights(self, weights): class InternVLLlamaPreAndPostLayerWeight(LlamaPreAndPostLayerWeight): def __init__(self, data_type, network_config, mode): super().__init__(data_type, network_config, mode) + # if we don't assign an extra process for visual model, we need initialize the image cache manager here + self.visual_model = build_visual_model(get_env_start_args(), data_type) return def load_hf_weights(self, weights): diff --git a/lightllm/models/llama/layer_weights/pre_and_post_layer_weight.py b/lightllm/models/llama/layer_weights/pre_and_post_layer_weight.py index 711406e3f..ef6f71382 100644 --- a/lightllm/models/llama/layer_weights/pre_and_post_layer_weight.py +++ b/lightllm/models/llama/layer_weights/pre_and_post_layer_weight.py @@ -23,7 +23,6 @@ def load_hf_weights(self, weights): self.lm_head_weight_ = self._cuda(weights["lm_head.weight"][split_start:split_end, :]) if "model.norm.weight" in weights: self.final_norm_weight_ = self._cuda(weights["model.norm.weight"]) - return def verify_load(self): diff --git a/lightllm/models/llava/llava_visual.py b/lightllm/models/llava/llava_visual.py index 293bcd445..398bb66b6 100644 --- a/lightllm/models/llava/llava_visual.py +++ b/lightllm/models/llava/llava_visual.py @@ -15,7 +15,10 @@ class LlavaVisionModel: - def __init__(self): + def __init__(self, kvargs): + self.weight_dir = kvargs["weight_dir"] + self.load_model(self.weight_dir) + self.cuda() pass def load_model(self, weight_dir): diff --git a/lightllm/models/qwen2_5_vl/layer_weights/pre_and_post_layer_weight.py b/lightllm/models/qwen2_5_vl/layer_weights/pre_and_post_layer_weight.py new file mode 100644 index 000000000..c3d3e2225 --- /dev/null +++ b/lightllm/models/qwen2_5_vl/layer_weights/pre_and_post_layer_weight.py @@ -0,0 +1,27 @@ +import torch +import numpy as np +from lightllm.utils.envs_utils import get_env_start_args +from transformers.configuration_utils import PretrainedConfig +from lightllm.models.qwen2.layer_weights.pre_and_post_layer_weight import Qwen2PreAndPostLayerWeight +from lightllm.models.qwen2_5_vl.qwen2_5_visual import Qwen2_5VLTransformer + + +def build_visual_model(args, data_type: torch.dtype): + if args.disable_extra_process_for_multimodal: + kvargs = { + "weight_dir": args.model_dir, + "data_type": args.data_type, + "quant_type": args.vit_quant_type, + "quant_cfg": args.vit_quant_cfg, + "max_batch_size": args.visual_infer_batch_size, + } + model_cfg, _ = PretrainedConfig.get_config_dict(kvargs["weight_dir"]) + return Qwen2_5VLTransformer(kvargs=kvargs, **model_cfg["vision_config"]).eval().to(dtype=data_type) + return None + + +class Qwen2_5VLPreAndPostLayerWeight(Qwen2PreAndPostLayerWeight): + def __init__(self, data_type, network_config, mode): + super().__init__(data_type, network_config, mode) + self.visual_model = build_visual_model(get_env_start_args(), data_type) + return diff --git a/lightllm/models/qwen2_5_vl/qwen2_5_visual.py b/lightllm/models/qwen2_5_vl/qwen2_5_visual.py index a9e162e3f..addd448a2 100644 --- a/lightllm/models/qwen2_5_vl/qwen2_5_visual.py +++ b/lightllm/models/qwen2_5_vl/qwen2_5_visual.py @@ -16,7 +16,7 @@ from torch.nn import LayerNorm from transformers.activations import ACT2FN import math -from lightllm.models.qwen2_vl.vision_process import get_image, Qwen2VLImageProcessor +from lightllm.models.qwen2_vl.vision_process import resize_image, Qwen2VLImageProcessor from transformers import AutoProcessor from safetensors import safe_open from transformers.utils import TensorType @@ -24,48 +24,12 @@ from lightllm.models.qwen2_vl.qwen2_visual import PatchEmbed, VisionRotaryEmbedding from lightllm.models.vit.triton_kernel.flashattention_nopad import flash_attention_fwd from lightllm.common.basemodel.layer_infer.cache_tensor_manager import g_cache_manager +from lightllm.models.qwen2_vl.triton_kernel.rotary_pos_emb import apply_rotary_pos_emb_triton # adapted from # https://github.com/huggingface/transformers/blob/ # be37d34f44ff1bc928e59ffb8a30adecab8835a8/src # /transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py#L30C1-L31C1 -class Qwen2_5_VLVisionConfig(PretrainedConfig): - model_type = "qwen2_5_vl" - - def __init__( - self, - depth=32, - hidden_size=3584, - hidden_act="silu", - intermediate_size=3420, - num_heads=16, - in_channels=3, - patch_size=14, - spatial_merge_size=2, - temporal_patch_size=2, - tokens_per_second=4, - window_size=112, - out_hidden_size=3584, - fullatt_block_indexes=[7, 15, 23, 31], - **kwargs, - ): - super().__init__(**kwargs) - - self.depth = depth - self.hidden_size = hidden_size - self.hidden_act = hidden_act - self.intermediate_size = intermediate_size - self.num_heads = num_heads - self.in_channels = in_channels - self.patch_size = patch_size - self.spatial_merge_size = spatial_merge_size - self.temporal_patch_size = temporal_patch_size - self.tokens_per_second = tokens_per_second - self.window_size = window_size - self.fullatt_block_indexes = fullatt_block_indexes - self.out_hidden_size = out_hidden_size - - class Qwen2RMSNorm(nn.Module): def __init__(self, hidden_size, eps=1e-6): """ @@ -104,27 +68,6 @@ def forward(self, hidden_state): return self.down_proj(self.act_fn(self.gate_proj(hidden_state)) * self.up_proj(hidden_state)) -def rotate_half(x): - """Rotates half the hidden dims of the input.""" - x1 = x[..., : x.shape[-1] // 2] - x2 = x[..., x.shape[-1] // 2 :] - return torch.cat((-x2, x1), dim=-1) - - -def apply_rotary_pos_emb_vision( - q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor -) -> Tuple[torch.Tensor, torch.Tensor]: - orig_q_dtype = q.dtype - orig_k_dtype = k.dtype - q, k = q.float(), k.float() - cos, sin = cos.unsqueeze(-2).float(), sin.unsqueeze(-2).float() - q_embed = (q * cos) + (rotate_half(q) * sin) - k_embed = (k * cos) + (rotate_half(k) * sin) - q_embed = q_embed.to(orig_q_dtype) - k_embed = k_embed.to(orig_k_dtype) - return q_embed, k_embed - - class Qwen2_5_VLVisionFlashAttention(nn.Module): def __init__(self, dim: int, num_heads: int = 16) -> None: super().__init__() @@ -132,26 +75,39 @@ def __init__(self, dim: int, num_heads: int = 16) -> None: self.head_dim = dim // num_heads self.qkv = nn.Linear(dim, dim * 3, bias=True) self.proj = nn.Linear(dim, dim) + try: + from vllm.vllm_flash_attn.layers.rotary import apply_rotary_emb + + self.has_vllm = True + self.apply_rotary_emb = apply_rotary_emb + except ImportError: + print("Failed to import _flash_attn_forward from hopper.flash_attn_interface.") + self.has_vllm = False + self.apply_rotary_emb = apply_rotary_pos_emb_triton + + def apply_rotary_pos_emb_vision(self, t: torch.Tensor, freqs: torch.Tensor) -> torch.Tensor: + t_ = t.float() + cos = freqs.cos() + sin = freqs.sin() + output = self.apply_rotary_emb(t_, cos, sin).type_as(t) + return output def forward( self, hidden_states: torch.Tensor, cu_seqlens: torch.Tensor, + max_seqlen: int = 0, rotary_pos_emb: Optional[torch.Tensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, ) -> torch.Tensor: seq_length = hidden_states.shape[0] q, k, v = self.qkv(hidden_states).reshape(seq_length, 3, self.num_heads, -1).permute(1, 0, 2, 3).unbind(0) - if position_embeddings is None: - emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1) - cos = emb.cos() - sin = emb.sin() - else: - cos, sin = position_embeddings - q, k = apply_rotary_pos_emb_vision(q, k, cos, sin) + # if position_embeddings is None: + # position_embeddings = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1) + q = self.apply_rotary_pos_emb_vision(q.unsqueeze(0), rotary_pos_emb) + k = self.apply_rotary_pos_emb_vision(k.unsqueeze(0), rotary_pos_emb) + q = q.squeeze(0) + k = k.squeeze(0) - cu_seqlens = cu_seqlens.to(q.device, torch.int32) - max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() attn_output = g_cache_manager.alloc_tensor(q.shape, q.dtype, device=q.device) flash_attention_fwd(q, k, v, attn_output, cu_seqlens, max_seqlen) attn_output = attn_output.reshape(seq_length, -1) @@ -183,14 +139,14 @@ def forward( self, hidden_states: torch.Tensor, cu_seqlens: torch.Tensor, + max_seqlen: int = 0, rotary_pos_emb: Optional[torch.Tensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, ) -> torch.Tensor: hidden_states = hidden_states + self.attn( self.norm1(hidden_states), cu_seqlens=cu_seqlens, + max_seqlen=max_seqlen, rotary_pos_emb=rotary_pos_emb, - position_embeddings=position_embeddings, ) hidden_states = hidden_states + self.mlp(self.norm2(hidden_states)) return hidden_states @@ -212,9 +168,10 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: return x -class Qwen2_5_VisionTransformerPretrainedModel(nn.Module): +class Qwen2_5VLTransformer(nn.Module): def __init__( self, + kvargs, depth=32, hidden_size=3584, hidden_act="silu", @@ -231,7 +188,13 @@ def __init__( **kwargs, ): super().__init__() - + self.weight_dir = kvargs["weight_dir"] + self.data_type = kvargs.get("data_type", "bfloat16") + # self.mode = [m.replace("int4weight", "w4a16").replace("int8weight", "w8a16") for m in kvargs.get("mode", [])] + # self.weight_dict = kvargs.get("weight_dict", None) + # self.quant_type = kvargs.get("quant_type", None) + # self.quant_cfg_path = kvargs.get("quant_cfg", None) + # self.max_batch_size = kvargs.get("max_batch_size", 1) self.depth = depth self.hidden_size = hidden_size self.hidden_act = hidden_act @@ -278,41 +241,42 @@ def __init__( self.gradient_checkpointing = False - self.device = self.get_device() - self.dtype = self.get_dtype() - - def get_dtype(self) -> torch.dtype: - return self.blocks[0].mlp.down_proj.weight.dtype + processor_config_path = os.path.join(self.weight_dir, "preprocessor_config.json") + with open(processor_config_path, "r") as f: + processor_config_dict = json.load(f) + self.processor = Qwen2VLImageProcessor(**processor_config_dict) - def get_device(self) -> torch.device: - return self.blocks[0].mlp.down_proj.weight.device + self._init_datatype() + self.load_model(kvargs["weight_dir"]) + self.cuda() + + def _init_datatype(self): + if isinstance(self.data_type, torch.dtype): + return + if self.data_type in ["fp16", "float16"]: + self.data_type = torch.float16 + elif self.data_type in ["bf16", "bfloat16"]: + self.data_type = torch.bfloat16 + elif self.data_type in ["fp32", "float32"]: + self.data_type = torch.float32 + else: + raise ValueError(f"Unsupport datatype {self.data_type}!") + return def rot_pos_emb(self, grid_thw): pos_ids = [] - for t, h, w in grid_thw: + s = self.spatial_merge_size + for _, h, w in grid_thw: + pos_shape = (h // s, s, w // s, s) hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w) - hpos_ids = hpos_ids.reshape( - h // self.spatial_merge_size, - self.spatial_merge_size, - w // self.spatial_merge_size, - self.spatial_merge_size, - ) - hpos_ids = hpos_ids.permute(0, 2, 1, 3) - hpos_ids = hpos_ids.flatten() - wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1) - wpos_ids = wpos_ids.reshape( - h // self.spatial_merge_size, - self.spatial_merge_size, - w // self.spatial_merge_size, - self.spatial_merge_size, - ) - wpos_ids = wpos_ids.permute(0, 2, 1, 3) - wpos_ids = wpos_ids.flatten() - pos_ids.append(torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1)) + hpos_ids = hpos_ids.reshape(pos_shape).permute(0, 2, 1, 3).flatten() + wpos_ids = wpos_ids.reshape(pos_shape).permute(0, 2, 1, 3).flatten() + + pos_ids.append(torch.stack([hpos_ids, wpos_ids], dim=-1)) pos_ids = torch.cat(pos_ids, dim=0) max_grid_size = grid_thw[:, 1:].max() - rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size) + rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size).type(torch.float32) rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1) return rotary_pos_emb @@ -359,7 +323,14 @@ def get_window_index(self, grid_thw): def forward(self, hidden_states: torch.Tensor, grid_thw: torch.Tensor) -> torch.Tensor: hidden_states = self.patch_embed(hidden_states) - rotary_pos_emb = self.rot_pos_emb(grid_thw) + rotary_pos_emb = self.rot_pos_emb(grid_thw).to("cuda", non_blocking=True) + cu_seqlens = torch.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]).cumsum( + dim=0, dtype=torch.int32 + ) + cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0) + max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() + cu_seqlens = cu_seqlens.to("cuda", non_blocking=True) + window_index, cu_window_seqlens = self.get_window_index(grid_thw) cu_window_seqlens = torch.tensor( cu_window_seqlens, @@ -367,6 +338,7 @@ def forward(self, hidden_states: torch.Tensor, grid_thw: torch.Tensor) -> torch. dtype=grid_thw.dtype if torch.jit.is_tracing() else torch.int32, ) cu_window_seqlens = torch.unique_consecutive(cu_window_seqlens) + max_window_seqlen = (cu_window_seqlens[1:] - cu_window_seqlens[:-1]).max().item() seq_len, _ = hidden_states.size() hidden_states = hidden_states.reshape(seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1) @@ -375,40 +347,21 @@ def forward(self, hidden_states: torch.Tensor, grid_thw: torch.Tensor) -> torch. rotary_pos_emb = rotary_pos_emb.reshape(seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1) rotary_pos_emb = rotary_pos_emb[window_index, :, :] rotary_pos_emb = rotary_pos_emb.reshape(seq_len, -1) - emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1) - position_embeddings = (emb.cos(), emb.sin()) - - cu_seqlens = torch.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]).cumsum( - dim=0, - # Select dtype based on the following factors: - # - FA2 requires that cu_seqlens_q must have dtype int32 - # - torch.onnx.export requires that cu_seqlens_q must have same - # dtype as grid_thw - # See https://github.com/huggingface/transformers/pull/34852 - # for more information - dtype=grid_thw.dtype if torch.jit.is_tracing() else torch.int32, - ) - cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0) for layer_num, blk in enumerate(self.blocks): if layer_num in self.fullatt_block_indexes: cu_seqlens_now = cu_seqlens + max_seqlen_now = max_seqlen else: cu_seqlens_now = cu_window_seqlens - if self.gradient_checkpointing and self.training: - hidden_states = self._gradient_checkpointing_func( - blk.__call__, - hidden_states, - cu_seqlens_now, - None, - position_embeddings, - ) - else: - hidden_states = blk( - hidden_states, - cu_seqlens=cu_seqlens_now, - position_embeddings=position_embeddings, - ) + max_seqlen_now = max_window_seqlen + + hidden_states = blk( + hidden_states, + cu_seqlens=cu_seqlens_now, + max_seqlen=max_seqlen_now, + rotary_pos_emb=rotary_pos_emb, + ) hidden_states = self.merger(hidden_states) reverse_indices = torch.argsort(window_index) @@ -418,11 +371,6 @@ def forward(self, hidden_states: torch.Tensor, grid_thw: torch.Tensor) -> torch. def load_model(self, weight_dir): - processor_config_path = os.path.join(weight_dir, "preprocessor_config.json") - with open(processor_config_path, "r") as f: - processor_config_dict = json.load(f) - self.processor = Qwen2VLImageProcessor(**processor_config_dict) - bin_weight_files = [file_ for file_ in os.listdir(weight_dir) if file_.endswith(".bin")] if bin_weight_files: weight_dict = {} @@ -443,6 +391,22 @@ def load_model(self, weight_dir): self.load_state_dict(weight_dict) + def load_image(self, img: List[ImageItem]): + pixel_values = None + if isinstance(img, ImageItem): + image_data = read_shm(get_shm_name_data(img.uuid)) + image_data = Image.open(BytesIO(image_data)) + image_data = resize_image(image_data) + pixel_values, image_grid_thw = self.processor.preprocess(image_data) + elif isinstance(img, dict): + image_data = read_shm(get_shm_name_data(img["uuid"])) + image_data = Image.open(BytesIO(image_data)) + image_data = resize_image(image_data) + pixel_values, image_grid_thw = self.processor.preprocess(image_data) + else: + raise Exception("Unsupport input types: {} for {}".format(type(img), img)) + return pixel_values.to(dtype=self.data_type), image_grid_thw + def encode(self, images: List[ImageItem]): img_tensors = [] valid_ids = [] @@ -455,10 +419,8 @@ def encode(self, images: List[ImageItem]): uuids.append(img.uuid) image_data = read_shm(get_shm_name_data(img.uuid)) image_data = Image.open(BytesIO(image_data)) - image_data = get_image(image_data) - image_inputs = self.processor.preprocess(images=image_data, return_tensors="pt") - pixel_values = image_inputs["pixel_values"].to(dtype=torch.bfloat16) - image_grid_thw = image_inputs["image_grid_thw"] + image_data = resize_image(image_data) + pixel_values, image_grid_thw = self.processor.preprocess(image_data) img_tensors.append(pixel_values) img_grids.append(image_grid_thw) else: @@ -476,10 +438,9 @@ def encode(self, images: List[ImageItem]): imgs = torch.cat(img_tensors, dim=0) grid_thw = torch.cat(img_grids, dim=0) - pixel_values = imgs.cuda().to(dtype=torch.float32) - image_grid_thw = grid_thw.cuda() + pixel_values = imgs.to("cuda", dtype=self.data_type, non_blocking=True) + image_grid_thw = grid_thw.to("cuda", non_blocking=True) - pixel_values = pixel_values.type(self.get_dtype()) all_img_embeds = self.forward(pixel_values, grid_thw=image_grid_thw) return all_img_embeds, uuids, valid_ids diff --git a/lightllm/models/qwen2_vl/layer_weights/pre_and_post_layer_weight.py b/lightllm/models/qwen2_vl/layer_weights/pre_and_post_layer_weight.py new file mode 100644 index 000000000..ec46102a6 --- /dev/null +++ b/lightllm/models/qwen2_vl/layer_weights/pre_and_post_layer_weight.py @@ -0,0 +1,27 @@ +import torch +import numpy as np +from lightllm.utils.envs_utils import get_env_start_args +from transformers.configuration_utils import PretrainedConfig +from lightllm.models.qwen2.layer_weights.pre_and_post_layer_weight import Qwen2PreAndPostLayerWeight +from lightllm.models.qwen2_vl.qwen2_visual import Qwen2VLTransformer + + +def build_visual_model(args, data_type: torch.dtype): + if args.disable_extra_process_for_multimodal: + kvargs = { + "weight_dir": args.model_dir, + "data_type": args.data_type, + "quant_type": args.vit_quant_type, + "quant_cfg": args.vit_quant_cfg, + "max_batch_size": args.visual_infer_batch_size, + } + model_cfg, _ = PretrainedConfig.get_config_dict(kvargs["weight_dir"]) + return Qwen2VLTransformer(kvargs=kvargs, **model_cfg["vision_config"]).eval().to(dtype=data_type) + return None + + +class Qwen2VLPreAndPostLayerWeight(Qwen2PreAndPostLayerWeight): + def __init__(self, data_type, network_config, mode): + super().__init__(data_type, network_config, mode) + self.visual_model = build_visual_model(get_env_start_args(), data_type) + return diff --git a/lightllm/models/qwen2_vl/model.py b/lightllm/models/qwen2_vl/model.py index 70c8bf32e..afa6712a1 100644 --- a/lightllm/models/qwen2_vl/model.py +++ b/lightllm/models/qwen2_vl/model.py @@ -16,6 +16,8 @@ from lightllm.common.build_utils import repair_config from lightllm.models.registry import ModelRegistry from lightllm.models.qwen2_vl.infer_struct import Qwen2VLInferStateInfo +from lightllm.models.qwen2_vl.layer_weights.pre_and_post_layer_weight import Qwen2VLPreAndPostLayerWeight +from lightllm.models.qwen2_5_vl.layer_weights.pre_and_post_layer_weight import Qwen2_5VLPreAndPostLayerWeight from lightllm.models.qwen2_vl.layer_infer.transformer_layer_infer import Qwen2VLTransformerLayerInfer import torch @@ -31,6 +33,10 @@ class QWen2VLTokenizer(BaseMultiModalTokenizer): def __init__(self, tokenizer=None, image_processor=None, **kwargs): super().__init__(tokenizer) self.image_processor = image_processor + self.min_pixel = self.image_processor.image_processor.min_pixels + self.max_pixel = self.image_processor.image_processor.max_pixels + self.patch_size = self.image_processor.image_processor.patch_size + self.merge_size = self.image_processor.image_processor.merge_size self.image_start_id = kwargs["model_cfg"]["vision_start_token_id"] self.image_end_id = kwargs["model_cfg"]["vision_end_token_id"] self.image_token_id = kwargs["model_cfg"]["image_token_id"] @@ -46,17 +52,13 @@ def init_audioitem_extral_params( raise NotImplementedError def get_image_token_length(self, img: ImageItem): - width = img.image_w - height = img.image_h - resized_height, resized_width = smart_resize(height=height, width=width) - self.patch_size = self.image_processor.image_processor.patch_size - self.merge_size = self.image_processor.image_processor.merge_size - grid_t = 1 + width, height = img.image_w, img.image_h + resized_height, resized_width = smart_resize( + height=height, width=width, min_pixels=self.min_pixel, max_pixels=self.max_pixel + ) grid_h, grid_w = resized_height // self.patch_size, resized_width // self.patch_size - merge_length = self.merge_size ** 2 - self.token_num = (grid_t * grid_h * grid_w) // merge_length - self.image_length = self.token_num - return self.image_length + self.token_num = (grid_h * grid_w) // (self.merge_size ** 2) + return self.token_num def get_audio_token_length(self, audio: AudioItem): raise NotImplementedError @@ -93,12 +95,44 @@ def encode(self, prompt, multimodal_params: MultimodalParams = None, **kwargs): return input_ids -@ModelRegistry(["qwen2_vl", "qwen2_5_vl"], is_multimodal=True) +@ModelRegistry(["qwen2_vl"], is_multimodal=True) class Qwen2VLTpPartModel(Qwen2TpPartModel): pre_layer_infer_class = LlamaMultimodalPreLayerInfer transformer_layer_infer_class = Qwen2VLTransformerLayerInfer + pre_and_post_weight_class = Qwen2VLPreAndPostLayerWeight + + infer_state_class = Qwen2VLInferStateInfo + + def __init__(self, kvargs): + super().__init__(kvargs) + return + + def _init_inferstate_cls(self): + if get_env_start_args().enable_fa3: + self.infer_state_class = Qwen2VLFlashAttentionStateInfo + + def _init_config(self): + with open(os.path.join(self.weight_dir_, "config.json"), "r") as json_file: + self.config = json.load(json_file) + # rename keys + repair_config(self.config, same_names=["num_attention_heads", "n_head"]) + repair_config(self.config, same_names=["hidden_size", "n_embd", "n_embed"]) + repair_config(self.config, same_names=["num_hidden_layers", "n_layer"]) + if self.finetune_config: + self.config["vocab_size"] = self.finetune_config.vocab_size + return + + +@ModelRegistry(["qwen2_5_vl"], is_multimodal=True) +class Qwen2_5VLTpPartModel(Qwen2TpPartModel): + + pre_layer_infer_class = LlamaMultimodalPreLayerInfer + transformer_layer_infer_class = Qwen2VLTransformerLayerInfer + + pre_and_post_weight_class = Qwen2_5VLPreAndPostLayerWeight + infer_state_class = Qwen2VLInferStateInfo def __init__(self, kvargs): diff --git a/lightllm/models/qwen2_vl/qwen2_visual.py b/lightllm/models/qwen2_vl/qwen2_visual.py index 4a9012518..970fe78b4 100644 --- a/lightllm/models/qwen2_vl/qwen2_visual.py +++ b/lightllm/models/qwen2_vl/qwen2_visual.py @@ -24,6 +24,7 @@ import torch import torch.nn.functional as F from PIL import Image +from einops import rearrange from typing import List, Union from torchvision import transforms as T from torchvision.transforms.functional import InterpolationMode @@ -36,59 +37,18 @@ import torch.nn as nn from torch.nn import LayerNorm from transformers.activations import ACT2FN -import math -from .vision_process import get_image -from transformers import AutoProcessor +import time from safetensors import safe_open from transformers.utils import TensorType +from lightllm.common.build_utils import repair_config from lightllm.server.multimodal_params import MultimodalParams, ImageItem -from lightllm.models.qwen2_vl.vision_process import Qwen2VLImageProcessor +from lightllm.models.qwen2_vl.vision_process import resize_image, Qwen2VLImageProcessor +from lightllm.models.qwen2_vl.triton_kernel.rotary_pos_emb import apply_rotary_pos_emb_triton from lightllm.models.vit.triton_kernel.flashattention_nopad import flash_attention_fwd from lightllm.common.basemodel.layer_infer.cache_tensor_manager import g_cache_manager -from transformers.utils import is_flash_attn_2_available - -if is_flash_attn_2_available(): - from flash_attn import flash_attn_varlen_func - - from transformers.modeling_flash_attention_utils import _flash_attention_forward -else: - flash_attn_varlen_func = None - - -logger = logging.get_logger(__name__) - # adapted from # https://github.com/huggingface/transformers/blob/main/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py -class Qwen2VLVisionConfig(PretrainedConfig): - model_type = "qwen2_vl" - - def __init__( - self, - depth=32, - embed_dim=1280, - hidden_size=3584, - hidden_act="quick_gelu", - mlp_ratio=4, - num_heads=16, - in_channels=3, - patch_size=14, - spatial_merge_size=2, - temporal_patch_size=2, - **kwargs, - ): - super().__init__(**kwargs) - - self.depth = depth - self.embed_dim = embed_dim - self.hidden_size = hidden_size - self.hidden_act = hidden_act - self.mlp_ratio = mlp_ratio - self.num_heads = num_heads - self.in_channels = in_channels - self.patch_size = patch_size - self.spatial_merge_size = spatial_merge_size - self.temporal_patch_size = temporal_patch_size class PatchEmbed(nn.Module): @@ -109,11 +69,10 @@ def __init__( self.proj = nn.Conv3d(in_channels, embed_dim, kernel_size=kernel_size, stride=kernel_size, bias=False) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: - target_dtype = self.proj.weight.dtype hidden_states = hidden_states.view( -1, self.in_channels, self.temporal_patch_size, self.patch_size, self.patch_size - ).cuda() - hidden_states = self.proj(hidden_states.to(dtype=target_dtype)).view(-1, self.embed_dim) + ) + hidden_states = self.proj(hidden_states).view(-1, self.embed_dim) return hidden_states @@ -133,38 +92,6 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: return x -# Copied from transformers.models.llama.modeling_llama.rotate_half -def rotate_half(x): - """Rotates half the hidden dims of the input.""" - x1 = x[..., : x.shape[-1] // 2] - x2 = x[..., x.shape[-1] // 2 :] - return torch.cat((-x2, x1), dim=-1) - - -def apply_rotary_pos_emb_vision(tensor: torch.Tensor, freqs: torch.Tensor) -> torch.Tensor: - orig_dtype = tensor.dtype - tensor = tensor.float() - cos = freqs.cos() - sin = freqs.sin() - cos = cos.unsqueeze(1).repeat(1, 1, 2).unsqueeze(0).float() - sin = sin.unsqueeze(1).repeat(1, 1, 2).unsqueeze(0).float() - output = (tensor * cos) + (rotate_half(tensor) * sin) - output = output.to(orig_dtype) - return output - - -class VisionRotaryEmbedding(nn.Module): - def __init__(self, dim: int, theta: float = 10000.0) -> None: - super().__init__() - inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float) / dim)) - self.register_buffer("inv_freq", inv_freq, persistent=False) - - def forward(self, seqlen: int) -> torch.Tensor: - seq = torch.arange(seqlen, device=self.inv_freq.device, dtype=self.inv_freq.dtype) - freqs = torch.outer(seq, self.inv_freq) - return freqs - - class VisionMlp(nn.Module): def __init__(self, dim: int, hidden_dim: int, hidden_act: str) -> None: super().__init__() @@ -176,60 +103,70 @@ def forward(self, x) -> torch.Tensor: return self.fc2(self.act(self.fc1(x))) -class VisionAttention(nn.Module): - def __init__(self, dim: int, num_heads: int = 16) -> None: +# copy form vllm +class VisionRotaryEmbedding(nn.Module): + def __init__(self, dim: int, theta: float = 10000.0) -> None: super().__init__() - self.num_heads = num_heads - self.head_dim = dim // num_heads # 初始化 head_dim,每个头的维度 - self.qkv = nn.Linear(dim, dim * 3, bias=True) - self.proj = nn.Linear(dim, dim) + self.dim = dim + self.theta = theta + self.inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float) / dim)) + self._seq_len_cached = 0 + self._freqs_cached = None + + def update_freqs_cache(self, seqlen: int) -> None: + if seqlen > self._seq_len_cached: + seqlen *= 2 + self._seq_len_cached = seqlen + self.inv_freq = 1.0 / ( + self.theta ** (torch.arange(0, self.dim, 2, dtype=torch.float, device=self.inv_freq.device) / self.dim) + ) + seq = torch.arange(seqlen, device=self.inv_freq.device, dtype=self.inv_freq.dtype) + freqs = torch.outer(seq, self.inv_freq) + self._freqs_cached = freqs - def forward( - self, hidden_states: torch.Tensor, cu_seqlens: torch.Tensor, rotary_pos_emb: torch.Tensor = None - ) -> torch.Tensor: - seq_length = hidden_states.shape[0] - q, k, v = self.qkv(hidden_states).reshape(seq_length, 3, self.num_heads, -1).permute(1, 0, 2, 3).unbind(0) - q = apply_rotary_pos_emb_vision(q.unsqueeze(0), rotary_pos_emb).squeeze(0) - k = apply_rotary_pos_emb_vision(k.unsqueeze(0), rotary_pos_emb).squeeze(0) - - attention_mask = torch.zeros([1, seq_length, seq_length], device=q.device, dtype=torch.bool) - for i in range(1, len(cu_seqlens)): - attention_mask[..., cu_seqlens[i - 1] : cu_seqlens[i], cu_seqlens[i - 1] : cu_seqlens[i]] = True - - q = q.transpose(0, 1) - k = k.transpose(0, 1) - v = v.transpose(0, 1) - attn_weights = torch.matmul(q, k.transpose(1, 2)) / math.sqrt(self.head_dim) - attn_weights = attn_weights + attention_mask - attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(q.dtype) - attn_output = torch.matmul(attn_weights, v) - attn_output = attn_output.transpose(0, 1) - attn_output = attn_output.reshape(seq_length, -1) - attn_output = self.proj(attn_output) - return attn_output + def forward(self, seqlen: int) -> torch.Tensor: + self.update_freqs_cache(seqlen) + return self._freqs_cached[:seqlen] -# adapted from -# https://github.com/huggingface/transformers/blob/main/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py class VisionFlashAttention(nn.Module): def __init__(self, dim: int, num_heads: int = 16) -> None: super().__init__() self.num_heads = num_heads self.qkv = nn.Linear(dim, dim * 3, bias=True) self.proj = nn.Linear(dim, dim) + self.has_vllm = False + try: + from vllm.vllm_flash_attn.layers.rotary import apply_rotary_emb + + self.has_vllm = True + self.apply_rotary_emb = apply_rotary_emb + except ImportError: + print("Failed to import _flash_attn_forward from hopper.flash_attn_interface.") + self.has_vllm = False + self.apply_rotary_emb = apply_rotary_pos_emb_triton + + def apply_rotary_pos_emb_vision(self, t: torch.Tensor, freqs: torch.Tensor) -> torch.Tensor: + t_ = t.float() + cos = freqs.cos() + sin = freqs.sin() + output = self.apply_rotary_emb(t_, cos, sin).type_as(t) + return output def forward( - self, hidden_states: torch.Tensor, cu_seqlens: torch.Tensor, rotary_pos_emb: torch.Tensor = None + self, + hidden_states: torch.Tensor, + cu_seqlens: torch.Tensor, + max_seqlen: int = 0, + rotary_pos_emb: torch.Tensor = None, ) -> torch.Tensor: seq_length = hidden_states.shape[0] q, k, v = self.qkv(hidden_states).reshape(seq_length, 3, self.num_heads, -1).permute(1, 0, 2, 3).unbind(0) - q = apply_rotary_pos_emb_vision(q.unsqueeze(0), rotary_pos_emb) - k = apply_rotary_pos_emb_vision(k.unsqueeze(0), rotary_pos_emb) + q = self.apply_rotary_pos_emb_vision(q.unsqueeze(0), rotary_pos_emb) + k = self.apply_rotary_pos_emb_vision(k.unsqueeze(0), rotary_pos_emb) q = q.squeeze(0) k = k.squeeze(0) - cu_seqlens = cu_seqlens.to(q.device, torch.int32) - max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() attn_output = g_cache_manager.alloc_tensor(q.shape, q.dtype, device=q.device) flash_attention_fwd(q, k, v, attn_output, cu_seqlens, max_seqlen) @@ -238,8 +175,6 @@ def forward( return attn_output -# adapted from -# https://github.com/huggingface/transformers/blob/main/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py class Qwen2VLVisionBlock(nn.Module): def __init__(self, embed_dim, mlp_ratio, num_heads, hidden_act) -> None: super().__init__() @@ -250,19 +185,18 @@ def __init__(self, embed_dim, mlp_ratio, num_heads, hidden_act) -> None: self.attn = VisionFlashAttention(embed_dim, num_heads=num_heads) self.mlp = VisionMlp(dim=embed_dim, hidden_dim=mlp_hidden_dim, hidden_act=hidden_act) - def forward(self, hidden_states, cu_seqlens, rotary_pos_emb) -> torch.Tensor: + def forward(self, hidden_states, cu_seqlens, max_seqlen, rotary_pos_emb) -> torch.Tensor: hidden_states = hidden_states + self.attn( - self.norm1(hidden_states), cu_seqlens=cu_seqlens, rotary_pos_emb=rotary_pos_emb + self.norm1(hidden_states), cu_seqlens=cu_seqlens, max_seqlen=max_seqlen, rotary_pos_emb=rotary_pos_emb ) hidden_states = hidden_states + self.mlp(self.norm2(hidden_states)) return hidden_states -# adapted from -# https://github.com/huggingface/transformers/blob/main/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py -class Qwen2VisionTransformerPretrainedModel(nn.Module): +class Qwen2VLTransformer(nn.Module): def __init__( self, + kvargs, depth=32, embed_dim=1280, hidden_size=3584, @@ -276,6 +210,15 @@ def __init__( **kwargs, ): super().__init__() + + self.weight_dir_ = kvargs["weight_dir"] + self.data_type = kvargs.get("data_type", "bfloat16") + # self.mode = [m.replace("int4weight", "w4a16").replace("int8weight", "w8a16") for m in kvargs.get("mode", [])] + # self.weight_dict = kvargs.get("weight_dict", None) + # self.quant_type = kvargs.get("quant_type", None) + # self.quant_cfg_path = kvargs.get("quant_cfg", None) + # self.max_batch_size = kvargs.get("max_batch_size", 1) + self.depth = depth self.embed_dim = embed_dim self.hidden_size = hidden_size @@ -295,7 +238,7 @@ def __init__( ) head_dim = self.embed_dim // self.num_heads - self.rotary_pos_emb = VisionRotaryEmbedding(head_dim // 2) + self.rotary_pos_emb = VisionRotaryEmbedding(head_dim // 2).to("cuda", non_blocking=True) self.blocks = nn.ModuleList( [ @@ -305,38 +248,59 @@ def __init__( ) self.merger = PatchMerger(dim=self.hidden_size, context_dim=self.embed_dim) - self.device = self.get_device() - self.dtype = self.get_dtype() + processor_config_path = os.path.join(kvargs["weight_dir"], "preprocessor_config.json") + with open(processor_config_path, "r") as f: + processor_config_dict = json.load(f) + self.processor = Qwen2VLImageProcessor(**processor_config_dict) - def get_dtype(self) -> torch.dtype: - return self.blocks[0].mlp.fc2.weight.dtype + self._init_datatype() + self.load_model(kvargs["weight_dir"]) + self.cuda() + + def _init_datatype(self): + if isinstance(self.data_type, torch.dtype): + return + if self.data_type in ["fp16", "float16"]: + self.data_type = torch.float16 + elif self.data_type in ["bf16", "bfloat16"]: + self.data_type = torch.bfloat16 + elif self.data_type in ["fp32", "float32"]: + self.data_type = torch.float32 + else: + raise ValueError(f"Unsupport datatype {self.data_type}!") + return - def get_device(self) -> torch.device: - return self.blocks[0].mlp.fc2.weight.device + def load_model(self, weight_dir): + bin_weight_files = [file_ for file_ in os.listdir(weight_dir) if file_.endswith(".bin")] + if bin_weight_files: + weight_dict = {} + for file_ in bin_weight_files: + f = torch.load(os.path.join(weight_dir, file_), "cpu") + for k, v in f.items(): + if "visual" in k: + weight_dict[k[len("visual.") :]] = v + else: + hf_weight_files = [file_ for file_ in os.listdir(weight_dir) if file_.endswith(".safetensors")] + weight_dict = {} + for file_ in hf_weight_files: + f = safe_open(os.path.join(weight_dir, file_), "pt", "cpu") + for k in f.keys(): + if "visual" in k: + weight_dict[k[len("visual.") :]] = f.get_tensor(k) + + self.load_state_dict(weight_dict) def rot_pos_emb(self, grid_thw): pos_ids = [] - for t, h, w in grid_thw: + s = self.spatial_merge_size + for _, h, w in grid_thw: + pos_shape = (h // s, s, w // s, s) hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w) - hpos_ids = hpos_ids.reshape( - h // self.spatial_merge_size, - self.spatial_merge_size, - w // self.spatial_merge_size, - self.spatial_merge_size, - ) - hpos_ids = hpos_ids.permute(0, 2, 1, 3) - hpos_ids = hpos_ids.flatten() - wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1) - wpos_ids = wpos_ids.reshape( - h // self.spatial_merge_size, - self.spatial_merge_size, - w // self.spatial_merge_size, - self.spatial_merge_size, - ) - wpos_ids = wpos_ids.permute(0, 2, 1, 3) - wpos_ids = wpos_ids.flatten() - pos_ids.append(torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1)) + hpos_ids = hpos_ids.reshape(pos_shape).permute(0, 2, 1, 3).flatten() + wpos_ids = wpos_ids.reshape(pos_shape).permute(0, 2, 1, 3).flatten() + + pos_ids.append(torch.stack([hpos_ids, wpos_ids], dim=-1)) pos_ids = torch.cat(pos_ids, dim=0) max_grid_size = grid_thw[:, 1:].max() rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size).type(torch.float32) @@ -344,52 +308,36 @@ def rot_pos_emb(self, grid_thw): return rotary_pos_emb def forward(self, hidden_states: torch.Tensor, grid_thw: torch.Tensor) -> torch.Tensor: - hidden_states = hidden_states.to( - dtype=self.get_dtype(), - device=self.device, - ) - grid_thw = grid_thw.to( - dtype=torch.int32, - device=self.device, - ) - hidden_states = self.patch_embed(hidden_states) - rotary_pos_emb = self.rot_pos_emb(grid_thw) + rotary_pos_emb = self.rot_pos_emb(grid_thw).to("cuda", non_blocking=True) cu_seqlens = torch.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]).cumsum( dim=0, dtype=torch.int32 ) cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0) + max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() + cu_seqlens = cu_seqlens.to("cuda", non_blocking=True) for blk in self.blocks: - hidden_states = blk(hidden_states, cu_seqlens=cu_seqlens, rotary_pos_emb=rotary_pos_emb) + hidden_states = blk( + hidden_states, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen, rotary_pos_emb=rotary_pos_emb + ) return self.merger(hidden_states) - def load_model(self, weight_dir): - - processor_config_path = os.path.join(weight_dir, "preprocessor_config.json") - with open(processor_config_path, "r") as f: - processor_config_dict = json.load(f) - self.processor = Qwen2VLImageProcessor(**processor_config_dict) - - bin_weight_files = [file_ for file_ in os.listdir(weight_dir) if file_.endswith(".bin")] - if bin_weight_files: - weight_dict = {} - for file_ in bin_weight_files: - f = torch.load(os.path.join(weight_dir, file_), "cpu") - for k, v in f.items(): - if "visual" in k: - weight_dict[k[len("visual.") :]] = v - + def load_image(self, img: List[ImageItem]): + pixel_values = None + if isinstance(img, ImageItem): + image_data = read_shm(get_shm_name_data(img.uuid)) + image_data = Image.open(BytesIO(image_data)) + image_data = resize_image(image_data) + pixel_values, image_grid_thw = self.processor.preprocess(image_data) + elif isinstance(img, dict): + image_data = read_shm(get_shm_name_data(img["uuid"])) + image_data = Image.open(BytesIO(image_data)) + image_data = resize_image(image_data) + pixel_values, image_grid_thw = self.processor.preprocess(image_data) else: - hf_weight_files = [file_ for file_ in os.listdir(weight_dir) if file_.endswith(".safetensors")] - weight_dict = {} - for file_ in hf_weight_files: - f = safe_open(os.path.join(weight_dir, file_), "pt", "cpu") - for k in f.keys(): - if "visual" in k: - weight_dict[k[len("visual.") :]] = f.get_tensor(k) - - self.load_state_dict(weight_dict) + raise Exception("Unsupport input types: {} for {}".format(type(img), img)) + return pixel_values.to(dtype=self.data_type), image_grid_thw def encode(self, images: List[ImageItem]): img_tensors = [] @@ -397,16 +345,14 @@ def encode(self, images: List[ImageItem]): valid_id = 0 img_grids = [] uuids = [] - + print("begin encode+++++++++") for i, img in enumerate(images): if isinstance(img, ImageItem): uuids.append(img.uuid) image_data = read_shm(get_shm_name_data(img.uuid)) image_data = Image.open(BytesIO(image_data)) - image_data = get_image(image_data) - image_inputs = self.processor.preprocess(images=image_data, return_tensors="pt") - pixel_values = image_inputs["pixel_values"].to(dtype=torch.bfloat16) - image_grid_thw = image_inputs["image_grid_thw"] + image_data = resize_image(image_data) + pixel_values, image_grid_thw = self.processor.preprocess(image_data) img_tensors.append(pixel_values) img_grids.append(image_grid_thw) else: @@ -424,10 +370,9 @@ def encode(self, images: List[ImageItem]): imgs = torch.cat(img_tensors, dim=0) grid_thw = torch.cat(img_grids, dim=0) - pixel_values = imgs.cuda().to(dtype=torch.float32) - image_grid_thw = grid_thw.cuda() + pixel_values = imgs.to("cuda", dtype=self.data_type, non_blocking=True) + image_grid_thw = grid_thw.to("cuda", non_blocking=True) - pixel_values = pixel_values.type(self.get_dtype()) all_img_embeds = self.forward(pixel_values, grid_thw=image_grid_thw) return all_img_embeds, uuids, valid_ids diff --git a/lightllm/models/qwen2_vl/triton_kernel/rotary_pos_emb.py b/lightllm/models/qwen2_vl/triton_kernel/rotary_pos_emb.py new file mode 100644 index 000000000..499211620 --- /dev/null +++ b/lightllm/models/qwen2_vl/triton_kernel/rotary_pos_emb.py @@ -0,0 +1,160 @@ +import math +import torch +import triton +import triton.language as tl + + +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +def apply_rotary_pos_emb_vision_ref( + tensor: torch.Tensor, + cos: torch.Tensor, + sin: torch.Tensor, +) -> torch.Tensor: + orig_dtype = tensor.dtype + tensor = tensor.float() + cos = cos.unsqueeze(1).repeat(1, 1, 2).unsqueeze(0).float() + sin = sin.unsqueeze(1).repeat(1, 1, 2).unsqueeze(0).float() + output = (tensor * cos) + (rotate_half(tensor) * sin) + output = output.to(orig_dtype) + return output + + +@triton.jit +def rotary_kernel( + inp_ptr, + cos_ptr, + sin_ptr, + out_ptr, + stride_b, + stride_l, + stride_h, + stride_cos_l, + stride_sin_l, + H: tl.constexpr, + D: tl.constexpr, + HALF_D: tl.constexpr, + BLOCK_D: tl.constexpr, +): + pid_bh = tl.program_id(0) + pid_l = tl.program_id(1) + pid_blk = tl.program_id(2) + + b = pid_bh // H + h = pid_bh - b * H + + offs_d = tl.arange(0, BLOCK_D) + d = pid_blk * BLOCK_D + offs_d + mask = d < D + + # 64-bit 计算基址,防止 b*stride_b 溢出 + base = ( + tl.full([], b, tl.int64) * tl.full([], stride_b, tl.int64) + + tl.full([], pid_l, tl.int64) * tl.full([], stride_l, tl.int64) + + tl.full([], h, tl.int64) * tl.full([], stride_h, tl.int64) + ) + + in_ptr = inp_ptr + base + d + cos_ptr_ = cos_ptr + tl.full([], pid_l, tl.int64) * tl.full([], stride_cos_l, tl.int64) + d + sin_ptr_ = sin_ptr + tl.full([], pid_l, tl.int64) * tl.full([], stride_sin_l, tl.int64) + d + + x = tl.load(in_ptr, mask=mask) + cos = tl.load(cos_ptr_, mask=mask) + sin = tl.load(sin_ptr_, mask=mask) + + partner_d = tl.where(d < HALF_D, d + HALF_D, d - HALF_D) + partner_ptr = inp_ptr + base + partner_d + partner_val = tl.load(partner_ptr, mask=mask) + rotated = tl.where(d < HALF_D, -partner_val, partner_val) + + y = x * cos + rotated * sin + + out_ptr_ = out_ptr + base + d + tl.store(out_ptr_, y, mask=mask) + + +def apply_rotary_pos_emb_triton( + tensor: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, BLOCK_D: int = 128 +) -> torch.Tensor: + assert tensor.is_cuda and cos.is_cuda and sin.is_cuda + if tensor.ndim != 4: + raise RuntimeError("tensor shape should be [B, L, H, D]") + orig_dtype = tensor.dtype + x = tensor.float() + + cos = cos.unsqueeze(1).repeat(1, 1, 2).view(cos.size(0), -1).contiguous().float() + sin = sin.unsqueeze(1).repeat(1, 1, 2).view(sin.size(0), -1).contiguous().float() + + B, L, H, D = x.shape + HALF_D = D // 2 + y = torch.empty_like(x) + + stride_b, stride_l, stride_h, _ = x.stride() + stride_cos_l, stride_sin_l = cos.stride(0), sin.stride(0) + + grid = (B * H, L, math.ceil(D / BLOCK_D)) + + rotary_kernel[grid]( + x, + cos, + sin, + y, + stride_b, + stride_l, + stride_h, + stride_cos_l, + stride_sin_l, + H, + D, + HALF_D, + BLOCK_D=BLOCK_D, + ) + + return y.to(orig_dtype) + + +def test_accuracy_and_speed( + B: int = 16, + L: int = 1296, + H: int = 64, + D: int = 80, + warmup: int = 10, + repeats: int = 50, +): + torch.manual_seed(0) + freqs = torch.randn(L, D // 2, device="cuda") + x = torch.randn(B, L, H, D, device="cuda") + + # 误差 + y_ref = apply_rotary_pos_emb_vision_ref(x, freqs.cos(), freqs.sin()) + y_tri = apply_rotary_pos_emb_triton(x, freqs.cos(), freqs.sin()) + print("max abs error:", (y_ref - y_tri).abs().max().item()) + + # 预热 + for _ in range(warmup): + apply_rotary_pos_emb_vision_ref(x, freqs.cos(), freqs.sin()) + apply_rotary_pos_emb_triton(x, freqs.cos(), freqs.sin()) + torch.cuda.synchronize() + + # 计时 + def bench(fn): + start = torch.cuda.Event(True) + end = torch.cuda.Event(True) + start.record() + for _ in range(repeats): + fn(x, freqs.cos(), freqs.sin()) + end.record() + torch.cuda.synchronize() + return start.elapsed_time(end) / repeats # ms + + print(f"PyTorch : {bench(apply_rotary_pos_emb_vision_ref):.3f} ms") + print(f"Triton : {bench(apply_rotary_pos_emb_triton):.3f} ms") + + +if __name__ == "__main__": + test_accuracy_and_speed() diff --git a/lightllm/models/qwen2_vl/vision_process.py b/lightllm/models/qwen2_vl/vision_process.py index 9366ca747..ce0f7c771 100644 --- a/lightllm/models/qwen2_vl/vision_process.py +++ b/lightllm/models/qwen2_vl/vision_process.py @@ -1,38 +1,11 @@ -# coding=utf-8 -# Copyright 2024 The Qwen team, Alibaba Group and the HuggingFace Inc. team. All rights reserved. -# -# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX -# and OPT implementations in this library. It has been modified from its -# original forms to accommodate minor architectural differences compared -# to GPT-NeoX and OPT used by the Meta AI team that trained the model. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - from __future__ import annotations - -import base64 -from io import BytesIO import math -from typing import Dict, List, Optional, Union -import numpy as np -import requests import torch +import numpy as np from PIL import Image -from torchvision import io, transforms -from torchvision.transforms import InterpolationMode -from transformers import AutoProcessor +from typing import List, Optional, Union -from transformers.image_processing_utils import BaseImageProcessor, BatchFeature +from transformers.image_processing_utils import BaseImageProcessor from transformers.image_transforms import ( convert_to_rgb, resize, @@ -42,190 +15,48 @@ OPENAI_CLIP_MEAN, OPENAI_CLIP_STD, ChannelDimension, - ImageInput, PILImageResampling, get_image_size, infer_channel_dimension_format, - is_scaled_image, - is_valid_image, - make_list_of_images, to_numpy_array, - valid_images, - validate_preprocess_arguments, ) -from transformers.video_utils import VideoInput -from transformers.utils import TensorType, is_vision_available, logging - -logger = logging.get_logger(__name__) IMAGE_FACTOR = 28 MIN_PIXELS = 4 * 28 * 28 MAX_PIXELS = 16384 * 28 * 28 MAX_RATIO = 200 - -VIDEO_MIN_PIXELS = 128 * 28 * 28 -VIDEO_MAX_PIXELS = 768 * 28 * 28 -VIDEO_TOTAL_PIXELS = 24576 * 28 * 28 FRAME_FACTOR = 2 FPS = 2.0 FPS_MIN_FRAMES = 4 FPS_MAX_FRAMES = 768 -def make_batched_images(images) -> List[List[ImageInput]]: - """ - Accepts images in list or nested list format, and makes a list of images for preprocessing. - - Args: - images (`Union[List[List[ImageInput]], List[ImageInput], ImageInput]`): - The input image. - - Returns: - list: A list of images. - """ - if isinstance(images, (list, tuple)) and isinstance(images[0], (list, tuple)) and is_valid_image(images[0][0]): - return [img for img_list in images for img in img_list] - - elif isinstance(images, (list, tuple)) and is_valid_image(images[0]): - return images - - elif is_valid_image(images): - return [images] - - raise ValueError(f"Could not make batched images from {images}") - - -# Copied from transformers.models.llava_next_video.image_processing_llava_next_video.make_batched_videos -def make_batched_videos(videos) -> List[VideoInput]: - if isinstance(videos, (list, tuple)) and isinstance(videos[0], (list, tuple)) and is_valid_image(videos[0][0]): - return videos - - elif isinstance(videos, (list, tuple)) and is_valid_image(videos[0]): - if isinstance(videos[0], Image.Image): - return [videos] - elif len(videos[0].shape) == 4: - return [list(video) for video in videos] - - elif is_valid_image(videos) and len(videos.shape) == 4: - return [list(videos)] - - raise ValueError(f"Could not make batched video from {videos}") - - -def round_by_factor(number: int, factor: int) -> int: - """Returns the closest integer to 'number' that is divisible by 'factor'.""" - return round(number / factor) * factor - - -def ceil_by_factor(number: int, factor: int) -> int: - """Returns the smallest integer greater than or equal to 'number' that is divisible by 'factor'.""" - return math.ceil(number / factor) * factor - - -def floor_by_factor(number: int, factor: int) -> int: - """Returns the largest integer less than or equal to 'number' that is divisible by 'factor'.""" - return math.floor(number / factor) * factor - - def smart_resize( height: int, width: int, factor: int = IMAGE_FACTOR, min_pixels: int = MIN_PIXELS, max_pixels: int = MAX_PIXELS ) -> tuple[int, int]: - """ - Rescales the image so that the following conditions are met: - - 1. Both dimensions (height and width) are divisible by 'factor'. - - 2. The total number of pixels is within the range ['min_pixels', 'max_pixels']. - 3. The aspect ratio of the image is maintained as closely as possible. - """ if max(height, width) / min(height, width) > MAX_RATIO: raise ValueError( f"absolute aspect ratio must be smaller than {MAX_RATIO}, got {max(height, width) / min(height, width)}" ) - h_bar = max(factor, round_by_factor(height, factor)) - w_bar = max(factor, round_by_factor(width, factor)) + h_bar = max(factor, round(height / factor) * factor) + w_bar = max(factor, round(width / factor) * factor) if h_bar * w_bar > max_pixels: beta = math.sqrt((height * width) / max_pixels) - h_bar = floor_by_factor(height / beta, factor) - w_bar = floor_by_factor(width / beta, factor) + h_bar = math.floor(height / beta / factor) * factor + w_bar = math.floor(width / beta / factor) * factor elif h_bar * w_bar < min_pixels: beta = math.sqrt(min_pixels / (height * width)) - h_bar = ceil_by_factor(height * beta, factor) - w_bar = ceil_by_factor(width * beta, factor) + h_bar = math.ceil(height * beta / factor) * factor + w_bar = math.ceil(width * beta / factor) * factor return h_bar, w_bar -def fetch_image(ele: dict[str, str | Image.Image], size_factor: int = IMAGE_FACTOR) -> Image.Image: - if "image" in ele: - image = ele["image"] - else: - image = ele["image_url"] - image_obj = None - if isinstance(image, Image.Image): - image_obj = image - elif image.startswith("http://") or image.startswith("https://"): - image_obj = Image.open(requests.get(image, stream=True).raw) - elif image.startswith("file://"): - image_obj = Image.open(image[7:]) - elif image.startswith("data:image"): - data = image.split(";", 1)[1] - if data.startswith("base64,"): - data = base64.b64decode(data[7:]) - image_obj = Image.open(BytesIO(data)) - else: - image_obj = Image.open(image) - if image_obj is None: - raise ValueError(f"Unrecognized image input, support local path, http url, base64 and PIL.Image, got {image}") - image = image_obj.convert("RGB") - ## resize - if "resized_height" in ele and "resized_width" in ele: - resized_height, resized_width = smart_resize( - ele["resized_height"], - ele["resized_width"], - factor=size_factor, - ) - else: - width, height = image.size - min_pixels = ele.get("min_pixels", MIN_PIXELS) - max_pixels = ele.get("max_pixels", MAX_PIXELS) - resized_height, resized_width = smart_resize( - height, - width, - factor=size_factor, - min_pixels=min_pixels, - max_pixels=max_pixels, - ) - image = image.resize((resized_width, resized_height)) - return image - - -def get_image(image_file: str | Image.Image, size_factor: int = IMAGE_FACTOR) -> tuple[Image.Image, int, int]: - image_obj = None - - if isinstance(image_file, Image.Image): - image_obj = image_file - elif image_file.startswith("http://") or image_file.startswith("https://"): - image_obj = Image.open(requests.get(image_file, stream=True).raw) - elif image_file.startswith("file://"): - image_obj = Image.open(image_file[7:]) - elif image_file.startswith("data:image"): - data = image_file.split(";", 1)[1] - if data.startswith("base64,"): - data = base64.b64decode(data[7:]) - image_obj = Image.open(BytesIO(data)) - else: - image_obj = Image.open(image_file) - - if image_obj is None: - raise ValueError("Unrecognized image input. Supports local path, http url, base64, and PIL.Image.") +def resize_image(image_file: str | Image.Image, size_factor: int = IMAGE_FACTOR) -> tuple[Image.Image, int, int]: - image = image_obj.convert("RGB") - - # 获取原始宽度和高度 + image = image_file.convert("RGB") width, height = image.size - # 使用默认的最小像素和最大像素调整大小 resized_height, resized_width = smart_resize( height, width, @@ -233,56 +64,12 @@ def get_image(image_file: str | Image.Image, size_factor: int = IMAGE_FACTOR) -> min_pixels=MIN_PIXELS, max_pixels=MAX_PIXELS, ) - - # 调整图片大小 image = image.resize((resized_width, resized_height)) return image -def extract_vision_info(conversations: list[dict] | list[list[dict]]) -> list[dict]: - vision_infos = [] - if isinstance(conversations[0], dict): - conversations = [conversations] - for conversation in conversations: - for message in conversation: - if isinstance(message["content"], list): - for ele in message["content"]: - if ( - "image" in ele - or "image_url" in ele - or "video" in ele - or ele["type"] in ("image", "image_url", "video") - ): - vision_infos.append(ele) - return vision_infos - - -def process_vision_info( - conversations: list[dict] | list[list[dict]], -) -> tuple[list[Image.Image] | None, list[torch.Tensor | list[Image.Image]] | None]: - vision_infos = extract_vision_info(conversations) - ## Read images or videos - image_inputs = [] - # video_inputs = [] - for vision_info in vision_infos: - if "image" in vision_info or "image_url" in vision_info: - image_inputs.append(fetch_image(vision_info)) - # elif "video" in vision_info: - # video_inputs.append(fetch_video(vision_info)) - else: - raise ValueError("image, image_url or video should in content.") - if len(image_inputs) == 0: - image_inputs = None - return image_inputs - - -# adapted from -# transformers/blob/main/src/transformers/models/qwen2_vl/image_processing_qwen2_vl.py class Qwen2VLImageProcessor(BaseImageProcessor): - - model_input_names = ["pixel_values", "image_grid_thw", "pixel_values_videos", "video_grid_thw"] - def __init__( self, do_resize: bool = True, @@ -306,6 +93,7 @@ def __init__( self.do_rescale = do_rescale self.rescale_factor = rescale_factor self.do_normalize = do_normalize + self.do_convert_rgb = do_convert_rgb self.image_mean = image_mean if image_mean is not None else OPENAI_CLIP_MEAN self.image_std = image_std if image_std is not None else OPENAI_CLIP_STD self.min_pixels = min_pixels @@ -313,76 +101,46 @@ def __init__( self.patch_size = patch_size self.temporal_patch_size = temporal_patch_size self.merge_size = merge_size - self.size = {"min_pixels": min_pixels, "max_pixels": max_pixels} - self.do_convert_rgb = do_convert_rgb - - def _preprocess( - self, - images: Union[ImageInput, VideoInput], - do_resize: bool = None, - resample: PILImageResampling = None, - do_rescale: bool = None, - rescale_factor: float = None, - do_normalize: bool = None, - image_mean: Optional[Union[float, List[float]]] = None, - image_std: Optional[Union[float, List[float]]] = None, - do_convert_rgb: bool = None, - data_format: Optional[ChannelDimension] = ChannelDimension.FIRST, - input_data_format: Optional[Union[str, ChannelDimension]] = None, - ): - - images = make_list_of_images(images) + self.data_format = ChannelDimension.FIRST - if do_convert_rgb: - images = [convert_to_rgb(image) for image in images] + def preprocess(self, image): + if self.do_convert_rgb: + image = convert_to_rgb(image) + image = to_numpy_array(image) + input_data_format = infer_channel_dimension_format(image) + height, width = get_image_size(image, channel_dim=input_data_format) - # All transformations expect numpy arrays. - images = [to_numpy_array(image) for image in images] - - if is_scaled_image(images[0]) and do_rescale: - logger.warning_once( - "It looks like you are trying to rescale already rescaled images. If the input" - " images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again." + resized_height, resized_width = height, width + if self.do_resize: + resized_height, resized_width = smart_resize( + height, + width, + factor=self.patch_size * self.merge_size, + min_pixels=self.min_pixels, + max_pixels=self.max_pixels, + ) + image = resize( + image, size=(resized_height, resized_width), resample=self.resample, input_data_format=input_data_format ) - if input_data_format is None: - # We assume that all images have the same channel dimension format. - input_data_format = infer_channel_dimension_format(images[0]) - height, width = get_image_size(images[0], channel_dim=input_data_format) - resized_height, resized_width = height, width - processed_images = [] - for image in images: - if do_resize: - resized_height, resized_width = smart_resize( - height, - width, - factor=self.patch_size * self.merge_size, - min_pixels=self.min_pixels, - max_pixels=self.max_pixels, - ) - image = resize( - image, size=(resized_height, resized_width), resample=resample, input_data_format=input_data_format - ) + if self.do_rescale: + image = self.rescale(image, scale=self.rescale_factor, input_data_format=input_data_format) - if do_rescale: - image = self.rescale(image, scale=rescale_factor, input_data_format=input_data_format) + if self.do_normalize: + image = self.normalize( + image=image, mean=self.image_mean, std=self.image_std, input_data_format=input_data_format + ) - if do_normalize: - image = self.normalize(image=image, mean=image_mean, std=image_std, input_data_format=input_data_format) + image = to_channel_dimension_format(image, self.data_format, input_channel_dim=input_data_format) - image = to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format) - processed_images.append(image) + patches = np.array([image]) - patches = np.array(processed_images) - if data_format == ChannelDimension.LAST: - patches = patches.transpose(0, 3, 1, 2) if patches.shape[0] == 1: patches = np.tile(patches, (self.temporal_patch_size, 1, 1, 1)) channel = patches.shape[1] - grid_t = patches.shape[0] // self.temporal_patch_size grid_h, grid_w = resized_height // self.patch_size, resized_width // self.patch_size patches = patches.reshape( - grid_t, + 1, self.temporal_patch_size, channel, grid_h // self.merge_size, @@ -394,102 +152,10 @@ def _preprocess( ) patches = patches.transpose(0, 3, 6, 4, 7, 2, 1, 5, 8) flatten_patches = patches.reshape( - grid_t * grid_h * grid_w, channel * self.temporal_patch_size * self.patch_size * self.patch_size - ) - - return flatten_patches, (grid_t, grid_h, grid_w) - - def preprocess( - self, - images: ImageInput, - videos: VideoInput = None, - do_resize: bool = None, - size: Dict[str, int] = None, - resample: PILImageResampling = None, - do_rescale: bool = None, - rescale_factor: float = None, - do_normalize: bool = None, - image_mean: Optional[Union[float, List[float]]] = None, - image_std: Optional[Union[float, List[float]]] = None, - do_convert_rgb: bool = None, - return_tensors: Optional[Union[str, TensorType]] = "pt", - data_format: Optional[ChannelDimension] = ChannelDimension.FIRST, - input_data_format: Optional[Union[str, ChannelDimension]] = None, - ): - - do_resize = do_resize if do_resize is not None else self.do_resize - size = size if size is not None else self.size - resample = resample if resample is not None else self.resample - do_rescale = do_rescale if do_rescale is not None else self.do_rescale - rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor - do_normalize = do_normalize if do_normalize is not None else self.do_normalize - image_mean = image_mean if image_mean is not None else self.image_mean - image_std = image_std if image_std is not None else self.image_std - do_convert_rgb = do_convert_rgb if do_convert_rgb is not None else self.do_convert_rgb - - if images is not None: - images = make_batched_images(images) - if videos is not None: - videos = make_batched_videos(videos) - - if images is not None and not valid_images(images): - raise ValueError( - "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, " - "torch.Tensor, tf.Tensor or jax.ndarray." - ) - - validate_preprocess_arguments( - rescale_factor=rescale_factor, - do_normalize=do_normalize, - image_mean=image_mean, - image_std=image_std, - do_resize=do_resize, - size=size, - resample=resample, + grid_h * grid_w, channel * self.temporal_patch_size * self.patch_size * self.patch_size ) + image_grid_thw = (1, grid_h, grid_w) + pixel_values = torch.as_tensor(flatten_patches) + grid_thw = torch.as_tensor([image_grid_thw]) - if images is not None: - pixel_values, vision_grid_thws = [], [] - for image in images: - patches, image_grid_thw = self._preprocess( - image, - do_resize=do_resize, - resample=resample, - do_rescale=do_rescale, - rescale_factor=rescale_factor, - do_normalize=do_normalize, - image_mean=image_mean, - image_std=image_std, - data_format=data_format, - do_convert_rgb=do_convert_rgb, - input_data_format=input_data_format, - ) - pixel_values.extend(patches) - vision_grid_thws.append(image_grid_thw) - pixel_values = np.array(pixel_values) - vision_grid_thws = np.array(vision_grid_thws) - data = {"pixel_values": pixel_values, "image_grid_thw": vision_grid_thws} - - if videos is not None: - pixel_values, vision_grid_thws = [], [] - for images in videos: - patches, video_grid_thw = self._preprocess( - images, - do_resize=do_resize, - resample=resample, - do_rescale=do_rescale, - rescale_factor=rescale_factor, - do_normalize=do_normalize, - image_mean=image_mean, - image_std=image_std, - data_format=data_format, - do_convert_rgb=do_convert_rgb, - input_data_format=input_data_format, - ) - pixel_values.extend(patches) - vision_grid_thws.append(video_grid_thw) - pixel_values = np.array(pixel_values) - vision_grid_thws = np.array(vision_grid_thws) - data = {"pixel_values_videos": pixel_values, "video_grid_thw": vision_grid_thws} - - return BatchFeature(data=data, tensor_type=return_tensors) + return pixel_values, grid_thw diff --git a/lightllm/models/qwen_vl/layer_infer/pre_layer_infer.py b/lightllm/models/qwen_vl/layer_infer/pre_layer_infer.py index b5b31a413..e6d0dba2d 100644 --- a/lightllm/models/qwen_vl/layer_infer/pre_layer_infer.py +++ b/lightllm/models/qwen_vl/layer_infer/pre_layer_infer.py @@ -1,3 +1,4 @@ +import rpyc import torch import torch.distributed as dist @@ -6,10 +7,10 @@ from lightllm.models.llama.layer_infer.pre_layer_infer import LlamaPreLayerInfer from lightllm.utils.infer_utils import mark_cost_time -from lightllm.server.embed_cache.utils import bytes2tensor, read_shm, get_shm_name_embed +from lightllm.utils.envs_utils import get_env_start_args, get_cache_port from lightllm.common.basemodel.triton_kernel.multimodal_emb import multimodal_emb from lightllm.distributed.communication_op import all_reduce - +from lightllm.server.embed_cache.utils import bytes2tensor, tensor2bytes, read_shm, create_shm, get_shm_name_embed """ infer_state.multimodal_params: batch list of MultimodalParams-dict like: @@ -31,47 +32,57 @@ def __init__(self, network_config, mode): super().__init__(network_config, mode) return - def context_forward(self, input_ids, infer_state: LlamaInferStateInfo, layer_weight: LlamaPreAndPostLayerWeight): - - img_weight = [] - img_start_token_ids = [] - img_token_lens = [] - img_start_loc = 0 - img_start_locs = [] + def _infer_image_embeds(self, infer_state, layer_weight): + image_weight = [] + if layer_weight.visual_model is None: + for batch_id, p in enumerate(infer_state.multimodal_params): + for img in p["images"] + p["audios"]: + # skip the same image + if img.get("_prefill_", True): + # pull the img_embeds by uid from shm + image_embed = read_shm(get_shm_name_embed(img["uuid"])) + image_weight.append(bytes2tensor(image_embed).cuda().reshape(img["token_num"], -1)) + else: + for batch_id, p in enumerate(infer_state.multimodal_params): + for img in p["images"] + p["audios"]: + if img.get("_prefill_", True): + image_data = img["image_data"].to("cuda", non_blocking=True) + image_grid_thw = img["image_grid_thw"] + # image_embed = torch.zeros( + # (img["token_num"],layer_weight.wte_weight_.shape[1]),device="cuda",dtype=torch.bfloat16) + image_embed = layer_weight.visual_model.forward(image_data, image_grid_thw).view( + img["token_num"], -1 + ) + image_weight.append(image_embed) + if len(image_weight) > 0: + image_weight = torch.cat(image_weight, dim=0) + image_weight = image_weight / self.tp_world_size_ + assert image_weight.shape[1] == layer_weight.wte_weight_.shape[1], ( + f"Dimension mismatch: text weight dimension is {layer_weight.wte_weight_.shape[1]}, " + f"but image weight dimension is {image_weight.shape[1]}" + ) + else: + hidden_size = layer_weight.wte_weight_.shape[1] + image_weight = torch.empty((0, hidden_size), device="cpu", dtype=layer_weight.wte_weight_.dtype).to( + "cuda", non_blocking=True + ) + return image_weight - device = layer_weight.wte_weight_.device + def context_forward(self, input_ids, infer_state: LlamaInferStateInfo, layer_weight: LlamaPreAndPostLayerWeight): dtype = layer_weight.wte_weight_.dtype hidden_size = layer_weight.wte_weight_.shape[1] - infer_state.mark_multimodal_objs_for_prefill(input_ids=input_ids) + img_weight = self._infer_image_embeds(infer_state, layer_weight) - for batch_id, p in enumerate(infer_state.multimodal_params): - for img in p["images"] + p["audios"]: - # skip the same image - if img["token_id"] in img_start_token_ids or img["_prefill_"] is False: - continue - # pull the img_embeds by uid from shm - data = read_shm(get_shm_name_embed(img["uuid"])) - img_weight.append(bytes2tensor(data).cuda().reshape(img["token_num"], -1)) - img_start_token_ids.append(img["token_id"]) - img_token_lens.append(img["token_num"]) - img_start_locs.append(img_start_loc) - img_start_loc += img["token_num"] - out = torch.zeros((len(input_ids), hidden_size), dtype=dtype, device=device) - if len(img_weight) > 0: - img_weight = torch.cat(img_weight, dim=0).to(device=device, dtype=dtype) + out = torch.zeros((len(input_ids), hidden_size), dtype=dtype, device="cpu").to("cuda", non_blocking=True) + if infer_state.image_start_token_ids is not None: + img_start_token_ids = infer_state.image_start_token_ids.to("cuda", non_blocking=True) + img_token_lens = infer_state.image_token_lens.to("cuda", non_blocking=True) + img_start_locs = infer_state.image_start_locs.to("cuda", non_blocking=True) else: - img_weight = torch.empty((0, hidden_size), device=device, dtype=dtype) - assert img_weight.shape[1] == hidden_size, ( - f"Dimension mismatch: text weight dimension is {hidden_size}, " - f"but image weight dimension is {img_weight.shape[1]}" - ) - # each tp will fill the img embeds, should divide by world_size - img_weight = img_weight / self.tp_world_size_ - img_start_token_ids = torch.Tensor(img_start_token_ids).to(device=device, dtype=torch.long) - img_token_lens = torch.Tensor(img_token_lens).to(device=device, dtype=torch.long) - img_start_locs = torch.Tensor(img_start_locs).to(device=device, dtype=torch.long) - + img_start_token_ids = torch.empty((0,), device="cpu", dtype=torch.long).to("cuda", non_blocking=True) + img_token_lens = torch.empty((0,), device="cpu", dtype=torch.long).to("cuda", non_blocking=True) + img_start_locs = torch.empty((0,), device="cpu", dtype=torch.long).to("cuda", non_blocking=True) multimodal_emb( out, input_ids, diff --git a/lightllm/models/qwen_vl/qwen_visual.py b/lightllm/models/qwen_vl/qwen_visual.py index f6468b144..b826b52d4 100644 --- a/lightllm/models/qwen_vl/qwen_visual.py +++ b/lightllm/models/qwen_vl/qwen_visual.py @@ -333,6 +333,7 @@ def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None): class QWenVisionTransformer(nn.Module): def __init__( self, + kvargs, image_size: int, patch_size: int, width: int, @@ -344,6 +345,7 @@ def __init__( **kwargs, ): super().__init__() + self.weight_dir = kvargs["weight_dir"] image_height, image_width = self.image_size = (image_size, image_size) patch_height, patch_width = self.patch_size = (patch_size, patch_size) self.grid_size = (image_height // patch_height, image_width // patch_width) @@ -388,6 +390,9 @@ def __init__( self.ln_post = norm_layer(output_dim) self.proj = nn.Parameter((output_dim ** -0.5) * torch.randn(output_dim, output_dim)) + self.load_model(self.weight_dir) + self.cuda() + def forward(self, x: torch.Tensor): x = x.to( dtype=self.transformer.get_cast_dtype(), diff --git a/lightllm/models/tarsier2/tarsier2_visual.py b/lightllm/models/tarsier2/tarsier2_visual.py index 978b9bd17..3da66423e 100644 --- a/lightllm/models/tarsier2/tarsier2_visual.py +++ b/lightllm/models/tarsier2/tarsier2_visual.py @@ -13,10 +13,10 @@ from transformers import AutoModel from safetensors import safe_open -from lightllm.models.qwen2_vl.qwen2_visual import Qwen2VisionTransformerPretrainedModel +from lightllm.models.qwen2_vl.qwen2_visual import Qwen2VLTransformer from lightllm.server.embed_cache.utils import read_shm, get_shm_name_data from lightllm.server.multimodal_params import ImageItem -from lightllm.models.qwen2_vl.vision_process import Qwen2VLImageProcessor, get_image +from lightllm.models.qwen2_vl.vision_process import Qwen2VLImageProcessor, resize_image def add_split_tokens(image_features, image_newline_embed, image_new_embed): @@ -152,6 +152,7 @@ def forward(self, image_features, input_embeddings): class TarsierVisionTransformerPretrainedModel(nn.Module): def __init__( self, + kvargs, vision_config=None, text_config=None, ignore_index=-100, @@ -165,7 +166,8 @@ def __init__( **kwargs, ): super().__init__() - self.vision_tower = Qwen2VisionTransformerPretrainedModel(**vision_config) + self.weight_dir = kvargs["weight_dir"] + self.vision_tower = Qwen2VLTransformer(**vision_config) if projection_head == "Pixel_Shuffle": self.multi_modal_projector = PixelShuffleMultiModalProjector( @@ -195,6 +197,9 @@ def __init__( self.image_token_index = image_token_index self.merge_size = 1 + self.load_model(self.weight_dir) + self.cuda() + def forward( self, pixel_values: torch.Tensor = None, @@ -253,7 +258,7 @@ def encode(self, images: List[ImageItem]): uuids.append(img.uuid) image_data = read_shm(get_shm_name_data(img.uuid)) image_data = Image.open(BytesIO(image_data)) - image_data = get_image(image_data) + image_data = resize_image(image_data) image_inputs = self.processor.preprocess(images=image_data, return_tensors="pt") pixel_values = image_inputs["pixel_values"].to(dtype=torch.bfloat16) image_grid_thw = image_inputs["image_grid_thw"] diff --git a/lightllm/models/vit/layer_weights/transformer_layer_weight.py b/lightllm/models/vit/layer_weights/transformer_layer_weight.py index 55d73fa73..3c42f712e 100644 --- a/lightllm/models/vit/layer_weights/transformer_layer_weight.py +++ b/lightllm/models/vit/layer_weights/transformer_layer_weight.py @@ -11,7 +11,9 @@ MultiROWMMWeight, TpNormWeight, ) -from lightllm.utils.dist_utils import get_current_device_id +from lightllm.utils.dist_utils import ( + get_current_device_id, +) class ViTTransformerLayerWeight(TransformerLayerWeight): diff --git a/lightllm/models/vit/model.py b/lightllm/models/vit/model.py index 01bb69bdf..7aa95eae8 100644 --- a/lightllm/models/vit/model.py +++ b/lightllm/models/vit/model.py @@ -18,7 +18,8 @@ from io import BytesIO from rpyc.utils.classic import obtain from lightllm.common.quantization import Quantcfg -from lightllm.utils.dist_utils import get_dp_world_size +from lightllm.utils.dist_utils import get_dp_world_size, get_global_world_size +from lightllm.utils.envs_utils import get_env_start_args from lightllm.common.basemodel.layer_infer.cache_tensor_manager import g_cache_manager @@ -57,6 +58,20 @@ def __init__(self, kvargs): self._check_max_len_infer() return + def load_image(self, img: List[ImageItem]): + pixel_values = None + if isinstance(img, ImageItem): + image_data = read_shm(get_shm_name_data(img.uuid)) + image_data = Image.open(BytesIO(image_data)) + pixel_values = self.load_image_func(image_data, max_num=img.extra_params["image_patch_max_num"]) + elif isinstance(img, dict): + image_data = read_shm(get_shm_name_data(img["uuid"])) + image_data = Image.open(BytesIO(image_data)) + pixel_values = self.load_image_func(image_data, max_num=img["extra_params"]["image_patch_max_num"]) + else: + raise Exception("Unsupport input types: {} for {}".format(type(img), img)) + return pixel_values.to(dtype=self.data_type), None + @final @torch.no_grad() def _check_max_len_infer(self): @@ -68,8 +83,9 @@ def _check_max_len_infer(self): dummy_images = torch.randn( (self.MAX_PATH_NUM * self.max_batch_size, 3, self.IMAGE_H, self.IMAGE_W), dtype=self.data_type ).cuda() - all_img_embeds = self.forward(dummy_images) + all_img_embeds = self.forward(dummy_images, image_gird_thw=None) del all_img_embeds + del dummy_images logger.info(f"vit check max_len {self.max_batch_size} infer ok") except (RuntimeError, torch.OutOfMemoryError) as e: logger.exception(str(e)) @@ -150,6 +166,8 @@ def _init_infer_layer(self): return def _init_datatype(self): + if isinstance(self.data_type, torch.dtype): + return if self.data_type in ["fp16", "float16"]: self.data_type = torch.float16 elif self.data_type in ["bf16", "bfloat16"]: @@ -160,13 +178,11 @@ def _init_datatype(self): raise ValueError(f"Unsupport datatype {self.data_type}!") @torch.no_grad() - def forward(self, pixel_values): - g_cache_manager.cache_env_in() + def forward(self, pixel_values, image_gird_thw): input_embs = self.pre_infer.forward(pixel_values, self.pre_post_weight) for i in range(self.layers_num + self.select_layer + 1): input_embs = self.layers_infer[i].forward(input_embs, self.trans_layers_weight[i]) input_embs = self.post_infer.forward(input_embs[:, 1:, :], self.pre_post_weight) - g_cache_manager.cache_env_out() return input_embs @torch.no_grad() @@ -182,6 +198,12 @@ def encode(self, images: List[ImageItem]): image_data = Image.open(BytesIO(image_data)) t = self.load_image_func(image_data, max_num=img.extra_params["image_patch_max_num"]) img_tensors.append(t) + elif isinstance(img, dict): + uuids.append(img["uuid"]) + image_data = read_shm(get_shm_name_data(img["uuid"])) + image_data = Image.open(BytesIO(image_data)) + t = self.load_image_func(image_data, max_num=img["extra_params"]["image_patch_max_num"]) + img_tensors.append(t) else: raise Exception("Unsupport input types: {} for {}".format(type(img), img)) diff --git a/lightllm/models/vit/triton_kernel/flashattention_nopad.py b/lightllm/models/vit/triton_kernel/flashattention_nopad.py index ab3770a36..fcb6fbb5d 100644 --- a/lightllm/models/vit/triton_kernel/flashattention_nopad.py +++ b/lightllm/models/vit/triton_kernel/flashattention_nopad.py @@ -215,7 +215,7 @@ def flash_attention_fwd(q, k, v, o, cu_seqlens, max_seqlen): 统一的 Flash Attention 接口。如果 sgl_kernel 存在, 则使用 sgl_kernel里的接口,否则使用 Triton 版本。 """ - if _flash_attn_v3_available and is_hopper() and False: + if _flash_attn_v3_available and is_hopper(): flash_attention_v3_fwd(q, k, v, o, cu_seqlens, max_seqlen) else: _flash_attention_triton_fwd(q, k, v, o, cu_seqlens, max_seqlen) diff --git a/lightllm/server/api_cli.py b/lightllm/server/api_cli.py index dfbf0c84b..b784a4040 100644 --- a/lightllm/server/api_cli.py +++ b/lightllm/server/api_cli.py @@ -239,6 +239,11 @@ def make_argument_parser() -> argparse.ArgumentParser: parser.add_argument( "--enable_multimodal", action="store_true", help="Whether or not to allow to load additional visual models." ) + parser.add_argument( + "--disable_extra_process_for_multimodal", + action="store_true", + help="Whether or not to disable extra process for multimodal.", + ) parser.add_argument( "--enable_multimodal_audio", action="store_true", diff --git a/lightllm/server/api_start.py b/lightllm/server/api_start.py index c2a87b4c3..5f0d5ece2 100644 --- a/lightllm/server/api_start.py +++ b/lightllm/server/api_start.py @@ -9,7 +9,7 @@ from .metrics.manager import start_metric_manager from .embed_cache.manager import start_cache_manager from lightllm.utils.log_utils import init_logger -from lightllm.utils.envs_utils import set_env_start_args, set_unique_server_name, get_unique_server_name +from lightllm.utils.envs_utils import set_env_start_args, set_unique_server_name, get_unique_server_name, set_cache_port from lightllm.utils.envs_utils import get_lightllm_gunicorn_time_out_seconds, get_lightllm_gunicorn_keep_alive from .detokenization.manager import start_detokenization_process from .router.manager import start_router_process @@ -242,6 +242,7 @@ def normal_or_p_d_start(args): args.cache_port = cache_port args.metric_port = metric_port + set_cache_port(args.cache_port) # 申请在 p d 分离模式下,会用的端口 args.pd_node_infer_rpyc_ports = can_use_ports[0:node_world_size] # p d 分离模式下用于标识节点的id @@ -269,7 +270,7 @@ def normal_or_p_d_start(args): ], start_args=[(cache_port, args)], ) - if args.enable_multimodal_audio: + if args.enable_multimodal_audio and not args.disable_extra_process_for_multimodal: from .audioserver.manager import start_audio_process process_manager.start_submodule_processes( @@ -289,7 +290,7 @@ def normal_or_p_d_start(args): ], ) - else: + elif not args.disable_extra_process_for_multimodal: process_manager.start_submodule_processes( start_funcs=[ start_visual_process, diff --git a/lightllm/server/httpserver/manager.py b/lightllm/server/httpserver/manager.py index de552d80c..b943be3a3 100644 --- a/lightllm/server/httpserver/manager.py +++ b/lightllm/server/httpserver/manager.py @@ -81,11 +81,12 @@ def __init__( ) self.enable_multimodal = enable_multimodal + self.disable_extra_process_for_multimodal = args.disable_extra_process_for_multimodal if self.enable_multimodal: self.cache_client = rpyc.connect("localhost", cache_port, config={"allow_pickle": True}) - self.send_to_visual = context.socket(zmq.PUSH) - self.send_to_visual.connect(f"{args.zmq_mode}127.0.0.1:{visual_port}") - + if not self.disable_extra_process_for_multimodal: + self.send_to_visual = context.socket(zmq.PUSH) + self.send_to_visual.connect(f"{args.zmq_mode}127.0.0.1:{visual_port}") self.shm_req_manager = ShmReqManager() self.recv_from_detokenization = context.socket(zmq.SUB) @@ -438,7 +439,7 @@ async def transfer_to_next_module( ): if self.pd_mode == NodeRole.P: - if self.enable_multimodal: + if self.enable_multimodal and not self.disable_extra_process_for_multimodal: self.send_to_visual.send_pyobj( group_req_objs.to_group_req_index(), protocol=pickle.HIGHEST_PROTOCOL, @@ -459,7 +460,7 @@ async def transfer_to_next_module( return if self.pd_mode == NodeRole.NORMAL: - if self.enable_multimodal: + if self.enable_multimodal and not self.disable_extra_process_for_multimodal: self.send_to_visual.send_pyobj( group_req_objs.to_group_req_index(), protocol=pickle.HIGHEST_PROTOCOL, diff --git a/lightllm/server/multimodal_params.py b/lightllm/server/multimodal_params.py index e3c1d19d2..5af5d5d95 100644 --- a/lightllm/server/multimodal_params.py +++ b/lightllm/server/multimodal_params.py @@ -120,6 +120,7 @@ def to_dict(self): ret["uuid"] = self.uuid ret["token_id"] = self.token_id ret["token_num"] = self.token_num + ret["extra_params"] = self.extra_params return ret def to_origin_dict(self): diff --git a/lightllm/server/router/model_infer/mode_backend/base_backend.py b/lightllm/server/router/model_infer/mode_backend/base_backend.py index fd75afdbf..3a6572839 100644 --- a/lightllm/server/router/model_infer/mode_backend/base_backend.py +++ b/lightllm/server/router/model_infer/mode_backend/base_backend.py @@ -459,6 +459,36 @@ def _get_classed_reqs( return prefill_reqs, decode_reqs + def _preprocess_image(self, batch: ModelInput): + # 如果不是多模态模型,或者单进程推理,直接跳过 + args = get_env_start_args() + if not args.enable_multimodal: + return + # assert self.model.visual_model is not None, "visual_model is not initialized" + image_start_locs = [] + image_token_lens = [] + image_start_token_ids = [] + image_start_loc = 0 + for i, p in enumerate(batch.multimodal_params): + for img in p["images"]: + # 重复图片 + if img["token_id"] in image_start_token_ids: + continue + image_start_locs.append(image_start_loc) + image_token_lens.append(img["token_num"]) + image_start_token_ids.append(img["token_id"]) + image_start_loc += img["token_num"] + if not args.disable_extra_process_for_multimodal: + continue + # 预拉取已经存在的image data + image_data, image_grid_thw = self.model.pre_post_weight.visual_model.load_image(img) + img["image_data"] = image_data + img["image_grid_thw"] = image_grid_thw + batch.image_start_locs = torch.tensor(image_start_locs, device="cpu", dtype=torch.long) + batch.image_token_lens = torch.tensor(image_token_lens, device="cpu", dtype=torch.long) + batch.image_start_token_ids = torch.tensor(image_start_token_ids, device="cpu", dtype=torch.long) + return batch + def _pre_handle_finished_reqs(self, finished_reqs: List[InferReq]): """ 给 PD 分离模式下,prefill node 使用的继承钩子函数,用于发起 kv 传输任务。 diff --git a/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl.py b/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl.py index 39d345ff5..4972194b2 100644 --- a/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl.py +++ b/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl.py @@ -96,6 +96,7 @@ def prefill_normal( model_input, run_reqs = prepare_prefill_inputs( prefill_reqs, is_chuncked_mode=not self.disable_chunked_prefill, is_multimodal=self.is_multimodal ) + self._preprocess_image(model_input) with torch.cuda.stream(g_infer_context.get_overlap_stream()): model_output = self.model.forward(model_input) logits = model_output.logits diff --git a/lightllm/server/visualserver/model_infer/model_rpc.py b/lightllm/server/visualserver/model_infer/model_rpc.py index d2d45f2fd..d27a31207 100644 --- a/lightllm/server/visualserver/model_infer/model_rpc.py +++ b/lightllm/server/visualserver/model_infer/model_rpc.py @@ -14,8 +14,8 @@ from lightllm.models.gemma3.gemma3_visual import Gemma3VisionModel from lightllm.models.vit.model import VisionTransformer from lightllm.server.multimodal_params import MultimodalParams, ImageItem -from lightllm.models.qwen2_vl.qwen2_visual import Qwen2VisionTransformerPretrainedModel -from lightllm.models.qwen2_5_vl.qwen2_5_visual import Qwen2_5_VisionTransformerPretrainedModel +from lightllm.models.qwen2_vl.qwen2_visual import Qwen2VLTransformer +from lightllm.models.qwen2_5_vl.qwen2_5_visual import Qwen2_5VLTransformer from lightllm.models.tarsier2.tarsier2_visual import TarsierVisionTransformerPretrainedModel from lightllm.server.embed_cache.utils import tensor2bytes, read_shm, create_shm, get_shm_name_data, get_shm_name_embed from lightllm.utils.infer_utils import set_random_seed @@ -46,33 +46,31 @@ def exposed_init_model(self, kvargs): try: self.model_type = model_cfg["model_type"] + kvargs = { + "weight_dir": weight_dir, + "data_type": self.data_type, + "quant_type": kvargs["quant_type"], + "quant_cfg": kvargs["quant_cfg"], + "max_batch_size": kvargs["max_batch_size"], + } if self.model_type == "qwen": - self.model = QWenVisionTransformer(**model_cfg["visual"]).eval().bfloat16() + self.model = QWenVisionTransformer(kvargs, **model_cfg["visual"]).eval().bfloat16() elif self.model_type == "qwen2_vl": - self.model = Qwen2VisionTransformerPretrainedModel(**model_cfg["vision_config"]).eval().bfloat16() + self.model = Qwen2VLTransformer(kvargs, **model_cfg["vision_config"]).eval().bfloat16() elif self.model_type == "qwen2_5_vl": - self.model = Qwen2_5_VisionTransformerPretrainedModel(**model_cfg["vision_config"]).eval().bfloat16() + self.model = Qwen2_5VLTransformer(kvargs, **model_cfg["vision_config"]).eval().bfloat16() elif model_cfg["architectures"][0] == "TarsierForConditionalGeneration": - self.model = TarsierVisionTransformerPretrainedModel(**model_cfg).eval().bfloat16() + self.model = TarsierVisionTransformerPretrainedModel(kvargs, **model_cfg).eval().bfloat16() elif self.model_type == "llava": - self.model = LlavaVisionModel() + self.model = LlavaVisionModel(kvargs) elif self.model_type == "internvl_chat": - kvargs = { - "weight_dir": weight_dir, - "data_type": self.data_type, - "quant_type": kvargs["quant_type"], - "quant_cfg": kvargs["quant_cfg"], - "max_batch_size": kvargs["max_batch_size"], - } self.model = VisionTransformer(kvargs) # self.model = InternVLVisionModel() elif self.model_type == "gemma3": - self.model = Gemma3VisionModel() + self.model = Gemma3VisionModel(kvargs) else: raise Exception(f"can not support {self.model_type} now") - self.model.load_model(weight_dir) - self.model = self.model.cuda() except Exception as e: print("#" * 16) print("load model error:", str(e), e, type(e)) diff --git a/lightllm/utils/envs_utils.py b/lightllm/utils/envs_utils.py index b78784d82..14b75628f 100644 --- a/lightllm/utils/envs_utils.py +++ b/lightllm/utils/envs_utils.py @@ -17,6 +17,16 @@ def set_unique_server_name(args): return +def set_cache_port(cache_port): + os.environ["LIGHTLLM_CACHE_PORT"] = str(cache_port) + cache_port = os.environ.get("LIGHTLLM_CACHE_PORT") + + +def get_cache_port(): + _cache_port = os.environ.get("LIGHTLLM_CACHE_PORT") + return int(_cache_port) + + @lru_cache(maxsize=None) def get_unique_server_name(): service_uni_name = os.getenv("LIGHTLLM_UNIQUE_SERVICE_NAME_ID") diff --git a/unit_tests/common/basemodel/triton_kernel/test_multimodal_emb.py b/unit_tests/common/basemodel/triton_kernel/test_multimodal_emb.py index f5f7e903d..438eaa157 100755 --- a/unit_tests/common/basemodel/triton_kernel/test_multimodal_emb.py +++ b/unit_tests/common/basemodel/triton_kernel/test_multimodal_emb.py @@ -6,7 +6,7 @@ logger = init_logger(__name__) -def test_mark_multimodal_obj(): +def test_mark_mubltimodal_obj(): obj_start_ids = torch.tensor([1, 4, 100], device="cuda", dtype=torch.int64) obj_token_lens = torch.tensor([1, 3, 2], device="cuda", dtype=torch.int64) input_ids = torch.tensor([1, 7, 9, 333], device="cuda", dtype=torch.int64) diff --git a/unit_tests/models/llama/llama_gqa_decode_vsm.py b/unit_tests/models/llama/llama_gqa_decode_vsm.py new file mode 100644 index 000000000..f124a28eb --- /dev/null +++ b/unit_tests/models/llama/llama_gqa_decode_vsm.py @@ -0,0 +1,104 @@ +import unittest +import random +import torch +from tqdm import tqdm +from lightllm.common.basemodel.infer_struct import InferStateInfo +from lightllm.common.req_manager import ReqManager +from lightllm.models.llama.triton_kernel.gqa_flash_decoding_vsm import ( + gqa_token_decode_attention_flash_decoding_vsm, +) +from lightllm.models.llama.triton_kernel.gqa_flash_decoding import ( + gqa_token_decode_attention_flash_decoding, +) + + +class TestVSMGQADecoding(unittest.TestCase): + def test_vsm_gqa_decoding_align(self): + random.seed(0) + torch.manual_seed(0) + torch.cuda.manual_seed(0) + torch.cuda.manual_seed_all(0) + torch.backends.cudnn.deterministic = True + torch.backends.cudnn.benchmark = False + + bs_list = [1, 8, 16, 32, 64, 128, 256] + group_size_list = [16, 32, 64] + seq_len_list = [128, 512, 1024, 2048, 4096, 8192] + q_head_dim_list = [64, 128] + q_head_num_list = [8, 16, 32] + + def get_test_configs(): + for bs in bs_list: + for group_size in group_size_list: + for seq_len_m in seq_len_list: + for q_head_dim in q_head_dim_list: + for q_head_num in q_head_num_list: + if q_head_num < group_size: + continue + yield bs, group_size, seq_len_m, q_head_dim, q_head_num + + for bs, group_size, seq_len_m, q_head_dim, q_head_num in tqdm(list(get_test_configs())): + kv_head_num = q_head_num // group_size + q_head_dim = q_head_dim + kv_head_dim = q_head_dim + seq_len = (torch.zeros(bs, dtype=torch.int32) + seq_len_m).to(torch.int32) + total_token_in_the_batch = seq_len.sum().item() + rounded_total_token_in_the_batch = (total_token_in_the_batch + 128 - 1) // 128 * 128 + + q_shape = [bs, q_head_num, q_head_dim] + kv_shape = [ + rounded_total_token_in_the_batch, + kv_head_num, + kv_head_dim, + ] + qkv_dtype = torch.float16 + + q, k, v = ( + torch.randn(q_shape, dtype=qkv_dtype, device="cuda"), + torch.randn(kv_shape, dtype=qkv_dtype, device="cuda"), + torch.randn(kv_shape, dtype=qkv_dtype, device="cuda"), + ) + q, k, v = q / 10, k / 10, v / 10 + + req_to_token_index = torch.zeros((bs, seq_len_m)) - 1 + token_index = torch.arange(rounded_total_token_in_the_batch) + + total_count = 0 + for i in range(bs): + req_to_token_index[i, : seq_len[i]] = token_index[total_count : total_count + seq_len[i]] + total_count += seq_len[i] + + req_to_token_index = req_to_token_index.long().cuda() + + b_req_idx = torch.arange(bs, device="cuda") + infer_state = InferStateInfo() + infer_state.req_manager = ReqManager(bs, 2048, None) + infer_state.req_manager.req_to_token_indexs = req_to_token_index + infer_state.b_req_idx = b_req_idx.cuda() + infer_state.b_seq_len = seq_len.cuda() + infer_state.max_len_in_batch = seq_len_m + infer_state.batch_size = bs + infer_state.q_head_num = q_head_num + infer_state.q_head_dim = q_head_dim + infer_state.kv_head_num = kv_head_num + infer_state.softmax_scale = 1 / (q_head_dim ** 0.5) + infer_state.total_token_num = torch.tensor([total_token_in_the_batch], dtype=torch.int32).cuda() + new_out = gqa_token_decode_attention_flash_decoding_vsm(q, k, v, infer_state) + old_out = gqa_token_decode_attention_flash_decoding( + q, + infer_state, + infer_state.q_head_num, + infer_state.q_head_dim, + k, + v, + ) + cos_sim = torch.nn.functional.cosine_similarity(new_out, old_out, dim=-1).mean().cpu().item() + self.assertGreaterEqual( + cos_sim, + 0.9, + f"bs={bs},group_size={group_size},seq_len={seq_len_m},q_head_dim={q_head_dim},q_head_num={q_head_num}", + ) + + +if __name__ == "__main__": + unittest.main()