Skip to content

Commit

Permalink
ARROW-17242: [C++][FlightRPC] Propagate RecordBatchReader::Close erro…
Browse files Browse the repository at this point in the history
…rs through Flight (#13738)

Authored-by: David Li <li.davidm96@gmail.com>
Signed-off-by: Yibo Cai <yibo.cai@arm.com>
  • Loading branch information
lidavidm authored and kszucs committed Jul 29, 2022
1 parent 5116973 commit db2173c
Show file tree
Hide file tree
Showing 6 changed files with 51 additions and 1 deletion.
5 changes: 5 additions & 0 deletions cpp/src/arrow/flight/server.cc
Expand Up @@ -323,6 +323,8 @@ class RecordBatchStream::RecordBatchStreamImpl {
}
}

Status Close() { return reader_->Close(); }

private:
Status GetNextDictionary(FlightPayload* payload) {
const auto& it = dictionaries_[dictionary_index_++];
Expand All @@ -344,6 +346,7 @@ class RecordBatchStream::RecordBatchStreamImpl {
FlightMetadataWriter::~FlightMetadataWriter() = default;

FlightDataStream::~FlightDataStream() {}
Status FlightDataStream::Close() { return Status::OK(); }

RecordBatchStream::RecordBatchStream(const std::shared_ptr<RecordBatchReader>& reader,
const ipc::IpcWriteOptions& options) {
Expand All @@ -352,6 +355,8 @@ RecordBatchStream::RecordBatchStream(const std::shared_ptr<RecordBatchReader>& r

RecordBatchStream::~RecordBatchStream() {}

Status RecordBatchStream::Close() { return impl_->Close(); }

std::shared_ptr<Schema> RecordBatchStream::schema() { return impl_->schema(); }

arrow::Result<FlightPayload> RecordBatchStream::GetSchemaPayload() {
Expand Down
3 changes: 3 additions & 0 deletions cpp/src/arrow/flight/server.h
Expand Up @@ -64,6 +64,8 @@ class ARROW_FLIGHT_EXPORT FlightDataStream {

ARROW_DEPRECATED("Deprecated in 8.0.0. Use Result-returning overload instead.")
Status Next(FlightPayload* payload) { return Next().Value(payload); }

virtual Status Close();
};

/// \brief A basic implementation of FlightDataStream that will provide
Expand All @@ -85,6 +87,7 @@ class ARROW_FLIGHT_EXPORT RecordBatchStream : public FlightDataStream {
arrow::Result<FlightPayload> GetSchemaPayload() override;

arrow::Result<FlightPayload> Next() override;
Status Close() override;

private:
class RecordBatchStreamImpl;
Expand Down
15 changes: 15 additions & 0 deletions cpp/src/arrow/flight/test_definitions.cc
Expand Up @@ -237,6 +237,21 @@ void DataTest::TestDoGetLargeBatch() {
Ticket ticket{"ticket-large-batch-1"};
CheckDoGet(ticket, expected_batches);
}
// Ensure FlightDataStream/RecordBatchStream::Close errors are propagated
void DataTest::TestFlightDataStreamError() {
Ticket ticket{"ticket-stream-error"};

ASSERT_OK_AND_ASSIGN(auto stream, client_->DoGet(ticket));
Status status;
while (true) {
FlightStreamChunk chunk;
status = stream->Next().Value(&chunk);
if (!chunk.data) break;
if (!status.ok()) break;
}
EXPECT_RAISES_WITH_MESSAGE_THAT(IOError, ::testing::HasSubstr("Expected error"),
status);
}
void DataTest::TestOverflowServerBatch() {
// Regression test for ARROW-13253
// N.B. this is rather a slow and memory-hungry test
Expand Down
2 changes: 2 additions & 0 deletions cpp/src/arrow/flight/test_definitions.h
Expand Up @@ -76,6 +76,7 @@ class ARROW_FLIGHT_EXPORT DataTest : public FlightTest {
void TestDoGetFloats();
void TestDoGetDicts();
void TestDoGetLargeBatch();
void TestFlightDataStreamError();
void TestOverflowServerBatch();
void TestOverflowClientBatch();
void TestDoExchange();
Expand Down Expand Up @@ -107,6 +108,7 @@ class ARROW_FLIGHT_EXPORT DataTest : public FlightTest {
TEST_F(FIXTURE, TestDoGetFloats) { TestDoGetFloats(); } \
TEST_F(FIXTURE, TestDoGetDicts) { TestDoGetDicts(); } \
TEST_F(FIXTURE, TestDoGetLargeBatch) { TestDoGetLargeBatch(); } \
TEST_F(FIXTURE, TestFlightDataStreamError) { TestFlightDataStreamError(); } \
TEST_F(FIXTURE, TestOverflowServerBatch) { TestOverflowServerBatch(); } \
TEST_F(FIXTURE, TestOverflowClientBatch) { TestOverflowClientBatch(); } \
TEST_F(FIXTURE, TestDoExchange) { TestDoExchange(); } \
Expand Down
25 changes: 25 additions & 0 deletions cpp/src/arrow/flight/test_util.cc
Expand Up @@ -89,6 +89,25 @@ Status ResolveCurrentExecutable(fs::path* out) {
}
}

class ErrorRecordBatchReader : public RecordBatchReader {
public:
ErrorRecordBatchReader() : schema_(arrow::schema({})) {}

std::shared_ptr<Schema> schema() const override { return schema_; }

Status ReadNext(std::shared_ptr<RecordBatch>* out) override {
*out = nullptr;
return Status::OK();
}

Status Close() override {
// This should be propagated over DoGet to the client
return Status::IOError("Expected error");
}

private:
std::shared_ptr<Schema> schema_;
};
} // namespace

void TestServer::Start(const std::vector<std::string>& extra_args) {
Expand Down Expand Up @@ -225,6 +244,12 @@ class FlightTestServer : public FlightServerBase {
std::unique_ptr<FlightDataStream>(new RecordBatchStream(std::move(reader)));
return Status::OK();
}
if (request.ticket == "ticket-stream-error") {
auto reader = std::make_shared<ErrorRecordBatchReader>();
*data_stream =
std::unique_ptr<FlightDataStream>(new RecordBatchStream(std::move(reader)));
return Status::OK();
}

std::shared_ptr<RecordBatchReader> batch_reader;
RETURN_NOT_OK(GetBatchForFlight(request, &batch_reader));
Expand Down
2 changes: 1 addition & 1 deletion cpp/src/arrow/flight/transport_server.cc
Expand Up @@ -293,7 +293,7 @@ Status ServerTransport::DoGet(const ServerCallContext& context, const Ticket& ti
if (!success) return Status::OK();
}
RETURN_NOT_OK(stream->WritesDone());
return Status::OK();
return data_stream->Close();
}

Status ServerTransport::DoPut(const ServerCallContext& context,
Expand Down

0 comments on commit db2173c

Please sign in to comment.