Skip to content

Commit

Permalink
Revert "Improve make_tensor performance for float and complex types (p…
Browse files Browse the repository at this point in the history
…ytorch#85473)"

This reverts commit a76995e.

Reverted pytorch#85473 on behalf of https://github.com/huydhn due to Sorry for revert your PR, but it seems to cause a bunch of flaky test in pull an periodic
  • Loading branch information
pytorchmergebot authored and alvgaona committed Oct 11, 2022
1 parent 9adeed6 commit 6e8845f
Show file tree
Hide file tree
Showing 6 changed files with 39 additions and 185 deletions.
22 changes: 5 additions & 17 deletions functorch/test/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,8 +304,6 @@ class TestOperators(TestCase):
@opsToleranceOverride('TestOperators', 'test_grad', (
tol1('nn.functional.binary_cross_entropy_with_logits',
{torch.float32: tol(atol=1e-04, rtol=1e-04)}),
tol1('masked.cumprod',
{torch.float32: tol(atol=1e-05, rtol=1e-05)}),
))
def test_grad(self, device, dtype, op):
if op.name in vjp_fail:
Expand Down Expand Up @@ -364,8 +362,6 @@ def wrapped_fn(*args, **kwargs):
{torch.float32: tol(atol=1e-04, rtol=1.3e-06)}, device_type='cuda'),
tol1('nn.functional.binary_cross_entropy_with_logits',
{torch.float32: tol(atol=4e-04, rtol=4e-04)}),
tol1('nn.functional.batch_norm',
{torch.float32: tol(atol=4e-05, rtol=5e-05)}),
))
def test_jvp(self, device, dtype, op):
# TODO: get rid of vjp_decomp when we add decomposition support to
Expand Down Expand Up @@ -611,11 +607,11 @@ def fn(inp, *args, **kwargs):
@toleranceOverride({torch.float32: tol(atol=1e-04, rtol=1e-04)})
@opsToleranceOverride('TestOperators', 'test_vmapvjpvjp', (
tol1('linalg.svd',
{torch.float32: tol(atol=1e-03, rtol=5e-04)}),
{torch.float32: tol(atol=5e-04, rtol=5e-04)}),
tol1('linalg.lu_factor',
{torch.float32: tol(atol=2e-03, rtol=2e-02)}),
tol1('svd',
{torch.float32: tol(atol=1e-03, rtol=5e-04)}),
{torch.float32: tol(atol=5e-04, rtol=5e-04)}),
))
def test_vmapvjpvjp(self, device, dtype, op):
# Since, we test `vjpvjp` independently,
Expand Down Expand Up @@ -722,9 +718,9 @@ def vjp_of_vjp(*args_and_cotangents):
@toleranceOverride({torch.float32: tol(atol=1e-04, rtol=1e-04)})
@opsToleranceOverride('TestOperators', 'test_vmapvjp', (
tol1('linalg.svd',
{torch.float32: tol(atol=5e-04, rtol=1e-04)}, device_type="cuda"),
{torch.float32: tol(atol=1.5e-04, rtol=1e-04)}, device_type="cuda"),
tol1('svd',
{torch.float32: tol(atol=5e-04, rtol=1e-04)}, device_type="cuda"),
{torch.float32: tol(atol=1.5e-04, rtol=1e-04)}, device_type="cuda"),
))
@skipOps('TestOperators', 'test_vmapvjp', vmapvjp_fail)
def test_vmapvjp(self, device, dtype, op):
Expand Down Expand Up @@ -1160,15 +1156,11 @@ def get_vjp(cotangents, *primals):
tol1('masked.prod',
{torch.float32: tol(atol=1e-04, rtol=1.3e-05)}),
tol1('masked.cumprod',
{torch.float32: tol(atol=1e-04, rtol=5e-04)}),
{torch.float32: tol(atol=1e-04, rtol=1e-04)}),
tol1('cumprod',
{torch.float32: tol(atol=1e-04, rtol=1.3e-05)}, device_type='cuda'),
tol1('linalg.vander',
{torch.float32: tol(atol=1e-04, rtol=1.3e-05)}, device_type='cuda'),
tol1('nn.functional.group_norm',
{torch.float32: tol(atol=1e-03, rtol=1e-03)}),
tol2('linalg.pinv', 'hermitian',
{torch.float32: tol(atol=5e-03, rtol=5e-03)}),
))
def test_jvpvjp(self, device, dtype, op):
if not op.supports_autograd:
Expand Down Expand Up @@ -1317,8 +1309,6 @@ def reference(primals, cotangents, primals_tangents, cotangents_tangents):
{torch.float32: tol(atol=5e-04, rtol=5e-04)}),
tol1('linalg.multi_dot',
{torch.float32: tol(atol=5e-04, rtol=5e-04)}),
tol2('linalg.pinv', 'hermitian',
{torch.float32: tol(atol=5e-04, rtol=5e-04)}),
tol1('svd',
{torch.float32: tol(atol=5e-04, rtol=5e-04)}),
))
Expand Down Expand Up @@ -1552,8 +1542,6 @@ def fn(input, weight, bias):
{torch.float32: tol(atol=5e-04, rtol=9e-03)}, device_type='cuda'),
tol1('linalg.householder_product',
{torch.float32: tol(atol=1e-04, rtol=1e-04)}, device_type='cpu'),
tol2('linalg.pinv', 'hermitian',
{torch.float32: tol(atol=5e-06, rtol=5e-06)}),
))
def test_vmap_autograd_grad(self, device, dtype, op):
def is_differentiable(inp):
Expand Down
2 changes: 0 additions & 2 deletions test/test_decomp.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,8 +162,6 @@ def op_assert_ref(test_case, op, test_dtype, i, orig, decomp, ref, args, kwargs)
(torch.float16, torch.ops.aten.native_batch_norm.default): 1e-5,
(torch.bfloat16, torch.ops.aten.linalg_vector_norm.default): 1e-6,
(torch.float16, torch.ops.aten.linalg_vector_norm.default): 1e-6,
(torch.float16, torch.ops.aten.nll_loss_forward.default): 1e-2,
(torch.bfloat16, torch.ops.aten.nll_loss_forward.default): 1e-1,
}
if ref.is_floating_point():
orig_diff = (orig - ref).abs().max()
Expand Down
26 changes: 9 additions & 17 deletions torch/testing/_creation.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,16 +13,6 @@
torch.complex128: torch.float64}
float_to_corresponding_complex_type_map = {v: k for k, v in complex_to_corresponding_float_type_map.items()}


def _uniform_random(t: torch.Tensor, low: float, high: float):
# uniform_ requires to-from <= std::numeric_limits<scalar_t>::max()
# Work around this by scaling the range before and after the PRNG
if high - low >= torch.finfo(t.dtype).max:
return t.uniform_(low / 2, high / 2).mul_(2)
else:
return t.uniform_(low, high)


def make_tensor(
*shape: Union[int, torch.Size, List[int], Tuple[int, ...]],
dtype: torch.dtype,
Expand Down Expand Up @@ -138,16 +128,18 @@ def clamp(a, l, h):
result = torch.randint(low, high, shape, device=device, dtype=dtype) # type: ignore[call-overload]
elif dtype in _floating_types:
ranges_floats = (torch.finfo(dtype).min, torch.finfo(dtype).max)
m_low, m_high = _modify_low_high(low, high, ranges_floats[0], ranges_floats[1], -9, 9, dtype)
result = torch.empty(shape, device=device, dtype=dtype)
_uniform_random(result, m_low, m_high)
low, high = _modify_low_high(low, high, ranges_floats[0], ranges_floats[1], -9, 9, dtype)
rand_val = torch.rand(shape, device=device, dtype=dtype)
result = high * rand_val + low * (1 - rand_val)
elif dtype in _complex_types:
float_dtype = complex_to_corresponding_float_type_map[dtype]
ranges_floats = (torch.finfo(float_dtype).min, torch.finfo(float_dtype).max)
m_low, m_high = _modify_low_high(low, high, ranges_floats[0], ranges_floats[1], -9, 9, dtype)
result = torch.empty(shape, device=device, dtype=dtype)
result_real = torch.view_as_real(result)
_uniform_random(result_real, m_low, m_high)
low, high = _modify_low_high(low, high, ranges_floats[0], ranges_floats[1], -9, 9, dtype)
real_rand_val = torch.rand(shape, device=device, dtype=float_dtype)
imag_rand_val = torch.rand(shape, device=device, dtype=float_dtype)
real = high * real_rand_val + low * (1 - real_rand_val)
imag = high * imag_rand_val + low * (1 - imag_rand_val)
result = torch.complex(real, imag)
else:
raise TypeError(f"The requested dtype '{dtype}' is not supported by torch.testing.make_tensor()."
" To request support, file an issue at: https://github.com/pytorch/pytorch/issues")
Expand Down
69 changes: 16 additions & 53 deletions torch/testing/_internal/common_methods_invocations.py
Original file line number Diff line number Diff line change
Expand Up @@ -8408,9 +8408,7 @@ def reference_flatten(input, start_dim=0, end_dim=-1):
supports_two_python_scalars=True,
decorators=(
DecorateInfo(
toleranceOverride({torch.float16: tol(atol=1e-2, rtol=0),
torch.bfloat16: tol(atol=1e-5, rtol=5e-3),
torch.complex32: tol(atol=1e-5, rtol=1e-3)}),
toleranceOverride({torch.float16: tol(atol=1e-2, rtol=0)}),
'TestBinaryUfuncs', 'test_reference_numerics'),
DecorateInfo(
toleranceOverride({torch.chalf: tol(atol=1e-2, rtol=0)}),
Expand Down Expand Up @@ -8499,13 +8497,7 @@ def reference_flatten(input, start_dim=0, end_dim=-1):
DecorateInfo(
toleranceOverride({torch.float32: tol(atol=1.3e-05, rtol=1.3e-05),
torch.complex64: tol(atol=1e-05, rtol=1.2e-03)}),
'TestCommon', 'test_numpy_refs'),
DecorateInfo(
toleranceOverride({torch.float32: tol(atol=1e-5, rtol=1e-5)}),
'TestConsistency',
'test_output_match',
),
],
'TestCommon', 'test_numpy_refs')],
skips=(
# NVIDIA only assures that bfloat16 is supported by bmm if SM >= 5.3
DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_dtypes', device_type='cuda', active_if=not SM53OrLater),
Expand Down Expand Up @@ -8585,8 +8577,6 @@ def reference_flatten(input, start_dim=0, end_dim=-1):
skips=(
# NVIDIA only assures that bfloat16 is supported by bmm if SM >= 5.3
DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_dtypes', device_type='cuda', active_if=not SM53OrLater),
DecorateInfo(toleranceOverride({torch.float32: tol(atol=1e-5, rtol=1e-5)}),
"TestCommon", "test_out")
),
sample_inputs_func=sample_inputs_bmm),
OpInfo('mv',
Expand Down Expand Up @@ -9746,8 +9736,6 @@ def reference_flatten(input, start_dim=0, end_dim=-1):
# The following tests fails on some jobs
DecorateInfo(unittest.skip('Skipped!'), 'TestBinaryUfuncs', 'test_reference_numerics_extremal_values',
dtypes=(torch.float16,)),
DecorateInfo(toleranceOverride({torch.float16: tol(atol=1e-3, rtol=5e-3)}),
'TestBinaryUfuncs', 'test_reference_numerics'),
)),
UnaryUfuncInfo('frexp',
op=torch.frexp,
Expand Down Expand Up @@ -10199,15 +10187,7 @@ def reference_flatten(input, start_dim=0, end_dim=-1):
# backward on CPU
DecorateInfo(toleranceOverride({torch.float32: tol(atol=0, rtol=1e-5)}),
'TestCommon', 'test_noncontiguous_samples',
device_type='cpu'),
DecorateInfo(
toleranceOverride({
torch.float32: tol(atol=1e-5, rtol=1e-5),
torch.complex64: tol(atol=1e-5, rtol=1e-5),
}),
"TestDecomp", "test_comprehensive", device_type="cuda",
),
],
device_type='cpu'), ],
skips=(
# Strides are not the same!
DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out'),
Expand Down Expand Up @@ -10878,10 +10858,7 @@ def reference_flatten(input, start_dim=0, end_dim=-1):
'TestCommon', 'test_variant_consistency_eager', device_type='cuda'),
DecorateInfo(
toleranceOverride({torch.chalf: tol(atol=5e-2, rtol=5e-2), }),
'TestCommon', 'test_complex_half_reference_testing'),
DecorateInfo(
toleranceOverride({torch.complex32: tol(atol=1e-5, rtol=5e-3)}),
"TestCudaFuserOpInfo", "test_nvfuser_correctness"),
'TestCommon', 'test_complex_half_reference_testing')
),
skips=(
# Reason for Skip: https://github.com/pytorch/pytorch/pull/79694#issuecomment-1186949486
Expand Down Expand Up @@ -10968,7 +10945,7 @@ def reference_flatten(input, start_dim=0, end_dim=-1):
gradcheck_nondet_tol=GRADCHECK_NONDET_TOL,
decorators=(
DecorateInfo(
toleranceOverride({torch.chalf: tol(atol=1e-2, rtol=5e-2)}),
toleranceOverride({torch.chalf: tol(atol=1e-2, rtol=1e-2)}),
'TestCommon', 'test_complex_half_reference_testing'
),
DecorateInfo(
Expand Down Expand Up @@ -11537,7 +11514,7 @@ def reference_flatten(input, start_dim=0, end_dim=-1):
DecorateInfo(unittest.skip("Skipped!"), 'TestGradients', 'test_fn_grad'),
DecorateInfo(unittest.skip("Skipped!"), 'TestGradients', 'test_fn_gradgrad'),
DecorateInfo(unittest.skip("Skipped!"), 'TestGradients', 'test_forward_mode_AD'),
DecorateInfo(unittest.skip("Skipped!"), 'TestCompositeCompliance', 'test_forward_ad',
DecorateInfo(unittest.expectedFailure, 'TestCompositeCompliance', 'test_forward_ad',
device_type='cpu'),
)),
OpInfo('nn.functional.max_unpool1d',
Expand Down Expand Up @@ -12157,8 +12134,6 @@ def reference_flatten(input, start_dim=0, end_dim=-1):
# because it has not been implemented yet.
DecorateInfo(unittest.expectedFailure, 'TestCompositeCompliance', 'test_forward_ad',
device_type="cuda", active_if=TEST_WITH_ROCM),
DecorateInfo(toleranceOverride({torch.float32: tol(atol=5e-05, rtol=1e-05)}),
'TestCompositeCompliance', 'test_forward_ad', device_type="cpu"),
DecorateInfo(toleranceOverride({torch.float16: tol(atol=1e-02, rtol=1e-02)}),
'TestCudaFuserOpInfo', 'test_nvfuser_correctness'),
)),
Expand All @@ -12175,8 +12150,6 @@ def reference_flatten(input, start_dim=0, end_dim=-1):
skips=(
DecorateInfo(toleranceOverride({torch.float16: tol(atol=1e-02, rtol=1e-02)}),
'TestCudaFuserOpInfo', 'test_nvfuser_correctness'),
DecorateInfo(toleranceOverride({torch.float32: tol(atol=1e-03, rtol=1e-04)}),
'TestJit', 'test_variant_consistency_jit'),
),
sample_inputs_func=sample_inputs_batch_norm),
OpInfo(
Expand Down Expand Up @@ -12718,12 +12691,6 @@ def reference_flatten(input, start_dim=0, end_dim=-1):
DecorateInfo(unittest.skip("Skipped!"), 'TestGradients'),
DecorateInfo(unittest.skip("Skipped!"), 'TestJit'),
DecorateInfo(unittest.skip("Skipped!"), 'TestMathBits'),
DecorateInfo(toleranceOverride({torch.bfloat16: tol(atol=1e-3, rtol=0.016)}),
"TestUnaryUfuncs", "test_reference_numerics_extremal",
device_type="cuda"),
DecorateInfo(toleranceOverride({torch.bfloat16: tol(atol=1e-3, rtol=0.016)}),
"TestUnaryUfuncs", "test_reference_numerics_normal",
device_type="cuda"),
),
supports_forward_ad=True,
supports_fwgrad_bwgrad=True,
Expand Down Expand Up @@ -12978,9 +12945,6 @@ def reference_flatten(input, start_dim=0, end_dim=-1):
'TestMathBits', 'test_conj_view'),
DecorateInfo(toleranceOverride({torch.float32: tol(atol=1e-05, rtol=1.2e-03)}),
'TestCommon', 'test_noncontiguous_samples'),
DecorateInfo(toleranceOverride({torch.complex64: tol(atol=1e-05, rtol=1e-05)}),
"TestDecomp", "test_comprehensive", device_type="cuda",
active_if=TEST_WITH_ROCM),
),
skips=(
DecorateInfo(unittest.expectedFailure, 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'),
Expand Down Expand Up @@ -16079,9 +16043,6 @@ def reference_flatten(input, start_dim=0, end_dim=-1):
dtypes=[torch.float16, torch.complex64]),
DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_ref_duplicate_values',
dtypes=[torch.uint8, torch.float16, torch.complex64]),
# FIXME: ValueError: The data in MaskedTensor a and Tensor b do not match
DecorateInfo(unittest.skip("Skipped!"), 'TestOperators', 'test_reduction_all',
dtypes=[torch.float16]),
),
),
ReductionOpInfo(
Expand All @@ -16106,8 +16067,6 @@ def reference_flatten(input, start_dim=0, end_dim=-1):
dtypes=[torch.float16]),
DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_ref_duplicate_values',
dtypes=[torch.float16]),
DecorateInfo(unittest.skip("Skipped!"), 'TestOperators', 'test_reduction_all',
dtypes=[torch.float32]),
),
),
ReductionOpInfo(
Expand Down Expand Up @@ -17442,6 +17401,16 @@ def reference_flatten(input, start_dim=0, end_dim=-1):
# https://github.com/pytorch/pytorch/issues/76944
supports_two_python_scalars=False,
supports_one_python_scalar=True,
skips=(
# Reference result was farther (nan) from the precise computation than
# the torch result was (nan)!
DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref',
dtypes=(torch.chalf,), device_type='cpu'),
# Reference result was farther (nan) from the precise computation than
# the torch result was (nan)!
DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref_torch_fallback',
dtypes=(torch.chalf,), device_type='cpu'),
),
),
ElementwiseBinaryPythonRefInfo(
"_refs.true_divide",
Expand Down Expand Up @@ -17663,12 +17632,6 @@ def reference_flatten(input, start_dim=0, end_dim=-1):
PythonRefInfo(
"_refs.native_layer_norm",
torch_opinfo_name="native_layer_norm",
skips=(
DecorateInfo(unittest.skip("Skipped!"), "TestCommon", "test_python_ref",
device_type="cpu", dtypes=(torch.float32,)),
DecorateInfo(unittest.skip("Skipped!"), "TestCommon", "test_python_ref_torch_fallback",
device_type="cpu", dtypes=(torch.float32,)),
),
),
PythonRefInfo(
"_refs.permute",
Expand Down
Loading

0 comments on commit 6e8845f

Please sign in to comment.