Skip to content

Commit

Permalink
ARROW-7966: [FlightRPC][C++] Validate individual batches in integration
Browse files Browse the repository at this point in the history
Closes #6662 from lidavidm/arrow-7966

Authored-by: David Li <li.davidm96@gmail.com>
Signed-off-by: Antoine Pitrou <antoine@python.org>
  • Loading branch information
lidavidm authored and pitrou committed Mar 19, 2020
1 parent 70b0921 commit f7d3923
Show file tree
Hide file tree
Showing 2 changed files with 97 additions and 70 deletions.
118 changes: 58 additions & 60 deletions cpp/src/arrow/flight/test_integration_client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -44,52 +44,26 @@ DEFINE_string(host, "localhost", "Server port to connect to");
DEFINE_int32(port, 31337, "Server port to connect to");
DEFINE_string(path, "", "Resource path to request");

/// \brief Helper to read a MetadataRecordBatchReader into a Table.
arrow::Status ReadToTable(arrow::flight::MetadataRecordBatchReader& reader,
std::shared_ptr<arrow::Table>* retrieved_data) {
// For integration testing, we expect the server numbers the
// batches, to test the application metadata part of the spec.
std::vector<std::shared_ptr<arrow::RecordBatch>> retrieved_chunks;
arrow::flight::FlightStreamChunk chunk;
int counter = 0;
while (true) {
RETURN_NOT_OK(reader.Next(&chunk));
if (!chunk.data) break;
retrieved_chunks.push_back(chunk.data);
if (std::to_string(counter) != chunk.app_metadata->ToString()) {
return arrow::Status::Invalid(
"Expected metadata value: " + std::to_string(counter) +
" but got: " + chunk.app_metadata->ToString());
}
counter++;
}
return arrow::Table::FromRecordBatches(reader.schema(), retrieved_chunks,
retrieved_data);
}

/// \brief Helper to read a JsonReader into a Table.
arrow::Status ReadToTable(std::unique_ptr<arrow::ipc::internal::json::JsonReader>& reader,
std::shared_ptr<arrow::Table>* retrieved_data) {
std::vector<std::shared_ptr<arrow::RecordBatch>> retrieved_chunks;
/// \brief Helper to read all batches from a JsonReader
arrow::Status ReadBatches(std::unique_ptr<arrow::ipc::internal::json::JsonReader>& reader,
std::vector<std::shared_ptr<arrow::RecordBatch>>* chunks) {
std::shared_ptr<arrow::RecordBatch> chunk;
for (int i = 0; i < reader->num_record_batches(); i++) {
RETURN_NOT_OK(reader->ReadRecordBatch(i, &chunk));
retrieved_chunks.push_back(chunk);
RETURN_NOT_OK(chunk->ValidateFull());
chunks->push_back(chunk);
}
return arrow::Table::FromRecordBatches(reader->schema(), retrieved_chunks,
retrieved_data);
return arrow::Status::OK();
}

/// \brief Upload the contents of a RecordBatchReader to a Flight
/// server, validating the application metadata on the side.
arrow::Status UploadReaderToFlight(arrow::RecordBatchReader* reader,
arrow::flight::FlightStreamWriter& writer,
arrow::flight::FlightMetadataReader& metadata_reader) {
/// \brief Upload the a list of batches to a Flight server, validating
/// the application metadata on the side.
arrow::Status UploadBatchesToFlight(
const std::vector<std::shared_ptr<arrow::RecordBatch>>& chunks,
arrow::flight::FlightStreamWriter& writer,
arrow::flight::FlightMetadataReader& metadata_reader) {
int counter = 0;
while (true) {
std::shared_ptr<arrow::RecordBatch> chunk;
RETURN_NOT_OK(reader->ReadNext(&chunk));
if (chunk == nullptr) break;
for (const auto& chunk : chunks) {
std::shared_ptr<arrow::Buffer> metadata =
arrow::Buffer::FromString(std::to_string(counter));
RETURN_NOT_OK(writer.WriteWithMetadata(*chunk, metadata));
Expand All @@ -105,18 +79,50 @@ arrow::Status UploadReaderToFlight(arrow::RecordBatchReader* reader,
return writer.Close();
}

/// \brief Helper to read a flight into a Table.
arrow::Status ConsumeFlightLocation(const arrow::flight::Location& location,
const arrow::flight::Ticket& ticket,
const std::shared_ptr<arrow::Schema>& schema,
std::shared_ptr<arrow::Table>* retrieved_data) {
/// \brief Retrieve the given Flight and compare to the original expected batches.
arrow::Status ConsumeFlightLocation(
const arrow::flight::Location& location, const arrow::flight::Ticket& ticket,
const std::vector<std::shared_ptr<arrow::RecordBatch>>& retrieved_data) {
std::unique_ptr<arrow::flight::FlightClient> read_client;
RETURN_NOT_OK(arrow::flight::FlightClient::Connect(location, &read_client));

std::unique_ptr<arrow::flight::FlightStreamReader> stream;
RETURN_NOT_OK(read_client->DoGet(ticket, &stream));

return ReadToTable(*stream, retrieved_data);
int counter = 0;
const int expected = static_cast<int>(retrieved_data.size());
for (const auto& original_batch : retrieved_data) {
arrow::flight::FlightStreamChunk chunk;
RETURN_NOT_OK(stream->Next(&chunk));
if (chunk.data == nullptr) {
return arrow::Status::Invalid("Got fewer batches than expected, received so far: ",
counter, " expected ", expected);
}

if (!original_batch->Equals(*chunk.data)) {
return arrow::Status::Invalid("Batch ", counter, " does not match");
}

const auto st = chunk.data->ValidateFull();
if (!st.ok()) {
return arrow::Status::Invalid("Batch ", counter, " is not valid: ", st.ToString());
}

if (std::to_string(counter) != chunk.app_metadata->ToString()) {
return arrow::Status::Invalid(
"Expected metadata value: " + std::to_string(counter) +
" but got: " + chunk.app_metadata->ToString());
}
counter++;
}

arrow::flight::FlightStreamChunk chunk;
RETURN_NOT_OK(stream->Next(&chunk));
if (chunk.data != nullptr) {
return arrow::Status::Invalid("Got more batches than the expected ", expected);
}

return arrow::Status::OK();
}

int main(int argc, char** argv) {
Expand All @@ -138,15 +144,14 @@ int main(int argc, char** argv) {
ABORT_NOT_OK(arrow::ipc::internal::json::JsonReader::Open(arrow::default_memory_pool(),
in_file, &reader));

std::shared_ptr<arrow::Table> original_data;
ABORT_NOT_OK(ReadToTable(reader, &original_data));
std::shared_ptr<arrow::Schema> original_schema = reader->schema();
std::vector<std::shared_ptr<arrow::RecordBatch>> original_data;
ABORT_NOT_OK(ReadBatches(reader, &original_data));

std::unique_ptr<arrow::flight::FlightStreamWriter> write_stream;
std::unique_ptr<arrow::flight::FlightMetadataReader> metadata_reader;
ABORT_NOT_OK(client->DoPut(descr, reader->schema(), &write_stream, &metadata_reader));
std::unique_ptr<arrow::RecordBatchReader> table_reader(
new arrow::TableBatchReader(*original_data));
ABORT_NOT_OK(UploadReaderToFlight(table_reader.get(), *write_stream, *metadata_reader));
ABORT_NOT_OK(client->DoPut(descr, original_schema, &write_stream, &metadata_reader));
ABORT_NOT_OK(UploadBatchesToFlight(original_data, *write_stream, *metadata_reader));

// 2. Get the ticket for the data.
std::unique_ptr<arrow::flight::FlightInfo> info;
Expand All @@ -171,15 +176,8 @@ int main(int argc, char** argv) {

for (const auto location : locations) {
std::cout << "Verifying location " << location.ToString() << std::endl;
// 3. Download the data from the server.
std::shared_ptr<arrow::Table> retrieved_data;
ABORT_NOT_OK(ConsumeFlightLocation(location, ticket, schema, &retrieved_data));

// 4. Validate that the data is equal.
if (!original_data->Equals(*retrieved_data)) {
std::cerr << "Data does not match!" << std::endl;
return 1;
}
// 3. Stream data from the server, comparing individual batches.
ABORT_NOT_OK(ConsumeFlightLocation(location, ticket, original_data));
}
}
return 0;
Expand Down
49 changes: 39 additions & 10 deletions cpp/src/arrow/flight/test_integration_server.cc
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,33 @@ DEFINE_int32(port, 31337, "Server port to listen on");
namespace arrow {
namespace flight {

struct IntegrationDataset {
std::shared_ptr<Schema> schema;
std::vector<std::shared_ptr<RecordBatch>> chunks;
};

class RecordBatchListReader : public RecordBatchReader {
public:
explicit RecordBatchListReader(IntegrationDataset dataset)
: dataset_(dataset), current_(0) {}

std::shared_ptr<Schema> schema() const override { return dataset_.schema; }

Status ReadNext(std::shared_ptr<RecordBatch>* batch) override {
if (current_ >= dataset_.chunks.size()) {
*batch = nullptr;
return Status::OK();
}
*batch = dataset_.chunks[current_];
current_++;
return Status::OK();
}

private:
IntegrationDataset dataset_;
uint64_t current_;
};

class FlightIntegrationTestServer : public FlightServerBase {
Status GetFlightInfo(const ServerCallContext& context, const FlightDescriptor& request,
std::unique_ptr<FlightInfo>* info) override {
Expand All @@ -57,10 +84,13 @@ class FlightIntegrationTestServer : public FlightServerBase {
FlightEndpoint endpoint1({{request.path[0]}, {}});

FlightInfo::Data flight_data;
RETURN_NOT_OK(internal::SchemaToString(*flight->schema(), &flight_data.schema));
RETURN_NOT_OK(internal::SchemaToString(*flight.schema, &flight_data.schema));
flight_data.descriptor = request;
flight_data.endpoints = {endpoint1};
flight_data.total_records = flight->num_rows();
flight_data.total_records = 0;
for (const auto& chunk : flight.chunks) {
flight_data.total_records += chunk->num_rows();
}
flight_data.total_bytes = -1;
FlightInfo value(flight_data);

Expand All @@ -81,7 +111,7 @@ class FlightIntegrationTestServer : public FlightServerBase {

*data_stream = std::unique_ptr<FlightDataStream>(
new NumberingStream(std::unique_ptr<FlightDataStream>(new RecordBatchStream(
std::shared_ptr<RecordBatchReader>(new TableBatchReader(*flight))))));
std::shared_ptr<RecordBatchReader>(new RecordBatchListReader(flight))))));

return Status::OK();
}
Expand All @@ -99,24 +129,23 @@ class FlightIntegrationTestServer : public FlightServerBase {

std::string key = descriptor.path[0];

std::vector<std::shared_ptr<arrow::RecordBatch>> retrieved_chunks;
IntegrationDataset dataset;
dataset.schema = reader->schema();
arrow::flight::FlightStreamChunk chunk;
while (true) {
RETURN_NOT_OK(reader->Next(&chunk));
if (chunk.data == nullptr) break;
retrieved_chunks.push_back(chunk.data);
RETURN_NOT_OK(chunk.data->ValidateFull());
dataset.chunks.push_back(chunk.data);
if (chunk.app_metadata) {
RETURN_NOT_OK(writer->WriteMetadata(*chunk.app_metadata));
}
}
std::shared_ptr<arrow::Table> retrieved_data;
RETURN_NOT_OK(arrow::Table::FromRecordBatches(reader->schema(), retrieved_chunks,
&retrieved_data));
uploaded_chunks[key] = retrieved_data;
uploaded_chunks[key] = dataset;
return Status::OK();
}

std::unordered_map<std::string, std::shared_ptr<arrow::Table>> uploaded_chunks;
std::unordered_map<std::string, IntegrationDataset> uploaded_chunks;
};

} // namespace flight
Expand Down

0 comments on commit f7d3923

Please sign in to comment.