Skip to content

Commit cd6bfc7

Browse files
jbschlosservoznesenskym
authored andcommitted
Proper view support for jagged layout NestedTensor (pytorch#113279)
This PR: * Introduces an ATen op for creating true jagged views from a dense values buffer * `_nested_view_from_jagged(values, offsets, lengths, ragged_idx, dummy)` * This ops is implemented on the Python side using torch.library so we can return a subclass instance * `jagged_from_list()` now uses this instead of the old autograd.Function `NestedViewFromBuffer` * The latter op is used for non-contiguous JTs returned via `torch.nested.narrow()` * `dummy` is an awful hack to ensure that `NestedTensor.__torch_dispatch__()` is invoked for our view * Introduces an ATen op for accessing the `values` component of an NT via a view * `_nested_get_values(nt)` * **Removes** the autograd.Functions `ViewNestedFromBuffer` and `ViewBufferFromNested` in favor of `nested_from_values_offsets()` / `nested_from_values_offsets_lengths()` and `nt.values()`, respectively. * Changes test code to prefer `as_nested_tensor()` over `jagged_from_list()` directly * Similarly, avoid `buffer_from_jagged()`, preferring `values()` * Depends on general subclass view fake-ification on the PT2 side (handled solely in previous PRs in the stack) With these changes, the semantics of jagged layout NTs are such that they are considered a true view of the underlying `values` buffer. This means views of jagged NTs are views of the underlying buffer as well, simplifying some handling. Differential Revision: [D54269922](https://our.internmc.facebook.com/intern/diff/D54269922) Co-authored-by: voznesenskym <voznesenskym@gmail.com> Pull Request resolved: pytorch#113279 Approved by: https://github.com/ezyang
1 parent bde2283 commit cd6bfc7

File tree

17 files changed

+542
-205
lines changed

17 files changed

+542
-205
lines changed

aten/src/ATen/FunctionalInverses.cpp

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -303,6 +303,29 @@ Tensor FunctionalInverses::_nested_view_from_buffer_inverse(const Tensor& base,
303303
return Tensor();
304304
}
305305

306+
Tensor FunctionalInverses::_nested_view_from_jagged_inverse(const Tensor& base, const Tensor& mutated_view, InverseReturnMode inverse_return_mode, const Tensor& offsets, const Tensor& dummy, const std::optional<Tensor>& lengths, int64_t ragged_idx) {
307+
auto values = at::_nested_get_values(mutated_view);
308+
if (inverse_return_mode != InverseReturnMode::NeverView) {
309+
return values;
310+
} else {
311+
return values.clone(/*memory_format=*/at::MemoryFormat::Contiguous);
312+
}
313+
}
314+
315+
Tensor FunctionalInverses::_nested_get_values_inverse(const Tensor& base, const Tensor& mutated_view, InverseReturnMode inverse_return_mode) {
316+
auto offsets = at::_nested_get_offsets(base);
317+
auto lengths = at::_nested_get_lengths(base);
318+
auto ragged_idx = at::_nested_get_ragged_idx(base);
319+
auto dummy = at::_nested_get_jagged_dummy(base);
320+
auto nt = at::_nested_view_from_jagged(mutated_view, offsets, dummy, lengths, ragged_idx);
321+
322+
if (inverse_return_mode != InverseReturnMode::NeverView) {
323+
return nt;
324+
} else {
325+
return nt.clone(/*memory_format=*/at::MemoryFormat::Contiguous);
326+
}
327+
}
328+
306329
Tensor FunctionalInverses::unsqueeze_inverse(const Tensor& base, const Tensor& mutated_view, InverseReturnMode inverse_return_mode, int64_t dim) {
307330
if (inverse_return_mode != InverseReturnMode::NeverView) {
308331
return at::squeeze(mutated_view, dim);

aten/src/ATen/native/native_functions.yaml

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6158,6 +6158,52 @@
61586158
CompositeExplicitAutogradNonFunctional: _nested_view_from_buffer_copy
61596159
autogen: _nested_view_from_buffer_copy.out
61606160

6161+
- func: _nested_view_from_jagged(Tensor(a) self, Tensor offsets, Tensor dummy, Tensor? lengths=None, int ragged_idx=1) -> Tensor(a)
6162+
variants: function
6163+
device_check: NoCheck
6164+
dispatch: {}
6165+
6166+
- func: _nested_view_from_jagged_copy(Tensor self, Tensor offsets, Tensor dummy, Tensor? lengths=None, int ragged_idx=1) -> Tensor
6167+
variants: function
6168+
device_check: NoCheck
6169+
tags: view_copy
6170+
dispatch:
6171+
CompositeExplicitAutogradNonFunctional: _nested_view_from_jagged_copy
6172+
autogen: _nested_view_from_jagged_copy.out
6173+
6174+
- func: _nested_get_values(Tensor(a) self) -> Tensor(a)
6175+
variants: function
6176+
device_check: NoCheck
6177+
dispatch: {}
6178+
6179+
- func: _nested_get_values_copy(Tensor self) -> Tensor
6180+
variants: function
6181+
device_check: NoCheck
6182+
tags: view_copy
6183+
dispatch:
6184+
CompositeExplicitAutogradNonFunctional: _nested_get_values_copy
6185+
autogen: _nested_get_values_copy.out
6186+
6187+
- func: _nested_get_offsets(Tensor self) -> Tensor
6188+
variants: function
6189+
device_check: NoCheck
6190+
dispatch: {}
6191+
6192+
# returns undefined Tensor if no lengths present
6193+
- func: _nested_get_lengths(Tensor self) -> Tensor
6194+
variants: function
6195+
device_check: NoCheck
6196+
dispatch: {}
6197+
6198+
- func: _nested_get_ragged_idx(Tensor self) -> int
6199+
variants: function
6200+
device_check: NoCheck
6201+
dispatch: {}
6202+
6203+
- func: _nested_get_jagged_dummy(Tensor any) -> Tensor
6204+
category_override: dummy
6205+
dispatch: {}
6206+
61616207
- func: _trilinear(Tensor i1, Tensor i2, Tensor i3, int[] expand1, int[] expand2, int[] expand3, int[] sumdim, int unroll_dim=1) -> Tensor
61626208
dispatch:
61636209
# calls unsqueeze

test/dynamo/test_subclasses.py

Lines changed: 108 additions & 78 deletions
Original file line numberDiff line numberDiff line change
@@ -19,10 +19,10 @@
1919
StatelessSymbolicContext,
2020
)
2121
from torch.nested._internal.nested_tensor import (
22-
buffer_from_jagged,
2322
jagged_from_list,
2423
jagged_from_tensor_and_lengths,
25-
ViewBufferFromNested,
24+
nested_view_from_values_offsets,
25+
NestedTensor,
2626
)
2727
from torch.testing._internal.common_utils import (
2828
instantiate_parametrized_tests,
@@ -1273,19 +1273,20 @@ def _test_autograd(self, backend):
12731273
a = torch.randn(2, 3, requires_grad=True, dtype=torch.float64)
12741274
b = torch.randn(3, 3, requires_grad=True, dtype=torch.float64)
12751275
c = torch.randn(4, 3, requires_grad=True, dtype=torch.float64)
1276-
nt, offsets = jagged_from_list([a, b, c], None)
1277-
nt2, _ = jagged_from_list([a, b, c], offsets)
1276+
nt = torch.nested.as_nested_tensor([a, b, c], layout=torch.jagged)
1277+
# TODO: Switch to public API when it exists
1278+
nt2, _ = jagged_from_list([a, b, c], nt.offsets())
12781279

12791280
def fn1(nt1, nt2):
12801281
return (nt1 + nt2).sin().cos()
12811282

12821283
compiled_f = torch.compile(fn1, fullgraph=True, backend=backend, dynamic=True)
12831284
out = compiled_f(nt, nt2)
1284-
out_buffer = ViewBufferFromNested.apply(out)
1285+
out_buffer = out.values()
12851286
ga, gb, gc = torch.autograd.grad(out_buffer.sum(), (a, b, c))
12861287

12871288
out_ref = fn1(nt, nt2)
1288-
out_buffer_ref = ViewBufferFromNested.apply(out_ref)
1289+
out_buffer_ref = out_ref.values()
12891290
ga_ref, gb_ref, gc_ref = torch.autograd.grad(out_buffer_ref.sum(), (a, b, c))
12901291

12911292
self.assertTrue(torch.allclose(ga, ga_ref))
@@ -1325,10 +1326,10 @@ def fn(x, y):
13251326
ret = fn_c(nt, y)[0]
13261327
ref = fn(nt_copy, y_copy)[0]
13271328

1328-
self.assertEqual(buffer_from_jagged(ret), buffer_from_jagged(ref))
1329+
self.assertEqual(ret.values(), ref.values())
13291330

1330-
buffer_from_jagged(ret).sum().backward()
1331-
buffer_from_jagged(ref).sum().backward()
1331+
ret.values().sum().backward()
1332+
ref.values().sum().backward()
13321333
for ref_v, res_v in zip(values_copy, values):
13331334
self.assertEqual(ref_v.grad, res_v.grad)
13341335

@@ -1361,83 +1362,112 @@ def fn(x):
13611362
self._check_recompiles(fn, (nt,), (nt3,), True)
13621363

13631364
def _get_views(self):
1364-
# There are three cases to consider here based on the logic in
1365-
# meta_utils.py
1366-
#
1367-
# (1) basic case:
1368-
# view is not a leaf and has the same requires grad as its basic case
1369-
x, _ = self._get_jagged_tensor(((2, 3, 4), 3), None, requires_grad=True)
1370-
self.assertEqual(x.is_leaf, False)
1371-
yield x.unsqueeze(-1)
1372-
1373-
# (2) leaf view case:
1374-
# the view has to be a leaf (w/ requires_grad True or requires_grad False)
1375-
# base w/ requires_grad True or requires_grad False
1376-
for requires_grad_1, requires_grad_2 in itertools.product(
1377-
[True, False], repeat=2
1378-
):
1379-
x, _ = self._get_jagged_tensor(
1380-
((2, 3, 4), 3), None, requires_grad=requires_grad_1
1381-
)
1365+
# Test all cases with both an NT base and a dense base
1366+
# Subclass -> Subclass
1367+
# Dense -> Subclass
1368+
for base_is_nt in [False, True]:
1369+
# There are three cases to consider here based on the logic in
1370+
# meta_utils.py
1371+
#
1372+
# (1) basic case:
1373+
# view is not a leaf and has the same requires grad as its basic case
1374+
x, _ = self._get_jagged_tensor(((2, 3, 4), 3), None, requires_grad=True)
1375+
x = x.clone() if base_is_nt else x
1376+
self.assertEqual(x.is_leaf, False)
1377+
yield x.unsqueeze(-1)
1378+
1379+
# (2) leaf view case:
1380+
# the view has to be a leaf (w/ requires_grad True or requires_grad False)
1381+
# base w/ requires_grad True or requires_grad False
1382+
for requires_grad_1, requires_grad_2 in itertools.product(
1383+
[True, False], repeat=2
1384+
):
1385+
x, _ = self._get_jagged_tensor(
1386+
((2, 3, 4), 3), None, requires_grad=requires_grad_1
1387+
)
1388+
x = x.clone() if base_is_nt else x
1389+
with torch.no_grad():
1390+
x_view = x.unsqueeze(-1)
1391+
# The issue is this doesn't quite work
1392+
x_view.requires_grad_(requires_grad_2)
1393+
yield x_view
1394+
1395+
# (3) obscure case:
1396+
# view is not a leaf (implies requires_grad True)
1397+
# base w/ requires_grad False)
1398+
x, _ = self._get_jagged_tensor(((2, 3, 4), 3), None, requires_grad=False)
1399+
x = x.clone() if base_is_nt else x
1400+
# intermediate leaf view
13821401
with torch.no_grad():
13831402
x_view = x.unsqueeze(-1)
1384-
# The issue is this doesn't quite work
1385-
x_view.requires_grad_(requires_grad_2)
1386-
yield x_view
1387-
1388-
# (3) obscure case:
1389-
# view is not a leaf (implies requires_grad True)
1390-
# base w/ requires_grad False)
1391-
x, _ = self._get_jagged_tensor(((2, 3, 4), 3), None, requires_grad=False)
1392-
# intermediate leaf view
1393-
with torch.no_grad():
1394-
x_view = x.unsqueeze(-1)
1395-
x_view.requires_grad_(True)
1396-
x_view_view = x_view.unsqueeze(-1)
1397-
yield x_view_view
1398-
1399-
def test_inputs_to_compiled_fn_are_views(self):
1400-
for nt_view in self._get_views():
1403+
x_view.requires_grad_(True)
1404+
x_view_view = x_view.unsqueeze(-1)
1405+
yield x_view_view
1406+
1407+
# Subclass -> Dense
1408+
x = self._get_jagged_tensor(((2, 3, 4), 3), None, requires_grad=True)[0].clone()
1409+
yield x.values()
1410+
1411+
# Dense -> Subclass -> Dense -> Subclass
1412+
values = torch.randn(10, 5)
1413+
offsets = torch.tensor([0, 3, 6, 10])
1414+
offsets2 = offsets.clone().detach()
1415+
yield nested_view_from_values_offsets(
1416+
nested_view_from_values_offsets(values, offsets).values(), offsets
1417+
)
14011418

1402-
def fn(x):
1403-
return x.sin()
1419+
def _input_view_test(self, nt_view):
1420+
def fn(x):
1421+
return x.sin()
14041422

1405-
out_ref = fn(nt_view)
1406-
torch._dynamo.reset()
1407-
compile_fn = torch.compile(
1408-
fn, fullgraph=True, backend="aot_eager", dynamic=True
1409-
)
1410-
out = compile_fn(nt_view)
1423+
out_ref = fn(nt_view)
1424+
torch._dynamo.reset()
1425+
compile_fn = torch.compile(
1426+
fn, fullgraph=True, backend="aot_eager", dynamic=True
1427+
)
1428+
out = compile_fn(nt_view)
14111429

1412-
# Check metadata and values are correct
1413-
self.assertTrue(out.size() == out_ref.size())
1414-
self.assertTrue(out.stride() == out_ref.stride())
1430+
# Check metadata and values are correct
1431+
self.assertTrue(out.size() == out_ref.size())
1432+
self.assertTrue(out.stride() == out_ref.stride())
1433+
if out.is_nested:
14151434
self.assertTrue(torch.allclose(out.values(), out_ref.values()))
1435+
else:
1436+
self.assertTrue(torch.allclose(out, out_ref))
14161437

1417-
# Check that no upper/lower bound guards are incurred
1418-
def backend(gm, args):
1419-
context = torch._guards.TracingContext.get()
1420-
guards = [str(g.expr) for g in context.fake_mode.shape_env.guards]
1421-
ranges = [
1422-
f"{s}: [{vr.lower}, {vr.upper}]"
1423-
for s, vr in context.fake_mode.shape_env.var_to_range.items()
1424-
]
1425-
self.assertExpectedInline("\n".join(guards), """Eq(s3 - 1, s0)""")
1426-
self.assertExpectedInline(
1427-
"\n".join(ranges),
1428-
"""\
1429-
s0: [2, 9223372036854775805]
1430-
s2: [2, 9223372036854775806]
1431-
s3: [3, 9223372036854775806]
1432-
s5: [2, 9223372036854775806]""",
1433-
)
1434-
return gm
1438+
# Check that no upper/lower bound guards are incurred
1439+
def backend(gm, args):
1440+
context = torch._guards.TracingContext.get()
1441+
guards = [str(g.expr) for g in context.fake_mode.shape_env.guards]
14351442

1436-
torch._dynamo.reset()
1437-
compile_fn = torch.compile(
1438-
fn, fullgraph=True, backend=backend, dynamic=True
1439-
)
1440-
out = compile_fn(nt_view)
1443+
# varies based on the type of view
1444+
guard_str = "\n".join(guards)
1445+
if isinstance(nt_view._base, NestedTensor):
1446+
self.assertExpectedInline(guard_str, """Eq(s3 - 1, s0)""")
1447+
else:
1448+
self.assertExpectedInline(guard_str, """""")
1449+
return gm
1450+
1451+
torch._dynamo.reset()
1452+
compile_fn = torch.compile(fn, fullgraph=True, backend=backend, dynamic=True)
1453+
out = compile_fn(nt_view)
1454+
1455+
def test_inputs_to_compiled_fn_are_views(self):
1456+
for nt_view in self._get_views():
1457+
self._input_view_test(nt_view)
1458+
1459+
# NJT1 -> Dense -> NJT2 -> Dense view
1460+
# During view replay, the Dense -> NJT2 part will construct an intermediate,
1461+
# symbolically-sized NJT that is immediately deconstructed to return the final dense
1462+
# view. To construct this intermediate properly, we need the associated nested int
1463+
# to be symbolic. This view is expected to fail compilation until symbolic nested ints
1464+
# are cached onto fake offsets to solve this problem.
1465+
@unittest.expectedFailure
1466+
def test_subclass_dense_subclass_dense_view(self):
1467+
x = self._get_jagged_tensor(((2, 3, 4), 3), None, requires_grad=True)[0].clone()
1468+
offsets2 = x.offsets().clone().detach()
1469+
nt_view = nested_view_from_values_offsets(x.values(), offsets2).values()
1470+
self._input_view_test(nt_view)
14411471

14421472

14431473
if __name__ == "__main__":

test/expect/HasDecompTest.test_has_decomposition.expect

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -437,6 +437,13 @@ aten::_nested_from_padded
437437
aten::_nested_from_padded.out
438438
aten::_nested_from_padded_and_nested_example
439439
aten::_nested_from_padded_and_nested_example.out
440+
aten::_nested_get_jagged_dummy
441+
aten::_nested_get_lengths
442+
aten::_nested_get_offsets
443+
aten::_nested_get_ragged_idx
444+
aten::_nested_get_values
445+
aten::_nested_get_values_copy
446+
aten::_nested_get_values_copy.out
440447
aten::_nested_select_backward
441448
aten::_nested_sum_backward
442449
aten::_nested_tensor_from_mask
@@ -454,6 +461,9 @@ aten::_nested_tensor_strides.out
454461
aten::_nested_view_from_buffer
455462
aten::_nested_view_from_buffer_copy
456463
aten::_nested_view_from_buffer_copy.out
464+
aten::_nested_view_from_jagged
465+
aten::_nested_view_from_jagged_copy
466+
aten::_nested_view_from_jagged_copy.out
457467
aten::_new_zeros_with_same_feature_meta
458468
aten::_new_zeros_with_same_feature_meta.out
459469
aten::_nnpack_spatial_convolution

test/test_autograd.py

Lines changed: 32 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8822,7 +8822,7 @@ def _assert_match_metadata(a, b):
88228822
self.assertEqual(a.device, b.device)
88238823
self.assertEqual(a.dtype, b.dtype)
88248824

8825-
def _test_fn(fn, inp, *args):
8825+
def _test_fn(fn, inp, *args, use_unsafe_view_func=False):
88268826
outs = fn(inp, *args)
88278827
# handle functions that return multiple views (e.g. split)
88288828
if isinstance(outs, torch.Tensor):
@@ -8835,7 +8835,10 @@ def _test_fn(fn, inp, *args):
88358835
# forward view_func
88368836
new_inp = inp.clone()
88378837
_assert_match_metadata(new_inp, inp)
8838-
new_out = out._view_func(new_inp)
8838+
if use_unsafe_view_func:
8839+
new_out = out._view_func_unsafe(new_inp)
8840+
else:
8841+
new_out = out._view_func(new_inp)
88398842
_assert_match_metadata(new_out, out)
88408843
self.assertEqual(new_out, out)
88418844

@@ -8901,6 +8904,33 @@ def chain_with_only_current_view_func(x):
89018904

89028905
_test_fn(chain_with_only_current_view_func, torch.randn(2, 3, 4))
89038906

8907+
# TODO: Move this somewhere else
8908+
# test NT views
8909+
from torch.nested._internal.nested_tensor import nested_view_from_values_offsets
8910+
8911+
values = torch.randn(10, 5)
8912+
offsets = torch.tensor([0, 3, 6, 10])
8913+
_test_fn(nested_view_from_values_offsets, values, offsets)
8914+
8915+
nt = nested_view_from_values_offsets(values, offsets).clone().detach()
8916+
_test_fn(torch.ops.aten._nested_get_values.default, nt, use_unsafe_view_func=True)
8917+
8918+
def chain_nt_to_dense_back_and_forth(nt):
8919+
# NJT1 -> dense -> NJT2 -> dense
8920+
offsets2 = nt.offsets().clone().detach()
8921+
return nested_view_from_values_offsets(nt.values(), offsets2).values()
8922+
8923+
_test_fn(chain_nt_to_dense_back_and_forth, nt, use_unsafe_view_func=True)
8924+
8925+
def chain_dense_to_nt_back_and_forth(values, offsets):
8926+
offsets2 = offsets.clone().detach()
8927+
# dense -> NJT1 -> dense -> NJT2
8928+
return nested_view_from_values_offsets(
8929+
nested_view_from_values_offsets(values, offsets).values(),
8930+
offsets2)
8931+
8932+
_test_fn(chain_dense_to_nt_back_and_forth, values, offsets, use_unsafe_view_func=True)
8933+
89048934
def test_view_func_replay_with_modified_state(self):
89058935
with torch.autograd._force_original_view_tracking(True):
89068936
base = torch.randn(3, 4, 5)

0 commit comments

Comments
 (0)