-
Notifications
You must be signed in to change notification settings - Fork 5.5k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
9 changed files
with
347 additions
and
6 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,13 @@ | ||
cc_library(ddim SRCS ddim.cc) | ||
cc_test(ddim_test SRCS ddim_test.cc DEPS ddim) | ||
|
||
nv_test(dim_test SRCS dim_test.cu DEPS ddim) | ||
|
||
cc_test(variable_test SRCS variable_test.cc) | ||
# include generated protobuf headers | ||
INCLUDE_DIRECTORIES(${CMAKE_CURRENT_BINARY_DIR}) | ||
generate_protobuf_cpp(attr_proto_src attr_proto_hdr attr.proto) | ||
cc_library(attr_proto SRCS ${attr_proto_src}) | ||
generate_protobuf_cpp(attr_test_proto_src attr_test_proto_header attr_test.proto) | ||
message(STATUS ${attr_test_proto_src}) | ||
cc_test(attr_test SRCS ${attr_test_proto_src} attr_test.cc | ||
DEPS attr_proto protobuf glog gflags) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
syntax="proto3"; | ||
package paddle.framework; | ||
|
||
message Attribute { | ||
message ListValue { | ||
repeated int32 ints = 1; | ||
repeated float floats = 2; | ||
repeated string strings = 3; | ||
} | ||
|
||
oneof value { | ||
ListValue list = 1; | ||
int32 i = 2; | ||
float f = 3; | ||
string s = 4; | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,183 @@ | ||
#pragma once | ||
#include <google/protobuf/map.h> | ||
#include <paddle/framework/error.h> | ||
#include <iterator> | ||
#include <string> | ||
#include <type_traits> | ||
#include "attr.pb.h" | ||
namespace paddle { | ||
namespace framework { | ||
using AttributeMap = google::protobuf::Map<std::string, Attribute>; | ||
|
||
class AttributeReader final { | ||
public: | ||
explicit AttributeReader(const AttributeMap& attrs) : attrs_(attrs) {} | ||
|
||
template <typename T> | ||
Error __must_check Get(const std::string& attributeName, T* attr) const; | ||
|
||
template <typename T> | ||
Error __must_check GetArray(const std::string& attributeName, | ||
std::vector<T>* array) const; | ||
|
||
private: | ||
const AttributeMap& attrs_; | ||
}; | ||
|
||
namespace details { | ||
template <typename Iterator, typename T> | ||
struct SetArrayImpl { | ||
Error __must_check operator()(AttributeMap* attrs, | ||
const std::string& attributeName, | ||
Iterator begin, Iterator end, bool overwrite); | ||
}; | ||
} // namespace details | ||
|
||
class AttributeWriter { | ||
public: | ||
explicit AttributeWriter(AttributeMap* attrs) : attrs_(attrs) {} | ||
|
||
template <typename T> | ||
Error __must_check Set(const std::string& attributeName, const T& attr, | ||
bool overwrite = false); | ||
|
||
template <typename Iterator> | ||
Error __must_check SetArray(const std::string& attributeName, Iterator begin, | ||
Iterator end, bool overwrite = false) { | ||
return details::SetArrayImpl< | ||
Iterator, typename std::iterator_traits<Iterator>::value_type>()( | ||
attrs_, attributeName, begin, end, overwrite); | ||
} | ||
|
||
template <typename T, typename Container = std::initializer_list<T>> | ||
Error __must_check SetArray(const std::string& attributeName, | ||
Container container, bool overwrite = false) { | ||
return SetArray(attributeName, container.begin(), container.end(), | ||
overwrite); | ||
} | ||
|
||
private: | ||
AttributeMap* attrs_; | ||
}; | ||
|
||
#define ATTR_READER_IMPL_PLAIN_TYPE(T, CASE, FIELD_NAME) \ | ||
template <> \ | ||
Error __must_check AttributeReader::Get<T>(const std::string& attributeName, \ | ||
T* attr) const { \ | ||
auto it = attrs_.find(attributeName); \ | ||
if (it == attrs_.end()) { \ | ||
return Error("Attribute %s not found", attributeName.c_str()); \ | ||
} \ | ||
if (it->second.value_case() != CASE) { \ | ||
return Error("Attribute should be in field " #FIELD_NAME); \ | ||
} \ | ||
*attr = it->second.FIELD_NAME(); \ | ||
return Error(); \ | ||
} | ||
|
||
ATTR_READER_IMPL_PLAIN_TYPE(int, Attribute::kI, i); | ||
ATTR_READER_IMPL_PLAIN_TYPE(float, Attribute::kF, f); | ||
ATTR_READER_IMPL_PLAIN_TYPE(std::string, Attribute::kS, s); | ||
|
||
#undef ATTR_READER_IMPL_PLAIN_TYPE | ||
|
||
#define ATTR_READER_IMPL_ARRAY_TYPE(T, FIELD_NAME) \ | ||
template <> \ | ||
Error __must_check AttributeReader::GetArray<T>( \ | ||
const std::string& attributeName, std::vector<T>* array) const { \ | ||
if (!array->empty()) { \ | ||
return Error("The output array must be empty."); \ | ||
} \ | ||
\ | ||
auto it = attrs_.find(attributeName); \ | ||
if (it == attrs_.end()) { \ | ||
return Error("Attribute %s not found", attributeName.c_str()); \ | ||
} \ | ||
\ | ||
auto& lst = it->second.list(); \ | ||
auto& field = lst.FIELD_NAME(); \ | ||
array->reserve(field.size()); \ | ||
std::copy(field.begin(), field.end(), std::back_inserter(*array)); \ | ||
return Error(); \ | ||
} | ||
|
||
ATTR_READER_IMPL_ARRAY_TYPE(float, floats); | ||
ATTR_READER_IMPL_ARRAY_TYPE(int, ints); | ||
ATTR_READER_IMPL_ARRAY_TYPE(std::string, strings); | ||
|
||
#undef ATTR_READER_IMPL_ARRAY_TYPE | ||
|
||
#define ATTR_WRITER_IMPL_PLAIN_TYPE(T, FIELD_NAME) \ | ||
template <> \ | ||
Error __must_check AttributeWriter::Set<T>(const std::string& attributeName, \ | ||
const T& attr, bool overwrite) { \ | ||
auto it = attrs_->find(attributeName); \ | ||
if (it != attrs_->end() && !overwrite) { \ | ||
return Error("Attribute %s has been set", attributeName.c_str()); \ | ||
} \ | ||
(*attrs_)[attributeName].set_##FIELD_NAME(attr); \ | ||
return Error(); \ | ||
} | ||
|
||
ATTR_WRITER_IMPL_PLAIN_TYPE(int, i); | ||
ATTR_WRITER_IMPL_PLAIN_TYPE(float, f); | ||
ATTR_WRITER_IMPL_PLAIN_TYPE(std::string, s); | ||
|
||
#undef ATTR_WRITER_IMPL_PLAIN_TYPE | ||
|
||
namespace details { | ||
template <typename T> | ||
void AppendToField(google::protobuf::RepeatedField<T>* field, const T& val) { | ||
field->Add(val); | ||
} | ||
template <typename T> | ||
void AppendToField(google::protobuf::RepeatedPtrField<T>* field, const T& val) { | ||
*(field->Add()) = val; | ||
} | ||
|
||
} // namespace details | ||
|
||
#define ATTR_WRITER_IMPL_ARRAY_TYPE(T, FIELD_NAME) \ | ||
namespace details { \ | ||
\ | ||
template <typename Iterator> \ | ||
struct SetArrayImpl<Iterator, T> { \ | ||
using VALUE_TYPE = typename std::iterator_traits<Iterator>::value_type; \ | ||
Error __must_check operator()(AttributeMap* attrs, \ | ||
const std::string& attributeName, \ | ||
Iterator begin, Iterator end, \ | ||
bool overwrite) { \ | ||
static_assert(std::is_same<VALUE_TYPE, T>::value, ""); \ | ||
auto it = attrs->find(attributeName); \ | ||
if (it != attrs->end() && !overwrite) { \ | ||
return Error("Attribute %s has been set", attributeName.c_str()); \ | ||
} \ | ||
\ | ||
if (it != attrs->end() && overwrite) { \ | ||
auto repeatedFieldPtr = \ | ||
it->second.mutable_list()->mutable_##FIELD_NAME(); \ | ||
repeatedFieldPtr->erase(repeatedFieldPtr->begin(), \ | ||
repeatedFieldPtr->end()); \ | ||
} \ | ||
auto lst = (*attrs)[attributeName].mutable_list(); \ | ||
auto elems = lst->mutable_##FIELD_NAME(); \ | ||
auto distance = std::distance(begin, end); \ | ||
if (std::is_integral<decltype(distance)>::value) { \ | ||
elems->Reserve(distance); \ | ||
} \ | ||
for (; begin != end; ++begin) { \ | ||
AppendToField(elems, *begin); \ | ||
} \ | ||
return Error(); \ | ||
} \ | ||
}; \ | ||
} | ||
|
||
ATTR_WRITER_IMPL_ARRAY_TYPE(float, floats); | ||
ATTR_WRITER_IMPL_ARRAY_TYPE(int, ints); | ||
ATTR_WRITER_IMPL_ARRAY_TYPE(std::string, strings); | ||
|
||
#undef ATTR_WRITER_IMPL_ARRAY_TYPE | ||
|
||
} // namespace framework | ||
} // namespace paddle |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,82 @@ | ||
#include <gtest/gtest.h> | ||
#include <paddle/framework/attr_helper.h> | ||
#include <random> | ||
#include "attr_test.pb.h" | ||
|
||
TEST(AttrHelper, plainTypes) { | ||
paddle::framework::AttributeTestMessage msg; | ||
std::random_device dev; | ||
unsigned int seed = dev(); | ||
std::mt19937 rng(seed); | ||
std::uniform_int_distribution<int> distInt(-1000, 1000); | ||
std::uniform_real_distribution<float> distFloat(-1000.0f, 1000.0f); | ||
paddle::framework::AttributeWriter writer(msg.mutable_attrs()); | ||
paddle::framework::AttributeReader reader(msg.attrs()); | ||
for (size_t i = 0; i < 1000; ++i) { | ||
std::string key = "str_" + std::to_string(i); | ||
switch (i % 3) { | ||
case 0: | ||
ASSERT_TRUE(writer.Set<int>(key, distInt(rng)).isOK()); | ||
break; | ||
case 1: | ||
ASSERT_TRUE(writer.Set<float>(key, distFloat(rng)).isOK()); | ||
break; | ||
case 2: | ||
ASSERT_TRUE(writer | ||
.Set<std::string>( | ||
key, "random_str_" + std::to_string(distInt(rng))) | ||
.isOK()); | ||
break; | ||
default: | ||
ASSERT_TRUE(false); | ||
} | ||
} | ||
|
||
std::mt19937 rng2(seed); | ||
|
||
for (size_t i = 0; i < 1000; ++i) { | ||
std::string key = "str_" + std::to_string(i); | ||
int intVal; | ||
float floatVal; | ||
std::string strVal; | ||
|
||
switch (i % 3) { | ||
case 0: | ||
ASSERT_TRUE(reader.Get<int>(key, &intVal).isOK()); | ||
ASSERT_EQ(distInt(rng2), intVal); | ||
break; | ||
case 1: | ||
ASSERT_TRUE(reader.Get<float>(key, &floatVal).isOK()); | ||
ASSERT_EQ(distFloat(rng2), floatVal); | ||
break; | ||
case 2: | ||
ASSERT_TRUE(reader.Get<std::string>(key, &strVal).isOK()); | ||
ASSERT_EQ("random_str_" + std::to_string(distInt(rng2)), strVal); | ||
break; | ||
default: | ||
ASSERT_TRUE(false); | ||
} | ||
} | ||
} | ||
|
||
template <typename Container> | ||
inline void TestArrayImpl(const Container& container) { | ||
paddle::framework::AttributeTestMessage msg; | ||
paddle::framework::AttributeWriter writer(msg.mutable_attrs()); | ||
paddle::framework::AttributeReader reader(msg.attrs()); | ||
|
||
auto err = | ||
writer.SetArray<typename Container::value_type>("test_array", container); | ||
ASSERT_TRUE(err.isOK()); | ||
|
||
Container tmp; | ||
err = reader.GetArray("test_array", &tmp); | ||
ASSERT_TRUE(err.isOK()); | ||
ASSERT_EQ(container, tmp); | ||
} | ||
|
||
TEST(AttrHelper, array) { | ||
TestArrayImpl(std::vector<float>{0.7, 0.8, 0.9}); | ||
TestArrayImpl(std::vector<int>{-1, -2, 1, 2, 0}); | ||
TestArrayImpl(std::vector<std::string>{"a", "", "c", "e", "01304fksd"}); | ||
} |
Oops, something went wrong.