Skip to content

Commit

Permalink
feat: added torch.all (#355)
Browse files Browse the repository at this point in the history
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
k223kim and pre-commit-ci[bot] committed May 13, 2024
1 parent 8a8a6cf commit 084101e
Show file tree
Hide file tree
Showing 3 changed files with 83 additions and 0 deletions.
4 changes: 4 additions & 0 deletions thunder/executors/torchex.py
Original file line number Diff line number Diff line change
Expand Up @@ -487,6 +487,8 @@ def _tensor_from_sequence_prims_transform(
unsqueeze = _register_torch_operation("unsqueeze")
view = _register_torch_operation("view", module=torch.Tensor)
view_as = _register_torch_operation("view_as", module=torch.Tensor)
all_tensor = _register_torch_operation("all", like=ltorch.all_tensor)
any_tensor = _register_torch_operation("any", like=ltorch.any_tensor)


def _broadcast_in_dim_prim_transform(
Expand Down Expand Up @@ -605,6 +607,8 @@ def _empty_transform(
_register_implementation(ltorch.view, view, checker=_always_executable)
_register_implementation(ltorch.view_as, view_as, checker=_always_executable)
_register_implementation(ltorch.empty, empty, checker=_always_executable, execution_transform=_empty_transform)
_register_implementation(ltorch.all_tensor, all_tensor, checker=_always_executable)
_register_implementation(ltorch.any_tensor, any_tensor, checker=_always_executable)

#
# Memory format operations
Expand Down
47 changes: 47 additions & 0 deletions thunder/tests/opinfos.py
Original file line number Diff line number Diff line change
Expand Up @@ -4832,6 +4832,53 @@ def unsqueeze_sample_generator(op, device, dtype, requires_grad, **kwargs):
reduction_ops = []


def all_tensor_sample_generator(op, device, dtype, requires_grad, **kwargs):
make = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)

# input shape, dim, keepdim
dim_cases = (
((4, 4), None, False),
((4, 4), None, True),
((2, 3), 0, True),
((2, 3, 4), (1, 2), False),
((2, 3, 4), (1, 2), True),
((2, 3, 4), (-1, 1), False),
((2, 3, 4), (-1, 1), True),
)

for input_shape, dim, keepdim in dim_cases:
yield SampleInput(make(input_shape), dim, keepdim)


def all_tensor_error_generator(op, device, dtype=torch.float32, **kwargs):
make = partial(make_tensor, device=device, dtype=dtype)
err_msg = r"Dimension out of range \(expected to be in range of \[.*?\], but got .*\)"
yield (
SampleInput(make(5, 1, 2, 3), 4),
IndexError,
err_msg,
)


all_tensor_opinfo = OpInfo(
ltorch.all_tensor,
sample_input_generator=all_tensor_sample_generator,
error_input_generator=all_tensor_error_generator,
torch_reference=torch.all,
)

reduction_ops.append(all_tensor_opinfo)


any_tensor_opinfo = OpInfo(
ltorch.any_tensor,
sample_input_generator=all_tensor_sample_generator,
torch_reference=torch.any,
)

reduction_ops.append(any_tensor_opinfo)


# TODO: increase reduction samples and refacort amax and sum generators
def amax_amin_sample_generator(op, device, dtype, requires_grad, **kwargs):
# For grad test stability it's better to use wider range of values
Expand Down
32 changes: 32 additions & 0 deletions thunder/torch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -1931,6 +1931,38 @@ def _reduction(
return result


@torchsymbol(torch.all, is_method=True, id="torch.all")
def all_tensor(
a: TensorLike, /, dim: None | int | Sequence[int] = None, keepdim: bool = False, *, out: None | TensorLike = None
) -> TensorLike:
# named as all_tensor to avoid confusion with python's built-in all function
utils.check(out is None, lambda: "out is not None which is currently unsupported", NotImplementedError)
result = logical_not(any_tensor(logical_not(a), dim=dim, keepdim=keepdim))

# Pytorch's torch.all matches the behavior of NumPy in returning output of dtype bool for all supported dtypes except uint8.
# For uint8 the dtype of output is uint8 iteself (https://pytorch.org/docs/stable/generated/torch.all.html)
if a.dtype is dtypes.uint8:
result = to(result, dtype=dtypes.uint8)
return result


@torchsymbol(torch.any, is_method=True, id="torch.any")
def any_tensor(a: TensorLike, /, dim: None | int | Sequence[int] = None, keepdim: bool = False) -> TensorLike:
# named as any_tensor to avoid confusion with python's built-in any function
a_ = clang.maybe_convert_to_dtype(a, dtypes.bool8)
if isinstance(dim, Sequence) and len(dim) == 0:
# PyTorch returns a_.clone()
result = a_ | a_
else:
result = ne(sum(a_, dim=dim, keepdim=keepdim), False)

# Pytorch's torch.any matches the behavior of NumPy in returning output of dtype bool for all supported dtypes except uint8.
# For uint8 the dtype of output is uint8 iteself (https://pytorch.org/docs/stable/generated/torch.any.html)
if a.dtype is dtypes.uint8:
return prims.convert_element_type(result, dtypes.uint8)
return result


@torchsymbol(torch.amax, is_method=True)
def amax(a, /, dim=None, keepdim: bool = False):
return _reduction(
Expand Down

0 comments on commit 084101e

Please sign in to comment.