diff --git a/test/inductor/test_aot_inductor.py b/test/inductor/test_aot_inductor.py index cefbaa828b45c..e02447fa1f0f4 100644 --- a/test/inductor/test_aot_inductor.py +++ b/test/inductor/test_aot_inductor.py @@ -30,7 +30,8 @@ from torch.testing._internal.common_cuda import ( SM80OrLater, SM90OrLater, - PLATFORM_SUPPORTS_FLASH_ATTENTION + PLATFORM_SUPPORTS_FLASH_ATTENTION, + PLATFORM_SUPPORTS_FP8 ) from torch.testing._internal.common_device_type import ( _has_sufficient_memory,