Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for fp16 in arithmetic and mathematical ops #4674

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 4 additions & 5 deletions dali/operators/math/expressions/arithmetic_meta.h
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright (c) 2019-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// Copyright (c) 2019-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -261,10 +261,9 @@ DALI_HOST_DEV constexpr bool IsComparison(ArithmeticOp op) {
}
}


// TODO(klecki): float16
#define ARITHMETIC_ALLOWED_TYPES \
(bool, uint8_t, uint16_t, uint32_t, uint64_t, int8_t, int16_t, int32_t, int64_t, float, double)
#define ARITHMETIC_ALLOWED_TYPES \
(bool, uint8_t, uint16_t, uint32_t, uint64_t, int8_t, int16_t, int32_t, int64_t, float16, float, \
double)

/**
* @brief Type promotion rules
Expand Down
7 changes: 4 additions & 3 deletions dali/test/python/operator_1/test_arithmetic_ops.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2019-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2019-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -72,8 +72,7 @@ def shape_small(arg_idx):
np.bool_, np.int8, np.int16, np.int32, np.int64, np.uint8, np.uint16, np.uint32, np.uint64
]

# float16 is marked as TODO in backend for gpu
float_types = [np.float32, np.float64]
float_types = [np.float16, np.float32, np.float64]

input_types = integer_types + float_types

Expand Down Expand Up @@ -618,8 +617,10 @@ def check_comparsion_op(kinds, types, op, shape, _):
device_id=0)
pipe.build()
pipe_out = pipe.run()
np.set_printoptions(formatter={'float':lambda x:float(x).hex()})
for sample in range(batch_size):
l_np, r_np, out = extract_data(pipe_out, sample, kinds, None)
print(f"L {l_np.dtype}={l_np},\n\nR {r_np.dtype}={r_np},\n\nOut {out.dtype}={out},\n\nnp_out={op(l_np, r_np)},\n\nOP:{_}")
assert_equals(out.dtype, np.bool_)
np.testing.assert_array_equal(out, op(l_np, r_np), err_msg=f"{l_np} op\n{r_np} =\n{out}")

Expand Down
2 changes: 1 addition & 1 deletion docs/math.rst
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ The resulting type is calculated in accordance to the table below.

``T`` stands for any one of the supported numerical types:
``bool``, ``int8``, ``int16``, ``int32``, ``int64``, ``uint8``, ``uint16``,
``uint32``, ``uint64``, ``float32``, and ``float64``.
``uint32``, ``uint64``, ``float16``, ``float32``, and ``float64``.

``bool`` type is considered the smallest unsigned integer type and is treated as ``uint1``
with respect to the table above.
Expand Down