diff --git a/oneflow/core/common/either_ptr.h b/oneflow/core/common/either_ptr.h index f63d2de67a3..d7b32f9acf4 100644 --- a/oneflow/core/common/either_ptr.h +++ b/oneflow/core/common/either_ptr.h @@ -25,92 +25,79 @@ template class EitherPtr final { public: static_assert(!std::is_same::value, "X should not be Y"); - EitherPtr() : type_(UnionType::value) {} - EitherPtr(const std::shared_ptr& ptr) { Set(ptr); } - EitherPtr(const std::shared_ptr& ptr) { Set(ptr); } - EitherPtr(const EitherPtr& either_ptr) { CopyFrom(either_ptr); } - ~EitherPtr() { Reset(); } + + using XPtr = std::shared_ptr; + using YPtr = std::shared_ptr; + + // WARNING: we should assume that the structure of shared_ptr and shared_ptr is same, + // and obviously at most time the assumption holds + static_assert(sizeof(XPtr) == sizeof(YPtr), "unsupported shared_ptr implementation"); + + EitherPtr() : type_(UnionType::value), x_ptr_(nullptr) {} + EitherPtr(const XPtr& ptr) : type_(UnionType::value), x_ptr_(ptr) {} + EitherPtr(const YPtr& ptr) + : type_(UnionType::value), x_ptr_(reinterpret_cast(ptr)) {} + + EitherPtr(XPtr&& ptr) : type_(UnionType::value), x_ptr_(std::move(ptr)) {} + EitherPtr(YPtr&& ptr) : type_(UnionType::value), x_ptr_(reinterpret_cast(ptr)) {} + + EitherPtr(const EitherPtr& either_ptr) : type_(either_ptr.type_), x_ptr_(either_ptr.x_ptr_) {} + EitherPtr(EitherPtr&& either_ptr) + : type_(either_ptr.type_), x_ptr_(std::move(either_ptr.x_ptr_)) {} + + // the destructor of X or Y will be called properly because it will be stored in the deleter of + // shared_ptr while constructed + ~EitherPtr() = default; + + EitherPtr& operator=(const EitherPtr& either_ptr) { + x_ptr_ = either_ptr.x_ptr_; + type_ = either_ptr.type_; + return *this; + } + + EitherPtr& operator=(EitherPtr&& either_ptr) { + x_ptr_ = std::move(either_ptr.x_ptr_); + type_ = either_ptr.type_; + return *this; + } template bool Has() const { return type_ == UnionType::value; } + template const std::shared_ptr& Get() const { - CHECK(this->template Has()); - return Cast(); - } - void Reset(const std::shared_ptr& ptr) { - Reset(); - Set(ptr); - } - void Reset(const std::shared_ptr& ptr) { - Reset(); - Set(ptr); - } - - void Reset() { - if (type_ == UnionType::value) { - union_.reset(); - } else if (type_ == UnionType::value) { - MutCast()->reset(); - } else if (type_ == UnionType::value) { - MutCast()->reset(); - } else { - LOG(FATAL) << "UNIMPLEMENTED"; - } + return Get(tag{}); } private: - struct Void {}; template struct UnionType; template - struct UnionType::value>::type> { - static const int8_t value = 0; - }; - template struct UnionType::value>::type> { - static const int8_t value = 1; + static constexpr int8_t value = 0; }; template struct UnionType::value>::type> { - static const int8_t value = 2; + static constexpr int8_t value = 1; }; - void CopyFrom(const EitherPtr& either_ptr) { - if (either_ptr.template Has()) { - Set(either_ptr.template Get()); - } else if (either_ptr.template Has()) { - Set(either_ptr.template Get()); - } else { - // do nothin - } - } - void Set(const std::shared_ptr& ptr) { - CHECK(union_.get() == nullptr); - *MutCast() = ptr; - type_ = UnionType::value; - } - void Set(const std::shared_ptr& ptr) { - CHECK(union_.get() == nullptr); - *MutCast() = ptr; - type_ = UnionType::value; - } - template - std::shared_ptr* MutCast() { - std::shared_ptr* __attribute__((__may_alias__)) ptr = - reinterpret_cast*>(&union_); - return ptr; + + template + struct tag {}; + + const XPtr& Get(tag) const { + CHECK(Has()); + return x_ptr_; } - template - const std::shared_ptr& Cast() const { - const std::shared_ptr* __attribute__((__may_alias__)) ptr = - reinterpret_cast*>(&union_); - return *ptr; + + const YPtr& Get(tag) const { + CHECK(Has()); + return reinterpret_cast(x_ptr_); } - std::shared_ptr union_; int8_t type_; + std::shared_ptr x_ptr_; }; } // namespace oneflow diff --git a/oneflow/core/common/maybe.h b/oneflow/core/common/maybe.h index d577ae5bbe6..48a270f8817 100644 --- a/oneflow/core/common/maybe.h +++ b/oneflow/core/common/maybe.h @@ -47,9 +47,10 @@ class Maybe::value || IsScala Maybe(const T& data) : data_or_error_(std::make_shared(data)) {} Maybe(const Error& error) : data_or_error_(error.error_proto()) {} Maybe(const std::shared_ptr& data) : data_or_error_(data) {} + Maybe(std::shared_ptr&& data) : data_or_error_(std::move(data)) {} Maybe(const std::shared_ptr& error) : data_or_error_(error) {} Maybe(const Maybe&) = default; - Maybe(Maybe&&) = default; + Maybe(Maybe&& other) : data_or_error_(std::move(other.data_or_error_)) {} ~Maybe() = default; bool IsOk() const { return data_or_error_.template Has(); } @@ -268,47 +269,40 @@ inline bool MaybeIsOk(Maybe&& maybe) { #if defined(__GNUC__) || defined(__CUDACC__) || defined(__clang__) -// fix CUDA 11.1 compiler crashes -#if defined(__CUDACC__) -#define MAYBE_CONST_AUTO_REF const auto -#else -#define MAYBE_CONST_AUTO_REF const auto& -#endif // defined(__CUDACC__) - #define TRY(...) __MaybeErrorStackCheckWrapper__(__VA_ARGS__) -#define JUST(...) \ - ({ \ - MAYBE_CONST_AUTO_REF maybe = __MaybeErrorStackCheckWrapper__(__VA_ARGS__); \ - if (!maybe.IsOk()) { \ - auto* stack_frame = maybe.error()->add_stack_frame(); \ - stack_frame->set_file(__FILE__); \ - stack_frame->set_line(__LINE__); \ - stack_frame->set_function(__FUNCTION__); \ - stack_frame->set_error_msg(OF_PP_STRINGIZE((__VA_ARGS__))); \ - return maybe.error(); \ - } \ - maybe; \ +#define JUST(...) \ + ({ \ + auto&& maybe = __MaybeErrorStackCheckWrapper__(__VA_ARGS__); \ + if (!maybe.IsOk()) { \ + auto* stack_frame = maybe.error()->add_stack_frame(); \ + stack_frame->set_file(__FILE__); \ + stack_frame->set_line(__LINE__); \ + stack_frame->set_function(__FUNCTION__); \ + stack_frame->set_error_msg(OF_PP_STRINGIZE((__VA_ARGS__))); \ + return maybe.error(); \ + } \ + std::move(maybe); \ }).Data_YouAreNotAllowedToCallThisFuncOutsideThisFile() -#define CHECK_JUST(...) \ - ([&](const char* func_name) { \ - MAYBE_CONST_AUTO_REF maybe = __MaybeErrorStackCheckWrapper__(__VA_ARGS__); \ - if (!maybe.IsOk()) { \ - auto* stack_frame = maybe.error()->add_stack_frame(); \ - stack_frame->set_file(__FILE__); \ - stack_frame->set_line(__LINE__); \ - stack_frame->set_function(func_name); \ - stack_frame->set_error_msg(OF_PP_STRINGIZE((__VA_ARGS__))); \ - LOG(FATAL) << maybe.GetSerializedError(); \ - } \ - return maybe; \ - })(__FUNCTION__) \ +#define CHECK_JUST(...) \ + ([&](const char* func_name) { \ + auto&& maybe = __MaybeErrorStackCheckWrapper__(__VA_ARGS__); \ + if (!maybe.IsOk()) { \ + auto* stack_frame = maybe.error()->add_stack_frame(); \ + stack_frame->set_file(__FILE__); \ + stack_frame->set_line(__LINE__); \ + stack_frame->set_function(func_name); \ + stack_frame->set_error_msg(OF_PP_STRINGIZE((__VA_ARGS__))); \ + LOG(FATAL) << maybe.GetSerializedError(); \ + } \ + return std::move(maybe); \ + })(__FUNCTION__) \ .Data_YouAreNotAllowedToCallThisFuncOutsideThisFile() #define CHECK_OK(...) CHECK(MaybeIsOk(__VA_ARGS__)) -#define OF_RETURN_IF_ERROR(...) \ - for (MAYBE_CONST_AUTO_REF maybe_##__LINE__ = __MaybeErrorStackCheckWrapper__(__VA_ARGS__); \ - !maybe_##__LINE__.IsOk();) \ +#define OF_RETURN_IF_ERROR(...) \ + for (auto&& maybe_##__LINE__ = __MaybeErrorStackCheckWrapper__(__VA_ARGS__); \ + !maybe_##__LINE__.IsOk();) \ return Error(maybe_##__LINE__.error()).AddStackFrame(__FILE__, __LINE__, __FUNCTION__) #else diff --git a/oneflow/core/common/shared_or_scalar.h b/oneflow/core/common/shared_or_scalar.h index b102430ec43..e2cedfbdbfe 100644 --- a/oneflow/core/common/shared_or_scalar.h +++ b/oneflow/core/common/shared_or_scalar.h @@ -16,8 +16,10 @@ limitations under the License. #ifndef ONEFLOW_CORE_COMMON_SHARED_OR_SCALAR_H_ #define ONEFLOW_CORE_COMMON_SHARED_OR_SCALAR_H_ -#include +#include + #include +#include "oneflow/core/common/type_traits.h" #include "oneflow/core/common/preprocessor.h" namespace oneflow { @@ -25,100 +27,101 @@ namespace oneflow { template class SharedOrScalar final { public: - SharedOrScalar(ScalarT scalar_value) : shared_ptr_() { SetScalar(scalar_value); } - SharedOrScalar(const SharedOrScalar& rhs) { *this = rhs; } - SharedOrScalar(const std::shared_ptr& shared_ptr) : shared_ptr_(shared_ptr) { - CHECK(!IsScalar()); - } - ~SharedOrScalar(); + static_assert(IsScalarType::value, "ScalarT should be scalar type."); - SharedOrScalar& operator=(const SharedOrScalar& rhs); + using Shared = std::shared_ptr; - bool IsScalar() const; - ScalarT scalar_value() const; - std::shared_ptr shared_ptr() const; + SharedOrScalar(const ScalarT& scalar_value) : is_scalar_(true), scalar_value_(scalar_value) {} - ScalarT operator*() const { return scalar_value(); } + SharedOrScalar(const std::shared_ptr& shared_ptr) : is_scalar_(false) { + new (&shared_mem_) Shared(shared_ptr); + } - private: - struct ScalarStruct final { - uint64_t _ : 62, is_scalar_value : 2; - ScalarT scalar_value; - }; - static_assert(sizeof(StructT*) == 8, "only 64-bit pointer supported"); - static_assert(sizeof(ScalarT) <= 8, "only scalar data type supported"); - static_assert(sizeof(std::shared_ptr) >= sizeof(ScalarStruct), - "unsupported shared_ptr implemenet"); + SharedOrScalar(std::shared_ptr&& shared_ptr) : is_scalar_(false) { + new (&shared_mem_) Shared(std::move(shared_ptr)); + } - void SetScalar(ScalarT scalar_value); - const ScalarStruct* CastToScalarStruct() const; - ScalarStruct* MutCastToScalarStruct(); + SharedOrScalar(const SharedOrScalar& rhs) : is_scalar_(rhs.is_scalar_) { + if (rhs.is_scalar_) { + scalar_value_ = rhs.scalar_value_; + } else { + new (&shared_mem_) Shared(rhs.GetShared()); + } + } - std::shared_ptr shared_ptr_; -}; + SharedOrScalar(SharedOrScalar&& rhs) : is_scalar_(rhs.is_scalar_) { + if (rhs.is_scalar_) { + scalar_value_ = rhs.scalar_value_; + } else { + new (&shared_mem_) Shared(std::move(*rhs.MutableShared())); + } + } -template -SharedOrScalar& SharedOrScalar::operator=( - const SharedOrScalar& rhs) { - if (rhs.IsScalar()) { -#if defined(__GNUC__) && __GNUC__ >= 8 -#pragma GCC diagnostic ignored "-Wclass-memaccess" -#endif - std::memcpy(this, &rhs, sizeof(*this)); - } else { - shared_ptr_ = rhs.shared_ptr_; + SharedOrScalar& operator=(const SharedOrScalar& rhs) { + if (rhs.is_scalar_) { + scalar_value_ = rhs.scalar_value_; + } else { + if (is_scalar_) { + scalar_value_.~ScalarT(); + new (&shared_mem_) Shared(rhs.GetShared()); + } else { + *MutableShared() = rhs.GetShared(); + } + } + is_scalar_ = rhs.is_scalar_; + return *this; } - return *this; -} -template -const typename SharedOrScalar::ScalarStruct* -SharedOrScalar::CastToScalarStruct() const { - const ScalarStruct* __attribute__((__may_alias__)) ptr = - reinterpret_cast(&shared_ptr_); - return ptr; -} + SharedOrScalar& operator=(SharedOrScalar&& rhs) { + if (rhs.is_scalar_) { + scalar_value_ = rhs.scalar_value_; + } else { + if (is_scalar_) { + scalar_value_.~ScalarT(); + new (&shared_mem_) Shared(std::move(*rhs.MutableShared())); + } else { + *MutableShared() = std::move(*rhs.MutableShared()); + } + } + is_scalar_ = rhs.is_scalar_; + return *this; + } -template -typename SharedOrScalar::ScalarStruct* -SharedOrScalar::MutCastToScalarStruct() { - ScalarStruct* __attribute__((__may_alias__)) ptr = reinterpret_cast(&shared_ptr_); - return ptr; -} + ~SharedOrScalar() { + if (is_scalar_) { + scalar_value_.~ScalarT(); + } else { + GetShared().~Shared(); + } + } -template -void SharedOrScalar::SetScalar(ScalarT scalar_value) { - ScalarStruct* const ptr = MutCastToScalarStruct(); - ptr->is_scalar_value = 1; - ptr->scalar_value = scalar_value; -} + bool IsScalar() const { return is_scalar_; } + const ScalarT& scalar_value() const { + CHECK(is_scalar_); + return scalar_value_; + } -template -std::shared_ptr SharedOrScalar::shared_ptr() const { - CHECK(!IsScalar()); - return shared_ptr_; -} + const std::shared_ptr& shared_ptr() const { + CHECK(!is_scalar_); + return GetShared(); + } -template -ScalarT SharedOrScalar::scalar_value() const { - const ScalarStruct* const ptr = CastToScalarStruct(); - CHECK(ptr->is_scalar_value); - return ptr->scalar_value; -} + const ScalarT& operator*() const { return scalar_value(); } -template -bool SharedOrScalar::IsScalar() const { - const ScalarStruct* const ptr = CastToScalarStruct(); - return ptr->is_scalar_value; -} + private: + bool is_scalar_; + union { + ScalarT scalar_value_; -template -SharedOrScalar::~SharedOrScalar() { - if (IsScalar()) { - std::shared_ptr empty_ptr; - std::memcpy(&shared_ptr_, &empty_ptr, sizeof(empty_ptr)); - } -} + // to avoid error(a non-POD class definition is not allowed inside of a statement expression) + // in nvcc while using with JUST macro (this type is used in Maybe) + alignas(Shared) char shared_mem_[sizeof(Shared)]; + }; + + const Shared& GetShared() const { return reinterpret_cast(shared_mem_); } + + Shared* MutableShared() { return reinterpret_cast(&shared_mem_); } +}; } // namespace oneflow