diff --git a/lightllm/models/internvl/model.py b/lightllm/models/internvl/model.py index a724d5668..d1d269436 100644 --- a/lightllm/models/internvl/model.py +++ b/lightllm/models/internvl/model.py @@ -64,6 +64,7 @@ def init_imageitem_extral_params( img.extra_params["image_patch_max_num"] = 6 elif num_images > 6: img.extra_params["image_patch_max_num"] = 0 + img.patch_num = self.get_image_patch(img) return def init_audioitem_extral_params( @@ -71,14 +72,14 @@ def init_audioitem_extral_params( ): return - def get_image_token_length(self, img: ImageItem): - return ( - self.get_image_patch_func( - img.image_w, img.image_h, max_num=img.extra_params["image_patch_max_num"], use_thumbnail=True - ) - * self.image_length + def get_image_patch(self, img: ImageItem): + return self.get_image_patch_func( + img.image_w, img.image_h, max_num=img.extra_params["image_patch_max_num"], use_thumbnail=True ) + def get_image_token_length(self, img: ImageItem): + return self.get_image_patch(img) * self.image_length + def get_audio_token_length(self, audio: AudioItem): L = audio.audio_length L = L if L <= 480000 else 480000 # max_length < 30s diff --git a/lightllm/models/qwen2_vl/triton_kernel/rotary_pos_emb.py b/lightllm/models/qwen2_vl/triton_kernel/rotary_pos_emb.py index 07e7c8b3f..0ae8099c6 100644 --- a/lightllm/models/qwen2_vl/triton_kernel/rotary_pos_emb.py +++ b/lightllm/models/qwen2_vl/triton_kernel/rotary_pos_emb.py @@ -1,7 +1,6 @@ -import math -import torch import triton import triton.language as tl +import torch @triton.jit @@ -17,46 +16,72 @@ def rotary_kernel( stride_cos_d, stride_sin_l, stride_sin_d, - D: tl.constexpr, - HALF_D: tl.constexpr, + L, + H, + D, + BLOCK_SEQ: tl.constexpr, + BLOCK_HEAD: tl.constexpr, BLOCK_D: tl.constexpr, ): - pid_h = tl.program_id(0).to(tl.int64) - pid_l = tl.program_id(1).to(tl.int64) - pid_blk = tl.program_id(2).to(tl.int64) + pid_head_blk = tl.program_id(0) + pid_seq_blk = tl.program_id(1) + offs_h = pid_head_blk * BLOCK_HEAD + tl.arange(0, BLOCK_HEAD) + offs_l = pid_seq_blk * BLOCK_SEQ + tl.arange(0, BLOCK_SEQ) offs_d = tl.arange(0, BLOCK_D) - d = pid_blk * BLOCK_D + offs_d - mask = d < D - base = pid_l * stride_l + pid_h * stride_h + offs_h = offs_h.to(tl.int64) + offs_l = offs_l.to(tl.int64) + offs_d = offs_d.to(tl.int64) + + mask_h = offs_h < H + mask_l = offs_l < L + mask_d = offs_d < D + + HALF_D = D // 2 + + l_b = offs_l[:, None, None] + h_b = offs_h[None, :, None] + d_b = offs_d[None, None, :] + + mask = mask_l[:, None, None] & mask_h[None, :, None] & mask_d[None, None, :] + + base = l_b * stride_l + h_b * stride_h + d_b * stride_d + x = tl.load(inp_ptr + base, mask=mask, other=0.0) - in_ptr = inp_ptr + base + d * stride_d - cos_ptr_ = cos_ptr + pid_l * stride_cos_l + d - sin_ptr_ = sin_ptr + pid_l * stride_sin_l + d + cos_base_2d = offs_l[:, None] * stride_cos_l + offs_d[None, :] * stride_cos_d + sin_base_2d = offs_l[:, None] * stride_sin_l + offs_d[None, :] * stride_sin_d + mask_ld = mask_l[:, None] & mask_d[None, :] - x = tl.load(in_ptr, mask=mask) - cos = tl.load(cos_ptr_, mask=mask) - sin = tl.load(sin_ptr_, mask=mask) + cos_2d = tl.load(cos_ptr + cos_base_2d, mask=mask_ld, other=0.0) + sin_2d = tl.load(sin_ptr + sin_base_2d, mask=mask_ld, other=0.0) - partner_d = tl.where(d < HALF_D, d + HALF_D, d - HALF_D) - partner_ptr = inp_ptr + base + partner_d * stride_d - partner_val = tl.load(partner_ptr, mask=mask) - rotated = tl.where(d < HALF_D, -partner_val, partner_val) + cos = cos_2d[:, None, :] + sin = sin_2d[:, None, :] + + partner_d = tl.where(offs_d < HALF_D, offs_d + HALF_D, offs_d - HALF_D) + partner_d_b = partner_d[None, None, :] + + partner_base = l_b * stride_l + h_b * stride_h + partner_d_b * stride_d + partner_val = tl.load(inp_ptr + partner_base, mask=mask, other=0.0) + + rotated = tl.where(d_b < HALF_D, -partner_val, partner_val) y = x * cos + rotated * sin - out_ptr_ = out_ptr + base + d - tl.store(out_ptr_, y, mask=mask) + tl.store(out_ptr + base, y, mask=mask) def apply_rotary_pos_emb_triton( - tensor: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, BLOCK_D: int = 128 + tensor: torch.Tensor, + cos: torch.Tensor, + sin: torch.Tensor, ) -> torch.Tensor: assert tensor.is_cuda and cos.is_cuda and sin.is_cuda assert cos.is_contiguous() and sin.is_contiguous() if tensor.ndim != 3: raise RuntimeError("tensor shape should be [L, H, D]") + orig_dtype = tensor.dtype x = tensor.float() @@ -64,10 +89,21 @@ def apply_rotary_pos_emb_triton( sin = sin.repeat(1, 2).view(sin.size(0), -1).contiguous().float() L, H, D = x.shape - HALF_D = D // 2 y = torch.empty_like(x) - grid = (H, L, triton.cdiv(D, BLOCK_D)) + BLOCK_SEQ = 16 + BLOCK_HEAD = 4 + BLOCK_D = triton.next_power_of_2(D) + + if D >= 128: + num_warps = 8 + else: + num_warps = 4 + + grid = ( + triton.cdiv(H, BLOCK_HEAD), + triton.cdiv(L, BLOCK_SEQ), + ) rotary_kernel[grid]( inp_ptr=x, @@ -81,9 +117,13 @@ def apply_rotary_pos_emb_triton( stride_cos_d=cos.stride(1), stride_sin_l=sin.stride(0), stride_sin_d=sin.stride(1), + L=L, + H=H, D=D, - HALF_D=HALF_D, + BLOCK_SEQ=BLOCK_SEQ, + BLOCK_HEAD=BLOCK_HEAD, BLOCK_D=BLOCK_D, + num_warps=num_warps, ) return y.to(orig_dtype) diff --git a/lightllm/models/qwen2_vl/vision_process.py b/lightllm/models/qwen2_vl/vision_process.py index 95622fc02..90601f759 100644 --- a/lightllm/models/qwen2_vl/vision_process.py +++ b/lightllm/models/qwen2_vl/vision_process.py @@ -22,6 +22,11 @@ ) from torchvision.transforms.v2 import functional as F +from lightllm.utils.log_utils import init_logger + +logger = init_logger(__name__) + + IMAGE_FACTOR = 28 MIN_PIXELS = 4 * 28 * 28 MAX_PIXELS = 16384 * 28 * 28 @@ -160,9 +165,19 @@ def rescale_and_normalize( return images + @torch.inference_mode() def preprocess(self, image) -> Tuple[torch.Tensor, torch.Tensor]: + try: + return self._preprocess_bydevice(image, device="cuda") + except Exception as e: + logger.warning(f"Exception during image preprocessing on CUDA: {str(e)}") + torch.cuda.current_stream().synchronize() + return self._preprocess_bydevice(image, device="cpu") + + def _preprocess_bydevice(self, image, device="cuda") -> Tuple[torch.Tensor, torch.Tensor]: image_arr = np.asarray(image, dtype=np.uint8) - image_data = torch.from_numpy(image_arr).permute(2, 0, 1).contiguous().to("cuda", non_blocking=True) + image_data = torch.from_numpy(image_arr).permute(2, 0, 1).contiguous().to(device=device, non_blocking=True) + grouped_images, grouped_images_index = group_images_by_shape( [image_data], disable_grouping=self.disable_grouping ) @@ -183,27 +198,39 @@ def preprocess(self, image) -> Tuple[torch.Tensor, torch.Tensor]: interpolation=self.interpolation, ) resized_images_grouped[shape] = stacked_images + + grouped_images = None resized_images = reorder_images(resized_images_grouped, grouped_images_index) + resized_images_grouped = None - # Group images by size for further processing - # Needed in case do_resize is False, or resize returns images with different sizes grouped_images, grouped_images_index = group_images_by_shape( resized_images, disable_grouping=self.disable_grouping ) + resized_images = None + processed_images_grouped = {} processed_grids = {} + for shape, stacked_images in grouped_images.items(): + stacked_images = stacked_images.to("cuda", non_blocking=True) + resized_height, resized_width = stacked_images.shape[-2:] - # Fused rescale and normalize + patches = self.rescale_and_normalize( - stacked_images, self.do_rescale, self.rescale_factor, self.do_normalize, self.image_mean, self.image_std + stacked_images, + self.do_rescale, + self.rescale_factor, + self.do_normalize, + self.image_mean, + self.image_std, ) if patches.ndim == 4: - # add a temporal dimension if we have images patches = patches.unsqueeze(1) + if patches.shape[1] % self.temporal_patch_size != 0: repeats = patches[:, -1:].repeat(1, self.temporal_patch_size - 1, 1, 1, 1) patches = torch.cat([patches, repeats], dim=1) + batch_size, grid_t, channel = patches.shape[:3] grid_t = grid_t // self.temporal_patch_size grid_h, grid_w = resized_height // self.patch_size, resized_width // self.patch_size @@ -224,8 +251,7 @@ def preprocess(self, image) -> Tuple[torch.Tensor, torch.Tensor]: .permute(0, 1, 4, 7, 5, 8, 3, 2, 6, 9) .contiguous() ) - # Reorder dimensions to group grid and patch information for subsequent flattening. - # (batch, grid_t, grid_h, grid_w, merge_h, merge_w, channel, temp_patch_size, patch_h, patch_w) + flatten_patches = patches.view( batch_size, grid_t * grid_h * grid_w, @@ -235,9 +261,12 @@ def preprocess(self, image) -> Tuple[torch.Tensor, torch.Tensor]: processed_images_grouped[shape] = flatten_patches processed_grids[shape] = [[grid_t, grid_h, grid_w]] * batch_size + grouped_images = None + processed_images = reorder_images(processed_images_grouped, grouped_images_index) processed_grids = reorder_images(processed_grids, grouped_images_index) - pixel_values = torch.cat(processed_images, dim=0) # (num_patches_total, C*T*ps*ps) + + pixel_values = torch.cat(processed_images, dim=0) image_grid_thw = torch.as_tensor(processed_grids) return pixel_values, image_grid_thw diff --git a/lightllm/models/qwen_vl/layer_infer/pre_layer_infer.py b/lightllm/models/qwen_vl/layer_infer/pre_layer_infer.py index b5b31a413..dd6585d60 100644 --- a/lightllm/models/qwen_vl/layer_infer/pre_layer_infer.py +++ b/lightllm/models/qwen_vl/layer_infer/pre_layer_infer.py @@ -1,3 +1,5 @@ +import rpyc +import socket import torch import torch.distributed as dist @@ -6,9 +8,10 @@ from lightllm.models.llama.layer_infer.pre_layer_infer import LlamaPreLayerInfer from lightllm.utils.infer_utils import mark_cost_time -from lightllm.server.embed_cache.utils import bytes2tensor, read_shm, get_shm_name_embed +from lightllm.server.embed_cache.utils import bytes2tensor, read_shm, get_shm_name_embed, read_afs from lightllm.common.basemodel.triton_kernel.multimodal_emb import multimodal_emb from lightllm.distributed.communication_op import all_reduce +from lightllm.utils.envs_utils import get_env_start_args """ @@ -29,6 +32,9 @@ class LlamaMultimodalPreLayerInfer(LlamaPreLayerInfer): def __init__(self, network_config, mode): super().__init__(network_config, mode) + self.args = get_env_start_args() + self.cache_client = rpyc.connect("localhost", self.args.cache_port, config={"allow_pickle": True}) + self.cache_client._channel.stream.sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) return def context_forward(self, input_ids, infer_state: LlamaInferStateInfo, layer_weight: LlamaPreAndPostLayerWeight): @@ -50,9 +56,13 @@ def context_forward(self, input_ids, infer_state: LlamaInferStateInfo, layer_wei # skip the same image if img["token_id"] in img_start_token_ids or img["_prefill_"] is False: continue - # pull the img_embeds by uid from shm - data = read_shm(get_shm_name_embed(img["uuid"])) - img_weight.append(bytes2tensor(data).cuda().reshape(img["token_num"], -1)) + # pull the img_embeds by uid from shm or afs + if self.args.enable_remote_vit: + embed = read_afs(get_shm_name_embed(img["uuid"]), self.args.image_embed_dir) + else: + embed = read_shm(get_shm_name_embed(img["uuid"])) + self.cache_client.root.release([img["uuid"]]) + img_weight.append(bytes2tensor(embed).cuda().reshape(img["token_num"], -1)) img_start_token_ids.append(img["token_id"]) img_token_lens.append(img["token_num"]) img_start_locs.append(img_start_loc) diff --git a/lightllm/server/api_cli.py b/lightllm/server/api_cli.py index ec11f8f1d..c38576a63 100644 --- a/lightllm/server/api_cli.py +++ b/lightllm/server/api_cli.py @@ -7,7 +7,7 @@ def make_argument_parser() -> argparse.ArgumentParser: parser.add_argument( "--run_mode", type=str, - choices=["normal", "prefill", "decode", "nixl_prefill", "nixl_decode", "pd_master", "config_server"], + choices=["normal", "prefill", "decode", "nixl_prefill", "nixl_decode", "pd_master", "config_server", "visual"], default="normal", help="""set run mode, normal is started for a single server, prefill decode pd_master is for pd split run mode, config_server is for pd split mode used to register pd_master node, and get pd_master node list, @@ -516,6 +516,41 @@ def make_argument_parser() -> argparse.ArgumentParser: default=0.03, help="""The interval of the schedule time, default is 30ms.""", ) + parser.add_argument( + "--image_embed_dir", + type=str, + default=None, + help="path for vit embed", + ) + parser.add_argument( + "--enable_remote_vit", + action="store_true", + help="Whether to enable remote vit for multimodal service.", + ) + parser.add_argument( + "--remote_vit_port", + type=int, + default=12346, + help="The port number for the remote vit service.", + ) + # redis for vit llm disaggregation + parser.add_argument( + "--redis_port", + type=int, + default=6379, + help="The port number for the redis service in config_server mode.", + ) + parser.add_argument( + "--redis_evict_fraction", + type=float, + default=0.3, + help="The evict fraction for the redis service in config_server mode.", + ) + parser.add_argument( + "--start_redis", + action="store_true", + help="Whether to start the redis service in config_server mode.", + ) parser.add_argument( "--enable_cpu_cache", action="store_true", diff --git a/lightllm/server/api_http.py b/lightllm/server/api_http.py index 8bda50fb7..5b7aab501 100755 --- a/lightllm/server/api_http.py +++ b/lightllm/server/api_http.py @@ -42,8 +42,11 @@ from lightllm.server.core.objs import StartArgs from .multimodal_params import MultimodalParams from .httpserver.manager import HttpServerManager +from .visualserver.manager import VisualManager from .httpserver_for_pd_master.manager import HttpServerManagerForPDMaster -from .api_lightllm import lightllm_get_score + +# from .visualserver.manager import VisualManager +from .api_lightllm import lightllm_get_score, lightllm_get_image_embedding from lightllm.utils.envs_utils import get_env_start_args, get_lightllm_websocket_max_message_size from lightllm.utils.log_utils import init_logger from lightllm.utils.error_utils import ServerBusyError @@ -71,6 +74,7 @@ class G_Objs: g_generate_func: Callable = None g_generate_stream_func: Callable = None httpserver_manager: Union[HttpServerManager, HttpServerManagerForPDMaster] = None + visual_manager: VisualManager = None shared_token_load: TokenLoad = None def set_args(self, args: StartArgs): @@ -92,6 +96,8 @@ def set_args(self, args: StartArgs): self.httpserver_manager = HttpServerManagerForPDMaster( args=args, ) + elif args.run_mode == "visual": + self.metric_client = MetricClient(args.metric_port) else: init_tokenizer(args) # for openai api SamplingParams.load_generation_cfg(args.model_dir) @@ -134,7 +140,7 @@ def get_model_name(): @app.get("/health", summary="Check server health") @app.head("/health", summary="Check server health") async def healthcheck(request: Request): - if g_objs.args.run_mode == "pd_master": + if g_objs.args.run_mode in ["pd_master", "visual"]: return JSONResponse({"message": "Ok"}, status_code=200) if os.environ.get("DEBUG_HEALTHCHECK_RETURN_FAIL") == "true": @@ -219,6 +225,18 @@ async def get_score(request: Request) -> Response: return create_error_response(HTTPStatus.EXPECTATION_FAILED, str(e)) +@app.post("/get_image_embedding") +async def get_image_embed(request: Request) -> Response: + try: + return await lightllm_get_image_embedding(request, g_objs.httpserver_manager) + except ServerBusyError as e: + logger.error("%s", str(e), exc_info=True) + return create_error_response(HTTPStatus.SERVICE_UNAVAILABLE, str(e)) + except Exception as e: + logger.error("An error occurred: %s", str(e), exc_info=True) + return create_error_response(HTTPStatus.EXPECTATION_FAILED, str(e)) + + @app.post("/") async def compat_generate(request: Request) -> Response: if get_env_start_args().run_mode in ["prefill", "decode", "nixl_prefill", "nixl_decode"]: @@ -357,6 +375,8 @@ async def startup_event(): logger.info("server start up") loop = asyncio.get_event_loop() g_objs.set_args(get_env_start_args()) + if g_objs.httpserver_manager is None: + return loop.create_task(g_objs.httpserver_manager.handle_loop()) logger.info(f"server start up ok, loop use is {asyncio.get_event_loop()}") return diff --git a/lightllm/server/api_lightllm.py b/lightllm/server/api_lightllm.py index d3592a5f5..ecf113f38 100644 --- a/lightllm/server/api_lightllm.py +++ b/lightllm/server/api_lightllm.py @@ -5,6 +5,8 @@ from lightllm.server.core.objs.sampling_params import SamplingParams from .multimodal_params import MultimodalParams from .httpserver.manager import HttpServerManager +from .visualserver.manager import VisualManager +from fastapi.responses import JSONResponse import ujson as json @@ -150,3 +152,19 @@ async def stream_results() -> AsyncGenerator[bytes, None]: background_tasks = BackgroundTasks() return StreamingResponse(stream_results(), media_type="text/event-stream", background=background_tasks) + + +async def lightllm_get_image_embedding(request: Request, httpserver_manager: HttpServerManager) -> Response: + request_dict = await request.json() + # request_dict: {'parameters': {'max_new_tokens': 128}, + # 'multimodal_params': {'images': [{'type': 'base64', 'data': 'base64'}]}} + sample_params_dict = request_dict["parameters"] + sampling_params = SamplingParams() + sampling_params.init(tokenizer=None, **sample_params_dict) + sampling_params.verify() + multimodal_params_dict = request_dict.get("multimodal_params", {}) + multimodal_params = MultimodalParams(**multimodal_params_dict) + + await httpserver_manager.get_image_embeding(sampling_params, multimodal_params, request=request) + + return JSONResponse({"message": "OK"}, status_code=200) diff --git a/lightllm/server/api_server.py b/lightllm/server/api_server.py index b4447d808..c6700c041 100755 --- a/lightllm/server/api_server.py +++ b/lightllm/server/api_server.py @@ -5,11 +5,13 @@ torch.multiprocessing.set_start_method("spawn") # this code will not be ok for settings to fork to subprocess parser = make_argument_parser() args = parser.parse_args() - from .api_start import pd_master_start, normal_or_p_d_start, config_server_start + from .api_start import pd_master_start, normal_or_p_d_start, visual_start, config_server_start if args.run_mode == "pd_master": pd_master_start(args) elif args.run_mode == "config_server": config_server_start(args) + elif args.run_mode == "visual": + visual_start(args) else: normal_or_p_d_start(args) diff --git a/lightllm/server/api_start.py b/lightllm/server/api_start.py index f73be30db..a4db51c2e 100644 --- a/lightllm/server/api_start.py +++ b/lightllm/server/api_start.py @@ -5,7 +5,7 @@ import subprocess import signal from lightllm.utils.net_utils import alloc_can_use_network_port, PortLocker -from lightllm.utils.start_utils import process_manager, kill_recursive +from lightllm.utils.start_utils import process_manager, kill_recursive, is_multimodal_mode from .metrics.manager import start_metric_manager from .embed_cache.manager import start_cache_manager from lightllm.utils.log_utils import init_logger @@ -15,6 +15,7 @@ from .router.manager import start_router_process from lightllm.utils.process_check import is_process_active from lightllm.utils.multinode_utils import send_and_receive_node_ip +from lightllm.utils.redis_utils import start_redis_service from lightllm.utils.shm_size_check import check_recommended_shm_size logger = init_logger(__name__) @@ -56,11 +57,12 @@ def signal_handler(sig, frame): signal.signal(signal.SIGINT, signal_handler) logger.info(f"start process pid {os.getpid()}") - logger.info(f"http server pid {http_server_process.pid}") + if http_server_process: + logger.info(f"http server pid {http_server_process.pid}") return -def normal_or_p_d_start(args): +def check_and_set_args(args): from lightllm.server.core.objs.start_args_type import StartArgs args: StartArgs = args @@ -75,7 +77,7 @@ def normal_or_p_d_start(args): enable_mps() - if args.run_mode not in ["normal", "prefill", "decode", "nixl_prefill", "nixl_decode"]: + if args.run_mode not in ["normal", "prefill", "decode", "nixl_prefill", "nixl_decode", "visual"]: return if args.enable_cpu_cache: @@ -150,6 +152,7 @@ def normal_or_p_d_start(args): assert args.mtp_draft_model_dir is None assert args.mtp_step == 0 + args.enable_multimodal = is_multimodal_mode(args) # 检查GPU数量是否足够 if args.visual_gpu_ids is None: args.visual_gpu_ids = list(range(args.visual_dp * args.visual_tp)) @@ -162,13 +165,11 @@ def normal_or_p_d_start(args): args.visual_gpu_ids = args.visual_gpu_ids[:total_required_gpus] # 检查visual_nccl_port数量是否足够 - if len(args.visual_nccl_ports) < args.visual_dp: + if args.visual_nccl_ports is not None and len(args.visual_nccl_ports) < args.visual_dp: raise ValueError( f"Not enough visual_nccl_ports specified. You need at least {args.visual_dp}, " f"but got ({len(args.visual_nccl_ports)})." ) - else: - args.visual_nccl_ports = args.visual_nccl_ports[: args.visual_dp] if args.visual_dp <= 0: raise ValueError("visual_dp must be a positive integer.") @@ -214,9 +215,13 @@ def normal_or_p_d_start(args): args.data_type = get_dtype(args.model_dir) assert args.data_type in ["fp16", "float16", "bf16", "bfloat16", "fp32", "float32"] - already_uesd_ports = args.visual_nccl_ports + [args.nccl_port, args.port] + +def normal_or_p_d_start(args): + + check_and_set_args(args) + already_uesd_ports = [args.nccl_port, args.port] if args.run_mode == "decode": - already_uesd_ports = args.visual_nccl_ports + [args.nccl_port, args.port, args.pd_decode_rpyc_port] + already_uesd_ports = [args.nccl_port, args.port, args.pd_decode_rpyc_port] # 提前锁定端口,防止在单个机器上启动多个实列的时候,要到模型启动的时候才能 # 捕获到端口设置冲突的问题 @@ -225,7 +230,7 @@ def normal_or_p_d_start(args): node_world_size = args.tp // args.nnodes can_use_ports = alloc_can_use_network_port( - num=8 + node_world_size + args.visual_dp * args.visual_tp, used_nccl_ports=already_uesd_ports + num=8 + node_world_size + args.visual_dp * args.visual_tp + args.visual_dp, used_nccl_ports=already_uesd_ports ) logger.info(f"alloced ports: {can_use_ports}") ( @@ -246,6 +251,9 @@ def normal_or_p_d_start(args): can_use_ports = can_use_ports[args.visual_tp :] visual_model_tp_ports.append(tp_ports_for_dp) + args.visual_nccl_ports = can_use_ports[0 : args.visual_dp] + can_use_ports = can_use_ports[args.visual_dp :] + # 将申请好的端口放入args参数中 args.router_port = router_port args.detokenization_port = detokenization_port @@ -273,7 +281,6 @@ def normal_or_p_d_start(args): logger.info(f"all start args:{args}") ports_locker.release_port() - if args.enable_multimodal: from .visualserver.manager import start_visual_process @@ -283,15 +290,6 @@ def normal_or_p_d_start(args): ], start_args=[(args,)], ) - process_manager.start_submodule_processes( - start_funcs=[ - start_visual_process, - ], - start_args=[ - (args, visual_model_tp_ports), - ], - ) - if args.enable_multimodal_audio: from .audioserver.manager import start_audio_process @@ -303,6 +301,15 @@ def normal_or_p_d_start(args): (args,), ], ) + if not args.enable_remote_vit: + process_manager.start_submodule_processes( + start_funcs=[ + start_visual_process, + ], + start_args=[ + (args, visual_model_tp_ports), + ], + ) if args.enable_cpu_cache: from .multi_level_kv_cache.manager import start_multi_level_kv_cache_manager @@ -431,6 +438,70 @@ def pd_master_start(args): http_server_process.wait() +def visual_start(args): + check_and_set_args(args) + already_uesd_ports = args.visual_nccl_ports + [args.nccl_port, args.remote_vit_port] + can_use_ports = alloc_can_use_network_port( + num=5 + args.visual_dp * args.visual_tp, used_nccl_ports=already_uesd_ports + ) + logger.info(f"alloced ports: {can_use_ports}") + ( + router_port, + visual_port, + audio_port, + cache_port, + metric_port, + ) = can_use_ports[0:5] + can_use_ports = can_use_ports[5:] + + visual_model_tp_ports = [] + for _ in range(args.visual_dp): + tp_ports_for_dp = can_use_ports[0 : args.visual_tp] + can_use_ports = can_use_ports[args.visual_tp :] + visual_model_tp_ports.append(tp_ports_for_dp) + + # 将申请好的端口放入args参数中 + args.router_port = router_port + args.visual_port = visual_port + args.audio_port = audio_port + args.cache_port = cache_port + args.metric_port = metric_port + args.visual_model_rpc_ports = visual_model_tp_ports + + # 远程vit server 需要一个唯一的id + args.visual_node_id = uuid.uuid4().int + + logger.info(f"all start args:{args}") + + set_env_start_args(args) + + from .visualserver.manager import start_visual_process + + process_manager.start_submodule_processes( + start_funcs=[ + start_cache_manager, + ], + start_args=[(args,)], + ) + process_manager.start_submodule_processes( + start_funcs=[ + start_visual_process, + ], + start_args=[ + (args, visual_model_tp_ports), + ], + ) + setup_signal_handlers(None, process_manager) + try: + while True: + time.sleep(1) + except KeyboardInterrupt: + logger.info("Received keyboard interrupt, shutting down...") + process_manager.terminate_all_processes() + logger.info("All processes have been terminated gracefully.") + sys.exit(0) + + def config_server_start(args): set_unique_server_name(args) if args.run_mode != "config_server": @@ -438,6 +509,9 @@ def config_server_start(args): logger.info(f"all start args:{args}") + if args.start_redis: + start_redis_service(args) + set_env_start_args(args) command = [ diff --git a/lightllm/server/config_server/api_http.py b/lightllm/server/config_server/api_http.py index 97551c6b6..5eb76ec08 100644 --- a/lightllm/server/config_server/api_http.py +++ b/lightllm/server/config_server/api_http.py @@ -9,6 +9,7 @@ from typing import Dict, List from fastapi.responses import JSONResponse from lightllm.utils.log_utils import init_logger +from lightllm.server.visualserver.vit_connect import VIT_Obj from ..pd_io_struct import PD_Master_Obj from .nccl_tcp_store import start_tcp_store_server from lightllm.utils.envs_utils import get_env_start_args, get_unique_server_name @@ -19,7 +20,9 @@ app = FastAPI() registered_pd_master_objs: Dict[str, PD_Master_Obj] = {} +registered_visual_server_objs: Dict[str, VIT_Obj] = {} registered_pd_master_obj_lock = Lock() +registered_visual_server_obj_lock = Lock() global_req_id = 0 global_req_id_lock = Lock() @@ -72,6 +75,30 @@ async def websocket_endpoint(websocket: WebSocket): return +@app.websocket("/visual_register") +async def visual_websocket_endpoint(websocket: WebSocket): + await websocket.accept() + client_ip, client_port = websocket.client + logger.info(f"ws connected from IP: {client_ip}, Port: {client_port}") + registered_visual_server_obj: VIT_Obj = pickle.loads(await websocket.receive_bytes()) + logger.info(f"recieved registered_visual_server_obj {registered_visual_server_obj}") + with registered_visual_server_obj_lock: + registered_visual_server_objs[registered_visual_server_obj.node_id] = registered_visual_server_obj + + try: + while True: + data = await websocket.receive_text() + assert data == "heartbeat" + except (WebSocketDisconnect, Exception, RuntimeError) as e: + logger.error(f"registered_visual_server_obj {registered_visual_server_obj} has error {str(e)}") + logger.exception(str(e)) + finally: + logger.error(f"registered_visual_server_obj {registered_visual_server_obj} removed") + with registered_visual_server_obj_lock: + registered_visual_server_objs.pop(registered_visual_server_obj.node_id, None) + return + + @app.get("/registered_objects") async def get_registered_objects(): with registered_pd_master_obj_lock: @@ -80,6 +107,14 @@ async def get_registered_objects(): return {"data": base64_encoded} +@app.get("/registered_visual_objects") +async def get_vit_registered_objects(): + with registered_visual_server_obj_lock: + serialized_data = pickle.dumps(registered_visual_server_objs) + base64_encoded = base64.b64encode(serialized_data).decode("utf-8") + return {"data": base64_encoded} + + @app.get("/allocate_global_unique_id_range") async def allocate_global_id_range(): """ diff --git a/lightllm/server/core/objs/io_objs/group_req.py b/lightllm/server/core/objs/io_objs/group_req.py index dfcbdd256..75f2c0e2f 100644 --- a/lightllm/server/core/objs/io_objs/group_req.py +++ b/lightllm/server/core/objs/io_objs/group_req.py @@ -23,7 +23,9 @@ def to_group_req_index(self): return GroupReqIndexes( group_req_id=self.group_req_id, multimodal_params=self.multimodal_params, - shm_req_indexes=[req.index_in_shm_mem for req in self.shm_req_objs], + shm_req_indexes=[req.index_in_shm_mem for req in self.shm_req_objs] + if self.shm_req_objs is not None + else None, time_mark=self.time_mark, ) diff --git a/lightllm/server/core/objs/req.py b/lightllm/server/core/objs/req.py index 947f24644..cff51f637 100644 --- a/lightllm/server/core/objs/req.py +++ b/lightllm/server/core/objs/req.py @@ -246,7 +246,6 @@ def can_release(self): # 只有管理节点有一个引用 ref_count_ok = self.ref_count == 1 can_released_mark = self.can_released_mark - if self.is_aborted and can_released_mark and ref_count_ok: return True diff --git a/lightllm/server/embed_cache/impl/memory_cache_with_redis.py b/lightllm/server/embed_cache/impl/memory_cache_with_redis.py new file mode 100644 index 000000000..05bd0bc23 --- /dev/null +++ b/lightllm/server/embed_cache/impl/memory_cache_with_redis.py @@ -0,0 +1,74 @@ +import uuid +import threading +import dataclasses +import requests +from typing import Union, Optional +import torch +import time +from collections import deque +import multiprocessing.shared_memory as shm +from ..utils import get_shm_name_data, get_shm_name_embed, free_shm, EmbedRefCountRedis +from .naive_memory_cache import Record, InMemoryCache +from lightllm.utils.log_utils import init_logger + +logger = init_logger(__name__) + + +class MemoryCacheWithRedis(InMemoryCache): + def __init__(self, args) -> None: + super().__init__(args) + redis_url = f"redis://{args.config_server_host}:{args.redis_port}/0" + self.redis_cache = EmbedRefCountRedis( + redis_url=redis_url, + capacity=args.cache_capacity, + evict_fraction=args.redis_evict_fraction, + image_embed_dir=args.image_embed_dir, + ) + # 这里之所以把cache * 2是因为,在分离模式下,cache 服务只是为了更新redis状态,以及维护图片cache的 token_id + # 便于 dynamic prompt cache 的使用。所以要把cache_capacity * 2,保障其保留的图片cache > redis 服务维护的 + # 硬盘里的图片image embed 数量。 + self.cache_capacity = args.cache_capacity * 2 + + # llm 负责release + def release(self, ids: list[int]) -> None: + with self.lock: + for id in ids: + self._records[id].ref -= 1 + if self.redis_cache.query(str(id)): + self.redis_cache.decr(str(id)) + # print(self.redis_cache.stats(), flush=True) + + # vit 负责set + def set_items_embed(self, ids: list[int]) -> None: + with self.lock: + for id in ids: + self.redis_cache.insert(str(id)) + self._records[id].embed = True + self._records[id].ref -= 1 + self.redis_cache.decr(str(id)) # vit端alloc之后ref+1 vit完成后ref-1 + + def get_items_embed(self, ids: list[int], embeding_only: bool = False) -> list[Optional[bool]]: + ret = [] + for id in ids: + if embeding_only: + exist = self.redis_cache.query(str(id)) + else: + exist = self.redis_cache.query_and_incre(str(id)) + ret.append(exist) + if exist: + self._records[id].embed = True + return ret + + # def get_items_embed_and_incre(self, ids: list[int]) -> list[Optional[bool]]: + # ret = [] + # for id in ids: + # # if self.redis_cache.query(str(id)): + # # ret.append(True) + # # continue + # # 避免重复的引用计数增加 + # if self._records[id].embed: + # ret.append(True) + # continue + # self._records[id].embed = self.redis_cache.query_and_incre(str(id)) + # ret.append(self._records[id].embed) + # return ret diff --git a/lightllm/server/embed_cache/impl/naive_memory_cache.py b/lightllm/server/embed_cache/impl/naive_memory_cache.py index 5477be22b..7f9ac58f4 100644 --- a/lightllm/server/embed_cache/impl/naive_memory_cache.py +++ b/lightllm/server/embed_cache/impl/naive_memory_cache.py @@ -7,7 +7,7 @@ import time from collections import deque import multiprocessing.shared_memory as shm -from ..utils import get_shm_name_data, get_shm_name_embed, free_shm +from ..utils import get_shm_name_data, get_shm_name_embed, free_shm, free_afs from lightllm.utils.log_utils import init_logger logger = init_logger(__name__) @@ -77,7 +77,12 @@ def _clear(self, free_max_count: int): if record.data: free_shm(get_shm_name_data(id)) if record.embed: - free_shm(get_shm_name_embed(id)) + # 仅vit释放掉afs里的, llm端不做释放 + # if self.args.run_mode == "visual": + # free_afs(get_shm_name_embed(id), self.args.image_embed_dir) + # elif not self.args.enable_remote_vit: + if not self.args.enable_remote_vit and self.args.run_mode != "visual": + free_shm(get_shm_name_embed(id)) del self._md5_to_record[record.md5sum] del self._records[id] self.occupied -= 1 @@ -103,7 +108,7 @@ def alloc(self, md5sum_list: list[str], token_num_list: list[int]) -> Optional[l rec.visittime = now rec.ref += 1 else: - uid_int = uuid.uuid1().int + uid_int = md5sum self._check_and_set_new_id_range(token_num) rec = Record( id=uid_int, @@ -139,5 +144,5 @@ def set_items_embed(self, ids: list[int]) -> None: for id_ in ids: self._records[id_].embed = True - def get_items_embed(self, ids: list[int]) -> list[Optional[bool]]: + def get_items_embed(self, ids: list[int], embeding_only: bool = False) -> list[Optional[bool]]: return [self._records.get(id_).embed if id_ in self._records else False for id_ in ids] diff --git a/lightllm/server/embed_cache/manager.py b/lightllm/server/embed_cache/manager.py index 5de4df4ab..0dc8830cd 100644 --- a/lightllm/server/embed_cache/manager.py +++ b/lightllm/server/embed_cache/manager.py @@ -6,6 +6,7 @@ from lightllm.server.core.objs import StartArgs from lightllm.utils.graceful_utils import graceful_registry from lightllm.server.embed_cache.impl.naive_memory_cache import InMemoryCache +from lightllm.server.embed_cache.impl.memory_cache_with_redis import MemoryCacheWithRedis from rpyc.utils.classic import obtain from lightllm.utils.envs_utils import get_unique_server_name @@ -25,6 +26,10 @@ def on_disconnect(self, conn): # (to finalize the service, if needed) pass + def exposed__check_and_set_new_id_range(self, token_num: int) -> int: + token_num = obtain(token_num) + return self._impl._check_and_set_new_id_range(token_num) + def exposed_alloc(self, md5sum_list: list[str], token_num_list: list[int]) -> Optional[list[dict]]: md5sum_list = obtain(md5sum_list) token_num_list = obtain(token_num_list) @@ -47,9 +52,16 @@ def exposed_set_items_embed(self, ids: list[int]) -> None: ids = obtain(ids) return self._impl.set_items_embed(ids) - def exposed_get_items_embed(self, ids: list[int]) -> list[bool]: + def exposed_get_items_embed(self, ids: list[int], embeding_only: bool = False) -> list[bool]: ids = obtain(ids) - return self._impl.get_items_embed(ids) + return self._impl.get_items_embed(ids, embeding_only) + + +def get_cache_manager(args): + if args.enable_remote_vit or args.run_mode == "visual": + return MemoryCacheWithRedis(args) + else: + return InMemoryCache(args) def start_cache_manager(args: StartArgs, pipe_writer): @@ -57,7 +69,7 @@ def start_cache_manager(args: StartArgs, pipe_writer): graceful_registry(inspect.currentframe().f_code.co_name) setproctitle.setproctitle(f"lightllm::{get_unique_server_name()}::cache_manager") - manager = InMemoryCache(args) + manager = get_cache_manager(args) service = CacheServer(manager) from rpyc.utils.server import ThreadedServer import lightllm.utils.rpyc_fix_utils as _ diff --git a/lightllm/server/embed_cache/utils.py b/lightllm/server/embed_cache/utils.py index 6df031293..66bb72674 100644 --- a/lightllm/server/embed_cache/utils.py +++ b/lightllm/server/embed_cache/utils.py @@ -1,7 +1,16 @@ +import os +import time import torch +import redis import numpy as np +from typing import List, Tuple from io import BytesIO +from pathlib import Path import multiprocessing.shared_memory as shm +from lightllm.utils.envs_utils import get_env_start_args +from lightllm.utils.log_utils import init_logger + +logger = init_logger(__name__) def tensor2bytes(t: torch.Tensor): @@ -35,21 +44,381 @@ def create_shm(name, data): print("Warning create shm {} failed because of FileExistsError!".format(name)) +def create_afs(name, data, path): + try: + data_size = len(data) + path = os.path.join(path, name) + with open(path, "xb") as f: + mem_view = memoryview(data) + f.write(mem_view[:data_size]) + f.flush() + os.fsync(f.fileno()) + except FileExistsError: + print("Warning create afs {} failed because of FileExistsError!".format(name)) + + def read_shm(name): shared_memory = shm.SharedMemory(name=name) data = shared_memory.buf.tobytes() return data +def read_afs(name: str, base_dir) -> bytes: + + path = Path(base_dir) / name + return path.read_bytes() + + def free_shm(name): shared_memory = shm.SharedMemory(name=name) shared_memory.close() shared_memory.unlink() +def free_afs(name: str, base_dir) -> None: + path = Path(base_dir) / name + path.unlink() + + def get_shm_name_data(uid): return str(uid) + "-data" def get_shm_name_embed(uid): return str(uid) + "-embed" + + +""" +Importable Redis-backed MD5 refcount with LRU eviction. + +Public API: + from md5_refcount import EmbedRefCountRedis + + cache = EmbedRefCountRedis( + redis_url="redis://localhost:6379/0", + capacity=10000, + evict_fraction=0.2 + ) + + # Insert a new md5 with default ref_count=0 + success, evicted_list = cache.insert(md5) + + # Query if exists and increment ref_count if found + exists = cache.query_and_incre(md5) + + # Decrement ref_count + rc, deleted = cache.decr(md5) + + s = cache.stats() +""" + + +class EmbedRefCountRedis: + def __init__( + self, + redis_url: str = "redis://localhost:6379/0", + capacity: int = 50000, + evict_fraction: float = 0.1, + key_prefix: str = "md5:", + image_embed_dir: str = None, + path_ext: str = "-embed", + **redis_kwargs, + ) -> None: + """ + - capacity: max count of md5 entries allowed in Redis + - evict_fraction: fraction to evict when inserting a NEW md5 and at capacity + - image_embed_dir: base directory for image embed files (e.g., "/afs/embeds") + - path_ext: file extension for embed files (default: "-embed") + """ + if not (0.0 <= evict_fraction <= 1.0): + raise ValueError("evict_fraction must be 0..1") + if capacity < 1: + raise ValueError("capacity must be >=1") + + self.capacity = int(capacity) + self.evict_fraction = float(evict_fraction) + self.zset_key = f"{key_prefix}lru" + self.ref_prefix = f"{key_prefix}rc:" + self.lock_key = f"{key_prefix}evict:lock" + self.image_embed_dir = image_embed_dir + self.path_ext = path_ext + + self.r = redis.Redis.from_url(redis_url, decode_responses=True, **redis_kwargs) + + # Register Lua scripts + self._insert_script = self.r.register_script(self._INSERT_LUA) + self._query_incre_script = self.r.register_script(self._QUERY_INCRE_LUA) + self._decr_script = self.r.register_script(self._DECR_LUA) + self._evict_and_insert_script = self.r.register_script(self._EVICT_AND_INSERT_LUA) + + def insert(self, md5: str) -> Tuple[bool, List[str]]: + """Insert a new md5 with default ref_count=1. May trigger LRU eviction.""" + # 等待任何正在进行的逐出操作 + self._wait_if_eviction() + + res = self._insert_script( + keys=[self.zset_key, self.ref_prefix], + args=[md5, self.capacity, self.evict_fraction], + ) + + if res[0] == 0: # No eviction needed + return True, [] + + # Need eviction - use atomic eviction script + try: + if self._try_acquire_lock(): + try: + # 原子执行逐出和插入 + evict_res = self._evict_and_insert_script( + keys=[self.zset_key, self.ref_prefix], + args=[md5, self.capacity, self.evict_fraction], + ) + success = bool(evict_res[0]) + victims = evict_res[1:] if len(evict_res) > 1 else [] + + if success: + # 删除被逐出md5对应的AFS文件 + if victims and self.image_embed_dir: + self._delete_afs_files(victims) + return True, victims + else: + # 逐出失败,短暂退避后重试 + time.sleep(0.01) + return self.insert(md5) + finally: + self._release_lock() + else: + # 等待锁释放后重试 + time.sleep(0.01) + return self.insert(md5) + except Exception as e: + self._release_lock() + raise e + + def query(self, md5: str) -> bool: + """Quert if md5 exists.""" + self._wait_if_eviction() + return bool(self.r.exists(self.ref_prefix + md5)) + + def query_and_incre(self, md5: str) -> bool: + """Query if md5 exists and increment ref_count if found.""" + self._wait_if_eviction() + res = self._query_incre_script( + keys=[self.zset_key, self.ref_prefix], + args=[md5], + ) + return bool(res[0]) + + def decr(self, md5: str) -> Tuple[int, bool]: + """Decrement ref_count for md5. Returns (ref_count, deleted).""" + self._wait_if_eviction() + + res = self._decr_script( + keys=[self.zset_key, self.ref_prefix], + args=[md5], + ) + if res[0] == -1: + raise KeyError("md5 not found") + return int(res[0]), bool(res[1]) + + def stats(self) -> dict: + self._wait_if_eviction() + + size = self.r.zcard(self.zset_key) + return { + "items": size, + "capacity": self.capacity, + "evict_fraction": self.evict_fraction, + } + + def get_ref(self, md5: str) -> int | None: + self._wait_if_eviction() + val = self.r.get(self.ref_prefix + md5) + return int(val) if val is not None else None + + def _wait_if_eviction(self) -> None: + max_wait = 30 + start_time = time.time() + + while self.r.exists(self.lock_key): + if time.time() - start_time > max_wait: + raise TimeoutError("Eviction operation timeout, waited too long") + time.sleep(0.01) # 短暂等待 + + def _try_acquire_lock(self) -> bool: + return bool(self.r.set(self.lock_key, "1", nx=True, ex=30)) + + def _release_lock(self) -> None: + try: + self.r.delete(self.lock_key) + except Exception: + pass + + def _md5_to_afs_path(self, md5: str) -> str: + """Convert md5 to AFS file path.""" + if not self.image_embed_dir: + return None + filename = self.image_embed_dir + md5 + self.path_ext + return filename + + def _delete_afs_files(self, victims: List[str]) -> None: + """Delete AFS files for evicted md5s.""" + if not self.image_embed_dir: + return + + for md5 in victims: + try: + file_path = self._md5_to_afs_path(md5) + if file_path and os.path.exists(file_path): + os.remove(file_path) + logger.debug(f"Deleted AFS file: {file_path}") + except Exception as e: + logger.debug(f"Warning: Failed to delete AFS file for {md5}: {e}") + + # ---------------- Lua scripts ---------------- + _INSERT_LUA = r""" +-- KEYS[1] = zset key, KEYS[2] = ref_prefix +-- ARGV[1] = md5, ARGV[2] = capacity, ARGV[3] = evict_fraction +local zset = KEYS[1] +local ref_prefix = KEYS[2] +local md5 = ARGV[1] +local capacity = tonumber(ARGV[2]) + +local unpack = unpack or table.unpack +local ref_key = ref_prefix .. md5 +if redis.call('GET', ref_key) then + return {0} -- Already exists +end + +local size = redis.call('ZCARD', zset) +if size < capacity then + -- Insert with ref_count=1 + redis.call('SET', ref_key, 1) + local now = redis.call('TIME')[1] * 1000 + redis.call('ZADD', zset, now, md5) + return {0} -- Success, no eviction +end + +return {1} -- Need eviction +""" + + _QUERY_INCRE_LUA = r""" +-- KEYS[1] = zset key, KEYS[2] = ref_prefix +-- ARGV[1] = md5 +local zset = KEYS[1] +local ref_prefix = KEYS[2] +local md5 = ARGV[1] + +local ref_key = ref_prefix .. md5 +local val = redis.call('GET', ref_key) + +if not val then + return {0} -- Not found +end + +-- Found, increment ref_count and update LRU +local rc = tonumber(val) + 1 +redis.call('SET', ref_key, rc) +local now = redis.call('TIME')[1] * 1000 +redis.call('ZADD', zset, now, md5) +return {1} -- Found and incremented +""" + + _DECR_LUA = r""" +-- KEYS[1] = zset key, KEYS[2] = ref_prefix +-- ARGV[1] = md5 +local zset = KEYS[1] +local ref_prefix = KEYS[2] +local md5 = ARGV[1] + +local ref_key = ref_prefix .. md5 +local val = redis.call('GET', ref_key) + +if not val then + return {-1, 0} -- Not found +end + +--ref 递减到 0 时保留键,只更新计数与 LRU +local rc = tonumber(val) - 1 +if rc < 0 then rc = 0 end +redis.call('SET', ref_key, rc) + +if rc > 0 then + -- 只有仍被引用时才更新 LRU + local now = redis.call('TIME')[1] * 1000 + redis.call('ZADD', zset, now, md5) +end + +return {rc, 0} +""" + + _EVICT_AND_INSERT_LUA = r""" +-- KEYS[1] = zset key, KEYS[2] = ref_prefix +-- ARGV[1] = new_md5, ARGV[2] = capacity, ARGV[3] = evict_fraction +local zset = KEYS[1] +local ref_prefix = KEYS[2] +local new_md5 = ARGV[1] +local capacity = tonumber(ARGV[2]) +local evict_fraction = tonumber(ARGV[3]) + +local unpack = unpack or table.unpack + +-- helper: now millis +local function now_ms() + local t = redis.call('TIME') + return t[1] * 1000 + math.floor(t[2] / 1000) +end + +local new_ref_key = ref_prefix .. new_md5 + +-- If already exists, treat as a hit: bump ref_count and refresh LRU +local cur = redis.call('GET', new_ref_key) +if cur then + local rc = tonumber(cur) + 1 + redis.call('SET', new_ref_key, rc) + redis.call('ZADD', zset, now_ms(), new_md5) + return {1} -- success, no victims +end + +-- If not at capacity, just insert +local size = redis.call('ZCARD', zset) +if size < capacity then + redis.call('SET', new_ref_key, 1) + redis.call('ZADD', zset, now_ms(), new_md5) + return {1} -- success, no victims +end + +-- At capacity: try to evict up to max_try items with rc==0, but success if at least 1 is freed +local max_try = math.max(1, math.floor(size * evict_fraction + 0.5)) +local victims = {} +local freed = 0 + +-- Scan from LRU (smallest score) to MRU +local all_keys = redis.call('ZRANGE', zset, 0, -1, 'WITHSCORES') +local i = 1 +while freed < 1 and i <= #all_keys and #victims < max_try do + local md5 = all_keys[i] + local ref_key = ref_prefix .. md5 + local v = redis.call('GET', ref_key) + if v and tonumber(v) <= 0 then + table.insert(victims, md5) + freed = freed + 1 + end + i = i + 2 -- skip score +end + +if freed >= 1 then + -- delete victims + for _, v in ipairs(victims) do + redis.call('DEL', ref_prefix .. v) + redis.call('ZREM', zset, v) + end + -- insert new + redis.call('SET', new_ref_key, 1) + redis.call('ZADD', zset, now_ms(), new_md5) + return {1, unpack(victims)} +else + -- no zero-ref items found + return {0} +end +""" diff --git a/lightllm/server/httpserver/manager.py b/lightllm/server/httpserver/manager.py index 11919398e..4cfedb1e6 100644 --- a/lightllm/server/httpserver/manager.py +++ b/lightllm/server/httpserver/manager.py @@ -14,6 +14,7 @@ asyncio.set_event_loop_policy(uvloop.EventLoopPolicy()) from typing import Union, List, Tuple, Dict, Optional +from lightllm.server.core.objs.io_objs.group_req import GroupReqIndexes from websockets import ClientConnection from fastapi import Request from ..tokenizer import get_tokenizer @@ -81,8 +82,10 @@ def __init__( if self.enable_multimodal: self.cache_client = rpyc.connect("localhost", args.cache_port, config={"allow_pickle": True}) self.cache_client._channel.stream.sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) - self.send_to_visual = context.socket(zmq.PUSH) - self.send_to_visual.connect(f"{args.zmq_mode}127.0.0.1:{args.visual_port}") + # 初始化VIT连接管理器 + from lightllm.server.visualserver.vit_connect import VITConnectionManager + + self.vit_manager = VITConnectionManager(args, context, args.visual_port, self.cache_client) if args.enable_cpu_cache and not self.args.enable_multimodal: self.send_to_multi_level_kv_cache = context.socket(zmq.PUSH) self.send_to_multi_level_kv_cache.connect(f"{args.zmq_mode}127.0.0.1:{args.multi_level_kv_cache_port}") @@ -116,10 +119,10 @@ def __init__( self.latest_success_infer_time_mark.set_value(int(time.time())) return - async def _alloc_resource(self, items, md5sums, token_nums, datas): + async def _alloc_resource(self, items, uuids, token_nums, datas): while True: - records = obtain(self.cache_client.root.alloc(md5sums, token_nums)) + records = obtain(self.cache_client.root.alloc(uuids, token_nums)) if records is None: await asyncio.sleep(0.1) @@ -132,6 +135,10 @@ async def _alloc_resource(self, items, md5sums, token_nums, datas): item.token_num = rec["token_num"] uid_list.append(rec["id"]) + # # If enable the vit/audio-llm disaggregation, no need to cache the data in the memory of the server + if self.args.enable_remote_vit: + return + ready_flags = obtain(self.cache_client.root.get_items_data(uid_list)) update_data_ids = [] @@ -151,14 +158,15 @@ async def _alloc_multimodal_resources(self, multimodal_params: MultimodalParams, # 如果不加任何锁,假如请求1和请求2都有6张图片,而cache_capacity为10, # 那么如果某一时刻shm中存在请求1的5张图和请求2的5张图,将会资源竞争产生死锁。 async with self._resource_lock: - items, md5sums, tokens_nums, datas = [], [], [], [] + items, uuids, tokens_nums, datas = [], [], [], [] for img in multimodal_params.images: self.tokenizer.init_imageitem_extral_params(img, multimodal_params, sampling_params) data = img.read() # must after init_imageitem_extral_params token_num = self.tokenizer.get_image_token_length(img) - md5sum = hashlib.md5(data).hexdigest() + "_" + str(hash(frozendict(img.extra_params))) - md5sums.append(md5sum) + md5sum = "{}_{}".format(hashlib.md5(data).hexdigest(), img.patch_num) + uuid = int(md5sum, 16) + uuids.append(uuid) tokens_nums.append(token_num) datas.append(data) items.append(img) @@ -166,13 +174,17 @@ async def _alloc_multimodal_resources(self, multimodal_params: MultimodalParams, self.tokenizer.init_audioitem_extral_params(audio, multimodal_params, sampling_params) data = audio.read() token_num = self.tokenizer.get_audio_token_length(audio) - md5sum = hashlib.md5(data).hexdigest() + "_" + str(hash(frozendict(audio.extra_params))) - md5sums.append(md5sum) + md5sum = "{}_{}".format( + hashlib.md5(data).hexdigest(), + hashlib.md5(pickle.dumps(audio.extra_params, protocol=4)).hexdigest(), + ) + uuid = int(md5sum, 16) + uuids.append(uuid) tokens_nums.append(token_num) datas.append(data) items.append(audio) - await self._alloc_resource(items, md5sums, tokens_nums, datas) + await self._alloc_resource(items, uuids, tokens_nums, datas) return async def _release_multimodal_resources(self, multimodal_params: MultimodalParams): @@ -194,8 +206,8 @@ async def _release_multimodal_resources(self, multimodal_params: MultimodalParam audio.uuid = None audio.token_id = None audio.token_num = None - if ids_to_release: - self.cache_client.root.release(ids_to_release) + # if ids_to_release: + # self.cache_client.root.release(ids_to_release) return def tokens(self, prompt, multimodal_params, samping_params: SamplingParams, kwargs=None): @@ -385,6 +397,49 @@ async def generate( raise e return + async def get_image_embeding( + self, + sampling_params: SamplingParams, + multimodal_params: MultimodalParams, + request: Request, + is_health_req: bool = False, + ) -> Tuple[int, str, dict, FinishStatus]: + start_time = time.time() + request_headers = request.headers if request is not None else {} + group_request_id = self.alloc_req_id(sampling_params, is_health_req) + + try: + original_multimodal_params = None + if self.is_multinode_tp_master: + original_multimodal_params = copy.deepcopy(multimodal_params) + + if self.pd_mode.is_P_or_NORMAL(): + await multimodal_params.verify_and_preload(request) + + await multimodal_params.verify_and_preload(request) + image_count = len(multimodal_params.images) + # 记录请求到达的相关信息 + + await self._log_req_header(request_headers, group_request_id) + logger.info(f"image_count:{image_count}") + assert ( + len(multimodal_params.images + multimodal_params.audios) <= self.args.cache_capacity + ), "too many multimodal items!" + + await self._alloc_multimodal_resources(multimodal_params, sampling_params) + + visual_req_status = GroupReqObjs(group_request_id, multimodal_params, None, start_time) + + await self.transfer_to_next_module_or_node( + None, sampling_params, original_multimodal_params, visual_req_status, embeding_only=True + ) + + except Exception as e: + logger.error(f"group_request_id: {group_request_id} has exception {str(e)}") + await self.abort(group_request_id) + raise e + return + def _count_multimodal_tokens(self, multimodal_params: MultimodalParams) -> Tuple[int, int]: image_tokens = 0 audio_tokens = 0 @@ -479,6 +534,7 @@ async def transfer_to_next_module_or_node( sampling_params: SamplingParams, original_multimodal_params: MultimodalParams, group_req_objs: Optional[GroupReqObjs] = None, + embeding_only: Optional[bool] = False, ): # 多节点纯tp 运行模式下,master 节点需要将请求转发给slave节点. if self.is_multinode_tp_master: @@ -488,21 +544,22 @@ async def transfer_to_next_module_or_node( protocol=pickle.HIGHEST_PROTOCOL, ) - await self.transfer_to_next_module(group_req_objs) + await self.transfer_to_next_module(group_req_objs, embeding_only) return async def transfer_to_next_module( self, group_req_objs: Optional[GroupReqObjs] = None, + embeding_only: Optional[bool] = False, ): if self.pd_mode.is_P_or_NORMAL(): if self.enable_multimodal: - self.send_to_visual.send_pyobj( + await self.vit_manager.send_to_vit( group_req_objs.to_group_req_index(), protocol=pickle.HIGHEST_PROTOCOL, + embeding_only=embeding_only, ) - return if self.args.enable_cpu_cache: self.send_to_multi_level_kv_cache.send_pyobj( @@ -511,14 +568,15 @@ async def transfer_to_next_module( ) return - self.send_to_router.send_pyobj( - group_req_objs.to_group_req_index(), - protocol=pickle.HIGHEST_PROTOCOL, - ) + if not self.enable_multimodal or self.args.enable_remote_vit: + self.send_to_router.send_pyobj( + group_req_objs.to_group_req_index(), + protocol=pickle.HIGHEST_PROTOCOL, + ) return if self.pd_mode.is_D(): - # 在 D 模式下,不需要传输真的多模态参数,因为其已经被 P 处理好了 + # 在 D 模式下,不需要传输真的多模态参数,因为其已经被 P 处理好了, 传输一个空的即可 self.send_to_router.send_pyobj( group_req_objs.to_group_req_index(), protocol=pickle.HIGHEST_PROTOCOL, @@ -696,6 +754,9 @@ async def handle_loop(self): asyncio.create_task(pd_handle_loop(self)) + if self.enable_multimodal: + asyncio.create_task(self.vit_manager.vit_handle_loop()) + while True: try: await asyncio.wait_for(self.zmq_recv_socket.recv_pyobj(), timeout=0.05) diff --git a/lightllm/server/multimodal_params.py b/lightllm/server/multimodal_params.py index 066fe5cc2..01fc0af1e 100644 --- a/lightllm/server/multimodal_params.py +++ b/lightllm/server/multimodal_params.py @@ -24,6 +24,7 @@ def __init__(self, **kwargs): self.token_num = None # the audio length self.audio_length = None + self.afs_embed = False self._preload_data = None self.extra_params = {} @@ -52,10 +53,7 @@ async def preload(self, request: Request): def read(self): assert self._preload_data is not None - ans = self._preload_data - self._preload_data = None - self._data = None - return ans + return self._preload_data def to_dict(self): ret = {} @@ -77,6 +75,7 @@ def __init__(self, **kwargs): self.token_num = None self.image_w = 0 self.image_h = 0 + self.patch_num = 0 self._preload_data = None self.extra_params = {} @@ -110,10 +109,11 @@ async def preload(self, request: Request): def read(self): assert self._preload_data is not None - ans = self._preload_data + return self._preload_data + + def free(self): self._preload_data = None self._data = None - return ans def to_dict(self): ret = {} @@ -142,6 +142,15 @@ def __init__( self.audios = [AudioItem(**a) for a in audios] return + def free(self): + for image in self.images: + image.free() + for audio in self.audios: + audio.free() + + def get_all_uuids(self): + return [image.uuid for image in self.images] + [audio.uuid for audio in self.audios] + async def verify_and_preload(self, request: Request): for image in self.images: await image.preload(request) diff --git a/lightllm/server/pd_io_struct.py b/lightllm/server/pd_io_struct.py index 5002e4f1c..cc1e7453e 100644 --- a/lightllm/server/pd_io_struct.py +++ b/lightllm/server/pd_io_struct.py @@ -20,6 +20,7 @@ class NodeRole(enum.Enum): NORMAL = "normal" PD_MASTER = "pd_master" + LLM_ONLY = "llm_only" def is_D(self): return self == NodeRole.D or self == NodeRole.ND @@ -34,7 +35,7 @@ def is_ND(self): return self == NodeRole.ND def is_normal(self): - return self == NodeRole.NORMAL + return (self == NodeRole.NORMAL) or (self == NodeRole.LLM_ONLY) def is_P_or_NORMAL(self): return self.is_P() or self.is_normal() diff --git a/lightllm/server/visualserver/manager.py b/lightllm/server/visualserver/manager.py index b7e1ac10c..06a272b40 100644 --- a/lightllm/server/visualserver/manager.py +++ b/lightllm/server/visualserver/manager.py @@ -1,14 +1,26 @@ import zmq +import time import zmq.asyncio import asyncio import uvloop import rpyc import socket import pickle +import hashlib +import datetime import inspect +from fastapi import Request +from ..tokenizer import get_tokenizer import setproctitle from typing import List from lightllm.server.core.objs.io_objs.group_req import GroupReqIndexes +from lightllm.server.embed_cache.utils import get_shm_name_data, create_shm +from lightllm.server.core.objs import SamplingParams +from lightllm.server.core.objs import Req, FinishStatus +from typing import Union, Tuple, Dict, Optional +from ..req_id_generator import ReqIDGenerator +from lightllm.server.core.objs.io_objs import GroupReqObjs +from lightllm.server.embed_cache.impl.memory_cache_with_redis import MemoryCacheWithRedis from lightllm.server.core.objs import ShmReqManager, StartArgs asyncio.set_event_loop_policy(uvloop.EventLoopPolicy()) @@ -30,6 +42,8 @@ def __init__( args: StartArgs, visual_model_rpc_ports, ): + self.args = args + self.remote_vit = args.enable_remote_vit or args.run_mode == "visual" context = zmq.Context(2) if args.enable_multimodal_audio: @@ -48,48 +62,45 @@ def __init__( self.cache_client = rpyc.connect("localhost", args.cache_port, config={"allow_pickle": True}) self.cache_client._channel.stream.sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) self.cache_port = args.cache_port + self.waiting_reqs: List[GroupReqIndexes] = [] - self.model_weightdir = args.model_dir - self.tp_world_size = args.tp - self.vit_dp = args.visual_dp - self.vit_tp = args.visual_tp self.infer_batch_size = args.visual_infer_batch_size self.trust_remote_code = args.trust_remote_code - self.args = args self.visual_model_rpc_ports = visual_model_rpc_ports self.shm_req_manager = ShmReqManager() + self._setup_connections() - async def wait_to_model_ready(self): + def _setup_connections(self): + context = zmq.Context(2) + if self.remote_vit: + self.vit_receiver = context.socket(zmq.PULL) + self.vit_receiver.bind(f"tcp://*:{self.args.remote_vit_port}") + else: + self.vit_receiver = context.socket(zmq.PULL) + self.vit_receiver.bind(f"{self.args.zmq_mode}127.0.0.1:{self.args.visual_port}") + self.cache_client = rpyc.connect("localhost", self.cache_port, config={"allow_pickle": True}) + self.cache_client._channel.stream.sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) - self.model_rpcs: List[List[VisualModelRpcClient]] = [[] for _ in range(self.vit_dp)] + async def wait_to_model_ready(self): + visual_dp = self.args.visual_dp + visual_tp = self.args.visual_tp + self.model_rpcs: List[List[VisualModelRpcClient]] = [[] for _ in range(visual_dp)] - for dp_rank_id in range(self.vit_dp): + for dp_rank_id in range(visual_dp): tp_ports_each_dp = self.visual_model_rpc_ports[dp_rank_id] - for tp_rank_id in range(self.vit_tp): - device_id = self.args.visual_gpu_ids[dp_rank_id * self.vit_tp + tp_rank_id] + for tp_rank_id in range(visual_tp): + device_id = self.args.visual_gpu_ids[dp_rank_id * visual_tp + tp_rank_id] rpc_model = await start_model_process( - port=tp_ports_each_dp[tp_rank_id], vit_tp=self.vit_tp, device_id=device_id + port=tp_ports_each_dp[tp_rank_id], vit_tp=visual_tp, device_id=device_id ) self.model_rpcs[dp_rank_id].append(rpc_model) init_model_ret = [] - for dp_rank_id in range(self.vit_dp): # async init model process - for tp_rank_id in range(self.vit_tp): + for dp_rank_id in range(visual_dp): # async init model process + for tp_rank_id in range(visual_tp): kvargs = { - "weight_dir": self.model_weightdir, - "trust_remote_code": self.trust_remote_code, - "vit_dp": self.vit_dp, - "vit_tp": self.vit_tp, - "cache_port": self.cache_port, "tp_rank_id": tp_rank_id, "dp_rank_id": dp_rank_id, - "vit_rank_id": dp_rank_id * self.vit_tp + tp_rank_id, - "data_type": self.args.data_type, - "visual_nccl_port": self.args.visual_nccl_ports[dp_rank_id], - "visual_gpu_ids": self.args.visual_gpu_ids, - "quant_type": self.args.vit_quant_type, - "quant_cfg": self.args.vit_quant_cfg, - "max_batch_size": min(self.infer_batch_size // self.vit_dp, 1), } init_model_ret.append(self.model_rpcs[dp_rank_id][tp_rank_id].init_model(kvargs)) await asyncio.gather(*init_model_ret) @@ -100,13 +111,12 @@ async def infer_imgs(self, images: List[ImageItem]): return tasks = [] - for vit_dp_rank in range(self.vit_dp): - assigned_images = [images[i] for i in range(vit_dp_rank, len(images), self.vit_dp)] + for vit_dp_rank in range(self.args.visual_dp): + assigned_images = [images[i] for i in range(vit_dp_rank, len(images), self.args.visual_dp)] if assigned_images: - for vit_tp_rank in range(self.vit_tp): + for vit_tp_rank in range(self.args.visual_tp): task = asyncio.create_task(self.model_rpcs[vit_dp_rank][vit_tp_rank].encode(assigned_images)) tasks.append(task) - await asyncio.gather(*tasks) return @@ -164,6 +174,47 @@ async def loop_for_fwd(self): processing_group_reqs = [] images_need_infer = [] + async def _recv_reqs(self): + if self.remote_vit: + recv_req: GroupReqIndexes = self.vit_receiver.recv_pyobj(zmq.NOBLOCK) + # recv_req.multimodal_params.images[:]= [ + # img for img in recv_req.multimodal_params.images + # if not self.cache_client.root.get_item_embed(img.uuid) # embed已存在的被丢弃 , ref +1 + # ] + logger.info(f"Receive req {recv_req.group_req_id}, image_count:{len(recv_req.multimodal_params.images)}") + uuids = [img.uuid for img in recv_req.multimodal_params.images] + already_embed = await asyncio.to_thread(self.cache_client.root.get_items_embed, uuids) + if all(already_embed): + return None + + uuids = [] + token_nums = [] + datas = [] + for img, embed in zip(recv_req.multimodal_params.images, already_embed): + if not embed: + uuids.append(img.uuid) + token_nums.append(img.token_num) + datas.append(img._preload_data) + img.free() + while True: + records = await asyncio.to_thread(self.cache_client.root.alloc, uuids, token_nums) + if records is not None: + break + await asyncio.sleep(0.01) + ready_flags = obtain(self.cache_client.root.get_items_data(uuids)) + update_data_ids = [] + + for uid, ready, data in zip(uuids, ready_flags, datas): + if not ready: + create_shm(get_shm_name_data(uid), data) + update_data_ids.append(uid) + + if update_data_ids: + await asyncio.to_thread(self.cache_client.root.set_items_data, update_data_ids) + return recv_req + else: + return self.vit_receiver.recv_pyobj(zmq.NOBLOCK) + async def loop_for_netio_req(self): if not hasattr(self, "visual_recv_max_count"): self.visual_recv_max_count = 64 @@ -171,17 +222,47 @@ async def loop_for_netio_req(self): while True: try: for _ in range(self.visual_recv_max_count): - recv_req: GroupReqIndexes = self.zmq_recv_socket.recv_pyobj(zmq.NOBLOCK) + recv_req: GroupReqIndexes = await self._recv_reqs() + if recv_req is None: + continue if isinstance(recv_req, GroupReqIndexes): self.waiting_reqs.append(recv_req) else: assert False, f"Error Req Inf {recv_req}" - self.visual_recv_max_count = min(self.visual_recv_max_count * 1.3, 256) + await asyncio.sleep(0) + self.visual_recv_max_count = min(int(self.visual_recv_max_count * 1.3), 256) except zmq.ZMQError: # 当队列已经开始清空的时候,将一次接受数量下调 self.visual_recv_max_count = 64 + except Exception as e: + logger.exception(f"Error in loop_for_netio_req: {e}") + raise e await asyncio.sleep(0.01) + # code for visual only mode + async def loop_for_fwd_visual_only(self): + while True: + if len(self.waiting_reqs) == 0: + await asyncio.sleep(0.01) # 10ms + else: + images_need_infer = [] + + while len(self.waiting_reqs) > 0: + visual_req = self.waiting_reqs.pop(0) + + for img in visual_req.multimodal_params.images: + images_need_infer.append(img) + + if len(images_need_infer) == self.infer_batch_size: + await self.infer_imgs(images_need_infer) + images_need_infer = [] + + if len(images_need_infer) > 0: + await self.infer_imgs(images_need_infer) + images_need_infer = [] + # 在这里release这个image,ref-1 + logger.info(f"req-id {visual_req.group_req_id} has been release ok") + def clean_up(self): for model_rpc in self.model_rpcs: model_rpc.rpc_server_process.kill() @@ -190,6 +271,17 @@ def clean_up(self): return +def create_forward_loop(args, visualserver: VisualManager, loop: asyncio.AbstractEventLoop): + if args.run_mode == "visual": + from .register_loop import register_loop + + loop.create_task(visualserver.loop_for_fwd_visual_only()) + loop.create_task(register_loop(args)) + else: + loop.create_task(visualserver.loop_for_fwd()) + return + + def start_visual_process(args, model_rpc_ports, pipe_writer): # 注册graceful 退出的处理 graceful_registry(inspect.currentframe().f_code.co_name) @@ -211,6 +303,6 @@ def handle_exception(loop, context): loop = asyncio.new_event_loop() loop.set_exception_handler(handle_exception) asyncio.set_event_loop(loop) - loop.create_task(visualserver.loop_for_fwd()) + create_forward_loop(args, visualserver, loop) loop.run_until_complete(visualserver.loop_for_netio_req()) return diff --git a/lightllm/server/visualserver/model_infer/model_rpc.py b/lightllm/server/visualserver/model_infer/model_rpc.py index a9409ceb9..5e2a96a40 100644 --- a/lightllm/server/visualserver/model_infer/model_rpc.py +++ b/lightllm/server/visualserver/model_infer/model_rpc.py @@ -18,12 +18,23 @@ from lightllm.models.qwen2_vl.qwen2_visual import Qwen2VisionTransformerPretrainedModel from lightllm.models.qwen2_5_vl.qwen2_5_visual import Qwen2_5_VisionTransformerPretrainedModel from lightllm.models.tarsier2.tarsier2_visual import TarsierVisionTransformerPretrainedModel -from lightllm.server.embed_cache.utils import tensor2bytes, read_shm, create_shm, get_shm_name_data, get_shm_name_embed +from lightllm.server.embed_cache.utils import ( + tensor2bytes, + read_shm, + create_shm, + create_afs, + get_shm_name_data, + get_shm_name_embed, +) from lightllm.utils.infer_utils import set_random_seed from lightllm.utils.infer_utils import calculate_time, mark_start, mark_end from lightllm.utils.dist_utils import init_vision_distributed_env from lightllm.utils.graceful_utils import graceful_registry from lightllm.utils.envs_utils import get_env_start_args +from lightllm.utils.log_utils import init_logger +import pickle + +logger = init_logger(__name__) class VisualModelRpcServer(rpyc.Service): @@ -32,16 +43,22 @@ def exposed_init_model(self, kvargs): import torch import torch.distributed as dist - self.vit_dp = kvargs["vit_dp"] - self.vit_tp = kvargs["vit_tp"] + self.args = get_env_start_args() + + weight_dir = self.args.model_dir + cache_port = self.args.cache_port + data_type = self.args.data_type + quant_type = self.args.vit_quant_type + quant_cfg = self.args.vit_quant_cfg + max_batch_size = min(self.args.visual_infer_batch_size // self.args.visual_dp, 1) + remote_vit = True if self.args.run_mode == "visual" else False + + self.image_embed_dir = self.args.image_embed_dir self.dp_rank_id = kvargs["dp_rank_id"] self.tp_rank_id = kvargs["tp_rank_id"] - self.cache_port = kvargs["cache_port"] - weight_dir = kvargs["weight_dir"] - self.vit_rank_id = kvargs["vit_rank_id"] - self.cache_client = rpyc.connect("localhost", self.cache_port, config={"allow_pickle": True}) + kvargs["vit_rank_id"] = self.dp_rank_id * self.args.visual_tp + self.tp_rank_id + self.cache_client = rpyc.connect("localhost", cache_port, config={"allow_pickle": True}) self.cache_client._channel.stream.sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) - self.data_type = kvargs["data_type"] init_vision_distributed_env(kvargs) model_cfg, _ = PretrainedConfig.get_config_dict(weight_dir) @@ -49,14 +66,15 @@ def exposed_init_model(self, kvargs): try: kvargs = { "weight_dir": weight_dir, - "data_type": self.data_type, - "quant_type": kvargs["quant_type"], - "quant_cfg": kvargs["quant_cfg"], - "max_batch_size": kvargs["max_batch_size"], + "data_type": data_type, + "quant_type": quant_type, + "quant_cfg": quant_cfg, + "max_batch_size": max_batch_size, + "remote_vit": remote_vit, } self.model_type = model_cfg["model_type"] if self.model_type == "qwen": - self.model = QWenVisionTransformer(**model_cfg["visual"]).eval().bfloat16() + self.model = QWenVisionTransformer(kvargs, **model_cfg["visual"]).eval().bfloat16() elif self.model_type == "qwen2_vl": self.model = ( Qwen2VisionTransformerPretrainedModel(kvargs, **model_cfg["vision_config"]).eval().bfloat16() @@ -66,17 +84,16 @@ def exposed_init_model(self, kvargs): Qwen2_5_VisionTransformerPretrainedModel(kvargs, **model_cfg["vision_config"]).eval().bfloat16() ) elif model_cfg["architectures"][0] == "TarsierForConditionalGeneration": - self.model = TarsierVisionTransformerPretrainedModel(**model_cfg).eval().bfloat16() + self.model = TarsierVisionTransformerPretrainedModel(kvargs, **model_cfg).eval().bfloat16() elif self.model_type == "llava": - self.model = LlavaVisionModel() + self.model = LlavaVisionModel(kvargs) elif self.model_type == "internvl_chat": self.model = VisionTransformer(kvargs) # self.model = InternVLVisionModel() elif self.model_type == "gemma3": - self.model = Gemma3VisionModel() + self.model = Gemma3VisionModel(kvargs) else: raise Exception(f"can not support {self.model_type} now") - self.model.load_model(weight_dir) self.model = self.model.cuda() except Exception as e: @@ -99,18 +116,20 @@ def forward(self, images: List[ImageItem]): def exposed_encode(self, images: List[ImageItem]): images = obtain(images) all_img_embeds, uuids, valid_ids = self.forward(images) - all_img_embeds = all_img_embeds.to(torch.device("cpu")) - + all_img_embeds = all_img_embeds.to(torch.device("cpu"), non_blocking=True) if self.tp_rank_id == 0: - ready_flags = obtain(self.cache_client.root.get_items_embed(uuids)) + # ready_flags = obtain(self.cache_client.root.get_items_embed(uuids)) ids_to_set = [] - for i, ready in enumerate(ready_flags): - if ready: - continue + for i in range(len(images)): + # if ready: + # continue uid = uuids[i] start, end = valid_ids[i] cur_embed_bytes = tensor2bytes(all_img_embeds[start:end]) - create_shm(get_shm_name_embed(uid), cur_embed_bytes) + if self.args.run_mode == "visual": + create_afs(get_shm_name_embed(uid), cur_embed_bytes, self.image_embed_dir) + else: + create_shm(get_shm_name_embed(uid), cur_embed_bytes) ids_to_set.append(uid) if ids_to_set: self.cache_client.root.set_items_embed(ids_to_set) @@ -118,11 +137,13 @@ def exposed_encode(self, images: List[ImageItem]): class VisualModelRpcClient: - def __init__(self, model_rpc, vit_tp, rpc_server_process=None): - self.model: VisualModelRpcServer = model_rpc + def __init__(self, conn, vit_tp, rpc_server_process=None): + self.conn = conn + self.model: VisualModelRpcServer = conn.root self.vit_tp = vit_tp self.rpc_server_process = rpc_server_process self.use_rpc = True + self._bg = rpyc.BgServingThread(self.conn) if self.use_rpc: def async_wrap(f): @@ -195,4 +216,4 @@ async def start_model_process(port, vit_tp, device_id): raise Exception("init rpc env error!") assert proc.is_alive() - return VisualModelRpcClient(con.root, vit_tp, rpc_server_process=proc) + return VisualModelRpcClient(con, vit_tp, rpc_server_process=proc) diff --git a/lightllm/server/visualserver/register_loop.py b/lightllm/server/visualserver/register_loop.py new file mode 100644 index 000000000..31d0f7b8a --- /dev/null +++ b/lightllm/server/visualserver/register_loop.py @@ -0,0 +1,42 @@ +import asyncio +import pickle +import websockets +import socket +from lightllm.utils.net_utils import get_hostname_ip +from lightllm.utils.log_utils import init_logger +from .vit_connect import VIT_Obj + +logger = init_logger(__name__) + + +async def register_loop(args): + assert args.host not in ["127.0.0.1", "localhost"], "remote visual server must specify host ip" + + if args.host in ["0.0.0.0"]: + host_ip = get_hostname_ip() + else: + host_ip = args.host + + while True: + + try: + uri = f"ws://{args.config_server_host}:{args.config_server_port}/visual_register" + async with websockets.connect(uri, max_queue=(2048 * 1024, 2048 * 1023)) as websocket: + + sock = websocket.transport.get_extra_info("socket") + sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) + + vit_obj = VIT_Obj(node_id=args.visual_node_id, host_ip_port=f"{host_ip}:{args.remote_vit_port}") + + await websocket.send(pickle.dumps(vit_obj)) + logger.info(f"Sent registration vit_obj: {vit_obj}") + + while True: + await websocket.send("heartbeat") + await asyncio.sleep(40) + + except Exception as e: + logger.error("connetion to config_server has error") + logger.exception(str(e)) + await asyncio.sleep(10) + logger.info("reconnection to config_server") diff --git a/lightllm/server/visualserver/vit_connect.py b/lightllm/server/visualserver/vit_connect.py new file mode 100644 index 000000000..7a1443f02 --- /dev/null +++ b/lightllm/server/visualserver/vit_connect.py @@ -0,0 +1,236 @@ +import asyncio +import zmq +import zmq.asyncio +import time +import pickle +from typing import Dict, List, Optional, Any +from lightllm.utils.log_utils import init_logger +from lightllm.server.core.objs.io_objs import GroupReqObjs, GroupReqIndexes +from lightllm.server.multimodal_params import MultimodalParams +import httpx +import base64 +from dataclasses import dataclass +import rpyc + +logger = init_logger(__name__) + + +@dataclass +class VIT_Obj: + node_id: int + host_ip_port: str + + def to_log_str(self): + return f"VIT host_ip_port: {self.host_ip_port} node_id: {self.node_id}" + + +class VITConnectionManager: + """VIT连接管理器""" + + def __init__(self, args, context, local_visual_port: int, cache_client: rpyc.Connection): + self.args = args + self.context = context + self.local_visual_port = local_visual_port + + self.send_to_visual = None + self.remote_vit_instances = {} + self.current_vit_index = 0 + self.remote_vit = args.enable_remote_vit + self.remote_vit_port = args.remote_vit_port + self.cache_client = cache_client + + self._setup_vit_connections() + + def _setup_vit_connections(self): + """ + 设置VIT连接,支持本地和远程VIT实例 + 支持多种连接模式: + 1. 本地VIT实例 (默认) + 2. 远程多个VIT实例 (负载均衡) + """ + if self.remote_vit: + # 远程VIT实例模式 + self._setup_remote_vit_connections() + else: + print("not remote") + self._setup_local_vit_connection() + + def _setup_local_vit_connection(self): + self.send_to_visual = self.context.socket(zmq.PUSH) + self.send_to_visual.connect(f"{self.args.zmq_mode}127.0.0.1:{self.local_visual_port}") + logger.info(f"Connected to local VIT instance at {self.args.zmq_mode}127.0.0.1:{self.local_visual_port}") + + def _setup_remote_vit_connections(self): + """ + 初始化远程VIT连接,同步获取初始实例 + """ + logger.info("Setting up remote VIT connections...") + + self._sync_init_vit_instances() + + retry_count = 0 + max_retries = 30 # 最多等待30秒 + while len(self.remote_vit_instances) == 0 and retry_count < max_retries: + logger.info(f"Waiting for VIT instances... (attempt {retry_count + 1}/{max_retries})") + time.sleep(1) + retry_count += 1 + self._sync_init_vit_instances() + + if len(self.remote_vit_instances) == 0: + logger.warning("No VIT instances available after initialization") + else: + logger.info(f"Successfully connected to {len(self.remote_vit_instances)} VIT instances") + + def _sync_init_vit_instances(self): + """ + 同步初始化VIT实例连接 + """ + try: + # 使用同步方式获取VIT实例 + vit_objs = self._sync_get_vit_objs() + if vit_objs: + self._update_vit_connections(vit_objs) + except Exception as e: + logger.error(f"Failed to initialize VIT instances: {e}") + + def _sync_get_vit_objs(self) -> Optional[Dict[int, VIT_Obj]]: + """ + 同步获取VIT实例信息 + """ + import requests + + uri = f"http://{self.args.config_server_host}:{self.args.config_server_port}/registered_visual_objects" + try: + response = requests.get(uri, timeout=10) + if response.status_code == 200: + base64data = response.json()["data"] + id_to_vit_obj = pickle.loads(base64.b64decode(base64data)) + return id_to_vit_obj + else: + logger.error(f"Failed to get VIT instances: {response.status_code}") + return None + except Exception as e: + logger.error(f"Error getting VIT instances: {e}") + return None + + def _update_vit_connections(self, id_to_vit_obj: Dict[int, VIT_Obj]): + """ + 更新VIT连接,添加新的连接,关闭失效的连接 + """ + # 关闭不再存在的连接 + closed_ids = [] + for id, remote_instance in self.remote_vit_instances.items(): + if id not in id_to_vit_obj: + try: + remote_instance.close() + except: + pass + closed_ids.append(id) + logger.info(f"Closed VIT connection {id}") + + for id in closed_ids: + self.remote_vit_instances.pop(id) + + # 建立新的连接 + for id, vit_obj in id_to_vit_obj.items(): + if id not in self.remote_vit_instances: + try: + socket = self.context.socket(zmq.PUSH) + # print(vit_obj.host_ip_port, self.args.remote_vit_port, flush=True) + ip, port = vit_obj.host_ip_port.split(":") + socket.connect(f"tcp://{ip}:{port}") + self.remote_vit_instances[id] = socket + logger.info(f"Connected to VIT instance {id} at {vit_obj.host_ip_port}") + except Exception as e: + logger.error(f"Failed to connect to VIT instance {id}: {e}") + + def _get_vit_instance(self): + """ + 获取下一个可用的VIT实例 (轮询负载均衡) + """ + if not self.remote_vit: + return self.send_to_visual + + if len(self.remote_vit_instances) == 0: + raise Exception("No available VIT instances") + + # 简单的轮询负载均衡 + index = (self.current_vit_index + 1) % len(self.remote_vit_instances) + self.current_vit_index = index + return list(self.remote_vit_instances.values())[index] + + async def send_to_vit(self, req: GroupReqIndexes, protocol=pickle.HIGHEST_PROTOCOL, embeding_only=False): + """ + 发送数据到VIT实例,支持本地和远程模式 + """ + instance = self._get_vit_instance() + # 本地模式下,提前释放图片资源,降低传输开销 + if not self.remote_vit: + req.multimodal_params.free() + + try: + print(instance, flush=True) + instance.send_pyobj(req, protocol=protocol) + except Exception as e: + logger.error(f"Failed to send to VIT instance: {e}") + raise Exception(f"Failed to send to VIT instance: {e}") + + # 远程模式下,发送完以后,在释放图片资源 + await self._wait_visual_embed_ready(req, embeding_only) + if self.remote_vit: + req.multimodal_params.free() + + async def vit_handle_loop(self): + """ + 异步VIT连接管理循环,由外部启动 + """ + if not self.remote_vit: + return + logger.info("Starting VIT connection management loop") + while True: + try: + id_to_vit_obj = await self._async_get_vit_objs() + if id_to_vit_obj: + self._update_vit_connections(id_to_vit_obj) + await asyncio.sleep(30) + except Exception as e: + logger.exception(f"Error in VIT handle loop: {e}") + await asyncio.sleep(10) + + async def _async_get_vit_objs(self) -> Optional[Dict[int, VIT_Obj]]: + """ + 异步获取VIT实例信息 + """ + uri = f"ws://{self.args.config_server_host}:{self.args.config_server_port}/registered_visual_objects" + try: + async with httpx.AsyncClient() as client: + response = await client.get(uri) + if response.status_code == 200: + base64data = response.json()["data"] + id_to_vit_obj = pickle.loads(base64.b64decode(base64data)) + return id_to_vit_obj + else: + logger.error(f"Failed to get VIT instances: {response.status_code}") + return None + except Exception as e: + logger.exception(f"Error getting VIT instances: {e}") + return None + + async def _wait_visual_embed_ready( + self, req: GroupReqIndexes, embeding_only: bool = False, timeout_seconds: int = 1000 + ): + # 本地模式不需要等待 + if not self.remote_vit: + return + uuids = req.multimodal_params.get_all_uuids() + + async def wait_for_embeds(): + while not all(self.cache_client.root.get_items_embed(uuids, embeding_only)): + await asyncio.sleep(0.01) + + try: + await asyncio.wait_for(wait_for_embeds(), timeout=timeout_seconds) + except asyncio.TimeoutError: + logger.error( + f"Req {req.group_req_id}: timeout waiting for visual embed ready after {timeout_seconds} seconds" + ) diff --git a/lightllm/utils/dist_utils.py b/lightllm/utils/dist_utils.py index 65ac401d4..f0b06ead1 100644 --- a/lightllm/utils/dist_utils.py +++ b/lightllm/utils/dist_utils.py @@ -55,19 +55,23 @@ def get_environ(environ_name): def init_vision_distributed_env(kvargs): - tp_world_size = kvargs["vit_tp"] + from lightllm.utils.envs_utils import get_env_start_args + + args = get_env_start_args() + tp_world_size = args.visual_tp dp_size = 1 tp_rank_id = kvargs["tp_rank_id"] set_dp_size(dp_size) set_dp_world_size(tp_world_size) set_current_rank_in_dp(tp_rank_id) - visual_gpu_ids = kvargs["visual_gpu_ids"] + visual_gpu_ids = args.visual_gpu_ids device_id = visual_gpu_ids[kvargs["vit_rank_id"]] set_current_device_id(device_id) torch.cuda.set_device(device_id) + visual_nccl_port = args.visual_nccl_ports[kvargs["dp_rank_id"]] dist.init_process_group( "nccl", - init_method=f'tcp://127.0.0.1:{kvargs["visual_nccl_port"]}', + init_method=f"tcp://127.0.0.1:{visual_nccl_port}", rank=kvargs["tp_rank_id"], world_size=tp_world_size, device_id=torch.device(f"cuda:{device_id}"), diff --git a/lightllm/utils/redis_utils.py b/lightllm/utils/redis_utils.py new file mode 100644 index 000000000..0cf19afb9 --- /dev/null +++ b/lightllm/utils/redis_utils.py @@ -0,0 +1,74 @@ +import subprocess +from lightllm.utils.log_utils import init_logger + +logger = init_logger(__name__) + + +def start_redis_service(args): + """launch redis service""" + if not hasattr(args, "start_redis") or not args.start_redis: + return None + + config_server_host = args.config_server_host + redis_port = args.redis_port + try: + subprocess.run( + ["redis-cli", "-h", config_server_host, "-p", str(redis_port), "FLUSHALL", "ASYNC"], check=False, timeout=2 + ) + subprocess.run( + ["redis-cli", "-h", config_server_host, "-p", str(redis_port), "SHUTDOWN", "NOSAVE"], check=False, timeout=2 + ) + except Exception: + pass + + try: + redis_command = [ + "redis-server", + "--port", + str(redis_port), + "--bind", + f"{config_server_host}", + "--daemonize", + "no", + "--logfile", + "-", + "--loglevel", + "notice", + "--save", + '""', # 不触发 RDB 快照 + "--appendonly", + "no", # 关闭 AOF + ] + + logger.info(f"Starting Redis service on port {redis_port}") + redis_process = subprocess.Popen(redis_command) + + import redis + import time + + max_wait = 10 + start_time = time.time() + + while time.time() - start_time < max_wait: + try: + r = redis.Redis(host=args.config_server_host, port=redis_port, socket_connect_timeout=1) + r.ping() + logger.info(f"Redis service started successfully on port {redis_port}") + del r + break + except Exception: + time.sleep(0.5) + if redis_process.poll() is not None: + logger.error("Redis service failed to start") + return None + else: + logger.error("Redis service startup timeout") + if redis_process.poll() is None: + redis_process.terminate() + return None + + return redis_process + + except Exception as e: + logger.error(f"Failed to start Redis service: {e}") + return None diff --git a/lightllm/utils/start_utils.py b/lightllm/utils/start_utils.py index 372b7e1cf..824543108 100644 --- a/lightllm/utils/start_utils.py +++ b/lightllm/utils/start_utils.py @@ -111,4 +111,12 @@ def kill_recursive(proc): logger.warning(f"Process {proc.pid} does not exist.") +def is_multimodal_mode(args): + from transformers import PretrainedConfig + + model_cfg, _ = PretrainedConfig.get_config_dict(args.model_dir) + is_multimodal = "visual" in model_cfg or "vision_config" in model_cfg + return is_multimodal + + process_manager = SubmoduleManager()