Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 15 additions & 4 deletions tests/pytorch/fused_attn/test_fused_attn_with_cp.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,10 @@
# See LICENSE for license information.

import os
import pytest
import subprocess
from test_fused_attn import ModelConfig

import pytest
import torch
from transformer_engine.pytorch.attention import (
_flash_attn_2_plus,
_flash_attn_2_3_plus,
Expand All @@ -15,6 +16,8 @@
get_cudnn_version,
)

from test_fused_attn import ModelConfig

model_configs_flash_attn = {
# test: b, h, hg, d, sq, skv, p, mask, bias
"cp_1_0": ModelConfig(2, 12, 12, 128, 4096, 4096, 0.0, "causal", "no_bias"), # MHA
Expand Down Expand Up @@ -58,6 +61,10 @@ def get_bash_arguments(num_gpus_per_node, **kwargs):
@pytest.mark.parametrize("qkv_format", ["bshd", "sbhd", "thd"])
@pytest.mark.parametrize("cp_comm_type", ["p2p", "all_gather", "a2a", "a2a+p2p"])
def test_cp_with_flash_attention(dtype, model, qkv_format, cp_comm_type):
num_gpus = 4 if cp_comm_type == "a2a+p2p" else 2
if num_gpus > torch.cuda.device_count():
pytest.skip(f"Test requires {num_gpus} GPUs, but found {torch.cuda.device_count()}")

config = model_configs_flash_attn[model]
if "p2p" in cp_comm_type and config.window_size != (-1, 0) and config.window_size != (-1, -1):
pytest.skip("CP implementation with KV P2P does not support sliding window yet!")
Expand All @@ -77,7 +84,7 @@ def test_cp_with_flash_attention(dtype, model, qkv_format, cp_comm_type):

subprocess.run(
get_bash_arguments(
num_gpus_per_node=4 if cp_comm_type == "a2a+p2p" else 2,
num_gpus_per_node=num_gpus,
dtype=dtype,
model=model,
qkv_format=qkv_format,
Expand Down Expand Up @@ -115,6 +122,10 @@ def test_cp_with_flash_attention(dtype, model, qkv_format, cp_comm_type):
@pytest.mark.parametrize("cp_comm_type", ["p2p", "all_gather", "a2a", "a2a+p2p"])
@pytest.mark.parametrize("fp8_mha", [False, True])
def test_cp_with_fused_attention(dtype, model, qkv_format, cp_comm_type, fp8_mha):
num_gpus = 4 if cp_comm_type == "a2a+p2p" else 2
if num_gpus > torch.cuda.device_count():
pytest.skip(f"Test requires {num_gpus} GPUs, but found {torch.cuda.device_count()}")

if qkv_format == "thd" and get_device_compute_capability() < (9, 0):
pytest.skip("THD format is only supported on sm90+!")
if cp_comm_type == "all_gather" and get_cudnn_version() < (9, 3, 0):
Expand Down Expand Up @@ -155,7 +166,7 @@ def test_cp_with_fused_attention(dtype, model, qkv_format, cp_comm_type, fp8_mha

subprocess.run(
get_bash_arguments(
num_gpus_per_node=4 if cp_comm_type == "a2a+p2p" else 2,
num_gpus_per_node=num_gpus,
dtype=dtype,
model=model,
qkv_format=qkv_format,
Expand Down