Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 5 additions & 13 deletions tests/python/direct/test_narrow_precision.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
FLOAT4_E2M1_MAX,
FLOAT8_E4M3_MAX,
pytorch_nvfp4_quantize,
microarchitecture_is,
is_pre_blackwell,
microarchitecture_is_pre,
linear_to_swizzled_128_4,
Expand Down Expand Up @@ -241,10 +242,7 @@ def nvfuser_fusion_id0(fd: FusionDefinition):

# cannot use opinfo test, because the input tensor dtype and fusion definition dtype doesn't match
@pytest.mark.skipif(
is_pre_blackwell(), reason="Only supported on blackwell and newer devices."
)
@pytest.mark.skipif(
not microarchitecture_is_pre(12), reason="Does not support blackwell compute 12.0"
not microarchitecture_is(10, 0), reason="Only supported on blackwell compute 10.0."
)
@pytest.mark.parametrize("config", [[128, 256, 512], [128, 256, 512]])
@pytest.mark.parametrize("out_dtype", [torch.bfloat16])
Expand Down Expand Up @@ -324,7 +322,7 @@ def nvfuser_fusion_id0(fd: FusionDefinition) -> None:


@pytest.mark.skipif(
is_pre_blackwell(), reason="Only supported on blackwell and newer devices."
not microarchitecture_is(10, 0), reason="Only supported on blackwell compute 10.0."
)
@pytest.mark.parametrize("config", [[1024, 1024, 1024]])
@pytest.mark.parametrize("out_dtype", [torch.bfloat16])
Expand Down Expand Up @@ -454,10 +452,7 @@ def fusion_baseline(fd: FusionDefinition) -> None:


@pytest.mark.skipif(
is_pre_blackwell(), reason="Only supported on blackwell and newer devices."
)
@pytest.mark.skipif(
not microarchitecture_is_pre(12), reason="Does not support blackwell compute 12.0"
not microarchitecture_is(10, 0), reason="Only supported on blackwell compute 10.0."
)
@pytest.mark.parametrize("config", [[1024, 128, 256]])
@pytest.mark.parametrize("tokens_per_expert_neg_one", [[115, 144, 8]])
Expand Down Expand Up @@ -661,10 +656,7 @@ def nvfuser_fusion_id0(fd: FusionDefinition) -> None:
# 1. inputs data needs to be changed from `torch.testing.make_tensor` to `torch.randn`;
# 2. output errors are much more relaxed.
@pytest.mark.skipif(
is_pre_blackwell(), reason="Only supported on blackwell and newer devices."
)
@pytest.mark.skipif(
not microarchitecture_is_pre(12), reason="Does not support blackwell compute 12.0"
not microarchitecture_is(10, 0), reason="Only supported on blackwell compute 10.0."
)
@pytest.mark.parametrize("config", [[1024, 128, 256]])
@pytest.mark.parametrize("tokens_per_expert_neg_one", [[115, 144, 8]])
Expand Down