Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 66 additions & 9 deletions src/fury/encoder/row_encode_trait.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
#include "fury/meta/type_traits.h"
#include "fury/row/writer.h"
#include <string_view>
#include <type_traits>
#include <utility>

namespace fury {
Expand Down Expand Up @@ -67,7 +68,19 @@ inline constexpr bool IsString =
meta::IsOneOf<T, std::string, std::string_view>::value;

template <typename T>
inline constexpr bool IsClassButNotBuiltin = std::is_class_v<T> && !IsString<T>;
inline constexpr bool IsArray = meta::IsIterable<T> && !IsString<T>;

template <typename T>
inline constexpr bool IsClassButNotBuiltin =
std::is_class_v<T> && !(IsString<T> || IsArray<T>);

inline decltype(auto) GetChildType(RowWriter &writer, int index) {
return writer.schema()->field(index)->type();
}

inline decltype(auto) GetChildType(ArrayWriter &writer, int index) {
return writer.type()->field(0)->type();
}

} // namespace details

Expand Down Expand Up @@ -106,8 +119,10 @@ struct RowEncodeTrait<
return details::ArrowSchemaBasicType<std::remove_cv_t<T>>::value();
}

template <typename V>
static void Write(V &&, const T &value, RowWriter &writer, int index) {
template <typename V, typename W,
std::enable_if_t<meta::IsOneOf<W, RowWriter, ArrayWriter>::value,
int> = 0>
static void Write(V &&, const T &value, W &writer, int index) {
writer.Write(index, value);
}
};
Expand All @@ -117,8 +132,10 @@ struct RowEncodeTrait<
T, std::enable_if_t<details::IsString<std::remove_cv_t<T>>>> {
static auto Type() { return arrow::utf8(); }

template <typename V>
static void Write(V &&, const T &value, RowWriter &writer, int index) {
template <typename V, typename W,
std::enable_if_t<meta::IsOneOf<W, RowWriter, ArrayWriter>::value,
int> = 0>
static void Write(V &&, const T &value, W &writer, int index) {
writer.WriteString(index, value);
}
};
Expand Down Expand Up @@ -165,13 +182,14 @@ struct RowEncodeTrait<
std::make_index_sequence<FieldInfo::Size>());
}

template <typename V>
static void Write(V &&visitor, const T &value, RowWriter &writer, int index) {
template <typename V, typename W,
std::enable_if_t<meta::IsOneOf<W, RowWriter, ArrayWriter>::value,
int> = 0>
static void Write(V &&visitor, const T &value, W &writer, int index) {
auto offset = writer.cursor();

auto inner_writer = std::make_unique<RowWriter>(
arrow::schema(writer.schema()->field(index)->type()->fields()),
&writer);
arrow::schema(details::GetChildType(writer, index)->fields()), &writer);

inner_writer->Reset();
RowEncodeTrait<T>::Write(std::forward<V>(visitor), value,
Expand All @@ -184,6 +202,45 @@ struct RowEncodeTrait<
}
};

template <typename T>
struct RowEncodeTrait<T,
std::enable_if_t<details::IsArray<std::remove_cv_t<T>>>> {
static auto Type() {
return arrow::list(RowEncodeTrait<meta::GetValueType<T>>::Type());
}

template <typename V>
static void Write(V &&visitor, const T &value, ArrayWriter &writer) {
int index = 0;
for (const auto &v : value) {
RowEncodeTrait<meta::GetValueType<T>>::Write(std::forward<V>(visitor), v,
writer, index);
++index;
}
}

template <typename V, typename W,
std::enable_if_t<meta::IsOneOf<W, RowWriter, ArrayWriter>::value,
int> = 0>
static void Write(V &&visitor, const T &value, W &writer, int index) {
auto offset = writer.cursor();

auto inner_writer = std::make_unique<ArrayWriter>(
std::dynamic_pointer_cast<arrow::ListType>(
details::GetChildType(writer, index)),
&writer);

inner_writer->Reset(value.size());
RowEncodeTrait<T>::Write(std::forward<V>(visitor), value,
*inner_writer.get());

writer.SetOffsetAndSize(index, offset, writer.cursor() - offset);

std::forward<V>(visitor).template Visit<std::remove_cv_t<T>>(
std::move(inner_writer));
}
};

} // namespace encoder

} // namespace fury
103 changes: 102 additions & 1 deletion src/fury/encoder/row_encode_trait_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,11 @@
*/

#include "gtest/gtest.h"
#include <memory>
#include <type_traits>

#include "fury/encoder/row_encode_trait.h"
#include "src/fury/row/writer.h"
#include "fury/row/writer.h"

namespace fury {

Expand Down Expand Up @@ -146,6 +147,106 @@ TEST(RowEncodeTrait, NestedStruct) {
ASSERT_EQ(y_schema->field(2)->type()->name(), "bool");
}

TEST(RowEncodeTrait, SimpleArray) {
std::vector<int> a{10, 20, 30};

auto type = encoder::RowEncodeTrait<decltype(a)>::Type();

ASSERT_EQ(type->name(), "list");
ASSERT_EQ(type->field(0)->type()->name(), "int32");

ArrayWriter writer(std::dynamic_pointer_cast<arrow::ListType>(type));
writer.Reset(a.size());

encoder::RowEncodeTrait<decltype(a)>::Write(encoder::EmptyWriteVisitor{}, a,
writer);

auto array = writer.CopyToArrayData();
ASSERT_EQ(array->GetInt32(0), 10);
ASSERT_EQ(array->GetInt32(1), 20);
ASSERT_EQ(array->GetInt32(2), 30);
}

TEST(RowEncodeTrait, StructInArray) {
std::vector<A> a{{233, 1.1, false}, {234, 3.14, true}};

auto type = encoder::RowEncodeTrait<decltype(a)>::Type();

ASSERT_EQ(type->name(), "list");
ASSERT_EQ(type->field(0)->type()->name(), "struct");

ArrayWriter writer(std::dynamic_pointer_cast<arrow::ListType>(type));
writer.Reset(a.size());

encoder::RowEncodeTrait<decltype(a)>::Write(encoder::EmptyWriteVisitor{}, a,
writer);

auto array = writer.CopyToArrayData();

auto row1 = array->GetStruct(0);
ASSERT_EQ(row1->GetInt32(0), 233);
ASSERT_FLOAT_EQ(row1->GetFloat(1), 1.1);
ASSERT_EQ(row1->GetBoolean(2), false);

auto row2 = array->GetStruct(1);
ASSERT_EQ(row2->GetInt32(0), 234);
ASSERT_FLOAT_EQ(row2->GetFloat(1), 3.14);
ASSERT_EQ(row2->GetBoolean(2), true);
}

struct E {
int a;
std::vector<int> b;
};

FURY_FIELD_INFO(E, a, b);

TEST(RowEncodeTrait, ArrayInStruct) {
E e{233, {10, 20, 30}};

auto type = encoder::RowEncodeTrait<decltype(e)>::Type();

ASSERT_EQ(type->name(), "struct");
ASSERT_EQ(type->field(0)->type()->name(), "int32");
ASSERT_EQ(type->field(1)->type()->name(), "list");

RowWriter writer(encoder::RowEncodeTrait<decltype(e)>::Schema());
writer.Reset();

encoder::RowEncodeTrait<decltype(e)>::Write(encoder::EmptyWriteVisitor{}, e,
writer);

auto row = writer.ToRow();
ASSERT_EQ(row->GetInt32(0), 233);

ASSERT_EQ(row->GetArray(1)->GetInt32(0), 10);
ASSERT_EQ(row->GetArray(1)->GetInt32(1), 20);
ASSERT_EQ(row->GetArray(1)->GetInt32(2), 30);
}

TEST(RowEncodeTrait, ArrayInArray) {
std::vector<std::vector<int>> a{{10}, {20, 30}, {40, 50, 60}};

auto type = encoder::RowEncodeTrait<decltype(a)>::Type();

ASSERT_EQ(type->name(), "list");
ASSERT_EQ(type->field(0)->type()->name(), "list");

ArrayWriter writer(std::dynamic_pointer_cast<arrow::ListType>(type));
writer.Reset(a.size());

encoder::RowEncodeTrait<decltype(a)>::Write(encoder::EmptyWriteVisitor{}, a,
writer);

auto array = writer.CopyToArrayData();
ASSERT_EQ(array->GetArray(0)->GetInt32(0), 10);
ASSERT_EQ(array->GetArray(1)->GetInt32(0), 20);
ASSERT_EQ(array->GetArray(1)->GetInt32(1), 30);
ASSERT_EQ(array->GetArray(2)->GetInt32(0), 40);
ASSERT_EQ(array->GetArray(2)->GetInt32(1), 50);
ASSERT_EQ(array->GetArray(2)->GetInt32(2), 60);
}

} // namespace test

} // namespace fury
Expand Down
4 changes: 2 additions & 2 deletions src/fury/encoder/row_encoder_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@
#include <type_traits>

#include "fury/encoder/row_encode_trait.h"
#include "src/fury/encoder/row_encoder.h"
#include "src/fury/row/writer.h"
#include "fury/encoder/row_encoder.h"
#include "fury/row/writer.h"

namespace fury {

Expand Down
26 changes: 26 additions & 0 deletions src/fury/meta/type_traits.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

#pragma once

#include <iterator>
#include <type_traits>

namespace fury {
Expand Down Expand Up @@ -71,6 +72,31 @@ template <typename T, typename... Args>
using EnableIfIsOneOf =
typename std::enable_if<IsOneOf<T, Args...>::value, T>::type;

namespace details {
using std::begin;
using std::end;

template <typename T,
typename U = std::void_t<
decltype(*begin(std::declval<T &>()),
++std::declval<decltype(begin(std::declval<T &>())) &>(),
begin(std::declval<T &>()) != end(std::declval<T &>()))>>
std::true_type IsIterableImpl(int);

template <typename T> std::false_type IsIterableImpl(...);

template <typename T> struct GetValueTypeImpl {
using type = std::remove_reference_t<decltype(*begin(std::declval<T &>()))>;
};
} // namespace details

template <typename T>
constexpr inline bool IsIterable =
decltype(details::IsIterableImpl<T>(0))::value;

template <typename T>
using GetValueType = typename details::GetValueTypeImpl<T>::type;

} // namespace meta

} // namespace fury
17 changes: 17 additions & 0 deletions src/fury/meta/type_traits_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@
*/

#include "gtest/gtest.h"
#include <deque>
#include <initializer_list>
#include <list>

#include "fury/meta/field_info.h"
#include "src/fury/meta/type_traits.h"
Expand Down Expand Up @@ -60,6 +63,20 @@ TEST(Meta, IsUnique) {
static_assert(!IsUnique<1, false, true, &A::x, 1>::value);
}

TEST(Meta, IsIterable) {
static_assert(IsIterable<std::vector<int>>);
static_assert(IsIterable<std::vector<std::vector<int>>>);
static_assert(IsIterable<std::deque<float>>);
static_assert(IsIterable<std::list<int>>);
static_assert(IsIterable<std::set<int>>);
static_assert(IsIterable<std::map<int, std::vector<unsigned>>>);
static_assert(IsIterable<struct A[10]>);
static_assert(IsIterable<float[2][2]>);
static_assert(IsIterable<std::initializer_list<A>>);
static_assert(IsIterable<std::string>);
static_assert(IsIterable<std::string_view>);
}

} // namespace test

} // namespace fury
Expand Down
2 changes: 2 additions & 0 deletions src/fury/row/writer.h
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,8 @@ class ArrayWriter : public Writer {

int size() { return cursor() - starting_offset_; }

std::shared_ptr<arrow::ListType> type() { return type_; }

private:
std::shared_ptr<arrow::ListType> type_;
int element_size_;
Expand Down