diff --git a/tests/pytorch/distributed/run_numerics.py b/tests/pytorch/distributed/run_numerics.py index 3c3c807a90..1e34b06632 100644 --- a/tests/pytorch/distributed/run_numerics.py +++ b/tests/pytorch/distributed/run_numerics.py @@ -47,11 +47,6 @@ ) -# Disable TF32 -torch.backends.cuda.matmul.allow_tf32 = False -torch.backends.cudnn.allow_tf32 = False - - # Quantization recipe setup def quantization_recipe() -> Recipe: if QUANTIZATION == "fp8": @@ -166,7 +161,7 @@ def backward(ctx, grad_output): def _constant(tensor): - return nn.init.constant_(tensor, 0.5) + return nn.init.constant_(tensor, 0.05) def dist_print(msg, src=None, end="\n", error=False): @@ -189,7 +184,8 @@ def _get_tolerances(dtype): if dtype == torch.bfloat16: return {"rtol": 1.6e-2, "atol": 1e-5} if dtype == torch.float32: - return {"rtol": 1.2e-4, "atol": 1e-4} + # TF32 has same mantissa bits as FP16 + return {"rtol": 1e-3, "atol": 1e-5} raise ValueError(f"Unsupported dtype ({dtype})") diff --git a/tests/pytorch/distributed/test_numerics.py b/tests/pytorch/distributed/test_numerics.py index 632f50e90a..1ff5aff997 100644 --- a/tests/pytorch/distributed/test_numerics.py +++ b/tests/pytorch/distributed/test_numerics.py @@ -56,7 +56,7 @@ def test_distributed(quantization): if quantization == "fp8" and not fp8_available: pytest.skip(reason_for_no_fp8) if quantization == "fp8_cs" and not fp8_available: - pytest.skip(fp8_available) + pytest.skip(reason_for_no_fp8) if quantization == "mxfp8" and not mxfp8_available: pytest.skip(reason_for_no_mxfp8) if quantization == "fp8_block_scaling" and not fp8_block_scaling_available: