From 51108689d18cb61e23442b636d67326f15ef6a40 Mon Sep 17 00:00:00 2001 From: Weijun Huang Date: Tue, 21 Mar 2023 21:25:37 +0100 Subject: [PATCH] feat: add sort_dictionary --- arrow-cast/Cargo.toml | 1 + arrow-cast/src/cast.rs | 339 ++++++++++++++++++++++++++++++++++------- 2 files changed, 288 insertions(+), 52 deletions(-) diff --git a/arrow-cast/Cargo.toml b/arrow-cast/Cargo.toml index 53c62ffb60d3..c8468419bcc4 100644 --- a/arrow-cast/Cargo.toml +++ b/arrow-cast/Cargo.toml @@ -49,6 +49,7 @@ arrow-buffer = { version = "35.0.0", path = "../arrow-buffer" } arrow-data = { version = "35.0.0", path = "../arrow-data" } arrow-schema = { version = "35.0.0", path = "../arrow-schema" } arrow-select = { version = "35.0.0", path = "../arrow-select" } +arrow-ord = { version = "35.0.0", path = "../arrow-ord" } chrono = { version = "0.4.23", default-features = false, features = ["clock"] } num = { version = "0.4", default-features = false, features = ["std"] } lexical-core = { version = "^0.8", default-features = false, features = ["write-integers", "write-floats", "parse-integers", "parse-floats"] } diff --git a/arrow-cast/src/cast.rs b/arrow-cast/src/cast.rs index 43048c2aba45..e696489dadcb 100644 --- a/arrow-cast/src/cast.rs +++ b/arrow-cast/src/cast.rs @@ -49,6 +49,7 @@ use arrow_array::{ }; use arrow_buffer::{i256, ArrowNativeType, Buffer, MutableBuffer}; use arrow_data::ArrayData; +use arrow_ord::sort::sort; use arrow_schema::*; use arrow_select::take::take; use num::cast::AsPrimitive; @@ -59,9 +60,17 @@ use num::{NumCast, ToPrimitive}; pub struct CastOptions { /// how to handle cast failures, either return NULL (safe=true) or return ERR (safe=false) pub safe: bool, + // if the result is a dictionary array, the dictionary will be sorted + pub sort_dictionary: bool, + // if the result is a dictionary array, the dictionary will be unique + pub pack_dictionary: bool, } -pub const DEFAULT_CAST_OPTIONS: CastOptions = CastOptions { safe: true }; +pub const DEFAULT_CAST_OPTIONS: CastOptions = CastOptions { + safe: true, + sort_dictionary: false, + pack_dictionary: false, +}; /// Return true if a value of type `from_type` can be cast into a /// value of `to_type`. Note that such as cast may be lossy. @@ -3251,7 +3260,12 @@ where V: ArrowPrimitiveType, { // attempt to cast the source array values to the target value type (the dictionary values type) - let cast_values = cast_with_options(array, dict_value_type, cast_options)?; + let mut cast_values = cast_with_options(array, dict_value_type, cast_options)?; + // sort the values if requested + if cast_options.sort_dictionary { + cast_values = sort(&cast_values, None)?; + } + let values = cast_values .as_any() .downcast_ref::>() @@ -3259,8 +3273,6 @@ where let mut b = PrimitiveDictionaryBuilder::::with_capacity(values.len(), values.len()); - - // copy each element one at a time for i in 0..values.len() { if values.is_null(i) { b.append_null(); @@ -3281,7 +3293,11 @@ where K: ArrowDictionaryKeyType, T: ByteArrayType, { - let cast_values = cast_with_options(array, &T::DATA_TYPE, cast_options)?; + let mut cast_values = cast_with_options(array, &T::DATA_TYPE, cast_options)?; + // sort the values if requested + if cast_options.sort_dictionary { + cast_values = sort(&cast_values, None)?; + } let values = cast_values .as_any() .downcast_ref::>() @@ -3542,7 +3558,11 @@ mod tests { } } - let cast_option = CastOptions { safe: false }; + let cast_option = CastOptions { + safe: false, + sort_dictionary: false, + pack_dictionary: false, + }; let casted_array_with_option = cast_with_options($INPUT_ARRAY, $OUTPUT_TYPE, &cast_option).unwrap(); let result_array = casted_array_with_option @@ -3750,8 +3770,15 @@ mod tests { let array = vec![Some(i128::MAX)]; let array = create_decimal_array(array, 38, 3).unwrap(); - let result = - cast_with_options(&array, &output_type, &CastOptions { safe: false }); + let result = cast_with_options( + &array, + &output_type, + &CastOptions { + safe: false, + sort_dictionary: false, + pack_dictionary: false, + }, + ); assert_eq!("Cast error: Cannot cast to Decimal128(38, 38). Overflowing on 170141183460469231731687303715884105727", result.unwrap_err().to_string()); } @@ -3764,8 +3791,15 @@ mod tests { let array = vec![Some(i128::MAX)]; let array = create_decimal_array(array, 38, 3).unwrap(); - let result = - cast_with_options(&array, &output_type, &CastOptions { safe: false }); + let result = cast_with_options( + &array, + &output_type, + &CastOptions { + safe: false, + sort_dictionary: false, + pack_dictionary: false, + }, + ); assert_eq!("Cast error: Cannot cast to Decimal256(76, 76). Overflowing on 170141183460469231731687303715884105727", result.unwrap_err().to_string()); } @@ -3797,8 +3831,15 @@ mod tests { assert!(can_cast_types(&input_type, &output_type)); let array = vec![Some(i256::from_i128(i128::MAX))]; let array = create_decimal256_array(array, 76, 5).unwrap(); - let result = - cast_with_options(&array, &output_type, &CastOptions { safe: false }); + let result = cast_with_options( + &array, + &output_type, + &CastOptions { + safe: false, + sort_dictionary: false, + pack_dictionary: false, + }, + ); assert_eq!("Cast error: Cannot cast to Decimal128(38, 7). Overflowing on 170141183460469231731687303715884105727", result.unwrap_err().to_string()); } @@ -3810,8 +3851,15 @@ mod tests { assert!(can_cast_types(&input_type, &output_type)); let array = vec![Some(i256::from_i128(i128::MAX))]; let array = create_decimal256_array(array, 76, 5).unwrap(); - let result = - cast_with_options(&array, &output_type, &CastOptions { safe: false }); + let result = cast_with_options( + &array, + &output_type, + &CastOptions { + safe: false, + sort_dictionary: false, + pack_dictionary: false, + }, + ); assert_eq!("Cast error: Cannot cast to Decimal256(76, 55). Overflowing on 170141183460469231731687303715884105727", result.unwrap_err().to_string()); } @@ -3957,30 +4005,58 @@ mod tests { // overflow test: out of range of max u8 let value_array: Vec> = vec![Some(51300)]; let array = create_decimal_array(value_array, 38, 2).unwrap(); - let casted_array = - cast_with_options(&array, &DataType::UInt8, &CastOptions { safe: false }); + let casted_array = cast_with_options( + &array, + &DataType::UInt8, + &CastOptions { + safe: false, + sort_dictionary: false, + pack_dictionary: false, + }, + ); assert_eq!( "Cast error: value of 513 is out of range UInt8".to_string(), casted_array.unwrap_err().to_string() ); - let casted_array = - cast_with_options(&array, &DataType::UInt8, &CastOptions { safe: true }); + let casted_array = cast_with_options( + &array, + &DataType::UInt8, + &CastOptions { + safe: true, + sort_dictionary: false, + pack_dictionary: false, + }, + ); assert!(casted_array.is_ok()); assert!(casted_array.unwrap().is_null(0)); // overflow test: out of range of max i8 let value_array: Vec> = vec![Some(24400)]; let array = create_decimal_array(value_array, 38, 2).unwrap(); - let casted_array = - cast_with_options(&array, &DataType::Int8, &CastOptions { safe: false }); + let casted_array = cast_with_options( + &array, + &DataType::Int8, + &CastOptions { + safe: false, + sort_dictionary: false, + pack_dictionary: false, + }, + ); assert_eq!( "Cast error: value of 244 is out of range Int8".to_string(), casted_array.unwrap_err().to_string() ); - let casted_array = - cast_with_options(&array, &DataType::Int8, &CastOptions { safe: true }); + let casted_array = cast_with_options( + &array, + &DataType::Int8, + &CastOptions { + safe: true, + sort_dictionary: false, + pack_dictionary: false, + }, + ); assert!(casted_array.is_ok()); assert!(casted_array.unwrap().is_null(0)); @@ -4136,15 +4212,29 @@ mod tests { // overflow test: out of range of max i8 let value_array: Vec> = vec![Some(i256::from_i128(24400))]; let array = create_decimal256_array(value_array, 38, 2).unwrap(); - let casted_array = - cast_with_options(&array, &DataType::Int8, &CastOptions { safe: false }); + let casted_array = cast_with_options( + &array, + &DataType::Int8, + &CastOptions { + safe: false, + sort_dictionary: false, + pack_dictionary: false, + }, + ); assert_eq!( "Cast error: value of 244 is out of range Int8".to_string(), casted_array.unwrap_err().to_string() ); - let casted_array = - cast_with_options(&array, &DataType::Int8, &CastOptions { safe: true }); + let casted_array = cast_with_options( + &array, + &DataType::Int8, + &CastOptions { + safe: true, + sort_dictionary: false, + pack_dictionary: false, + }, + ); assert!(casted_array.is_ok()); assert!(casted_array.unwrap().is_null(0)); @@ -4560,7 +4650,11 @@ mod tests { fn test_cast_int32_to_u8_with_error() { let array = Int32Array::from(vec![-5, 6, -7, 8, 100000000]); // overflow with the error - let cast_option = CastOptions { safe: false }; + let cast_option = CastOptions { + safe: false, + sort_dictionary: false, + pack_dictionary: false, + }; let result = cast_with_options(&array, &DataType::UInt8, &cast_option); assert!(result.is_err()); result.unwrap(); @@ -4686,8 +4780,15 @@ mod tests { #[test] fn test_cast_with_options_utf8_to_i32() { let array = StringArray::from(vec!["5", "6", "seven", "8", "9.1"]); - let result = - cast_with_options(&array, &DataType::Int32, &CastOptions { safe: false }); + let result = cast_with_options( + &array, + &DataType::Int32, + &CastOptions { + safe: false, + sort_dictionary: false, + pack_dictionary: false, + }, + ); match result { Ok(_) => panic!("expected error"), Err(e) => { @@ -4713,8 +4814,15 @@ mod tests { #[test] fn test_cast_with_options_utf8_to_bool() { let strings = StringArray::from(vec!["true", "false", "invalid", " Y ", ""]); - let casted = - cast_with_options(&strings, &DataType::Boolean, &CastOptions { safe: false }); + let casted = cast_with_options( + &strings, + &DataType::Boolean, + &CastOptions { + safe: false, + sort_dictionary: false, + pack_dictionary: false, + }, + ); match casted { Ok(_) => panic!("expected error"), Err(e) => { @@ -4928,7 +5036,11 @@ mod tests { } } - let options = CastOptions { safe: false }; + let options = CastOptions { + safe: false, + sort_dictionary: false, + pack_dictionary: false, + }; let err = cast_with_options(array, &to_type, &options).unwrap_err(); assert_eq!( err.to_string(), @@ -4958,7 +5070,11 @@ mod tests { assert!(c.is_null(1)); assert!(c.is_null(2)); - let options = CastOptions { safe: false }; + let options = CastOptions { + safe: false, + sort_dictionary: false, + pack_dictionary: false, + }; let err = cast_with_options(array, &to_type, &options).unwrap_err(); assert_eq!(err.to_string(), "Cast error: Cannot cast string 'Not a valid date' to value of Date32 type"); } @@ -4990,7 +5106,11 @@ mod tests { assert!(c.is_null(3)); assert!(c.is_null(4)); - let options = CastOptions { safe: false }; + let options = CastOptions { + safe: false, + sort_dictionary: false, + pack_dictionary: false, + }; let err = cast_with_options(array, &to_type, &options).unwrap_err(); assert_eq!(err.to_string(), "Cast error: Cannot cast string '08:08:61.091323414' to value of Time32(Second) type"); } @@ -5022,7 +5142,11 @@ mod tests { assert!(c.is_null(3)); assert!(c.is_null(4)); - let options = CastOptions { safe: false }; + let options = CastOptions { + safe: false, + sort_dictionary: false, + pack_dictionary: false, + }; let err = cast_with_options(array, &to_type, &options).unwrap_err(); assert_eq!(err.to_string(), "Cast error: Cannot cast string '08:08:61.091323414' to value of Time32(Millisecond) type"); } @@ -5048,7 +5172,11 @@ mod tests { assert!(c.is_null(1)); assert!(c.is_null(2)); - let options = CastOptions { safe: false }; + let options = CastOptions { + safe: false, + sort_dictionary: false, + pack_dictionary: false, + }; let err = cast_with_options(array, &to_type, &options).unwrap_err(); assert_eq!(err.to_string(), "Cast error: Cannot cast string 'Not a valid time' to value of Time64(Microsecond) type"); } @@ -5074,7 +5202,11 @@ mod tests { assert!(c.is_null(1)); assert!(c.is_null(2)); - let options = CastOptions { safe: false }; + let options = CastOptions { + safe: false, + sort_dictionary: false, + pack_dictionary: false, + }; let err = cast_with_options(array, &to_type, &options).unwrap_err(); assert_eq!(err.to_string(), "Cast error: Cannot cast string 'Not a valid time' to value of Time64(Nanosecond) type"); } @@ -5100,7 +5232,11 @@ mod tests { assert!(c.is_null(1)); assert!(c.is_null(2)); - let options = CastOptions { safe: false }; + let options = CastOptions { + safe: false, + sort_dictionary: false, + pack_dictionary: false, + }; let err = cast_with_options(array, &to_type, &options).unwrap_err(); assert_eq!(err.to_string(), "Cast error: Cannot cast string 'Not a valid date' to value of Date64 type"); } @@ -5111,7 +5247,11 @@ mod tests { let source_string_array = Arc::new(StringArray::from($data_vec.clone())) as ArrayRef; - let options = CastOptions { safe: true }; + let options = CastOptions { + safe: true, + sort_dictionary: false, + pack_dictionary: false, + }; let target_interval_array = cast_with_options( &source_string_array.clone(), @@ -5235,7 +5375,11 @@ mod tests { macro_rules! test_unsafe_string_to_interval_err { ($data_vec:expr, $interval_unit:expr, $error_msg:expr) => { let string_array = Arc::new(StringArray::from($data_vec.clone())) as ArrayRef; - let options = CastOptions { safe: false }; + let options = CastOptions { + safe: false, + sort_dictionary: false, + pack_dictionary: false, + }; let arrow_err = cast_with_options( &string_array.clone(), &DataType::Interval($interval_unit), @@ -5384,7 +5528,11 @@ mod tests { assert!(b.is_null(0)); // test overflow, unsafe cast let array = TimestampSecondArray::from(vec![Some(i64::MAX)]); - let options = CastOptions { safe: false }; + let options = CastOptions { + safe: false, + sort_dictionary: false, + pack_dictionary: false, + }; let b = cast_with_options(&array, &DataType::Date64, &options); assert!(b.is_err()); } @@ -6876,6 +7024,49 @@ mod tests { assert_eq!(cast_array.data_type(), &Int64); } + #[test] + fn test_cast_primitvie_sorted_dict() { + use DataType::*; + + let mut builder = PrimitiveBuilder::::new(); + builder.append_value(3); + builder.append_null(); + builder.append_value(1); + let array: ArrayRef = Arc::new(builder.finish()); + + let expected = vec!["null", "1", "3"]; + + // Cast to a dictionary (same value type, Int32) + let cast_type = Dictionary(Box::new(UInt8), Box::new(Int32)); + let cast_array = cast_with_options( + &array, + &cast_type, + &CastOptions { + safe: true, + sort_dictionary: true, + pack_dictionary: false, + }, + ) + .expect("cast failed"); + assert_eq!(array_to_strings(&cast_array), expected); + assert_eq!(cast_array.data_type(), &cast_type); + + // Cast to a dictionary (different value type, Int8) + let cast_type = Dictionary(Box::new(UInt8), Box::new(Int8)); + let cast_array = cast_with_options( + &array, + &cast_type, + &CastOptions { + safe: true, + sort_dictionary: true, + pack_dictionary: false, + }, + ) + .expect("cast failed"); + assert_eq!(array_to_strings(&cast_array), expected); + assert_eq!(cast_array.data_type(), &cast_type); + } + #[test] fn test_cast_primitive_array_to_dict() { use DataType::*; @@ -7335,7 +7526,11 @@ mod tests { let casted_array = cast_with_options( &array, &DataType::Decimal128(38, 30), - &CastOptions { safe: true }, + &CastOptions { + safe: true, + sort_dictionary: false, + pack_dictionary: false, + }, ); assert!(casted_array.is_ok()); assert!(casted_array.unwrap().is_null(0)); @@ -7343,7 +7538,11 @@ mod tests { let casted_array = cast_with_options( &array, &DataType::Decimal128(38, 30), - &CastOptions { safe: false }, + &CastOptions { + safe: false, + sort_dictionary: false, + pack_dictionary: false, + }, ); assert!(casted_array.is_err()); } @@ -7355,7 +7554,11 @@ mod tests { let casted_array = cast_with_options( &array, &DataType::Decimal256(76, 76), - &CastOptions { safe: true }, + &CastOptions { + safe: true, + sort_dictionary: false, + pack_dictionary: false, + }, ); assert!(casted_array.is_ok()); assert!(casted_array.unwrap().is_null(0)); @@ -7363,7 +7566,11 @@ mod tests { let casted_array = cast_with_options( &array, &DataType::Decimal256(76, 76), - &CastOptions { safe: false }, + &CastOptions { + safe: false, + sort_dictionary: false, + pack_dictionary: false, + }, ); assert!(casted_array.is_err()); } @@ -7375,7 +7582,11 @@ mod tests { let casted_array = cast_with_options( &array, &DataType::Decimal128(38, 30), - &CastOptions { safe: true }, + &CastOptions { + safe: true, + sort_dictionary: false, + pack_dictionary: false, + }, ); assert!(casted_array.is_ok()); assert!(casted_array.unwrap().is_null(0)); @@ -7383,7 +7594,11 @@ mod tests { let casted_array = cast_with_options( &array, &DataType::Decimal128(38, 30), - &CastOptions { safe: false }, + &CastOptions { + safe: false, + sort_dictionary: false, + pack_dictionary: false, + }, ); let err = casted_array.unwrap_err().to_string(); let expected_error = "Cast error: Cannot cast to Decimal128(38, 30)"; @@ -7400,7 +7615,11 @@ mod tests { let casted_array = cast_with_options( &array, &DataType::Decimal256(76, 50), - &CastOptions { safe: true }, + &CastOptions { + safe: true, + sort_dictionary: false, + pack_dictionary: false, + }, ); assert!(casted_array.is_ok()); assert!(casted_array.unwrap().is_null(0)); @@ -7408,7 +7627,11 @@ mod tests { let casted_array = cast_with_options( &array, &DataType::Decimal256(76, 50), - &CastOptions { safe: false }, + &CastOptions { + safe: false, + sort_dictionary: false, + pack_dictionary: false, + }, ); let err = casted_array.unwrap_err().to_string(); let expected_error = "Cast error: Cannot cast to Decimal256(76, 50)"; @@ -7734,7 +7957,11 @@ mod tests { let output_type = DataType::Decimal128(38, 2); let str_array = StringArray::from(vec!["4.4.5"]); let array = Arc::new(str_array) as ArrayRef; - let option = CastOptions { safe: false }; + let option = CastOptions { + safe: false, + sort_dictionary: false, + pack_dictionary: false, + }; let casted_err = cast_with_options(&array, &output_type, &option).unwrap_err(); assert!(casted_err .to_string() @@ -7961,7 +8188,11 @@ mod tests { let b = cast_with_options( &array, &DataType::Timestamp(TimeUnit::Nanosecond, Some(tz)), - &CastOptions { safe: false }, + &CastOptions { + safe: false, + sort_dictionary: false, + pack_dictionary: false, + }, ) .unwrap(); @@ -7989,7 +8220,11 @@ mod tests { let v1: &[u8] = b"\xFF invalid"; let v2: &[u8] = b"\x00 Foo"; let s = BinaryArray::from(vec![v1, v2]); - let options = CastOptions { safe: true }; + let options = CastOptions { + safe: true, + sort_dictionary: false, + pack_dictionary: false, + }; let array = cast_with_options(&s, &DataType::Utf8, &options).unwrap(); let a = as_string_array(array.as_ref()); a.data().validate_full().unwrap();