diff --git a/arrow/src/array/array.rs b/arrow/src/array/array.rs index d715bc49eb61..4702179c7839 100644 --- a/arrow/src/array/array.rs +++ b/arrow/src/array/array.rs @@ -296,6 +296,7 @@ pub fn make_array(data: ArrayData) -> ArrayRef { DataType::List(_) => Arc::new(ListArray::from(data)) as ArrayRef, DataType::LargeList(_) => Arc::new(LargeListArray::from(data)) as ArrayRef, DataType::Struct(_) => Arc::new(StructArray::from(data)) as ArrayRef, + DataType::Map(_, _) => Arc::new(MapArray::from(data)) as ArrayRef, DataType::Union(_) => Arc::new(UnionArray::from(data)) as ArrayRef, DataType::FixedSizeList(_, _) => { Arc::new(FixedSizeListArray::from(data)) as ArrayRef @@ -452,6 +453,9 @@ pub fn new_null_array(data_type: &DataType, length: usize) -> ArrayRef { .map(|field| ArrayData::new_empty(field.data_type())) .collect(), )), + DataType::Map(field, _keys_sorted) => { + new_null_list_array::(data_type, field.data_type(), length) + } DataType::Union(_) => { unimplemented!("Creating null Union array not yet supported") } @@ -657,6 +661,28 @@ mod tests { } } + #[test] + fn test_null_map() { + let data_type = DataType::Map( + Box::new(Field::new( + "entry", + DataType::Struct(vec![ + Field::new("key", DataType::Utf8, false), + Field::new("key", DataType::Int32, true), + ]), + false, + )), + false, + ); + let array = new_null_array(&data_type, 9); + let a = array.as_any().downcast_ref::().unwrap(); + assert_eq!(a.len(), 9); + assert_eq!(a.value_offsets()[9], 0i32); + for i in 0..9 { + assert!(a.is_null(i)); + } + } + #[test] fn test_null_dictionary() { let values = vec![None, None, None, None, None, None, None, None, None] diff --git a/arrow/src/array/array_map.rs b/arrow/src/array/array_map.rs new file mode 100644 index 000000000000..b10c39e43b01 --- /dev/null +++ b/arrow/src/array/array_map.rs @@ -0,0 +1,421 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::any::Any; +use std::fmt; +use std::mem; + +use super::make_array; +use super::{ + array::print_long_array, raw_pointer::RawPtrBox, Array, ArrayData, ArrayRef, +}; +use crate::datatypes::{ArrowNativeType, DataType}; +use crate::error::ArrowError; + +/// A nested array type where each record is a key-value map. +/// Keys should always be non-null, but values can be null. +/// +/// [MapArray] is physically a [ListArray] that has a [StructArray] +/// with 2 child fields. +pub struct MapArray { + data: ArrayData, + values: ArrayRef, + value_offsets: RawPtrBox, +} + +impl MapArray { + /// Returns a reference to the keys of this map. + pub fn keys(&self) -> ArrayRef { + make_array(self.values.data().child_data()[0].clone()) + } + + /// Returns a reference to the values of this map. + pub fn values(&self) -> ArrayRef { + make_array(self.values.data().child_data()[1].clone()) + } + + /// Returns the data type of the map's keys. + pub fn key_type(&self) -> DataType { + self.values.data().child_data()[0].data_type().clone() + } + + /// Returns the data type of the map's values. + pub fn value_type(&self) -> DataType { + self.values.data().child_data()[1].data_type().clone() + } + + /// Returns ith value of this map array. + /// # Safety + /// Caller must ensure that the index is within the array bounds + pub unsafe fn value_unchecked(&self, i: usize) -> ArrayRef { + let end = *self.value_offsets().get_unchecked(i + 1); + let start = *self.value_offsets().get_unchecked(i); + self.values + .slice(start.to_usize().unwrap(), (end - start).to_usize().unwrap()) + } + + /// Returns ith value of this map array. + pub fn value(&self, i: usize) -> ArrayRef { + let end = self.value_offsets()[i + 1] as usize; + let start = self.value_offsets()[i] as usize; + self.values.slice(start, end - start) + } + + /// Returns the offset values in the offsets buffer + #[inline] + pub fn value_offsets(&self) -> &[i32] { + // Soundness + // pointer alignment & location is ensured by RawPtrBox + // buffer bounds/offset is ensured by the ArrayData instance. + unsafe { + std::slice::from_raw_parts( + self.value_offsets.as_ptr().add(self.data.offset()), + self.len() + 1, + ) + } + } + + /// Returns the length for value at index `i`. + #[inline] + pub fn value_length(&self, i: usize) -> i32 { + let offsets = self.value_offsets(); + offsets[i + 1] - offsets[i] + } +} + +impl From for MapArray { + fn from(data: ArrayData) -> Self { + Self::try_new_from_array_data(data) + .expect("Expected infallable creation of MapArray from ArrayData failed") + } +} + +impl MapArray { + fn try_new_from_array_data(data: ArrayData) -> Result { + if data.buffers().len() != 1 { + return Err(ArrowError::InvalidArgumentError( + format!("MapArray data should contain a single buffer only (value offsets), had {}", + data.len()))); + } + + if data.child_data().len() != 1 { + return Err(ArrowError::InvalidArgumentError(format!( + "MapArray should contain a single child array (values array), had {}", + data.child_data().len() + ))); + } + + let entries = data.child_data()[0].clone(); + + if let DataType::Struct(fields) = entries.data_type() { + if fields.len() != 2 { + return Err(ArrowError::InvalidArgumentError(format!( + "MapArray should contain a struct array with 2 fields, have {} fields", + fields.len() + ))); + } + } else { + return Err(ArrowError::InvalidArgumentError(format!( + "MapArray should contain a struct array child, found {:?}", + entries.data_type() + ))); + } + + let values = make_array(entries); + let value_offsets = data.buffers()[0].as_ptr(); + + let value_offsets = unsafe { RawPtrBox::::new(value_offsets) }; + unsafe { + if (*value_offsets.as_ptr().offset(0)) != 0 { + return Err(ArrowError::InvalidArgumentError(String::from( + "offsets do not start at zero", + ))); + } + } + Ok(Self { + data, + values, + value_offsets, + }) + } +} + +impl Array for MapArray { + fn as_any(&self) -> &dyn Any { + self + } + + fn data(&self) -> &ArrayData { + &self.data + } + + /// Returns the total number of bytes of memory occupied by the buffers owned by this [MapArray]. + fn get_buffer_memory_size(&self) -> usize { + self.data.get_buffer_memory_size() + } + + /// Returns the total number of bytes of memory occupied physically by this [MapArray]. + fn get_array_memory_size(&self) -> usize { + self.data.get_array_memory_size() + mem::size_of_val(self) + } +} + +impl fmt::Debug for MapArray { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "MapArray\n[\n")?; + print_long_array(self, f, |array, index, f| { + fmt::Debug::fmt(&array.value(index), f) + })?; + write!(f, "]") + } +} + +#[cfg(test)] +mod tests { + use std::sync::Arc; + + use crate::{ + array::ArrayData, + array::{Int32Array, StructArray, UInt32Array}, + buffer::Buffer, + datatypes::Field, + datatypes::ToByteSlice, + }; + + use super::*; + + fn create_from_buffers() -> MapArray { + // Construct key and values + let keys_data = ArrayData::builder(DataType::Int32) + .len(8) + .add_buffer(Buffer::from(&[0, 1, 2, 3, 4, 5, 6, 7].to_byte_slice())) + .build(); + let values_data = ArrayData::builder(DataType::UInt32) + .len(8) + .add_buffer(Buffer::from( + &[0u32, 10, 20, 30, 40, 50, 60, 70].to_byte_slice(), + )) + .build(); + + // Construct a buffer for value offsets, for the nested array: + // [[0, 1, 2], [3, 4, 5], [6, 7]] + let entry_offsets = Buffer::from(&[0, 3, 6, 8].to_byte_slice()); + + let keys = Field::new("keys", DataType::Int32, false); + let values = Field::new("values", DataType::UInt32, false); + let entry_struct = StructArray::from(vec![ + (keys, make_array(keys_data)), + (values, make_array(values_data)), + ]); + + // Construct a map array from the above two + let map_data_type = DataType::Map( + Box::new(Field::new( + "entries", + entry_struct.data_type().clone(), + true, + )), + false, + ); + let map_data = ArrayData::builder(map_data_type) + .len(3) + .add_buffer(entry_offsets) + .add_child_data(entry_struct.data().clone()) + .build(); + MapArray::from(map_data) + } + + #[test] + fn test_map_array() { + // Construct key and values + let key_data = ArrayData::builder(DataType::Int32) + .len(8) + .add_buffer(Buffer::from(&[0, 1, 2, 3, 4, 5, 6, 7].to_byte_slice())) + .build(); + let value_data = ArrayData::builder(DataType::UInt32) + .len(8) + .add_buffer(Buffer::from( + &[0u32, 10, 20, 0, 40, 0, 60, 70].to_byte_slice(), + )) + .null_bit_buffer(Buffer::from(&[0b11010110])) + .build(); + + // Construct a buffer for value offsets, for the nested array: + // [[0, 1, 2], [3, 4, 5], [6, 7]] + let entry_offsets = Buffer::from(&[0, 3, 6, 8].to_byte_slice()); + + let keys_field = Field::new("keys", DataType::Int32, false); + let values_field = Field::new("values", DataType::UInt32, true); + let entry_struct = StructArray::from(vec![ + (keys_field.clone(), make_array(key_data)), + (values_field.clone(), make_array(value_data.clone())), + ]); + + // Construct a map array from the above two + let map_data_type = DataType::Map( + Box::new(Field::new( + "entries", + entry_struct.data_type().clone(), + true, + )), + false, + ); + let map_data = ArrayData::builder(map_data_type) + .len(3) + .add_buffer(entry_offsets) + .add_child_data(entry_struct.data().clone()) + .build(); + let map_array = MapArray::from(map_data); + + let values = map_array.values(); + assert_eq!(&value_data, values.data()); + assert_eq!(DataType::UInt32, map_array.value_type()); + assert_eq!(3, map_array.len()); + assert_eq!(0, map_array.null_count()); + assert_eq!(6, map_array.value_offsets()[2]); + assert_eq!(2, map_array.value_length(2)); + + let key_array = Arc::new(Int32Array::from(vec![0, 1, 2])) as ArrayRef; + let value_array = + Arc::new(UInt32Array::from(vec![None, Some(10u32), Some(20)])) as ArrayRef; + let struct_array = StructArray::from(vec![ + (keys_field.clone(), key_array), + (values_field.clone(), value_array), + ]); + assert_eq!( + struct_array, + StructArray::from(map_array.value(0).data().clone()) + ); + assert_eq!( + &struct_array, + unsafe { map_array.value_unchecked(0) } + .as_any() + .downcast_ref::() + .unwrap() + ); + for i in 0..3 { + assert!(map_array.is_valid(i)); + assert!(!map_array.is_null(i)); + } + + // Now test with a non-zero offset + let map_data = ArrayData::builder(map_array.data_type().clone()) + .len(3) + .offset(1) + .add_buffer(map_array.data().buffers()[0].clone()) + .add_child_data(map_array.data().child_data()[0].clone()) + .build(); + let map_array = MapArray::from(map_data); + + let values = map_array.values(); + assert_eq!(&value_data, values.data()); + assert_eq!(DataType::UInt32, map_array.value_type()); + assert_eq!(3, map_array.len()); + assert_eq!(0, map_array.null_count()); + assert_eq!(6, map_array.value_offsets()[1]); + assert_eq!(2, map_array.value_length(1)); + + let key_array = Arc::new(Int32Array::from(vec![3, 4, 5])) as ArrayRef; + let value_array = + Arc::new(UInt32Array::from(vec![None, Some(40), None])) as ArrayRef; + let struct_array = + StructArray::from(vec![(keys_field, key_array), (values_field, value_array)]); + assert_eq!( + &struct_array, + map_array + .value(0) + .as_any() + .downcast_ref::() + .unwrap() + ); + assert_eq!( + &struct_array, + unsafe { map_array.value_unchecked(0) } + .as_any() + .downcast_ref::() + .unwrap() + ); + } + + #[test] + #[ignore = "Test fails because slice of > is still buggy"] + fn test_map_array_slice() { + let map_array = create_from_buffers(); + + let sliced_array = map_array.slice(1, 2); + assert_eq!(2, sliced_array.len()); + assert_eq!(1, sliced_array.offset()); + let sliced_array_data = sliced_array.data(); + for array_data in sliced_array_data.child_data() { + assert_eq!(array_data.offset(), 1); + } + + // Check offset and length for each non-null value. + let sliced_map_array = sliced_array.as_any().downcast_ref::().unwrap(); + assert_eq!(3, sliced_map_array.value_offsets()[0]); + assert_eq!(3, sliced_map_array.value_length(0)); + assert_eq!(6, sliced_map_array.value_offsets()[1]); + assert_eq!(2, sliced_map_array.value_length(1)); + + // Construct key and values + let keys_data = ArrayData::builder(DataType::Int32) + .len(5) + .add_buffer(Buffer::from(&[3, 4, 5, 6, 7].to_byte_slice())) + .build(); + let values_data = ArrayData::builder(DataType::UInt32) + .len(5) + .add_buffer(Buffer::from(&[30u32, 40, 50, 60, 70].to_byte_slice())) + .build(); + + // Construct a buffer for value offsets, for the nested array: + // [[3, 4, 5], [6, 7]] + let entry_offsets = Buffer::from(&[0, 3, 5].to_byte_slice()); + + let keys = Field::new("keys", DataType::Int32, false); + let values = Field::new("values", DataType::UInt32, false); + let entry_struct = StructArray::from(vec![ + (keys, make_array(keys_data)), + (values, make_array(values_data)), + ]); + + // Construct a map array from the above two + let map_data_type = DataType::Map( + Box::new(Field::new( + "entries", + entry_struct.data_type().clone(), + true, + )), + false, + ); + let expected_map_data = ArrayData::builder(map_data_type) + .len(2) + .add_buffer(entry_offsets) + .add_child_data(entry_struct.data().clone()) + .build(); + let expected_map_array = MapArray::from(expected_map_data); + + assert_eq!(&expected_map_array, sliced_map_array) + } + + #[test] + #[should_panic(expected = "index out of bounds: the len is ")] + fn test_map_array_index_out_of_bound() { + let map_array = create_from_buffers(); + + map_array.value(map_array.len()); + } +} diff --git a/arrow/src/array/builder.rs b/arrow/src/array/builder.rs index 8f3f7305790d..fc0a5c807df1 100644 --- a/arrow/src/array/builder.rs +++ b/arrow/src/array/builder.rs @@ -1594,6 +1594,163 @@ impl StructBuilder { } } +#[derive(Debug)] +pub struct MapBuilder { + offsets_builder: BufferBuilder, + bitmap_builder: BooleanBufferBuilder, + field_names: MapFieldNames, + key_builder: K, + value_builder: V, + len: i32, +} + +#[derive(Debug, Clone)] +pub struct MapFieldNames { + pub entry: String, + pub key: String, + pub value: String, +} + +impl Default for MapFieldNames { + fn default() -> Self { + Self { + entry: "entries".to_string(), + key: "keys".to_string(), + value: "values".to_string(), + } + } +} + +impl MapBuilder { + pub fn new( + field_names: Option, + key_builder: K, + value_builder: V, + ) -> Self { + let capacity = key_builder.len(); + Self::with_capacity(field_names, key_builder, value_builder, capacity) + } + + pub fn with_capacity( + field_names: Option, + key_builder: K, + value_builder: V, + capacity: usize, + ) -> Self { + let mut offsets_builder = BufferBuilder::::new(capacity + 1); + let len = 0; + offsets_builder.append(len); + Self { + offsets_builder, + bitmap_builder: BooleanBufferBuilder::new(capacity), + field_names: field_names.unwrap_or_default(), + key_builder, + value_builder, + len, + } + } + + pub fn keys(&mut self) -> &mut K { + &mut self.key_builder + } + + pub fn values(&mut self) -> &mut V { + &mut self.value_builder + } + + /// Finish the current map array slot + #[inline] + pub fn append(&mut self, is_valid: bool) -> Result<()> { + if self.key_builder.len() != self.value_builder.len() { + return Err(ArrowError::InvalidArgumentError(format!( + "Cannot append to a map builder when its keys and values have unequal lengths of {} and {}", + self.key_builder.len(), + self.value_builder.len() + ))); + } + self.offsets_builder.append(self.key_builder.len() as i32); + self.bitmap_builder.append(is_valid); + self.len += 1; + Ok(()) + } + + pub fn finish(&mut self) -> MapArray { + let len = self.len(); + self.len = 0; + + // Build the keys + let keys_arr = self + .key_builder + .as_any_mut() + .downcast_mut::() + .unwrap() + .finish(); + let values_arr = self + .value_builder + .as_any_mut() + .downcast_mut::() + .unwrap() + .finish(); + + let keys_field = Field::new( + self.field_names.key.as_str(), + keys_arr.data_type().clone(), + false, // always nullable + ); + let values_field = Field::new( + self.field_names.value.as_str(), + values_arr.data_type().clone(), + true, + ); + + let struct_array = + StructArray::from(vec![(keys_field, keys_arr), (values_field, values_arr)]); + + let offset_buffer = self.offsets_builder.finish(); + let null_bit_buffer = self.bitmap_builder.finish(); + self.offsets_builder.append(self.len); + let map_field = Box::new(Field::new( + self.field_names.entry.as_str(), + struct_array.data_type().clone(), + false, // always non-nullable + )); + let data = ArrayData::builder(DataType::Map(map_field, false)) // TODO: support sorted keys + .len(len) + .add_buffer(offset_buffer) + .add_child_data(struct_array.data().clone()) + .null_bit_buffer(null_bit_buffer) + .build(); + + MapArray::from(data) + } +} + +impl ArrayBuilder for MapBuilder { + fn len(&self) -> usize { + self.len as usize + } + + fn is_empty(&self) -> bool { + self.len == 0 + } + + fn finish(&mut self) -> ArrayRef { + Arc::new(self.finish()) + } + + fn as_any(&self) -> &dyn Any { + self + } + + fn as_any_mut(&mut self) -> &mut dyn Any { + self + } + + fn into_box_any(self: Box) -> Box { + self + } +} + /// `FieldData` is a helper struct to track the state of the fields in the `UnionBuilder`. #[derive(Debug)] struct FieldData { @@ -3184,6 +3341,60 @@ mod tests { assert_eq!(0, builder.len()); } + #[test] + fn test_map_array_builder() { + let string_builder = StringBuilder::new(4); + let int_builder = Int32Builder::new(4); + + let mut builder = MapBuilder::new(None, string_builder, int_builder); + + let string_builder = builder.keys(); + string_builder.append_value("joe").unwrap(); + string_builder.append_null().unwrap(); + string_builder.append_null().unwrap(); + string_builder.append_value("mark").unwrap(); + + let int_builder = builder.values(); + int_builder.append_value(1).unwrap(); + int_builder.append_value(2).unwrap(); + int_builder.append_null().unwrap(); + int_builder.append_value(4).unwrap(); + + builder.append(true).unwrap(); + builder.append(false).unwrap(); + builder.append(true).unwrap(); + + let arr = builder.finish(); + + let map_data = arr.data(); + assert_eq!(3, map_data.len()); + assert_eq!(1, map_data.null_count()); + assert_eq!( + &Some(Bitmap::from(Buffer::from(&[5_u8]))), + map_data.null_bitmap() + ); + + let expected_string_data = ArrayData::builder(DataType::Utf8) + .len(4) + .null_bit_buffer(Buffer::from(&[9_u8])) + .add_buffer(Buffer::from_slice_ref(&[0, 3, 3, 3, 7])) + .add_buffer(Buffer::from_slice_ref(b"joemark")) + .build(); + + let expected_int_data = ArrayData::builder(DataType::Int32) + .len(4) + .null_bit_buffer(Buffer::from_slice_ref(&[11_u8])) + .add_buffer(Buffer::from_slice_ref(&[1, 2, 0, 4])) + .build(); + + assert_eq!(&expected_string_data, arr.keys().data()); + assert_eq!(&expected_int_data, arr.values().data()); + } + + // TODO: add a test that finishes building, after designing a spec-compliant + // way of inserting values to the map. + // A map's values shouldn't be repeated within a slot + #[test] fn test_struct_array_builder_from_schema() { let mut fields = Vec::new(); diff --git a/arrow/src/array/data.rs b/arrow/src/array/data.rs index 228f0221933f..cb389cacc7f6 100644 --- a/arrow/src/array/data.rs +++ b/arrow/src/array/data.rs @@ -126,7 +126,7 @@ pub(crate) fn new_buffers(data_type: &DataType, capacity: usize) -> [MutableBuff buffer.push(0i64); [buffer, MutableBuffer::new(capacity * mem::size_of::())] } - DataType::List(_) => { + DataType::List(_) | DataType::Map(_, _) => { // offset buffer always starts with a zero let mut buffer = MutableBuffer::new((1 + capacity) * mem::size_of::()); buffer.push(0i32); @@ -475,6 +475,9 @@ impl ArrayData { .iter() .map(|field| Self::new_empty(field.data_type())) .collect(), + DataType::Map(field, _) => { + vec![Self::new_empty(field.data_type())] + } DataType::Union(_) => unimplemented!(), DataType::Dictionary(_, data_type) => { vec![Self::new_empty(data_type)] diff --git a/arrow/src/array/equal/mod.rs b/arrow/src/array/equal/mod.rs index 4ddf4e473202..8368717c6747 100644 --- a/arrow/src/array/equal/mod.rs +++ b/arrow/src/array/equal/mod.rs @@ -22,7 +22,7 @@ use super::{ Array, ArrayData, BinaryOffsetSizeTrait, BooleanArray, DecimalArray, FixedSizeBinaryArray, FixedSizeListArray, GenericBinaryArray, GenericListArray, - GenericStringArray, NullArray, OffsetSizeTrait, PrimitiveArray, + GenericStringArray, MapArray, NullArray, OffsetSizeTrait, PrimitiveArray, StringOffsetSizeTrait, StructArray, }; @@ -117,6 +117,12 @@ impl PartialEq for GenericListArray { } } +impl PartialEq for MapArray { + fn eq(&self, other: &Self) -> bool { + equal(self.data(), other.data()) + } +} + impl PartialEq for FixedSizeListArray { fn eq(&self, other: &Self) -> bool { equal(self.data(), other.data()) @@ -246,6 +252,9 @@ fn equal_values( _ => unreachable!(), }, DataType::Float16 => unreachable!(), + DataType::Map(_, _) => { + list_equal::(lhs, rhs, lhs_nulls, rhs_nulls, lhs_start, rhs_start, len) + } } } diff --git a/arrow/src/array/equal/utils.rs b/arrow/src/array/equal/utils.rs index 2a1ce88d1474..1e33a867c83b 100644 --- a/arrow/src/array/equal/utils.rs +++ b/arrow/src/array/equal/utils.rs @@ -106,7 +106,7 @@ pub(super) fn child_logical_null_buffer( Bitmap::from(Buffer::from(vec![0b11111111; ceil])) }); match parent_data.data_type() { - DataType::List(_) => Some(logical_list_bitmap::( + DataType::List(_) | DataType::Map(_, _) => Some(logical_list_bitmap::( parent_data, parent_bitmap, self_null_bitmap, diff --git a/arrow/src/array/equal_json.rs b/arrow/src/array/equal_json.rs index 7120e6cf430e..adc33a7a1cd3 100644 --- a/arrow/src/array/equal_json.rs +++ b/arrow/src/array/equal_json.rs @@ -219,6 +219,38 @@ impl PartialEq for Value { } } +impl JsonEqual for MapArray { + fn equals_json(&self, json: &[&Value]) -> bool { + if self.len() != json.len() { + return false; + } + + (0..self.len()).all(|i| match json[i] { + Value::Array(v) => self.is_valid(i) && self.value(i).equals_json_values(v), + Value::Null => self.is_null(i) || self.value_length(i).eq(&0), + _ => false, + }) + } +} + +impl PartialEq for MapArray { + fn eq(&self, json: &Value) -> bool { + match json { + Value::Array(json_array) => self.equals_json_values(json_array), + _ => false, + } + } +} + +impl PartialEq for Value { + fn eq(&self, arrow: &MapArray) -> bool { + match self { + Value::Array(json_array) => arrow.equals_json_values(json_array), + _ => false, + } + } +} + impl JsonEqual for GenericBinaryArray { fn equals_json(&self, json: &[&Value]) -> bool { if self.len() != json.len() { diff --git a/arrow/src/array/mod.rs b/arrow/src/array/mod.rs index 69b65f41ad5a..bd791f96c64f 100644 --- a/arrow/src/array/mod.rs +++ b/arrow/src/array/mod.rs @@ -87,6 +87,7 @@ mod array_binary; mod array_boolean; mod array_dictionary; mod array_list; +mod array_map; mod array_primitive; mod array_string; mod array_struct; @@ -122,6 +123,7 @@ pub use self::array_dictionary::DictionaryArray; pub use self::array_list::FixedSizeListArray; pub use self::array_list::LargeListArray; pub use self::array_list::ListArray; +pub use self::array_map::MapArray; pub use self::array_primitive::PrimitiveArray; pub use self::array_string::LargeStringArray; pub use self::array_string::StringArray; diff --git a/arrow/src/datatypes/datatype.rs b/arrow/src/datatypes/datatype.rs index 122cbdd5e47d..1cbec341cf37 100644 --- a/arrow/src/datatypes/datatype.rs +++ b/arrow/src/datatypes/datatype.rs @@ -129,6 +129,20 @@ pub enum DataType { Dictionary(Box, Box), /// Decimal value with precision and scale Decimal(usize, usize), + /// A Map is a logical nested type that is represented as + /// + /// `List>` + /// + /// The keys and values are each respectively contiguous. + /// The key and value types are not constrained, but keys should be + /// hashable and unique. + /// Whether the keys are sorted can be set in the `bool` after the `Field`. + /// + /// In a field with Map type, the field has a child Struct field, which then + /// has two children: key type and the second the value type. The names of the + /// child fields may be respectively "entries", "key", and "value", but this is + /// not enforced. + Map(Box, bool), } /// An absolute length of time in seconds, milliseconds, microseconds or nanoseconds. @@ -335,6 +349,16 @@ impl DataType { // return an empty `struct` type as its children aren't defined in the map Ok(DataType::Struct(vec![])) } + Some(s) if s == "map" => { + if let Some(Value::Bool(keys_sorted)) = map.get("keysSorted") { + // Return a map with an empty type as its children aren't defined in the map + Ok(DataType::Map(Box::new(default_field), *keys_sorted)) + } else { + Err(ArrowError::ParseError( + "Expecting a keysSorted for map".to_string(), + )) + } + } Some(other) => Err(ArrowError::ParseError(format!( "invalid or unsupported type name: {} in {:?}", other, json @@ -429,6 +453,9 @@ impl DataType { DataType::Decimal(precision, scale) => { json!({"name": "decimal", "precision": precision, "scale": scale}) } + DataType::Map(_, keys_sorted) => { + json!({"name": "map", "keysSorted": keys_sorted}) + } } } @@ -471,6 +498,10 @@ impl DataType { && a.data_type().equals_datatype(b.data_type()) }) } + ( + DataType::Map(a_field, a_is_sorted), + DataType::Map(b_field, b_is_sorted), + ) => a_field == b_field && a_is_sorted == b_is_sorted, _ => self == other, } } diff --git a/arrow/src/datatypes/field.rs b/arrow/src/datatypes/field.rs index 1cb8eb807539..497dbb389fd7 100644 --- a/arrow/src/datatypes/field.rs +++ b/arrow/src/datatypes/field.rs @@ -271,6 +271,35 @@ impl Field { )); } }, + DataType::Map(_, keys_sorted) => { + match map.get("children") { + Some(Value::Array(values)) if values.len() == 1 => { + let child = Self::from(&values[0])?; + // child must be a struct + match child.data_type() { + DataType::Struct(map_fields) if map_fields.len() == 2 => { + DataType::Map(Box::new(child), keys_sorted) + } + t => { + return Err(ArrowError::ParseError( + format!("Map children should be a struct with 2 fields, found {:?}", t) + )) + } + } + } + Some(_) => { + return Err(ArrowError::ParseError( + "Field 'children' must be an array with 1 element" + .to_string(), + )) + } + None => { + return Err(ArrowError::ParseError( + "Field missing 'children' attribute".to_string(), + )); + } + } + } _ => data_type, }; @@ -329,6 +358,9 @@ impl Field { DataType::List(field) => vec![field.to_json()], DataType::LargeList(field) => vec![field.to_json()], DataType::FixedSizeList(field, _) => vec![field.to_json()], + DataType::Map(field, _) => { + vec![field.to_json()] + } _ => vec![], }; match self.data_type() { @@ -468,6 +500,7 @@ impl Field { | DataType::Interval(_) | DataType::LargeList(_) | DataType::List(_) + | DataType::Map(_, _) | DataType::Dictionary(_, _) | DataType::FixedSizeList(_, _) | DataType::FixedSizeBinary(_) diff --git a/arrow/src/datatypes/mod.rs b/arrow/src/datatypes/mod.rs index 5da712684a43..9920cf95d3c6 100644 --- a/arrow/src/datatypes/mod.rs +++ b/arrow/src/datatypes/mod.rs @@ -207,6 +207,66 @@ mod tests { assert_eq!(value, f.to_json()); } + #[test] + fn map_field_to_json() { + let f = Field::new( + "my_map", + DataType::Map( + Box::new(Field::new( + "my_entries", + DataType::Struct(vec![ + Field::new("my_keys", DataType::Utf8, false), + Field::new("my_values", DataType::UInt16, true), + ]), + false, + )), + true, + ), + false, + ); + let value: Value = serde_json::from_str( + r#"{ + "name": "my_map", + "nullable": false, + "type": { + "name": "map", + "keysSorted": true + }, + "children": [ + { + "name": "my_entries", + "nullable": false, + "type": { + "name": "struct" + }, + "children": [ + { + "name": "my_keys", + "nullable": false, + "type": { + "name": "utf8" + }, + "children": [] + }, + { + "name": "my_values", + "nullable": true, + "type": { + "name": "int", + "bitWidth": 16, + "isSigned": false + }, + "children": [] + } + ] + } + ] + }"#, + ) + .unwrap(); + assert_eq!(value, f.to_json()); + } + #[test] fn primitive_field_to_json() { let f = Field::new("first_name", DataType::Utf8, false); @@ -269,6 +329,69 @@ mod tests { assert_eq!(expected, dt); } + #[test] + fn parse_map_from_json() { + let json = r#" + { + "name": "my_map", + "nullable": false, + "type": { + "name": "map", + "keysSorted": true + }, + "children": [ + { + "name": "my_entries", + "nullable": false, + "type": { + "name": "struct" + }, + "children": [ + { + "name": "my_keys", + "nullable": false, + "type": { + "name": "utf8" + }, + "children": [] + }, + { + "name": "my_values", + "nullable": true, + "type": { + "name": "int", + "bitWidth": 16, + "isSigned": false + }, + "children": [] + } + ] + } + ] + } + "#; + let value: Value = serde_json::from_str(json).unwrap(); + let dt = Field::from(&value).unwrap(); + + let expected = Field::new( + "my_map", + DataType::Map( + Box::new(Field::new( + "my_entries", + DataType::Struct(vec![ + Field::new("my_keys", DataType::Utf8, false), + Field::new("my_values", DataType::UInt16, true), + ]), + false, + )), + true, + ), + false, + ); + + assert_eq!(expected, dt); + } + #[test] fn parse_utf8_from_json() { let json = "{\"name\":\"utf8\"}"; @@ -396,6 +519,21 @@ mod tests { ))), true, ), + Field::new( + "c35", + DataType::Map( + Box::new(Field::new( + "my_entries", + DataType::Struct(vec![ + Field::new("my_keys", DataType::Utf8, false), + Field::new("my_values", DataType::UInt16, true), + ]), + false, + )), + true, + ), + false, + ), ], metadata, ); @@ -790,6 +928,43 @@ mod tests { ] } ] + }, + { + "name": "c35", + "nullable": false, + "type": { + "name": "map", + "keysSorted": true + }, + "children": [ + { + "name": "my_entries", + "nullable": false, + "type": { + "name": "struct" + }, + "children": [ + { + "name": "my_keys", + "nullable": false, + "type": { + "name": "utf8" + }, + "children": [] + }, + { + "name": "my_values", + "nullable": true, + "type": { + "name": "int", + "bitWidth": 16, + "isSigned": false + }, + "children": [] + } + ] + } + ] } ], "metadata" : { diff --git a/arrow/src/ipc/convert.rs b/arrow/src/ipc/convert.rs index 59d4d0b9089c..5244a387c90b 100644 --- a/arrow/src/ipc/convert.rs +++ b/arrow/src/ipc/convert.rs @@ -308,6 +308,14 @@ pub(crate) fn get_data_type(field: ipc::Field, may_be_dictionary: bool) -> DataT DataType::Struct(fields) } + ipc::Type::Map => { + let map = field.type_as_map().unwrap(); + let children = field.children().unwrap(); + if children.len() != 1 { + panic!("expect a map to have one child") + } + DataType::Map(Box::new(children.get(0).into()), map.keysSorted()) + } ipc::Type::Decimal => { let fsb = field.type_as_decimal().unwrap(); DataType::Decimal(fsb.precision() as usize, fsb.scale() as usize) @@ -624,6 +632,16 @@ pub(crate) fn get_fb_field_type<'a>( children: Some(fbb.create_vector(&children[..])), } } + Map(map_field, keys_sorted) => { + let child = build_field(fbb, map_field); + let mut field_type = ipc::MapBuilder::new(fbb); + field_type.add_keysSorted(*keys_sorted); + FBFieldType { + type_type: ipc::Type::Map, + type_: field_type.finish().as_union_value(), + children: Some(fbb.create_vector(&[child])), + } + } Dictionary(_, value_type) => { // In this library, the dictionary "type" is a logical construct. Here we // pass through to the value type, as we've already captured the index diff --git a/arrow/src/ipc/reader.rs b/arrow/src/ipc/reader.rs index 7bba3119f239..50e858f098a8 100644 --- a/arrow/src/ipc/reader.rs +++ b/arrow/src/ipc/reader.rs @@ -89,7 +89,7 @@ fn create_array( buffer_index += 2; array } - List(ref list_field) | LargeList(ref list_field) => { + List(ref list_field) | LargeList(ref list_field) | Map(ref list_field, _) => { let list_node = &nodes[node_index]; let list_buffers: Vec = buffers[buffer_index..buffer_index + 2] .iter() @@ -377,8 +377,19 @@ fn create_list_array( builder = builder.null_bit_buffer(buffers[0].clone()) } make_array(builder.build()) + } else if let DataType::Map(_, _) = *data_type { + let null_count = field_node.null_count() as usize; + let mut builder = ArrayData::builder(data_type.clone()) + .len(field_node.length() as usize) + .buffers(buffers[1..2].to_vec()) + .offset(0) + .child_data(vec![child_array.data().clone()]); + if null_count > 0 { + builder = builder.null_bit_buffer(buffers[0].clone()) + } + make_array(builder.build()) } else { - panic!("Cannot create list array from {:?}", data_type) + panic!("Cannot create list or map array from {:?}", data_type) } } @@ -931,6 +942,7 @@ mod tests { "generated_interval", "generated_datetime", "generated_dictionary", + "generated_map", "generated_nested", "generated_primitive_no_batches", "generated_primitive_zerolength", @@ -972,6 +984,7 @@ mod tests { "generated_interval", "generated_datetime", "generated_dictionary", + "generated_map", "generated_nested", "generated_null_trivial", "generated_null", @@ -999,6 +1012,7 @@ mod tests { "generated_interval", "generated_datetime", "generated_dictionary", + "generated_map", "generated_nested", "generated_primitive_no_batches", "generated_primitive_zerolength", @@ -1033,6 +1047,8 @@ mod tests { "generated_interval", "generated_datetime", "generated_dictionary", + "generated_map", + // "generated_map_non_canonical", "generated_nested", "generated_null_trivial", "generated_null", @@ -1064,6 +1080,8 @@ mod tests { "generated_interval", "generated_datetime", "generated_dictionary", + "generated_map", + // "generated_map_non_canonical", "generated_nested", "generated_null_trivial", "generated_null", diff --git a/arrow/src/ipc/writer.rs b/arrow/src/ipc/writer.rs index f342d6773fed..0376265f4f65 100644 --- a/arrow/src/ipc/writer.rs +++ b/arrow/src/ipc/writer.rs @@ -928,6 +928,7 @@ mod tests { "generated_interval", "generated_datetime", "generated_dictionary", + "generated_map", "generated_nested", "generated_primitive_no_batches", "generated_primitive_zerolength", @@ -979,6 +980,7 @@ mod tests { "generated_interval", "generated_datetime", "generated_dictionary", + "generated_map", "generated_nested", "generated_primitive_no_batches", "generated_primitive_zerolength", @@ -1031,6 +1033,7 @@ mod tests { "generated_dictionary", // "generated_duplicate_fieldnames", "generated_interval", + "generated_map", "generated_nested", // "generated_nested_large_offsets", "generated_null_trivial", @@ -1094,6 +1097,7 @@ mod tests { "generated_dictionary", // "generated_duplicate_fieldnames", "generated_interval", + "generated_map", "generated_nested", // "generated_nested_large_offsets", "generated_null_trivial", diff --git a/arrow/src/json/reader.rs b/arrow/src/json/reader.rs index 290ad4f23977..c4e847082c2b 100644 --- a/arrow/src/json/reader.rs +++ b/arrow/src/json/reader.rs @@ -47,6 +47,7 @@ use std::sync::Arc; use indexmap::map::IndexMap as HashMap; use indexmap::set::IndexSet as HashSet; +use serde_json::json; use serde_json::{map::Map as JsonMap, Value}; use crate::buffer::MutableBuffer; @@ -1282,6 +1283,12 @@ impl Decoder { .build(); Ok(make_array(data)) } + DataType::Map(map_field, _) => self.build_map_array( + rows, + field.name(), + field.data_type(), + map_field, + ), _ => Err(ArrowError::JsonError(format!( "{:?} type is not supported", field.data_type() @@ -1292,6 +1299,101 @@ impl Decoder { arrays } + fn build_map_array( + &self, + rows: &[Value], + field_name: &str, + map_type: &DataType, + struct_field: &Field, + ) -> Result { + // A map has the format {"key": "value"} where key is most commonly a string, + // but could be a string, number or boolean (🤷🏾‍♂️) (e.g. {1: "value"}). + // A map is also represented as a flattened contiguous array, with the number + // of key-value pairs being separated by a list offset. + // If row 1 has 2 key-value pairs, and row 2 has 3, the offsets would be + // [0, 2, 5]. + // + // Thus we try to read a map by iterating through the keys and values + + let (key_field, value_field) = + if let DataType::Struct(fields) = struct_field.data_type() { + if fields.len() != 2 { + return Err(ArrowError::InvalidArgumentError(format!( + "DataType::Map expects a struct with 2 fields, found {} fields", + fields.len() + ))); + } + (&fields[0], &fields[1]) + } else { + return Err(ArrowError::InvalidArgumentError(format!( + "JSON map array builder expects a DataType::Map, found {:?}", + struct_field.data_type() + ))); + }; + let value_map_iter = rows.iter().map(|value| { + value + .get(field_name) + .map(|v| v.as_object().map(|map| (map, map.len() as i32))) + .flatten() + }); + let rows_len = rows.len(); + let mut list_offsets = Vec::with_capacity(rows_len + 1); + list_offsets.push(0i32); + let mut last_offset = 0; + let num_bytes = bit_util::ceil(rows_len, 8); + let mut list_bitmap = MutableBuffer::from_len_zeroed(num_bytes); + let null_data = list_bitmap.as_slice_mut(); + + let struct_rows = value_map_iter + .enumerate() + .filter_map(|(i, v)| match v { + Some((map, len)) => { + list_offsets.push(last_offset + len); + last_offset += len; + bit_util::set_bit(null_data, i); + Some(map.iter().map(|(k, v)| { + json!({ + key_field.name(): k, + value_field.name(): v + }) + })) + } + None => { + list_offsets.push(last_offset); + None + } + }) + .flatten() + .collect::>(); + + let struct_children = self.build_struct_array( + struct_rows.as_slice(), + &[key_field.clone(), value_field.clone()], + &[], + )?; + + Ok(make_array(ArrayData::new( + map_type.clone(), + rows_len, + None, + Some(list_bitmap.into()), + 0, + vec![Buffer::from_slice_ref(&list_offsets)], + vec![ArrayData::new( + struct_field.data_type().clone(), + struct_children[0].len(), + None, + None, + 0, + vec![], + struct_children + .into_iter() + .map(|array| array.data().clone()) + .collect(), + )], + ))) + } + #[inline(always)] fn build_dictionary_array( &self, @@ -2177,6 +2279,81 @@ mod tests { assert_eq!(read.data_ref(), expected.data_ref()); } + #[test] + fn test_map_json_arrays() { + let account_field = Field::new("account", DataType::UInt16, false); + let value_list_type = + DataType::List(Box::new(Field::new("item", DataType::Utf8, false))); + let entries_struct_type = DataType::Struct(vec![ + Field::new("key", DataType::Utf8, false), + Field::new("value", value_list_type.clone(), true), + ]); + let stocks_field = Field::new( + "stocks", + DataType::Map( + Box::new(Field::new("entries", entries_struct_type.clone(), false)), + false, + ), + true, + ); + let schema = Arc::new(Schema::new(vec![account_field, stocks_field.clone()])); + let builder = ReaderBuilder::new().with_schema(schema).with_batch_size(64); + // Note: account 456 has 'long' twice, to show that the JSON reader will overwrite + // existing keys. This thus guarantees unique keys for the map + let json_content = r#" + {"account": 123, "stocks":{"long": ["$AAA", "$BBB"], "short": ["$CCC", "$D"]}} + {"account": 456, "stocks":{"long": null, "long": ["$AAA", "$CCC", "$D"], "short": null}} + {"account": 789, "stocks":{"hedged": ["$YYY"], "long": null, "short": ["$D"]}} + "#; + let mut reader = builder.build(Cursor::new(json_content)).unwrap(); + + // build expected output + let expected_accounts = UInt16Array::from(vec![123, 456, 789]); + + let expected_keys = StringArray::from(vec![ + "long", "short", "long", "short", "hedged", "long", "short", + ]) + .data() + .clone(); + let expected_value_array_data = StringArray::from(vec![ + "$AAA", "$BBB", "$CCC", "$D", "$AAA", "$CCC", "$D", "$YYY", "$D", + ]) + .data() + .clone(); + // Create the list that holds ["$_", "$_"] + let expected_values = ArrayDataBuilder::new(value_list_type) + .len(7) + .add_buffer(Buffer::from( + vec![0i32, 2, 4, 7, 7, 8, 8, 9].to_byte_slice(), + )) + .add_child_data(expected_value_array_data) + .null_bit_buffer(Buffer::from(vec![0b01010111])) + .build(); + let expected_stocks_entries_data = ArrayDataBuilder::new(entries_struct_type) + .len(7) + .add_child_data(expected_keys) + .add_child_data(expected_values) + .build(); + let expected_stocks_data = + ArrayDataBuilder::new(stocks_field.data_type().clone()) + .len(3) + .add_buffer(Buffer::from(vec![0i32, 2, 4, 7].to_byte_slice())) + .add_child_data(expected_stocks_entries_data) + .build(); + + let expected_stocks = make_array(expected_stocks_data); + + // compare with result from json reader + let batch = reader.next().unwrap().unwrap(); + assert_eq!(batch.num_rows(), 3); + assert_eq!(batch.num_columns(), 2); + let col1 = batch.column(0); + assert_eq!(col1.data(), expected_accounts.data()); + // Compare the map + let col2 = batch.column(1); + assert_eq!(col2.data(), expected_stocks.data()); + } + #[test] fn test_dictionary_from_json_basic_with_nulls() { let schema = Schema::new(vec![Field::new( diff --git a/arrow/src/util/integration_util.rs b/arrow/src/util/integration_util.rs index bac0e47a92f0..ada2494d3c2d 100644 --- a/arrow/src/util/integration_util.rs +++ b/arrow/src/util/integration_util.rs @@ -350,6 +350,10 @@ impl ArrowJsonBatch { let arr = arr.as_any().downcast_ref::().unwrap(); arr.equals_json(&json_array.iter().collect::>()[..]) } + DataType::Map(_, _) => { + let arr = arr.as_any().downcast_ref::().unwrap(); + arr.equals_json(&json_array.iter().collect::>()[..]) + } DataType::Decimal(_, _) => { let arr = arr.as_any().downcast_ref::().unwrap(); arr.equals_json(&json_array.iter().collect::>()[..]) @@ -492,6 +496,7 @@ fn json_from_col(col: &ArrowJsonColumn, data_type: &DataType) -> Vec { json_from_fixed_size_list_col(col, field.data_type(), *list_size as usize) } DataType::Struct(fields) => json_from_struct_col(col, fields), + DataType::Map(field, keys_sorted) => json_from_map_col(col, field, *keys_sorted), DataType::Int64 | DataType::UInt64 | DataType::Date64 @@ -649,6 +654,51 @@ fn json_from_fixed_size_list_col( values } +fn json_from_map_col( + col: &ArrowJsonColumn, + field: &Field, + _keys_sorted: bool, +) -> Vec { + let mut values = Vec::with_capacity(col.count); + + // get the inner array + let child = &col.children.clone().expect("list type must have children")[0]; + let offsets: Vec = col + .offset + .clone() + .unwrap() + .iter() + .map(|o| match o { + Value::String(s) => s.parse::().unwrap(), + Value::Number(n) => n.as_u64().unwrap() as usize, + _ => panic!( + "Offsets should be numbers or strings that are convertible to numbers" + ), + }) + .collect(); + + let inner = match field.data_type() { + DataType::Struct(fields) => json_from_struct_col(child, fields), + _ => panic!("Map child must be Struct"), + }; + + for i in 0..col.count { + match &col.validity { + Some(validity) => match &validity[i] { + 0 => values.push(Value::Null), + 1 => { + values.push(Value::Array(inner[offsets[i]..offsets[i + 1]].to_vec())) + } + _ => panic!("Validity data should be 0 or 1"), + }, + None => { + // Null type does not have a validity vector + } + } + } + + values +} #[cfg(test)] mod tests { use super::*; diff --git a/parquet/src/arrow/array_reader.rs b/parquet/src/arrow/array_reader.rs index f8cad6f05131..d3259c46bbad 100644 --- a/parquet/src/arrow/array_reader.rs +++ b/parquet/src/arrow/array_reader.rs @@ -27,7 +27,7 @@ use arrow::array::{ new_empty_array, Array, ArrayData, ArrayDataBuilder, ArrayRef, BinaryArray, BinaryBuilder, BooleanArray, BooleanBufferBuilder, BooleanBuilder, DecimalBuilder, FixedSizeBinaryArray, FixedSizeBinaryBuilder, GenericListArray, Int16BufferBuilder, - Int32Array, Int64Array, OffsetSizeTrait, PrimitiveArray, PrimitiveBuilder, + Int32Array, Int64Array, MapArray, OffsetSizeTrait, PrimitiveArray, PrimitiveBuilder, StringArray, StringBuilder, StructArray, }; use arrow::buffer::{Buffer, MutableBuffer}; @@ -924,6 +924,145 @@ impl ArrayReader for ListArrayReader { } } +/// Implementation of a map array reader. +pub struct MapArrayReader { + key_reader: Box, + value_reader: Box, + data_type: ArrowType, + map_def_level: i16, + map_rep_level: i16, + def_level_buffer: Option, + rep_level_buffer: Option, +} + +impl MapArrayReader { + pub fn new( + key_reader: Box, + value_reader: Box, + data_type: ArrowType, + def_level: i16, + rep_level: i16, + ) -> Self { + Self { + key_reader, + value_reader, + data_type, + map_def_level: rep_level, + map_rep_level: def_level, + def_level_buffer: None, + rep_level_buffer: None, + } + } +} + +impl ArrayReader for MapArrayReader { + fn as_any(&self) -> &dyn Any { + self + } + + fn get_data_type(&self) -> &ArrowType { + &self.data_type + } + + fn next_batch(&mut self, batch_size: usize) -> Result { + let key_array = self.key_reader.next_batch(batch_size)?; + let value_array = self.value_reader.next_batch(batch_size)?; + + // Check that key and value have the same lengths + let key_length = key_array.len(); + if key_length != value_array.len() { + return Err(general_err!( + "Map key and value should have the same lengths." + )); + } + + let def_levels = self + .key_reader + .get_def_levels() + .ok_or_else(|| ArrowError("item_reader def levels are None.".to_string()))?; + let rep_levels = self + .key_reader + .get_rep_levels() + .ok_or_else(|| ArrowError("item_reader rep levels are None.".to_string()))?; + + if !((def_levels.len() == rep_levels.len()) && (rep_levels.len() == key_length)) { + return Err(ArrowError( + "Expected item_reader def_levels and rep_levels to be same length as batch".to_string(), + )); + } + + let entry_data_type = if let ArrowType::Map(field, _) = &self.data_type { + field.data_type().clone() + } else { + return Err(ArrowError("Expected a map arrow type".to_string())); + }; + + let entry_data = ArrayDataBuilder::new(entry_data_type) + .len(key_length) + .add_child_data(key_array.data().clone()) + .add_child_data(value_array.data().clone()) + .build(); + + let entry_len = rep_levels.iter().filter(|level| **level == 0).count(); + + // first item in each list has rep_level = 0, subsequent items have rep_level = 1 + let mut offsets: Vec = Vec::new(); + let mut cur_offset = 0; + def_levels.iter().zip(rep_levels).for_each(|(d, r)| { + if *r == 0 || d == &self.map_def_level { + offsets.push(cur_offset); + } + if d > &self.map_def_level { + cur_offset += 1; + } + }); + offsets.push(cur_offset); + + let num_bytes = bit_util::ceil(offsets.len(), 8); + // TODO: A useful optimization is to use the null count to fill with + // 0 or null, to reduce individual bits set in a loop. + // To favour dense data, set every slot to true, then unset + let mut null_buf = MutableBuffer::new(num_bytes).with_bitset(num_bytes, true); + let null_slice = null_buf.as_slice_mut(); + let mut list_index = 0; + for i in 0..rep_levels.len() { + // If the level is lower than empty, then the slot is null. + // When a list is non-nullable, its empty level = null level, + // so this automatically factors that in. + if rep_levels[i] == 0 && def_levels[i] < self.map_def_level { + // should be empty list + bit_util::unset_bit(null_slice, list_index); + } + if rep_levels[i] == 0 { + list_index += 1; + } + } + let value_offsets = Buffer::from(&offsets.to_byte_slice()); + + // Now we can build array data + let array_data = ArrayDataBuilder::new(self.data_type.clone()) + .len(entry_len) + .add_buffer(value_offsets) + .null_bit_buffer(null_buf.into()) + .add_child_data(entry_data) + .build(); + + Ok(Arc::new(MapArray::from(array_data))) + } + + fn get_def_levels(&self) -> Option<&[i16]> { + self.def_level_buffer + .as_ref() + .map(|buf| unsafe { buf.typed_data() }) + } + + fn get_rep_levels(&self) -> Option<&[i16]> { + self.rep_level_buffer + .as_ref() + .map(|buf| unsafe { buf.typed_data() }) + } +} + /// Implementation of struct array reader. pub struct StructArrayReader { children: Vec>, @@ -1176,8 +1315,6 @@ impl<'a> TypeVisitor>, &'a ArrayReaderBuilderContext for ArrayReaderBuilder { /// Build array reader for primitive type. - /// Currently we don't have a list reader implementation, so repeated type is not - /// supported yet. fn visit_primitive( &mut self, cur_type: TypePtr, @@ -1251,15 +1388,87 @@ impl<'a> TypeVisitor>, &'a ArrayReaderBuilderContext } /// Build array reader for map type. - /// Currently this is not supported. fn visit_map( &mut self, - _cur_type: Arc, - _context: &'a ArrayReaderBuilderContext, + map_type: Arc, + context: &'a ArrayReaderBuilderContext, ) -> Result>> { - Err(ArrowError( - "Reading parquet map array into arrow is not supported yet!".to_string(), - )) + // Add map type to context + let mut new_context = context.clone(); + new_context.path.append(vec![map_type.name().to_string()]); + if let Repetition::OPTIONAL = map_type.get_basic_info().repetition() { + new_context.def_level += 1; + } + + // Add map entry (key_value) to context + let map_key_value = map_type.get_fields().first().ok_or_else(|| { + ArrowError("Map field must have a key_value entry".to_string()) + })?; + new_context + .path + .append(vec![map_key_value.name().to_string()]); + new_context.rep_level += 1; + + // Get key and value, and create context for each + let map_key = map_key_value + .get_fields() + .first() + .ok_or_else(|| ArrowError("Map entry must have a key".to_string()))?; + let map_value = map_key_value + .get_fields() + .get(1) + .ok_or_else(|| ArrowError("Map entry must have a value".to_string()))?; + + let key_reader = { + let mut key_context = new_context.clone(); + key_context.def_level += 1; + key_context.path.append(vec![map_key.name().to_string()]); + self.dispatch(map_key.clone(), &key_context)?.unwrap() + }; + let value_reader = { + let mut value_context = new_context.clone(); + if let Repetition::OPTIONAL = map_value.get_basic_info().repetition() { + value_context.def_level += 1; + } + self.dispatch(map_value.clone(), &value_context)?.unwrap() + }; + + let arrow_type = self + .arrow_schema + .field_with_name(map_type.name()) + .ok() + .map(|f| f.data_type().to_owned()) + .unwrap_or_else(|| { + ArrowType::Map( + Box::new(Field::new( + map_key_value.name(), + ArrowType::Struct(vec![ + Field::new( + map_key.name(), + key_reader.get_data_type().clone(), + false, + ), + Field::new( + map_value.name(), + value_reader.get_data_type().clone(), + map_value.is_optional(), + ), + ]), + map_type.is_optional(), + )), + false, + ) + }); + + let key_array_reader: Box = Box::new(MapArrayReader::new( + key_reader, + value_reader, + arrow_type, + new_context.def_level, + new_context.rep_level, + )); + + Ok(Some(key_array_reader)) } /// Build array reader for list type. @@ -1269,10 +1478,11 @@ impl<'a> TypeVisitor>, &'a ArrayReaderBuilderContext item_type: Arc, context: &'a ArrayReaderBuilderContext, ) -> Result>> { - let list_child = &list_type + let mut list_child = &list_type .get_fields() .first() - .ok_or_else(|| ArrowError("List field must have a child.".to_string()))?; + .ok_or_else(|| ArrowError("List field must have a child.".to_string()))? + .clone(); let mut new_context = context.clone(); new_context.path.append(vec![list_type.name().to_string()]); @@ -1316,9 +1526,6 @@ impl<'a> TypeVisitor>, &'a ArrayReaderBuilderContext _ => { // a list is a group type with a single child. The list child's // name comes from the child's field name. - let mut list_child = list_type.get_fields().first().ok_or(ArrowError( - "List GroupType should have a field".to_string(), - ))?; // if the child's name is "list" and it has a child, then use this child if list_child.name() == "list" && !list_child.get_fields().is_empty() { list_child = list_child.get_fields().first().unwrap(); diff --git a/parquet/src/arrow/arrow_reader.rs b/parquet/src/arrow/arrow_reader.rs index 83fb0a2f7e95..761c5a6781bf 100644 --- a/parquet/src/arrow/arrow_reader.rs +++ b/parquet/src/arrow/arrow_reader.rs @@ -668,4 +668,20 @@ mod tests { batch.unwrap(); } } + + #[test] + fn test_read_maps() { + let testdata = arrow::util::test_util::parquet_test_data(); + let path = format!("{}/nested_maps.snappy.parquet", testdata); + let parquet_file_reader = + SerializedFileReader::try_from(File::open(&path).unwrap()).unwrap(); + let mut arrow_reader = ParquetFileArrowReader::new(Arc::new(parquet_file_reader)); + let record_batch_reader = arrow_reader + .get_record_reader(60) + .expect("Failed to read into array!"); + + for batch in record_batch_reader { + batch.unwrap(); + } + } } diff --git a/parquet/src/arrow/arrow_writer.rs b/parquet/src/arrow/arrow_writer.rs index 3ff1304bbf26..4726734475ba 100644 --- a/parquet/src/arrow/arrow_writer.rs +++ b/parquet/src/arrow/arrow_writer.rs @@ -199,6 +199,15 @@ fn write_leaves( } Ok(()) } + ArrowDataType::Map(_, _) => { + let map_array: &arrow_array::MapArray = array + .as_any() + .downcast_ref::() + .expect("Unable to get map array"); + write_leaves(&mut row_group_writer, &map_array.keys(), &mut levels)?; + write_leaves(&mut row_group_writer, &map_array.values(), &mut levels)?; + Ok(()) + } ArrowDataType::Dictionary(_, value_type) => { // cast dictionary to a primitive let array = arrow::compute::cast(array, value_type)?; @@ -935,6 +944,36 @@ mod tests { ); } + #[test] + fn arrow_writer_map() { + // Note: we are using the JSON Arrow reader for brevity + let json_content = r#" + {"stocks":{"long": "$AAA", "short": "$BBB"}} + {"stocks":{"long": null, "long": "$CCC", "short": null}} + {"stocks":{"hedged": "$YYY", "long": null, "short": "$D"}} + "#; + let entries_struct_type = DataType::Struct(vec![ + Field::new("key", DataType::Utf8, false), + Field::new("value", DataType::Utf8, true), + ]); + let stocks_field = Field::new( + "stocks", + DataType::Map( + Box::new(Field::new("entries", entries_struct_type, false)), + false, + ), + true, + ); + let schema = Arc::new(Schema::new(vec![stocks_field])); + let builder = arrow::json::ReaderBuilder::new() + .with_schema(schema) + .with_batch_size(64); + let mut reader = builder.build(std::io::Cursor::new(json_content)).unwrap(); + + let batch = reader.next().unwrap().unwrap(); + roundtrip("test_arrow_writer_map.parquet", batch, None); + } + #[test] fn arrow_writer_2_level_struct() { // tests writing > diff --git a/parquet/src/arrow/levels.rs b/parquet/src/arrow/levels.rs index 0af0f9efa6c1..3be315b71e69 100644 --- a/parquet/src/arrow/levels.rs +++ b/parquet/src/arrow/levels.rs @@ -40,7 +40,7 @@ //! //! \[1\] [parquet-format#nested-encoding](https://github.com/apache/parquet-format#nested-encoding) -use arrow::array::{make_array, ArrayRef, StructArray}; +use arrow::array::{make_array, ArrayRef, MapArray, StructArray}; use arrow::datatypes::{DataType, Field}; /// Keeps track of the level information per array that is needed to write an Arrow array to Parquet. @@ -234,13 +234,53 @@ impl LevelInfo { LevelType::Primitive(list_field.is_nullable()), )] } - DataType::List(_) | DataType::LargeList(_) | DataType::Struct(_) => { + DataType::List(_) + | DataType::LargeList(_) + | DataType::Struct(_) + | DataType::Map(_, _) => { list_level.calculate_array_levels(&child_array, list_field) } DataType::FixedSizeList(_, _) => unimplemented!(), DataType::Union(_) => unimplemented!(), } } + DataType::Map(map_field, _) => { + // Calculate the map level + let map_level = self.calculate_child_levels( + array_offsets, + array_mask, + // A map is treated like a list as it has repetition + LevelType::List(field.is_nullable()), + ); + + let map_array = array.as_any().downcast_ref::().unwrap(); + + let key_array = map_array.keys(); + let value_array = map_array.values(); + + if let DataType::Struct(fields) = map_field.data_type() { + let key_field = &fields[0]; + let value_field = &fields[1]; + + let mut map_levels = vec![]; + + // Get key levels + let mut key_levels = + map_level.calculate_array_levels(&key_array, key_field); + map_levels.append(&mut key_levels); + + let mut value_levels = + map_level.calculate_array_levels(&value_array, value_field); + map_levels.append(&mut value_levels); + + map_levels + } else { + panic!( + "Map field should be a struct, found {:?}", + map_field.data_type() + ); + } + } DataType::FixedSizeList(_, _) => unimplemented!(), DataType::Struct(struct_fields) => { let struct_array: &StructArray = array @@ -663,7 +703,7 @@ impl LevelInfo { }; ((0..=(len as i64)).collect(), array_mask) } - DataType::List(_) => { + DataType::List(_) | DataType::Map(_, _) => { let data = array.data(); let offsets = unsafe { data.buffers()[0].typed_data::() }; let offsets = offsets @@ -1547,4 +1587,90 @@ mod tests { panic!("Levels should not be equal, to reflect the difference in struct nullness"); } } + + #[test] + fn test_map_array() { + // Note: we are using the JSON Arrow reader for brevity + let json_content = r#" + {"stocks":{"long": "$AAA", "short": "$BBB"}} + {"stocks":{"long": null, "long": "$CCC", "short": null}} + {"stocks":{"hedged": "$YYY", "long": null, "short": "$D"}} + "#; + let entries_struct_type = DataType::Struct(vec![ + Field::new("key", DataType::Utf8, false), + Field::new("value", DataType::Utf8, true), + ]); + let stocks_field = Field::new( + "stocks", + DataType::Map( + Box::new(Field::new("entries", entries_struct_type, false)), + false, + ), + // not nullable, so the keys have max level = 1 + false, + ); + let schema = Arc::new(Schema::new(vec![stocks_field])); + let builder = arrow::json::ReaderBuilder::new() + .with_schema(schema) + .with_batch_size(64); + let mut reader = builder.build(std::io::Cursor::new(json_content)).unwrap(); + + let batch = reader.next().unwrap().unwrap(); + + let expected_batch_level = LevelInfo { + definition: vec![0; 3], + repetition: None, + array_offsets: (0..=3).collect(), + array_mask: vec![true, true, true], + max_definition: 0, + level_type: LevelType::Root, + offset: 0, + length: 3, + }; + + let batch_level = LevelInfo::new(0, 3); + assert_eq!(&batch_level, &expected_batch_level); + + // calculate the map's level + let mut levels = vec![]; + batch + .columns() + .iter() + .zip(batch.schema().fields()) + .for_each(|(array, field)| { + let mut array_levels = batch_level.calculate_array_levels(array, field); + levels.append(&mut array_levels); + }); + assert_eq!(levels.len(), 2); + + // test key levels + let list_level = levels.get(0).unwrap(); + + let expected_level = LevelInfo { + definition: vec![1; 7], + repetition: Some(vec![0, 1, 0, 1, 0, 1, 1]), + array_offsets: vec![0, 2, 4, 7], + array_mask: vec![true; 7], + max_definition: 1, + level_type: LevelType::Primitive(false), + offset: 0, + length: 7, + }; + assert_eq!(list_level, &expected_level); + + // test values levels + let list_level = levels.get(1).unwrap(); + + let expected_level = LevelInfo { + definition: vec![2, 2, 2, 1, 2, 1, 2], + repetition: Some(vec![0, 1, 0, 1, 0, 1, 1]), + array_offsets: vec![0, 2, 4, 7], + array_mask: vec![true, true, true, false, true, false, true], + max_definition: 2, + level_type: LevelType::Primitive(true), + offset: 0, + length: 7, + }; + assert_eq!(list_level, &expected_level); + } } diff --git a/parquet/src/arrow/schema.rs b/parquet/src/arrow/schema.rs index 18dacd193a54..5fe94cef94db 100644 --- a/parquet/src/arrow/schema.rs +++ b/parquet/src/arrow/schema.rs @@ -507,6 +507,35 @@ fn arrow_to_parquet_type(field: &Field) -> Result { .with_repetition(repetition) .build() } + DataType::Map(field, _) => { + if let DataType::Struct(struct_fields) = field.data_type() { + Type::group_type_builder(name) + .with_fields(&mut vec![Arc::new( + Type::group_type_builder(field.name()) + .with_fields(&mut vec![ + Arc::new(arrow_to_parquet_type(&Field::new( + struct_fields[0].name(), + struct_fields[0].data_type().clone(), + false, + ))?), + Arc::new(arrow_to_parquet_type(&Field::new( + struct_fields[1].name(), + struct_fields[1].data_type().clone(), + struct_fields[1].is_nullable(), + ))?), + ]) + .with_repetition(Repetition::REPEATED) + .build()?, + )]) + .with_logical_type(Some(LogicalType::MAP(Default::default()))) + .with_repetition(repetition) + .build() + } else { + Err(ArrowError( + "DataType::Map should contain a struct field child".to_string(), + )) + } + } DataType::Union(_) => unimplemented!("See ARROW-8817."), DataType::Dictionary(_, ref value) => { // Dictionary encoding not handled at the schema level @@ -791,24 +820,28 @@ impl ParquetTypeConverter<'_> { /// /// This function takes care of logical type and repetition. fn to_group_type(&self) -> Result> { - if self.is_repeated() { - self.to_struct().map(|opt| { - opt.map(|dt| { - DataType::List(Box::new(Field::new( - self.schema.name(), - dt, - self.is_nullable(), - ))) - }) - }) - } else { - match ( - self.schema.get_basic_info().logical_type(), - self.schema.get_basic_info().converted_type(), - ) { - (Some(LogicalType::LIST(_)), _) => self.to_list(), - (None, ConvertedType::LIST) => self.to_list(), - _ => self.to_struct(), + match ( + self.schema.get_basic_info().logical_type(), + self.schema.get_basic_info().converted_type(), + ) { + (Some(LogicalType::LIST(_)), _) | (_, ConvertedType::LIST) => self.to_list(), + (Some(LogicalType::MAP(_)), _) + | (_, ConvertedType::MAP) + | (_, ConvertedType::MAP_KEY_VALUE) => self.to_map(), + (_, _) => { + if self.is_repeated() { + self.to_struct().map(|opt| { + opt.map(|dt| { + DataType::List(Box::new(Field::new( + self.schema.name(), + dt, + self.is_nullable(), + ))) + }) + }) + } else { + self.to_struct() + } } } } @@ -916,6 +949,87 @@ impl ParquetTypeConverter<'_> { )), } } + + /// Converts a parquet map to arrow map. + /// + /// To fully understand this algorithm, please refer to + /// [parquet doc](https://github.com/apache/parquet-format/blob/master/LogicalTypes.md). + fn to_map(&self) -> Result> { + match self.schema { + Type::PrimitiveType { .. } => Err(ParquetError::General(format!( + "{:?} is a map type and can't be processed as primitive.", + self.schema + ))), + Type::GroupType { + basic_info: _, + fields, + } if fields.len() == 1 => { + let key_item = fields.first().unwrap(); + + let (key_type, value_type) = match key_item.as_ref() { + Type::PrimitiveType { .. } => { + return Err(ArrowError( + "A map can only have a group child type (key_values)." + .to_string(), + )) + } + Type::GroupType { + basic_info: _, + fields, + } => { + if fields.len() != 2 { + return Err(ArrowError(format!("Map type should have 2 fields, a key and value. Found {} fields", fields.len()))); + } else { + let nested_key = fields.first().unwrap(); + let nested_key_converter = self.clone_with_schema(nested_key); + + let nested_value = fields.last().unwrap(); + let nested_value_converter = + self.clone_with_schema(nested_value); + + ( + nested_key_converter.to_data_type()?.map(|d| { + Field::new( + nested_key.name(), + d, + nested_key.is_optional(), + ) + }), + nested_value_converter.to_data_type()?.map(|d| { + Field::new( + nested_value.name(), + d, + nested_value.is_optional(), + ) + }), + ) + } + } + }; + + match (key_type, value_type) { + (Some(key), Some(value)) => Ok(Some(DataType::Map( + Box::new(Field::new( + key_item.name(), + DataType::Struct(vec![key, value]), + false, + )), + false, // There is no information to tell if keys are sorted + ))), + (None, None) => Ok(None), + (None, Some(_)) => Err(ArrowError( + "Could not convert the map key to a valid datatype".to_string(), + )), + (Some(_), None) => Err(ArrowError( + "Could not convert the map value to a valid datatype".to_string(), + )), + } + } + _ => Err(ArrowError( + "Group element type of map can only contain one field.".to_string(), + )), + } + } } #[cfg(test)] @@ -1311,6 +1425,122 @@ mod tests { } } + #[test] + fn test_parquet_maps() { + let mut arrow_fields = Vec::new(); + + // LIST encoding example taken from parquet-format/LogicalTypes.md + let message_type = " + message test_schema { + REQUIRED group my_map1 (MAP) { + REPEATED group key_value { + REQUIRED binary key (UTF8); + OPTIONAL int32 value; + } + } + OPTIONAL group my_map2 (MAP) { + REPEATED group map { + REQUIRED binary str (UTF8); + REQUIRED int32 num; + } + } + OPTIONAL group my_map3 (MAP_KEY_VALUE) { + REPEATED group map { + REQUIRED binary key (UTF8); + OPTIONAL int32 value; + } + } + } + "; + + // // Map + // required group my_map (MAP) { + // repeated group key_value { + // required binary key (UTF8); + // optional int32 value; + // } + // } + { + arrow_fields.push(Field::new( + "my_map1", + DataType::Map( + Box::new(Field::new( + "key_value", + DataType::Struct(vec![ + Field::new("key", DataType::Utf8, false), + Field::new("value", DataType::Int32, true), + ]), + false, + )), + false, + ), + false, + )); + } + + // // Map (nullable map, non-null values) + // optional group my_map (MAP) { + // repeated group map { + // required binary str (UTF8); + // required int32 num; + // } + // } + { + arrow_fields.push(Field::new( + "my_map2", + DataType::Map( + Box::new(Field::new( + "map", + DataType::Struct(vec![ + Field::new("str", DataType::Utf8, false), + Field::new("num", DataType::Int32, false), + ]), + false, + )), + false, + ), + true, + )); + } + + // // Map (nullable map, nullable values) + // optional group my_map (MAP_KEY_VALUE) { + // repeated group map { + // required binary key (UTF8); + // optional int32 value; + // } + // } + { + arrow_fields.push(Field::new( + "my_map3", + DataType::Map( + Box::new(Field::new( + "map", + DataType::Struct(vec![ + Field::new("key", DataType::Utf8, false), + Field::new("value", DataType::Int32, true), + ]), + false, + )), + false, + ), + true, + )); + } + + let parquet_group_type = parse_message_type(message_type).unwrap(); + + let parquet_schema = SchemaDescriptor::new(Arc::new(parquet_group_type)); + let converted_arrow_schema = + parquet_to_arrow_schema(&parquet_schema, &None).unwrap(); + let converted_fields = converted_arrow_schema.fields(); + + assert_eq!(arrow_fields.len(), converted_fields.len()); + for i in 0..arrow_fields.len() { + assert_eq!(arrow_fields[i], converted_fields[i]); + } + } + #[test] fn test_nested_schema() { let mut arrow_fields = Vec::new(); @@ -1843,6 +2073,52 @@ mod tests { Field::new("c36", DataType::Decimal(2, 1), false), Field::new("c37", DataType::Decimal(50, 20), false), Field::new("c38", DataType::Decimal(18, 12), true), + Field::new( + "c39", + DataType::Map( + Box::new(Field::new( + "key_value", + DataType::Struct(vec![ + Field::new("key", DataType::Utf8, false), + Field::new( + "value", + DataType::List(Box::new(Field::new( + "element", + DataType::Utf8, + true, + ))), + true, + ), + ]), + false, + )), + false, // fails to roundtrip keys_sorted + ), + true, + ), + Field::new( + "c40", + DataType::Map( + Box::new(Field::new( + "my_entries", + DataType::Struct(vec![ + Field::new("my_key", DataType::Utf8, false), + Field::new( + "my_value", + DataType::List(Box::new(Field::new( + "item", + DataType::Utf8, + true, + ))), + true, + ), + ]), + false, + )), + false, // fails to roundtrip keys_sorted + ), + true, + ), ], metadata, );