Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
114 changes: 107 additions & 7 deletions fastdeploy/cache_manager/prefix_cache_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -598,6 +598,21 @@ def update_cache_blocks(self, task, block_size, num_computed_tokens):
logger.error(f"update_cache_blocks, error: {type(e)} {e}, {str(traceback.format_exc())}")
raise e

def is_chunked_mm_input(self, mm_inputs, matched_token_num):
"""
check if mm_inputs is chunked
"""
if mm_inputs is None or "mm_positions" not in mm_inputs or len(mm_inputs["mm_positions"]) == 0:
return False, 0

for idx in range(len(mm_inputs["mm_positions"])):
position = mm_inputs["mm_positions"][idx]
if position.offset < matched_token_num < position.offset + position.length:
return True, idx
elif matched_token_num < position.offset:
break
return False, 0

def request_match_blocks(self, task, block_size, *args):
"""
get match blocks info for a task.
Expand All @@ -617,9 +632,12 @@ def request_match_blocks(self, task, block_size, *args):
"""
with self.request_release_lock:
try:
hit_info = {}
hit_info["gpu_cache_blocks"] = 0
hit_info["cpu_cache_blocks"] = 0
hit_info = {
"gpu_cache_blocks": 0,
"cpu_cache_blocks": 0,
"gpu_match_token_num": 0,
"cpu_match_token_num": 0,
}
self.metrics.req_count += 1
if isinstance(task.prompt_token_ids, np.ndarray):
prompt_token_ids = task.prompt_token_ids.tolist()
Expand Down Expand Up @@ -673,8 +691,10 @@ def request_match_blocks(self, task, block_size, *args):
gpu_match_token_num,
input_token_num,
)
hit_info["gpu_cache_blocks"] = gpu_match_token_num // block_size
hit_info["cpu_cache_blocks"] = cpu_match_token_num // block_size
hit_info["gpu_cache_blocks"] = len(match_gpu_block_ids)
hit_info["cpu_cache_blocks"] = len(match_cpu_block_ids)
hit_info["gpu_match_token_num"] = gpu_match_token_num
hit_info["cpu_match_token_num"] = cpu_match_token_num
self.metrics._update_history_hit_metrics()
if self.metrics.req_count % 10000 == 0:
self.metrics.reset_metrics()
Expand All @@ -685,8 +705,8 @@ def request_match_blocks(self, task, block_size, *args):
self.req_leaf_map[req_id] = match_block_node
self.leaf_req_map[match_block_node].add(req_id)
# record request cache info
self.cache_info[req_id] = (match_block_node, matched_token_num)
task.cached_block_num = matched_token_num // block_size
self.cache_info[req_id] = (match_block_node, len(common_block_ids) * block_size)
task.cached_block_num = len(common_block_ids)
return common_block_ids, matched_token_num, hit_info
except Exception as e:
logger.error(f"request_match_blocks: request_block_ids: error: {type(e)} {e}")
Expand Down Expand Up @@ -1202,6 +1222,64 @@ def hash_block_features(self, input_ids, extra_keys: list = []):
"""
return hashlib.sha256(pickle.dumps((input_ids, extra_keys))).hexdigest()

def _revert_match_blocks(
self,
request,
matched_token_num: int,
block_size: int,
chunk_idx: int,
match_node_ids: list,
matche_nodes: list,
match_gpu_block_ids: list,
match_cpu_block_ids: list,
gpu_match_token_num: int,
cpu_match_token_num: int,
swap_node_ids: list,
):
position = request.multimodal_inputs["mm_positions"][chunk_idx]
revert_tokens = matched_token_num - position.offset
match_block_ids = [node.block_id for node in matche_nodes]
logger.warning(
f"match_block: req_id {request.request_id} revert tokens: {revert_tokens} from matched nodes: {match_block_ids}"
)
while revert_tokens >= block_size:
if len(matche_nodes) == 0:
logger.error(f"req_id {request.request_id} revert nodes error, tokens: {revert_tokens}")
break
revert_tokens -= block_size
revert_block = matche_nodes.pop()
revert_block_id = revert_block.block_id
if revert_block_id in match_gpu_block_ids:
match_gpu_block_ids.remove(revert_block_id)
match_node_ids.remove(revert_block.node_id)
gpu_match_token_num -= block_size
elif revert_block_id in match_cpu_block_ids:
match_cpu_block_ids.remove(revert_block_id)
match_node_ids.remove(revert_block.node_id)
cpu_match_token_num -= block_size
else:
logger.error(
f"req_id {request.request_id} revert nodes error, nodes: {revert_block_id}, "
f"match_gpu_block_ids: {match_gpu_block_ids}, match_cpu_block_ids: {match_cpu_block_ids}"
)
break
if revert_block_id in swap_node_ids:
swap_node_ids.remove(revert_block_id)

if revert_tokens > 0:
last_block_id = matche_nodes[-1].block_id
if last_block_id in match_gpu_block_ids:
gpu_match_token_num -= revert_tokens
elif last_block_id in match_cpu_block_ids:
cpu_match_token_num -= revert_tokens
else:
logger.error(
f"req_id {request.request_id} revert nodes error, revert_tokens: {revert_tokens}, nodes: {last_block_id}, "
f"match_gpu_block_ids: {match_gpu_block_ids}, match_cpu_block_ids: {match_cpu_block_ids}"
)
current_node = self.radix_tree_root if len(matche_nodes) == 0 else matche_nodes[-1]
return gpu_match_token_num, cpu_match_token_num, current_node

def mm_match_block(self, request, block_size):
"""
Match and retrieve cached blocks for multimodal requests using a radix tree structure.
Expand Down Expand Up @@ -1290,6 +1368,28 @@ def mm_match_block(self, request, block_size):
if has_modified_cpu_lru_leaf_heap:
heapq.heapify(self.cpu_lru_leaf_heap)

if self.cache_config.disable_chunked_mm_input:
matched_token_num = gpu_match_token_num + cpu_match_token_num
is_chunked, chunk_idx = self.is_chunked_mm_input(request.multimodal_inputs, matched_token_num)
if is_chunked:
(
gpu_match_token_num,
cpu_match_token_num,
current_match_node,
) = self._revert_match_blocks(
request=request,
matched_token_num=matched_token_num,
block_size=block_size,
chunk_idx=chunk_idx,
match_node_ids=match_node_ids,
matche_nodes=matche_nodes,
match_gpu_block_ids=match_gpu_block_ids,
match_cpu_block_ids=match_cpu_block_ids,
gpu_match_token_num=gpu_match_token_num,
cpu_match_token_num=cpu_match_token_num,
swap_node_ids=swap_node_ids,
)

logger.info(f"match_block: req_id {request.request_id} matched nodes: {match_node_ids}")
return (
match_gpu_block_ids,
Expand Down
1 change: 1 addition & 0 deletions fastdeploy/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -1202,6 +1202,7 @@ def __init__(self, args):
self.swap_space = None
self.max_encoder_cache = None
self.max_processor_cache = None
self.disable_chunked_mm_input = False
for key, value in args.items():
if hasattr(self, key):
setattr(self, key, value)
Expand Down
11 changes: 11 additions & 0 deletions fastdeploy/engine/args_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -314,6 +314,10 @@ class EngineArgs:
"""
additional decode block num
"""
disable_chunked_mm_input: bool = False
"""
Disable chunked_mm_input for multi-model inference.
"""

scheduler_name: str = "local"
"""
Expand Down Expand Up @@ -936,6 +940,13 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
help="ports for rdma communication.",
)

perf_group.add_argument(
"--disable-chunked-mm-input",
action="store_true",
default=EngineArgs.disable_chunked_mm_input,
help="Disable chunked mm input.",
)

# Router parameters group
router_group = parser.add_argument_group("Router")
router_group.add_argument(
Expand Down
4 changes: 2 additions & 2 deletions fastdeploy/engine/sched/resource_manager_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -771,8 +771,8 @@ def get_prefix_cached_blocks(self, request: Request):
)

request.num_cached_tokens = matched_token_num
request.gpu_cache_token_num = hit_info["gpu_cache_blocks"] * self.config.cache_config.block_size
request.cpu_cache_token_num = hit_info["cpu_cache_blocks"] * self.config.cache_config.block_size
request.gpu_cache_token_num = hit_info["gpu_match_token_num"]
request.cpu_cache_token_num = hit_info["cpu_match_token_num"]
request.cache_info = (matched_block_num, no_cache_block_num)
request.block_tables = common_block_ids
request.skip_allocate = False
Expand Down
Loading
Loading