Skip to content

Commit

Permalink
Implement projection for arrow file / streams (#1339)
Browse files Browse the repository at this point in the history
* Implement projection for arrow file / streams

* Tests

* Fix

* Fix

* Add test

* Add test

* Add link

* Undo change to existing test

* Update arrow/src/ipc/reader.rs

Co-authored-by: Andrew Lamb <andrew@nerdnetworks.org>

* Use project

Co-authored-by: Andrew Lamb <andrew@nerdnetworks.org>
  • Loading branch information
Dandandan and alamb committed Mar 9, 2022
1 parent f19d1ed commit 4bcc7a6
Show file tree
Hide file tree
Showing 7 changed files with 135 additions and 43 deletions.
1 change: 1 addition & 0 deletions arrow-flight/src/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ pub fn flight_data_to_arrow_batch(
batch,
schema,
dictionaries_by_field,
None,
)
})?
}
Expand Down
148 changes: 119 additions & 29 deletions arrow/src/ipc/reader.rs
Original file line number Diff line number Diff line change
Expand Up @@ -465,6 +465,7 @@ pub fn read_record_batch(
batch: ipc::RecordBatch,
schema: SchemaRef,
dictionaries: &[Option<ArrayRef>],
projection: Option<&[usize]>,
) -> Result<RecordBatch> {
let buffers = batch.buffers().ok_or_else(|| {
ArrowError::IoError("Unable to get buffers from IPC RecordBatch".to_string())
Expand All @@ -477,23 +478,43 @@ pub fn read_record_batch(
let mut node_index = 0;
let mut arrays = vec![];

// keep track of index as lists require more than one node
for field in schema.fields() {
let triple = create_array(
field_nodes,
field.data_type(),
buf,
buffers,
dictionaries,
node_index,
buffer_index,
)?;
node_index = triple.1;
buffer_index = triple.2;
arrays.push(triple.0);
}
if let Some(projection) = projection {
let fields = schema.fields();
for &index in projection {
let field = &fields[index];
let triple = create_array(
field_nodes,
field.data_type(),
buf,
buffers,
dictionaries,
node_index,
buffer_index,
)?;
node_index = triple.1;
buffer_index = triple.2;
arrays.push(triple.0);
}

RecordBatch::try_new(schema, arrays)
RecordBatch::try_new(Arc::new(schema.project(projection)?), arrays)
} else {
// keep track of index as lists require more than one node
for field in schema.fields() {
let triple = create_array(
field_nodes,
field.data_type(),
buf,
buffers,
dictionaries,
node_index,
buffer_index,
)?;
node_index = triple.1;
buffer_index = triple.2;
arrays.push(triple.0);
}
RecordBatch::try_new(schema, arrays)
}
}

/// Read the dictionary from the buffer and provided metadata,
Expand Down Expand Up @@ -532,6 +553,7 @@ pub fn read_dictionary(
batch.data().unwrap(),
Arc::new(schema),
dictionaries_by_field,
None,
)?;
Some(record_batch.column(0).clone())
}
Expand Down Expand Up @@ -581,14 +603,17 @@ pub struct FileReader<R: Read + Seek> {

/// Metadata version
metadata_version: ipc::MetadataVersion,

/// Optional projection and projected_schema
projection: Option<(Vec<usize>, Schema)>,
}

impl<R: Read + Seek> FileReader<R> {
/// Try to create a new file reader
///
/// Returns errors if the file does not meet the Arrow Format header and footer
/// requirements
pub fn try_new(reader: R) -> Result<Self> {
pub fn try_new(reader: R, projection: Option<Vec<usize>>) -> Result<Self> {
let mut reader = BufReader::new(reader);
// check if header and footer contain correct magic bytes
let mut magic_buffer: [u8; 6] = [0; 6];
Expand Down Expand Up @@ -671,6 +696,13 @@ impl<R: Read + Seek> FileReader<R> {
}
};
}
let projection = match projection {
Some(projection_indices) => {
let schema = schema.project(&projection_indices)?;
Some((projection_indices, schema))
}
_ => None,
};

Ok(Self {
reader,
Expand All @@ -680,6 +712,7 @@ impl<R: Read + Seek> FileReader<R> {
total_blocks,
dictionaries_by_field,
metadata_version: footer.version(),
projection,
})
}

Expand Down Expand Up @@ -760,6 +793,8 @@ impl<R: Read + Seek> FileReader<R> {
batch,
self.schema(),
&self.dictionaries_by_field,
self.projection.as_ref().map(|x| x.0.as_ref()),

).map(Some)
}
ipc::MessageHeader::NONE => {
Expand Down Expand Up @@ -808,6 +843,9 @@ pub struct StreamReader<R: Read> {
///
/// This value is set to `true` the first time the reader's `next()` returns `None`.
finished: bool,

/// Optional projection
projection: Option<(Vec<usize>, Schema)>,
}

impl<R: Read> StreamReader<R> {
Expand All @@ -816,7 +854,7 @@ impl<R: Read> StreamReader<R> {
/// The first message in the stream is the schema, the reader will fail if it does not
/// encounter a schema.
/// To check if the reader is done, use `is_finished(self)`
pub fn try_new(reader: R) -> Result<Self> {
pub fn try_new(reader: R, projection: Option<Vec<usize>>) -> Result<Self> {
let mut reader = BufReader::new(reader);
// determine metadata length
let mut meta_size: [u8; 4] = [0; 4];
Expand Down Expand Up @@ -845,11 +883,19 @@ impl<R: Read> StreamReader<R> {
// Create an array of optional dictionary value arrays, one per field.
let dictionaries_by_field = vec![None; schema.fields().len()];

let projection = match projection {
Some(projection_indices) => {
let schema = schema.project(&projection_indices)?;
Some((projection_indices, schema))
}
_ => None,
};
Ok(Self {
reader,
schema: Arc::new(schema),
finished: false,
dictionaries_by_field,
projection,
})
}

Expand Down Expand Up @@ -922,7 +968,7 @@ impl<R: Read> StreamReader<R> {
let mut buf = vec![0; message.bodyLength() as usize];
self.reader.read_exact(&mut buf)?;

read_record_batch(&buf, batch, self.schema(), &self.dictionaries_by_field).map(Some)
read_record_batch(&buf, batch, self.schema(), &self.dictionaries_by_field, self.projection.as_ref().map(|x| x.0.as_ref())).map(Some)
}
ipc::MessageHeader::DictionaryBatch => {
let batch = message.header_as_dictionary_batch().ok_or_else(|| {
Expand Down Expand Up @@ -998,7 +1044,7 @@ mod tests {
))
.unwrap();

let mut reader = FileReader::try_new(file).unwrap();
let mut reader = FileReader::try_new(file, None).unwrap();

// read expected JSON output
let arrow_json = read_gzip_json(version, path);
Expand All @@ -1015,7 +1061,7 @@ mod tests {
testdata
))
.unwrap();
FileReader::try_new(file).unwrap();
FileReader::try_new(file, None).unwrap();
}

#[test]
Expand All @@ -1031,7 +1077,7 @@ mod tests {
testdata
))
.unwrap();
FileReader::try_new(file).unwrap();
FileReader::try_new(file, None).unwrap();
}

#[test]
Expand All @@ -1056,7 +1102,39 @@ mod tests {
))
.unwrap();

FileReader::try_new(file).unwrap();
FileReader::try_new(file, None).unwrap();
});
}

#[test]
fn projection_should_work() {
// complementary to the previous test
let testdata = crate::util::test_util::arrow_test_data();
let paths = vec![
"generated_interval",
"generated_datetime",
// "generated_map", Err: Last offset 872415232 of Utf8 is larger than values length 52 (https://github.com/apache/arrow-rs/issues/859)
"generated_nested",
"generated_null_trivial",
"generated_null",
"generated_primitive_no_batches",
"generated_primitive_zerolength",
"generated_primitive",
];
paths.iter().for_each(|path| {
let file = File::open(format!(
"{}/arrow-ipc-stream/integration/1.0.0-bigendian/{}.arrow_file",
testdata, path
))
.unwrap();

let reader = FileReader::try_new(file, Some(vec![0])).unwrap();
let datatype_0 = reader.schema().fields()[0].data_type().clone();
reader.for_each(|batch| {
let batch = batch.unwrap();
assert_eq!(batch.columns().len(), 1);
assert_eq!(datatype_0, batch.schema().fields()[0].data_type().clone());
});
});
}

Expand All @@ -1083,7 +1161,7 @@ mod tests {
))
.unwrap();

let mut reader = StreamReader::try_new(file).unwrap();
let mut reader = StreamReader::try_new(file, None).unwrap();

// read expected JSON output
let arrow_json = read_gzip_json(version, path);
Expand Down Expand Up @@ -1120,7 +1198,7 @@ mod tests {
))
.unwrap();

let mut reader = FileReader::try_new(file).unwrap();
let mut reader = FileReader::try_new(file, None).unwrap();

// read expected JSON output
let arrow_json = read_gzip_json(version, path);
Expand Down Expand Up @@ -1153,7 +1231,7 @@ mod tests {
))
.unwrap();

let mut reader = StreamReader::try_new(file).unwrap();
let mut reader = StreamReader::try_new(file, None).unwrap();

// read expected JSON output
let arrow_json = read_gzip_json(version, path);
Expand Down Expand Up @@ -1189,7 +1267,7 @@ mod tests {

// read stream back
let file = File::open("target/debug/testdata/float.stream").unwrap();
let reader = StreamReader::try_new(file).unwrap();
let reader = StreamReader::try_new(file, None).unwrap();

reader.for_each(|batch| {
let batch = batch.unwrap();
Expand All @@ -1211,7 +1289,19 @@ mod tests {
.value(0)
!= 0.0
);
})
});

let file = File::open("target/debug/testdata/float.stream").unwrap();

// Read with projection
let reader = StreamReader::try_new(file, Some(vec![0, 3])).unwrap();

reader.for_each(|batch| {
let batch = batch.unwrap();
assert_eq!(batch.schema().fields().len(), 2);
assert_eq!(batch.schema().fields()[0].data_type(), &DataType::Float32);
assert_eq!(batch.schema().fields()[1].data_type(), &DataType::Int32);
});
}

fn roundtrip_ipc(rb: &RecordBatch) -> RecordBatch {
Expand All @@ -1223,7 +1313,7 @@ mod tests {
drop(writer);

let mut reader =
ipc::reader::FileReader::try_new(std::io::Cursor::new(buf)).unwrap();
ipc::reader::FileReader::try_new(std::io::Cursor::new(buf), None).unwrap();
reader.next().unwrap().unwrap()
}

Expand Down
Loading

0 comments on commit 4bcc7a6

Please sign in to comment.