From 8bf854e0316319e11f8900d7227cac8d01c40c60 Mon Sep 17 00:00:00 2001 From: ltd0924 Date: Sat, 8 Nov 2025 21:26:11 +0800 Subject: [PATCH 1/9] [Feature] support unified cache backend --- fastdeploy/cache_manager/cache_messager.py | 54 +++++---- .../cache_manager/cache_transfer_manager.py | 107 +++++++++++------- .../cache_manager/prefix_cache_manager.py | 84 +++++++++----- .../layers/attention/append_attn_backend.py | 13 +-- .../attention/block_multihead_attn_backend.py | 13 +-- .../layers/attention/flash_attn_backend.py | 13 +-- .../layers/attention/iluvatar_attn_backend.py | 8 +- .../layers/attention/mla_attention_backend.py | 9 +- .../attention/moba_attention_backend.py | 13 +-- .../layers/attention/xpu_attn_backend.py | 11 +- fastdeploy/worker/gcu_model_runner.py | 6 +- fastdeploy/worker/gpu_model_runner.py | 25 ++-- fastdeploy/worker/hpu_model_runner.py | 6 +- fastdeploy/worker/metax_model_runner.py | 16 +-- fastdeploy/worker/xpu_model_runner.py | 14 +-- 15 files changed, 212 insertions(+), 180 deletions(-) diff --git a/fastdeploy/cache_manager/cache_messager.py b/fastdeploy/cache_manager/cache_messager.py index e6e6aa15218..dedd6d30526 100644 --- a/fastdeploy/cache_manager/cache_messager.py +++ b/fastdeploy/cache_manager/cache_messager.py @@ -52,8 +52,8 @@ def parse_args(): parser.add_argument("--rank", type=int, default=0, help="current rank") parser.add_argument("--device_id", type=int, default=0, help="device id") parser.add_argument("--num_layers", type=int, default=1, help="model num layers") - parser.add_argument("--head_dim", type=int, default=1, help="model head dim") - parser.add_argument("--kv_num_head", type=int, default=1, help="model kv num head") + parser.add_argument("--key_cache_shape", type=list, default=[], help="key cache shape") + parser.add_argument("--value_cache_shape", type=list, default=[], help="value cache shape") parser.add_argument("--rdma_port", type=str, default="", help="rmda port") parser.add_argument("--mp_num", type=int, default=1, help="number of model parallel") parser.add_argument("--engine_pid", type=str, default=None, help="engine pid") @@ -71,8 +71,6 @@ def parse_args(): default=9923, help="engine worker queue port", ) - parser.add_argument("--num_gpu_blocks", type=int, default=1, help="gpu cache block number") - parser.add_argument("--block_size", type=int, default=64, help="cache block size(tokens)") parser.add_argument( "--cache_dtype", type=str, @@ -755,38 +753,54 @@ def main(): cache_type = args.cache_dtype speculative_config = SpeculativeConfig(args.speculative_config) num_extra_layers = speculative_config.num_extra_cache_layer - num_extra_layer_gpu_blocks = int(args.num_gpu_blocks * speculative_config.num_gpu_block_expand_ratio) + total_gpu_blocks = args.key_cache_shape[0] + num_extra_layer_gpu_blocks = int(total_gpu_blocks * speculative_config.num_gpu_block_expand_ratio) gpu_cache_kvs = {} gpu_cache_k_tensors = [] gpu_cache_v_tensors = [] logger.info(f"[rank {rank}/{args.mp_num}] Initializing kv cache for all layers.") for i in range(args.num_layers + num_extra_layers): - num_gpu_blocks = args.num_gpu_blocks if i < args.num_layers else num_extra_layer_gpu_blocks - cache_shape = [num_gpu_blocks, args.kv_num_head, args.block_size, args.head_dim] - logger.info(f"[rank {rank}/{args.mp_num}] ..creating kv cache for layer {i}: {cache_shape}") + num_gpu_blocks = total_gpu_blocks if i < args.num_layers else num_extra_layer_gpu_blocks + key_cache_shape = [ + num_gpu_blocks, + args.key_cache_shape[1], + args.key_cache_shape[2], + args.key_cache_shape[3], + ] + if args.value_cache_shape: + value_cache_shape = [ + num_gpu_blocks, + args.key_cache_shape[1], + args.key_cache_shape[2], + args.key_cache_shape[3], + ] + logger.info( + f"[rank {rank}/{args.mp_num}] ..creating kv cache for layer {i}: {key_cache_shape} {value_cache_shape}" + ) gpu_cache_kvs[f"key_caches_{i}_rank{rank}_device{device}"] = paddle.full( - shape=cache_shape, + shape=key_cache_shape, fill_value=0, dtype=cache_type, ) gpu_cache_k_tensors.append(gpu_cache_kvs[f"key_caches_{i}_rank{rank}_device{device}"]) - gpu_cache_kvs[f"value_caches_{i}_rank{rank}_device{device}"] = paddle.full( - shape=cache_shape, - fill_value=0, - dtype=cache_type, - ) - gpu_cache_v_tensors.append(gpu_cache_kvs[f"value_caches_{i}_rank{rank}_device{device}"]) - set_data_ipc( gpu_cache_kvs[f"key_caches_{i}_rank{rank}_device{device}"], f"key_caches_{i}_rank{rank}.device{device}", ) - set_data_ipc( - gpu_cache_kvs[f"value_caches_{i}_rank{rank}_device{device}"], - f"value_caches_{i}_rank{rank}.device{device}", - ) + if args.value_cache_shape: + gpu_cache_kvs[f"value_caches_{i}_rank{rank}_device{device}"] = paddle.full( + shape=value_cache_shape, + fill_value=0, + dtype=cache_type, + ) + gpu_cache_v_tensors.append(gpu_cache_kvs[f"value_caches_{i}_rank{rank}_device{device}"]) + + set_data_ipc( + gpu_cache_kvs[f"value_caches_{i}_rank{rank}_device{device}"], + f"value_caches_{i}_rank{rank}.device{device}", + ) cache_kv_size_byte = sum([tmp.numel() * 1 for key, tmp in gpu_cache_kvs.items()]) logger.info(f"device :{device}") logger.info(f"cache_kv_size_byte : {cache_kv_size_byte}") diff --git a/fastdeploy/cache_manager/cache_transfer_manager.py b/fastdeploy/cache_manager/cache_transfer_manager.py index 8694a3787a8..d5d0f96c939 100644 --- a/fastdeploy/cache_manager/cache_transfer_manager.py +++ b/fastdeploy/cache_manager/cache_transfer_manager.py @@ -58,18 +58,18 @@ def parse_args(): parser.add_argument("--rank", type=int, default=0, help="current rank") parser.add_argument("--device_id", type=int, default=0, help="device id") parser.add_argument("--num_layers", type=int, default=1, help="model num layers") - parser.add_argument("--head_dim", type=int, default=1, help="model head dim") - parser.add_argument("--kv_num_head", type=int, default=1, help="model kv num head") - parser.add_argument("--rdma_port", type=str, default="", help="rmda port") parser.add_argument("--mp_num", type=int, default=1, help="number of model parallel") parser.add_argument( - "--protocol", + "--cache_dtype", type=str, - default="ipc", - help="cache transfer protocol, only support ipc now", + default="bfloat16", + choices=["uint8", "bfloat16"], + help="cache dtype", ) - parser.add_argument("--enable_splitwise", type=int, default=0, help="enable splitwise ") + parser.add_argument("--key_cache_shape", type=list, default=[], help="key cache shape") + parser.add_argument("--value_cache_shape", type=list, default=[], help="value cache shape") parser.add_argument("--cache_queue_port", type=int, default=9923, help="cache queue port") + parser.add_argument("--enable_splitwise", type=int, default=0, help="enable splitwise ") parser.add_argument("--pod_ip", type=str, default="0.0.0.0", help="pod ip") parser.add_argument( "--engine_worker_queue_port", @@ -77,31 +77,22 @@ def parse_args(): default=9923, help="engine worker queue port", ) - parser.add_argument("--engine_pid", type=str, default=None, help="engine pid") - - parser.add_argument("--num_gpu_blocks", type=int, default=1, help="gpu cache block number") parser.add_argument("--num_cpu_blocks", type=int, default=4, help="cpu cache block number") - parser.add_argument("--block_size", type=int, default=64, help="cache block size(tokens)") - parser.add_argument( - "--bytes_per_layer_per_block", - type=int, - default=1024, - help="per layer per block bytes", - ) + parser.add_argument("--engine_pid", type=str, default=None, help="engine pid") parser.add_argument( - "--cache_dtype", + "--protocol", type=str, - default="bfloat16", - choices=["uint8", "bfloat16"], - help="cache dtype", + default="ipc", + help="cache transfer protocol, only support ipc now", ) + parser.add_argument("--local_data_parallel_id", type=int, default=0) + parser.add_argument("--rdma_port", type=str, default="", help="rmda port") parser.add_argument( "--speculative_config", type=json.loads, default="{}", help="speculative config", ) - parser.add_argument("--local_data_parallel_id", type=int, default=0) parser.add_argument("--create_cache_tensor", action="store_true") args = parser.parse_args() @@ -124,8 +115,9 @@ def __init__(self, args): self.gpu_cache_k_tensors = [] self.gpu_cache_v_tensors = [] self.speculative_config = SpeculativeConfig(args.speculative_config) + self.num_gpu_blocks = args.key_cache_shape[0] self.num_extra_layers = self.speculative_config.num_extra_cache_layer - self.num_extra_layer_gpu_blocks = int(args.num_gpu_blocks * self.speculative_config.num_gpu_block_expand_ratio) + self.num_extra_layer_gpu_blocks = int(self.num_gpu_blocks * self.speculative_config.num_gpu_block_expand_ratio) self.swap_to_cpu_thread_pool = concurrent.futures.ThreadPoolExecutor(max_workers=1) self.swap_to_gpu_thread_pool = concurrent.futures.ThreadPoolExecutor(max_workers=1) @@ -164,8 +156,9 @@ def __init__(self, args): self.num_cpu_blocks = args.num_cpu_blocks - self._init_cpu_cache(args) self._init_gpu_cache(args) + if self.num_cpu_blocks > 0: + self._init_cpu_cache(args) cache_task_broadcast_data = np.zeros(shape=[1], dtype=np.int32) self.cache_task_broadcast_signal = IPCSignal( @@ -209,28 +202,46 @@ def _init_gpu_cache(self, args): logger.info(f"[rank {self.rank}/{self.n_ranks}] Initializing kv cache for all layers.") set_device(self.device) for i in range(args.num_layers + self.num_extra_layers): - num_gpu_blocks = args.num_gpu_blocks if i < args.num_layers else self.num_extra_layer_gpu_blocks - cache_shape = [num_gpu_blocks, args.kv_num_head, args.block_size, args.head_dim] + num_gpu_blocks = self.num_gpu_blocks if i < args.num_layers else self.num_extra_layer_gpu_blocks key_name = f"key_caches_{i}_rank{self.rank}.device{self.device}" val_name = f"value_caches_{i}_rank{self.rank}.device{self.device}" - + key_cache_shape = [ + num_gpu_blocks, + args.key_cache_shape[1], + args.key_cache_shape[2], + args.key_cache_shape[3], + ] + if args.value_cache_shape: + value_cache_shape = [ + num_gpu_blocks, + args.key_cache_shape[1], + args.key_cache_shape[2], + args.key_cache_shape[3], + ] if args.create_cache_tensor: - logger.info(f"[rank {self.rank}/{self.n_ranks}] ..creating kv cache for layer {i}: {cache_shape}") - key_cache = paddle.full(shape=cache_shape, fill_value=0, dtype=args.cache_dtype) - val_cache = paddle.full(shape=cache_shape, fill_value=0, dtype=args.cache_dtype) + logger.info( + f"[rank {self.rank}/{self.n_ranks}] ..creating kv cache for layer {i}: {key_cache_shape} {value_cache_shape}" + ) + key_cache = paddle.full(shape=key_cache_shape, fill_value=0, dtype=args.cache_dtype) set_data_ipc(key_cache, key_name) - set_data_ipc(val_cache, val_name) + if args.value_cache_shape: + val_cache = paddle.full(shape=value_cache_shape, fill_value=0, dtype=args.cache_dtype) + set_data_ipc(val_cache, val_name) else: - logger.info(f"[rank {self.rank}/{self.n_ranks}] ..attaching kv cache for layer {i}: {cache_shape}") + logger.info( + f"[rank {self.rank}/{self.n_ranks}] ..attaching kv cache for layer {i}: {key_cache_shape} {value_cache_shape}" + ) key_cache = paddle.empty(shape=[], dtype=args.cache_dtype) val_cache = paddle.empty(shape=[], dtype=args.cache_dtype) - key_cache = share_external_data_(key_cache, key_name, cache_shape, True) - val_cache = share_external_data_(val_cache, val_name, cache_shape, True) + key_cache = share_external_data_(key_cache, key_name, key_cache_shape, True) + if args.value_cache_shape: + val_cache = share_external_data_(val_cache, val_name, value_cache_shape, True) self.gpu_cache_kvs[key_name] = key_cache - self.gpu_cache_kvs[val_name] = val_cache self.gpu_cache_k_tensors.append(self.gpu_cache_kvs[key_name]) - self.gpu_cache_v_tensors.append(self.gpu_cache_kvs[val_name]) + if args.value_cache_shape: + self.gpu_cache_kvs[val_name] = val_cache + self.gpu_cache_v_tensors.append(self.gpu_cache_kvs[val_name]) if args.create_cache_tensor: logger.info(f"[rank {self.rank}/{self.n_ranks}] ✅ kv cache is ready!") @@ -242,6 +253,20 @@ def _init_gpu_cache(self, args): logger.info(f"[rank {self.rank}/{self.n_ranks}] done init cache (full) gmem alloc : {memory_allocated()}") def _init_cpu_cache(self, args): + key_cache_size = args.key_cache_shape[1] * args.key_cache_shape[2] * args.key_cache_shape[3] + if args.value_cache_shape: + value_cache_size = args.value_cache_shape[1] * args.value_cache_shape[2] * args.value_cache_shape[3] + else: + value_cache_size = 0 + if args.cache_dtype == "bfloat16": + cache_bytes = 2 + elif args.cache_dtype == "uint8": + cache_bytes = 1 + else: + raise ValueError(f"Unsupported cache dtype: {args.cache_dtype}") + key_need_to_allocate_bytes = args.num_cpu_blocks * cache_bytes * key_cache_size + value_need_to_allocate_bytes = args.num_cpu_blocks * cache_bytes * value_cache_size + # logger.info(f"[rank {self.rank}/{self.n_ranks}] ..swap space size : { / 1024 ** 3:.2f}GB") if args.num_cpu_blocks == 0: logger.info(f"[rank {self.rank}/{self.n_ranks}] 💡 no swap space (cpu cache) is specified.") self.swap_space_ready_signal.value[self.rank] = 1 @@ -253,14 +278,14 @@ def _init_cpu_cache(self, args): for i in range(args.num_layers + self.num_extra_layers): key_name = f"key_caches_{i}_rank{self.rank}" val_name = f"value_caches_{i}_rank{self.rank}" - need_to_allocate_bytes = args.num_cpu_blocks * args.bytes_per_layer_per_block logger.info( - f"[rank {self.rank}/{self.n_ranks}] ..creating cpu cache for layer {i}: {2 * need_to_allocate_bytes / 1024 ** 3:.2f}GB" + f"[rank {self.rank}/{self.n_ranks}] ..creating cpu cache for layer {i}: {(key_need_to_allocate_bytes + value_need_to_allocate_bytes) / 1024 ** 3:.2f}GB" ) - self.cpu_cache_kvs[key_name] = cuda_host_alloc(need_to_allocate_bytes) + self.cpu_cache_kvs[key_name] = cuda_host_alloc(key_need_to_allocate_bytes) self.k_dst_ptrs.append(self.cpu_cache_kvs[key_name]) - self.cpu_cache_kvs[val_name] = cuda_host_alloc(need_to_allocate_bytes) - self.v_dst_ptrs.append(self.cpu_cache_kvs[val_name]) + if value_need_to_allocate_bytes > 0: + self.cpu_cache_kvs[val_name] = cuda_host_alloc(value_need_to_allocate_bytes) + self.v_dst_ptrs.append(self.cpu_cache_kvs[val_name]) logger.info(f"[rank {self.rank}/{self.n_ranks}] ✅ swap space (cpu cache) is ready!") self.swap_space_ready_signal.value[self.rank] = 1 diff --git a/fastdeploy/cache_manager/prefix_cache_manager.py b/fastdeploy/cache_manager/prefix_cache_manager.py index 1d5dc9c33f9..a0d48f30ca1 100644 --- a/fastdeploy/cache_manager/prefix_cache_manager.py +++ b/fastdeploy/cache_manager/prefix_cache_manager.py @@ -63,7 +63,7 @@ def __init__( else: self.enable_splitwise = 0 self.splitwise_role = splitwise_role - + self.config = config self.cache_config = config.cache_config self.speculative_config = config.speculative_config self.local_data_parallel_id = local_data_parallel_id @@ -82,6 +82,8 @@ def __init__( heapq.heapify(self.gpu_free_block_list) heapq.heapify(self.cpu_free_block_list) + self.key_cache_shape = [] + self.val_cache_shape = [] self.node_id_pool = list(range(self.num_gpu_blocks + self.num_cpu_blocks)) self.radix_tree_root = BlockNode(-1, [], 0, 0, -1, 0, None, None, None) @@ -120,6 +122,39 @@ def __init__( main_process_metrics.free_gpu_block_num.set(self.num_gpu_blocks) main_process_metrics.available_gpu_resource.set(1.0) + def _get_kv_cache_shape(self, max_block_num): + from fastdeploy.model_executor.layers.attention import get_attention_backend + + attn_cls = get_attention_backend() + num_heads = self.config.model_config.num_attention_heads // self.config.parallel_config.tensor_parallel_size + kv_num_heads = max( + 1, + int(self.config.model_config.num_key_value_heads) // self.config.parallel_config.tensor_parallel_size, + ) + head_dim = self.config.model_config.head_dim + + kv_cache_quant_type = None + if ( + self.config.quant_config + and hasattr(self.config.quant_config, "kv_cache_quant_type") + and self.config.quant_config.kv_cache_quant_type is not None + ): + kv_cache_quant_type = self.config.quant_config.kv_cache_quant_type + + # Initialize AttentionBackend buffers + encoder_block_shape_q = 64 + decoder_block_shape_q = 16 + key_cache_shape, value_cache_shape = attn_cls( + self.config, + kv_num_heads=kv_num_heads, + num_heads=num_heads, + head_dim=head_dim, + encoder_block_shape_q=encoder_block_shape_q, + decoder_block_shape_q=decoder_block_shape_q, + ).get_kv_cache_shape(max_num_blocks=max_block_num, kv_cache_quant_type=kv_cache_quant_type) + logger.info(f"key_cache_shape {key_cache_shape} value_cache_shape {value_cache_shape}") + return key_cache_shape, value_cache_shape + @property def available_gpu_resource(self): return len(self.gpu_free_block_list) / self.num_gpu_blocks if self.num_gpu_blocks > 0 else 0.0 @@ -161,11 +196,15 @@ def launch_cache_manager( py_path = os.path.join(current_dir_path, filename) cache_messager_processes = [] + key_cache_shape, val_cache_shape = self._get_kv_cache_shape(cache_config.total_block_num) + if self.enable_splitwise: cache_messager_processes = self.launch_cache_messager( cache_config, tensor_parallel_size, device_ids, + key_cache_shape, + val_cache_shape, pod_ip, engine_worker_queue_port, pid_suffix, @@ -174,17 +213,6 @@ def launch_cache_manager( raise RuntimeError("Launch cache messager failed") return [] - if ( - hasattr(cache_config.model_cfg, "num_key_value_heads") - and hasattr(cache_config.model_cfg, "num_key_value_heads") - and cache_config.model_cfg.num_key_value_heads is not None - and int(cache_config.model_cfg.num_key_value_heads) > 0 - ): - kv_num_head = int(cache_config.model_cfg.num_key_value_heads) // tensor_parallel_size - else: - kv_num_head = cache_config.model_cfg.num_attention_heads // tensor_parallel_size - kv_num_head = max(1, kv_num_head) - cache_ready_signal_data = np.zeros(shape=[tensor_parallel_size], dtype=np.int32) self.cache_ready_signal = IPCSignal( name="cache_ready_signal", @@ -223,18 +251,15 @@ def launch_cache_manager( + f" --rank {i}" + f" --splitwise_role {self.splitwise_role}" + f" --num_layers {cache_config.model_cfg.num_hidden_layers}" - + f" --head_dim {cache_config.model_cfg.head_dim}" - + f" --kv_num_head {kv_num_head}" + f" --mp_num {tensor_parallel_size}" + f" --cache_dtype {cache_config.cache_dtype}" + + f" --key_cache_shape '{key_cache_shape}'" + + f" --value_cache_shape '{val_cache_shape}'" + f" --cache_queue_port {cache_config.cache_queue_port}" + f" --enable_splitwise {int(self.enable_splitwise)}" + f" --pod_ip {pod_ip}" + f" --engine_worker_queue_port {engine_worker_queue_port}" - + f" --num_gpu_blocks {cache_config.total_block_num}" + f" --num_cpu_blocks {cache_config.num_cpu_blocks}" - + f" --bytes_per_layer_per_block {cache_config.bytes_per_layer_per_block}" - + f" --block_size {cache_config.block_size}" + f" --engine_pid {pid_suffix}" + f" --protocol {cache_config.cache_transfer_protocol}" + f" --local_data_parallel_id {self.local_data_parallel_id}" @@ -273,22 +298,21 @@ def launch_cache_manager( return all_cache_processes def launch_cache_messager( - self, cache_config, tensor_parallel_size, device_ids, pod_ip, engine_worker_queue_port, pid_suffix + self, + cache_config, + tensor_parallel_size, + device_ids, + key_cache_shape, + value_cache_shape, + pod_ip, + engine_worker_queue_port, + pid_suffix, ): """ launch_cache_messager function used to initialize the cache messager. """ current_dir_path = os.path.split(os.path.abspath(__file__))[0] filename = "cache_messager.py" - if ( - hasattr(cache_config.model_cfg, "num_key_value_heads") - and hasattr(cache_config.model_cfg, "num_key_value_heads") - and cache_config.model_cfg.num_key_value_heads is not None - and int(cache_config.model_cfg.num_key_value_heads) > 0 - ): - kv_num_head = int(cache_config.model_cfg.num_key_value_heads) // tensor_parallel_size - else: - kv_num_head = cache_config.model_cfg.num_attention_heads // tensor_parallel_size cache_ready_signal_data = np.zeros(shape=[tensor_parallel_size], dtype=np.int32) self.cache_ready_signal = IPCSignal( @@ -311,15 +335,13 @@ def launch_cache_messager( + f" --rank {i}" + f" --splitwise_role {self.splitwise_role}" + f" --num_layers {cache_config.model_cfg.num_hidden_layers}" - + f" --head_dim {cache_config.model_cfg.head_dim}" - + f" --kv_num_head {kv_num_head}" + f" --mp_num {tensor_parallel_size}" + f" --cache_dtype {cache_config.cache_dtype}" + + f" --key_cache_shape '{key_cache_shape}'" + + f" --value_cache_shape '{value_cache_shape}'" + f" --pod_ip {pod_ip}" + f" --cache_queue_port {cache_config.cache_queue_port}" + f" --engine_worker_queue_port {engine_worker_queue_port}" - + f" --num_gpu_blocks {cache_config.total_block_num}" - + f" --block_size {cache_config.block_size}" + f" --protocol {cache_config.cache_transfer_protocol}" + f" --local_data_parallel_id {self.local_data_parallel_id}" + f" --engine_pid {pid_suffix}" diff --git a/fastdeploy/model_executor/layers/attention/append_attn_backend.py b/fastdeploy/model_executor/layers/attention/append_attn_backend.py index 23a05590a6d..99976333792 100644 --- a/fastdeploy/model_executor/layers/attention/append_attn_backend.py +++ b/fastdeploy/model_executor/layers/attention/append_attn_backend.py @@ -167,20 +167,15 @@ def get_kv_cache_shape( """ Calculate kv cache shape """ + key_cache_shape = value_cache_shape = [max_num_blocks, self.kv_num_heads, self.block_size, self.head_dim] if kv_cache_quant_type is not None and kv_cache_quant_type == "int4_zp": - return ( + key_cache_shape = value_cache_shape = [ max_num_blocks, self.kv_num_heads, self.block_size, self.head_dim // 2, - ) - else: - return ( - max_num_blocks, - self.kv_num_heads, - self.block_size, - self.head_dim, - ) + ] + return key_cache_shape, value_cache_shape def forward_mixed( self, diff --git a/fastdeploy/model_executor/layers/attention/block_multihead_attn_backend.py b/fastdeploy/model_executor/layers/attention/block_multihead_attn_backend.py index b16a6681753..f40c060b18f 100644 --- a/fastdeploy/model_executor/layers/attention/block_multihead_attn_backend.py +++ b/fastdeploy/model_executor/layers/attention/block_multihead_attn_backend.py @@ -116,20 +116,15 @@ def get_kv_cache_shape( """ Calculate kv cache shape """ + key_cache_shape = value_cache_shape = [max_num_blocks, self.kv_num_heads, self.block_size, self.head_dim] if kv_cache_quant_type is not None and kv_cache_quant_type == "int4_zp": - return ( + key_cache_shape = value_cache_shape = [ max_num_blocks, self.kv_num_heads, self.block_size, self.head_dim // 2, - ) - else: - return ( - max_num_blocks, - self.kv_num_heads, - self.block_size, - self.head_dim, - ) + ] + return key_cache_shape, value_cache_shape def forward_mixed( self, diff --git a/fastdeploy/model_executor/layers/attention/flash_attn_backend.py b/fastdeploy/model_executor/layers/attention/flash_attn_backend.py index ccde8a502b9..71b11407c08 100644 --- a/fastdeploy/model_executor/layers/attention/flash_attn_backend.py +++ b/fastdeploy/model_executor/layers/attention/flash_attn_backend.py @@ -171,20 +171,15 @@ def get_kv_cache_shape( """ Calculate kv cache shape """ + key_cache_shape = value_cache_shape = [max_num_blocks, self.kv_num_heads, self.block_size, self.head_dim] if kv_cache_quant_type is not None and kv_cache_quant_type == "int4_zp": - return ( + key_cache_shape = value_cache_shape = [ max_num_blocks, self.kv_num_heads, self.block_size, self.head_dim // 2, - ) - else: - return ( - max_num_blocks, - self.kv_num_heads, - self.block_size, - self.head_dim, - ) + ] + return key_cache_shape, value_cache_shape def init_attention_metadata(self, forward_meta: ForwardMeta): metadata = FlashAttentionMetadata() diff --git a/fastdeploy/model_executor/layers/attention/iluvatar_attn_backend.py b/fastdeploy/model_executor/layers/attention/iluvatar_attn_backend.py index a913700af7d..8dbe06b76b5 100644 --- a/fastdeploy/model_executor/layers/attention/iluvatar_attn_backend.py +++ b/fastdeploy/model_executor/layers/attention/iluvatar_attn_backend.py @@ -186,12 +186,8 @@ def get_kv_cache_shape( """ Calculate kv cache shape """ - return ( - max_num_blocks, - self.num_kv_heads, - self.block_size, - self.head_dim, - ) + key_cache_shape = value_cache_shape = [max_num_blocks, self.kv_num_heads, self.block_size, self.head_dim] + return key_cache_shape, value_cache_shape def transpose(self, hidden_states): for ids, reverse_ids in zip(self.id_group, self.reverse_id_group): diff --git a/fastdeploy/model_executor/layers/attention/mla_attention_backend.py b/fastdeploy/model_executor/layers/attention/mla_attention_backend.py index d7d18526f93..54e72379eab 100644 --- a/fastdeploy/model_executor/layers/attention/mla_attention_backend.py +++ b/fastdeploy/model_executor/layers/attention/mla_attention_backend.py @@ -242,12 +242,9 @@ def get_kv_cache_shape( """ Calculate kv cache shape for MLA """ - return ( - max_num_blocks, - 1, - self.block_size, - self.kv_lora_rank + self.qk_rope_head_dim, - ) + key_cache_shape = [max_num_blocks, 1, self.block_size, self.kv_lora_rank + self.qk_rope_head_dim] + value_cache_shape = [] + return key_cache_shape, value_cache_shape def forward_extend( self, diff --git a/fastdeploy/model_executor/layers/attention/moba_attention_backend.py b/fastdeploy/model_executor/layers/attention/moba_attention_backend.py index f292ed65518..b89abb357af 100644 --- a/fastdeploy/model_executor/layers/attention/moba_attention_backend.py +++ b/fastdeploy/model_executor/layers/attention/moba_attention_backend.py @@ -126,20 +126,15 @@ def get_kv_cache_shape( """ Calculate kv cache shape """ + key_cache_shape = value_cache_shape = [max_num_blocks, self.kv_num_heads, self.block_size, self.head_dim] if kv_cache_quant_type is not None and kv_cache_quant_type == "int4_zp": - return ( + key_cache_shape = value_cache_shape = [ max_num_blocks, self.kv_num_heads, self.block_size, self.head_dim // 2, - ) - else: - return ( - max_num_blocks, - self.kv_num_heads, - self.block_size, - self.head_dim, - ) + ] + return key_cache_shape, value_cache_shape def forward_mixed( self, diff --git a/fastdeploy/model_executor/layers/attention/xpu_attn_backend.py b/fastdeploy/model_executor/layers/attention/xpu_attn_backend.py index 0073ea3b89d..9b0a01e099e 100644 --- a/fastdeploy/model_executor/layers/attention/xpu_attn_backend.py +++ b/fastdeploy/model_executor/layers/attention/xpu_attn_backend.py @@ -129,16 +129,13 @@ def get_attntion_meta(self) -> AttentionMetadata: def get_kv_cache_shape( self, max_num_blocks: int, - ) -> Tuple[int, int, int, int]: + kv_cache_quant_type: str = None, + ) -> Tuple[list, list]: """ Calculate kv cache shape """ - return ( - max_num_blocks, - self.kv_num_heads, - self.block_size, - self.head_dim, - ) + key_cache_shape = value_cache_shape = [max_num_blocks, self.kv_num_heads, self.block_size, self.head_dim] + return key_cache_shape, value_cache_shape def forward_mixed( self, diff --git a/fastdeploy/worker/gcu_model_runner.py b/fastdeploy/worker/gcu_model_runner.py index cd79e677b10..d9401a51aae 100644 --- a/fastdeploy/worker/gcu_model_runner.py +++ b/fastdeploy/worker/gcu_model_runner.py @@ -639,7 +639,7 @@ def initialize_kv_cache(self, profile: bool = False) -> None: kv_cache_quant_type = self.quant_config.kv_cache_quant_type # Get kv cache shape - kv_cache_shape = self.attn_backends[0].get_kv_cache_shape( + key_cache_shape, value_cache_shape = self.attn_backends[0].get_kv_cache_shape( max_num_blocks=max_block_num, kv_cache_quant_type=kv_cache_quant_type ) # local_rank = self.local_rank % self.parallel_config.tensor_parallel_size @@ -652,12 +652,12 @@ def initialize_kv_cache(self, profile: bool = False) -> None: for i in range(self.model_config.num_hidden_layers): cache_kvs[f"key_caches_{i}"] = paddle.full( - shape=kv_cache_shape, + shape=key_cache_shape, fill_value=0, dtype=cache_type, ) cache_kvs[f"value_caches_{i}"] = paddle.full( - shape=kv_cache_shape, + shape=value_cache_shape, fill_value=0, dtype=cache_type, ) diff --git a/fastdeploy/worker/gpu_model_runner.py b/fastdeploy/worker/gpu_model_runner.py index 67a09f5c5d4..59154c4584e 100644 --- a/fastdeploy/worker/gpu_model_runner.py +++ b/fastdeploy/worker/gpu_model_runner.py @@ -1399,11 +1399,11 @@ def initialize_kv_cache(self, profile: bool = False) -> None: kv_cache_quant_type = self.quant_config.kv_cache_quant_type # Get kv cache shape - kv_cache_shape = self.attn_backends[0].get_kv_cache_shape( + key_cache_shape, value_cache_shape = self.attn_backends[0].get_kv_cache_shape( max_num_blocks=max_block_num, kv_cache_quant_type=kv_cache_quant_type ) if kv_cache_quant_type == "block_wise_fp8": - kv_cache_scale_shape = [kv_cache_shape[0], kv_cache_shape[1], kv_cache_shape[2]] + kv_cache_scale_shape = [key_cache_shape[0], key_cache_shape[1], key_cache_shape[2]] local_rank = self.local_rank % self.parallel_config.tensor_parallel_size cache_ready_signal_data = np.zeros(shape=[self.parallel_config.tensor_parallel_size], dtype=np.int32) @@ -1435,15 +1435,16 @@ def initialize_kv_cache(self, profile: bool = False) -> None: self.mla_cache = envs.FD_ATTENTION_BACKEND == "MLA_ATTN" for i in range(self.model_config.num_hidden_layers): + # init key cache key_cache_name = f"key_caches_{i}_rank{local_rank}.device{self.device_id}" - if not self.mla_cache: + if value_cache_shape: val_cache_name = f"value_caches_{i}_rank{local_rank}.device{self.device_id}" if create_cache_tensor: - logger.info(f"..creating kv cache for layer {i}: {kv_cache_shape}") - key_cache = paddle.full(shape=kv_cache_shape, fill_value=0, dtype=cache_type) + logger.info(f"..creating kv cache for layer {i}: key:{key_cache_shape}, value:{value_cache_shape}") + key_cache = paddle.full(shape=key_cache_shape, fill_value=0, dtype=cache_type) set_data_ipc(key_cache, key_cache_name) - if not self.mla_cache: - val_cache = paddle.full(shape=kv_cache_shape, fill_value=0, dtype=cache_type) + if value_cache_shape: + val_cache = paddle.full(shape=value_cache_shape, fill_value=0, dtype=cache_type) set_data_ipc(val_cache, val_cache_name) cache_kvs_list.extend([key_cache, val_cache]) else: @@ -1452,7 +1453,7 @@ def initialize_kv_cache(self, profile: bool = False) -> None: key_cache_scales = paddle.full( shape=kv_cache_scale_shape, fill_value=0, dtype=paddle.get_default_dtype() ) - if not self.mla_cache: + if value_cache_shape: val_cache_scales = paddle.full( shape=kv_cache_scale_shape, fill_value=0, dtype=paddle.get_default_dtype() ) @@ -1460,12 +1461,12 @@ def initialize_kv_cache(self, profile: bool = False) -> None: else: cache_kvs_list.extend([key_cache_scales]) else: - logger.info(f"..attaching kv cache for layer {i}: {kv_cache_shape}") + logger.info(f"..attaching kv cache for layer {i}: key:{key_cache_shape}, value:{value_cache_shape}") key_cache = paddle.empty(shape=[], dtype=cache_type) - key_cache = share_external_data(key_cache, key_cache_name, kv_cache_shape) - if not self.mla_cache: + key_cache = share_external_data(key_cache, key_cache_name, key_cache_shape) + if value_cache_shape: val_cache = paddle.empty(shape=[], dtype=cache_type) - val_cache = share_external_data(val_cache, val_cache_name, kv_cache_shape) + val_cache = share_external_data(val_cache, val_cache_name, value_cache_shape) cache_kvs_list.extend([key_cache, val_cache]) else: cache_kvs_list.extend([key_cache]) diff --git a/fastdeploy/worker/hpu_model_runner.py b/fastdeploy/worker/hpu_model_runner.py index 14556147b2c..e1cc1e3e705 100644 --- a/fastdeploy/worker/hpu_model_runner.py +++ b/fastdeploy/worker/hpu_model_runner.py @@ -851,17 +851,17 @@ def initialize_kv_cache(self) -> None: cache_kvs = {} max_block_num = self.num_gpu_blocks - kv_cache_shape = self.attn_backends[0].get_kv_cache_shape(max_num_blocks=max_block_num) + key_cache_shape, value_cache_shape = self.attn_backends[0].get_kv_cache_shape(max_num_blocks=max_block_num) for i in range(self.model_config.num_hidden_layers): cache_type = self.model_config.dtype cache_kvs["key_caches_{}".format(i)] = paddle.full( - shape=kv_cache_shape, + shape=key_cache_shape, fill_value=0, dtype=cache_type, ) cache_kvs["value_caches_{}".format(i)] = paddle.full( - shape=kv_cache_shape, + shape=value_cache_shape, fill_value=0, dtype=cache_type, ) diff --git a/fastdeploy/worker/metax_model_runner.py b/fastdeploy/worker/metax_model_runner.py index 7361e717f12..1d2ee6d81f1 100644 --- a/fastdeploy/worker/metax_model_runner.py +++ b/fastdeploy/worker/metax_model_runner.py @@ -1196,11 +1196,11 @@ def initialize_kv_cache(self, profile: bool = False) -> None: kv_cache_quant_type = self.quant_config.kv_cache_quant_type # Get kv cache shape - kv_cache_shape = self.attn_backends[0].get_kv_cache_shape( + key_cache_shape, value_cache_shape = self.attn_backends[0].get_kv_cache_shape( max_num_blocks=max_block_num, kv_cache_quant_type=kv_cache_quant_type ) if kv_cache_quant_type == "block_wise_fp8": - kv_cache_scale_shape = [kv_cache_shape[0], kv_cache_shape[1], kv_cache_shape[2]] + kv_cache_scale_shape = [key_cache_shape[0], key_cache_shape[1], key_cache_shape[2]] local_rank = self.local_rank % self.parallel_config.tensor_parallel_size cache_ready_signal_data = np.zeros(shape=[self.parallel_config.tensor_parallel_size], dtype=np.int32) @@ -1236,11 +1236,11 @@ def initialize_kv_cache(self, profile: bool = False) -> None: if not self.mla_cache: val_cache_name = f"value_caches_{i}_rank{local_rank}.device{self.device_id}" if create_cache_tensor: - logger.info(f"..creating kv cache for layer {i}: {kv_cache_shape}") - key_cache = paddle.full(shape=kv_cache_shape, fill_value=0, dtype=cache_type) + logger.info(f"..creating kv cache for layer {i}: {key_cache_shape} {value_cache_shape}") + key_cache = paddle.full(shape=key_cache_shape, fill_value=0, dtype=cache_type) set_data_ipc(key_cache, key_cache_name) if not self.mla_cache: - val_cache = paddle.full(shape=kv_cache_shape, fill_value=0, dtype=cache_type) + val_cache = paddle.full(shape=value_cache_shape, fill_value=0, dtype=cache_type) set_data_ipc(val_cache, val_cache_name) cache_kvs_list.extend([key_cache, val_cache]) else: @@ -1257,12 +1257,12 @@ def initialize_kv_cache(self, profile: bool = False) -> None: else: cache_kvs_list.extend([key_cache_scales]) else: - logger.info(f"..attaching kv cache for layer {i}: {kv_cache_shape}") + logger.info(f"..attaching kv cache for layer {i}: {key_cache_shape} {value_cache_shape}") key_cache = paddle.empty(shape=[], dtype=cache_type) - key_cache = share_external_data(key_cache, key_cache_name, kv_cache_shape) + key_cache = share_external_data(key_cache, key_cache_name, key_cache_shape) if not self.mla_cache: val_cache = paddle.empty(shape=[], dtype=cache_type) - val_cache = share_external_data(val_cache, val_cache_name, kv_cache_shape) + val_cache = share_external_data(val_cache, val_cache_name, value_cache_shape) cache_kvs_list.extend([key_cache, val_cache]) else: cache_kvs_list.extend([key_cache]) diff --git a/fastdeploy/worker/xpu_model_runner.py b/fastdeploy/worker/xpu_model_runner.py index e7374c005ef..4ab4ee2ff3c 100644 --- a/fastdeploy/worker/xpu_model_runner.py +++ b/fastdeploy/worker/xpu_model_runner.py @@ -966,7 +966,7 @@ def initialize_kv_cache(self, profile: bool = False) -> None: cache_type = "int8" # Get kv cache shape - kv_cache_shape = self.attn_backends[0].get_kv_cache_shape(max_num_blocks=max_block_num) + key_cache_shape, value_cache_shape = self.attn_backends[0].get_kv_cache_shape(max_num_blocks=max_block_num) local_rank = self.local_rank % self.parallel_config.tensor_parallel_size cache_ready_signal_data = np.zeros(shape=[self.parallel_config.tensor_parallel_size], dtype=np.int32) @@ -996,19 +996,19 @@ def initialize_kv_cache(self, profile: bool = False) -> None: val_cache_name = f"value_caches_{i}_rank{local_rank}.device{self.device_id}" if create_cache_tensor: - logger.info(f"..creating kv cache for layer {i}: {kv_cache_shape}") - key_cache = paddle.full(shape=kv_cache_shape, fill_value=0, dtype=cache_type) + logger.info(f"..creating kv cache for layer {i}: {key_cache_shape} {value_cache_shape}") + key_cache = paddle.full(shape=key_cache_shape, fill_value=0, dtype=cache_type) set_data_ipc(key_cache, key_cache_name) - val_cache = paddle.full(shape=kv_cache_shape, fill_value=0, dtype=cache_type) + val_cache = paddle.full(shape=value_cache_shape, fill_value=0, dtype=cache_type) set_data_ipc(val_cache, val_cache_name) cache_kvs_list.extend([key_cache, val_cache]) else: - logger.info(f"..attaching kv cache for layer {i}: {kv_cache_shape}") + logger.info(f"..attaching kv cache for layer {i}: {key_cache_shape} {value_cache_shape}") key_cache = paddle.empty(shape=[], dtype=cache_type) - key_cache = share_external_data(key_cache, key_cache_name, kv_cache_shape, False) + key_cache = share_external_data(key_cache, key_cache_name, key_cache_shape, False) val_cache = paddle.empty(shape=[], dtype=cache_type) - val_cache = share_external_data(val_cache, val_cache_name, kv_cache_shape, False) + val_cache = share_external_data(val_cache, val_cache_name, value_cache_shape, False) cache_kvs_list.extend([key_cache, val_cache]) self.share_inputs["caches"] = cache_kvs_list From c2553c05d883a31381a97312920f5b26cf01c4c4 Mon Sep 17 00:00:00 2001 From: ltd0924 Date: Sat, 8 Nov 2025 22:22:37 +0800 Subject: [PATCH 2/9] fix --- .../layers/attention/iluvatar_attn_backend.py | 10 +++++++++- .../layers/attention/xpu_attn_backend.py | 2 ++ 2 files changed, 11 insertions(+), 1 deletion(-) diff --git a/fastdeploy/model_executor/layers/attention/iluvatar_attn_backend.py b/fastdeploy/model_executor/layers/attention/iluvatar_attn_backend.py index 8dbe06b76b5..191a779bf56 100644 --- a/fastdeploy/model_executor/layers/attention/iluvatar_attn_backend.py +++ b/fastdeploy/model_executor/layers/attention/iluvatar_attn_backend.py @@ -68,7 +68,15 @@ class IluvatarAttnBackend(AttentionBackend): Which is used only for testing purpose. """ - def __init__(self, fd_config: FDConfig, kv_num_heads: int, num_heads: int, head_dim: int): + def __init__( + self, + fd_config: FDConfig, + kv_num_heads: int, + num_heads: int, + head_dim: int, + encoder_block_shape_q: int = -1, + decoder_block_shape_q: int = -1, + ): super().__init__() self.attention_metadata = IluvatarAttentionMetadata() self.block_size = fd_config.cache_config.block_size diff --git a/fastdeploy/model_executor/layers/attention/xpu_attn_backend.py b/fastdeploy/model_executor/layers/attention/xpu_attn_backend.py index 9b0a01e099e..0c2c02a1e42 100644 --- a/fastdeploy/model_executor/layers/attention/xpu_attn_backend.py +++ b/fastdeploy/model_executor/layers/attention/xpu_attn_backend.py @@ -72,6 +72,8 @@ def __init__( kv_num_heads: int, num_heads: int, head_dim: int, + encoder_block_shape_q: int = -1, + decoder_block_shape_q: int = -1, ): """ XPUAttentionBackend __init__ From e98bee6c3d7dea716291c99df85a877c0fa162dd Mon Sep 17 00:00:00 2001 From: ltd0924 Date: Sat, 8 Nov 2025 23:08:41 +0800 Subject: [PATCH 3/9] fix --- fastdeploy/cache_manager/cache_messager.py | 24 +++++++++-------- .../cache_manager/cache_transfer_manager.py | 26 ++++++++++--------- .../cache_manager/prefix_cache_manager.py | 12 +++++---- 3 files changed, 34 insertions(+), 28 deletions(-) diff --git a/fastdeploy/cache_manager/cache_messager.py b/fastdeploy/cache_manager/cache_messager.py index 90397a4d244..9b4d2d0514a 100644 --- a/fastdeploy/cache_manager/cache_messager.py +++ b/fastdeploy/cache_manager/cache_messager.py @@ -52,8 +52,8 @@ def parse_args(): parser.add_argument("--rank", type=int, default=0, help="current rank") parser.add_argument("--device_id", type=int, default=0, help="device id") parser.add_argument("--num_layers", type=int, default=1, help="model num layers") - parser.add_argument("--key_cache_shape", type=list, default=[], help="key cache shape") - parser.add_argument("--value_cache_shape", type=list, default=[], help="value cache shape") + parser.add_argument("--key_cache_shape", type=str, default="", help="key cache shape") + parser.add_argument("--value_cache_shape", type=str, default="", help="value cache shape") parser.add_argument("--rdma_port", type=str, default="", help="rmda port") parser.add_argument("--mp_num", type=int, default=1, help="number of model parallel") parser.add_argument("--engine_pid", type=str, default=None, help="engine pid") @@ -756,7 +756,9 @@ def main(): cache_type = args.cache_dtype speculative_config = SpeculativeConfig(args.speculative_config) num_extra_layers = speculative_config.num_extra_cache_layer - total_gpu_blocks = args.key_cache_shape[0] + key_cache_shape_list = [int(i) for i in args.key_cache_shape.split(",")] + value_cache_shape_list = [int(i) for i in args.value_cache_shape.split(",")] + total_gpu_blocks = key_cache_shape_list[0] num_extra_layer_gpu_blocks = int(total_gpu_blocks * speculative_config.num_gpu_block_expand_ratio) gpu_cache_kvs = {} gpu_cache_k_tensors = [] @@ -767,16 +769,16 @@ def main(): num_gpu_blocks = total_gpu_blocks if i < args.num_layers else num_extra_layer_gpu_blocks key_cache_shape = [ num_gpu_blocks, - args.key_cache_shape[1], - args.key_cache_shape[2], - args.key_cache_shape[3], + key_cache_shape_list[1], + key_cache_shape_list[2], + key_cache_shape_list[3], ] - if args.value_cache_shape: + if value_cache_shape_list: value_cache_shape = [ num_gpu_blocks, - args.key_cache_shape[1], - args.key_cache_shape[2], - args.key_cache_shape[3], + value_cache_shape_list[1], + value_cache_shape_list[2], + value_cache_shape_list[3], ] logger.info( f"[rank {rank}/{args.mp_num}] ..creating kv cache for layer {i}: {key_cache_shape} {value_cache_shape}" @@ -792,7 +794,7 @@ def main(): gpu_cache_kvs[f"key_caches_{i}_rank{rank}_device{device}"], f"key_caches_{i}_rank{rank}.device{device}", ) - if args.value_cache_shape: + if value_cache_shape_list: gpu_cache_kvs[f"value_caches_{i}_rank{rank}_device{device}"] = paddle.full( shape=value_cache_shape, fill_value=0, diff --git a/fastdeploy/cache_manager/cache_transfer_manager.py b/fastdeploy/cache_manager/cache_transfer_manager.py index d5d0f96c939..02f3e7fff2f 100644 --- a/fastdeploy/cache_manager/cache_transfer_manager.py +++ b/fastdeploy/cache_manager/cache_transfer_manager.py @@ -66,8 +66,8 @@ def parse_args(): choices=["uint8", "bfloat16"], help="cache dtype", ) - parser.add_argument("--key_cache_shape", type=list, default=[], help="key cache shape") - parser.add_argument("--value_cache_shape", type=list, default=[], help="value cache shape") + parser.add_argument("--key_cache_shape", type=str, default="", help="key cache shape") + parser.add_argument("--value_cache_shape", type=str, default="", help="value cache shape") parser.add_argument("--cache_queue_port", type=int, default=9923, help="cache queue port") parser.add_argument("--enable_splitwise", type=int, default=0, help="enable splitwise ") parser.add_argument("--pod_ip", type=str, default="0.0.0.0", help="pod ip") @@ -115,7 +115,9 @@ def __init__(self, args): self.gpu_cache_k_tensors = [] self.gpu_cache_v_tensors = [] self.speculative_config = SpeculativeConfig(args.speculative_config) - self.num_gpu_blocks = args.key_cache_shape[0] + self.key_cache_shape = [int(i) for i in args.key_cache_shape.split(",")] + self.value_cache_shape = [int(i) for i in args.value_cache_shape.split(",")] + self.num_gpu_blocks = self.key_cache_shape[0] self.num_extra_layers = self.speculative_config.num_extra_cache_layer self.num_extra_layer_gpu_blocks = int(self.num_gpu_blocks * self.speculative_config.num_gpu_block_expand_ratio) @@ -207,16 +209,16 @@ def _init_gpu_cache(self, args): val_name = f"value_caches_{i}_rank{self.rank}.device{self.device}" key_cache_shape = [ num_gpu_blocks, - args.key_cache_shape[1], - args.key_cache_shape[2], - args.key_cache_shape[3], + self.key_cache_shape[1], + self.key_cache_shape[2], + self.key_cache_shape[3], ] - if args.value_cache_shape: + if self.value_cache_shape: value_cache_shape = [ num_gpu_blocks, - args.key_cache_shape[1], - args.key_cache_shape[2], - args.key_cache_shape[3], + self.value_cache_shape[1], + self.value_cache_shape[2], + self.value_cache_shape[3], ] if args.create_cache_tensor: logger.info( @@ -224,7 +226,7 @@ def _init_gpu_cache(self, args): ) key_cache = paddle.full(shape=key_cache_shape, fill_value=0, dtype=args.cache_dtype) set_data_ipc(key_cache, key_name) - if args.value_cache_shape: + if self.value_cache_shape: val_cache = paddle.full(shape=value_cache_shape, fill_value=0, dtype=args.cache_dtype) set_data_ipc(val_cache, val_name) else: @@ -234,7 +236,7 @@ def _init_gpu_cache(self, args): key_cache = paddle.empty(shape=[], dtype=args.cache_dtype) val_cache = paddle.empty(shape=[], dtype=args.cache_dtype) key_cache = share_external_data_(key_cache, key_name, key_cache_shape, True) - if args.value_cache_shape: + if self.value_cache_shape: val_cache = share_external_data_(val_cache, val_name, value_cache_shape, True) self.gpu_cache_kvs[key_name] = key_cache diff --git a/fastdeploy/cache_manager/prefix_cache_manager.py b/fastdeploy/cache_manager/prefix_cache_manager.py index 0384703c790..0b66bf24b26 100644 --- a/fastdeploy/cache_manager/prefix_cache_manager.py +++ b/fastdeploy/cache_manager/prefix_cache_manager.py @@ -197,7 +197,9 @@ def launch_cache_manager( cache_messager_processes = [] key_cache_shape, val_cache_shape = self._get_kv_cache_shape(cache_config.total_block_num) - + key_cache_shape = ",".join([str(i) for i in key_cache_shape]) + val_cache_shape = ",".join([str(i) for i in val_cache_shape]) + logger.info(f"key_cache_shape {key_cache_shape} value_cache_shape {val_cache_shape}") if self.enable_splitwise: cache_messager_processes = self.launch_cache_messager( cache_config, @@ -253,8 +255,8 @@ def launch_cache_manager( + f" --num_layers {cache_config.model_cfg.num_hidden_layers}" + f" --mp_num {tensor_parallel_size}" + f" --cache_dtype {cache_config.cache_dtype}" - + f" --key_cache_shape '{key_cache_shape}'" - + f" --value_cache_shape '{val_cache_shape}'" + + f" --key_cache_shape {key_cache_shape}" + + f" --value_cache_shape {val_cache_shape}" + f" --cache_queue_port {cache_config.cache_queue_port}" + f" --enable_splitwise {int(self.enable_splitwise)}" + f" --pod_ip {pod_ip}" @@ -337,8 +339,8 @@ def launch_cache_messager( + f" --num_layers {cache_config.model_cfg.num_hidden_layers}" + f" --mp_num {tensor_parallel_size}" + f" --cache_dtype {cache_config.cache_dtype}" - + f" --key_cache_shape '{key_cache_shape}'" - + f" --value_cache_shape '{value_cache_shape}'" + + f" --key_cache_shape {key_cache_shape}" + + f" --value_cache_shape {value_cache_shape}" + f" --pod_ip {pod_ip}" + f" --cache_queue_port {cache_config.cache_queue_port}" + f" --engine_worker_queue_port {engine_worker_queue_port}" From aeae8dd2a8cdae777f78401a15ed63923dfc0dc4 Mon Sep 17 00:00:00 2001 From: ltd0924 Date: Sat, 8 Nov 2025 23:16:02 +0800 Subject: [PATCH 4/9] fix --- fastdeploy/cache_manager/cache_messager.py | 4 +++- fastdeploy/cache_manager/cache_transfer_manager.py | 4 +++- fastdeploy/demo/offline_disaggregated_demo.py | 4 ++-- tests/cache_manager/test_cache_transfer_manager.py | 5 ++--- 4 files changed, 10 insertions(+), 7 deletions(-) diff --git a/fastdeploy/cache_manager/cache_messager.py b/fastdeploy/cache_manager/cache_messager.py index 9b4d2d0514a..06636fb1857 100644 --- a/fastdeploy/cache_manager/cache_messager.py +++ b/fastdeploy/cache_manager/cache_messager.py @@ -757,7 +757,9 @@ def main(): speculative_config = SpeculativeConfig(args.speculative_config) num_extra_layers = speculative_config.num_extra_cache_layer key_cache_shape_list = [int(i) for i in args.key_cache_shape.split(",")] - value_cache_shape_list = [int(i) for i in args.value_cache_shape.split(",")] + value_cache_shape_list = [] + if args.value_cache_shape: + value_cache_shape_list = [int(i) for i in args.value_cache_shape.split(",")] total_gpu_blocks = key_cache_shape_list[0] num_extra_layer_gpu_blocks = int(total_gpu_blocks * speculative_config.num_gpu_block_expand_ratio) gpu_cache_kvs = {} diff --git a/fastdeploy/cache_manager/cache_transfer_manager.py b/fastdeploy/cache_manager/cache_transfer_manager.py index 02f3e7fff2f..3b03c40166b 100644 --- a/fastdeploy/cache_manager/cache_transfer_manager.py +++ b/fastdeploy/cache_manager/cache_transfer_manager.py @@ -116,7 +116,9 @@ def __init__(self, args): self.gpu_cache_v_tensors = [] self.speculative_config = SpeculativeConfig(args.speculative_config) self.key_cache_shape = [int(i) for i in args.key_cache_shape.split(",")] - self.value_cache_shape = [int(i) for i in args.value_cache_shape.split(",")] + self.value_cache_shape = [] + if args.value_cache_shape: + self.value_cache_shape = [int(i) for i in args.value_cache_shape.split(",")] self.num_gpu_blocks = self.key_cache_shape[0] self.num_extra_layers = self.speculative_config.num_extra_cache_layer self.num_extra_layer_gpu_blocks = int(self.num_gpu_blocks * self.speculative_config.num_gpu_block_expand_ratio) diff --git a/fastdeploy/demo/offline_disaggregated_demo.py b/fastdeploy/demo/offline_disaggregated_demo.py index 9dbb5365531..fb1e1dd30ee 100644 --- a/fastdeploy/demo/offline_disaggregated_demo.py +++ b/fastdeploy/demo/offline_disaggregated_demo.py @@ -20,7 +20,7 @@ from fastdeploy.entrypoints.llm import LLM -model_name_or_path = "baidu/ERNIE-4.5-21B-A3B-Paddle" +model_name_or_path = "/root/PaddlePaddle/ERNIE-4.5-0.3B-Paddle" def start_decode(model_name_or_path): @@ -31,7 +31,7 @@ def start_decode(model_name_or_path): tensor_parallel_size=1, splitwise_role="decode", engine_worker_queue_port=6678, - innode_prefill_ports=[6676], + innode_prefill_ports=[6677], cache_queue_port=55668, ) return llm_decode diff --git a/tests/cache_manager/test_cache_transfer_manager.py b/tests/cache_manager/test_cache_transfer_manager.py index 954ba5624a7..96f0b2ada26 100644 --- a/tests/cache_manager/test_cache_transfer_manager.py +++ b/tests/cache_manager/test_cache_transfer_manager.py @@ -22,9 +22,8 @@ class Args: num_cpu_blocks = 1 num_gpu_blocks = 1 num_layers = 1 - head_dim = 1 - kv_num_head = 1 - bytes_per_layer_per_block = 1024 + key_cache_shape = "1,1,1,1" + value_cache_shape = "" create_cache_tensor = False From 5ec4414af74213ffc5e5fffc46efd8155f783080 Mon Sep 17 00:00:00 2001 From: ltd0924 Date: Sun, 9 Nov 2025 00:47:12 +0800 Subject: [PATCH 5/9] fix --- fastdeploy/spec_decode/mtp.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/fastdeploy/spec_decode/mtp.py b/fastdeploy/spec_decode/mtp.py index 58b9a4632d0..d7a9c52ff22 100644 --- a/fastdeploy/spec_decode/mtp.py +++ b/fastdeploy/spec_decode/mtp.py @@ -169,11 +169,11 @@ def initialize_kv_cache(self, main_model_num_blocks, profile: bool = False): kv_cache_quant_type = self.quant_config.kv_cache_quant_type # Get kv cache shape - kv_cache_shape = self.attn_backends[0].get_kv_cache_shape( + key_cache_shape, value_cache_shape = self.attn_backends[0].get_kv_cache_shape( max_num_blocks=self.num_gpu_blocks, kv_cache_quant_type=kv_cache_quant_type ) if kv_cache_quant_type == "block_wise_fp8": - kv_cache_scale_shape = [kv_cache_shape[0], kv_cache_shape[1], kv_cache_shape[2]] + kv_cache_scale_shape = [key_cache_shape[0], key_cache_shape[1], key_cache_shape[2]] local_rank = self.local_rank % self.parallel_config.tensor_parallel_size if not profile and ( self.cache_config.enable_prefix_caching or self.scheduler_config.splitwise_role != "mixed" @@ -186,22 +186,22 @@ def initialize_kv_cache(self, main_model_num_blocks, profile: bool = False): key_cache = paddle.empty(shape=[], dtype=cache_type) key_cache_name = f"key_caches_{i}_rank{local_rank}.device{self.device_id}" val_cache_name = f"value_caches_{i}_rank{local_rank}.device{self.device_id}" - key_cache = share_external_data(key_cache, key_cache_name, kv_cache_shape) + key_cache = share_external_data(key_cache, key_cache_name, key_cache_shape) cache_kvs_list.append(key_cache) value_cache = paddle.empty(shape=[], dtype=cache_type) - value_cache = share_external_data(value_cache, val_cache_name, kv_cache_shape) + value_cache = share_external_data(value_cache, val_cache_name, value_cache_shape) cache_kvs_list.append(value_cache) self.model_inputs["caches"] = cache_kvs_list else: for i in range(self.model_config.num_hidden_layers): self.cache_kvs[f"key_caches_{i}"] = paddle.full( - shape=kv_cache_shape, + shape=key_cache_shape, fill_value=0, dtype=cache_type, ) self.cache_kvs[f"value_caches_{i}"] = paddle.full( - shape=kv_cache_shape, + shape=value_cache_shape, fill_value=0, dtype=cache_type, ) From bd27a03b7f09035699ec1c059eb818e15ccf32be Mon Sep 17 00:00:00 2001 From: ltd0924 <32387785+ltd0924@users.noreply.github.com> Date: Mon, 10 Nov 2025 13:02:57 +0800 Subject: [PATCH 6/9] Update metax_model_runner.py --- fastdeploy/worker/metax_model_runner.py | 11 +++-------- 1 file changed, 3 insertions(+), 8 deletions(-) diff --git a/fastdeploy/worker/metax_model_runner.py b/fastdeploy/worker/metax_model_runner.py index acb4c6e541e..75be437e70a 100644 --- a/fastdeploy/worker/metax_model_runner.py +++ b/fastdeploy/worker/metax_model_runner.py @@ -1226,20 +1226,15 @@ def initialize_kv_cache(self, profile: bool = False) -> None: logger.info(f"Initializing kv cache for all layers. {cache_ready_signal.value}") cache_kvs_list = [] - # NOTE:(changwenbin) Determine whether it is Multi-Head Latent Attention, - # To rationalize the allocation of kvcache. - from fastdeploy import envs - - self.mla_cache = envs.FD_ATTENTION_BACKEND == "MLA_ATTN" for i in range(self.model_config.num_hidden_layers): key_cache_name = f"key_caches_{i}_rank{local_rank}.device{self.device_id}" - if not self.mla_cache: + if value_cache_shape: val_cache_name = f"value_caches_{i}_rank{local_rank}.device{self.device_id}" if create_cache_tensor: logger.info(f"..creating kv cache for layer {i}: {key_cache_shape} {value_cache_shape}") key_cache = paddle.full(shape=key_cache_shape, fill_value=0, dtype=cache_type) set_data_ipc(key_cache, key_cache_name) - if not self.mla_cache: + if value_cache_shape: val_cache = paddle.full(shape=value_cache_shape, fill_value=0, dtype=cache_type) set_data_ipc(val_cache, val_cache_name) cache_kvs_list.extend([key_cache, val_cache]) @@ -1260,7 +1255,7 @@ def initialize_kv_cache(self, profile: bool = False) -> None: logger.info(f"..attaching kv cache for layer {i}: {key_cache_shape} {value_cache_shape}") key_cache = paddle.empty(shape=[], dtype=cache_type) key_cache = share_external_data(key_cache, key_cache_name, key_cache_shape) - if not self.mla_cache: + if value_cache_shape: val_cache = paddle.empty(shape=[], dtype=cache_type) val_cache = share_external_data(val_cache, val_cache_name, value_cache_shape) cache_kvs_list.extend([key_cache, val_cache]) From e8030d0cb1419be406446d73975f9962d21b318f Mon Sep 17 00:00:00 2001 From: ltd0924 Date: Mon, 10 Nov 2025 14:16:43 +0800 Subject: [PATCH 7/9] fix --- fastdeploy/cache_manager/cache_messager.py | 1 + fastdeploy/cache_manager/cache_transfer_manager.py | 9 ++++++--- fastdeploy/demo/offline_disaggregated_demo.py | 2 +- .../layers/attention/iluvatar_attn_backend.py | 2 +- 4 files changed, 9 insertions(+), 5 deletions(-) diff --git a/fastdeploy/cache_manager/cache_messager.py b/fastdeploy/cache_manager/cache_messager.py index 06636fb1857..85dda11cd11 100644 --- a/fastdeploy/cache_manager/cache_messager.py +++ b/fastdeploy/cache_manager/cache_messager.py @@ -775,6 +775,7 @@ def main(): key_cache_shape_list[2], key_cache_shape_list[3], ] + value_cache_shape = [] if value_cache_shape_list: value_cache_shape = [ num_gpu_blocks, diff --git a/fastdeploy/cache_manager/cache_transfer_manager.py b/fastdeploy/cache_manager/cache_transfer_manager.py index 3b03c40166b..c9b6cd83f5b 100644 --- a/fastdeploy/cache_manager/cache_transfer_manager.py +++ b/fastdeploy/cache_manager/cache_transfer_manager.py @@ -215,6 +215,7 @@ def _init_gpu_cache(self, args): self.key_cache_shape[2], self.key_cache_shape[3], ] + value_cache_shape = [] if self.value_cache_shape: value_cache_shape = [ num_gpu_blocks, @@ -257,9 +258,9 @@ def _init_gpu_cache(self, args): logger.info(f"[rank {self.rank}/{self.n_ranks}] done init cache (full) gmem alloc : {memory_allocated()}") def _init_cpu_cache(self, args): - key_cache_size = args.key_cache_shape[1] * args.key_cache_shape[2] * args.key_cache_shape[3] + key_cache_size = self.key_cache_shape[1] * self.key_cache_shape[2] * self.key_cache_shape[3] if args.value_cache_shape: - value_cache_size = args.value_cache_shape[1] * args.value_cache_shape[2] * args.value_cache_shape[3] + value_cache_size = self.value_cache_shape[1] * self.value_cache_shape[2] * self.value_cache_shape[3] else: value_cache_size = 0 if args.cache_dtype == "bfloat16": @@ -270,7 +271,9 @@ def _init_cpu_cache(self, args): raise ValueError(f"Unsupported cache dtype: {args.cache_dtype}") key_need_to_allocate_bytes = args.num_cpu_blocks * cache_bytes * key_cache_size value_need_to_allocate_bytes = args.num_cpu_blocks * cache_bytes * value_cache_size - # logger.info(f"[rank {self.rank}/{self.n_ranks}] ..swap space size : { / 1024 ** 3:.2f}GB") + logger.info( + f"[rank {self.rank}/{self.n_ranks}] ..swap space size : {(key_need_to_allocate_bytes + value_need_to_allocate_bytes) / 1024 ** 3:.2f}GB" + ) if args.num_cpu_blocks == 0: logger.info(f"[rank {self.rank}/{self.n_ranks}] 💡 no swap space (cpu cache) is specified.") self.swap_space_ready_signal.value[self.rank] = 1 diff --git a/fastdeploy/demo/offline_disaggregated_demo.py b/fastdeploy/demo/offline_disaggregated_demo.py index fb1e1dd30ee..26e34794168 100644 --- a/fastdeploy/demo/offline_disaggregated_demo.py +++ b/fastdeploy/demo/offline_disaggregated_demo.py @@ -20,7 +20,7 @@ from fastdeploy.entrypoints.llm import LLM -model_name_or_path = "/root/PaddlePaddle/ERNIE-4.5-0.3B-Paddle" +model_name_or_path = "baidu/ERNIE-4.5-0.3B-Paddle" def start_decode(model_name_or_path): diff --git a/fastdeploy/model_executor/layers/attention/iluvatar_attn_backend.py b/fastdeploy/model_executor/layers/attention/iluvatar_attn_backend.py index 191a779bf56..07c2b2293f5 100644 --- a/fastdeploy/model_executor/layers/attention/iluvatar_attn_backend.py +++ b/fastdeploy/model_executor/layers/attention/iluvatar_attn_backend.py @@ -194,7 +194,7 @@ def get_kv_cache_shape( """ Calculate kv cache shape """ - key_cache_shape = value_cache_shape = [max_num_blocks, self.kv_num_heads, self.block_size, self.head_dim] + key_cache_shape = value_cache_shape = [max_num_blocks, self.num_kv_heads, self.block_size, self.head_dim] return key_cache_shape, value_cache_shape def transpose(self, hidden_states): From 0f5e48166ea56f9739d45070812169754e6f4495 Mon Sep 17 00:00:00 2001 From: ltd0924 Date: Mon, 10 Nov 2025 14:26:44 +0800 Subject: [PATCH 8/9] update --- .../layers/attention/append_attn_backend.py | 11 +++++++++-- .../layers/attention/block_multihead_attn_backend.py | 11 +++++++++-- .../layers/attention/flash_attn_backend.py | 11 +++++++++-- .../layers/attention/iluvatar_attn_backend.py | 3 ++- .../layers/attention/moba_attention_backend.py | 11 +++++++++-- 5 files changed, 38 insertions(+), 9 deletions(-) diff --git a/fastdeploy/model_executor/layers/attention/append_attn_backend.py b/fastdeploy/model_executor/layers/attention/append_attn_backend.py index 99976333792..b1b4e9df9c6 100644 --- a/fastdeploy/model_executor/layers/attention/append_attn_backend.py +++ b/fastdeploy/model_executor/layers/attention/append_attn_backend.py @@ -167,9 +167,16 @@ def get_kv_cache_shape( """ Calculate kv cache shape """ - key_cache_shape = value_cache_shape = [max_num_blocks, self.kv_num_heads, self.block_size, self.head_dim] + key_cache_shape = [max_num_blocks, self.kv_num_heads, self.block_size, self.head_dim] + value_cache_shape = [max_num_blocks, self.kv_num_heads, self.block_size, self.head_dim] if kv_cache_quant_type is not None and kv_cache_quant_type == "int4_zp": - key_cache_shape = value_cache_shape = [ + key_cache_shape = [ + max_num_blocks, + self.kv_num_heads, + self.block_size, + self.head_dim // 2, + ] + value_cache_shape = [ max_num_blocks, self.kv_num_heads, self.block_size, diff --git a/fastdeploy/model_executor/layers/attention/block_multihead_attn_backend.py b/fastdeploy/model_executor/layers/attention/block_multihead_attn_backend.py index f40c060b18f..c2279dd2bcf 100644 --- a/fastdeploy/model_executor/layers/attention/block_multihead_attn_backend.py +++ b/fastdeploy/model_executor/layers/attention/block_multihead_attn_backend.py @@ -116,9 +116,16 @@ def get_kv_cache_shape( """ Calculate kv cache shape """ - key_cache_shape = value_cache_shape = [max_num_blocks, self.kv_num_heads, self.block_size, self.head_dim] + key_cache_shape = [max_num_blocks, self.kv_num_heads, self.block_size, self.head_dim] + value_cache_shape = [max_num_blocks, self.kv_num_heads, self.block_size, self.head_dim] if kv_cache_quant_type is not None and kv_cache_quant_type == "int4_zp": - key_cache_shape = value_cache_shape = [ + key_cache_shape = [ + max_num_blocks, + self.kv_num_heads, + self.block_size, + self.head_dim // 2, + ] + value_cache_shape = [ max_num_blocks, self.kv_num_heads, self.block_size, diff --git a/fastdeploy/model_executor/layers/attention/flash_attn_backend.py b/fastdeploy/model_executor/layers/attention/flash_attn_backend.py index 71b11407c08..bce361eb5dd 100644 --- a/fastdeploy/model_executor/layers/attention/flash_attn_backend.py +++ b/fastdeploy/model_executor/layers/attention/flash_attn_backend.py @@ -171,9 +171,16 @@ def get_kv_cache_shape( """ Calculate kv cache shape """ - key_cache_shape = value_cache_shape = [max_num_blocks, self.kv_num_heads, self.block_size, self.head_dim] + key_cache_shape = [max_num_blocks, self.kv_num_heads, self.block_size, self.head_dim] + value_cache_shape = [max_num_blocks, self.kv_num_heads, self.block_size, self.head_dim] if kv_cache_quant_type is not None and kv_cache_quant_type == "int4_zp": - key_cache_shape = value_cache_shape = [ + key_cache_shape = [ + max_num_blocks, + self.kv_num_heads, + self.block_size, + self.head_dim // 2, + ] + value_cache_shape = [ max_num_blocks, self.kv_num_heads, self.block_size, diff --git a/fastdeploy/model_executor/layers/attention/iluvatar_attn_backend.py b/fastdeploy/model_executor/layers/attention/iluvatar_attn_backend.py index 07c2b2293f5..6fa82573e5f 100644 --- a/fastdeploy/model_executor/layers/attention/iluvatar_attn_backend.py +++ b/fastdeploy/model_executor/layers/attention/iluvatar_attn_backend.py @@ -194,7 +194,8 @@ def get_kv_cache_shape( """ Calculate kv cache shape """ - key_cache_shape = value_cache_shape = [max_num_blocks, self.num_kv_heads, self.block_size, self.head_dim] + key_cache_shape = [max_num_blocks, self.num_kv_heads, self.block_size, self.head_dim] + value_cache_shape = [max_num_blocks, self.num_kv_heads, self.block_size, self.head_dim] return key_cache_shape, value_cache_shape def transpose(self, hidden_states): diff --git a/fastdeploy/model_executor/layers/attention/moba_attention_backend.py b/fastdeploy/model_executor/layers/attention/moba_attention_backend.py index b89abb357af..ea2915f4306 100644 --- a/fastdeploy/model_executor/layers/attention/moba_attention_backend.py +++ b/fastdeploy/model_executor/layers/attention/moba_attention_backend.py @@ -126,9 +126,16 @@ def get_kv_cache_shape( """ Calculate kv cache shape """ - key_cache_shape = value_cache_shape = [max_num_blocks, self.kv_num_heads, self.block_size, self.head_dim] + key_cache_shape = [max_num_blocks, self.kv_num_heads, self.block_size, self.head_dim] + value_cache_shape = [max_num_blocks, self.kv_num_heads, self.block_size, self.head_dim] if kv_cache_quant_type is not None and kv_cache_quant_type == "int4_zp": - key_cache_shape = value_cache_shape = [ + key_cache_shape = [ + max_num_blocks, + self.kv_num_heads, + self.block_size, + self.head_dim // 2, + ] + value_cache_shape = [ max_num_blocks, self.kv_num_heads, self.block_size, From ed2ba1e1417cccfaee7a2d9a1eb662cdf9e289e8 Mon Sep 17 00:00:00 2001 From: ltd0924 <32387785+ltd0924@users.noreply.github.com> Date: Mon, 10 Nov 2025 15:51:10 +0800 Subject: [PATCH 9/9] Update test_moba_attention_backend.py --- tests/layers/test_moba_attention_backend.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/tests/layers/test_moba_attention_backend.py b/tests/layers/test_moba_attention_backend.py index 7b0d24e6187..b5e17181ffa 100644 --- a/tests/layers/test_moba_attention_backend.py +++ b/tests/layers/test_moba_attention_backend.py @@ -92,16 +92,16 @@ def test_get_kv_cache_shape(self): backend = PlasAttentionBackend(fd_config, kv_num_heads=2, num_heads=2, head_dim=8) # Default - shape = backend.get_kv_cache_shape(max_num_blocks=2) - self.assertEqual(shape, (2, 2, 4, 8)) + key_shape, value_shape = backend.get_kv_cache_shape(max_num_blocks=2) + self.assertEqual(key_shape, [2, 2, 4, 8]) # int4_zp quant - shape_int4 = backend.get_kv_cache_shape(max_num_blocks=2, kv_cache_quant_type="int4_zp") - self.assertEqual(shape_int4, (2, 2, 4, 4)) + key_shape_int4, value_shape_int4 = backend.get_kv_cache_shape(max_num_blocks=2, kv_cache_quant_type="int4_zp") + self.assertEqual(key_shape_int4, [2, 2, 4, 4]) # Other quant types - shape_other = backend.get_kv_cache_shape(max_num_blocks=2, kv_cache_quant_type="int8") - self.assertEqual(shape_other, (2, 2, 4, 8)) + key_shape_other, value_shape_other = backend.get_kv_cache_shape(max_num_blocks=2, kv_cache_quant_type="int8") + self.assertEqual(key_shape_other, [2, 2, 4, 8]) @patch( "fastdeploy.model_executor.layers.attention.moba_attention_backend.moba_attention",