diff --git a/parquet-testing b/parquet-testing index aafd3fc9df4..506afff9b69 160000 --- a/parquet-testing +++ b/parquet-testing @@ -1 +1 @@ -Subproject commit aafd3fc9df431c2625a514fb46626e5614f1d199 +Subproject commit 506afff9b6957ffe10d08470d467867d43e1bb91 diff --git a/parquet/Cargo.toml b/parquet/Cargo.toml index e5f5e1652b8..bdcbcb81cfc 100644 --- a/parquet/Cargo.toml +++ b/parquet/Cargo.toml @@ -66,6 +66,7 @@ tokio = { version = "1.0", optional = true, default-features = false, features = hashbrown = { version = "0.14", default-features = false } twox-hash = { version = "1.6", default-features = false } paste = { version = "1.0" } +half = { version = "2.1", default-features = false, features = ["num-traits"] } [dev-dependencies] base64 = { version = "0.21", default-features = false, features = ["std"] } diff --git a/parquet/regen.sh b/parquet/regen.sh index b8c3549e232..91539634339 100755 --- a/parquet/regen.sh +++ b/parquet/regen.sh @@ -17,7 +17,7 @@ # specific language governing permissions and limitations # under the License. -REVISION=aeae80660c1d0c97314e9da837de1abdebd49c37 +REVISION=46cc3a0647d301bb9579ca8dd2cc356caf2a72d2 SOURCE_DIR="$(cd "$(dirname "${BASH_SOURCE[0]:-$0}")" && pwd)" diff --git a/parquet/src/arrow/array_reader/fixed_len_byte_array.rs b/parquet/src/arrow/array_reader/fixed_len_byte_array.rs index 3b1a50ebcce..b846997d36b 100644 --- a/parquet/src/arrow/array_reader/fixed_len_byte_array.rs +++ b/parquet/src/arrow/array_reader/fixed_len_byte_array.rs @@ -27,13 +27,14 @@ use crate::column::reader::decoder::{ColumnValueDecoder, ValuesBufferSlice}; use crate::errors::{ParquetError, Result}; use crate::schema::types::ColumnDescPtr; use arrow_array::{ - ArrayRef, Decimal128Array, Decimal256Array, FixedSizeBinaryArray, + ArrayRef, Decimal128Array, Decimal256Array, FixedSizeBinaryArray, Float16Array, IntervalDayTimeArray, IntervalYearMonthArray, }; use arrow_buffer::{i256, Buffer}; use arrow_data::ArrayDataBuilder; use arrow_schema::{DataType as ArrowType, IntervalUnit}; use bytes::Bytes; +use half::f16; use std::any::Any; use std::ops::Range; use std::sync::Arc; @@ -88,6 +89,14 @@ pub fn make_fixed_len_byte_array_reader( )); } } + ArrowType::Float16 => { + if byte_length != 2 { + return Err(general_err!( + "float 16 type must be 2 bytes, got {}", + byte_length + )); + } + } _ => { return Err(general_err!( "invalid data type for fixed length byte array reader - {}", @@ -208,6 +217,12 @@ impl ArrayReader for FixedLenByteArrayReader { } } } + ArrowType::Float16 => Arc::new( + binary + .iter() + .map(|o| o.map(|b| f16::from_le_bytes(b[..2].try_into().unwrap()))) + .collect::(), + ) as ArrayRef, _ => Arc::new(binary) as ArrayRef, }; diff --git a/parquet/src/arrow/arrow_reader/mod.rs b/parquet/src/arrow/arrow_reader/mod.rs index 16cdf2934e6..b9e9d289845 100644 --- a/parquet/src/arrow/arrow_reader/mod.rs +++ b/parquet/src/arrow/arrow_reader/mod.rs @@ -712,13 +712,14 @@ mod tests { use std::sync::Arc; use bytes::Bytes; + use half::f16; use num::PrimInt; use rand::{thread_rng, Rng, RngCore}; use tempfile::tempfile; use arrow_array::builder::*; use arrow_array::cast::AsArray; - use arrow_array::types::{Decimal128Type, Decimal256Type, DecimalType}; + use arrow_array::types::{Decimal128Type, Decimal256Type, DecimalType, Float16Type}; use arrow_array::*; use arrow_array::{RecordBatch, RecordBatchReader}; use arrow_buffer::{i256, ArrowNativeType, Buffer}; @@ -924,6 +925,66 @@ mod tests { .unwrap(); } + #[test] + fn test_float16_roundtrip() -> Result<()> { + let schema = Arc::new(Schema::new(vec![ + Field::new("float16", ArrowDataType::Float16, false), + Field::new("float16-nullable", ArrowDataType::Float16, true), + ])); + + let mut buf = Vec::with_capacity(1024); + let mut writer = ArrowWriter::try_new(&mut buf, schema.clone(), None)?; + + let original = RecordBatch::try_new( + schema, + vec![ + Arc::new(Float16Array::from_iter_values([ + f16::EPSILON, + f16::MIN, + f16::MAX, + f16::NAN, + f16::INFINITY, + f16::NEG_INFINITY, + f16::ONE, + f16::NEG_ONE, + f16::ZERO, + f16::NEG_ZERO, + f16::E, + f16::PI, + f16::FRAC_1_PI, + ])), + Arc::new(Float16Array::from(vec![ + None, + None, + None, + Some(f16::NAN), + Some(f16::INFINITY), + Some(f16::NEG_INFINITY), + None, + None, + None, + None, + None, + None, + Some(f16::FRAC_1_PI), + ])), + ], + )?; + + writer.write(&original)?; + writer.close()?; + + let mut reader = ParquetRecordBatchReader::try_new(Bytes::from(buf), 1024)?; + let ret = reader.next().unwrap()?; + assert_eq!(ret, original); + + // Ensure can be downcast to the correct type + ret.column(0).as_primitive::(); + ret.column(1).as_primitive::(); + + Ok(()) + } + struct RandFixedLenGen {} impl RandGen for RandFixedLenGen { @@ -1255,6 +1316,62 @@ mod tests { } } + #[test] + fn test_read_float16_nonzeros_file() { + use arrow_array::Float16Array; + let testdata = arrow::util::test_util::parquet_test_data(); + // see https://github.com/apache/parquet-testing/pull/40 + let path = format!("{testdata}/float16_nonzeros_and_nans.parquet"); + let file = File::open(path).unwrap(); + let mut record_reader = ParquetRecordBatchReader::try_new(file, 32).unwrap(); + + let batch = record_reader.next().unwrap().unwrap(); + assert_eq!(batch.num_rows(), 8); + let col = batch + .column(0) + .as_any() + .downcast_ref::() + .unwrap(); + + let f16_two = f16::ONE + f16::ONE; + + assert_eq!(col.null_count(), 1); + assert!(col.is_null(0)); + assert_eq!(col.value(1), f16::ONE); + assert_eq!(col.value(2), -f16_two); + assert!(col.value(3).is_nan()); + assert_eq!(col.value(4), f16::ZERO); + assert!(col.value(4).is_sign_positive()); + assert_eq!(col.value(5), f16::NEG_ONE); + assert_eq!(col.value(6), f16::NEG_ZERO); + assert!(col.value(6).is_sign_negative()); + assert_eq!(col.value(7), f16_two); + } + + #[test] + fn test_read_float16_zeros_file() { + use arrow_array::Float16Array; + let testdata = arrow::util::test_util::parquet_test_data(); + // see https://github.com/apache/parquet-testing/pull/40 + let path = format!("{testdata}/float16_zeros_and_nans.parquet"); + let file = File::open(path).unwrap(); + let mut record_reader = ParquetRecordBatchReader::try_new(file, 32).unwrap(); + + let batch = record_reader.next().unwrap().unwrap(); + assert_eq!(batch.num_rows(), 3); + let col = batch + .column(0) + .as_any() + .downcast_ref::() + .unwrap(); + + assert_eq!(col.null_count(), 1); + assert!(col.is_null(0)); + assert_eq!(col.value(1), f16::ZERO); + assert!(col.value(1).is_sign_positive()); + assert!(col.value(2).is_nan()); + } + /// Parameters for single_column_reader_test #[derive(Clone)] struct TestOptions { diff --git a/parquet/src/arrow/arrow_writer/mod.rs b/parquet/src/arrow/arrow_writer/mod.rs index eca1dea791b..ea7b1eee99b 100644 --- a/parquet/src/arrow/arrow_writer/mod.rs +++ b/parquet/src/arrow/arrow_writer/mod.rs @@ -771,6 +771,10 @@ fn write_leaf(writer: &mut ColumnWriter<'_>, levels: &ArrayLevels) -> Result { + let array = column.as_primitive::(); + get_float_16_array_slice(array, indices) + } _ => { return Err(ParquetError::NYI( "Attempting to write an Arrow type that is not yet implemented".to_string(), @@ -867,6 +871,18 @@ fn get_decimal_256_array_slice( values } +fn get_float_16_array_slice( + array: &arrow_array::Float16Array, + indices: &[usize], +) -> Vec { + let mut values = Vec::with_capacity(indices.len()); + for i in indices { + let value = array.value(*i).to_le_bytes().to_vec(); + values.push(FixedLenByteArray::from(ByteArray::from(value))); + } + values +} + fn get_fsb_array_slice( array: &arrow_array::FixedSizeBinaryArray, indices: &[usize], diff --git a/parquet/src/arrow/schema/mod.rs b/parquet/src/arrow/schema/mod.rs index d56cc42d431..4c350c4b1d8 100644 --- a/parquet/src/arrow/schema/mod.rs +++ b/parquet/src/arrow/schema/mod.rs @@ -373,7 +373,12 @@ fn arrow_to_parquet_type(field: &Field) -> Result { .with_repetition(repetition) .with_id(id) .build(), - DataType::Float16 => Err(arrow_err!("Float16 arrays not supported")), + DataType::Float16 => Type::primitive_type_builder(name, PhysicalType::FIXED_LEN_BYTE_ARRAY) + .with_repetition(repetition) + .with_id(id) + .with_logical_type(Some(LogicalType::Float16)) + .with_length(2) + .build(), DataType::Float32 => Type::primitive_type_builder(name, PhysicalType::FLOAT) .with_repetition(repetition) .with_id(id) @@ -604,9 +609,10 @@ mod tests { REQUIRED INT32 uint8 (INTEGER(8,false)); REQUIRED INT32 uint16 (INTEGER(16,false)); REQUIRED INT32 int32; - REQUIRED INT64 int64 ; + REQUIRED INT64 int64; OPTIONAL DOUBLE double; OPTIONAL FLOAT float; + OPTIONAL FIXED_LEN_BYTE_ARRAY (2) float16 (FLOAT16); OPTIONAL BINARY string (UTF8); OPTIONAL BINARY string_2 (STRING); OPTIONAL BINARY json (JSON); @@ -628,6 +634,7 @@ mod tests { Field::new("int64", DataType::Int64, false), Field::new("double", DataType::Float64, true), Field::new("float", DataType::Float32, true), + Field::new("float16", DataType::Float16, true), Field::new("string", DataType::Utf8, true), Field::new("string_2", DataType::Utf8, true), Field::new("json", DataType::Utf8, true), @@ -1303,6 +1310,7 @@ mod tests { REQUIRED INT64 int64; OPTIONAL DOUBLE double; OPTIONAL FLOAT float; + OPTIONAL FIXED_LEN_BYTE_ARRAY (2) float16 (FLOAT16); OPTIONAL BINARY string (UTF8); REPEATED BOOLEAN bools; OPTIONAL INT32 date (DATE); @@ -1339,6 +1347,7 @@ mod tests { Field::new("int64", DataType::Int64, false), Field::new("double", DataType::Float64, true), Field::new("float", DataType::Float32, true), + Field::new("float16", DataType::Float16, true), Field::new("string", DataType::Utf8, true), Field::new_list( "bools", @@ -1398,6 +1407,7 @@ mod tests { REQUIRED INT64 int64; OPTIONAL DOUBLE double; OPTIONAL FLOAT float; + OPTIONAL FIXED_LEN_BYTE_ARRAY (2) float16 (FLOAT16); OPTIONAL BINARY string (STRING); OPTIONAL GROUP bools (LIST) { REPEATED GROUP list { @@ -1448,6 +1458,7 @@ mod tests { Field::new("int64", DataType::Int64, false), Field::new("double", DataType::Float64, true), Field::new("float", DataType::Float32, true), + Field::new("float16", DataType::Float16, true), Field::new("string", DataType::Utf8, true), Field::new_list( "bools", @@ -1661,6 +1672,8 @@ mod tests { vec![ Field::new("a", DataType::Int16, true), Field::new("b", DataType::Float64, false), + Field::new("c", DataType::Float32, false), + Field::new("d", DataType::Float16, false), ] .into(), ), diff --git a/parquet/src/arrow/schema/primitive.rs b/parquet/src/arrow/schema/primitive.rs index 7d8b6a04ee8..fdc744831a2 100644 --- a/parquet/src/arrow/schema/primitive.rs +++ b/parquet/src/arrow/schema/primitive.rs @@ -304,6 +304,16 @@ fn from_fixed_len_byte_array( // would be incorrect if all 12 bytes of the interval are populated Ok(DataType::Interval(IntervalUnit::DayTime)) } + (Some(LogicalType::Float16), _) => { + if type_length == 2 { + Ok(DataType::Float16) + } else { + Err(ParquetError::General( + "FLOAT16 logical type must be Fixed Length Byte Array with length 2" + .to_string(), + )) + } + } _ => Ok(DataType::FixedSizeBinary(type_length)), } } diff --git a/parquet/src/basic.rs b/parquet/src/basic.rs index 3c8602b8022..2327e1d84b4 100644 --- a/parquet/src/basic.rs +++ b/parquet/src/basic.rs @@ -194,6 +194,7 @@ pub enum LogicalType { Json, Bson, Uuid, + Float16, } // ---------------------------------------------------------------------- @@ -505,6 +506,7 @@ impl ColumnOrder { LogicalType::Timestamp { .. } => SortOrder::SIGNED, LogicalType::Unknown => SortOrder::UNDEFINED, LogicalType::Uuid => SortOrder::UNSIGNED, + LogicalType::Float16 => SortOrder::SIGNED, }, // Fall back to converted type None => Self::get_converted_sort_order(converted_type, physical_type), @@ -766,6 +768,7 @@ impl From for LogicalType { parquet::LogicalType::JSON(_) => LogicalType::Json, parquet::LogicalType::BSON(_) => LogicalType::Bson, parquet::LogicalType::UUID(_) => LogicalType::Uuid, + parquet::LogicalType::FLOAT16(_) => LogicalType::Float16, } } } @@ -806,6 +809,7 @@ impl From for parquet::LogicalType { LogicalType::Json => parquet::LogicalType::JSON(Default::default()), LogicalType::Bson => parquet::LogicalType::BSON(Default::default()), LogicalType::Uuid => parquet::LogicalType::UUID(Default::default()), + LogicalType::Float16 => parquet::LogicalType::FLOAT16(Default::default()), } } } @@ -853,10 +857,11 @@ impl From> for ConvertedType { (64, false) => ConvertedType::UINT_64, t => panic!("Integer type {t:?} is not supported"), }, - LogicalType::Unknown => ConvertedType::NONE, LogicalType::Json => ConvertedType::JSON, LogicalType::Bson => ConvertedType::BSON, - LogicalType::Uuid => ConvertedType::NONE, + LogicalType::Uuid | LogicalType::Float16 | LogicalType::Unknown => { + ConvertedType::NONE + } }, None => ConvertedType::NONE, } @@ -1102,6 +1107,7 @@ impl str::FromStr for LogicalType { "INTERVAL" => Err(general_err!( "Interval parquet logical type not yet supported" )), + "FLOAT16" => Ok(LogicalType::Float16), other => Err(general_err!("Invalid parquet logical type {}", other)), } } @@ -1746,6 +1752,10 @@ mod tests { ConvertedType::from(Some(LogicalType::Enum)), ConvertedType::ENUM ); + assert_eq!( + ConvertedType::from(Some(LogicalType::Float16)), + ConvertedType::NONE + ); assert_eq!( ConvertedType::from(Some(LogicalType::Unknown)), ConvertedType::NONE @@ -2119,6 +2129,7 @@ mod tests { is_adjusted_to_u_t_c: true, unit: TimeUnit::NANOS(Default::default()), }, + LogicalType::Float16, ]; check_sort_order(signed, SortOrder::SIGNED); diff --git a/parquet/src/column/writer/encoder.rs b/parquet/src/column/writer/encoder.rs index 2273ae77744..d0720dd2430 100644 --- a/parquet/src/column/writer/encoder.rs +++ b/parquet/src/column/writer/encoder.rs @@ -16,8 +16,9 @@ // under the License. use bytes::Bytes; +use half::f16; -use crate::basic::{Encoding, Type}; +use crate::basic::{Encoding, LogicalType, Type}; use crate::bloom_filter::Sbbf; use crate::column::writer::{ compare_greater, fallback_encoding, has_dictionary_support, is_nan, update_max, update_min, @@ -291,7 +292,7 @@ where { let first = loop { let next = iter.next()?; - if !is_nan(next) { + if !is_nan(descr, next) { break next; } }; @@ -299,7 +300,7 @@ where let mut min = first; let mut max = first; for val in iter { - if is_nan(val) { + if is_nan(descr, val) { continue; } if compare_greater(descr, min, val) { @@ -318,14 +319,14 @@ where // // For max, it has similar logic but will be written as 0.0 // (positive zero) - let min = replace_zero(min, -0.0); - let max = replace_zero(max, 0.0); + let min = replace_zero(min, descr, -0.0); + let max = replace_zero(max, descr, 0.0); Some((min, max)) } #[inline] -fn replace_zero(val: &T, replace: f32) -> T { +fn replace_zero(val: &T, descr: &ColumnDescriptor, replace: f32) -> T { match T::PHYSICAL_TYPE { Type::FLOAT if f32::from_le_bytes(val.as_bytes().try_into().unwrap()) == 0.0 => { T::try_from_le_slice(&f32::to_le_bytes(replace)).unwrap() @@ -333,6 +334,12 @@ fn replace_zero(val: &T, replace: f32) -> T { Type::DOUBLE if f64::from_le_bytes(val.as_bytes().try_into().unwrap()) == 0.0 => { T::try_from_le_slice(&f64::to_le_bytes(replace as f64)).unwrap() } + Type::FIXED_LEN_BYTE_ARRAY + if descr.logical_type() == Some(LogicalType::Float16) + && f16::from_le_bytes(val.as_bytes().try_into().unwrap()) == f16::NEG_ZERO => + { + T::try_from_le_slice(&f16::to_le_bytes(f16::from_f32(replace))).unwrap() + } _ => val.clone(), } } diff --git a/parquet/src/column/writer/mod.rs b/parquet/src/column/writer/mod.rs index 60db90c5d46..a917c486498 100644 --- a/parquet/src/column/writer/mod.rs +++ b/parquet/src/column/writer/mod.rs @@ -18,6 +18,7 @@ //! Contains column writer API. use bytes::Bytes; +use half::f16; use crate::bloom_filter::Sbbf; use crate::format::{ColumnIndex, OffsetIndex}; @@ -968,18 +969,23 @@ impl<'a, E: ColumnValueEncoder> GenericColumnWriter<'a, E> { } fn update_min(descr: &ColumnDescriptor, val: &T, min: &mut Option) { - update_stat::(val, min, |cur| compare_greater(descr, cur, val)) + update_stat::(descr, val, min, |cur| compare_greater(descr, cur, val)) } fn update_max(descr: &ColumnDescriptor, val: &T, max: &mut Option) { - update_stat::(val, max, |cur| compare_greater(descr, val, cur)) + update_stat::(descr, val, max, |cur| compare_greater(descr, val, cur)) } #[inline] #[allow(clippy::eq_op)] -fn is_nan(val: &T) -> bool { +fn is_nan(descr: &ColumnDescriptor, val: &T) -> bool { match T::PHYSICAL_TYPE { Type::FLOAT | Type::DOUBLE => val != val, + Type::FIXED_LEN_BYTE_ARRAY if descr.logical_type() == Some(LogicalType::Float16) => { + let val = val.as_bytes(); + let val = f16::from_le_bytes([val[0], val[1]]); + val.is_nan() + } _ => false, } } @@ -989,11 +995,15 @@ fn is_nan(val: &T) -> bool { /// If `cur` is `None`, sets `cur` to `Some(val)`, otherwise calls `should_update` with /// the value of `cur`, and updates `cur` to `Some(val)` if it returns `true` -fn update_stat(val: &T, cur: &mut Option, should_update: F) -where +fn update_stat( + descr: &ColumnDescriptor, + val: &T, + cur: &mut Option, + should_update: F, +) where F: Fn(&T) -> bool, { - if is_nan(val) { + if is_nan(descr, val) { return; } @@ -1039,6 +1049,14 @@ fn compare_greater(descr: &ColumnDescriptor, a: &T, b: &T) }; }; + if let Some(LogicalType::Float16) = descr.logical_type() { + let a = a.as_bytes(); + let a = f16::from_le_bytes([a[0], a[1]]); + let b = b.as_bytes(); + let b = f16::from_le_bytes([b[0], b[1]]); + return a > b; + } + a > b } @@ -1170,6 +1188,7 @@ fn increment_utf8(mut data: Vec) -> Option> { mod tests { use crate::{file::properties::DEFAULT_COLUMN_INDEX_TRUNCATE_LENGTH, format::BoundaryOrder}; use bytes::Bytes; + use half::f16; use rand::distributions::uniform::SampleUniform; use std::sync::Arc; @@ -2078,6 +2097,135 @@ mod tests { } } + #[test] + fn test_column_writer_check_float16_min_max() { + let input = [ + -f16::ONE, + f16::from_f32(3.0), + -f16::from_f32(2.0), + f16::from_f32(2.0), + ] + .into_iter() + .map(|s| ByteArray::from(s).into()) + .collect::>(); + + let stats = float16_statistics_roundtrip(&input); + assert!(stats.has_min_max_set()); + assert!(stats.is_min_max_backwards_compatible()); + assert_eq!(stats.min(), &ByteArray::from(-f16::from_f32(2.0))); + assert_eq!(stats.max(), &ByteArray::from(f16::from_f32(3.0))); + } + + #[test] + fn test_column_writer_check_float16_nan_middle() { + let input = [f16::ONE, f16::NAN, f16::ONE + f16::ONE] + .into_iter() + .map(|s| ByteArray::from(s).into()) + .collect::>(); + + let stats = float16_statistics_roundtrip(&input); + assert!(stats.has_min_max_set()); + assert!(stats.is_min_max_backwards_compatible()); + assert_eq!(stats.min(), &ByteArray::from(f16::ONE)); + assert_eq!(stats.max(), &ByteArray::from(f16::ONE + f16::ONE)); + } + + #[test] + fn test_float16_statistics_nan_middle() { + let input = [f16::ONE, f16::NAN, f16::ONE + f16::ONE] + .into_iter() + .map(|s| ByteArray::from(s).into()) + .collect::>(); + + let stats = float16_statistics_roundtrip(&input); + assert!(stats.has_min_max_set()); + assert!(stats.is_min_max_backwards_compatible()); + assert_eq!(stats.min(), &ByteArray::from(f16::ONE)); + assert_eq!(stats.max(), &ByteArray::from(f16::ONE + f16::ONE)); + } + + #[test] + fn test_float16_statistics_nan_start() { + let input = [f16::NAN, f16::ONE, f16::ONE + f16::ONE] + .into_iter() + .map(|s| ByteArray::from(s).into()) + .collect::>(); + + let stats = float16_statistics_roundtrip(&input); + assert!(stats.has_min_max_set()); + assert!(stats.is_min_max_backwards_compatible()); + assert_eq!(stats.min(), &ByteArray::from(f16::ONE)); + assert_eq!(stats.max(), &ByteArray::from(f16::ONE + f16::ONE)); + } + + #[test] + fn test_float16_statistics_nan_only() { + let input = [f16::NAN, f16::NAN] + .into_iter() + .map(|s| ByteArray::from(s).into()) + .collect::>(); + + let stats = float16_statistics_roundtrip(&input); + assert!(!stats.has_min_max_set()); + assert!(stats.is_min_max_backwards_compatible()); + } + + #[test] + fn test_float16_statistics_zero_only() { + let input = [f16::ZERO] + .into_iter() + .map(|s| ByteArray::from(s).into()) + .collect::>(); + + let stats = float16_statistics_roundtrip(&input); + assert!(stats.has_min_max_set()); + assert!(stats.is_min_max_backwards_compatible()); + assert_eq!(stats.min(), &ByteArray::from(f16::NEG_ZERO)); + assert_eq!(stats.max(), &ByteArray::from(f16::ZERO)); + } + + #[test] + fn test_float16_statistics_neg_zero_only() { + let input = [f16::NEG_ZERO] + .into_iter() + .map(|s| ByteArray::from(s).into()) + .collect::>(); + + let stats = float16_statistics_roundtrip(&input); + assert!(stats.has_min_max_set()); + assert!(stats.is_min_max_backwards_compatible()); + assert_eq!(stats.min(), &ByteArray::from(f16::NEG_ZERO)); + assert_eq!(stats.max(), &ByteArray::from(f16::ZERO)); + } + + #[test] + fn test_float16_statistics_zero_min() { + let input = [f16::ZERO, f16::ONE, f16::NAN, f16::PI] + .into_iter() + .map(|s| ByteArray::from(s).into()) + .collect::>(); + + let stats = float16_statistics_roundtrip(&input); + assert!(stats.has_min_max_set()); + assert!(stats.is_min_max_backwards_compatible()); + assert_eq!(stats.min(), &ByteArray::from(f16::NEG_ZERO)); + assert_eq!(stats.max(), &ByteArray::from(f16::PI)); + } + + #[test] + fn test_float16_statistics_neg_zero_max() { + let input = [f16::NEG_ZERO, f16::NEG_ONE, f16::NAN, -f16::PI] + .into_iter() + .map(|s| ByteArray::from(s).into()) + .collect::>(); + + let stats = float16_statistics_roundtrip(&input); + assert!(stats.has_min_max_set()); + assert!(stats.is_min_max_backwards_compatible()); + assert_eq!(stats.min(), &ByteArray::from(-f16::PI)); + assert_eq!(stats.max(), &ByteArray::from(f16::ZERO)); + } + #[test] fn test_float_statistics_nan_middle() { let stats = statistics_roundtrip::(&[1.0, f32::NAN, 2.0]); @@ -2850,6 +2998,50 @@ mod tests { ColumnDescriptor::new(Arc::new(tpe), max_def_level, max_rep_level, path) } + fn float16_statistics_roundtrip( + values: &[FixedLenByteArray], + ) -> ValueStatistics { + let page_writer = get_test_page_writer(); + let props = Default::default(); + let mut writer = + get_test_float16_column_writer::(page_writer, 0, 0, props); + writer.write_batch(values, None, None).unwrap(); + + let metadata = writer.close().unwrap().metadata; + if let Some(Statistics::FixedLenByteArray(stats)) = metadata.statistics() { + stats.clone() + } else { + panic!("metadata missing statistics"); + } + } + + fn get_test_float16_column_writer( + page_writer: Box, + max_def_level: i16, + max_rep_level: i16, + props: WriterPropertiesPtr, + ) -> ColumnWriterImpl<'static, T> { + let descr = Arc::new(get_test_float16_column_descr::( + max_def_level, + max_rep_level, + )); + let column_writer = get_column_writer(descr, props, page_writer); + get_typed_column_writer::(column_writer) + } + + fn get_test_float16_column_descr( + max_def_level: i16, + max_rep_level: i16, + ) -> ColumnDescriptor { + let path = ColumnPath::from("col"); + let tpe = SchemaType::primitive_type_builder("col", T::get_physical_type()) + .with_length(2) + .with_logical_type(Some(LogicalType::Float16)) + .build() + .unwrap(); + ColumnDescriptor::new(Arc::new(tpe), max_def_level, max_rep_level, path) + } + /// Returns column writer for UINT32 Column provided as ConvertedType only fn get_test_unsigned_int_given_as_converted_column_writer<'a, T: DataType>( page_writer: Box, diff --git a/parquet/src/data_type.rs b/parquet/src/data_type.rs index b895c250701..86da7a3acee 100644 --- a/parquet/src/data_type.rs +++ b/parquet/src/data_type.rs @@ -18,6 +18,7 @@ //! Data types that connect Parquet physical types with their Rust-specific //! representations. use bytes::Bytes; +use half::f16; use std::cmp::Ordering; use std::fmt; use std::mem; @@ -225,6 +226,12 @@ impl From for ByteArray { } } +impl From for ByteArray { + fn from(value: f16) -> Self { + Self::from(value.to_le_bytes().as_slice()) + } +} + impl PartialEq for ByteArray { fn eq(&self, other: &ByteArray) -> bool { match (&self.data, &other.data) { diff --git a/parquet/src/file/statistics.rs b/parquet/src/file/statistics.rs index b36e37a80c9..345fe7dd261 100644 --- a/parquet/src/file/statistics.rs +++ b/parquet/src/file/statistics.rs @@ -243,6 +243,8 @@ pub fn to_thrift(stats: Option<&Statistics>) -> Option { distinct_count: stats.distinct_count().map(|value| value as i64), max_value: None, min_value: None, + is_max_value_exact: None, + is_min_value_exact: None, }; // Get min/max if set. @@ -607,6 +609,8 @@ mod tests { distinct_count: None, max_value: None, min_value: None, + is_max_value_exact: None, + is_min_value_exact: None, }; from_thrift(Type::INT32, Some(thrift_stats)).unwrap(); diff --git a/parquet/src/format.rs b/parquet/src/format.rs index 46adc39e640..4700b05dc28 100644 --- a/parquet/src/format.rs +++ b/parquet/src/format.rs @@ -657,16 +657,26 @@ pub struct Statistics { pub null_count: Option, /// count of distinct values occurring pub distinct_count: Option, - /// Min and max values for the column, determined by its ColumnOrder. + /// Lower and upper bound values for the column, determined by its ColumnOrder. + /// + /// These may be the actual minimum and maximum values found on a page or column + /// chunk, but can also be (more compact) values that do not exist on a page or + /// column chunk. For example, instead of storing "Blart Versenwald III", a writer + /// may set min_value="B", max_value="C". Such more compact values must still be + /// valid values within the column's logical type. /// /// Values are encoded using PLAIN encoding, except that variable-length byte /// arrays do not include a length prefix. pub max_value: Option>, pub min_value: Option>, + /// If true, max_value is the actual maximum value for a column + pub is_max_value_exact: Option, + /// If true, min_value is the actual minimum value for a column + pub is_min_value_exact: Option, } impl Statistics { - pub fn new(max: F1, min: F2, null_count: F3, distinct_count: F4, max_value: F5, min_value: F6) -> Statistics where F1: Into>>, F2: Into>>, F3: Into>, F4: Into>, F5: Into>>, F6: Into>> { + pub fn new(max: F1, min: F2, null_count: F3, distinct_count: F4, max_value: F5, min_value: F6, is_max_value_exact: F7, is_min_value_exact: F8) -> Statistics where F1: Into>>, F2: Into>>, F3: Into>, F4: Into>, F5: Into>>, F6: Into>>, F7: Into>, F8: Into> { Statistics { max: max.into(), min: min.into(), @@ -674,6 +684,8 @@ impl Statistics { distinct_count: distinct_count.into(), max_value: max_value.into(), min_value: min_value.into(), + is_max_value_exact: is_max_value_exact.into(), + is_min_value_exact: is_min_value_exact.into(), } } } @@ -687,6 +699,8 @@ impl crate::thrift::TSerializable for Statistics { let mut f_4: Option = None; let mut f_5: Option> = None; let mut f_6: Option> = None; + let mut f_7: Option = None; + let mut f_8: Option = None; loop { let field_ident = i_prot.read_field_begin()?; if field_ident.field_type == TType::Stop { @@ -718,6 +732,14 @@ impl crate::thrift::TSerializable for Statistics { let val = i_prot.read_bytes()?; f_6 = Some(val); }, + 7 => { + let val = i_prot.read_bool()?; + f_7 = Some(val); + }, + 8 => { + let val = i_prot.read_bool()?; + f_8 = Some(val); + }, _ => { i_prot.skip(field_ident.field_type)?; }, @@ -732,6 +754,8 @@ impl crate::thrift::TSerializable for Statistics { distinct_count: f_4, max_value: f_5, min_value: f_6, + is_max_value_exact: f_7, + is_min_value_exact: f_8, }; Ok(ret) } @@ -768,6 +792,16 @@ impl crate::thrift::TSerializable for Statistics { o_prot.write_bytes(fld_var)?; o_prot.write_field_end()? } + if let Some(fld_var) = self.is_max_value_exact { + o_prot.write_field_begin(&TFieldIdentifier::new("is_max_value_exact", TType::Bool, 7))?; + o_prot.write_bool(fld_var)?; + o_prot.write_field_end()? + } + if let Some(fld_var) = self.is_min_value_exact { + o_prot.write_field_begin(&TFieldIdentifier::new("is_min_value_exact", TType::Bool, 8))?; + o_prot.write_bool(fld_var)?; + o_prot.write_field_end()? + } o_prot.write_field_stop()?; o_prot.write_struct_end() } @@ -996,6 +1030,43 @@ impl crate::thrift::TSerializable for DateType { } } +// +// Float16Type +// + +#[derive(Clone, Debug, Default, Eq, Hash, Ord, PartialEq, PartialOrd)] +pub struct Float16Type { +} + +impl Float16Type { + pub fn new() -> Float16Type { + Float16Type {} + } +} + +impl crate::thrift::TSerializable for Float16Type { + fn read_from_in_protocol(i_prot: &mut T) -> thrift::Result { + i_prot.read_struct_begin()?; + loop { + let field_ident = i_prot.read_field_begin()?; + if field_ident.field_type == TType::Stop { + break; + } + i_prot.skip(field_ident.field_type)?; + i_prot.read_field_end()?; + } + i_prot.read_struct_end()?; + let ret = Float16Type {}; + Ok(ret) + } + fn write_to_out_protocol(&self, o_prot: &mut T) -> thrift::Result<()> { + let struct_ident = TStructIdentifier::new("Float16Type"); + o_prot.write_struct_begin(&struct_ident)?; + o_prot.write_field_stop()?; + o_prot.write_struct_end() + } +} + // // NullType // @@ -1640,6 +1711,7 @@ pub enum LogicalType { JSON(JsonType), BSON(BsonType), UUID(UUIDType), + FLOAT16(Float16Type), } impl crate::thrift::TSerializable for LogicalType { @@ -1745,6 +1817,13 @@ impl crate::thrift::TSerializable for LogicalType { } received_field_count += 1; }, + 15 => { + let val = Float16Type::read_from_in_protocol(i_prot)?; + if ret.is_none() { + ret = Some(LogicalType::FLOAT16(val)); + } + received_field_count += 1; + }, _ => { i_prot.skip(field_ident.field_type)?; received_field_count += 1; @@ -1844,6 +1923,11 @@ impl crate::thrift::TSerializable for LogicalType { f.write_to_out_protocol(o_prot)?; o_prot.write_field_end()?; }, + LogicalType::FLOAT16(ref f) => { + o_prot.write_field_begin(&TFieldIdentifier::new("FLOAT16", TType::Struct, 15))?; + f.write_to_out_protocol(o_prot)?; + o_prot.write_field_end()?; + }, } o_prot.write_field_stop()?; o_prot.write_struct_end() diff --git a/parquet/src/record/api.rs b/parquet/src/record/api.rs index c7a0b09c37e..e4f473562e0 100644 --- a/parquet/src/record/api.rs +++ b/parquet/src/record/api.rs @@ -20,9 +20,11 @@ use std::fmt; use chrono::{TimeZone, Utc}; +use half::f16; +use num::traits::Float; use num_bigint::{BigInt, Sign}; -use crate::basic::{ConvertedType, Type as PhysicalType}; +use crate::basic::{ConvertedType, LogicalType, Type as PhysicalType}; use crate::data_type::{ByteArray, Decimal, Int96}; use crate::errors::{ParquetError, Result}; use crate::schema::types::ColumnDescPtr; @@ -121,6 +123,7 @@ pub trait RowAccessor { fn get_ushort(&self, i: usize) -> Result; fn get_uint(&self, i: usize) -> Result; fn get_ulong(&self, i: usize) -> Result; + fn get_float16(&self, i: usize) -> Result; fn get_float(&self, i: usize) -> Result; fn get_double(&self, i: usize) -> Result; fn get_timestamp_millis(&self, i: usize) -> Result; @@ -215,6 +218,8 @@ impl RowAccessor for Row { row_primitive_accessor!(get_ulong, ULong, u64); + row_primitive_accessor!(get_float16, Float16, f16); + row_primitive_accessor!(get_float, Float, f32); row_primitive_accessor!(get_double, Double, f64); @@ -293,6 +298,7 @@ pub trait ListAccessor { fn get_ushort(&self, i: usize) -> Result; fn get_uint(&self, i: usize) -> Result; fn get_ulong(&self, i: usize) -> Result; + fn get_float16(&self, i: usize) -> Result; fn get_float(&self, i: usize) -> Result; fn get_double(&self, i: usize) -> Result; fn get_timestamp_millis(&self, i: usize) -> Result; @@ -358,6 +364,8 @@ impl ListAccessor for List { list_primitive_accessor!(get_ulong, ULong, u64); + list_primitive_accessor!(get_float16, Float16, f16); + list_primitive_accessor!(get_float, Float, f32); list_primitive_accessor!(get_double, Double, f64); @@ -449,6 +457,8 @@ impl<'a> ListAccessor for MapList<'a> { map_list_primitive_accessor!(get_ulong, ULong, u64); + map_list_primitive_accessor!(get_float16, Float16, f16); + map_list_primitive_accessor!(get_float, Float, f32); map_list_primitive_accessor!(get_double, Double, f64); @@ -510,6 +520,8 @@ pub enum Field { UInt(u32), // Unsigned integer UINT_64. ULong(u64), + /// IEEE 16-bit floating point value. + Float16(f16), /// IEEE 32-bit floating point value. Float(f32), /// IEEE 64-bit floating point value. @@ -552,6 +564,7 @@ impl Field { Field::UShort(_) => "UShort", Field::UInt(_) => "UInt", Field::ULong(_) => "ULong", + Field::Float16(_) => "Float16", Field::Float(_) => "Float", Field::Double(_) => "Double", Field::Decimal(_) => "Decimal", @@ -636,8 +649,8 @@ impl Field { Field::Double(value) } - /// Converts Parquet BYTE_ARRAY type with converted type into either UTF8 string or - /// array of bytes. + /// Converts Parquet BYTE_ARRAY type with converted type into a UTF8 + /// string, decimal, float16, or an array of bytes. #[inline] pub fn convert_byte_array(descr: &ColumnDescPtr, value: ByteArray) -> Result { let field = match descr.physical_type() { @@ -666,6 +679,16 @@ impl Field { descr.type_precision(), descr.type_scale(), )), + ConvertedType::NONE if descr.logical_type() == Some(LogicalType::Float16) => { + if value.len() != 2 { + return Err(general_err!( + "Error reading FIXED_LEN_BYTE_ARRAY as FLOAT16. Length must be 2, got {}", + value.len() + )); + } + let bytes = [value.data()[0], value.data()[1]]; + Field::Float16(f16::from_le_bytes(bytes)) + } ConvertedType::NONE => Field::Bytes(value), _ => nyi!(descr, value), }, @@ -690,6 +713,9 @@ impl Field { Field::UShort(n) => Value::Number(serde_json::Number::from(*n)), Field::UInt(n) => Value::Number(serde_json::Number::from(*n)), Field::ULong(n) => Value::Number(serde_json::Number::from(*n)), + Field::Float16(n) => serde_json::Number::from_f64(f64::from(*n)) + .map(Value::Number) + .unwrap_or(Value::Null), Field::Float(n) => serde_json::Number::from_f64(f64::from(*n)) .map(Value::Number) .unwrap_or(Value::Null), @@ -736,6 +762,15 @@ impl fmt::Display for Field { Field::UShort(value) => write!(f, "{value}"), Field::UInt(value) => write!(f, "{value}"), Field::ULong(value) => write!(f, "{value}"), + Field::Float16(value) => { + if !value.is_finite() { + write!(f, "{value}") + } else if value.trunc() == value { + write!(f, "{value}.0") + } else { + write!(f, "{value}") + } + } Field::Float(value) => { if !(1e-15..=1e19).contains(&value) { write!(f, "{value:E}") @@ -1069,6 +1104,24 @@ mod tests { Field::Decimal(Decimal::from_bytes(value, 17, 5)) ); + // FLOAT16 + let descr = { + let tpe = PrimitiveTypeBuilder::new("col", PhysicalType::FIXED_LEN_BYTE_ARRAY) + .with_logical_type(Some(LogicalType::Float16)) + .with_length(2) + .build() + .unwrap(); + Arc::new(ColumnDescriptor::new( + Arc::new(tpe), + 0, + 0, + ColumnPath::from("col"), + )) + }; + let value = ByteArray::from(f16::PI); + let row = Field::convert_byte_array(&descr, value.clone()); + assert_eq!(row.unwrap(), Field::Float16(f16::PI)); + // NONE (FIXED_LEN_BYTE_ARRAY) let descr = make_column_descr![ PhysicalType::FIXED_LEN_BYTE_ARRAY, @@ -1145,6 +1198,18 @@ mod tests { check_datetime_conversion(2014, 11, 28, 21, 15, 12); } + #[test] + fn test_convert_float16_to_string() { + assert_eq!(format!("{}", Field::Float16(f16::ONE)), "1.0"); + assert_eq!(format!("{}", Field::Float16(f16::PI)), "3.140625"); + assert_eq!(format!("{}", Field::Float16(f16::MAX)), "65504.0"); + assert_eq!(format!("{}", Field::Float16(f16::NAN)), "NaN"); + assert_eq!(format!("{}", Field::Float16(f16::INFINITY)), "inf"); + assert_eq!(format!("{}", Field::Float16(f16::NEG_INFINITY)), "-inf"); + assert_eq!(format!("{}", Field::Float16(f16::ZERO)), "0.0"); + assert_eq!(format!("{}", Field::Float16(f16::NEG_ZERO)), "-0.0"); + } + #[test] fn test_convert_float_to_string() { assert_eq!(format!("{}", Field::Float(1.0)), "1.0"); @@ -1218,6 +1283,7 @@ mod tests { assert_eq!(format!("{}", Field::UShort(2)), "2"); assert_eq!(format!("{}", Field::UInt(3)), "3"); assert_eq!(format!("{}", Field::ULong(4)), "4"); + assert_eq!(format!("{}", Field::Float16(f16::E)), "2.71875"); assert_eq!(format!("{}", Field::Float(5.0)), "5.0"); assert_eq!(format!("{}", Field::Float(5.1234)), "5.1234"); assert_eq!(format!("{}", Field::Double(6.0)), "6.0"); @@ -1284,6 +1350,7 @@ mod tests { assert!(Field::UShort(2).is_primitive()); assert!(Field::UInt(3).is_primitive()); assert!(Field::ULong(4).is_primitive()); + assert!(Field::Float16(f16::E).is_primitive()); assert!(Field::Float(5.0).is_primitive()); assert!(Field::Float(5.1234).is_primitive()); assert!(Field::Double(6.0).is_primitive()); @@ -1344,6 +1411,7 @@ mod tests { ("15".to_string(), Field::TimestampMillis(1262391174000)), ("16".to_string(), Field::TimestampMicros(1262391174000000)), ("17".to_string(), Field::Decimal(Decimal::from_i32(4, 7, 2))), + ("18".to_string(), Field::Float16(f16::PI)), ]); assert_eq!("null", format!("{}", row.fmt(0))); @@ -1370,6 +1438,7 @@ mod tests { format!("{}", row.fmt(16)) ); assert_eq!("0.04", format!("{}", row.fmt(17))); + assert_eq!("3.140625", format!("{}", row.fmt(18))); } #[test] @@ -1429,6 +1498,7 @@ mod tests { Field::Bytes(ByteArray::from(vec![1, 2, 3, 4, 5])), ), ("o".to_string(), Field::Decimal(Decimal::from_i32(4, 7, 2))), + ("p".to_string(), Field::Float16(f16::from_f32(9.1))), ]); assert!(!row.get_bool(1).unwrap()); @@ -1445,6 +1515,7 @@ mod tests { assert_eq!("abc", row.get_string(12).unwrap()); assert_eq!(5, row.get_bytes(13).unwrap().len()); assert_eq!(7, row.get_decimal(14).unwrap().precision()); + assert!((f16::from_f32(9.1) - row.get_float16(15).unwrap()).abs() < f16::EPSILON); } #[test] @@ -1469,6 +1540,7 @@ mod tests { Field::Bytes(ByteArray::from(vec![1, 2, 3, 4, 5])), ), ("o".to_string(), Field::Decimal(Decimal::from_i32(4, 7, 2))), + ("p".to_string(), Field::Float16(f16::from_f32(9.1))), ]); for i in 0..row.len() { @@ -1583,6 +1655,9 @@ mod tests { let list = make_list(vec![Field::ULong(6), Field::ULong(7)]); assert_eq!(7, list.get_ulong(1).unwrap()); + let list = make_list(vec![Field::Float16(f16::PI)]); + assert!((f16::PI - list.get_float16(0).unwrap()).abs() < f16::EPSILON); + let list = make_list(vec![ Field::Float(8.1), Field::Float(9.2), @@ -1633,6 +1708,9 @@ mod tests { let list = make_list(vec![Field::ULong(6), Field::ULong(7)]); assert!(list.get_float(1).is_err()); + let list = make_list(vec![Field::Float16(f16::PI)]); + assert!(list.get_string(0).is_err()); + let list = make_list(vec![ Field::Float(8.1), Field::Float(9.2), @@ -1768,6 +1846,10 @@ mod tests { Field::ULong(4).to_json_value(), Value::Number(serde_json::Number::from(4)) ); + assert_eq!( + Field::Float16(f16::from_f32(5.0)).to_json_value(), + Value::Number(serde_json::Number::from_f64(5.0).unwrap()) + ); assert_eq!( Field::Float(5.0).to_json_value(), Value::Number(serde_json::Number::from_f64(5.0).unwrap()) diff --git a/parquet/src/schema/parser.rs b/parquet/src/schema/parser.rs index 5e213e3bb9e..dcef11aa66d 100644 --- a/parquet/src/schema/parser.rs +++ b/parquet/src/schema/parser.rs @@ -823,6 +823,7 @@ mod tests { message root { optional fixed_len_byte_array(5) f1 (DECIMAL(9, 3)); optional fixed_len_byte_array (16) f2 (DECIMAL (38, 18)); + optional fixed_len_byte_array (2) f3 (FLOAT16); } "; let message = parse(schema).unwrap(); @@ -855,6 +856,13 @@ mod tests { .build() .unwrap(), ), + Arc::new( + Type::primitive_type_builder("f3", PhysicalType::FIXED_LEN_BYTE_ARRAY) + .with_logical_type(Some(LogicalType::Float16)) + .with_length(2) + .build() + .unwrap(), + ), ]) .build() .unwrap(); diff --git a/parquet/src/schema/printer.rs b/parquet/src/schema/printer.rs index fe4757d41ae..2dec8a5be9f 100644 --- a/parquet/src/schema/printer.rs +++ b/parquet/src/schema/printer.rs @@ -270,6 +270,7 @@ fn print_logical_and_converted( LogicalType::Enum => "ENUM".to_string(), LogicalType::List => "LIST".to_string(), LogicalType::Map => "MAP".to_string(), + LogicalType::Float16 => "FLOAT16".to_string(), LogicalType::Unknown => "UNKNOWN".to_string(), }, None => { @@ -667,6 +668,15 @@ mod tests { .unwrap(), "OPTIONAL FIXED_LEN_BYTE_ARRAY (9) decimal (DECIMAL(19,4));", ), + ( + Type::primitive_type_builder("float16", PhysicalType::FIXED_LEN_BYTE_ARRAY) + .with_logical_type(Some(LogicalType::Float16)) + .with_length(2) + .with_repetition(Repetition::REQUIRED) + .build() + .unwrap(), + "REQUIRED FIXED_LEN_BYTE_ARRAY (2) float16 (FLOAT16);", + ), ]; types_and_strings.into_iter().for_each(|(field, expected)| { diff --git a/parquet/src/schema/types.rs b/parquet/src/schema/types.rs index 11c73542095..2f36deffbab 100644 --- a/parquet/src/schema/types.rs +++ b/parquet/src/schema/types.rs @@ -356,6 +356,14 @@ impl<'a> PrimitiveTypeBuilder<'a> { (LogicalType::Json, PhysicalType::BYTE_ARRAY) => {} (LogicalType::Bson, PhysicalType::BYTE_ARRAY) => {} (LogicalType::Uuid, PhysicalType::FIXED_LEN_BYTE_ARRAY) => {} + (LogicalType::Float16, PhysicalType::FIXED_LEN_BYTE_ARRAY) + if self.length == 2 => {} + (LogicalType::Float16, PhysicalType::FIXED_LEN_BYTE_ARRAY) => { + return Err(general_err!( + "FLOAT16 cannot annotate field '{}' because it is not a FIXED_LEN_BYTE_ARRAY(2) field", + self.name + )) + } (a, b) => { return Err(general_err!( "Cannot annotate {:?} from {} for field '{}'", @@ -1504,6 +1512,41 @@ mod tests { "Parquet error: Invalid FIXED_LEN_BYTE_ARRAY length: -1 for field 'foo'" ); } + + result = Type::primitive_type_builder("foo", PhysicalType::FIXED_LEN_BYTE_ARRAY) + .with_repetition(Repetition::REQUIRED) + .with_logical_type(Some(LogicalType::Float16)) + .with_length(2) + .build(); + assert!(result.is_ok()); + + // Can't be other than FIXED_LEN_BYTE_ARRAY for physical type + result = Type::primitive_type_builder("foo", PhysicalType::FLOAT) + .with_repetition(Repetition::REQUIRED) + .with_logical_type(Some(LogicalType::Float16)) + .with_length(2) + .build(); + assert!(result.is_err()); + if let Err(e) = result { + assert_eq!( + format!("{e}"), + "Parquet error: Cannot annotate Float16 from FLOAT for field 'foo'" + ); + } + + // Must have length 2 + result = Type::primitive_type_builder("foo", PhysicalType::FIXED_LEN_BYTE_ARRAY) + .with_repetition(Repetition::REQUIRED) + .with_logical_type(Some(LogicalType::Float16)) + .with_length(4) + .build(); + assert!(result.is_err()); + if let Err(e) = result { + assert_eq!( + format!("{e}"), + "Parquet error: FLOAT16 cannot annotate field 'foo' because it is not a FIXED_LEN_BYTE_ARRAY(2) field" + ); + } } #[test] @@ -1981,6 +2024,7 @@ mod tests { let message_type = " message conversions { REQUIRED INT64 id; + OPTIONAL FIXED_LEN_BYTE_ARRAY (2) f16 (FLOAT16); OPTIONAL group int_array_Array (LIST) { REPEATED group list { OPTIONAL group element (LIST) {