Skip to content

Commit

Permalink
[functorch] roll : fix batching rule for scalar tensor (pytorch#95048)
Browse files Browse the repository at this point in the history
  • Loading branch information
kshitij12345 authored and pytorchmergebot committed Feb 19, 2023
1 parent 039b4c8 commit 06489a3
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 0 deletions.
7 changes: 7 additions & 0 deletions aten/src/ATen/functorch/BatchRulesViews.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -319,7 +319,14 @@ std::tuple<Tensor, optional<int64_t>> roll_batch_rule(const Tensor& self, option
// We will do something like: t.reshape(a, -1).roll(1, dims=[1, ]).reshape(old_shape)
auto old_shape = self_.sizes();
new_dims.push_back(1);
auto logical_rank = rankWithoutBatchDim(self, bdim);
if (logical_rank == 0) {
self_ = self_.unsqueeze(0);
}

auto output = at::roll(self_.flatten(1), shifts, new_dims);
// NOTE: For scalar tensor, we don't need to unsqueeze as reshape
// with `old_shape` takes care of it.
output = output.reshape(old_shape);
return std::make_tuple(output, 0);
}
Expand Down
5 changes: 5 additions & 0 deletions torch/_refs/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3282,6 +3282,11 @@ def roll(
# Keeping this as ref for now as FakeTensor runs into some issues with complex tensors
return clone(a)

if a.dim() == 0 and len(dims) > 0:
raise IndexError(
f"Dimension specified as {dims[0]} but tensor has no dimensions"
)

len_shifts = len(shifts)
len_dims = len(dims)
if len_shifts != 1 or len_dims != 1:
Expand Down
6 changes: 6 additions & 0 deletions torch/testing/_internal/common_methods_invocations.py
Original file line number Diff line number Diff line change
Expand Up @@ -5216,6 +5216,8 @@ def sample_inputs_roll(op_info, device, dtype, requires_grad=False, **kwargs):
yield SampleInput(make_arg((0, 0, 0)), args=arg)
yield SampleInput(make_arg((S, S, S)), args=arg)

# Scalar tensor
yield SampleInput(make_arg(()), args=(10, ))

def error_inputs_roll(op_info, device, **kwargs):
make_arg = partial(make_tensor, device=device, dtype=torch.float32)
Expand All @@ -5231,6 +5233,10 @@ def error_inputs_roll(op_info, device, **kwargs):
s3 = SampleInput(make_arg((S, )), 0, 2)
yield ErrorInput(s3, error_regex=err_msg3, error_type=IndexError)

err_msg4 = ("Dimension specified as 0")
s4 = SampleInput(make_arg(()), 0, 0)
yield ErrorInput(s4, error_regex=err_msg4, error_type=IndexError)

def sample_inputs_rot90(op_info, device, dtype, requires_grad=False, **kwargs):
make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)

Expand Down

0 comments on commit 06489a3

Please sign in to comment.