Skip to content

Commit

Permalink
Support for as_nested_tensor() with jagged layout + fixed nested_tens…
Browse files Browse the repository at this point in the history
…or() semantics (pytorch#112304)

This PR:
* Adds support for the `layout` kwarg to `torch.nested.as_nested_tensor()`
* Fixes `torch.nested.nested_tensor()`
    * It should accept a list of lists of scalars
    * It should not preserve autograd history
* Adds extensive testing for these two functions

Semantics for the two functions follow those of the strided layout:
* `torch.nested.nested_tensor(tensor_list, layout=torch.jagged)`: Creates a new jagged layout NT **with no autograd history**
    * `tensor_list` can be a list of Tensors or list of lists of scalars
* `torch.nested.as_nested_tensor(tensor_list, layout=torch.jagged)`: Creates a new jagged layout NT **preserving autograd history of `tensor_list`**
    * `tensor_list` must be a list of Tensors
Pull Request resolved: pytorch#112304
Approved by: https://github.com/cpuhrsch, https://github.com/soulitzer
  • Loading branch information
jbschlosser authored and andreigh committed Nov 19, 2023
1 parent cdca441 commit fc3ae96
Show file tree
Hide file tree
Showing 2 changed files with 109 additions and 30 deletions.
113 changes: 86 additions & 27 deletions test/test_nestedtensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -2843,25 +2843,45 @@ def grad_test_func(a, b, c):
# We can probably parametrizing existing tests instead of having a separate
# test class as we begin to support more ops. Also maybe rewrite with OpInfos.
class TestNestedTensorSubclass(TestCase):
def _get_example_tensor_lists(self):
return [
def _get_example_tensor_lists(self, include_list_of_lists=True, include_requires_grad=True):

def _make_tensor(*shape, include_requires_grad=include_requires_grad, requires_grad=True):
return torch.randn(
*shape,
requires_grad=(requires_grad if include_requires_grad else False)
)

# Purposefully introduce mixed requires_grad settings for the components
# when include_requires_grad=True.
example_lists = [
# (B, *, D) with B=4
[
torch.randn(2, 5),
torch.randn(3, 5),
torch.randn(4, 5),
torch.randn(6, 5)
_make_tensor(2, 5),
_make_tensor(3, 5, requires_grad=False),
_make_tensor(4, 5, requires_grad=False),
_make_tensor(6, 5)
],
# (B, *, D_0, D_1) with B=5
[
torch.randn(2, 5, 6),
torch.randn(3, 5, 6),
torch.randn(4, 5, 6),
torch.randn(5, 5, 6),
torch.randn(6, 5, 6),
_make_tensor(2, 5, 6),
_make_tensor(3, 5, 6),
_make_tensor(4, 5, 6, requires_grad=False),
_make_tensor(5, 5, 6),
_make_tensor(6, 5, 6),
],
]

if include_list_of_lists:
example_lists.append(
# (B, *, D) with B=3 in list form
[
_make_tensor(2, 5, requires_grad=False).tolist(),
_make_tensor(3, 5).tolist(),
_make_tensor(4, 5).tolist(),
])

return example_lists

def test_tensor_attributes(self, device):
a = torch.randn(2, 3, requires_grad=True, dtype=torch.float64, device=device)
b = torch.randn(3, 3, requires_grad=True, dtype=torch.float64, device=device)
Expand Down Expand Up @@ -2975,30 +2995,69 @@ def test_pin_memory(self, device):
self.assertIs(pinned, pinned.pin_memory())
self.assertEqual(pinned.data_ptr(), pinned.pin_memory().data_ptr())

def _validate_nt(self, nt, tensor_list, device, dtype, requires_grad):
# Validate a bunch of properties after NT construction.
device = torch.device(device)
first_t = torch.as_tensor(tensor_list[0])
expected_dim = first_t.dim() + 1
batch_size = len(tensor_list)
self.assertEqual(nt.dim(), expected_dim)
self.assertEqual(nt.device, device)
self.assertEqual(nt.dtype, dtype)
self.assertEqual(nt.layout, torch.jagged)
self.assertEqual(nt.requires_grad, requires_grad)
self.assertEqual(nt.values().device, device)
self.assertEqual(nt.offsets().device, device)
self.assertEqual(nt.shape[0], batch_size)
self.assertTrue(isinstance(nt.shape[1], torch.SymInt))
self.assertEqual(nt.shape[2:], first_t.shape[1:])

@dtypes(torch.float, torch.double, torch.half)
@parametrize("requires_grad", [False, True])
def test_jagged_layout_construction(self, device, dtype, requires_grad):
for tensor_list in self._get_example_tensor_lists():
@parametrize("components_require_grad", [False, True])
def test_jagged_layout_construction_nested_tensor(
self, device, dtype, requires_grad, components_require_grad):
for tensor_list in self._get_example_tensor_lists(
include_list_of_lists=True, include_requires_grad=components_require_grad):
nt = torch.nested.nested_tensor(
tensor_list,
device=device,
dtype=dtype,
layout=torch.jagged,
requires_grad=requires_grad)
self._validate_nt(nt, tensor_list, device, dtype, requires_grad)

device = torch.device(device)
expected_dim = tensor_list[0].dim() + 1
batch_size = len(tensor_list)
self.assertEqual(nt.dim(), expected_dim)
self.assertEqual(nt.device, device)
self.assertEqual(nt.dtype, dtype)
self.assertEqual(nt.layout, torch.jagged)
self.assertEqual(nt.requires_grad, requires_grad)
self.assertEqual(nt.values().device, device)
self.assertEqual(nt.offsets().device, device)
self.assertEqual(nt.shape[0], batch_size)
self.assertTrue(isinstance(nt.shape[1], torch.SymInt))
self.assertEqual(nt.shape[2:], tensor_list[0].shape[1:])
# Make sure grads -don't- flow back into original tensors for nested_tensor()
if requires_grad:
(nt * 2).backward(torch.ones_like(nt))
for t in tensor_list:
t = t if isinstance(t, torch.Tensor) else torch.as_tensor(t)
self.assertTrue(t.grad is None)

@dtypes(torch.float, torch.double, torch.half)
@parametrize("components_require_grad", [False, True])
def test_jagged_layout_construction_as_nested_tensor(
self, device, dtype, components_require_grad):
# NB: as_nested_tensor(tensor_list) doesn't support lists of lists for tensor_list
for tensor_list in self._get_example_tensor_lists(
include_list_of_lists=False, include_requires_grad=components_require_grad):
nt = torch.nested.as_nested_tensor(
tensor_list,
device=device,
dtype=dtype,
layout=torch.jagged)

# nt.requires_grad=True should be set if at least one component requires grad
self._validate_nt(nt, tensor_list, device, dtype, components_require_grad)

# Make sure grads flow back into original tensors for as_nested_tensor()
if components_require_grad:
(nt * 2).backward(torch.ones_like(nt))
for t in tensor_list:
if t.requires_grad:
self.assertEqual(t.grad, torch.ones_like(t) * 2)
else:
self.assertTrue(t.grad is None)

@unittest.skipIf(PYTORCH_CUDA_MEMCHECK, "is_pinned uses failure to detect pointer property")
@onlyCUDA
Expand All @@ -3010,7 +3069,7 @@ def test_jagged_layout_construction_with_pinned_memory(self, device):
device="cpu",
pin_memory=True)

self.assertEqual(nt.device, torch.device('cpu'))
self._validate_nt(nt, tensor_list, "cpu", torch.float32, requires_grad=False)
self.assertTrue(nt.is_pinned())

@dtypes(torch.double, torch.half)
Expand Down
26 changes: 23 additions & 3 deletions torch/nested/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ def as_nested_tensor(
tensor_list: List[Tensor],
dtype: Optional[DType] = None,
device: Optional[Device] = None,
layout=None
) -> Tensor:
r"""
Constructs a nested tensor preserving autograd history from :attr:`tensor_list` a list of tensors.
Expand All @@ -34,6 +35,8 @@ def as_nested_tensor(
Default: if None, same :class:`torch.dtype` as leftmost tensor in the list.
device (:class:`torch.device`, optional): the desired device of returned nested tensor.
Default: if None, same :class:`torch.device` as leftmost tensor in the list
layout (:class:`torch.layout`, optional): the desired layout of returned nested tensor.
Only strided and jagged layouts are supported. Default: if None, the strided layout.
Example::
Expand All @@ -53,9 +56,20 @@ def as_nested_tensor(
not isinstance(t, Tensor) for t in tensor_list
):
raise TypeError(
"nested_tensor(): Expected first argument to be a list of tensors "
"as_nested_tensor(): Expected first argument to be a list of tensors "
)
return torch._nested_tensor_from_tensor_list(tensor_list, dtype, None, device, None)

if layout is None:
layout = torch.strided
if layout == torch.strided:
return torch._nested_tensor_from_tensor_list(tensor_list, dtype, None, device, None)
elif layout == torch.jagged:
from torch.nested._internal.nested_tensor import jagged_from_list

nt, _ = jagged_from_list(tensor_list, offsets=None, device=device, dtype=dtype)
return nt
else:
raise RuntimeError(f"Specified layout is unsupported for nested tensors: {layout}")


# Note: This not only adds doc strings for the nested ops, but
Expand Down Expand Up @@ -155,9 +169,15 @@ def nested_tensor(tensor_list, *, dtype=None, layout=None, device=None, requires
requires_grad=requires_grad,
pin_memory=pin_memory)
elif layout == torch.jagged:
# Need to:
# * Detach tensors to discard autograd history
# * Wrap lists of scalars as tensors
list_of_tensors = [t.detach() if isinstance(t, Tensor) else torch.as_tensor(t)
for t in tensor_list]

from torch.nested._internal.nested_tensor import jagged_from_list

nt, _ = jagged_from_list(tensor_list, offsets=None, device=device, dtype=dtype)
nt, _ = jagged_from_list(list_of_tensors, offsets=None, device=device, dtype=dtype)

nt.requires_grad_(requires_grad)
if pin_memory:
Expand Down

0 comments on commit fc3ae96

Please sign in to comment.