Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
166 changes: 128 additions & 38 deletions arrow-ipc/src/tests/delta_dictionary.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,11 @@ use crate::{
writer::FileWriter,
};
use arrow_array::{
Array, ArrayRef, DictionaryArray, RecordBatch, StringArray, builder::StringDictionaryBuilder,
Array, ArrayRef, DictionaryArray, ListArray, RecordBatch, StringArray, StructArray,
builder::{ArrayBuilder, ListBuilder, StringDictionaryBuilder, StructBuilder},
types::Int32Type,
};

use arrow_schema::{DataType, Field, Schema};
use std::io::Cursor;
use std::sync::Arc;
Expand All @@ -35,7 +37,7 @@ use std::sync::Arc;
fn test_zero_row_dict() {
let batches: &[&[&str]] = &[&[], &["A"], &[], &["B", "C"], &[]];
run_delta_sequence_test(
batches,
&build_batches(batches),
&[
MessageType::Dict(vec![]),
MessageType::RecordBatch,
Expand All @@ -48,7 +50,7 @@ fn test_zero_row_dict() {
);

run_resend_sequence_test(
batches,
&build_batches(batches),
&[
MessageType::Dict(vec![]),
MessageType::RecordBatch,
Expand All @@ -72,7 +74,7 @@ fn test_mixed_delta() {
];

run_delta_sequence_test(
batches,
&build_batches(batches),
&[
MessageType::Dict(str_vec(&["A"])),
MessageType::RecordBatch,
Expand All @@ -87,7 +89,7 @@ fn test_mixed_delta() {
);

run_resend_sequence_test(
batches,
&build_batches(batches),
&[
MessageType::Dict(str_vec(&["A"])),
MessageType::RecordBatch,
Expand All @@ -106,7 +108,7 @@ fn test_mixed_delta() {
fn test_disjoint_delta() {
let batches: &[&[&str]] = &[&["A"], &["B"], &["C", "E"]];
run_delta_sequence_test(
batches,
&build_batches(batches),
&[
MessageType::Dict(str_vec(&["A"])),
MessageType::RecordBatch,
Expand All @@ -118,7 +120,7 @@ fn test_disjoint_delta() {
);

run_resend_sequence_test(
batches,
&build_batches(batches),
&[
MessageType::Dict(str_vec(&["A"])),
MessageType::RecordBatch,
Expand All @@ -134,7 +136,7 @@ fn test_disjoint_delta() {
fn test_increasing_delta() {
let batches: &[&[&str]] = &[&["A"], &["A", "B"], &["A", "B", "C"]];
run_delta_sequence_test(
batches,
&build_batches(batches),
&[
MessageType::Dict(str_vec(&["A"])),
MessageType::RecordBatch,
Expand All @@ -146,7 +148,7 @@ fn test_increasing_delta() {
);

run_resend_sequence_test(
batches,
&build_batches(batches),
&[
MessageType::Dict(str_vec(&["A"])),
MessageType::RecordBatch,
Expand All @@ -162,7 +164,7 @@ fn test_increasing_delta() {
fn test_single_delta() {
let batches: &[&[&str]] = &[&["A", "B", "C"], &["D"]];
run_delta_sequence_test(
batches,
&build_batches(batches),
&[
MessageType::Dict(str_vec(&["A", "B", "C"])),
MessageType::RecordBatch,
Expand All @@ -172,7 +174,7 @@ fn test_single_delta() {
);

run_resend_sequence_test(
batches,
&build_batches(batches),
&[
MessageType::Dict(str_vec(&["A", "B", "C"])),
MessageType::RecordBatch,
Expand All @@ -186,7 +188,7 @@ fn test_single_delta() {
fn test_single_same_value_sequence() {
let batches: &[&[&str]] = &[&["A"], &["A"], &["A"], &["A"]];
run_delta_sequence_test(
batches,
&build_batches(batches),
&[
MessageType::Dict(str_vec(&["A"])),
MessageType::RecordBatch,
Expand All @@ -197,7 +199,7 @@ fn test_single_same_value_sequence() {
);

run_resend_sequence_test(
batches,
&build_batches(batches),
&[
MessageType::Dict(str_vec(&["A"])),
MessageType::RecordBatch,
Expand All @@ -216,7 +218,7 @@ fn str_vec(strings: &[&str]) -> Vec<String> {
fn test_multi_same_value_sequence() {
let batches: &[&[&str]] = &[&["A", "B", "C"], &["A", "B", "C"]];
run_delta_sequence_test(
batches,
&build_batches(batches),
&[
MessageType::Dict(str_vec(&["A", "B", "C"])),
MessageType::RecordBatch,
Expand All @@ -232,17 +234,17 @@ enum MessageType {
RecordBatch,
}

fn run_resend_sequence_test(batches: &[&[&str]], sequence: &[MessageType]) {
fn run_resend_sequence_test(batches: &[RecordBatch], sequence: &[MessageType]) {
let opts = IpcWriteOptions::default().with_dictionary_handling(DictionaryHandling::Resend);
run_sequence_test(batches, sequence, opts);
}

fn run_delta_sequence_test(batches: &[&[&str]], sequence: &[MessageType]) {
fn run_delta_sequence_test(batches: &[RecordBatch], sequence: &[MessageType]) {
let opts = IpcWriteOptions::default().with_dictionary_handling(DictionaryHandling::Delta);
run_sequence_test(batches, sequence, opts);
}

fn run_sequence_test(batches: &[&[&str]], sequence: &[MessageType], options: IpcWriteOptions) {
fn run_sequence_test(batches: &[RecordBatch], sequence: &[MessageType], options: IpcWriteOptions) {
let stream_buf = write_all_to_stream(options.clone(), batches);
let ipc_stream = get_ipc_message_stream(stream_buf);
for (message, expected) in ipc_stream.iter().zip(sequence.iter()) {
Expand Down Expand Up @@ -310,7 +312,7 @@ fn test_replace_same_length() {
&["A", "B", "C", "D", "E", "F"],
&["A", "G", "H", "I", "J", "K"],
];
run_parity_test(batches);
run_parity_test(&build_batches(batches));
}

#[test]
Expand All @@ -323,22 +325,34 @@ fn test_sparse_deltas() {
&["parquet", "B"],
&["123", "B", "C"],
];
run_parity_test(batches);
run_parity_test(&build_batches(batches));
}

#[test]
fn test_deltas_with_reset() {
// Dictionary resets at ["C", "D"]
let batches: &[&[&str]] = &[&["A"], &["A", "B"], &["C", "D"], &["A", "B", "C", "D"]];
run_parity_test(batches);
run_parity_test(&build_batches(batches));
}

/// FileWriter can only tolerate very specific patterns of delta dictionaries,
/// because the dictionary cannot be replaced/reset.
#[test]
fn test_deltas_with_file() {
let batches: &[&[&str]] = &[&["A"], &["A", "B"], &["A", "B", "C"], &["A", "B", "C", "D"]];
run_parity_test(batches);
run_parity_test(&build_batches(batches));
}

#[test]
fn test_deltas_with_in_struct() {
let batches: &[&[&str]] = &[&["A"], &["A", "B"], &["A", "B", "C"], &["A", "B", "C", "D"]];
run_parity_test(&build_struct_batches(batches));
}

#[test]
fn test_deltas_with_in_list() {
let batches: &[&[&str]] = &[&["A"], &["A", "B"], &["A", "B", "C"], &["A", "B", "C", "D"]];
run_parity_test(&build_list_batches(batches));
}

/// Encode all batches three times and compare all three for the same results
Expand All @@ -348,7 +362,7 @@ fn test_deltas_with_file() {
/// - Stream encoding without delta
/// - File encoding with delta (File format does not allow replacement
/// dictionaries)
fn run_parity_test(batches: &[&[&str]]) {
fn run_parity_test(batches: &[RecordBatch]) {
let delta_options =
IpcWriteOptions::default().with_dictionary_handling(DictionaryHandling::Delta);
let delta_stream_buf = write_all_to_stream(delta_options.clone(), batches);
Expand All @@ -368,16 +382,16 @@ fn run_parity_test(batches: &[&[&str]]) {
let (first_stream, other_streams) = streams.split_first_mut().unwrap();

for (idx, batch) in first_stream.by_ref().enumerate() {
let first_dict = extract_dictionary(batch);
let expected_values = batches[idx];
assert_eq!(expected_values, &dict_to_vec(first_dict.clone()));
let first_dict = extract_dictionary(&batch);
let expected_values = dict_to_vec(&extract_dictionary(&batches[idx]));
assert_eq!(expected_values, dict_to_vec(&first_dict));

for stream in other_streams.iter_mut() {
let next_batch = stream
.next()
.expect("All streams should yield same number of elements");
let next_dict = extract_dictionary(next_batch);
assert_eq!(expected_values, &dict_to_vec(next_dict.clone()));
let next_dict = extract_dictionary(&next_batch);
assert_eq!(expected_values, dict_to_vec(&next_dict));
assert_eq!(first_dict, next_dict);
}
}
Expand All @@ -390,7 +404,7 @@ fn run_parity_test(batches: &[&[&str]]) {
}
}

fn dict_to_vec(dict: DictionaryArray<Int32Type>) -> Vec<String> {
fn dict_to_vec(dict: &DictionaryArray<Int32Type>) -> Vec<String> {
dict.downcast_dict::<StringArray>()
.unwrap()
.into_iter()
Expand Down Expand Up @@ -418,35 +432,43 @@ fn get_file_batches(buf: Vec<u8>) -> Box<dyn Iterator<Item = RecordBatch>> {
)
}

fn extract_dictionary(batch: RecordBatch) -> DictionaryArray<arrow_array::types::Int32Type> {
batch
.column(0)
fn extract_dictionary(batch: &RecordBatch) -> DictionaryArray<arrow_array::types::Int32Type> {
let mut column = batch.column(0);

// if we've been passed a struct, assume the first column contains the dict
if let Some(struct_arr) = column.as_any().downcast_ref::<StructArray>() {
column = struct_arr.column(0);
}

// if we've been passed a list, assume the lists' values are the dict
if let Some(list_arr) = column.as_any().downcast_ref::<ListArray>() {
column = list_arr.values();
}

column
.as_any()
.downcast_ref::<DictionaryArray<arrow_array::types::Int32Type>>()
.unwrap()
.clone()
}

fn write_all_to_file(options: IpcWriteOptions, vals: &[&[&str]]) -> Vec<u8> {
let batches = build_batches(vals);
fn write_all_to_file(options: IpcWriteOptions, batches: &[RecordBatch]) -> Vec<u8> {
let mut buf: Vec<u8> = Vec::new();
let mut writer =
FileWriter::try_new_with_options(&mut buf, &batches[0].schema(), options).unwrap();
for batch in batches {
writer.write(&batch).unwrap();
writer.write(batch).unwrap();
}
writer.finish().unwrap();
buf
}

fn write_all_to_stream(options: IpcWriteOptions, vals: &[&[&str]]) -> Vec<u8> {
let batches = build_batches(vals);

fn write_all_to_stream(options: IpcWriteOptions, batches: &[RecordBatch]) -> Vec<u8> {
let mut buf: Vec<u8> = Vec::new();
let mut writer =
StreamWriter::try_new_with_options(&mut buf, &batches[0].schema(), options).unwrap();
for batch in batches {
writer.write(&batch).unwrap();
writer.write(batch).unwrap();
}

writer.finish().unwrap();
Expand Down Expand Up @@ -477,3 +499,71 @@ fn build_batch(

RecordBatch::try_new(schema.clone(), vec![Arc::new(array) as ArrayRef]).unwrap()
}

/// build batches where the dictionary array is nested within a struct array. The dictionary array
/// is the first field within the struct.
fn build_struct_batches(vals: &[&[&str]]) -> Vec<RecordBatch> {
let total_vals = vals.iter().map(|v| v.len()).sum();
let mut struct_builder = StructBuilder::from_fields(
vec![Field::new(
"struct",
DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::Utf8)),
false,
)],
total_vals,
);

vals.iter()
.map(|v| build_struct_batch(v, &mut struct_builder))
.collect()
}

fn build_struct_batch(vals: &[&str], struct_builder: &mut StructBuilder) -> RecordBatch {
for &val in vals {
let dict_builder = struct_builder
.field_builder::<StringDictionaryBuilder<arrow_array::types::Int32Type>>(0)
.unwrap();
dict_builder.append_value(val);
struct_builder.append(true);
}

let array = struct_builder.finish_preserve_values();

let schema = Arc::new(Schema::new(vec![Field::new(
"dict",
array.data_type().clone(),
true,
)]));

RecordBatch::try_new(schema.clone(), vec![Arc::new(array) as ArrayRef]).unwrap()
}

/// builds batches where the dictionary array is nested within a list array
fn build_list_batches(vals: &[&[&str]]) -> Vec<RecordBatch> {
let mut list_builder = ListBuilder::new(StringDictionaryBuilder::<Int32Type>::new());

vals.iter()
.map(|v| build_list_batch(v, &mut list_builder))
.collect()
}

fn build_list_batch(
vals: &[&str],
list_builder: &mut ListBuilder<StringDictionaryBuilder<Int32Type>>,
) -> RecordBatch {
for &val in vals {
let vals_builder = list_builder.values();
vals_builder.append(val).unwrap();
list_builder.append(true);
}

let array = list_builder.finish_preserve_values();

let schema = Arc::new(Schema::new(vec![Field::new(
"dict",
array.data_type().clone(),
true,
)]));

RecordBatch::try_new(schema.clone(), vec![Arc::new(array) as ArrayRef]).unwrap()
}
2 changes: 1 addition & 1 deletion arrow-ipc/src/writer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1351,7 +1351,7 @@ impl<W: Write> RecordBatchWriter for FileWriter<W> {
/// // You must set `.with_dictionary_handling(DictionaryHandling::Delta)` to
/// // enable delta dictionaries in the writer
/// let options = IpcWriteOptions::default().with_dictionary_handling(DictionaryHandling::Delta);
/// let mut writer = StreamWriter::try_new(&mut stream, &schema).unwrap();
/// let mut writer = StreamWriter::try_new_with_options(&mut stream, &schema, options).unwrap();
///
/// // When writing the first batch, a dictionary message with 'a' and 'b' will be written
/// // prior to the record batch.
Expand Down
Loading