Skip to content

Commit

Permalink
Optimize maybe. (#5839)
Browse files Browse the repository at this point in the history
* Optimize maybe.

* revert

* refine code style

* maybe: fix either_ptr and shared_or_scalar

* maybe: clang format

* maybe: fix error for nvcc

* maybe: fix either_ptr and shared_or_scalar

* maybe: clang format

* either_ptr: fix dtor

* maybe: fix

Co-authored-by: PragmaTwice <i@twice.moe>
Co-authored-by: oneflow-ci-bot <69100618+oneflow-ci-bot@users.noreply.github.com>
  • Loading branch information
3 people committed Aug 13, 2021
1 parent d3ca591 commit bdb64f7
Show file tree
Hide file tree
Showing 3 changed files with 164 additions and 180 deletions.
115 changes: 51 additions & 64 deletions oneflow/core/common/either_ptr.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,92 +25,79 @@ template<typename X, typename Y>
class EitherPtr final {
public:
static_assert(!std::is_same<X, Y>::value, "X should not be Y");
EitherPtr() : type_(UnionType<Void>::value) {}
EitherPtr(const std::shared_ptr<X>& ptr) { Set(ptr); }
EitherPtr(const std::shared_ptr<Y>& ptr) { Set(ptr); }
EitherPtr(const EitherPtr<X, Y>& either_ptr) { CopyFrom(either_ptr); }
~EitherPtr() { Reset(); }

using XPtr = std::shared_ptr<X>;
using YPtr = std::shared_ptr<Y>;

// WARNING: we should assume that the structure of shared_ptr<X> and shared_ptr<Y> is same,
// and obviously at most time the assumption holds
static_assert(sizeof(XPtr) == sizeof(YPtr), "unsupported shared_ptr implementation");

EitherPtr() : type_(UnionType<X>::value), x_ptr_(nullptr) {}
EitherPtr(const XPtr& ptr) : type_(UnionType<X>::value), x_ptr_(ptr) {}
EitherPtr(const YPtr& ptr)
: type_(UnionType<Y>::value), x_ptr_(reinterpret_cast<const XPtr&>(ptr)) {}

EitherPtr(XPtr&& ptr) : type_(UnionType<X>::value), x_ptr_(std::move(ptr)) {}
EitherPtr(YPtr&& ptr) : type_(UnionType<Y>::value), x_ptr_(reinterpret_cast<XPtr&&>(ptr)) {}

EitherPtr(const EitherPtr& either_ptr) : type_(either_ptr.type_), x_ptr_(either_ptr.x_ptr_) {}
EitherPtr(EitherPtr&& either_ptr)
: type_(either_ptr.type_), x_ptr_(std::move(either_ptr.x_ptr_)) {}

// the destructor of X or Y will be called properly because it will be stored in the deleter of
// shared_ptr while constructed
~EitherPtr() = default;

EitherPtr& operator=(const EitherPtr& either_ptr) {
x_ptr_ = either_ptr.x_ptr_;
type_ = either_ptr.type_;
return *this;
}

EitherPtr& operator=(EitherPtr&& either_ptr) {
x_ptr_ = std::move(either_ptr.x_ptr_);
type_ = either_ptr.type_;
return *this;
}

template<typename T>
bool Has() const {
return type_ == UnionType<T>::value;
}

template<typename T>
const std::shared_ptr<T>& Get() const {
CHECK(this->template Has<T>());
return Cast<T>();
}
void Reset(const std::shared_ptr<X>& ptr) {
Reset();
Set(ptr);
}
void Reset(const std::shared_ptr<Y>& ptr) {
Reset();
Set(ptr);
}

void Reset() {
if (type_ == UnionType<Void>::value) {
union_.reset();
} else if (type_ == UnionType<X>::value) {
MutCast<X>()->reset();
} else if (type_ == UnionType<Y>::value) {
MutCast<Y>()->reset();
} else {
LOG(FATAL) << "UNIMPLEMENTED";
}
return Get(tag<T>{});
}

private:
struct Void {};
template<typename T, typename Enable = void>
struct UnionType;
template<typename T>
struct UnionType<T, typename std::enable_if<std::is_same<Void, T>::value>::type> {
static const int8_t value = 0;
};
template<typename T>
struct UnionType<T, typename std::enable_if<std::is_same<X, T>::value>::type> {
static const int8_t value = 1;
static constexpr int8_t value = 0;
};
template<typename T>
struct UnionType<T, typename std::enable_if<std::is_same<Y, T>::value>::type> {
static const int8_t value = 2;
static constexpr int8_t value = 1;
};
void CopyFrom(const EitherPtr<X, Y>& either_ptr) {
if (either_ptr.template Has<X>()) {
Set(either_ptr.template Get<X>());
} else if (either_ptr.template Has<Y>()) {
Set(either_ptr.template Get<Y>());
} else {
// do nothin
}
}
void Set(const std::shared_ptr<X>& ptr) {
CHECK(union_.get() == nullptr);
*MutCast<X>() = ptr;
type_ = UnionType<X>::value;
}
void Set(const std::shared_ptr<Y>& ptr) {
CHECK(union_.get() == nullptr);
*MutCast<Y>() = ptr;
type_ = UnionType<Y>::value;
}
template<typename T>
std::shared_ptr<T>* MutCast() {
std::shared_ptr<T>* __attribute__((__may_alias__)) ptr =
reinterpret_cast<std::shared_ptr<T>*>(&union_);
return ptr;

template<typename>
struct tag {};

const XPtr& Get(tag<X>) const {
CHECK(Has<X>());
return x_ptr_;
}
template<typename T>
const std::shared_ptr<T>& Cast() const {
const std::shared_ptr<T>* __attribute__((__may_alias__)) ptr =
reinterpret_cast<const std::shared_ptr<T>*>(&union_);
return *ptr;

const YPtr& Get(tag<Y>) const {
CHECK(Has<Y>());
return reinterpret_cast<const YPtr&>(x_ptr_);
}

std::shared_ptr<Void> union_;
int8_t type_;
std::shared_ptr<X> x_ptr_;
};

} // namespace oneflow
Expand Down
66 changes: 30 additions & 36 deletions oneflow/core/common/maybe.h
Original file line number Diff line number Diff line change
Expand Up @@ -47,9 +47,10 @@ class Maybe<T, typename std::enable_if<!(std::is_same<T, void>::value || IsScala
Maybe(const T& data) : data_or_error_(std::make_shared<T>(data)) {}
Maybe(const Error& error) : data_or_error_(error.error_proto()) {}
Maybe(const std::shared_ptr<T>& data) : data_or_error_(data) {}
Maybe(std::shared_ptr<T>&& data) : data_or_error_(std::move(data)) {}
Maybe(const std::shared_ptr<cfg::ErrorProto>& error) : data_or_error_(error) {}
Maybe(const Maybe&) = default;
Maybe(Maybe&&) = default;
Maybe(Maybe&& other) : data_or_error_(std::move(other.data_or_error_)) {}
~Maybe() = default;

bool IsOk() const { return data_or_error_.template Has<T>(); }
Expand Down Expand Up @@ -268,47 +269,40 @@ inline bool MaybeIsOk(Maybe<void>&& maybe) {

#if defined(__GNUC__) || defined(__CUDACC__) || defined(__clang__)

// fix CUDA 11.1 compiler crashes
#if defined(__CUDACC__)
#define MAYBE_CONST_AUTO_REF const auto
#else
#define MAYBE_CONST_AUTO_REF const auto&
#endif // defined(__CUDACC__)

#define TRY(...) __MaybeErrorStackCheckWrapper__(__VA_ARGS__)
#define JUST(...) \
({ \
MAYBE_CONST_AUTO_REF maybe = __MaybeErrorStackCheckWrapper__(__VA_ARGS__); \
if (!maybe.IsOk()) { \
auto* stack_frame = maybe.error()->add_stack_frame(); \
stack_frame->set_file(__FILE__); \
stack_frame->set_line(__LINE__); \
stack_frame->set_function(__FUNCTION__); \
stack_frame->set_error_msg(OF_PP_STRINGIZE((__VA_ARGS__))); \
return maybe.error(); \
} \
maybe; \
#define JUST(...) \
({ \
auto&& maybe = __MaybeErrorStackCheckWrapper__(__VA_ARGS__); \
if (!maybe.IsOk()) { \
auto* stack_frame = maybe.error()->add_stack_frame(); \
stack_frame->set_file(__FILE__); \
stack_frame->set_line(__LINE__); \
stack_frame->set_function(__FUNCTION__); \
stack_frame->set_error_msg(OF_PP_STRINGIZE((__VA_ARGS__))); \
return maybe.error(); \
} \
std::move(maybe); \
}).Data_YouAreNotAllowedToCallThisFuncOutsideThisFile()
#define CHECK_JUST(...) \
([&](const char* func_name) { \
MAYBE_CONST_AUTO_REF maybe = __MaybeErrorStackCheckWrapper__(__VA_ARGS__); \
if (!maybe.IsOk()) { \
auto* stack_frame = maybe.error()->add_stack_frame(); \
stack_frame->set_file(__FILE__); \
stack_frame->set_line(__LINE__); \
stack_frame->set_function(func_name); \
stack_frame->set_error_msg(OF_PP_STRINGIZE((__VA_ARGS__))); \
LOG(FATAL) << maybe.GetSerializedError(); \
} \
return maybe; \
})(__FUNCTION__) \
#define CHECK_JUST(...) \
([&](const char* func_name) { \
auto&& maybe = __MaybeErrorStackCheckWrapper__(__VA_ARGS__); \
if (!maybe.IsOk()) { \
auto* stack_frame = maybe.error()->add_stack_frame(); \
stack_frame->set_file(__FILE__); \
stack_frame->set_line(__LINE__); \
stack_frame->set_function(func_name); \
stack_frame->set_error_msg(OF_PP_STRINGIZE((__VA_ARGS__))); \
LOG(FATAL) << maybe.GetSerializedError(); \
} \
return std::move(maybe); \
})(__FUNCTION__) \
.Data_YouAreNotAllowedToCallThisFuncOutsideThisFile()

#define CHECK_OK(...) CHECK(MaybeIsOk(__VA_ARGS__))

#define OF_RETURN_IF_ERROR(...) \
for (MAYBE_CONST_AUTO_REF maybe_##__LINE__ = __MaybeErrorStackCheckWrapper__(__VA_ARGS__); \
!maybe_##__LINE__.IsOk();) \
#define OF_RETURN_IF_ERROR(...) \
for (auto&& maybe_##__LINE__ = __MaybeErrorStackCheckWrapper__(__VA_ARGS__); \
!maybe_##__LINE__.IsOk();) \
return Error(maybe_##__LINE__.error()).AddStackFrame(__FILE__, __LINE__, __FUNCTION__)

#else
Expand Down

0 comments on commit bdb64f7

Please sign in to comment.