Skip to content

Commit 405eb8e

Browse files
wangxiyuanMengqingCao
andauthoredJan 9, 2025
[platform] Allow platform specify attention backend (#11609)
Signed-off-by: wangxiyuan <wangxiyuan1007@gmail.com> Signed-off-by: Mengqing Cao <cmq0113@163.com> Co-authored-by: Mengqing Cao <cmq0113@163.com>
1 parent 65097ca commit 405eb8e

File tree

10 files changed

+164
-175
lines changed

10 files changed

+164
-175
lines changed
 
+42-32
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,24 @@
1-
from unittest.mock import patch
1+
from unittest.mock import Mock, patch
22

33
import pytest
44
import torch
55

66
from tests.kernels.utils import override_backend_env_variable
7-
from vllm.attention.selector import which_attn_to_use
7+
from vllm.attention.selector import _cached_get_attn_backend, get_attn_backend
88
from vllm.platforms.cpu import CpuPlatform
99
from vllm.platforms.cuda import CudaPlatform
1010
from vllm.platforms.openvino import OpenVinoPlatform
1111
from vllm.platforms.rocm import RocmPlatform
1212
from vllm.utils import STR_FLASH_ATTN_VAL, STR_INVALID_VAL
1313

1414

15+
@pytest.fixture(autouse=True)
16+
def clear_cache():
17+
"""Clear lru cache to ensure each test case runs without caching.
18+
"""
19+
_cached_get_attn_backend.cache_clear()
20+
21+
1522
@pytest.mark.parametrize(
1623
"name", ["TORCH_SDPA", "ROCM_FLASH", "XFORMERS", "FLASHINFER", "OPENVINO"])
1724
@pytest.mark.parametrize("device", ["cpu", "openvino", "hip", "cuda"])
@@ -24,67 +31,70 @@ def test_env(name: str, device: str, monkeypatch):
2431

2532
if device == "cpu":
2633
with patch("vllm.attention.selector.current_platform", CpuPlatform()):
27-
backend = which_attn_to_use(16, torch.float16, torch.float16, 16,
28-
False)
29-
assert backend.name == "TORCH_SDPA"
34+
backend = get_attn_backend(16, torch.float16, torch.float16, 16,
35+
False)
36+
assert backend.get_name() == "TORCH_SDPA"
3037
elif device == "hip":
3138
with patch("vllm.attention.selector.current_platform", RocmPlatform()):
32-
backend = which_attn_to_use(16, torch.float16, torch.float16, 16,
33-
False)
34-
assert backend.name == "ROCM_FLASH"
39+
backend = get_attn_backend(16, torch.float16, torch.float16, 16,
40+
False)
41+
assert backend.get_name() == "ROCM_FLASH"
3542
elif device == "openvino":
3643
with patch("vllm.attention.selector.current_platform",
37-
OpenVinoPlatform()):
38-
backend = which_attn_to_use(16, torch.float16, torch.float16, 16,
39-
False)
40-
assert backend.name == "OPENVINO"
44+
OpenVinoPlatform()), patch.dict('sys.modules',
45+
{'openvino': Mock()}):
46+
backend = get_attn_backend(16, torch.float16, torch.float16, 16,
47+
False)
48+
assert backend.get_name() == "OPENVINO"
4149
else:
42-
with patch("vllm.attention.selector.current_platform", CudaPlatform()):
43-
backend = which_attn_to_use(16, torch.float16, torch.float16, 16,
44-
False)
45-
assert backend.name == name
50+
if name in ["XFORMERS", "FLASHINFER"]:
51+
with patch("vllm.attention.selector.current_platform",
52+
CudaPlatform()):
53+
backend = get_attn_backend(16, torch.float16, torch.float16,
54+
16, False)
55+
assert backend.get_name() == name
4656

4757

4858
def test_flash_attn(monkeypatch):
4959
"""Test FlashAttn validation."""
5060
# TODO: When testing for v1, pipe in `use_v1` as an argument to
51-
# which_attn_to_use
61+
# get_attn_backend
5262

5363
override_backend_env_variable(monkeypatch, STR_FLASH_ATTN_VAL)
5464

5565
# Unsupported CUDA arch
5666
with patch("torch.cuda.get_device_capability", return_value=(7, 5)):
57-
backend = which_attn_to_use(16, torch.float16, None, 16, False)
58-
assert backend.name != STR_FLASH_ATTN_VAL
67+
backend = get_attn_backend(16, torch.float16, None, 16, False)
68+
assert backend.get_name() != STR_FLASH_ATTN_VAL
5969

6070
# Unsupported data type
61-
backend = which_attn_to_use(16, torch.float8_e4m3fn, None, 16, False)
62-
assert backend.name != STR_FLASH_ATTN_VAL
71+
backend = get_attn_backend(16, torch.float8_e4m3fn, None, 16, False)
72+
assert backend.get_name() != STR_FLASH_ATTN_VAL
6373

6474
# Unsupported kv cache data type
65-
backend = which_attn_to_use(16, torch.float16, "fp8", 16, False)
66-
assert backend.name != STR_FLASH_ATTN_VAL
75+
backend = get_attn_backend(16, torch.float16, "fp8", 16, False)
76+
assert backend.get_name() != STR_FLASH_ATTN_VAL
6777

6878
# Unsupported block size
69-
backend = which_attn_to_use(16, torch.float16, None, 8, False)
70-
assert backend.name != STR_FLASH_ATTN_VAL
79+
backend = get_attn_backend(16, torch.float16, None, 8, False)
80+
assert backend.get_name() != STR_FLASH_ATTN_VAL
7181

7282
# flash-attn is not installed
7383
with patch.dict('sys.modules', {'vllm_flash_attn': None}):
74-
backend = which_attn_to_use(16, torch.float16, None, 16, False)
75-
assert backend.name != STR_FLASH_ATTN_VAL
84+
backend = get_attn_backend(16, torch.float16, None, 16, False)
85+
assert backend.get_name() != STR_FLASH_ATTN_VAL
7686

7787
# Unsupported head size
78-
backend = which_attn_to_use(17, torch.float16, None, 16, False)
79-
assert backend.name != STR_FLASH_ATTN_VAL
88+
backend = get_attn_backend(17, torch.float16, None, 16, False)
89+
assert backend.get_name() != STR_FLASH_ATTN_VAL
8090

8191
# Attention-free models should bypass env and use PlaceholderAttention
82-
backend = which_attn_to_use(16, torch.float16, torch.float16, 16, True)
83-
assert backend.name != STR_FLASH_ATTN_VAL
92+
backend = get_attn_backend(16, torch.float16, torch.float16, 16, True)
93+
assert backend.get_name() != STR_FLASH_ATTN_VAL
8494

8595

8696
def test_invalid_env(monkeypatch):
8797
"""Throw an exception if the backend name is invalid."""
8898
override_backend_env_variable(monkeypatch, STR_INVALID_VAL)
8999
with pytest.raises(ValueError):
90-
which_attn_to_use(16, torch.float16, None, 16, False)
100+
get_attn_backend(16, torch.float16, None, 16, False)

‎vllm/attention/selector.py

+12-127
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from vllm.attention.backends.abstract import AttentionBackend
1010
from vllm.logger import init_logger
1111
from vllm.platforms import _Backend, current_platform
12-
from vllm.utils import STR_BACKEND_ENV_VAR
12+
from vllm.utils import STR_BACKEND_ENV_VAR, resolve_obj_by_qualname
1313

1414
logger = init_logger(__name__)
1515

@@ -114,83 +114,19 @@ def _cached_get_attn_backend(
114114
BlocksparseFlashAttentionBackend)
115115
return BlocksparseFlashAttentionBackend
116116

117-
backend = which_attn_to_use(head_size, dtype, kv_cache_dtype, block_size,
118-
is_attention_free, use_v1)
119-
if backend == _Backend.FLASH_ATTN:
120-
logger.info("Using Flash Attention backend.")
121-
from vllm.attention.backends.flash_attn import ( # noqa: F401
122-
FlashAttentionBackend)
123-
return FlashAttentionBackend
124-
if backend == _Backend.FLASH_ATTN_VLLM_V1:
125-
from vllm.v1.attention.backends.flash_attn import ( # noqa: F401
126-
FlashAttentionBackend as FlashAttentionBackendV1)
127-
return FlashAttentionBackendV1
128-
if backend == _Backend.XFORMERS:
129-
logger.info("Using XFormers backend.")
130-
from vllm.attention.backends.xformers import ( # noqa: F401
131-
XFormersBackend)
132-
return XFormersBackend
133-
elif backend == _Backend.ROCM_FLASH:
134-
logger.info("Using ROCmFlashAttention backend.")
135-
from vllm.attention.backends.rocm_flash_attn import ( # noqa: F401
136-
ROCmFlashAttentionBackend)
137-
return ROCmFlashAttentionBackend
138-
elif backend == _Backend.TORCH_SDPA:
139-
assert current_platform.is_cpu(), RuntimeError(
140-
"Torch SDPA backend is only used for the CPU device.")
141-
logger.info("Using Torch SDPA backend.")
142-
from vllm.attention.backends.torch_sdpa import TorchSDPABackend
143-
return TorchSDPABackend
144-
elif backend == _Backend.OPENVINO:
145-
logger.info("Using OpenVINO Attention backend.")
146-
from vllm.attention.backends.openvino import OpenVINOAttentionBackend
147-
return OpenVINOAttentionBackend
148-
elif backend == _Backend.IPEX:
149-
assert current_platform.is_xpu(), RuntimeError(
150-
"IPEX attention backend is only used for the XPU device.")
151-
logger.info("Using IPEX attention backend.")
152-
from vllm.attention.backends.ipex_attn import IpexAttnBackend
153-
return IpexAttnBackend
154-
elif backend == _Backend.FLASHINFER:
155-
logger.info("Using Flashinfer backend.")
156-
from vllm.attention.backends.flashinfer import FlashInferBackend
157-
return FlashInferBackend
158-
elif backend == _Backend.HPU_ATTN:
159-
logger.info("Using HPUAttention backend.")
160-
from vllm.attention.backends.hpu_attn import HPUAttentionBackend
161-
return HPUAttentionBackend
162-
elif backend == _Backend.PALLAS:
163-
logger.info("Using Pallas backend.")
164-
from vllm.attention.backends.pallas import PallasAttentionBackend
165-
return PallasAttentionBackend
166-
elif backend == _Backend.NO_ATTENTION:
167-
from vllm.attention.backends.placeholder_attn import (
168-
PlaceholderAttentionBackend)
169-
return PlaceholderAttentionBackend
170-
else:
171-
raise ValueError("Invalid attention backend.")
172-
173-
174-
def which_attn_to_use(head_size: int,
175-
dtype: torch.dtype,
176-
kv_cache_dtype: Optional[str],
177-
block_size: int,
178-
is_attention_free: bool,
179-
use_v1: bool = False) -> _Backend:
180-
"""Returns which flash attention backend to use."""
181-
# Default case.
182-
selected_backend = _Backend.FLASH_ATTN
183-
184117
# If there are no attention layers (e.g. we are running Mamba),
185118
# use the placeholder NO_ATTENTION
186119
if is_attention_free:
187-
return _Backend.NO_ATTENTION
120+
from vllm.attention.backends.placeholder_attn import (
121+
PlaceholderAttentionBackend)
122+
return PlaceholderAttentionBackend
188123

189124
# Check whether a particular choice of backend was
190125
# previously forced.
191126
#
192127
# THIS SELECTION OVERRIDES THE VLLM_ATTENTION_BACKEND
193128
# ENVIRONMENT VARIABLE.
129+
selected_backend = None
194130
backend_by_global_setting: Optional[_Backend] = (
195131
get_global_forced_attn_backend())
196132
if backend_by_global_setting is not None:
@@ -201,64 +137,13 @@ def which_attn_to_use(head_size: int,
201137
if backend_by_env_var is not None:
202138
selected_backend = backend_name_to_enum(backend_by_env_var)
203139

204-
# get device-specific default attn_backend
205-
default_backend = current_platform.get_default_attn_backend(
206-
selected_backend)
207-
if default_backend is not None:
208-
return default_backend
209-
210-
if use_v1:
211-
return _Backend.FLASH_ATTN_VLLM_V1
212-
213-
# FlashAttn in NVIDIA GPUs.
214-
if selected_backend == _Backend.FLASH_ATTN:
215-
if not current_platform.has_device_capability(80):
216-
# Volta and Turing NVIDIA GPUs.
217-
logger.info(
218-
"Cannot use FlashAttention-2 backend for Volta and Turing "
219-
"GPUs.")
220-
selected_backend = _Backend.XFORMERS
221-
elif dtype not in (torch.float16, torch.bfloat16):
222-
logger.info(
223-
"Cannot use FlashAttention-2 backend for dtype other than "
224-
"torch.float16 or torch.bfloat16.")
225-
selected_backend = _Backend.XFORMERS
226-
elif kv_cache_dtype is not None and kv_cache_dtype.startswith("fp8"):
227-
logger.info(
228-
"Cannot use FlashAttention-2 backend for FP8 KV cache.")
229-
logger.warning(
230-
"Please use FlashInfer backend with FP8 KV Cache for "
231-
"better performance by setting environment variable "
232-
"VLLM_ATTENTION_BACKEND=FLASHINFER")
233-
selected_backend = _Backend.XFORMERS
234-
elif block_size % 16 != 0:
235-
logger.info(
236-
"Cannot use FlashAttention-2 backend for block size not "
237-
"divisible by 16.")
238-
selected_backend = _Backend.XFORMERS
239-
240-
# FlashAttn is valid for the model, checking if the package is installed.
241-
if selected_backend == _Backend.FLASH_ATTN:
242-
try:
243-
import vllm.vllm_flash_attn # noqa: F401
244-
from vllm.attention.backends.flash_attn import ( # noqa: F401
245-
FlashAttentionBackend)
246-
247-
supported_sizes = FlashAttentionBackend.get_supported_head_sizes()
248-
if head_size not in supported_sizes:
249-
logger.info(
250-
"Cannot use FlashAttention-2 backend for head size %d.",
251-
head_size)
252-
selected_backend = _Backend.XFORMERS
253-
except ImportError:
254-
logger.info(
255-
"Cannot use FlashAttention-2 backend because the "
256-
"vllm.vllm_flash_attn package is not found. "
257-
"Make sure that vllm_flash_attn was built and installed "
258-
"(on by default).")
259-
selected_backend = _Backend.XFORMERS
260-
261-
return selected_backend
140+
# get device-specific attn_backend
141+
attention_cls = current_platform.get_attn_backend_cls(
142+
selected_backend, head_size, dtype, kv_cache_dtype, block_size, use_v1)
143+
if not attention_cls:
144+
raise ValueError(
145+
f"Invalid attention backend for {current_platform.device_name}")
146+
return resolve_obj_by_qualname(attention_cls)
262147

263148

264149
@contextmanager

‎vllm/platforms/cpu.py

+5-2
Original file line numberDiff line numberDiff line change
@@ -28,10 +28,13 @@ def get_device_name(cls, device_id: int = 0) -> str:
2828
return "cpu"
2929

3030
@classmethod
31-
def get_default_attn_backend(cls, selected_backend: _Backend) -> _Backend:
31+
def get_attn_backend_cls(cls, selected_backend: _Backend, head_size: int,
32+
dtype: torch.dtype, kv_cache_dtype: Optional[str],
33+
block_size: int, use_v1: bool) -> str:
3234
if selected_backend != _Backend.TORCH_SDPA:
3335
logger.info("Cannot use %s backend on CPU.", selected_backend)
34-
return _Backend.TORCH_SDPA
36+
logger.info("Using Torch SDPA backend.")
37+
return "vllm.attention.backends.torch_sdpa.TorchSDPABackend"
3538

3639
@classmethod
3740
def get_device_total_memory(cls, device_id: int = 0) -> int:

0 commit comments

Comments
 (0)
Failed to load comments.