From b1110ae377c66bc3b666f9c287afdf4907bb1952 Mon Sep 17 00:00:00 2001 From: Miles Granger Date: Tue, 22 Nov 2022 11:57:08 +0100 Subject: [PATCH] ARROW-17989: [C++][Python] Enable struct_field kernel to accept string field names (#14495) Will close [ARROW-17989](https://issues.apache.org/jira/browse/ARROW-17989) Allows using names in `pc.struct_field` ```python In [1]: arr = pa.array([{'a': {'b': 1}, 'c': 2}]) In [2]: pc.struct_field(arr, 'c') Out[2]: [ 2 ] In [3]: pc.struct_field(arr, '.a.b') Out[3]: [ 1 ] # And indices as before... In [4]: pc.struct_field(arr, [0, 0]) Out[4]: [ 1 ] In [5]: ``` Lead-authored-by: Miles Granger Co-authored-by: Antoine Pitrou Co-authored-by: Joris Van den Bossche Signed-off-by: Joris Van den Bossche --- cpp/src/arrow/compute/api_scalar.cc | 11 ++- cpp/src/arrow/compute/api_scalar.h | 7 +- .../arrow/compute/kernels/scalar_nested.cc | 27 +++++-- .../compute/kernels/scalar_nested_test.cc | 46 +++++++++-- .../engine/substrait/expression_internal.cc | 18 +++-- cpp/src/arrow/type.cc | 78 ++++++++++++++----- cpp/src/arrow/type.h | 3 + cpp/src/arrow/type_test.cc | 52 ++++++++++++- python/pyarrow/_compute.pyx | 39 +++++++++- python/pyarrow/includes/libarrow.pxd | 6 ++ python/pyarrow/tests/test_compute.py | 32 ++++++-- 11 files changed, 267 insertions(+), 52 deletions(-) diff --git a/cpp/src/arrow/compute/api_scalar.cc b/cpp/src/arrow/compute/api_scalar.cc index 5de6eade5b859..425274043ed03 100644 --- a/cpp/src/arrow/compute/api_scalar.cc +++ b/cpp/src/arrow/compute/api_scalar.cc @@ -365,7 +365,7 @@ static auto kStrptimeOptionsType = GetFunctionOptionsType( DataMember("unit", &StrptimeOptions::unit), DataMember("error_is_null", &StrptimeOptions::error_is_null)); static auto kStructFieldOptionsType = GetFunctionOptionsType( - DataMember("indices", &StructFieldOptions::indices)); + DataMember("field_ref", &StructFieldOptions::field_ref)); static auto kTrimOptionsType = GetFunctionOptionsType( DataMember("characters", &TrimOptions::characters)); static auto kUtf8NormalizeOptionsType = GetFunctionOptionsType( @@ -578,8 +578,13 @@ StrptimeOptions::StrptimeOptions() : StrptimeOptions("", TimeUnit::MICRO, false) constexpr char StrptimeOptions::kTypeName[]; StructFieldOptions::StructFieldOptions(std::vector indices) - : FunctionOptions(internal::kStructFieldOptionsType), indices(std::move(indices)) {} -StructFieldOptions::StructFieldOptions() : StructFieldOptions(std::vector()) {} + : FunctionOptions(internal::kStructFieldOptionsType), field_ref(std::move(indices)) {} +StructFieldOptions::StructFieldOptions(std::initializer_list indices) + : FunctionOptions(internal::kStructFieldOptionsType), field_ref(std::move(indices)) {} +StructFieldOptions::StructFieldOptions(FieldRef ref) + : FunctionOptions(internal::kStructFieldOptionsType), field_ref(std::move(ref)) {} +StructFieldOptions::StructFieldOptions() + : FunctionOptions(internal::kStructFieldOptionsType) {} constexpr char StructFieldOptions::kTypeName[]; TrimOptions::TrimOptions(std::string characters) diff --git a/cpp/src/arrow/compute/api_scalar.h b/cpp/src/arrow/compute/api_scalar.h index f15d9c667f6e3..1c27757fcfc51 100644 --- a/cpp/src/arrow/compute/api_scalar.h +++ b/cpp/src/arrow/compute/api_scalar.h @@ -278,12 +278,13 @@ class ARROW_EXPORT SetLookupOptions : public FunctionOptions { class ARROW_EXPORT StructFieldOptions : public FunctionOptions { public: explicit StructFieldOptions(std::vector indices); + explicit StructFieldOptions(std::initializer_list); + explicit StructFieldOptions(FieldRef field_ref); StructFieldOptions(); static constexpr char const kTypeName[] = "StructFieldOptions"; - /// The child indices to extract. For instance, to get the 2nd child - /// of the 1st child of a struct or union, this would be {0, 1}. - std::vector indices; + /// The FieldRef specifying what to extract from struct or union. + FieldRef field_ref; }; class ARROW_EXPORT StrptimeOptions : public FunctionOptions { diff --git a/cpp/src/arrow/compute/kernels/scalar_nested.cc b/cpp/src/arrow/compute/kernels/scalar_nested.cc index 5af6b78182cd3..fb1cd9220b14a 100644 --- a/cpp/src/arrow/compute/kernels/scalar_nested.cc +++ b/cpp/src/arrow/compute/kernels/scalar_nested.cc @@ -388,9 +388,17 @@ const FunctionDoc list_element_doc( struct StructFieldFunctor { static Status Exec(KernelContext* ctx, const ExecSpan& batch, ExecResult* out) { const auto& options = OptionsWrapper::Get(ctx); - std::shared_ptr current = MakeArray(batch[0].array.ToArrayData()); - for (const auto& index : options.indices) { + + FieldPath field_path; + if (options.field_ref.IsNested() || options.field_ref.IsName()) { + ARROW_ASSIGN_OR_RAISE(field_path, options.field_ref.FindOne(*current->type())); + } else { + DCHECK(options.field_ref.IsFieldPath()); + field_path = *options.field_ref.field_path(); + } + + for (const auto& index : field_path.indices()) { RETURN_NOT_OK(CheckIndex(index, *current->type())); switch (current->type()->id()) { case Type::STRUCT: { @@ -421,7 +429,8 @@ struct StructFieldFunctor { ArrayData(int32(), union_array.length(), {std::move(take_bitmap), union_array.value_offsets()}, kUnknownNullCount, union_array.offset())); - // Do not slice the child since the indices are relative to the unsliced array. + // Do not slice the child since the indices are relative to the unsliced + // array. ARROW_ASSIGN_OR_RAISE( Datum result, CallFunction("take", {union_array.field(index), std::move(take_indices)})); @@ -463,9 +472,17 @@ struct StructFieldFunctor { Result ResolveStructFieldType(KernelContext* ctx, const std::vector& types) { - const auto& options = OptionsWrapper::Get(ctx); + const auto& field_ref = OptionsWrapper::Get(ctx).field_ref; const DataType* type = types.front().type; - for (const auto& index : options.indices) { + + FieldPath field_path; + if (field_ref.IsNested() || field_ref.IsName()) { + ARROW_ASSIGN_OR_RAISE(field_path, field_ref.FindOne(*type)); + } else { + field_path = *field_ref.field_path(); + } + + for (const auto& index : field_path.indices()) { RETURN_NOT_OK(StructFieldFunctor::CheckIndex(index, *type)); type = type->field(index)->type().get(); } diff --git a/cpp/src/arrow/compute/kernels/scalar_nested_test.cc b/cpp/src/arrow/compute/kernels/scalar_nested_test.cc index ec1e7ceeae480..744f18890809a 100644 --- a/cpp/src/arrow/compute/kernels/scalar_nested_test.cc +++ b/cpp/src/arrow/compute/kernels/scalar_nested_test.cc @@ -261,6 +261,13 @@ TEST(TestScalarNested, StructField) { StructFieldOptions invalid2({2, 4}); StructFieldOptions invalid3({3}); StructFieldOptions invalid4({0, 1}); + + // Test using FieldRefs + StructFieldOptions extract0_field_ref_path(FieldRef(FieldPath({0}))); + StructFieldOptions extract0_field_ref_name(FieldRef("a")); + ASSERT_OK_AND_ASSIGN(auto field_ref, FieldRef::FromDotPath(".c.d")); + StructFieldOptions extract20_field_ref_nest(field_ref); + FieldVector fields = {field("a", int32()), field("b", utf8()), field("c", struct_({ field("d", int64()), @@ -278,16 +285,25 @@ TEST(TestScalarNested, StructField) { &extract0); CheckScalar("struct_field", {arr}, ArrayFromJSON(int64(), "[10, 11, 12, null]"), &extract20); + + CheckScalar("struct_field", {arr}, ArrayFromJSON(int32(), "[1, null, 3, null]"), + &extract0_field_ref_path); + CheckScalar("struct_field", {arr}, ArrayFromJSON(int32(), "[1, null, 3, null]"), + &extract0_field_ref_name); + CheckScalar("struct_field", {arr}, ArrayFromJSON(int64(), "[10, 11, 12, null]"), + &extract20_field_ref_nest); + EXPECT_RAISES_WITH_MESSAGE_THAT(Invalid, ::testing::HasSubstr("out-of-bounds field reference"), CallFunction("struct_field", {arr}, &invalid1)); EXPECT_RAISES_WITH_MESSAGE_THAT(Invalid, - ::testing::HasSubstr("out-of-bounds field reference"), + ::testing::HasSubstr("No match for FieldRef"), CallFunction("struct_field", {arr}, &invalid2)); EXPECT_RAISES_WITH_MESSAGE_THAT(Invalid, ::testing::HasSubstr("out-of-bounds field reference"), CallFunction("struct_field", {arr}, &invalid3)); - EXPECT_RAISES_WITH_MESSAGE_THAT(TypeError, ::testing::HasSubstr("cannot subscript"), + EXPECT_RAISES_WITH_MESSAGE_THAT(Invalid, + ::testing::HasSubstr("No match for FieldRef"), CallFunction("struct_field", {arr}, &invalid4)); } { @@ -303,16 +319,25 @@ TEST(TestScalarNested, StructField) { &extract0); CheckScalar("struct_field", {arr}, ArrayFromJSON(int64(), "[null, null, null, 10]"), &extract20); + + CheckScalar("struct_field", {arr}, ArrayFromJSON(int32(), "[1, null, null, null]"), + &extract0_field_ref_path); + CheckScalar("struct_field", {arr}, ArrayFromJSON(int32(), "[1, null, null, null]"), + &extract0_field_ref_name); + CheckScalar("struct_field", {arr}, ArrayFromJSON(int64(), "[null, null, null, 10]"), + &extract20_field_ref_nest); + EXPECT_RAISES_WITH_MESSAGE_THAT(Invalid, ::testing::HasSubstr("out-of-bounds field reference"), CallFunction("struct_field", {arr}, &invalid1)); EXPECT_RAISES_WITH_MESSAGE_THAT(Invalid, - ::testing::HasSubstr("out-of-bounds field reference"), + ::testing::HasSubstr("No match for FieldRef"), CallFunction("struct_field", {arr}, &invalid2)); EXPECT_RAISES_WITH_MESSAGE_THAT(Invalid, ::testing::HasSubstr("out-of-bounds field reference"), CallFunction("struct_field", {arr}, &invalid3)); - EXPECT_RAISES_WITH_MESSAGE_THAT(TypeError, ::testing::HasSubstr("cannot subscript"), + EXPECT_RAISES_WITH_MESSAGE_THAT(Invalid, + ::testing::HasSubstr("No match for FieldRef"), CallFunction("struct_field", {arr}, &invalid4)); // Test edge cases for union representation @@ -352,16 +377,25 @@ TEST(TestScalarNested, StructField) { &extract0); CheckScalar("struct_field", {arr}, ArrayFromJSON(int64(), "[null, null, null, 10]"), &extract20); + + CheckScalar("struct_field", {arr}, ArrayFromJSON(int32(), "[1, null, null, null]"), + &extract0_field_ref_path); + CheckScalar("struct_field", {arr}, ArrayFromJSON(int32(), "[1, null, null, null]"), + &extract0_field_ref_name); + CheckScalar("struct_field", {arr}, ArrayFromJSON(int64(), "[null, null, null, 10]"), + &extract20_field_ref_nest); + EXPECT_RAISES_WITH_MESSAGE_THAT(Invalid, ::testing::HasSubstr("out-of-bounds field reference"), CallFunction("struct_field", {arr}, &invalid1)); EXPECT_RAISES_WITH_MESSAGE_THAT(Invalid, - ::testing::HasSubstr("out-of-bounds field reference"), + ::testing::HasSubstr("No match for FieldRef"), CallFunction("struct_field", {arr}, &invalid2)); EXPECT_RAISES_WITH_MESSAGE_THAT(Invalid, ::testing::HasSubstr("out-of-bounds field reference"), CallFunction("struct_field", {arr}, &invalid3)); - EXPECT_RAISES_WITH_MESSAGE_THAT(TypeError, ::testing::HasSubstr("cannot subscript"), + EXPECT_RAISES_WITH_MESSAGE_THAT(Invalid, + ::testing::HasSubstr("No match for FieldRef"), CallFunction("struct_field", {arr}, &invalid4)); } { diff --git a/cpp/src/arrow/engine/substrait/expression_internal.cc b/cpp/src/arrow/engine/substrait/expression_internal.cc index 7495d1a34e1fa..b988bf195a23b 100644 --- a/cpp/src/arrow/engine/substrait/expression_internal.cc +++ b/cpp/src/arrow/engine/substrait/expression_internal.cc @@ -170,9 +170,10 @@ Result FromProto(const substrait::Expression& expr, out = compute::field_ref(FieldRef(*out_ref, index)); } else if (out->call() && out->call()->function_name == "struct_field") { // Nested StructFields on top of an arbitrary expression - std::static_pointer_cast( - out->call()->options) - ->indices.push_back(index); + auto* field_options = + checked_cast(out->call()->options.get()); + field_options->field_ref = + FieldRef(std::move(field_options->field_ref), index); } else { // First StructField on top of an arbitrary expression out = compute::call("struct_field", {std::move(*out)}, @@ -1019,13 +1020,16 @@ Result> ToProto( if (call->function_name == "struct_field") { // catch the special case of calls convertible to a StructField + const auto& field_options = + checked_cast(*call->options); + const DataType& struct_type = *call->arguments[0].type(); + DCHECK_EQ(struct_type.id(), Type::STRUCT); + + ARROW_ASSIGN_OR_RAISE(auto field_path, field_options.field_ref.FindOne(struct_type)); out = std::move(arguments[0]); - for (int index : - checked_cast(*call->options) - .indices) { + for (int index : field_path.indices()) { ARROW_ASSIGN_OR_RAISE(out, MakeStructFieldReference(std::move(out), index)); } - return std::move(out); } diff --git a/cpp/src/arrow/type.cc b/cpp/src/arrow/type.cc index 6b63a1f8b7649..4247ac2360cac 100644 --- a/cpp/src/arrow/type.cc +++ b/cpp/src/arrow/type.cc @@ -20,6 +20,7 @@ #include #include #include +#include #include #include #include @@ -1161,36 +1162,72 @@ Result> FieldPath::Get(const ArrayData& data) const { return FieldPathGetImpl::Get(this, data.child_data); } -FieldRef::FieldRef(FieldPath indices) : impl_(std::move(indices)) { - DCHECK_GT(std::get(impl_).indices().size(), 0); -} +FieldRef::FieldRef(FieldPath indices) : impl_(std::move(indices)) {} void FieldRef::Flatten(std::vector children) { + ARROW_CHECK(!children.empty()); + // flatten children struct Visitor { - void operator()(std::string&& name) { out->push_back(FieldRef(std::move(name))); } + void operator()(std::string&& name, std::vector* out) { + out->push_back(FieldRef(std::move(name))); + } - void operator()(FieldPath&& indices) { out->push_back(FieldRef(std::move(indices))); } + void operator()(FieldPath&& path, std::vector* out) { + if (path.indices().empty()) { + return; + } + out->push_back(FieldRef(std::move(path))); + } - void operator()(std::vector&& children) { - out->reserve(out->size() + children.size()); + void operator()(std::vector&& children, std::vector* out) { + if (children.empty()) { + return; + } + // First flatten children into temporary result + std::vector flattened_children; + flattened_children.reserve(children.size()); for (auto&& child : children) { - std::visit(*this, std::move(child.impl_)); + std::visit(std::bind(*this, std::placeholders::_1, &flattened_children), + std::move(child.impl_)); + } + // If all children are FieldPaths, concatenate them into a single FieldPath + int64_t n_indices = 0; + for (const auto& child : flattened_children) { + const FieldPath* path = child.field_path(); + if (!path) { + n_indices = -1; + break; + } + n_indices += static_cast(path->indices().size()); + } + if (n_indices == 0) { + return; + } else if (n_indices > 0) { + std::vector indices(n_indices); + auto out_indices = indices.begin(); + for (const auto& child : flattened_children) { + for (int index : *child.field_path()) { + *out_indices++ = index; + } + } + DCHECK_EQ(out_indices, indices.end()); + out->push_back(FieldRef(std::move(indices))); + } else { + // ... otherwise, just transfer them to the final result + out->insert(out->end(), std::move_iterator(flattened_children.begin()), + std::move_iterator(flattened_children.end())); } } - - std::vector* out; }; std::vector out; - Visitor visitor{&out}; - visitor(std::move(children)); + Visitor visitor; + visitor(std::move(children), &out); - DCHECK(!out.empty()); - DCHECK(std::none_of(out.begin(), out.end(), - [](const FieldRef& ref) { return ref.IsNested(); })); - - if (out.size() == 1) { + if (out.empty()) { + impl_ = std::vector(); + } else if (out.size() == 1) { impl_ = std::move(out[0].impl_); } else { impl_ = std::move(out); @@ -1199,7 +1236,7 @@ void FieldRef::Flatten(std::vector children) { Result FieldRef::FromDotPath(const std::string& dot_path_arg) { if (dot_path_arg.empty()) { - return Status::Invalid("Dot path was empty"); + return FieldRef(); } std::vector children; @@ -1449,6 +1486,11 @@ std::vector FieldRef::FindAll(const RecordBatch& batch) const { void PrintTo(const FieldRef& ref, std::ostream* os) { *os << ref.ToString(); } +std::ostream& operator<<(std::ostream& os, const FieldRef& ref) { + os << ref.ToString(); + return os; +} + // ---------------------------------------------------------------------- // Schema implementation diff --git a/cpp/src/arrow/type.h b/cpp/src/arrow/type.h index 3bb92bf26f230..415aaacf1c9ef 100644 --- a/cpp/src/arrow/type.h +++ b/cpp/src/arrow/type.h @@ -1851,6 +1851,9 @@ class ARROW_EXPORT FieldRef : public util::EqualityComparable { ARROW_EXPORT void PrintTo(const FieldRef& ref, std::ostream* os); +ARROW_EXPORT +std::ostream& operator<<(std::ostream& os, const FieldRef&); + // ---------------------------------------------------------------------- // Schema diff --git a/cpp/src/arrow/type_test.cc b/cpp/src/arrow/type_test.cc index c6ce1887de4b4..954ad63c8aa68 100644 --- a/cpp/src/arrow/type_test.cc +++ b/cpp/src/arrow/type_test.cc @@ -414,12 +414,26 @@ TEST(TestFieldRef, FromDotPath) { ASSERT_OK_AND_EQ(FieldRef(R"([y]\tho.\)"), FieldRef::FromDotPath(R"(.\[y\]\\tho\.\)")); - ASSERT_RAISES(Invalid, FieldRef::FromDotPath(R"()")); + ASSERT_OK_AND_EQ(FieldRef(), FieldRef::FromDotPath(R"()")); + ASSERT_RAISES(Invalid, FieldRef::FromDotPath(R"(alpha)")); ASSERT_RAISES(Invalid, FieldRef::FromDotPath(R"([134234)")); ASSERT_RAISES(Invalid, FieldRef::FromDotPath(R"([1stuf])")); } +TEST(TestFieldRef, DotPathRoundTrip) { + auto check_roundtrip = [](const FieldRef& ref) { + auto dot_path = ref.ToDotPath(); + ASSERT_OK_AND_EQ(ref, FieldRef::FromDotPath(dot_path)); + }; + + check_roundtrip(FieldRef()); + check_roundtrip(FieldRef("foo")); + check_roundtrip(FieldRef("foo", 1, "bar", 2, 3)); + check_roundtrip(FieldRef(1, 2, 3)); + check_roundtrip(FieldRef("foo", 1, FieldRef("bar", 2, 3), FieldRef())); +} + TEST(TestFieldPath, Nested) { auto f0 = field("alpha", int32()); auto f1_0 = field("alpha", int32()); @@ -456,6 +470,42 @@ TEST(TestFieldRef, Nested) { ElementsAre(FieldPath{2, 1, 0}, FieldPath{2, 1, 1})); } +TEST(TestFieldRef, Flatten) { + FieldRef ref; + + auto assert_name = [](const FieldRef& ref, const std::string& expected) { + ASSERT_TRUE(ref.IsName()); + ASSERT_EQ(*ref.name(), expected); + }; + + auto assert_path = [](const FieldRef& ref, const std::vector& expected) { + ASSERT_TRUE(ref.IsFieldPath()); + ASSERT_EQ(ref.field_path()->indices(), expected); + }; + + auto assert_nested = [](const FieldRef& ref, const std::vector& expected) { + ASSERT_TRUE(ref.IsNested()); + ASSERT_EQ(*ref.nested_refs(), expected); + }; + + assert_path(FieldRef(), {}); + assert_path(FieldRef(1, 2, 3), {1, 2, 3}); + // If all leaves are field paths, they are fully flattened + assert_path(FieldRef(1, FieldRef(2, 3)), {1, 2, 3}); + assert_path(FieldRef(1, FieldRef(2, 3), FieldRef(), FieldRef(FieldRef(4), FieldRef(5))), + {1, 2, 3, 4, 5}); + assert_path(FieldRef(FieldRef(), FieldRef(FieldRef(), FieldRef())), {}); + + assert_name(FieldRef("foo"), "foo"); + + // Nested empty field refs are optimized away + assert_nested(FieldRef("foo", 1, FieldRef(), FieldRef(FieldRef(), "bar")), + {FieldRef("foo"), FieldRef(1), FieldRef("bar")}); + // For now, subsequences of indices are not concatenated + assert_nested(FieldRef("foo", FieldRef("bar"), FieldRef(1, 2), FieldRef(3)), + {FieldRef("foo"), FieldRef("bar"), FieldRef(1, 2), FieldRef(3)}); +} + using TestSchema = ::testing::Test; TEST_F(TestSchema, Basics) { diff --git a/python/pyarrow/_compute.pyx b/python/pyarrow/_compute.pyx index 659af0afba37c..c75c5bf189ba7 100644 --- a/python/pyarrow/_compute.pyx +++ b/python/pyarrow/_compute.pyx @@ -1361,7 +1361,37 @@ class MakeStructOptions(_MakeStructOptions): cdef class _StructFieldOptions(FunctionOptions): def _set_options(self, indices): - self.wrapped.reset(new CStructFieldOptions(indices)) + cdef: + CFieldRef field_ref + const CFieldRef* field_ref_ptr + + if isinstance(indices, (list, tuple)): + if len(indices): + indices = Expression._nested_field(tuple(indices)) + else: + # Allow empty indices; effecitively return same array + self.wrapped.reset( + new CStructFieldOptions(indices)) + return + + if isinstance(indices, Expression): + field_ref_ptr = (indices).unwrap().field_ref() + if field_ref_ptr is NULL: + raise ValueError("Unable to get CFieldRef from Expression") + field_ref = deref(field_ref_ptr) + elif isinstance(indices, (bytes, str)): + if indices.startswith(b'.' if isinstance(indices, bytes) else '.'): + field_ref = GetResultValue( + CFieldRef.FromDotPath(tobytes(indices))) + else: + field_ref = CFieldRef(tobytes(indices)) + elif isinstance(indices, int): + field_ref = CFieldRef( indices) + else: + raise TypeError("Expected List[str], List[int], List[bytes], " + "Expression, bytes, str, or int. " + f"Got: {type(indices)}") + self.wrapped.reset(new CStructFieldOptions(field_ref)) class StructFieldOptions(_StructFieldOptions): @@ -1370,7 +1400,7 @@ class StructFieldOptions(_StructFieldOptions): Parameters ---------- - indices : sequence of int + indices : List[str], List[bytes], List[int], Expression, bytes, str, or int List of indices for chained field lookup, for example `[4, 1]` will look up the second nested field in the fifth outer field. """ @@ -2442,7 +2472,10 @@ cdef class Expression(_Weakrefable): raise ValueError("nested field reference should be non-empty") nested.reserve(len(names)) for name in names: - nested.push_back(CFieldRef( tobytes(name))) + if isinstance(name, int): + nested.push_back(CFieldRef(name)) + else: + nested.push_back(CFieldRef( tobytes(name))) return Expression.wrap(CMakeFieldExpression(CFieldRef(move(nested)))) @staticmethod diff --git a/python/pyarrow/includes/libarrow.pxd b/python/pyarrow/includes/libarrow.pxd index bc82a420897d7..9cea340a3090e 100644 --- a/python/pyarrow/includes/libarrow.pxd +++ b/python/pyarrow/includes/libarrow.pxd @@ -434,6 +434,9 @@ cdef extern from "arrow/api.h" namespace "arrow" nogil: CFieldRef(c_string name) CFieldRef(int index) CFieldRef(vector[CFieldRef]) + + @staticmethod + CResult[CFieldRef] FromDotPath(c_string& dot_path) const c_string* name() const cdef cppclass CFieldRefHash" arrow::FieldRef::Hash": @@ -2291,7 +2294,9 @@ cdef extern from "arrow/compute/api.h" namespace "arrow::compute" nogil: cdef cppclass CStructFieldOptions \ "arrow::compute::StructFieldOptions"(CFunctionOptions): CStructFieldOptions(vector[int] indices) + CStructFieldOptions(CFieldRef field_ref) vector[int] indices + CFieldRef field_ref ctypedef enum CSortOrder" arrow::compute::SortOrder": CSortOrder_Ascending \ @@ -2496,6 +2501,7 @@ cdef extern from "arrow/compute/exec/expression.h" \ c_bool Equals(const CExpression& other) const c_string ToString() const CResult[CExpression] Bind(const CSchema&) + const CFieldRef* field_ref() const cdef CExpression CMakeScalarExpression \ "arrow::compute::literal"(shared_ptr[CScalar] value) diff --git a/python/pyarrow/tests/test_compute.py b/python/pyarrow/tests/test_compute.py index 3d03c7d86a0c9..68b3303fe782f 100644 --- a/python/pyarrow/tests/test_compute.py +++ b/python/pyarrow/tests/test_compute.py @@ -2690,14 +2690,32 @@ def test_struct_fields_options(): c = pa.StructArray.from_arrays([a, b], ["a", "b"]) arr = pa.StructArray.from_arrays([a, c], ["a", "c"]) - assert pc.struct_field(arr, - indices=[1, 1]) == pa.array(["bar", None, ""]) - assert pc.struct_field(arr, [1, 1]) == pa.array(["bar", None, ""]) - assert pc.struct_field(arr, [0]) == pa.array([4, 5, 6], type=pa.int64()) + assert pc.struct_field(arr, '.c.b') == b + assert pc.struct_field(arr, b'.c.b') == b + assert pc.struct_field(arr, ['c', 'b']) == b + assert pc.struct_field(arr, [1, 'b']) == b + assert pc.struct_field(arr, (b'c', 'b')) == b + assert pc.struct_field(arr, pc.field(('c', 'b'))) == b + + assert pc.struct_field(arr, '.a') == a + assert pc.struct_field(arr, ['a']) == a + assert pc.struct_field(arr, 'a') == a + assert pc.struct_field(arr, pc.field(('a',))) == a + + assert pc.struct_field(arr, indices=[1, 1]) == b + assert pc.struct_field(arr, (1, 1)) == b + assert pc.struct_field(arr, [0]) == a assert pc.struct_field(arr, []) == arr - with pytest.raises(TypeError, match="an integer is required"): - pc.struct_field(arr, indices=['a']) + with pytest.raises(pa.ArrowInvalid, match="No match for FieldRef"): + pc.struct_field(arr, 'foo') + + with pytest.raises(pa.ArrowInvalid, match="No match for FieldRef"): + pc.struct_field(arr, '.c.foo') + + # drill into a non-struct array and continue to ask for a field + with pytest.raises(pa.ArrowInvalid, match="No match for FieldRef"): + pc.struct_field(arr, '.a.foo') # TODO: https://issues.apache.org/jira/browse/ARROW-14853 # assert pc.struct_field(arr) == arr @@ -2863,6 +2881,7 @@ def test_expression_construction(): false = pc.scalar(False) string = pc.scalar("string") field = pc.field("field") + nested_mixed_types = pc.field(b"a", 1, "b") nested_field = pc.field(("nested", "field")) nested_field2 = pc.field("nested", "field") @@ -2872,6 +2891,7 @@ def test_expression_construction(): field.cast(typ) == true field.isin([1, 2]) + nested_mixed_types.isin(["foo", "bar"]) nested_field.isin(["foo", "bar"]) nested_field2.isin(["foo", "bar"])