Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Truncate IPC record batch #2040

Merged
merged 11 commits into from
Jul 14, 2022
1 change: 1 addition & 0 deletions arrow/src/array/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,7 @@ pub(crate) use self::data::layout;
pub use self::data::ArrayData;
pub use self::data::ArrayDataBuilder;
pub use self::data::ArrayDataRef;
pub(crate) use self::data::BufferSpec;

pub use self::array_binary::BinaryArray;
pub use self::array_binary::FixedSizeBinaryArray;
Expand Down
15 changes: 15 additions & 0 deletions arrow/src/datatypes/datatype.rs
Original file line number Diff line number Diff line change
Expand Up @@ -721,6 +721,21 @@ impl DataType {
)
}

/// Returns true if this type is temporal: (Date*, Time*, Duration, or Interval).
pub fn is_temporal(t: &DataType) -> bool {
use DataType::*;
matches!(
t,
Date32
| Date64
| Timestamp(_, _)
| Time32(_)
| Time64(_)
| Duration(_)
| Interval(_)
)
}

/// Returns true if this type is valid as a dictionary key
/// (e.g. [`super::ArrowDictionaryKeyType`]
pub fn is_dictionary_key_type(t: &DataType) -> bool {
Expand Down
293 changes: 285 additions & 8 deletions arrow/src/ipc/writer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,17 @@
//! The `FileWriter` and `StreamWriter` have similar interfaces,
//! however the `FileWriter` expects a reader that supports `Seek`ing

use std::cmp::min;
use std::collections::HashMap;
use std::io::{BufWriter, Write};

use flatbuffers::FlatBufferBuilder;

use crate::array::{
as_large_list_array, as_list_array, as_map_array, as_struct_array, as_union_array,
make_array, Array, ArrayData, ArrayRef, FixedSizeListArray,
layout, make_array, Array, ArrayData, ArrayRef, BinaryArray, BufferBuilder,
BufferSpec, FixedSizeListArray, GenericBinaryArray, GenericStringArray,
LargeBinaryArray, LargeStringArray, OffsetSizeTrait, StringArray,
};
use crate::buffer::{Buffer, MutableBuffer};
use crate::datatypes::*;
Expand Down Expand Up @@ -861,6 +864,106 @@ fn has_validity_bitmap(data_type: &DataType, write_options: &IpcWriteOptions) ->
}
}

/// Whether to truncate the buffer
#[inline]
fn buffer_need_truncate(
array_offset: usize,
buffer: &Buffer,
spec: &BufferSpec,
min_length: usize,
) -> bool {
spec != &BufferSpec::AlwaysNull && (array_offset != 0 || min_length < buffer.len())
}
viirya marked this conversation as resolved.
Show resolved Hide resolved

/// Returns byte width for a buffer spec. Only for `BufferSpec::FixedWidth`.
#[inline]
fn get_buffer_element_width(spec: &BufferSpec) -> usize {
match spec {
BufferSpec::FixedWidth { byte_width } => *byte_width,
_ => 0,
}
}

/// Returns the number of total bytes in base binary arrays.
fn get_binary_buffer_len(array_data: &ArrayData) -> usize {
if array_data.is_empty() {
return 0;
}
match array_data.data_type() {
DataType::Binary => {
let array: BinaryArray = array_data.clone().into();
let offsets = array.value_offsets();
(offsets[array_data.len()] - offsets[0]) as usize
viirya marked this conversation as resolved.
Show resolved Hide resolved
}
DataType::LargeBinary => {
let array: LargeBinaryArray = array_data.clone().into();
let offsets = array.value_offsets();
(offsets[array_data.len()] - offsets[0]) as usize
}
DataType::Utf8 => {
let array: StringArray = array_data.clone().into();
let offsets = array.value_offsets();
(offsets[array_data.len()] - offsets[0]) as usize
}
DataType::LargeUtf8 => {
let array: LargeStringArray = array_data.clone().into();
let offsets = array.value_offsets();
(offsets[array_data.len()] - offsets[0]) as usize
}
_ => unreachable!(),
}
}

/// Rebase value offsets for given ArrayData to zero-based.
fn get_zero_based_value_offsets<OffsetSize: OffsetSizeTrait>(
array_data: &ArrayData,
) -> Buffer {
match array_data.data_type() {
DataType::Binary | DataType::LargeBinary => {
let array: GenericBinaryArray<OffsetSize> = array_data.clone().into();
let offsets = array.value_offsets();
let start_offset = offsets[0];
viirya marked this conversation as resolved.
Show resolved Hide resolved

let mut builder = BufferBuilder::<OffsetSize>::new(array_data.len() + 1);
for x in offsets {
builder.append(*x - start_offset);
}

builder.finish()
}
DataType::Utf8 | DataType::LargeUtf8 => {
let array: GenericStringArray<OffsetSize> = array_data.clone().into();
let offsets = array.value_offsets();
let start_offset = offsets[0];

let mut builder = BufferBuilder::<OffsetSize>::new(array_data.len() + 1);
for x in offsets {
builder.append(*x - start_offset);
}

builder.finish()
}
_ => unreachable!(),
}
}

/// Returns the start offset of base binary array.
fn get_buffer_offset<OffsetSize: OffsetSizeTrait>(array_data: &ArrayData) -> OffsetSize {
match array_data.data_type() {
DataType::Binary | DataType::LargeBinary => {
let array: GenericBinaryArray<OffsetSize> = array_data.clone().into();
let offsets = array.value_offsets();
offsets[0]
}
DataType::Utf8 | DataType::LargeUtf8 => {
let array: GenericStringArray<OffsetSize> = array_data.clone().into();
let offsets = array.value_offsets();
offsets[0]
}
_ => unreachable!(),
}
}

/// Write array data to a vector of bytes
#[allow(clippy::too_many_arguments)]
fn write_array_data(
Expand Down Expand Up @@ -891,15 +994,80 @@ fn write_array_data(
let buffer = buffer.with_bitset(num_bytes, true);
buffer.into()
}
Some(buffer) => buffer.clone(),
Some(buffer) => buffer.bit_slice(array_data.offset(), array_data.len()),
};

offset = write_buffer(&null_buffer, buffers, arrow_data, offset);
offset = write_buffer(null_buffer.as_slice(), buffers, arrow_data, offset);
viirya marked this conversation as resolved.
Show resolved Hide resolved
}

array_data.buffers().iter().for_each(|buffer| {
offset = write_buffer(buffer, buffers, arrow_data, offset);
});
let data_type = array_data.data_type();
viirya marked this conversation as resolved.
Show resolved Hide resolved
if matches!(
data_type,
DataType::Binary | DataType::LargeBinary | DataType::Utf8 | DataType::LargeUtf8
) {
let total_bytes = get_binary_buffer_len(array_data);
let value_buffer = &array_data.buffers()[1];
if buffer_need_truncate(
array_data.offset(),
value_buffer,
&BufferSpec::VariableWidth,
total_bytes,
) {
// Rebase offsets and truncate values
let (new_offsets, byte_offset) =
if matches!(data_type, DataType::Binary | DataType::Utf8) {
(
get_zero_based_value_offsets::<i32>(array_data),
get_buffer_offset::<i32>(array_data) as usize,
)
} else {
(
get_zero_based_value_offsets::<i64>(array_data),
get_buffer_offset::<i64>(array_data) as usize,
)
};

offset = write_buffer(new_offsets.as_slice(), buffers, arrow_data, offset);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Definitely something that could be left to another PR, but I do wonder if there might be some way to write the new offsets directly without buffering them first. I'm not very familiar with the IPC code so not sure how feasible this may be


let buffer_length = min(total_bytes, value_buffer.len() - byte_offset);
let buffer_slice =
&value_buffer.as_slice()[byte_offset..(byte_offset + buffer_length)];
offset = write_buffer(buffer_slice, buffers, arrow_data, offset);
} else {
array_data.buffers().iter().for_each(|buffer| {
offset = write_buffer(buffer.as_slice(), buffers, arrow_data, offset);
});
}
} else if DataType::is_numeric(data_type)
|| DataType::is_temporal(data_type)
|| matches!(
array_data.data_type(),
DataType::FixedSizeBinary(_) | DataType::Dictionary(_, _)
)
{
// Truncate values
assert!(array_data.buffers().len() == 1);

let buffer = &array_data.buffers()[0];
let layout = layout(data_type);
let spec = &layout.buffers[0];

let byte_width = get_buffer_element_width(spec);
let min_length = array_data.len() * byte_width;
if buffer_need_truncate(array_data.offset(), buffer, spec, min_length) {
let byte_offset = array_data.offset() * byte_width;
let buffer_length = min(min_length, buffer.len() - byte_offset);
let buffer_slice =
&buffer.as_slice()[byte_offset..(byte_offset + buffer_length)];
offset = write_buffer(buffer_slice, buffers, arrow_data, offset);
} else {
offset = write_buffer(buffer.as_slice(), buffers, arrow_data, offset);
}
} else {
array_data.buffers().iter().for_each(|buffer| {
offset = write_buffer(buffer, buffers, arrow_data, offset);
});
}

if !matches!(array_data.data_type(), DataType::Dictionary(_, _)) {
// recursively write out nested structures
Expand All @@ -923,7 +1091,7 @@ fn write_array_data(

/// Write a buffer to a vector of bytes, and add its ipc::Buffer to a vector
fn write_buffer(
buffer: &Buffer,
buffer: &[u8],
buffers: &mut Vec<ipc::Buffer>,
arrow_data: &mut Vec<u8>,
offset: i64,
Expand All @@ -933,7 +1101,7 @@ fn write_buffer(
let total_len: i64 = (len + pad_len) as i64;
// assert_eq!(len % 8, 0, "Buffer width not a multiple of 8 bytes");
buffers.push(ipc::Buffer::new(offset, total_len));
arrow_data.extend_from_slice(buffer.as_slice());
arrow_data.extend_from_slice(buffer);
arrow_data.extend_from_slice(&vec![0u8; pad_len][..]);
offset + total_len
}
Expand Down Expand Up @@ -1507,4 +1675,113 @@ mod tests {
IpcWriteOptions::try_new(8, false, MetadataVersion::V5).unwrap(),
);
}

fn serialize(record: &RecordBatch) -> Vec<u8> {
let buffer: Vec<u8> = Vec::new();
let mut stream_writer = StreamWriter::try_new(buffer, &record.schema()).unwrap();
stream_writer.write(record).unwrap();
stream_writer.finish().unwrap();
stream_writer.into_inner().unwrap()
}

fn deserialize(bytes: Vec<u8>) -> RecordBatch {
let mut stream_reader =
ipc::reader::StreamReader::try_new(std::io::Cursor::new(bytes), None)
.unwrap();
stream_reader.next().unwrap().unwrap()
}

#[test]
fn truncate_ipc_record_batch() {
fn create_batch(rows: usize) -> RecordBatch {
let schema = Schema::new(vec![
Field::new("a", DataType::Int32, false),
Field::new("b", DataType::Utf8, false),
]);

let a = Int32Array::from_iter_values(0..rows as i32);
let b = StringArray::from_iter_values((0..rows).map(|i| i.to_string()));

RecordBatch::try_new(Arc::new(schema), vec![Arc::new(a), Arc::new(b)])
.unwrap()
}

let big_record_batch = create_batch(65536);

let length = 5;
let small_record_batch = create_batch(length);

let offset = 2;
let record_batch_slice = big_record_batch.slice(offset, length);
assert!(
serialize(&big_record_batch).len() > serialize(&small_record_batch).len()
);
assert_eq!(
serialize(&small_record_batch).len(),
serialize(&record_batch_slice).len()
);

assert_eq!(
deserialize(serialize(&record_batch_slice)),
record_batch_slice
);
}

#[test]
fn truncate_ipc_record_batch_with_nulls() {
fn create_batch() -> RecordBatch {
let schema = Schema::new(vec![
Field::new("a", DataType::Int32, true),
Field::new("b", DataType::Utf8, true),
]);

let a = Int32Array::from(vec![Some(1), None, Some(1), None, Some(1)]);
let b = StringArray::from(vec![None, Some("a"), Some("a"), None, Some("a")]);

RecordBatch::try_new(Arc::new(schema), vec![Arc::new(a), Arc::new(b)])
.unwrap()
}

let record_batch = create_batch();
let record_batch_slice = record_batch.slice(1, 2);
let deserialized_batch = deserialize(serialize(&record_batch_slice));

assert!(serialize(&record_batch).len() > serialize(&record_batch_slice).len());

assert!(deserialized_batch.column(0).is_null(0));
assert!(deserialized_batch.column(0).is_valid(1));
assert!(deserialized_batch.column(1).is_valid(0));
assert!(deserialized_batch.column(1).is_valid(1));

assert_eq!(record_batch_slice, deserialized_batch);
}

#[test]
fn truncate_ipc_dictionary_array() {
fn create_batch() -> RecordBatch {
let values: StringArray = [Some("foo"), Some("bar"), Some("baz")]
.into_iter()
.collect();
let keys: Int32Array =
[Some(0), Some(2), None, Some(1)].into_iter().collect();

let array = DictionaryArray::<Int32Type>::try_new(&keys, &values).unwrap();

let schema =
Schema::new(vec![Field::new("dict", array.data_type().clone(), true)]);

RecordBatch::try_new(Arc::new(schema), vec![Arc::new(array)]).unwrap()
}

let record_batch = create_batch();
let record_batch_slice = record_batch.slice(1, 2);
let deserialized_batch = deserialize(serialize(&record_batch_slice));

assert!(serialize(&record_batch).len() > serialize(&record_batch_slice).len());

assert!(deserialized_batch.column(0).is_valid(0));
assert!(deserialized_batch.column(0).is_null(1));

assert_eq!(record_batch_slice, deserialized_batch);
}
}