Skip to content

Commit

Permalink
Parquet: read/write f16 for Arrow (#5003)
Browse files Browse the repository at this point in the history
* Support for read/write f16 Parquet to Arrow

* Update parquet/src/arrow/arrow_writer/mod.rs

Co-authored-by: Raphael Taylor-Davies <1781103+tustvold@users.noreply.github.com>

* Update parquet/src/arrow/arrow_reader/mod.rs

Co-authored-by: Raphael Taylor-Davies <1781103+tustvold@users.noreply.github.com>

* Update test with null version

* Fix schema tests and parsing for f16

* f16 for record api

* Handle NaN for f16 statistics writing

* Revert formatting changes

* Fix num trait

* Fix half feature

* Handle writing signed zero statistics

* Bump parquet-testing and read new f16 files for test

---------

Co-authored-by: Raphael Taylor-Davies <1781103+tustvold@users.noreply.github.com>
  • Loading branch information
Jefffrey and tustvold committed Nov 13, 2023
1 parent 924b6e9 commit 7ba36b0
Show file tree
Hide file tree
Showing 18 changed files with 646 additions and 25 deletions.
1 change: 1 addition & 0 deletions parquet/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ tokio = { version = "1.0", optional = true, default-features = false, features =
hashbrown = { version = "0.14", default-features = false }
twox-hash = { version = "1.6", default-features = false }
paste = { version = "1.0" }
half = { version = "2.1", default-features = false, features = ["num-traits"] }

[dev-dependencies]
base64 = { version = "0.21", default-features = false, features = ["std"] }
Expand Down
2 changes: 1 addition & 1 deletion parquet/regen.sh
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
# specific language governing permissions and limitations
# under the License.

REVISION=aeae80660c1d0c97314e9da837de1abdebd49c37
REVISION=46cc3a0647d301bb9579ca8dd2cc356caf2a72d2

SOURCE_DIR="$(cd "$(dirname "${BASH_SOURCE[0]:-$0}")" && pwd)"

Expand Down
17 changes: 16 additions & 1 deletion parquet/src/arrow/array_reader/fixed_len_byte_array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,13 +27,14 @@ use crate::column::reader::decoder::{ColumnValueDecoder, ValuesBufferSlice};
use crate::errors::{ParquetError, Result};
use crate::schema::types::ColumnDescPtr;
use arrow_array::{
ArrayRef, Decimal128Array, Decimal256Array, FixedSizeBinaryArray,
ArrayRef, Decimal128Array, Decimal256Array, FixedSizeBinaryArray, Float16Array,
IntervalDayTimeArray, IntervalYearMonthArray,
};
use arrow_buffer::{i256, Buffer};
use arrow_data::ArrayDataBuilder;
use arrow_schema::{DataType as ArrowType, IntervalUnit};
use bytes::Bytes;
use half::f16;
use std::any::Any;
use std::ops::Range;
use std::sync::Arc;
Expand Down Expand Up @@ -88,6 +89,14 @@ pub fn make_fixed_len_byte_array_reader(
));
}
}
ArrowType::Float16 => {
if byte_length != 2 {
return Err(general_err!(
"float 16 type must be 2 bytes, got {}",
byte_length
));
}
}
_ => {
return Err(general_err!(
"invalid data type for fixed length byte array reader - {}",
Expand Down Expand Up @@ -208,6 +217,12 @@ impl ArrayReader for FixedLenByteArrayReader {
}
}
}
ArrowType::Float16 => Arc::new(
binary
.iter()
.map(|o| o.map(|b| f16::from_le_bytes(b[..2].try_into().unwrap())))
.collect::<Float16Array>(),
) as ArrayRef,
_ => Arc::new(binary) as ArrayRef,
};

Expand Down
119 changes: 118 additions & 1 deletion parquet/src/arrow/arrow_reader/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -712,13 +712,14 @@ mod tests {
use std::sync::Arc;

use bytes::Bytes;
use half::f16;
use num::PrimInt;
use rand::{thread_rng, Rng, RngCore};
use tempfile::tempfile;

use arrow_array::builder::*;
use arrow_array::cast::AsArray;
use arrow_array::types::{Decimal128Type, Decimal256Type, DecimalType};
use arrow_array::types::{Decimal128Type, Decimal256Type, DecimalType, Float16Type};
use arrow_array::*;
use arrow_array::{RecordBatch, RecordBatchReader};
use arrow_buffer::{i256, ArrowNativeType, Buffer};
Expand Down Expand Up @@ -924,6 +925,66 @@ mod tests {
.unwrap();
}

#[test]
fn test_float16_roundtrip() -> Result<()> {
let schema = Arc::new(Schema::new(vec![
Field::new("float16", ArrowDataType::Float16, false),
Field::new("float16-nullable", ArrowDataType::Float16, true),
]));

let mut buf = Vec::with_capacity(1024);
let mut writer = ArrowWriter::try_new(&mut buf, schema.clone(), None)?;

let original = RecordBatch::try_new(
schema,
vec![
Arc::new(Float16Array::from_iter_values([
f16::EPSILON,
f16::MIN,
f16::MAX,
f16::NAN,
f16::INFINITY,
f16::NEG_INFINITY,
f16::ONE,
f16::NEG_ONE,
f16::ZERO,
f16::NEG_ZERO,
f16::E,
f16::PI,
f16::FRAC_1_PI,
])),
Arc::new(Float16Array::from(vec![
None,
None,
None,
Some(f16::NAN),
Some(f16::INFINITY),
Some(f16::NEG_INFINITY),
None,
None,
None,
None,
None,
None,
Some(f16::FRAC_1_PI),
])),
],
)?;

writer.write(&original)?;
writer.close()?;

let mut reader = ParquetRecordBatchReader::try_new(Bytes::from(buf), 1024)?;
let ret = reader.next().unwrap()?;
assert_eq!(ret, original);

// Ensure can be downcast to the correct type
ret.column(0).as_primitive::<Float16Type>();
ret.column(1).as_primitive::<Float16Type>();

Ok(())
}

struct RandFixedLenGen {}

impl RandGen<FixedLenByteArrayType> for RandFixedLenGen {
Expand Down Expand Up @@ -1255,6 +1316,62 @@ mod tests {
}
}

#[test]
fn test_read_float16_nonzeros_file() {
use arrow_array::Float16Array;
let testdata = arrow::util::test_util::parquet_test_data();
// see https://github.com/apache/parquet-testing/pull/40
let path = format!("{testdata}/float16_nonzeros_and_nans.parquet");
let file = File::open(path).unwrap();
let mut record_reader = ParquetRecordBatchReader::try_new(file, 32).unwrap();

let batch = record_reader.next().unwrap().unwrap();
assert_eq!(batch.num_rows(), 8);
let col = batch
.column(0)
.as_any()
.downcast_ref::<Float16Array>()
.unwrap();

let f16_two = f16::ONE + f16::ONE;

assert_eq!(col.null_count(), 1);
assert!(col.is_null(0));
assert_eq!(col.value(1), f16::ONE);
assert_eq!(col.value(2), -f16_two);
assert!(col.value(3).is_nan());
assert_eq!(col.value(4), f16::ZERO);
assert!(col.value(4).is_sign_positive());
assert_eq!(col.value(5), f16::NEG_ONE);
assert_eq!(col.value(6), f16::NEG_ZERO);
assert!(col.value(6).is_sign_negative());
assert_eq!(col.value(7), f16_two);
}

#[test]
fn test_read_float16_zeros_file() {
use arrow_array::Float16Array;
let testdata = arrow::util::test_util::parquet_test_data();
// see https://github.com/apache/parquet-testing/pull/40
let path = format!("{testdata}/float16_zeros_and_nans.parquet");
let file = File::open(path).unwrap();
let mut record_reader = ParquetRecordBatchReader::try_new(file, 32).unwrap();

let batch = record_reader.next().unwrap().unwrap();
assert_eq!(batch.num_rows(), 3);
let col = batch
.column(0)
.as_any()
.downcast_ref::<Float16Array>()
.unwrap();

assert_eq!(col.null_count(), 1);
assert!(col.is_null(0));
assert_eq!(col.value(1), f16::ZERO);
assert!(col.value(1).is_sign_positive());
assert!(col.value(2).is_nan());
}

/// Parameters for single_column_reader_test
#[derive(Clone)]
struct TestOptions {
Expand Down
16 changes: 16 additions & 0 deletions parquet/src/arrow/arrow_writer/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -771,6 +771,10 @@ fn write_leaf(writer: &mut ColumnWriter<'_>, levels: &ArrayLevels) -> Result<usi
.unwrap();
get_decimal_256_array_slice(array, indices)
}
ArrowDataType::Float16 => {
let array = column.as_primitive::<Float16Type>();
get_float_16_array_slice(array, indices)
}
_ => {
return Err(ParquetError::NYI(
"Attempting to write an Arrow type that is not yet implemented".to_string(),
Expand Down Expand Up @@ -867,6 +871,18 @@ fn get_decimal_256_array_slice(
values
}

fn get_float_16_array_slice(
array: &arrow_array::Float16Array,
indices: &[usize],
) -> Vec<FixedLenByteArray> {
let mut values = Vec::with_capacity(indices.len());
for i in indices {
let value = array.value(*i).to_le_bytes().to_vec();
values.push(FixedLenByteArray::from(ByteArray::from(value)));
}
values
}

fn get_fsb_array_slice(
array: &arrow_array::FixedSizeBinaryArray,
indices: &[usize],
Expand Down
17 changes: 15 additions & 2 deletions parquet/src/arrow/schema/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -373,7 +373,12 @@ fn arrow_to_parquet_type(field: &Field) -> Result<Type> {
.with_repetition(repetition)
.with_id(id)
.build(),
DataType::Float16 => Err(arrow_err!("Float16 arrays not supported")),
DataType::Float16 => Type::primitive_type_builder(name, PhysicalType::FIXED_LEN_BYTE_ARRAY)
.with_repetition(repetition)
.with_id(id)
.with_logical_type(Some(LogicalType::Float16))
.with_length(2)
.build(),
DataType::Float32 => Type::primitive_type_builder(name, PhysicalType::FLOAT)
.with_repetition(repetition)
.with_id(id)
Expand Down Expand Up @@ -604,9 +609,10 @@ mod tests {
REQUIRED INT32 uint8 (INTEGER(8,false));
REQUIRED INT32 uint16 (INTEGER(16,false));
REQUIRED INT32 int32;
REQUIRED INT64 int64 ;
REQUIRED INT64 int64;
OPTIONAL DOUBLE double;
OPTIONAL FLOAT float;
OPTIONAL FIXED_LEN_BYTE_ARRAY (2) float16 (FLOAT16);
OPTIONAL BINARY string (UTF8);
OPTIONAL BINARY string_2 (STRING);
OPTIONAL BINARY json (JSON);
Expand All @@ -628,6 +634,7 @@ mod tests {
Field::new("int64", DataType::Int64, false),
Field::new("double", DataType::Float64, true),
Field::new("float", DataType::Float32, true),
Field::new("float16", DataType::Float16, true),
Field::new("string", DataType::Utf8, true),
Field::new("string_2", DataType::Utf8, true),
Field::new("json", DataType::Utf8, true),
Expand Down Expand Up @@ -1303,6 +1310,7 @@ mod tests {
REQUIRED INT64 int64;
OPTIONAL DOUBLE double;
OPTIONAL FLOAT float;
OPTIONAL FIXED_LEN_BYTE_ARRAY (2) float16 (FLOAT16);
OPTIONAL BINARY string (UTF8);
REPEATED BOOLEAN bools;
OPTIONAL INT32 date (DATE);
Expand Down Expand Up @@ -1339,6 +1347,7 @@ mod tests {
Field::new("int64", DataType::Int64, false),
Field::new("double", DataType::Float64, true),
Field::new("float", DataType::Float32, true),
Field::new("float16", DataType::Float16, true),
Field::new("string", DataType::Utf8, true),
Field::new_list(
"bools",
Expand Down Expand Up @@ -1398,6 +1407,7 @@ mod tests {
REQUIRED INT64 int64;
OPTIONAL DOUBLE double;
OPTIONAL FLOAT float;
OPTIONAL FIXED_LEN_BYTE_ARRAY (2) float16 (FLOAT16);
OPTIONAL BINARY string (STRING);
OPTIONAL GROUP bools (LIST) {
REPEATED GROUP list {
Expand Down Expand Up @@ -1448,6 +1458,7 @@ mod tests {
Field::new("int64", DataType::Int64, false),
Field::new("double", DataType::Float64, true),
Field::new("float", DataType::Float32, true),
Field::new("float16", DataType::Float16, true),
Field::new("string", DataType::Utf8, true),
Field::new_list(
"bools",
Expand Down Expand Up @@ -1661,6 +1672,8 @@ mod tests {
vec![
Field::new("a", DataType::Int16, true),
Field::new("b", DataType::Float64, false),
Field::new("c", DataType::Float32, false),
Field::new("d", DataType::Float16, false),
]
.into(),
),
Expand Down
10 changes: 10 additions & 0 deletions parquet/src/arrow/schema/primitive.rs
Original file line number Diff line number Diff line change
Expand Up @@ -304,6 +304,16 @@ fn from_fixed_len_byte_array(
// would be incorrect if all 12 bytes of the interval are populated
Ok(DataType::Interval(IntervalUnit::DayTime))
}
(Some(LogicalType::Float16), _) => {
if type_length == 2 {
Ok(DataType::Float16)
} else {
Err(ParquetError::General(
"FLOAT16 logical type must be Fixed Length Byte Array with length 2"
.to_string(),
))
}
}
_ => Ok(DataType::FixedSizeBinary(type_length)),
}
}
Loading

0 comments on commit 7ba36b0

Please sign in to comment.