Skip to content

Commit

Permalink
KernelState (#6198)
Browse files Browse the repository at this point in the history
Co-authored-by: oneflow-ci-bot <69100618+oneflow-ci-bot@users.noreply.github.com>
  • Loading branch information
liujuncheng and oneflow-ci-bot committed Sep 7, 2021
1 parent 5c667f5 commit eae9ff3
Show file tree
Hide file tree
Showing 18 changed files with 55 additions and 76 deletions.
15 changes: 4 additions & 11 deletions oneflow/core/actor/actor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,7 @@ limitations under the License.
*/
#include "oneflow/core/actor/actor.h"
#include "oneflow/core/control/global_process_ctx.h"
#include "oneflow/core/thread/thread_manager.h"
#include "oneflow/core/job/runtime_job_descs.h"
#include "oneflow/core/control/global_process_ctx.h"
#include "oneflow/core/stream/stream_context.h"

namespace oneflow {
Expand All @@ -44,12 +42,9 @@ class KernelContextImpl : public KernelContext {

Blob* BnInOp2Blob(const std::string& bn) const override { return bn_in_op2blob_fn_(bn); }

void* state() const override { return state_; }
const std::shared_ptr<KernelState>& state() const override { return state_; }

void set_state(void* state) override {
CHECK(state_ == nullptr);
state_ = state;
}
void set_state(std::shared_ptr<KernelState> state) override { state_ = std::move(state); }

void WillForward(KernelContext* kernel_ctx, const Kernel* kernel) override;
void DidForward(KernelContext* kernel_ctx, const Kernel* kernel) override;
Expand All @@ -67,7 +62,7 @@ class KernelContextImpl : public KernelContext {
private:
DeviceCtx* device_ctx_;
std::function<Blob*(const std::string&)> bn_in_op2blob_fn_;
void* state_;
std::shared_ptr<KernelState> state_;
KernelObserver* stream_kernel_observer_;
};

Expand Down Expand Up @@ -127,9 +122,7 @@ void CheckInplaceRegstDescId(const TaskProto& task_proto) {

} // namespace

Actor::~Actor() {
for (ExecKernel& ek : exec_kernel_vec_) { ek.kernel->DestroyState(ek.kernel_ctx->state()); }
}
Actor::~Actor() = default;

void Actor::Init(const JobDesc* job_desc, const TaskProto& task_proto, StreamContext* stream_ctx) {
job_desc_ = job_desc;
Expand Down
3 changes: 2 additions & 1 deletion oneflow/core/actor/case_actor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,8 @@ class CaseActor final : public Actor {

void CaseActor::VirtualActorInit(const TaskProto& task_proto) {
CHECK_EQ(1, exec_kernel_vec().size());
case_status_ = static_cast<CaseStatus*>(exec_kernel_vec().at(0).kernel_ctx->state());
case_status_ =
CHECK_NOTNULL(dynamic_cast<CaseStatus*>(exec_kernel_vec().at(0).kernel_ctx->state().get()));
const int32_t output_bns_size =
task_proto.exec_sequence().exec_node().Get(0).kernel_conf().op_attribute().output_bns_size();
FOR_RANGE(int64_t, i, 0, output_bns_size) {
Expand Down
6 changes: 4 additions & 2 deletions oneflow/core/actor/esac_actor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ limitations under the License.
*/
#include "oneflow/core/actor/actor.h"
#include "oneflow/core/operator/operator.h"
#include "oneflow/core/kernel/esac_kernel.h"

namespace oneflow {

Expand Down Expand Up @@ -49,7 +50,7 @@ class EsacActor final : public Actor {
int64_t GetCurProcessedRegstDescId() const;

RegstSlot consumed_rs_;
int64_t cur_processed_regst_desc_id_;
int64_t cur_processed_regst_desc_id_{};
HashMap<int64_t, int64_t> regst_desc_id2in_bn_id_;
};

Expand Down Expand Up @@ -88,7 +89,8 @@ void EsacActor::Act() {
CHECK(cur_regst);
int64_t in_bn_id = InBnId4RegstDescId(cur_processed_regst_desc_id_);
CHECK_EQ(exec_kernel_vec().size(), 1);
*static_cast<int64_t*>(exec_kernel_vec().at(0).kernel_ctx->state()) = in_bn_id;
CHECK_NOTNULL(dynamic_cast<EsacKernelState*>(exec_kernel_vec().at(0).kernel_ctx->state().get()))
->value = in_bn_id;
AsyncLaunchKernel([&](int64_t regst_desc_id) -> Regst* {
if (cur_processed_regst_desc_id_ != regst_desc_id) { return nullptr; }
return cur_regst;
Expand Down
16 changes: 7 additions & 9 deletions oneflow/core/actor/light_actor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ struct RegstState {
struct KernelInfo {
std::unique_ptr<const Kernel> kernel;
HashMap<std::string, Blob*> bn_in_op2blob;
void* state = nullptr;
std::shared_ptr<KernelState> state;
};

template<typename IndexType, int max_size>
Expand Down Expand Up @@ -218,9 +218,7 @@ class LightActor : public ActorBase, public KernelContext {
}
}
}
~LightActor() override {
if (exec_kernel) { kernel_info_[0]->kernel->DestroyState(kernel_info_[0]->state); }
}
~LightActor() override = default;

void Init(const JobDesc* job_desc, const TaskProto& task_proto,
StreamContext* stream_ctx) override {
Expand Down Expand Up @@ -513,18 +511,18 @@ class LightActor : public ActorBase, public KernelContext {
}
}

void* state() const override {
const std::shared_ptr<KernelState>& state() const override {
if (exec_kernel) {
return kernel_info_[0]->state;
} else {
return nullptr;
static const std::shared_ptr<KernelState> null_state;
return null_state;
}
}

void set_state(void* state) override {
void set_state(std::shared_ptr<KernelState> state) override {
CHECK(exec_kernel);
CHECK(kernel_info_[0]->state == nullptr);
kernel_info_[0]->state = state;
kernel_info_[0]->state = std::move(state);
}

void WillForward(KernelContext* kernel_ctx, const Kernel* kernel) override {
Expand Down
4 changes: 2 additions & 2 deletions oneflow/core/actor/reentrant_lock_actor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,8 @@ class ReentrantLockActor final : public Actor {

void ReentrantLockActor::VirtualActorInit(const TaskProto& task_proto) {
CHECK_EQ(1, exec_kernel_vec().size());
reentrant_lock_status_ =
static_cast<ReentrantLockStatus*>(exec_kernel_vec().at(0).kernel_ctx->state());
reentrant_lock_status_ = CHECK_NOTNULL(
dynamic_cast<ReentrantLockStatus*>(exec_kernel_vec().at(0).kernel_ctx->state().get()));
act_id_ = 0;
const auto& kernel_conf = task_proto.exec_sequence().exec_node().Get(0).kernel_conf();
const auto& ibns = kernel_conf.op_attribute().input_bns();
Expand Down
4 changes: 2 additions & 2 deletions oneflow/core/actor/wait_and_send_ids_actor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,8 @@ class WaitAndSendIdsActor final : public Actor {

void WaitAndSendIdsActor::VirtualActorInit(const TaskProto& task_proto) {
CHECK_EQ(exec_kernel_vec().size(), 1);
wait_and_send_ids_status_ =
static_cast<WaitAndSendIdsStatus*>(exec_kernel_vec().at(0).kernel_ctx->state());
wait_and_send_ids_status_ = CHECK_NOTNULL(
dynamic_cast<WaitAndSendIdsStatus*>(exec_kernel_vec().at(0).kernel_ctx->state().get()));
wait_and_send_ids_status_->buffer_status_ = kBufferStatusSuccess;
wait_and_send_ids_status_->in_id_ = 0;
wait_and_send_ids_status_->out_idx_ = 0;
Expand Down
8 changes: 4 additions & 4 deletions oneflow/core/eager/opkernel_object.h
Original file line number Diff line number Diff line change
Expand Up @@ -89,12 +89,12 @@ class SystemOpKernelContext : public KernelContext {

Blob* BnInOp2Blob(const std::string& bn) const override { return bn_in_op2blob_fn_(bn); }

void* state() const override {
UNIMPLEMENTED();
return nullptr;
const std::shared_ptr<KernelState>& state() const override {
static const std::shared_ptr<KernelState> null_state;
return null_state;
}

void set_state(void* state) override { UNIMPLEMENTED(); }
void set_state(std::shared_ptr<KernelState> state) override { UNIMPLEMENTED(); }

void set_device_ctx(DeviceCtx* ctx) { device_ctx_ = ctx; }

Expand Down
9 changes: 2 additions & 7 deletions oneflow/core/kernel/case_kernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,17 +20,12 @@ namespace oneflow {

template<typename T>
void CaseKernel<T>::VirtualKernelInit(KernelContext* ctx) {
ctx->set_state(new CaseStatus);
}

template<typename T>
void CaseKernel<T>::DestroyState(void* state) const {
delete static_cast<CaseStatus*>(state);
ctx->set_state(std::make_shared<CaseStatus>());
}

template<typename T>
void CaseKernel<T>::ForwardDataContent(KernelContext* ctx) const {
CaseStatus* const case_status = static_cast<CaseStatus*>(ctx->state());
CaseStatus* const case_status = CHECK_NOTNULL(dynamic_cast<CaseStatus*>(ctx->state().get()));
if (case_status->cmd == kCaseCmdHandleInput) {
int64_t cur_selected_id = static_cast<int64_t>(ctx->BnInOp2Blob("in")->dptr<T>()[0]);
case_status->select_id2request_cnt[cur_selected_id] += 1;
Expand Down
3 changes: 1 addition & 2 deletions oneflow/core/kernel/case_kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ enum CaseCmd {
kCaseCmdHandleOutput = 2,
};

struct CaseStatus final {
struct CaseStatus final : public KernelState {
CaseStatus() : cmd(kCaseCmdInvalid), cur_selected_id(-1) {}
~CaseStatus() = default;

Expand All @@ -44,7 +44,6 @@ class CaseKernel final : public Kernel {

private:
void VirtualKernelInit(KernelContext* ctx) override;
void DestroyState(void* state) const override;
void ForwardDataContent(KernelContext* ctx) const override;
};

Expand Down
10 changes: 3 additions & 7 deletions oneflow/core/kernel/esac_kernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,17 +19,13 @@ namespace oneflow {

template<typename T>
void EsacKernel<T>::VirtualKernelInit(KernelContext* ctx) {
ctx->set_state(new int64_t);
}

template<typename T>
void EsacKernel<T>::DestroyState(void* state) const {
delete static_cast<int64_t*>(state);
ctx->set_state(std::make_shared<EsacKernelState>());
}

template<typename T>
void EsacKernel<T>::ForwardDataContent(KernelContext* ctx) const {
T value = static_cast<T>(*static_cast<int64_t*>(ctx->state()));
T value =
static_cast<T>(CHECK_NOTNULL(dynamic_cast<EsacKernelState*>(ctx->state().get()))->value);
KernelUtil<DeviceType::kCPU, T>::Set(ctx->device_ctx(), value,
ctx->BnInOp2Blob("out")->mut_dptr<T>());
}
Expand Down
5 changes: 4 additions & 1 deletion oneflow/core/kernel/esac_kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,10 @@ limitations under the License.

namespace oneflow {

struct EsacKernelState : public KernelState {
int64_t value{};
};

template<typename T>
class EsacKernel final : public Kernel {
public:
Expand All @@ -28,7 +32,6 @@ class EsacKernel final : public Kernel {
~EsacKernel() override = default;

private:
void DestroyState(void* state) const override;
void VirtualKernelInit(KernelContext* ctx) override;
void ForwardDataContent(KernelContext* ctx) const override;
};
Expand Down
2 changes: 0 additions & 2 deletions oneflow/core/kernel/kernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,6 @@ void Kernel::Init(const KernelConf& kernel_conf, KernelContext* ctx) {
VirtualKernelInit(ctx);
}

void Kernel::DestroyState(void* state) const { CHECK(state == nullptr); }

void Kernel::Launch(KernelContext* ctx) const {
ctx->WillForward(ctx, this);
Forward(ctx);
Expand Down
1 change: 0 additions & 1 deletion oneflow/core/kernel/kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@ class Kernel {
virtual ~Kernel();

void Init(const KernelConf& kernel_conf, KernelContext* ctx);
virtual void DestroyState(void* state) const;
void Launch(KernelContext* ctx) const;

const OperatorConf& op_conf() const { return op_attribute().op_conf(); }
Expand Down
12 changes: 10 additions & 2 deletions oneflow/core/kernel/kernel_context.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,14 @@ namespace oneflow {

class Blob;
class JobDesc;

class KernelState {
public:
OF_DISALLOW_COPY_AND_MOVE(KernelState);
KernelState() = default;
virtual ~KernelState() = default;
};

class KernelContext : public KernelObserver {
public:
OF_DISALLOW_COPY_AND_MOVE(KernelContext);
Expand All @@ -31,8 +39,8 @@ class KernelContext : public KernelObserver {

virtual DeviceCtx* device_ctx() const = 0;
virtual Blob* BnInOp2Blob(const std::string& bn) const = 0;
virtual void* state() const = 0;
virtual void set_state(void* state) = 0;
virtual const std::shared_ptr<KernelState>& state() const = 0;
virtual void set_state(std::shared_ptr<KernelState> state) = 0;
};

} // namespace oneflow
Expand Down
9 changes: 2 additions & 7 deletions oneflow/core/kernel/reentrant_lock_kernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -82,17 +82,12 @@ void ReentrantLockStatus::ReleaseLock(int64_t lock_id, std::queue<int64_t>* unlo

template<typename T>
void ReentrantLockKernel<T>::VirtualKernelInit(KernelContext* ctx) {
ctx->set_state(new ReentrantLockStatus);
}

template<typename T>
void ReentrantLockKernel<T>::DestroyState(void* state) const {
delete static_cast<ReentrantLockStatus*>(state);
ctx->set_state(std::make_shared<ReentrantLockStatus>());
}

template<typename T>
void ReentrantLockKernel<T>::ForwardDataContent(KernelContext* ctx) const {
auto* const status = static_cast<ReentrantLockStatus*>(ctx->state());
auto* const status = CHECK_NOTNULL(dynamic_cast<ReentrantLockStatus*>(ctx->state().get()));
if (status->cur_ibn() == "start") {
T lock_id = *ctx->BnInOp2Blob("start")->dptr<T>();
status->RequestLock(lock_id, status->mut_cur_unlocked_ids());
Expand Down
11 changes: 5 additions & 6 deletions oneflow/core/kernel/reentrant_lock_kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ limitations under the License.

namespace oneflow {

class ReentrantLockStatus final {
class ReentrantLockStatus final : public KernelState {
public:
OF_DISALLOW_COPY_AND_MOVE(ReentrantLockStatus);
ReentrantLockStatus() = default;
Expand Down Expand Up @@ -60,10 +60,10 @@ class ReentrantLockStatus final {
bool TryAcquireLock(int64_t lock_id);

std::string cur_ibn_;
int64_t cur_act_id_;
bool acquired_lock_to_be_sent_;
size_t total_queued_request_lock_num_;
size_t total_acquired_lock_num_;
int64_t cur_act_id_{};
bool acquired_lock_to_be_sent_{};
size_t total_queued_request_lock_num_{};
size_t total_acquired_lock_num_{};
std::vector<std::queue<int64_t>> lock_id2queued_request_act_id_;
std::vector<size_t> lock_id2acquired_num_;
std::vector<std::vector<int64_t>> lock_id2intersecting_lock_ids_;
Expand All @@ -79,7 +79,6 @@ class ReentrantLockKernel final : public Kernel {

private:
void VirtualKernelInit(KernelContext* ctx) override;
void DestroyState(void* state) const override;
void ForwardDataContent(KernelContext* ctx) const override;
};

Expand Down
10 changes: 2 additions & 8 deletions oneflow/core/kernel/wait_and_send_ids_kernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,18 +23,12 @@ namespace oneflow {

template<typename T>
void WaitAndSendIdsKernel<T>::VirtualKernelInit(KernelContext* ctx) {
ctx->set_state(new WaitAndSendIdsStatus);
}

template<typename T>
void WaitAndSendIdsKernel<T>::DestroyState(void* state) const {
delete static_cast<WaitAndSendIdsStatus*>(state);
ctx->set_state(std::make_shared<WaitAndSendIdsStatus>());
}

template<typename T>
void WaitAndSendIdsKernel<T>::ForwardDataContent(KernelContext* ctx) const {
CHECK(ctx->state());
auto* status = static_cast<WaitAndSendIdsStatus*>(ctx->state());
auto* status = CHECK_NOTNULL(dynamic_cast<WaitAndSendIdsStatus*>(ctx->state().get()));
const auto& conf = this->op_conf().wait_and_send_ids_conf();
if (status->out_idx_ >= status->out_num_) {
if (CHECK_JUST(*Global<Maybe<bool>, MultiClient>::Get())) {
Expand Down
3 changes: 1 addition & 2 deletions oneflow/core/kernel/wait_and_send_ids_kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ limitations under the License.

namespace oneflow {

struct WaitAndSendIdsStatus final {
struct WaitAndSendIdsStatus final : public KernelState {
BufferStatus buffer_status_;
int64_t in_id_;
int64_t out_idx_;
Expand All @@ -37,7 +37,6 @@ class WaitAndSendIdsKernel final : public Kernel {

private:
void VirtualKernelInit(KernelContext* ctx) override;
void DestroyState(void* state) const override;
void ForwardDataContent(KernelContext* ctx) const override;
};

Expand Down

0 comments on commit eae9ff3

Please sign in to comment.