|
4 | 4 | import torch
|
5 | 5 |
|
6 | 6 | from tests.kernels.utils import override_backend_env_variable
|
7 |
| -from vllm.attention.selector import get_attn_backend |
| 7 | +from vllm.attention.selector import _cached_get_attn_backend, get_attn_backend |
8 | 8 | from vllm.platforms.cpu import CpuPlatform
|
9 | 9 | from vllm.platforms.cuda import CudaPlatform
|
10 | 10 | from vllm.platforms.openvino import OpenVinoPlatform
|
11 | 11 | from vllm.platforms.rocm import RocmPlatform
|
12 | 12 | from vllm.utils import STR_FLASH_ATTN_VAL, STR_INVALID_VAL
|
13 | 13 |
|
14 | 14 |
|
| 15 | +@pytest.fixture(autouse=True) |
| 16 | +def clear_cache(): |
| 17 | + """Clear lru cache to ensure each test case runs without caching. |
| 18 | + """ |
| 19 | + _cached_get_attn_backend.cache_clear() |
| 20 | + |
| 21 | + |
15 | 22 | @pytest.mark.parametrize(
|
16 | 23 | "name", ["TORCH_SDPA", "ROCM_FLASH", "XFORMERS", "FLASHINFER", "OPENVINO"])
|
17 | 24 | @pytest.mark.parametrize("device", ["cpu", "openvino", "hip", "cuda"])
|
@@ -39,10 +46,12 @@ def test_env(name: str, device: str, monkeypatch):
|
39 | 46 | False)
|
40 | 47 | assert backend.get_name() == "OPENVINO"
|
41 | 48 | else:
|
42 |
| - with patch("vllm.attention.selector.current_platform", CudaPlatform()): |
43 |
| - backend = get_attn_backend(16, torch.float16, torch.float16, 16, |
44 |
| - False) |
45 |
| - assert backend.get_name() == name |
| 49 | + if name in ["XFORMERS", "FLASHINFER"]: |
| 50 | + with patch("vllm.attention.selector.current_platform", |
| 51 | + CudaPlatform()): |
| 52 | + backend = get_attn_backend(16, torch.float16, torch.float16, |
| 53 | + 16, False) |
| 54 | + assert backend.get_name() == name |
46 | 55 |
|
47 | 56 |
|
48 | 57 | def test_flash_attn(monkeypatch):
|
|
0 commit comments