1
- from unittest .mock import patch
1
+ from unittest .mock import Mock , patch
2
2
3
3
import pytest
4
4
import torch
5
5
6
6
from tests .kernels .utils import override_backend_env_variable
7
- from vllm .attention .selector import which_attn_to_use
7
+ from vllm .attention .selector import _cached_get_attn_backend , get_attn_backend
8
8
from vllm .platforms .cpu import CpuPlatform
9
9
from vllm .platforms .cuda import CudaPlatform
10
10
from vllm .platforms .openvino import OpenVinoPlatform
11
11
from vllm .platforms .rocm import RocmPlatform
12
12
from vllm .utils import STR_FLASH_ATTN_VAL , STR_INVALID_VAL
13
13
14
14
15
+ @pytest .fixture (autouse = True )
16
+ def clear_cache ():
17
+ """Clear lru cache to ensure each test case runs without caching.
18
+ """
19
+ _cached_get_attn_backend .cache_clear ()
20
+
21
+
15
22
@pytest .mark .parametrize (
16
23
"name" , ["TORCH_SDPA" , "ROCM_FLASH" , "XFORMERS" , "FLASHINFER" , "OPENVINO" ])
17
24
@pytest .mark .parametrize ("device" , ["cpu" , "openvino" , "hip" , "cuda" ])
@@ -24,67 +31,70 @@ def test_env(name: str, device: str, monkeypatch):
24
31
25
32
if device == "cpu" :
26
33
with patch ("vllm.attention.selector.current_platform" , CpuPlatform ()):
27
- backend = which_attn_to_use (16 , torch .float16 , torch .float16 , 16 ,
28
- False )
29
- assert backend .name == "TORCH_SDPA"
34
+ backend = get_attn_backend (16 , torch .float16 , torch .float16 , 16 ,
35
+ False )
36
+ assert backend .get_name () == "TORCH_SDPA"
30
37
elif device == "hip" :
31
38
with patch ("vllm.attention.selector.current_platform" , RocmPlatform ()):
32
- backend = which_attn_to_use (16 , torch .float16 , torch .float16 , 16 ,
33
- False )
34
- assert backend .name == "ROCM_FLASH"
39
+ backend = get_attn_backend (16 , torch .float16 , torch .float16 , 16 ,
40
+ False )
41
+ assert backend .get_name () == "ROCM_FLASH"
35
42
elif device == "openvino" :
36
43
with patch ("vllm.attention.selector.current_platform" ,
37
- OpenVinoPlatform ()):
38
- backend = which_attn_to_use (16 , torch .float16 , torch .float16 , 16 ,
39
- False )
40
- assert backend .name == "OPENVINO"
44
+ OpenVinoPlatform ()), patch .dict ('sys.modules' ,
45
+ {'openvino' : Mock ()}):
46
+ backend = get_attn_backend (16 , torch .float16 , torch .float16 , 16 ,
47
+ False )
48
+ assert backend .get_name () == "OPENVINO"
41
49
else :
42
- with patch ("vllm.attention.selector.current_platform" , CudaPlatform ()):
43
- backend = which_attn_to_use (16 , torch .float16 , torch .float16 , 16 ,
44
- False )
45
- assert backend .name == name
50
+ if name in ["XFORMERS" , "FLASHINFER" ]:
51
+ with patch ("vllm.attention.selector.current_platform" ,
52
+ CudaPlatform ()):
53
+ backend = get_attn_backend (16 , torch .float16 , torch .float16 ,
54
+ 16 , False )
55
+ assert backend .get_name () == name
46
56
47
57
48
58
def test_flash_attn (monkeypatch ):
49
59
"""Test FlashAttn validation."""
50
60
# TODO: When testing for v1, pipe in `use_v1` as an argument to
51
- # which_attn_to_use
61
+ # get_attn_backend
52
62
53
63
override_backend_env_variable (monkeypatch , STR_FLASH_ATTN_VAL )
54
64
55
65
# Unsupported CUDA arch
56
66
with patch ("torch.cuda.get_device_capability" , return_value = (7 , 5 )):
57
- backend = which_attn_to_use (16 , torch .float16 , None , 16 , False )
58
- assert backend .name != STR_FLASH_ATTN_VAL
67
+ backend = get_attn_backend (16 , torch .float16 , None , 16 , False )
68
+ assert backend .get_name () != STR_FLASH_ATTN_VAL
59
69
60
70
# Unsupported data type
61
- backend = which_attn_to_use (16 , torch .float8_e4m3fn , None , 16 , False )
62
- assert backend .name != STR_FLASH_ATTN_VAL
71
+ backend = get_attn_backend (16 , torch .float8_e4m3fn , None , 16 , False )
72
+ assert backend .get_name () != STR_FLASH_ATTN_VAL
63
73
64
74
# Unsupported kv cache data type
65
- backend = which_attn_to_use (16 , torch .float16 , "fp8" , 16 , False )
66
- assert backend .name != STR_FLASH_ATTN_VAL
75
+ backend = get_attn_backend (16 , torch .float16 , "fp8" , 16 , False )
76
+ assert backend .get_name () != STR_FLASH_ATTN_VAL
67
77
68
78
# Unsupported block size
69
- backend = which_attn_to_use (16 , torch .float16 , None , 8 , False )
70
- assert backend .name != STR_FLASH_ATTN_VAL
79
+ backend = get_attn_backend (16 , torch .float16 , None , 8 , False )
80
+ assert backend .get_name () != STR_FLASH_ATTN_VAL
71
81
72
82
# flash-attn is not installed
73
83
with patch .dict ('sys.modules' , {'vllm_flash_attn' : None }):
74
- backend = which_attn_to_use (16 , torch .float16 , None , 16 , False )
75
- assert backend .name != STR_FLASH_ATTN_VAL
84
+ backend = get_attn_backend (16 , torch .float16 , None , 16 , False )
85
+ assert backend .get_name () != STR_FLASH_ATTN_VAL
76
86
77
87
# Unsupported head size
78
- backend = which_attn_to_use (17 , torch .float16 , None , 16 , False )
79
- assert backend .name != STR_FLASH_ATTN_VAL
88
+ backend = get_attn_backend (17 , torch .float16 , None , 16 , False )
89
+ assert backend .get_name () != STR_FLASH_ATTN_VAL
80
90
81
91
# Attention-free models should bypass env and use PlaceholderAttention
82
- backend = which_attn_to_use (16 , torch .float16 , torch .float16 , 16 , True )
83
- assert backend .name != STR_FLASH_ATTN_VAL
92
+ backend = get_attn_backend (16 , torch .float16 , torch .float16 , 16 , True )
93
+ assert backend .get_name () != STR_FLASH_ATTN_VAL
84
94
85
95
86
96
def test_invalid_env (monkeypatch ):
87
97
"""Throw an exception if the backend name is invalid."""
88
98
override_backend_env_variable (monkeypatch , STR_INVALID_VAL )
89
99
with pytest .raises (ValueError ):
90
- which_attn_to_use (16 , torch .float16 , None , 16 , False )
100
+ get_attn_backend (16 , torch .float16 , None , 16 , False )
0 commit comments