Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
234 changes: 194 additions & 40 deletions ucm/integration/vllm/patch/0.9.2/vllm-adapt-sparse.patch
Original file line number Diff line number Diff line change
@@ -1,70 +1,84 @@
From 8c02671e05ed23d7a0c9dc112f8474b26d579f99 Mon Sep 17 00:00:00 2001
From: harrisonyhq <harrisonyhq@gmail.com>
Date: Wed, 5 Nov 2025 00:22:36 -0800
Subject: [PATCH 3/3] [Patch2] UCM patch for sparsed attention
From 0431022b90649f7115b89b61aaf2a0f83e896d5a Mon Sep 17 00:00:00 2001
From: wenxinwang <wangwenxin21@huawei.com>
Date: Mon, 10 Nov 2025 20:35:47 +0800
Subject: [PATCH] adapt to deepseek patch

---
vllm/attention/layer.py | 43 ++++++++++++++++++
vllm/v1/core/kv_cache_manager.py | 7 ++-
vllm/v1/core/sched/output.py | 3 ++
vllm/v1/core/sched/scheduler.py | 26 ++++++++++-
vllm/v1/worker/block_table.py | 13 ++++++
vllm/v1/worker/gpu_model_runner.py | 70 +++++++++++++++++++++++++-----
vllm/v1/worker/gpu_worker.py | 2 +
7 files changed, 151 insertions(+), 13 deletions(-)
vllm/attention/layer.py | 49 ++++++++++++-
.../kv_transfer/kv_connector/utils.py | 5 ++
.../v1/shared_storage_connector.py | 7 +-
vllm/v1/attention/backends/mla/common.py | 10 ++-
vllm/v1/core/kv_cache_manager.py | 7 +-
vllm/v1/core/sched/output.py | 3 +
vllm/v1/core/sched/scheduler.py | 37 +++++++---
vllm/v1/worker/block_table.py | 13 ++++
vllm/v1/worker/gpu_model_runner.py | 71 +++++++++++++++----
vllm/v1/worker/gpu_worker.py | 2 +
10 files changed, 171 insertions(+), 33 deletions(-)

diff --git a/vllm/attention/layer.py b/vllm/attention/layer.py
index f0ad68b16..d55f3d689 100644
index f0ad68b16..728ab99fd 100644
--- a/vllm/attention/layer.py
+++ b/vllm/attention/layer.py
@@ -22,6 +22,7 @@ from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
@@ -2,7 +2,6 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Attention layer."""
from typing import Any, Dict, List, Optional
-
import torch
import torch.nn as nn
import torch.nn.functional as F
@@ -22,6 +21,7 @@ from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
from vllm.platforms import _Backend, current_platform
from vllm.utils import direct_register_custom_op
from vllm.v1.attention.backends.utils import validate_kv_sharing_target
+from ucm.sparse.state import get_ucm_sparse, has_ucm_sparse


class Attention(nn.Module):
@@ -409,9 +410,11 @@ def unified_attention(
@@ -409,9 +409,10 @@ def unified_attention(
attn_metadata = attn_metadata[layer_name]
self = forward_context.no_compile_layers[layer_name]
kv_cache = self.kv_cache[forward_context.virtual_engine]
+ maybe_execute_sparse_attention_begin(query, key, value, layer_name, forward_context)
output = self.impl.forward(self, query, key, value, kv_cache,
attn_metadata)
-
+ maybe_execute_sparse_attention_finished(query, key, value, output, layer_name, forward_context)
maybe_save_kv_layer_to_connector(layer_name, kv_cache)
return output

@@ -449,6 +452,7 @@ def unified_attention_with_output(
@@ -449,6 +450,8 @@ def unified_attention_with_output(
attn_metadata = attn_metadata[layer_name]
self = forward_context.no_compile_layers[layer_name]
kv_cache = self.kv_cache[forward_context.virtual_engine]
+ maybe_execute_sparse_attention_begin(query, key, value, layer_name, forward_context)
+ if not self.use_mla:
+ maybe_execute_sparse_attention_begin(query, key, value, layer_name, forward_context)
self.impl.forward(self,
query,
key,
@@ -458,6 +462,7 @@ def unified_attention_with_output(
@@ -457,7 +460,8 @@ def unified_attention_with_output(
attn_metadata,
output=output,
output_scale=output_scale)

+ maybe_execute_sparse_attention_finished(query, key, value, output, layer_name, forward_context)
-
+ if not self.use_mla:
+ maybe_execute_sparse_attention_finished(query, key, value, output, layer_name, forward_context)
maybe_save_kv_layer_to_connector(layer_name, kv_cache)


@@ -479,3 +484,41 @@ direct_register_custom_op(
@@ -479,3 +483,42 @@ direct_register_custom_op(
fake_impl=unified_attention_with_output_fake,
dispatch_key=current_platform.dispatch_key,
)
+
+
+def maybe_execute_sparse_attention_begin(
+ query: torch.Tensor,
+ key: torch.Tensor,
+ value: torch.Tensor,
+ layer_name: str,
+ forward_context: ForwardContext,
+ phase: Optional[str] = None,
+):
+ if not has_ucm_sparse():
+ return
Expand All @@ -75,7 +89,7 @@ index f0ad68b16..d55f3d689 100644
+ if attn_metadata is None:
+ return
+
+ ucm_sparse.attention_begin(query, key, value, layer_name, forward_context)
+ ucm_sparse.attention_begin(query, key, value, layer_name, forward_context, phase)
+
+def maybe_execute_sparse_attention_finished(
+ query: torch.Tensor,
Expand All @@ -84,6 +98,7 @@ index f0ad68b16..d55f3d689 100644
+ attn_output: torch.Tensor,
+ layer_name: str,
+ forward_context: ForwardContext,
+ phase: Optional[str] = None,
+):
+ if not has_ucm_sparse():
+ return
Expand All @@ -94,8 +109,101 @@ index f0ad68b16..d55f3d689 100644
+ if attn_metadata is None:
+ return
+
+ ucm_sparse.attention_finished(query, key, value, attn_output, layer_name, forward_context)
\ No newline at end of file
+ ucm_sparse.attention_finished(query, key, value, attn_output, layer_name, forward_context, phase)
diff --git a/vllm/distributed/kv_transfer/kv_connector/utils.py b/vllm/distributed/kv_transfer/kv_connector/utils.py
index b63bf5965..155597c51 100644
--- a/vllm/distributed/kv_transfer/kv_connector/utils.py
+++ b/vllm/distributed/kv_transfer/kv_connector/utils.py
@@ -3,6 +3,11 @@
"""
KV cache helper for store.
"""
+from collections import defaultdict
+from collections.abc import Sequence
+from concurrent.futures import CancelledError, Future
+from typing import Optional, cast
+
import torch

from collections import defaultdict
diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/shared_storage_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/shared_storage_connector.py
index 3c574d065..223106def 100644
--- a/vllm/distributed/kv_transfer/kv_connector/v1/shared_storage_connector.py
+++ b/vllm/distributed/kv_transfer/kv_connector/v1/shared_storage_connector.py
@@ -2,7 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import hashlib
import os
-from dataclasses import dataclass
+from dataclasses import dataclass, field
from typing import TYPE_CHECKING

import safetensors
@@ -53,10 +53,7 @@ class ReqMeta:

@dataclass
class SharedStorageConnectorMetadata(KVConnectorMetadata):
- requests: list[ReqMeta]
-
- def __init__(self):
- self.requests = []
+ requests: list[ReqMeta] = field(default_factory=list)

def add_request(
self,
diff --git a/vllm/v1/attention/backends/mla/common.py b/vllm/v1/attention/backends/mla/common.py
index f2aaf59a4..b56f62b39 100644
--- a/vllm/v1/attention/backends/mla/common.py
+++ b/vllm/v1/attention/backends/mla/common.py
@@ -200,6 +200,7 @@ from vllm.attention.backends.abstract import (AttentionBackend, AttentionLayer,
MLAAttentionImpl)
from vllm.attention.backends.utils import get_mla_dims
from vllm.attention.ops.merge_attn_states import merge_attn_states
+from vllm.forward_context import ForwardContext, get_forward_context
from vllm.attention.utils.fa_utils import get_flash_attn_version
from vllm.logger import init_logger
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
@@ -211,6 +212,7 @@ from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder,
CommonAttentionMetadata)
from vllm.v1.kv_cache_interface import AttentionSpec
from vllm.v1.worker.block_table import BlockTable
+from vllm.attention.layer import (maybe_execute_sparse_attention_begin, maybe_execute_sparse_attention_finished)

try:
from vllm.vllm_flash_attn import flash_attn_varlen_func
@@ -908,7 +910,7 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
output: Optional[torch.Tensor] = None,
output_scale: Optional[torch.Tensor] = None,
) -> torch.Tensor:
-
+ forward_context: ForwardContext = get_forward_context()
assert output is not None, "Output tensor must be provided."

if output_scale is not None:
@@ -957,10 +959,11 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
)

if has_prefill:
+ maybe_execute_sparse_attention_begin(prefill_q, prefill_k_c_normed, prefill_k_pe, layer.layer_name, forward_context, "prefill")
output[num_decode_tokens:] = self._forward_prefill(
prefill_q, prefill_k_c_normed, prefill_k_pe, kv_cache,
attn_metadata)
-
+ maybe_execute_sparse_attention_finished(prefill_q, prefill_k_c_normed, prefill_k_pe, output[num_decode_tokens:], layer.layer_name, forward_context, "prefill")
if has_decode:
assert attn_metadata.decode is not None
decode_q_nope, decode_q_pe = decode_q.split(
@@ -971,8 +974,9 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
decode_ql_nope = torch.bmm(decode_q_nope, self.W_UK_T)
# Convert from (N, B, L) to (B, N, L)
decode_ql_nope = decode_ql_nope.transpose(0, 1)
-
+ maybe_execute_sparse_attention_begin(torch.cat([decode_ql_nope, decode_q_pe],dim=-1), decode_ql_nope, decode_q_pe, layer.layer_name, forward_context, "decode")
output[:num_decode_tokens] = self._forward_decode(
decode_ql_nope, decode_q_pe, kv_cache, attn_metadata)
+ maybe_execute_sparse_attention_finished(torch.cat([decode_ql_nope, decode_q_pe],dim=-1), decode_ql_nope, decode_q_pe, output[:num_decode_tokens], layer.layer_name, forward_context, "decode")

return output_padded
diff --git a/vllm/v1/core/kv_cache_manager.py b/vllm/v1/core/kv_cache_manager.py
index 6937455e7..bf9aec864 100644
--- a/vllm/v1/core/kv_cache_manager.py
Expand Down Expand Up @@ -136,7 +244,7 @@ index 6937455e7..bf9aec864 100644
if new_computed_blocks is not None:
new_computed_block_list = new_computed_blocks.blocks
diff --git a/vllm/v1/core/sched/output.py b/vllm/v1/core/sched/output.py
index c94e421c0..f6f170e10 100644
index c94e421c0..fff0eeb42 100644
--- a/vllm/v1/core/sched/output.py
+++ b/vllm/v1/core/sched/output.py
@@ -157,3 +157,6 @@ class SchedulerOutput:
Expand All @@ -146,9 +254,8 @@ index c94e421c0..f6f170e10 100644
+
+ # modified slots by sparse algorithm
+ req_sparsed_slots: dict[str, int] = None
\ No newline at end of file
diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py
index 2d4fd4d59..8268c1409 100644
index 2d4fd4d59..e99a51788 100644
--- a/vllm/v1/core/sched/scheduler.py
+++ b/vllm/v1/core/sched/scheduler.py
@@ -35,6 +35,8 @@ from vllm.v1.request import Request, RequestStatus
Expand Down Expand Up @@ -230,7 +337,42 @@ index 2d4fd4d59..8268c1409 100644
# finished_req_ids is an existing state in the scheduler,
# instead of being newly scheduled in this step.
# It contains the request IDs that are finished in between
@@ -955,6 +975,8 @@ class Scheduler(SchedulerInterface):
@@ -809,16 +829,12 @@ class Scheduler(SchedulerInterface):
new_logprobs = None
new_token_ids = generated_token_ids
kv_transfer_params = None
+
if model_runner_output.finished_dumping is not None:
request.succeed_dumped_blocks.extend(model_runner_output.finished_dumping.get(req_id, []))
is_prefill = request.num_output_tokens == 0
if is_prefill:
- if isinstance(self.connector, MultiConnector):
- for c in self.connector._connectors:
- if hasattr(c, 'connector') and hasattr(c.connector, 'commit'):
- c.connector.commit(model_runner_output.finished_dumping.get(req_id, []), True)
- else:
- self.connector.connector.commit(model_runner_output.finished_dumping.get(req_id, []), True)
+ self.connector.connector.commit(model_runner_output.finished_dumping.get(req_id, []), True)

# Append generated tokens and check for stop. Note that if
# a request is still being prefilled, we expect the model runner
@@ -870,7 +886,6 @@ class Scheduler(SchedulerInterface):
spec_token_ids[req_index])
else:
request.spec_token_ids = spec_token_ids[req_index]
-
# Get prompt logprobs for this request.
prompt_logprobs_tensors = prompt_logprobs_dict.get(req_id)
if new_token_ids or pooler_output is not None \
@@ -897,6 +912,7 @@ class Scheduler(SchedulerInterface):

if not stopped:
new_running.append(request)
+
self.running = new_running

# KV Connector: update state for finished KV Transfers.
@@ -955,6 +971,8 @@ class Scheduler(SchedulerInterface):
def add_request(self, request: Request) -> None:
self.waiting.add_request(request)
self.requests[request.request_id] = request
Expand All @@ -239,7 +381,7 @@ index 2d4fd4d59..8268c1409 100644
if self.log_stats:
request.record_event(EngineCoreEventType.QUEUED)

@@ -1004,6 +1026,8 @@ class Scheduler(SchedulerInterface):
@@ -1004,6 +1022,8 @@ class Scheduler(SchedulerInterface):

def _free_request(self, request: Request) -> Optional[dict[str, Any]]:
assert request.is_finished()
Expand All @@ -248,6 +390,14 @@ index 2d4fd4d59..8268c1409 100644

delay_free_blocks, kv_xfer_params = self._connector_finished(request)
self.encoder_cache_manager.free(request)
@@ -1155,7 +1175,6 @@ class Scheduler(SchedulerInterface):
logger.debug("Finished sending KV transfer for request %s", req_id)
self._free_blocks(self.requests[req_id])

-
def _update_requests_with_invalid_blocks(
self, requests: Iterable[Request],
invalid_block_ids: set[int]) -> tuple[set[str], int]:
diff --git a/vllm/v1/worker/block_table.py b/vllm/v1/worker/block_table.py
index 8f4e8d64c..f45e39f5c 100644
--- a/vllm/v1/worker/block_table.py
Expand Down Expand Up @@ -280,7 +430,7 @@ index 8f4e8d64c..f45e39f5c 100644
for i, block_table in enumerate(self.block_tables):
block_table.add_row(block_ids[i], row_idx)
diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py
index c3df1d5d2..6341efc70 100644
index c3df1d5d2..dbf1ea7d7 100644
--- a/vllm/v1/worker/gpu_model_runner.py
+++ b/vllm/v1/worker/gpu_model_runner.py
@@ -72,6 +72,9 @@ from ..sample.logits_processor import LogitsProcessorManager
Expand Down Expand Up @@ -347,7 +497,7 @@ index c3df1d5d2..6341efc70 100644
self.input_batch.block_table.append_row(new_block_ids, req_index)

# For the last rank, we don't need to update the token_ids_cpu
@@ -639,6 +647,20 @@ class GPUModelRunner(LoRAModelRunnerMixin):
@@ -639,6 +647,19 @@ class GPUModelRunner(LoRAModelRunnerMixin):
if self.uses_mrope:
self._calc_mrope_positions(scheduler_output)

Expand All @@ -364,11 +514,10 @@ index c3df1d5d2..6341efc70 100644
+ offset = 0 if req_index == 0 else cu_num_tokens[req_index - 1] # TODO: support MTP
+ if is_sparsed_request:
+ sparsed_positions[offset] = req_sparsed_slots[req_id] - 1
+
# Get token indices.
# E.g., [0, 1, 0, 1, 2, 3, 4, 0, 1, 2]
# -> [0, 1, M, M + 1, M + 2, M + 3, M + 4, 2 * M, 2 * M + 1, 2 * M + 2]
@@ -668,11 +690,11 @@ class GPUModelRunner(LoRAModelRunnerMixin):
@@ -668,11 +689,11 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# block_size.
block_table_indices = (
req_indices * block_table.max_num_blocks_per_req +
Expand All @@ -382,7 +531,7 @@ index c3df1d5d2..6341efc70 100644
np.add(
block_numbers * block_size,
block_offsets,
@@ -682,9 +704,11 @@ class GPUModelRunner(LoRAModelRunnerMixin):
@@ -682,9 +703,11 @@ class GPUModelRunner(LoRAModelRunnerMixin):
self.query_start_loc_np[0] = 0
self.query_start_loc_np[1:num_reqs + 1] = cu_num_tokens

Expand All @@ -397,7 +546,7 @@ index c3df1d5d2..6341efc70 100644

# Copy the tensors to the GPU.
self.input_ids[:total_num_scheduled_tokens].copy_(
@@ -696,6 +720,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
@@ -696,6 +719,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
non_blocking=True)
else:
# Common case (1D positions)
Expand All @@ -406,15 +555,15 @@ index c3df1d5d2..6341efc70 100644
self.positions[:total_num_scheduled_tokens].copy_(
self.positions_cpu[:total_num_scheduled_tokens],
non_blocking=True)
@@ -1386,6 +1412,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
@@ -1386,6 +1411,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
skip_cuda_graphs=skip_cuda_graphs,
):
self.maybe_setup_kv_connector(scheduler_output)
+ self.maybe_execute_ucm_sparse_begin(scheduler_output, attn_metadata)

model_output = self.model(
input_ids=input_ids,
@@ -1395,6 +1422,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
@@ -1395,6 +1421,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
)

finished_dumping = self.maybe_wait_for_kv_save()
Expand All @@ -423,7 +572,12 @@ index c3df1d5d2..6341efc70 100644
finished_sending, finished_recving = (
self.get_finished_kv_transfers(scheduler_output))
invalid_block_ids = self.get_block_ids_with_load_errors()
@@ -1745,6 +1774,25 @@ class GPUModelRunner(LoRAModelRunnerMixin):
@@ -1741,10 +1769,29 @@ class GPUModelRunner(LoRAModelRunnerMixin):
kv_connector.start_load_kv(get_forward_context())

@staticmethod
- def maybe_wait_for_kv_save():
+ def maybe_wait_for_kv_save() -> Optional[dict[str, list[str]]]:
if has_kv_transfer_group():
return get_kv_transfer_group().wait_for_save()

Expand Down
Loading