diff --git a/ucm/integration/vllm/patch/patch_funcs/v092/vllm_ascend_adapt.py b/ucm/integration/vllm/patch/patch_funcs/v092/vllm_ascend_adapt.py index 1350775d..b16e65a0 100644 --- a/ucm/integration/vllm/patch/patch_funcs/v092/vllm_ascend_adapt.py +++ b/ucm/integration/vllm/patch/patch_funcs/v092/vllm_ascend_adapt.py @@ -117,7 +117,25 @@ def maybe_execute_sparse_attention_finished( maybe_execute_sparse_attention_finished ) - def unified_ascend_attention_with_output( + vllm_ops = torch.ops.vllm + orig_unified_ascend_attention_with_output = ( + vllm_ops.unified_ascend_attention_with_output + ) + + def _wrap_op_overload(orig, impl): + class _Wrapper: + def __init__(self, orig): + self._orig = orig + + def __call__(self, *args, **kwargs): + return impl(*args, **kwargs) + + def __getattr__(self, name): + return getattr(self._orig, name) + + return _Wrapper(orig) + + def unified_ascend_attention_with_output_impl( query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, @@ -151,8 +169,13 @@ def unified_ascend_attention_with_output( maybe_save_kv_layer_to_connector(layer_name, kv_cache) return + vllm_ops.unified_ascend_attention_with_output = _wrap_op_overload( + orig_unified_ascend_attention_with_output, + unified_ascend_attention_with_output_impl, + ) + attention_v1.unified_ascend_attention_with_output = ( - unified_ascend_attention_with_output + unified_ascend_attention_with_output_impl ) except ImportError as e: logger.error(f"Failed to patch attention_v1.py: {e}", exc_info=True)