diff --git a/test/test_linalg.py b/test/test_linalg.py index 6fc46663a67e7..4208ada17d877 100644 --- a/test/test_linalg.py +++ b/test/test_linalg.py @@ -9003,7 +9003,8 @@ def dims_full_for_fn(): r1 = fntorch(t0_full, t1, t2) self.assertEqual(r0, r1) - @tf32_on_and_off(0.001) + # ROCm 6.4 passes with tf32=on, but 6.4.1 needed tolerance reduced slightly + @tf32_on_and_off(0.002 if torch.version.hip else 0.001) @bf32_on_and_off(0.001) def test_broadcast_batched_matmul(self, device): n_dim = random.randint(1, 8) diff --git a/test/test_transformers.py b/test/test_transformers.py index 798095e065785..1bd836a3eb15b 100644 --- a/test/test_transformers.py +++ b/test/test_transformers.py @@ -3194,6 +3194,8 @@ def _get_mem_eff_drop_mask(batch_size, n_heads, q_len, kv_len, p, seed, offset, fudge_factors['grad_query'] = 650.0 if dtype == torch.float32: fudge_factors['grad_key'] = 90.0 + if "gfx95" in torch.cuda.get_device_properties(0).gcnArchName: + fudge_factors['grad_value'] = 15.0 check_out_and_grad( (out_ref, out_lp_ref, out), @@ -3315,6 +3317,8 @@ def _get_mem_eff_drop_mask(batch_size, n_heads, q_len, kv_len, p, seed, offset, fudge_factors['grad_query'] = 650.0 if dtype == torch.float32: fudge_factors['grad_key'] = 90.0 + if "gfx95" in torch.cuda.get_device_properties(0).gcnArchName: + fudge_factors['grad_value'] = 15.0 check_out_and_grad( (out_ref, out_lp_ref, out),