From d498aafb4844585770977ddc60136c4d5bedd495 Mon Sep 17 00:00:00 2001 From: baishihao Date: Tue, 1 Jul 2025 19:21:13 +0800 Subject: [PATCH 01/18] vit and llm inference at a single process --- lightllm/common/image_cache_manager.py | 74 +++++++++++++++++++ .../pre_and_post_layer_weight.py | 65 ++++++++++++++++ .../qwen_vl/layer_infer/pre_layer_infer.py | 26 ++++++- .../pre_and_post_layer_weight.py | 11 ++- .../layer_weights/transformer_layer_weight.py | 10 ++- lightllm/models/vit/model.py | 23 +++++- lightllm/server/api_cli.py | 5 ++ lightllm/server/api_start.py | 4 +- lightllm/server/httpserver/manager.py | 10 ++- lightllm/server/multimodal_params.py | 1 + 10 files changed, 215 insertions(+), 14 deletions(-) create mode 100644 lightllm/common/image_cache_manager.py diff --git a/lightllm/common/image_cache_manager.py b/lightllm/common/image_cache_manager.py new file mode 100644 index 000000000..47086d818 --- /dev/null +++ b/lightllm/common/image_cache_manager.py @@ -0,0 +1,74 @@ +from collections import OrderedDict +from lightllm.utils.dist_utils import get_current_device_id + + +class ImageCacheManager: + def __init__(self): + """ + Initialize the image cache manager with a simple GPU cache and an LRU CPU cache. + """ + self._gpu_cache = dict() + self._cpu_cache = OrderedDict() + + def set_max_size(self, max_size: int): + """ + Set the maximum number of items to keep in the CPU cache. + :param max_size: Maximum number of items to keep in the CPU cache. + """ + if max_size <= 0: + raise ValueError("max_size must be greater than 0") + self._max_size = max_size + + def set_embed(self, uuid, embed): + """ + Store the embedding for the given uuid in the GPU cache. + :param uuid: Unique identifier for the image + :param embed: Embedding vector for the image (on GPU) + """ + self._gpu_cache[uuid] = embed + + def get_embed(self, uuid): + """ + Retrieve the embedding for the given uuid. Prefer GPU cache, + otherwise return CPU cache and move to GPU (simulate .cuda()). + :param uuid: Unique identifier for the image + :return: Embedding vector (on GPU if possible, else move from CPU to GPU) + """ + if uuid in self._gpu_cache: + return self._gpu_cache[uuid] + elif uuid in self._cpu_cache: + self._cpu_cache.move_to_end(uuid) + embed = self._cpu_cache[uuid].cuda(get_current_device_id()) + return embed + return None + + def query_embed(self, uuid): + """ + Query if the embedding for the given uuid is in the cache. + :param uuid: Unique identifier for the image + :return: True if the embedding is in the cache, False otherwise + """ + return uuid in self._gpu_cache or uuid in self._cpu_cache + + def filter(self, uuid_list): + """ + Given a list of uuids, move their embeddings from GPU cache to CPU cache if present, + and return a dict of those found in the cache and their embeddings (on CPU). + :param uuid_list: List of uuids + """ + for uuid in uuid_list: + if uuid in self._gpu_cache: + embed_cpu = self._gpu_cache[uuid].cpu(non_blocking=True) + # Move to CPU cache and remove from GPU cache + self._gpu_cache.pop(uuid) + if uuid in self._cpu_cache: + self._cpu_cache.move_to_end(uuid) + self._cpu_cache[uuid] = embed_cpu + if len(self._cpu_cache) > self._max_size: + self._cpu_cache.popitem(last=False) + elif uuid in self._cpu_cache: + self._cpu_cache.move_to_end(uuid) + return + + +image_cache_manager = ImageCacheManager() diff --git a/lightllm/models/internvl/layer_weights/pre_and_post_layer_weight.py b/lightllm/models/internvl/layer_weights/pre_and_post_layer_weight.py index f19563932..486319495 100644 --- a/lightllm/models/internvl/layer_weights/pre_and_post_layer_weight.py +++ b/lightllm/models/internvl/layer_weights/pre_and_post_layer_weight.py @@ -3,6 +3,9 @@ from lightllm.models.llama.layer_weights.pre_and_post_layer_weight import LlamaPreAndPostLayerWeight from lightllm.models.internlm2.layer_weights.pre_and_post_layer_weight import Internlm2PreAndPostLayerWeight +from lightllm.models.vit.model import VisionTransformer +from lightllm.utils.envs_utils import get_env_start_args +from lightllm.common.image_cache_manager import image_cache_manager # add key: language_model.xxx -> xxx @@ -15,9 +18,45 @@ def rename_weight_keys(weights): weights[k[len(prefix) :]] = weights[k] +class InternVLPreAndPostLayerWeight(LlamaPreAndPostLayerWeight): + def __init__(self, data_type, network_config, mode): + super().__init__(data_type, network_config, mode) + # if we don't assign an extra process for visual model, we need initialize the image cache manager here + if get_env_start_args().disable_extra_process_for_multimodal: + kvargs = { + "weight_dir": get_env_start_args().model_dir, + "data_type": self.data_type_, + "quant_type": get_env_start_args().vit_quant_type, + "quant_cfg": get_env_start_args().vit_quant_cfg, + "max_batch_size": get_env_start_args().visual_infer_batch_size, + } + self.visual_model = VisionTransformer( + kvargs=kvargs, + ) + image_cache_manager.set_max_size(get_env_start_args().cache_capacity * 2) + return + + def load_hf_weights(self, weights): + rename_weight_keys(weights) + super().load_hf_weights(weights) + + class InternVLPhi3PreAndPostLayerWeight(LlamaPreAndPostLayerWeight): def __init__(self, data_type, network_config, mode): super().__init__(data_type, network_config, mode) + # if we don't assign an extra process for visual model, we need initialize the image cache manager here + if get_env_start_args().disable_extra_process_for_multimodal: + kvargs = { + "weight_dir": get_env_start_args().model_dir, + "data_type": self.data_type_, + "quant_type": get_env_start_args().vit_quant_type, + "quant_cfg": get_env_start_args().vit_quant_cfg, + "max_batch_size": get_env_start_args().visual_infer_batch_size, + } + self.visual_model = VisionTransformer( + kvargs=kvargs, + ) + image_cache_manager.set_max_size(get_env_start_args().cache_capacity * 2) return def load_hf_weights(self, weights): @@ -29,6 +68,19 @@ def load_hf_weights(self, weights): class InternVLInternlm2PreAndPostLayerWeight(Internlm2PreAndPostLayerWeight): def __init__(self, data_type, network_config, mode): super().__init__(data_type, network_config, mode) + # if we don't assign an extra process for visual model, we need initialize the image cache manager here + if get_env_start_args().disable_extra_process_for_multimodal: + kvargs = { + "weight_dir": get_env_start_args().model_dir, + "data_type": self.data_type_, + "quant_type": get_env_start_args().vit_quant_type, + "quant_cfg": get_env_start_args().vit_quant_cfg, + "max_batch_size": get_env_start_args().visual_infer_batch_size, + } + self.visual_model = VisionTransformer( + kvargs=kvargs, + ) + image_cache_manager.set_max_size(get_env_start_args().cache_capacity * 2) return def load_hf_weights(self, weights): @@ -40,6 +92,19 @@ def load_hf_weights(self, weights): class InternVLLlamaPreAndPostLayerWeight(LlamaPreAndPostLayerWeight): def __init__(self, data_type, network_config, mode): super().__init__(data_type, network_config, mode) + # if we don't assign an extra process for visual model, we need initialize the image cache manager here + if get_env_start_args().disable_extra_process_for_multimodal: + kvargs = { + "weight_dir": get_env_start_args().model_dir, + "data_type": self.data_type_, + "quant_type": get_env_start_args().vit_quant_type, + "quant_cfg": get_env_start_args().vit_quant_cfg, + "max_batch_size": get_env_start_args().visual_infer_batch_size, + } + self.visual_model = VisionTransformer( + kvargs=kvargs, + ) + image_cache_manager.set_max_size(get_env_start_args().cache_capacity * 2) return def load_hf_weights(self, weights): diff --git a/lightllm/models/qwen_vl/layer_infer/pre_layer_infer.py b/lightllm/models/qwen_vl/layer_infer/pre_layer_infer.py index 60c9e0564..92ad282d9 100644 --- a/lightllm/models/qwen_vl/layer_infer/pre_layer_infer.py +++ b/lightllm/models/qwen_vl/layer_infer/pre_layer_infer.py @@ -6,7 +6,9 @@ from lightllm.models.llama.layer_infer.pre_layer_infer import LlamaPreLayerInfer from lightllm.utils.infer_utils import mark_cost_time +from lightllm.utils.envs_utils import get_env_start_args from lightllm.server.embed_cache.utils import bytes2tensor, read_shm, get_shm_name_embed +from lightllm.common.image_cache_manager import image_cache_manager from lightllm.common.basemodel.triton_kernel.multimodal_emb import multimodal_emb from lightllm.distributed.communication_op import all_reduce @@ -29,8 +31,22 @@ class LlamaMultimodalPreLayerInfer(LlamaPreLayerInfer): def __init__(self, network_config, mode): super().__init__(network_config, mode) + self.disable_extra_process_for_multimodal = get_env_start_args().disable_extra_process_for_multimodal return + def _infer_image_embeds(self, infer_state, layer_weight): + if not self.disable_extra_process_for_multimodal: + return + infer_images = [] + for _, p in enumerate(infer_state.multimodal_params): + for img in p["images"] + p["audios"]: + if not image_cache_manager.query_embed(img["uuid"]): + infer_images.append(img) + if len(infer_images) > 0: + img_embeds, uuids, valid_ids = layer_weight.visual_model.encode(infer_images) + for uuid, valid_id in zip(uuids, valid_ids): + image_cache_manager.set_embed(uuid, img_embeds[valid_id[0] : valid_id[1]]) + def context_forward(self, input_ids, infer_state: LlamaInferStateInfo, layer_weight: LlamaPreAndPostLayerWeight): img_weight = [] @@ -42,14 +58,20 @@ def context_forward(self, input_ids, infer_state: LlamaInferStateInfo, layer_wei device = layer_weight.wte_weight_.device dtype = layer_weight.wte_weight_.dtype hidden_size = layer_weight.wte_weight_.shape[1] + self._infer_image_embeds(infer_state, layer_weight) for batch_id, p in enumerate(infer_state.multimodal_params): for img in p["images"] + p["audios"]: # skip the same image if img["token_id"] in img_start_token_ids: continue # pull the img_embeds by uid from shm - data = read_shm(get_shm_name_embed(img["uuid"])) - img_weight.append(bytes2tensor(data).cuda().reshape(img["token_num"], -1)) + if self.disable_extra_process_for_multimodal: + img_embed = image_cache_manager.get_embed(img["uuid"]) + img_weight.append(img_embed.reshape(img["token_num"], -1)) + print(img_weight[-1].shape) + else: + data = read_shm(get_shm_name_embed(img["uuid"])) + img_weight.append(bytes2tensor(data).cuda().reshape(img["token_num"], -1)) img_start_token_ids.append(img["token_id"]) img_token_lens.append(img["token_num"]) img_start_locs.append(img_start_loc) diff --git a/lightllm/models/vit/layer_weights/pre_and_post_layer_weight.py b/lightllm/models/vit/layer_weights/pre_and_post_layer_weight.py index 276d4e5d0..2f0e725c8 100644 --- a/lightllm/models/vit/layer_weights/pre_and_post_layer_weight.py +++ b/lightllm/models/vit/layer_weights/pre_and_post_layer_weight.py @@ -3,7 +3,12 @@ import numpy as np import torch.nn.functional as F from lightllm.common.basemodel import PreAndPostLayerWeight -from lightllm.utils.dist_utils import get_current_device_id +from lightllm.utils.dist_utils import ( + get_current_device_id, + get_global_rank, + get_global_world_size, +) +from lightllm.utils.envs_utils import get_env_start_args class ViTPreAndPostLayerWeight(PreAndPostLayerWeight): @@ -13,6 +18,10 @@ def __init__(self, data_type, network_config, mode): self.image_size = self.network_config_["image_size"] self.patch_size = self.network_config_["patch_size"] self.llm_hidden_size = self.network_config_["llm_hidden_size"] + if get_env_start_args().disable_extra_process_for_multimodal: + self.tp_world_size_ = get_global_world_size() + self.tp_rank_ = get_global_rank() + return def _cuda(self, cpu_tensor): diff --git a/lightllm/models/vit/layer_weights/transformer_layer_weight.py b/lightllm/models/vit/layer_weights/transformer_layer_weight.py index 55d73fa73..4b62bff36 100644 --- a/lightllm/models/vit/layer_weights/transformer_layer_weight.py +++ b/lightllm/models/vit/layer_weights/transformer_layer_weight.py @@ -11,12 +11,20 @@ MultiROWMMWeight, TpNormWeight, ) -from lightllm.utils.dist_utils import get_current_device_id +from lightllm.utils.dist_utils import ( + get_current_device_id, + get_global_rank, + get_global_world_size, +) +from lightllm.utils.envs_utils import get_env_start_args class ViTTransformerLayerWeight(TransformerLayerWeight): def __init__(self, layer_num, data_type, network_config, mode=[], quant_cfg=None): super().__init__(layer_num, data_type, network_config, mode, quant_cfg) + if get_env_start_args().disable_extra_process_for_multimodal: + self.tp_world_size_ = get_global_world_size() + self.tp_rank_ = get_global_rank() return def _cuda(self, cpu_tensor): diff --git a/lightllm/models/vit/model.py b/lightllm/models/vit/model.py index 01bb69bdf..bf74cc1a4 100644 --- a/lightllm/models/vit/model.py +++ b/lightllm/models/vit/model.py @@ -18,7 +18,8 @@ from io import BytesIO from rpyc.utils.classic import obtain from lightllm.common.quantization import Quantcfg -from lightllm.utils.dist_utils import get_dp_world_size +from lightllm.utils.dist_utils import get_dp_world_size, get_global_world_size +from lightllm.utils.envs_utils import get_env_start_args from lightllm.common.basemodel.layer_infer.cache_tensor_manager import g_cache_manager @@ -37,7 +38,11 @@ class VisionTransformer: post_layer_infer_class = ViTPostLayerInfer def __init__(self, kvargs): - self.tp_world_size_ = get_dp_world_size() + if get_env_start_args().disable_extra_process_for_multimodal: + # if we don't assign an extra process for visual model, the visual model uses tensor parallel by default. + self.tp_world_size_ = get_global_world_size() + else: + self.tp_world_size_ = get_dp_world_size() self.weight_dir_ = kvargs["weight_dir"] self.load_way = kvargs.get("load_way", "HF") self.mode = [m.replace("int4weight", "w4a16").replace("int8weight", "w8a16") for m in kvargs.get("mode", [])] @@ -150,6 +155,8 @@ def _init_infer_layer(self): return def _init_datatype(self): + if isinstance(self.data_type, torch.dtype): + return if self.data_type in ["fp16", "float16"]: self.data_type = torch.float16 elif self.data_type in ["bf16", "bfloat16"]: @@ -161,12 +168,14 @@ def _init_datatype(self): @torch.no_grad() def forward(self, pixel_values): - g_cache_manager.cache_env_in() + if not get_env_start_args().disable_extra_process_for_multimodal: + g_cache_manager.cache_env_in() input_embs = self.pre_infer.forward(pixel_values, self.pre_post_weight) for i in range(self.layers_num + self.select_layer + 1): input_embs = self.layers_infer[i].forward(input_embs, self.trans_layers_weight[i]) input_embs = self.post_infer.forward(input_embs[:, 1:, :], self.pre_post_weight) - g_cache_manager.cache_env_out() + if not get_env_start_args().disable_extra_process_for_multimodal: + g_cache_manager.cache_env_out() return input_embs @torch.no_grad() @@ -182,6 +191,12 @@ def encode(self, images: List[ImageItem]): image_data = Image.open(BytesIO(image_data)) t = self.load_image_func(image_data, max_num=img.extra_params["image_patch_max_num"]) img_tensors.append(t) + elif isinstance(img, dict): + uuids.append(img["uuid"]) + image_data = read_shm(get_shm_name_data(img["uuid"])) + image_data = Image.open(BytesIO(image_data)) + t = self.load_image_func(image_data, max_num=img["extra_params"]["image_patch_max_num"]) + img_tensors.append(t) else: raise Exception("Unsupport input types: {} for {}".format(type(img), img)) diff --git a/lightllm/server/api_cli.py b/lightllm/server/api_cli.py index 3f3eaf96f..3cb62d842 100644 --- a/lightllm/server/api_cli.py +++ b/lightllm/server/api_cli.py @@ -233,6 +233,11 @@ def make_argument_parser() -> argparse.ArgumentParser: parser.add_argument( "--enable_multimodal", action="store_true", help="Whether or not to allow to load additional visual models." ) + parser.add_argument( + "--disable_extra_process_for_multimodal", + action="store_true", + help="Whether or not to disable extra process for multimodal.", + ) parser.add_argument( "--enable_multimodal_audio", action="store_true", diff --git a/lightllm/server/api_start.py b/lightllm/server/api_start.py index de1e690a2..1f9107fa7 100644 --- a/lightllm/server/api_start.py +++ b/lightllm/server/api_start.py @@ -243,7 +243,7 @@ def normal_or_p_d_start(args): ], start_args=[(cache_port, args)], ) - if args.enable_multimodal_audio: + if args.enable_multimodal_audio and not args.disable_extra_process_for_multimodal: from .audioserver.manager import start_audio_process process_manager.start_submodule_processes( @@ -263,7 +263,7 @@ def normal_or_p_d_start(args): ], ) - else: + elif not args.disable_extra_process_for_multimodal: process_manager.start_submodule_processes( start_funcs=[ start_visual_process, diff --git a/lightllm/server/httpserver/manager.py b/lightllm/server/httpserver/manager.py index fa455c225..967c716dd 100644 --- a/lightllm/server/httpserver/manager.py +++ b/lightllm/server/httpserver/manager.py @@ -81,10 +81,12 @@ def __init__( ) self.enable_multimodal = enable_multimodal + self.disable_extra_process_for_multimodal = args.disable_extra_process_for_multimodal if self.enable_multimodal: self.cache_client = rpyc.connect("localhost", cache_port) - self.send_to_visual = context.socket(zmq.PUSH) - self.send_to_visual.connect(f"{args.zmq_mode}127.0.0.1:{visual_port}") + if not self.disable_extra_process_for_multimodal: + self.send_to_visual = context.socket(zmq.PUSH) + self.send_to_visual.connect(f"{args.zmq_mode}127.0.0.1:{visual_port}") self.shm_req_manager = ShmReqManager() @@ -449,7 +451,7 @@ async def transfer_to_next_module( ): if self.pd_mode == NodeRole.P: - if self.enable_multimodal: + if self.enable_multimodal and not self.disable_extra_process_for_multimodal: self.send_to_visual.send_pyobj( group_req_objs.to_group_req_index(), protocol=pickle.HIGHEST_PROTOCOL, @@ -470,7 +472,7 @@ async def transfer_to_next_module( return if self.pd_mode == NodeRole.NORMAL: - if self.enable_multimodal: + if self.enable_multimodal and not self.disable_extra_process_for_multimodal: self.send_to_visual.send_pyobj( group_req_objs.to_group_req_index(), protocol=pickle.HIGHEST_PROTOCOL, diff --git a/lightllm/server/multimodal_params.py b/lightllm/server/multimodal_params.py index bf320e199..97d456355 100644 --- a/lightllm/server/multimodal_params.py +++ b/lightllm/server/multimodal_params.py @@ -119,6 +119,7 @@ def to_dict(self): ret["uuid"] = self.uuid ret["token_id"] = self.token_id ret["token_num"] = self.token_num + ret["extra_params"] = self.extra_params return ret def to_origin_dict(self): From 691b89c7cf40567e02326aa2c73489aaf644c0d3 Mon Sep 17 00:00:00 2001 From: baishihao Date: Tue, 1 Jul 2025 22:09:19 +0800 Subject: [PATCH 02/18] fix --- lightllm/common/image_cache_manager.py | 4 +++- .../models/qwen_vl/layer_infer/pre_layer_infer.py | 1 - lightllm/models/vit/layer_infer/pre_layer_infer.py | 1 + .../models/vit/layer_infer/transformer_layer_infer.py | 1 + .../vit/layer_weights/pre_and_post_layer_weight.py | 11 ++++------- .../vit/layer_weights/transformer_layer_weight.py | 6 ------ lightllm/models/vit/model.py | 6 +----- lightllm/server/router/model_infer/infer_batch.py | 6 ++++++ 8 files changed, 16 insertions(+), 20 deletions(-) diff --git a/lightllm/common/image_cache_manager.py b/lightllm/common/image_cache_manager.py index 47086d818..fb04e4b59 100644 --- a/lightllm/common/image_cache_manager.py +++ b/lightllm/common/image_cache_manager.py @@ -58,7 +58,7 @@ def filter(self, uuid_list): """ for uuid in uuid_list: if uuid in self._gpu_cache: - embed_cpu = self._gpu_cache[uuid].cpu(non_blocking=True) + embed_cpu = self._gpu_cache[uuid].cpu() # Move to CPU cache and remove from GPU cache self._gpu_cache.pop(uuid) if uuid in self._cpu_cache: @@ -68,6 +68,8 @@ def filter(self, uuid_list): self._cpu_cache.popitem(last=False) elif uuid in self._cpu_cache: self._cpu_cache.move_to_end(uuid) + print(self._gpu_cache.keys()) + print(self._cpu_cache.keys()) return diff --git a/lightllm/models/qwen_vl/layer_infer/pre_layer_infer.py b/lightllm/models/qwen_vl/layer_infer/pre_layer_infer.py index 92ad282d9..0d88ae3d3 100644 --- a/lightllm/models/qwen_vl/layer_infer/pre_layer_infer.py +++ b/lightllm/models/qwen_vl/layer_infer/pre_layer_infer.py @@ -68,7 +68,6 @@ def context_forward(self, input_ids, infer_state: LlamaInferStateInfo, layer_wei if self.disable_extra_process_for_multimodal: img_embed = image_cache_manager.get_embed(img["uuid"]) img_weight.append(img_embed.reshape(img["token_num"], -1)) - print(img_weight[-1].shape) else: data = read_shm(get_shm_name_embed(img["uuid"])) img_weight.append(bytes2tensor(data).cuda().reshape(img["token_num"], -1)) diff --git a/lightllm/models/vit/layer_infer/pre_layer_infer.py b/lightllm/models/vit/layer_infer/pre_layer_infer.py index 896e8e898..022466ec2 100644 --- a/lightllm/models/vit/layer_infer/pre_layer_infer.py +++ b/lightllm/models/vit/layer_infer/pre_layer_infer.py @@ -16,6 +16,7 @@ def __init__(self, network_config, mode): self.tp_world_size_ = get_dp_world_size() self.network_config_ = network_config self.mode = mode + print(f"tp_rank_: {self.tp_rank_}, tp_world_size_: {self.tp_world_size_}") return def forward(self, pixel_values, layer_weight: ViTPreAndPostLayerWeight): diff --git a/lightllm/models/vit/layer_infer/transformer_layer_infer.py b/lightllm/models/vit/layer_infer/transformer_layer_infer.py index 14ba9cfed..2be7a0e56 100644 --- a/lightllm/models/vit/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/vit/layer_infer/transformer_layer_infer.py @@ -21,6 +21,7 @@ class ViTTransformerLayerInfer: def __init__(self, layer_num, network_config, mode=[]): self.tp_rank_ = get_current_rank_in_dp() self.tp_world_size_ = get_dp_world_size() + print(f"tp_rank_: {self.tp_rank_}, tp_world_size_: {self.tp_world_size_}") self.eps_ = network_config["layer_norm_eps"] self.head_num = network_config["num_attention_heads"] self.tp_padding_head_num = network_config["padding_head_num"] // self.tp_world_size_ diff --git a/lightllm/models/vit/layer_weights/pre_and_post_layer_weight.py b/lightllm/models/vit/layer_weights/pre_and_post_layer_weight.py index 2f0e725c8..69e6ef317 100644 --- a/lightllm/models/vit/layer_weights/pre_and_post_layer_weight.py +++ b/lightllm/models/vit/layer_weights/pre_and_post_layer_weight.py @@ -5,10 +5,9 @@ from lightllm.common.basemodel import PreAndPostLayerWeight from lightllm.utils.dist_utils import ( get_current_device_id, - get_global_rank, - get_global_world_size, + get_current_rank_in_dp, + get_dp_world_size, ) -from lightllm.utils.envs_utils import get_env_start_args class ViTPreAndPostLayerWeight(PreAndPostLayerWeight): @@ -18,10 +17,8 @@ def __init__(self, data_type, network_config, mode): self.image_size = self.network_config_["image_size"] self.patch_size = self.network_config_["patch_size"] self.llm_hidden_size = self.network_config_["llm_hidden_size"] - if get_env_start_args().disable_extra_process_for_multimodal: - self.tp_world_size_ = get_global_world_size() - self.tp_rank_ = get_global_rank() - + self.tp_rank_ = get_current_rank_in_dp() + self.tp_world_size_ = get_dp_world_size() return def _cuda(self, cpu_tensor): diff --git a/lightllm/models/vit/layer_weights/transformer_layer_weight.py b/lightllm/models/vit/layer_weights/transformer_layer_weight.py index 4b62bff36..3c42f712e 100644 --- a/lightllm/models/vit/layer_weights/transformer_layer_weight.py +++ b/lightllm/models/vit/layer_weights/transformer_layer_weight.py @@ -13,18 +13,12 @@ ) from lightllm.utils.dist_utils import ( get_current_device_id, - get_global_rank, - get_global_world_size, ) -from lightllm.utils.envs_utils import get_env_start_args class ViTTransformerLayerWeight(TransformerLayerWeight): def __init__(self, layer_num, data_type, network_config, mode=[], quant_cfg=None): super().__init__(layer_num, data_type, network_config, mode, quant_cfg) - if get_env_start_args().disable_extra_process_for_multimodal: - self.tp_world_size_ = get_global_world_size() - self.tp_rank_ = get_global_rank() return def _cuda(self, cpu_tensor): diff --git a/lightllm/models/vit/model.py b/lightllm/models/vit/model.py index bf74cc1a4..c89f7339d 100644 --- a/lightllm/models/vit/model.py +++ b/lightllm/models/vit/model.py @@ -38,11 +38,7 @@ class VisionTransformer: post_layer_infer_class = ViTPostLayerInfer def __init__(self, kvargs): - if get_env_start_args().disable_extra_process_for_multimodal: - # if we don't assign an extra process for visual model, the visual model uses tensor parallel by default. - self.tp_world_size_ = get_global_world_size() - else: - self.tp_world_size_ = get_dp_world_size() + self.tp_world_size_ = get_dp_world_size() self.weight_dir_ = kvargs["weight_dir"] self.load_way = kvargs.get("load_way", "HF") self.mode = [m.replace("int4weight", "w4a16").replace("int8weight", "w8a16") for m in kvargs.get("mode", [])] diff --git a/lightllm/server/router/model_infer/infer_batch.py b/lightllm/server/router/model_infer/infer_batch.py index 10b68245c..551a70a22 100644 --- a/lightllm/server/router/model_infer/infer_batch.py +++ b/lightllm/server/router/model_infer/infer_batch.py @@ -9,6 +9,7 @@ from dataclasses import dataclass, field from typing import List, Dict, Tuple, Optional, Union, Any from lightllm.common.req_manager import ReqManager +from lightllm.common.image_cache_manager import image_cache_manager from lightllm.utils.infer_utils import mark_start, mark_end from lightllm.server.core.objs import Req, SamplingParams, FinishStatus, ShmReqManager from lightllm.server.router.dynamic_prompt.radix_cache import RadixCache, TreeNode @@ -131,6 +132,7 @@ def filter(self, finished_request_ids: List[int]): free_req_index = [] free_token_index = [] + image_uuid_list = [] for request_id in finished_request_ids: req: InferReq = self.requests_mapping.pop(request_id) group_req_id = convert_sub_id_to_group_id(req.shm_req.request_id) @@ -145,6 +147,10 @@ def filter(self, finished_request_ids: List[int]): # logger.info(f"infer release req id {req.shm_req.request_id}") req.shm_req.shm_infer_released = True self.shm_req_manager.put_back_req_obj(req.shm_req) + if req.multimodal_params is not None and get_env_start_args().disable_extra_process_for_multimodal: + for img in req.multimodal_params["images"]: + image_uuid_list.append(img["uuid"]) + image_cache_manager.filter(image_uuid_list) free_token_index = custom_cat(free_token_index) self.req_manager.free(free_req_index, free_token_index) From 8624e8f5789f0e2529d281e40dc9e48c55174f5e Mon Sep 17 00:00:00 2001 From: baishihao Date: Tue, 1 Jul 2025 22:16:37 +0800 Subject: [PATCH 03/18] fix --- .../models/vit/layer_weights/pre_and_post_layer_weight.py | 8 +------- 1 file changed, 1 insertion(+), 7 deletions(-) diff --git a/lightllm/models/vit/layer_weights/pre_and_post_layer_weight.py b/lightllm/models/vit/layer_weights/pre_and_post_layer_weight.py index 69e6ef317..276d4e5d0 100644 --- a/lightllm/models/vit/layer_weights/pre_and_post_layer_weight.py +++ b/lightllm/models/vit/layer_weights/pre_and_post_layer_weight.py @@ -3,11 +3,7 @@ import numpy as np import torch.nn.functional as F from lightllm.common.basemodel import PreAndPostLayerWeight -from lightllm.utils.dist_utils import ( - get_current_device_id, - get_current_rank_in_dp, - get_dp_world_size, -) +from lightllm.utils.dist_utils import get_current_device_id class ViTPreAndPostLayerWeight(PreAndPostLayerWeight): @@ -17,8 +13,6 @@ def __init__(self, data_type, network_config, mode): self.image_size = self.network_config_["image_size"] self.patch_size = self.network_config_["patch_size"] self.llm_hidden_size = self.network_config_["llm_hidden_size"] - self.tp_rank_ = get_current_rank_in_dp() - self.tp_world_size_ = get_dp_world_size() return def _cuda(self, cpu_tensor): From 3e04c7feefdd1a2e409ad0fb084db3320b2953f9 Mon Sep 17 00:00:00 2001 From: baishihao Date: Tue, 1 Jul 2025 22:19:26 +0800 Subject: [PATCH 04/18] remove print --- lightllm/models/vit/layer_infer/pre_layer_infer.py | 1 - lightllm/models/vit/layer_infer/transformer_layer_infer.py | 1 - 2 files changed, 2 deletions(-) diff --git a/lightllm/models/vit/layer_infer/pre_layer_infer.py b/lightllm/models/vit/layer_infer/pre_layer_infer.py index 022466ec2..896e8e898 100644 --- a/lightllm/models/vit/layer_infer/pre_layer_infer.py +++ b/lightllm/models/vit/layer_infer/pre_layer_infer.py @@ -16,7 +16,6 @@ def __init__(self, network_config, mode): self.tp_world_size_ = get_dp_world_size() self.network_config_ = network_config self.mode = mode - print(f"tp_rank_: {self.tp_rank_}, tp_world_size_: {self.tp_world_size_}") return def forward(self, pixel_values, layer_weight: ViTPreAndPostLayerWeight): diff --git a/lightllm/models/vit/layer_infer/transformer_layer_infer.py b/lightllm/models/vit/layer_infer/transformer_layer_infer.py index 2be7a0e56..14ba9cfed 100644 --- a/lightllm/models/vit/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/vit/layer_infer/transformer_layer_infer.py @@ -21,7 +21,6 @@ class ViTTransformerLayerInfer: def __init__(self, layer_num, network_config, mode=[]): self.tp_rank_ = get_current_rank_in_dp() self.tp_world_size_ = get_dp_world_size() - print(f"tp_rank_: {self.tp_rank_}, tp_world_size_: {self.tp_world_size_}") self.eps_ = network_config["layer_norm_eps"] self.head_num = network_config["num_attention_heads"] self.tp_padding_head_num = network_config["padding_head_num"] // self.tp_world_size_ From 06c38a085e782e0c85f9745ba243e18425cb4b52 Mon Sep 17 00:00:00 2001 From: baishihao Date: Wed, 2 Jul 2025 11:13:50 +0800 Subject: [PATCH 05/18] fix --- lightllm/models/qwen_vl/layer_infer/pre_layer_infer.py | 8 +++++--- lightllm/models/vit/model.py | 8 ++++++-- 2 files changed, 11 insertions(+), 5 deletions(-) diff --git a/lightllm/models/qwen_vl/layer_infer/pre_layer_infer.py b/lightllm/models/qwen_vl/layer_infer/pre_layer_infer.py index 0d88ae3d3..9b1d95f7c 100644 --- a/lightllm/models/qwen_vl/layer_infer/pre_layer_infer.py +++ b/lightllm/models/qwen_vl/layer_infer/pre_layer_infer.py @@ -43,9 +43,11 @@ def _infer_image_embeds(self, infer_state, layer_weight): if not image_cache_manager.query_embed(img["uuid"]): infer_images.append(img) if len(infer_images) > 0: - img_embeds, uuids, valid_ids = layer_weight.visual_model.encode(infer_images) - for uuid, valid_id in zip(uuids, valid_ids): - image_cache_manager.set_embed(uuid, img_embeds[valid_id[0] : valid_id[1]]) + infer_batch_size = get_env_start_args().visual_infer_batch_size + for i in range(0, len(infer_images), infer_batch_size): + img_embeds, uuids, valid_ids = layer_weight.visual_model.encode(infer_images[i : i + infer_batch_size]) + for uuid, valid_id in zip(uuids, valid_ids): + image_cache_manager.set_embed(uuid, img_embeds[valid_id[0] : valid_id[1]]) def context_forward(self, input_ids, infer_state: LlamaInferStateInfo, layer_weight: LlamaPreAndPostLayerWeight): diff --git a/lightllm/models/vit/model.py b/lightllm/models/vit/model.py index c89f7339d..a8b475889 100644 --- a/lightllm/models/vit/model.py +++ b/lightllm/models/vit/model.py @@ -48,6 +48,7 @@ def __init__(self, kvargs): self.quant_cfg_path = kvargs.get("quant_cfg", None) self.load_image_func = get_load_image_func(self.weight_dir_) self.max_batch_size = kvargs.get("max_batch_size", 1) + self.enable_tensor_cache = not get_env_start_args().disable_extra_process_for_multimodal self._init_datatype() self._init_config() @@ -64,6 +65,7 @@ def _check_max_len_infer(self): disable_check_max_len_infer = os.getenv("DISABLE_CHECK_MAX_LEN_INFER", None) is not None if disable_check_max_len_infer: return + self.enable_tensor_cache = True try: dummy_images = torch.randn( @@ -71,6 +73,7 @@ def _check_max_len_infer(self): ).cuda() all_img_embeds = self.forward(dummy_images) del all_img_embeds + del dummy_images logger.info(f"vit check max_len {self.max_batch_size} infer ok") except (RuntimeError, torch.OutOfMemoryError) as e: logger.exception(str(e)) @@ -79,6 +82,7 @@ def _check_max_len_infer(self): ) logger.error(exception_str) raise Exception(exception_str) + self.enable_tensor_cache = not get_env_start_args().disable_extra_process_for_multimodal return def _init_config(self): @@ -164,13 +168,13 @@ def _init_datatype(self): @torch.no_grad() def forward(self, pixel_values): - if not get_env_start_args().disable_extra_process_for_multimodal: + if self.enable_tensor_cache: g_cache_manager.cache_env_in() input_embs = self.pre_infer.forward(pixel_values, self.pre_post_weight) for i in range(self.layers_num + self.select_layer + 1): input_embs = self.layers_infer[i].forward(input_embs, self.trans_layers_weight[i]) input_embs = self.post_infer.forward(input_embs[:, 1:, :], self.pre_post_weight) - if not get_env_start_args().disable_extra_process_for_multimodal: + if self.enable_tensor_cache: g_cache_manager.cache_env_out() return input_embs From b0768e26ca63c638005b4c0673fd8706c63613a1 Mon Sep 17 00:00:00 2001 From: wangzaijun Date: Wed, 2 Jul 2025 19:35:50 +0800 Subject: [PATCH 06/18] mutli infer detect need prefill objs --- lightllm/common/basemodel/infer_struct.py | 22 +++ .../basemodel/triton_kernel/multimodal_emb.py | 138 ++++++++++++++---- .../gemma3/layer_infer/pre_layer_infer.py | 4 +- .../qwen_vl/layer_infer/pre_layer_infer.py | 7 +- .../triton_kernel/test_multimodal_emb.py | 22 +++ 5 files changed, 162 insertions(+), 31 deletions(-) create mode 100755 unit_tests/common/basemodel/triton_kernel/test_multimodal_emb.py diff --git a/lightllm/common/basemodel/infer_struct.py b/lightllm/common/basemodel/infer_struct.py index 021de6843..e45cc11c7 100755 --- a/lightllm/common/basemodel/infer_struct.py +++ b/lightllm/common/basemodel/infer_struct.py @@ -5,6 +5,7 @@ from typing import Tuple, Any, Optional from .triton_kernel.gen_prefill_params import gen_prefill_params from .triton_kernel.gen_decode_params import gen_decode_params +from .triton_kernel.multimodal_emb import mark_multimodal_obj class InferStateInfo: @@ -98,3 +99,24 @@ def copy_for_cuda_graph(self, new_infer_state: "InferStateInfo"): if attr_ is not None and attr_.data_ptr() != attr_value.data_ptr(): attr_.copy_(attr_value, non_blocking=True) return + + def mark_multimodal_objs_for_prefill(self, input_ids: torch.Tensor): + """ + 功能函数,用于标记在chuncked prefill的过程中,到底哪些多模态对象对应的token是需要参与计算的。 + 因为分chunck的原因,并不是所有的多模态对象对应的token都需要参与计算。 + """ + multi_objs = [] + for _, p in enumerate(self.multimodal_params): + for obj in p["images"] + p["audios"]: + multi_objs.append(obj) + + if multi_objs: + obj_start_ids = torch.tensor([e["token_id"] for e in multi_objs], dtype=torch.int64, device="cuda") + obj_token_lens = torch.tensor([e["token_num"] for e in multi_objs], dtype=torch.int64, device="cuda") + marks = mark_multimodal_obj( + obj_start_token_ids=obj_start_ids, obj_token_lens=obj_token_lens, input_ids=input_ids + ) + marks_array = marks.detach().cpu().numpy() + for mark, obj in zip(marks_array, multi_objs): + obj["_prefill_"] = mark > 0 + return diff --git a/lightllm/common/basemodel/triton_kernel/multimodal_emb.py b/lightllm/common/basemodel/triton_kernel/multimodal_emb.py index 8b66827a5..64d45e0dc 100644 --- a/lightllm/common/basemodel/triton_kernel/multimodal_emb.py +++ b/lightllm/common/basemodel/triton_kernel/multimodal_emb.py @@ -5,48 +5,78 @@ @triton.jit def _fwd_kernel( - Prompt_ids, + Prompt_ids, Text_weight_embs, Img_embs, Out, Img_token_lens, Img_start_token_ids, Img_start_locs, - stride_text_emb_s, stride_text_emb_d, # text_stride - stride_img_emb_s, stride_img_emb_d, # img_stride - stride_out_s, stride_out_d, + stride_text_emb_s, + stride_text_emb_d, # text_stride + stride_img_emb_s, + stride_img_emb_d, # img_stride + stride_out_s, + stride_out_d, tp_text_start_token_id, tp_text_end_token_id, hidden_size, - BLOCK_HIDDEN_DIM: tl.constexpr - ): + BLOCK_HIDDEN_DIM: tl.constexpr, +): seq_index = tl.program_id(0).to(tl.int64) img_handle_id = tl.program_id(1) token_id = tl.load(Prompt_ids + seq_index) off_d = tl.arange(0, BLOCK_HIDDEN_DIM) - + # load store text emb - for _ in range(0, tl.where((img_handle_id == 0) & (token_id < tp_text_end_token_id) & (token_id >= tp_text_start_token_id), 1, 0), 1): - load_emb = tl.load(Text_weight_embs + stride_text_emb_s * (token_id - tp_text_start_token_id) + off_d * stride_text_emb_d, mask=off_d < hidden_size, other=0) + for _ in range( + 0, + tl.where((img_handle_id == 0) & (token_id < tp_text_end_token_id) & (token_id >= tp_text_start_token_id), 1, 0), + 1, + ): + load_emb = tl.load( + Text_weight_embs + stride_text_emb_s * (token_id - tp_text_start_token_id) + off_d * stride_text_emb_d, + mask=off_d < hidden_size, + other=0, + ) tl.store(Out + stride_out_s * seq_index + stride_out_d * off_d, load_emb, mask=off_d < hidden_size) - + img_start_token_id = tl.load(Img_start_token_ids + img_handle_id - 1, mask=img_handle_id >= 1, other=0) img_start_loc = tl.load(Img_start_locs + img_handle_id - 1, mask=img_handle_id >= 1, other=0) img_token_len = tl.load(Img_token_lens + img_handle_id - 1, mask=img_handle_id >= 1, other=0) # load store img emb - for _ in range(0, tl.where((img_handle_id != 0) & (token_id >= img_start_token_id) & (token_id < img_start_token_id + img_token_len), 1, 0), 1): - load_emb = tl.load(Img_embs + stride_img_emb_s * (img_start_loc + token_id - img_start_token_id) + off_d * stride_img_emb_d, mask=off_d < hidden_size, other=0) + for _ in range( + 0, + tl.where( + (img_handle_id != 0) & (token_id >= img_start_token_id) & (token_id < img_start_token_id + img_token_len), + 1, + 0, + ), + 1, + ): + load_emb = tl.load( + Img_embs + stride_img_emb_s * (img_start_loc + token_id - img_start_token_id) + off_d * stride_img_emb_d, + mask=off_d < hidden_size, + other=0, + ) tl.store(Out + stride_out_s * seq_index + stride_out_d * off_d, load_emb, mask=off_d < hidden_size) return @torch.no_grad() -def multimodal_emb(out: torch.Tensor, prompt_ids: torch.Tensor, text_weight_embs: torch.Tensor, img_embs: torch.Tensor, - img_token_lens: torch.Tensor, img_start_token_ids: torch.Tensor, img_start_locs: torch.Tensor, - tp_text_start_token_id, - tp_text_end_token_id): +def multimodal_emb( + out: torch.Tensor, + prompt_ids: torch.Tensor, + text_weight_embs: torch.Tensor, + img_embs: torch.Tensor, + img_token_lens: torch.Tensor, + img_start_token_ids: torch.Tensor, + img_start_locs: torch.Tensor, + tp_text_start_token_id, + tp_text_end_token_id, +): total_len = prompt_ids.shape[0] BLOCK = triton.next_power_of_2(out.shape[1]) # print(len(img_token_lens)) @@ -60,9 +90,12 @@ def multimodal_emb(out: torch.Tensor, prompt_ids: torch.Tensor, text_weight_embs img_token_lens, img_start_token_ids, img_start_locs, - text_weight_embs.stride(0), text_weight_embs.stride(1), - img_embs.stride(0), img_embs.stride(1), - out.stride(0), out.stride(1), + text_weight_embs.stride(0), + text_weight_embs.stride(1), + img_embs.stride(0), + img_embs.stride(1), + out.stride(0), + out.stride(1), tp_text_start_token_id, tp_text_end_token_id, hidden_size=out.shape[1], @@ -73,6 +106,48 @@ def multimodal_emb(out: torch.Tensor, prompt_ids: torch.Tensor, text_weight_embs return +@triton.jit +def _mark_multimodal_obj_need_kernel( + obj_start_token_ids_ptr, + obj_token_lens_ptr, + obj_marks_ptr, + input_ids_ptr, + input_size, + BLOCK_SIZE: tl.constexpr, +): + + obj_index = tl.program_id(0) + start_id = tl.load(obj_start_token_ids_ptr + obj_index) + token_len = tl.load(obj_token_lens_ptr + obj_index) + + for block_start in range(0, input_size, BLOCK_SIZE): + block_range = block_start + tl.arange(0, BLOCK_SIZE) + cur_input_ids = tl.load(input_ids_ptr + block_range, mask=block_range < input_size, other=0) + mark = tl.where((cur_input_ids >= start_id) & (cur_input_ids < start_id + token_len), 1, 0) + mark = tl.sum(mark) + tl.store(obj_marks_ptr + obj_index, 1, mask=mark > 0) + return + + +@torch.no_grad() +def mark_multimodal_obj(obj_start_token_ids: torch.Tensor, obj_token_lens: torch.Tensor, input_ids: torch.Tensor): + out_mark = torch.empty_like(obj_start_token_ids) + out_mark.fill_(0) + assert obj_start_token_ids.shape == obj_token_lens.shape + BLOCK = 512 + grid = (obj_start_token_ids.shape[0],) + _mark_multimodal_obj_need_kernel[grid]( + obj_start_token_ids_ptr=obj_start_token_ids, + obj_token_lens_ptr=obj_token_lens, + obj_marks_ptr=out_mark, + input_ids_ptr=input_ids, + input_size=input_ids.shape[0], + BLOCK_SIZE=BLOCK, + num_warps=1, + num_stages=1, + ) + return out_mark + def test(): S, D = 1024 * 1000, 128 * 64 @@ -80,27 +155,35 @@ def test(): image_size = 10 image_token_size = 512 - text_weight = torch.randn((vob_size, D), device='cuda', dtype=torch.float16) - img_weight = torch.randn((image_size * image_token_size, D), device='cuda', dtype=torch.float16) - img_token_lens = torch.full((image_size,), image_token_size, device='cuda', dtype=torch.long) - img_start_token_ids = (torch.arange(0, image_size * image_token_size, image_token_size) + vob_size * 10).cuda().long() + text_weight = torch.randn((vob_size, D), device="cuda", dtype=torch.float16) + img_weight = torch.randn((image_size * image_token_size, D), device="cuda", dtype=torch.float16) + img_token_lens = torch.full((image_size,), image_token_size, device="cuda", dtype=torch.long) + img_start_token_ids = ( + (torch.arange(0, image_size * image_token_size, image_token_size) + vob_size * 10).cuda().long() + ) img_start_locs = torch.arange(0, image_size * image_token_size, image_token_size).cuda().long() prompt_ids = torch.arange(0, S, 1).cuda().long() - prompt_ids[0: image_size * image_token_size] = (vob_size * 10 + torch.arange(0, image_size * image_token_size, 1)).cuda().long() + prompt_ids[0 : image_size * image_token_size] = ( + (vob_size * 10 + torch.arange(0, image_size * image_token_size, 1)).cuda().long() + ) out = torch.zeros((S, D), dtype=torch.float16, device="cuda") print(out.shape) import time - - triton_output = multimodal_emb(out, prompt_ids, text_weight, img_weight, img_token_lens, img_start_token_ids, img_start_locs, 0, vob_size) + + multimodal_emb( + out, prompt_ids, text_weight, img_weight, img_token_lens, img_start_token_ids, img_start_locs, 0, vob_size + ) torch.cuda.synchronize() iters = 20 t1 = time.time() for _ in range(iters): - triton_output = multimodal_emb(out, prompt_ids, text_weight, img_weight, img_token_lens, img_start_token_ids, img_start_locs, 0, vob_size) + multimodal_emb( + out, prompt_ids, text_weight, img_weight, img_token_lens, img_start_token_ids, img_start_locs, 0, vob_size + ) torch.cuda.synchronize() t2 = time.time() print("Triton time cost", (t2 - t1) / iters) @@ -109,4 +192,3 @@ def test(): # if __name__ == "__main__": # test() - diff --git a/lightllm/models/gemma3/layer_infer/pre_layer_infer.py b/lightllm/models/gemma3/layer_infer/pre_layer_infer.py index 89c8e0d8d..46b782879 100644 --- a/lightllm/models/gemma3/layer_infer/pre_layer_infer.py +++ b/lightllm/models/gemma3/layer_infer/pre_layer_infer.py @@ -35,10 +35,12 @@ def context_forward(self, input_ids, infer_state, layer_weight): else: weight_mask[idx] = scale + infer_state.mark_multimodal_objs_for_prefill(input_ids=input_ids) + for batch_id, p in enumerate(infer_state.multimodal_params): for img in p["images"]: # skip the same image - if img["token_id"] in img_start_token_ids: + if img["token_id"] in img_start_token_ids or img["_prefill_"] is False: continue # pull the img_embeds by uid from shm data = read_shm(get_shm_name_embed(img["uuid"])) diff --git a/lightllm/models/qwen_vl/layer_infer/pre_layer_infer.py b/lightllm/models/qwen_vl/layer_infer/pre_layer_infer.py index 9b1d95f7c..f4ac4c326 100644 --- a/lightllm/models/qwen_vl/layer_infer/pre_layer_infer.py +++ b/lightllm/models/qwen_vl/layer_infer/pre_layer_infer.py @@ -40,7 +40,7 @@ def _infer_image_embeds(self, infer_state, layer_weight): infer_images = [] for _, p in enumerate(infer_state.multimodal_params): for img in p["images"] + p["audios"]: - if not image_cache_manager.query_embed(img["uuid"]): + if (img["_prefill_"] is True) and (not image_cache_manager.query_embed(img["uuid"])): infer_images.append(img) if len(infer_images) > 0: infer_batch_size = get_env_start_args().visual_infer_batch_size @@ -60,11 +60,14 @@ def context_forward(self, input_ids, infer_state: LlamaInferStateInfo, layer_wei device = layer_weight.wte_weight_.device dtype = layer_weight.wte_weight_.dtype hidden_size = layer_weight.wte_weight_.shape[1] + + infer_state.mark_multimodal_objs_for_prefill(input_ids=input_ids) + self._infer_image_embeds(infer_state, layer_weight) for batch_id, p in enumerate(infer_state.multimodal_params): for img in p["images"] + p["audios"]: # skip the same image - if img["token_id"] in img_start_token_ids: + if img["token_id"] in img_start_token_ids or img["_prefill_"] is False: continue # pull the img_embeds by uid from shm if self.disable_extra_process_for_multimodal: diff --git a/unit_tests/common/basemodel/triton_kernel/test_multimodal_emb.py b/unit_tests/common/basemodel/triton_kernel/test_multimodal_emb.py new file mode 100755 index 000000000..49fcc2b60 --- /dev/null +++ b/unit_tests/common/basemodel/triton_kernel/test_multimodal_emb.py @@ -0,0 +1,22 @@ +import torch +import pytest +from lightllm.common.basemodel.triton_kernel.multimodal_emb import mark_multimodal_obj +from lightllm.utils.log_utils import init_logger + +logger = init_logger(__name__) + + +def test_mark_mubltimodal_obj(): + obj_start_ids = torch.tensor([1, 4, 100], device="cuda", dtype=torch.int64) + obj_token_lens = torch.tensor([1, 3, 2], device="cuda", dtype=torch.int64) + input_ids = torch.tensor([1, 7, 9, 333], device="cuda", dtype=torch.int64) + + mark_obj = mark_multimodal_obj( + obj_start_token_ids=obj_start_ids, obj_token_lens=obj_token_lens, input_ids=input_ids + ) + + assert torch.equal(mark_obj, torch.tensor([1, 0, 0], device="cuda")) + + +if __name__ == "__main__": + pytest.main() From f96c6ab009fd862926a0fa90e92d1e35b298c9ae Mon Sep 17 00:00:00 2001 From: sangchengmeng Date: Fri, 1 Aug 2025 09:36:09 +0000 Subject: [PATCH 07/18] 0801-fix --- lightllm/common/image_cache_manager.py | 76 -------- .../pre_and_post_layer_weight.py | 9 +- lightllm/models/qwen2_vl/model.py | 20 +-- lightllm/models/qwen2_vl/qwen2_visual.py | 42 ----- lightllm/models/qwen2_vl/vision_process.py | 167 +----------------- .../qwen_vl/layer_infer/pre_layer_infer.py | 28 +-- .../embed_cache/impl/naive_memory_cache.py | 1 + .../server/router/model_infer/infer_batch.py | 6 - 8 files changed, 35 insertions(+), 314 deletions(-) diff --git a/lightllm/common/image_cache_manager.py b/lightllm/common/image_cache_manager.py index fb04e4b59..e69de29bb 100644 --- a/lightllm/common/image_cache_manager.py +++ b/lightllm/common/image_cache_manager.py @@ -1,76 +0,0 @@ -from collections import OrderedDict -from lightllm.utils.dist_utils import get_current_device_id - - -class ImageCacheManager: - def __init__(self): - """ - Initialize the image cache manager with a simple GPU cache and an LRU CPU cache. - """ - self._gpu_cache = dict() - self._cpu_cache = OrderedDict() - - def set_max_size(self, max_size: int): - """ - Set the maximum number of items to keep in the CPU cache. - :param max_size: Maximum number of items to keep in the CPU cache. - """ - if max_size <= 0: - raise ValueError("max_size must be greater than 0") - self._max_size = max_size - - def set_embed(self, uuid, embed): - """ - Store the embedding for the given uuid in the GPU cache. - :param uuid: Unique identifier for the image - :param embed: Embedding vector for the image (on GPU) - """ - self._gpu_cache[uuid] = embed - - def get_embed(self, uuid): - """ - Retrieve the embedding for the given uuid. Prefer GPU cache, - otherwise return CPU cache and move to GPU (simulate .cuda()). - :param uuid: Unique identifier for the image - :return: Embedding vector (on GPU if possible, else move from CPU to GPU) - """ - if uuid in self._gpu_cache: - return self._gpu_cache[uuid] - elif uuid in self._cpu_cache: - self._cpu_cache.move_to_end(uuid) - embed = self._cpu_cache[uuid].cuda(get_current_device_id()) - return embed - return None - - def query_embed(self, uuid): - """ - Query if the embedding for the given uuid is in the cache. - :param uuid: Unique identifier for the image - :return: True if the embedding is in the cache, False otherwise - """ - return uuid in self._gpu_cache or uuid in self._cpu_cache - - def filter(self, uuid_list): - """ - Given a list of uuids, move their embeddings from GPU cache to CPU cache if present, - and return a dict of those found in the cache and their embeddings (on CPU). - :param uuid_list: List of uuids - """ - for uuid in uuid_list: - if uuid in self._gpu_cache: - embed_cpu = self._gpu_cache[uuid].cpu() - # Move to CPU cache and remove from GPU cache - self._gpu_cache.pop(uuid) - if uuid in self._cpu_cache: - self._cpu_cache.move_to_end(uuid) - self._cpu_cache[uuid] = embed_cpu - if len(self._cpu_cache) > self._max_size: - self._cpu_cache.popitem(last=False) - elif uuid in self._cpu_cache: - self._cpu_cache.move_to_end(uuid) - print(self._gpu_cache.keys()) - print(self._cpu_cache.keys()) - return - - -image_cache_manager = ImageCacheManager() diff --git a/lightllm/models/internvl/layer_weights/pre_and_post_layer_weight.py b/lightllm/models/internvl/layer_weights/pre_and_post_layer_weight.py index 486319495..b85a536bd 100644 --- a/lightllm/models/internvl/layer_weights/pre_and_post_layer_weight.py +++ b/lightllm/models/internvl/layer_weights/pre_and_post_layer_weight.py @@ -5,7 +5,6 @@ from lightllm.models.internlm2.layer_weights.pre_and_post_layer_weight import Internlm2PreAndPostLayerWeight from lightllm.models.vit.model import VisionTransformer from lightllm.utils.envs_utils import get_env_start_args -from lightllm.common.image_cache_manager import image_cache_manager # add key: language_model.xxx -> xxx @@ -29,11 +28,11 @@ def __init__(self, data_type, network_config, mode): "quant_type": get_env_start_args().vit_quant_type, "quant_cfg": get_env_start_args().vit_quant_cfg, "max_batch_size": get_env_start_args().visual_infer_batch_size, + "cache_port": get_env_start_args().cache_port, } self.visual_model = VisionTransformer( kvargs=kvargs, ) - image_cache_manager.set_max_size(get_env_start_args().cache_capacity * 2) return def load_hf_weights(self, weights): @@ -52,11 +51,11 @@ def __init__(self, data_type, network_config, mode): "quant_type": get_env_start_args().vit_quant_type, "quant_cfg": get_env_start_args().vit_quant_cfg, "max_batch_size": get_env_start_args().visual_infer_batch_size, + "cache_port": get_env_start_args().cache_port, } self.visual_model = VisionTransformer( kvargs=kvargs, ) - image_cache_manager.set_max_size(get_env_start_args().cache_capacity * 2) return def load_hf_weights(self, weights): @@ -76,11 +75,11 @@ def __init__(self, data_type, network_config, mode): "quant_type": get_env_start_args().vit_quant_type, "quant_cfg": get_env_start_args().vit_quant_cfg, "max_batch_size": get_env_start_args().visual_infer_batch_size, + "cache_port": get_env_start_args().cache_port, } self.visual_model = VisionTransformer( kvargs=kvargs, ) - image_cache_manager.set_max_size(get_env_start_args().cache_capacity * 2) return def load_hf_weights(self, weights): @@ -100,11 +99,11 @@ def __init__(self, data_type, network_config, mode): "quant_type": get_env_start_args().vit_quant_type, "quant_cfg": get_env_start_args().vit_quant_cfg, "max_batch_size": get_env_start_args().visual_infer_batch_size, + "cache_port": get_env_start_args().cache_port, } self.visual_model = VisionTransformer( kvargs=kvargs, ) - image_cache_manager.set_max_size(get_env_start_args().cache_capacity * 2) return def load_hf_weights(self, weights): diff --git a/lightllm/models/qwen2_vl/model.py b/lightllm/models/qwen2_vl/model.py index f87c3d6ba..00976467c 100644 --- a/lightllm/models/qwen2_vl/model.py +++ b/lightllm/models/qwen2_vl/model.py @@ -29,6 +29,10 @@ class QWen2VLTokenizer(BaseMultiModalTokenizer): def __init__(self, tokenizer=None, image_processor=None, **kwargs): super().__init__(tokenizer) self.image_processor = image_processor + self.min_pixel = self.image_processor.image_processor.min_pixels + self.max_pixel = self.image_processor.image_processor.max_pixels + self.patch_size = self.image_processor.image_processor.patch_size + self.merge_size = self.image_processor.image_processor.merge_size self.image_start_id = kwargs["model_cfg"]["vision_start_token_id"] self.image_end_id = kwargs["model_cfg"]["vision_end_token_id"] self.image_token_id = kwargs["model_cfg"]["image_token_id"] @@ -44,17 +48,13 @@ def init_audioitem_extral_params( raise NotImplementedError def get_image_token_length(self, img: ImageItem): - width = img.image_w - height = img.image_h - resized_height, resized_width = smart_resize(height=height, width=width) - self.patch_size = self.image_processor.image_processor.patch_size - self.merge_size = self.image_processor.image_processor.merge_size - grid_t = 1 + width, height = img.image_w, img.image_h + resized_height, resized_width = smart_resize( + height=height, width=width, min_pixels=self.min_pixel, max_pixels=self.max_pixel + ) grid_h, grid_w = resized_height // self.patch_size, resized_width // self.patch_size - merge_length = self.merge_size ** 2 - self.token_num = (grid_t * grid_h * grid_w) // merge_length - self.image_length = self.token_num - return self.image_length + self.token_num = (grid_h * grid_w) // (self.merge_size ** 2) + return self.token_num def get_audio_token_length(self, audio: AudioItem): raise NotImplementedError diff --git a/lightllm/models/qwen2_vl/qwen2_visual.py b/lightllm/models/qwen2_vl/qwen2_visual.py index 4a9012518..591915a6e 100644 --- a/lightllm/models/qwen2_vl/qwen2_visual.py +++ b/lightllm/models/qwen2_vl/qwen2_visual.py @@ -46,15 +46,6 @@ from lightllm.models.vit.triton_kernel.flashattention_nopad import flash_attention_fwd from lightllm.common.basemodel.layer_infer.cache_tensor_manager import g_cache_manager -from transformers.utils import is_flash_attn_2_available - -if is_flash_attn_2_available(): - from flash_attn import flash_attn_varlen_func - - from transformers.modeling_flash_attention_utils import _flash_attention_forward -else: - flash_attn_varlen_func = None - logger = logging.get_logger(__name__) @@ -176,39 +167,6 @@ def forward(self, x) -> torch.Tensor: return self.fc2(self.act(self.fc1(x))) -class VisionAttention(nn.Module): - def __init__(self, dim: int, num_heads: int = 16) -> None: - super().__init__() - self.num_heads = num_heads - self.head_dim = dim // num_heads # 初始化 head_dim,每个头的维度 - self.qkv = nn.Linear(dim, dim * 3, bias=True) - self.proj = nn.Linear(dim, dim) - - def forward( - self, hidden_states: torch.Tensor, cu_seqlens: torch.Tensor, rotary_pos_emb: torch.Tensor = None - ) -> torch.Tensor: - seq_length = hidden_states.shape[0] - q, k, v = self.qkv(hidden_states).reshape(seq_length, 3, self.num_heads, -1).permute(1, 0, 2, 3).unbind(0) - q = apply_rotary_pos_emb_vision(q.unsqueeze(0), rotary_pos_emb).squeeze(0) - k = apply_rotary_pos_emb_vision(k.unsqueeze(0), rotary_pos_emb).squeeze(0) - - attention_mask = torch.zeros([1, seq_length, seq_length], device=q.device, dtype=torch.bool) - for i in range(1, len(cu_seqlens)): - attention_mask[..., cu_seqlens[i - 1] : cu_seqlens[i], cu_seqlens[i - 1] : cu_seqlens[i]] = True - - q = q.transpose(0, 1) - k = k.transpose(0, 1) - v = v.transpose(0, 1) - attn_weights = torch.matmul(q, k.transpose(1, 2)) / math.sqrt(self.head_dim) - attn_weights = attn_weights + attention_mask - attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(q.dtype) - attn_output = torch.matmul(attn_weights, v) - attn_output = attn_output.transpose(0, 1) - attn_output = attn_output.reshape(seq_length, -1) - attn_output = self.proj(attn_output) - return attn_output - - # adapted from # https://github.com/huggingface/transformers/blob/main/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py class VisionFlashAttention(nn.Module): diff --git a/lightllm/models/qwen2_vl/vision_process.py b/lightllm/models/qwen2_vl/vision_process.py index 45c250378..806013f12 100644 --- a/lightllm/models/qwen2_vl/vision_process.py +++ b/lightllm/models/qwen2_vl/vision_process.py @@ -44,7 +44,6 @@ ChannelDimension, ImageInput, PILImageResampling, - VideoInput, get_image_size, infer_channel_dimension_format, is_scaled_image, @@ -95,23 +94,6 @@ def make_batched_images(images) -> List[List[ImageInput]]: raise ValueError(f"Could not make batched images from {images}") -# Copied from transformers.models.llava_next_video.image_processing_llava_next_video.make_batched_videos -def make_batched_videos(videos) -> List[VideoInput]: - if isinstance(videos, (list, tuple)) and isinstance(videos[0], (list, tuple)) and is_valid_image(videos[0][0]): - return videos - - elif isinstance(videos, (list, tuple)) and is_valid_image(videos[0]): - if isinstance(videos[0], Image.Image): - return [videos] - elif len(videos[0].shape) == 4: - return [list(video) for video in videos] - - elif is_valid_image(videos) and len(videos.shape) == 4: - return [list(videos)] - - raise ValueError(f"Could not make batched video from {videos}") - - def round_by_factor(number: int, factor: int) -> int: """Returns the closest integer to 'number' that is divisible by 'factor'.""" return round(number / factor) * factor @@ -156,71 +138,8 @@ def smart_resize( return h_bar, w_bar -def fetch_image(ele: dict[str, str | Image.Image], size_factor: int = IMAGE_FACTOR) -> Image.Image: - if "image" in ele: - image = ele["image"] - else: - image = ele["image_url"] - image_obj = None - if isinstance(image, Image.Image): - image_obj = image - elif image.startswith("http://") or image.startswith("https://"): - image_obj = Image.open(requests.get(image, stream=True).raw) - elif image.startswith("file://"): - image_obj = Image.open(image[7:]) - elif image.startswith("data:image"): - data = image.split(";", 1)[1] - if data.startswith("base64,"): - data = base64.b64decode(data[7:]) - image_obj = Image.open(BytesIO(data)) - else: - image_obj = Image.open(image) - if image_obj is None: - raise ValueError(f"Unrecognized image input, support local path, http url, base64 and PIL.Image, got {image}") - image = image_obj.convert("RGB") - ## resize - if "resized_height" in ele and "resized_width" in ele: - resized_height, resized_width = smart_resize( - ele["resized_height"], - ele["resized_width"], - factor=size_factor, - ) - else: - width, height = image.size - min_pixels = ele.get("min_pixels", MIN_PIXELS) - max_pixels = ele.get("max_pixels", MAX_PIXELS) - resized_height, resized_width = smart_resize( - height, - width, - factor=size_factor, - min_pixels=min_pixels, - max_pixels=max_pixels, - ) - image = image.resize((resized_width, resized_height)) - return image - - -def get_image(image_file: str | Image.Image, size_factor: int = IMAGE_FACTOR) -> tuple[Image.Image, int, int]: - image_obj = None - - if isinstance(image_file, Image.Image): - image_obj = image_file - elif image_file.startswith("http://") or image_file.startswith("https://"): - image_obj = Image.open(requests.get(image_file, stream=True).raw) - elif image_file.startswith("file://"): - image_obj = Image.open(image_file[7:]) - elif image_file.startswith("data:image"): - data = image_file.split(";", 1)[1] - if data.startswith("base64,"): - data = base64.b64decode(data[7:]) - image_obj = Image.open(BytesIO(data)) - else: - image_obj = Image.open(image_file) - - if image_obj is None: - raise ValueError("Unrecognized image input. Supports local path, http url, base64, and PIL.Image.") - - image = image_obj.convert("RGB") +def get_image(image_file: Image.Image, size_factor: int = IMAGE_FACTOR) -> tuple[Image.Image, int, int]: + image = image_file.convert("RGB") # 获取原始宽度和高度 width, height = image.size @@ -240,43 +159,6 @@ def get_image(image_file: str | Image.Image, size_factor: int = IMAGE_FACTOR) -> return image -def extract_vision_info(conversations: list[dict] | list[list[dict]]) -> list[dict]: - vision_infos = [] - if isinstance(conversations[0], dict): - conversations = [conversations] - for conversation in conversations: - for message in conversation: - if isinstance(message["content"], list): - for ele in message["content"]: - if ( - "image" in ele - or "image_url" in ele - or "video" in ele - or ele["type"] in ("image", "image_url", "video") - ): - vision_infos.append(ele) - return vision_infos - - -def process_vision_info( - conversations: list[dict] | list[list[dict]], -) -> tuple[list[Image.Image] | None, list[torch.Tensor | list[Image.Image]] | None]: - vision_infos = extract_vision_info(conversations) - ## Read images or videos - image_inputs = [] - # video_inputs = [] - for vision_info in vision_infos: - if "image" in vision_info or "image_url" in vision_info: - image_inputs.append(fetch_image(vision_info)) - # elif "video" in vision_info: - # video_inputs.append(fetch_video(vision_info)) - else: - raise ValueError("image, image_url or video should in content.") - if len(image_inputs) == 0: - image_inputs = None - return image_inputs - - # adapted from # transformers/blob/main/src/transformers/models/qwen2_vl/image_processing_qwen2_vl.py class Qwen2VLImageProcessor(BaseImageProcessor): @@ -318,7 +200,7 @@ def __init__( def _preprocess( self, - images: Union[ImageInput, VideoInput], + images: Union[ImageInput], do_resize: bool = None, resample: PILImageResampling = None, do_rescale: bool = None, @@ -402,7 +284,6 @@ def _preprocess( def preprocess( self, images: ImageInput, - videos: VideoInput = None, do_resize: bool = None, size: Dict[str, int] = None, resample: PILImageResampling = None, @@ -429,26 +310,6 @@ def preprocess( if images is not None: images = make_batched_images(images) - if videos is not None: - videos = make_batched_videos(videos) - - if images is not None and not valid_images(images): - raise ValueError( - "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, " - "torch.Tensor, tf.Tensor or jax.ndarray." - ) - - validate_preprocess_arguments( - rescale_factor=rescale_factor, - do_normalize=do_normalize, - image_mean=image_mean, - image_std=image_std, - do_resize=do_resize, - size=size, - resample=resample, - ) - - if images is not None: pixel_values, vision_grid_thws = [], [] for image in images: patches, image_grid_thw = self._preprocess( @@ -470,26 +331,4 @@ def preprocess( vision_grid_thws = np.array(vision_grid_thws) data = {"pixel_values": pixel_values, "image_grid_thw": vision_grid_thws} - if videos is not None: - pixel_values, vision_grid_thws = [], [] - for images in videos: - patches, video_grid_thw = self._preprocess( - images, - do_resize=do_resize, - resample=resample, - do_rescale=do_rescale, - rescale_factor=rescale_factor, - do_normalize=do_normalize, - image_mean=image_mean, - image_std=image_std, - data_format=data_format, - do_convert_rgb=do_convert_rgb, - input_data_format=input_data_format, - ) - pixel_values.extend(patches) - vision_grid_thws.append(video_grid_thw) - pixel_values = np.array(pixel_values) - vision_grid_thws = np.array(vision_grid_thws) - data = {"pixel_values_videos": pixel_values, "video_grid_thw": vision_grid_thws} - return BatchFeature(data=data, tensor_type=return_tensors) diff --git a/lightllm/models/qwen_vl/layer_infer/pre_layer_infer.py b/lightllm/models/qwen_vl/layer_infer/pre_layer_infer.py index f4ac4c326..cd8d2ae8d 100644 --- a/lightllm/models/qwen_vl/layer_infer/pre_layer_infer.py +++ b/lightllm/models/qwen_vl/layer_infer/pre_layer_infer.py @@ -1,3 +1,4 @@ +import rpyc import torch import torch.distributed as dist @@ -7,10 +8,9 @@ from lightllm.models.llama.layer_infer.pre_layer_infer import LlamaPreLayerInfer from lightllm.utils.infer_utils import mark_cost_time from lightllm.utils.envs_utils import get_env_start_args -from lightllm.server.embed_cache.utils import bytes2tensor, read_shm, get_shm_name_embed -from lightllm.common.image_cache_manager import image_cache_manager from lightllm.common.basemodel.triton_kernel.multimodal_emb import multimodal_emb from lightllm.distributed.communication_op import all_reduce +from lightllm.server.embed_cache.utils import bytes2tensor, tensor2bytes, read_shm, create_shm, get_shm_name_embed """ @@ -31,6 +31,8 @@ class LlamaMultimodalPreLayerInfer(LlamaPreLayerInfer): def __init__(self, network_config, mode): super().__init__(network_config, mode) + print(f"network_config is {network_config}") + self.cache_client = rpyc.connect("localhost", self.cache_port) self.disable_extra_process_for_multimodal = get_env_start_args().disable_extra_process_for_multimodal return @@ -40,14 +42,22 @@ def _infer_image_embeds(self, infer_state, layer_weight): infer_images = [] for _, p in enumerate(infer_state.multimodal_params): for img in p["images"] + p["audios"]: - if (img["_prefill_"] is True) and (not image_cache_manager.query_embed(img["uuid"])): + if img["_prefill_"] is True: infer_images.append(img) if len(infer_images) > 0: infer_batch_size = get_env_start_args().visual_infer_batch_size for i in range(0, len(infer_images), infer_batch_size): img_embeds, uuids, valid_ids = layer_weight.visual_model.encode(infer_images[i : i + infer_batch_size]) - for uuid, valid_id in zip(uuids, valid_ids): - image_cache_manager.set_embed(uuid, img_embeds[valid_id[0] : valid_id[1]]) + img_embeds = img_embeds.to(torch.device("cpu")) + if not dist.is_initialized() or dist.get_rank() == 0: + for i in range(len(uuids)): + uid = uuids[i] + if not self.cache_client.root.get_item_embed(uid): + start, end = valid_ids[i] + cur_embed_bytes = tensor2bytes(img_embeds[start:end]) + create_shm(get_shm_name_embed(uuids[i]), cur_embed_bytes) + self.cache_client.root.set_item_embed(uuids[i]) + return def context_forward(self, input_ids, infer_state: LlamaInferStateInfo, layer_weight: LlamaPreAndPostLayerWeight): @@ -70,12 +80,8 @@ def context_forward(self, input_ids, infer_state: LlamaInferStateInfo, layer_wei if img["token_id"] in img_start_token_ids or img["_prefill_"] is False: continue # pull the img_embeds by uid from shm - if self.disable_extra_process_for_multimodal: - img_embed = image_cache_manager.get_embed(img["uuid"]) - img_weight.append(img_embed.reshape(img["token_num"], -1)) - else: - data = read_shm(get_shm_name_embed(img["uuid"])) - img_weight.append(bytes2tensor(data).cuda().reshape(img["token_num"], -1)) + data = read_shm(get_shm_name_embed(img["uuid"])) + img_weight.append(bytes2tensor(data).cuda().reshape(img["token_num"], -1)) img_start_token_ids.append(img["token_id"]) img_token_lens.append(img["token_num"]) img_start_locs.append(img_start_loc) diff --git a/lightllm/server/embed_cache/impl/naive_memory_cache.py b/lightllm/server/embed_cache/impl/naive_memory_cache.py index c03b084c4..be2eeb6af 100644 --- a/lightllm/server/embed_cache/impl/naive_memory_cache.py +++ b/lightllm/server/embed_cache/impl/naive_memory_cache.py @@ -31,6 +31,7 @@ class Record(object): class InMemoryCache(CacheManager): def __init__(self, args) -> None: self.args = args + self.cache_port = self.args.cache_port self._records = dict() self._md5_to_record = dict() self.capacity = max(1, args.cache_capacity) diff --git a/lightllm/server/router/model_infer/infer_batch.py b/lightllm/server/router/model_infer/infer_batch.py index 551a70a22..10b68245c 100644 --- a/lightllm/server/router/model_infer/infer_batch.py +++ b/lightllm/server/router/model_infer/infer_batch.py @@ -9,7 +9,6 @@ from dataclasses import dataclass, field from typing import List, Dict, Tuple, Optional, Union, Any from lightllm.common.req_manager import ReqManager -from lightllm.common.image_cache_manager import image_cache_manager from lightllm.utils.infer_utils import mark_start, mark_end from lightllm.server.core.objs import Req, SamplingParams, FinishStatus, ShmReqManager from lightllm.server.router.dynamic_prompt.radix_cache import RadixCache, TreeNode @@ -132,7 +131,6 @@ def filter(self, finished_request_ids: List[int]): free_req_index = [] free_token_index = [] - image_uuid_list = [] for request_id in finished_request_ids: req: InferReq = self.requests_mapping.pop(request_id) group_req_id = convert_sub_id_to_group_id(req.shm_req.request_id) @@ -147,10 +145,6 @@ def filter(self, finished_request_ids: List[int]): # logger.info(f"infer release req id {req.shm_req.request_id}") req.shm_req.shm_infer_released = True self.shm_req_manager.put_back_req_obj(req.shm_req) - if req.multimodal_params is not None and get_env_start_args().disable_extra_process_for_multimodal: - for img in req.multimodal_params["images"]: - image_uuid_list.append(img["uuid"]) - image_cache_manager.filter(image_uuid_list) free_token_index = custom_cat(free_token_index) self.req_manager.free(free_req_index, free_token_index) From d61da4db6725c779157ccc0e7482eff4c0f4b799 Mon Sep 17 00:00:00 2001 From: sangchengmeng Date: Fri, 1 Aug 2025 11:10:54 +0000 Subject: [PATCH 08/18] 0801-fix-2 --- .../gemma3/layer_infer/pre_layer_infer.py | 2 +- .../pre_and_post_layer_weight.py | 6 +---- lightllm/models/qwen2_vl/vision_process.py | 3 --- .../qwen_vl/layer_infer/pre_layer_infer.py | 24 ++++++++++--------- lightllm/server/api_start.py | 3 ++- .../embed_cache/impl/naive_memory_cache.py | 1 - lightllm/utils/envs_utils.py | 10 ++++++++ 7 files changed, 27 insertions(+), 22 deletions(-) diff --git a/lightllm/models/gemma3/layer_infer/pre_layer_infer.py b/lightllm/models/gemma3/layer_infer/pre_layer_infer.py index 46b782879..63fe6bf08 100644 --- a/lightllm/models/gemma3/layer_infer/pre_layer_infer.py +++ b/lightllm/models/gemma3/layer_infer/pre_layer_infer.py @@ -40,7 +40,7 @@ def context_forward(self, input_ids, infer_state, layer_weight): for batch_id, p in enumerate(infer_state.multimodal_params): for img in p["images"]: # skip the same image - if img["token_id"] in img_start_token_ids or img["_prefill_"] is False: + if img["token_id"] in img_start_token_ids or img.get("_prefill_", False): continue # pull the img_embeds by uid from shm data = read_shm(get_shm_name_embed(img["uuid"])) diff --git a/lightllm/models/internvl/layer_weights/pre_and_post_layer_weight.py b/lightllm/models/internvl/layer_weights/pre_and_post_layer_weight.py index b85a536bd..260f5c4dc 100644 --- a/lightllm/models/internvl/layer_weights/pre_and_post_layer_weight.py +++ b/lightllm/models/internvl/layer_weights/pre_and_post_layer_weight.py @@ -4,7 +4,7 @@ from lightllm.models.internlm2.layer_weights.pre_and_post_layer_weight import Internlm2PreAndPostLayerWeight from lightllm.models.vit.model import VisionTransformer -from lightllm.utils.envs_utils import get_env_start_args +from lightllm.utils.envs_utils import get_env_start_args, get_cache_port # add key: language_model.xxx -> xxx @@ -28,7 +28,6 @@ def __init__(self, data_type, network_config, mode): "quant_type": get_env_start_args().vit_quant_type, "quant_cfg": get_env_start_args().vit_quant_cfg, "max_batch_size": get_env_start_args().visual_infer_batch_size, - "cache_port": get_env_start_args().cache_port, } self.visual_model = VisionTransformer( kvargs=kvargs, @@ -51,7 +50,6 @@ def __init__(self, data_type, network_config, mode): "quant_type": get_env_start_args().vit_quant_type, "quant_cfg": get_env_start_args().vit_quant_cfg, "max_batch_size": get_env_start_args().visual_infer_batch_size, - "cache_port": get_env_start_args().cache_port, } self.visual_model = VisionTransformer( kvargs=kvargs, @@ -75,7 +73,6 @@ def __init__(self, data_type, network_config, mode): "quant_type": get_env_start_args().vit_quant_type, "quant_cfg": get_env_start_args().vit_quant_cfg, "max_batch_size": get_env_start_args().visual_infer_batch_size, - "cache_port": get_env_start_args().cache_port, } self.visual_model = VisionTransformer( kvargs=kvargs, @@ -99,7 +96,6 @@ def __init__(self, data_type, network_config, mode): "quant_type": get_env_start_args().vit_quant_type, "quant_cfg": get_env_start_args().vit_quant_cfg, "max_batch_size": get_env_start_args().visual_infer_batch_size, - "cache_port": get_env_start_args().cache_port, } self.visual_model = VisionTransformer( kvargs=kvargs, diff --git a/lightllm/models/qwen2_vl/vision_process.py b/lightllm/models/qwen2_vl/vision_process.py index 2433faa80..0107fd97c 100644 --- a/lightllm/models/qwen2_vl/vision_process.py +++ b/lightllm/models/qwen2_vl/vision_process.py @@ -50,10 +50,7 @@ is_valid_image, make_list_of_images, to_numpy_array, - valid_images, - validate_preprocess_arguments, ) -from transformers.video_utils import VideoInput from transformers.utils import TensorType, is_vision_available, logging logger = logging.get_logger(__name__) diff --git a/lightllm/models/qwen_vl/layer_infer/pre_layer_infer.py b/lightllm/models/qwen_vl/layer_infer/pre_layer_infer.py index cd8d2ae8d..b3179a16c 100644 --- a/lightllm/models/qwen_vl/layer_infer/pre_layer_infer.py +++ b/lightllm/models/qwen_vl/layer_infer/pre_layer_infer.py @@ -7,12 +7,11 @@ from lightllm.models.llama.layer_infer.pre_layer_infer import LlamaPreLayerInfer from lightllm.utils.infer_utils import mark_cost_time -from lightllm.utils.envs_utils import get_env_start_args +from lightllm.utils.envs_utils import get_env_start_args, get_cache_port from lightllm.common.basemodel.triton_kernel.multimodal_emb import multimodal_emb from lightllm.distributed.communication_op import all_reduce from lightllm.server.embed_cache.utils import bytes2tensor, tensor2bytes, read_shm, create_shm, get_shm_name_embed - """ infer_state.multimodal_params: batch list of MultimodalParams-dict like: { @@ -31,8 +30,6 @@ class LlamaMultimodalPreLayerInfer(LlamaPreLayerInfer): def __init__(self, network_config, mode): super().__init__(network_config, mode) - print(f"network_config is {network_config}") - self.cache_client = rpyc.connect("localhost", self.cache_port) self.disable_extra_process_for_multimodal = get_env_start_args().disable_extra_process_for_multimodal return @@ -42,21 +39,26 @@ def _infer_image_embeds(self, infer_state, layer_weight): infer_images = [] for _, p in enumerate(infer_state.multimodal_params): for img in p["images"] + p["audios"]: - if img["_prefill_"] is True: + if img.get("_prefill_", True): infer_images.append(img) if len(infer_images) > 0: + self.cache_client = rpyc.connect("localhost", get_cache_port(), config={"allow_pickle": True}) infer_batch_size = get_env_start_args().visual_infer_batch_size for i in range(0, len(infer_images), infer_batch_size): img_embeds, uuids, valid_ids = layer_weight.visual_model.encode(infer_images[i : i + infer_batch_size]) img_embeds = img_embeds.to(torch.device("cpu")) if not dist.is_initialized() or dist.get_rank() == 0: - for i in range(len(uuids)): - uid = uuids[i] - if not self.cache_client.root.get_item_embed(uid): + ready_flags = self.cache_client.root.get_items_embed(uuids) + ids_to_set = [] + for i, ready in enumerate(ready_flags): + if not ready: + uid = uuids[i] start, end = valid_ids[i] cur_embed_bytes = tensor2bytes(img_embeds[start:end]) - create_shm(get_shm_name_embed(uuids[i]), cur_embed_bytes) - self.cache_client.root.set_item_embed(uuids[i]) + create_shm(get_shm_name_embed(uid), cur_embed_bytes) + ids_to_set.append(uid) + if ids_to_set: + self.cache_client.root.set_items_embed(ids_to_set) return def context_forward(self, input_ids, infer_state: LlamaInferStateInfo, layer_weight: LlamaPreAndPostLayerWeight): @@ -77,7 +79,7 @@ def context_forward(self, input_ids, infer_state: LlamaInferStateInfo, layer_wei for batch_id, p in enumerate(infer_state.multimodal_params): for img in p["images"] + p["audios"]: # skip the same image - if img["token_id"] in img_start_token_ids or img["_prefill_"] is False: + if img["token_id"] in img_start_token_ids or img.get("_prefill_", False): continue # pull the img_embeds by uid from shm data = read_shm(get_shm_name_embed(img["uuid"])) diff --git a/lightllm/server/api_start.py b/lightllm/server/api_start.py index b73e6bade..5f0d5ece2 100644 --- a/lightllm/server/api_start.py +++ b/lightllm/server/api_start.py @@ -9,7 +9,7 @@ from .metrics.manager import start_metric_manager from .embed_cache.manager import start_cache_manager from lightllm.utils.log_utils import init_logger -from lightllm.utils.envs_utils import set_env_start_args, set_unique_server_name, get_unique_server_name +from lightllm.utils.envs_utils import set_env_start_args, set_unique_server_name, get_unique_server_name, set_cache_port from lightllm.utils.envs_utils import get_lightllm_gunicorn_time_out_seconds, get_lightllm_gunicorn_keep_alive from .detokenization.manager import start_detokenization_process from .router.manager import start_router_process @@ -242,6 +242,7 @@ def normal_or_p_d_start(args): args.cache_port = cache_port args.metric_port = metric_port + set_cache_port(args.cache_port) # 申请在 p d 分离模式下,会用的端口 args.pd_node_infer_rpyc_ports = can_use_ports[0:node_world_size] # p d 分离模式下用于标识节点的id diff --git a/lightllm/server/embed_cache/impl/naive_memory_cache.py b/lightllm/server/embed_cache/impl/naive_memory_cache.py index 11ddcaaba..5477be22b 100644 --- a/lightllm/server/embed_cache/impl/naive_memory_cache.py +++ b/lightllm/server/embed_cache/impl/naive_memory_cache.py @@ -29,7 +29,6 @@ class Record(object): class InMemoryCache: def __init__(self, args) -> None: self.args = args - self.cache_port = self.args.cache_port self._records = dict() self._md5_to_record = dict() self.capacity = max(1, args.cache_capacity) diff --git a/lightllm/utils/envs_utils.py b/lightllm/utils/envs_utils.py index b78784d82..14b75628f 100644 --- a/lightllm/utils/envs_utils.py +++ b/lightllm/utils/envs_utils.py @@ -17,6 +17,16 @@ def set_unique_server_name(args): return +def set_cache_port(cache_port): + os.environ["LIGHTLLM_CACHE_PORT"] = str(cache_port) + cache_port = os.environ.get("LIGHTLLM_CACHE_PORT") + + +def get_cache_port(): + _cache_port = os.environ.get("LIGHTLLM_CACHE_PORT") + return int(_cache_port) + + @lru_cache(maxsize=None) def get_unique_server_name(): service_uni_name = os.getenv("LIGHTLLM_UNIQUE_SERVICE_NAME_ID") From ab02ccd741f96c34237541600110b3f39e670b3e Mon Sep 17 00:00:00 2001 From: sangchengmeng Date: Fri, 1 Aug 2025 11:14:18 +0000 Subject: [PATCH 09/18] 0801-fix-3 --- lightllm/models/gemma3/layer_infer/pre_layer_infer.py | 2 +- lightllm/models/qwen_vl/layer_infer/pre_layer_infer.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/lightllm/models/gemma3/layer_infer/pre_layer_infer.py b/lightllm/models/gemma3/layer_infer/pre_layer_infer.py index 63fe6bf08..46b782879 100644 --- a/lightllm/models/gemma3/layer_infer/pre_layer_infer.py +++ b/lightllm/models/gemma3/layer_infer/pre_layer_infer.py @@ -40,7 +40,7 @@ def context_forward(self, input_ids, infer_state, layer_weight): for batch_id, p in enumerate(infer_state.multimodal_params): for img in p["images"]: # skip the same image - if img["token_id"] in img_start_token_ids or img.get("_prefill_", False): + if img["token_id"] in img_start_token_ids or img["_prefill_"] is False: continue # pull the img_embeds by uid from shm data = read_shm(get_shm_name_embed(img["uuid"])) diff --git a/lightllm/models/qwen_vl/layer_infer/pre_layer_infer.py b/lightllm/models/qwen_vl/layer_infer/pre_layer_infer.py index b3179a16c..94aa07b78 100644 --- a/lightllm/models/qwen_vl/layer_infer/pre_layer_infer.py +++ b/lightllm/models/qwen_vl/layer_infer/pre_layer_infer.py @@ -79,7 +79,7 @@ def context_forward(self, input_ids, infer_state: LlamaInferStateInfo, layer_wei for batch_id, p in enumerate(infer_state.multimodal_params): for img in p["images"] + p["audios"]: # skip the same image - if img["token_id"] in img_start_token_ids or img.get("_prefill_", False): + if img["token_id"] in img_start_token_ids or img["_prefill_"] is False: continue # pull the img_embeds by uid from shm data = read_shm(get_shm_name_embed(img["uuid"])) From 08e7701b7df5e7a1aac48b9dd526327160ad2c19 Mon Sep 17 00:00:00 2001 From: sangchengmeng Date: Fri, 1 Aug 2025 11:16:08 +0000 Subject: [PATCH 10/18] 0801-fix-3 --- lightllm/models/qwen_vl/layer_infer/pre_layer_infer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lightllm/models/qwen_vl/layer_infer/pre_layer_infer.py b/lightllm/models/qwen_vl/layer_infer/pre_layer_infer.py index 94aa07b78..6c3df12f2 100644 --- a/lightllm/models/qwen_vl/layer_infer/pre_layer_infer.py +++ b/lightllm/models/qwen_vl/layer_infer/pre_layer_infer.py @@ -39,7 +39,7 @@ def _infer_image_embeds(self, infer_state, layer_weight): infer_images = [] for _, p in enumerate(infer_state.multimodal_params): for img in p["images"] + p["audios"]: - if img.get("_prefill_", True): + if img.get("_prefill_"): infer_images.append(img) if len(infer_images) > 0: self.cache_client = rpyc.connect("localhost", get_cache_port(), config={"allow_pickle": True}) From 0d112fb4c1df00186cc58a7ed12378c04d8a8036 Mon Sep 17 00:00:00 2001 From: sangchengmeng Date: Fri, 1 Aug 2025 11:18:12 +0000 Subject: [PATCH 11/18] 0801-fix-3 --- lightllm/models/qwen_vl/layer_infer/pre_layer_infer.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/lightllm/models/qwen_vl/layer_infer/pre_layer_infer.py b/lightllm/models/qwen_vl/layer_infer/pre_layer_infer.py index 6c3df12f2..8aab2bd7a 100644 --- a/lightllm/models/qwen_vl/layer_infer/pre_layer_infer.py +++ b/lightllm/models/qwen_vl/layer_infer/pre_layer_infer.py @@ -39,8 +39,7 @@ def _infer_image_embeds(self, infer_state, layer_weight): infer_images = [] for _, p in enumerate(infer_state.multimodal_params): for img in p["images"] + p["audios"]: - if img.get("_prefill_"): - infer_images.append(img) + infer_images.append(img) if len(infer_images) > 0: self.cache_client = rpyc.connect("localhost", get_cache_port(), config={"allow_pickle": True}) infer_batch_size = get_env_start_args().visual_infer_batch_size From 257a732fa08b613d2825af4f72c14b280462cc3a Mon Sep 17 00:00:00 2001 From: sangchengmeng Date: Mon, 4 Aug 2025 07:39:35 +0000 Subject: [PATCH 12/18] 0801-fix-3 --- lightllm/common/image_cache_manager.py | 0 .../models/internvl/layer_weights/pre_and_post_layer_weight.py | 2 +- 2 files changed, 1 insertion(+), 1 deletion(-) delete mode 100644 lightllm/common/image_cache_manager.py diff --git a/lightllm/common/image_cache_manager.py b/lightllm/common/image_cache_manager.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/lightllm/models/internvl/layer_weights/pre_and_post_layer_weight.py b/lightllm/models/internvl/layer_weights/pre_and_post_layer_weight.py index 260f5c4dc..dcd869018 100644 --- a/lightllm/models/internvl/layer_weights/pre_and_post_layer_weight.py +++ b/lightllm/models/internvl/layer_weights/pre_and_post_layer_weight.py @@ -4,7 +4,7 @@ from lightllm.models.internlm2.layer_weights.pre_and_post_layer_weight import Internlm2PreAndPostLayerWeight from lightllm.models.vit.model import VisionTransformer -from lightllm.utils.envs_utils import get_env_start_args, get_cache_port +from lightllm.utils.envs_utils import get_env_start_args # add key: language_model.xxx -> xxx From ca83f87c00f64845b7d7588eebf536b748dd0770 Mon Sep 17 00:00:00 2001 From: baishihao Date: Tue, 5 Aug 2025 11:12:17 +0800 Subject: [PATCH 13/18] update --- lightllm/common/image_cache_manager.py | 53 ++++++++ .../pre_and_post_layer_weight.py | 61 +++------ .../qwen_vl/layer_infer/pre_layer_infer.py | 117 +++++++++--------- lightllm/models/vit/model.py | 16 +++ .../model_infer/mode_backend/base_backend.py | 32 +++++ .../mode_backend/chunked_prefill/impl.py | 1 + 6 files changed, 177 insertions(+), 103 deletions(-) create mode 100644 lightllm/common/image_cache_manager.py diff --git a/lightllm/common/image_cache_manager.py b/lightllm/common/image_cache_manager.py new file mode 100644 index 000000000..fd337424f --- /dev/null +++ b/lightllm/common/image_cache_manager.py @@ -0,0 +1,53 @@ +from collections import OrderedDict + + +class ImageCacheManager: + def __init__(self): + """ + Initialize the image cache manager with a simple LRU CPU cache. + """ + self._cpu_cache = OrderedDict() + self._max_size = 10000 + + def set_max_size(self, max_size: int): + """ + Set the maximum number of items to keep in the CPU cache. + :param max_size: Maximum number of items to keep in the CPU cache. + """ + if max_size <= 0: + raise ValueError("max_size must be greater than 0") + self._max_size = max_size + + def set_embed(self, uuid, embed): + """ + Store the embedding for the given uuid in the GPU cache. + :param uuid: Unique identifier for the image + :param embed: Embedding vector for the image (on GPU) + """ + if len(self._cpu_cache) >= self._max_size: + self._cpu_cache.popitem(last=False) + self._cpu_cache[uuid] = embed.to("cpu", non_blocking=True) + + def get_embed(self, uuid): + """ + Retrieve the embedding for the given uuid. Prefer GPU cache, + otherwise return CPU cache and move to GPU (simulate .cuda()). + :param uuid: Unique identifier for the image + :return: Embedding vector (on GPU if possible, else move from CPU to GPU) + """ + if uuid in self._cpu_cache: + self._cpu_cache.move_to_end(uuid) + embed = self._cpu_cache[uuid] + return embed.cuda(non_blocking=True) + return None + + def query_embed(self, uuid): + """ + Query if the embedding for the given uuid is in the cache. + :param uuid: Unique identifier for the image + :return: True if the embedding is in the cache, False otherwise + """ + return uuid in self._cpu_cache + + +image_cache_manager = ImageCacheManager() diff --git a/lightllm/models/internvl/layer_weights/pre_and_post_layer_weight.py b/lightllm/models/internvl/layer_weights/pre_and_post_layer_weight.py index dcd869018..e50f8eb2d 100644 --- a/lightllm/models/internvl/layer_weights/pre_and_post_layer_weight.py +++ b/lightllm/models/internvl/layer_weights/pre_and_post_layer_weight.py @@ -17,21 +17,24 @@ def rename_weight_keys(weights): weights[k[len(prefix) :]] = weights[k] +def build_visual_model(args, data_type: torch.dtype): + if args.disable_extra_process_for_multimodal: + kvargs = { + "weight_dir": args.model_dir, + "data_type": data_type, + "quant_type": args.vit_quant_type, + "quant_cfg": args.vit_quant_cfg, + "max_batch_size": args.visual_infer_batch_size, + } + return VisionTransformer(kvargs=kvargs) + return None + + class InternVLPreAndPostLayerWeight(LlamaPreAndPostLayerWeight): def __init__(self, data_type, network_config, mode): super().__init__(data_type, network_config, mode) # if we don't assign an extra process for visual model, we need initialize the image cache manager here - if get_env_start_args().disable_extra_process_for_multimodal: - kvargs = { - "weight_dir": get_env_start_args().model_dir, - "data_type": self.data_type_, - "quant_type": get_env_start_args().vit_quant_type, - "quant_cfg": get_env_start_args().vit_quant_cfg, - "max_batch_size": get_env_start_args().visual_infer_batch_size, - } - self.visual_model = VisionTransformer( - kvargs=kvargs, - ) + self.visual_model = build_visual_model(get_env_start_args(), data_type) return def load_hf_weights(self, weights): @@ -43,17 +46,7 @@ class InternVLPhi3PreAndPostLayerWeight(LlamaPreAndPostLayerWeight): def __init__(self, data_type, network_config, mode): super().__init__(data_type, network_config, mode) # if we don't assign an extra process for visual model, we need initialize the image cache manager here - if get_env_start_args().disable_extra_process_for_multimodal: - kvargs = { - "weight_dir": get_env_start_args().model_dir, - "data_type": self.data_type_, - "quant_type": get_env_start_args().vit_quant_type, - "quant_cfg": get_env_start_args().vit_quant_cfg, - "max_batch_size": get_env_start_args().visual_infer_batch_size, - } - self.visual_model = VisionTransformer( - kvargs=kvargs, - ) + self.visual_model = build_visual_model(get_env_start_args(), data_type) return def load_hf_weights(self, weights): @@ -66,17 +59,7 @@ class InternVLInternlm2PreAndPostLayerWeight(Internlm2PreAndPostLayerWeight): def __init__(self, data_type, network_config, mode): super().__init__(data_type, network_config, mode) # if we don't assign an extra process for visual model, we need initialize the image cache manager here - if get_env_start_args().disable_extra_process_for_multimodal: - kvargs = { - "weight_dir": get_env_start_args().model_dir, - "data_type": self.data_type_, - "quant_type": get_env_start_args().vit_quant_type, - "quant_cfg": get_env_start_args().vit_quant_cfg, - "max_batch_size": get_env_start_args().visual_infer_batch_size, - } - self.visual_model = VisionTransformer( - kvargs=kvargs, - ) + self.visual_model = build_visual_model(get_env_start_args(), data_type) return def load_hf_weights(self, weights): @@ -89,17 +72,7 @@ class InternVLLlamaPreAndPostLayerWeight(LlamaPreAndPostLayerWeight): def __init__(self, data_type, network_config, mode): super().__init__(data_type, network_config, mode) # if we don't assign an extra process for visual model, we need initialize the image cache manager here - if get_env_start_args().disable_extra_process_for_multimodal: - kvargs = { - "weight_dir": get_env_start_args().model_dir, - "data_type": self.data_type_, - "quant_type": get_env_start_args().vit_quant_type, - "quant_cfg": get_env_start_args().vit_quant_cfg, - "max_batch_size": get_env_start_args().visual_infer_batch_size, - } - self.visual_model = VisionTransformer( - kvargs=kvargs, - ) + self.visual_model = build_visual_model(get_env_start_args(), data_type) return def load_hf_weights(self, weights): diff --git a/lightllm/models/qwen_vl/layer_infer/pre_layer_infer.py b/lightllm/models/qwen_vl/layer_infer/pre_layer_infer.py index 8aab2bd7a..7f9297641 100644 --- a/lightllm/models/qwen_vl/layer_infer/pre_layer_infer.py +++ b/lightllm/models/qwen_vl/layer_infer/pre_layer_infer.py @@ -9,6 +9,7 @@ from lightllm.utils.infer_utils import mark_cost_time from lightllm.utils.envs_utils import get_env_start_args, get_cache_port from lightllm.common.basemodel.triton_kernel.multimodal_emb import multimodal_emb +from lightllm.common.image_cache_manager import image_cache_manager from lightllm.distributed.communication_op import all_reduce from lightllm.server.embed_cache.utils import bytes2tensor, tensor2bytes, read_shm, create_shm, get_shm_name_embed @@ -30,77 +31,75 @@ class LlamaMultimodalPreLayerInfer(LlamaPreLayerInfer): def __init__(self, network_config, mode): super().__init__(network_config, mode) - self.disable_extra_process_for_multimodal = get_env_start_args().disable_extra_process_for_multimodal return def _infer_image_embeds(self, infer_state, layer_weight): - if not self.disable_extra_process_for_multimodal: + if layer_weight.visual_model is None: return - infer_images = [] - for _, p in enumerate(infer_state.multimodal_params): - for img in p["images"] + p["audios"]: - infer_images.append(img) - if len(infer_images) > 0: - self.cache_client = rpyc.connect("localhost", get_cache_port(), config={"allow_pickle": True}) - infer_batch_size = get_env_start_args().visual_infer_batch_size - for i in range(0, len(infer_images), infer_batch_size): - img_embeds, uuids, valid_ids = layer_weight.visual_model.encode(infer_images[i : i + infer_batch_size]) - img_embeds = img_embeds.to(torch.device("cpu")) - if not dist.is_initialized() or dist.get_rank() == 0: - ready_flags = self.cache_client.root.get_items_embed(uuids) - ids_to_set = [] - for i, ready in enumerate(ready_flags): - if not ready: - uid = uuids[i] - start, end = valid_ids[i] - cur_embed_bytes = tensor2bytes(img_embeds[start:end]) - create_shm(get_shm_name_embed(uid), cur_embed_bytes) - ids_to_set.append(uid) - if ids_to_set: - self.cache_client.root.set_items_embed(ids_to_set) - return + image_weight = [] + for batch_id, p in enumerate(infer_state.multimodal_params): + for uuid, image_data, token_num in p["image_data"]: + if image_cache_manager.query_embed(uuid): + image_embed = image_cache_manager.get_embed(uuid) + else: + image_data = image_data.to("cuda", non_blocking=True) + image_embed = layer_weight.visual_model.forward(image_data).view(token_num, -1) + image_cache_manager.set_embed(uuid, image_embed) + image_weight.append(image_embed) + if len(image_weight) > 0: + image_weight = torch.cat(image_weight, dim=0) + image_weight = image_weight / self.tp_world_size_ + assert image_weight.shape[1] == layer_weight.wte_weight_.shape[1], ( + f"Dimension mismatch: text weight dimension is {layer_weight.wte_weight_.shape[1]}, " + f"but image weight dimension is {image_weight.shape[1]}" + ) + else: + hidden_size = layer_weight.wte_weight_.shape[1] + image_weight = torch.empty((0, hidden_size), device="cpu", dtype=layer_weight.wte_weight_.dtype).to( + "cuda", non_blocking=True + ) + return image_weight def context_forward(self, input_ids, infer_state: LlamaInferStateInfo, layer_weight: LlamaPreAndPostLayerWeight): - - img_weight = [] - img_start_token_ids = [] - img_token_lens = [] - img_start_loc = 0 - img_start_locs = [] - - device = layer_weight.wte_weight_.device dtype = layer_weight.wte_weight_.dtype hidden_size = layer_weight.wte_weight_.shape[1] + img_weight = self._infer_image_embeds(infer_state, layer_weight) - infer_state.mark_multimodal_objs_for_prefill(input_ids=input_ids) + # infer_state.mark_multimodal_objs_for_prefill(input_ids=input_ids) - self._infer_image_embeds(infer_state, layer_weight) - for batch_id, p in enumerate(infer_state.multimodal_params): - for img in p["images"] + p["audios"]: - # skip the same image - if img["token_id"] in img_start_token_ids or img["_prefill_"] is False: - continue - # pull the img_embeds by uid from shm - data = read_shm(get_shm_name_embed(img["uuid"])) - img_weight.append(bytes2tensor(data).cuda().reshape(img["token_num"], -1)) - img_start_token_ids.append(img["token_id"]) - img_token_lens.append(img["token_num"]) - img_start_locs.append(img_start_loc) - img_start_loc += img["token_num"] - out = torch.zeros((len(input_ids), hidden_size), dtype=dtype, device=device) - if len(img_weight) > 0: - img_weight = torch.cat(img_weight, dim=0).to(device=device, dtype=dtype) - else: - img_weight = torch.empty((0, hidden_size), device=device, dtype=dtype) - assert img_weight.shape[1] == hidden_size, ( - f"Dimension mismatch: text weight dimension is {hidden_size}, " - f"but image weight dimension is {img_weight.shape[1]}" - ) + # self._infer_image_embeds(infer_state, layer_weight) + # for batch_id, p in enumerate(infer_state.multimodal_params): + # for img in p["images"] + p["audios"]: + # # skip the same image + # if img["token_id"] in img_start_token_ids or img["_prefill_"] is False: + # continue + # # pull the img_embeds by uid from shm + # data = read_shm(get_shm_name_embed(img["uuid"])) + # img_weight.append(bytes2tensor(data).cuda().reshape(img["token_num"], -1)) + # img_start_token_ids.append(img["token_id"]) + # img_token_lens.append(img["token_num"]) + # img_start_locs.append(img_start_loc) + # img_start_loc += img["token_num"] + # out = torch.zeros((len(input_ids), hidden_size), dtype=dtype, device=device) + # if len(img_weight) > 0: + # img_weight = torch.cat(img_weight, dim=0).to(device=device, dtype=dtype) + # else: + # img_weight = torch.empty((0, hidden_size), device=device, dtype=dtype) + # assert img_weight.shape[1] == hidden_size, ( + # f"Dimension mismatch: text weight dimension is {hidden_size}, " + # f"but image weight dimension is {img_weight.shape[1]}" + # ) # each tp will fill the img embeds, should divide by world_size + out = torch.zeros((len(input_ids), hidden_size), dtype=dtype, device="cpu").to("cuda", non_blocking=True) img_weight = img_weight / self.tp_world_size_ - img_start_token_ids = torch.Tensor(img_start_token_ids).to(device=device, dtype=torch.long) - img_token_lens = torch.Tensor(img_token_lens).to(device=device, dtype=torch.long) - img_start_locs = torch.Tensor(img_start_locs).to(device=device, dtype=torch.long) + if hasattr(infer_state, "image_start_token_ids"): + img_start_token_ids = infer_state.image_start_token_ids.to("cuda", non_blocking=True) + img_token_lens = infer_state.image_token_lens.to("cuda", non_blocking=True) + img_start_locs = infer_state.image_start_locs.to("cuda", non_blocking=True) + else: + img_start_token_ids = torch.empty((0,), device="cpu", dtype=torch.long).to("cuda", non_blocking=True) + img_token_lens = torch.empty((0,), device="cpu", dtype=torch.long).to("cuda", non_blocking=True) + img_start_locs = torch.empty((0,), device="cpu", dtype=torch.long).to("cuda", non_blocking=True) multimodal_emb( out, diff --git a/lightllm/models/vit/model.py b/lightllm/models/vit/model.py index a8b475889..44705fd8f 100644 --- a/lightllm/models/vit/model.py +++ b/lightllm/models/vit/model.py @@ -59,6 +59,22 @@ def __init__(self, kvargs): self._check_max_len_infer() return + def load_image(self, img: List[ImageItem]): + from lightllm.server.multimodal_params import ImageItem + + img_tensor = None + if isinstance(img, ImageItem): + image_data = read_shm(get_shm_name_data(img.uuid)) + image_data = Image.open(BytesIO(image_data)) + img_tensor = self.load_image_func(image_data, max_num=img.extra_params["image_patch_max_num"]) + elif isinstance(img, dict): + image_data = read_shm(get_shm_name_data(img["uuid"])) + image_data = Image.open(BytesIO(image_data)) + img_tensor = self.load_image_func(image_data, max_num=img["extra_params"]["image_patch_max_num"]) + else: + raise Exception("Unsupport input types: {} for {}".format(type(img), img)) + return img_tensor.to(dtype=self.data_type) + @final @torch.no_grad() def _check_max_len_infer(self): diff --git a/lightllm/server/router/model_infer/mode_backend/base_backend.py b/lightllm/server/router/model_infer/mode_backend/base_backend.py index fd75afdbf..f4c7adbf0 100644 --- a/lightllm/server/router/model_infer/mode_backend/base_backend.py +++ b/lightllm/server/router/model_infer/mode_backend/base_backend.py @@ -16,6 +16,7 @@ from lightllm.common.basemodel.basemodel import TpPartBaseModel from lightllm.common.basemodel.batch_objs import ModelOutput, ModelInput from lightllm.common.basemodel.triton_kernel.mtp_verify import mtp_verify +from lightllm.common.image_cache_manager import image_cache_manager from lightllm.utils.dist_utils import init_distributed_env from lightllm.utils.envs_utils import get_unique_server_name from lightllm.server.core.objs import ShmReqManager, StartArgs @@ -459,6 +460,37 @@ def _get_classed_reqs( return prefill_reqs, decode_reqs + def _preprocess_image(self, batch: ModelInput): + # 如果不是多模态模型,或者单进程推理,直接跳过 + args = get_env_start_args() + if not args.enable_multimodal: + return + # assert self.model.visual_model is not None, "visual_model is not initialized" + image_start_locs = [] + image_token_lens = [] + image_start_token_ids = [] + image_start_loc = 0 + for i, p in enumerate(batch.multimodal_params): + image_datas = [] + for img in p["images"]: + # 重复图片 + if img["token_id"] in image_start_token_ids: + continue + image_start_locs.append(image_start_loc) + image_token_lens.append(img["token_num"]) + image_start_token_ids.append(img["token_id"]) + image_start_loc += img["token_num"] + if not args.disable_extra_process_for_multimodal: + continue + # 预拉取已经存在的image embed + image_data = self.model.pre_post_weight.visual_model.load_image(img) + image_datas.append([img["uuid"], image_data, img["token_num"]]) + p["image_data"] = image_datas + batch.image_start_locs = torch.tensor(image_start_locs, device="cpu", dtype=torch.long) + batch.image_token_lens = torch.tensor(image_token_lens, device="cpu", dtype=torch.long) + batch.image_start_token_ids = torch.tensor(image_start_token_ids, device="cpu", dtype=torch.long) + return batch + def _pre_handle_finished_reqs(self, finished_reqs: List[InferReq]): """ 给 PD 分离模式下,prefill node 使用的继承钩子函数,用于发起 kv 传输任务。 diff --git a/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl.py b/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl.py index 39d345ff5..4972194b2 100644 --- a/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl.py +++ b/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl.py @@ -96,6 +96,7 @@ def prefill_normal( model_input, run_reqs = prepare_prefill_inputs( prefill_reqs, is_chuncked_mode=not self.disable_chunked_prefill, is_multimodal=self.is_multimodal ) + self._preprocess_image(model_input) with torch.cuda.stream(g_infer_context.get_overlap_stream()): model_output = self.model.forward(model_input) logits = model_output.logits From 96a0afb9ff637281c79d2478e4a407d0e1357dd6 Mon Sep 17 00:00:00 2001 From: sangchengmeng Date: Wed, 6 Aug 2025 08:24:34 +0000 Subject: [PATCH 14/18] 0806-fix-no-image-cache --- lightllm/common/basemodel/basemodel.py | 3 ++ lightllm/common/basemodel/batch_objs.py | 3 ++ lightllm/common/image_cache_manager.py | 53 ------------------- .../pre_and_post_layer_weight.py | 1 - .../qwen_vl/layer_infer/pre_layer_infer.py | 43 +++------------ .../model_infer/mode_backend/base_backend.py | 7 +-- 6 files changed, 15 insertions(+), 95 deletions(-) delete mode 100644 lightllm/common/image_cache_manager.py diff --git a/lightllm/common/basemodel/basemodel.py b/lightllm/common/basemodel/basemodel.py index 4516e18c3..5a4858331 100755 --- a/lightllm/common/basemodel/basemodel.py +++ b/lightllm/common/basemodel/basemodel.py @@ -265,6 +265,9 @@ def _create_inferstate(self, model_input: ModelInput, microbatch_index: int = 0) infer_state.b_ready_cache_len = torch.zeros_like(input=infer_state.b_seq_len) infer_state.multimodal_params = model_input.multimodal_params + infer_state.image_start_locs = model_input.image_start_locs + infer_state.image_token_lens = model_input.image_token_lens + infer_state.image_start_token_ids = model_input.image_start_token_ids infer_state.mem_manager = self.mem_manager infer_state.req_manager = self.req_manager diff --git a/lightllm/common/basemodel/batch_objs.py b/lightllm/common/basemodel/batch_objs.py index 9b317b423..3bf372372 100644 --- a/lightllm/common/basemodel/batch_objs.py +++ b/lightllm/common/basemodel/batch_objs.py @@ -17,6 +17,9 @@ class ModelInput: mem_indexes: torch.Tensor = None is_prefill: bool = False b_ready_cache_len: torch.Tensor = None + image_start_locs: torch.Tensor = None + image_token_lens: torch.Tensor = None + image_start_token_ids: torch.Tensor = None multimodal_params: list = field(default_factory=list) # cpu 变量 diff --git a/lightllm/common/image_cache_manager.py b/lightllm/common/image_cache_manager.py deleted file mode 100644 index fd337424f..000000000 --- a/lightllm/common/image_cache_manager.py +++ /dev/null @@ -1,53 +0,0 @@ -from collections import OrderedDict - - -class ImageCacheManager: - def __init__(self): - """ - Initialize the image cache manager with a simple LRU CPU cache. - """ - self._cpu_cache = OrderedDict() - self._max_size = 10000 - - def set_max_size(self, max_size: int): - """ - Set the maximum number of items to keep in the CPU cache. - :param max_size: Maximum number of items to keep in the CPU cache. - """ - if max_size <= 0: - raise ValueError("max_size must be greater than 0") - self._max_size = max_size - - def set_embed(self, uuid, embed): - """ - Store the embedding for the given uuid in the GPU cache. - :param uuid: Unique identifier for the image - :param embed: Embedding vector for the image (on GPU) - """ - if len(self._cpu_cache) >= self._max_size: - self._cpu_cache.popitem(last=False) - self._cpu_cache[uuid] = embed.to("cpu", non_blocking=True) - - def get_embed(self, uuid): - """ - Retrieve the embedding for the given uuid. Prefer GPU cache, - otherwise return CPU cache and move to GPU (simulate .cuda()). - :param uuid: Unique identifier for the image - :return: Embedding vector (on GPU if possible, else move from CPU to GPU) - """ - if uuid in self._cpu_cache: - self._cpu_cache.move_to_end(uuid) - embed = self._cpu_cache[uuid] - return embed.cuda(non_blocking=True) - return None - - def query_embed(self, uuid): - """ - Query if the embedding for the given uuid is in the cache. - :param uuid: Unique identifier for the image - :return: True if the embedding is in the cache, False otherwise - """ - return uuid in self._cpu_cache - - -image_cache_manager = ImageCacheManager() diff --git a/lightllm/models/llama/layer_weights/pre_and_post_layer_weight.py b/lightllm/models/llama/layer_weights/pre_and_post_layer_weight.py index 711406e3f..ef6f71382 100644 --- a/lightllm/models/llama/layer_weights/pre_and_post_layer_weight.py +++ b/lightllm/models/llama/layer_weights/pre_and_post_layer_weight.py @@ -23,7 +23,6 @@ def load_hf_weights(self, weights): self.lm_head_weight_ = self._cuda(weights["lm_head.weight"][split_start:split_end, :]) if "model.norm.weight" in weights: self.final_norm_weight_ = self._cuda(weights["model.norm.weight"]) - return def verify_load(self): diff --git a/lightllm/models/qwen_vl/layer_infer/pre_layer_infer.py b/lightllm/models/qwen_vl/layer_infer/pre_layer_infer.py index 7f9297641..5c5f57e47 100644 --- a/lightllm/models/qwen_vl/layer_infer/pre_layer_infer.py +++ b/lightllm/models/qwen_vl/layer_infer/pre_layer_infer.py @@ -9,7 +9,6 @@ from lightllm.utils.infer_utils import mark_cost_time from lightllm.utils.envs_utils import get_env_start_args, get_cache_port from lightllm.common.basemodel.triton_kernel.multimodal_emb import multimodal_emb -from lightllm.common.image_cache_manager import image_cache_manager from lightllm.distributed.communication_op import all_reduce from lightllm.server.embed_cache.utils import bytes2tensor, tensor2bytes, read_shm, create_shm, get_shm_name_embed @@ -38,14 +37,11 @@ def _infer_image_embeds(self, infer_state, layer_weight): return image_weight = [] for batch_id, p in enumerate(infer_state.multimodal_params): - for uuid, image_data, token_num in p["image_data"]: - if image_cache_manager.query_embed(uuid): - image_embed = image_cache_manager.get_embed(uuid) - else: - image_data = image_data.to("cuda", non_blocking=True) - image_embed = layer_weight.visual_model.forward(image_data).view(token_num, -1) - image_cache_manager.set_embed(uuid, image_embed) - image_weight.append(image_embed) + for img in p["images"] + p["audios"]: + if img.get("_prefill_", True): + image_data = img["image_data"].to("cuda", non_blocking=True) + image_embed = layer_weight.visual_model.forward(image_data).view(img["token_num"], -1) + image_weight.append(image_embed) if len(image_weight) > 0: image_weight = torch.cat(image_weight, dim=0) image_weight = image_weight / self.tp_world_size_ @@ -63,36 +59,12 @@ def _infer_image_embeds(self, infer_state, layer_weight): def context_forward(self, input_ids, infer_state: LlamaInferStateInfo, layer_weight: LlamaPreAndPostLayerWeight): dtype = layer_weight.wte_weight_.dtype hidden_size = layer_weight.wte_weight_.shape[1] + infer_state.mark_multimodal_objs_for_prefill(input_ids=input_ids) img_weight = self._infer_image_embeds(infer_state, layer_weight) - # infer_state.mark_multimodal_objs_for_prefill(input_ids=input_ids) - - # self._infer_image_embeds(infer_state, layer_weight) - # for batch_id, p in enumerate(infer_state.multimodal_params): - # for img in p["images"] + p["audios"]: - # # skip the same image - # if img["token_id"] in img_start_token_ids or img["_prefill_"] is False: - # continue - # # pull the img_embeds by uid from shm - # data = read_shm(get_shm_name_embed(img["uuid"])) - # img_weight.append(bytes2tensor(data).cuda().reshape(img["token_num"], -1)) - # img_start_token_ids.append(img["token_id"]) - # img_token_lens.append(img["token_num"]) - # img_start_locs.append(img_start_loc) - # img_start_loc += img["token_num"] - # out = torch.zeros((len(input_ids), hidden_size), dtype=dtype, device=device) - # if len(img_weight) > 0: - # img_weight = torch.cat(img_weight, dim=0).to(device=device, dtype=dtype) - # else: - # img_weight = torch.empty((0, hidden_size), device=device, dtype=dtype) - # assert img_weight.shape[1] == hidden_size, ( - # f"Dimension mismatch: text weight dimension is {hidden_size}, " - # f"but image weight dimension is {img_weight.shape[1]}" - # ) - # each tp will fill the img embeds, should divide by world_size out = torch.zeros((len(input_ids), hidden_size), dtype=dtype, device="cpu").to("cuda", non_blocking=True) img_weight = img_weight / self.tp_world_size_ - if hasattr(infer_state, "image_start_token_ids"): + if infer_state.image_start_token_ids is not None: img_start_token_ids = infer_state.image_start_token_ids.to("cuda", non_blocking=True) img_token_lens = infer_state.image_token_lens.to("cuda", non_blocking=True) img_start_locs = infer_state.image_start_locs.to("cuda", non_blocking=True) @@ -100,7 +72,6 @@ def context_forward(self, input_ids, infer_state: LlamaInferStateInfo, layer_wei img_start_token_ids = torch.empty((0,), device="cpu", dtype=torch.long).to("cuda", non_blocking=True) img_token_lens = torch.empty((0,), device="cpu", dtype=torch.long).to("cuda", non_blocking=True) img_start_locs = torch.empty((0,), device="cpu", dtype=torch.long).to("cuda", non_blocking=True) - multimodal_emb( out, input_ids, diff --git a/lightllm/server/router/model_infer/mode_backend/base_backend.py b/lightllm/server/router/model_infer/mode_backend/base_backend.py index f4c7adbf0..b3bc0e6f2 100644 --- a/lightllm/server/router/model_infer/mode_backend/base_backend.py +++ b/lightllm/server/router/model_infer/mode_backend/base_backend.py @@ -16,7 +16,6 @@ from lightllm.common.basemodel.basemodel import TpPartBaseModel from lightllm.common.basemodel.batch_objs import ModelOutput, ModelInput from lightllm.common.basemodel.triton_kernel.mtp_verify import mtp_verify -from lightllm.common.image_cache_manager import image_cache_manager from lightllm.utils.dist_utils import init_distributed_env from lightllm.utils.envs_utils import get_unique_server_name from lightllm.server.core.objs import ShmReqManager, StartArgs @@ -471,7 +470,6 @@ def _preprocess_image(self, batch: ModelInput): image_start_token_ids = [] image_start_loc = 0 for i, p in enumerate(batch.multimodal_params): - image_datas = [] for img in p["images"]: # 重复图片 if img["token_id"] in image_start_token_ids: @@ -482,10 +480,9 @@ def _preprocess_image(self, batch: ModelInput): image_start_loc += img["token_num"] if not args.disable_extra_process_for_multimodal: continue - # 预拉取已经存在的image embed + # 预拉取已经存在的image data image_data = self.model.pre_post_weight.visual_model.load_image(img) - image_datas.append([img["uuid"], image_data, img["token_num"]]) - p["image_data"] = image_datas + img["image_data"] = image_data batch.image_start_locs = torch.tensor(image_start_locs, device="cpu", dtype=torch.long) batch.image_token_lens = torch.tensor(image_token_lens, device="cpu", dtype=torch.long) batch.image_start_token_ids = torch.tensor(image_start_token_ids, device="cpu", dtype=torch.long) From eba4b0097866a29933590e0a21ed815fb4845733 Mon Sep 17 00:00:00 2001 From: sangchengmeng Date: Thu, 7 Aug 2025 13:16:12 +0000 Subject: [PATCH 15/18] 0807-fix-qwen2vl --- .../pre_and_post_layer_weight.py | 27 ++ lightllm/models/qwen2_5_vl/qwen2_5_visual.py | 37 ++- .../pre_and_post_layer_weight.py | 27 ++ lightllm/models/qwen2_vl/model.py | 36 ++- lightllm/models/qwen2_vl/qwen2_visual.py | 256 ++++++++--------- .../qwen2_vl/triton_kernel/rotary_pos_emb.py | 156 +++++++++++ lightllm/models/qwen2_vl/vision_process.py | 257 +++--------------- .../qwen_vl/layer_infer/pre_layer_infer.py | 8 +- lightllm/models/tarsier2/tarsier2_visual.py | 8 +- lightllm/models/vit/model.py | 21 +- .../model_infer/mode_backend/base_backend.py | 3 +- .../visualserver/model_infer/model_rpc.py | 24 +- 12 files changed, 456 insertions(+), 404 deletions(-) create mode 100644 lightllm/models/qwen2_5_vl/layer_weights/pre_and_post_layer_weight.py create mode 100644 lightllm/models/qwen2_vl/layer_weights/pre_and_post_layer_weight.py create mode 100644 lightllm/models/qwen2_vl/triton_kernel/rotary_pos_emb.py diff --git a/lightllm/models/qwen2_5_vl/layer_weights/pre_and_post_layer_weight.py b/lightllm/models/qwen2_5_vl/layer_weights/pre_and_post_layer_weight.py new file mode 100644 index 000000000..c3d3e2225 --- /dev/null +++ b/lightllm/models/qwen2_5_vl/layer_weights/pre_and_post_layer_weight.py @@ -0,0 +1,27 @@ +import torch +import numpy as np +from lightllm.utils.envs_utils import get_env_start_args +from transformers.configuration_utils import PretrainedConfig +from lightllm.models.qwen2.layer_weights.pre_and_post_layer_weight import Qwen2PreAndPostLayerWeight +from lightllm.models.qwen2_5_vl.qwen2_5_visual import Qwen2_5VLTransformer + + +def build_visual_model(args, data_type: torch.dtype): + if args.disable_extra_process_for_multimodal: + kvargs = { + "weight_dir": args.model_dir, + "data_type": args.data_type, + "quant_type": args.vit_quant_type, + "quant_cfg": args.vit_quant_cfg, + "max_batch_size": args.visual_infer_batch_size, + } + model_cfg, _ = PretrainedConfig.get_config_dict(kvargs["weight_dir"]) + return Qwen2_5VLTransformer(kvargs=kvargs, **model_cfg["vision_config"]).eval().to(dtype=data_type) + return None + + +class Qwen2_5VLPreAndPostLayerWeight(Qwen2PreAndPostLayerWeight): + def __init__(self, data_type, network_config, mode): + super().__init__(data_type, network_config, mode) + self.visual_model = build_visual_model(get_env_start_args(), data_type) + return diff --git a/lightllm/models/qwen2_5_vl/qwen2_5_visual.py b/lightllm/models/qwen2_5_vl/qwen2_5_visual.py index a9e162e3f..693a73a38 100644 --- a/lightllm/models/qwen2_5_vl/qwen2_5_visual.py +++ b/lightllm/models/qwen2_5_vl/qwen2_5_visual.py @@ -16,7 +16,7 @@ from torch.nn import LayerNorm from transformers.activations import ACT2FN import math -from lightllm.models.qwen2_vl.vision_process import get_image, Qwen2VLImageProcessor +from lightllm.models.qwen2_vl.vision_process import resize_image, Qwen2VLImageProcessor from transformers import AutoProcessor from safetensors import safe_open from transformers.utils import TensorType @@ -212,9 +212,10 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: return x -class Qwen2_5_VisionTransformerPretrainedModel(nn.Module): +class Qwen2_5VLTransformer(nn.Module): def __init__( self, + weight_dir, depth=32, hidden_size=3584, hidden_act="silu", @@ -278,6 +279,11 @@ def __init__( self.gradient_checkpointing = False + processor_config_path = os.path.join(weight_dir, "preprocessor_config.json") + with open(processor_config_path, "r") as f: + processor_config_dict = json.load(f) + self.processor = Qwen2VLImageProcessor(**processor_config_dict) + self.device = self.get_device() self.dtype = self.get_dtype() @@ -416,12 +422,27 @@ def forward(self, hidden_states: torch.Tensor, grid_thw: torch.Tensor) -> torch. return hidden_states - def load_model(self, weight_dir): + def load_image(self, img: List[ImageItem]): + pixel_values = None + if isinstance(img, ImageItem): + image_data = read_shm(get_shm_name_data(img.uuid)) + image_data = Image.open(BytesIO(image_data)) + image_data = resize_image(image_data) + image_inputs = self.processor.preprocess(images=image_data, return_tensors="pt") + pixel_values = image_inputs["pixel_values"].to(dtype=torch.bfloat16) + image_grid_thw = image_inputs["image_grid_thw"] + elif isinstance(img, dict): + image_data = read_shm(get_shm_name_data(img["uuid"])) + image_data = Image.open(BytesIO(image_data)) + image_data = resize_image(image_data) + image_inputs = self.processor.preprocess(images=image_data, return_tensors="pt") + pixel_values = image_inputs["pixel_values"].to(dtype=torch.bfloat16) + image_grid_thw = image_inputs["image_grid_thw"] + else: + raise Exception("Unsupport input types: {} for {}".format(type(img), img)) + return pixel_values.to(dtype=self.get_dtype()), image_grid_thw - processor_config_path = os.path.join(weight_dir, "preprocessor_config.json") - with open(processor_config_path, "r") as f: - processor_config_dict = json.load(f) - self.processor = Qwen2VLImageProcessor(**processor_config_dict) + def load_model(self, weight_dir): bin_weight_files = [file_ for file_ in os.listdir(weight_dir) if file_.endswith(".bin")] if bin_weight_files: @@ -455,7 +476,7 @@ def encode(self, images: List[ImageItem]): uuids.append(img.uuid) image_data = read_shm(get_shm_name_data(img.uuid)) image_data = Image.open(BytesIO(image_data)) - image_data = get_image(image_data) + image_data = resize_image(image_data) image_inputs = self.processor.preprocess(images=image_data, return_tensors="pt") pixel_values = image_inputs["pixel_values"].to(dtype=torch.bfloat16) image_grid_thw = image_inputs["image_grid_thw"] diff --git a/lightllm/models/qwen2_vl/layer_weights/pre_and_post_layer_weight.py b/lightllm/models/qwen2_vl/layer_weights/pre_and_post_layer_weight.py new file mode 100644 index 000000000..ec46102a6 --- /dev/null +++ b/lightllm/models/qwen2_vl/layer_weights/pre_and_post_layer_weight.py @@ -0,0 +1,27 @@ +import torch +import numpy as np +from lightllm.utils.envs_utils import get_env_start_args +from transformers.configuration_utils import PretrainedConfig +from lightllm.models.qwen2.layer_weights.pre_and_post_layer_weight import Qwen2PreAndPostLayerWeight +from lightllm.models.qwen2_vl.qwen2_visual import Qwen2VLTransformer + + +def build_visual_model(args, data_type: torch.dtype): + if args.disable_extra_process_for_multimodal: + kvargs = { + "weight_dir": args.model_dir, + "data_type": args.data_type, + "quant_type": args.vit_quant_type, + "quant_cfg": args.vit_quant_cfg, + "max_batch_size": args.visual_infer_batch_size, + } + model_cfg, _ = PretrainedConfig.get_config_dict(kvargs["weight_dir"]) + return Qwen2VLTransformer(kvargs=kvargs, **model_cfg["vision_config"]).eval().to(dtype=data_type) + return None + + +class Qwen2VLPreAndPostLayerWeight(Qwen2PreAndPostLayerWeight): + def __init__(self, data_type, network_config, mode): + super().__init__(data_type, network_config, mode) + self.visual_model = build_visual_model(get_env_start_args(), data_type) + return diff --git a/lightllm/models/qwen2_vl/model.py b/lightllm/models/qwen2_vl/model.py index fb8ee294b..afa6712a1 100644 --- a/lightllm/models/qwen2_vl/model.py +++ b/lightllm/models/qwen2_vl/model.py @@ -16,6 +16,8 @@ from lightllm.common.build_utils import repair_config from lightllm.models.registry import ModelRegistry from lightllm.models.qwen2_vl.infer_struct import Qwen2VLInferStateInfo +from lightllm.models.qwen2_vl.layer_weights.pre_and_post_layer_weight import Qwen2VLPreAndPostLayerWeight +from lightllm.models.qwen2_5_vl.layer_weights.pre_and_post_layer_weight import Qwen2_5VLPreAndPostLayerWeight from lightllm.models.qwen2_vl.layer_infer.transformer_layer_infer import Qwen2VLTransformerLayerInfer import torch @@ -93,12 +95,44 @@ def encode(self, prompt, multimodal_params: MultimodalParams = None, **kwargs): return input_ids -@ModelRegistry(["qwen2_vl", "qwen2_5_vl"], is_multimodal=True) +@ModelRegistry(["qwen2_vl"], is_multimodal=True) class Qwen2VLTpPartModel(Qwen2TpPartModel): pre_layer_infer_class = LlamaMultimodalPreLayerInfer transformer_layer_infer_class = Qwen2VLTransformerLayerInfer + pre_and_post_weight_class = Qwen2VLPreAndPostLayerWeight + + infer_state_class = Qwen2VLInferStateInfo + + def __init__(self, kvargs): + super().__init__(kvargs) + return + + def _init_inferstate_cls(self): + if get_env_start_args().enable_fa3: + self.infer_state_class = Qwen2VLFlashAttentionStateInfo + + def _init_config(self): + with open(os.path.join(self.weight_dir_, "config.json"), "r") as json_file: + self.config = json.load(json_file) + # rename keys + repair_config(self.config, same_names=["num_attention_heads", "n_head"]) + repair_config(self.config, same_names=["hidden_size", "n_embd", "n_embed"]) + repair_config(self.config, same_names=["num_hidden_layers", "n_layer"]) + if self.finetune_config: + self.config["vocab_size"] = self.finetune_config.vocab_size + return + + +@ModelRegistry(["qwen2_5_vl"], is_multimodal=True) +class Qwen2_5VLTpPartModel(Qwen2TpPartModel): + + pre_layer_infer_class = LlamaMultimodalPreLayerInfer + transformer_layer_infer_class = Qwen2VLTransformerLayerInfer + + pre_and_post_weight_class = Qwen2_5VLPreAndPostLayerWeight + infer_state_class = Qwen2VLInferStateInfo def __init__(self, kvargs): diff --git a/lightllm/models/qwen2_vl/qwen2_visual.py b/lightllm/models/qwen2_vl/qwen2_visual.py index 591915a6e..e9274af82 100644 --- a/lightllm/models/qwen2_vl/qwen2_visual.py +++ b/lightllm/models/qwen2_vl/qwen2_visual.py @@ -24,6 +24,7 @@ import torch import torch.nn.functional as F from PIL import Image +from einops import rearrange from typing import List, Union from torchvision import transforms as T from torchvision.transforms.functional import InterpolationMode @@ -36,50 +37,19 @@ import torch.nn as nn from torch.nn import LayerNorm from transformers.activations import ACT2FN -import math -from .vision_process import get_image -from transformers import AutoProcessor +import time from safetensors import safe_open from transformers.utils import TensorType +from lightllm.common.build_utils import repair_config from lightllm.server.multimodal_params import MultimodalParams, ImageItem -from lightllm.models.qwen2_vl.vision_process import Qwen2VLImageProcessor +from lightllm.models.qwen2_vl.vision_process import resize_image, Qwen2VLImageProcessor +from lightllm.models.qwen2_vl.triton_kernel.rotary_pos_emb import apply_rotary_pos_emb_triton from lightllm.models.vit.triton_kernel.flashattention_nopad import flash_attention_fwd from lightllm.common.basemodel.layer_infer.cache_tensor_manager import g_cache_manager - - -logger = logging.get_logger(__name__) +from vllm.vllm_flash_attn.layers.rotary import apply_rotary_emb # adapted from # https://github.com/huggingface/transformers/blob/main/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py -class Qwen2VLVisionConfig(PretrainedConfig): - model_type = "qwen2_vl" - - def __init__( - self, - depth=32, - embed_dim=1280, - hidden_size=3584, - hidden_act="quick_gelu", - mlp_ratio=4, - num_heads=16, - in_channels=3, - patch_size=14, - spatial_merge_size=2, - temporal_patch_size=2, - **kwargs, - ): - super().__init__(**kwargs) - - self.depth = depth - self.embed_dim = embed_dim - self.hidden_size = hidden_size - self.hidden_act = hidden_act - self.mlp_ratio = mlp_ratio - self.num_heads = num_heads - self.in_channels = in_channels - self.patch_size = patch_size - self.spatial_merge_size = spatial_merge_size - self.temporal_patch_size = temporal_patch_size class PatchEmbed(nn.Module): @@ -100,11 +70,10 @@ def __init__( self.proj = nn.Conv3d(in_channels, embed_dim, kernel_size=kernel_size, stride=kernel_size, bias=False) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: - target_dtype = self.proj.weight.dtype hidden_states = hidden_states.view( -1, self.in_channels, self.temporal_patch_size, self.patch_size, self.patch_size - ).cuda() - hidden_states = self.proj(hidden_states.to(dtype=target_dtype)).view(-1, self.embed_dim) + ) + hidden_states = self.proj(hidden_states).view(-1, self.embed_dim) return hidden_states @@ -124,38 +93,6 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: return x -# Copied from transformers.models.llama.modeling_llama.rotate_half -def rotate_half(x): - """Rotates half the hidden dims of the input.""" - x1 = x[..., : x.shape[-1] // 2] - x2 = x[..., x.shape[-1] // 2 :] - return torch.cat((-x2, x1), dim=-1) - - -def apply_rotary_pos_emb_vision(tensor: torch.Tensor, freqs: torch.Tensor) -> torch.Tensor: - orig_dtype = tensor.dtype - tensor = tensor.float() - cos = freqs.cos() - sin = freqs.sin() - cos = cos.unsqueeze(1).repeat(1, 1, 2).unsqueeze(0).float() - sin = sin.unsqueeze(1).repeat(1, 1, 2).unsqueeze(0).float() - output = (tensor * cos) + (rotate_half(tensor) * sin) - output = output.to(orig_dtype) - return output - - -class VisionRotaryEmbedding(nn.Module): - def __init__(self, dim: int, theta: float = 10000.0) -> None: - super().__init__() - inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float) / dim)) - self.register_buffer("inv_freq", inv_freq, persistent=False) - - def forward(self, seqlen: int) -> torch.Tensor: - seq = torch.arange(seqlen, device=self.inv_freq.device, dtype=self.inv_freq.dtype) - freqs = torch.outer(seq, self.inv_freq) - return freqs - - class VisionMlp(nn.Module): def __init__(self, dim: int, hidden_dim: int, hidden_act: str) -> None: super().__init__() @@ -167,8 +104,17 @@ def forward(self, x) -> torch.Tensor: return self.fc2(self.act(self.fc1(x))) -# adapted from -# https://github.com/huggingface/transformers/blob/main/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py +class VisionRotaryEmbedding(nn.Module): + def __init__(self, dim: int, theta: float = 10000.0) -> None: + super().__init__() + self.inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float) / dim)) + + def forward(self, seqlen: int) -> torch.Tensor: + self.seq = torch.arange(seqlen, device=self.inv_freq.device, dtype=self.inv_freq.dtype) + self.freqs = torch.outer(self.seq, self.inv_freq) + return self.freqs + + class VisionFlashAttention(nn.Module): def __init__(self, dim: int, num_heads: int = 16) -> None: super().__init__() @@ -176,17 +122,24 @@ def __init__(self, dim: int, num_heads: int = 16) -> None: self.qkv = nn.Linear(dim, dim * 3, bias=True) self.proj = nn.Linear(dim, dim) + def apply_rotary_pos_emb_vision(self, t: torch.Tensor, freqs: torch.Tensor) -> torch.Tensor: + t_ = t.float() + cos = freqs.cos() + sin = freqs.sin() + output = apply_rotary_emb(t_, cos, sin).type_as(t) + return output + def forward( self, hidden_states: torch.Tensor, cu_seqlens: torch.Tensor, rotary_pos_emb: torch.Tensor = None ) -> torch.Tensor: seq_length = hidden_states.shape[0] q, k, v = self.qkv(hidden_states).reshape(seq_length, 3, self.num_heads, -1).permute(1, 0, 2, 3).unbind(0) - q = apply_rotary_pos_emb_vision(q.unsqueeze(0), rotary_pos_emb) - k = apply_rotary_pos_emb_vision(k.unsqueeze(0), rotary_pos_emb) + q = self.apply_rotary_pos_emb_vision(q.unsqueeze(0), rotary_pos_emb.cuda()) + k = self.apply_rotary_pos_emb_vision(k.unsqueeze(0), rotary_pos_emb.cuda()) q = q.squeeze(0) k = k.squeeze(0) - cu_seqlens = cu_seqlens.to(q.device, torch.int32) + cu_seqlens = cu_seqlens.to(q.device, torch.int32, non_blocking=True) max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() attn_output = g_cache_manager.alloc_tensor(q.shape, q.dtype, device=q.device) @@ -196,8 +149,6 @@ def forward( return attn_output -# adapted from -# https://github.com/huggingface/transformers/blob/main/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py class Qwen2VLVisionBlock(nn.Module): def __init__(self, embed_dim, mlp_ratio, num_heads, hidden_act) -> None: super().__init__() @@ -216,11 +167,10 @@ def forward(self, hidden_states, cu_seqlens, rotary_pos_emb) -> torch.Tensor: return hidden_states -# adapted from -# https://github.com/huggingface/transformers/blob/main/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py -class Qwen2VisionTransformerPretrainedModel(nn.Module): +class Qwen2VLTransformer(nn.Module): def __init__( self, + kvargs, depth=32, embed_dim=1280, hidden_size=3584, @@ -234,6 +184,15 @@ def __init__( **kwargs, ): super().__init__() + + self.weight_dir_ = kvargs["weight_dir"] + self.data_type = kvargs.get("data_type", "bfloat16") + # self.mode = [m.replace("int4weight", "w4a16").replace("int8weight", "w8a16") for m in kvargs.get("mode", [])] + # self.weight_dict = kvargs.get("weight_dict", None) + # self.quant_type = kvargs.get("quant_type", None) + # self.quant_cfg_path = kvargs.get("quant_cfg", None) + # self.max_batch_size = kvargs.get("max_batch_size", 1) + self.depth = depth self.embed_dim = embed_dim self.hidden_size = hidden_size @@ -253,7 +212,7 @@ def __init__( ) head_dim = self.embed_dim // self.num_heads - self.rotary_pos_emb = VisionRotaryEmbedding(head_dim // 2) + self.rotary_pos_emb = VisionRotaryEmbedding(head_dim // 2).to("cuda", non_blocking=True) self.blocks = nn.ModuleList( [ @@ -263,38 +222,60 @@ def __init__( ) self.merger = PatchMerger(dim=self.hidden_size, context_dim=self.embed_dim) - self.device = self.get_device() - self.dtype = self.get_dtype() + processor_config_path = os.path.join(kvargs["weight_dir"], "preprocessor_config.json") + with open(processor_config_path, "r") as f: + processor_config_dict = json.load(f) + self.processor = Qwen2VLImageProcessor(**processor_config_dict) - def get_dtype(self) -> torch.dtype: - return self.blocks[0].mlp.fc2.weight.dtype + self._init_datatype() + self.load_model(kvargs["weight_dir"]) + self.cuda() + + def _init_datatype(self): + if isinstance(self.data_type, torch.dtype): + return + if self.data_type in ["fp16", "float16"]: + self.data_type = torch.float16 + elif self.data_type in ["bf16", "bfloat16"]: + self.data_type = torch.bfloat16 + elif self.data_type in ["fp32", "float32"]: + self.data_type = torch.float32 + else: + raise ValueError(f"Unsupport datatype {self.data_type}!") + return - def get_device(self) -> torch.device: - return self.blocks[0].mlp.fc2.weight.device + def load_model(self, weight_dir): + bin_weight_files = [file_ for file_ in os.listdir(weight_dir) if file_.endswith(".bin")] + if bin_weight_files: + weight_dict = {} + for file_ in bin_weight_files: + f = torch.load(os.path.join(weight_dir, file_), "cpu") + for k, v in f.items(): + if "visual" in k: + weight_dict[k[len("visual.") :]] = v + else: + hf_weight_files = [file_ for file_ in os.listdir(weight_dir) if file_.endswith(".safetensors")] + weight_dict = {} + for file_ in hf_weight_files: + f = safe_open(os.path.join(weight_dir, file_), "pt", "cpu") + for k in f.keys(): + if "visual" in k: + weight_dict[k[len("visual.") :]] = f.get_tensor(k) + + self.load_state_dict(weight_dict) def rot_pos_emb(self, grid_thw): pos_ids = [] - for t, h, w in grid_thw: + s = self.spatial_merge_size + for _, h, w in grid_thw: + pos_shape = (h // s, s, w // s, s) hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w) - hpos_ids = hpos_ids.reshape( - h // self.spatial_merge_size, - self.spatial_merge_size, - w // self.spatial_merge_size, - self.spatial_merge_size, - ) - hpos_ids = hpos_ids.permute(0, 2, 1, 3) - hpos_ids = hpos_ids.flatten() - wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1) - wpos_ids = wpos_ids.reshape( - h // self.spatial_merge_size, - self.spatial_merge_size, - w // self.spatial_merge_size, - self.spatial_merge_size, - ) - wpos_ids = wpos_ids.permute(0, 2, 1, 3) - wpos_ids = wpos_ids.flatten() - pos_ids.append(torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1)) + hpos_ids, wpos_ids = hpos_ids.reshape(pos_shape), wpos_ids.reshape(pos_shape) + hpos_ids, wpos_ids = hpos_ids.permute(0, 2, 1, 3), wpos_ids.permute(0, 2, 1, 3) + hpos_ids, wpos_ids = hpos_ids.flatten(), wpos_ids.flatten() + + pos_ids.append(torch.stack([hpos_ids, wpos_ids], dim=-1)) pos_ids = torch.cat(pos_ids, dim=0) max_grid_size = grid_thw[:, 1:].max() rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size).type(torch.float32) @@ -302,15 +283,6 @@ def rot_pos_emb(self, grid_thw): return rotary_pos_emb def forward(self, hidden_states: torch.Tensor, grid_thw: torch.Tensor) -> torch.Tensor: - hidden_states = hidden_states.to( - dtype=self.get_dtype(), - device=self.device, - ) - grid_thw = grid_thw.to( - dtype=torch.int32, - device=self.device, - ) - hidden_states = self.patch_embed(hidden_states) rotary_pos_emb = self.rot_pos_emb(grid_thw) cu_seqlens = torch.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]).cumsum( @@ -322,32 +294,23 @@ def forward(self, hidden_states: torch.Tensor, grid_thw: torch.Tensor) -> torch. hidden_states = blk(hidden_states, cu_seqlens=cu_seqlens, rotary_pos_emb=rotary_pos_emb) return self.merger(hidden_states) - def load_model(self, weight_dir): - - processor_config_path = os.path.join(weight_dir, "preprocessor_config.json") - with open(processor_config_path, "r") as f: - processor_config_dict = json.load(f) - self.processor = Qwen2VLImageProcessor(**processor_config_dict) - - bin_weight_files = [file_ for file_ in os.listdir(weight_dir) if file_.endswith(".bin")] - if bin_weight_files: - weight_dict = {} - for file_ in bin_weight_files: - f = torch.load(os.path.join(weight_dir, file_), "cpu") - for k, v in f.items(): - if "visual" in k: - weight_dict[k[len("visual.") :]] = v - + def load_image(self, img: List[ImageItem]): + pixel_values = None + if isinstance(img, ImageItem): + image_data = read_shm(get_shm_name_data(img.uuid)) + image_data = Image.open(BytesIO(image_data)) + image_data = resize_image(image_data) + pixel_values, image_grid_thw = self.processor.preprocess(image_data) + # pixel_values, image_grid_thw = tensor["pixel_values"], tensor["image_grid_thw"] + elif isinstance(img, dict): + image_data = read_shm(get_shm_name_data(img["uuid"])) + image_data = Image.open(BytesIO(image_data)) + image_data = resize_image(image_data) + pixel_values, image_grid_thw = self.processor.preprocess(image_data) + # pixel_values, image_grid_thw = tensor["pixel_values"], tensor["image_grid_thw"] else: - hf_weight_files = [file_ for file_ in os.listdir(weight_dir) if file_.endswith(".safetensors")] - weight_dict = {} - for file_ in hf_weight_files: - f = safe_open(os.path.join(weight_dir, file_), "pt", "cpu") - for k in f.keys(): - if "visual" in k: - weight_dict[k[len("visual.") :]] = f.get_tensor(k) - - self.load_state_dict(weight_dict) + raise Exception("Unsupport input types: {} for {}".format(type(img), img)) + return pixel_values.to(dtype=self.data_type), image_grid_thw def encode(self, images: List[ImageItem]): img_tensors = [] @@ -355,16 +318,16 @@ def encode(self, images: List[ImageItem]): valid_id = 0 img_grids = [] uuids = [] - + print("begin encode+++++++++") for i, img in enumerate(images): if isinstance(img, ImageItem): uuids.append(img.uuid) image_data = read_shm(get_shm_name_data(img.uuid)) image_data = Image.open(BytesIO(image_data)) - image_data = get_image(image_data) - image_inputs = self.processor.preprocess(images=image_data, return_tensors="pt") - pixel_values = image_inputs["pixel_values"].to(dtype=torch.bfloat16) - image_grid_thw = image_inputs["image_grid_thw"] + image_data = resize_image(image_data) + tensor = self.processor.preprocess(images=image_data, return_tensors="pt") + pixel_values, image_grid_thw = tensor["pixel_values"], tensor["image_grid_thw"] + pixel_values = pixel_values.to(dtype=torch.bfloat16) img_tensors.append(pixel_values) img_grids.append(image_grid_thw) else: @@ -385,7 +348,6 @@ def encode(self, images: List[ImageItem]): pixel_values = imgs.cuda().to(dtype=torch.float32) image_grid_thw = grid_thw.cuda() - pixel_values = pixel_values.type(self.get_dtype()) all_img_embeds = self.forward(pixel_values, grid_thw=image_grid_thw) return all_img_embeds, uuids, valid_ids diff --git a/lightllm/models/qwen2_vl/triton_kernel/rotary_pos_emb.py b/lightllm/models/qwen2_vl/triton_kernel/rotary_pos_emb.py new file mode 100644 index 000000000..ab9013c8a --- /dev/null +++ b/lightllm/models/qwen2_vl/triton_kernel/rotary_pos_emb.py @@ -0,0 +1,156 @@ +import math +import torch +import triton +import triton.language as tl + + +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +def apply_rotary_pos_emb_vision_ref(tensor: torch.Tensor, freqs: torch.Tensor) -> torch.Tensor: + orig_dtype = tensor.dtype + tensor = tensor.float() + cos = freqs.cos() + sin = freqs.sin() + cos = cos.unsqueeze(1).repeat(1, 1, 2).unsqueeze(0).float() + sin = sin.unsqueeze(1).repeat(1, 1, 2).unsqueeze(0).float() + output = (tensor * cos) + (rotate_half(tensor) * sin) + output = output.to(orig_dtype) + return output + + +@triton.jit +def rotary_kernel( + inp_ptr, + cos_ptr, + sin_ptr, + out_ptr, + stride_b, + stride_l, + stride_h, + stride_cos_l, + stride_sin_l, + H: tl.constexpr, + D: tl.constexpr, + HALF_D: tl.constexpr, + BLOCK_D: tl.constexpr, +): + pid_bh = tl.program_id(0) + pid_l = tl.program_id(1) + pid_blk = tl.program_id(2) + + b = pid_bh // H + h = pid_bh - b * H + + offs_d = tl.arange(0, BLOCK_D) + d = pid_blk * BLOCK_D + offs_d + mask = d < D + + # 64-bit 计算基址,防止 b*stride_b 溢出 + base = ( + tl.full([], b, tl.int64) * tl.full([], stride_b, tl.int64) + + tl.full([], pid_l, tl.int64) * tl.full([], stride_l, tl.int64) + + tl.full([], h, tl.int64) * tl.full([], stride_h, tl.int64) + ) + + in_ptr = inp_ptr + base + d + cos_ptr_ = cos_ptr + tl.full([], pid_l, tl.int64) * tl.full([], stride_cos_l, tl.int64) + d + sin_ptr_ = sin_ptr + tl.full([], pid_l, tl.int64) * tl.full([], stride_sin_l, tl.int64) + d + + x = tl.load(in_ptr, mask=mask) + cos = tl.load(cos_ptr_, mask=mask) + sin = tl.load(sin_ptr_, mask=mask) + + partner_d = tl.where(d < HALF_D, d + HALF_D, d - HALF_D) + partner_ptr = inp_ptr + base + partner_d + partner_val = tl.load(partner_ptr, mask=mask) + rotated = tl.where(d < HALF_D, -partner_val, partner_val) + + y = x * cos + rotated * sin + + out_ptr_ = out_ptr + base + d + tl.store(out_ptr_, y, mask=mask) + + +def apply_rotary_pos_emb_triton(tensor: torch.Tensor, freqs: torch.Tensor, BLOCK_D: int = 128) -> torch.Tensor: + assert tensor.is_cuda and freqs.is_cuda + if tensor.ndim != 4: + raise RuntimeError("tensor shape should be [B, L, H, D]") + orig_dtype = tensor.dtype + x = tensor.float() + + cos = freqs.cos().unsqueeze(1).repeat(1, 1, 2).view(freqs.size(0), -1).contiguous().float() + sin = freqs.sin().unsqueeze(1).repeat(1, 1, 2).view(freqs.size(0), -1).contiguous().float() + + B, L, H, D = x.shape + HALF_D = D // 2 + y = torch.empty_like(x) + + stride_b, stride_l, stride_h, _ = x.stride() + stride_cos_l, stride_sin_l = cos.stride(0), sin.stride(0) + + grid = (B * H, L, math.ceil(D / BLOCK_D)) + + rotary_kernel[grid]( + x, + cos, + sin, + y, + stride_b, + stride_l, + stride_h, + stride_cos_l, + stride_sin_l, + H, + D, + HALF_D, + BLOCK_D=BLOCK_D, + ) + + return y.to(orig_dtype) + + +def test_accuracy_and_speed( + B: int = 16, + L: int = 1296, + H: int = 64, + D: int = 80, + warmup: int = 10, + repeats: int = 50, +): + torch.manual_seed(0) + freqs = torch.randn(L, D // 2, device="cuda") + x = torch.randn(B, L, H, D, device="cuda") + + # 误差 + y_ref = apply_rotary_pos_emb_vision_ref(x, freqs) + y_tri = apply_rotary_pos_emb_triton(x, freqs) + print("max abs error:", (y_ref - y_tri).abs().max().item()) + + # 预热 + for _ in range(warmup): + apply_rotary_pos_emb_vision_ref(x, freqs) + apply_rotary_pos_emb_triton(x, freqs) + torch.cuda.synchronize() + + # 计时 + def bench(fn): + start = torch.cuda.Event(True) + end = torch.cuda.Event(True) + start.record() + for _ in range(repeats): + fn(x, freqs) + end.record() + torch.cuda.synchronize() + return start.elapsed_time(end) / repeats # ms + + print(f"PyTorch : {bench(apply_rotary_pos_emb_vision_ref):.3f} ms") + print(f"Triton : {bench(apply_rotary_pos_emb_triton):.3f} ms") + + +# if __name__ == "__main__": +# test_accuracy_and_speed() diff --git a/lightllm/models/qwen2_vl/vision_process.py b/lightllm/models/qwen2_vl/vision_process.py index 0107fd97c..0f8e8572b 100644 --- a/lightllm/models/qwen2_vl/vision_process.py +++ b/lightllm/models/qwen2_vl/vision_process.py @@ -1,36 +1,9 @@ -# coding=utf-8 -# Copyright 2024 The Qwen team, Alibaba Group and the HuggingFace Inc. team. All rights reserved. -# -# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX -# and OPT implementations in this library. It has been modified from its -# original forms to accommodate minor architectural differences compared -# to GPT-NeoX and OPT used by the Meta AI team that trained the model. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - from __future__ import annotations - -import base64 -from io import BytesIO import math -from typing import Dict, List, Optional, Union -import numpy as np -import requests import torch +import numpy as np from PIL import Image -from torchvision import io, transforms -from torchvision.transforms import InterpolationMode -from transformers import AutoProcessor +from typing import List, Optional, Union from transformers.image_processing_utils import BaseImageProcessor, BatchFeature from transformers.image_transforms import ( @@ -42,107 +15,48 @@ OPENAI_CLIP_MEAN, OPENAI_CLIP_STD, ChannelDimension, - ImageInput, PILImageResampling, get_image_size, infer_channel_dimension_format, - is_scaled_image, - is_valid_image, - make_list_of_images, to_numpy_array, ) -from transformers.utils import TensorType, is_vision_available, logging - -logger = logging.get_logger(__name__) IMAGE_FACTOR = 28 MIN_PIXELS = 4 * 28 * 28 MAX_PIXELS = 16384 * 28 * 28 MAX_RATIO = 200 - -VIDEO_MIN_PIXELS = 128 * 28 * 28 -VIDEO_MAX_PIXELS = 768 * 28 * 28 -VIDEO_TOTAL_PIXELS = 24576 * 28 * 28 FRAME_FACTOR = 2 FPS = 2.0 FPS_MIN_FRAMES = 4 FPS_MAX_FRAMES = 768 -def make_batched_images(images) -> List[List[ImageInput]]: - """ - Accepts images in list or nested list format, and makes a list of images for preprocessing. - - Args: - images (`Union[List[List[ImageInput]], List[ImageInput], ImageInput]`): - The input image. - - Returns: - list: A list of images. - """ - if isinstance(images, (list, tuple)) and isinstance(images[0], (list, tuple)) and is_valid_image(images[0][0]): - return [img for img_list in images for img in img_list] - - elif isinstance(images, (list, tuple)) and is_valid_image(images[0]): - return images - - elif is_valid_image(images): - return [images] - - raise ValueError(f"Could not make batched images from {images}") - - -def round_by_factor(number: int, factor: int) -> int: - """Returns the closest integer to 'number' that is divisible by 'factor'.""" - return round(number / factor) * factor - - -def ceil_by_factor(number: int, factor: int) -> int: - """Returns the smallest integer greater than or equal to 'number' that is divisible by 'factor'.""" - return math.ceil(number / factor) * factor - - -def floor_by_factor(number: int, factor: int) -> int: - """Returns the largest integer less than or equal to 'number' that is divisible by 'factor'.""" - return math.floor(number / factor) * factor - - def smart_resize( height: int, width: int, factor: int = IMAGE_FACTOR, min_pixels: int = MIN_PIXELS, max_pixels: int = MAX_PIXELS ) -> tuple[int, int]: - """ - Rescales the image so that the following conditions are met: - - 1. Both dimensions (height and width) are divisible by 'factor'. - - 2. The total number of pixels is within the range ['min_pixels', 'max_pixels']. - 3. The aspect ratio of the image is maintained as closely as possible. - """ if max(height, width) / min(height, width) > MAX_RATIO: raise ValueError( f"absolute aspect ratio must be smaller than {MAX_RATIO}, got {max(height, width) / min(height, width)}" ) - h_bar = max(factor, round_by_factor(height, factor)) - w_bar = max(factor, round_by_factor(width, factor)) + h_bar = max(factor, round(height / factor) * factor) + w_bar = max(factor, round(width / factor) * factor) if h_bar * w_bar > max_pixels: beta = math.sqrt((height * width) / max_pixels) - h_bar = floor_by_factor(height / beta, factor) - w_bar = floor_by_factor(width / beta, factor) + h_bar = math.floor(height / beta / factor) * factor + w_bar = math.floor(width / beta / factor) * factor elif h_bar * w_bar < min_pixels: beta = math.sqrt(min_pixels / (height * width)) - h_bar = ceil_by_factor(height * beta, factor) - w_bar = ceil_by_factor(width * beta, factor) + h_bar = math.ceil(height * beta / factor) * factor + w_bar = math.ceil(width * beta / factor) * factor return h_bar, w_bar -def get_image(image_file: Image.Image, size_factor: int = IMAGE_FACTOR) -> tuple[Image.Image, int, int]: - image = image_file.convert("RGB") +def resize_image(image_file: str | Image.Image, size_factor: int = IMAGE_FACTOR) -> tuple[Image.Image, int, int]: - # 获取原始宽度和高度 + image = image_file.convert("RGB") width, height = image.size - # 使用默认的最小像素和最大像素调整大小 resized_height, resized_width = smart_resize( height, width, @@ -150,19 +64,12 @@ def get_image(image_file: Image.Image, size_factor: int = IMAGE_FACTOR) -> tuple min_pixels=MIN_PIXELS, max_pixels=MAX_PIXELS, ) - - # 调整图片大小 image = image.resize((resized_width, resized_height)) return image -# adapted from -# transformers/blob/main/src/transformers/models/qwen2_vl/image_processing_qwen2_vl.py class Qwen2VLImageProcessor(BaseImageProcessor): - - model_input_names = ["pixel_values", "image_grid_thw", "pixel_values_videos", "video_grid_thw"] - def __init__( self, do_resize: bool = True, @@ -186,6 +93,7 @@ def __init__( self.do_rescale = do_rescale self.rescale_factor = rescale_factor self.do_normalize = do_normalize + self.do_convert_rgb = do_convert_rgb self.image_mean = image_mean if image_mean is not None else OPENAI_CLIP_MEAN self.image_std = image_std if image_std is not None else OPENAI_CLIP_STD self.min_pixels = min_pixels @@ -193,76 +101,46 @@ def __init__( self.patch_size = patch_size self.temporal_patch_size = temporal_patch_size self.merge_size = merge_size - self.size = {"min_pixels": min_pixels, "max_pixels": max_pixels} - self.do_convert_rgb = do_convert_rgb - - def _preprocess( - self, - images: Union[ImageInput], - do_resize: bool = None, - resample: PILImageResampling = None, - do_rescale: bool = None, - rescale_factor: float = None, - do_normalize: bool = None, - image_mean: Optional[Union[float, List[float]]] = None, - image_std: Optional[Union[float, List[float]]] = None, - do_convert_rgb: bool = None, - data_format: Optional[ChannelDimension] = ChannelDimension.FIRST, - input_data_format: Optional[Union[str, ChannelDimension]] = None, - ): - - images = make_list_of_images(images) - - if do_convert_rgb: - images = [convert_to_rgb(image) for image in images] + self.data_format = ChannelDimension.FIRST - # All transformations expect numpy arrays. - images = [to_numpy_array(image) for image in images] + def preprocess(self, image): + if self.do_convert_rgb: + image = convert_to_rgb(image) + image = to_numpy_array(image) + input_data_format = infer_channel_dimension_format(image) + height, width = get_image_size(image, channel_dim=input_data_format) - if is_scaled_image(images[0]) and do_rescale: - logger.warning_once( - "It looks like you are trying to rescale already rescaled images. If the input" - " images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again." + resized_height, resized_width = height, width + if self.do_resize: + resized_height, resized_width = smart_resize( + height, + width, + factor=self.patch_size * self.merge_size, + min_pixels=self.min_pixels, + max_pixels=self.max_pixels, + ) + image = resize( + image, size=(resized_height, resized_width), resample=self.resample, input_data_format=input_data_format ) - if input_data_format is None: - # We assume that all images have the same channel dimension format. - input_data_format = infer_channel_dimension_format(images[0]) - height, width = get_image_size(images[0], channel_dim=input_data_format) - resized_height, resized_width = height, width - processed_images = [] - for image in images: - if do_resize: - resized_height, resized_width = smart_resize( - height, - width, - factor=self.patch_size * self.merge_size, - min_pixels=self.min_pixels, - max_pixels=self.max_pixels, - ) - image = resize( - image, size=(resized_height, resized_width), resample=resample, input_data_format=input_data_format - ) + if self.do_rescale: + image = self.rescale(image, scale=self.rescale_factor, input_data_format=input_data_format) - if do_rescale: - image = self.rescale(image, scale=rescale_factor, input_data_format=input_data_format) + if self.do_normalize: + image = self.normalize( + image=image, mean=self.image_mean, std=self.image_std, input_data_format=input_data_format + ) - if do_normalize: - image = self.normalize(image=image, mean=image_mean, std=image_std, input_data_format=input_data_format) + image = to_channel_dimension_format(image, self.data_format, input_channel_dim=input_data_format) - image = to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format) - processed_images.append(image) + patches = np.array([image]) - patches = np.array(processed_images) - if data_format == ChannelDimension.LAST: - patches = patches.transpose(0, 3, 1, 2) if patches.shape[0] == 1: patches = np.tile(patches, (self.temporal_patch_size, 1, 1, 1)) channel = patches.shape[1] - grid_t = patches.shape[0] // self.temporal_patch_size grid_h, grid_w = resized_height // self.patch_size, resized_width // self.patch_size patches = patches.reshape( - grid_t, + 1, self.temporal_patch_size, channel, grid_h // self.merge_size, @@ -274,59 +152,10 @@ def _preprocess( ) patches = patches.transpose(0, 3, 6, 4, 7, 2, 1, 5, 8) flatten_patches = patches.reshape( - grid_t * grid_h * grid_w, channel * self.temporal_patch_size * self.patch_size * self.patch_size + grid_h * grid_w, channel * self.temporal_patch_size * self.patch_size * self.patch_size ) + image_grid_thw = (1, grid_h, grid_w) + pixel_values = torch.as_tensor(flatten_patches) + grid_thw = torch.as_tensor([image_grid_thw]) - return flatten_patches, (grid_t, grid_h, grid_w) - - def preprocess( - self, - images: ImageInput, - do_resize: bool = None, - size: Dict[str, int] = None, - resample: PILImageResampling = None, - do_rescale: bool = None, - rescale_factor: float = None, - do_normalize: bool = None, - image_mean: Optional[Union[float, List[float]]] = None, - image_std: Optional[Union[float, List[float]]] = None, - do_convert_rgb: bool = None, - return_tensors: Optional[Union[str, TensorType]] = "pt", - data_format: Optional[ChannelDimension] = ChannelDimension.FIRST, - input_data_format: Optional[Union[str, ChannelDimension]] = None, - ): - - do_resize = do_resize if do_resize is not None else self.do_resize - size = size if size is not None else self.size - resample = resample if resample is not None else self.resample - do_rescale = do_rescale if do_rescale is not None else self.do_rescale - rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor - do_normalize = do_normalize if do_normalize is not None else self.do_normalize - image_mean = image_mean if image_mean is not None else self.image_mean - image_std = image_std if image_std is not None else self.image_std - do_convert_rgb = do_convert_rgb if do_convert_rgb is not None else self.do_convert_rgb - - if images is not None: - images = make_batched_images(images) - pixel_values, vision_grid_thws = [], [] - for image in images: - patches, image_grid_thw = self._preprocess( - image, - do_resize=do_resize, - resample=resample, - do_rescale=do_rescale, - rescale_factor=rescale_factor, - do_normalize=do_normalize, - image_mean=image_mean, - image_std=image_std, - data_format=data_format, - do_convert_rgb=do_convert_rgb, - input_data_format=input_data_format, - ) - pixel_values.extend(patches) - vision_grid_thws.append(image_grid_thw) - pixel_values = np.array(pixel_values) - vision_grid_thws = np.array(vision_grid_thws) - data = {"pixel_values": pixel_values, "image_grid_thw": vision_grid_thws} - - return BatchFeature(data=data, tensor_type=return_tensors) + return pixel_values, grid_thw diff --git a/lightllm/models/qwen_vl/layer_infer/pre_layer_infer.py b/lightllm/models/qwen_vl/layer_infer/pre_layer_infer.py index 5c5f57e47..2f5d07a22 100644 --- a/lightllm/models/qwen_vl/layer_infer/pre_layer_infer.py +++ b/lightllm/models/qwen_vl/layer_infer/pre_layer_infer.py @@ -40,7 +40,12 @@ def _infer_image_embeds(self, infer_state, layer_weight): for img in p["images"] + p["audios"]: if img.get("_prefill_", True): image_data = img["image_data"].to("cuda", non_blocking=True) - image_embed = layer_weight.visual_model.forward(image_data).view(img["token_num"], -1) + image_grid_thw = img["image_grid_thw"] + # image_embed = torch.zeros( + # (img["token_num"],layer_weight.wte_weight_.shape[1]),device="cuda",dtype=torch.bfloat16) + image_embed = layer_weight.visual_model.forward(image_data, image_grid_thw).view( + img["token_num"], -1 + ) image_weight.append(image_embed) if len(image_weight) > 0: image_weight = torch.cat(image_weight, dim=0) @@ -63,7 +68,6 @@ def context_forward(self, input_ids, infer_state: LlamaInferStateInfo, layer_wei img_weight = self._infer_image_embeds(infer_state, layer_weight) out = torch.zeros((len(input_ids), hidden_size), dtype=dtype, device="cpu").to("cuda", non_blocking=True) - img_weight = img_weight / self.tp_world_size_ if infer_state.image_start_token_ids is not None: img_start_token_ids = infer_state.image_start_token_ids.to("cuda", non_blocking=True) img_token_lens = infer_state.image_token_lens.to("cuda", non_blocking=True) diff --git a/lightllm/models/tarsier2/tarsier2_visual.py b/lightllm/models/tarsier2/tarsier2_visual.py index 978b9bd17..d7171dfe5 100644 --- a/lightllm/models/tarsier2/tarsier2_visual.py +++ b/lightllm/models/tarsier2/tarsier2_visual.py @@ -13,10 +13,10 @@ from transformers import AutoModel from safetensors import safe_open -from lightllm.models.qwen2_vl.qwen2_visual import Qwen2VisionTransformerPretrainedModel +from lightllm.models.qwen2_vl.qwen2_visual import Qwen2VLTransformer from lightllm.server.embed_cache.utils import read_shm, get_shm_name_data from lightllm.server.multimodal_params import ImageItem -from lightllm.models.qwen2_vl.vision_process import Qwen2VLImageProcessor, get_image +from lightllm.models.qwen2_vl.vision_process import Qwen2VLImageProcessor, resize_image def add_split_tokens(image_features, image_newline_embed, image_new_embed): @@ -165,7 +165,7 @@ def __init__( **kwargs, ): super().__init__() - self.vision_tower = Qwen2VisionTransformerPretrainedModel(**vision_config) + self.vision_tower = Qwen2VLTransformer(**vision_config) if projection_head == "Pixel_Shuffle": self.multi_modal_projector = PixelShuffleMultiModalProjector( @@ -253,7 +253,7 @@ def encode(self, images: List[ImageItem]): uuids.append(img.uuid) image_data = read_shm(get_shm_name_data(img.uuid)) image_data = Image.open(BytesIO(image_data)) - image_data = get_image(image_data) + image_data = resize_image(image_data) image_inputs = self.processor.preprocess(images=image_data, return_tensors="pt") pixel_values = image_inputs["pixel_values"].to(dtype=torch.bfloat16) image_grid_thw = image_inputs["image_grid_thw"] diff --git a/lightllm/models/vit/model.py b/lightllm/models/vit/model.py index 44705fd8f..7aa95eae8 100644 --- a/lightllm/models/vit/model.py +++ b/lightllm/models/vit/model.py @@ -48,7 +48,6 @@ def __init__(self, kvargs): self.quant_cfg_path = kvargs.get("quant_cfg", None) self.load_image_func = get_load_image_func(self.weight_dir_) self.max_batch_size = kvargs.get("max_batch_size", 1) - self.enable_tensor_cache = not get_env_start_args().disable_extra_process_for_multimodal self._init_datatype() self._init_config() @@ -60,20 +59,18 @@ def __init__(self, kvargs): return def load_image(self, img: List[ImageItem]): - from lightllm.server.multimodal_params import ImageItem - - img_tensor = None + pixel_values = None if isinstance(img, ImageItem): image_data = read_shm(get_shm_name_data(img.uuid)) image_data = Image.open(BytesIO(image_data)) - img_tensor = self.load_image_func(image_data, max_num=img.extra_params["image_patch_max_num"]) + pixel_values = self.load_image_func(image_data, max_num=img.extra_params["image_patch_max_num"]) elif isinstance(img, dict): image_data = read_shm(get_shm_name_data(img["uuid"])) image_data = Image.open(BytesIO(image_data)) - img_tensor = self.load_image_func(image_data, max_num=img["extra_params"]["image_patch_max_num"]) + pixel_values = self.load_image_func(image_data, max_num=img["extra_params"]["image_patch_max_num"]) else: raise Exception("Unsupport input types: {} for {}".format(type(img), img)) - return img_tensor.to(dtype=self.data_type) + return pixel_values.to(dtype=self.data_type), None @final @torch.no_grad() @@ -81,13 +78,12 @@ def _check_max_len_infer(self): disable_check_max_len_infer = os.getenv("DISABLE_CHECK_MAX_LEN_INFER", None) is not None if disable_check_max_len_infer: return - self.enable_tensor_cache = True try: dummy_images = torch.randn( (self.MAX_PATH_NUM * self.max_batch_size, 3, self.IMAGE_H, self.IMAGE_W), dtype=self.data_type ).cuda() - all_img_embeds = self.forward(dummy_images) + all_img_embeds = self.forward(dummy_images, image_gird_thw=None) del all_img_embeds del dummy_images logger.info(f"vit check max_len {self.max_batch_size} infer ok") @@ -98,7 +94,6 @@ def _check_max_len_infer(self): ) logger.error(exception_str) raise Exception(exception_str) - self.enable_tensor_cache = not get_env_start_args().disable_extra_process_for_multimodal return def _init_config(self): @@ -183,15 +178,11 @@ def _init_datatype(self): raise ValueError(f"Unsupport datatype {self.data_type}!") @torch.no_grad() - def forward(self, pixel_values): - if self.enable_tensor_cache: - g_cache_manager.cache_env_in() + def forward(self, pixel_values, image_gird_thw): input_embs = self.pre_infer.forward(pixel_values, self.pre_post_weight) for i in range(self.layers_num + self.select_layer + 1): input_embs = self.layers_infer[i].forward(input_embs, self.trans_layers_weight[i]) input_embs = self.post_infer.forward(input_embs[:, 1:, :], self.pre_post_weight) - if self.enable_tensor_cache: - g_cache_manager.cache_env_out() return input_embs @torch.no_grad() diff --git a/lightllm/server/router/model_infer/mode_backend/base_backend.py b/lightllm/server/router/model_infer/mode_backend/base_backend.py index b3bc0e6f2..3a6572839 100644 --- a/lightllm/server/router/model_infer/mode_backend/base_backend.py +++ b/lightllm/server/router/model_infer/mode_backend/base_backend.py @@ -481,8 +481,9 @@ def _preprocess_image(self, batch: ModelInput): if not args.disable_extra_process_for_multimodal: continue # 预拉取已经存在的image data - image_data = self.model.pre_post_weight.visual_model.load_image(img) + image_data, image_grid_thw = self.model.pre_post_weight.visual_model.load_image(img) img["image_data"] = image_data + img["image_grid_thw"] = image_grid_thw batch.image_start_locs = torch.tensor(image_start_locs, device="cpu", dtype=torch.long) batch.image_token_lens = torch.tensor(image_token_lens, device="cpu", dtype=torch.long) batch.image_start_token_ids = torch.tensor(image_start_token_ids, device="cpu", dtype=torch.long) diff --git a/lightllm/server/visualserver/model_infer/model_rpc.py b/lightllm/server/visualserver/model_infer/model_rpc.py index d2d45f2fd..679ab7ffd 100644 --- a/lightllm/server/visualserver/model_infer/model_rpc.py +++ b/lightllm/server/visualserver/model_infer/model_rpc.py @@ -14,8 +14,8 @@ from lightllm.models.gemma3.gemma3_visual import Gemma3VisionModel from lightllm.models.vit.model import VisionTransformer from lightllm.server.multimodal_params import MultimodalParams, ImageItem -from lightllm.models.qwen2_vl.qwen2_visual import Qwen2VisionTransformerPretrainedModel -from lightllm.models.qwen2_5_vl.qwen2_5_visual import Qwen2_5_VisionTransformerPretrainedModel +from lightllm.models.qwen2_vl.qwen2_visual import Qwen2VLTransformer +from lightllm.models.qwen2_5_vl.qwen2_5_visual import Qwen2_5VLTransformer from lightllm.models.tarsier2.tarsier2_visual import TarsierVisionTransformerPretrainedModel from lightllm.server.embed_cache.utils import tensor2bytes, read_shm, create_shm, get_shm_name_data, get_shm_name_embed from lightllm.utils.infer_utils import set_random_seed @@ -46,24 +46,24 @@ def exposed_init_model(self, kvargs): try: self.model_type = model_cfg["model_type"] + kvargs = { + "weight_dir": weight_dir, + "data_type": self.data_type, + "quant_type": kvargs["quant_type"], + "quant_cfg": kvargs["quant_cfg"], + "max_batch_size": kvargs["max_batch_size"], + } if self.model_type == "qwen": self.model = QWenVisionTransformer(**model_cfg["visual"]).eval().bfloat16() elif self.model_type == "qwen2_vl": - self.model = Qwen2VisionTransformerPretrainedModel(**model_cfg["vision_config"]).eval().bfloat16() + self.model = Qwen2VLTransformer(**model_cfg["vision_config"]).eval().bfloat16() elif self.model_type == "qwen2_5_vl": - self.model = Qwen2_5_VisionTransformerPretrainedModel(**model_cfg["vision_config"]).eval().bfloat16() + self.model = Qwen2_5VLTransformer(kvargs, **model_cfg["vision_config"]).eval().bfloat16() elif model_cfg["architectures"][0] == "TarsierForConditionalGeneration": - self.model = TarsierVisionTransformerPretrainedModel(**model_cfg).eval().bfloat16() + self.model = TarsierVisionTransformerPretrainedModel(kvargs, **model_cfg).eval().bfloat16() elif self.model_type == "llava": self.model = LlavaVisionModel() elif self.model_type == "internvl_chat": - kvargs = { - "weight_dir": weight_dir, - "data_type": self.data_type, - "quant_type": kvargs["quant_type"], - "quant_cfg": kvargs["quant_cfg"], - "max_batch_size": kvargs["max_batch_size"], - } self.model = VisionTransformer(kvargs) # self.model = InternVLVisionModel() elif self.model_type == "gemma3": From 88577b3ea75fd0ac422885fb5c50ea257d345f92 Mon Sep 17 00:00:00 2001 From: sangchengmeng Date: Fri, 8 Aug 2025 07:47:56 +0000 Subject: [PATCH 16/18] 0808-temp --- lightllm/models/qwen2_vl/qwen2_visual.py | 50 +++++++++++++------ .../vit/triton_kernel/flashattention_nopad.py | 2 +- 2 files changed, 36 insertions(+), 16 deletions(-) diff --git a/lightllm/models/qwen2_vl/qwen2_visual.py b/lightllm/models/qwen2_vl/qwen2_visual.py index e9274af82..b08788a7e 100644 --- a/lightllm/models/qwen2_vl/qwen2_visual.py +++ b/lightllm/models/qwen2_vl/qwen2_visual.py @@ -104,15 +104,30 @@ def forward(self, x) -> torch.Tensor: return self.fc2(self.act(self.fc1(x))) +# copy form vllm class VisionRotaryEmbedding(nn.Module): def __init__(self, dim: int, theta: float = 10000.0) -> None: super().__init__() + self.dim = dim + self.theta = theta self.inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float) / dim)) + self._seq_len_cached = 0 + self._freqs_cached = None + + def update_freqs_cache(self, seqlen: int) -> None: + if seqlen > self._seq_len_cached: + seqlen *= 2 + self._seq_len_cached = seqlen + self.inv_freq = 1.0 / ( + self.theta ** (torch.arange(0, self.dim, 2, dtype=torch.float, device=self.inv_freq.device) / self.dim) + ) + seq = torch.arange(seqlen, device=self.inv_freq.device, dtype=self.inv_freq.dtype) + freqs = torch.outer(seq, self.inv_freq) + self._freqs_cached = freqs def forward(self, seqlen: int) -> torch.Tensor: - self.seq = torch.arange(seqlen, device=self.inv_freq.device, dtype=self.inv_freq.dtype) - self.freqs = torch.outer(self.seq, self.inv_freq) - return self.freqs + self.update_freqs_cache(seqlen) + return self._freqs_cached[:seqlen] class VisionFlashAttention(nn.Module): @@ -130,17 +145,19 @@ def apply_rotary_pos_emb_vision(self, t: torch.Tensor, freqs: torch.Tensor) -> t return output def forward( - self, hidden_states: torch.Tensor, cu_seqlens: torch.Tensor, rotary_pos_emb: torch.Tensor = None + self, + hidden_states: torch.Tensor, + cu_seqlens: torch.Tensor, + max_seqlen: int = 0, + rotary_pos_emb: torch.Tensor = None, ) -> torch.Tensor: seq_length = hidden_states.shape[0] q, k, v = self.qkv(hidden_states).reshape(seq_length, 3, self.num_heads, -1).permute(1, 0, 2, 3).unbind(0) - q = self.apply_rotary_pos_emb_vision(q.unsqueeze(0), rotary_pos_emb.cuda()) - k = self.apply_rotary_pos_emb_vision(k.unsqueeze(0), rotary_pos_emb.cuda()) + q = self.apply_rotary_pos_emb_vision(q.unsqueeze(0), rotary_pos_emb) + k = self.apply_rotary_pos_emb_vision(k.unsqueeze(0), rotary_pos_emb) q = q.squeeze(0) k = k.squeeze(0) - cu_seqlens = cu_seqlens.to(q.device, torch.int32, non_blocking=True) - max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() attn_output = g_cache_manager.alloc_tensor(q.shape, q.dtype, device=q.device) flash_attention_fwd(q, k, v, attn_output, cu_seqlens, max_seqlen) @@ -159,9 +176,9 @@ def __init__(self, embed_dim, mlp_ratio, num_heads, hidden_act) -> None: self.attn = VisionFlashAttention(embed_dim, num_heads=num_heads) self.mlp = VisionMlp(dim=embed_dim, hidden_dim=mlp_hidden_dim, hidden_act=hidden_act) - def forward(self, hidden_states, cu_seqlens, rotary_pos_emb) -> torch.Tensor: + def forward(self, hidden_states, cu_seqlens, max_seqlen, rotary_pos_emb) -> torch.Tensor: hidden_states = hidden_states + self.attn( - self.norm1(hidden_states), cu_seqlens=cu_seqlens, rotary_pos_emb=rotary_pos_emb + self.norm1(hidden_states), cu_seqlens=cu_seqlens, max_seqlen=max_seqlen, rotary_pos_emb=rotary_pos_emb ) hidden_states = hidden_states + self.mlp(self.norm2(hidden_states)) return hidden_states @@ -271,9 +288,8 @@ def rot_pos_emb(self, grid_thw): pos_shape = (h // s, s, w // s, s) hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w) wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1) - hpos_ids, wpos_ids = hpos_ids.reshape(pos_shape), wpos_ids.reshape(pos_shape) - hpos_ids, wpos_ids = hpos_ids.permute(0, 2, 1, 3), wpos_ids.permute(0, 2, 1, 3) - hpos_ids, wpos_ids = hpos_ids.flatten(), wpos_ids.flatten() + hpos_ids = hpos_ids.reshape(pos_shape).permute(0, 2, 1, 3).flatten() + wpos_ids = wpos_ids.reshape(pos_shape).permute(0, 2, 1, 3).flatten() pos_ids.append(torch.stack([hpos_ids, wpos_ids], dim=-1)) pos_ids = torch.cat(pos_ids, dim=0) @@ -284,14 +300,18 @@ def rot_pos_emb(self, grid_thw): def forward(self, hidden_states: torch.Tensor, grid_thw: torch.Tensor) -> torch.Tensor: hidden_states = self.patch_embed(hidden_states) - rotary_pos_emb = self.rot_pos_emb(grid_thw) + rotary_pos_emb = self.rot_pos_emb(grid_thw).to("cuda", non_blocking=True) cu_seqlens = torch.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]).cumsum( dim=0, dtype=torch.int32 ) cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0) + max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() + cu_seqlens = cu_seqlens.to("cuda", non_blocking=True) for blk in self.blocks: - hidden_states = blk(hidden_states, cu_seqlens=cu_seqlens, rotary_pos_emb=rotary_pos_emb) + hidden_states = blk( + hidden_states, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen, rotary_pos_emb=rotary_pos_emb + ) return self.merger(hidden_states) def load_image(self, img: List[ImageItem]): diff --git a/lightllm/models/vit/triton_kernel/flashattention_nopad.py b/lightllm/models/vit/triton_kernel/flashattention_nopad.py index ab3770a36..fcb6fbb5d 100644 --- a/lightllm/models/vit/triton_kernel/flashattention_nopad.py +++ b/lightllm/models/vit/triton_kernel/flashattention_nopad.py @@ -215,7 +215,7 @@ def flash_attention_fwd(q, k, v, o, cu_seqlens, max_seqlen): 统一的 Flash Attention 接口。如果 sgl_kernel 存在, 则使用 sgl_kernel里的接口,否则使用 Triton 版本。 """ - if _flash_attn_v3_available and is_hopper() and False: + if _flash_attn_v3_available and is_hopper(): flash_attention_v3_fwd(q, k, v, o, cu_seqlens, max_seqlen) else: _flash_attention_triton_fwd(q, k, v, o, cu_seqlens, max_seqlen) From 24974072a940e3e17383e12b8fbc189a1f71d928 Mon Sep 17 00:00:00 2001 From: sangchengmeng Date: Mon, 11 Aug 2025 10:08:54 +0000 Subject: [PATCH 17/18] 0811-fix-qwen2-5 --- lightllm/models/qwen2_5_vl/qwen2_5_visual.py | 217 +++++++----------- lightllm/models/qwen2_vl/qwen2_visual.py | 15 +- .../qwen2_vl/triton_kernel/rotary_pos_emb.py | 32 +-- lightllm/models/qwen2_vl/vision_process.py | 2 +- 4 files changed, 110 insertions(+), 156 deletions(-) diff --git a/lightllm/models/qwen2_5_vl/qwen2_5_visual.py b/lightllm/models/qwen2_5_vl/qwen2_5_visual.py index 693a73a38..0a599475d 100644 --- a/lightllm/models/qwen2_5_vl/qwen2_5_visual.py +++ b/lightllm/models/qwen2_5_vl/qwen2_5_visual.py @@ -24,48 +24,12 @@ from lightllm.models.qwen2_vl.qwen2_visual import PatchEmbed, VisionRotaryEmbedding from lightllm.models.vit.triton_kernel.flashattention_nopad import flash_attention_fwd from lightllm.common.basemodel.layer_infer.cache_tensor_manager import g_cache_manager +from lightllm.models.qwen2_vl.triton_kernel.rotary_pos_emb import apply_rotary_pos_emb_triton # adapted from # https://github.com/huggingface/transformers/blob/ # be37d34f44ff1bc928e59ffb8a30adecab8835a8/src # /transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py#L30C1-L31C1 -class Qwen2_5_VLVisionConfig(PretrainedConfig): - model_type = "qwen2_5_vl" - - def __init__( - self, - depth=32, - hidden_size=3584, - hidden_act="silu", - intermediate_size=3420, - num_heads=16, - in_channels=3, - patch_size=14, - spatial_merge_size=2, - temporal_patch_size=2, - tokens_per_second=4, - window_size=112, - out_hidden_size=3584, - fullatt_block_indexes=[7, 15, 23, 31], - **kwargs, - ): - super().__init__(**kwargs) - - self.depth = depth - self.hidden_size = hidden_size - self.hidden_act = hidden_act - self.intermediate_size = intermediate_size - self.num_heads = num_heads - self.in_channels = in_channels - self.patch_size = patch_size - self.spatial_merge_size = spatial_merge_size - self.temporal_patch_size = temporal_patch_size - self.tokens_per_second = tokens_per_second - self.window_size = window_size - self.fullatt_block_indexes = fullatt_block_indexes - self.out_hidden_size = out_hidden_size - - class Qwen2RMSNorm(nn.Module): def __init__(self, hidden_size, eps=1e-6): """ @@ -104,27 +68,6 @@ def forward(self, hidden_state): return self.down_proj(self.act_fn(self.gate_proj(hidden_state)) * self.up_proj(hidden_state)) -def rotate_half(x): - """Rotates half the hidden dims of the input.""" - x1 = x[..., : x.shape[-1] // 2] - x2 = x[..., x.shape[-1] // 2 :] - return torch.cat((-x2, x1), dim=-1) - - -def apply_rotary_pos_emb_vision( - q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor -) -> Tuple[torch.Tensor, torch.Tensor]: - orig_q_dtype = q.dtype - orig_k_dtype = k.dtype - q, k = q.float(), k.float() - cos, sin = cos.unsqueeze(-2).float(), sin.unsqueeze(-2).float() - q_embed = (q * cos) + (rotate_half(q) * sin) - k_embed = (k * cos) + (rotate_half(k) * sin) - q_embed = q_embed.to(orig_q_dtype) - k_embed = k_embed.to(orig_k_dtype) - return q_embed, k_embed - - class Qwen2_5_VLVisionFlashAttention(nn.Module): def __init__(self, dim: int, num_heads: int = 16) -> None: super().__init__() @@ -132,26 +75,39 @@ def __init__(self, dim: int, num_heads: int = 16) -> None: self.head_dim = dim // num_heads self.qkv = nn.Linear(dim, dim * 3, bias=True) self.proj = nn.Linear(dim, dim) + try: + from vllm.vllm_flash_attn.layers.rotary import apply_rotary_emb + + self.has_vllm = True + self.apply_rotary_emb = apply_rotary_emb + except ImportError: + print("Failed to import _flash_attn_forward from hopper.flash_attn_interface.") + self.has_vllm = False + self.apply_rotary_emb = apply_rotary_pos_emb_triton + + def apply_rotary_pos_emb_vision(self, t: torch.Tensor, freqs: torch.Tensor) -> torch.Tensor: + t_ = t.float() + cos = freqs.cos() + sin = freqs.sin() + output = self.apply_rotary_emb(t_, cos, sin).type_as(t) + return output def forward( self, hidden_states: torch.Tensor, cu_seqlens: torch.Tensor, + max_seqlen: int = 0, rotary_pos_emb: Optional[torch.Tensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, ) -> torch.Tensor: seq_length = hidden_states.shape[0] q, k, v = self.qkv(hidden_states).reshape(seq_length, 3, self.num_heads, -1).permute(1, 0, 2, 3).unbind(0) - if position_embeddings is None: - emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1) - cos = emb.cos() - sin = emb.sin() - else: - cos, sin = position_embeddings - q, k = apply_rotary_pos_emb_vision(q, k, cos, sin) + # if position_embeddings is None: + # position_embeddings = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1) + q = self.apply_rotary_pos_emb_vision(q.unsqueeze(0), rotary_pos_emb) + k = self.apply_rotary_pos_emb_vision(k.unsqueeze(0), rotary_pos_emb) + q = q.squeeze(0) + k = k.squeeze(0) - cu_seqlens = cu_seqlens.to(q.device, torch.int32) - max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() attn_output = g_cache_manager.alloc_tensor(q.shape, q.dtype, device=q.device) flash_attention_fwd(q, k, v, attn_output, cu_seqlens, max_seqlen) attn_output = attn_output.reshape(seq_length, -1) @@ -183,14 +139,14 @@ def forward( self, hidden_states: torch.Tensor, cu_seqlens: torch.Tensor, + max_seqlen: int = 0, rotary_pos_emb: Optional[torch.Tensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, ) -> torch.Tensor: hidden_states = hidden_states + self.attn( self.norm1(hidden_states), cu_seqlens=cu_seqlens, + max_seqlen=max_seqlen, rotary_pos_emb=rotary_pos_emb, - position_embeddings=position_embeddings, ) hidden_states = hidden_states + self.mlp(self.norm2(hidden_states)) return hidden_states @@ -215,7 +171,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: class Qwen2_5VLTransformer(nn.Module): def __init__( self, - weight_dir, + kvargs, depth=32, hidden_size=3584, hidden_act="silu", @@ -232,7 +188,13 @@ def __init__( **kwargs, ): super().__init__() - + self.weight_dir = kvargs["weight_dir"] + self.data_type = kvargs.get("data_type", "bfloat16") + # self.mode = [m.replace("int4weight", "w4a16").replace("int8weight", "w8a16") for m in kvargs.get("mode", [])] + # self.weight_dict = kvargs.get("weight_dict", None) + # self.quant_type = kvargs.get("quant_type", None) + # self.quant_cfg_path = kvargs.get("quant_cfg", None) + # self.max_batch_size = kvargs.get("max_batch_size", 1) self.depth = depth self.hidden_size = hidden_size self.hidden_act = hidden_act @@ -279,46 +241,42 @@ def __init__( self.gradient_checkpointing = False - processor_config_path = os.path.join(weight_dir, "preprocessor_config.json") + processor_config_path = os.path.join(self.weight_dir, "preprocessor_config.json") with open(processor_config_path, "r") as f: processor_config_dict = json.load(f) self.processor = Qwen2VLImageProcessor(**processor_config_dict) - self.device = self.get_device() - self.dtype = self.get_dtype() - - def get_dtype(self) -> torch.dtype: - return self.blocks[0].mlp.down_proj.weight.dtype - - def get_device(self) -> torch.device: - return self.blocks[0].mlp.down_proj.weight.device + self._init_datatype() + self.load_model(kvargs["weight_dir"]) + self.cuda() + + def _init_datatype(self): + if isinstance(self.data_type, torch.dtype): + return + if self.data_type in ["fp16", "float16"]: + self.data_type = torch.float16 + elif self.data_type in ["bf16", "bfloat16"]: + self.data_type = torch.bfloat16 + elif self.data_type in ["fp32", "float32"]: + self.data_type = torch.float32 + else: + raise ValueError(f"Unsupport datatype {self.data_type}!") + return def rot_pos_emb(self, grid_thw): pos_ids = [] - for t, h, w in grid_thw: + s = self.spatial_merge_size + for _, h, w in grid_thw: + pos_shape = (h // s, s, w // s, s) hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w) - hpos_ids = hpos_ids.reshape( - h // self.spatial_merge_size, - self.spatial_merge_size, - w // self.spatial_merge_size, - self.spatial_merge_size, - ) - hpos_ids = hpos_ids.permute(0, 2, 1, 3) - hpos_ids = hpos_ids.flatten() - wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1) - wpos_ids = wpos_ids.reshape( - h // self.spatial_merge_size, - self.spatial_merge_size, - w // self.spatial_merge_size, - self.spatial_merge_size, - ) - wpos_ids = wpos_ids.permute(0, 2, 1, 3) - wpos_ids = wpos_ids.flatten() - pos_ids.append(torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1)) + hpos_ids = hpos_ids.reshape(pos_shape).permute(0, 2, 1, 3).flatten() + wpos_ids = wpos_ids.reshape(pos_shape).permute(0, 2, 1, 3).flatten() + + pos_ids.append(torch.stack([hpos_ids, wpos_ids], dim=-1)) pos_ids = torch.cat(pos_ids, dim=0) max_grid_size = grid_thw[:, 1:].max() - rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size) + rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size).type(torch.float32) rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1) return rotary_pos_emb @@ -365,7 +323,14 @@ def get_window_index(self, grid_thw): def forward(self, hidden_states: torch.Tensor, grid_thw: torch.Tensor) -> torch.Tensor: hidden_states = self.patch_embed(hidden_states) - rotary_pos_emb = self.rot_pos_emb(grid_thw) + rotary_pos_emb = self.rot_pos_emb(grid_thw).to("cuda", non_blocking=True) + cu_seqlens = torch.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]).cumsum( + dim=0, dtype=torch.int32 + ) + cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0) + max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() + cu_seqlens = cu_seqlens.to("cuda", non_blocking=True) + window_index, cu_window_seqlens = self.get_window_index(grid_thw) cu_window_seqlens = torch.tensor( cu_window_seqlens, @@ -373,6 +338,7 @@ def forward(self, hidden_states: torch.Tensor, grid_thw: torch.Tensor) -> torch. dtype=grid_thw.dtype if torch.jit.is_tracing() else torch.int32, ) cu_window_seqlens = torch.unique_consecutive(cu_window_seqlens) + max_window_seqlen = (cu_window_seqlens[1:] - cu_window_seqlens[:-1]).max().item() seq_len, _ = hidden_states.size() hidden_states = hidden_states.reshape(seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1) @@ -381,40 +347,21 @@ def forward(self, hidden_states: torch.Tensor, grid_thw: torch.Tensor) -> torch. rotary_pos_emb = rotary_pos_emb.reshape(seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1) rotary_pos_emb = rotary_pos_emb[window_index, :, :] rotary_pos_emb = rotary_pos_emb.reshape(seq_len, -1) - emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1) - position_embeddings = (emb.cos(), emb.sin()) - - cu_seqlens = torch.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]).cumsum( - dim=0, - # Select dtype based on the following factors: - # - FA2 requires that cu_seqlens_q must have dtype int32 - # - torch.onnx.export requires that cu_seqlens_q must have same - # dtype as grid_thw - # See https://github.com/huggingface/transformers/pull/34852 - # for more information - dtype=grid_thw.dtype if torch.jit.is_tracing() else torch.int32, - ) - cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0) for layer_num, blk in enumerate(self.blocks): if layer_num in self.fullatt_block_indexes: cu_seqlens_now = cu_seqlens + max_seqlen_now = max_seqlen else: cu_seqlens_now = cu_window_seqlens - if self.gradient_checkpointing and self.training: - hidden_states = self._gradient_checkpointing_func( - blk.__call__, - hidden_states, - cu_seqlens_now, - None, - position_embeddings, - ) - else: - hidden_states = blk( - hidden_states, - cu_seqlens=cu_seqlens_now, - position_embeddings=position_embeddings, - ) + max_seqlen_now = max_window_seqlen + + hidden_states = blk( + hidden_states, + cu_seqlens=cu_seqlens_now, + max_seqlen=max_seqlen_now, + rotary_pos_emb=rotary_pos_emb, + ) hidden_states = self.merger(hidden_states) reverse_indices = torch.argsort(window_index) @@ -428,19 +375,15 @@ def load_image(self, img: List[ImageItem]): image_data = read_shm(get_shm_name_data(img.uuid)) image_data = Image.open(BytesIO(image_data)) image_data = resize_image(image_data) - image_inputs = self.processor.preprocess(images=image_data, return_tensors="pt") - pixel_values = image_inputs["pixel_values"].to(dtype=torch.bfloat16) - image_grid_thw = image_inputs["image_grid_thw"] + pixel_values, image_grid_thw = self.processor.preprocess(image_data) elif isinstance(img, dict): image_data = read_shm(get_shm_name_data(img["uuid"])) image_data = Image.open(BytesIO(image_data)) image_data = resize_image(image_data) - image_inputs = self.processor.preprocess(images=image_data, return_tensors="pt") - pixel_values = image_inputs["pixel_values"].to(dtype=torch.bfloat16) - image_grid_thw = image_inputs["image_grid_thw"] + pixel_values, image_grid_thw = self.processor.preprocess(image_data) else: raise Exception("Unsupport input types: {} for {}".format(type(img), img)) - return pixel_values.to(dtype=self.get_dtype()), image_grid_thw + return pixel_values.to(dtype=self.data_type), image_grid_thw def load_model(self, weight_dir): diff --git a/lightllm/models/qwen2_vl/qwen2_visual.py b/lightllm/models/qwen2_vl/qwen2_visual.py index b08788a7e..467c1c59a 100644 --- a/lightllm/models/qwen2_vl/qwen2_visual.py +++ b/lightllm/models/qwen2_vl/qwen2_visual.py @@ -46,7 +46,6 @@ from lightllm.models.qwen2_vl.triton_kernel.rotary_pos_emb import apply_rotary_pos_emb_triton from lightllm.models.vit.triton_kernel.flashattention_nopad import flash_attention_fwd from lightllm.common.basemodel.layer_infer.cache_tensor_manager import g_cache_manager -from vllm.vllm_flash_attn.layers.rotary import apply_rotary_emb # adapted from # https://github.com/huggingface/transformers/blob/main/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py @@ -136,12 +135,22 @@ def __init__(self, dim: int, num_heads: int = 16) -> None: self.num_heads = num_heads self.qkv = nn.Linear(dim, dim * 3, bias=True) self.proj = nn.Linear(dim, dim) + self.has_vllm = False + try: + from vllm.vllm_flash_attn.layers.rotary import apply_rotary_emb + + self.has_vllm = True + self.apply_rotary_emb = apply_rotary_emb + except ImportError: + print("Failed to import _flash_attn_forward from hopper.flash_attn_interface.") + self.has_vllm = False + self.apply_rotary_emb = apply_rotary_pos_emb_triton def apply_rotary_pos_emb_vision(self, t: torch.Tensor, freqs: torch.Tensor) -> torch.Tensor: t_ = t.float() cos = freqs.cos() sin = freqs.sin() - output = apply_rotary_emb(t_, cos, sin).type_as(t) + output = self.apply_rotary_emb(t_, cos, sin).type_as(t) return output def forward( @@ -321,13 +330,11 @@ def load_image(self, img: List[ImageItem]): image_data = Image.open(BytesIO(image_data)) image_data = resize_image(image_data) pixel_values, image_grid_thw = self.processor.preprocess(image_data) - # pixel_values, image_grid_thw = tensor["pixel_values"], tensor["image_grid_thw"] elif isinstance(img, dict): image_data = read_shm(get_shm_name_data(img["uuid"])) image_data = Image.open(BytesIO(image_data)) image_data = resize_image(image_data) pixel_values, image_grid_thw = self.processor.preprocess(image_data) - # pixel_values, image_grid_thw = tensor["pixel_values"], tensor["image_grid_thw"] else: raise Exception("Unsupport input types: {} for {}".format(type(img), img)) return pixel_values.to(dtype=self.data_type), image_grid_thw diff --git a/lightllm/models/qwen2_vl/triton_kernel/rotary_pos_emb.py b/lightllm/models/qwen2_vl/triton_kernel/rotary_pos_emb.py index ab9013c8a..499211620 100644 --- a/lightllm/models/qwen2_vl/triton_kernel/rotary_pos_emb.py +++ b/lightllm/models/qwen2_vl/triton_kernel/rotary_pos_emb.py @@ -11,11 +11,13 @@ def rotate_half(x): return torch.cat((-x2, x1), dim=-1) -def apply_rotary_pos_emb_vision_ref(tensor: torch.Tensor, freqs: torch.Tensor) -> torch.Tensor: +def apply_rotary_pos_emb_vision_ref( + tensor: torch.Tensor, + cos: torch.Tensor, + sin: torch.Tensor, +) -> torch.Tensor: orig_dtype = tensor.dtype tensor = tensor.float() - cos = freqs.cos() - sin = freqs.sin() cos = cos.unsqueeze(1).repeat(1, 1, 2).unsqueeze(0).float() sin = sin.unsqueeze(1).repeat(1, 1, 2).unsqueeze(0).float() output = (tensor * cos) + (rotate_half(tensor) * sin) @@ -76,15 +78,17 @@ def rotary_kernel( tl.store(out_ptr_, y, mask=mask) -def apply_rotary_pos_emb_triton(tensor: torch.Tensor, freqs: torch.Tensor, BLOCK_D: int = 128) -> torch.Tensor: - assert tensor.is_cuda and freqs.is_cuda +def apply_rotary_pos_emb_triton( + tensor: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, BLOCK_D: int = 128 +) -> torch.Tensor: + assert tensor.is_cuda and cos.is_cuda and sin.is_cuda if tensor.ndim != 4: raise RuntimeError("tensor shape should be [B, L, H, D]") orig_dtype = tensor.dtype x = tensor.float() - cos = freqs.cos().unsqueeze(1).repeat(1, 1, 2).view(freqs.size(0), -1).contiguous().float() - sin = freqs.sin().unsqueeze(1).repeat(1, 1, 2).view(freqs.size(0), -1).contiguous().float() + cos = cos.unsqueeze(1).repeat(1, 1, 2).view(cos.size(0), -1).contiguous().float() + sin = sin.unsqueeze(1).repeat(1, 1, 2).view(sin.size(0), -1).contiguous().float() B, L, H, D = x.shape HALF_D = D // 2 @@ -127,14 +131,14 @@ def test_accuracy_and_speed( x = torch.randn(B, L, H, D, device="cuda") # 误差 - y_ref = apply_rotary_pos_emb_vision_ref(x, freqs) - y_tri = apply_rotary_pos_emb_triton(x, freqs) + y_ref = apply_rotary_pos_emb_vision_ref(x, freqs.cos(), freqs.sin()) + y_tri = apply_rotary_pos_emb_triton(x, freqs.cos(), freqs.sin()) print("max abs error:", (y_ref - y_tri).abs().max().item()) # 预热 for _ in range(warmup): - apply_rotary_pos_emb_vision_ref(x, freqs) - apply_rotary_pos_emb_triton(x, freqs) + apply_rotary_pos_emb_vision_ref(x, freqs.cos(), freqs.sin()) + apply_rotary_pos_emb_triton(x, freqs.cos(), freqs.sin()) torch.cuda.synchronize() # 计时 @@ -143,7 +147,7 @@ def bench(fn): end = torch.cuda.Event(True) start.record() for _ in range(repeats): - fn(x, freqs) + fn(x, freqs.cos(), freqs.sin()) end.record() torch.cuda.synchronize() return start.elapsed_time(end) / repeats # ms @@ -152,5 +156,5 @@ def bench(fn): print(f"Triton : {bench(apply_rotary_pos_emb_triton):.3f} ms") -# if __name__ == "__main__": -# test_accuracy_and_speed() +if __name__ == "__main__": + test_accuracy_and_speed() diff --git a/lightllm/models/qwen2_vl/vision_process.py b/lightllm/models/qwen2_vl/vision_process.py index 0f8e8572b..ce0f7c771 100644 --- a/lightllm/models/qwen2_vl/vision_process.py +++ b/lightllm/models/qwen2_vl/vision_process.py @@ -5,7 +5,7 @@ from PIL import Image from typing import List, Optional, Union -from transformers.image_processing_utils import BaseImageProcessor, BatchFeature +from transformers.image_processing_utils import BaseImageProcessor from transformers.image_transforms import ( convert_to_rgb, resize, From 1742f3e8dfa4f4253b7f72e4633f2038574cf27f Mon Sep 17 00:00:00 2001 From: sangchengmeng Date: Wed, 13 Aug 2025 06:45:32 +0000 Subject: [PATCH 18/18] 0813-fix-visual-server --- lightllm/models/gemma3/gemma3_visual.py | 5 ++- lightllm/models/llava/llava_visual.py | 5 ++- lightllm/models/qwen2_5_vl/qwen2_5_visual.py | 41 +++++++++---------- lightllm/models/qwen2_vl/qwen2_visual.py | 8 ++-- .../qwen_vl/layer_infer/pre_layer_infer.py | 33 +++++++++------ lightllm/models/qwen_vl/qwen_visual.py | 5 +++ lightllm/models/tarsier2/tarsier2_visual.py | 5 +++ .../visualserver/model_infer/model_rpc.py | 10 ++--- 8 files changed, 64 insertions(+), 48 deletions(-) diff --git a/lightllm/models/gemma3/gemma3_visual.py b/lightllm/models/gemma3/gemma3_visual.py index b2f7a6b77..16e244c3e 100644 --- a/lightllm/models/gemma3/gemma3_visual.py +++ b/lightllm/models/gemma3/gemma3_visual.py @@ -16,7 +16,10 @@ class Gemma3VisionModel: - def __init__(self): + def __init__(self, kvargs): + self.weight_dir = kvargs["weight_dir"] + self.load_model(self.weight_dir) + self.cuda() pass def load_model(self, weight_dir): diff --git a/lightllm/models/llava/llava_visual.py b/lightllm/models/llava/llava_visual.py index 293bcd445..398bb66b6 100644 --- a/lightllm/models/llava/llava_visual.py +++ b/lightllm/models/llava/llava_visual.py @@ -15,7 +15,10 @@ class LlavaVisionModel: - def __init__(self): + def __init__(self, kvargs): + self.weight_dir = kvargs["weight_dir"] + self.load_model(self.weight_dir) + self.cuda() pass def load_model(self, weight_dir): diff --git a/lightllm/models/qwen2_5_vl/qwen2_5_visual.py b/lightllm/models/qwen2_5_vl/qwen2_5_visual.py index 0a599475d..addd448a2 100644 --- a/lightllm/models/qwen2_5_vl/qwen2_5_visual.py +++ b/lightllm/models/qwen2_5_vl/qwen2_5_visual.py @@ -369,22 +369,6 @@ def forward(self, hidden_states: torch.Tensor, grid_thw: torch.Tensor) -> torch. return hidden_states - def load_image(self, img: List[ImageItem]): - pixel_values = None - if isinstance(img, ImageItem): - image_data = read_shm(get_shm_name_data(img.uuid)) - image_data = Image.open(BytesIO(image_data)) - image_data = resize_image(image_data) - pixel_values, image_grid_thw = self.processor.preprocess(image_data) - elif isinstance(img, dict): - image_data = read_shm(get_shm_name_data(img["uuid"])) - image_data = Image.open(BytesIO(image_data)) - image_data = resize_image(image_data) - pixel_values, image_grid_thw = self.processor.preprocess(image_data) - else: - raise Exception("Unsupport input types: {} for {}".format(type(img), img)) - return pixel_values.to(dtype=self.data_type), image_grid_thw - def load_model(self, weight_dir): bin_weight_files = [file_ for file_ in os.listdir(weight_dir) if file_.endswith(".bin")] @@ -407,6 +391,22 @@ def load_model(self, weight_dir): self.load_state_dict(weight_dict) + def load_image(self, img: List[ImageItem]): + pixel_values = None + if isinstance(img, ImageItem): + image_data = read_shm(get_shm_name_data(img.uuid)) + image_data = Image.open(BytesIO(image_data)) + image_data = resize_image(image_data) + pixel_values, image_grid_thw = self.processor.preprocess(image_data) + elif isinstance(img, dict): + image_data = read_shm(get_shm_name_data(img["uuid"])) + image_data = Image.open(BytesIO(image_data)) + image_data = resize_image(image_data) + pixel_values, image_grid_thw = self.processor.preprocess(image_data) + else: + raise Exception("Unsupport input types: {} for {}".format(type(img), img)) + return pixel_values.to(dtype=self.data_type), image_grid_thw + def encode(self, images: List[ImageItem]): img_tensors = [] valid_ids = [] @@ -420,9 +420,7 @@ def encode(self, images: List[ImageItem]): image_data = read_shm(get_shm_name_data(img.uuid)) image_data = Image.open(BytesIO(image_data)) image_data = resize_image(image_data) - image_inputs = self.processor.preprocess(images=image_data, return_tensors="pt") - pixel_values = image_inputs["pixel_values"].to(dtype=torch.bfloat16) - image_grid_thw = image_inputs["image_grid_thw"] + pixel_values, image_grid_thw = self.processor.preprocess(image_data) img_tensors.append(pixel_values) img_grids.append(image_grid_thw) else: @@ -440,10 +438,9 @@ def encode(self, images: List[ImageItem]): imgs = torch.cat(img_tensors, dim=0) grid_thw = torch.cat(img_grids, dim=0) - pixel_values = imgs.cuda().to(dtype=torch.float32) - image_grid_thw = grid_thw.cuda() + pixel_values = imgs.to("cuda", dtype=self.data_type, non_blocking=True) + image_grid_thw = grid_thw.to("cuda", non_blocking=True) - pixel_values = pixel_values.type(self.get_dtype()) all_img_embeds = self.forward(pixel_values, grid_thw=image_grid_thw) return all_img_embeds, uuids, valid_ids diff --git a/lightllm/models/qwen2_vl/qwen2_visual.py b/lightllm/models/qwen2_vl/qwen2_visual.py index 467c1c59a..970fe78b4 100644 --- a/lightllm/models/qwen2_vl/qwen2_visual.py +++ b/lightllm/models/qwen2_vl/qwen2_visual.py @@ -352,9 +352,7 @@ def encode(self, images: List[ImageItem]): image_data = read_shm(get_shm_name_data(img.uuid)) image_data = Image.open(BytesIO(image_data)) image_data = resize_image(image_data) - tensor = self.processor.preprocess(images=image_data, return_tensors="pt") - pixel_values, image_grid_thw = tensor["pixel_values"], tensor["image_grid_thw"] - pixel_values = pixel_values.to(dtype=torch.bfloat16) + pixel_values, image_grid_thw = self.processor.preprocess(image_data) img_tensors.append(pixel_values) img_grids.append(image_grid_thw) else: @@ -372,8 +370,8 @@ def encode(self, images: List[ImageItem]): imgs = torch.cat(img_tensors, dim=0) grid_thw = torch.cat(img_grids, dim=0) - pixel_values = imgs.cuda().to(dtype=torch.float32) - image_grid_thw = grid_thw.cuda() + pixel_values = imgs.to("cuda", dtype=self.data_type, non_blocking=True) + image_grid_thw = grid_thw.to("cuda", non_blocking=True) all_img_embeds = self.forward(pixel_values, grid_thw=image_grid_thw) diff --git a/lightllm/models/qwen_vl/layer_infer/pre_layer_infer.py b/lightllm/models/qwen_vl/layer_infer/pre_layer_infer.py index 2f5d07a22..e6d0dba2d 100644 --- a/lightllm/models/qwen_vl/layer_infer/pre_layer_infer.py +++ b/lightllm/models/qwen_vl/layer_infer/pre_layer_infer.py @@ -33,20 +33,27 @@ def __init__(self, network_config, mode): return def _infer_image_embeds(self, infer_state, layer_weight): - if layer_weight.visual_model is None: - return image_weight = [] - for batch_id, p in enumerate(infer_state.multimodal_params): - for img in p["images"] + p["audios"]: - if img.get("_prefill_", True): - image_data = img["image_data"].to("cuda", non_blocking=True) - image_grid_thw = img["image_grid_thw"] - # image_embed = torch.zeros( - # (img["token_num"],layer_weight.wte_weight_.shape[1]),device="cuda",dtype=torch.bfloat16) - image_embed = layer_weight.visual_model.forward(image_data, image_grid_thw).view( - img["token_num"], -1 - ) - image_weight.append(image_embed) + if layer_weight.visual_model is None: + for batch_id, p in enumerate(infer_state.multimodal_params): + for img in p["images"] + p["audios"]: + # skip the same image + if img.get("_prefill_", True): + # pull the img_embeds by uid from shm + image_embed = read_shm(get_shm_name_embed(img["uuid"])) + image_weight.append(bytes2tensor(image_embed).cuda().reshape(img["token_num"], -1)) + else: + for batch_id, p in enumerate(infer_state.multimodal_params): + for img in p["images"] + p["audios"]: + if img.get("_prefill_", True): + image_data = img["image_data"].to("cuda", non_blocking=True) + image_grid_thw = img["image_grid_thw"] + # image_embed = torch.zeros( + # (img["token_num"],layer_weight.wte_weight_.shape[1]),device="cuda",dtype=torch.bfloat16) + image_embed = layer_weight.visual_model.forward(image_data, image_grid_thw).view( + img["token_num"], -1 + ) + image_weight.append(image_embed) if len(image_weight) > 0: image_weight = torch.cat(image_weight, dim=0) image_weight = image_weight / self.tp_world_size_ diff --git a/lightllm/models/qwen_vl/qwen_visual.py b/lightllm/models/qwen_vl/qwen_visual.py index f6468b144..b826b52d4 100644 --- a/lightllm/models/qwen_vl/qwen_visual.py +++ b/lightllm/models/qwen_vl/qwen_visual.py @@ -333,6 +333,7 @@ def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None): class QWenVisionTransformer(nn.Module): def __init__( self, + kvargs, image_size: int, patch_size: int, width: int, @@ -344,6 +345,7 @@ def __init__( **kwargs, ): super().__init__() + self.weight_dir = kvargs["weight_dir"] image_height, image_width = self.image_size = (image_size, image_size) patch_height, patch_width = self.patch_size = (patch_size, patch_size) self.grid_size = (image_height // patch_height, image_width // patch_width) @@ -388,6 +390,9 @@ def __init__( self.ln_post = norm_layer(output_dim) self.proj = nn.Parameter((output_dim ** -0.5) * torch.randn(output_dim, output_dim)) + self.load_model(self.weight_dir) + self.cuda() + def forward(self, x: torch.Tensor): x = x.to( dtype=self.transformer.get_cast_dtype(), diff --git a/lightllm/models/tarsier2/tarsier2_visual.py b/lightllm/models/tarsier2/tarsier2_visual.py index d7171dfe5..3da66423e 100644 --- a/lightllm/models/tarsier2/tarsier2_visual.py +++ b/lightllm/models/tarsier2/tarsier2_visual.py @@ -152,6 +152,7 @@ def forward(self, image_features, input_embeddings): class TarsierVisionTransformerPretrainedModel(nn.Module): def __init__( self, + kvargs, vision_config=None, text_config=None, ignore_index=-100, @@ -165,6 +166,7 @@ def __init__( **kwargs, ): super().__init__() + self.weight_dir = kvargs["weight_dir"] self.vision_tower = Qwen2VLTransformer(**vision_config) if projection_head == "Pixel_Shuffle": @@ -195,6 +197,9 @@ def __init__( self.image_token_index = image_token_index self.merge_size = 1 + self.load_model(self.weight_dir) + self.cuda() + def forward( self, pixel_values: torch.Tensor = None, diff --git a/lightllm/server/visualserver/model_infer/model_rpc.py b/lightllm/server/visualserver/model_infer/model_rpc.py index 679ab7ffd..d27a31207 100644 --- a/lightllm/server/visualserver/model_infer/model_rpc.py +++ b/lightllm/server/visualserver/model_infer/model_rpc.py @@ -54,25 +54,23 @@ def exposed_init_model(self, kvargs): "max_batch_size": kvargs["max_batch_size"], } if self.model_type == "qwen": - self.model = QWenVisionTransformer(**model_cfg["visual"]).eval().bfloat16() + self.model = QWenVisionTransformer(kvargs, **model_cfg["visual"]).eval().bfloat16() elif self.model_type == "qwen2_vl": - self.model = Qwen2VLTransformer(**model_cfg["vision_config"]).eval().bfloat16() + self.model = Qwen2VLTransformer(kvargs, **model_cfg["vision_config"]).eval().bfloat16() elif self.model_type == "qwen2_5_vl": self.model = Qwen2_5VLTransformer(kvargs, **model_cfg["vision_config"]).eval().bfloat16() elif model_cfg["architectures"][0] == "TarsierForConditionalGeneration": self.model = TarsierVisionTransformerPretrainedModel(kvargs, **model_cfg).eval().bfloat16() elif self.model_type == "llava": - self.model = LlavaVisionModel() + self.model = LlavaVisionModel(kvargs) elif self.model_type == "internvl_chat": self.model = VisionTransformer(kvargs) # self.model = InternVLVisionModel() elif self.model_type == "gemma3": - self.model = Gemma3VisionModel() + self.model = Gemma3VisionModel(kvargs) else: raise Exception(f"can not support {self.model_type} now") - self.model.load_model(weight_dir) - self.model = self.model.cuda() except Exception as e: print("#" * 16) print("load model error:", str(e), e, type(e))