Skip to content

Commit

Permalink
[AMP] Add bfloat16 and float16 tests for compare ops (#51978)
Browse files Browse the repository at this point in the history
* add bf16 and fp16 tests

* fix dtype check
  • Loading branch information
yeliang2258 committed Mar 23, 2023
1 parent 9c853d1 commit a7397e0
Show file tree
Hide file tree
Showing 3 changed files with 48 additions and 4 deletions.
6 changes: 4 additions & 2 deletions paddle/phi/kernels/cpu/compare_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,8 @@ PD_REGISTER_KERNEL(equal_all,
int64_t, \
float, \
double, \
phi::dtype::float16) { \
phi::dtype::float16, \
phi::dtype::bfloat16) { \
kernel->OutputAt(0).SetDataType(phi::DataType::BOOL); \
} \
PD_REGISTER_KERNEL(name##_raw, \
Expand All @@ -107,7 +108,8 @@ PD_REGISTER_KERNEL(equal_all,
int64_t, \
float, \
double, \
phi::dtype::float16) { \
phi::dtype::float16, \
phi::dtype::bfloat16) { \
kernel->OutputAt(0).SetDataType(phi::DataType::BOOL); \
}
PD_REGISTER_COMPARE_KERNEL(less_than, LessThan)
Expand Down
42 changes: 42 additions & 0 deletions python/paddle/fluid/tests/unittests/test_compare_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,15 @@ def test_dynamic_api_float(self):
self.assertEqual((out.numpy() == self.real_result).all(), True)
paddle.enable_static()

def test_dynamic_api_float16(self):
paddle.disable_static()
x = paddle.to_tensor(self.input_x, dtype="float16")
y = paddle.to_tensor(self.input_y, dtype="float16")
op = eval("paddle.%s" % (self.op_type))
out = op(x, y)
self.assertEqual((out.numpy() == self.real_result).all(), True)
paddle.enable_static()

def test_dynamic_api_inf_1(self):
if self.op_type == "equal":
paddle.disable_static()
Expand Down Expand Up @@ -434,6 +443,39 @@ def test_attr_name(self):
create_paddle_case('not_equal', lambda _a, _b: _a != _b)


# add bf16 tests
def create_bf16_case(op_type, callback):
class TestCompareOpBF16Op(op_test.OpTest):
def setUp(self):
self.op_type = op_type
self.dtype = np.uint16
self.python_api = eval("paddle." + op_type)

x = np.random.uniform(0, 1, [5, 5]).astype(np.float32)
y = np.random.uniform(0, 1, [5, 5]).astype(np.float32)
real_result = callback(x, y)
self.inputs = {
'X': op_test.convert_float_to_uint16(x),
'Y': op_test.convert_float_to_uint16(y),
}
self.outputs = {'Out': real_result}

def test_check_output(self):
self.check_output()

cls_name = "BF16TestCase_{}".format(op_type)
TestCompareOpBF16Op.__name__ = cls_name
globals()[cls_name] = TestCompareOpBF16Op


create_bf16_case('less_than', lambda _a, _b: _a < _b)
create_bf16_case('less_equal', lambda _a, _b: _a <= _b)
create_bf16_case('greater_than', lambda _a, _b: _a > _b)
create_bf16_case('greater_equal', lambda _a, _b: _a >= _b)
create_bf16_case('equal', lambda _a, _b: _a == _b)
create_bf16_case('not_equal', lambda _a, _b: _a != _b)


class TestCompareOpError(unittest.TestCase):
def test_errors(self):
paddle.enable_static()
Expand Down
4 changes: 2 additions & 2 deletions python/paddle/tensor/logic.py
Original file line number Diff line number Diff line change
Expand Up @@ -746,13 +746,13 @@ def not_equal(x, y, name=None):
check_variable_and_dtype(
x,
"x",
["bool", "float32", "float64", "int32", "int64"],
["bool", "float16", "float32", "float64", "int32", "int64"],
"not_equal",
)
check_variable_and_dtype(
y,
"y",
["bool", "float32", "float64", "int32", "int64"],
["bool", "float16", "float32", "float64", "int32", "int64"],
"not_equal",
)
helper = LayerHelper("not_equal", **locals())
Expand Down

0 comments on commit a7397e0

Please sign in to comment.