From 49e106686d63c0083e6d95074a0e38ce046971b8 Mon Sep 17 00:00:00 2001 From: xiaozude Date: Mon, 24 Nov 2025 13:41:04 +0800 Subject: [PATCH] [Metax] support ENABLE_V1_KVCACHE_SCHEDULER --- fastdeploy/engine/args_utils.py | 2 +- .../metax/attention/mla_attn_metax_backend.py | 8 ++++++-- fastdeploy/worker/metax_worker.py | 20 ++++++++++--------- fastdeploy/worker/worker_process.py | 2 +- requirements_metaxgpu.txt | 6 +++++- 5 files changed, 24 insertions(+), 14 deletions(-) diff --git a/fastdeploy/engine/args_utils.py b/fastdeploy/engine/args_utils.py index 05781352d77..26cb4e16ba1 100644 --- a/fastdeploy/engine/args_utils.py +++ b/fastdeploy/engine/args_utils.py @@ -523,7 +523,7 @@ def __post_init__(self): f"= {expected_ports}, but got {len(self.rdma_comm_ports)}." ) - if not current_platform.is_cuda() and not current_platform.is_xpu(): + if not (current_platform.is_cuda() or current_platform.is_xpu() or current_platform.is_maca()): envs.ENABLE_V1_KVCACHE_SCHEDULER = 0 if self.guided_decoding_backend != "off": envs.ENABLE_V1_KVCACHE_SCHEDULER = 0 diff --git a/fastdeploy/model_executor/layers/backends/metax/attention/mla_attn_metax_backend.py b/fastdeploy/model_executor/layers/backends/metax/attention/mla_attn_metax_backend.py index 8800d497eae..9d4913425af 100644 --- a/fastdeploy/model_executor/layers/backends/metax/attention/mla_attn_metax_backend.py +++ b/fastdeploy/model_executor/layers/backends/metax/attention/mla_attn_metax_backend.py @@ -141,8 +141,11 @@ def __init__( self.flash_attn_func = flash_attn_unpadded_func self.flash_attn_kwargs = {"softmax_scale": self.attn_softmax_scale} + @paddle.no_grad() def init_attention_metadata(self, forward_meta: ForwardMeta): """Initialize attention metadata hence all layers in the forward pass can reuse it.""" + paddle.device.empty_cache() + metadata = MLAAttentionMetadata() metadata.max_partition_size = 32768 metadata.encoder_max_partition_size = self.max_seq_len @@ -203,8 +206,6 @@ def init_attention_metadata(self, forward_meta: ForwardMeta): self.seq_lens = seq_lens_decoder + seq_lens_this_time self.block_tables = forward_meta.block_tables[non_zero_index] - paddle.device.empty_cache() - def get_attntion_meta(self) -> AttentionMetadata: """get_attntion_meta""" return self.attention_metadata @@ -290,6 +291,8 @@ def forward_extend( """ Prefill阶段的前向传播 """ + paddle.device.empty_cache() + metadata = self.attention_metadata latent_cache = forward_meta.caches[layer.layer_id] if hasattr(forward_meta, "caches") else None @@ -364,6 +367,7 @@ def forward_decode( return fmha_out + @paddle.no_grad() def forward_mixed( self, q: paddle.Tensor, diff --git a/fastdeploy/worker/metax_worker.py b/fastdeploy/worker/metax_worker.py index 675c2a9e0ea..b57b4e3dd14 100644 --- a/fastdeploy/worker/metax_worker.py +++ b/fastdeploy/worker/metax_worker.py @@ -103,12 +103,12 @@ def determine_available_memory(self) -> int: Gb = 1024**3 local_rank = self.local_rank % self.max_chips_per_node - paddle.device.cuda.reset_max_memory_reserved(local_rank) - paddle.device.cuda.reset_max_memory_allocated(local_rank) + paddle.device.reset_max_memory_reserved(local_rank) + paddle.device.reset_max_memory_allocated(local_rank) # max memory for Allocator - paddle_reserved_mem_before_run = paddle.device.cuda.max_memory_reserved(local_rank) + paddle_reserved_mem_before_run = paddle.device.max_memory_reserved(local_rank) # max memory for Tensor - paddle_allocated_mem_before_run = paddle.device.cuda.max_memory_allocated(local_rank) # not reserved + paddle_allocated_mem_before_run = paddle.device.max_memory_allocated(local_rank) # not reserved device_id = int(self.device_ids[local_rank]) if os.getenv("MACA_VISIBLE_DEVICES") is not None: @@ -132,13 +132,13 @@ def determine_available_memory(self) -> int: self.model_runner.profile_run() # 3. Statistical memory information - paddle_reserved_mem_after_run = paddle.device.cuda.max_memory_reserved(local_rank) - paddle_allocated_mem_after_run = paddle.device.cuda.max_memory_allocated(local_rank) + paddle_reserved_mem_after_run = paddle.device.max_memory_reserved(local_rank) + paddle_allocated_mem_after_run = paddle.device.max_memory_allocated(local_rank) model_block_memory_used = self.cal_theortical_kvcache() paddle_peak_increase = paddle_reserved_mem_after_run - paddle_allocated_mem_before_run - paddle.device.cuda.empty_cache() + paddle.device.empty_cache() info = pymxsml.mxSmlGetMemoryInfo(device_id) after_run_meminfo_total = info.vramTotal * 1024 @@ -146,8 +146,10 @@ def determine_available_memory(self) -> int: after_run_meminfo_free = after_run_meminfo_total - after_run_meminfo_used available_kv_cache_memory = ( - after_run_meminfo_free - paddle_peak_increase - ) * self.cache_config.gpu_memory_utilization + after_run_meminfo_total * self.cache_config.gpu_memory_utilization + - after_run_meminfo_used + - paddle_peak_increase + ) available_kv_cache_memory += model_block_memory_used * self.cache_config.total_block_num end_time = time.perf_counter() diff --git a/fastdeploy/worker/worker_process.py b/fastdeploy/worker/worker_process.py index b30db63a7ec..09f765c25b7 100644 --- a/fastdeploy/worker/worker_process.py +++ b/fastdeploy/worker/worker_process.py @@ -929,7 +929,7 @@ def initialize_fd_config(args, ranks: int = 1, local_rank: int = 0) -> FDConfig: logger.info(f"- Dynamic load weight: {load_config.dynamic_load_weight}") logger.info(f"- Load strategy: {load_config.load_strategy}") - if not current_platform.is_cuda() and not current_platform.is_xpu(): + if not (current_platform.is_cuda() or current_platform.is_xpu() or current_platform.is_maca()): logger.info("Set ENABLE_V1_KVCACHE_SCHEDULER to 0 due to not supported.") envs.ENABLE_V1_KVCACHE_SCHEDULER = 0 if structured_outputs_config.guided_decoding_backend != "off": diff --git a/requirements_metaxgpu.txt b/requirements_metaxgpu.txt index cd7d96a0f75..f04659410a9 100644 --- a/requirements_metaxgpu.txt +++ b/requirements_metaxgpu.txt @@ -10,7 +10,7 @@ tqdm pynvml uvicorn==0.29.0 fastapi -paddleformers>=0.2 +paddleformers==0.3.2 redis etcd3 httpx @@ -42,3 +42,7 @@ opentelemetry-exporter-otlp opentelemetry-instrumentation-fastapi opentelemetry-instrumentation-logging partial_json_parser +msgspec +einops +setproctitle +aistudio_sdk