Skip to content

Commit

Permalink
Introduce int_oo (pytorch#127693)
Browse files Browse the repository at this point in the history
In a previous life, we used sympy.oo to represent the lower/upper bounds of integer ranges. Later, we changed this to be sys.maxsize - 1 for a few reasons: (1) sometimes we do tests on a value being exactly sys.maxsize, and we wanted to avoid a data dependent guard in this case, (2) sympy.oo corresponds to floating point infinity, so you get incorrect types for value ranges with oo, and (3) you can do slightly better reasoning if you assume that input sizes fall within representable 64-bit integer range.

After working in the sys.maxsize regime for a bit, I've concluded that this was actually a bad idea. Specifically, the problem is that you end up with sys.maxsize in your upper bound, and then whenever you do any sort of size-increasing computation like size * 2, you end up with 2 * sys.maxsize, and you end up doing a ton of arbitrary precision int computation that is totally unnecessary. A symbolic bound is better.

But especially after pytorch#126905, we can't go back to using sympy.oo, because that advertises that it's not an integer, and now your ValueRanges is typed incorrectly. So what do we do? We define a new numeric constant `int_oo`, which is like `sympy.oo` but it advertises `is_integer`. **test/test_sympy_utils.py** describes some basic properties of the number, and **torch/utils/_sympy/numbers.py** has the actual implementation.

The rest of the changes of the PR are working out the implications of this change. I'll give more commentary as inline comments.

Fixes pytorch#127396

Signed-off-by: Edward Z. Yang <ezyang@meta.com>

Pull Request resolved: pytorch#127693
Approved by: https://github.com/lezcano
ghstack dependencies: pytorch#126905
  • Loading branch information
ezyang authored and TharinduRusira committed Jun 14, 2024
1 parent 5733fcf commit 3e6f084
Show file tree
Hide file tree
Showing 19 changed files with 746 additions and 145 deletions.
9 changes: 2 additions & 7 deletions test/dynamo/test_exc.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,7 +253,6 @@ def fn(x, shape):
==> (>= 0 s1)
==> (>= 0 s2)
==> (>= 0 s3)
==> (>= 9223372036854775806 s0)
Failed Source Expressions:
==> (== (+ L['shape'][0] L['shape'][1] L['shape'][2]) L['x'].size()[0])""",
Expand Down Expand Up @@ -287,14 +286,14 @@ def fn(x, shape):
Model:
==> L['shape'][0]: 1
==> L['shape'][1]: 1
==> L['shape'][2]: 2
==> L['shape'][2]: 0
==> L['x'].size()[0]: 3
==> L['x'].storage_offset(): 0
==> L['x'].stride()[0]: 1
==> s0: 3
==> s1: 1
==> s2: 1
==> s3: 2
==> s3: 0
Assertions:
==> (== 0 L['x'].storage_offset())
Expand All @@ -318,10 +317,6 @@ def fn(x, shape):
==> (== L['shape'][2] s3)
==> (== L['x'].size()[0] s0)
==> (> s0 0)
==> (>= 9223372036854775806 s0)
==> (>= 9223372036854775807 s1)
==> (>= 9223372036854775807 s2)
==> (>= 9223372036854775807 s3)
Failed Source Expressions:
==> (== (+ L['shape'][0] L['shape'][1] L['shape'][2]) L['x'].size()[0])""",
Expand Down
1 change: 0 additions & 1 deletion test/dynamo/test_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -3473,7 +3473,6 @@ def forward(self, pred, x):
]
false_guard_code = [
"Ne(cast_symbool_to_symint_guardless(L['pred']), 1)",
"-9223372036854775808 <= cast_symbool_to_symint_guardless(L['pred'])",
]
test_symbool_guards(
f,
Expand Down
12 changes: 6 additions & 6 deletions test/dynamo/test_misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -9341,7 +9341,7 @@ def test_shape_env_equal_create_symbolic_sizes_strides_storage_offset(self):
> Left: {0: 0, 1: 1, 2: s1, 3: s0}
> Right: {0: 0, 1: 1}
==> var_to_range: values don't match.
> Left: {s0: VR[2, 9223372036854775806], s1: VR[2, 9223372036854775806]}
> Left: {s0: VR[2, int_oo], s1: VR[2, int_oo]}
> Right: {}
==> var_to_sources: values don't match.
> Left: {s0: [TensorPropertySource(base=ConstantSource(source_name='x'), prop=<TensorProperty.SIZE: 0>, idx=0)], s1: [TensorPropertySource(base=ConstantSource(source_name='x'), prop=<TensorProperty.SIZE: 0>, idx=1)]}
Expand Down Expand Up @@ -9375,7 +9375,7 @@ def test_shape_env_equal_unbacked(self):
> Left: 2
> Right: 0
==> var_to_range: values don't match.
> Left: {u0: VR[-9223372036854775808, 9223372036854775807], u1: VR[0, 1], zuf0: VR[-oo, oo]}
> Left: {u0: VR[-int_oo, int_oo], u1: VR[0, 1], zuf0: VR[-oo, oo]}
> Right: {}
""",
)
Expand Down Expand Up @@ -9452,8 +9452,8 @@ def test_shape_env_equal_evaluate_expr_replacement(self):
> Left: {s0: 3}
> Right: {}
==> var_to_range: values don't match.
> Left: {s0: VR[3, 3], s1: VR[2, 9223372036854775806]}
> Right: {s0: VR[2, 9223372036854775806], s1: VR[2, 9223372036854775806]}
> Left: {s0: VR[3, 3], s1: VR[2, int_oo]}
> Right: {s0: VR[2, int_oo], s1: VR[2, int_oo]}
""",
)
self._replay_and_check(main)
Expand Down Expand Up @@ -9490,8 +9490,8 @@ def test_shape_env_equal_evaluate_expr_refinement(self):
> Left: {_assert, ge, x_size_0_, x_size_1_, x_storage_offset, x_stride_0_, x_stride_1_}
> Right: {x_size_0_, x_size_1_, x_storage_offset, x_stride_0_, x_stride_1_}
==> var_to_range: values don't match.
> Left: {s0: VR[3, 9223372036854775806], s1: VR[2, 9223372036854775806]}
> Right: {s0: VR[2, 9223372036854775806], s1: VR[2, 9223372036854775806]}
> Left: {s0: VR[3, int_oo], s1: VR[2, int_oo]}
> Right: {s0: VR[2, int_oo], s1: VR[2, int_oo]}
""",
)
self._replay_and_check(main)
Expand Down
15 changes: 14 additions & 1 deletion test/export/test_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,19 @@ def forward(self, x):
dynamic_shapes={"x": {0: dim_x}},
)

def test_export_slice_maxsize(self):
class Slice(torch.nn.Module):
def forward(self, *args):
return torch.ops.aten.slice.Tensor(*args)

inp = (torch.rand((10, 3, 224, 224)), 0, 0, 9223372036854775807)
dynamic_shapes = (({0: Dim("dim")}, None, None, None),)
torch.export.export(
Slice(),
inp,
dynamic_shapes=dynamic_shapes,
)

def test_export_constraints_error(self):
class ConflictingConstraints(torch.nn.Module):
def forward(self, x):
Expand Down Expand Up @@ -5183,7 +5196,7 @@ def forward(self, x):
}
export(f, (inputs,), dynamic_shapes=dynamic_shapes)

def test_disable_forced_specializations(self):
def test_disable_forced_specializations_ok(self):
# check that _disable_forced_specializations and _allow_complex_guards_as_runtime_asserts flags
# both behave correctly, avoiding forced specializations and deferring to runtime.
# case 1: modulo guards
Expand Down
4 changes: 0 additions & 4 deletions test/onnx/test_fx_to_onnx_with_onnxruntime.py
Original file line number Diff line number Diff line change
Expand Up @@ -633,10 +633,6 @@ def forward(self, x):
func, (torch.randn(3, 4),)
)

@pytorch_test_common.xfail_if_model_type_is_exportedprogram(
error_message="Unsupported FX nodes: {'call_function': ['aten._assert_async.msg']}.",
reason="https://github.com/pytorch/pytorch/issues/112622",
)
def test_operator_with_scalar_output(self):
class Foo(torch.nn.Module):
def forward(self, x, y):
Expand Down
11 changes: 11 additions & 0 deletions test/test_dynamic_shapes.py
Original file line number Diff line number Diff line change
Expand Up @@ -386,6 +386,17 @@ def test_size_expressions(self):
self.assertTrue(str(expand_x.shape[1]), str(x.shape[0]))
self.assertTrue(str(expand_x.shape[1]), str(result.shape[0]))

def test_floordiv_static(self):
shape_env = ShapeEnv()
s0 = create_symint(shape_env, 8)
# This was extracted from
# python test/inductor/test_cuda_cpp_wrapper.py -k
# DynamicShapesCudaWrapperCudaTests.test_insignificant_strides_cuda_dynamic_shapes_cuda_wrapper
bool(s0 % 2 == 0)
bool(s0 % (s0 // 2) == 0)
bool(2 * (s0 // 2) == s0)
self.assertTrue(statically_known_true(s0 // (s0 // 2) == 2))

def test_numel(self):
shape_env = ShapeEnv()
x = create_symbolic_tensor("x", torch.randn(5), shape_env)
Expand Down
4 changes: 3 additions & 1 deletion test/test_proxy_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1201,7 +1201,9 @@ def f(src_tokens):
batch_size = 4
src_tokens = torch.randint(1, vocab_size, (batch_size, prompt_size))
gm = make_fx(f, tracing_mode="symbolic")(src_tokens)
self.assertEqual(len(gm.shape_env.guards), 0)
# Guards to rule out batch_size == sys.maxsize (wobbling between 2 and
# 1 ok)
self.assertEqual(len(gm.shape_env.guards), 1)

@unittest.skipIf(not HAS_CUDA, 'CUDA-only test')
def test_cpu_scalar_cuda(self):
Expand Down
70 changes: 70 additions & 0 deletions test/test_sympy_utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Owner(s): ["oncall: pt2"]

import itertools
import math
import sys

import sympy
Expand All @@ -19,6 +20,7 @@
from torch.utils._sympy.reference import ReferenceAnalysis, PythonReferenceAnalysis
from torch.utils._sympy.interp import sympy_interp
from torch.utils._sympy.singleton_int import SingletonInt
from torch.utils._sympy.numbers import int_oo, IntInfinity, NegativeIntInfinity
from sympy.core.relational import is_ge, is_le, is_gt, is_lt
import functools
import torch.fx as fx
Expand Down Expand Up @@ -122,6 +124,74 @@ def generate_range(vals):
yield ValueRanges(a1, a2)


class TestNumbers(TestCase):
def test_int_infinity(self):
self.assertIsInstance(int_oo, IntInfinity)
self.assertIsInstance(-int_oo, NegativeIntInfinity)
self.assertTrue(int_oo.is_integer)
# is tests here are for singleton-ness, don't use it for comparisons
# against numbers
self.assertIs(int_oo + int_oo, int_oo)
self.assertIs(int_oo + 1, int_oo)
self.assertIs(int_oo - 1, int_oo)
self.assertIs(-int_oo - 1, -int_oo)
self.assertIs(-int_oo + 1, -int_oo)
self.assertIs(-int_oo + (-int_oo), -int_oo)
self.assertIs(-int_oo - int_oo, -int_oo)
self.assertIs(1 + int_oo, int_oo)
self.assertIs(1 - int_oo, -int_oo)
self.assertIs(int_oo * int_oo, int_oo)
self.assertIs(2 * int_oo, int_oo)
self.assertIs(int_oo * 2, int_oo)
self.assertIs(-1 * int_oo, -int_oo)
self.assertIs(-int_oo * int_oo, -int_oo)
self.assertIs(2 * -int_oo, -int_oo)
self.assertIs(-int_oo * 2, -int_oo)
self.assertIs(-1 * -int_oo, int_oo)
self.assertIs(int_oo / 2, sympy.oo)
self.assertIs(-(-int_oo), int_oo) # noqa: B002
self.assertIs(abs(int_oo), int_oo)
self.assertIs(abs(-int_oo), int_oo)
self.assertIs(int_oo ** 2, int_oo)
self.assertIs((-int_oo) ** 2, int_oo)
self.assertIs((-int_oo) ** 3, -int_oo)
self.assertEqual(int_oo ** -1, 0)
self.assertEqual((-int_oo) ** -1, 0)
self.assertIs(int_oo ** int_oo, int_oo)
self.assertTrue(int_oo == int_oo)
self.assertFalse(int_oo != int_oo)
self.assertTrue(-int_oo == -int_oo)
self.assertFalse(int_oo == 2)
self.assertTrue(int_oo != 2)
self.assertFalse(int_oo == sys.maxsize)
self.assertTrue(int_oo >= sys.maxsize)
self.assertTrue(int_oo >= 2)
self.assertTrue(int_oo >= -int_oo)

def test_relation(self):
self.assertIs(sympy.Add(2, int_oo), int_oo)
self.assertFalse(-int_oo > 2)

def test_lt_self(self):
self.assertFalse(int_oo < int_oo)
self.assertIs(min(-int_oo, -4), -int_oo)
self.assertIs(min(-int_oo, -int_oo), -int_oo)

def test_float_cast(self):
self.assertEqual(float(int_oo), math.inf)
self.assertEqual(float(-int_oo), -math.inf)

def test_mixed_oo_int_oo(self):
# Arbitrary choice
self.assertTrue(int_oo < sympy.oo)
self.assertFalse(int_oo > sympy.oo)
self.assertTrue(sympy.oo > int_oo)
self.assertFalse(sympy.oo < int_oo)
self.assertIs(max(int_oo, sympy.oo), sympy.oo)
self.assertTrue(-int_oo > -sympy.oo)
self.assertIs(min(-int_oo, -sympy.oo), -sympy.oo)


class TestValueRanges(TestCase):
@parametrize("fn", UNARY_OPS)
@parametrize("dtype", ("int", "float"))
Expand Down
9 changes: 8 additions & 1 deletion torch/_decomp/decompositions.py
Original file line number Diff line number Diff line change
Expand Up @@ -734,6 +734,11 @@ def slice_forward(
end: Optional[int] = None,
step: int = 1,
):
from torch.fx.experimental.symbolic_shapes import (
guard_size_oblivious,
statically_known_true,
)

ndim = self.dim()
if ndim == 0:
raise RuntimeError("slice() cannot be applied to a 0-dim tensor.")
Expand All @@ -760,7 +765,9 @@ def slice_forward(

if end_val < start_val:
end_val = start_val
elif end_val > sizes[dim]:
elif statically_known_true(end_val == sys.maxsize) or guard_size_oblivious(
end_val > sizes[dim]
):
end_val = sizes[dim]

storage_offset = self.storage_offset() + start_val * strides[dim]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import torch
import torch.fx
from torch.utils._sympy.value_ranges import ValueRanges
from torch.utils._sympy.numbers import int_oo
from torch.fx.experimental.symbolic_shapes import free_unbacked_symbols
from torch.fx.passes.infra.pass_base import PassBase, PassResult

Expand All @@ -23,9 +24,9 @@ class InputDim(NamedTuple):

def _convert_to_int(val):
# Convert simple sympy Integers into concrete int
if val == sympy.oo:
if val in (sympy.oo, int_oo):
return math.inf
if val == -sympy.oo:
if val in (-sympy.oo, -int_oo):
return -math.inf
if isinstance(val, sympy.Integer):
return int(val)
Expand Down
11 changes: 6 additions & 5 deletions torch/_export/serde/serialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
from torch.utils import _pytree as pytree
from torch.utils._pytree import treespec_dumps, treespec_loads
from torch.utils._sympy.value_ranges import ValueRanges
from torch.utils._sympy.numbers import int_oo

from .schema import ( # type: ignore[attr-defined]
Argument,
Expand Down Expand Up @@ -321,9 +322,9 @@ def deserialize_torch_artifact(serialized: Union[Dict[str, Any], Tuple[Any, ...]

def _sympy_int_to_int(val: sympy.Expr, adjust: str):
# Convert simple sympy Integers into concrete int
if val == sympy.oo:
if val in (sympy.oo, int_oo):
return math.inf
if val == -sympy.oo:
if val in (-sympy.oo, -int_oo):
return -math.inf
if isinstance(val, sympy.Integer):
return int(val)
Expand All @@ -346,9 +347,9 @@ def _sympy_int_to_int(val: sympy.Expr, adjust: str):
def _int_to_sympy_int(val) -> sympy.Expr:
# Convert concrete int into simple sympy Integers
if val == math.inf:
return sympy.oo
return int_oo
if val == -math.inf:
return -sympy.oo
return -int_oo
return sympy.Integer(val)


Expand Down Expand Up @@ -1826,7 +1827,7 @@ def deserialize(
self.symbol_name_to_range = {}
if symbol_name_to_range:
for k, vr in symbol_name_to_range.items():
lower = int(vr.lower)
lower = vr.lower
if vr.upper >= 2: # max is >= 2, not sym bool range
lower = max(2, lower)
self.symbol_name_to_range[k] = symbolic_shapes.ValueRanges(_int_to_sympy_int(lower), vr.upper)
Expand Down
14 changes: 9 additions & 5 deletions torch/_inductor/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
SymTypes,
)
from torch.utils._mode_utils import no_dispatch
from torch.utils._sympy.numbers import int_oo

from . import config, ir
from .codegen.common import (
Expand Down Expand Up @@ -1427,18 +1428,21 @@ def format_buffers():
vr = shape_env.var_to_range[i0]
if not shape_env._default_unspecified_value_range().issubset(vr):

def convert(s):
def is_convertible(s):
if s in (int_oo, -int_oo):
return False
try:
return int(s)
int(s)
return True
except TypeError:
return None
return False

if (lower := convert(vr.lower)) is not None:
if is_convertible(vr.lower):
self.register_buffer(
ir.AssertScalar(i0 >= vr.lower, f"{i0} >= {vr.lower}"),
set_name=True,
)
if (upper := convert(vr.upper)) is not None:
if is_convertible(vr.upper):
self.register_buffer(
ir.AssertScalar(i0 <= vr.upper, f"{i0} <= {vr.upper}"),
set_name=True,
Expand Down
Loading

0 comments on commit 3e6f084

Please sign in to comment.