diff --git a/paddle/framework/CMakeLists.txt b/paddle/framework/CMakeLists.txt index e3c3155aa902c..b06ecc26286de 100644 --- a/paddle/framework/CMakeLists.txt +++ b/paddle/framework/CMakeLists.txt @@ -1,6 +1,5 @@ cc_library(ddim SRCS ddim.cc) cc_test(ddim_test SRCS ddim_test.cc DEPS ddim) - nv_test(dim_test SRCS dim_test.cu DEPS ddim) - cc_test(variable_test SRCS variable_test.cc) +cc_test(enforce_test SRCS enforce_test.cc) diff --git a/paddle/framework/enforce.h b/paddle/framework/enforce.h new file mode 100644 index 0000000000000..56cb7f95647e8 --- /dev/null +++ b/paddle/framework/enforce.h @@ -0,0 +1,69 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once +#include +#include +#include + +namespace paddle { +namespace framework { + +/** + * @brief Enforce exception. Inherits std::exception + * + * All enforce condition not met, will throw an EnforceNotMet exception. + */ +class EnforceNotMet : public std::exception { + public: + EnforceNotMet(const std::string& msg, const char* file, int fileline) { + std::ostringstream sout; + sout << msg << " at [" << file << ":" << fileline << "];"; + all_msg_ = sout.str(); + } + + const char* what() const noexcept override { return all_msg_.c_str(); } + + private: + std::string all_msg_; +}; + +// From https://stackoverflow.com/questions/30130930/ +// __buildin_expect is in C++ 11 standard. Since the condition which enforced +// should be true in most situation, it will make the compiler generate faster +// code by adding `UNLIKELY` macro. +#define UNLIKELY(condition) __builtin_expect(static_cast(condition), 0) + +/** + * @brief Throw a EnforceNotMet exception, automatically filled __FILE__ & + * __LINE__ + * + * This macro take __VA_ARGS__, user can pass any type if that type can + * serialize to std::ostream + */ +#define PADDLE_THROW(...) \ + do { \ + throw ::paddle::framework::EnforceNotMet( \ + ::paddle::string::Sprintf(__VA_ARGS__), __FILE__, __LINE__); \ + } while (0) + +/** + * @brief Enforce a condition, otherwise throw an EnforceNotMet + */ +#define PADDLE_ENFORCE(condition, ...) \ + do { \ + if (UNLIKELY(!(condition))) { \ + PADDLE_THROW(__VA_ARGS__); \ + } \ + } while (0) + +} // namespace framework +} // namespace paddle diff --git a/paddle/framework/enforce_test.cc b/paddle/framework/enforce_test.cc new file mode 100644 index 0000000000000..f8da1a192f63a --- /dev/null +++ b/paddle/framework/enforce_test.cc @@ -0,0 +1,35 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include +#include + +TEST(ENFORCE, OK) { + PADDLE_ENFORCE(true, "Enforce is ok %d now %f", 123, 0.345); + size_t val = 1; + const size_t limit = 10; + PADDLE_ENFORCE(val < limit, "Enforce is OK too"); +} + +TEST(ENFORCE, FAILED) { + bool in_catch = false; + try { + PADDLE_ENFORCE(false, "Enforce is not ok %d at all", 123); + } catch (paddle::framework::EnforceNotMet err) { + in_catch = true; + std::string msg = "Enforce is not ok 123 at all"; + const char* what = err.what(); + for (size_t i = 0; i < msg.length(); ++i) { + ASSERT_EQ(what[i], msg[i]); + } + } + ASSERT_TRUE(in_catch); +} \ No newline at end of file