Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 11 additions & 10 deletions test/test_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -5176,18 +5176,20 @@ def test_batchnorm_buffer_update_when_stats_are_not_tracked(self):
def test_batchnorm(self, dims, mode, memory_format, ref_backend, mixed, dtype):
if torch.version.cuda:
if self._testMethodName in ("test_batchnorm_2D_train_NCHW_vs_cpu_mixed_bfloat16",
"test_batchnorm_3D_train_NCHW_vs_cpu_mixed_bfloat16"):
self.skipTest("bfloat16 NHWC train failed on CUDA due to native tolerance issue "
"https://github.com/pytorch/pytorch/issues/156513")
if self._testMethodName == "test_batchnorm_3D_train_NCHW_vs_native_mixed_float16":
self.skipTest("Batchnorm 3D NHWC train failed on CUDA")
"test_batchnorm_3D_train_NCHW_vs_cpu_mixed_bfloat16",
"test_batchnorm_2D_train_NHWC_vs_NCHW_mixed_bfloat16",
"test_batchnorm_3D_train_NHWC_vs_NCHW_mixed_bfloat16",
"test_batchnorm_3D_train_NCHW_vs_native_mixed_float16"):
self.skipTest("Failed on CUDA")

if torch.version.hip:
if self._testMethodName in ("test_batchnorm_2D_train_NCHW_vs_cpu_mixed_bfloat16",
"test_batchnorm_3D_train_NCHW_vs_cpu_mixed_bfloat16") \
"test_batchnorm_3D_train_NCHW_vs_cpu_mixed_bfloat16",
"test_batchnorm_2D_train_NHWC_vs_NCHW_mixed_bfloat16",
"test_batchnorm_3D_train_NHWC_vs_NCHW_mixed_bfloat16") \
and _get_torch_rocm_version() < (6, 4):
# NCHW bfloat16 path uses native kernels for rocm<=6.3
# train failed on rocm<=6.3 due to native tolerance issue
# train failed on rocm<=6.3 due to native accuracy issue
# https://github.com/pytorch/pytorch/issues/156513
self.skipTest("bfloat16 NHWC train failed on ROCm <= 6.3")

Expand All @@ -5197,9 +5199,8 @@ def test_batchnorm(self, dims, mode, memory_format, ref_backend, mixed, dtype):
# https://github.com/pytorch/pytorch/issues/156513
self.skipTest("bfloat16 NCHW train failed due to native tolerance issue")

if self._testMethodName == "test_batchnorm_3D_train_NCHW_vs_native_mixed_float16" \
and _get_torch_rocm_version() < (7, 0):
self.skipTest("3D float16 NCHW train failed on ROCm<7.0")
if self._testMethodName == "test_batchnorm_3D_train_NCHW_vs_native_mixed_float16":
self.skipTest("3D float16 NCHW train failed on ROCm")

if dims == 3 and memory_format in ("NHWC", "NCHW"):
memory_format = memory_format + "3D"
Expand Down