Skip to content

Commit

Permalink
GH-39163: [C++] Add missing data copy in StreamDecoder::Consume(data) (
Browse files Browse the repository at this point in the history
…#39164)

### Rationale for this change

We need to copy data for metadata message. Because it may be used in subsequent `Consume(data)` calls. We can't assume that the given `data` is still valid in subsequent `Consume(data)`.

We also need to copy buffered `data` because it's used in subsequent `Consume(data)` calls.

### What changes are included in this PR?

* Add missing copies.
* Clean up existing buffer copy codes.
* Change tests to use ephemeral `data` to detect this case.
* Add `copy_record_batch` option to `CollectListener` to copy decoded record batches.

### Are these changes tested?

Yes.

### Are there any user-facing changes?

Yes.

* Closes #39163 
* Closes: #39163

Authored-by: Sutou Kouhei <kou@clear-code.com>
Signed-off-by: Sutou Kouhei <kou@clear-code.com>
  • Loading branch information
kou committed Jan 6, 2024
1 parent 33c64ed commit 6ab7a18
Show file tree
Hide file tree
Showing 2 changed files with 68 additions and 15 deletions.
37 changes: 27 additions & 10 deletions cpp/src/arrow/ipc/message.cc
Original file line number Diff line number Diff line change
Expand Up @@ -626,10 +626,24 @@ class MessageDecoder::MessageDecoderImpl {
RETURN_NOT_OK(ConsumeMetadataLengthData(data, next_required_size_));
break;
case State::METADATA: {
auto buffer = std::make_shared<Buffer>(data, next_required_size_);
// We need to copy metadata because it's used in
// ConsumeBody(). ConsumeBody() may be called from another
// ConsumeData(). We can't assume that the given data for
// the current ConsumeData() call is still valid in the
// next ConsumeData() call. So we need to copy metadata
// here.
ARROW_ASSIGN_OR_RAISE(std::shared_ptr<Buffer> buffer,
AllocateBuffer(next_required_size_, pool_));
memcpy(buffer->mutable_data(), data, next_required_size_);
RETURN_NOT_OK(ConsumeMetadataBuffer(buffer));
} break;
case State::BODY: {
// We don't need to copy the given data for body because
// we can assume that a decoded record batch should be
// valid only in a listener_->OnMessageDecoded() call. If
// the passed message is needed to be valid after the
// call, it's a listener_'s responsibility. The listener_
// may copy the data for it.
auto buffer = std::make_shared<Buffer>(data, next_required_size_);
RETURN_NOT_OK(ConsumeBodyBuffer(buffer));
} break;
Expand All @@ -645,7 +659,12 @@ class MessageDecoder::MessageDecoderImpl {
return Status::OK();
}

chunks_.push_back(std::make_shared<Buffer>(data, size));
// We need to copy unused data because the given data for the
// current ConsumeData() call may be invalid in the next
// ConsumeData() call.
ARROW_ASSIGN_OR_RAISE(std::shared_ptr<Buffer> chunk, AllocateBuffer(size, pool_));
memcpy(chunk->mutable_data(), data, size);
chunks_.push_back(std::move(chunk));
buffered_size_ += size;
return ConsumeChunks();
}
Expand Down Expand Up @@ -830,8 +849,7 @@ class MessageDecoder::MessageDecoderImpl {
}
buffered_size_ -= next_required_size_;
} else {
ARROW_ASSIGN_OR_RAISE(auto metadata, AllocateBuffer(next_required_size_, pool_));
metadata_ = std::shared_ptr<Buffer>(metadata.release());
ARROW_ASSIGN_OR_RAISE(metadata_, AllocateBuffer(next_required_size_, pool_));
RETURN_NOT_OK(ConsumeDataChunks(next_required_size_, metadata_->mutable_data()));
}
return ConsumeMetadata();
Expand All @@ -846,9 +864,8 @@ class MessageDecoder::MessageDecoderImpl {
next_required_size_ = skip_body_ ? 0 : body_length;
RETURN_NOT_OK(listener_->OnBody());
if (next_required_size_ == 0) {
ARROW_ASSIGN_OR_RAISE(auto body, AllocateBuffer(0, pool_));
std::shared_ptr<Buffer> shared_body(body.release());
return ConsumeBody(&shared_body);
auto body = std::make_shared<Buffer>(nullptr, 0);
return ConsumeBody(&body);
} else {
return Status::OK();
}
Expand All @@ -872,10 +889,10 @@ class MessageDecoder::MessageDecoderImpl {
buffered_size_ -= used_size;
return Status::OK();
} else {
ARROW_ASSIGN_OR_RAISE(auto body, AllocateBuffer(next_required_size_, pool_));
ARROW_ASSIGN_OR_RAISE(std::shared_ptr<Buffer> body,
AllocateBuffer(next_required_size_, pool_));
RETURN_NOT_OK(ConsumeDataChunks(next_required_size_, body->mutable_data()));
std::shared_ptr<Buffer> shared_body(body.release());
return ConsumeBody(&shared_body);
return ConsumeBody(&body);
}
}

Expand Down
46 changes: 41 additions & 5 deletions cpp/src/arrow/ipc/read_write_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1330,11 +1330,44 @@ struct StreamWriterHelper {
std::shared_ptr<RecordBatchWriter> writer_;
};

class CopyCollectListener : public CollectListener {
public:
CopyCollectListener() : CollectListener() {}

Status OnRecordBatchWithMetadataDecoded(
RecordBatchWithMetadata record_batch_with_metadata) override {
auto& record_batch = record_batch_with_metadata.batch;
for (auto column_data : record_batch->column_data()) {
ARROW_RETURN_NOT_OK(CopyArrayData(column_data));
}
return CollectListener::OnRecordBatchWithMetadataDecoded(record_batch_with_metadata);
}

private:
Status CopyArrayData(std::shared_ptr<ArrayData> data) {
auto& buffers = data->buffers;
for (size_t i = 0; i < buffers.size(); ++i) {
auto& buffer = buffers[i];
if (!buffer) {
continue;
}
ARROW_ASSIGN_OR_RAISE(buffers[i], Buffer::Copy(buffer, buffer->memory_manager()));
}
for (auto child_data : data->child_data) {
ARROW_RETURN_NOT_OK(CopyArrayData(child_data));
}
if (data->dictionary) {
ARROW_RETURN_NOT_OK(CopyArrayData(data->dictionary));
}
return Status::OK();
}
};

struct StreamDecoderWriterHelper : public StreamWriterHelper {
Status ReadBatches(const IpcReadOptions& options, RecordBatchVector* out_batches,
ReadStats* out_stats = nullptr,
MetadataVector* out_metadata_list = nullptr) override {
auto listener = std::make_shared<CollectListener>();
auto listener = std::make_shared<CopyCollectListener>();
StreamDecoder decoder(listener, options);
RETURN_NOT_OK(DoConsume(&decoder));
*out_batches = listener->record_batches();
Expand All @@ -1358,7 +1391,10 @@ struct StreamDecoderWriterHelper : public StreamWriterHelper {

struct StreamDecoderDataWriterHelper : public StreamDecoderWriterHelper {
Status DoConsume(StreamDecoder* decoder) override {
return decoder->Consume(buffer_->data(), buffer_->size());
// This data is valid only in this function.
ARROW_ASSIGN_OR_RAISE(auto temporary_buffer,
Buffer::Copy(buffer_, arrow::default_cpu_memory_manager()));
return decoder->Consume(temporary_buffer->data(), temporary_buffer->size());
}
};

Expand All @@ -1369,7 +1405,9 @@ struct StreamDecoderBufferWriterHelper : public StreamDecoderWriterHelper {
struct StreamDecoderSmallChunksWriterHelper : public StreamDecoderWriterHelper {
Status DoConsume(StreamDecoder* decoder) override {
for (int64_t offset = 0; offset < buffer_->size() - 1; ++offset) {
RETURN_NOT_OK(decoder->Consume(buffer_->data() + offset, 1));
// This data is valid only in this block.
ARROW_ASSIGN_OR_RAISE(auto temporary_buffer, buffer_->CopySlice(offset, 1));
RETURN_NOT_OK(decoder->Consume(temporary_buffer->data(), temporary_buffer->size()));
}
return Status::OK();
}
Expand Down Expand Up @@ -2172,7 +2210,6 @@ TEST(TestRecordBatchStreamReader, MalformedInput) {
ASSERT_RAISES(Invalid, RecordBatchStreamReader::Open(&garbage_reader));
}

namespace {
class EndlessCollectListener : public CollectListener {
public:
EndlessCollectListener() : CollectListener(), decoder_(nullptr) {}
Expand All @@ -2184,7 +2221,6 @@ class EndlessCollectListener : public CollectListener {
private:
StreamDecoder* decoder_;
};
}; // namespace

TEST(TestStreamDecoder, Reset) {
auto listener = std::make_shared<EndlessCollectListener>();
Expand Down

0 comments on commit 6ab7a18

Please sign in to comment.