diff --git a/test/test_nn.py b/test/test_nn.py index d5c245c5887d2..0c84d6ffe129e 100644 --- a/test/test_nn.py +++ b/test/test_nn.py @@ -5176,18 +5176,20 @@ def test_batchnorm_buffer_update_when_stats_are_not_tracked(self): def test_batchnorm(self, dims, mode, memory_format, ref_backend, mixed, dtype): if torch.version.cuda: if self._testMethodName in ("test_batchnorm_2D_train_NCHW_vs_cpu_mixed_bfloat16", - "test_batchnorm_3D_train_NCHW_vs_cpu_mixed_bfloat16"): - self.skipTest("bfloat16 NHWC train failed on CUDA due to native tolerance issue " - "https://github.com/pytorch/pytorch/issues/156513") - if self._testMethodName == "test_batchnorm_3D_train_NCHW_vs_native_mixed_float16": - self.skipTest("Batchnorm 3D NHWC train failed on CUDA") + "test_batchnorm_3D_train_NCHW_vs_cpu_mixed_bfloat16", + "test_batchnorm_2D_train_NHWC_vs_NCHW_mixed_bfloat16", + "test_batchnorm_3D_train_NHWC_vs_NCHW_mixed_bfloat16", + "test_batchnorm_3D_train_NCHW_vs_native_mixed_float16"): + self.skipTest("Failed on CUDA") if torch.version.hip: if self._testMethodName in ("test_batchnorm_2D_train_NCHW_vs_cpu_mixed_bfloat16", - "test_batchnorm_3D_train_NCHW_vs_cpu_mixed_bfloat16") \ + "test_batchnorm_3D_train_NCHW_vs_cpu_mixed_bfloat16", + "test_batchnorm_2D_train_NHWC_vs_NCHW_mixed_bfloat16", + "test_batchnorm_3D_train_NHWC_vs_NCHW_mixed_bfloat16") \ and _get_torch_rocm_version() < (6, 4): # NCHW bfloat16 path uses native kernels for rocm<=6.3 - # train failed on rocm<=6.3 due to native tolerance issue + # train failed on rocm<=6.3 due to native accuracy issue # https://github.com/pytorch/pytorch/issues/156513 self.skipTest("bfloat16 NHWC train failed on ROCm <= 6.3") @@ -5197,9 +5199,8 @@ def test_batchnorm(self, dims, mode, memory_format, ref_backend, mixed, dtype): # https://github.com/pytorch/pytorch/issues/156513 self.skipTest("bfloat16 NCHW train failed due to native tolerance issue") - if self._testMethodName == "test_batchnorm_3D_train_NCHW_vs_native_mixed_float16" \ - and _get_torch_rocm_version() < (7, 0): - self.skipTest("3D float16 NCHW train failed on ROCm<7.0") + if self._testMethodName == "test_batchnorm_3D_train_NCHW_vs_native_mixed_float16": + self.skipTest("3D float16 NCHW train failed on ROCm") if dims == 3 and memory_format in ("NHWC", "NCHW"): memory_format = memory_format + "3D"