Skip to content

Commit

Permalink
Add Half support for cummax, cummin, cumprod, logcumsumexp, and prod …
Browse files Browse the repository at this point in the history
…on CPU (pytorch#112132)

Add Half support for cummax, cummin, cumprod, logcumsumexp, and prod on CPU.

Pull Request resolved: pytorch#112132
Approved by: https://github.com/cpuhrsch
  • Loading branch information
CaoE authored and Skylion007 committed Nov 14, 2023
1 parent 9b1481a commit 4da8534
Show file tree
Hide file tree
Showing 7 changed files with 44 additions and 23 deletions.
4 changes: 2 additions & 2 deletions aten/src/ATen/native/ReduceOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -794,7 +794,7 @@ void cummax_cummin_helper(const T1* self_data, T1* values_data, T2* indices_data
}

void cummax_helper_cpu(const Tensor& self, Tensor& values, Tensor& indices, int64_t dim) {
AT_DISPATCH_ALL_TYPES_AND2(kBool, kBFloat16,
AT_DISPATCH_ALL_TYPES_AND3(kBool, kBFloat16, kHalf,
self.scalar_type(), "cummax_cpu",
[&] {
at::native::tensor_dim_apply3<scalar_t, int64_t>(self, values, indices, dim, cummax_cummin_helper<scalar_t, int64_t, std::greater_equal<scalar_t>>);
Expand Down Expand Up @@ -829,7 +829,7 @@ std::tuple<Tensor, Tensor> cummax(const Tensor& self, int64_t dim) {
}

void cummin_helper_cpu(const Tensor& self, Tensor& values, Tensor& indices, int64_t dim) {
AT_DISPATCH_ALL_TYPES_AND2(kBool, kBFloat16,
AT_DISPATCH_ALL_TYPES_AND3(kBool, kBFloat16, kHalf,
self.scalar_type(), "cummin_cpu",
[&] {
at::native::tensor_dim_apply3<scalar_t, int64_t>(self, values, indices, dim, cummax_cummin_helper<scalar_t, int64_t, std::less_equal<scalar_t>>);
Expand Down
6 changes: 3 additions & 3 deletions aten/src/ATen/native/cpu/ReduceOpsKernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ static void cumprod_cpu_kernel(const Tensor& result, const Tensor& self, int64_t
auto wrap_dim = maybe_wrap_dim(dim, self.dim());
int64_t self_dim_size = ensure_nonempty_size(self, wrap_dim);

AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND(kBFloat16, self.scalar_type(), "cumprod_out_cpu", [&] {
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2(kBFloat16, kHalf, self.scalar_type(), "cumprod_out_cpu", [&] {
cpu_cum_base_kernel<scalar_t>(result, self, wrap_dim, [&] (
scalar_t* result_data, auto result_dim_stride,
const scalar_t* self_data, auto self_dim_stride, scalar_t init_val) {
Expand All @@ -119,7 +119,7 @@ static void logcumsumexp_cpu_kernel(Tensor& result, const Tensor& self, int64_t
auto wrap_dim = maybe_wrap_dim(dim, self.dim());
int64_t self_dim_size = ensure_nonempty_size(self, wrap_dim);

AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND1(kBFloat16, self.scalar_type(), "logcumsumexp_out_cpu", [&] {
AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(kBFloat16, kHalf, self.scalar_type(), "logcumsumexp_out_cpu", [&] {
cpu_cum_base_kernel<scalar_t>(result, self, wrap_dim, [&] (
scalar_t* result_data, auto result_dim_stride,
const scalar_t* self_data, auto self_dim_stride, scalar_t init_val) {
Expand Down Expand Up @@ -176,7 +176,7 @@ static void prod_kernel_impl(TensorIterator& iter) {
// NOLINTNEXTLINE(bugprone-argument-comment)
/*identity=*/1);
} else {
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND(kBFloat16, iter.dtype(), "prod_out_cpu", [&] {
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2(kBFloat16, kHalf, iter.dtype(), "prod_out_cpu", [&] {
binary_kernel_reduce_vec(
iter,
[=](scalar_t a, scalar_t b)
Expand Down
5 changes: 4 additions & 1 deletion test/test_mps.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,9 @@ def mps_ops_grad_modifier(ops):
'__rpow__': [torch.float32],

# See https://github.com/pytorch/pytorch/issues/106112 for more information
'cumprod': [torch.float32],
'cumprod': [torch.float32, torch.float16],
# See https://github.com/pytorch/pytorch/issues/109166 for more information
'masked.cumprod': [torch.float16],
}

SKIPLIST_GRAD = {
Expand Down Expand Up @@ -10943,6 +10945,7 @@ class TestConsistency(TestCaseMPS):
'nn.functional.kl_div',
'nn.functional.softmin',
'cross', 'linalg.cross',
'prod', 'masked.prod',

# for macOS 12
'masked.normalize', 'masked.sum', 'masked.var',
Expand Down
12 changes: 12 additions & 0 deletions test/test_reductions.py
Original file line number Diff line number Diff line change
Expand Up @@ -1435,6 +1435,18 @@ def test_prod(self, device, dtype):
torch.prod(x, 1, out=res2)
self.assertEqual(res1, res2)

@onlyCPU
@dtypes(torch.float16, torch.bfloat16)
def test_prod_lowp(self, device, dtype):
x = torch.rand(100, 100, dtype=dtype, device=device)
x_ref = x.float()
res1 = torch.prod(x, 1)
res2 = torch.prod(x_ref, 1)
self.assertEqual(res1, res2.to(dtype=dtype))
res1 = torch.prod(x, 0)
res2 = torch.prod(x_ref, 0)
self.assertEqual(res1, res2.to(dtype=dtype))

def test_prod_bool(self, device):
vals = [[True, True], [True, False], [False, False], []]
for val in vals:
Expand Down
11 changes: 10 additions & 1 deletion test/test_sparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -4225,6 +4225,10 @@ def fn(x):
class TestSparseMaskedReductions(TestCase):
exact_dtype = True

fp16_low_precision_list = {
'masked.prod',
}

@ops(sparse_masked_reduction_ops)
def test_future_empty_dim(self, device, dtype, op):
"""Currently, `dim=()` in reductions operations means "reduce over
Expand Down Expand Up @@ -4263,7 +4267,12 @@ def test_future_empty_dim(self, device, dtype, op):
self.assertEqual(actual.layout, torch.sparse_coo)

expected = op(t, *sample_input.args, **sample_input_kwargs).to_sparse()
self.assertEqual(actual, expected)
atol = None
rtol = None
if op.name in self.fp16_low_precision_list and dtype == torch.half:
atol = 1e-5
rtol = 2e-3
self.assertEqual(actual, expected, atol=atol, rtol=rtol)


class TestSparseMeta(TestCase):
Expand Down
14 changes: 5 additions & 9 deletions torch/testing/_internal/common_methods_invocations.py
Original file line number Diff line number Diff line change
Expand Up @@ -10821,8 +10821,7 @@ def reference_flatten(input, start_dim=0, end_dim=-1):
),
sample_inputs_func=sample_inputs_cumulative_ops),
OpInfo('cumprod',
dtypes=all_types_and_complex_and(torch.bfloat16),
dtypesIfCUDA=all_types_and_complex_and(torch.float16, torch.bfloat16),
dtypes=all_types_and_complex_and(torch.float16, torch.bfloat16),
supports_forward_ad=True,
supports_fwgrad_bwgrad=True,
skips=(
Expand All @@ -10833,17 +10832,15 @@ def reference_flatten(input, start_dim=0, end_dim=-1):
sample_inputs_func=sample_inputs_cumprod,
gradcheck_fast_mode=False),
OpInfo('cummax',
dtypes=all_types_and(torch.bool, torch.bfloat16),
dtypesIfCUDA=all_types_and(torch.bool, torch.half, torch.bfloat16),
dtypes=all_types_and(torch.bool, torch.half, torch.bfloat16),
sample_inputs_func=partial(sample_inputs_cumulative_ops, supports_dtype_kwargs=False),
supports_forward_ad=True,
supports_fwgrad_bwgrad=True,
skips=(
),
gradcheck_nondet_tol=GRADCHECK_NONDET_TOL),
OpInfo('cummin',
dtypes=all_types_and(torch.bool, torch.bfloat16),
dtypesIfCUDA=all_types_and(torch.bool, torch.half, torch.bfloat16),
dtypes=all_types_and(torch.bool, torch.half, torch.bfloat16),
sample_inputs_func=partial(sample_inputs_cumulative_ops, supports_dtype_kwargs=False),
supports_forward_ad=True,
supports_fwgrad_bwgrad=True,
Expand Down Expand Up @@ -17294,8 +17291,7 @@ def reference_flatten(input, start_dim=0, end_dim=-1):
)
),
OpInfo('logcumsumexp',
dtypes=floating_and_complex_types_and(torch.bfloat16),
dtypesIfCUDA=floating_and_complex_types_and(torch.half, torch.bfloat16),
dtypes=floating_and_complex_types_and(torch.bfloat16, torch.half),
backward_dtypes=floating_and_complex_types_and(torch.bfloat16),
backward_dtypesIfCUDA=floating_and_complex_types_and(torch.bfloat16),
supports_forward_ad=True,
Expand Down Expand Up @@ -18371,7 +18367,7 @@ def reference_flatten(input, start_dim=0, end_dim=-1):
supports_fwgrad_bwgrad=True,
promotes_int_to_int64=True,
gradcheck_nondet_tol=GRADCHECK_NONDET_TOL,
dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16),
dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.float16),
dtypesIfCUDA=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16, torch.chalf),
sample_inputs_func=sample_inputs_prod,
ref=prod_numpy,
Expand Down
15 changes: 8 additions & 7 deletions torch/testing/_internal/opinfo/definitions/_masked.py
Original file line number Diff line number Diff line change
Expand Up @@ -504,11 +504,7 @@ def sample_inputs_masked_normalize(op_info, device, dtype, requires_grad, **kwar
supports_sparse=True,
supports_sparse_csr=True,
promotes_int_to_int64=True,
# FIXME: "prod_cpu" not implemented for 'Half'
dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16),
dtypesIfCUDA=all_types_and_complex_and(
torch.bool, torch.float16, torch.bfloat16
),
dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16),
skips=(
DecorateInfo(
unittest.expectedFailure,
Expand Down Expand Up @@ -554,6 +550,12 @@ def sample_inputs_masked_normalize(op_info, device, dtype, requires_grad, **kwar
"TestReductions",
"test_ref_small_input",
),
DecorateInfo(
toleranceOverride({torch.float16: tol(atol=1e-02, rtol=1.5e-03)}),
"TestMasked",
"test_mask_layout",
device_type="cpu",
),
],
sample_inputs_func=sample_inputs_masked_reduction,
sample_inputs_sparse_coo_func=sample_inputs_sparse_coo_masked_reduction,
Expand Down Expand Up @@ -585,8 +587,7 @@ def sample_inputs_masked_normalize(op_info, device, dtype, requires_grad, **kwar
),
OpInfo(
"masked.cumprod",
dtypes=all_types_and_complex_and(torch.bfloat16),
dtypesIfCUDA=all_types_and_complex_and(torch.float16, torch.bfloat16),
dtypes=all_types_and_complex_and(torch.float16, torch.bfloat16),
method_variant=None,
# Runs very slowly on slow gradcheck - alternatively reduce input sizes
gradcheck_fast_mode=True,
Expand Down

0 comments on commit 4da8534

Please sign in to comment.