Skip to content

Commit

Permalink
ARROW-10642: [R] Can't get Table from RecordBatchReader with 0 batches
Browse files Browse the repository at this point in the history
Closes #8956 from nealrichardson/zero-batches

Authored-by: Neal Richardson <neal.p.richardson@gmail.com>
Signed-off-by: Neal Richardson <neal.p.richardson@gmail.com>
  • Loading branch information
nealrichardson committed Dec 21, 2020
1 parent 38ba81b commit a2e7d3a
Show file tree
Hide file tree
Showing 6 changed files with 79 additions and 33 deletions.
8 changes: 4 additions & 4 deletions r/R/arrowExports.R

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

5 changes: 3 additions & 2 deletions r/R/record-batch-reader.R
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ RecordBatchReader <- R6Class("RecordBatchReader", inherit = ArrowObject,
RecordBatchStreamReader <- R6Class("RecordBatchStreamReader", inherit = RecordBatchReader,
public = list(
batches = function() ipc___RecordBatchStreamReader__batches(self),
read_table = function() Table__from_RecordBatchStreamReader(self)
read_table = function() Table__from_RecordBatchReader(self)
)
)
RecordBatchStreamReader$create <- function(stream) {
Expand All @@ -128,7 +128,8 @@ RecordBatchStreamReader$create <- function(stream) {
#' @format NULL
#' @export
RecordBatchFileReader <- R6Class("RecordBatchFileReader", inherit = ArrowObject,
# Why doesn't this inherit from RecordBatchReader?
# Why doesn't this inherit from RecordBatchReader in C++?
# Origin: https://github.com/apache/arrow/pull/679
public = list(
get_batch = function(i) {
ipc___RecordBatchFileReader__ReadRecordBatch(self, i)
Expand Down
26 changes: 13 additions & 13 deletions r/src/arrowExports.cpp

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

24 changes: 10 additions & 14 deletions r/src/recordbatchreader.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -88,9 +88,19 @@ std::shared_ptr<arrow::ipc::RecordBatchFileReader> ipc___RecordBatchFileReader__
return ValueOrStop(arrow::ipc::RecordBatchFileReader::Open(file));
}

// [[arrow::export]]
std::shared_ptr<arrow::Table> Table__from_RecordBatchReader(
const std::shared_ptr<arrow::RecordBatchReader>& reader) {
std::shared_ptr<arrow::Table> table = nullptr;
StopIfNotOk(reader->ReadAll(&table));
return table;
}

// [[arrow::export]]
std::shared_ptr<arrow::Table> Table__from_RecordBatchFileReader(
const std::shared_ptr<arrow::ipc::RecordBatchFileReader>& reader) {
// RecordBatchStreamReader inherits from RecordBatchReader
// but RecordBatchFileReader apparently does not
int num_batches = reader->num_record_batches();
std::vector<std::shared_ptr<arrow::RecordBatch>> batches(num_batches);
for (int i = 0; i < num_batches; i++) {
Expand All @@ -100,20 +110,6 @@ std::shared_ptr<arrow::Table> Table__from_RecordBatchFileReader(
return ValueOrStop(arrow::Table::FromRecordBatches(std::move(batches)));
}

// [[arrow::export]]
std::shared_ptr<arrow::Table> Table__from_RecordBatchStreamReader(
const std::shared_ptr<arrow::ipc::RecordBatchStreamReader>& reader) {
std::shared_ptr<arrow::RecordBatch> batch;
std::vector<std::shared_ptr<arrow::RecordBatch>> batches;
while (true) {
StopIfNotOk(reader->ReadNext(&batch));
if (!batch) break;
batches.push_back(batch);
}

return ValueOrStop(arrow::Table::FromRecordBatches(std::move(batches)));
}

// [[arrow::export]]
cpp11::list ipc___RecordBatchFileReader__batches(
const std::shared_ptr<arrow::ipc::RecordBatchFileReader>& reader) {
Expand Down
6 changes: 6 additions & 0 deletions r/tests/testthat/test-Table.R
Original file line number Diff line number Diff line change
Expand Up @@ -353,6 +353,12 @@ test_that("table() auto splices (ARROW-5718)", {
expect_equivalent(as.data.frame(tab3), df)
})

test_that("Validation when creating table with schema (ARROW-10953)", {
tab <- Table$create(data.frame(), schema = schema(a = int32()))
skip("This segfaults")
expect_identical(dim(as.data.frame(tab)), c(0L, 1L))
})

test_that("==.Table", {
tab1 <- Table$create(x = 1:2, y = c("a", "b"))
tab2 <- Table$create(x = 1:2, y = c("a", "b"))
Expand Down
43 changes: 43 additions & 0 deletions r/tests/testthat/test-record-batch-reader.R
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,36 @@ test_that("RecordBatchFileReader / Writer", {
expect_equal(reader$num_record_batches, 3)
})

test_that("StreamReader read_table", {
sink <- BufferOutputStream$create()
writer <- RecordBatchStreamWriter$create(sink, batch$schema)
expect_is(writer, "RecordBatchWriter")
writer$write(batch)
writer$write(tab)
writer$write(tbl)
writer$close()
buf <- sink$finish()

reader <- RecordBatchStreamReader$create(buf)
out <- reader$read_table()
expect_identical(dim(out), c(30L, 2L))
})

test_that("FileReader read_table", {
sink <- BufferOutputStream$create()
writer <- RecordBatchFileWriter$create(sink, batch$schema)
expect_is(writer, "RecordBatchWriter")
writer$write(batch)
writer$write(tab)
writer$write(tbl)
writer$close()
buf <- sink$finish()

reader <- RecordBatchFileReader$create(buf)
out <- reader$read_table()
expect_identical(dim(out), c(30L, 2L))
})

test_that("MetadataFormat", {
expect_identical(get_ipc_metadata_version(5), 4L)
expect_identical(get_ipc_metadata_version("V4"), 3L)
Expand All @@ -97,3 +127,16 @@ test_that("MetadataFormat", {
'"45" is not a valid IPC MetadataVersion'
)
})

test_that("reader with 0 batches", {
# IPC stream containing only a schema (ARROW-10642)
sink <- BufferOutputStream$create()
writer <- RecordBatchStreamWriter$create(sink, schema(a = int32()))
writer$close()
buf <- sink$finish()

reader <- RecordBatchStreamReader$create(buf)
tab <- reader$read_table()
expect_is(tab, "Table")
expect_identical(dim(tab), c(0L, 1L))
})

0 comments on commit a2e7d3a

Please sign in to comment.