Skip to content

Commit

Permalink
cc: 80->100
Browse files Browse the repository at this point in the history
  • Loading branch information
willzhang4a58 committed Apr 23, 2018
1 parent 2ce2f34 commit b992ff5
Show file tree
Hide file tree
Showing 364 changed files with 5,736 additions and 8,823 deletions.
2 changes: 1 addition & 1 deletion .clang-format
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ BreakBeforeTernaryOperators: true
BreakConstructorInitializersBeforeComma: false
BreakAfterJavaFieldAnnotations: false
BreakStringLiterals: true
ColumnLimit: 80
ColumnLimit: 100
CommentPragmas: '^ IWYU pragma:'
BreakBeforeInheritanceComma: false
ConstructorInitializerAllOnOneLineOrOnePerLine: true
Expand Down
17 changes: 5 additions & 12 deletions oneflow/core/actor/accumulate_compute_actor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,7 @@

namespace oneflow {

void AccumulateCompActor::Init(const TaskProto& task_proto, int32_t max_acc_cnt,
ColIdOrder order) {
void AccumulateCompActor::Init(const TaskProto& task_proto, int32_t max_acc_cnt, ColIdOrder order) {
using namespace std::placeholders;
is_in_eord_ = false;
order_ = order;
Expand All @@ -16,8 +15,7 @@ void AccumulateCompActor::Init(const TaskProto& task_proto, int32_t max_acc_cnt,
);
} else {
#ifdef WITH_CUDA
cpy_func_ = std::bind(Memcpy<DeviceType::kGPU>, _1, _2, _3, _4,
cudaMemcpyDeviceToDevice);
cpy_func_ = std::bind(Memcpy<DeviceType::kGPU>, _1, _2, _3, _4, cudaMemcpyDeviceToDevice);
#else
UNIMPLEMENTED();
#endif
Expand All @@ -34,9 +32,7 @@ int AccumulateCompActor::HandlerNormal(const ActorMsg& msg) {
DecreaseRemainingEordCnt();
} else if (msg.msg_type() == ActorMsgType::kRegstMsg) {
Regst* regst = msg.regst();
if (TryUpdtStateAsProducedRegst(regst) != 0) {
pending_in_regst_.push(regst);
}
if (TryUpdtStateAsProducedRegst(regst) != 0) { pending_in_regst_.push(regst); }
ActUntilFail();
} else {
UNIMPLEMENTED();
Expand All @@ -48,9 +44,7 @@ bool AccumulateCompActor::IsReadAlwaysUnReadyFromNow() {
return is_in_eord_ && pending_in_regst_.empty();
}

void AccumulateCompActor::AsyncReturnAllReadableRegst() {
CHECK(pending_in_regst_.empty());
}
void AccumulateCompActor::AsyncReturnAllReadableRegst() { CHECK(pending_in_regst_.empty()); }

void AccumulateCompActor::Act() {
Regst* in_regst = pending_in_regst_.front();
Expand Down Expand Up @@ -85,8 +79,7 @@ void AccumulateCompActor::Act() {
pending_in_regst_.pop();
}

void AccumulateCompActor::ForEachCurReadableRegst(
std::function<void(const Regst*)> handler) {
void AccumulateCompActor::ForEachCurReadableRegst(std::function<void(const Regst*)> handler) {
handler(pending_in_regst_.front());
}

Expand Down
80 changes: 26 additions & 54 deletions oneflow/core/actor/actor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,11 @@ bool IsLastRegstInPieceWithOrder(const Regst* regst, ColIdOrder order) {

bool NeedModelSave(int64_t model_version_id) {
return model_version_id + 1 == Global<JobDesc>::Get()->TotalBatchNum()
|| (model_version_id + 1)
% Global<JobDesc>::Get()->NumOfBatchesInSnapshot()
== 0;
|| (model_version_id + 1) % Global<JobDesc>::Get()->NumOfBatchesInSnapshot() == 0;
}

Actor::~Actor() {
if (Global<RuntimeCtx>::Get()->is_experiment_phase() == false
&& act_id_ >= 0) {
if (Global<RuntimeCtx>::Get()->is_experiment_phase() == false && act_id_ >= 0) {
double avg_act_interval = act_interval_acc_ / (act_id_ + 1);
Global<CtrlClient>::Get()->PushAvgActInterval(actor_id_, avg_act_interval);
}
Expand All @@ -42,9 +39,7 @@ void Actor::Init(const TaskProto& task_proto, const ThreadCtx& thread_ctx) {
for (const auto& pair : task_proto.produced_regst_desc()) {
Global<RegstMgr>::Get()->NewRegsts(
pair.second, GetDeviceType(), task_proto.record_type(),
[this](Regst* regst) {
produced_regsts_[regst->regst_desc_id()].emplace_back(regst);
});
[this](Regst* regst) { produced_regsts_[regst->regst_desc_id()].emplace_back(regst); });
int64_t regst_desc_id = pair.second.regst_desc_id();
CHECK(name2regst_desc_id_.emplace(pair.first, regst_desc_id).second);
}
Expand All @@ -67,13 +62,9 @@ void Actor::Init(const TaskProto& task_proto, const ThreadCtx& thread_ctx) {
VirtualActorInit(task_proto);
}

int64_t Actor::machine_id() const {
return Global<IDMgr>::Get()->MachineId4ActorId(actor_id_);
}
int64_t Actor::machine_id() const { return Global<IDMgr>::Get()->MachineId4ActorId(actor_id_); }

int64_t Actor::thrd_id() const {
return Global<IDMgr>::Get()->ThrdId4ActorId(actor_id_);
}
int64_t Actor::thrd_id() const { return Global<IDMgr>::Get()->ThrdId4ActorId(actor_id_); }

int64_t Actor::RegstDescId4Name(const std::string& name) const {
auto find_it = name2regst_desc_id_.find(name);
Expand All @@ -89,10 +80,10 @@ void Actor::InitDeviceCtx(const ThreadCtx&) {
}
#ifdef WITH_CUDA
case DeviceType::kGPU: {
device_ctx_.reset(new CudaDeviceCtx(
NewWorkStreamId(), cuda_handle_.cuda_stream(),
cuda_handle_.cublas_pmh_handle(), cuda_handle_.cublas_pmd_handle(),
cuda_handle_.cudnn_handle(), cuda_handle_.eigen_gpu_device()));
device_ctx_.reset(
new CudaDeviceCtx(NewWorkStreamId(), cuda_handle_.cuda_stream(),
cuda_handle_.cublas_pmh_handle(), cuda_handle_.cublas_pmd_handle(),
cuda_handle_.cudnn_handle(), cuda_handle_.eigen_gpu_device()));
break;
}
#endif
Expand All @@ -115,9 +106,7 @@ int Actor::HandlerZombie(const ActorMsg& msg) {
if (msg.msg_type() == ActorMsgType::kEordMsg) {
remaining_eord_cnt_ -= 1;
} else if (msg.msg_type() == ActorMsgType::kRegstMsg) {
if (TryUpdtStateAsProducedRegst(msg.regst()) != 0) {
AsyncSendRegstMsgToProducer(msg.regst());
}
if (TryUpdtStateAsProducedRegst(msg.regst()) != 0) { AsyncSendRegstMsgToProducer(msg.regst()); }
} else {
UNIMPLEMENTED();
}
Expand All @@ -141,8 +130,7 @@ void Actor::ActUntilFail() {
ReadableRegstInfo* info = act_event->add_readable_regst_infos();
SetReadableRegstInfo(readable_regst, info);
});
device_ctx_->AddCallBack(
[act_event]() { act_event->set_start_time(GetCurTime()); });
device_ctx_->AddCallBack([act_event]() { act_event->set_start_time(GetCurTime()); });
}
double cur_time = GetCurTime();
if (last_act_start_time_ > 0.0) {
Expand Down Expand Up @@ -182,25 +170,21 @@ int Actor::TrySwitchToZombieOrFinish() {
return 0;
}

void Actor::AsyncLaunchKernel(
const KernelCtx& kernel_ctx,
std::function<Regst*(int64_t)> Regst4RegstDescId) {
void Actor::AsyncLaunchKernel(const KernelCtx& kernel_ctx,
std::function<Regst*(int64_t)> Regst4RegstDescId) {
for (const ExecKernel& ek : exec_kernel_vec_) {
ek.kernel->Launch(kernel_ctx, [&](const std::string& bn_in_op) -> Blob* {
auto regst_desc_id_it = ek.bn_in_op2regst_desc_id.find(bn_in_op);
if (regst_desc_id_it == ek.bn_in_op2regst_desc_id.end()) {
return nullptr;
}
if (regst_desc_id_it == ek.bn_in_op2regst_desc_id.end()) { return nullptr; }
Regst* regst = Regst4RegstDescId(regst_desc_id_it->second);
const LogicalBlobId& lbi = ek.kernel->BnInOp2Lbi(bn_in_op);
return regst->GetBlobByLbi(lbi);
});
}
}

void Actor::AsyncSendRegstMsgToConsumer(
std::function<bool(Regst*)> RegstPreProcess,
std::function<bool(int64_t)> IsAllowedActor) {
void Actor::AsyncSendRegstMsgToConsumer(std::function<bool(Regst*)> RegstPreProcess,
std::function<bool(int64_t)> IsAllowedActor) {
int64_t this_actor_id = actor_id_;
for (auto& pair : writeable_produced_regst_) {
if (pair.second.empty()) { continue; }
Expand All @@ -214,8 +198,7 @@ void Actor::AsyncSendRegstMsgToConsumer(
regst_reading_cnt_it->second += 1;
regst->set_act_id(act_id_);
device_ctx_->AddCallBack([consumer, regst, this_actor_id]() {
ActorMsg msg =
ActorMsg::BuildRegstMsgToConsumer(this_actor_id, consumer, regst);
ActorMsg msg = ActorMsg::BuildRegstMsgToConsumer(this_actor_id, consumer, regst);
Global<ActorMsgBus>::Get()->SendMsg(std::move(msg));
});
}
Expand All @@ -224,13 +207,11 @@ void Actor::AsyncSendRegstMsgToConsumer(
}
}

void Actor::AsyncSendRegstMsgToConsumer(
std::function<bool(Regst*)> RegstPreProcess) {
void Actor::AsyncSendRegstMsgToConsumer(std::function<bool(Regst*)> RegstPreProcess) {
AsyncSendRegstMsgToConsumer(RegstPreProcess, [](int64_t) { return true; });
}

void Actor::AsyncSendRegstMsgToConsumer(
std::function<bool(int64_t)> IsAllowedActor) {
void Actor::AsyncSendRegstMsgToConsumer(std::function<bool(int64_t)> IsAllowedActor) {
AsyncSendRegstMsgToConsumer([](Regst*) { return true; }, IsAllowedActor);
}

Expand All @@ -239,21 +220,17 @@ void Actor::AsyncSendRegstMsgToConsumer() {
}

void Actor::AsyncSendEORDMsgToConsumers(int64_t regst_desc_id) {
const RtRegstDesc* regst_desc =
produced_regsts_.at(regst_desc_id).front()->regst_desc();
const RtRegstDesc* regst_desc = produced_regsts_.at(regst_desc_id).front()->regst_desc();
device_ctx_->AddCallBack([regst_desc]() {
for (int64_t consumer : regst_desc->consumers_actor_id()) {
ActorMsg msg =
ActorMsg::BuildEordMsg(consumer, regst_desc->regst_desc_id());
ActorMsg msg = ActorMsg::BuildEordMsg(consumer, regst_desc->regst_desc_id());
Global<ActorMsgBus>::Get()->SendMsg(std::move(msg));
}
});
}

void Actor::AsyncSendEORDMsgForAllProducedRegstDesc() {
for (const auto& pair : produced_regsts_) {
AsyncSendEORDMsgToConsumers(pair.first);
}
for (const auto& pair : produced_regsts_) { AsyncSendEORDMsgToConsumers(pair.first); }
}

void Actor::AsyncSendRegstMsgToProducer(Regst* regst) {
Expand All @@ -262,13 +239,10 @@ void Actor::AsyncSendRegstMsgToProducer(Regst* regst) {

void Actor::AsyncSendRegstMsgToProducer(Regst* regst, int64_t producer) {
ActorMsg msg = ActorMsg::BuildRegstMsgToProducer(actor_id_, producer, regst);
device_ctx_->AddCallBack(
[msg]() { Global<ActorMsgBus>::Get()->SendMsg(msg); });
device_ctx_->AddCallBack([msg]() { Global<ActorMsgBus>::Get()->SendMsg(msg); });
}

void Actor::AsyncDo(std::function<void()> func) {
device_ctx_->AddCallBack(func);
}
void Actor::AsyncDo(std::function<void()> func) { device_ctx_->AddCallBack(func); }

int Actor::TryUpdtStateAsProducedRegst(Regst* regst) {
auto reading_cnt_it = produced_regst2reading_cnt_.find(regst);
Expand Down Expand Up @@ -301,8 +275,7 @@ Regst* Actor::GetCurSoleWriteableRegst() {
}

int64_t Actor::GetReservedWorkStreamId(int64_t reserved_id) {
return Global<IDMgr>::Get()->GetReservedWorkStreamId(machine_id(), thrd_id(),
reserved_id);
return Global<IDMgr>::Get()->GetReservedWorkStreamId(machine_id(), thrd_id(), reserved_id);
}

int64_t Actor::NewWorkStreamId() {
Expand All @@ -322,8 +295,7 @@ void AddActorCreator(TaskType task_type, std::function<Actor*()> creator) {
CHECK(ActorCreatorMap().emplace(task_type, creator).second);
}

std::unique_ptr<Actor> NewActor(const TaskProto& task_proto,
const ThreadCtx& thread_ctx) {
std::unique_ptr<Actor> NewActor(const TaskProto& task_proto, const ThreadCtx& thread_ctx) {
auto it = ActorCreatorMap().find(task_proto.task_type());
CHECK(it != ActorCreatorMap().end()) << TaskType_Name(task_proto.task_type());
Actor* rptr = it->second();
Expand Down
8 changes: 3 additions & 5 deletions oneflow/core/actor/actor.h
Original file line number Diff line number Diff line change
Expand Up @@ -84,8 +84,7 @@ class Actor {
int TrySwitchToZombieOrFinish();

// Async Do on device_ctx_
void AsyncLaunchKernel(const KernelCtx&,
std::function<Regst*(int64_t)> Regst4RegstDescId);
void AsyncLaunchKernel(const KernelCtx&, std::function<Regst*(int64_t)> Regst4RegstDescId);
void AsyncSendRegstMsgToConsumer(std::function<bool(Regst*)> RegstPreProcess,
std::function<bool(int64_t)> IsAllowedActor);
void AsyncSendRegstMsgToConsumer(std::function<bool(Regst*)> RegstPreProcess);
Expand Down Expand Up @@ -139,9 +138,8 @@ struct ActorRegistry {
}
};

#define REGISTER_ACTOR(TaskType, ActorType) \
static ActorRegistry<TaskType, ActorType> OF_PP_CAT( \
g_actor_##ActorType##registry_var, __LINE__)
#define REGISTER_ACTOR(TaskType, ActorType) \
static ActorRegistry<TaskType, ActorType> OF_PP_CAT(g_actor_##ActorType##registry_var, __LINE__)

} // namespace oneflow

Expand Down
3 changes: 1 addition & 2 deletions oneflow/core/actor/actor_message.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,7 @@ ActorMsg ActorMsg::BuildRegstMsgToConsumer(int64_t producer, int64_t consumer,
== Global<MachineCtx>::Get()->this_machine_id()) {
msg.regst_wrapper_.comm_net_token = nullptr;
} else {
msg.regst_wrapper_.comm_net_token =
regst_raw_ptr->packed_blob()->comm_net_token();
msg.regst_wrapper_.comm_net_token = regst_raw_ptr->packed_blob()->comm_net_token();
msg.regst_wrapper_.regst_status = regst_raw_ptr->status();
}
return msg;
Expand Down
6 changes: 2 additions & 4 deletions oneflow/core/actor/actor_message.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,8 @@ class ActorMsg final {
~ActorMsg() = default;

// Build Msg
static ActorMsg BuildRegstMsgToConsumer(int64_t producer, int64_t consumer,
Regst*);
static ActorMsg BuildRegstMsgToProducer(int64_t consumer, int64_t producer,
Regst*);
static ActorMsg BuildRegstMsgToConsumer(int64_t producer, int64_t consumer, Regst*);
static ActorMsg BuildRegstMsgToProducer(int64_t consumer, int64_t producer, Regst*);
static ActorMsg BuildEordMsg(int64_t consumer, int64_t regst_desc_id);
static ActorMsg BuildCommandMsg(int64_t dst_actor_id, ActorCmd cmd);

Expand Down
3 changes: 1 addition & 2 deletions oneflow/core/actor/actor_message_bus.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,7 @@
namespace oneflow {

void ActorMsgBus::SendMsg(const ActorMsg& msg) {
int64_t dst_machine_id =
Global<IDMgr>::Get()->MachineId4ActorId(msg.dst_actor_id());
int64_t dst_machine_id = Global<IDMgr>::Get()->MachineId4ActorId(msg.dst_actor_id());
if (dst_machine_id == Global<MachineCtx>::Get()->this_machine_id()) {
int64_t thrd_id = Global<IDMgr>::Get()->ThrdId4ActorId(msg.dst_actor_id());
Global<ThreadMgr>::Get()->GetThrd(thrd_id)->GetMsgChannelPtr()->Send(msg);
Expand Down
41 changes: 15 additions & 26 deletions oneflow/core/actor/backward_compute_actor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,7 @@ void BackwardCompActor::VirtualCompActorInit(const TaskProto& task_proto) {

int BackwardCompActor::HandlerNormal(const ActorMsg& msg) {
if (msg.msg_type() == ActorMsgType::kEordMsg) {
if (msg.eord_regst_desc_id() == out_diff_regst_desc_id_) {
is_out_diff_eord_ = true;
}
if (msg.eord_regst_desc_id() == out_diff_regst_desc_id_) { is_out_diff_eord_ = true; }
DecreaseRemainingEordCnt();
} else if (msg.msg_type() == ActorMsgType::kRegstMsg) {
Regst* regst = msg.regst();
Expand All @@ -42,13 +40,10 @@ int BackwardCompActor::HandlerNormal(const ActorMsg& msg) {
return TrySwitchToZombieOrFinish();
}

bool BackwardCompActor::IsReadReady() {
return readable_regsts_.size() == readable_regst_cnt_;
}
bool BackwardCompActor::IsReadReady() { return readable_regsts_.size() == readable_regst_cnt_; }

bool BackwardCompActor::IsReadAlwaysUnReadyFromNow() {
return is_out_diff_eord_
&& readable_regsts_.at(out_diff_regst_desc_id_).empty();
return is_out_diff_eord_ && readable_regsts_.at(out_diff_regst_desc_id_).empty();
}

void BackwardCompActor::AsyncReturnAllReadableRegst() {
Expand All @@ -66,19 +61,15 @@ void BackwardCompActor::AsyncReturnModelRegstUntilMatchCurOutRegst() {
const Regst* cur_out_regst = readable_regsts_.at(out_regst_desc_id_).front();
int64_t cur_model_id = cur_out_regst->model_version_id();
std::queue<Regst*>& model_rq = readable_regsts_.at(model_regst_desc_id_);
while (!model_rq.empty()
&& model_rq.front()->model_version_id() < cur_model_id) {
while (!model_rq.empty() && model_rq.front()->model_version_id() < cur_model_id) {
AsyncSendRegstMsgToProducer(model_rq.front());
model_rq.pop();
if (model_rq.empty()) { readable_regst_cnt_ -= 1; }
}
if (!model_rq.empty()) {
CHECK_EQ(model_rq.front()->model_version_id(), cur_model_id);
}
if (!model_rq.empty()) { CHECK_EQ(model_rq.front()->model_version_id(), cur_model_id); }
}

void BackwardCompActor::AsyncReturnModelRegstUntilLastPieceIdGreaterThan(
int64_t piece_id) {
void BackwardCompActor::AsyncReturnModelRegstUntilLastPieceIdGreaterThan(int64_t piece_id) {
if (model_regst_desc_id_ == -1) { return; }
std::queue<Regst*>& model_rq = readable_regsts_.at(model_regst_desc_id_);
while (model_rq.empty() == false) {
Expand All @@ -94,15 +85,14 @@ void BackwardCompActor::AsyncReturnModelRegstUntilLastPieceIdGreaterThan(
void BackwardCompActor::Act() {
std::queue<Regst*>& out_rq = readable_regsts_.at(out_regst_desc_id_);
int64_t piece_id = out_rq.front()->piece_id();
AsyncLaunchKernel(GenDefaultKernelCtx(),
[this](int64_t regst_desc_id) -> Regst* {
Regst* regst = GetCurWriteableRegst(regst_desc_id);
if (regst == nullptr) {
return readable_regsts_.at(regst_desc_id).front();
} else {
return regst;
}
});
AsyncLaunchKernel(GenDefaultKernelCtx(), [this](int64_t regst_desc_id) -> Regst* {
Regst* regst = GetCurWriteableRegst(regst_desc_id);
if (regst == nullptr) {
return readable_regsts_.at(regst_desc_id).front();
} else {
return regst;
}
});
AsyncSendRegstMsgToConsumer([&](Regst* regst) {
regst->set_piece_id(piece_id);
return true;
Expand All @@ -125,8 +115,7 @@ void BackwardCompActor::Act() {
}
}

void BackwardCompActor::ForEachCurReadableRegst(
std::function<void(const Regst*)> handler) {
void BackwardCompActor::ForEachCurReadableRegst(std::function<void(const Regst*)> handler) {
for (const auto& pair : readable_regsts_) { handler(pair.second.front()); }
}

Expand Down
Loading

0 comments on commit b992ff5

Please sign in to comment.