Skip to content

Commit

Permalink
ARROW-17989: [C++][Python] Enable struct_field kernel to accept strin…
Browse files Browse the repository at this point in the history
…g field names (#14495)

Will close [ARROW-17989](https://issues.apache.org/jira/browse/ARROW-17989)

Allows using names in `pc.struct_field`
```python
In [1]: arr = pa.array([{'a': {'b': 1}, 'c': 2}])

In [2]: pc.struct_field(arr, 'c')
Out[2]:
<pyarrow.lib.Int64Array object at 0x7f1442da3d60>
[
  2
]

In [3]: pc.struct_field(arr, '.a.b')
Out[3]:
<pyarrow.lib.Int64Array object at 0x7f14436d0f40>
[
  1
]

# And indices as before...
In [4]: pc.struct_field(arr, [0, 0])
Out[4]:
<pyarrow.lib.Int64Array object at 0x7f14436d0ee0>
[
  1
]

In [5]:
```

Lead-authored-by: Miles Granger <miles59923@gmail.com>
Co-authored-by: Antoine Pitrou <antoine@python.org>
Co-authored-by: Joris Van den Bossche <jorisvandenbossche@gmail.com>
Signed-off-by: Joris Van den Bossche <jorisvandenbossche@gmail.com>
  • Loading branch information
3 people committed Nov 22, 2022
1 parent f769f6b commit b1110ae
Show file tree
Hide file tree
Showing 11 changed files with 267 additions and 52 deletions.
11 changes: 8 additions & 3 deletions cpp/src/arrow/compute/api_scalar.cc
Expand Up @@ -365,7 +365,7 @@ static auto kStrptimeOptionsType = GetFunctionOptionsType<StrptimeOptions>(
DataMember("unit", &StrptimeOptions::unit),
DataMember("error_is_null", &StrptimeOptions::error_is_null));
static auto kStructFieldOptionsType = GetFunctionOptionsType<StructFieldOptions>(
DataMember("indices", &StructFieldOptions::indices));
DataMember("field_ref", &StructFieldOptions::field_ref));
static auto kTrimOptionsType = GetFunctionOptionsType<TrimOptions>(
DataMember("characters", &TrimOptions::characters));
static auto kUtf8NormalizeOptionsType = GetFunctionOptionsType<Utf8NormalizeOptions>(
Expand Down Expand Up @@ -578,8 +578,13 @@ StrptimeOptions::StrptimeOptions() : StrptimeOptions("", TimeUnit::MICRO, false)
constexpr char StrptimeOptions::kTypeName[];

StructFieldOptions::StructFieldOptions(std::vector<int> indices)
: FunctionOptions(internal::kStructFieldOptionsType), indices(std::move(indices)) {}
StructFieldOptions::StructFieldOptions() : StructFieldOptions(std::vector<int>()) {}
: FunctionOptions(internal::kStructFieldOptionsType), field_ref(std::move(indices)) {}
StructFieldOptions::StructFieldOptions(std::initializer_list<int> indices)
: FunctionOptions(internal::kStructFieldOptionsType), field_ref(std::move(indices)) {}
StructFieldOptions::StructFieldOptions(FieldRef ref)
: FunctionOptions(internal::kStructFieldOptionsType), field_ref(std::move(ref)) {}
StructFieldOptions::StructFieldOptions()
: FunctionOptions(internal::kStructFieldOptionsType) {}
constexpr char StructFieldOptions::kTypeName[];

TrimOptions::TrimOptions(std::string characters)
Expand Down
7 changes: 4 additions & 3 deletions cpp/src/arrow/compute/api_scalar.h
Expand Up @@ -278,12 +278,13 @@ class ARROW_EXPORT SetLookupOptions : public FunctionOptions {
class ARROW_EXPORT StructFieldOptions : public FunctionOptions {
public:
explicit StructFieldOptions(std::vector<int> indices);
explicit StructFieldOptions(std::initializer_list<int>);
explicit StructFieldOptions(FieldRef field_ref);
StructFieldOptions();
static constexpr char const kTypeName[] = "StructFieldOptions";

/// The child indices to extract. For instance, to get the 2nd child
/// of the 1st child of a struct or union, this would be {0, 1}.
std::vector<int> indices;
/// The FieldRef specifying what to extract from struct or union.
FieldRef field_ref;
};

class ARROW_EXPORT StrptimeOptions : public FunctionOptions {
Expand Down
27 changes: 22 additions & 5 deletions cpp/src/arrow/compute/kernels/scalar_nested.cc
Expand Up @@ -388,9 +388,17 @@ const FunctionDoc list_element_doc(
struct StructFieldFunctor {
static Status Exec(KernelContext* ctx, const ExecSpan& batch, ExecResult* out) {
const auto& options = OptionsWrapper<StructFieldOptions>::Get(ctx);

std::shared_ptr<Array> current = MakeArray(batch[0].array.ToArrayData());
for (const auto& index : options.indices) {

FieldPath field_path;
if (options.field_ref.IsNested() || options.field_ref.IsName()) {
ARROW_ASSIGN_OR_RAISE(field_path, options.field_ref.FindOne(*current->type()));
} else {
DCHECK(options.field_ref.IsFieldPath());
field_path = *options.field_ref.field_path();
}

for (const auto& index : field_path.indices()) {
RETURN_NOT_OK(CheckIndex(index, *current->type()));
switch (current->type()->id()) {
case Type::STRUCT: {
Expand Down Expand Up @@ -421,7 +429,8 @@ struct StructFieldFunctor {
ArrayData(int32(), union_array.length(),
{std::move(take_bitmap), union_array.value_offsets()},
kUnknownNullCount, union_array.offset()));
// Do not slice the child since the indices are relative to the unsliced array.
// Do not slice the child since the indices are relative to the unsliced
// array.
ARROW_ASSIGN_OR_RAISE(
Datum result,
CallFunction("take", {union_array.field(index), std::move(take_indices)}));
Expand Down Expand Up @@ -463,9 +472,17 @@ struct StructFieldFunctor {

Result<TypeHolder> ResolveStructFieldType(KernelContext* ctx,
const std::vector<TypeHolder>& types) {
const auto& options = OptionsWrapper<StructFieldOptions>::Get(ctx);
const auto& field_ref = OptionsWrapper<StructFieldOptions>::Get(ctx).field_ref;
const DataType* type = types.front().type;
for (const auto& index : options.indices) {

FieldPath field_path;
if (field_ref.IsNested() || field_ref.IsName()) {
ARROW_ASSIGN_OR_RAISE(field_path, field_ref.FindOne(*type));
} else {
field_path = *field_ref.field_path();
}

for (const auto& index : field_path.indices()) {
RETURN_NOT_OK(StructFieldFunctor::CheckIndex(index, *type));
type = type->field(index)->type().get();
}
Expand Down
46 changes: 40 additions & 6 deletions cpp/src/arrow/compute/kernels/scalar_nested_test.cc
Expand Up @@ -261,6 +261,13 @@ TEST(TestScalarNested, StructField) {
StructFieldOptions invalid2({2, 4});
StructFieldOptions invalid3({3});
StructFieldOptions invalid4({0, 1});

// Test using FieldRefs
StructFieldOptions extract0_field_ref_path(FieldRef(FieldPath({0})));
StructFieldOptions extract0_field_ref_name(FieldRef("a"));
ASSERT_OK_AND_ASSIGN(auto field_ref, FieldRef::FromDotPath(".c.d"));
StructFieldOptions extract20_field_ref_nest(field_ref);

FieldVector fields = {field("a", int32()), field("b", utf8()),
field("c", struct_({
field("d", int64()),
Expand All @@ -278,16 +285,25 @@ TEST(TestScalarNested, StructField) {
&extract0);
CheckScalar("struct_field", {arr}, ArrayFromJSON(int64(), "[10, 11, 12, null]"),
&extract20);

CheckScalar("struct_field", {arr}, ArrayFromJSON(int32(), "[1, null, 3, null]"),
&extract0_field_ref_path);
CheckScalar("struct_field", {arr}, ArrayFromJSON(int32(), "[1, null, 3, null]"),
&extract0_field_ref_name);
CheckScalar("struct_field", {arr}, ArrayFromJSON(int64(), "[10, 11, 12, null]"),
&extract20_field_ref_nest);

EXPECT_RAISES_WITH_MESSAGE_THAT(Invalid,
::testing::HasSubstr("out-of-bounds field reference"),
CallFunction("struct_field", {arr}, &invalid1));
EXPECT_RAISES_WITH_MESSAGE_THAT(Invalid,
::testing::HasSubstr("out-of-bounds field reference"),
::testing::HasSubstr("No match for FieldRef"),
CallFunction("struct_field", {arr}, &invalid2));
EXPECT_RAISES_WITH_MESSAGE_THAT(Invalid,
::testing::HasSubstr("out-of-bounds field reference"),
CallFunction("struct_field", {arr}, &invalid3));
EXPECT_RAISES_WITH_MESSAGE_THAT(TypeError, ::testing::HasSubstr("cannot subscript"),
EXPECT_RAISES_WITH_MESSAGE_THAT(Invalid,
::testing::HasSubstr("No match for FieldRef"),
CallFunction("struct_field", {arr}, &invalid4));
}
{
Expand All @@ -303,16 +319,25 @@ TEST(TestScalarNested, StructField) {
&extract0);
CheckScalar("struct_field", {arr}, ArrayFromJSON(int64(), "[null, null, null, 10]"),
&extract20);

CheckScalar("struct_field", {arr}, ArrayFromJSON(int32(), "[1, null, null, null]"),
&extract0_field_ref_path);
CheckScalar("struct_field", {arr}, ArrayFromJSON(int32(), "[1, null, null, null]"),
&extract0_field_ref_name);
CheckScalar("struct_field", {arr}, ArrayFromJSON(int64(), "[null, null, null, 10]"),
&extract20_field_ref_nest);

EXPECT_RAISES_WITH_MESSAGE_THAT(Invalid,
::testing::HasSubstr("out-of-bounds field reference"),
CallFunction("struct_field", {arr}, &invalid1));
EXPECT_RAISES_WITH_MESSAGE_THAT(Invalid,
::testing::HasSubstr("out-of-bounds field reference"),
::testing::HasSubstr("No match for FieldRef"),
CallFunction("struct_field", {arr}, &invalid2));
EXPECT_RAISES_WITH_MESSAGE_THAT(Invalid,
::testing::HasSubstr("out-of-bounds field reference"),
CallFunction("struct_field", {arr}, &invalid3));
EXPECT_RAISES_WITH_MESSAGE_THAT(TypeError, ::testing::HasSubstr("cannot subscript"),
EXPECT_RAISES_WITH_MESSAGE_THAT(Invalid,
::testing::HasSubstr("No match for FieldRef"),
CallFunction("struct_field", {arr}, &invalid4));

// Test edge cases for union representation
Expand Down Expand Up @@ -352,16 +377,25 @@ TEST(TestScalarNested, StructField) {
&extract0);
CheckScalar("struct_field", {arr}, ArrayFromJSON(int64(), "[null, null, null, 10]"),
&extract20);

CheckScalar("struct_field", {arr}, ArrayFromJSON(int32(), "[1, null, null, null]"),
&extract0_field_ref_path);
CheckScalar("struct_field", {arr}, ArrayFromJSON(int32(), "[1, null, null, null]"),
&extract0_field_ref_name);
CheckScalar("struct_field", {arr}, ArrayFromJSON(int64(), "[null, null, null, 10]"),
&extract20_field_ref_nest);

EXPECT_RAISES_WITH_MESSAGE_THAT(Invalid,
::testing::HasSubstr("out-of-bounds field reference"),
CallFunction("struct_field", {arr}, &invalid1));
EXPECT_RAISES_WITH_MESSAGE_THAT(Invalid,
::testing::HasSubstr("out-of-bounds field reference"),
::testing::HasSubstr("No match for FieldRef"),
CallFunction("struct_field", {arr}, &invalid2));
EXPECT_RAISES_WITH_MESSAGE_THAT(Invalid,
::testing::HasSubstr("out-of-bounds field reference"),
CallFunction("struct_field", {arr}, &invalid3));
EXPECT_RAISES_WITH_MESSAGE_THAT(TypeError, ::testing::HasSubstr("cannot subscript"),
EXPECT_RAISES_WITH_MESSAGE_THAT(Invalid,
::testing::HasSubstr("No match for FieldRef"),
CallFunction("struct_field", {arr}, &invalid4));
}
{
Expand Down
18 changes: 11 additions & 7 deletions cpp/src/arrow/engine/substrait/expression_internal.cc
Expand Up @@ -170,9 +170,10 @@ Result<compute::Expression> FromProto(const substrait::Expression& expr,
out = compute::field_ref(FieldRef(*out_ref, index));
} else if (out->call() && out->call()->function_name == "struct_field") {
// Nested StructFields on top of an arbitrary expression
std::static_pointer_cast<arrow::compute::StructFieldOptions>(
out->call()->options)
->indices.push_back(index);
auto* field_options =
checked_cast<compute::StructFieldOptions*>(out->call()->options.get());
field_options->field_ref =
FieldRef(std::move(field_options->field_ref), index);
} else {
// First StructField on top of an arbitrary expression
out = compute::call("struct_field", {std::move(*out)},
Expand Down Expand Up @@ -1019,13 +1020,16 @@ Result<std::unique_ptr<substrait::Expression>> ToProto(

if (call->function_name == "struct_field") {
// catch the special case of calls convertible to a StructField
const auto& field_options =
checked_cast<const compute::StructFieldOptions&>(*call->options);
const DataType& struct_type = *call->arguments[0].type();
DCHECK_EQ(struct_type.id(), Type::STRUCT);

ARROW_ASSIGN_OR_RAISE(auto field_path, field_options.field_ref.FindOne(struct_type));
out = std::move(arguments[0]);
for (int index :
checked_cast<const arrow::compute::StructFieldOptions&>(*call->options)
.indices) {
for (int index : field_path.indices()) {
ARROW_ASSIGN_OR_RAISE(out, MakeStructFieldReference(std::move(out), index));
}

return std::move(out);
}

Expand Down
78 changes: 60 additions & 18 deletions cpp/src/arrow/type.cc
Expand Up @@ -20,6 +20,7 @@
#include <algorithm>
#include <climits>
#include <cstddef>
#include <iterator>
#include <limits>
#include <memory>
#include <mutex>
Expand Down Expand Up @@ -1161,36 +1162,72 @@ Result<std::shared_ptr<ArrayData>> FieldPath::Get(const ArrayData& data) const {
return FieldPathGetImpl::Get(this, data.child_data);
}

FieldRef::FieldRef(FieldPath indices) : impl_(std::move(indices)) {
DCHECK_GT(std::get<FieldPath>(impl_).indices().size(), 0);
}
FieldRef::FieldRef(FieldPath indices) : impl_(std::move(indices)) {}

void FieldRef::Flatten(std::vector<FieldRef> children) {
ARROW_CHECK(!children.empty());

// flatten children
struct Visitor {
void operator()(std::string&& name) { out->push_back(FieldRef(std::move(name))); }
void operator()(std::string&& name, std::vector<FieldRef>* out) {
out->push_back(FieldRef(std::move(name)));
}

void operator()(FieldPath&& indices) { out->push_back(FieldRef(std::move(indices))); }
void operator()(FieldPath&& path, std::vector<FieldRef>* out) {
if (path.indices().empty()) {
return;
}
out->push_back(FieldRef(std::move(path)));
}

void operator()(std::vector<FieldRef>&& children) {
out->reserve(out->size() + children.size());
void operator()(std::vector<FieldRef>&& children, std::vector<FieldRef>* out) {
if (children.empty()) {
return;
}
// First flatten children into temporary result
std::vector<FieldRef> flattened_children;
flattened_children.reserve(children.size());
for (auto&& child : children) {
std::visit(*this, std::move(child.impl_));
std::visit(std::bind(*this, std::placeholders::_1, &flattened_children),
std::move(child.impl_));
}
// If all children are FieldPaths, concatenate them into a single FieldPath
int64_t n_indices = 0;
for (const auto& child : flattened_children) {
const FieldPath* path = child.field_path();
if (!path) {
n_indices = -1;
break;
}
n_indices += static_cast<int64_t>(path->indices().size());
}
if (n_indices == 0) {
return;
} else if (n_indices > 0) {
std::vector<int> indices(n_indices);
auto out_indices = indices.begin();
for (const auto& child : flattened_children) {
for (int index : *child.field_path()) {
*out_indices++ = index;
}
}
DCHECK_EQ(out_indices, indices.end());
out->push_back(FieldRef(std::move(indices)));
} else {
// ... otherwise, just transfer them to the final result
out->insert(out->end(), std::move_iterator(flattened_children.begin()),
std::move_iterator(flattened_children.end()));
}
}

std::vector<FieldRef>* out;
};

std::vector<FieldRef> out;
Visitor visitor{&out};
visitor(std::move(children));
Visitor visitor;
visitor(std::move(children), &out);

DCHECK(!out.empty());
DCHECK(std::none_of(out.begin(), out.end(),
[](const FieldRef& ref) { return ref.IsNested(); }));

if (out.size() == 1) {
if (out.empty()) {
impl_ = std::vector<int>();
} else if (out.size() == 1) {
impl_ = std::move(out[0].impl_);
} else {
impl_ = std::move(out);
Expand All @@ -1199,7 +1236,7 @@ void FieldRef::Flatten(std::vector<FieldRef> children) {

Result<FieldRef> FieldRef::FromDotPath(const std::string& dot_path_arg) {
if (dot_path_arg.empty()) {
return Status::Invalid("Dot path was empty");
return FieldRef();
}

std::vector<FieldRef> children;
Expand Down Expand Up @@ -1449,6 +1486,11 @@ std::vector<FieldPath> FieldRef::FindAll(const RecordBatch& batch) const {

void PrintTo(const FieldRef& ref, std::ostream* os) { *os << ref.ToString(); }

std::ostream& operator<<(std::ostream& os, const FieldRef& ref) {
os << ref.ToString();
return os;
}

// ----------------------------------------------------------------------
// Schema implementation

Expand Down
3 changes: 3 additions & 0 deletions cpp/src/arrow/type.h
Expand Up @@ -1851,6 +1851,9 @@ class ARROW_EXPORT FieldRef : public util::EqualityComparable<FieldRef> {

ARROW_EXPORT void PrintTo(const FieldRef& ref, std::ostream* os);

ARROW_EXPORT
std::ostream& operator<<(std::ostream& os, const FieldRef&);

// ----------------------------------------------------------------------
// Schema

Expand Down

0 comments on commit b1110ae

Please sign in to comment.