Skip to content

Commit

Permalink
Implement attribute
Browse files Browse the repository at this point in the history
  • Loading branch information
reyoung committed Jun 27, 2017
1 parent dcece75 commit e815fe2
Show file tree
Hide file tree
Showing 9 changed files with 347 additions and 6 deletions.
7 changes: 5 additions & 2 deletions cmake/external/protobuf.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -39,12 +39,15 @@ macro(PROMPT_PROTOBUF_LIB)
ADD_LIBRARY(protobuf_lite ${protobuf_LIBTYPE} IMPORTED GLOBAL)
SET_PROPERTY(TARGET protobuf_lite PROPERTY IMPORTED_LOCATION ${PROTOBUF_LITE_LIBRARY})

ADD_LIBRARY(protoc ${protobuf_LIBTYPE} IMPORTED GLOBAL)
SET_PROPERTY(TARGET protoc PROPERTY IMPORTED_LOCATION ${PROTOC_LIBRARY})
ADD_LIBRARY(libprotoc ${protobuf_LIBTYPE} IMPORTED GLOBAL)
SET_PROPERTY(TARGET libprotoc PROPERTY IMPORTED_LOCATION ${PROTOC_LIBRARY})
ADD_EXECUTABLE(protoc IMPORTED GLOBAL)
SET_PROPERTY(TARGET protoc PROPERTY IMPORTED_LOCATION ${PROTOBUF_PROTOC_EXECUTABLE})

FOREACH(dep ${protobuf_DEPS})
ADD_DEPENDENCIES(protobuf ${dep})
ADD_DEPENDENCIES(protobuf_lite ${dep})
ADD_DEPENDENCIES(libprotoc ${dep})
ADD_DEPENDENCIES(protoc ${dep})
ENDFOREACH()

Expand Down
4 changes: 3 additions & 1 deletion cmake/flags.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,9 @@ set(COMMON_FLAGS
-Wno-unused-function
-Wno-error=literal-suffix
-Wno-error=sign-compare
-Wno-error=unused-local-typedefs)
-Wno-error=unused-local-typedefs
-Wno-error=ignored-qualifiers # Warning in protobuf 3 Map.h
-Wno-error=no-enum-compare) # Warning in protobuf 3 Map.h

set(GPU_COMMON_FLAGS
-fPIC
Expand Down
41 changes: 40 additions & 1 deletion cmake/generic.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ function(cc_library TARGET_NAME)
if (cc_library_DEPS)
merge_static_libs(${TARGET_NAME} ${cc_library_DEPS})
else()
message(FATAL "Please specify source file or library in cc_library.")
message(FATAL_ERROR "Please specify source file or library in cc_library.")
endif()
endif(cc_library_SRCS)
endfunction(cc_library)
Expand Down Expand Up @@ -331,3 +331,42 @@ function(go_test TARGET_NAME)
add_custom_target(${TARGET_NAME} ALL DEPENDS ${TARGET_NAME}_timestamp ${go_test_DEPS})
add_test(${TARGET_NAME} ${CMAKE_CURRENT_BINARY_DIR}/${TARGET_NAME})
endfunction(go_test)

# go_extern will download extern go project.
# go_extern(target_name extern_source)
# go_extern(go_redis github.com/hoisie/redis)
function(go_extern TARGET_NAME)
add_custom_target(${TARGET_NAME} env GOPATH=${GOPATH} ${CMAKE_Go_COMPILER} get ${ARGN})
endfunction(go_extern)


function(generate_protobuf_cpp SRCS HDRS)
set(PROTO_FILES ${ARGN})
set(${SRCS})
set(${HDRS})
foreach(FIL ${PROTO_FILES})
get_filename_component(ABS_FIL ${FIL} ABSOLUTE)
get_filename_component(FIL_WE ${FIL} NAME_WE)
if(NOT PROTOBUF_GENERATE_CPP_APPEND_PATH)
get_filename_component(FIL_DIR ${FIL} DIRECTORY)
if(FIL_DIR)
set(FIL_WE "${FIL_DIR}/${FIL_WE}")
endif()
endif()

list(APPEND ${SRCS} "${CMAKE_CURRENT_BINARY_DIR}/${FIL_WE}.pb.cc")
list(APPEND ${HDRS} "${CMAKE_CURRENT_BINARY_DIR}/${FIL_WE}.pb.h")

add_custom_command(
OUTPUT "${CMAKE_CURRENT_BINARY_DIR}/${FIL_WE}.pb.cc"
"${CMAKE_CURRENT_BINARY_DIR}/${FIL_WE}.pb.h"
COMMAND ${PROTOBUF_PROTOC_EXECUTABLE}
ARGS "--cpp_out=${DLL_EXPORT_DECL}${CMAKE_CURRENT_BINARY_DIR}" "-I" ${CMAKE_CURRENT_SOURCE_DIR} ${ABS_FIL}
DEPENDS ${ABS_FIL} protoc
COMMENT "Running C++ protocol buffer compiler on ${FIL}"
VERBATIM )
endforeach()
set_source_files_properties(${${SRCS}} ${${HDRS}} PROPERTIES GENERATED TRUE)
set(${SRCS} ${${SRCS}} PARENT_SCOPE)
set(${HDRS} ${${HDRS}} PARENT_SCOPE)
endfunction()
2 changes: 1 addition & 1 deletion paddle/api/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ set(API_SOURCES
SequenceGenerator.cpp
Trainer.cpp
Util.cpp
Vector.cpp)
Vector.cpp ../framework/attr_test.cc)
set(API_HEADER
PaddleAPI.h
Internal.h)
Expand Down
9 changes: 8 additions & 1 deletion paddle/framework/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,13 @@
cc_library(ddim SRCS ddim.cc)
cc_test(ddim_test SRCS ddim_test.cc DEPS ddim)

nv_test(dim_test SRCS dim_test.cu DEPS ddim)

cc_test(variable_test SRCS variable_test.cc)
# include generated protobuf headers
INCLUDE_DIRECTORIES(${CMAKE_CURRENT_BINARY_DIR})
generate_protobuf_cpp(attr_proto_src attr_proto_hdr attr.proto)
cc_library(attr_proto SRCS ${attr_proto_src})
generate_protobuf_cpp(attr_test_proto_src attr_test_proto_header attr_test.proto)
message(STATUS ${attr_test_proto_src})
cc_test(attr_test SRCS ${attr_test_proto_src} attr_test.cc
DEPS attr_proto protobuf glog gflags)
17 changes: 17 additions & 0 deletions paddle/framework/attr.proto
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
syntax="proto3";
package paddle.framework;

message Attribute {
message ListValue {
repeated int32 ints = 1;
repeated float floats = 2;
repeated string strings = 3;
}

oneof value {
ListValue list = 1;
int32 i = 2;
float f = 3;
string s = 4;
}
}
183 changes: 183 additions & 0 deletions paddle/framework/attr_helper.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,183 @@
#pragma once
#include <google/protobuf/map.h>
#include <paddle/framework/error.h>
#include <iterator>
#include <string>
#include <type_traits>
#include "attr.pb.h"
namespace paddle {
namespace framework {
using AttributeMap = google::protobuf::Map<std::string, Attribute>;

class AttributeReader final {
public:
explicit AttributeReader(const AttributeMap& attrs) : attrs_(attrs) {}

template <typename T>
Error __must_check Get(const std::string& attributeName, T* attr) const;

template <typename T>
Error __must_check GetArray(const std::string& attributeName,
std::vector<T>* array) const;

private:
const AttributeMap& attrs_;
};

namespace details {
template <typename Iterator, typename T>
struct SetArrayImpl {
Error __must_check operator()(AttributeMap* attrs,
const std::string& attributeName,
Iterator begin, Iterator end, bool overwrite);
};
} // namespace details

class AttributeWriter {
public:
explicit AttributeWriter(AttributeMap* attrs) : attrs_(attrs) {}

template <typename T>
Error __must_check Set(const std::string& attributeName, const T& attr,
bool overwrite = false);

template <typename Iterator>
Error __must_check SetArray(const std::string& attributeName, Iterator begin,
Iterator end, bool overwrite = false) {
return details::SetArrayImpl<
Iterator, typename std::iterator_traits<Iterator>::value_type>()(
attrs_, attributeName, begin, end, overwrite);
}

template <typename T, typename Container = std::initializer_list<T>>
Error __must_check SetArray(const std::string& attributeName,
Container container, bool overwrite = false) {
return SetArray(attributeName, container.begin(), container.end(),
overwrite);
}

private:
AttributeMap* attrs_;
};

#define ATTR_READER_IMPL_PLAIN_TYPE(T, CASE, FIELD_NAME) \
template <> \
Error __must_check AttributeReader::Get<T>(const std::string& attributeName, \
T* attr) const { \
auto it = attrs_.find(attributeName); \
if (it == attrs_.end()) { \
return Error("Attribute %s not found", attributeName.c_str()); \
} \
if (it->second.value_case() != CASE) { \
return Error("Attribute should be in field " #FIELD_NAME); \
} \
*attr = it->second.FIELD_NAME(); \
return Error(); \
}

ATTR_READER_IMPL_PLAIN_TYPE(int, Attribute::kI, i);
ATTR_READER_IMPL_PLAIN_TYPE(float, Attribute::kF, f);
ATTR_READER_IMPL_PLAIN_TYPE(std::string, Attribute::kS, s);

#undef ATTR_READER_IMPL_PLAIN_TYPE

#define ATTR_READER_IMPL_ARRAY_TYPE(T, FIELD_NAME) \
template <> \
Error __must_check AttributeReader::GetArray<T>( \
const std::string& attributeName, std::vector<T>* array) const { \
if (!array->empty()) { \
return Error("The output array must be empty."); \
} \
\
auto it = attrs_.find(attributeName); \
if (it == attrs_.end()) { \
return Error("Attribute %s not found", attributeName.c_str()); \
} \
\
auto& lst = it->second.list(); \
auto& field = lst.FIELD_NAME(); \
array->reserve(field.size()); \
std::copy(field.begin(), field.end(), std::back_inserter(*array)); \
return Error(); \
}

ATTR_READER_IMPL_ARRAY_TYPE(float, floats);
ATTR_READER_IMPL_ARRAY_TYPE(int, ints);
ATTR_READER_IMPL_ARRAY_TYPE(std::string, strings);

#undef ATTR_READER_IMPL_ARRAY_TYPE

#define ATTR_WRITER_IMPL_PLAIN_TYPE(T, FIELD_NAME) \
template <> \
Error __must_check AttributeWriter::Set<T>(const std::string& attributeName, \
const T& attr, bool overwrite) { \
auto it = attrs_->find(attributeName); \
if (it != attrs_->end() && !overwrite) { \
return Error("Attribute %s has been set", attributeName.c_str()); \
} \
(*attrs_)[attributeName].set_##FIELD_NAME(attr); \
return Error(); \
}

ATTR_WRITER_IMPL_PLAIN_TYPE(int, i);
ATTR_WRITER_IMPL_PLAIN_TYPE(float, f);
ATTR_WRITER_IMPL_PLAIN_TYPE(std::string, s);

#undef ATTR_WRITER_IMPL_PLAIN_TYPE

namespace details {
template <typename T>
void AppendToField(google::protobuf::RepeatedField<T>* field, const T& val) {
field->Add(val);
}
template <typename T>
void AppendToField(google::protobuf::RepeatedPtrField<T>* field, const T& val) {
*(field->Add()) = val;
}

} // namespace details

#define ATTR_WRITER_IMPL_ARRAY_TYPE(T, FIELD_NAME) \
namespace details { \
\
template <typename Iterator> \
struct SetArrayImpl<Iterator, T> { \
using VALUE_TYPE = typename std::iterator_traits<Iterator>::value_type; \
Error __must_check operator()(AttributeMap* attrs, \
const std::string& attributeName, \
Iterator begin, Iterator end, \
bool overwrite) { \
static_assert(std::is_same<VALUE_TYPE, T>::value, ""); \
auto it = attrs->find(attributeName); \
if (it != attrs->end() && !overwrite) { \
return Error("Attribute %s has been set", attributeName.c_str()); \
} \
\
if (it != attrs->end() && overwrite) { \
auto repeatedFieldPtr = \
it->second.mutable_list()->mutable_##FIELD_NAME(); \
repeatedFieldPtr->erase(repeatedFieldPtr->begin(), \
repeatedFieldPtr->end()); \
} \
auto lst = (*attrs)[attributeName].mutable_list(); \
auto elems = lst->mutable_##FIELD_NAME(); \
auto distance = std::distance(begin, end); \
if (std::is_integral<decltype(distance)>::value) { \
elems->Reserve(distance); \
} \
for (; begin != end; ++begin) { \
AppendToField(elems, *begin); \
} \
return Error(); \
} \
}; \
}

ATTR_WRITER_IMPL_ARRAY_TYPE(float, floats);
ATTR_WRITER_IMPL_ARRAY_TYPE(int, ints);
ATTR_WRITER_IMPL_ARRAY_TYPE(std::string, strings);

#undef ATTR_WRITER_IMPL_ARRAY_TYPE

} // namespace framework
} // namespace paddle
82 changes: 82 additions & 0 deletions paddle/framework/attr_test.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
#include <gtest/gtest.h>
#include <paddle/framework/attr_helper.h>
#include <random>
#include "attr_test.pb.h"

TEST(AttrHelper, plainTypes) {
paddle::framework::AttributeTestMessage msg;
std::random_device dev;
unsigned int seed = dev();
std::mt19937 rng(seed);
std::uniform_int_distribution<int> distInt(-1000, 1000);
std::uniform_real_distribution<float> distFloat(-1000.0f, 1000.0f);
paddle::framework::AttributeWriter writer(msg.mutable_attrs());
paddle::framework::AttributeReader reader(msg.attrs());
for (size_t i = 0; i < 1000; ++i) {
std::string key = "str_" + std::to_string(i);
switch (i % 3) {
case 0:
ASSERT_TRUE(writer.Set<int>(key, distInt(rng)).isOK());
break;
case 1:
ASSERT_TRUE(writer.Set<float>(key, distFloat(rng)).isOK());
break;
case 2:
ASSERT_TRUE(writer
.Set<std::string>(
key, "random_str_" + std::to_string(distInt(rng)))
.isOK());
break;
default:
ASSERT_TRUE(false);
}
}

std::mt19937 rng2(seed);

for (size_t i = 0; i < 1000; ++i) {
std::string key = "str_" + std::to_string(i);
int intVal;
float floatVal;
std::string strVal;

switch (i % 3) {
case 0:
ASSERT_TRUE(reader.Get<int>(key, &intVal).isOK());
ASSERT_EQ(distInt(rng2), intVal);
break;
case 1:
ASSERT_TRUE(reader.Get<float>(key, &floatVal).isOK());
ASSERT_EQ(distFloat(rng2), floatVal);
break;
case 2:
ASSERT_TRUE(reader.Get<std::string>(key, &strVal).isOK());
ASSERT_EQ("random_str_" + std::to_string(distInt(rng2)), strVal);
break;
default:
ASSERT_TRUE(false);
}
}
}

template <typename Container>
inline void TestArrayImpl(const Container& container) {
paddle::framework::AttributeTestMessage msg;
paddle::framework::AttributeWriter writer(msg.mutable_attrs());
paddle::framework::AttributeReader reader(msg.attrs());

auto err =
writer.SetArray<typename Container::value_type>("test_array", container);
ASSERT_TRUE(err.isOK());

Container tmp;
err = reader.GetArray("test_array", &tmp);
ASSERT_TRUE(err.isOK());
ASSERT_EQ(container, tmp);
}

TEST(AttrHelper, array) {
TestArrayImpl(std::vector<float>{0.7, 0.8, 0.9});
TestArrayImpl(std::vector<int>{-1, -2, 1, 2, 0});
TestArrayImpl(std::vector<std::string>{"a", "", "c", "e", "01304fksd"});
}
Loading

0 comments on commit e815fe2

Please sign in to comment.