diff --git a/arrow/src/ipc/reader.rs b/arrow/src/ipc/reader.rs index 2a6619ec833..99c2dccb400 100644 --- a/arrow/src/ipc/reader.rs +++ b/arrow/src/ipc/reader.rs @@ -275,6 +275,120 @@ fn create_array( Ok((array, node_index, buffer_index)) } +/// Skip fields based on data types to advance `node_index` and `buffer_index`. +/// This function should be called when doing projection in fn `read_record_batch`. +/// The advancement logic references fn `create_array`. +fn skip_field( + nodes: &[ipc::FieldNode], + field: &Field, + data: &[u8], + buffers: &[ipc::Buffer], + dictionaries_by_id: &HashMap, + mut node_index: usize, + mut buffer_index: usize, +) -> Result<(usize, usize)> { + use DataType::*; + let data_type = field.data_type(); + match data_type { + Utf8 | Binary | LargeBinary | LargeUtf8 => { + node_index += 1; + buffer_index += 3; + } + FixedSizeBinary(_) => { + node_index += 1; + buffer_index += 2; + } + List(ref list_field) | LargeList(ref list_field) | Map(ref list_field, _) => { + node_index += 1; + buffer_index += 2; + let tuple = skip_field( + nodes, + list_field, + data, + buffers, + dictionaries_by_id, + node_index, + buffer_index, + )?; + node_index = tuple.0; + buffer_index = tuple.1; + } + FixedSizeList(ref list_field, _) => { + node_index += 1; + buffer_index += 1; + let tuple = skip_field( + nodes, + list_field, + data, + buffers, + dictionaries_by_id, + node_index, + buffer_index, + )?; + node_index = tuple.0; + buffer_index = tuple.1; + } + Struct(struct_fields) => { + node_index += 1; + buffer_index += 1; + + // skip for each field + for struct_field in struct_fields { + let tuple = skip_field( + nodes, + struct_field, + data, + buffers, + dictionaries_by_id, + node_index, + buffer_index, + )?; + node_index = tuple.0; + buffer_index = tuple.1; + } + } + Dictionary(_, _) => { + node_index += 1; + buffer_index += 2; + } + Union(fields, _field_type_ids, mode) => { + node_index += 1; + buffer_index += 1; + + match mode { + UnionMode::Dense => { + buffer_index += 1; + } + UnionMode::Sparse => {} + }; + + for field in fields { + let tuple = skip_field( + nodes, + field, + data, + buffers, + dictionaries_by_id, + node_index, + buffer_index, + )?; + + node_index = tuple.0; + buffer_index = tuple.1; + } + } + Null => { + node_index += 1; + // no buffer increases + } + _ => { + node_index += 1; + buffer_index += 2; + } + }; + Ok((node_index, buffer_index)) +} + /// Reads the correct number of buffers based on data type and null_count, and creates a /// primitive array ref fn create_primitive_array( @@ -493,21 +607,37 @@ pub fn read_record_batch( let mut arrays = vec![]; if let Some(projection) = projection { - let fields = schema.fields(); - for &index in projection { - let field = &fields[index]; - let triple = create_array( - field_nodes, - field, - buf, - buffers, - dictionaries_by_id, - node_index, - buffer_index, - )?; - node_index = triple.1; - buffer_index = triple.2; - arrays.push(triple.0); + // project fields + for (idx, field) in schema.fields().iter().enumerate() { + // Create array for projected field + if projection.contains(&idx) { + let triple = create_array( + field_nodes, + field, + buf, + buffers, + dictionaries_by_id, + node_index, + buffer_index, + )?; + node_index = triple.1; + buffer_index = triple.2; + arrays.push(triple.0); + } else { + // Skip field. + // This must be called to advance `node_index` and `buffer_index`. + let tuple = skip_field( + field_nodes, + field, + buf, + buffers, + dictionaries_by_id, + node_index, + buffer_index, + )?; + node_index = tuple.0; + buffer_index = tuple.1; + } } RecordBatch::try_new(Arc::new(schema.project(projection)?), arrays) @@ -1032,7 +1162,7 @@ mod tests { use flate2::read::GzDecoder; - use crate::datatypes::{ArrowNativeType, Int8Type}; + use crate::datatypes::{ArrowNativeType, Float64Type, Int32Type, Int8Type}; use crate::{datatypes, util::integration_util::*}; #[test] @@ -1260,6 +1390,169 @@ mod tests { }); } + fn create_test_projection_schema() -> Schema { + // define field types + let list_data_type = + DataType::List(Box::new(Field::new("item", DataType::Int32, true))); + + let fixed_size_list_data_type = DataType::FixedSizeList( + Box::new(Field::new("item", DataType::Int32, false)), + 3, + ); + + let key_type = DataType::Int8; + let value_type = DataType::Utf8; + let dict_data_type = + DataType::Dictionary(Box::new(key_type), Box::new(value_type)); + + let union_fileds = vec![ + Field::new("a", DataType::Int32, false), + Field::new("b", DataType::Float64, false), + ]; + let union_data_type = DataType::Union(union_fileds, vec![0, 1], UnionMode::Dense); + + let struct_fields = vec![ + Field::new("id", DataType::Int32, false), + Field::new( + "list", + DataType::List(Box::new(Field::new("item", DataType::Int8, true))), + false, + ), + ]; + let struct_data_type = DataType::Struct(struct_fields); + + // define schema + Schema::new(vec![ + Field::new("f0", DataType::UInt32, false), + Field::new("f1", DataType::Utf8, false), + Field::new("f2", DataType::Boolean, false), + Field::new("f3", union_data_type, true), + Field::new("f4", DataType::Null, true), + Field::new("f5", DataType::Float64, true), + Field::new("f6", list_data_type, false), + Field::new("f7", DataType::FixedSizeBinary(3), true), + Field::new("f8", fixed_size_list_data_type, false), + Field::new("f9", struct_data_type, false), + Field::new("f10", DataType::Boolean, false), + Field::new("f11", dict_data_type, false), + Field::new("f12", DataType::Utf8, false), + ]) + } + + fn create_test_projection_batch_data(schema: &Schema) -> RecordBatch { + // set test data for each column + let array0 = UInt32Array::from(vec![1, 2, 3]); + let array1 = StringArray::from(vec!["foo", "bar", "baz"]); + let array2 = BooleanArray::from(vec![true, false, true]); + + let mut union_builder = UnionBuilder::new_dense(3); + union_builder.append::("a", 1).unwrap(); + union_builder.append::("b", 10.1).unwrap(); + union_builder.append_null::("b").unwrap(); + let array3 = union_builder.build().unwrap(); + + let array4 = NullArray::new(3); + let array5 = Float64Array::from(vec![Some(1.1), None, Some(3.3)]); + let array6_values = vec![ + Some(vec![Some(10), Some(10), Some(10)]), + Some(vec![Some(20), Some(20), Some(20)]), + Some(vec![Some(30), Some(30)]), + ]; + let array6 = ListArray::from_iter_primitive::(array6_values); + let array7_values = vec![vec![11, 12, 13], vec![22, 23, 24], vec![33, 34, 35]]; + let array7 = + FixedSizeBinaryArray::try_from_iter(array7_values.into_iter()).unwrap(); + + let array8_values = ArrayData::builder(DataType::Int32) + .len(9) + .add_buffer(Buffer::from_slice_ref(&[ + 40, 41, 42, 43, 44, 45, 46, 47, 48, + ])) + .build() + .unwrap(); + let array8_data = ArrayData::builder(schema.field(8).data_type().clone()) + .len(3) + .add_child_data(array8_values) + .build() + .unwrap(); + let array8 = FixedSizeListArray::from(array8_data); + + let array9_id: ArrayRef = Arc::new(Int32Array::from(vec![1001, 1002, 1003])); + let array9_list: ArrayRef = + Arc::new(ListArray::from_iter_primitive::(vec![ + Some(vec![Some(-10)]), + Some(vec![Some(-20), Some(-20), Some(-20)]), + Some(vec![Some(-30)]), + ])); + let array9 = ArrayDataBuilder::new(schema.field(9).data_type().clone()) + .add_child_data(array9_id.data().clone()) + .add_child_data(array9_list.data().clone()) + .len(3) + .build() + .unwrap(); + let array9: ArrayRef = Arc::new(StructArray::from(array9)); + + let array10 = BooleanArray::from(vec![false, false, true]); + + let array11_values = StringArray::from(vec!["x", "yy", "zzz"]); + let array11_keys = Int8Array::from_iter_values([1, 1, 2]); + let array11 = + DictionaryArray::::try_new(&array11_keys, &array11_values).unwrap(); + + let array12 = StringArray::from(vec!["a", "bb", "ccc"]); + + // create record batch + RecordBatch::try_new( + Arc::new(schema.clone()), + vec![ + Arc::new(array0), + Arc::new(array1), + Arc::new(array2), + Arc::new(array3), + Arc::new(array4), + Arc::new(array5), + Arc::new(array6), + Arc::new(array7), + Arc::new(array8), + Arc::new(array9), + Arc::new(array10), + Arc::new(array11), + Arc::new(array12), + ], + ) + .unwrap() + } + + #[test] + fn test_projection_array_values() { + // define schema + let schema = create_test_projection_schema(); + + // create record batch with test data + let batch = create_test_projection_batch_data(&schema); + + // write record batch in IPC format + let mut buf = Vec::new(); + { + let mut writer = ipc::writer::FileWriter::try_new(&mut buf, &schema).unwrap(); + writer.write(&batch).unwrap(); + writer.finish().unwrap(); + } + + // read record batch with projection + for index in 0..12 { + let projection = vec![index]; + let reader = + FileReader::try_new(std::io::Cursor::new(buf.clone()), Some(projection)); + let read_batch = reader.unwrap().next().unwrap().unwrap(); + let projected_column = read_batch.column(0); + let expected_column = batch.column(index); + + // check the projected column equals the expected column + assert_eq!(projected_column.as_ref(), expected_column.as_ref()); + } + } + #[test] fn test_arrow_single_float_row() { let schema = Schema::new(vec![