Skip to content

Commit

Permalink
Added support for torch.Tensor.unfold (#136)
Browse files Browse the repository at this point in the history
  • Loading branch information
riccardofelluga committed Apr 8, 2024
1 parent cd80d08 commit 393828f
Show file tree
Hide file tree
Showing 5 changed files with 74 additions and 0 deletions.
5 changes: 5 additions & 0 deletions thunder/clang/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -1116,6 +1116,11 @@ def unsqueeze(a, /, dims: int | Sequence[int]) -> TensorProxy:
return prims.broadcast_in_dim(a, shape, broadcast_dims)


@clangop()
def unfold(a: TensorProxy, /, dim: int, size: int, step: int) -> TensorProxy:
return prims.unfold(a, dim, size, step)


@clangop()
def cat(tensors: list[TensorProxy], dim: int):
"""Concatenates the given sequence of tensors in the given dimension."""
Expand Down
21 changes: 21 additions & 0 deletions thunder/core/prims.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,7 @@ class PrimIDs(Enum):
SLICE = auto()
SQUEEZE = auto()
TRANSPOSE = auto()
UNFOLD = auto()
VIEW = auto()
# Memory layout prims (Experimental)
STRIDE_ORDER = auto()
Expand Down Expand Up @@ -3077,6 +3078,26 @@ def transpose_meta(a: TensorProxy, /, permutation: tuple[int, ...]) -> TensorPro

view = make_prim(PrimIDs.VIEW, "view", meta=reshape_meta, tags=(OpTags.SHAPE_OP,))


def unfold_meta(a: TensorProxy, /, dim: int, size: int, step: int) -> TensorProxy:
dim = utils.canonicalize_dim(a.ndim, dim)
max_size = 1 if a.ndim == 0 else a.shape[dim]

utils.check(
size <= max_size, lambda: f"Maximum size for tensor at dimension {dim} is {max_size} but size is {size}"
)
utils.check(size >= 0, lambda: f"Size is {size} but must be >= 0")
utils.check(step > 0, lambda: f"Step is {step} but must be > 0")

shape = list(a.shape)
shape.append(size)
shape[dim] = (shape[dim] - size) // step + 1

return TensorProxy(like=a, shape=shape)


unfold = make_prim(PrimIDs.UNFOLD, "unfold", meta=unfold_meta, tags=(OpTags.SHAPE_OP,))

#
# Memory format prims (Experimental)
#
Expand Down
3 changes: 3 additions & 0 deletions thunder/executors/torchex.py
Original file line number Diff line number Diff line change
Expand Up @@ -463,6 +463,7 @@ def _tensor_from_sequence_prims_transform(
tensor_split = _register_torch_operation("tensor_split")
transpose = _register_torch_operation("transpose")
unbind = _register_torch_operation("unbind")
unfold = _register_torch_operation("unfold", module=torch.Tensor)
unsqueeze = _register_torch_operation("unsqueeze")
view = _register_torch_operation("view", module=torch.Tensor)

Expand Down Expand Up @@ -533,6 +534,7 @@ def _squeeze_transform(a: TensorLike, /, dim: None | int | Sequence[int] = None)
_register_implementation(prims.slice_prim, slice_prim_impl, checker=_always_executable)
_register_implementation(prims.squeeze, checker=_always_executable, execution_transform=_squeeze_transform)
_register_implementation(prims.transpose, checker=_always_executable, execution_transform=_transpose_prim_transform)
_register_implementation(prims.unfold, unfold, checker=_always_executable)
_register_implementation(prims.view, view, checker=_always_executable)

_register_implementation(ltorch.cat, cat, checker=_always_executable)
Expand All @@ -553,6 +555,7 @@ def _squeeze_transform(a: TensorLike, /, dim: None | int | Sequence[int] = None)
_register_implementation(ltorch.tensor_split, tensor_split, checker=_always_executable)
_register_implementation(ltorch.transpose, transpose, checker=_always_executable)
_register_implementation(ltorch.unbind, unbind, checker=_always_executable)
_register_implementation(ltorch.unfold, unfold, checker=_always_executable)
_register_implementation(ltorch.unsqueeze, unsqueeze, checker=_always_executable)
_register_implementation(ltorch.view, view, checker=_always_executable)

Expand Down
40 changes: 40 additions & 0 deletions thunder/tests/opinfos.py
Original file line number Diff line number Diff line change
Expand Up @@ -2994,6 +2994,46 @@ def unbind_sample_generator(op, device, dtype, requires_grad, **kwargs):
shape_ops.append(unbind_opinfo)


def unfold_sample_generator(op, device, dtype, requires_grad, **kwargs):
make = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)

cases = (
((), 0, 1, 3),
((), -1, 0, 5),
((0,), 0, 0, 1),
((8,), 0, 2, 1),
((6, 2), 0, 2, 2),
)

for shape, dim, size, step in cases:
yield SampleInput(make(shape), dim, size, step)


def unfold_error_generator(op, device, dtype=torch.float32, **kwargs):
make = partial(make_tensor, device=device, dtype=dtype)

cases = (
((), 0, 2, 1, RuntimeError, "Maximum size for tensor at dimension 0 is 1 but size is 2"),
((0,), 0, 0, -1, RuntimeError, "Step is -1 but must be > 0"),
((8,), 1, 2, 1, IndexError, r"Dimension out of range \(expected to be in range of \[-1, 0\], but got 1\)"),
((8,), 0, -5, 1, RuntimeError, "Size is -5 but must be >= 0"),
((8,), 0, 10, 1, RuntimeError, "Maximum size for tensor at dimension 0 is 8 but size is 10"),
)

for shape, dim, size, step, err_type, err_msg in cases:
yield SampleInput(make(shape), dim, size, step), err_type, err_msg


unfold_opinfo = OpInfo(
clang.unfold,
sample_input_generator=unfold_sample_generator,
error_input_generator=unfold_error_generator,
torch_reference=torch.Tensor.unfold,
)

shape_ops.append(unfold_opinfo)


def flip_sample_generator(op, device, dtype, requires_grad, **kwargs):
make = partial(make_tensor, device=device, dtype=dtype)

Expand Down
5 changes: 5 additions & 0 deletions thunder/torch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -956,6 +956,11 @@ def unbind(a: TensorLike, /, dim: int = 0) -> tuple[TensorLike, ...]:
return tuple(s.squeeze(dim) for s in tensor_split(a, a.shape[dim], dim))


@torchsymbol(torch.Tensor.unfold, is_method=True)
def unfold(a: TensorLike, /, dim: int, size: int, step: int) -> TensorLike:
return clang.unfold(a, dim, size, step)


@torchsymbol(torch.unsqueeze, is_method=True)
def unsqueeze(a: TensorLike, /, dim: int) -> TensorLike:
return clang.unsqueeze(a, dim)
Expand Down

0 comments on commit 393828f

Please sign in to comment.