From a38ca5a34b5bba76224956a6adaf9b09ce9fc735 Mon Sep 17 00:00:00 2001 From: Jefffrey <22608443+Jefffrey@users.noreply.github.com> Date: Sun, 29 Oct 2023 21:13:43 +1100 Subject: [PATCH 01/12] Support for read/write f16 Parquet to Arrow --- parquet/Cargo.toml | 1 + parquet/regen.sh | 2 +- .../array_reader/fixed_len_byte_array.rs | 54 ++--- parquet/src/arrow/arrow_reader/mod.rs | 48 +++++ parquet/src/arrow/arrow_writer/mod.rs | 19 ++ parquet/src/arrow/schema/mod.rs | 200 ++++++------------ parquet/src/arrow/schema/primitive.rs | 22 +- parquet/src/basic.rs | 15 +- parquet/src/file/statistics.rs | 4 + parquet/src/format.rs | 88 +++++++- parquet/src/schema/printer.rs | 1 + parquet/src/schema/types.rs | 1 + 12 files changed, 283 insertions(+), 172 deletions(-) diff --git a/parquet/Cargo.toml b/parquet/Cargo.toml index 659e2c0ee3a..ab4292e0e0e 100644 --- a/parquet/Cargo.toml +++ b/parquet/Cargo.toml @@ -66,6 +66,7 @@ tokio = { version = "1.0", optional = true, default-features = false, features = hashbrown = { version = "0.14", default-features = false } twox-hash = { version = "1.6", default-features = false } paste = { version = "1.0" } +half = { version = "2.1", default-features = false } [dev-dependencies] base64 = { version = "0.21", default-features = false, features = ["std"] } diff --git a/parquet/regen.sh b/parquet/regen.sh index b8c3549e232..91539634339 100755 --- a/parquet/regen.sh +++ b/parquet/regen.sh @@ -17,7 +17,7 @@ # specific language governing permissions and limitations # under the License. -REVISION=aeae80660c1d0c97314e9da837de1abdebd49c37 +REVISION=46cc3a0647d301bb9579ca8dd2cc356caf2a72d2 SOURCE_DIR="$(cd "$(dirname "${BASH_SOURCE[0]:-$0}")" && pwd)" diff --git a/parquet/src/arrow/array_reader/fixed_len_byte_array.rs b/parquet/src/arrow/array_reader/fixed_len_byte_array.rs index b06091b6b57..07213083390 100644 --- a/parquet/src/arrow/array_reader/fixed_len_byte_array.rs +++ b/parquet/src/arrow/array_reader/fixed_len_byte_array.rs @@ -28,12 +28,13 @@ use crate::errors::{ParquetError, Result}; use crate::schema::types::ColumnDescPtr; use crate::util::memory::ByteBufferPtr; use arrow_array::{ - ArrayRef, Decimal128Array, Decimal256Array, FixedSizeBinaryArray, + ArrayRef, Decimal128Array, Decimal256Array, FixedSizeBinaryArray, Float16Array, IntervalDayTimeArray, IntervalYearMonthArray, }; use arrow_buffer::{i256, Buffer}; use arrow_data::ArrayDataBuilder; use arrow_schema::{DataType as ArrowType, IntervalUnit}; +use half::f16; use std::any::Any; use std::ops::Range; use std::sync::Arc; @@ -88,6 +89,14 @@ pub fn make_fixed_len_byte_array_reader( )); } } + ArrowType::Float16 => { + if byte_length != 2 { + return Err(general_err!( + "float 16 type must be 2 bytes, got {}", + byte_length + )); + } + } _ => { return Err(general_err!( "invalid data type for fixed length byte array reader - {}", @@ -153,11 +162,10 @@ impl ArrayReader for FixedLenByteArrayReader { fn consume_batch(&mut self) -> Result { let record_data = self.record_reader.consume_record_data(); - let array_data = - ArrayDataBuilder::new(ArrowType::FixedSizeBinary(self.byte_length as i32)) - .len(self.record_reader.num_values()) - .add_buffer(record_data) - .null_bit_buffer(self.record_reader.consume_bitmap_buffer()); + let array_data = ArrayDataBuilder::new(ArrowType::FixedSizeBinary(self.byte_length as i32)) + .len(self.record_reader.num_values()) + .add_buffer(record_data) + .null_bit_buffer(self.record_reader.consume_bitmap_buffer()); let binary = FixedSizeBinaryArray::from(unsafe { array_data.build_unchecked() }); @@ -188,19 +196,13 @@ impl ArrayReader for FixedLenByteArrayReader { IntervalUnit::YearMonth => Arc::new( binary .iter() - .map(|o| { - o.map(|b| i32::from_le_bytes(b[0..4].try_into().unwrap())) - }) + .map(|o| o.map(|b| i32::from_le_bytes(b[0..4].try_into().unwrap()))) .collect::(), ) as ArrayRef, IntervalUnit::DayTime => Arc::new( binary .iter() - .map(|o| { - o.map(|b| { - i64::from_le_bytes(b[4..12].try_into().unwrap()) - }) - }) + .map(|o| o.map(|b| i64::from_le_bytes(b[4..12].try_into().unwrap()))) .collect::(), ) as ArrayRef, IntervalUnit::MonthDayNano => { @@ -208,6 +210,12 @@ impl ArrayReader for FixedLenByteArrayReader { } } } + ArrowType::Float16 => Arc::new( + binary + .iter() + .map(|o| o.map(|b| f16::from_le_bytes(b[..2].try_into().unwrap()))) + .collect::(), + ) as ArrayRef, _ => Arc::new(binary) as ArrayRef, }; @@ -278,9 +286,7 @@ impl ValuesBuffer for FixedLenByteArrayBuffer { let slice = self.buffer.as_slice_mut(); let values_range = read_offset..read_offset + values_read; - for (value_pos, level_pos) in - values_range.rev().zip(iter_set_bits_rev(valid_mask)) - { + for (value_pos, level_pos) in values_range.rev().zip(iter_set_bits_rev(valid_mask)) { debug_assert!(level_pos >= value_pos); if level_pos <= value_pos { break; @@ -376,8 +382,7 @@ impl ColumnValueDecoder for ValueDecoder { let len = range.end - range.start; match self.decoder.as_mut().unwrap() { Decoder::Plain { offset, buf } => { - let to_read = - (len * self.byte_length).min(buf.len() - *offset) / self.byte_length; + let to_read = (len * self.byte_length).min(buf.len() - *offset) / self.byte_length; let end_offset = *offset + to_read * self.byte_length; out.buffer .extend_from_slice(&buf.as_ref()[*offset..end_offset]); @@ -470,15 +475,12 @@ mod tests { .build() .unwrap(); - let written = RecordBatch::try_from_iter([( - "list", - Arc::new(ListArray::from(data)) as ArrayRef, - )]) - .unwrap(); + let written = + RecordBatch::try_from_iter([("list", Arc::new(ListArray::from(data)) as ArrayRef)]) + .unwrap(); let mut buffer = Vec::with_capacity(1024); - let mut writer = - ArrowWriter::try_new(&mut buffer, written.schema(), None).unwrap(); + let mut writer = ArrowWriter::try_new(&mut buffer, written.schema(), None).unwrap(); writer.write(&written).unwrap(); writer.close().unwrap(); diff --git a/parquet/src/arrow/arrow_reader/mod.rs b/parquet/src/arrow/arrow_reader/mod.rs index 16cdf2934e6..9ed63086c65 100644 --- a/parquet/src/arrow/arrow_reader/mod.rs +++ b/parquet/src/arrow/arrow_reader/mod.rs @@ -712,6 +712,7 @@ mod tests { use std::sync::Arc; use bytes::Bytes; + use half::f16; use num::PrimInt; use rand::{thread_rng, Rng, RngCore}; use tempfile::tempfile; @@ -924,6 +925,53 @@ mod tests { .unwrap(); } + #[test] + fn test_float16_roundtrip() -> Result<()> { + let schema = Arc::new(Schema::new(vec![Field::new( + "float16", + ArrowDataType::Float16, + true, + )])); + + let mut buf = Vec::with_capacity(1024); + let mut writer = ArrowWriter::try_new(&mut buf, schema.clone(), None)?; + + let original = RecordBatch::try_new( + schema, + vec![Arc::new(Float16Array::from_iter_values([ + f16::EPSILON, + f16::INFINITY, + f16::MIN, + f16::MAX, + f16::NAN, + f16::INFINITY, + f16::NEG_INFINITY, + f16::ONE, + f16::NEG_ONE, + f16::ZERO, + f16::NEG_ZERO, + f16::E, + f16::PI, + f16::FRAC_1_PI, + ]))], + )?; + + writer.write(&original)?; + writer.close()?; + + let mut reader = ParquetRecordBatchReader::try_new(Bytes::from(buf), 1024)?; + let ret = reader.next().unwrap()?; + assert_eq!(ret, original); + + // Ensure can be downcast to the correct type + ret.column(0) + .as_any() + .downcast_ref::() + .unwrap(); + + Ok(()) + } + struct RandFixedLenGen {} impl RandGen for RandFixedLenGen { diff --git a/parquet/src/arrow/arrow_writer/mod.rs b/parquet/src/arrow/arrow_writer/mod.rs index a9cd1afb247..df218f354d9 100644 --- a/parquet/src/arrow/arrow_writer/mod.rs +++ b/parquet/src/arrow/arrow_writer/mod.rs @@ -771,6 +771,13 @@ fn write_leaf(writer: &mut ColumnWriter<'_>, levels: &ArrayLevels) -> Result { + let array = column + .as_any() + .downcast_ref::() + .unwrap(); + get_float_16_array_slice(array, indices) + } _ => { return Err(ParquetError::NYI( "Attempting to write an Arrow type that is not yet implemented".to_string(), @@ -867,6 +874,18 @@ fn get_decimal_256_array_slice( values } +fn get_float_16_array_slice( + array: &arrow_array::Float16Array, + indices: &[usize], +) -> Vec { + let mut values = Vec::with_capacity(indices.len()); + for i in indices { + let value = array.value(*i).to_le_bytes().to_vec(); + values.push(FixedLenByteArray::from(ByteArray::from(value))); + } + values +} + fn get_fsb_array_slice( array: &arrow_array::FixedSizeBinaryArray, indices: &[usize], diff --git a/parquet/src/arrow/schema/mod.rs b/parquet/src/arrow/schema/mod.rs index d56cc42d431..1b44c012308 100644 --- a/parquet/src/arrow/schema/mod.rs +++ b/parquet/src/arrow/schema/mod.rs @@ -32,8 +32,7 @@ use arrow_ipc::writer; use arrow_schema::{DataType, Field, Fields, Schema, TimeUnit}; use crate::basic::{ - ConvertedType, LogicalType, Repetition, TimeUnit as ParquetTimeUnit, - Type as PhysicalType, + ConvertedType, LogicalType, Repetition, TimeUnit as ParquetTimeUnit, Type as PhysicalType, }; use crate::errors::{ParquetError, Result}; use crate::file::{metadata::KeyValue, properties::WriterProperties}; @@ -55,11 +54,7 @@ pub fn parquet_to_arrow_schema( parquet_schema: &SchemaDescriptor, key_value_metadata: Option<&Vec>, ) -> Result { - parquet_to_arrow_schema_by_columns( - parquet_schema, - ProjectionMask::all(), - key_value_metadata, - ) + parquet_to_arrow_schema_by_columns(parquet_schema, ProjectionMask::all(), key_value_metadata) } /// Convert parquet schema to arrow schema including optional metadata, @@ -199,10 +194,7 @@ fn encode_arrow_schema(schema: &Schema) -> String { /// Mutates writer metadata by storing the encoded Arrow schema. /// If there is an existing Arrow schema metadata, it is replaced. -pub(crate) fn add_encoded_arrow_schema_to_metadata( - schema: &Schema, - props: &mut WriterProperties, -) { +pub(crate) fn add_encoded_arrow_schema_to_metadata(schema: &Schema, props: &mut WriterProperties) { let encoded = encode_arrow_schema(schema); let schema_kv = KeyValue { @@ -270,16 +262,15 @@ fn parse_key_value_metadata( /// Convert parquet column schema to arrow field. pub fn parquet_to_arrow_field(parquet_column: &ColumnDescriptor) -> Result { let field = complex::convert_type(&parquet_column.self_type_ptr())?; - let mut ret = Field::new( - parquet_column.name(), - field.arrow_type, - field.nullable, - ); + let mut ret = Field::new(parquet_column.name(), field.arrow_type, field.nullable); let basic_info = parquet_column.self_type().get_basic_info(); if basic_info.has_id() { let mut meta = HashMap::with_capacity(1); - meta.insert(PARQUET_FIELD_ID_META_KEY.to_string(), basic_info.id().to_string()); + meta.insert( + PARQUET_FIELD_ID_META_KEY.to_string(), + basic_info.id().to_string(), + ); ret.set_metadata(meta); } @@ -373,7 +364,12 @@ fn arrow_to_parquet_type(field: &Field) -> Result { .with_repetition(repetition) .with_id(id) .build(), - DataType::Float16 => Err(arrow_err!("Float16 arrays not supported")), + DataType::Float16 => Type::primitive_type_builder(name, PhysicalType::FIXED_LEN_BYTE_ARRAY) + .with_repetition(repetition) + .with_id(id) + .with_logical_type(Some(LogicalType::Float16)) + .with_length(2) + .build(), DataType::Float32 => Type::primitive_type_builder(name, PhysicalType::FLOAT) .with_repetition(repetition) .with_id(id) @@ -396,15 +392,9 @@ fn arrow_to_parquet_type(field: &Field) -> Result { is_adjusted_to_u_t_c: matches!(tz, Some(z) if !z.as_ref().is_empty()), unit: match time_unit { TimeUnit::Second => unreachable!(), - TimeUnit::Millisecond => { - ParquetTimeUnit::MILLIS(Default::default()) - } - TimeUnit::Microsecond => { - ParquetTimeUnit::MICROS(Default::default()) - } - TimeUnit::Nanosecond => { - ParquetTimeUnit::NANOS(Default::default()) - } + TimeUnit::Millisecond => ParquetTimeUnit::MILLIS(Default::default()), + TimeUnit::Microsecond => ParquetTimeUnit::MICROS(Default::default()), + TimeUnit::Nanosecond => ParquetTimeUnit::NANOS(Default::default()), }, })) .with_repetition(repetition) @@ -452,9 +442,7 @@ fn arrow_to_parquet_type(field: &Field) -> Result { .with_repetition(repetition) .with_id(id) .build(), - DataType::Duration(_) => { - Err(arrow_err!("Converting Duration to parquet not supported",)) - } + DataType::Duration(_) => Err(arrow_err!("Converting Duration to parquet not supported",)), DataType::Interval(_) => { Type::primitive_type_builder(name, PhysicalType::FIXED_LEN_BYTE_ARRAY) .with_converted_type(ConvertedType::INTERVAL) @@ -476,8 +464,7 @@ fn arrow_to_parquet_type(field: &Field) -> Result { .with_length(*length) .build() } - DataType::Decimal128(precision, scale) - | DataType::Decimal256(precision, scale) => { + DataType::Decimal128(precision, scale) | DataType::Decimal256(precision, scale) => { // Decimal precision determines the Parquet physical type to use. // Following the: https://github.com/apache/parquet-format/blob/master/LogicalTypes.md#decimal let (physical_type, length) = if *precision > 1 && *precision <= 9 { @@ -524,9 +511,7 @@ fn arrow_to_parquet_type(field: &Field) -> Result { } DataType::Struct(fields) => { if fields.is_empty() { - return Err( - arrow_err!("Parquet does not support writing empty structs",), - ); + return Err(arrow_err!("Parquet does not support writing empty structs",)); } // recursively convert children to types/nodes let fields = fields @@ -604,9 +589,10 @@ mod tests { REQUIRED INT32 uint8 (INTEGER(8,false)); REQUIRED INT32 uint16 (INTEGER(16,false)); REQUIRED INT32 int32; - REQUIRED INT64 int64 ; + REQUIRED INT64 int64; OPTIONAL DOUBLE double; OPTIONAL FLOAT float; + OPTIONAL FIXED_LEN_BYTE_ARRAY (2) float16 (FLOAT16); OPTIONAL BINARY string (UTF8); OPTIONAL BINARY string_2 (STRING); OPTIONAL BINARY json (JSON); @@ -615,8 +601,7 @@ mod tests { let parquet_group_type = parse_message_type(message_type).unwrap(); let parquet_schema = SchemaDescriptor::new(Arc::new(parquet_group_type)); - let converted_arrow_schema = - parquet_to_arrow_schema(&parquet_schema, None).unwrap(); + let converted_arrow_schema = parquet_to_arrow_schema(&parquet_schema, None).unwrap(); let arrow_fields = Fields::from(vec![ Field::new("boolean", DataType::Boolean, false), @@ -628,6 +613,7 @@ mod tests { Field::new("int64", DataType::Int64, false), Field::new("double", DataType::Float64, true), Field::new("float", DataType::Float32, true), + Field::new("float16", DataType::Float16, true), Field::new("string", DataType::Utf8, true), Field::new("string_2", DataType::Utf8, true), Field::new("json", DataType::Utf8, true), @@ -653,8 +639,7 @@ mod tests { let parquet_group_type = parse_message_type(message_type).unwrap(); let parquet_schema = SchemaDescriptor::new(Arc::new(parquet_group_type)); - let converted_arrow_schema = - parquet_to_arrow_schema(&parquet_schema, None).unwrap(); + let converted_arrow_schema = parquet_to_arrow_schema(&parquet_schema, None).unwrap(); let arrow_fields = Fields::from(vec![ Field::new("decimal1", DataType::Decimal128(4, 2), false), @@ -680,8 +665,7 @@ mod tests { let parquet_group_type = parse_message_type(message_type).unwrap(); let parquet_schema = SchemaDescriptor::new(Arc::new(parquet_group_type)); - let converted_arrow_schema = - parquet_to_arrow_schema(&parquet_schema, None).unwrap(); + let converted_arrow_schema = parquet_to_arrow_schema(&parquet_schema, None).unwrap(); let arrow_fields = Fields::from(vec![ Field::new("binary", DataType::Binary, false), @@ -702,8 +686,7 @@ mod tests { let parquet_group_type = parse_message_type(message_type).unwrap(); let parquet_schema = SchemaDescriptor::new(Arc::new(parquet_group_type)); - let converted_arrow_schema = - parquet_to_arrow_schema(&parquet_schema, None).unwrap(); + let converted_arrow_schema = parquet_to_arrow_schema(&parquet_schema, None).unwrap(); let arrow_fields = Fields::from(vec![ Field::new("boolean", DataType::Boolean, false), @@ -711,12 +694,9 @@ mod tests { ]); assert_eq!(&arrow_fields, converted_arrow_schema.fields()); - let converted_arrow_schema = parquet_to_arrow_schema_by_columns( - &parquet_schema, - ProjectionMask::all(), - None, - ) - .unwrap(); + let converted_arrow_schema = + parquet_to_arrow_schema_by_columns(&parquet_schema, ProjectionMask::all(), None) + .unwrap(); assert_eq!(&arrow_fields, converted_arrow_schema.fields()); } @@ -914,8 +894,7 @@ mod tests { let parquet_group_type = parse_message_type(message_type).unwrap(); let parquet_schema = SchemaDescriptor::new(Arc::new(parquet_group_type)); - let converted_arrow_schema = - parquet_to_arrow_schema(&parquet_schema, None).unwrap(); + let converted_arrow_schema = parquet_to_arrow_schema(&parquet_schema, None).unwrap(); let converted_fields = converted_arrow_schema.fields(); assert_eq!(arrow_fields.len(), converted_fields.len()); @@ -993,8 +972,7 @@ mod tests { let parquet_group_type = parse_message_type(message_type).unwrap(); let parquet_schema = SchemaDescriptor::new(Arc::new(parquet_group_type)); - let converted_arrow_schema = - parquet_to_arrow_schema(&parquet_schema, None).unwrap(); + let converted_arrow_schema = parquet_to_arrow_schema(&parquet_schema, None).unwrap(); let converted_fields = converted_arrow_schema.fields(); assert_eq!(arrow_fields.len(), converted_fields.len()); @@ -1088,8 +1066,7 @@ mod tests { let parquet_group_type = parse_message_type(message_type).unwrap(); let parquet_schema = SchemaDescriptor::new(Arc::new(parquet_group_type)); - let converted_arrow_schema = - parquet_to_arrow_schema(&parquet_schema, None).unwrap(); + let converted_arrow_schema = parquet_to_arrow_schema(&parquet_schema, None).unwrap(); let converted_fields = converted_arrow_schema.fields(); assert_eq!(arrow_fields.len(), converted_fields.len()); @@ -1106,8 +1083,7 @@ mod tests { Field::new("leaf1", DataType::Boolean, false), Field::new("leaf2", DataType::Int32, false), ]); - let group1_struct = - Field::new("group1", DataType::Struct(group1_fields), false); + let group1_struct = Field::new("group1", DataType::Struct(group1_fields), false); arrow_fields.push(group1_struct); let leaf3_field = Field::new("leaf3", DataType::Int64, false); @@ -1126,8 +1102,7 @@ mod tests { let parquet_group_type = parse_message_type(message_type).unwrap(); let parquet_schema = SchemaDescriptor::new(Arc::new(parquet_group_type)); - let converted_arrow_schema = - parquet_to_arrow_schema(&parquet_schema, None).unwrap(); + let converted_arrow_schema = parquet_to_arrow_schema(&parquet_schema, None).unwrap(); let converted_fields = converted_arrow_schema.fields(); assert_eq!(arrow_fields.len(), converted_fields.len()); @@ -1280,8 +1255,7 @@ mod tests { let parquet_group_type = parse_message_type(message_type).unwrap(); let parquet_schema = SchemaDescriptor::new(Arc::new(parquet_group_type)); - let converted_arrow_schema = - parquet_to_arrow_schema(&parquet_schema, None).unwrap(); + let converted_arrow_schema = parquet_to_arrow_schema(&parquet_schema, None).unwrap(); let converted_fields = converted_arrow_schema.fields(); assert_eq!(arrow_fields.len(), converted_fields.len()); @@ -1303,6 +1277,7 @@ mod tests { REQUIRED INT64 int64; OPTIONAL DOUBLE double; OPTIONAL FLOAT float; + OPTIONAL FIXED_LEN_BYTE_ARRAY (2) float16 (FLOAT16); OPTIONAL BINARY string (UTF8); REPEATED BOOLEAN bools; OPTIONAL INT32 date (DATE); @@ -1339,6 +1314,7 @@ mod tests { Field::new("int64", DataType::Int64, false), Field::new("double", DataType::Float64, true), Field::new("float", DataType::Float32, true), + Field::new("float16", DataType::Float16, true), Field::new("string", DataType::Utf8, true), Field::new_list( "bools", @@ -1398,6 +1374,7 @@ mod tests { REQUIRED INT64 int64; OPTIONAL DOUBLE double; OPTIONAL FLOAT float; + OPTIONAL FIXED_LEN_BYTE_ARRAY (2) float16 (FLOAT16); OPTIONAL BINARY string (STRING); OPTIONAL GROUP bools (LIST) { REPEATED GROUP list { @@ -1448,6 +1425,7 @@ mod tests { Field::new("int64", DataType::Int64, false), Field::new("double", DataType::Float64, true), Field::new("float", DataType::Float32, true), + Field::new("float16", DataType::Float16, true), Field::new("string", DataType::Utf8, true), Field::new_list( "bools", @@ -1502,20 +1480,11 @@ mod tests { vec![ Field::new("bools", DataType::Boolean, false), Field::new("uint32", DataType::UInt32, false), - Field::new_list( - "int32", - Field::new("element", DataType::Int32, true), - false, - ), + Field::new_list("int32", Field::new("element", DataType::Int32, true), false), ], false, ), - Field::new_dictionary( - "dictionary_strings", - DataType::Int32, - DataType::Utf8, - false, - ), + Field::new_dictionary("dictionary_strings", DataType::Int32, DataType::Utf8, false), Field::new("decimal_int32", DataType::Decimal128(8, 2), false), Field::new("decimal_int64", DataType::Decimal128(16, 2), false), Field::new("decimal_fix_length", DataType::Decimal128(30, 2), false), @@ -1600,10 +1569,8 @@ mod tests { let schema = Schema::new_with_metadata( vec![ - Field::new("c1", DataType::Utf8, false).with_metadata(meta(&[ - ("Key", "Foo"), - (PARQUET_FIELD_ID_META_KEY, "2"), - ])), + Field::new("c1", DataType::Utf8, false) + .with_metadata(meta(&[("Key", "Foo"), (PARQUET_FIELD_ID_META_KEY, "2")])), Field::new("c2", DataType::Binary, false), Field::new("c3", DataType::FixedSizeBinary(3), false), Field::new("c4", DataType::Boolean, false), @@ -1621,10 +1588,7 @@ mod tests { ), Field::new( "c17", - DataType::Timestamp( - TimeUnit::Microsecond, - Some("Africa/Johannesburg".into()), - ), + DataType::Timestamp(TimeUnit::Microsecond, Some("Africa/Johannesburg".into())), false, ), Field::new( @@ -1636,10 +1600,8 @@ mod tests { Field::new("c20", DataType::Interval(IntervalUnit::YearMonth), false), Field::new_list( "c21", - Field::new("item", DataType::Boolean, true).with_metadata(meta(&[ - ("Key", "Bar"), - (PARQUET_FIELD_ID_META_KEY, "5"), - ])), + Field::new("item", DataType::Boolean, true) + .with_metadata(meta(&[("Key", "Bar"), (PARQUET_FIELD_ID_META_KEY, "5")])), false, ) .with_metadata(meta(&[(PARQUET_FIELD_ID_META_KEY, "4")])), @@ -1661,6 +1623,8 @@ mod tests { vec![ Field::new("a", DataType::Int16, true), Field::new("b", DataType::Float64, false), + Field::new("c", DataType::Float32, false), + Field::new("d", DataType::Float16, false), ] .into(), ), @@ -1687,10 +1651,7 @@ mod tests { // Field::new("c30", DataType::Duration(TimeUnit::Nanosecond), false), Field::new_dict( "c31", - DataType::Dictionary( - Box::new(DataType::Int32), - Box::new(DataType::Utf8), - ), + DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::Utf8)), true, 123, true, @@ -1725,11 +1686,7 @@ mod tests { "c39", "key_value", Field::new("key", DataType::Utf8, false), - Field::new_list( - "value", - Field::new("element", DataType::Utf8, true), - true, - ), + Field::new_list("value", Field::new("element", DataType::Utf8, true), true), false, // fails to roundtrip keys_sorted true, ), @@ -1768,11 +1725,8 @@ mod tests { // write to an empty parquet file so that schema is serialized let file = tempfile::tempfile().unwrap(); - let writer = ArrowWriter::try_new( - file.try_clone().unwrap(), - Arc::new(schema.clone()), - None, - )?; + let writer = + ArrowWriter::try_new(file.try_clone().unwrap(), Arc::new(schema.clone()), None)?; writer.close()?; // read file back @@ -1831,33 +1785,23 @@ mod tests { }; let schema = Schema::new_with_metadata( vec![ - Field::new("c1", DataType::Utf8, true).with_metadata(meta(&[ - (PARQUET_FIELD_ID_META_KEY, "1"), - ])), - Field::new("c2", DataType::Utf8, true).with_metadata(meta(&[ - (PARQUET_FIELD_ID_META_KEY, "2"), - ])), + Field::new("c1", DataType::Utf8, true) + .with_metadata(meta(&[(PARQUET_FIELD_ID_META_KEY, "1")])), + Field::new("c2", DataType::Utf8, true) + .with_metadata(meta(&[(PARQUET_FIELD_ID_META_KEY, "2")])), ], HashMap::new(), ); - let writer = ArrowWriter::try_new( - vec![], - Arc::new(schema.clone()), - None, - )?; + let writer = ArrowWriter::try_new(vec![], Arc::new(schema.clone()), None)?; let parquet_bytes = writer.into_inner()?; - let reader = crate::file::reader::SerializedFileReader::new( - bytes::Bytes::from(parquet_bytes), - )?; + let reader = + crate::file::reader::SerializedFileReader::new(bytes::Bytes::from(parquet_bytes))?; let schema_descriptor = reader.metadata().file_metadata().schema_descr_ptr(); // don't pass metadata so field ids are read from Parquet and not from serialized Arrow schema - let arrow_schema = crate::arrow::parquet_to_arrow_schema( - &schema_descriptor, - None, - )?; + let arrow_schema = crate::arrow::parquet_to_arrow_schema(&schema_descriptor, None)?; let parq_schema_descr = crate::arrow::arrow_to_parquet_schema(&arrow_schema)?; let parq_fields = parq_schema_descr.root_schema().get_fields(); @@ -1870,19 +1814,14 @@ mod tests { #[test] fn test_arrow_schema_roundtrip_lists() -> Result<()> { - let metadata: HashMap = - [("Key".to_string(), "Value".to_string())] - .iter() - .cloned() - .collect(); + let metadata: HashMap = [("Key".to_string(), "Value".to_string())] + .iter() + .cloned() + .collect(); let schema = Schema::new_with_metadata( vec![ - Field::new_list( - "c21", - Field::new("array", DataType::Boolean, true), - false, - ), + Field::new_list("c21", Field::new("array", DataType::Boolean, true), false), Field::new( "c22", DataType::FixedSizeList( @@ -1913,11 +1852,8 @@ mod tests { // write to an empty parquet file so that schema is serialized let file = tempfile::tempfile().unwrap(); - let writer = ArrowWriter::try_new( - file.try_clone().unwrap(), - Arc::new(schema.clone()), - None, - )?; + let writer = + ArrowWriter::try_new(file.try_clone().unwrap(), Arc::new(schema.clone()), None)?; writer.close()?; // read file back diff --git a/parquet/src/arrow/schema/primitive.rs b/parquet/src/arrow/schema/primitive.rs index 7d8b6a04ee8..447fe5fc3ab 100644 --- a/parquet/src/arrow/schema/primitive.rs +++ b/parquet/src/arrow/schema/primitive.rs @@ -15,9 +15,7 @@ // specific language governing permissions and limitations // under the License. -use crate::basic::{ - ConvertedType, LogicalType, TimeUnit as ParquetTimeUnit, Type as PhysicalType, -}; +use crate::basic::{ConvertedType, LogicalType, TimeUnit as ParquetTimeUnit, Type as PhysicalType}; use crate::errors::{ParquetError, Result}; use crate::schema::types::{BasicTypeInfo, Type}; use arrow_schema::{DataType, IntervalUnit, TimeUnit, DECIMAL128_MAX_PRECISION}; @@ -158,9 +156,7 @@ fn from_int32(info: &BasicTypeInfo, scale: i32, precision: i32) -> Result Ok(DataType::UInt32), _ => Err(arrow_err!("Cannot create INT32 physical type from {:?}", t)), }, - (Some(LogicalType::Decimal { scale, precision }), _) => { - decimal_128_type(scale, precision) - } + (Some(LogicalType::Decimal { scale, precision }), _) => decimal_128_type(scale, precision), (Some(LogicalType::Date), _) => Ok(DataType::Date32), (Some(LogicalType::Time { unit, .. }), _) => match unit { ParquetTimeUnit::MILLIS(_) => Ok(DataType::Time32(TimeUnit::Millisecond)), @@ -237,9 +233,7 @@ fn from_int64(info: &BasicTypeInfo, scale: i32, precision: i32) -> Result { - decimal_128_type(scale, precision) - } + (Some(LogicalType::Decimal { scale, precision }), _) => decimal_128_type(scale, precision), (None, ConvertedType::DECIMAL) => decimal_128_type(scale, precision), (logical, converted) => Err(arrow_err!( "Unable to convert parquet INT64 logical type {:?} or converted type {}", @@ -304,6 +298,16 @@ fn from_fixed_len_byte_array( // would be incorrect if all 12 bytes of the interval are populated Ok(DataType::Interval(IntervalUnit::DayTime)) } + (Some(LogicalType::Float16), _) => { + if type_length == 2 { + Ok(DataType::Float16) + } else { + Err(ParquetError::General( + "FLOAT16 logical type must be Fixed Length Byte Array with length 2" + .to_string(), + )) + } + } _ => Ok(DataType::FixedSizeBinary(type_length)), } } diff --git a/parquet/src/basic.rs b/parquet/src/basic.rs index ab71aa44169..4c9cb018d4e 100644 --- a/parquet/src/basic.rs +++ b/parquet/src/basic.rs @@ -194,6 +194,7 @@ pub enum LogicalType { Json, Bson, Uuid, + Float16, } // ---------------------------------------------------------------------- @@ -478,6 +479,7 @@ impl ColumnOrder { LogicalType::Timestamp { .. } => SortOrder::SIGNED, LogicalType::Unknown => SortOrder::UNDEFINED, LogicalType::Uuid => SortOrder::UNSIGNED, + LogicalType::Float16 => SortOrder::SIGNED, }, // Fall back to converted type None => Self::get_converted_sort_order(converted_type, physical_type), @@ -739,6 +741,7 @@ impl From for LogicalType { parquet::LogicalType::JSON(_) => LogicalType::Json, parquet::LogicalType::BSON(_) => LogicalType::Bson, parquet::LogicalType::UUID(_) => LogicalType::Uuid, + parquet::LogicalType::FLOAT16(_) => LogicalType::Float16, } } } @@ -779,6 +782,7 @@ impl From for parquet::LogicalType { LogicalType::Json => parquet::LogicalType::JSON(Default::default()), LogicalType::Bson => parquet::LogicalType::BSON(Default::default()), LogicalType::Uuid => parquet::LogicalType::UUID(Default::default()), + LogicalType::Float16 => parquet::LogicalType::FLOAT16(Default::default()), } } } @@ -826,10 +830,11 @@ impl From> for ConvertedType { (64, false) => ConvertedType::UINT_64, t => panic!("Integer type {t:?} is not supported"), }, - LogicalType::Unknown => ConvertedType::NONE, LogicalType::Json => ConvertedType::JSON, LogicalType::Bson => ConvertedType::BSON, - LogicalType::Uuid => ConvertedType::NONE, + LogicalType::Uuid | LogicalType::Float16 | LogicalType::Unknown => { + ConvertedType::NONE + } }, None => ConvertedType::NONE, } @@ -1075,6 +1080,7 @@ impl str::FromStr for LogicalType { "INTERVAL" => Err(general_err!( "Interval parquet logical type not yet supported" )), + "FLOAT16" => Ok(LogicalType::Float16), other => Err(general_err!("Invalid parquet logical type {}", other)), } } @@ -1719,6 +1725,10 @@ mod tests { ConvertedType::from(Some(LogicalType::Enum)), ConvertedType::ENUM ); + assert_eq!( + ConvertedType::from(Some(LogicalType::Float16)), + ConvertedType::NONE + ); assert_eq!( ConvertedType::from(Some(LogicalType::Unknown)), ConvertedType::NONE @@ -2092,6 +2102,7 @@ mod tests { is_adjusted_to_u_t_c: true, unit: TimeUnit::NANOS(Default::default()), }, + LogicalType::Float16, ]; check_sort_order(signed, SortOrder::SIGNED); diff --git a/parquet/src/file/statistics.rs b/parquet/src/file/statistics.rs index b36e37a80c9..345fe7dd261 100644 --- a/parquet/src/file/statistics.rs +++ b/parquet/src/file/statistics.rs @@ -243,6 +243,8 @@ pub fn to_thrift(stats: Option<&Statistics>) -> Option { distinct_count: stats.distinct_count().map(|value| value as i64), max_value: None, min_value: None, + is_max_value_exact: None, + is_min_value_exact: None, }; // Get min/max if set. @@ -607,6 +609,8 @@ mod tests { distinct_count: None, max_value: None, min_value: None, + is_max_value_exact: None, + is_min_value_exact: None, }; from_thrift(Type::INT32, Some(thrift_stats)).unwrap(); diff --git a/parquet/src/format.rs b/parquet/src/format.rs index 46adc39e640..4700b05dc28 100644 --- a/parquet/src/format.rs +++ b/parquet/src/format.rs @@ -657,16 +657,26 @@ pub struct Statistics { pub null_count: Option, /// count of distinct values occurring pub distinct_count: Option, - /// Min and max values for the column, determined by its ColumnOrder. + /// Lower and upper bound values for the column, determined by its ColumnOrder. + /// + /// These may be the actual minimum and maximum values found on a page or column + /// chunk, but can also be (more compact) values that do not exist on a page or + /// column chunk. For example, instead of storing "Blart Versenwald III", a writer + /// may set min_value="B", max_value="C". Such more compact values must still be + /// valid values within the column's logical type. /// /// Values are encoded using PLAIN encoding, except that variable-length byte /// arrays do not include a length prefix. pub max_value: Option>, pub min_value: Option>, + /// If true, max_value is the actual maximum value for a column + pub is_max_value_exact: Option, + /// If true, min_value is the actual minimum value for a column + pub is_min_value_exact: Option, } impl Statistics { - pub fn new(max: F1, min: F2, null_count: F3, distinct_count: F4, max_value: F5, min_value: F6) -> Statistics where F1: Into>>, F2: Into>>, F3: Into>, F4: Into>, F5: Into>>, F6: Into>> { + pub fn new(max: F1, min: F2, null_count: F3, distinct_count: F4, max_value: F5, min_value: F6, is_max_value_exact: F7, is_min_value_exact: F8) -> Statistics where F1: Into>>, F2: Into>>, F3: Into>, F4: Into>, F5: Into>>, F6: Into>>, F7: Into>, F8: Into> { Statistics { max: max.into(), min: min.into(), @@ -674,6 +684,8 @@ impl Statistics { distinct_count: distinct_count.into(), max_value: max_value.into(), min_value: min_value.into(), + is_max_value_exact: is_max_value_exact.into(), + is_min_value_exact: is_min_value_exact.into(), } } } @@ -687,6 +699,8 @@ impl crate::thrift::TSerializable for Statistics { let mut f_4: Option = None; let mut f_5: Option> = None; let mut f_6: Option> = None; + let mut f_7: Option = None; + let mut f_8: Option = None; loop { let field_ident = i_prot.read_field_begin()?; if field_ident.field_type == TType::Stop { @@ -718,6 +732,14 @@ impl crate::thrift::TSerializable for Statistics { let val = i_prot.read_bytes()?; f_6 = Some(val); }, + 7 => { + let val = i_prot.read_bool()?; + f_7 = Some(val); + }, + 8 => { + let val = i_prot.read_bool()?; + f_8 = Some(val); + }, _ => { i_prot.skip(field_ident.field_type)?; }, @@ -732,6 +754,8 @@ impl crate::thrift::TSerializable for Statistics { distinct_count: f_4, max_value: f_5, min_value: f_6, + is_max_value_exact: f_7, + is_min_value_exact: f_8, }; Ok(ret) } @@ -768,6 +792,16 @@ impl crate::thrift::TSerializable for Statistics { o_prot.write_bytes(fld_var)?; o_prot.write_field_end()? } + if let Some(fld_var) = self.is_max_value_exact { + o_prot.write_field_begin(&TFieldIdentifier::new("is_max_value_exact", TType::Bool, 7))?; + o_prot.write_bool(fld_var)?; + o_prot.write_field_end()? + } + if let Some(fld_var) = self.is_min_value_exact { + o_prot.write_field_begin(&TFieldIdentifier::new("is_min_value_exact", TType::Bool, 8))?; + o_prot.write_bool(fld_var)?; + o_prot.write_field_end()? + } o_prot.write_field_stop()?; o_prot.write_struct_end() } @@ -996,6 +1030,43 @@ impl crate::thrift::TSerializable for DateType { } } +// +// Float16Type +// + +#[derive(Clone, Debug, Default, Eq, Hash, Ord, PartialEq, PartialOrd)] +pub struct Float16Type { +} + +impl Float16Type { + pub fn new() -> Float16Type { + Float16Type {} + } +} + +impl crate::thrift::TSerializable for Float16Type { + fn read_from_in_protocol(i_prot: &mut T) -> thrift::Result { + i_prot.read_struct_begin()?; + loop { + let field_ident = i_prot.read_field_begin()?; + if field_ident.field_type == TType::Stop { + break; + } + i_prot.skip(field_ident.field_type)?; + i_prot.read_field_end()?; + } + i_prot.read_struct_end()?; + let ret = Float16Type {}; + Ok(ret) + } + fn write_to_out_protocol(&self, o_prot: &mut T) -> thrift::Result<()> { + let struct_ident = TStructIdentifier::new("Float16Type"); + o_prot.write_struct_begin(&struct_ident)?; + o_prot.write_field_stop()?; + o_prot.write_struct_end() + } +} + // // NullType // @@ -1640,6 +1711,7 @@ pub enum LogicalType { JSON(JsonType), BSON(BsonType), UUID(UUIDType), + FLOAT16(Float16Type), } impl crate::thrift::TSerializable for LogicalType { @@ -1745,6 +1817,13 @@ impl crate::thrift::TSerializable for LogicalType { } received_field_count += 1; }, + 15 => { + let val = Float16Type::read_from_in_protocol(i_prot)?; + if ret.is_none() { + ret = Some(LogicalType::FLOAT16(val)); + } + received_field_count += 1; + }, _ => { i_prot.skip(field_ident.field_type)?; received_field_count += 1; @@ -1844,6 +1923,11 @@ impl crate::thrift::TSerializable for LogicalType { f.write_to_out_protocol(o_prot)?; o_prot.write_field_end()?; }, + LogicalType::FLOAT16(ref f) => { + o_prot.write_field_begin(&TFieldIdentifier::new("FLOAT16", TType::Struct, 15))?; + f.write_to_out_protocol(o_prot)?; + o_prot.write_field_end()?; + }, } o_prot.write_field_stop()?; o_prot.write_struct_end() diff --git a/parquet/src/schema/printer.rs b/parquet/src/schema/printer.rs index fe4757d41ae..e15ba311be2 100644 --- a/parquet/src/schema/printer.rs +++ b/parquet/src/schema/printer.rs @@ -270,6 +270,7 @@ fn print_logical_and_converted( LogicalType::Enum => "ENUM".to_string(), LogicalType::List => "LIST".to_string(), LogicalType::Map => "MAP".to_string(), + LogicalType::Float16 => "FLOAT16".to_string(), LogicalType::Unknown => "UNKNOWN".to_string(), }, None => { diff --git a/parquet/src/schema/types.rs b/parquet/src/schema/types.rs index 11c73542095..597ed971d47 100644 --- a/parquet/src/schema/types.rs +++ b/parquet/src/schema/types.rs @@ -356,6 +356,7 @@ impl<'a> PrimitiveTypeBuilder<'a> { (LogicalType::Json, PhysicalType::BYTE_ARRAY) => {} (LogicalType::Bson, PhysicalType::BYTE_ARRAY) => {} (LogicalType::Uuid, PhysicalType::FIXED_LEN_BYTE_ARRAY) => {} + (LogicalType::Float16, PhysicalType::FIXED_LEN_BYTE_ARRAY) => {} (a, b) => { return Err(general_err!( "Cannot annotate {:?} from {} for field '{}'", From ef6664225fad38aea5c7cc146f1e77d99b3d7d01 Mon Sep 17 00:00:00 2001 From: Jeffrey <22608443+Jefffrey@users.noreply.github.com> Date: Tue, 31 Oct 2023 07:44:06 +1100 Subject: [PATCH 02/12] Update parquet/src/arrow/arrow_writer/mod.rs Co-authored-by: Raphael Taylor-Davies <1781103+tustvold@users.noreply.github.com> --- parquet/src/arrow/arrow_writer/mod.rs | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/parquet/src/arrow/arrow_writer/mod.rs b/parquet/src/arrow/arrow_writer/mod.rs index df218f354d9..f46d3b751f5 100644 --- a/parquet/src/arrow/arrow_writer/mod.rs +++ b/parquet/src/arrow/arrow_writer/mod.rs @@ -772,10 +772,7 @@ fn write_leaf(writer: &mut ColumnWriter<'_>, levels: &ArrayLevels) -> Result { - let array = column - .as_any() - .downcast_ref::() - .unwrap(); + let array = column.as_primitive::(); get_float_16_array_slice(array, indices) } _ => { From 517aebe68df1984cecdb6e259b8cdc7c03a905e9 Mon Sep 17 00:00:00 2001 From: Jeffrey <22608443+Jefffrey@users.noreply.github.com> Date: Tue, 31 Oct 2023 07:44:14 +1100 Subject: [PATCH 03/12] Update parquet/src/arrow/arrow_reader/mod.rs Co-authored-by: Raphael Taylor-Davies <1781103+tustvold@users.noreply.github.com> --- parquet/src/arrow/arrow_reader/mod.rs | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/parquet/src/arrow/arrow_reader/mod.rs b/parquet/src/arrow/arrow_reader/mod.rs index 9ed63086c65..6d057494ef5 100644 --- a/parquet/src/arrow/arrow_reader/mod.rs +++ b/parquet/src/arrow/arrow_reader/mod.rs @@ -964,10 +964,7 @@ mod tests { assert_eq!(ret, original); // Ensure can be downcast to the correct type - ret.column(0) - .as_any() - .downcast_ref::() - .unwrap(); + ret.column(0).as_primitive::(); Ok(()) } From 64b0e50601b494e77398dfb5c0a70c1c475ec2ca Mon Sep 17 00:00:00 2001 From: Jefffrey <22608443+Jefffrey@users.noreply.github.com> Date: Tue, 31 Oct 2023 07:45:26 +1100 Subject: [PATCH 04/12] Update test with null version --- parquet/src/arrow/arrow_reader/mod.rs | 60 +++++++++++++++++---------- 1 file changed, 38 insertions(+), 22 deletions(-) diff --git a/parquet/src/arrow/arrow_reader/mod.rs b/parquet/src/arrow/arrow_reader/mod.rs index 6d057494ef5..947dfc6c9f5 100644 --- a/parquet/src/arrow/arrow_reader/mod.rs +++ b/parquet/src/arrow/arrow_reader/mod.rs @@ -719,7 +719,7 @@ mod tests { use arrow_array::builder::*; use arrow_array::cast::AsArray; - use arrow_array::types::{Decimal128Type, Decimal256Type, DecimalType}; + use arrow_array::types::{Decimal128Type, Decimal256Type, DecimalType, Float16Type}; use arrow_array::*; use arrow_array::{RecordBatch, RecordBatchReader}; use arrow_buffer::{i256, ArrowNativeType, Buffer}; @@ -927,33 +927,48 @@ mod tests { #[test] fn test_float16_roundtrip() -> Result<()> { - let schema = Arc::new(Schema::new(vec![Field::new( - "float16", - ArrowDataType::Float16, - true, - )])); + let schema = Arc::new(Schema::new(vec![ + Field::new("float16", ArrowDataType::Float16, false), + Field::new("float16-nullable", ArrowDataType::Float16, true), + ])); let mut buf = Vec::with_capacity(1024); let mut writer = ArrowWriter::try_new(&mut buf, schema.clone(), None)?; let original = RecordBatch::try_new( schema, - vec![Arc::new(Float16Array::from_iter_values([ - f16::EPSILON, - f16::INFINITY, - f16::MIN, - f16::MAX, - f16::NAN, - f16::INFINITY, - f16::NEG_INFINITY, - f16::ONE, - f16::NEG_ONE, - f16::ZERO, - f16::NEG_ZERO, - f16::E, - f16::PI, - f16::FRAC_1_PI, - ]))], + vec![ + Arc::new(Float16Array::from_iter_values([ + f16::EPSILON, + f16::MIN, + f16::MAX, + f16::NAN, + f16::INFINITY, + f16::NEG_INFINITY, + f16::ONE, + f16::NEG_ONE, + f16::ZERO, + f16::NEG_ZERO, + f16::E, + f16::PI, + f16::FRAC_1_PI, + ])), + Arc::new(Float16Array::from(vec![ + None, + None, + None, + Some(f16::NAN), + Some(f16::INFINITY), + Some(f16::NEG_INFINITY), + None, + None, + None, + None, + None, + None, + Some(f16::FRAC_1_PI), + ])), + ], )?; writer.write(&original)?; @@ -965,6 +980,7 @@ mod tests { // Ensure can be downcast to the correct type ret.column(0).as_primitive::(); + ret.column(1).as_primitive::(); Ok(()) } From bf43eea1c93c1c1a41df6ad3fa641f8a6ef1dd9b Mon Sep 17 00:00:00 2001 From: Jefffrey <22608443+Jefffrey@users.noreply.github.com> Date: Tue, 7 Nov 2023 12:31:15 +1100 Subject: [PATCH 05/12] Fix schema tests and parsing for f16 --- parquet/src/schema/parser.rs | 8 +++++++ parquet/src/schema/printer.rs | 9 +++++++ parquet/src/schema/types.rs | 45 ++++++++++++++++++++++++++++++++++- 3 files changed, 61 insertions(+), 1 deletion(-) diff --git a/parquet/src/schema/parser.rs b/parquet/src/schema/parser.rs index 5e213e3bb9e..dcef11aa66d 100644 --- a/parquet/src/schema/parser.rs +++ b/parquet/src/schema/parser.rs @@ -823,6 +823,7 @@ mod tests { message root { optional fixed_len_byte_array(5) f1 (DECIMAL(9, 3)); optional fixed_len_byte_array (16) f2 (DECIMAL (38, 18)); + optional fixed_len_byte_array (2) f3 (FLOAT16); } "; let message = parse(schema).unwrap(); @@ -855,6 +856,13 @@ mod tests { .build() .unwrap(), ), + Arc::new( + Type::primitive_type_builder("f3", PhysicalType::FIXED_LEN_BYTE_ARRAY) + .with_logical_type(Some(LogicalType::Float16)) + .with_length(2) + .build() + .unwrap(), + ), ]) .build() .unwrap(); diff --git a/parquet/src/schema/printer.rs b/parquet/src/schema/printer.rs index e15ba311be2..2dec8a5be9f 100644 --- a/parquet/src/schema/printer.rs +++ b/parquet/src/schema/printer.rs @@ -668,6 +668,15 @@ mod tests { .unwrap(), "OPTIONAL FIXED_LEN_BYTE_ARRAY (9) decimal (DECIMAL(19,4));", ), + ( + Type::primitive_type_builder("float16", PhysicalType::FIXED_LEN_BYTE_ARRAY) + .with_logical_type(Some(LogicalType::Float16)) + .with_length(2) + .with_repetition(Repetition::REQUIRED) + .build() + .unwrap(), + "REQUIRED FIXED_LEN_BYTE_ARRAY (2) float16 (FLOAT16);", + ), ]; types_and_strings.into_iter().for_each(|(field, expected)| { diff --git a/parquet/src/schema/types.rs b/parquet/src/schema/types.rs index 597ed971d47..2f36deffbab 100644 --- a/parquet/src/schema/types.rs +++ b/parquet/src/schema/types.rs @@ -356,7 +356,14 @@ impl<'a> PrimitiveTypeBuilder<'a> { (LogicalType::Json, PhysicalType::BYTE_ARRAY) => {} (LogicalType::Bson, PhysicalType::BYTE_ARRAY) => {} (LogicalType::Uuid, PhysicalType::FIXED_LEN_BYTE_ARRAY) => {} - (LogicalType::Float16, PhysicalType::FIXED_LEN_BYTE_ARRAY) => {} + (LogicalType::Float16, PhysicalType::FIXED_LEN_BYTE_ARRAY) + if self.length == 2 => {} + (LogicalType::Float16, PhysicalType::FIXED_LEN_BYTE_ARRAY) => { + return Err(general_err!( + "FLOAT16 cannot annotate field '{}' because it is not a FIXED_LEN_BYTE_ARRAY(2) field", + self.name + )) + } (a, b) => { return Err(general_err!( "Cannot annotate {:?} from {} for field '{}'", @@ -1505,6 +1512,41 @@ mod tests { "Parquet error: Invalid FIXED_LEN_BYTE_ARRAY length: -1 for field 'foo'" ); } + + result = Type::primitive_type_builder("foo", PhysicalType::FIXED_LEN_BYTE_ARRAY) + .with_repetition(Repetition::REQUIRED) + .with_logical_type(Some(LogicalType::Float16)) + .with_length(2) + .build(); + assert!(result.is_ok()); + + // Can't be other than FIXED_LEN_BYTE_ARRAY for physical type + result = Type::primitive_type_builder("foo", PhysicalType::FLOAT) + .with_repetition(Repetition::REQUIRED) + .with_logical_type(Some(LogicalType::Float16)) + .with_length(2) + .build(); + assert!(result.is_err()); + if let Err(e) = result { + assert_eq!( + format!("{e}"), + "Parquet error: Cannot annotate Float16 from FLOAT for field 'foo'" + ); + } + + // Must have length 2 + result = Type::primitive_type_builder("foo", PhysicalType::FIXED_LEN_BYTE_ARRAY) + .with_repetition(Repetition::REQUIRED) + .with_logical_type(Some(LogicalType::Float16)) + .with_length(4) + .build(); + assert!(result.is_err()); + if let Err(e) = result { + assert_eq!( + format!("{e}"), + "Parquet error: FLOAT16 cannot annotate field 'foo' because it is not a FIXED_LEN_BYTE_ARRAY(2) field" + ); + } } #[test] @@ -1982,6 +2024,7 @@ mod tests { let message_type = " message conversions { REQUIRED INT64 id; + OPTIONAL FIXED_LEN_BYTE_ARRAY (2) f16 (FLOAT16); OPTIONAL group int_array_Array (LIST) { REPEATED group list { OPTIONAL group element (LIST) { From af39f80ab4ec7a0ea846e33a46c7653214ec0946 Mon Sep 17 00:00:00 2001 From: Jefffrey <22608443+Jefffrey@users.noreply.github.com> Date: Tue, 7 Nov 2023 13:59:52 +1100 Subject: [PATCH 06/12] f16 for record api --- parquet/src/record/api.rs | 88 +++++++++++++++++++++++++++++++++++++-- 1 file changed, 85 insertions(+), 3 deletions(-) diff --git a/parquet/src/record/api.rs b/parquet/src/record/api.rs index c7a0b09c37e..44bf5494f33 100644 --- a/parquet/src/record/api.rs +++ b/parquet/src/record/api.rs @@ -20,9 +20,11 @@ use std::fmt; use chrono::{TimeZone, Utc}; +use half::f16; +use num::Float; use num_bigint::{BigInt, Sign}; -use crate::basic::{ConvertedType, Type as PhysicalType}; +use crate::basic::{ConvertedType, LogicalType, Type as PhysicalType}; use crate::data_type::{ByteArray, Decimal, Int96}; use crate::errors::{ParquetError, Result}; use crate::schema::types::ColumnDescPtr; @@ -121,6 +123,7 @@ pub trait RowAccessor { fn get_ushort(&self, i: usize) -> Result; fn get_uint(&self, i: usize) -> Result; fn get_ulong(&self, i: usize) -> Result; + fn get_float16(&self, i: usize) -> Result; fn get_float(&self, i: usize) -> Result; fn get_double(&self, i: usize) -> Result; fn get_timestamp_millis(&self, i: usize) -> Result; @@ -215,6 +218,8 @@ impl RowAccessor for Row { row_primitive_accessor!(get_ulong, ULong, u64); + row_primitive_accessor!(get_float16, Float16, f16); + row_primitive_accessor!(get_float, Float, f32); row_primitive_accessor!(get_double, Double, f64); @@ -293,6 +298,7 @@ pub trait ListAccessor { fn get_ushort(&self, i: usize) -> Result; fn get_uint(&self, i: usize) -> Result; fn get_ulong(&self, i: usize) -> Result; + fn get_float16(&self, i: usize) -> Result; fn get_float(&self, i: usize) -> Result; fn get_double(&self, i: usize) -> Result; fn get_timestamp_millis(&self, i: usize) -> Result; @@ -358,6 +364,8 @@ impl ListAccessor for List { list_primitive_accessor!(get_ulong, ULong, u64); + list_primitive_accessor!(get_float16, Float16, f16); + list_primitive_accessor!(get_float, Float, f32); list_primitive_accessor!(get_double, Double, f64); @@ -449,6 +457,8 @@ impl<'a> ListAccessor for MapList<'a> { map_list_primitive_accessor!(get_ulong, ULong, u64); + map_list_primitive_accessor!(get_float16, Float16, f16); + map_list_primitive_accessor!(get_float, Float, f32); map_list_primitive_accessor!(get_double, Double, f64); @@ -510,6 +520,8 @@ pub enum Field { UInt(u32), // Unsigned integer UINT_64. ULong(u64), + /// IEEE 16-bit floating point value. + Float16(f16), /// IEEE 32-bit floating point value. Float(f32), /// IEEE 64-bit floating point value. @@ -552,6 +564,7 @@ impl Field { Field::UShort(_) => "UShort", Field::UInt(_) => "UInt", Field::ULong(_) => "ULong", + Field::Float16(_) => "Float16", Field::Float(_) => "Float", Field::Double(_) => "Double", Field::Decimal(_) => "Decimal", @@ -636,8 +649,8 @@ impl Field { Field::Double(value) } - /// Converts Parquet BYTE_ARRAY type with converted type into either UTF8 string or - /// array of bytes. + /// Converts Parquet BYTE_ARRAY type with converted type into a UTF8 + /// string, decimal, float16, or an array of bytes. #[inline] pub fn convert_byte_array(descr: &ColumnDescPtr, value: ByteArray) -> Result { let field = match descr.physical_type() { @@ -666,6 +679,16 @@ impl Field { descr.type_precision(), descr.type_scale(), )), + ConvertedType::NONE if descr.logical_type() == Some(LogicalType::Float16) => { + if value.len() != 2 { + return Err(general_err!( + "Error reading FIXED_LEN_BYTE_ARRAY as FLOAT16. Length must be 2, got {}", + value.len() + )); + } + let bytes = [value.data()[0], value.data()[1]]; + Field::Float16(f16::from_le_bytes(bytes)) + } ConvertedType::NONE => Field::Bytes(value), _ => nyi!(descr, value), }, @@ -690,6 +713,9 @@ impl Field { Field::UShort(n) => Value::Number(serde_json::Number::from(*n)), Field::UInt(n) => Value::Number(serde_json::Number::from(*n)), Field::ULong(n) => Value::Number(serde_json::Number::from(*n)), + Field::Float16(n) => serde_json::Number::from_f64(f64::from(*n)) + .map(Value::Number) + .unwrap_or(Value::Null), Field::Float(n) => serde_json::Number::from_f64(f64::from(*n)) .map(Value::Number) .unwrap_or(Value::Null), @@ -736,6 +762,15 @@ impl fmt::Display for Field { Field::UShort(value) => write!(f, "{value}"), Field::UInt(value) => write!(f, "{value}"), Field::ULong(value) => write!(f, "{value}"), + Field::Float16(value) => { + if !value.is_finite() { + write!(f, "{value}") + } else if value.trunc() == value { + write!(f, "{value}.0") + } else { + write!(f, "{value}") + } + } Field::Float(value) => { if !(1e-15..=1e19).contains(&value) { write!(f, "{value:E}") @@ -1069,6 +1104,24 @@ mod tests { Field::Decimal(Decimal::from_bytes(value, 17, 5)) ); + // FLOAT16 + let descr = { + let tpe = PrimitiveTypeBuilder::new("col", PhysicalType::FIXED_LEN_BYTE_ARRAY) + .with_logical_type(Some(LogicalType::Float16)) + .with_length(2) + .build() + .unwrap(); + Arc::new(ColumnDescriptor::new( + Arc::new(tpe), + 0, + 0, + ColumnPath::from("col"), + )) + }; + let value = ByteArray::from(f16::PI); + let row = Field::convert_byte_array(&descr, value.clone()); + assert_eq!(row.unwrap(), Field::Float16(f16::PI)); + // NONE (FIXED_LEN_BYTE_ARRAY) let descr = make_column_descr![ PhysicalType::FIXED_LEN_BYTE_ARRAY, @@ -1145,6 +1198,18 @@ mod tests { check_datetime_conversion(2014, 11, 28, 21, 15, 12); } + #[test] + fn test_convert_float16_to_string() { + assert_eq!(format!("{}", Field::Float16(f16::ONE)), "1.0"); + assert_eq!(format!("{}", Field::Float16(f16::PI)), "3.140625"); + assert_eq!(format!("{}", Field::Float16(f16::MAX)), "65504.0"); + assert_eq!(format!("{}", Field::Float16(f16::NAN)), "NaN"); + assert_eq!(format!("{}", Field::Float16(f16::INFINITY)), "inf"); + assert_eq!(format!("{}", Field::Float16(f16::NEG_INFINITY)), "-inf"); + assert_eq!(format!("{}", Field::Float16(f16::ZERO)), "0.0"); + assert_eq!(format!("{}", Field::Float16(f16::NEG_ZERO)), "-0.0"); + } + #[test] fn test_convert_float_to_string() { assert_eq!(format!("{}", Field::Float(1.0)), "1.0"); @@ -1218,6 +1283,7 @@ mod tests { assert_eq!(format!("{}", Field::UShort(2)), "2"); assert_eq!(format!("{}", Field::UInt(3)), "3"); assert_eq!(format!("{}", Field::ULong(4)), "4"); + assert_eq!(format!("{}", Field::Float16(f16::E)), "2.71875"); assert_eq!(format!("{}", Field::Float(5.0)), "5.0"); assert_eq!(format!("{}", Field::Float(5.1234)), "5.1234"); assert_eq!(format!("{}", Field::Double(6.0)), "6.0"); @@ -1284,6 +1350,7 @@ mod tests { assert!(Field::UShort(2).is_primitive()); assert!(Field::UInt(3).is_primitive()); assert!(Field::ULong(4).is_primitive()); + assert!(Field::Float16(f16::E).is_primitive()); assert!(Field::Float(5.0).is_primitive()); assert!(Field::Float(5.1234).is_primitive()); assert!(Field::Double(6.0).is_primitive()); @@ -1344,6 +1411,7 @@ mod tests { ("15".to_string(), Field::TimestampMillis(1262391174000)), ("16".to_string(), Field::TimestampMicros(1262391174000000)), ("17".to_string(), Field::Decimal(Decimal::from_i32(4, 7, 2))), + ("18".to_string(), Field::Float16(f16::PI)), ]); assert_eq!("null", format!("{}", row.fmt(0))); @@ -1370,6 +1438,7 @@ mod tests { format!("{}", row.fmt(16)) ); assert_eq!("0.04", format!("{}", row.fmt(17))); + assert_eq!("3.140625", format!("{}", row.fmt(18))); } #[test] @@ -1429,6 +1498,7 @@ mod tests { Field::Bytes(ByteArray::from(vec![1, 2, 3, 4, 5])), ), ("o".to_string(), Field::Decimal(Decimal::from_i32(4, 7, 2))), + ("p".to_string(), Field::Float16(f16::from_f32(9.1))), ]); assert!(!row.get_bool(1).unwrap()); @@ -1445,6 +1515,7 @@ mod tests { assert_eq!("abc", row.get_string(12).unwrap()); assert_eq!(5, row.get_bytes(13).unwrap().len()); assert_eq!(7, row.get_decimal(14).unwrap().precision()); + assert!((f16::from_f32(9.1) - row.get_float16(15).unwrap()).abs() < f16::EPSILON); } #[test] @@ -1469,6 +1540,7 @@ mod tests { Field::Bytes(ByteArray::from(vec![1, 2, 3, 4, 5])), ), ("o".to_string(), Field::Decimal(Decimal::from_i32(4, 7, 2))), + ("p".to_string(), Field::Float16(f16::from_f32(9.1))), ]); for i in 0..row.len() { @@ -1583,6 +1655,9 @@ mod tests { let list = make_list(vec![Field::ULong(6), Field::ULong(7)]); assert_eq!(7, list.get_ulong(1).unwrap()); + let list = make_list(vec![Field::Float16(f16::PI)]); + assert!((f16::PI - list.get_float16(0).unwrap()).abs() < f16::EPSILON); + let list = make_list(vec![ Field::Float(8.1), Field::Float(9.2), @@ -1633,6 +1708,9 @@ mod tests { let list = make_list(vec![Field::ULong(6), Field::ULong(7)]); assert!(list.get_float(1).is_err()); + let list = make_list(vec![Field::Float16(f16::PI)]); + assert!(list.get_string(0).is_err()); + let list = make_list(vec![ Field::Float(8.1), Field::Float(9.2), @@ -1768,6 +1846,10 @@ mod tests { Field::ULong(4).to_json_value(), Value::Number(serde_json::Number::from(4)) ); + assert_eq!( + Field::Float16(f16::from_f32(5.0)).to_json_value(), + Value::Number(serde_json::Number::from_f64(5.0).unwrap()) + ); assert_eq!( Field::Float(5.0).to_json_value(), Value::Number(serde_json::Number::from_f64(5.0).unwrap()) From 40f3e5f4262c25bfe5e419908d06878bb0c085cd Mon Sep 17 00:00:00 2001 From: Jefffrey <22608443+Jefffrey@users.noreply.github.com> Date: Tue, 7 Nov 2023 15:19:23 +1100 Subject: [PATCH 07/12] Handle NaN for f16 statistics writing --- parquet/src/column/writer/encoder.rs | 4 +- parquet/src/column/writer/mod.rs | 149 +++++++++++++++++++++++++-- parquet/src/data_type.rs | 7 ++ 3 files changed, 152 insertions(+), 8 deletions(-) diff --git a/parquet/src/column/writer/encoder.rs b/parquet/src/column/writer/encoder.rs index 7bd4db30c3a..0cbcda5b485 100644 --- a/parquet/src/column/writer/encoder.rs +++ b/parquet/src/column/writer/encoder.rs @@ -290,7 +290,7 @@ where { let first = loop { let next = iter.next()?; - if !is_nan(next) { + if !is_nan(descr, next) { break next; } }; @@ -298,7 +298,7 @@ where let mut min = first; let mut max = first; for val in iter { - if is_nan(val) { + if is_nan(descr, val) { continue; } if compare_greater(descr, min, val) { diff --git a/parquet/src/column/writer/mod.rs b/parquet/src/column/writer/mod.rs index 84bf1911d89..e657992acc9 100644 --- a/parquet/src/column/writer/mod.rs +++ b/parquet/src/column/writer/mod.rs @@ -17,6 +17,8 @@ //! Contains column writer API. +use half::f16; + use crate::bloom_filter::Sbbf; use crate::format::{ColumnIndex, OffsetIndex}; use std::collections::{BTreeSet, VecDeque}; @@ -967,18 +969,23 @@ impl<'a, E: ColumnValueEncoder> GenericColumnWriter<'a, E> { } fn update_min(descr: &ColumnDescriptor, val: &T, min: &mut Option) { - update_stat::(val, min, |cur| compare_greater(descr, cur, val)) + update_stat::(descr, val, min, |cur| compare_greater(descr, cur, val)) } fn update_max(descr: &ColumnDescriptor, val: &T, max: &mut Option) { - update_stat::(val, max, |cur| compare_greater(descr, val, cur)) + update_stat::(descr, val, max, |cur| compare_greater(descr, val, cur)) } #[inline] #[allow(clippy::eq_op)] -fn is_nan(val: &T) -> bool { +fn is_nan(descr: &ColumnDescriptor, val: &T) -> bool { match T::PHYSICAL_TYPE { Type::FLOAT | Type::DOUBLE => val != val, + Type::FIXED_LEN_BYTE_ARRAY if descr.logical_type() == Some(LogicalType::Float16) => { + let val = val.as_bytes(); + let val = f16::from_le_bytes([val[0], val[1]]); + val.is_nan() + } _ => false, } } @@ -988,11 +995,15 @@ fn is_nan(val: &T) -> bool { /// If `cur` is `None`, sets `cur` to `Some(val)`, otherwise calls `should_update` with /// the value of `cur`, and updates `cur` to `Some(val)` if it returns `true` -fn update_stat(val: &T, cur: &mut Option, should_update: F) -where +fn update_stat( + descr: &ColumnDescriptor, + val: &T, + cur: &mut Option, + should_update: F, +) where F: Fn(&T) -> bool, { - if is_nan(val) { + if is_nan(descr, val) { return; } @@ -1038,6 +1049,14 @@ fn compare_greater(descr: &ColumnDescriptor, a: &T, b: &T) }; }; + if let Some(LogicalType::Float16) = descr.logical_type() { + let a = a.as_bytes(); + let a = f16::from_le_bytes([a[0], a[1]]); + let b = b.as_bytes(); + let b = f16::from_le_bytes([b[0], b[1]]); + return a > b; + } + a > b } @@ -1169,6 +1188,7 @@ fn increment_utf8(mut data: Vec) -> Option> { mod tests { use crate::{file::properties::DEFAULT_COLUMN_INDEX_TRUNCATE_LENGTH, format::BoundaryOrder}; use bytes::Bytes; + use half::f16; use rand::distributions::uniform::SampleUniform; use std::sync::Arc; @@ -2077,6 +2097,79 @@ mod tests { } } + #[test] + fn test_column_writer_check_float16_min_max() { + let input = [ + -f16::ONE, + f16::from_f32(3.0), + -f16::from_f32(2.0), + f16::from_f32(2.0), + ] + .into_iter() + .map(|s| ByteArray::from(s).into()) + .collect::>(); + + let stats = float16_statistics_roundtrip(&input); + assert!(stats.has_min_max_set()); + assert!(stats.is_min_max_backwards_compatible()); + assert_eq!(stats.min(), &ByteArray::from(-f16::from_f32(2.0))); + assert_eq!(stats.max(), &ByteArray::from(f16::from_f32(3.0))); + } + + #[test] + fn test_column_writer_check_float16_nan_middle() { + let input = [f16::ONE, f16::NAN, f16::ONE + f16::ONE] + .into_iter() + .map(|s| ByteArray::from(s).into()) + .collect::>(); + + let stats = float16_statistics_roundtrip(&input); + assert!(stats.has_min_max_set()); + assert!(stats.is_min_max_backwards_compatible()); + assert_eq!(stats.min(), &ByteArray::from(f16::ONE)); + assert_eq!(stats.max(), &ByteArray::from(f16::ONE + f16::ONE)); + } + + #[test] + fn test_float16_statistics_nan_middle() { + let input = [f16::ONE, f16::NAN, f16::ONE + f16::ONE] + .into_iter() + .map(|s| ByteArray::from(s).into()) + .collect::>(); + + let stats = float16_statistics_roundtrip(&input); + assert!(stats.has_min_max_set()); + assert!(stats.is_min_max_backwards_compatible()); + assert_eq!(stats.min(), &ByteArray::from(f16::ONE)); + assert_eq!(stats.max(), &ByteArray::from(f16::ONE + f16::ONE)); + } + + #[test] + fn test_float16_statistics_nan_start() { + let input = [f16::NAN, f16::ONE, f16::ONE + f16::ONE] + .into_iter() + .map(|s| ByteArray::from(s).into()) + .collect::>(); + + let stats = float16_statistics_roundtrip(&input); + assert!(stats.has_min_max_set()); + assert!(stats.is_min_max_backwards_compatible()); + assert_eq!(stats.min(), &ByteArray::from(f16::ONE)); + assert_eq!(stats.max(), &ByteArray::from(f16::ONE + f16::ONE)); + } + + #[test] + fn test_float16_statistics_nan_only() { + let input = [f16::NAN, f16::NAN] + .into_iter() + .map(|s| ByteArray::from(s).into()) + .collect::>(); + + let stats = float16_statistics_roundtrip(&input); + assert!(!stats.has_min_max_set()); + assert!(stats.is_min_max_backwards_compatible()); + } + #[test] fn test_float_statistics_nan_middle() { let stats = statistics_roundtrip::(&[1.0, f32::NAN, 2.0]); @@ -2735,6 +2828,50 @@ mod tests { ColumnDescriptor::new(Arc::new(tpe), max_def_level, max_rep_level, path) } + fn float16_statistics_roundtrip( + values: &[FixedLenByteArray], + ) -> ValueStatistics { + let page_writer = get_test_page_writer(); + let props = Default::default(); + let mut writer = + get_test_float16_column_writer::(page_writer, 0, 0, props); + writer.write_batch(values, None, None).unwrap(); + + let metadata = writer.close().unwrap().metadata; + if let Some(Statistics::FixedLenByteArray(stats)) = metadata.statistics() { + stats.clone() + } else { + panic!("metadata missing statistics"); + } + } + + fn get_test_float16_column_writer( + page_writer: Box, + max_def_level: i16, + max_rep_level: i16, + props: WriterPropertiesPtr, + ) -> ColumnWriterImpl<'static, T> { + let descr = Arc::new(get_test_float16_column_descr::( + max_def_level, + max_rep_level, + )); + let column_writer = get_column_writer(descr, props, page_writer); + get_typed_column_writer::(column_writer) + } + + fn get_test_float16_column_descr( + max_def_level: i16, + max_rep_level: i16, + ) -> ColumnDescriptor { + let path = ColumnPath::from("col"); + let tpe = SchemaType::primitive_type_builder("col", T::get_physical_type()) + .with_length(2) + .with_logical_type(Some(LogicalType::Float16)) + .build() + .unwrap(); + ColumnDescriptor::new(Arc::new(tpe), max_def_level, max_rep_level, path) + } + /// Returns column writer for UINT32 Column provided as ConvertedType only fn get_test_unsigned_int_given_as_converted_column_writer<'a, T: DataType>( page_writer: Box, diff --git a/parquet/src/data_type.rs b/parquet/src/data_type.rs index 7e64478ed94..b1d52b75c72 100644 --- a/parquet/src/data_type.rs +++ b/parquet/src/data_type.rs @@ -18,6 +18,7 @@ //! Data types that connect Parquet physical types with their Rust-specific //! representations. use bytes::Bytes; +use half::f16; use std::cmp::Ordering; use std::fmt; use std::mem; @@ -231,6 +232,12 @@ impl From for ByteArray { } } +impl From for ByteArray { + fn from(value: f16) -> Self { + Self::from(value.to_le_bytes().as_slice()) + } +} + impl PartialEq for ByteArray { fn eq(&self, other: &ByteArray) -> bool { match (&self.data, &other.data) { From 3a8bec0a11aa4c9d165ba627381efd7f1235f131 Mon Sep 17 00:00:00 2001 From: Jefffrey <22608443+Jefffrey@users.noreply.github.com> Date: Tue, 7 Nov 2023 15:30:06 +1100 Subject: [PATCH 08/12] Revert formatting changes --- .../array_reader/fixed_len_byte_array.rs | 37 ++-- parquet/src/arrow/schema/mod.rs | 183 +++++++++++++----- parquet/src/arrow/schema/primitive.rs | 12 +- 3 files changed, 164 insertions(+), 68 deletions(-) diff --git a/parquet/src/arrow/array_reader/fixed_len_byte_array.rs b/parquet/src/arrow/array_reader/fixed_len_byte_array.rs index 07213083390..324dbe21e1a 100644 --- a/parquet/src/arrow/array_reader/fixed_len_byte_array.rs +++ b/parquet/src/arrow/array_reader/fixed_len_byte_array.rs @@ -162,10 +162,11 @@ impl ArrayReader for FixedLenByteArrayReader { fn consume_batch(&mut self) -> Result { let record_data = self.record_reader.consume_record_data(); - let array_data = ArrayDataBuilder::new(ArrowType::FixedSizeBinary(self.byte_length as i32)) - .len(self.record_reader.num_values()) - .add_buffer(record_data) - .null_bit_buffer(self.record_reader.consume_bitmap_buffer()); + let array_data = + ArrayDataBuilder::new(ArrowType::FixedSizeBinary(self.byte_length as i32)) + .len(self.record_reader.num_values()) + .add_buffer(record_data) + .null_bit_buffer(self.record_reader.consume_bitmap_buffer()); let binary = FixedSizeBinaryArray::from(unsafe { array_data.build_unchecked() }); @@ -196,13 +197,19 @@ impl ArrayReader for FixedLenByteArrayReader { IntervalUnit::YearMonth => Arc::new( binary .iter() - .map(|o| o.map(|b| i32::from_le_bytes(b[0..4].try_into().unwrap()))) + .map(|o| { + o.map(|b| i32::from_le_bytes(b[0..4].try_into().unwrap())) + }) .collect::(), ) as ArrayRef, IntervalUnit::DayTime => Arc::new( binary .iter() - .map(|o| o.map(|b| i64::from_le_bytes(b[4..12].try_into().unwrap()))) + .map(|o| { + o.map(|b| { + i64::from_le_bytes(b[4..12].try_into().unwrap()) + }) + }) .collect::(), ) as ArrayRef, IntervalUnit::MonthDayNano => { @@ -286,7 +293,9 @@ impl ValuesBuffer for FixedLenByteArrayBuffer { let slice = self.buffer.as_slice_mut(); let values_range = read_offset..read_offset + values_read; - for (value_pos, level_pos) in values_range.rev().zip(iter_set_bits_rev(valid_mask)) { + for (value_pos, level_pos) in + values_range.rev().zip(iter_set_bits_rev(valid_mask)) + { debug_assert!(level_pos >= value_pos); if level_pos <= value_pos { break; @@ -382,7 +391,8 @@ impl ColumnValueDecoder for ValueDecoder { let len = range.end - range.start; match self.decoder.as_mut().unwrap() { Decoder::Plain { offset, buf } => { - let to_read = (len * self.byte_length).min(buf.len() - *offset) / self.byte_length; + let to_read = + (len * self.byte_length).min(buf.len() - *offset) / self.byte_length; let end_offset = *offset + to_read * self.byte_length; out.buffer .extend_from_slice(&buf.as_ref()[*offset..end_offset]); @@ -475,12 +485,15 @@ mod tests { .build() .unwrap(); - let written = - RecordBatch::try_from_iter([("list", Arc::new(ListArray::from(data)) as ArrayRef)]) - .unwrap(); + let written = RecordBatch::try_from_iter([( + "list", + Arc::new(ListArray::from(data)) as ArrayRef, + )]) + .unwrap(); let mut buffer = Vec::with_capacity(1024); - let mut writer = ArrowWriter::try_new(&mut buffer, written.schema(), None).unwrap(); + let mut writer = + ArrowWriter::try_new(&mut buffer, written.schema(), None).unwrap(); writer.write(&written).unwrap(); writer.close().unwrap(); diff --git a/parquet/src/arrow/schema/mod.rs b/parquet/src/arrow/schema/mod.rs index 1b44c012308..4c350c4b1d8 100644 --- a/parquet/src/arrow/schema/mod.rs +++ b/parquet/src/arrow/schema/mod.rs @@ -32,7 +32,8 @@ use arrow_ipc::writer; use arrow_schema::{DataType, Field, Fields, Schema, TimeUnit}; use crate::basic::{ - ConvertedType, LogicalType, Repetition, TimeUnit as ParquetTimeUnit, Type as PhysicalType, + ConvertedType, LogicalType, Repetition, TimeUnit as ParquetTimeUnit, + Type as PhysicalType, }; use crate::errors::{ParquetError, Result}; use crate::file::{metadata::KeyValue, properties::WriterProperties}; @@ -54,7 +55,11 @@ pub fn parquet_to_arrow_schema( parquet_schema: &SchemaDescriptor, key_value_metadata: Option<&Vec>, ) -> Result { - parquet_to_arrow_schema_by_columns(parquet_schema, ProjectionMask::all(), key_value_metadata) + parquet_to_arrow_schema_by_columns( + parquet_schema, + ProjectionMask::all(), + key_value_metadata, + ) } /// Convert parquet schema to arrow schema including optional metadata, @@ -194,7 +199,10 @@ fn encode_arrow_schema(schema: &Schema) -> String { /// Mutates writer metadata by storing the encoded Arrow schema. /// If there is an existing Arrow schema metadata, it is replaced. -pub(crate) fn add_encoded_arrow_schema_to_metadata(schema: &Schema, props: &mut WriterProperties) { +pub(crate) fn add_encoded_arrow_schema_to_metadata( + schema: &Schema, + props: &mut WriterProperties, +) { let encoded = encode_arrow_schema(schema); let schema_kv = KeyValue { @@ -262,15 +270,16 @@ fn parse_key_value_metadata( /// Convert parquet column schema to arrow field. pub fn parquet_to_arrow_field(parquet_column: &ColumnDescriptor) -> Result { let field = complex::convert_type(&parquet_column.self_type_ptr())?; - let mut ret = Field::new(parquet_column.name(), field.arrow_type, field.nullable); + let mut ret = Field::new( + parquet_column.name(), + field.arrow_type, + field.nullable, + ); let basic_info = parquet_column.self_type().get_basic_info(); if basic_info.has_id() { let mut meta = HashMap::with_capacity(1); - meta.insert( - PARQUET_FIELD_ID_META_KEY.to_string(), - basic_info.id().to_string(), - ); + meta.insert(PARQUET_FIELD_ID_META_KEY.to_string(), basic_info.id().to_string()); ret.set_metadata(meta); } @@ -392,9 +401,15 @@ fn arrow_to_parquet_type(field: &Field) -> Result { is_adjusted_to_u_t_c: matches!(tz, Some(z) if !z.as_ref().is_empty()), unit: match time_unit { TimeUnit::Second => unreachable!(), - TimeUnit::Millisecond => ParquetTimeUnit::MILLIS(Default::default()), - TimeUnit::Microsecond => ParquetTimeUnit::MICROS(Default::default()), - TimeUnit::Nanosecond => ParquetTimeUnit::NANOS(Default::default()), + TimeUnit::Millisecond => { + ParquetTimeUnit::MILLIS(Default::default()) + } + TimeUnit::Microsecond => { + ParquetTimeUnit::MICROS(Default::default()) + } + TimeUnit::Nanosecond => { + ParquetTimeUnit::NANOS(Default::default()) + } }, })) .with_repetition(repetition) @@ -442,7 +457,9 @@ fn arrow_to_parquet_type(field: &Field) -> Result { .with_repetition(repetition) .with_id(id) .build(), - DataType::Duration(_) => Err(arrow_err!("Converting Duration to parquet not supported",)), + DataType::Duration(_) => { + Err(arrow_err!("Converting Duration to parquet not supported",)) + } DataType::Interval(_) => { Type::primitive_type_builder(name, PhysicalType::FIXED_LEN_BYTE_ARRAY) .with_converted_type(ConvertedType::INTERVAL) @@ -464,7 +481,8 @@ fn arrow_to_parquet_type(field: &Field) -> Result { .with_length(*length) .build() } - DataType::Decimal128(precision, scale) | DataType::Decimal256(precision, scale) => { + DataType::Decimal128(precision, scale) + | DataType::Decimal256(precision, scale) => { // Decimal precision determines the Parquet physical type to use. // Following the: https://github.com/apache/parquet-format/blob/master/LogicalTypes.md#decimal let (physical_type, length) = if *precision > 1 && *precision <= 9 { @@ -511,7 +529,9 @@ fn arrow_to_parquet_type(field: &Field) -> Result { } DataType::Struct(fields) => { if fields.is_empty() { - return Err(arrow_err!("Parquet does not support writing empty structs",)); + return Err( + arrow_err!("Parquet does not support writing empty structs",), + ); } // recursively convert children to types/nodes let fields = fields @@ -601,7 +621,8 @@ mod tests { let parquet_group_type = parse_message_type(message_type).unwrap(); let parquet_schema = SchemaDescriptor::new(Arc::new(parquet_group_type)); - let converted_arrow_schema = parquet_to_arrow_schema(&parquet_schema, None).unwrap(); + let converted_arrow_schema = + parquet_to_arrow_schema(&parquet_schema, None).unwrap(); let arrow_fields = Fields::from(vec![ Field::new("boolean", DataType::Boolean, false), @@ -639,7 +660,8 @@ mod tests { let parquet_group_type = parse_message_type(message_type).unwrap(); let parquet_schema = SchemaDescriptor::new(Arc::new(parquet_group_type)); - let converted_arrow_schema = parquet_to_arrow_schema(&parquet_schema, None).unwrap(); + let converted_arrow_schema = + parquet_to_arrow_schema(&parquet_schema, None).unwrap(); let arrow_fields = Fields::from(vec![ Field::new("decimal1", DataType::Decimal128(4, 2), false), @@ -665,7 +687,8 @@ mod tests { let parquet_group_type = parse_message_type(message_type).unwrap(); let parquet_schema = SchemaDescriptor::new(Arc::new(parquet_group_type)); - let converted_arrow_schema = parquet_to_arrow_schema(&parquet_schema, None).unwrap(); + let converted_arrow_schema = + parquet_to_arrow_schema(&parquet_schema, None).unwrap(); let arrow_fields = Fields::from(vec![ Field::new("binary", DataType::Binary, false), @@ -686,7 +709,8 @@ mod tests { let parquet_group_type = parse_message_type(message_type).unwrap(); let parquet_schema = SchemaDescriptor::new(Arc::new(parquet_group_type)); - let converted_arrow_schema = parquet_to_arrow_schema(&parquet_schema, None).unwrap(); + let converted_arrow_schema = + parquet_to_arrow_schema(&parquet_schema, None).unwrap(); let arrow_fields = Fields::from(vec![ Field::new("boolean", DataType::Boolean, false), @@ -694,9 +718,12 @@ mod tests { ]); assert_eq!(&arrow_fields, converted_arrow_schema.fields()); - let converted_arrow_schema = - parquet_to_arrow_schema_by_columns(&parquet_schema, ProjectionMask::all(), None) - .unwrap(); + let converted_arrow_schema = parquet_to_arrow_schema_by_columns( + &parquet_schema, + ProjectionMask::all(), + None, + ) + .unwrap(); assert_eq!(&arrow_fields, converted_arrow_schema.fields()); } @@ -894,7 +921,8 @@ mod tests { let parquet_group_type = parse_message_type(message_type).unwrap(); let parquet_schema = SchemaDescriptor::new(Arc::new(parquet_group_type)); - let converted_arrow_schema = parquet_to_arrow_schema(&parquet_schema, None).unwrap(); + let converted_arrow_schema = + parquet_to_arrow_schema(&parquet_schema, None).unwrap(); let converted_fields = converted_arrow_schema.fields(); assert_eq!(arrow_fields.len(), converted_fields.len()); @@ -972,7 +1000,8 @@ mod tests { let parquet_group_type = parse_message_type(message_type).unwrap(); let parquet_schema = SchemaDescriptor::new(Arc::new(parquet_group_type)); - let converted_arrow_schema = parquet_to_arrow_schema(&parquet_schema, None).unwrap(); + let converted_arrow_schema = + parquet_to_arrow_schema(&parquet_schema, None).unwrap(); let converted_fields = converted_arrow_schema.fields(); assert_eq!(arrow_fields.len(), converted_fields.len()); @@ -1066,7 +1095,8 @@ mod tests { let parquet_group_type = parse_message_type(message_type).unwrap(); let parquet_schema = SchemaDescriptor::new(Arc::new(parquet_group_type)); - let converted_arrow_schema = parquet_to_arrow_schema(&parquet_schema, None).unwrap(); + let converted_arrow_schema = + parquet_to_arrow_schema(&parquet_schema, None).unwrap(); let converted_fields = converted_arrow_schema.fields(); assert_eq!(arrow_fields.len(), converted_fields.len()); @@ -1083,7 +1113,8 @@ mod tests { Field::new("leaf1", DataType::Boolean, false), Field::new("leaf2", DataType::Int32, false), ]); - let group1_struct = Field::new("group1", DataType::Struct(group1_fields), false); + let group1_struct = + Field::new("group1", DataType::Struct(group1_fields), false); arrow_fields.push(group1_struct); let leaf3_field = Field::new("leaf3", DataType::Int64, false); @@ -1102,7 +1133,8 @@ mod tests { let parquet_group_type = parse_message_type(message_type).unwrap(); let parquet_schema = SchemaDescriptor::new(Arc::new(parquet_group_type)); - let converted_arrow_schema = parquet_to_arrow_schema(&parquet_schema, None).unwrap(); + let converted_arrow_schema = + parquet_to_arrow_schema(&parquet_schema, None).unwrap(); let converted_fields = converted_arrow_schema.fields(); assert_eq!(arrow_fields.len(), converted_fields.len()); @@ -1255,7 +1287,8 @@ mod tests { let parquet_group_type = parse_message_type(message_type).unwrap(); let parquet_schema = SchemaDescriptor::new(Arc::new(parquet_group_type)); - let converted_arrow_schema = parquet_to_arrow_schema(&parquet_schema, None).unwrap(); + let converted_arrow_schema = + parquet_to_arrow_schema(&parquet_schema, None).unwrap(); let converted_fields = converted_arrow_schema.fields(); assert_eq!(arrow_fields.len(), converted_fields.len()); @@ -1480,11 +1513,20 @@ mod tests { vec![ Field::new("bools", DataType::Boolean, false), Field::new("uint32", DataType::UInt32, false), - Field::new_list("int32", Field::new("element", DataType::Int32, true), false), + Field::new_list( + "int32", + Field::new("element", DataType::Int32, true), + false, + ), ], false, ), - Field::new_dictionary("dictionary_strings", DataType::Int32, DataType::Utf8, false), + Field::new_dictionary( + "dictionary_strings", + DataType::Int32, + DataType::Utf8, + false, + ), Field::new("decimal_int32", DataType::Decimal128(8, 2), false), Field::new("decimal_int64", DataType::Decimal128(16, 2), false), Field::new("decimal_fix_length", DataType::Decimal128(30, 2), false), @@ -1569,8 +1611,10 @@ mod tests { let schema = Schema::new_with_metadata( vec![ - Field::new("c1", DataType::Utf8, false) - .with_metadata(meta(&[("Key", "Foo"), (PARQUET_FIELD_ID_META_KEY, "2")])), + Field::new("c1", DataType::Utf8, false).with_metadata(meta(&[ + ("Key", "Foo"), + (PARQUET_FIELD_ID_META_KEY, "2"), + ])), Field::new("c2", DataType::Binary, false), Field::new("c3", DataType::FixedSizeBinary(3), false), Field::new("c4", DataType::Boolean, false), @@ -1588,7 +1632,10 @@ mod tests { ), Field::new( "c17", - DataType::Timestamp(TimeUnit::Microsecond, Some("Africa/Johannesburg".into())), + DataType::Timestamp( + TimeUnit::Microsecond, + Some("Africa/Johannesburg".into()), + ), false, ), Field::new( @@ -1600,8 +1647,10 @@ mod tests { Field::new("c20", DataType::Interval(IntervalUnit::YearMonth), false), Field::new_list( "c21", - Field::new("item", DataType::Boolean, true) - .with_metadata(meta(&[("Key", "Bar"), (PARQUET_FIELD_ID_META_KEY, "5")])), + Field::new("item", DataType::Boolean, true).with_metadata(meta(&[ + ("Key", "Bar"), + (PARQUET_FIELD_ID_META_KEY, "5"), + ])), false, ) .with_metadata(meta(&[(PARQUET_FIELD_ID_META_KEY, "4")])), @@ -1651,7 +1700,10 @@ mod tests { // Field::new("c30", DataType::Duration(TimeUnit::Nanosecond), false), Field::new_dict( "c31", - DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::Utf8)), + DataType::Dictionary( + Box::new(DataType::Int32), + Box::new(DataType::Utf8), + ), true, 123, true, @@ -1686,7 +1738,11 @@ mod tests { "c39", "key_value", Field::new("key", DataType::Utf8, false), - Field::new_list("value", Field::new("element", DataType::Utf8, true), true), + Field::new_list( + "value", + Field::new("element", DataType::Utf8, true), + true, + ), false, // fails to roundtrip keys_sorted true, ), @@ -1725,8 +1781,11 @@ mod tests { // write to an empty parquet file so that schema is serialized let file = tempfile::tempfile().unwrap(); - let writer = - ArrowWriter::try_new(file.try_clone().unwrap(), Arc::new(schema.clone()), None)?; + let writer = ArrowWriter::try_new( + file.try_clone().unwrap(), + Arc::new(schema.clone()), + None, + )?; writer.close()?; // read file back @@ -1785,23 +1844,33 @@ mod tests { }; let schema = Schema::new_with_metadata( vec![ - Field::new("c1", DataType::Utf8, true) - .with_metadata(meta(&[(PARQUET_FIELD_ID_META_KEY, "1")])), - Field::new("c2", DataType::Utf8, true) - .with_metadata(meta(&[(PARQUET_FIELD_ID_META_KEY, "2")])), + Field::new("c1", DataType::Utf8, true).with_metadata(meta(&[ + (PARQUET_FIELD_ID_META_KEY, "1"), + ])), + Field::new("c2", DataType::Utf8, true).with_metadata(meta(&[ + (PARQUET_FIELD_ID_META_KEY, "2"), + ])), ], HashMap::new(), ); - let writer = ArrowWriter::try_new(vec![], Arc::new(schema.clone()), None)?; + let writer = ArrowWriter::try_new( + vec![], + Arc::new(schema.clone()), + None, + )?; let parquet_bytes = writer.into_inner()?; - let reader = - crate::file::reader::SerializedFileReader::new(bytes::Bytes::from(parquet_bytes))?; + let reader = crate::file::reader::SerializedFileReader::new( + bytes::Bytes::from(parquet_bytes), + )?; let schema_descriptor = reader.metadata().file_metadata().schema_descr_ptr(); // don't pass metadata so field ids are read from Parquet and not from serialized Arrow schema - let arrow_schema = crate::arrow::parquet_to_arrow_schema(&schema_descriptor, None)?; + let arrow_schema = crate::arrow::parquet_to_arrow_schema( + &schema_descriptor, + None, + )?; let parq_schema_descr = crate::arrow::arrow_to_parquet_schema(&arrow_schema)?; let parq_fields = parq_schema_descr.root_schema().get_fields(); @@ -1814,14 +1883,19 @@ mod tests { #[test] fn test_arrow_schema_roundtrip_lists() -> Result<()> { - let metadata: HashMap = [("Key".to_string(), "Value".to_string())] - .iter() - .cloned() - .collect(); + let metadata: HashMap = + [("Key".to_string(), "Value".to_string())] + .iter() + .cloned() + .collect(); let schema = Schema::new_with_metadata( vec![ - Field::new_list("c21", Field::new("array", DataType::Boolean, true), false), + Field::new_list( + "c21", + Field::new("array", DataType::Boolean, true), + false, + ), Field::new( "c22", DataType::FixedSizeList( @@ -1852,8 +1926,11 @@ mod tests { // write to an empty parquet file so that schema is serialized let file = tempfile::tempfile().unwrap(); - let writer = - ArrowWriter::try_new(file.try_clone().unwrap(), Arc::new(schema.clone()), None)?; + let writer = ArrowWriter::try_new( + file.try_clone().unwrap(), + Arc::new(schema.clone()), + None, + )?; writer.close()?; // read file back diff --git a/parquet/src/arrow/schema/primitive.rs b/parquet/src/arrow/schema/primitive.rs index 447fe5fc3ab..fdc744831a2 100644 --- a/parquet/src/arrow/schema/primitive.rs +++ b/parquet/src/arrow/schema/primitive.rs @@ -15,7 +15,9 @@ // specific language governing permissions and limitations // under the License. -use crate::basic::{ConvertedType, LogicalType, TimeUnit as ParquetTimeUnit, Type as PhysicalType}; +use crate::basic::{ + ConvertedType, LogicalType, TimeUnit as ParquetTimeUnit, Type as PhysicalType, +}; use crate::errors::{ParquetError, Result}; use crate::schema::types::{BasicTypeInfo, Type}; use arrow_schema::{DataType, IntervalUnit, TimeUnit, DECIMAL128_MAX_PRECISION}; @@ -156,7 +158,9 @@ fn from_int32(info: &BasicTypeInfo, scale: i32, precision: i32) -> Result Ok(DataType::UInt32), _ => Err(arrow_err!("Cannot create INT32 physical type from {:?}", t)), }, - (Some(LogicalType::Decimal { scale, precision }), _) => decimal_128_type(scale, precision), + (Some(LogicalType::Decimal { scale, precision }), _) => { + decimal_128_type(scale, precision) + } (Some(LogicalType::Date), _) => Ok(DataType::Date32), (Some(LogicalType::Time { unit, .. }), _) => match unit { ParquetTimeUnit::MILLIS(_) => Ok(DataType::Time32(TimeUnit::Millisecond)), @@ -233,7 +237,9 @@ fn from_int64(info: &BasicTypeInfo, scale: i32, precision: i32) -> Result decimal_128_type(scale, precision), + (Some(LogicalType::Decimal { scale, precision }), _) => { + decimal_128_type(scale, precision) + } (None, ConvertedType::DECIMAL) => decimal_128_type(scale, precision), (logical, converted) => Err(arrow_err!( "Unable to convert parquet INT64 logical type {:?} or converted type {}", From fa70501db26fd0fa6ad40f22efa0b643971b6ed5 Mon Sep 17 00:00:00 2001 From: Jefffrey <22608443+Jefffrey@users.noreply.github.com> Date: Tue, 7 Nov 2023 15:38:18 +1100 Subject: [PATCH 09/12] Fix num trait --- parquet/src/record/api.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/parquet/src/record/api.rs b/parquet/src/record/api.rs index 44bf5494f33..e4f473562e0 100644 --- a/parquet/src/record/api.rs +++ b/parquet/src/record/api.rs @@ -21,7 +21,7 @@ use std::fmt; use chrono::{TimeZone, Utc}; use half::f16; -use num::Float; +use num::traits::Float; use num_bigint::{BigInt, Sign}; use crate::basic::{ConvertedType, LogicalType, Type as PhysicalType}; From 21cefd9ef25ff0bcc5e87763d7c895454f1a2304 Mon Sep 17 00:00:00 2001 From: Jefffrey <22608443+Jefffrey@users.noreply.github.com> Date: Tue, 7 Nov 2023 15:41:53 +1100 Subject: [PATCH 10/12] Fix half feature --- parquet/Cargo.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/parquet/Cargo.toml b/parquet/Cargo.toml index ab4292e0e0e..4c7ac4b79c6 100644 --- a/parquet/Cargo.toml +++ b/parquet/Cargo.toml @@ -66,7 +66,7 @@ tokio = { version = "1.0", optional = true, default-features = false, features = hashbrown = { version = "0.14", default-features = false } twox-hash = { version = "1.6", default-features = false } paste = { version = "1.0" } -half = { version = "2.1", default-features = false } +half = { version = "2.1", default-features = false, features = ["num-traits"] } [dev-dependencies] base64 = { version = "0.21", default-features = false, features = ["std"] } From b122f47f61c06e4af52ef8649e76720887ddc977 Mon Sep 17 00:00:00 2001 From: Jefffrey <22608443+Jefffrey@users.noreply.github.com> Date: Wed, 8 Nov 2023 07:20:12 +1100 Subject: [PATCH 11/12] Handle writing signed zero statistics --- parquet/src/column/writer/encoder.rs | 16 ++++++-- parquet/src/column/writer/mod.rs | 56 ++++++++++++++++++++++++++++ 2 files changed, 68 insertions(+), 4 deletions(-) diff --git a/parquet/src/column/writer/encoder.rs b/parquet/src/column/writer/encoder.rs index 92d8f20d82f..dc26b204908 100644 --- a/parquet/src/column/writer/encoder.rs +++ b/parquet/src/column/writer/encoder.rs @@ -15,7 +15,9 @@ // specific language governing permissions and limitations // under the License. -use crate::basic::{Encoding, Type}; +use half::f16; + +use crate::basic::{Encoding, LogicalType, Type}; use crate::bloom_filter::Sbbf; use crate::column::writer::{ compare_greater, fallback_encoding, has_dictionary_support, is_nan, update_max, update_min, @@ -317,14 +319,14 @@ where // // For max, it has similar logic but will be written as 0.0 // (positive zero) - let min = replace_zero(min, -0.0); - let max = replace_zero(max, 0.0); + let min = replace_zero(min, descr, -0.0); + let max = replace_zero(max, descr, 0.0); Some((min, max)) } #[inline] -fn replace_zero(val: &T, replace: f32) -> T { +fn replace_zero(val: &T, descr: &ColumnDescriptor, replace: f32) -> T { match T::PHYSICAL_TYPE { Type::FLOAT if f32::from_le_bytes(val.as_bytes().try_into().unwrap()) == 0.0 => { T::try_from_le_slice(&f32::to_le_bytes(replace)).unwrap() @@ -332,6 +334,12 @@ fn replace_zero(val: &T, replace: f32) -> T { Type::DOUBLE if f64::from_le_bytes(val.as_bytes().try_into().unwrap()) == 0.0 => { T::try_from_le_slice(&f64::to_le_bytes(replace as f64)).unwrap() } + Type::FIXED_LEN_BYTE_ARRAY + if descr.logical_type() == Some(LogicalType::Float16) + && f16::from_le_bytes(val.as_bytes().try_into().unwrap()) == f16::NEG_ZERO => + { + T::try_from_le_slice(&f16::to_le_bytes(f16::from_f32(replace))).unwrap() + } _ => val.clone(), } } diff --git a/parquet/src/column/writer/mod.rs b/parquet/src/column/writer/mod.rs index 60427c9d332..ceaa6996794 100644 --- a/parquet/src/column/writer/mod.rs +++ b/parquet/src/column/writer/mod.rs @@ -2170,6 +2170,62 @@ mod tests { assert!(stats.is_min_max_backwards_compatible()); } + #[test] + fn test_float16_statistics_zero_only() { + let input = [f16::ZERO] + .into_iter() + .map(|s| ByteArray::from(s).into()) + .collect::>(); + + let stats = float16_statistics_roundtrip(&input); + assert!(stats.has_min_max_set()); + assert!(stats.is_min_max_backwards_compatible()); + assert_eq!(stats.min(), &ByteArray::from(f16::NEG_ZERO)); + assert_eq!(stats.max(), &ByteArray::from(f16::ZERO)); + } + + #[test] + fn test_float16_statistics_neg_zero_only() { + let input = [f16::NEG_ZERO] + .into_iter() + .map(|s| ByteArray::from(s).into()) + .collect::>(); + + let stats = float16_statistics_roundtrip(&input); + assert!(stats.has_min_max_set()); + assert!(stats.is_min_max_backwards_compatible()); + assert_eq!(stats.min(), &ByteArray::from(f16::NEG_ZERO)); + assert_eq!(stats.max(), &ByteArray::from(f16::ZERO)); + } + + #[test] + fn test_float16_statistics_zero_min() { + let input = [f16::ZERO, f16::ONE, f16::NAN, f16::PI] + .into_iter() + .map(|s| ByteArray::from(s).into()) + .collect::>(); + + let stats = float16_statistics_roundtrip(&input); + assert!(stats.has_min_max_set()); + assert!(stats.is_min_max_backwards_compatible()); + assert_eq!(stats.min(), &ByteArray::from(f16::NEG_ZERO)); + assert_eq!(stats.max(), &ByteArray::from(f16::PI)); + } + + #[test] + fn test_float16_statistics_neg_zero_max() { + let input = [f16::NEG_ZERO, f16::NEG_ONE, f16::NAN, -f16::PI] + .into_iter() + .map(|s| ByteArray::from(s).into()) + .collect::>(); + + let stats = float16_statistics_roundtrip(&input); + assert!(stats.has_min_max_set()); + assert!(stats.is_min_max_backwards_compatible()); + assert_eq!(stats.min(), &ByteArray::from(-f16::PI)); + assert_eq!(stats.max(), &ByteArray::from(f16::ZERO)); + } + #[test] fn test_float_statistics_nan_middle() { let stats = statistics_roundtrip::(&[1.0, f32::NAN, 2.0]); From 9d424ccdfeac6bf55c0eefd567acce816a7b7b6b Mon Sep 17 00:00:00 2001 From: Jefffrey <22608443+Jefffrey@users.noreply.github.com> Date: Fri, 10 Nov 2023 08:10:03 +1100 Subject: [PATCH 12/12] Bump parquet-testing and read new f16 files for test --- parquet-testing | 2 +- parquet/src/arrow/arrow_reader/mod.rs | 56 +++++++++++++++++++++++++++ 2 files changed, 57 insertions(+), 1 deletion(-) diff --git a/parquet-testing b/parquet-testing index aafd3fc9df4..506afff9b69 160000 --- a/parquet-testing +++ b/parquet-testing @@ -1 +1 @@ -Subproject commit aafd3fc9df431c2625a514fb46626e5614f1d199 +Subproject commit 506afff9b6957ffe10d08470d467867d43e1bb91 diff --git a/parquet/src/arrow/arrow_reader/mod.rs b/parquet/src/arrow/arrow_reader/mod.rs index 947dfc6c9f5..b9e9d289845 100644 --- a/parquet/src/arrow/arrow_reader/mod.rs +++ b/parquet/src/arrow/arrow_reader/mod.rs @@ -1316,6 +1316,62 @@ mod tests { } } + #[test] + fn test_read_float16_nonzeros_file() { + use arrow_array::Float16Array; + let testdata = arrow::util::test_util::parquet_test_data(); + // see https://github.com/apache/parquet-testing/pull/40 + let path = format!("{testdata}/float16_nonzeros_and_nans.parquet"); + let file = File::open(path).unwrap(); + let mut record_reader = ParquetRecordBatchReader::try_new(file, 32).unwrap(); + + let batch = record_reader.next().unwrap().unwrap(); + assert_eq!(batch.num_rows(), 8); + let col = batch + .column(0) + .as_any() + .downcast_ref::() + .unwrap(); + + let f16_two = f16::ONE + f16::ONE; + + assert_eq!(col.null_count(), 1); + assert!(col.is_null(0)); + assert_eq!(col.value(1), f16::ONE); + assert_eq!(col.value(2), -f16_two); + assert!(col.value(3).is_nan()); + assert_eq!(col.value(4), f16::ZERO); + assert!(col.value(4).is_sign_positive()); + assert_eq!(col.value(5), f16::NEG_ONE); + assert_eq!(col.value(6), f16::NEG_ZERO); + assert!(col.value(6).is_sign_negative()); + assert_eq!(col.value(7), f16_two); + } + + #[test] + fn test_read_float16_zeros_file() { + use arrow_array::Float16Array; + let testdata = arrow::util::test_util::parquet_test_data(); + // see https://github.com/apache/parquet-testing/pull/40 + let path = format!("{testdata}/float16_zeros_and_nans.parquet"); + let file = File::open(path).unwrap(); + let mut record_reader = ParquetRecordBatchReader::try_new(file, 32).unwrap(); + + let batch = record_reader.next().unwrap().unwrap(); + assert_eq!(batch.num_rows(), 3); + let col = batch + .column(0) + .as_any() + .downcast_ref::() + .unwrap(); + + assert_eq!(col.null_count(), 1); + assert!(col.is_null(0)); + assert_eq!(col.value(1), f16::ZERO); + assert!(col.value(1).is_sign_positive()); + assert!(col.value(2).is_nan()); + } + /// Parameters for single_column_reader_test #[derive(Clone)] struct TestOptions {