Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 28 additions & 0 deletions ucm/integration/vllm/patch/patch_funcs/v092/vllm_ascend_patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -506,6 +506,7 @@ def _patch_model_runner_v1() -> None:
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.distributed.kv_transfer import (
get_kv_transfer_group,
has_kv_transfer_group,
)
from vllm.distributed.kv_transfer.kv_connector.v1 import KVConnectorBase_V1
from vllm.forward_context import get_forward_context, set_forward_context
Expand Down Expand Up @@ -1044,6 +1045,7 @@ def _process_reqs(
maybe_converting_weight_acl_format(
self.model, ACL_FORMAT_FRACTAL_ND
)
self.maybe_setup_kv_connector(scheduler_output)
self.maybe_execute_ucm_sparse_begin(
scheduler_output, attn_metadata
)
Expand All @@ -1055,6 +1057,7 @@ def _process_reqs(
inputs_embeds=inputs_embeds,
**model_kwargs,
)
self.maybe_wait_for_kv_save()
self.maybe_execute_ucm_sparse_finished()

use_spec_decode = len(scheduler_output.scheduled_spec_decode_tokens) > 0
Expand Down Expand Up @@ -1310,6 +1313,31 @@ def execute_model(

NPUModelRunner.execute_model = execute_model

@staticmethod
def maybe_setup_kv_connector(scheduler_output: "SchedulerOutput"):
# Update KVConnector with the KVConnector metadata forward().
if has_kv_transfer_group():
kv_connector = get_kv_transfer_group()
assert isinstance(kv_connector, KVConnectorBase_V1)
assert scheduler_output.kv_connector_metadata is not None
kv_connector.bind_connector_metadata(
scheduler_output.kv_connector_metadata
)
# Background KV cache transfers happen here.
# These transfers are designed to be async and the requests
# involved may be disjoint from the running requests.
# Do this here to save a collective_rpc.
kv_connector.start_load_kv(get_forward_context())

NPUModelRunner.maybe_setup_kv_connector = maybe_setup_kv_connector

@staticmethod
def maybe_wait_for_kv_save():
if has_kv_transfer_group():
get_kv_transfer_group().wait_for_save()

NPUModelRunner.maybe_wait_for_kv_save = maybe_wait_for_kv_save

def maybe_execute_ucm_sparse_begin(
self,
scheduler_output: "SchedulerOutput",
Expand Down