From e5de550c5e511bb34453fc850148b2062cd021db Mon Sep 17 00:00:00 2001 From: k223kim Date: Fri, 3 May 2024 17:13:50 +0900 Subject: [PATCH 1/9] feat: added torch.all --- thunder/clang/__init__.py | 7 ++++++ thunder/tests/opinfos.py | 48 +++++++++++++++++++++++++++++++++++++++ thunder/torch/__init__.py | 23 +++++++++++++++++++ 3 files changed, 78 insertions(+) diff --git a/thunder/clang/__init__.py b/thunder/clang/__init__.py index d922cf013..131f33af3 100644 --- a/thunder/clang/__init__.py +++ b/thunder/clang/__init__.py @@ -1799,6 +1799,13 @@ def logical_and(a, b): return a & b +@clangop() +def logical_not(a): + if not utils.is_boolean_dtype(dtypes.to_dtype(a)): + return a == 0 + return ~a + + @clangop(method_name="le") def le(a, b): return _elementwise_binary_wrapper( diff --git a/thunder/tests/opinfos.py b/thunder/tests/opinfos.py index 07bb925ba..149ad3429 100644 --- a/thunder/tests/opinfos.py +++ b/thunder/tests/opinfos.py @@ -2607,6 +2607,54 @@ def nan_to_num_error_generator(op, device, dtype=torch.float32, **kwargs): conditional_and_mask_ops.append(nan_to_num_opinfo) +def all_sample_generator(op, device, dtype, requires_grad, **kwargs): + make = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + + # when dim is None + cases = ( + (4, 4), + (2, 3), + (0, 2), + (2, 3, 4), + ) + + for input_shape in cases: + yield SampleInput(make(input_shape)) + + # input shape, dim, keepdim + dim_cases = ( + ((4, 4), None, False), + ((4, 4), None, True), + ((2, 3), 0, True), + ((2, 3, 4), (1, 2), False), + ((2, 3, 4), (1, 2), True), + ((2, 3, 4), (-1, 1), False), + ((2, 3, 4), (-1, 1), True), + ) + + for input_shape, dim, keepdim in dim_cases: + yield SampleInput(make(input_shape), dim, keepdim) + + +def all_error_generator(op, device, dtype=torch.float32, **kwargs): + make = partial(make_tensor, device=device, dtype=dtype) + err_msg = r"Dimension out of range \(expected to be in range of \[.*?\], but got .*\)" + yield ( + SampleInput(make(5, 1, 2, 3), 4), + IndexError, + err_msg, + ) + + +all_opinfo = OpInfo( + ltorch.all, + sample_input_generator=all_sample_generator, + error_input_generator=all_error_generator, + torch_reference=torch.all, +) +conditional_and_mask_ops.append(all_opinfo) + + def clamp_sample_generator(op, device, dtype, requires_grad, **kwargs): cases = ( ((5,), (5,), (5,)), diff --git a/thunder/torch/__init__.py b/thunder/torch/__init__.py index 1a4ed0b31..378f50557 100644 --- a/thunder/torch/__init__.py +++ b/thunder/torch/__init__.py @@ -1372,6 +1372,11 @@ def logical_and(a, b, /): return clang.logical_and(a, b) +@torchsymbol(torch.logical_not, is_method=True) +def logical_not(a, /): + return clang.logical_not(a) + + @torchsymbol(torch.le, is_method=True) def le(a, b, /): return clang.le(a, b) @@ -1654,6 +1659,24 @@ def convert(x, if_none): return result +@torchsymbol(torch.all, is_method=True) +def all( + a: TensorLike, /, dim: None | int | Sequence[int] = None, keepdim: bool = False, *, out: None | TensorLike = None +) -> TensorLike: + if isinstance(dim, Sequence) and len(dim) == 0: + # PyTorch returns a.clone() + result = a | a + else: + not_result = logical_not(a) + sum_result = sum(not_result, dim=dim, keepdim=keepdim) + result = ne(sum_result, False) + result = logical_not(result) + + if a.dtype is dtypes.uint8: + result = to(result, dtype=dtypes.uint8) + return result + + # # Reduction operations # From cdc59969cf84b6871069bccabcf832bfee32f43a Mon Sep 17 00:00:00 2001 From: k223kim Date: Fri, 3 May 2024 19:10:57 +0900 Subject: [PATCH 2/9] feat: temp fix with logical_not test --- thunder/executors/torchex.py | 2 ++ thunder/tests/opinfos.py | 53 ++++++++++++++++++++++++------------ thunder/torch/__init__.py | 16 ++++++----- 3 files changed, 46 insertions(+), 25 deletions(-) diff --git a/thunder/executors/torchex.py b/thunder/executors/torchex.py index 01c39c082..ab50882c6 100644 --- a/thunder/executors/torchex.py +++ b/thunder/executors/torchex.py @@ -782,6 +782,7 @@ def _elementwise_unary_with_inplace_checker(a: TensorProxy, /, inplace: bool = F ge = _register_torch_operation("ge") gt = _register_torch_operation("gt") logical_and = _register_torch_operation("logical_and") +logical_not = _register_torch_operation("logical_not") le = _register_torch_operation("le") lt = _register_torch_operation("lt") maximum = _register_torch_operation("maximum") @@ -872,6 +873,7 @@ def _sub_transform(a: Number | TensorProxy, b: Number | TensorProxy, *, alpha: N _register_elementwise_binary_implementation(ltorch.ge, ge) _register_elementwise_binary_implementation(ltorch.gt, gt) _register_elementwise_binary_implementation(ltorch.logical_and, logical_and) +_register_elementwise_binary_implementation(ltorch.logical_not, logical_not) _register_elementwise_binary_implementation(ltorch.le, le) _register_elementwise_binary_implementation(ltorch.lt, lt) _register_elementwise_binary_implementation(ltorch.maximum, maximum) diff --git a/thunder/tests/opinfos.py b/thunder/tests/opinfos.py index 149ad3429..92d2ba6d4 100644 --- a/thunder/tests/opinfos.py +++ b/thunder/tests/opinfos.py @@ -561,6 +561,23 @@ def _is_cuda_torch(x: torch.Tensor) -> bool: tensor_properties.append(is_cuda_opinfo) +def logical_not_sample_generator(op, device, dtype, requires_grad, **kwargs): + make = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + cases = ( + (4,4), + (5, 1, 3), + ) + for input_shape in cases: + yield SampleInput(make(input_shape)) + + +logical_not_opinfo = OpInfo( + clang.logical_not, + sample_input_generator=logical_not_sample_generator, + torch_reference=torch.logical_not, +) +tensor_properties.append(logical_not_opinfo) + opinfos.extend(tensor_properties) @@ -2607,7 +2624,7 @@ def nan_to_num_error_generator(op, device, dtype=torch.float32, **kwargs): conditional_and_mask_ops.append(nan_to_num_opinfo) -def all_sample_generator(op, device, dtype, requires_grad, **kwargs): +def all_tensor_sample_generator(op, device, dtype, requires_grad, **kwargs): make = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) # when dim is None @@ -2622,21 +2639,21 @@ def all_sample_generator(op, device, dtype, requires_grad, **kwargs): yield SampleInput(make(input_shape)) # input shape, dim, keepdim - dim_cases = ( - ((4, 4), None, False), - ((4, 4), None, True), - ((2, 3), 0, True), - ((2, 3, 4), (1, 2), False), - ((2, 3, 4), (1, 2), True), - ((2, 3, 4), (-1, 1), False), - ((2, 3, 4), (-1, 1), True), - ) + # dim_cases = ( + # ((4, 4), None, False), + # ((4, 4), None, True), + # ((2, 3), 0, True), + # ((2, 3, 4), (1, 2), False), + # ((2, 3, 4), (1, 2), True), + # ((2, 3, 4), (-1, 1), False), + # ((2, 3, 4), (-1, 1), True), + # ) - for input_shape, dim, keepdim in dim_cases: - yield SampleInput(make(input_shape), dim, keepdim) + # for input_shape, dim, keepdim in dim_cases: + # yield SampleInput(make(input_shape), dim, keepdim) -def all_error_generator(op, device, dtype=torch.float32, **kwargs): +def all_tensor_error_generator(op, device, dtype=torch.float32, **kwargs): make = partial(make_tensor, device=device, dtype=dtype) err_msg = r"Dimension out of range \(expected to be in range of \[.*?\], but got .*\)" yield ( @@ -2646,13 +2663,13 @@ def all_error_generator(op, device, dtype=torch.float32, **kwargs): ) -all_opinfo = OpInfo( - ltorch.all, - sample_input_generator=all_sample_generator, - error_input_generator=all_error_generator, +all_tensor_opinfo = OpInfo( + ltorch.all_tensor, + sample_input_generator=all_tensor_sample_generator, + error_input_generator=all_tensor_error_generator, torch_reference=torch.all, ) -conditional_and_mask_ops.append(all_opinfo) +conditional_and_mask_ops.append(all_tensor_opinfo) def clamp_sample_generator(op, device, dtype, requires_grad, **kwargs): diff --git a/thunder/torch/__init__.py b/thunder/torch/__init__.py index 378f50557..8b9cc3c1e 100644 --- a/thunder/torch/__init__.py +++ b/thunder/torch/__init__.py @@ -1659,17 +1659,19 @@ def convert(x, if_none): return result -@torchsymbol(torch.all, is_method=True) -def all( +@torchsymbol(torch.all, is_method=True, id="torch.all") +def all_tensor( a: TensorLike, /, dim: None | int | Sequence[int] = None, keepdim: bool = False, *, out: None | TensorLike = None -) -> TensorLike: +) -> TensorLike | None: + result = logical_not(a) if isinstance(dim, Sequence) and len(dim) == 0: - # PyTorch returns a.clone() - result = a | a + # PyTorch returns result.clone() + result = result | result else: - not_result = logical_not(a) - sum_result = sum(not_result, dim=dim, keepdim=keepdim) + sum_result = sum(result, dim=dim, keepdim=keepdim) result = ne(sum_result, False) + if a.dtype is dtypes.uint8: + result = prims.convert_element_type(result, dtype=dtypes.uint8) result = logical_not(result) if a.dtype is dtypes.uint8: From 8e1feeac38dfdaacab71e7557f480cd97a5b4deb Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 3 May 2024 10:12:14 +0000 Subject: [PATCH 3/9] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- thunder/tests/opinfos.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/thunder/tests/opinfos.py b/thunder/tests/opinfos.py index 92d2ba6d4..fb8b5d643 100644 --- a/thunder/tests/opinfos.py +++ b/thunder/tests/opinfos.py @@ -561,15 +561,16 @@ def _is_cuda_torch(x: torch.Tensor) -> bool: tensor_properties.append(is_cuda_opinfo) + def logical_not_sample_generator(op, device, dtype, requires_grad, **kwargs): make = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) cases = ( - (4,4), + (4, 4), (5, 1, 3), ) for input_shape in cases: yield SampleInput(make(input_shape)) - + logical_not_opinfo = OpInfo( clang.logical_not, From de89a30d75e5e375314635229cc2b24ea982a6ec Mon Sep 17 00:00:00 2001 From: Kaeun Kim Date: Wed, 8 May 2024 22:46:21 +0900 Subject: [PATCH 4/9] feat: removed logical not --- thunder/clang/__init__.py | 7 ------- thunder/executors/torchex.py | 2 -- thunder/torch/__init__.py | 5 ----- 3 files changed, 14 deletions(-) diff --git a/thunder/clang/__init__.py b/thunder/clang/__init__.py index 131f33af3..d922cf013 100644 --- a/thunder/clang/__init__.py +++ b/thunder/clang/__init__.py @@ -1799,13 +1799,6 @@ def logical_and(a, b): return a & b -@clangop() -def logical_not(a): - if not utils.is_boolean_dtype(dtypes.to_dtype(a)): - return a == 0 - return ~a - - @clangop(method_name="le") def le(a, b): return _elementwise_binary_wrapper( diff --git a/thunder/executors/torchex.py b/thunder/executors/torchex.py index ab50882c6..01c39c082 100644 --- a/thunder/executors/torchex.py +++ b/thunder/executors/torchex.py @@ -782,7 +782,6 @@ def _elementwise_unary_with_inplace_checker(a: TensorProxy, /, inplace: bool = F ge = _register_torch_operation("ge") gt = _register_torch_operation("gt") logical_and = _register_torch_operation("logical_and") -logical_not = _register_torch_operation("logical_not") le = _register_torch_operation("le") lt = _register_torch_operation("lt") maximum = _register_torch_operation("maximum") @@ -873,7 +872,6 @@ def _sub_transform(a: Number | TensorProxy, b: Number | TensorProxy, *, alpha: N _register_elementwise_binary_implementation(ltorch.ge, ge) _register_elementwise_binary_implementation(ltorch.gt, gt) _register_elementwise_binary_implementation(ltorch.logical_and, logical_and) -_register_elementwise_binary_implementation(ltorch.logical_not, logical_not) _register_elementwise_binary_implementation(ltorch.le, le) _register_elementwise_binary_implementation(ltorch.lt, lt) _register_elementwise_binary_implementation(ltorch.maximum, maximum) diff --git a/thunder/torch/__init__.py b/thunder/torch/__init__.py index 8b9cc3c1e..43ef6bc93 100644 --- a/thunder/torch/__init__.py +++ b/thunder/torch/__init__.py @@ -1372,11 +1372,6 @@ def logical_and(a, b, /): return clang.logical_and(a, b) -@torchsymbol(torch.logical_not, is_method=True) -def logical_not(a, /): - return clang.logical_not(a) - - @torchsymbol(torch.le, is_method=True) def le(a, b, /): return clang.le(a, b) From 232eedd82b9e1e78656a4c43552944a845ada376 Mon Sep 17 00:00:00 2001 From: Kaeun Kim Date: Wed, 8 May 2024 22:47:19 +0900 Subject: [PATCH 5/9] feat: removed logical not --- thunder/tests/opinfos.py | 20 -------------------- 1 file changed, 20 deletions(-) diff --git a/thunder/tests/opinfos.py b/thunder/tests/opinfos.py index fb8b5d643..867d629eb 100644 --- a/thunder/tests/opinfos.py +++ b/thunder/tests/opinfos.py @@ -562,26 +562,6 @@ def _is_cuda_torch(x: torch.Tensor) -> bool: tensor_properties.append(is_cuda_opinfo) -def logical_not_sample_generator(op, device, dtype, requires_grad, **kwargs): - make = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) - cases = ( - (4, 4), - (5, 1, 3), - ) - for input_shape in cases: - yield SampleInput(make(input_shape)) - - -logical_not_opinfo = OpInfo( - clang.logical_not, - sample_input_generator=logical_not_sample_generator, - torch_reference=torch.logical_not, -) -tensor_properties.append(logical_not_opinfo) - -opinfos.extend(tensor_properties) - - # NOTE: slightly different from generic _elementwise_unary_torch helper # because this returns the input when given an unsigned type @wraps(torch.abs) From 7b94503af2940d3f020df2dc2d996f62d6d96a68 Mon Sep 17 00:00:00 2001 From: Kaeun Kim Date: Thu, 9 May 2024 22:57:44 +0900 Subject: [PATCH 6/9] feat: added torch any and changed torch.all implementation --- thunder/tests/opinfos.py | 111 +++++++++++++++++++++----------------- thunder/torch/__init__.py | 46 +++++++++------- 2 files changed, 89 insertions(+), 68 deletions(-) diff --git a/thunder/tests/opinfos.py b/thunder/tests/opinfos.py index aa1c311d3..df071f7dc 100644 --- a/thunder/tests/opinfos.py +++ b/thunder/tests/opinfos.py @@ -2636,54 +2636,6 @@ def nan_to_num_error_generator(op, device, dtype=torch.float32, **kwargs): conditional_and_mask_ops.append(nan_to_num_opinfo) -def all_tensor_sample_generator(op, device, dtype, requires_grad, **kwargs): - make = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) - - # when dim is None - cases = ( - (4, 4), - (2, 3), - (0, 2), - (2, 3, 4), - ) - - for input_shape in cases: - yield SampleInput(make(input_shape)) - - # input shape, dim, keepdim - # dim_cases = ( - # ((4, 4), None, False), - # ((4, 4), None, True), - # ((2, 3), 0, True), - # ((2, 3, 4), (1, 2), False), - # ((2, 3, 4), (1, 2), True), - # ((2, 3, 4), (-1, 1), False), - # ((2, 3, 4), (-1, 1), True), - # ) - - # for input_shape, dim, keepdim in dim_cases: - # yield SampleInput(make(input_shape), dim, keepdim) - - -def all_tensor_error_generator(op, device, dtype=torch.float32, **kwargs): - make = partial(make_tensor, device=device, dtype=dtype) - err_msg = r"Dimension out of range \(expected to be in range of \[.*?\], but got .*\)" - yield ( - SampleInput(make(5, 1, 2, 3), 4), - IndexError, - err_msg, - ) - - -all_tensor_opinfo = OpInfo( - ltorch.all_tensor, - sample_input_generator=all_tensor_sample_generator, - error_input_generator=all_tensor_error_generator, - torch_reference=torch.all, -) -conditional_and_mask_ops.append(all_tensor_opinfo) - - def clamp_sample_generator(op, device, dtype, requires_grad, **kwargs): cases = ( ((5,), (5,), (5,)), @@ -4689,6 +4641,69 @@ def unsqueeze_sample_generator(op, device, dtype, requires_grad, **kwargs): reduction_ops = [] +def all_tensor_sample_generator(op, device, dtype, requires_grad, **kwargs): + make = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + + # input shape, dim, keepdim + dim_cases = ( + ((4, 4), None, False), + ((4, 4), None, True), + ((2, 3), 0, True), + ((2, 3, 4), (1, 2), False), + ((2, 3, 4), (1, 2), True), + ((2, 3, 4), (-1, 1), False), + ((2, 3, 4), (-1, 1), True), + ) + + for input_shape, dim, keepdim in dim_cases: + yield SampleInput(make(input_shape), dim, keepdim) + + +def all_tensor_error_generator(op, device, dtype=torch.float32, **kwargs): + make = partial(make_tensor, device=device, dtype=dtype) + err_msg = r"Dimension out of range \(expected to be in range of \[.*?\], but got .*\)" + yield ( + SampleInput(make(5, 1, 2, 3), 4), + IndexError, + err_msg, + ) + + +all_tensor_opinfo = OpInfo( + ltorch.all_tensor, + sample_input_generator=all_tensor_sample_generator, + error_input_generator=all_tensor_error_generator, + torch_reference=torch.all, +) + +reduction_ops.append(all_tensor_opinfo) + + +def any_tensor_sample_generator(op, device, dtype, requires_grad, **kwargs): + make = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + + # input shape, dim, keepdim + dim_cases = ( + ((4, 4), None, False), + ((4, 4), None, True), + ((2, 3), 0, True), + ((2, 3, 4), (1, 2), False), + ((2, 3, 4), (1, 2), True), + ((2, 3, 4), (-1, 1), False), + ((2, 3, 4), (-1, 1), True), + ) + + for input_shape, dim, keepdim in dim_cases: + yield SampleInput(make(input_shape), dim, keepdim) + + +any_tensor_opinfo = OpInfo( + ltorch.any_tensor, + sample_input_generator=any_tensor_sample_generator, + torch_reference=torch.any, +) + + # TODO: increase reduction samples and refacort amax and sum generators def amax_amin_sample_generator(op, device, dtype, requires_grad, **kwargs): # For grad test stability it's better to use wider range of values diff --git a/thunder/torch/__init__.py b/thunder/torch/__init__.py index 5c8573b5e..a88f0575a 100644 --- a/thunder/torch/__init__.py +++ b/thunder/torch/__init__.py @@ -1679,26 +1679,6 @@ def convert(x, if_none): return result -@torchsymbol(torch.all, is_method=True, id="torch.all") -def all_tensor( - a: TensorLike, /, dim: None | int | Sequence[int] = None, keepdim: bool = False, *, out: None | TensorLike = None -) -> TensorLike | None: - result = logical_not(a) - if isinstance(dim, Sequence) and len(dim) == 0: - # PyTorch returns result.clone() - result = result | result - else: - sum_result = sum(result, dim=dim, keepdim=keepdim) - result = ne(sum_result, False) - if a.dtype is dtypes.uint8: - result = prims.convert_element_type(result, dtype=dtypes.uint8) - result = logical_not(result) - - if a.dtype is dtypes.uint8: - result = to(result, dtype=dtypes.uint8) - return result - - # # Reduction operations # @@ -1804,6 +1784,32 @@ def _reduction( return result +@torchsymbol(torch.all, is_method=True, id="torch.all") +def all_tensor( + a: TensorLike, /, dim: None | int | Sequence[int] = None, keepdim: bool = False, *, out: None | TensorLike = None +) -> TensorLike | None: + utils.check(out is None, lambda: "out is not None which is currently unsupported", NotImplementedError) + result = logical_not(a) + result = logical_not(any_tensor(logical_not(a), dim=dim, keepdim=keepdim)) + + if a.dtype is dtypes.uint8: + result = to(result, dtype=dtypes.uint8) + return result + + +@torchsymbol(torch.any, is_method=True, id="torch.any") +def any_tensor(a: TensorLike, /, dim: None | int | Sequence[int] = None, keepdim: bool = False): + a_ = clang.maybe_convert_to_dtype(a, dtypes.bool8) + if isinstance(dim, Sequence) and len(dim) == 0: + # PyTorch returns a_.clone() + result = a_ | a_ + else: + result = ne(sum(a_, dim=dim, keepdim=keepdim), False) + if a.dtype is dtypes.uint8: + return prims.convert_element_type(result, dtypes.uint8) + return result + + @torchsymbol(torch.amax, is_method=True) def amax(a, /, dim=None, keepdim: bool = False): return _reduction( From 1757f16db99dca56683e212b229bcf278d216211 Mon Sep 17 00:00:00 2001 From: k223kim Date: Fri, 10 May 2024 11:13:07 +0900 Subject: [PATCH 7/9] feat: updated torch all with comments and registered torch operations --- thunder/executors/torchex.py | 4 ++++ thunder/torch/__init__.py | 12 +++++++++--- 2 files changed, 13 insertions(+), 3 deletions(-) diff --git a/thunder/executors/torchex.py b/thunder/executors/torchex.py index a2e46a512..19973aa97 100644 --- a/thunder/executors/torchex.py +++ b/thunder/executors/torchex.py @@ -473,6 +473,8 @@ def _tensor_from_sequence_prims_transform( unfold = _register_torch_operation("unfold", module=torch.Tensor) unsqueeze = _register_torch_operation("unsqueeze") view = _register_torch_operation("view", module=torch.Tensor) +all_tensor = _register_torch_operation("all", like=ltorch.all_tensor) +any_tensor = _register_torch_operation("any", like=ltorch.any_tensor) def _broadcast_in_dim_prim_transform( @@ -565,6 +567,8 @@ def _squeeze_transform(a: TensorLike, /, dim: None | int | Sequence[int] = None) _register_implementation(ltorch.unfold, unfold, checker=_always_executable) _register_implementation(ltorch.unsqueeze, unsqueeze, checker=_always_executable) _register_implementation(ltorch.view, view, checker=_always_executable) +_register_implementation(ltorch.all_tensor, all_tensor, checker=_always_executable) +_register_implementation(ltorch.any_tensor, any_tensor, checker=_always_executable) # # Memory format operations diff --git a/thunder/torch/__init__.py b/thunder/torch/__init__.py index 5b2d805f9..ba5054471 100644 --- a/thunder/torch/__init__.py +++ b/thunder/torch/__init__.py @@ -1789,24 +1789,30 @@ def _reduction( @torchsymbol(torch.all, is_method=True, id="torch.all") def all_tensor( a: TensorLike, /, dim: None | int | Sequence[int] = None, keepdim: bool = False, *, out: None | TensorLike = None -) -> TensorLike | None: +) -> TensorLike: + # named as all_tensor to avoid confusion with python's built-in all function utils.check(out is None, lambda: "out is not None which is currently unsupported", NotImplementedError) - result = logical_not(a) result = logical_not(any_tensor(logical_not(a), dim=dim, keepdim=keepdim)) + # Pytorch's torch.all matches the behavior of NumPy in returning output of dtype bool for all supported dtypes except uint8. + # For uint8 the dtype of output is uint8 iteself (https://pytorch.org/docs/stable/generated/torch.all.html) if a.dtype is dtypes.uint8: result = to(result, dtype=dtypes.uint8) return result @torchsymbol(torch.any, is_method=True, id="torch.any") -def any_tensor(a: TensorLike, /, dim: None | int | Sequence[int] = None, keepdim: bool = False): +def any_tensor(a: TensorLike, /, dim: None | int | Sequence[int] = None, keepdim: bool = False) -> TensorLike: + # named as any_tensor to avoid confusion with python's built-in any function a_ = clang.maybe_convert_to_dtype(a, dtypes.bool8) if isinstance(dim, Sequence) and len(dim) == 0: # PyTorch returns a_.clone() result = a_ | a_ else: result = ne(sum(a_, dim=dim, keepdim=keepdim), False) + + # Pytorch's torch.any matches the behavior of NumPy in returning output of dtype bool for all supported dtypes except uint8. + # For uint8 the dtype of output is uint8 iteself (https://pytorch.org/docs/stable/generated/torch.any.html) if a.dtype is dtypes.uint8: return prims.convert_element_type(result, dtypes.uint8) return result From e51cdba9bc331d1c7b49a03f38aa73e2444c4f7a Mon Sep 17 00:00:00 2001 From: k223kim Date: Fri, 10 May 2024 11:15:11 +0900 Subject: [PATCH 8/9] feat: update testset --- thunder/tests/opinfos.py | 20 +------------------- 1 file changed, 1 insertion(+), 19 deletions(-) diff --git a/thunder/tests/opinfos.py b/thunder/tests/opinfos.py index df071f7dc..e90a7fc4c 100644 --- a/thunder/tests/opinfos.py +++ b/thunder/tests/opinfos.py @@ -4679,27 +4679,9 @@ def all_tensor_error_generator(op, device, dtype=torch.float32, **kwargs): reduction_ops.append(all_tensor_opinfo) -def any_tensor_sample_generator(op, device, dtype, requires_grad, **kwargs): - make = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) - - # input shape, dim, keepdim - dim_cases = ( - ((4, 4), None, False), - ((4, 4), None, True), - ((2, 3), 0, True), - ((2, 3, 4), (1, 2), False), - ((2, 3, 4), (1, 2), True), - ((2, 3, 4), (-1, 1), False), - ((2, 3, 4), (-1, 1), True), - ) - - for input_shape, dim, keepdim in dim_cases: - yield SampleInput(make(input_shape), dim, keepdim) - - any_tensor_opinfo = OpInfo( ltorch.any_tensor, - sample_input_generator=any_tensor_sample_generator, + sample_input_generator=all_tensor_sample_generator, torch_reference=torch.any, ) From f8d447d2a8bc11675d7213196c08ebe72dd33b0b Mon Sep 17 00:00:00 2001 From: Kaeun Kim Date: Tue, 14 May 2024 01:26:24 +0900 Subject: [PATCH 9/9] feat: added to reduction_ops --- thunder/tests/opinfos.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/thunder/tests/opinfos.py b/thunder/tests/opinfos.py index e70ee43f0..afea779fe 100644 --- a/thunder/tests/opinfos.py +++ b/thunder/tests/opinfos.py @@ -4876,6 +4876,8 @@ def all_tensor_error_generator(op, device, dtype=torch.float32, **kwargs): torch_reference=torch.any, ) +reduction_ops.append(any_tensor_opinfo) + # TODO: increase reduction samples and refacort amax and sum generators def amax_amin_sample_generator(op, device, dtype, requires_grad, **kwargs):