diff --git a/cmake/external/protobuf.cmake b/cmake/external/protobuf.cmake index d43badc1da507..eb29431046ca7 100644 --- a/cmake/external/protobuf.cmake +++ b/cmake/external/protobuf.cmake @@ -39,12 +39,15 @@ macro(PROMPT_PROTOBUF_LIB) ADD_LIBRARY(protobuf_lite ${protobuf_LIBTYPE} IMPORTED GLOBAL) SET_PROPERTY(TARGET protobuf_lite PROPERTY IMPORTED_LOCATION ${PROTOBUF_LITE_LIBRARY}) - ADD_LIBRARY(protoc ${protobuf_LIBTYPE} IMPORTED GLOBAL) - SET_PROPERTY(TARGET protoc PROPERTY IMPORTED_LOCATION ${PROTOC_LIBRARY}) + ADD_LIBRARY(libprotoc ${protobuf_LIBTYPE} IMPORTED GLOBAL) + SET_PROPERTY(TARGET libprotoc PROPERTY IMPORTED_LOCATION ${PROTOC_LIBRARY}) + ADD_EXECUTABLE(protoc IMPORTED GLOBAL) + SET_PROPERTY(TARGET protoc PROPERTY IMPORTED_LOCATION ${PROTOBUF_PROTOC_EXECUTABLE}) FOREACH(dep ${protobuf_DEPS}) ADD_DEPENDENCIES(protobuf ${dep}) ADD_DEPENDENCIES(protobuf_lite ${dep}) + ADD_DEPENDENCIES(libprotoc ${dep}) ADD_DEPENDENCIES(protoc ${dep}) ENDFOREACH() diff --git a/cmake/flags.cmake b/cmake/flags.cmake index 7a996dea92b13..4c22db1d82285 100644 --- a/cmake/flags.cmake +++ b/cmake/flags.cmake @@ -109,7 +109,9 @@ set(COMMON_FLAGS -Wno-unused-function -Wno-error=literal-suffix -Wno-error=sign-compare - -Wno-error=unused-local-typedefs) + -Wno-error=unused-local-typedefs + -Wno-error=ignored-qualifiers # Warning in protobuf 3 Map.h + -Wno-error=no-enum-compare) # Warning in protobuf 3 Map.h set(GPU_COMMON_FLAGS -fPIC diff --git a/cmake/generic.cmake b/cmake/generic.cmake index 6839abc1a7bce..58b2f0a9c8926 100644 --- a/cmake/generic.cmake +++ b/cmake/generic.cmake @@ -171,7 +171,7 @@ function(cc_library TARGET_NAME) if (cc_library_DEPS) merge_static_libs(${TARGET_NAME} ${cc_library_DEPS}) else() - message(FATAL "Please specify source file or library in cc_library.") + message(FATAL_ERROR "Please specify source file or library in cc_library.") endif() endif(cc_library_SRCS) endfunction(cc_library) @@ -331,3 +331,42 @@ function(go_test TARGET_NAME) add_custom_target(${TARGET_NAME} ALL DEPENDS ${TARGET_NAME}_timestamp ${go_test_DEPS}) add_test(${TARGET_NAME} ${CMAKE_CURRENT_BINARY_DIR}/${TARGET_NAME}) endfunction(go_test) + +# go_extern will download extern go project. +# go_extern(target_name extern_source) +# go_extern(go_redis github.com/hoisie/redis) +function(go_extern TARGET_NAME) + add_custom_target(${TARGET_NAME} env GOPATH=${GOPATH} ${CMAKE_Go_COMPILER} get ${ARGN}) +endfunction(go_extern) + + +function(generate_protobuf_cpp SRCS HDRS) + set(PROTO_FILES ${ARGN}) + set(${SRCS}) + set(${HDRS}) + foreach(FIL ${PROTO_FILES}) + get_filename_component(ABS_FIL ${FIL} ABSOLUTE) + get_filename_component(FIL_WE ${FIL} NAME_WE) + if(NOT PROTOBUF_GENERATE_CPP_APPEND_PATH) + get_filename_component(FIL_DIR ${FIL} DIRECTORY) + if(FIL_DIR) + set(FIL_WE "${FIL_DIR}/${FIL_WE}") + endif() + endif() + + list(APPEND ${SRCS} "${CMAKE_CURRENT_BINARY_DIR}/${FIL_WE}.pb.cc") + list(APPEND ${HDRS} "${CMAKE_CURRENT_BINARY_DIR}/${FIL_WE}.pb.h") + + add_custom_command( + OUTPUT "${CMAKE_CURRENT_BINARY_DIR}/${FIL_WE}.pb.cc" + "${CMAKE_CURRENT_BINARY_DIR}/${FIL_WE}.pb.h" + COMMAND ${PROTOBUF_PROTOC_EXECUTABLE} + ARGS "--cpp_out=${DLL_EXPORT_DECL}${CMAKE_CURRENT_BINARY_DIR}" "-I" ${CMAKE_CURRENT_SOURCE_DIR} ${ABS_FIL} + DEPENDS ${ABS_FIL} protoc + COMMENT "Running C++ protocol buffer compiler on ${FIL}" + VERBATIM ) + endforeach() + set_source_files_properties(${${SRCS}} ${${HDRS}} PROPERTIES GENERATED TRUE) + set(${SRCS} ${${SRCS}} PARENT_SCOPE) + set(${HDRS} ${${HDRS}} PARENT_SCOPE) +endfunction() diff --git a/paddle/api/CMakeLists.txt b/paddle/api/CMakeLists.txt index f2315e31cc06d..b57b74fde6319 100644 --- a/paddle/api/CMakeLists.txt +++ b/paddle/api/CMakeLists.txt @@ -10,7 +10,7 @@ set(API_SOURCES SequenceGenerator.cpp Trainer.cpp Util.cpp - Vector.cpp) + Vector.cpp ../framework/attr_test.cc) set(API_HEADER PaddleAPI.h Internal.h) diff --git a/paddle/framework/CMakeLists.txt b/paddle/framework/CMakeLists.txt index e3c3155aa902c..301c47ace44fe 100644 --- a/paddle/framework/CMakeLists.txt +++ b/paddle/framework/CMakeLists.txt @@ -1,6 +1,13 @@ cc_library(ddim SRCS ddim.cc) cc_test(ddim_test SRCS ddim_test.cc DEPS ddim) - nv_test(dim_test SRCS dim_test.cu DEPS ddim) cc_test(variable_test SRCS variable_test.cc) +# include generated protobuf headers +INCLUDE_DIRECTORIES(${CMAKE_CURRENT_BINARY_DIR}) +generate_protobuf_cpp(attr_proto_src attr_proto_hdr attr.proto) +cc_library(attr_proto SRCS ${attr_proto_src}) +generate_protobuf_cpp(attr_test_proto_src attr_test_proto_header attr_test.proto) +message(STATUS ${attr_test_proto_src}) +cc_test(attr_test SRCS ${attr_test_proto_src} attr_test.cc + DEPS attr_proto protobuf glog gflags) diff --git a/paddle/framework/attr.proto b/paddle/framework/attr.proto new file mode 100644 index 0000000000000..9df00f064fb7a --- /dev/null +++ b/paddle/framework/attr.proto @@ -0,0 +1,17 @@ +syntax="proto3"; +package paddle.framework; + +message Attribute { + message ListValue { + repeated int32 ints = 1; + repeated float floats = 2; + repeated string strings = 3; + } + + oneof value { + ListValue list = 1; + int32 i = 2; + float f = 3; + string s = 4; + } +} \ No newline at end of file diff --git a/paddle/framework/attr_helper.h b/paddle/framework/attr_helper.h new file mode 100644 index 0000000000000..ba346abe1475d --- /dev/null +++ b/paddle/framework/attr_helper.h @@ -0,0 +1,183 @@ +#pragma once +#include +#include +#include +#include +#include +#include "attr.pb.h" +namespace paddle { +namespace framework { +using AttributeMap = google::protobuf::Map; + +class AttributeReader final { + public: + explicit AttributeReader(const AttributeMap& attrs) : attrs_(attrs) {} + + template + Error __must_check Get(const std::string& attributeName, T* attr) const; + + template + Error __must_check GetArray(const std::string& attributeName, + std::vector* array) const; + + private: + const AttributeMap& attrs_; +}; + +namespace details { +template +struct SetArrayImpl { + Error __must_check operator()(AttributeMap* attrs, + const std::string& attributeName, + Iterator begin, Iterator end, bool overwrite); +}; +} // namespace details + +class AttributeWriter { + public: + explicit AttributeWriter(AttributeMap* attrs) : attrs_(attrs) {} + + template + Error __must_check Set(const std::string& attributeName, const T& attr, + bool overwrite = false); + + template + Error __must_check SetArray(const std::string& attributeName, Iterator begin, + Iterator end, bool overwrite = false) { + return details::SetArrayImpl< + Iterator, typename std::iterator_traits::value_type>()( + attrs_, attributeName, begin, end, overwrite); + } + + template > + Error __must_check SetArray(const std::string& attributeName, + Container container, bool overwrite = false) { + return SetArray(attributeName, container.begin(), container.end(), + overwrite); + } + + private: + AttributeMap* attrs_; +}; + +#define ATTR_READER_IMPL_PLAIN_TYPE(T, CASE, FIELD_NAME) \ + template <> \ + Error __must_check AttributeReader::Get(const std::string& attributeName, \ + T* attr) const { \ + auto it = attrs_.find(attributeName); \ + if (it == attrs_.end()) { \ + return Error("Attribute %s not found", attributeName.c_str()); \ + } \ + if (it->second.value_case() != CASE) { \ + return Error("Attribute should be in field " #FIELD_NAME); \ + } \ + *attr = it->second.FIELD_NAME(); \ + return Error(); \ + } + +ATTR_READER_IMPL_PLAIN_TYPE(int, Attribute::kI, i); +ATTR_READER_IMPL_PLAIN_TYPE(float, Attribute::kF, f); +ATTR_READER_IMPL_PLAIN_TYPE(std::string, Attribute::kS, s); + +#undef ATTR_READER_IMPL_PLAIN_TYPE + +#define ATTR_READER_IMPL_ARRAY_TYPE(T, FIELD_NAME) \ + template <> \ + Error __must_check AttributeReader::GetArray( \ + const std::string& attributeName, std::vector* array) const { \ + if (!array->empty()) { \ + return Error("The output array must be empty."); \ + } \ + \ + auto it = attrs_.find(attributeName); \ + if (it == attrs_.end()) { \ + return Error("Attribute %s not found", attributeName.c_str()); \ + } \ + \ + auto& lst = it->second.list(); \ + auto& field = lst.FIELD_NAME(); \ + array->reserve(field.size()); \ + std::copy(field.begin(), field.end(), std::back_inserter(*array)); \ + return Error(); \ + } + +ATTR_READER_IMPL_ARRAY_TYPE(float, floats); +ATTR_READER_IMPL_ARRAY_TYPE(int, ints); +ATTR_READER_IMPL_ARRAY_TYPE(std::string, strings); + +#undef ATTR_READER_IMPL_ARRAY_TYPE + +#define ATTR_WRITER_IMPL_PLAIN_TYPE(T, FIELD_NAME) \ + template <> \ + Error __must_check AttributeWriter::Set(const std::string& attributeName, \ + const T& attr, bool overwrite) { \ + auto it = attrs_->find(attributeName); \ + if (it != attrs_->end() && !overwrite) { \ + return Error("Attribute %s has been set", attributeName.c_str()); \ + } \ + (*attrs_)[attributeName].set_##FIELD_NAME(attr); \ + return Error(); \ + } + +ATTR_WRITER_IMPL_PLAIN_TYPE(int, i); +ATTR_WRITER_IMPL_PLAIN_TYPE(float, f); +ATTR_WRITER_IMPL_PLAIN_TYPE(std::string, s); + +#undef ATTR_WRITER_IMPL_PLAIN_TYPE + +namespace details { +template +void AppendToField(google::protobuf::RepeatedField* field, const T& val) { + field->Add(val); +} +template +void AppendToField(google::protobuf::RepeatedPtrField* field, const T& val) { + *(field->Add()) = val; +} + +} // namespace details + +#define ATTR_WRITER_IMPL_ARRAY_TYPE(T, FIELD_NAME) \ + namespace details { \ + \ + template \ + struct SetArrayImpl { \ + using VALUE_TYPE = typename std::iterator_traits::value_type; \ + Error __must_check operator()(AttributeMap* attrs, \ + const std::string& attributeName, \ + Iterator begin, Iterator end, \ + bool overwrite) { \ + static_assert(std::is_same::value, ""); \ + auto it = attrs->find(attributeName); \ + if (it != attrs->end() && !overwrite) { \ + return Error("Attribute %s has been set", attributeName.c_str()); \ + } \ + \ + if (it != attrs->end() && overwrite) { \ + auto repeatedFieldPtr = \ + it->second.mutable_list()->mutable_##FIELD_NAME(); \ + repeatedFieldPtr->erase(repeatedFieldPtr->begin(), \ + repeatedFieldPtr->end()); \ + } \ + auto lst = (*attrs)[attributeName].mutable_list(); \ + auto elems = lst->mutable_##FIELD_NAME(); \ + auto distance = std::distance(begin, end); \ + if (std::is_integral::value) { \ + elems->Reserve(distance); \ + } \ + for (; begin != end; ++begin) { \ + AppendToField(elems, *begin); \ + } \ + return Error(); \ + } \ + }; \ + } + +ATTR_WRITER_IMPL_ARRAY_TYPE(float, floats); +ATTR_WRITER_IMPL_ARRAY_TYPE(int, ints); +ATTR_WRITER_IMPL_ARRAY_TYPE(std::string, strings); + +#undef ATTR_WRITER_IMPL_ARRAY_TYPE + +} // namespace framework +} // namespace paddle diff --git a/paddle/framework/attr_test.cc b/paddle/framework/attr_test.cc new file mode 100644 index 0000000000000..2f15a048b40f7 --- /dev/null +++ b/paddle/framework/attr_test.cc @@ -0,0 +1,82 @@ +#include +#include +#include +#include "attr_test.pb.h" + +TEST(AttrHelper, plainTypes) { + paddle::framework::AttributeTestMessage msg; + std::random_device dev; + unsigned int seed = dev(); + std::mt19937 rng(seed); + std::uniform_int_distribution distInt(-1000, 1000); + std::uniform_real_distribution distFloat(-1000.0f, 1000.0f); + paddle::framework::AttributeWriter writer(msg.mutable_attrs()); + paddle::framework::AttributeReader reader(msg.attrs()); + for (size_t i = 0; i < 1000; ++i) { + std::string key = "str_" + std::to_string(i); + switch (i % 3) { + case 0: + ASSERT_TRUE(writer.Set(key, distInt(rng)).isOK()); + break; + case 1: + ASSERT_TRUE(writer.Set(key, distFloat(rng)).isOK()); + break; + case 2: + ASSERT_TRUE(writer + .Set( + key, "random_str_" + std::to_string(distInt(rng))) + .isOK()); + break; + default: + ASSERT_TRUE(false); + } + } + + std::mt19937 rng2(seed); + + for (size_t i = 0; i < 1000; ++i) { + std::string key = "str_" + std::to_string(i); + int intVal; + float floatVal; + std::string strVal; + + switch (i % 3) { + case 0: + ASSERT_TRUE(reader.Get(key, &intVal).isOK()); + ASSERT_EQ(distInt(rng2), intVal); + break; + case 1: + ASSERT_TRUE(reader.Get(key, &floatVal).isOK()); + ASSERT_EQ(distFloat(rng2), floatVal); + break; + case 2: + ASSERT_TRUE(reader.Get(key, &strVal).isOK()); + ASSERT_EQ("random_str_" + std::to_string(distInt(rng2)), strVal); + break; + default: + ASSERT_TRUE(false); + } + } +} + +template +inline void TestArrayImpl(const Container& container) { + paddle::framework::AttributeTestMessage msg; + paddle::framework::AttributeWriter writer(msg.mutable_attrs()); + paddle::framework::AttributeReader reader(msg.attrs()); + + auto err = + writer.SetArray("test_array", container); + ASSERT_TRUE(err.isOK()); + + Container tmp; + err = reader.GetArray("test_array", &tmp); + ASSERT_TRUE(err.isOK()); + ASSERT_EQ(container, tmp); +} + +TEST(AttrHelper, array) { + TestArrayImpl(std::vector{0.7, 0.8, 0.9}); + TestArrayImpl(std::vector{-1, -2, 1, 2, 0}); + TestArrayImpl(std::vector{"a", "", "c", "e", "01304fksd"}); +} \ No newline at end of file diff --git a/paddle/framework/attr_test.proto b/paddle/framework/attr_test.proto new file mode 100644 index 0000000000000..28e9615b05e44 --- /dev/null +++ b/paddle/framework/attr_test.proto @@ -0,0 +1,8 @@ +syntax="proto3"; +package paddle.framework; + +import "attr.proto"; + +message AttributeTestMessage { + map attrs = 1; +} \ No newline at end of file