Skip to content

Commit

Permalink
Enhance check_nan_inf implementation for CPU. (#48591)
Browse files Browse the repository at this point in the history
* Enable to print device info.

* Enhance the nan and inf checking for cpu.

* Implement a common print function.

* Unify the check of complex numbers.

* Rewrite the omp method.

* Count and print the number of nan and inf.

* Change the print content.

* Add unittest.
  • Loading branch information
Xreki committed Dec 12, 2022
1 parent 6698e8d commit 69e695b
Show file tree
Hide file tree
Showing 5 changed files with 441 additions and 336 deletions.
46 changes: 22 additions & 24 deletions paddle/fluid/eager/tests/task_tests/nan_inf_utils_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -30,32 +30,30 @@ PD_DECLARE_KERNEL(strings_empty, CPU, ALL_LAYOUT);

namespace egr {

#define CHECK_NAN_INF(tensors) \
{ \
bool caught_exception = false; \
try { \
CheckTensorHasNanOrInf("nan_inf_test", tensors); \
} catch (paddle::platform::EnforceNotMet & error) { \
caught_exception = true; \
std::string ex_msg = error.what(); \
EXPECT_TRUE(ex_msg.find("There are `nan` or `inf` in tensor") != \
std::string::npos); \
} \
EXPECT_TRUE(caught_exception); \
#define CHECK_NAN_INF(tensors) \
{ \
bool caught_exception = false; \
try { \
CheckTensorHasNanOrInf("nan_inf_test", tensors); \
} catch (paddle::platform::EnforceNotMet & error) { \
caught_exception = true; \
std::string ex_msg = error.what(); \
EXPECT_TRUE(ex_msg.find("There are NAN or INF") != std::string::npos); \
} \
EXPECT_TRUE(caught_exception); \
}

#define CHECK_NO_NAN_INF(tensors) \
{ \
bool caught_exception = false; \
try { \
CheckTensorHasNanOrInf("nan_inf_test", tensors); \
} catch (paddle::platform::EnforceNotMet & error) { \
caught_exception = true; \
std::string ex_msg = error.what(); \
EXPECT_TRUE(ex_msg.find("There are `nan` or `inf` in tensor") != \
std::string::npos); \
} \
EXPECT_FALSE(caught_exception); \
#define CHECK_NO_NAN_INF(tensors) \
{ \
bool caught_exception = false; \
try { \
CheckTensorHasNanOrInf("nan_inf_test", tensors); \
} catch (paddle::platform::EnforceNotMet & error) { \
caught_exception = true; \
std::string ex_msg = error.what(); \
EXPECT_TRUE(ex_msg.find("There are NAN or INF") != std::string::npos); \
} \
EXPECT_FALSE(caught_exception); \
}

TEST(NanInfUtils, Functions) {
Expand Down
Loading

0 comments on commit 69e695b

Please sign in to comment.