Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions qa/L3_pytorch_FA_versions_test/test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,10 @@ sm_arch=`python3 -c "import torch; sm = torch.cuda.get_device_capability(0); pri
export FLASH_ATTN_CUDA_ARCHS=$sm_arch
if [ $sm_arch -gt 90 ]
then
FA_versions=(2.8.3)
FA_versions=(2.8.3 4.0.0b8)
elif [ $sm_arch -eq 90 ]
then
FA_versions=(2.7.3 2.8.3 3.0.0b1)
FA_versions=(2.7.3 2.8.3 3.0.0b1 4.0.0b8)
fi

for fa_version in "${FA_versions[@]}"
Expand All @@ -31,6 +31,9 @@ do
if [ "${fa_version}" \< "3.0.0" ]
then
pip3 install flash-attn==${fa_version} --no-build-isolation
elif [[ "${fa_version}" == 4.* ]]
then
pip3 install flash-attn-4==${fa_version} nvidia-cutlass-dsl[cu13]==4.4.2 --no-build-isolation
else
git clone https://github.com/Dao-AILab/flash-attention.git
cd flash-attention/hopper && python setup.py install
Expand Down
135 changes: 134 additions & 1 deletion tests/pytorch/attention/test_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@
)

_current_file = pathlib.Path(__file__).resolve()
sys.path.append(str(_current_file.parent.parent))
sys.path = [str(_current_file.parent.parent)] + sys.path
from utils import (
reset_rng_states,
compare_and_assert,
Expand Down Expand Up @@ -362,6 +362,139 @@ def test_dpa_num_splits(dtype, model_configs, model):
)


# ==============================
# Flash Attention 4 (FA4) tests
# ==============================

model_configs_fa4_base = {
# test: ModelConfig(b, sq, hq, dqk)
# Standard head dims
"fa4_base_1": ModelConfig(4, 128, 16, 64),
"fa4_base_2": ModelConfig(2, 2048, 24, 128, attn_mask_type="causal"),
"fa4_base_3": ModelConfig(2, 1024, 8, 96, attn_mask_type="causal"),
# GQA
"fa4_gqa_1": ModelConfig(2, 1024, 32, 128, num_gqa_groups=8, attn_mask_type="causal"),
"fa4_gqa_2": ModelConfig(2, 1024, 16, 128, num_gqa_groups=1, attn_mask_type="causal"),
# num_splits
"fa4_splits_1": ModelConfig(2, 2048, 24, 128, num_splits=2),
"fa4_splits_2": ModelConfig(1, 2048, 24, 128, max_seqlen_kv=4096, num_splits=4),
}


@pytest.mark.skipif(
not FlashAttentionUtils.v4_is_installed, reason="Flash-attn v4 (flash-attn-4) is required."
)
@pytest.mark.skipif(get_cudnn_version() < (8, 9, 1), reason="cuDNN 8.9.1+ is required.")
@pytest.mark.parametrize("dtype", param_types_lean)
@pytest.mark.parametrize("model_configs", [model_configs_fa4_base])
@pytest.mark.parametrize("model", model_configs_fa4_base.keys())
def test_dpa_fa4_base(dtype, model_configs, model):
"""Test DotProductAttention with FA4: base configs, extended head dims, GQA, num_splits"""
test_dot_product_attention(dtype, model_configs, model, False, True, None, False, False)


model_configs_fa4_mla = {
# test: ModelConfig(b, sq, hq, dqk, head_dim_v=dv)
"fa4_mla_1": ModelConfig(4, 128, 16, 128, head_dim_v=64),
"fa4_mla_2": ModelConfig(2, 128, 16, 64, max_seqlen_kv=256, head_dim_v=128),
"fa4_mla_3": ModelConfig(2, 1024, 16, 96, head_dim_v=64, attn_mask_type="causal"),
# dqk=128, dv=96: FA4 SM100 backward has dK_reduce_ncol misalignment for dV;
# the backend filter should reject FA4 and fall back to another backend.
"fa4_mla_4": ModelConfig(2, 1024, 16, 128, head_dim_v=96, attn_mask_type="causal"),
# DeepSeek-style MLA: dqk=192, dv=128 (supported on SM100 as special case)
"fa4_mla_deepseek": ModelConfig(2, 1024, 16, 192, head_dim_v=128, attn_mask_type="causal"),
}


@pytest.mark.skipif(
not FlashAttentionUtils.v4_is_installed, reason="Flash-attn v4 (flash-attn-4) is required."
)
@pytest.mark.skipif(get_cudnn_version() < (8, 9, 1), reason="cuDNN 8.9.1+ is required.")
@pytest.mark.parametrize("dtype", param_types_lean)
@pytest.mark.parametrize("model_configs", [model_configs_fa4_mla])
@pytest.mark.parametrize("model", model_configs_fa4_mla.keys())
def test_dpa_fa4_mla(dtype, model_configs, model):
"""Test DotProductAttention with FA4: MLA (head_dim_qk != head_dim_v)"""
test_dot_product_attention(
dtype, model_configs, model, False, True, "bshd_bshd_bshd", False, False
)


model_configs_fa4_swa = {
# test: ModelConfig(b, sq, hq, dqk, window_size=(left, right))
"fa4_swa_1": ModelConfig(2, 2048, 16, 128, attn_mask_type="causal", window_size=(128, 0)),
"fa4_swa_2": ModelConfig(2, 2048, 24, 64, attn_mask_type="causal", window_size=(64, 0)),
"fa4_swa_3": ModelConfig(
2, 2048, 16, 128, num_gqa_groups=4, attn_mask_type="causal", window_size=(256, 0)
),
"fa4_swa_4": ModelConfig(
2, 2048, 16, 128, attn_mask_type="padding_causal", window_size=(128, 0)
),
}


@pytest.mark.skipif(
not FlashAttentionUtils.v4_is_installed, reason="Flash-attn v4 (flash-attn-4) is required."
)
@pytest.mark.skipif(get_cudnn_version() < (8, 9, 1), reason="cuDNN 8.9.1+ is required.")
@pytest.mark.parametrize("dtype", param_types_lean)
@pytest.mark.parametrize("model_configs", [model_configs_fa4_swa])
@pytest.mark.parametrize("model", model_configs_fa4_swa.keys())
@pytest.mark.parametrize("qkv_layout", ["sbhd_sbhd_sbhd", "bshd_bshd_bshd"])
def test_dpa_fa4_sliding_window(dtype, model_configs, model, qkv_layout):
"""Test DotProductAttention with FA4: sliding window attention"""
test_dot_product_attention(dtype, model_configs, model, False, True, qkv_layout, True, False)


model_configs_fa4_varlen = {
# test: ModelConfig(b, sq, hq, dqk)
"fa4_varlen_1": ModelConfig(4, 128, 16, 64, attn_mask_type="padding"),
"fa4_varlen_2": ModelConfig(2, 2048, 24, 128, attn_mask_type="padding_causal"),
"fa4_varlen_3": ModelConfig(
2, 2048, 24, 128, num_gqa_groups=4, attn_mask_type="padding_causal"
),
"fa4_varlen_4": ModelConfig(2, 128, 16, 64, max_seqlen_kv=256, attn_mask_type="padding"),
}


@pytest.mark.skipif(
not FlashAttentionUtils.v4_is_installed, reason="Flash-attn v4 (flash-attn-4) is required."
)
@pytest.mark.skipif(get_cudnn_version() < (8, 9, 1), reason="cuDNN 8.9.1+ is required.")
@pytest.mark.parametrize("dtype", param_types_lean)
@pytest.mark.parametrize("model_configs", [model_configs_fa4_varlen])
@pytest.mark.parametrize("model", model_configs_fa4_varlen.keys())
@pytest.mark.parametrize("qkv_layout", ["thd_thd_thd", "bshd_bshd_bshd"])
def test_dpa_fa4_varlen(dtype, model_configs, model, qkv_layout):
"""Test DotProductAttention with FA4: variable-length sequences (varlen/thd)"""
test_dot_product_attention(dtype, model_configs, model, False, True, qkv_layout, False, False)


model_configs_fa4_mask = {
# test: ModelConfig(b, sq, hq, dqk)
"fa4_mask_no_mask": ModelConfig(2, 1024, 16, 128),
"fa4_mask_causal": ModelConfig(2, 1024, 16, 128, attn_mask_type="causal"),
"fa4_mask_causal_br": ModelConfig(2, 1024, 16, 128, attn_mask_type="causal_bottom_right"),
"fa4_mask_padding": ModelConfig(2, 1024, 16, 128, attn_mask_type="padding"),
"fa4_mask_padding_causal": ModelConfig(2, 1024, 16, 128, attn_mask_type="padding_causal"),
"fa4_mask_padding_causal_br": ModelConfig(
2, 1024, 16, 128, attn_mask_type="padding_causal_bottom_right"
),
}


@pytest.mark.skipif(
not FlashAttentionUtils.v4_is_installed, reason="Flash-attn v4 (flash-attn-4) is required."
)
@pytest.mark.skipif(get_cudnn_version() < (8, 9, 1), reason="cuDNN 8.9.1+ is required.")
@pytest.mark.parametrize("dtype", param_types_lean)
@pytest.mark.parametrize("model_configs", [model_configs_fa4_mask])
@pytest.mark.parametrize("model", model_configs_fa4_mask.keys())
def test_dpa_fa4_mask(dtype, model_configs, model):
"""Test DotProductAttention with FA4: various attention mask types"""
test_dot_product_attention(dtype, model_configs, model, False, True, None, False, False)


model_configs_softmax = {
# test: ModelConfig(b, sq, hq, dqk)
"softmax_1_0": ModelConfig(2, 2048, 64, 64, num_gqa_groups=8),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -80,14 +80,16 @@
from transformer_engine.pytorch.export import is_in_onnx_export_mode
from transformer_engine.pytorch.graph import is_graph_capturing

# Global vars for flash attn v2 and v3 imports
# Global vars for flash attn v2
flash_attn_cuda_bwd = None
flash_attn_func = None
flash_attn_varlen_func = None
_flash_attn_fwd = None
_flash_attn_bwd = None
_flash_attn_varlen_fwd = None
_flash_attn_varlen_bwd = None

# Try to import Flash Attention v2
try:
fa_utils.version = PkgVersion(get_pkg_version("flash-attn"))
except PackageNotFoundError:
Expand Down Expand Up @@ -130,12 +132,16 @@
),
fa_utils.version,
)

# Try to import Flash Attention v3
try:
fa_utils.fa3_version = PkgVersion(get_pkg_version("flash-attn-3"))
except PackageNotFoundError:
flash_attn_func_v3 = None
flash_attn_varlen_func_v3 = None
flash_attn_with_kvcache_v3 = None
_flash_attn_fwd_v3 = None
_flash_attn_bwd_v3 = None
# pass # only print warning if use_flash_attention_3 = True in get_attention_backend
else:
from flash_attn_interface import flash_attn_func as flash_attn_func_v3
Expand All @@ -150,6 +156,20 @@

fa_utils.set_flash_attention_3_params()

# Try to import Flash Attention v4
try:
fa_utils.fa4_version = PkgVersion(get_pkg_version("flash-attn-4"))
except PackageNotFoundError:
flash_attn_func_v4 = None
flash_attn_varlen_func_v4 = None
else:
from flash_attn.cute.interface import ( # pylint: disable=ungrouped-imports,no-name-in-module
flash_attn_func as flash_attn_func_v4,
flash_attn_varlen_func as flash_attn_varlen_func_v4,
)

fa_utils.set_flash_attention_4_params()

# Float8CurrentScaling: fused_attn_bwd takes O in FP8 by default, this flag allows it in F16
_dpa_fp8_cs_o_in_f16 = os.getenv("NVTE_DPA_FP8CS_O_in_F16", "1") == "1"

Expand Down Expand Up @@ -916,8 +936,13 @@ def forward(
batch_size * context_len,
)

use_flash_attn_4 = False
if flash_attention_backend is not None and flash_attention_backend > PkgVersion("4.0.0b"):
use_flash_attn_4 = True
use_flash_attn_3 = False
if flash_attention_backend is not None and flash_attention_backend > PkgVersion("3.0.0b"):
if flash_attention_backend is not None and PkgVersion(
"3.0.0b"
) < flash_attention_backend < PkgVersion("4.0.0"):
use_flash_attn_3 = True
if context_parallel and all(
not isinstance(x, Float8Tensor) for x in [query_layer, key_layer, value_layer]
Expand Down Expand Up @@ -971,24 +996,55 @@ def forward(
# | | thd + padding
# | flash_attn_with_kvcache | KV cache (not-paged/paged), i.e.
# | | bshd/sbhd/thd + padding
# FA v4 | flash_attn_func | bshd/sbhd + not padding
# | flash_attn_varlen_func | bshd/sbhd + padding
# | | thd + padding
fa_optional_forward_args_thd = []
if qkv_format in ["bshd", "sbhd"] and "padding" not in attn_mask_type:
func = (
flash_attn_func if not use_flash_attn_3 else flash_attn_func_v3
) # pylint: disable=possibly-used-before-assignment
func = None
if use_flash_attn_4:
func = flash_attn_func_v4
elif use_flash_attn_3:
func = flash_attn_func_v3
else:
func = flash_attn_func
else:
if not use_flash_attn_3:
if use_flash_attn_4:
func = flash_attn_varlen_func_v4
elif not use_flash_attn_3:
func = flash_attn_varlen_func
elif inference_params is None:
func = flash_attn_varlen_func_v3 # pylint: disable=possibly-used-before-assignment
else:
func = flash_attn_with_kvcache_v3 # pylint: disable=possibly-used-before-assignment
if not use_flash_attn_3 or inference_params is None:
if not use_flash_attn_4 and (not use_flash_attn_3 or inference_params is None):
fa_optional_forward_args_thd.append(cu_seqlens_q)
fa_optional_forward_args_thd.append(cu_seqlens_kv)
fa_optional_forward_args_thd.append(max_seqlen_q)
fa_optional_forward_args_thd.append(max_seqlen_kv)
if not use_flash_attn_3:
if use_flash_attn_4:
fa_4_optional_forward_kwargs = {
"window_size": window_size,
"num_splits": num_splits,
}
if inference_params is None:
fa_4_optional_forward_kwargs["deterministic"] = self.deterministic
if func is flash_attn_varlen_func_v4:
fa_4_optional_forward_kwargs["cu_seqlens_q"] = cu_seqlens_q
fa_4_optional_forward_kwargs["cu_seqlens_k"] = cu_seqlens_kv
fa_4_optional_forward_kwargs["max_seqlen_q"] = max_seqlen_q
fa_4_optional_forward_kwargs["max_seqlen_k"] = max_seqlen_kv
output = func(
query_layer,
key_layer,
value_layer,
softmax_scale=self.softmax_scale,
causal="causal" in attn_mask_type,
Comment on lines +1037 to +1042
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 causal_bottom_right treated identically to causal for FA4

causal="causal" in attn_mask_type evaluates to True for both "causal" and "causal_bottom_right". If FA4's flash_attn_func supports a separate bottom-right diagonal alignment flag (similar to how cuDNN fused attention distinguishes the two), passing only causal=True would produce incorrect results for causal_bottom_right configs.

This is consistent with the existing FA2 path, but since fa4_mask_causal_br is explicitly added as a test case, it is worth verifying that the FA4 causal parameter correctly implements both variants, or adding a dedicated causal_bottom_right kwarg if the FA4 API exposes one.

**fa_4_optional_forward_kwargs,
)
if isinstance(output, (List, Tuple)):
output = output[0]
elif not use_flash_attn_3:
fa_optional_forward_kwargs = {}
if fa_utils.v2_3_plus:
fa_optional_forward_kwargs["window_size"] = window_size
Expand Down
Loading
Loading