diff --git a/fastdeploy/cache_manager/prefix_cache_manager.py b/fastdeploy/cache_manager/prefix_cache_manager.py index 1d5dc9c33f9..5b5dce2c079 100644 --- a/fastdeploy/cache_manager/prefix_cache_manager.py +++ b/fastdeploy/cache_manager/prefix_cache_manager.py @@ -598,6 +598,21 @@ def update_cache_blocks(self, task, block_size, num_computed_tokens): logger.error(f"update_cache_blocks, error: {type(e)} {e}, {str(traceback.format_exc())}") raise e + def is_chunked_mm_input(self, mm_inputs, matched_token_num): + """ + check if mm_inputs is chunked + """ + if mm_inputs is None or "mm_positions" not in mm_inputs or len(mm_inputs["mm_positions"]) == 0: + return False, 0 + + for idx in range(len(mm_inputs["mm_positions"])): + position = mm_inputs["mm_positions"][idx] + if position.offset < matched_token_num < position.offset + position.length: + return True, idx + elif matched_token_num < position.offset: + break + return False, 0 + def request_match_blocks(self, task, block_size, *args): """ get match blocks info for a task. @@ -617,9 +632,12 @@ def request_match_blocks(self, task, block_size, *args): """ with self.request_release_lock: try: - hit_info = {} - hit_info["gpu_cache_blocks"] = 0 - hit_info["cpu_cache_blocks"] = 0 + hit_info = { + "gpu_cache_blocks": 0, + "cpu_cache_blocks": 0, + "gpu_match_token_num": 0, + "cpu_match_token_num": 0, + } self.metrics.req_count += 1 if isinstance(task.prompt_token_ids, np.ndarray): prompt_token_ids = task.prompt_token_ids.tolist() @@ -673,8 +691,10 @@ def request_match_blocks(self, task, block_size, *args): gpu_match_token_num, input_token_num, ) - hit_info["gpu_cache_blocks"] = gpu_match_token_num // block_size - hit_info["cpu_cache_blocks"] = cpu_match_token_num // block_size + hit_info["gpu_cache_blocks"] = len(match_gpu_block_ids) + hit_info["cpu_cache_blocks"] = len(match_cpu_block_ids) + hit_info["gpu_match_token_num"] = gpu_match_token_num + hit_info["cpu_match_token_num"] = cpu_match_token_num self.metrics._update_history_hit_metrics() if self.metrics.req_count % 10000 == 0: self.metrics.reset_metrics() @@ -685,8 +705,8 @@ def request_match_blocks(self, task, block_size, *args): self.req_leaf_map[req_id] = match_block_node self.leaf_req_map[match_block_node].add(req_id) # record request cache info - self.cache_info[req_id] = (match_block_node, matched_token_num) - task.cached_block_num = matched_token_num // block_size + self.cache_info[req_id] = (match_block_node, len(common_block_ids) * block_size) + task.cached_block_num = len(common_block_ids) return common_block_ids, matched_token_num, hit_info except Exception as e: logger.error(f"request_match_blocks: request_block_ids: error: {type(e)} {e}") @@ -1202,6 +1222,64 @@ def hash_block_features(self, input_ids, extra_keys: list = []): """ return hashlib.sha256(pickle.dumps((input_ids, extra_keys))).hexdigest() + def _revert_match_blocks( + self, + request, + matched_token_num: int, + block_size: int, + chunk_idx: int, + match_node_ids: list, + matche_nodes: list, + match_gpu_block_ids: list, + match_cpu_block_ids: list, + gpu_match_token_num: int, + cpu_match_token_num: int, + swap_node_ids: list, + ): + position = request.multimodal_inputs["mm_positions"][chunk_idx] + revert_tokens = matched_token_num - position.offset + match_block_ids = [node.block_id for node in matche_nodes] + logger.warning( + f"match_block: req_id {request.request_id} revert tokens: {revert_tokens} from matched nodes: {match_block_ids}" + ) + while revert_tokens >= block_size: + if len(matche_nodes) == 0: + logger.error(f"req_id {request.request_id} revert nodes error, tokens: {revert_tokens}") + break + revert_tokens -= block_size + revert_block = matche_nodes.pop() + revert_block_id = revert_block.block_id + if revert_block_id in match_gpu_block_ids: + match_gpu_block_ids.remove(revert_block_id) + match_node_ids.remove(revert_block.node_id) + gpu_match_token_num -= block_size + elif revert_block_id in match_cpu_block_ids: + match_cpu_block_ids.remove(revert_block_id) + match_node_ids.remove(revert_block.node_id) + cpu_match_token_num -= block_size + else: + logger.error( + f"req_id {request.request_id} revert nodes error, nodes: {revert_block_id}, " + f"match_gpu_block_ids: {match_gpu_block_ids}, match_cpu_block_ids: {match_cpu_block_ids}" + ) + break + if revert_block_id in swap_node_ids: + swap_node_ids.remove(revert_block_id) + + if revert_tokens > 0: + last_block_id = matche_nodes[-1].block_id + if last_block_id in match_gpu_block_ids: + gpu_match_token_num -= revert_tokens + elif last_block_id in match_cpu_block_ids: + cpu_match_token_num -= revert_tokens + else: + logger.error( + f"req_id {request.request_id} revert nodes error, revert_tokens: {revert_tokens}, nodes: {last_block_id}, " + f"match_gpu_block_ids: {match_gpu_block_ids}, match_cpu_block_ids: {match_cpu_block_ids}" + ) + current_node = self.radix_tree_root if len(matche_nodes) == 0 else matche_nodes[-1] + return gpu_match_token_num, cpu_match_token_num, current_node + def mm_match_block(self, request, block_size): """ Match and retrieve cached blocks for multimodal requests using a radix tree structure. @@ -1290,6 +1368,28 @@ def mm_match_block(self, request, block_size): if has_modified_cpu_lru_leaf_heap: heapq.heapify(self.cpu_lru_leaf_heap) + if self.cache_config.disable_chunked_mm_input: + matched_token_num = gpu_match_token_num + cpu_match_token_num + is_chunked, chunk_idx = self.is_chunked_mm_input(request.multimodal_inputs, matched_token_num) + if is_chunked: + ( + gpu_match_token_num, + cpu_match_token_num, + current_match_node, + ) = self._revert_match_blocks( + request=request, + matched_token_num=matched_token_num, + block_size=block_size, + chunk_idx=chunk_idx, + match_node_ids=match_node_ids, + matche_nodes=matche_nodes, + match_gpu_block_ids=match_gpu_block_ids, + match_cpu_block_ids=match_cpu_block_ids, + gpu_match_token_num=gpu_match_token_num, + cpu_match_token_num=cpu_match_token_num, + swap_node_ids=swap_node_ids, + ) + logger.info(f"match_block: req_id {request.request_id} matched nodes: {match_node_ids}") return ( match_gpu_block_ids, diff --git a/fastdeploy/config.py b/fastdeploy/config.py index 2b04bd2c42f..245549440cf 100644 --- a/fastdeploy/config.py +++ b/fastdeploy/config.py @@ -1202,6 +1202,7 @@ def __init__(self, args): self.swap_space = None self.max_encoder_cache = None self.max_processor_cache = None + self.disable_chunked_mm_input = False for key, value in args.items(): if hasattr(self, key): setattr(self, key, value) diff --git a/fastdeploy/engine/args_utils.py b/fastdeploy/engine/args_utils.py index e6eb22f95e5..a48784978ef 100644 --- a/fastdeploy/engine/args_utils.py +++ b/fastdeploy/engine/args_utils.py @@ -314,6 +314,10 @@ class EngineArgs: """ additional decode block num """ + disable_chunked_mm_input: bool = False + """ + Disable chunked_mm_input for multi-model inference. + """ scheduler_name: str = "local" """ @@ -936,6 +940,13 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: help="ports for rdma communication.", ) + perf_group.add_argument( + "--disable-chunked-mm-input", + action="store_true", + default=EngineArgs.disable_chunked_mm_input, + help="Disable chunked mm input.", + ) + # Router parameters group router_group = parser.add_argument_group("Router") router_group.add_argument( diff --git a/fastdeploy/engine/sched/resource_manager_v1.py b/fastdeploy/engine/sched/resource_manager_v1.py index 904edbbbbc5..0ffbc3aa7c8 100644 --- a/fastdeploy/engine/sched/resource_manager_v1.py +++ b/fastdeploy/engine/sched/resource_manager_v1.py @@ -771,8 +771,8 @@ def get_prefix_cached_blocks(self, request: Request): ) request.num_cached_tokens = matched_token_num - request.gpu_cache_token_num = hit_info["gpu_cache_blocks"] * self.config.cache_config.block_size - request.cpu_cache_token_num = hit_info["cpu_cache_blocks"] * self.config.cache_config.block_size + request.gpu_cache_token_num = hit_info["gpu_match_token_num"] + request.cpu_cache_token_num = hit_info["cpu_match_token_num"] request.cache_info = (matched_block_num, no_cache_block_num) request.block_tables = common_block_ids request.skip_allocate = False diff --git a/tests/v1/cache_manager/test_revert_blocks.py b/tests/v1/cache_manager/test_revert_blocks.py new file mode 100644 index 00000000000..e47510d0c8c --- /dev/null +++ b/tests/v1/cache_manager/test_revert_blocks.py @@ -0,0 +1,300 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +from dataclasses import asdict +from types import SimpleNamespace + +from fastdeploy.cache_manager.cache_data import BlockNode +from fastdeploy.cache_manager.prefix_cache_manager import PrefixCacheManager +from fastdeploy.config import CacheConfig, FDConfig, ParallelConfig +from fastdeploy.engine.args_utils import EngineArgs +from fastdeploy.engine.request import ImagePosition, Request +from fastdeploy.scheduler import SchedulerConfig + + +def make_prefix_cache_manager(max_num_seqs, enable_mm=False, num_gpu_blocks_override=100, max_num_batched_tokens=3200): + engine_args = EngineArgs( + max_num_seqs=max_num_seqs, + num_gpu_blocks_override=num_gpu_blocks_override, + max_num_batched_tokens=max_num_batched_tokens, + ) + args = asdict(engine_args) + cache_cfg = CacheConfig(args) + model_cfg = SimpleNamespace(enable_mm=enable_mm, max_model_len=8192) + speculative_cfg = SimpleNamespace(method=None) + model_cfg.print = print + cache_cfg.bytes_per_layer_per_block = 1 + parallel_cfg = ParallelConfig(args) + scheduler_cfg = SchedulerConfig(args) + graph_opt_cfg = engine_args.create_graph_optimization_config() + fd_config = FDConfig( + model_config=model_cfg, + cache_config=cache_cfg, + parallel_config=parallel_cfg, + graph_opt_config=graph_opt_cfg, + speculative_config=speculative_cfg, + scheduler_config=scheduler_cfg, + ) + return PrefixCacheManager(config=fd_config, tensor_parallel_size=8, splitwise_role="mixed") + + +class TestIsChunkedMMInput(unittest.TestCase): + def setUp(self): + self.cache_manager = make_prefix_cache_manager(max_num_seqs=3, enable_mm=True, num_gpu_blocks_override=100) + + def test_is_chunked_mm_input_none_input(self): + result, idx = self.cache_manager.is_chunked_mm_input(None, 10) + self.assertFalse(result) + self.assertEqual(idx, 0) + + def test_is_chunked_mm_input_no_mm_positions(self): + mm_inputs = {"other_field": "value"} + result, idx = self.cache_manager.is_chunked_mm_input(mm_inputs, 10) + self.assertFalse(result) + self.assertEqual(idx, 0) + + def test_is_chunked_mm_input_empty_positions(self): + mm_inputs = {"mm_positions": []} + result, idx = self.cache_manager.is_chunked_mm_input(mm_inputs, 10) + self.assertFalse(result) + self.assertEqual(idx, 0) + + def test_is_chunked_mm_input_matched_in_chunk(self): + mm_inputs = { + "mm_positions": [ + ImagePosition(offset=5, length=10), + ImagePosition(offset=20, length=10), + ] + } + result, idx = self.cache_manager.is_chunked_mm_input(mm_inputs, 8) + self.assertTrue(result) + self.assertEqual(idx, 0) + + def test_is_chunked_mm_input_matched_in_second_chunk(self): + mm_inputs = { + "mm_positions": [ + ImagePosition(offset=5, length=10), + ImagePosition(offset=20, length=10), + ] + } + result, idx = self.cache_manager.is_chunked_mm_input(mm_inputs, 25) + self.assertTrue(result) + self.assertEqual(idx, 1) + + def test_is_chunked_mm_input_before_first_chunk(self): + mm_inputs = { + "mm_positions": [ + ImagePosition(offset=5, length=10), + ImagePosition(offset=20, length=10), + ] + } + result, idx = self.cache_manager.is_chunked_mm_input(mm_inputs, 3) + self.assertFalse(result) + self.assertEqual(idx, 0) + + def test_is_chunked_mm_input_after_last_chunk(self): + mm_inputs = { + "mm_positions": [ + ImagePosition(offset=5, length=10), + ImagePosition(offset=20, length=10), + ] + } + result, idx = self.cache_manager.is_chunked_mm_input(mm_inputs, 35) + self.assertFalse(result) + self.assertEqual(idx, 0) + + +class TestRevertMatchBlocks(unittest.TestCase): + def setUp(self): + self.block_size = 64 + self.cache_manager = make_prefix_cache_manager(max_num_seqs=3, enable_mm=True, num_gpu_blocks_override=100) + + def make_match_blocks(self, gpu_block_num, cpu_block_num): + block_num = gpu_block_num + cpu_block_num + matched_token_num = block_num * self.block_size + match_node_ids = [] + matche_nodes = [] + match_gpu_block_ids = [] + match_cpu_block_ids = [] + for idx in range(block_num): + node_id = idx + 10 + block = BlockNode(node_id, [], 0, 0, idx, 0, None, None, None) + match_node_ids.append(node_id) + matche_nodes.append(block) + match_gpu_block_ids.append(idx) + + for _ in range(cpu_block_num): + match_cpu_block_ids.append(match_gpu_block_ids.pop()) + + gpu_match_token_num = len(match_gpu_block_ids) * self.block_size + cpu_match_token_num = len(match_cpu_block_ids) * self.block_size + return ( + matched_token_num, + match_node_ids, + matche_nodes, + match_gpu_block_ids, + match_cpu_block_ids, + gpu_match_token_num, + cpu_match_token_num, + ) + + def test_revert_full_blocks(self): + # Setup test data + multimodal_inputs = { + "mm_positions": [ImagePosition(offset=0, length=1200)], + "mm_hashes": ["image1"], + } + req_dict = { + "request_id": "req1", + "prompt_token_ids": [-1] * 1200 + [2] * 120, + "prompt_token_ids_len": 1320, + "multimodal_inputs": multimodal_inputs, + } + + ( + matched_token_num, + match_node_ids, + matche_nodes, + match_gpu_block_ids, + match_cpu_block_ids, + gpu_match_token_num, + cpu_match_token_num, + ) = self.make_match_blocks(gpu_block_num=2, cpu_block_num=0) + + # Call method + ( + gpu_match_token_num, + cpu_match_token_num, + current_match_node, + ) = self.cache_manager._revert_match_blocks( + request=Request.from_dict(req_dict), + matched_token_num=matched_token_num, + block_size=self.block_size, + chunk_idx=0, + match_node_ids=match_node_ids, + matche_nodes=matche_nodes, + match_gpu_block_ids=match_gpu_block_ids, + match_cpu_block_ids=match_cpu_block_ids, + gpu_match_token_num=gpu_match_token_num, + cpu_match_token_num=cpu_match_token_num, + swap_node_ids=[], + ) + + # Assertions + self.assertEqual(gpu_match_token_num, 0) + self.assertEqual(cpu_match_token_num, 0) + self.assertEqual(len(match_node_ids), 0) + self.assertEqual(len(match_gpu_block_ids), 0) + + def test_revert_partial_block(self): + # Setup test data + multimodal_inputs = { + "mm_positions": [ImagePosition(offset=120, length=1200)], + "mm_hashes": ["image1"], + } + req_dict = { + "request_id": "req1", + "prompt_token_ids": [1] * 120 + [-1] * 1200 + [2] * 120, + "prompt_token_ids_len": 1440, + "multimodal_inputs": multimodal_inputs, + } + + ( + matched_token_num, + match_node_ids, + matche_nodes, + match_gpu_block_ids, + match_cpu_block_ids, + gpu_match_token_num, + cpu_match_token_num, + ) = self.make_match_blocks(gpu_block_num=20, cpu_block_num=0) + + # Call method + ( + gpu_match_token_num, + cpu_match_token_num, + current_match_node, + ) = self.cache_manager._revert_match_blocks( + request=Request.from_dict(req_dict), + matched_token_num=matched_token_num, + block_size=self.block_size, + chunk_idx=0, + match_node_ids=match_node_ids, + matche_nodes=matche_nodes, + match_gpu_block_ids=match_gpu_block_ids, + match_cpu_block_ids=match_cpu_block_ids, + gpu_match_token_num=gpu_match_token_num, + cpu_match_token_num=cpu_match_token_num, + swap_node_ids=[], + ) + + # Assertions + self.assertEqual(gpu_match_token_num, 120) + self.assertEqual(cpu_match_token_num, 0) + self.assertEqual(len(match_node_ids), 2) + self.assertEqual(len(match_gpu_block_ids), 2) + + def test_revert_with_cpu_blocks(self): + # Setup test data + multimodal_inputs = { + "mm_positions": [ImagePosition(offset=120, length=1200), ImagePosition(offset=1440, length=420)], + "mm_hashes": ["image1", "image2"], + } + req_dict = { + "request_id": "req1", + "prompt_token_ids": [1] * 120 + [-1] * 1200 + [2] * 120 + [-1] * 420, + "prompt_token_ids_len": 1860, + "multimodal_inputs": multimodal_inputs, + } + + ( + matched_token_num, + match_node_ids, + matche_nodes, + match_gpu_block_ids, + match_cpu_block_ids, + gpu_match_token_num, + cpu_match_token_num, + ) = self.make_match_blocks(gpu_block_num=22, cpu_block_num=6) + + # Call method + ( + gpu_match_token_num, + cpu_match_token_num, + current_match_node, + ) = self.cache_manager._revert_match_blocks( + request=Request.from_dict(req_dict), + matched_token_num=matched_token_num, + block_size=self.block_size, + chunk_idx=1, + match_node_ids=match_node_ids, + matche_nodes=matche_nodes, + match_gpu_block_ids=match_gpu_block_ids, + match_cpu_block_ids=match_cpu_block_ids, + gpu_match_token_num=gpu_match_token_num, + cpu_match_token_num=cpu_match_token_num, + swap_node_ids=[], + ) + + # Assertions + self.assertEqual(gpu_match_token_num, 22 * self.block_size) + self.assertEqual(cpu_match_token_num, 32) + self.assertEqual(len(match_node_ids), 23) + self.assertEqual(len(match_gpu_block_ids), 22) + self.assertEqual(len(match_cpu_block_ids), 1) + + +if __name__ == "__main__": + unittest.main()