diff --git a/cpp/src/arrow/dataset/dataset.cc b/cpp/src/arrow/dataset/dataset.cc index 6faaa953bb3fb..eb307681e916a 100644 --- a/cpp/src/arrow/dataset/dataset.cc +++ b/cpp/src/arrow/dataset/dataset.cc @@ -15,18 +15,19 @@ // specific language governing permissions and limitations // under the License. -#include "arrow/dataset/dataset.h" - #include #include +#include "arrow/dataset/dataset.h" #include "arrow/dataset/dataset_internal.h" #include "arrow/dataset/scanner.h" #include "arrow/table.h" +#include "arrow/util/async_generator.h" #include "arrow/util/bit_util.h" #include "arrow/util/iterator.h" #include "arrow/util/logging.h" #include "arrow/util/make_unique.h" +#include "arrow/util/thread_pool.h" namespace arrow { @@ -160,6 +161,33 @@ Result Dataset::GetFragments(compute::Expression predicate) { : MakeEmptyIterator>(); } +Result Dataset::GetFragmentsAsync() { + return GetFragmentsAsync(compute::literal(true)); +} + +Result Dataset::GetFragmentsAsync(compute::Expression predicate) { + ARROW_ASSIGN_OR_RAISE( + predicate, SimplifyWithGuarantee(std::move(predicate), partition_expression_)); + return predicate.IsSatisfiable() + ? GetFragmentsAsyncImpl(std::move(predicate), + arrow::internal::GetCpuThreadPool()) + : MakeEmptyGenerator>(); +} + +// Default impl delegating the work to `GetFragmentsImpl` and wrapping it into +// BackgroundGenerator/TransferredGenerator, which offloads potentially +// IO-intensive work to the default IO thread pool and then transfers the control +// back to the specified executor. +Result Dataset::GetFragmentsAsyncImpl( + compute::Expression predicate, arrow::internal::Executor* executor) { + ARROW_ASSIGN_OR_RAISE(auto iter, GetFragmentsImpl(std::move(predicate))); + ARROW_ASSIGN_OR_RAISE( + auto background_gen, + MakeBackgroundGenerator(std::move(iter), io::default_io_context().executor())); + auto transferred_gen = MakeTransferredGenerator(std::move(background_gen), executor); + return transferred_gen; +} + struct VectorRecordBatchGenerator : InMemoryDataset::RecordBatchGenerator { explicit VectorRecordBatchGenerator(RecordBatchVector batches) : batches_(std::move(batches)) {} diff --git a/cpp/src/arrow/dataset/dataset.h b/cpp/src/arrow/dataset/dataset.h index 62181b60ba423..3a5030b6be8f8 100644 --- a/cpp/src/arrow/dataset/dataset.h +++ b/cpp/src/arrow/dataset/dataset.h @@ -29,10 +29,16 @@ #include "arrow/compute/exec/expression.h" #include "arrow/dataset/type_fwd.h" #include "arrow/dataset/visibility.h" +#include "arrow/util/async_generator_fwd.h" #include "arrow/util/macros.h" #include "arrow/util/mutex.h" namespace arrow { + +namespace internal { +class Executor; +} // namespace internal + namespace dataset { using RecordBatchGenerator = std::function>()>; @@ -134,6 +140,8 @@ class ARROW_DS_EXPORT InMemoryFragment : public Fragment { /// @} +using FragmentGenerator = AsyncGenerator>; + /// \brief A container of zero or more Fragments. /// /// A Dataset acts as a union of Fragments, e.g. files deeply nested in a @@ -148,6 +156,10 @@ class ARROW_DS_EXPORT Dataset : public std::enable_shared_from_this { Result GetFragments(compute::Expression predicate); Result GetFragments(); + /// \brief Async versions of `GetFragments`. + Result GetFragmentsAsync(compute::Expression predicate); + Result GetFragmentsAsync(); + const std::shared_ptr& schema() const { return schema_; } /// \brief An expression which evaluates to true for all data viewed by this Dataset. @@ -174,6 +186,18 @@ class ARROW_DS_EXPORT Dataset : public std::enable_shared_from_this { Dataset(std::shared_ptr schema, compute::Expression partition_expression); virtual Result GetFragmentsImpl(compute::Expression predicate) = 0; + /// \brief Default non-virtual implementation method for the base + /// `GetFragmentsAsyncImpl` method, which creates a fragment generator for + /// the dataset, possibly filtering results with a predicate (forwarding to + /// the synchronous `GetFragmentsImpl` method and moving the computations + /// to the background, using the IO thread pool). + /// + /// Currently, `executor` is always the same as `internal::GetCPUThreadPool()`, + /// which means the results from the underlying fragment generator will be + /// transfered to the default CPU thread pool. The generator itself is + /// offloaded to run on the default IO thread pool. + virtual Result GetFragmentsAsyncImpl( + compute::Expression predicate, arrow::internal::Executor* executor); std::shared_ptr schema_; compute::Expression partition_expression_ = compute::literal(true); diff --git a/cpp/src/arrow/dataset/dataset_test.cc b/cpp/src/arrow/dataset/dataset_test.cc index cb155d7b962fd..5d199823474de 100644 --- a/cpp/src/arrow/dataset/dataset_test.cc +++ b/cpp/src/arrow/dataset/dataset_test.cc @@ -146,6 +146,34 @@ TEST_F(TestInMemoryDataset, HandlesDifferingSchemas) { scanner->ToTable()); } +TEST_F(TestInMemoryDataset, GetFragmentsSync) { + constexpr int64_t kBatchSize = 1024; + constexpr int64_t kNumberBatches = 16; + + SetSchema({field("i32", int32()), field("f64", float64())}); + auto batch = ConstantArrayGenerator::Zeroes(kBatchSize, schema_); + auto reader = ConstantArrayGenerator::Repeat(kNumberBatches, batch); + + auto dataset = std::make_shared( + schema_, RecordBatchVector{static_cast(kNumberBatches), batch}); + + AssertDatasetFragmentsEqual(reader.get(), dataset.get()); +} + +TEST_F(TestInMemoryDataset, GetFragmentsAsync) { + constexpr int64_t kBatchSize = 1024; + constexpr int64_t kNumberBatches = 16; + + SetSchema({field("i32", int32()), field("f64", float64())}); + auto batch = ConstantArrayGenerator::Zeroes(kBatchSize, schema_); + auto reader = ConstantArrayGenerator::Repeat(kNumberBatches, batch); + + auto dataset = std::make_shared( + schema_, RecordBatchVector{static_cast(kNumberBatches), batch}); + + AssertDatasetAsyncFragmentsEqual(reader.get(), dataset.get()); +} + class TestUnionDataset : public DatasetFixtureMixin {}; TEST_F(TestUnionDataset, ReplaceSchema) { diff --git a/cpp/src/arrow/dataset/test_util.h b/cpp/src/arrow/dataset/test_util.h index 05a9869389604..fb54dc3a91ab8 100644 --- a/cpp/src/arrow/dataset/test_util.h +++ b/cpp/src/arrow/dataset/test_util.h @@ -167,7 +167,7 @@ class DatasetFixtureMixin : public ::testing::Test { void AssertFragmentEquals(RecordBatchReader* expected, Fragment* fragment, bool ensure_drained = true) { ASSERT_OK_AND_ASSIGN(auto batch_gen, fragment->ScanBatchesAsync(options_)); - AssertScanTaskEquals(expected, batch_gen); + AssertScanTaskEquals(expected, batch_gen, ensure_drained); if (ensure_drained) { EnsureRecordBatchReaderDrained(expected); @@ -191,6 +191,22 @@ class DatasetFixtureMixin : public ::testing::Test { } } + void AssertDatasetAsyncFragmentsEqual(RecordBatchReader* expected, Dataset* dataset, + bool ensure_drained = true) { + ASSERT_OK_AND_ASSIGN(auto predicate, options_->filter.Bind(*dataset->schema())); + ASSERT_OK_AND_ASSIGN(auto gen, dataset->GetFragmentsAsync(predicate)) + + ASSERT_FINISHES_OK(VisitAsyncGenerator( + std::move(gen), [this, expected](const std::shared_ptr& f) { + AssertFragmentEquals(expected, f.get(), false /*ensure_drained*/); + return Status::OK(); + })); + + if (ensure_drained) { + EnsureRecordBatchReaderDrained(expected); + } + } + /// \brief Ensure that record batches found in reader are equals to the /// record batches yielded by a scanner. void AssertScannerEquals(RecordBatchReader* expected, Scanner* scanner, diff --git a/cpp/src/arrow/util/async_generator.h b/cpp/src/arrow/util/async_generator.h index d4a9c2829a77d..0d51208ac72e3 100644 --- a/cpp/src/arrow/util/async_generator.h +++ b/cpp/src/arrow/util/async_generator.h @@ -25,6 +25,7 @@ #include #include +#include "arrow/util/async_generator_fwd.h" #include "arrow/util/async_util.h" #include "arrow/util/functional.h" #include "arrow/util/future.h" @@ -66,9 +67,6 @@ namespace arrow { // until all outstanding futures have completed. Generators that spawn multiple // concurrent futures may need to hold onto an error while other concurrent futures wrap // up. -template -using AsyncGenerator = std::function()>; - template struct IterationTraits> { /// \brief by default when iterating through a sequence of AsyncGenerator, diff --git a/cpp/src/arrow/util/async_generator_fwd.h b/cpp/src/arrow/util/async_generator_fwd.h new file mode 100644 index 0000000000000..f3c5bf9ef6f52 --- /dev/null +++ b/cpp/src/arrow/util/async_generator_fwd.h @@ -0,0 +1,71 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include + +#include "arrow/type_fwd.h" + +namespace arrow { + +template +using AsyncGenerator = std::function()>; + +template +class MappingGenerator; + +template +class SequencingGenerator; + +template +class TransformingGenerator; + +template +class SerialReadaheadGenerator; + +template +class ReadaheadGenerator; + +template +class PushGenerator; + +template +class MergedGenerator; + +template +struct Enumerated; + +template +class EnumeratingGenerator; + +template +class TransferringGenerator; + +template +class BackgroundGenerator; + +template +class GeneratorIterator; + +template +struct CancellableGenerator; + +template +class DefaultIfEmptyGenerator; + +} // namespace arrow