diff --git a/docs/CN/source/tutorial/api_server_args.rst b/docs/CN/source/tutorial/api_server_args.rst index 6d372fee2..42ab04e4f 100644 --- a/docs/CN/source/tutorial/api_server_args.rst +++ b/docs/CN/source/tutorial/api_server_args.rst @@ -272,6 +272,12 @@ PD 分离模式参数 多模态资源的缓存服务器容量,默认为 ``200`` +.. option:: --max_image_token_count + + 单张图片在转换为 token 后允许的最大 token 数量,默认为 ``6128`` + + 当任意图片超过该阈值时,请求会被拒绝。 + .. option:: --visual_infer_batch_size 每次推理批次中处理的图像数量,默认为 ``1`` diff --git a/docs/EN/source/tutorial/api_server_args.rst b/docs/EN/source/tutorial/api_server_args.rst index ab5143a47..7f3f8f208 100644 --- a/docs/EN/source/tutorial/api_server_args.rst +++ b/docs/EN/source/tutorial/api_server_args.rst @@ -270,6 +270,12 @@ Multimodal Parameters Cache server capacity for multimodal resources, default is ``200`` +.. option:: --max_image_token_count + + Maximum allowed token count for a single image after tokenization, default is ``6128`` + + Requests are rejected when any image exceeds this limit. + .. option:: --visual_infer_batch_size Number of images processed in each inference batch, default is ``1`` diff --git a/lightllm/server/api_cli.py b/lightllm/server/api_cli.py index 4a345000b..26c651b15 100644 --- a/lightllm/server/api_cli.py +++ b/lightllm/server/api_cli.py @@ -442,6 +442,12 @@ def make_argument_parser() -> argparse.ArgumentParser: parser.add_argument( "--cache_capacity", type=int, default=200, help="cache server capacity for multimodal resources" ) + parser.add_argument( + "--max_image_token_count", + type=int, + default=6128, + help="maximum allowed token count for one image after tokenization", + ) parser.add_argument( "--embed_cache_storage_size", type=float, diff --git a/lightllm/server/core/objs/start_args_type.py b/lightllm/server/core/objs/start_args_type.py index ce613b105..b02094eed 100644 --- a/lightllm/server/core/objs/start_args_type.py +++ b/lightllm/server/core/objs/start_args_type.py @@ -95,6 +95,7 @@ class StartArgs: enable_decode_microbatch_overlap: bool = field(default=False) enable_prefill_microbatch_overlap: bool = field(default=False) cache_capacity: int = field(default=200) + max_image_token_count: int = field(default=6128) embed_cache_storage_size: float = field(default=4) data_type: Optional[str] = field( default=None, metadata={"choices": ["fp16", "float16", "bf16", "bfloat16", "fp32", "float32"]} diff --git a/lightllm/server/httpserver/manager.py b/lightllm/server/httpserver/manager.py index d54de63f3..02ed716f1 100644 --- a/lightllm/server/httpserver/manager.py +++ b/lightllm/server/httpserver/manager.py @@ -181,6 +181,17 @@ async def _alloc_resource(self, items, md5sums, token_nums, datas): self.cache_client.root.set_items_data(update_data_ids) return + def _assert_image_token_count(self, token_num: int): + if token_num > self.args.max_image_token_count: + err_msg = ( + f"single image token count {token_num} exceeds max_image_token_count {self.args.max_image_token_count}." + f"You can increase this limit by setting --max_image_token_count to a larger value when starting " + f"LightLLM. Warning: increasing this limit raises runtime OOM risk." + ) + logger.warning(err_msg) + raise ValueError(err_msg) + return + async def _alloc_multimodal_resources(self, multimodal_params: MultimodalParams, sampling_params: SamplingParams): # 只有 P 和 NORMAL 节点需要真的管理多模态资源 if self.pd_mode.is_P_or_NORMAL(): @@ -190,6 +201,7 @@ async def _alloc_multimodal_resources(self, multimodal_params: MultimodalParams, data = img.read() # must after init_imageitem_extral_params token_num = self.tokenizer.get_image_token_length(img) + self._assert_image_token_count(token_num) md5sum = hashlib.md5(data).hexdigest() + "_" + str(hash(frozendict(img.extra_params))) md5sums.append(md5sum) img.md5 = md5sum @@ -245,7 +257,9 @@ def tokens(self, prompt, multimodal_params, samping_params: SamplingParams, kwar for img in multimodal_params.images: img_count += 1 self.tokenizer.init_imageitem_extral_params(img, multimodal_params, samping_params) - image_tokens += self.tokenizer.get_image_token_length(img) + token_num = self.tokenizer.get_image_token_length(img) + self._assert_image_token_count(token_num) + image_tokens += token_num for audio in multimodal_params.audios: audio_count += 1 self.tokenizer.init_audioitem_extral_params(audio, multimodal_params, samping_params) diff --git a/lightllm/server/httpserver_for_pd_master/manager.py b/lightllm/server/httpserver_for_pd_master/manager.py index af7a1e29f..307a3d48a 100644 --- a/lightllm/server/httpserver_for_pd_master/manager.py +++ b/lightllm/server/httpserver_for_pd_master/manager.py @@ -81,7 +81,16 @@ def tokens(self, prompt, multimodal_params, samping_params: SamplingParams, kwar for img in multimodal_params.images: img_count += 1 self.tokenizer.init_imageitem_extral_params(img, multimodal_params, samping_params) - image_tokens += self.tokenizer.get_image_token_length(img) + token_num = self.tokenizer.get_image_token_length(img) + if token_num > self.args.max_image_token_count: + err_msg = ( + f"the image token count {token_num} > max_image_token_count {self.args.max_image_token_count}. " + f"You can increase this limit by setting --max_image_token_count to a larger value when starting " + f"LightLLM. Warning: increasing this limit raises runtime OOM risk." + ) + logger.warning(err_msg) + raise ValueError(err_msg) + image_tokens += token_num for audio in multimodal_params.audios: audio_count += 1 self.tokenizer.init_audioitem_extral_params(audio, multimodal_params, samping_params) diff --git a/lightllm/server/router/manager.py b/lightllm/server/router/manager.py index 24f8da6e6..045723d07 100644 --- a/lightllm/server/router/manager.py +++ b/lightllm/server/router/manager.py @@ -436,9 +436,11 @@ def _generate_new_batch(self): new_batch = self.req_queue.generate_new_batch( Batch.merge_two_batch(self.running_batch, self.schedule_new_batch) ) + + if new_batch is not None and len(new_batch.reqs) > 0: + logger.info(f"generate new batch, {new_batch.simple_log()}") + self.schedule_new_batch = Batch.merge_two_batch(self.schedule_new_batch, new_batch) - if self.schedule_new_batch is not None: - logger.info(f"gen new batch, {self.schedule_new_batch.simple_log()}") return def _multinode_tp_generate_new_batch(self):