Skip to content

Commit

Permalink
Fix reading dictionaries from nested structs in ipc StreamReader (#…
Browse files Browse the repository at this point in the history
…1550)

* Fix reading dictionaries from nested structs in ipc `StreamReader`

* Fix clippy error

* Apply review comment about field naming in test
  • Loading branch information
dispanser committed Apr 13, 2022
1 parent c9549bb commit ffb9b0b
Showing 1 changed file with 46 additions and 1 deletion.
47 changes: 46 additions & 1 deletion arrow/src/ipc/reader.rs
Original file line number Diff line number Diff line change
Expand Up @@ -881,7 +881,7 @@ impl<R: Read> StreamReader<R> {
let schema = ipc::convert::fb_to_schema(ipc_schema);

// Create an array of optional dictionary value arrays, one per field.
let dictionaries_by_field = vec![None; schema.fields().len()];
let dictionaries_by_field = vec![None; schema.all_fields().len()];

let projection = match projection {
Some(projection_indices) => {
Expand Down Expand Up @@ -1317,6 +1317,19 @@ mod tests {
reader.next().unwrap().unwrap()
}

fn roundtrip_ipc_stream(rb: &RecordBatch) -> RecordBatch {
let mut buf = Vec::new();
let mut writer =
ipc::writer::StreamWriter::try_new(&mut buf, &rb.schema()).unwrap();
writer.write(rb).unwrap();
writer.finish().unwrap();
drop(writer);

let mut reader =
ipc::reader::StreamReader::try_new(std::io::Cursor::new(buf), None).unwrap();
reader.next().unwrap().unwrap()
}

#[test]
fn test_roundtrip_nested_dict() {
let inner: DictionaryArray<datatypes::Int32Type> =
Expand Down Expand Up @@ -1394,4 +1407,36 @@ mod tests {
let arrow_json: ArrowJson = serde_json::from_str(&s).unwrap();
arrow_json
}

#[test]
fn test_roundtrip_stream_nested_dict() {
let xs = vec!["AA", "BB", "AA", "CC", "BB"];
let dict = Arc::new(
xs.clone()
.into_iter()
.collect::<DictionaryArray<datatypes::Int8Type>>(),
);
let string_array: ArrayRef = Arc::new(StringArray::from(xs.clone()));
let struct_array = StructArray::from(vec![
(Field::new("f2.1", DataType::Utf8, false), string_array),
(
Field::new("f2.2_struct", dict.data_type().clone(), false),
dict.clone() as ArrayRef,
),
]);
let schema = Arc::new(Schema::new(vec![
Field::new("f1_string", DataType::Utf8, false),
Field::new("f2_struct", struct_array.data_type().clone(), false),
]));
let input_batch = RecordBatch::try_new(
schema,
vec![
Arc::new(StringArray::from(xs.clone())),
Arc::new(struct_array),
],
)
.unwrap();
let output_batch = roundtrip_ipc_stream(&input_batch);
assert_eq!(input_batch, output_batch);
}
}

0 comments on commit ffb9b0b

Please sign in to comment.