diff --git a/ucm/integration/vllm/patch/0.9.2/vllm-adapt-sparse.patch b/ucm/integration/vllm/patch/0.9.2/vllm-adapt-sparse.patch index 5f97d632..eb984875 100644 --- a/ucm/integration/vllm/patch/0.9.2/vllm-adapt-sparse.patch +++ b/ucm/integration/vllm/patch/0.9.2/vllm-adapt-sparse.patch @@ -1,23 +1,34 @@ -From 8c02671e05ed23d7a0c9dc112f8474b26d579f99 Mon Sep 17 00:00:00 2001 -From: harrisonyhq -Date: Wed, 5 Nov 2025 00:22:36 -0800 -Subject: [PATCH 3/3] [Patch2] UCM patch for sparsed attention +From 0431022b90649f7115b89b61aaf2a0f83e896d5a Mon Sep 17 00:00:00 2001 +From: wenxinwang +Date: Mon, 10 Nov 2025 20:35:47 +0800 +Subject: [PATCH] adapt to deepseek patch --- - vllm/attention/layer.py | 43 ++++++++++++++++++ - vllm/v1/core/kv_cache_manager.py | 7 ++- - vllm/v1/core/sched/output.py | 3 ++ - vllm/v1/core/sched/scheduler.py | 26 ++++++++++- - vllm/v1/worker/block_table.py | 13 ++++++ - vllm/v1/worker/gpu_model_runner.py | 70 +++++++++++++++++++++++++----- - vllm/v1/worker/gpu_worker.py | 2 + - 7 files changed, 151 insertions(+), 13 deletions(-) + vllm/attention/layer.py | 49 ++++++++++++- + .../kv_transfer/kv_connector/utils.py | 5 ++ + .../v1/shared_storage_connector.py | 7 +- + vllm/v1/attention/backends/mla/common.py | 10 ++- + vllm/v1/core/kv_cache_manager.py | 7 +- + vllm/v1/core/sched/output.py | 3 + + vllm/v1/core/sched/scheduler.py | 37 +++++++--- + vllm/v1/worker/block_table.py | 13 ++++ + vllm/v1/worker/gpu_model_runner.py | 71 +++++++++++++++---- + vllm/v1/worker/gpu_worker.py | 2 + + 10 files changed, 171 insertions(+), 33 deletions(-) diff --git a/vllm/attention/layer.py b/vllm/attention/layer.py -index f0ad68b16..d55f3d689 100644 +index f0ad68b16..728ab99fd 100644 --- a/vllm/attention/layer.py +++ b/vllm/attention/layer.py -@@ -22,6 +22,7 @@ from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod +@@ -2,7 +2,6 @@ + # SPDX-FileCopyrightText: Copyright contributors to the vLLM project + """Attention layer.""" + from typing import Any, Dict, List, Optional +- + import torch + import torch.nn as nn + import torch.nn.functional as F +@@ -22,6 +21,7 @@ from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod from vllm.platforms import _Backend, current_platform from vllm.utils import direct_register_custom_op from vllm.v1.attention.backends.utils import validate_kv_sharing_target @@ -25,46 +36,49 @@ index f0ad68b16..d55f3d689 100644 class Attention(nn.Module): -@@ -409,9 +410,11 @@ def unified_attention( +@@ -409,9 +409,10 @@ def unified_attention( attn_metadata = attn_metadata[layer_name] self = forward_context.no_compile_layers[layer_name] kv_cache = self.kv_cache[forward_context.virtual_engine] + maybe_execute_sparse_attention_begin(query, key, value, layer_name, forward_context) output = self.impl.forward(self, query, key, value, kv_cache, attn_metadata) - +- + maybe_execute_sparse_attention_finished(query, key, value, output, layer_name, forward_context) maybe_save_kv_layer_to_connector(layer_name, kv_cache) return output -@@ -449,6 +452,7 @@ def unified_attention_with_output( +@@ -449,6 +450,8 @@ def unified_attention_with_output( attn_metadata = attn_metadata[layer_name] self = forward_context.no_compile_layers[layer_name] kv_cache = self.kv_cache[forward_context.virtual_engine] -+ maybe_execute_sparse_attention_begin(query, key, value, layer_name, forward_context) ++ if not self.use_mla: ++ maybe_execute_sparse_attention_begin(query, key, value, layer_name, forward_context) self.impl.forward(self, query, key, -@@ -458,6 +462,7 @@ def unified_attention_with_output( +@@ -457,7 +460,8 @@ def unified_attention_with_output( + attn_metadata, output=output, output_scale=output_scale) - -+ maybe_execute_sparse_attention_finished(query, key, value, output, layer_name, forward_context) +- ++ if not self.use_mla: ++ maybe_execute_sparse_attention_finished(query, key, value, output, layer_name, forward_context) maybe_save_kv_layer_to_connector(layer_name, kv_cache) -@@ -479,3 +484,41 @@ direct_register_custom_op( +@@ -479,3 +483,42 @@ direct_register_custom_op( fake_impl=unified_attention_with_output_fake, dispatch_key=current_platform.dispatch_key, ) + -+ +def maybe_execute_sparse_attention_begin( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + layer_name: str, + forward_context: ForwardContext, ++ phase: Optional[str] = None, +): + if not has_ucm_sparse(): + return @@ -75,7 +89,7 @@ index f0ad68b16..d55f3d689 100644 + if attn_metadata is None: + return + -+ ucm_sparse.attention_begin(query, key, value, layer_name, forward_context) ++ ucm_sparse.attention_begin(query, key, value, layer_name, forward_context, phase) + +def maybe_execute_sparse_attention_finished( + query: torch.Tensor, @@ -84,6 +98,7 @@ index f0ad68b16..d55f3d689 100644 + attn_output: torch.Tensor, + layer_name: str, + forward_context: ForwardContext, ++ phase: Optional[str] = None, +): + if not has_ucm_sparse(): + return @@ -94,8 +109,101 @@ index f0ad68b16..d55f3d689 100644 + if attn_metadata is None: + return + -+ ucm_sparse.attention_finished(query, key, value, attn_output, layer_name, forward_context) -\ No newline at end of file ++ ucm_sparse.attention_finished(query, key, value, attn_output, layer_name, forward_context, phase) +diff --git a/vllm/distributed/kv_transfer/kv_connector/utils.py b/vllm/distributed/kv_transfer/kv_connector/utils.py +index b63bf5965..155597c51 100644 +--- a/vllm/distributed/kv_transfer/kv_connector/utils.py ++++ b/vllm/distributed/kv_transfer/kv_connector/utils.py +@@ -3,6 +3,11 @@ + """ + KV cache helper for store. + """ ++from collections import defaultdict ++from collections.abc import Sequence ++from concurrent.futures import CancelledError, Future ++from typing import Optional, cast ++ + import torch + + from collections import defaultdict +diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/shared_storage_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/shared_storage_connector.py +index 3c574d065..223106def 100644 +--- a/vllm/distributed/kv_transfer/kv_connector/v1/shared_storage_connector.py ++++ b/vllm/distributed/kv_transfer/kv_connector/v1/shared_storage_connector.py +@@ -2,7 +2,7 @@ + # SPDX-FileCopyrightText: Copyright contributors to the vLLM project + import hashlib + import os +-from dataclasses import dataclass ++from dataclasses import dataclass, field + from typing import TYPE_CHECKING + + import safetensors +@@ -53,10 +53,7 @@ class ReqMeta: + + @dataclass + class SharedStorageConnectorMetadata(KVConnectorMetadata): +- requests: list[ReqMeta] +- +- def __init__(self): +- self.requests = [] ++ requests: list[ReqMeta] = field(default_factory=list) + + def add_request( + self, +diff --git a/vllm/v1/attention/backends/mla/common.py b/vllm/v1/attention/backends/mla/common.py +index f2aaf59a4..b56f62b39 100644 +--- a/vllm/v1/attention/backends/mla/common.py ++++ b/vllm/v1/attention/backends/mla/common.py +@@ -200,6 +200,7 @@ from vllm.attention.backends.abstract import (AttentionBackend, AttentionLayer, + MLAAttentionImpl) + from vllm.attention.backends.utils import get_mla_dims + from vllm.attention.ops.merge_attn_states import merge_attn_states ++from vllm.forward_context import ForwardContext, get_forward_context + from vllm.attention.utils.fa_utils import get_flash_attn_version + from vllm.logger import init_logger + from vllm.model_executor.layers.linear import (ColumnParallelLinear, +@@ -211,6 +212,7 @@ from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder, + CommonAttentionMetadata) + from vllm.v1.kv_cache_interface import AttentionSpec + from vllm.v1.worker.block_table import BlockTable ++from vllm.attention.layer import (maybe_execute_sparse_attention_begin, maybe_execute_sparse_attention_finished) + + try: + from vllm.vllm_flash_attn import flash_attn_varlen_func +@@ -908,7 +910,7 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]): + output: Optional[torch.Tensor] = None, + output_scale: Optional[torch.Tensor] = None, + ) -> torch.Tensor: +- ++ forward_context: ForwardContext = get_forward_context() + assert output is not None, "Output tensor must be provided." + + if output_scale is not None: +@@ -957,10 +959,11 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]): + ) + + if has_prefill: ++ maybe_execute_sparse_attention_begin(prefill_q, prefill_k_c_normed, prefill_k_pe, layer.layer_name, forward_context, "prefill") + output[num_decode_tokens:] = self._forward_prefill( + prefill_q, prefill_k_c_normed, prefill_k_pe, kv_cache, + attn_metadata) +- ++ maybe_execute_sparse_attention_finished(prefill_q, prefill_k_c_normed, prefill_k_pe, output[num_decode_tokens:], layer.layer_name, forward_context, "prefill") + if has_decode: + assert attn_metadata.decode is not None + decode_q_nope, decode_q_pe = decode_q.split( +@@ -971,8 +974,9 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]): + decode_ql_nope = torch.bmm(decode_q_nope, self.W_UK_T) + # Convert from (N, B, L) to (B, N, L) + decode_ql_nope = decode_ql_nope.transpose(0, 1) +- ++ maybe_execute_sparse_attention_begin(torch.cat([decode_ql_nope, decode_q_pe],dim=-1), decode_ql_nope, decode_q_pe, layer.layer_name, forward_context, "decode") + output[:num_decode_tokens] = self._forward_decode( + decode_ql_nope, decode_q_pe, kv_cache, attn_metadata) ++ maybe_execute_sparse_attention_finished(torch.cat([decode_ql_nope, decode_q_pe],dim=-1), decode_ql_nope, decode_q_pe, output[:num_decode_tokens], layer.layer_name, forward_context, "decode") + + return output_padded diff --git a/vllm/v1/core/kv_cache_manager.py b/vllm/v1/core/kv_cache_manager.py index 6937455e7..bf9aec864 100644 --- a/vllm/v1/core/kv_cache_manager.py @@ -136,7 +244,7 @@ index 6937455e7..bf9aec864 100644 if new_computed_blocks is not None: new_computed_block_list = new_computed_blocks.blocks diff --git a/vllm/v1/core/sched/output.py b/vllm/v1/core/sched/output.py -index c94e421c0..f6f170e10 100644 +index c94e421c0..fff0eeb42 100644 --- a/vllm/v1/core/sched/output.py +++ b/vllm/v1/core/sched/output.py @@ -157,3 +157,6 @@ class SchedulerOutput: @@ -146,9 +254,8 @@ index c94e421c0..f6f170e10 100644 + + # modified slots by sparse algorithm + req_sparsed_slots: dict[str, int] = None -\ No newline at end of file diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py -index 2d4fd4d59..8268c1409 100644 +index 2d4fd4d59..e99a51788 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -35,6 +35,8 @@ from vllm.v1.request import Request, RequestStatus @@ -230,7 +337,42 @@ index 2d4fd4d59..8268c1409 100644 # finished_req_ids is an existing state in the scheduler, # instead of being newly scheduled in this step. # It contains the request IDs that are finished in between -@@ -955,6 +975,8 @@ class Scheduler(SchedulerInterface): +@@ -809,16 +829,12 @@ class Scheduler(SchedulerInterface): + new_logprobs = None + new_token_ids = generated_token_ids + kv_transfer_params = None ++ + if model_runner_output.finished_dumping is not None: + request.succeed_dumped_blocks.extend(model_runner_output.finished_dumping.get(req_id, [])) + is_prefill = request.num_output_tokens == 0 + if is_prefill: +- if isinstance(self.connector, MultiConnector): +- for c in self.connector._connectors: +- if hasattr(c, 'connector') and hasattr(c.connector, 'commit'): +- c.connector.commit(model_runner_output.finished_dumping.get(req_id, []), True) +- else: +- self.connector.connector.commit(model_runner_output.finished_dumping.get(req_id, []), True) ++ self.connector.connector.commit(model_runner_output.finished_dumping.get(req_id, []), True) + + # Append generated tokens and check for stop. Note that if + # a request is still being prefilled, we expect the model runner +@@ -870,7 +886,6 @@ class Scheduler(SchedulerInterface): + spec_token_ids[req_index]) + else: + request.spec_token_ids = spec_token_ids[req_index] +- + # Get prompt logprobs for this request. + prompt_logprobs_tensors = prompt_logprobs_dict.get(req_id) + if new_token_ids or pooler_output is not None \ +@@ -897,6 +912,7 @@ class Scheduler(SchedulerInterface): + + if not stopped: + new_running.append(request) ++ + self.running = new_running + + # KV Connector: update state for finished KV Transfers. +@@ -955,6 +971,8 @@ class Scheduler(SchedulerInterface): def add_request(self, request: Request) -> None: self.waiting.add_request(request) self.requests[request.request_id] = request @@ -239,7 +381,7 @@ index 2d4fd4d59..8268c1409 100644 if self.log_stats: request.record_event(EngineCoreEventType.QUEUED) -@@ -1004,6 +1026,8 @@ class Scheduler(SchedulerInterface): +@@ -1004,6 +1022,8 @@ class Scheduler(SchedulerInterface): def _free_request(self, request: Request) -> Optional[dict[str, Any]]: assert request.is_finished() @@ -248,6 +390,14 @@ index 2d4fd4d59..8268c1409 100644 delay_free_blocks, kv_xfer_params = self._connector_finished(request) self.encoder_cache_manager.free(request) +@@ -1155,7 +1175,6 @@ class Scheduler(SchedulerInterface): + logger.debug("Finished sending KV transfer for request %s", req_id) + self._free_blocks(self.requests[req_id]) + +- + def _update_requests_with_invalid_blocks( + self, requests: Iterable[Request], + invalid_block_ids: set[int]) -> tuple[set[str], int]: diff --git a/vllm/v1/worker/block_table.py b/vllm/v1/worker/block_table.py index 8f4e8d64c..f45e39f5c 100644 --- a/vllm/v1/worker/block_table.py @@ -280,7 +430,7 @@ index 8f4e8d64c..f45e39f5c 100644 for i, block_table in enumerate(self.block_tables): block_table.add_row(block_ids[i], row_idx) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py -index c3df1d5d2..6341efc70 100644 +index c3df1d5d2..dbf1ea7d7 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -72,6 +72,9 @@ from ..sample.logits_processor import LogitsProcessorManager @@ -347,7 +497,7 @@ index c3df1d5d2..6341efc70 100644 self.input_batch.block_table.append_row(new_block_ids, req_index) # For the last rank, we don't need to update the token_ids_cpu -@@ -639,6 +647,20 @@ class GPUModelRunner(LoRAModelRunnerMixin): +@@ -639,6 +647,19 @@ class GPUModelRunner(LoRAModelRunnerMixin): if self.uses_mrope: self._calc_mrope_positions(scheduler_output) @@ -364,11 +514,10 @@ index c3df1d5d2..6341efc70 100644 + offset = 0 if req_index == 0 else cu_num_tokens[req_index - 1] # TODO: support MTP + if is_sparsed_request: + sparsed_positions[offset] = req_sparsed_slots[req_id] - 1 -+ # Get token indices. # E.g., [0, 1, 0, 1, 2, 3, 4, 0, 1, 2] # -> [0, 1, M, M + 1, M + 2, M + 3, M + 4, 2 * M, 2 * M + 1, 2 * M + 2] -@@ -668,11 +690,11 @@ class GPUModelRunner(LoRAModelRunnerMixin): +@@ -668,11 +689,11 @@ class GPUModelRunner(LoRAModelRunnerMixin): # block_size. block_table_indices = ( req_indices * block_table.max_num_blocks_per_req + @@ -382,7 +531,7 @@ index c3df1d5d2..6341efc70 100644 np.add( block_numbers * block_size, block_offsets, -@@ -682,9 +704,11 @@ class GPUModelRunner(LoRAModelRunnerMixin): +@@ -682,9 +703,11 @@ class GPUModelRunner(LoRAModelRunnerMixin): self.query_start_loc_np[0] = 0 self.query_start_loc_np[1:num_reqs + 1] = cu_num_tokens @@ -397,7 +546,7 @@ index c3df1d5d2..6341efc70 100644 # Copy the tensors to the GPU. self.input_ids[:total_num_scheduled_tokens].copy_( -@@ -696,6 +720,8 @@ class GPUModelRunner(LoRAModelRunnerMixin): +@@ -696,6 +719,8 @@ class GPUModelRunner(LoRAModelRunnerMixin): non_blocking=True) else: # Common case (1D positions) @@ -406,7 +555,7 @@ index c3df1d5d2..6341efc70 100644 self.positions[:total_num_scheduled_tokens].copy_( self.positions_cpu[:total_num_scheduled_tokens], non_blocking=True) -@@ -1386,6 +1412,7 @@ class GPUModelRunner(LoRAModelRunnerMixin): +@@ -1386,6 +1411,7 @@ class GPUModelRunner(LoRAModelRunnerMixin): skip_cuda_graphs=skip_cuda_graphs, ): self.maybe_setup_kv_connector(scheduler_output) @@ -414,7 +563,7 @@ index c3df1d5d2..6341efc70 100644 model_output = self.model( input_ids=input_ids, -@@ -1395,6 +1422,8 @@ class GPUModelRunner(LoRAModelRunnerMixin): +@@ -1395,6 +1421,8 @@ class GPUModelRunner(LoRAModelRunnerMixin): ) finished_dumping = self.maybe_wait_for_kv_save() @@ -423,7 +572,12 @@ index c3df1d5d2..6341efc70 100644 finished_sending, finished_recving = ( self.get_finished_kv_transfers(scheduler_output)) invalid_block_ids = self.get_block_ids_with_load_errors() -@@ -1745,6 +1774,25 @@ class GPUModelRunner(LoRAModelRunnerMixin): +@@ -1741,10 +1769,29 @@ class GPUModelRunner(LoRAModelRunnerMixin): + kv_connector.start_load_kv(get_forward_context()) + + @staticmethod +- def maybe_wait_for_kv_save(): ++ def maybe_wait_for_kv_save() -> Optional[dict[str, list[str]]]: if has_kv_transfer_group(): return get_kv_transfer_group().wait_for_save() diff --git a/ucm/integration/vllm/patch/0.9.2/vllm-ascend-adapt.patch b/ucm/integration/vllm/patch/0.9.2/vllm-ascend-adapt.patch index e15f7ab5..8c459aa7 100644 --- a/ucm/integration/vllm/patch/0.9.2/vllm-ascend-adapt.patch +++ b/ucm/integration/vllm/patch/0.9.2/vllm-ascend-adapt.patch @@ -1,16 +1,17 @@ -From 67b10fc431e5aac0155ca5b77cd9a99e35656521 Mon Sep 17 00:00:00 2001 +From 73de421dd3a9d3877b8903b8ee419e692da62b29 Mon Sep 17 00:00:00 2001 From: wenxinwang -Date: Thu, 25 Sep 2025 05:31:48 -0700 -Subject: [PATCH] UCM adaptor +Date: Mon, 10 Nov 2025 20:44:02 +0800 +Subject: [PATCH] adapt to deepseek --- - vllm_ascend/attention/attention_v1.py | 75 ++++++++++++++++++++ + vllm_ascend/attention/attention_v1.py | 76 ++++++++++++++++++++ + vllm_ascend/attention/mla_v1.py | 14 +++- vllm_ascend/worker/model_runner_v1.py | 99 +++++++++++++++++++++++---- vllm_ascend/worker/worker_v1.py | 25 +++++-- - 3 files changed, 183 insertions(+), 16 deletions(-) + 4 files changed, 196 insertions(+), 18 deletions(-) diff --git a/vllm_ascend/attention/attention_v1.py b/vllm_ascend/attention/attention_v1.py -index 7d7f488..09c4345 100644 +index 7d7f488f..18039f42 100644 --- a/vllm_ascend/attention/attention_v1.py +++ b/vllm_ascend/attention/attention_v1.py @@ -24,6 +24,9 @@ import torch_npu @@ -26,10 +27,10 @@ index 7d7f488..09c4345 100644 @@ -33,6 +36,8 @@ from vllm_ascend.ops.attention import vanilla_chunked_prefill from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_NZ, aligned_16, is_310p, nd_to_nz_2d, nd_to_nz_spec) - + +from ucm.sparse.state import get_ucm_sparse, has_ucm_sparse + - + class AscendAttentionBackend(AttentionBackend): accept_output_buffer: bool = True @@ -444,10 +449,14 @@ def unified_ascend_attention_with_output( @@ -42,19 +43,20 @@ index 7d7f488..09c4345 100644 attn_metadata = forward_context.attn_metadata self = forward_context.no_compile_layers[layer_name] kv_cache = self.kv_cache[forward_context.virtual_engine] -+ -+ maybe_execute_sparse_attention_begin(query, key, value, layer_name, forward_context) ++ if not self.use_mla: ++ maybe_execute_sparse_attention_begin(query, key, value, layer_name, forward_context) self.impl.forward(self, query, key, -@@ -456,8 +465,74 @@ def unified_ascend_attention_with_output( +@@ -456,8 +465,75 @@ def unified_ascend_attention_with_output( attn_metadata, output, trace_flag=False) -+ maybe_execute_sparse_attention_finished(query, key, value, output, layer_name, forward_context) ++ if not self.use_mla: ++ maybe_execute_sparse_attention_finished(query, key, value, output, layer_name, forward_context) + maybe_save_kv_layer_to_connector(layer_name, kv_cache) return - + +def wait_for_kv_layer_from_connector(layer_name: str): + if not has_kv_transfer_group() or not is_v1_kv_transfer_group(): + return @@ -119,11 +121,67 @@ index 7d7f488..09c4345 100644 + return + + ucm_sparse.attention_finished(query, key, value, attn_output, layer_name, forward_context) - + def unified_attention_with_output_fake( query: torch.Tensor, +diff --git a/vllm_ascend/attention/mla_v1.py b/vllm_ascend/attention/mla_v1.py +index f50fe56e..4a27c22f 100644 +--- a/vllm_ascend/attention/mla_v1.py ++++ b/vllm_ascend/attention/mla_v1.py +@@ -13,10 +13,12 @@ from vllm.distributed import get_tensor_model_parallel_world_size + from vllm.model_executor.layers.linear import (LinearBase, + UnquantizedLinearMethod) + from vllm.utils import cdiv, round_down ++from vllm.forward_context import ForwardContext, get_forward_context ++from vllm.attention.layer import (maybe_execute_sparse_attention_begin, maybe_execute_sparse_attention_finished) + + from vllm_ascend.ascend_config import get_ascend_config + from vllm_ascend.attention.attention import _ALLOWED_NUM_QUERIES_PER_KV +-from vllm_ascend.attention.attention_v1 import AscendAttentionState ++from vllm_ascend.attention.attention_v1 import AscendAttentionState, wait_for_kv_layer_from_connector, maybe_save_kv_layer_to_connector + from vllm_ascend.multistream.base import MSAttentionMetadataSplitConfig + from vllm_ascend.multistream.context import get_multistream_comm_context + from vllm_ascend.multistream.ms_split import model_input_split_v1_mla_attn +@@ -1042,6 +1044,7 @@ class AscendMLAImpl(MLAAttentionImpl): + enable_multistream_mla: bool = False, + ckq: Optional[torch.Tensor] = None, + ) -> torch.Tensor: ++ forward_context: ForwardContext = get_forward_context() + assert output is not None, "Output tensor must be provided." + if attn_metadata is None: + # Profiling run. +@@ -1192,6 +1195,8 @@ class AscendMLAImpl(MLAAttentionImpl): + # FIX: aicore move should be also placed on the comm stream in dbo, + # otherwise it may affect the accuracy + # TODO: use an elegant way to overlap ++ wait_for_kv_layer_from_connector(layer.layer_name) ++ maybe_execute_sparse_attention_begin(prefill_q, prefill_k_c_normed, prefill_k_pe, layer.layer_name, forward_context, "prefill") + output_prefill = self._forward_prefill(prefill_q, + prefill_k_c_normed, + prefill_k_pe, kv_cache, +@@ -1203,8 +1208,11 @@ class AscendMLAImpl(MLAAttentionImpl): + current_ms_metadata.after_comm_event.record() + else: + output[num_decode_tokens:] = output_prefill +- ++ maybe_execute_sparse_attention_finished(prefill_q, prefill_k_c_normed, prefill_k_pe, output[num_decode_tokens:], layer.layer_name, forward_context, "prefill") ++ maybe_save_kv_layer_to_connector(layer.layer_name, kv_cache) + if has_decode: ++ wait_for_kv_layer_from_connector(layer.layer_name) ++ maybe_execute_sparse_attention_begin(torch.cat([decode_ql_nope, decode_q_pe],dim=-1), decode_ql_nope, decode_q_pe, layer.layer_name, forward_context, "decode") + if self.running_in_graph: + return self._forward_decode(decode_ql_nope, decode_q_pe, + decode_k_nope, decode_k_pe, +@@ -1223,5 +1231,7 @@ class AscendMLAImpl(MLAAttentionImpl): + current_ms_metadata.after_comm_event.record() + else: + output[:num_decode_tokens] = output_decode ++ maybe_execute_sparse_attention_finished(torch.cat([decode_ql_nope, decode_q_pe],dim=-1), decode_ql_nope, decode_q_pe, output[:num_decode_tokens], layer.layer_name, forward_context, "decode") ++ maybe_save_kv_layer_to_connector(layer.layer_name, kv_cache) + + return output_padded diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py -index eabcdbc..e51f46e 100644 +index eabcdbcc..179dffde 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -39,7 +39,10 @@ from vllm.config import CompilationLevel, VllmConfig @@ -141,7 +199,7 @@ index eabcdbc..e51f46e 100644 @@ -88,6 +91,9 @@ from vllm_ascend.worker.eagle_proposer_v1 import EagleProposer from vllm_ascend.worker.mtp_proposer_v1 import MtpProposer from vllm_ascend.worker.npu_input_batch import CachedRequestState, InputBatch - + +from ucm.sparse.state import get_ucm_sparse, has_ucm_sparse +from ucm.sparse.base import UcmSparseMetadata, INVALID_SLOT + @@ -157,7 +215,7 @@ index eabcdbc..e51f46e 100644 self.encoder_cache.pop(req_id, None) # Remove the finished requests from the persistent batch. @@ -453,12 +460,14 @@ class NPUModelRunner(LoRAModelRunnerMixin): - + # Update the states of the running/resumed requests. req_data = scheduler_output.scheduled_cached_reqs + req_sparsed_slots = scheduler_output.req_sparsed_slots @@ -168,7 +226,7 @@ index eabcdbc..e51f46e 100644 new_block_ids = req_data.new_block_ids[i] resumed_from_preemption = req_data.resumed_from_preemption[i] + is_sparsed_request = req_sparsed_slots[req_id] != INVALID_SLOT - + req_state.num_computed_tokens = num_computed_tokens if not is_last_rank: @@ -474,15 +483,15 @@ class NPUModelRunner(LoRAModelRunnerMixin): @@ -189,18 +247,18 @@ index eabcdbc..e51f46e 100644 - # The request is resumed from preemption. - # Replace the existing block IDs with the new ones. - req_state.block_ids = new_block_ids - + req_index = self.input_batch.req_id_to_index.get(req_id) if req_index is None: @@ -496,6 +505,9 @@ class NPUModelRunner(LoRAModelRunnerMixin): self.input_batch.num_computed_tokens_cpu[req_index] = ( num_computed_tokens) - + + if is_sparsed_request: + self.input_batch.block_table.reset_row(req_index) + self.input_batch.block_table.append_row(new_block_ids, req_index) - + if not is_last_rank: @@ -876,7 +888,8 @@ class NPUModelRunner(LoRAModelRunnerMixin): intermediate_tensors: Optional[IntermediateTensors] = None, @@ -215,7 +273,7 @@ index eabcdbc..e51f46e 100644 @@ -955,12 +968,22 @@ class NPUModelRunner(LoRAModelRunnerMixin): num_scheduled_tokens) seq_lens = self.seq_lens_cpu[:num_reqs] - + + # TODO: improve performance, no `positions_np.copy()` + sparsed_positions = positions_np.copy() + req_sparsed_slots = scheduler_output.req_sparsed_slots @@ -229,7 +287,7 @@ index eabcdbc..e51f46e 100644 block_table_indices = (req_indices * self.max_num_blocks_per_req + - positions_np // self.block_size) + sparsed_positions // self.block_size) - + block_table_cpu = self.input_batch.block_table[0].get_cpu_tensor() block_numbers = block_table_cpu.flatten()[block_table_indices].numpy() - block_offsets = positions_np % self.block_size @@ -240,7 +298,7 @@ index eabcdbc..e51f46e 100644 @@ -985,10 +1008,16 @@ class NPUModelRunner(LoRAModelRunnerMixin): else: attn_state = AscendAttentionState.PrefillCacheHit - + + for req_id in self.input_batch.req_id_to_index: + is_sparsed_request = req_sparsed_slots[req_id] != INVALID_SLOT + req_index = self.input_batch.req_id_to_index[req_id] @@ -254,10 +312,10 @@ index eabcdbc..e51f46e 100644 + position=torch.tensor(sparsed_positions).npu(), attn_state=attn_state) self.attn_state = attn_state # type: ignore - + @@ -1100,6 +1129,7 @@ class NPUModelRunner(LoRAModelRunnerMixin): positions = self.positions[:padded_batch_size] - + # Run forward pass + finished_dumping = None with set_forward_context(attn_metadata, @@ -269,7 +327,7 @@ index eabcdbc..e51f46e 100644 ACL_FORMAT_FRACTAL_ND) + self.maybe_setup_kv_connector(scheduler_output) + self.maybe_execute_ucm_sparse_begin(scheduler_output, attn_metadata) - + hidden_states = self.model( input_ids=input_ids, @@ -1133,6 +1165,8 @@ class NPUModelRunner(LoRAModelRunnerMixin): @@ -278,16 +336,16 @@ index eabcdbc..e51f46e 100644 ) + finished_dumping = self.maybe_wait_for_kv_save() + self.maybe_execute_ucm_sparse_finished() - + use_spec_decode = len( scheduler_output.scheduled_spec_decode_tokens) > 0 @@ -1163,7 +1197,7 @@ class NPUModelRunner(LoRAModelRunnerMixin): - + return (attn_metadata, hidden_states, spec_decode_metadata, positions, total_num_scheduled_tokens, logits_indices, aux_hidden_states, - num_scheduled_tokens) + num_scheduled_tokens, finished_dumping) - + def _get_cumsum_and_arange( self, @@ -1400,7 +1434,7 @@ class NPUModelRunner(LoRAModelRunnerMixin): @@ -297,7 +355,7 @@ index eabcdbc..e51f46e 100644 - num_scheduled_tokens_np) = (self._process_reqs( + num_scheduled_tokens_np, finished_dumping) = (self._process_reqs( scheduler_output, intermediate_tensors)) - + with ProfileExecuteDuration().capture_async("post process"): @@ -1561,6 +1595,7 @@ class NPUModelRunner(LoRAModelRunnerMixin): logprobs=logprobs_lists, @@ -305,7 +363,7 @@ index eabcdbc..e51f46e 100644 pooler_output=[], + finished_dumping=finished_dumping ) - + durations = ProfileExecuteDuration().pop_captured_sync() @@ -2369,3 +2404,43 @@ class NPUModelRunner(LoRAModelRunnerMixin): if batch_size <= padded_batch_size < selected_batch_size: @@ -353,16 +411,16 @@ index eabcdbc..e51f46e 100644 + ucm_sparse.request_finished_in_worker(request_id) \ No newline at end of file diff --git a/vllm_ascend/worker/worker_v1.py b/vllm_ascend/worker/worker_v1.py -index df03d50..a854923 100644 +index df03d508..5d5d9b5a 100644 --- a/vllm_ascend/worker/worker_v1.py +++ b/vllm_ascend/worker/worker_v1.py @@ -17,6 +17,7 @@ # Adapted from vllm-project/vllm/vllm/worker/gpu_worker.py # - + +import copy from typing import Optional - + import torch @@ -27,7 +28,8 @@ from vllm import envs from vllm.config import VllmConfig @@ -381,15 +439,15 @@ index df03d50..a854923 100644 -from vllm.v1.outputs import ModelRunnerOutput +from vllm.v1.outputs import EMPTY_MODEL_RUNNER_OUTPUT, ModelRunnerOutput from vllm.v1.worker.worker_base import WorkerBase - + import vllm_ascend.envs as envs_ascend @@ -49,6 +51,7 @@ from vllm_ascend.utils import (check_kv_cache_bytes_cache_exist, read_kv_cache_bytes_from_file, sleep_mode_enabled, try_register_lib) from vllm_ascend.worker.model_runner_v1 import NPUModelRunner +from ucm.sparse.state import ensure_ucm_sparse_initialized - - + + class NPUWorker(WorkerBase): @@ -222,9 +225,22 @@ class NPUWorker(WorkerBase): assert isinstance(output, IntermediateTensors) @@ -413,7 +471,7 @@ index df03d50..a854923 100644 assert isinstance(output, ModelRunnerOutput) - return output if self.is_driver_worker else None + return output - + def load_model(self) -> None: if self.vllm_config.model_config.enable_sleep_mode: @@ -321,6 +337,7 @@ class NPUWorker(WorkerBase): @@ -421,8 +479,9 @@ index df03d50..a854923 100644 ) ensure_kv_transfer_initialized(self.vllm_config) + ensure_ucm_sparse_initialized(self.vllm_config) - + def _init_profiler(self): # Torch profiler. Enabled and configured through env vars: --- +-- 2.34.1 + diff --git a/ucm/sparse/base.py b/ucm/sparse/base.py index bc152b02..ed62ab30 100644 --- a/ucm/sparse/base.py +++ b/ucm/sparse/base.py @@ -130,6 +130,7 @@ def attention_begin( value: torch.Tensor, layer_name: str, forward_context: ForwardContext, + phase: Optional[str] = None, ) -> None: """ This is called at the beginning of "unified_attention". @@ -146,6 +147,7 @@ def attention_finished( attn_output: torch.Tensor, layer_name: str, forward_context: ForwardContext, + phase: Optional[str] = None, ) -> None: """ This is called at the end of "unified_attention". diff --git a/ucm/sparse/esa/esa.py b/ucm/sparse/esa/esa.py index 5ccaec4a..c7047f87 100644 --- a/ucm/sparse/esa/esa.py +++ b/ucm/sparse/esa/esa.py @@ -290,9 +290,7 @@ def launch_transfer_task(self, transfer_type, block_hashes, vllm_block_ids): self.tasks[task_v_hash] = task_v def extract_block_repre(self, vllm_block_ids): - if not self.is_mla: - return self.k_cache[vllm_block_ids].mean(1) - return self.k_cache[vllm_block_ids].mean(1).unsqueeze(-2) + return self.k_cache[vllm_block_ids].mean(1) def maybe_register_static_data(self, forward_context: ForwardContext): if self.init_static_flag: