diff --git a/thunder/clang/__init__.py b/thunder/clang/__init__.py index 9c592c26d..0f28c990e 100644 --- a/thunder/clang/__init__.py +++ b/thunder/clang/__init__.py @@ -1116,6 +1116,11 @@ def unsqueeze(a, /, dims: int | Sequence[int]) -> TensorProxy: return prims.broadcast_in_dim(a, shape, broadcast_dims) +@clangop() +def unfold(a: TensorProxy, /, dim: int, size: int, step: int) -> TensorProxy: + return prims.unfold(a, dim, size, step) + + @clangop() def cat(tensors: list[TensorProxy], dim: int): """Concatenates the given sequence of tensors in the given dimension.""" diff --git a/thunder/core/prims.py b/thunder/core/prims.py index 778ce0121..ec46cd153 100644 --- a/thunder/core/prims.py +++ b/thunder/core/prims.py @@ -152,6 +152,7 @@ class PrimIDs(Enum): SLICE = auto() SQUEEZE = auto() TRANSPOSE = auto() + UNFOLD = auto() VIEW = auto() # Memory layout prims (Experimental) STRIDE_ORDER = auto() @@ -3077,6 +3078,26 @@ def transpose_meta(a: TensorProxy, /, permutation: tuple[int, ...]) -> TensorPro view = make_prim(PrimIDs.VIEW, "view", meta=reshape_meta, tags=(OpTags.SHAPE_OP,)) + +def unfold_meta(a: TensorProxy, /, dim: int, size: int, step: int) -> TensorProxy: + dim = utils.canonicalize_dim(a.ndim, dim) + max_size = 1 if a.ndim == 0 else a.shape[dim] + + utils.check( + size <= max_size, lambda: f"Maximum size for tensor at dimension {dim} is {max_size} but size is {size}" + ) + utils.check(size >= 0, lambda: f"Size is {size} but must be >= 0") + utils.check(step > 0, lambda: f"Step is {step} but must be > 0") + + shape = list(a.shape) + shape.append(size) + shape[dim] = (shape[dim] - size) // step + 1 + + return TensorProxy(like=a, shape=shape) + + +unfold = make_prim(PrimIDs.UNFOLD, "unfold", meta=unfold_meta, tags=(OpTags.SHAPE_OP,)) + # # Memory format prims (Experimental) # diff --git a/thunder/executors/torchex.py b/thunder/executors/torchex.py index 7317b0587..eaddda043 100644 --- a/thunder/executors/torchex.py +++ b/thunder/executors/torchex.py @@ -463,6 +463,7 @@ def _tensor_from_sequence_prims_transform( tensor_split = _register_torch_operation("tensor_split") transpose = _register_torch_operation("transpose") unbind = _register_torch_operation("unbind") +unfold = _register_torch_operation("unfold", module=torch.Tensor) unsqueeze = _register_torch_operation("unsqueeze") view = _register_torch_operation("view", module=torch.Tensor) @@ -533,6 +534,7 @@ def _squeeze_transform(a: TensorLike, /, dim: None | int | Sequence[int] = None) _register_implementation(prims.slice_prim, slice_prim_impl, checker=_always_executable) _register_implementation(prims.squeeze, checker=_always_executable, execution_transform=_squeeze_transform) _register_implementation(prims.transpose, checker=_always_executable, execution_transform=_transpose_prim_transform) +_register_implementation(prims.unfold, unfold, checker=_always_executable) _register_implementation(prims.view, view, checker=_always_executable) _register_implementation(ltorch.cat, cat, checker=_always_executable) @@ -553,6 +555,7 @@ def _squeeze_transform(a: TensorLike, /, dim: None | int | Sequence[int] = None) _register_implementation(ltorch.tensor_split, tensor_split, checker=_always_executable) _register_implementation(ltorch.transpose, transpose, checker=_always_executable) _register_implementation(ltorch.unbind, unbind, checker=_always_executable) +_register_implementation(ltorch.unfold, unfold, checker=_always_executable) _register_implementation(ltorch.unsqueeze, unsqueeze, checker=_always_executable) _register_implementation(ltorch.view, view, checker=_always_executable) diff --git a/thunder/tests/opinfos.py b/thunder/tests/opinfos.py index 3e9c6b68c..462d98f3b 100644 --- a/thunder/tests/opinfos.py +++ b/thunder/tests/opinfos.py @@ -2994,6 +2994,46 @@ def unbind_sample_generator(op, device, dtype, requires_grad, **kwargs): shape_ops.append(unbind_opinfo) +def unfold_sample_generator(op, device, dtype, requires_grad, **kwargs): + make = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + + cases = ( + ((), 0, 1, 3), + ((), -1, 0, 5), + ((0,), 0, 0, 1), + ((8,), 0, 2, 1), + ((6, 2), 0, 2, 2), + ) + + for shape, dim, size, step in cases: + yield SampleInput(make(shape), dim, size, step) + + +def unfold_error_generator(op, device, dtype=torch.float32, **kwargs): + make = partial(make_tensor, device=device, dtype=dtype) + + cases = ( + ((), 0, 2, 1, RuntimeError, "Maximum size for tensor at dimension 0 is 1 but size is 2"), + ((0,), 0, 0, -1, RuntimeError, "Step is -1 but must be > 0"), + ((8,), 1, 2, 1, IndexError, r"Dimension out of range \(expected to be in range of \[-1, 0\], but got 1\)"), + ((8,), 0, -5, 1, RuntimeError, "Size is -5 but must be >= 0"), + ((8,), 0, 10, 1, RuntimeError, "Maximum size for tensor at dimension 0 is 8 but size is 10"), + ) + + for shape, dim, size, step, err_type, err_msg in cases: + yield SampleInput(make(shape), dim, size, step), err_type, err_msg + + +unfold_opinfo = OpInfo( + clang.unfold, + sample_input_generator=unfold_sample_generator, + error_input_generator=unfold_error_generator, + torch_reference=torch.Tensor.unfold, +) + +shape_ops.append(unfold_opinfo) + + def flip_sample_generator(op, device, dtype, requires_grad, **kwargs): make = partial(make_tensor, device=device, dtype=dtype) diff --git a/thunder/torch/__init__.py b/thunder/torch/__init__.py index a408dc768..aec1f4383 100644 --- a/thunder/torch/__init__.py +++ b/thunder/torch/__init__.py @@ -956,6 +956,11 @@ def unbind(a: TensorLike, /, dim: int = 0) -> tuple[TensorLike, ...]: return tuple(s.squeeze(dim) for s in tensor_split(a, a.shape[dim], dim)) +@torchsymbol(torch.Tensor.unfold, is_method=True) +def unfold(a: TensorLike, /, dim: int, size: int, step: int) -> TensorLike: + return clang.unfold(a, dim, size, step) + + @torchsymbol(torch.unsqueeze, is_method=True) def unsqueeze(a: TensorLike, /, dim: int) -> TensorLike: return clang.unsqueeze(a, dim)