Skip to content

Commit

Permalink
refactor message OperatorConf, change device_type to device_tag (#3411)
Browse files Browse the repository at this point in the history
* refactor message OperatorConf, change device_type to device_tag

* subsititute HobDeviceType with HobDeviceTag in user_op kernel registration

* remove c_api_util.DeviceType4DeviceTag

* fix error when buil with cuda off

* fix can not use CHECK_JUST macro in another macro
  • Loading branch information
Ldpe2G committed Aug 11, 2020
1 parent 4d44113 commit 120a7f8
Show file tree
Hide file tree
Showing 146 changed files with 612 additions and 575 deletions.
2 changes: 1 addition & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ if(WIN32)
#set(CMAKE_EXE_LINKER_FLAGS_DEBUG "${CMAKE_EXE_LINKER_FLAGS} /DEBUG:FASTLINK")
set(CMAKE_CXX_FLAGS_DEBUG "${CMAKE_CXX_FLAGS_DEBUG} /D_ITERATOR_DEBUG_LEVEL=0")
else()
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++11 -Wall -Wno-sign-compare -Wno-unused-function -fPIC -Werror=return-type")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++11 -Wall -Wno-sign-compare -Wno-unused-function -fPIC")
endif()

if (THIRD_PARTY)
Expand Down
2 changes: 1 addition & 1 deletion oneflow/core/common/error.proto
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ enum JobBuildAndInferError {
kLogicalBlobNameInvalid = 402;

kOpNameExist = 450;
kOpConfDeviceTypeNoSet = 460;
kOpConfDeviceTagNoSet = 460;
kPlacementError = 470;
kBlobSplitAxisInferError = 480;
kUnknownJobBuildAndInferError = 500;
Expand Down
3 changes: 2 additions & 1 deletion oneflow/core/framework/to_string.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,15 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/common/util.h"
#include "oneflow/core/framework/to_string.h"

namespace oneflow {

Maybe<const char*> DeviceTag4DeviceType(DeviceType device_type) {
if (device_type == kCPU) { return "cpu"; }
if (device_type == kGPU) { return "gpu"; }
return Error::DeviceTagNotFound() << "invalid";
return Error::DeviceTagNotFound() << "invalid_device";
}

Maybe<DeviceType> DeviceType4DeviceTag(const std::string& device_tag) {
Expand Down
15 changes: 8 additions & 7 deletions oneflow/core/framework/user_op_hob.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,13 +40,6 @@ hob::BoolFunctorPtr<KernelRegContext> HobFalse() {
return krbf_ptr;
}

hob::HobContextGetter<KernelRegContext, DeviceType> HobDeviceType() {
std::ostringstream string_stream;
string_stream << "device_type";
return hob::HobContextGetter<KernelRegContext, DeviceType>(
string_stream.str(), [](const KernelRegContext& ctx) { return ctx.device_type(); });
}

hob::HobContextGetter<KernelRegContext, DataType> HobDataType(const std::string& tensor_name,
int tensor_idx) {
std::ostringstream string_stream;
Expand All @@ -58,6 +51,14 @@ hob::HobContextGetter<KernelRegContext, DataType> HobDataType(const std::string&
});
}

HobStringContextGetter<KernelRegContext> HobDeviceTag() {
std::ostringstream string_stream;
string_stream << "device_tag";
return HobStringContextGetter<KernelRegContext>(
string_stream.str(),
[](const KernelRegContext& ctx) -> const std::string& { return ctx.device_tag(); });
}

} // namespace user_op

} // namespace oneflow
45 changes: 42 additions & 3 deletions oneflow/core/framework/user_op_hob.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,9 @@ limitations under the License.
#ifndef ONEFLOW_CORE_FRAMEWORK_USER_OP_HOB_H_
#define ONEFLOW_CORE_FRAMEWORK_USER_OP_HOB_H_

#include "oneflow/core/common/high_order_bool.h"
#include "oneflow/core/common/data_type.h"
#include "oneflow/core/common/high_order_bool.h"
#include "oneflow/core/framework/to_string.h"
#include "oneflow/core/framework/user_op_registry_manager.h"

namespace oneflow {
Expand All @@ -28,8 +29,6 @@ hob::BoolFunctorPtr<KernelRegContext> HobTrue();

hob::BoolFunctorPtr<KernelRegContext> HobFalse();

hob::HobContextGetter<KernelRegContext, DeviceType> HobDeviceType();

hob::HobContextGetter<KernelRegContext, DataType> HobDataType(const std::string& tensor_name,
int tensor_idx);

Expand All @@ -47,6 +46,46 @@ hob::HobContextGetter<user_op::KernelRegContext, T> HobAttr(const std::string& a
});
}

template<typename ContextT>
class HobStringContextGetter final {
public:
HobStringContextGetter(const DeviceType& device_type) {
std::string str = ToString(device_type);
debug_str_ = str;
context_getter_ = [str](const ContextT&) -> const std::string& { return str; };
}
HobStringContextGetter(const char* const_value) {
std::string str(const_value);
debug_str_ = str;
context_getter_ = [str](const ContextT&) -> const std::string& { return str; };
}
HobStringContextGetter(const std::string& const_value)
: debug_str_(const_value),
context_getter_(
[const_value](const ContextT&) -> const std::string& { return const_value; }) {}
HobStringContextGetter(const std::string& debug_str,
const std::function<const std::string&(const ContextT&)>& context_getter)
: debug_str_(debug_str), context_getter_(context_getter) {}

hob::BoolFunctorPtr<ContextT> operator==(const HobStringContextGetter& other) const {
std::ostringstream string_stream;
string_stream << debug_str_ << " == " << other.debug_str_;
std::function<std::string(const ContextT&)> l_fn = this->context_getter_;
std::function<std::string(const ContextT&)> r_fn = other.context_getter_;
std::shared_ptr<const hob::BoolFunctor<ContextT>> krbf_ptr =
std::make_shared<const hob::HighOrderBoolFunctor<ContextT>>(
string_stream.str(),
[l_fn, r_fn](const ContextT& ctx) { return l_fn(ctx) == r_fn(ctx); });
return krbf_ptr;
}

private:
std::string debug_str_;
std::function<const std::string&(const ContextT&)> context_getter_;
};

HobStringContextGetter<KernelRegContext> HobDeviceTag();

} // namespace user_op

} // namespace oneflow
Expand Down
1 change: 1 addition & 0 deletions oneflow/core/framework/user_op_kernel_registry.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ class KernelRegContext {
virtual ~KernelRegContext() = default;

virtual DeviceType device_type() const = 0;
virtual const std::string& device_tag() const = 0;
virtual const ParallelContext& parallel_ctx() const = 0;
virtual const TensorDesc* TensorDesc4ArgNameAndIndex(const std::string&, int32_t) const = 0;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,11 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/framework/to_string.h"
#include "oneflow/core/graph/boxing/chain_sub_task_graph_builder.h"
#include "oneflow/core/graph/boxing/collective_boxing_sub_task_graph_builder.h"
#include "oneflow/core/graph/boxing/sub_task_graph_builder_util.h"
#include "oneflow/core/graph/collective_boxing_task_node.h"
#include "oneflow/core/graph/boxing/chain_sub_task_graph_builder.h"
#include "oneflow/core/graph/slice_boxing_task_node.h"

namespace oneflow {
Expand All @@ -31,7 +32,7 @@ void NcclInitCollectiveNode(CollectiveBoxingGenericTaskNode* node,
const BlobDesc& logical_blob_desc, OpType op_type, int64_t root) {
OperatorConf op_conf;
op_conf.set_name(name);
op_conf.set_device_type(DeviceType::kGPU);
op_conf.set_device_tag(CHECK_JUST(DeviceTag4DeviceType(DeviceType::kGPU)));
CollectiveBoxingGenericOpConf* conf = op_conf.mutable_collective_boxing_generic_conf();
*conf->mutable_lbi() = lbi;
RankDesc* rank_desc = conf->mutable_rank_desc();
Expand Down
3 changes: 2 additions & 1 deletion oneflow/core/graph/boxing_identity_compute_task_node.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/framework/to_string.h"
#include "oneflow/core/graph/boxing_identity_compute_task_node.h"
#include "oneflow/core/graph/logical_node.h"

Expand Down Expand Up @@ -41,7 +42,7 @@ void BoxingIdentityCompTaskNode::BuildExecGphAndRegst() {
ExecNode* node = mut_exec_gph().NewNode();
OperatorConf op_conf;
op_conf.set_name("System-Boxing-Identity-" + NewUniqueId());
op_conf.set_device_type(this->device_type());
op_conf.set_device_tag(CHECK_JUST(DeviceTag4DeviceType(this->device_type())));
*op_conf.mutable_boxing_identity_conf()->mutable_lbi() = lbi_;
std::shared_ptr<Operator> sole_op = ConstructOp(op_conf, &GlobalJobDesc());
node->mut_op() = sole_op;
Expand Down
7 changes: 3 additions & 4 deletions oneflow/core/graph/chain_graph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ limitations under the License.
#include "oneflow/core/thread/thread_pool.h"
#include "oneflow/core/common/blocking_counter.h"
#include "oneflow/core/framework/config_def.h"
#include "oneflow/core/framework/to_string.h"
#include "oneflow/core/job/global_for.h"

namespace oneflow {
Expand Down Expand Up @@ -180,10 +181,8 @@ void CollectIgnoreTaskEdgesInFirstMergedChains(const std::vector<std::vector<Tas
if (fw_node == nullptr) { continue; }
if (fw_node->logical_node()->op_vec().size() != 1) { continue; }
const auto& src_op = *fw_node->logical_node()->SoleOp();
if (src_op.op_conf().has_variable_conf()
&& src_op.op_conf().device_type() == DeviceType::kGPU) {
return true;
}
DeviceType device_type = CHECK_JUST(DeviceType4DeviceTag(src_op.op_conf().device_tag()));
if (src_op.op_conf().has_variable_conf() && device_type == DeviceType::kGPU) { return true; }
}
return false;
};
Expand Down
7 changes: 4 additions & 3 deletions oneflow/core/graph/copy_task_node.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,10 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/framework/to_string.h"
#include "oneflow/core/graph/copy_task_node.h"
#include "oneflow/core/operator/operator.h"
#include "oneflow/core/job/thrd_id_generator.h"
#include "oneflow/core/operator/operator.h"

namespace oneflow {

Expand Down Expand Up @@ -79,7 +80,7 @@ void CopyHdTaskNode::InitProducedRegstMemCase(MemoryCase* mem_case) {
OperatorConf CopyHdTaskNode::NewCopyOpConf() {
OperatorConf conf;
conf.set_name("copy_hd_" + NewUniqueId());
conf.set_device_type(device_type());
conf.set_device_tag(CHECK_JUST(DeviceTag4DeviceType(device_type())));
conf.mutable_copy_hd_conf()->set_type(copy_type_);
auto in_regst = GetSoleConsumedRegst("copy_in");
if (in_regst->NumOfLbi() == 1) {
Expand Down Expand Up @@ -141,7 +142,7 @@ void CopyCommNetTaskNode::PinConsumedRegstMemCase(MemoryCase* mem_case) {
OperatorConf CopyCommNetTaskNode::NewCopyOpConf() {
OperatorConf conf;
conf.set_name("copy_comm_net_" + NewUniqueId());
conf.set_device_type(device_type());
conf.set_device_tag(CHECK_JUST(DeviceTag4DeviceType(this->device_type())));
conf.mutable_copy_comm_net_conf();
return conf;
}
Expand Down
7 changes: 4 additions & 3 deletions oneflow/core/graph/logical_graph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,13 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/common/balanced_splitter.h"
#include "oneflow/core/framework/to_string.h"
#include "oneflow/core/graph/logical_graph.h"
#include "oneflow/core/graph/op_graph.h"
#include "oneflow/core/job/global_for.h"
#include "oneflow/core/operator/operator.h"
#include "oneflow/core/operator/op_conf_util.h"
#include "oneflow/core/common/balanced_splitter.h"
#include "oneflow/core/job/global_for.h"

namespace oneflow {

Expand Down Expand Up @@ -63,7 +64,7 @@ void LogicalGraph::NaiveBuildFwStruct(
auto parallel_desc_ptr_it = name2parallel_desc.find(cur_op_conf.name());
CHECK(parallel_desc_ptr_it != name2parallel_desc.end());
const std::shared_ptr<ParallelDesc>& parallel_desc_ptr = parallel_desc_ptr_it->second;
cur_op_conf.set_device_type(parallel_desc_ptr->device_type());
cur_op_conf.set_device_tag(CHECK_JUST(DeviceTag4DeviceType(parallel_desc_ptr->device_type())));
std::shared_ptr<Operator> cur_op = ConstructOp(cur_op_conf, &GlobalJobDesc());
LogicalNode* cur_node = cur_op->NewProperLogicalNode();
AddAllocatedNode(cur_node);
Expand Down
3 changes: 2 additions & 1 deletion oneflow/core/graph/slice_boxing_task_node.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/framework/to_string.h"
#include "oneflow/core/graph/slice_boxing_task_node.h"

namespace oneflow {
Expand Down Expand Up @@ -102,7 +103,7 @@ void SliceBoxingTaskNode::SetOutShape(const Shape& shape) { out_shape_ = shape;

OperatorConf SliceBoxingTaskNode::GetBoxingOpConf() {
OperatorConf op_conf{};
op_conf.set_device_type(device_type());
op_conf.set_device_tag(CHECK_JUST(DeviceTag4DeviceType(device_type())));
SliceBoxingConf boxing_conf{};
*boxing_conf.mutable_lbi() = lbi_;
out_slice_.ToProto(boxing_conf.mutable_out_slice());
Expand Down
27 changes: 15 additions & 12 deletions oneflow/core/job/job_build_and_infer_ctx.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,17 +13,19 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/job/job_build_and_infer_ctx.h"
#include "oneflow/core/job_rewriter/op_graph_pass.h"
#include "oneflow/core/job_rewriter/autograd.h"
#include "oneflow/core/framework/config_def.h"
#include "oneflow/core/common/protobuf.h"
#include "oneflow/core/job/mirrored_sig_infer_hint.h"
#include "oneflow/core/job/foreign_callback.h"
#include "oneflow/core/eager/eager_symbol_storage.h"
#include "oneflow/core/framework/config_def.h"
#include "oneflow/core/framework/to_string.h"
#include "oneflow/core/job/foreign_callback.h"
#include "oneflow/core/job/job_build_and_infer_ctx.h"
#include "oneflow/core/job/mirrored_sig_infer_hint.h"
#include "oneflow/core/job/scope.h"
#include <google/protobuf/text_format.h>
#include "oneflow/core/job_rewriter/autograd.h"
#include "oneflow/core/job_rewriter/op_graph_pass.h"
#include "oneflow/user/summary/summary_converter.h"

#include <google/protobuf/text_format.h>
#include <json.hpp>

namespace oneflow {
Expand Down Expand Up @@ -505,9 +507,9 @@ Maybe<OpAttribute> JobBuildAndInferCtx::AddAndInferOp(const OperatorConf& op_con
CHECK_OR_RETURN(op_name2op_.find(op_name) == op_name2op_.end())
<< JobBuildAndInferError::kOpNameExist << "op_name: " << op_name
<< " already exist in job: " << job_->job_conf().job_name();
CHECK_NE_OR_RETURN(op_conf.device_type(), DeviceType::kInvalidDevice)
<< JobBuildAndInferError::kOpConfDeviceTypeNoSet << "op_name: " << op_name
<< " not set device type";
CHECK_NE_OR_RETURN(op_conf.device_tag(), "invalid_device")
<< JobBuildAndInferError::kOpConfDeviceTagNoSet << "op_name: " << op_name
<< " not set device tag";

op_name2op_.emplace(op_name, ConstructOp(op_conf, job_desc));
Operator* op = op_name2op_.at(op_name).get();
Expand Down Expand Up @@ -836,7 +838,7 @@ Maybe<LogicalBlobId> LazyJobBuildAndInferCtx::FindOrCreateMirroredLbiFromCompati
lbi_vec->push_back(sub_lbi);
};
OperatorConf op_conf;
op_conf.set_device_type(parallel_desc.device_type());
op_conf.set_device_tag(CHECK_JUST(DeviceTag4DeviceType(parallel_desc.device_type())));
if (sbp.has_broadcast_parallel()) {
op_conf.set_name(kAutoMirroredBlobNamePrefix + "-DistributeClone-" + NewUniqueId());
auto* distribute_clone = op_conf.mutable_distribute_clone_conf();
Expand Down Expand Up @@ -890,7 +892,8 @@ Maybe<LogicalBlobId> EagerJobBuildAndInferCtx::FindOrCreateMirroredLbiFromCompat
CHECK_OR_RETURN(producer_op_conf.has_scope_symbol_id());
op_conf.set_scope_symbol_id(producer_op_conf.scope_symbol_id());
}
op_conf.set_device_type(parallel_desc.device_type());
// const char* device_tag = JUST(DeviceTag4DeviceType(parallel_desc.device_type()));
op_conf.set_device_tag(JUST(DeviceTag4DeviceType(parallel_desc.device_type())));
op_conf.set_name(kAutoMirroredBlobNamePrefix + "-CastToMirrored-" + NewUniqueId());
auto* cast_to_mirrored_conf = op_conf.mutable_cast_to_mirrored_conf();
cast_to_mirrored_conf->set_in(lbn);
Expand Down
3 changes: 1 addition & 2 deletions oneflow/core/job/parallel_desc.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -157,10 +157,9 @@ Maybe<void> ParallelDesc::CheckWithResourceDesc(const ResourceDesc& resource_des

ParallelConf ParallelDesc::GetParallelIdOnlyParallelConf(int64_t parallel_id) const {
ParallelConf parallel_conf;
const char* device_tag = CHECK_JUST(DeviceTag4DeviceType(device_type()));
std::string machine_id = std::to_string(MachineIdForParallelId(parallel_id));
std::string device_id = std::to_string(DeviceIdForParallelId(parallel_id));
parallel_conf.set_device_tag(device_tag);
parallel_conf.set_device_tag(CHECK_JUST(DeviceTag4DeviceType(device_type())));
parallel_conf.add_device_name(machine_id + ":" + device_id);
return parallel_conf;
}
Expand Down
7 changes: 4 additions & 3 deletions oneflow/core/job/scope.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,10 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/framework/to_string.h"
#include "oneflow/core/job/scope.h"
#include "oneflow/core/vm/symbol_storage.h"
#include "oneflow/core/operator/operator.h"
#include "oneflow/core/vm/symbol_storage.h"

namespace oneflow {

Expand All @@ -42,15 +43,15 @@ Maybe<const JobDesc*> Scope::job_desc() const {
}

Maybe<int64_t> Scope::GetParallelDescSymbolId(const OperatorConf& op_conf) const {
if (op_conf.device_type() == DeviceType::kCPU || IsCpuOnly(op_conf)) {
if (op_conf.device_tag() == "cpu" || IsCpuOnly(op_conf)) {
return scope_proto_.host_parallel_desc_symbol_id();
} else {
return scope_proto_.device_parallel_desc_symbol_id();
}
}

Maybe<const ParallelDesc*> Scope::GetParallelDesc(const OperatorConf& op_conf) const {
if (op_conf.device_type() == DeviceType::kCPU || IsCpuOnly(op_conf)) {
if (op_conf.device_tag() == "cpu" || IsCpuOnly(op_conf)) {
return host_parallel_desc_.get();
} else {
return device_parallel_desc_.get();
Expand Down

0 comments on commit 120a7f8

Please sign in to comment.