Skip to content

Commit

Permalink
Refine acc actor (#6444)
Browse files Browse the repository at this point in the history
Co-authored-by: oneflow-ci-bot <69100618+oneflow-ci-bot@users.noreply.github.com>
  • Loading branch information
liujuncheng and oneflow-ci-bot committed Oct 8, 2021
1 parent d893f25 commit 57d0d18
Show file tree
Hide file tree
Showing 12 changed files with 16 additions and 607 deletions.
2 changes: 0 additions & 2 deletions cmake/cfg.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,6 @@ function(GENERATE_CFG_AND_PYBIND11_CPP SRCS HDRS PYBIND_SRCS ROOT_DIR)
oneflow/core/job/sbp_parallel.proto
oneflow/core/graph/boxing/collective_boxing.proto
oneflow/core/register/blob_desc.proto
oneflow/core/register/pod.proto
oneflow/core/job/scope.proto
oneflow/core/job/mirrored_parallel.proto
oneflow/core/operator/op_attribute.proto
Expand Down Expand Up @@ -100,7 +99,6 @@ function(GENERATE_CFG_AND_PYBIND11_CPP SRCS HDRS PYBIND_SRCS ROOT_DIR)
oneflow/core/operator/interface_blob_conf.proto
oneflow/core/common/shape.proto
oneflow/core/register/blob_desc.proto
oneflow/core/register/pod.proto
oneflow/core/operator/op_conf.proto
)

Expand Down
47 changes: 10 additions & 37 deletions oneflow/core/actor/acc_actor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,18 +23,14 @@ class AccActor final : public Actor {
AccActor() = default;
~AccActor() override = default;

using Actor::Init;

private:
void Act() override;
void VirtualAsyncSendNaiveProducedRegstMsgToConsumer() override;

void VirtualActorInit(const TaskProto& proto) override;
void Init(const TaskProto&, int32_t max_acc_cnt);

std::function<void(DeviceCtx*, void* dst, const void* src, size_t)> cpy_func_;
int32_t acc_cnt_;
int32_t max_acc_cnt_;
int32_t acc_cnt_{};
int32_t max_acc_cnt_{};
};

void AccActor::VirtualActorInit(const TaskProto& proto) {
Expand All @@ -45,44 +41,21 @@ void AccActor::VirtualActorInit(const TaskProto& proto) {
->RegstDesc4RegstDescId(Name2SoleRegstDescId("out"))
.data_regst_time_shape();
CHECK_GE(in_time_shape.elem_cnt(), out_time_shape.elem_cnt());
Init(proto, in_time_shape.elem_cnt() / out_time_shape.elem_cnt());
}

void AccActor::Init(const TaskProto& task_proto, int32_t max_acc_cnt) {
using namespace std::placeholders;
if (GetDeviceType() == DeviceType::kCPU) {
cpy_func_ = std::bind(Memcpy<DeviceType::kCPU>, _1, _2, _3, _4);
} else {
#ifdef WITH_CUDA
cpy_func_ = std::bind(Memcpy<DeviceType::kGPU>, _1, _2, _3, _4);
#else
UNIMPLEMENTED();
#endif
}
OF_SET_MSG_HANDLER(&AccActor::HandlerNormal);
max_acc_cnt_ = in_time_shape.elem_cnt() / out_time_shape.elem_cnt();
acc_cnt_ = 0;
max_acc_cnt_ = max_acc_cnt;
OF_SET_MSG_HANDLER(&AccActor::HandlerNormal);
}

void AccActor::Act() {
Regst* out_regst = GetNaiveCurWriteable("out");
Regst* in_regst = GetNaiveCurReadable("in");
if (acc_cnt_ == 0) {
Regst* out_regst = GetNaiveCurWriteable("out");
Regst* in_regst = GetNaiveCurReadable("in");
const Blob* in_blob = in_regst->GetMutSoleBlob();
Blob* out_blob = out_regst->GetMutSoleBlob();
if (GetDeviceType() == DeviceType::kCPU) {
Memcpy<DeviceType::kCPU>(mut_device_ctx().get(), out_blob->ForceMutDptr(), in_blob->dptr(),
out_blob->ByteSizeOfBlobBody());
} else if (GetDeviceType() == DeviceType::kGPU) {
#ifdef WITH_CUDA
Memcpy<DeviceType::kGPU>(mut_device_ctx().get(), out_blob->ForceMutDptr(), in_blob->dptr(),
out_blob->ByteSizeOfBlobBody());
#else
UNIMPLEMENTED();
#endif
} else {
UNIMPLEMENTED();
}
const size_t size = in_blob->ByteSizeOfBlobBody();
CHECK_EQ(out_blob->ByteSizeOfBlobBody(), size);
AutoMemcpy(mut_device_ctx().get(), out_blob->ForceMutDptr(), in_blob->dptr(), size,
out_blob->mem_case(), in_blob->mem_case());
} else {
AsyncLaunchKernel();
}
Expand Down
4 changes: 0 additions & 4 deletions oneflow/core/actor/actor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -297,10 +297,6 @@ void Actor::ForEachProducedRegst(const std::function<void(Regst*)>& Handler) con
}
}

DeviceType Actor::GetDeviceType() const {
return Global<IDMgr>::Get()->GetDeviceTypeFromActorId(actor_id_);
}

int64_t Actor::Name2SoleRegstDescId(const std::string& name) const {
auto find_it = name2regst_desc_id_.find(name);
if (find_it != name2regst_desc_id_.end()) {
Expand Down
1 change: 0 additions & 1 deletion oneflow/core/actor/actor.h
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,6 @@ class Actor : public ActorBase {
const ParallelContext* parallel_ctx() const { return parallel_ctx_.get(); }
bool ReceiveAllEordMsg() const { return remaining_eord_cnt_ == 0; }
bool ReceiveEordMsg(int64_t regst_desc_id) const;
DeviceType GetDeviceType() const;
virtual void VirtualActorInit(const TaskProto&) {}
int64_t Name2SoleRegstDescId(const std::string& name) const;
const std::vector<int64_t>& Name2RegstDescIds(const std::string& name) const;
Expand Down
7 changes: 6 additions & 1 deletion oneflow/core/framework/op_arg_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@ limitations under the License.
#include "oneflow/core/common/shape.cfg.h"
#include "oneflow/core/register/logical_blob_id.cfg.h"
#include "oneflow/core/operator/interface_blob_conf.cfg.h"
#include "oneflow/core/register/pod.cfg.h"
#include "oneflow/core/register/blob_desc.cfg.h"
#include "oneflow/core/operator/op_node_signature.cfg.h"
#include "oneflow/core/job/parallel_signature.cfg.h"
Expand All @@ -42,6 +41,9 @@ class OpArgBlobAttribute {
const std::string& logical_blob_name);

OpArgBlobAttribute(const OpArgBlobAttribute& op_arg_blob_attr) = default;
OpArgBlobAttribute(OpArgBlobAttribute&& op_arg_blob_attr) = delete;
OpArgBlobAttribute& operator=(const OpArgBlobAttribute&) = delete;
OpArgBlobAttribute& operator=(OpArgBlobAttribute&&) = delete;
virtual ~OpArgBlobAttribute() = default;

std::shared_ptr<cfg::BlobDescProto> blob_desc() const;
Expand Down Expand Up @@ -78,6 +80,9 @@ class OpArgParallelAttribute {
const std::shared_ptr<cfg::OptMirroredParallel>& opt_mirrored_parallel);

OpArgParallelAttribute(const OpArgParallelAttribute& op_arg_para_attr) = default;
OpArgParallelAttribute(OpArgParallelAttribute&& op_arg_blob_attr) = delete;
OpArgParallelAttribute& operator=(const OpArgParallelAttribute&) = delete;
OpArgParallelAttribute& operator=(OpArgParallelAttribute&&) = delete;
virtual ~OpArgParallelAttribute() = default;

std::shared_ptr<ParallelDesc> parallel_desc_symbol() const;
Expand Down
4 changes: 0 additions & 4 deletions oneflow/core/job/id_manager.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,6 @@ limitations under the License.

namespace oneflow {

DeviceType IDMgr::GetDeviceTypeFromActorId(int64_t actor_id) const {
return DeserializeTaskIdFromInt64(actor_id).stream_id().device_id().device_type();
}

int64_t IDMgr::MachineId4ActorId(int64_t actor_id) const {
// TODO: change this inferface semantics, rank does not indicate machine_id in multi-client
return DeserializeTaskIdFromInt64(actor_id).stream_id().device_id().rank();
Expand Down
1 change: 0 additions & 1 deletion oneflow/core/job/id_manager.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@ class IDMgr final {
int64_t NewChunkId() { return chunk_id_count_++; }

// Runtime
DeviceType GetDeviceTypeFromActorId(int64_t actor_id) const;
int64_t MachineId4ActorId(int64_t actor_id) const;
int64_t ThrdId4ActorId(int64_t actor_id) const;

Expand Down
41 changes: 0 additions & 41 deletions oneflow/core/register/pod.proto

This file was deleted.

Loading

0 comments on commit 57d0d18

Please sign in to comment.