diff --git a/transformer_engine/pytorch/attention/dot_product_attention/utils.py b/transformer_engine/pytorch/attention/dot_product_attention/utils.py index 20228ddb80..bcda2cc26b 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/utils.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/utils.py @@ -464,9 +464,14 @@ def get_attention_backend( # On SM90, prefer FA3 over FA4 when FA3 is available. # FA3 is more mature on Hopper; FA4's SM90 backward has limitations # (MLA, non-standard head dims, SplitKV). - if use_flash_attention_4 and use_flash_attention_3 and device_compute_capability == (9, 0): - if FlashAttentionUtils.v4_is_installed: - logger.debug("Disabling FlashAttention 4 to prefer FlashAttention 3 on SM90") + if ( + device_compute_capability == (9, 0) + and use_flash_attention_3 + and FlashAttentionUtils.v3_is_installed + and use_flash_attention_4 + and FlashAttentionUtils.v4_is_installed + ): + logger.debug("Disabling FlashAttention 4 to prefer FlashAttention 3 on SM90") use_flash_attention_4 = False # Filter: Data type