Skip to content

Commit

Permalink
feat(ipc): Support writing dictionaries nested in structs and unions (#…
Browse files Browse the repository at this point in the history
…870) (#915)

* feat(ipc): Support for writing dictionaries nested in structs and unions

Dictionaries are lost when serializing a RecordBatch for IPC, producing
invalid arrow data. This PR changes encoded_batch to recursively find
all dictionary fields within the schema (currently only in structs and
unions) so nested dictionaries are properly serialized.

* address lint and clippy

Co-authored-by: Helgi Kristvin Sigurbjarnarson <helgikrs@gmail.com>
  • Loading branch information
alamb and helgikrs committed Nov 5, 2021
1 parent 03d95e6 commit 8540214
Show file tree
Hide file tree
Showing 3 changed files with 127 additions and 14 deletions.
1 change: 1 addition & 0 deletions arrow/src/array/cast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -92,3 +92,4 @@ array_downcast_fn!(as_largestring_array, LargeStringArray);
array_downcast_fn!(as_boolean_array, BooleanArray);
array_downcast_fn!(as_null_array, NullArray);
array_downcast_fn!(as_struct_array, StructArray);
array_downcast_fn!(as_union_array, UnionArray);
2 changes: 1 addition & 1 deletion arrow/src/array/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -441,7 +441,7 @@ pub use self::ord::{build_compare, DynComparator};
pub use self::cast::{
as_boolean_array, as_dictionary_array, as_generic_binary_array,
as_generic_list_array, as_large_list_array, as_largestring_array, as_list_array,
as_null_array, as_primitive_array, as_string_array, as_struct_array,
as_null_array, as_primitive_array, as_string_array, as_struct_array, as_union_array,
};

// ------------------------------ C Data Interface ---------------------------
Expand Down
138 changes: 125 additions & 13 deletions arrow/src/ipc/writer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ use std::io::{BufWriter, Write};

use flatbuffers::FlatBufferBuilder;

use crate::array::{ArrayData, ArrayRef};
use crate::array::{as_struct_array, as_union_array, ArrayData, ArrayRef};
use crate::buffer::{Buffer, MutableBuffer};
use crate::datatypes::*;
use crate::error::{ArrowError, Result};
Expand Down Expand Up @@ -137,20 +137,45 @@ impl IpcDataGenerator {
}
}

pub fn encoded_batch(
fn encode_dictionaries(
&self,
batch: &RecordBatch,
field: &Field,
column: &ArrayRef,
encoded_dictionaries: &mut Vec<EncodedData>,
dictionary_tracker: &mut DictionaryTracker,
write_options: &IpcWriteOptions,
) -> Result<(Vec<EncodedData>, EncodedData)> {
// TODO: handle nested dictionaries
let schema = batch.schema();
let mut encoded_dictionaries = Vec::with_capacity(schema.fields().len());

for (i, field) in schema.fields().iter().enumerate() {
let column = batch.column(i);

if let DataType::Dictionary(_key_type, _value_type) = column.data_type() {
) -> Result<()> {
// TODO: Handle other nested types (map, list, etc)
match column.data_type() {
DataType::Struct(fields) => {
let s = as_struct_array(column);
for (field, &column) in fields.iter().zip(s.columns().iter()) {
self.encode_dictionaries(
field,
column,
encoded_dictionaries,
dictionary_tracker,
write_options,
)?;
}
}
DataType::Union(fields) => {
let union = as_union_array(column);
for (field, ref column) in fields
.iter()
.enumerate()
.map(|(n, f)| (f, union.child(n as i8)))
{
self.encode_dictionaries(
field,
column,
encoded_dictionaries,
dictionary_tracker,
write_options,
)?;
}
}
DataType::Dictionary(_key_type, _value_type) => {
let dict_id = field
.dict_id()
.expect("All Dictionary types have `dict_id`");
Expand All @@ -167,10 +192,33 @@ impl IpcDataGenerator {
));
}
}
_ => (),
}

let encoded_message = self.record_batch_to_bytes(batch, write_options);
Ok(())
}

pub fn encoded_batch(
&self,
batch: &RecordBatch,
dictionary_tracker: &mut DictionaryTracker,
write_options: &IpcWriteOptions,
) -> Result<(Vec<EncodedData>, EncodedData)> {
let schema = batch.schema();
let mut encoded_dictionaries = Vec::with_capacity(schema.fields().len());

for (i, field) in schema.fields().iter().enumerate() {
let column = batch.column(i);
self.encode_dictionaries(
field,
column,
&mut encoded_dictionaries,
dictionary_tracker,
write_options,
)?;
}

let encoded_message = self.record_batch_to_bytes(batch, write_options);
Ok((encoded_dictionaries, encoded_message))
}

Expand Down Expand Up @@ -1161,4 +1209,68 @@ mod tests {
let arrow_json: ArrowJson = serde_json::from_str(&s).unwrap();
arrow_json
}

#[test]
fn track_union_nested_dict() {
let inner: DictionaryArray<Int32Type> = vec!["a", "b", "a"].into_iter().collect();

let array = Arc::new(inner) as ArrayRef;

// Dict field with id 2
let dctfield =
Field::new_dict("dict", array.data_type().clone(), false, 2, false);

let types = Buffer::from_slice_ref(&[0_i8, 0, 0]);
let offsets = Buffer::from_slice_ref(&[0_i32, 1, 2]);

let union =
UnionArray::try_new(types, Some(offsets), vec![(dctfield, array)], None)
.unwrap();

let schema = Arc::new(Schema::new(vec![Field::new(
"union",
union.data_type().clone(),
false,
)]));

let batch = RecordBatch::try_new(schema, vec![Arc::new(union)]).unwrap();

let gen = IpcDataGenerator {};
let mut dict_tracker = DictionaryTracker::new(false);
gen.encoded_batch(&batch, &mut dict_tracker, &Default::default())
.unwrap();

// Dictionary with id 2 should have been written to the dict tracker
assert!(dict_tracker.written.contains_key(&2));
}

#[test]
fn track_struct_nested_dict() {
let inner: DictionaryArray<Int32Type> = vec!["a", "b", "a"].into_iter().collect();

let array = Arc::new(inner) as ArrayRef;

// Dict field with id 2
let dctfield =
Field::new_dict("dict", array.data_type().clone(), false, 2, false);

let s = StructArray::from(vec![(dctfield, array)]);
let struct_array = Arc::new(s) as ArrayRef;

let schema = Arc::new(Schema::new(vec![Field::new(
"struct",
struct_array.data_type().clone(),
false,
)]));

let batch = RecordBatch::try_new(schema, vec![struct_array]).unwrap();

let gen = IpcDataGenerator {};
let mut dict_tracker = DictionaryTracker::new(false);
gen.encoded_batch(&batch, &mut dict_tracker, &Default::default())
.unwrap();

// Dictionary with id 2 should have been written to the dict tracker
assert!(dict_tracker.written.contains_key(&2));
}
}

0 comments on commit 8540214

Please sign in to comment.