Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

StreamContext #6129

Merged
merged 15 commits into from
Sep 4, 2021
Merged
85 changes: 75 additions & 10 deletions oneflow/core/actor/actor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ limitations under the License.
#include "oneflow/core/thread/thread_manager.h"
#include "oneflow/core/job/runtime_job_descs.h"
#include "oneflow/core/control/global_process_ctx.h"
#include "oneflow/core/stream/stream_context.h"

namespace oneflow {

Expand All @@ -27,7 +28,19 @@ class KernelContextImpl : public KernelContext {
public:
OF_DISALLOW_COPY_AND_MOVE(KernelContextImpl);
explicit KernelContextImpl(const JobDesc* job_desc, DeviceCtx* device_ctx)
: job_desc_(job_desc), device_ctx_(device_ctx), state_(nullptr) {}
: job_desc_(job_desc),
device_ctx_(device_ctx),
state_(nullptr),
stream_kernel_observer_(nullptr) {
auto* stream_context_provider = dynamic_cast<StreamContextProvider*>(device_ctx);
if (stream_context_provider != nullptr) {
auto* kernel_observer_provider =
dynamic_cast<KernelObserverProvider*>(stream_context_provider->GetStreamContext());
if (kernel_observer_provider != nullptr) {
stream_kernel_observer_ = kernel_observer_provider->GetKernelObserver();
}
}
}
~KernelContextImpl() = default;

DeviceCtx* device_ctx() const override { return device_ctx_; }
Expand All @@ -43,6 +56,15 @@ class KernelContextImpl : public KernelContext {

const JobDesc* job_desc() const override { return job_desc_; }

void WillForward(KernelContext* kernel_ctx, const Kernel* kernel) override;
void DidForward(KernelContext* kernel_ctx, const Kernel* kernel) override;

void WillForwardHeader(KernelContext* kernel_ctx, const Kernel* kernel) override;
void DidForwardHeader(KernelContext* kernel_ctx, const Kernel* kernel) override;

void WillForwardDataContent(KernelContext* kernel_ctx, const Kernel* kernel) override;
void DidForwardDataContent(KernelContext* kernel_ctx, const Kernel* kernel) override;

void UpdateBnInOp2BlobFn(std::function<Blob*(const std::string&)> fn) {
bn_in_op2blob_fn_ = std::move(fn);
}
Expand All @@ -52,8 +74,51 @@ class KernelContextImpl : public KernelContext {
DeviceCtx* device_ctx_;
std::function<Blob*(const std::string&)> bn_in_op2blob_fn_;
void* state_;
KernelObserver* stream_kernel_observer_;
};

void KernelContextImpl::WillForward(KernelContext* kernel_ctx, const Kernel* kernel) {
Global<KernelObserver>::Get()->WillForward(kernel_ctx, kernel);
if (stream_kernel_observer_ != nullptr) {
stream_kernel_observer_->WillForward(kernel_ctx, kernel);
}
}

void KernelContextImpl::DidForward(KernelContext* kernel_ctx, const Kernel* kernel) {
Global<KernelObserver>::Get()->DidForward(kernel_ctx, kernel);
if (stream_kernel_observer_ != nullptr) {
stream_kernel_observer_->DidForward(kernel_ctx, kernel);
}
}

void KernelContextImpl::WillForwardHeader(KernelContext* kernel_ctx, const Kernel* kernel) {
Global<KernelObserver>::Get()->WillForwardHeader(kernel_ctx, kernel);
if (stream_kernel_observer_ != nullptr) {
stream_kernel_observer_->WillForwardHeader(kernel_ctx, kernel);
}
}

void KernelContextImpl::DidForwardHeader(KernelContext* kernel_ctx, const Kernel* kernel) {
Global<KernelObserver>::Get()->DidForwardHeader(kernel_ctx, kernel);
if (stream_kernel_observer_ != nullptr) {
stream_kernel_observer_->DidForwardHeader(kernel_ctx, kernel);
}
}

void KernelContextImpl::WillForwardDataContent(KernelContext* kernel_ctx, const Kernel* kernel) {
Global<KernelObserver>::Get()->WillForwardDataContent(kernel_ctx, kernel);
if (stream_kernel_observer_ != nullptr) {
stream_kernel_observer_->WillForwardDataContent(kernel_ctx, kernel);
}
}

void KernelContextImpl::DidForwardDataContent(KernelContext* kernel_ctx, const Kernel* kernel) {
Global<KernelObserver>::Get()->DidForwardDataContent(kernel_ctx, kernel);
if (stream_kernel_observer_ != nullptr) {
stream_kernel_observer_->DidForwardDataContent(kernel_ctx, kernel);
}
}

void CheckInplaceRegstDescId(const TaskProto& task_proto) {
HashSet<int64_t> consumed_regst_desc_ids;
for (const auto& pair : task_proto.consumed_regst_desc_id()) {
Expand All @@ -72,13 +137,12 @@ Actor::~Actor() {
for (ExecKernel& ek : exec_kernel_vec_) { ek.kernel->DestroyState(ek.kernel_ctx->state()); }
}

void Actor::Init(const JobDesc* job_desc, const TaskProto& task_proto,
const ThreadCtx& thread_ctx) {
void Actor::Init(const JobDesc* job_desc, const TaskProto& task_proto, StreamContext* stream_ctx) {
job_desc_ = job_desc;
actor_id_ = task_proto.task_id();
thrd_id_ = Global<IDMgr>::Get()->ThrdId4ActorId(actor_id_);
job_id_ = task_proto.job_id();
InitDeviceCtx(thread_ctx);
InitDeviceCtx(stream_ctx);
if (task_proto.has_parallel_ctx()) {
parallel_ctx_.reset(new ParallelContext(task_proto.parallel_ctx()));
}
Expand Down Expand Up @@ -269,10 +333,7 @@ void Actor::IncreaseReadingCnt4ProducedRegst(Regst* regst, int64_t val) {
produced_regst2reading_cnt_.at(regst) += val;
}

void Actor::InitDeviceCtx(const ThreadCtx& thread_ctx) {
DeviceCtx* dev_ctx = NewObj<int, DeviceCtx, const ThreadCtx&>(GetDeviceType(), thread_ctx);
device_ctx_.reset(dev_ctx);
}
void Actor::InitDeviceCtx(StreamContext* stream_ctx) { device_ctx_ = stream_ctx->device_ctx(); }

void Actor::ForEachCurNaiveReadableDataRegst(std::function<void(const Regst*)> func) const {
naive_consumed_rs_.ForEachFrontRegst([func](int64_t regst_desc_id, Regst* regst) {
Expand Down Expand Up @@ -568,7 +629,7 @@ void Actor::AsyncSendEORDMsgForAllProducedRegstDesc() {
for (auto& pair : produced_regsts_) {
CHECK(!pair.second.empty());
const RtRegstDesc* regst_desc = pair.second.front()->regst_desc();
device_ctx_->AddCallBack([regst_desc]() {
AddCallback([regst_desc]() {
for (int64_t consumer : regst_desc->consumers_actor_id()) {
Global<ActorMsgBus>::Get()->SendMsg(
ActorMsg::BuildEordMsg(consumer, regst_desc->regst_desc_id()));
Expand Down Expand Up @@ -649,10 +710,14 @@ void Actor::AsyncSendQueuedMsg() {
if (!async_msg_queue_.empty()) {
std::deque<ActorMsg> msgs;
msgs.swap(async_msg_queue_);
device_ctx_->AddCallBack([msgs]() {
AddCallback([msgs]() {
for (const ActorMsg& msg : msgs) { Global<ActorMsgBus>::Get()->SendMsg(msg); }
});
}
}

void Actor::AddCallback(std::function<void()> callback) {
device_ctx_->AddCallBack(std::move(callback));
}

} // namespace oneflow
10 changes: 6 additions & 4 deletions oneflow/core/actor/actor.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ class Actor : public ActorBase {

const JobDesc& job_desc() const { return *job_desc_; }

void Init(const JobDesc* job_desc, const TaskProto&, const ThreadCtx&) override;
void Init(const JobDesc* job_desc, const TaskProto&, StreamContext* stream_ctx) override;

// 1: success, and actor finish
// 0: success, and actor not finish
Expand Down Expand Up @@ -72,8 +72,8 @@ class Actor : public ActorBase {
virtual void VirtualActorInit(const TaskProto&) {}
int64_t Name2SoleRegstDescId(const std::string& name) const;
const std::vector<int64_t>& Name2RegstDescIds(const std::string& name) const;
virtual void InitDeviceCtx(const ThreadCtx&);
std::unique_ptr<DeviceCtx>& mut_device_ctx() { return device_ctx_; }
virtual void InitDeviceCtx(StreamContext* stream_ctx);
std::shared_ptr<DeviceCtx>& mut_device_ctx() { return device_ctx_; }
const std::vector<ExecKernel>& exec_kernel_vec() { return exec_kernel_vec_; }
void ForEachCurNaiveReadableDataRegst(std::function<void(const Regst*)>) const;

Expand Down Expand Up @@ -194,6 +194,8 @@ class Actor : public ActorBase {
virtual void AsyncSendCustomizedConsumedRegstMsgToProducer() {}
void AsyncRetInplaceConsumedRegstIfNoConsumer();

virtual void AddCallback(std::function<void()> callback);

const JobDesc* job_desc_;
int64_t actor_id_;
int64_t thrd_id_;
Expand All @@ -202,7 +204,7 @@ class Actor : public ActorBase {
std::vector<ExecKernel> exec_kernel_vec_;
HashMap<std::string, std::vector<int64_t>> name2regst_desc_id_;
MsgHandler msg_handler_;
std::unique_ptr<DeviceCtx> device_ctx_;
std::shared_ptr<DeviceCtx> device_ctx_;
HashSet<int64_t> eord_regst_desc_ids_;
int64_t remaining_eord_cnt_;

Expand Down
4 changes: 2 additions & 2 deletions oneflow/core/actor/actor_base.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,10 @@ limitations under the License.

namespace oneflow {

std::unique_ptr<ActorBase> NewActor(const TaskProto& task_proto, const ThreadCtx& thread_ctx) {
std::unique_ptr<ActorBase> NewActor(const TaskProto& task_proto, StreamContext* stream_ctx) {
ActorBase* rptr = NewObj<int32_t, ActorBase>(task_proto.task_type());
const auto& job_descs = *Global<RuntimeJobDescs>::Get();
rptr->Init(&job_descs.job_desc(task_proto.job_id()), task_proto, thread_ctx);
rptr->Init(&job_descs.job_desc(task_proto.job_id()), task_proto, stream_ctx);
return std::unique_ptr<ActorBase>(rptr);
}

Expand Down
6 changes: 3 additions & 3 deletions oneflow/core/actor/actor_base.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ namespace oneflow {

class JobDesc;
class TaskProto;
class ThreadCtx;
class StreamContext;
class ActorMsg;

class ActorBase {
Expand All @@ -32,14 +32,14 @@ class ActorBase {
ActorBase() = default;
virtual ~ActorBase() = default;

virtual void Init(const JobDesc* job_desc, const TaskProto&, const ThreadCtx&) = 0;
virtual void Init(const JobDesc* job_desc, const TaskProto&, StreamContext* stream_ctx) = 0;

// 1: success, and actor finish
// 0: success, and actor not finish
virtual int ProcessMsg(const ActorMsg& msg) = 0;
};

std::unique_ptr<ActorBase> NewActor(const TaskProto&, const ThreadCtx&);
std::unique_ptr<ActorBase> NewActor(const TaskProto&, StreamContext* stream_ctx);

#define REGISTER_ACTOR(task_type, ActorType) \
REGISTER_CLASS(int32_t, task_type, ActorBase, ActorType)
Expand Down
2 changes: 1 addition & 1 deletion oneflow/core/actor/collective_boxing_generic_actor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ class CollectiveBoxingGenericActor : public NaiveActor {
~CollectiveBoxingGenericActor() override = default;

private:
void InitDeviceCtx(const ThreadCtx& thread_ctx) override {
void InitDeviceCtx(StreamContext* stream_ctx) override {
mut_device_ctx().reset(new CollectiveBoxingDeviceCtx());
}
};
Expand Down
32 changes: 6 additions & 26 deletions oneflow/core/actor/copy_comm_net_actor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@ class CopyCommNetActor final : public Actor {
~CopyCommNetActor();

private:
class CommNetDeviceCtx;
struct RegstCtx {
void* comm_net_token;
Regst* regst_raw_ptr;
Expand All @@ -35,7 +34,6 @@ class CopyCommNetActor final : public Actor {
};

void VirtualActorInit(const TaskProto&) override;
void InitDeviceCtx(const ThreadCtx&) override;

std::pair<RegstNameType, HashSet<std::string>> GetNaiveOrCustomizedConsumedRegstDescName()
override {
Expand All @@ -50,44 +48,22 @@ class CopyCommNetActor final : public Actor {
bool IsCustomizedReadReady() const override;
bool IsCustomizedReadAlwaysUnReadyFromNow() const override;
void AsyncReturnAllCustomizedReadableRegst() override;

void AddCallback(std::function<void()> callback) override;
bool is_in_eord_;
HashMap<int64_t, RegstCtx> sequence_number2regst_ctx_;
void* actor_read_id_;
CommNetDeviceCtx* comm_net_device_ctx_;
int64_t next_sequence_number_;
int64_t in_regst_desc_id_;
};

CopyCommNetActor::~CopyCommNetActor() { Global<CommNet>::Get()->DeleteActorReadId(actor_read_id_); }

class CopyCommNetActor::CommNetDeviceCtx final : public DeviceCtx {
public:
CommNetDeviceCtx() = delete;
~CommNetDeviceCtx() = default;

CommNetDeviceCtx(void* actor_read_id) : actor_read_id_(actor_read_id) {}
std::unique_ptr<DeviceCtx> Copy() const { UNIMPLEMENTED(); }

void AddCallBack(std::function<void()> callback) const override {
Global<CommNet>::Get()->AddReadCallBack(actor_read_id_, callback);
}

private:
void* actor_read_id_;
};

void CopyCommNetActor::VirtualActorInit(const TaskProto& task_proto) {
is_in_eord_ = false;
next_sequence_number_ = 0;
in_regst_desc_id_ = Name2SoleRegstDescId("copy_in");
OF_SET_MSG_HANDLER(&CopyCommNetActor::HandlerNormal);
}

void CopyCommNetActor::InitDeviceCtx(const ThreadCtx&) {
actor_read_id_ = Global<CommNet>::Get()->NewActorReadId();
comm_net_device_ctx_ = new CommNetDeviceCtx(actor_read_id_);
mut_device_ctx().reset(comm_net_device_ctx_);
OF_SET_MSG_HANDLER(&CopyCommNetActor::HandlerNormal);
}

void CopyCommNetActor::ForEachCurCustomizedReadableRegst(
Expand Down Expand Up @@ -150,6 +126,10 @@ void CopyCommNetActor::AsyncReturnAllCustomizedReadableRegst() {
CHECK(sequence_number2regst_ctx_.empty());
}

void CopyCommNetActor::AddCallback(std::function<void()> callback) {
Global<CommNet>::Get()->AddReadCallBack(actor_read_id_, callback);
}

REGISTER_ACTOR(TaskType::kCopyCommNet, CopyCommNetActor);

} // namespace oneflow
Loading