Skip to content

Commit afb6c12

Browse files
committedJan 9, 2025
[platform] fix attn backend for cuda
Signed-off-by: Mengqing Cao <cmq0113@163.com>
1 parent 7e83803 commit afb6c12

File tree

2 files changed

+15
-7
lines changed

2 files changed

+15
-7
lines changed
 

‎tests/kernels/test_attention_selector.py

+14-5
Original file line numberDiff line numberDiff line change
@@ -4,14 +4,21 @@
44
import torch
55

66
from tests.kernels.utils import override_backend_env_variable
7-
from vllm.attention.selector import get_attn_backend
7+
from vllm.attention.selector import _cached_get_attn_backend, get_attn_backend
88
from vllm.platforms.cpu import CpuPlatform
99
from vllm.platforms.cuda import CudaPlatform
1010
from vllm.platforms.openvino import OpenVinoPlatform
1111
from vllm.platforms.rocm import RocmPlatform
1212
from vllm.utils import STR_FLASH_ATTN_VAL, STR_INVALID_VAL
1313

1414

15+
@pytest.fixture(autouse=True)
16+
def clear_cache():
17+
"""Clear lru cache to ensure each test case runs without caching.
18+
"""
19+
_cached_get_attn_backend.cache_clear()
20+
21+
1522
@pytest.mark.parametrize(
1623
"name", ["TORCH_SDPA", "ROCM_FLASH", "XFORMERS", "FLASHINFER", "OPENVINO"])
1724
@pytest.mark.parametrize("device", ["cpu", "openvino", "hip", "cuda"])
@@ -39,10 +46,12 @@ def test_env(name: str, device: str, monkeypatch):
3946
False)
4047
assert backend.get_name() == "OPENVINO"
4148
else:
42-
with patch("vllm.attention.selector.current_platform", CudaPlatform()):
43-
backend = get_attn_backend(16, torch.float16, torch.float16, 16,
44-
False)
45-
assert backend.get_name() == name
49+
if name in ["XFORMERS", "FLASHINFER"]:
50+
with patch("vllm.attention.selector.current_platform",
51+
CudaPlatform()):
52+
backend = get_attn_backend(16, torch.float16, torch.float16,
53+
16, False)
54+
assert backend.get_name() == name
4655

4756

4857
def test_flash_attn(monkeypatch):

‎vllm/platforms/cuda.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -154,8 +154,7 @@ def get_attn_backend_cls(cls, selected_backend, head_size, dtype,
154154
logger.info("Using XFormers backend.")
155155
return "vllm.attention.backends.xformers.XFormersBackend"
156156
elif selected_backend == _Backend.FLASH_ATTN:
157-
logger.info("Using FlashAttention backend.")
158-
return "vllm.attention.backends.flash_attn.FlashAttentionBackend"
157+
pass
159158
elif selected_backend:
160159
raise ValueError(
161160
f"Invalid attention backend for {cls.device_name}")

0 commit comments

Comments
 (0)
Failed to load comments.