Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 14 additions & 1 deletion ucm/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,17 @@
from ucm.integration.vllm.uc_connector import UnifiedCacheConnectorV1
from ucm.integration.vllm.ucm_connector import UCMConnector

__all__ = ["UnifiedCacheConnectorV1", "UCMConnector"]
try:
from ucm.integration.vllm.patch.apply_patch import ensure_patches_applied

ensure_patches_applied()
except Exception as e:
# Don't fail if patches can't be applied - might be running in environment without vLLM
import warnings

warnings.warn(
f"Failed to apply vLLM patches: {e}. "
f"If you're using vLLM, ensure it's installed and patches are compatible."
)

__all__ = ["UCMConnector"]
36 changes: 12 additions & 24 deletions ucm/integration/vllm/patch/apply_patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,19 +88,19 @@ def apply_all_patches() -> None:
supported_versions = get_supported_versions()
if version not in supported_versions:
logger.warning(
f"vLLM version {version} is not explicitly supported. "
f"vLLM version {version} is not explicitly supported to apply UCM patches. "
f"Supported versions: {', '.join(supported_versions)}. "
f"Attempting to apply 0.9.2 patches..."
)
raise ValueError(f"vLLM version {version} is not explicitly supported")

# Apply version-specific patches
if version == "0.9.1":
_apply_patches_v091()
elif version == "0.9.2":
_apply_patches_v092()
else:
raise ValueError(f"Unsupported vLLM version: {version}")
match version:
case "0.9.2":
_apply_patches_v092()
case _:
logger.warning(
f"Unsupported vLLM version: {version} to apply UCM patches. "
f"Supported versions: {', '.join(supported_versions)}."
)

_patches_applied = True
logger.info(f"All vLLM patches applied successfully for version {version}")
Expand All @@ -109,25 +109,13 @@ def apply_all_patches() -> None:
raise


def _apply_patches_v091() -> None:
"""Apply patches for vLLM 0.9.1."""
from .patch_funcs.v091.vllm_adapt import _apply_adapt_patch

_apply_adapt_patch() # apply vllm-adapt-pc.patch
if _patch_ascend():
from .patch_funcs.v091.vllm_ascend_adapt import _apply_ascend_patch

_apply_ascend_patch() # apply vllm-ascend-adapt.patch


def _apply_patches_v092() -> None:
"""Apply patches for vLLM 0.9.2."""
from .patch_funcs.v092.vllm_adapt import _apply_adapt_patches

_apply_adapt_patches()
from .patch_funcs.v092.vllm_patch import _apply_sparse_adapt

_apply_sparse_adapt() # apply vllm-sparse-adapt.patch
if _patch_ascend():
from .patch_funcs.v092.vllm_ascend_adapt import _apply_ascend_patch
from .patch_funcs.v092.vllm_ascend_patch import _apply_ascend_patch

_apply_ascend_patch() # apply vllm-ascend-adapt.patch

Expand Down
Empty file.
28 changes: 0 additions & 28 deletions ucm/integration/vllm/patch/patch_funcs/v091/vllm_adapt.py

This file was deleted.

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -44,44 +44,11 @@ def _patch_attention_v1() -> None:
from typing import List

import torch
from vllm.distributed.kv_transfer import (
get_kv_transfer_group,
has_kv_transfer_group,
is_v1_kv_transfer_group,
)
from vllm.forward_context import ForwardContext, get_forward_context
from vllm_ascend.attention import attention_v1

from ucm.sparse.state import get_ucm_sparse, has_ucm_sparse

def wait_for_kv_layer_from_connector(layer_name: str):
if not has_kv_transfer_group() or not is_v1_kv_transfer_group():
return

connector = get_kv_transfer_group()
forward_context: ForwardContext = get_forward_context()
attn_metadata = forward_context.attn_metadata
if attn_metadata is None:
return
connector.wait_for_layer_load(layer_name)

attention_v1.wait_for_kv_layer_from_connector = wait_for_kv_layer_from_connector

def maybe_save_kv_layer_to_connector(
layer_name: str,
kv_cache_layer: List[torch.Tensor],
):
if not has_kv_transfer_group() or not is_v1_kv_transfer_group():
return
connector = get_kv_transfer_group()
forward_context: ForwardContext = get_forward_context()
attn_metadata = forward_context.attn_metadata
if attn_metadata is None:
return
connector.save_kv_layer(layer_name, kv_cache_layer, attn_metadata)

attention_v1.maybe_save_kv_layer_to_connector = maybe_save_kv_layer_to_connector

def maybe_execute_sparse_attention_begin(
query: torch.Tensor,
key: torch.Tensor,
Expand Down Expand Up @@ -142,7 +109,6 @@ def unified_ascend_attention_with_output_impl(
output: torch.Tensor,
layer_name: str,
) -> None:
wait_for_kv_layer_from_connector(layer_name)

forward_context: ForwardContext = get_forward_context()
attn_metadata = forward_context.attn_metadata
Expand All @@ -166,7 +132,6 @@ def unified_ascend_attention_with_output_impl(
maybe_execute_sparse_attention_finished(
query, key, value, output, layer_name, forward_context
)
maybe_save_kv_layer_to_connector(layer_name, kv_cache)
return

vllm_ops.unified_ascend_attention_with_output = _wrap_op_overload(
Expand Down Expand Up @@ -198,8 +163,6 @@ def _patch_mla_v1() -> None:
from vllm.forward_context import ForwardContext, get_forward_context
from vllm_ascend.attention.attention_v1 import (
AscendAttentionState,
maybe_save_kv_layer_to_connector,
wait_for_kv_layer_from_connector,
)
from vllm_ascend.attention.mla_v1 import AscendMLAImpl
from vllm_ascend.multistream.context import get_multistream_comm_context
Expand Down Expand Up @@ -399,7 +362,6 @@ def forward(
# FIX: aicore move should be also placed on the comm stream in dbo,
# otherwise it may affect the accuracy
# TODO: use an elegant way to overlap
wait_for_kv_layer_from_connector(layer.layer_name)
maybe_execute_sparse_attention_begin(
prefill_q,
prefill_k_c_normed,
Expand Down Expand Up @@ -427,9 +389,7 @@ def forward(
forward_context,
"prefill",
)
maybe_save_kv_layer_to_connector(layer.layer_name, kv_cache)
if has_decode:
wait_for_kv_layer_from_connector(layer.layer_name)
maybe_execute_sparse_attention_begin(
torch.cat([decode_ql_nope, decode_q_pe], dim=-1),
decode_ql_nope,
Expand Down Expand Up @@ -473,7 +433,6 @@ def forward(
forward_context,
"decode",
)
maybe_save_kv_layer_to_connector(layer.layer_name, kv_cache)

return output_padded

Expand Down Expand Up @@ -523,7 +482,6 @@ def _patch_model_runner_v1() -> None:
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.distributed.kv_transfer import (
get_kv_transfer_group,
has_kv_transfer_group,
)
from vllm.distributed.kv_transfer.kv_connector.v1 import KVConnectorBase_V1
from vllm.forward_context import get_forward_context, set_forward_context
Expand Down Expand Up @@ -1034,7 +992,6 @@ def _process_reqs(
positions = self.positions[:padded_batch_size]

# Run forward pass
finished_dumping = None
with set_forward_context(
attn_metadata, self.vllm_config, num_tokens=num_input_tokens
):
Expand Down Expand Up @@ -1063,7 +1020,6 @@ def _process_reqs(
maybe_converting_weight_acl_format(
self.model, ACL_FORMAT_FRACTAL_ND
)
self.maybe_setup_kv_connector(scheduler_output)
self.maybe_execute_ucm_sparse_begin(
scheduler_output, attn_metadata
)
Expand All @@ -1075,7 +1031,6 @@ def _process_reqs(
inputs_embeds=inputs_embeds,
**model_kwargs,
)
finished_dumping = self.maybe_wait_for_kv_save()
self.maybe_execute_ucm_sparse_finished()

use_spec_decode = len(scheduler_output.scheduled_spec_decode_tokens) > 0
Expand Down Expand Up @@ -1116,7 +1071,6 @@ def _process_reqs(
logits_indices,
aux_hidden_states,
num_scheduled_tokens,
finished_dumping,
)

NPUModelRunner._process_reqs = _process_reqs
Expand All @@ -1141,7 +1095,6 @@ def execute_model(
logits_indices,
aux_hidden_states,
num_scheduled_tokens_np,
finished_dumping,
) = self._process_reqs(scheduler_output, intermediate_tensors)

with ProfileExecuteDuration().capture_async("post process"):
Expand Down Expand Up @@ -1313,7 +1266,6 @@ def execute_model(
logprobs=logprobs_lists,
prompt_logprobs_dict=prompt_logprobs_dict,
pooler_output=[],
finished_dumping=finished_dumping,
)

durations = ProfileExecuteDuration().pop_captured_sync()
Expand All @@ -1334,27 +1286,6 @@ def execute_model(

NPUModelRunner.execute_model = execute_model

@staticmethod
def maybe_setup_kv_connector(scheduler_output: "SchedulerOutput"):
# Update KVConnector with the KVConnector metadata forward().
if has_kv_transfer_group():
kv_connector = get_kv_transfer_group()
assert isinstance(kv_connector, KVConnectorBase_V1)
assert scheduler_output.kv_connector_metadata is not None
kv_connector.bind_connector_metadata(
scheduler_output.kv_connector_metadata
)
# Background KV cache transfers happen here.
# These transfers are designed to be async and the requests
# involved may be disjoint from the running requests.
# Do this here to save a collective_rpc.
kv_connector.start_load_kv(get_forward_context())

@staticmethod
def maybe_wait_for_kv_save():
if has_kv_transfer_group():
return get_kv_transfer_group().wait_for_save()

def maybe_execute_ucm_sparse_begin(
self,
scheduler_output: "SchedulerOutput",
Expand All @@ -1380,8 +1311,6 @@ def ucm_sparse_request_finished_in_worker(self, request_id: str | int):
ucm_sparse = get_ucm_sparse()
ucm_sparse.request_finished_in_worker(request_id)

NPUModelRunner.maybe_setup_kv_connector = maybe_setup_kv_connector
NPUModelRunner.maybe_wait_for_kv_save = maybe_wait_for_kv_save
NPUModelRunner.maybe_execute_ucm_sparse_begin = maybe_execute_ucm_sparse_begin
NPUModelRunner.maybe_execute_ucm_sparse_finished = (
maybe_execute_ucm_sparse_finished
Expand All @@ -1401,9 +1330,6 @@ def _patch_worker_v1() -> None:
import copy
from typing import Optional

from vllm.distributed.kv_transfer import (
has_kv_transfer_group,
)
from vllm.distributed.parallel_state import get_pp_group, get_tp_group
from vllm.logger import logger
from vllm.sequence import IntermediateTensors
Expand Down Expand Up @@ -1435,8 +1361,6 @@ def execute_model(
get_pp_group().send_tensor_dict(
output.tensors, all_gather_group=get_tp_group()
)
if not has_kv_transfer_group():
return None

kv_connector_output = output.kv_connector_output
finished_sending = kv_connector_output.finished_sending
Expand Down
Loading