Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion test/distributed/_tools/test_sac_ilp.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,11 @@
from torch.testing._internal.common_cuda import TEST_CUDA
from torch.testing._internal.common_utils import (
run_tests,
skipIfRocm,
skipIfTorchDynamo,
TestCase,
skipIfRocm,
skipIfRocmArch,
NAVI_ARCH,
)
from torch.testing._internal.distributed._tensor.common_dtensor import (
ModelArgs,
Expand Down Expand Up @@ -178,6 +180,7 @@ def test_sac_ilp_case1(self):

@skipIfTorchDynamo("https://github.com/pytorch/pytorch/issues/115653")
@unittest.skipIf(not TEST_CUDA, "CUDA not available")
@skipIfRocmArch(NAVI_ARCH)
def test_sac_ilp_case2(self):
"""
This is a case where the memory budget is not binding, meaning that no
Expand Down
2 changes: 1 addition & 1 deletion test/distributed/pipelining/test_schedule_multiproc.py
Original file line number Diff line number Diff line change
Expand Up @@ -604,7 +604,7 @@ def test_schedule_with_native_zero_bubble(self, ScheduleClass):
for name, p in stage_module.named_parameters():
ref_p = ref_submod.get_parameter(name)
try:
torch.testing.assert_close(p.grad, ref_p.grad, rtol=1e-5, atol=4e-5)
torch.testing.assert_close(p.grad, ref_p.grad, rtol=1e-5, atol=9e-5)
except AssertionError:
print(
f"Parameter test failed for {submod_name}.{name}: {p.grad} vs {ref_p.grad}"
Expand Down
4 changes: 3 additions & 1 deletion test/distributed/tensor/debug/test_comm_mode_features.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,8 @@
with_comms,
)


from torch.testing._internal.common_cuda import PLATFORM_SUPPORTS_FUSED_ATTENTION
import unittest
c10d_functional = torch.ops.c10d_functional


Expand Down Expand Up @@ -221,6 +222,7 @@ def test_MLP_module_tracing(self):

@skip_unless_torch_gpu
@with_comms
@unittest.skipIf(not PLATFORM_SUPPORTS_FUSED_ATTENTION, "Does not support fused scaled dot product attention")
def test_transformer_module_tracing(self, is_seq_parallel=False):
"""
tests module-level tracing for more complicated transformer module and
Expand Down