From 2026df418d4977bedb8592f28d08568e166c67b3 Mon Sep 17 00:00:00 2001 From: Bruce Ritchie Date: Thu, 14 Mar 2024 09:42:01 -0400 Subject: [PATCH 1/7] Fix to_timestamp benchmark --- datafusion/functions/benches/to_timestamp.rs | 173 ++++++++++--------- 1 file changed, 92 insertions(+), 81 deletions(-) diff --git a/datafusion/functions/benches/to_timestamp.rs b/datafusion/functions/benches/to_timestamp.rs index c83824526442..31d609dee9bc 100644 --- a/datafusion/functions/benches/to_timestamp.rs +++ b/datafusion/functions/benches/to_timestamp.rs @@ -17,97 +17,108 @@ extern crate criterion; +use std::sync::Arc; + +use arrow_array::builder::StringBuilder; +use arrow_array::ArrayRef; use criterion::{black_box, criterion_group, criterion_main, Criterion}; -use datafusion_expr::lit; -use datafusion_functions::expr_fn::to_timestamp; +use datafusion_expr::ColumnarValue; +use datafusion_functions::datetime::to_timestamp; fn criterion_benchmark(c: &mut Criterion) { c.bench_function("to_timestamp_no_formats", |b| { - let inputs = vec![ - lit("1997-01-31T09:26:56.123Z"), - lit("1997-01-31T09:26:56.123-05:00"), - lit("1997-01-31 09:26:56.123-05:00"), - lit("2023-01-01 04:05:06.789 -08"), - lit("1997-01-31T09:26:56.123"), - lit("1997-01-31 09:26:56.123"), - lit("1997-01-31 09:26:56"), - lit("1997-01-31 13:26:56"), - lit("1997-01-31 13:26:56+04:00"), - lit("1997-01-31"), - ]; + let mut inputs = StringBuilder::new(); + inputs.append_value("1997-01-31T09:26:56.123Z"); + inputs.append_value("1997-01-31T09:26:56.123-05:00"); + inputs.append_value("1997-01-31 09:26:56.123-05:00"); + inputs.append_value("2023-01-01 04:05:06.789 -08"); + inputs.append_value("1997-01-31T09:26:56.123"); + inputs.append_value("1997-01-31 09:26:56.123"); + inputs.append_value("1997-01-31 09:26:56"); + inputs.append_value("1997-01-31 13:26:56"); + inputs.append_value("1997-01-31 13:26:56+04:00"); + inputs.append_value("1997-01-31"); + + let string_array = ColumnarValue::Array(Arc::new(inputs.finish()) as ArrayRef); + b.iter(|| { - for i in inputs.iter() { - black_box(to_timestamp(vec![i.clone()])); - } - }); + black_box( + to_timestamp() + .invoke(&[string_array.clone()]) + .expect("to_timestamp should work on valid values"), + ) + }) }); c.bench_function("to_timestamp_with_formats", |b| { - let mut inputs = vec![]; - let mut format1 = vec![]; - let mut format2 = vec![]; - let mut format3 = vec![]; - - inputs.push(lit("1997-01-31T09:26:56.123Z")); - format1.push(lit("%+")); - format2.push(lit("%c")); - format3.push(lit("%Y-%m-%dT%H:%M:%S%.f%Z")); - - inputs.push(lit("1997-01-31T09:26:56.123-05:00")); - format1.push(lit("%+")); - format2.push(lit("%c")); - format3.push(lit("%Y-%m-%dT%H:%M:%S%.f%z")); - - inputs.push(lit("1997-01-31 09:26:56.123-05:00")); - format1.push(lit("%+")); - format2.push(lit("%c")); - format3.push(lit("%Y-%m-%d %H:%M:%S%.f%Z")); - - inputs.push(lit("2023-01-01 04:05:06.789 -08")); - format1.push(lit("%+")); - format2.push(lit("%c")); - format3.push(lit("%Y-%m-%d %H:%M:%S%.f %#z")); - - inputs.push(lit("1997-01-31T09:26:56.123")); - format1.push(lit("%+")); - format2.push(lit("%c")); - format3.push(lit("%Y-%m-%dT%H:%M:%S%.f")); - - inputs.push(lit("1997-01-31 09:26:56.123")); - format1.push(lit("%+")); - format2.push(lit("%c")); - format3.push(lit("%Y-%m-%d %H:%M:%S%.f")); - - inputs.push(lit("1997-01-31 09:26:56")); - format1.push(lit("%+")); - format2.push(lit("%c")); - format3.push(lit("%Y-%m-%d %H:%M:%S")); - - inputs.push(lit("1997-01-31 092656")); - format1.push(lit("%+")); - format2.push(lit("%c")); - format3.push(lit("%Y-%m-%d %H%M%S")); - - inputs.push(lit("1997-01-31 092656+04:00")); - format1.push(lit("%+")); - format2.push(lit("%c")); - format3.push(lit("%Y-%m-%d %H%M%S%:z")); - - inputs.push(lit("Sun Jul 8 00:34:60 2001")); - format1.push(lit("%+")); - format2.push(lit("%c")); - format3.push(lit("%Y-%m-%d 00:00:00")); - + let mut inputs = StringBuilder::new(); + let mut format1_builder = StringBuilder::with_capacity(2, 10); + let mut format2_builder = StringBuilder::with_capacity(2, 10); + let mut format3_builder = StringBuilder::with_capacity(2, 10); + + inputs.append_value("1997-01-31T09:26:56.123Z"); + format1_builder.append_value("%+"); + format2_builder.append_value("%c"); + format3_builder.append_value("%Y-%m-%dT%H:%M:%S%.f%Z"); + + inputs.append_value("1997-01-31T09:26:56.123-05:00"); + format1_builder.append_value("%+"); + format2_builder.append_value("%c"); + format3_builder.append_value("%Y-%m-%dT%H:%M:%S%.f%z"); + + inputs.append_value("1997-01-31 09:26:56.123-05:00"); + format1_builder.append_value("%+"); + format2_builder.append_value("%c"); + format3_builder.append_value("%Y-%m-%d %H:%M:%S%.f%Z"); + + inputs.append_value("2023-01-01 04:05:06.789 -08"); + format1_builder.append_value("%+"); + format2_builder.append_value("%c"); + format3_builder.append_value("%Y-%m-%d %H:%M:%S%.f %#z"); + + inputs.append_value("1997-01-31T09:26:56.123"); + format1_builder.append_value("%+"); + format2_builder.append_value("%c"); + format3_builder.append_value("%Y-%m-%dT%H:%M:%S%.f"); + + inputs.append_value("1997-01-31 09:26:56.123"); + format1_builder.append_value("%+"); + format2_builder.append_value("%c"); + format3_builder.append_value("%Y-%m-%d %H:%M:%S%.f"); + + inputs.append_value("1997-01-31 09:26:56"); + format1_builder.append_value("%+"); + format2_builder.append_value("%c"); + format3_builder.append_value("%Y-%m-%d %H:%M:%S"); + + inputs.append_value("1997-01-31 092656"); + format1_builder.append_value("%+"); + format2_builder.append_value("%c"); + format3_builder.append_value("%Y-%m-%d %H%M%S"); + + inputs.append_value("1997-01-31 092656+04:00"); + format1_builder.append_value("%+"); + format2_builder.append_value("%c"); + format3_builder.append_value("%Y-%m-%d %H%M%S%:z"); + + inputs.append_value("Sun Jul 8 00:34:60 2001"); + format1_builder.append_value("%+"); + format2_builder.append_value("%c"); + format3_builder.append_value("%Y-%m-%d 00:00:00"); + + let args = [ + ColumnarValue::Array(Arc::new(inputs.finish()) as ArrayRef), + ColumnarValue::Array(Arc::new(format1_builder.finish()) as ArrayRef), + ColumnarValue::Array(Arc::new(format2_builder.finish()) as ArrayRef), + ColumnarValue::Array(Arc::new(format3_builder.finish()) as ArrayRef), + ]; b.iter(|| { - inputs.iter().enumerate().for_each(|(idx, i)| { - black_box(to_timestamp(vec![ - i.clone(), - format1.get(idx).unwrap().clone(), - format2.get(idx).unwrap().clone(), - format3.get(idx).unwrap().clone(), - ])); - }) + black_box( + to_timestamp() + .invoke(&args.clone()) + .expect("to_timestamp should work on valid values"), + ) }) }); } From 6a450b4fa2caa523cb42580c912f742dd1a1ed2b Mon Sep 17 00:00:00 2001 From: Bruce Ritchie Date: Mon, 18 Mar 2024 10:43:02 -0400 Subject: [PATCH 2/7] Remove reference to simd and nightly build as simd is no longer an available feature in DataFusion and building with nightly may not be a good recommendation when getting started. --- docs/source/user-guide/example-usage.md | 8 +------- 1 file changed, 1 insertion(+), 7 deletions(-) diff --git a/docs/source/user-guide/example-usage.md b/docs/source/user-guide/example-usage.md index 1c5c8f49a16a..c5eefbdaf156 100644 --- a/docs/source/user-guide/example-usage.md +++ b/docs/source/user-guide/example-usage.md @@ -240,17 +240,11 @@ async fn main() -> datafusion::error::Result<()> { } ``` -Finally, in order to build with the `simd` optimization `cargo nightly` is required. - -```shell -rustup toolchain install nightly -``` - Based on the instruction set architecture you are building on you will want to configure the `target-cpu` as well, ideally with `native` or at least `avx2`. ```shell -RUSTFLAGS='-C target-cpu=native' cargo +nightly run --release +RUSTFLAGS='-C target-cpu=native' cargo run --release ``` ## Enable backtraces From a94a4f6c3317e8e952d34d968996fbd603cd0c2e Mon Sep 17 00:00:00 2001 From: Bruce Ritchie Date: Fri, 22 Mar 2024 22:20:35 -0400 Subject: [PATCH 3/7] Fixed missing trim() function. --- datafusion/functions/src/string/mod.rs | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/datafusion/functions/src/string/mod.rs b/datafusion/functions/src/string/mod.rs index 63026092f39a..517869a25682 100644 --- a/datafusion/functions/src/string/mod.rs +++ b/datafusion/functions/src/string/mod.rs @@ -72,6 +72,11 @@ pub mod expr_fn { super::to_hex().call(vec![arg1]) } + #[doc = "Removes all characters, spaces by default, from both sides of a string"] + pub fn trim(args: Vec) -> Expr { + super::btrim().call(args) + } + #[doc = "Converts a string to uppercase."] pub fn upper(arg1: Expr) -> Expr { super::upper().call(vec![arg1]) From e3860fa52a6118720d42b74305bc92b2ace58f43 Mon Sep 17 00:00:00 2001 From: Bruce Ritchie Date: Wed, 27 Mar 2024 12:08:02 -0400 Subject: [PATCH 4/7] Create unicode module in datafusion/functions/src/unicode and unicode_expressions feature flag, move char_length function --- datafusion-cli/Cargo.lock | 1 + datafusion/core/Cargo.toml | 1 + .../tests/dataframe/dataframe_functions.rs | 1 + datafusion/expr/src/built_in_function.rs | 14 +- datafusion/expr/src/expr_fn.rs | 8 - datafusion/functions/Cargo.toml | 4 + datafusion/functions/src/lib.rs | 9 + datafusion/functions/src/string/ascii.rs | 2 +- datafusion/functions/src/string/bit_length.rs | 4 +- datafusion/functions/src/string/btrim.rs | 1 + datafusion/functions/src/string/chr.rs | 2 +- datafusion/functions/src/string/common.rs | 158 +--------------- .../functions/src/string/levenshtein.rs | 3 +- datafusion/functions/src/string/lower.rs | 8 +- datafusion/functions/src/string/ltrim.rs | 3 +- .../functions/src/string/octet_length.rs | 13 +- datafusion/functions/src/string/overlay.rs | 2 +- datafusion/functions/src/string/repeat.rs | 4 +- datafusion/functions/src/string/replace.rs | 2 +- datafusion/functions/src/string/rtrim.rs | 1 + datafusion/functions/src/string/split_part.rs | 4 +- .../functions/src/string/starts_with.rs | 9 +- datafusion/functions/src/string/to_hex.rs | 9 +- datafusion/functions/src/string/upper.rs | 3 +- .../functions/src/unicode/character_length.rs | 176 +++++++++++++++++ datafusion/functions/src/unicode/mod.rs | 55 ++++++ datafusion/functions/src/utils.rs | 178 ++++++++++++++++++ datafusion/physical-expr/src/functions.rs | 70 ------- .../physical-expr/src/unicode_expressions.rs | 23 --- datafusion/proto/proto/datafusion.proto | 2 +- datafusion/proto/src/generated/pbjson.rs | 3 - datafusion/proto/src/generated/prost.rs | 4 +- .../proto/src/logical_plan/from_proto.rs | 8 +- datafusion/proto/src/logical_plan/to_proto.rs | 1 - datafusion/sql/Cargo.toml | 1 + datafusion/sql/tests/sql_integration.rs | 15 +- 36 files changed, 484 insertions(+), 318 deletions(-) create mode 100644 datafusion/functions/src/unicode/character_length.rs create mode 100644 datafusion/functions/src/unicode/mod.rs create mode 100644 datafusion/functions/src/utils.rs diff --git a/datafusion-cli/Cargo.lock b/datafusion-cli/Cargo.lock index 2f1d95d639d4..424dda7fdc61 100644 --- a/datafusion-cli/Cargo.lock +++ b/datafusion-cli/Cargo.lock @@ -1273,6 +1273,7 @@ dependencies = [ "md-5", "regex", "sha2", + "unicode-segmentation", "uuid", ] diff --git a/datafusion/core/Cargo.toml b/datafusion/core/Cargo.toml index 1e5c0d748e3d..de03579975a2 100644 --- a/datafusion/core/Cargo.toml +++ b/datafusion/core/Cargo.toml @@ -70,6 +70,7 @@ unicode_expressions = [ "datafusion-physical-expr/unicode_expressions", "datafusion-optimizer/unicode_expressions", "datafusion-sql/unicode_expressions", + "datafusion-functions/unicode_expressions", ] [dependencies] diff --git a/datafusion/core/tests/dataframe/dataframe_functions.rs b/datafusion/core/tests/dataframe/dataframe_functions.rs index 6ebd64c9b628..4371cce856ce 100644 --- a/datafusion/core/tests/dataframe/dataframe_functions.rs +++ b/datafusion/core/tests/dataframe/dataframe_functions.rs @@ -37,6 +37,7 @@ use datafusion::assert_batches_eq; use datafusion_common::DFSchema; use datafusion_expr::expr::Alias; use datafusion_expr::{approx_median, cast, ExprSchemable}; +use datafusion_functions::unicode::expr_fn::character_length; fn test_schema() -> SchemaRef { Arc::new(Schema::new(vec![ diff --git a/datafusion/expr/src/built_in_function.rs b/datafusion/expr/src/built_in_function.rs index bb0f79f8eca4..eefbc131a27b 100644 --- a/datafusion/expr/src/built_in_function.rs +++ b/datafusion/expr/src/built_in_function.rs @@ -103,8 +103,6 @@ pub enum BuiltinScalarFunction { Cot, // string functions - /// character_length - CharacterLength, /// concat Concat, /// concat_ws @@ -218,7 +216,6 @@ impl BuiltinScalarFunction { BuiltinScalarFunction::Cbrt => Volatility::Immutable, BuiltinScalarFunction::Cot => Volatility::Immutable, BuiltinScalarFunction::Trunc => Volatility::Immutable, - BuiltinScalarFunction::CharacterLength => Volatility::Immutable, BuiltinScalarFunction::Concat => Volatility::Immutable, BuiltinScalarFunction::ConcatWithSeparator => Volatility::Immutable, BuiltinScalarFunction::EndsWith => Volatility::Immutable, @@ -257,9 +254,6 @@ impl BuiltinScalarFunction { // the return type of the built in function. // Some built-in functions' return type depends on the incoming type. match self { - BuiltinScalarFunction::CharacterLength => { - utf8_to_int_type(&input_expr_types[0], "character_length") - } BuiltinScalarFunction::Coalesce => { // COALESCE has multiple args and they might get coerced, get a preview of this let coerced_types = data_types(input_expr_types, &self.signature()); @@ -367,9 +361,7 @@ impl BuiltinScalarFunction { BuiltinScalarFunction::Coalesce => { Signature::variadic_equal(self.volatility()) } - BuiltinScalarFunction::CharacterLength - | BuiltinScalarFunction::InitCap - | BuiltinScalarFunction::Reverse => { + BuiltinScalarFunction::InitCap | BuiltinScalarFunction::Reverse => { Signature::uniform(1, vec![Utf8, LargeUtf8], self.volatility()) } BuiltinScalarFunction::Lpad | BuiltinScalarFunction::Rpad => { @@ -584,10 +576,6 @@ impl BuiltinScalarFunction { // conditional functions BuiltinScalarFunction::Coalesce => &["coalesce"], - // string functions - BuiltinScalarFunction::CharacterLength => { - &["character_length", "char_length", "length"] - } BuiltinScalarFunction::Concat => &["concat"], BuiltinScalarFunction::ConcatWithSeparator => &["concat_ws"], BuiltinScalarFunction::EndsWith => &["ends_with"], diff --git a/datafusion/expr/src/expr_fn.rs b/datafusion/expr/src/expr_fn.rs index 0ea946288e0f..654464798625 100644 --- a/datafusion/expr/src/expr_fn.rs +++ b/datafusion/expr/src/expr_fn.rs @@ -577,13 +577,6 @@ scalar_expr!(Power, power, base exponent, "`base` raised to the power of `expone scalar_expr!(Atan2, atan2, y x, "inverse tangent of a division given in the argument"); scalar_expr!(Log, log, base x, "logarithm of a `x` for a particular `base`"); -// string functions -scalar_expr!( - CharacterLength, - character_length, - string, - "the number of characters in the `string`" -); scalar_expr!(InitCap, initcap, string, "converts the first letter of each word in `string` in uppercase and the remaining characters in lowercase"); scalar_expr!(Left, left, string n, "returns the first `n` characters in the `string`"); scalar_expr!(Reverse, reverse, string, "reverses the `string`"); @@ -1032,7 +1025,6 @@ mod test { test_scalar_expr!(Nanvl, nanvl, x, y); test_scalar_expr!(Iszero, iszero, input); - test_scalar_expr!(CharacterLength, character_length, string); test_scalar_expr!(Gcd, gcd, arg_1, arg_2); test_scalar_expr!(Lcm, lcm, arg_1, arg_2); test_scalar_expr!(InitCap, initcap, string); diff --git a/datafusion/functions/Cargo.toml b/datafusion/functions/Cargo.toml index 81050dfddf66..0cab0276ff4b 100644 --- a/datafusion/functions/Cargo.toml +++ b/datafusion/functions/Cargo.toml @@ -43,6 +43,7 @@ default = [ "regex_expressions", "crypto_expressions", "string_expressions", + "unicode_expressions", ] # enable encode/decode functions encoding_expressions = ["base64", "hex"] @@ -52,6 +53,8 @@ math_expressions = [] regex_expressions = ["regex"] # enable string functions string_expressions = [] +# enable unicode functions +unicode_expressions = ["unicode-segmentation"] [lib] name = "datafusion_functions" @@ -75,6 +78,7 @@ log = { workspace = true } md-5 = { version = "^0.10.0", optional = true } regex = { version = "1.8", optional = true } sha2 = { version = "^0.10.1", optional = true } +unicode-segmentation = { version = "^1.7.1", optional = true } uuid = { version = "1.7", features = ["v4"] } [dev-dependencies] diff --git a/datafusion/functions/src/lib.rs b/datafusion/functions/src/lib.rs index f469b343e144..2a00839dc532 100644 --- a/datafusion/functions/src/lib.rs +++ b/datafusion/functions/src/lib.rs @@ -124,6 +124,12 @@ make_stub_package!(regex, "regex_expressions"); pub mod crypto; make_stub_package!(crypto, "crypto_expressions"); +#[cfg(feature = "unicode_expressions")] +pub mod unicode; +make_stub_package!(unicode, "unicode_expressions"); + +mod utils; + /// Fluent-style API for creating `Expr`s pub mod expr_fn { #[cfg(feature = "core_expressions")] @@ -140,6 +146,8 @@ pub mod expr_fn { pub use super::regex::expr_fn::*; #[cfg(feature = "string_expressions")] pub use super::string::expr_fn::*; + #[cfg(feature = "unicode_expressions")] + pub use super::unicode::expr_fn::*; } /// Registers all enabled packages with a [`FunctionRegistry`] @@ -151,6 +159,7 @@ pub fn register_all(registry: &mut dyn FunctionRegistry) -> Result<()> { .chain(math::functions()) .chain(regex::functions()) .chain(crypto::functions()) + .chain(unicode::functions()) .chain(string::functions()); all_functions.try_for_each(|udf| { diff --git a/datafusion/functions/src/string/ascii.rs b/datafusion/functions/src/string/ascii.rs index 5bd77833a935..9a07f4c19cf1 100644 --- a/datafusion/functions/src/string/ascii.rs +++ b/datafusion/functions/src/string/ascii.rs @@ -15,7 +15,7 @@ // specific language governing permissions and limitations // under the License. -use crate::string::common::make_scalar_function; +use crate::utils::make_scalar_function; use arrow::array::Int32Array; use arrow::array::{ArrayRef, OffsetSizeTrait}; use arrow::datatypes::DataType; diff --git a/datafusion/functions/src/string/bit_length.rs b/datafusion/functions/src/string/bit_length.rs index 9f612751584e..6a200471d42d 100644 --- a/datafusion/functions/src/string/bit_length.rs +++ b/datafusion/functions/src/string/bit_length.rs @@ -15,16 +15,16 @@ // specific language governing permissions and limitations // under the License. -use arrow::compute::kernels::length::bit_length; use std::any::Any; +use arrow::compute::kernels::length::bit_length; use arrow::datatypes::DataType; use datafusion_common::{exec_err, Result, ScalarValue}; use datafusion_expr::{ColumnarValue, Volatility}; use datafusion_expr::{ScalarUDFImpl, Signature}; -use crate::string::common::*; +use crate::utils::utf8_to_int_type; #[derive(Debug)] pub(super) struct BitLengthFunc { diff --git a/datafusion/functions/src/string/btrim.rs b/datafusion/functions/src/string/btrim.rs index de1c9cc69b72..573a23d07021 100644 --- a/datafusion/functions/src/string/btrim.rs +++ b/datafusion/functions/src/string/btrim.rs @@ -26,6 +26,7 @@ use datafusion_expr::{ColumnarValue, Volatility}; use datafusion_expr::{ScalarUDFImpl, Signature}; use crate::string::common::*; +use crate::utils::{make_scalar_function, utf8_to_str_type}; /// Returns the longest string with leading and trailing characters removed. If the characters are not specified, whitespace is removed. /// btrim('xyxtrimyyx', 'xyz') = 'trim' diff --git a/datafusion/functions/src/string/chr.rs b/datafusion/functions/src/string/chr.rs index df3b803ba659..d1f8dc398a2b 100644 --- a/datafusion/functions/src/string/chr.rs +++ b/datafusion/functions/src/string/chr.rs @@ -29,7 +29,7 @@ use datafusion_common::{exec_err, Result}; use datafusion_expr::{ColumnarValue, Volatility}; use datafusion_expr::{ScalarUDFImpl, Signature}; -use crate::string::common::*; +use crate::utils::make_scalar_function; /// Returns the character with the given code. chr(0) is disallowed because text data types cannot store that character. /// chr(65) = 'A' diff --git a/datafusion/functions/src/string/common.rs b/datafusion/functions/src/string/common.rs index 339f4e6c1a23..276aad121df2 100644 --- a/datafusion/functions/src/string/common.rs +++ b/datafusion/functions/src/string/common.rs @@ -24,8 +24,7 @@ use arrow::datatypes::DataType; use datafusion_common::cast::as_generic_string_array; use datafusion_common::Result; use datafusion_common::{exec_err, ScalarValue}; -use datafusion_expr::{ColumnarValue, ScalarFunctionImplementation}; -use datafusion_physical_expr::functions::Hint; +use datafusion_expr::ColumnarValue; pub(crate) enum TrimType { Left, @@ -98,52 +97,6 @@ pub(crate) fn general_trim( } } -/// Creates a function to identify the optimal return type of a string function given -/// the type of its first argument. -/// -/// If the input type is `LargeUtf8` or `LargeBinary` the return type is -/// `$largeUtf8Type`, -/// -/// If the input type is `Utf8` or `Binary` the return type is `$utf8Type`, -macro_rules! get_optimal_return_type { - ($FUNC:ident, $largeUtf8Type:expr, $utf8Type:expr) => { - pub(crate) fn $FUNC(arg_type: &DataType, name: &str) -> Result { - Ok(match arg_type { - // LargeBinary inputs are automatically coerced to Utf8 - DataType::LargeUtf8 | DataType::LargeBinary => $largeUtf8Type, - // Binary inputs are automatically coerced to Utf8 - DataType::Utf8 | DataType::Binary => $utf8Type, - DataType::Null => DataType::Null, - DataType::Dictionary(_, value_type) => match **value_type { - DataType::LargeUtf8 | DataType::LargeBinary => $largeUtf8Type, - DataType::Utf8 | DataType::Binary => $utf8Type, - DataType::Null => DataType::Null, - _ => { - return datafusion_common::exec_err!( - "The {} function can only accept strings, but got {:?}.", - name.to_uppercase(), - **value_type - ); - } - }, - data_type => { - return datafusion_common::exec_err!( - "The {} function can only accept strings, but got {:?}.", - name.to_uppercase(), - data_type - ); - } - }) - } - }; -} - -// `utf8_to_str_type`: returns either a Utf8 or LargeUtf8 based on the input type size. -get_optimal_return_type!(utf8_to_str_type, DataType::LargeUtf8, DataType::Utf8); - -// `utf8_to_int_type`: returns either a Int32 or Int64 based on the input type size. -get_optimal_return_type!(utf8_to_int_type, DataType::Int64, DataType::Int32); - /// applies a unary expression to `args[0]` that is expected to be downcastable to /// a `GenericStringArray` and returns a `GenericStringArray` (which may have a different offset) /// # Errors @@ -221,112 +174,3 @@ where }, } } - -pub(super) fn make_scalar_function( - inner: F, - hints: Vec, -) -> ScalarFunctionImplementation -where - F: Fn(&[ArrayRef]) -> Result + Sync + Send + 'static, -{ - Arc::new(move |args: &[ColumnarValue]| { - // first, identify if any of the arguments is an Array. If yes, store its `len`, - // as any scalar will need to be converted to an array of len `len`. - let len = args - .iter() - .fold(Option::::None, |acc, arg| match arg { - ColumnarValue::Scalar(_) => acc, - ColumnarValue::Array(a) => Some(a.len()), - }); - - let is_scalar = len.is_none(); - - let inferred_length = len.unwrap_or(1); - let args = args - .iter() - .zip(hints.iter().chain(std::iter::repeat(&Hint::Pad))) - .map(|(arg, hint)| { - // Decide on the length to expand this scalar to depending - // on the given hints. - let expansion_len = match hint { - Hint::AcceptsSingular => 1, - Hint::Pad => inferred_length, - }; - arg.clone().into_array(expansion_len) - }) - .collect::>>()?; - - let result = (inner)(&args); - if is_scalar { - // If all inputs are scalar, keeps output as scalar - let result = result.and_then(|arr| ScalarValue::try_from_array(&arr, 0)); - result.map(ColumnarValue::Scalar) - } else { - result.map(ColumnarValue::Array) - } - }) -} - -#[cfg(test)] -pub mod test { - /// $FUNC ScalarUDFImpl to test - /// $ARGS arguments (vec) to pass to function - /// $EXPECTED a Result - /// $EXPECTED_TYPE is the expected value type - /// $EXPECTED_DATA_TYPE is the expected result type - /// $ARRAY_TYPE is the column type after function applied - macro_rules! test_function { - ($FUNC:expr, $ARGS:expr, $EXPECTED:expr, $EXPECTED_TYPE:ty, $EXPECTED_DATA_TYPE:expr, $ARRAY_TYPE:ident) => { - let expected: Result> = $EXPECTED; - let func = $FUNC; - - let type_array = $ARGS.iter().map(|arg| arg.data_type()).collect::>(); - let return_type = func.return_type(&type_array); - - match expected { - Ok(expected) => { - assert_eq!(return_type.is_ok(), true); - assert_eq!(return_type.unwrap(), $EXPECTED_DATA_TYPE); - - let result = func.invoke($ARGS); - assert_eq!(result.is_ok(), true); - - let len = $ARGS - .iter() - .fold(Option::::None, |acc, arg| match arg { - ColumnarValue::Scalar(_) => acc, - ColumnarValue::Array(a) => Some(a.len()), - }); - let inferred_length = len.unwrap_or(1); - let result = result.unwrap().clone().into_array(inferred_length).expect("Failed to convert to array"); - let result = result.as_any().downcast_ref::<$ARRAY_TYPE>().expect("Failed to convert to type"); - - // value is correct - match expected { - Some(v) => assert_eq!(result.value(0), v), - None => assert!(result.is_null(0)), - }; - } - Err(expected_error) => { - if return_type.is_err() { - match return_type { - Ok(_) => assert!(false, "expected error"), - Err(error) => { datafusion_common::assert_contains!(expected_error.strip_backtrace(), error.strip_backtrace()); } - } - } - else { - // invoke is expected error - cannot use .expect_err() due to Debug not being implemented - match func.invoke($ARGS) { - Ok(_) => assert!(false, "expected error"), - Err(error) => { - assert!(expected_error.strip_backtrace().starts_with(&error.strip_backtrace())); - } - } - } - } - }; - }; - } - - pub(crate) use test_function; -} diff --git a/datafusion/functions/src/string/levenshtein.rs b/datafusion/functions/src/string/levenshtein.rs index b5de4b28948f..8f497e73e393 100644 --- a/datafusion/functions/src/string/levenshtein.rs +++ b/datafusion/functions/src/string/levenshtein.rs @@ -21,6 +21,7 @@ use std::sync::Arc; use arrow::array::{ArrayRef, Int32Array, Int64Array, OffsetSizeTrait}; use arrow::datatypes::DataType; +use crate::utils::{make_scalar_function, utf8_to_int_type}; use datafusion_common::cast::as_generic_string_array; use datafusion_common::utils::datafusion_strsim; use datafusion_common::{exec_err, Result}; @@ -28,8 +29,6 @@ use datafusion_expr::ColumnarValue; use datafusion_expr::TypeSignature::*; use datafusion_expr::{ScalarUDFImpl, Signature, Volatility}; -use crate::string::common::{make_scalar_function, utf8_to_int_type}; - #[derive(Debug)] pub(super) struct LevenshteinFunc { signature: Signature, diff --git a/datafusion/functions/src/string/lower.rs b/datafusion/functions/src/string/lower.rs index 42bda0470067..327772bd808d 100644 --- a/datafusion/functions/src/string/lower.rs +++ b/datafusion/functions/src/string/lower.rs @@ -15,12 +15,16 @@ // specific language governing permissions and limitations // under the License. -use crate::string::common::{handle, utf8_to_str_type}; +use std::any::Any; + use arrow::datatypes::DataType; + use datafusion_common::Result; use datafusion_expr::ColumnarValue; use datafusion_expr::{ScalarUDFImpl, Signature, Volatility}; -use std::any::Any; + +use crate::string::common::handle; +use crate::utils::utf8_to_str_type; #[derive(Debug)] pub(super) struct LowerFunc { diff --git a/datafusion/functions/src/string/ltrim.rs b/datafusion/functions/src/string/ltrim.rs index 535ffb14f5f5..e6926e5bd56e 100644 --- a/datafusion/functions/src/string/ltrim.rs +++ b/datafusion/functions/src/string/ltrim.rs @@ -15,9 +15,9 @@ // specific language governing permissions and limitations // under the License. -use arrow::array::{ArrayRef, OffsetSizeTrait}; use std::any::Any; +use arrow::array::{ArrayRef, OffsetSizeTrait}; use arrow::datatypes::DataType; use datafusion_common::{exec_err, Result}; @@ -26,6 +26,7 @@ use datafusion_expr::{ColumnarValue, Volatility}; use datafusion_expr::{ScalarUDFImpl, Signature}; use crate::string::common::*; +use crate::utils::{make_scalar_function, utf8_to_str_type}; /// Returns the longest string with leading characters removed. If the characters are not specified, whitespace is removed. /// ltrim('zzzytest', 'xyz') = 'test' diff --git a/datafusion/functions/src/string/octet_length.rs b/datafusion/functions/src/string/octet_length.rs index 36a62fbe4e38..639bf6cb48a9 100644 --- a/datafusion/functions/src/string/octet_length.rs +++ b/datafusion/functions/src/string/octet_length.rs @@ -15,16 +15,16 @@ // specific language governing permissions and limitations // under the License. -use arrow::compute::kernels::length::length; use std::any::Any; +use arrow::compute::kernels::length::length; use arrow::datatypes::DataType; use datafusion_common::{exec_err, Result, ScalarValue}; use datafusion_expr::{ColumnarValue, Volatility}; use datafusion_expr::{ScalarUDFImpl, Signature}; -use crate::string::common::*; +use crate::utils::utf8_to_int_type; #[derive(Debug)] pub(super) struct OctetLengthFunc { @@ -86,14 +86,17 @@ impl ScalarUDFImpl for OctetLengthFunc { #[cfg(test)] mod tests { - use crate::string::common::test::test_function; - use crate::string::octet_length::OctetLengthFunc; + use std::sync::Arc; + use arrow::array::{Array, Int32Array, StringArray}; use arrow::datatypes::DataType::Int32; + use datafusion_common::ScalarValue; use datafusion_common::{exec_err, Result}; use datafusion_expr::{ColumnarValue, ScalarUDFImpl}; - use std::sync::Arc; + + use crate::string::octet_length::OctetLengthFunc; + use crate::utils::test::test_function; #[test] fn test_functions() -> Result<()> { diff --git a/datafusion/functions/src/string/overlay.rs b/datafusion/functions/src/string/overlay.rs index d7cc0da8068e..8b9cc03afc4d 100644 --- a/datafusion/functions/src/string/overlay.rs +++ b/datafusion/functions/src/string/overlay.rs @@ -27,7 +27,7 @@ use datafusion_expr::TypeSignature::*; use datafusion_expr::{ColumnarValue, Volatility}; use datafusion_expr::{ScalarUDFImpl, Signature}; -use crate::string::common::*; +use crate::utils::{make_scalar_function, utf8_to_str_type}; #[derive(Debug)] pub(super) struct OverlayFunc { diff --git a/datafusion/functions/src/string/repeat.rs b/datafusion/functions/src/string/repeat.rs index 83bc929cb9a4..f4319af0a5c4 100644 --- a/datafusion/functions/src/string/repeat.rs +++ b/datafusion/functions/src/string/repeat.rs @@ -27,7 +27,7 @@ use datafusion_expr::TypeSignature::*; use datafusion_expr::{ColumnarValue, Volatility}; use datafusion_expr::{ScalarUDFImpl, Signature}; -use crate::string::common::*; +use crate::utils::{make_scalar_function, utf8_to_str_type}; #[derive(Debug)] pub(super) struct RepeatFunc { @@ -99,8 +99,8 @@ mod tests { use datafusion_common::ScalarValue; use datafusion_expr::{ColumnarValue, ScalarUDFImpl}; - use crate::string::common::test::test_function; use crate::string::repeat::RepeatFunc; + use crate::utils::test::test_function; #[test] fn test_functions() -> Result<()> { diff --git a/datafusion/functions/src/string/replace.rs b/datafusion/functions/src/string/replace.rs index e35244296090..e869ac205440 100644 --- a/datafusion/functions/src/string/replace.rs +++ b/datafusion/functions/src/string/replace.rs @@ -27,7 +27,7 @@ use datafusion_expr::TypeSignature::*; use datafusion_expr::{ColumnarValue, Volatility}; use datafusion_expr::{ScalarUDFImpl, Signature}; -use crate::string::common::*; +use crate::utils::{make_scalar_function, utf8_to_str_type}; #[derive(Debug)] pub(super) struct ReplaceFunc { diff --git a/datafusion/functions/src/string/rtrim.rs b/datafusion/functions/src/string/rtrim.rs index 17d2f8234b34..d04d15ce8847 100644 --- a/datafusion/functions/src/string/rtrim.rs +++ b/datafusion/functions/src/string/rtrim.rs @@ -26,6 +26,7 @@ use datafusion_expr::{ColumnarValue, Volatility}; use datafusion_expr::{ScalarUDFImpl, Signature}; use crate::string::common::*; +use crate::utils::{make_scalar_function, utf8_to_str_type}; /// Returns the longest string with trailing characters removed. If the characters are not specified, whitespace is removed. /// rtrim('testxxzx', 'xyz') = 'test' diff --git a/datafusion/functions/src/string/split_part.rs b/datafusion/functions/src/string/split_part.rs index af201e90fcf6..0aa968a1ef5b 100644 --- a/datafusion/functions/src/string/split_part.rs +++ b/datafusion/functions/src/string/split_part.rs @@ -27,7 +27,7 @@ use datafusion_expr::TypeSignature::*; use datafusion_expr::{ColumnarValue, Volatility}; use datafusion_expr::{ScalarUDFImpl, Signature}; -use crate::string::common::*; +use crate::utils::{make_scalar_function, utf8_to_str_type}; #[derive(Debug)] pub(super) struct SplitPartFunc { @@ -117,8 +117,8 @@ mod tests { use datafusion_common::{exec_err, Result}; use datafusion_expr::{ColumnarValue, ScalarUDFImpl}; - use crate::string::common::test::test_function; use crate::string::split_part::SplitPartFunc; + use crate::utils::test::test_function; #[test] fn test_functions() -> Result<()> { diff --git a/datafusion/functions/src/string/starts_with.rs b/datafusion/functions/src/string/starts_with.rs index 4450b9d332a0..f1b03907f8d8 100644 --- a/datafusion/functions/src/string/starts_with.rs +++ b/datafusion/functions/src/string/starts_with.rs @@ -15,15 +15,18 @@ // specific language governing permissions and limitations // under the License. -use crate::string::common::make_scalar_function; +use std::any::Any; +use std::sync::Arc; + use arrow::array::{ArrayRef, OffsetSizeTrait}; use arrow::datatypes::DataType; + use datafusion_common::{cast::as_generic_string_array, internal_err, Result}; use datafusion_expr::ColumnarValue; use datafusion_expr::TypeSignature::*; use datafusion_expr::{ScalarUDFImpl, Signature, Volatility}; -use std::any::Any; -use std::sync::Arc; + +use crate::utils::make_scalar_function; /// Returns true if string starts with prefix. /// starts_with('alphabet', 'alph') = 't' diff --git a/datafusion/functions/src/string/to_hex.rs b/datafusion/functions/src/string/to_hex.rs index 1bdece3f7af8..ab320c68d493 100644 --- a/datafusion/functions/src/string/to_hex.rs +++ b/datafusion/functions/src/string/to_hex.rs @@ -15,18 +15,21 @@ // specific language governing permissions and limitations // under the License. -use crate::string::common::make_scalar_function; +use std::any::Any; +use std::sync::Arc; + use arrow::array::{ArrayRef, GenericStringArray, OffsetSizeTrait}; use arrow::datatypes::{ ArrowNativeType, ArrowPrimitiveType, DataType, Int32Type, Int64Type, }; + use datafusion_common::cast::as_primitive_array; use datafusion_common::Result; use datafusion_common::{exec_err, plan_err}; use datafusion_expr::ColumnarValue; use datafusion_expr::{ScalarUDFImpl, Signature, Volatility}; -use std::any::Any; -use std::sync::Arc; + +use crate::utils::make_scalar_function; /// Converts the number to its equivalent hexadecimal representation. /// to_hex(2147483647) = '7fffffff' diff --git a/datafusion/functions/src/string/upper.rs b/datafusion/functions/src/string/upper.rs index a0c910ebb2c8..066174abf277 100644 --- a/datafusion/functions/src/string/upper.rs +++ b/datafusion/functions/src/string/upper.rs @@ -15,7 +15,8 @@ // specific language governing permissions and limitations // under the License. -use crate::string::common::{handle, utf8_to_str_type}; +use crate::string::common::handle; +use crate::utils::utf8_to_str_type; use arrow::datatypes::DataType; use datafusion_common::Result; use datafusion_expr::ColumnarValue; diff --git a/datafusion/functions/src/unicode/character_length.rs b/datafusion/functions/src/unicode/character_length.rs new file mode 100644 index 000000000000..51331bf9a586 --- /dev/null +++ b/datafusion/functions/src/unicode/character_length.rs @@ -0,0 +1,176 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::utils::{make_scalar_function, utf8_to_int_type}; +use arrow::array::{ + ArrayRef, ArrowPrimitiveType, GenericStringArray, OffsetSizeTrait, PrimitiveArray, +}; +use arrow::datatypes::{ArrowNativeType, DataType, Int32Type, Int64Type}; +use datafusion_common::cast::as_generic_string_array; +use datafusion_common::exec_err; +use datafusion_common::Result; +use datafusion_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility}; +use std::any::Any; +use std::sync::Arc; + +#[derive(Debug)] +pub(super) struct CharacterLengthFunc { + signature: Signature, + aliases: Vec, +} + +impl CharacterLengthFunc { + pub fn new() -> Self { + use DataType::*; + Self { + signature: Signature::uniform( + 1, + vec![Utf8, LargeUtf8], + Volatility::Immutable, + ), + aliases: vec![String::from("length"), String::from("char_length")], + } + } +} + +impl ScalarUDFImpl for CharacterLengthFunc { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "character_length" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, arg_types: &[DataType]) -> Result { + utf8_to_int_type(&arg_types[0], "character_length") + } + + fn invoke(&self, args: &[ColumnarValue]) -> Result { + match args[0].data_type() { + DataType::Utf8 => { + make_scalar_function(character_length::, vec![])(args) + } + DataType::LargeUtf8 => { + make_scalar_function(character_length::, vec![])(args) + } + other => { + exec_err!("Unsupported data type {other:?} for function character_length") + } + } + } + + fn aliases(&self) -> &[String] { + &self.aliases + } +} + +/// Returns number of characters in the string. +/// character_length('josé') = 4 +/// The implementation counts UTF-8 code points to count the number of characters +fn character_length(args: &[ArrayRef]) -> Result +where + T::Native: OffsetSizeTrait, +{ + let string_array: &GenericStringArray = + as_generic_string_array::(&args[0])?; + + let result = string_array + .iter() + .map(|string| { + string.map(|string: &str| { + T::Native::from_usize(string.chars().count()) + .expect("should not fail as string.chars will always return integer") + }) + }) + .collect::>(); + + Ok(Arc::new(result) as ArrayRef) +} + +#[cfg(test)] +mod tests { + use crate::unicode::character_length::CharacterLengthFunc; + use crate::utils::test::test_function; + use arrow::array::{Array, Int32Array}; + use arrow::datatypes::DataType::Int32; + use datafusion_common::{Result, ScalarValue}; + use datafusion_expr::{ColumnarValue, ScalarUDFImpl}; + + #[test] + fn test_functions() -> Result<()> { + #[cfg(feature = "unicode_expressions")] + test_function!( + CharacterLengthFunc::new(), + &[ColumnarValue::Scalar(ScalarValue::Utf8(Some( + String::from("chars") + )))], + Ok(Some(5)), + i32, + Int32, + Int32Array + ); + #[cfg(feature = "unicode_expressions")] + test_function!( + CharacterLengthFunc::new(), + &[ColumnarValue::Scalar(ScalarValue::Utf8(Some( + String::from("josé") + )))], + Ok(Some(4)), + i32, + Int32, + Int32Array + ); + #[cfg(feature = "unicode_expressions")] + test_function!( + CharacterLengthFunc::new(), + &[ColumnarValue::Scalar(ScalarValue::Utf8(Some( + String::from("") + )))], + Ok(Some(0)), + i32, + Int32, + Int32Array + ); + #[cfg(feature = "unicode_expressions")] + test_function!( + CharacterLengthFunc::new(), + &[ColumnarValue::Scalar(ScalarValue::Utf8(None))], + Ok(None), + i32, + Int32, + Int32Array + ); + #[cfg(not(feature = "unicode_expressions"))] + test_function!( + CharacterLengthFunc::new(), + &[ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("josé"))))], + internal_err!( + "function character_length requires compilation with feature flag: unicode_expressions." + ), + i32, + Int32, + Int32Array + ); + + Ok(()) + } +} diff --git a/datafusion/functions/src/unicode/mod.rs b/datafusion/functions/src/unicode/mod.rs new file mode 100644 index 000000000000..291de3843903 --- /dev/null +++ b/datafusion/functions/src/unicode/mod.rs @@ -0,0 +1,55 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! "unicode" DataFusion functions + +use std::sync::Arc; + +use datafusion_expr::ScalarUDF; + +mod character_length; + +// create UDFs +make_udf_function!( + character_length::CharacterLengthFunc, + CHARACTER_LENGTH, + character_length +); + +pub mod expr_fn { + use datafusion_expr::Expr; + + #[doc = "the number of characters in the `string`"] + pub fn char_length(string: Expr) -> Expr { + character_length(string) + } + + #[doc = "the number of characters in the `string`"] + pub fn character_length(string: Expr) -> Expr { + super::character_length().call(vec![string]) + } + + #[doc = "the number of characters in the `string`"] + pub fn length(string: Expr) -> Expr { + character_length(string) + } +} + +/// Return a list of all functions in this package +pub fn functions() -> Vec> { + vec![character_length()] +} diff --git a/datafusion/functions/src/utils.rs b/datafusion/functions/src/utils.rs new file mode 100644 index 000000000000..f45deafdb37a --- /dev/null +++ b/datafusion/functions/src/utils.rs @@ -0,0 +1,178 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::array::ArrayRef; +use arrow::datatypes::DataType; +use datafusion_common::{Result, ScalarValue}; +use datafusion_expr::{ColumnarValue, ScalarFunctionImplementation}; +use datafusion_physical_expr::functions::Hint; +use std::sync::Arc; + +/// Creates a function to identify the optimal return type of a string function given +/// the type of its first argument. +/// +/// If the input type is `LargeUtf8` or `LargeBinary` the return type is +/// `$largeUtf8Type`, +/// +/// If the input type is `Utf8` or `Binary` the return type is `$utf8Type`, +macro_rules! get_optimal_return_type { + ($FUNC:ident, $largeUtf8Type:expr, $utf8Type:expr) => { + pub(crate) fn $FUNC(arg_type: &DataType, name: &str) -> Result { + Ok(match arg_type { + // LargeBinary inputs are automatically coerced to Utf8 + DataType::LargeUtf8 | DataType::LargeBinary => $largeUtf8Type, + // Binary inputs are automatically coerced to Utf8 + DataType::Utf8 | DataType::Binary => $utf8Type, + DataType::Null => DataType::Null, + DataType::Dictionary(_, value_type) => match **value_type { + DataType::LargeUtf8 | DataType::LargeBinary => $largeUtf8Type, + DataType::Utf8 | DataType::Binary => $utf8Type, + DataType::Null => DataType::Null, + _ => { + return datafusion_common::exec_err!( + "The {} function can only accept strings, but got {:?}.", + name.to_uppercase(), + **value_type + ); + } + }, + data_type => { + return datafusion_common::exec_err!( + "The {} function can only accept strings, but got {:?}.", + name.to_uppercase(), + data_type + ); + } + }) + } + }; +} + +// `utf8_to_str_type`: returns either a Utf8 or LargeUtf8 based on the input type size. +get_optimal_return_type!(utf8_to_str_type, DataType::LargeUtf8, DataType::Utf8); + +// `utf8_to_int_type`: returns either a Int32 or Int64 based on the input type size. +get_optimal_return_type!(utf8_to_int_type, DataType::Int64, DataType::Int32); + +pub(super) fn make_scalar_function( + inner: F, + hints: Vec, +) -> ScalarFunctionImplementation +where + F: Fn(&[ArrayRef]) -> Result + Sync + Send + 'static, +{ + Arc::new(move |args: &[ColumnarValue]| { + // first, identify if any of the arguments is an Array. If yes, store its `len`, + // as any scalar will need to be converted to an array of len `len`. + let len = args + .iter() + .fold(Option::::None, |acc, arg| match arg { + ColumnarValue::Scalar(_) => acc, + ColumnarValue::Array(a) => Some(a.len()), + }); + + let is_scalar = len.is_none(); + + let inferred_length = len.unwrap_or(1); + let args = args + .iter() + .zip(hints.iter().chain(std::iter::repeat(&Hint::Pad))) + .map(|(arg, hint)| { + // Decide on the length to expand this scalar to depending + // on the given hints. + let expansion_len = match hint { + Hint::AcceptsSingular => 1, + Hint::Pad => inferred_length, + }; + arg.clone().into_array(expansion_len) + }) + .collect::>>()?; + + let result = (inner)(&args); + if is_scalar { + // If all inputs are scalar, keeps output as scalar + let result = result.and_then(|arr| ScalarValue::try_from_array(&arr, 0)); + result.map(ColumnarValue::Scalar) + } else { + result.map(ColumnarValue::Array) + } + }) +} + +#[cfg(test)] +pub mod test { + /// $FUNC ScalarUDFImpl to test + /// $ARGS arguments (vec) to pass to function + /// $EXPECTED a Result + /// $EXPECTED_TYPE is the expected value type + /// $EXPECTED_DATA_TYPE is the expected result type + /// $ARRAY_TYPE is the column type after function applied + macro_rules! test_function { + ($FUNC:expr, $ARGS:expr, $EXPECTED:expr, $EXPECTED_TYPE:ty, $EXPECTED_DATA_TYPE:expr, $ARRAY_TYPE:ident) => { + let expected: Result> = $EXPECTED; + let func = $FUNC; + + let type_array = $ARGS.iter().map(|arg| arg.data_type()).collect::>(); + let return_type = func.return_type(&type_array); + + match expected { + Ok(expected) => { + assert_eq!(return_type.is_ok(), true); + assert_eq!(return_type.unwrap(), $EXPECTED_DATA_TYPE); + + let result = func.invoke($ARGS); + assert_eq!(result.is_ok(), true); + + let len = $ARGS + .iter() + .fold(Option::::None, |acc, arg| match arg { + ColumnarValue::Scalar(_) => acc, + ColumnarValue::Array(a) => Some(a.len()), + }); + let inferred_length = len.unwrap_or(1); + let result = result.unwrap().clone().into_array(inferred_length).expect("Failed to convert to array"); + let result = result.as_any().downcast_ref::<$ARRAY_TYPE>().expect("Failed to convert to type"); + + // value is correct + match expected { + Some(v) => assert_eq!(result.value(0), v), + None => assert!(result.is_null(0)), + }; + } + Err(expected_error) => { + if return_type.is_err() { + match return_type { + Ok(_) => assert!(false, "expected error"), + Err(error) => { datafusion_common::assert_contains!(expected_error.strip_backtrace(), error.strip_backtrace()); } + } + } + else { + // invoke is expected error - cannot use .expect_err() due to Debug not being implemented + match func.invoke($ARGS) { + Ok(_) => assert!(false, "expected error"), + Err(error) => { + assert!(expected_error.strip_backtrace().starts_with(&error.strip_backtrace())); + } + } + } + } + }; + }; + } + + pub(crate) use test_function; +} diff --git a/datafusion/physical-expr/src/functions.rs b/datafusion/physical-expr/src/functions.rs index cd9bba63d624..9adc8536341d 100644 --- a/datafusion/physical-expr/src/functions.rs +++ b/datafusion/physical-expr/src/functions.rs @@ -254,29 +254,6 @@ pub fn create_physical_fun( Arc::new(|args| make_scalar_function_inner(math_expressions::cot)(args)) } // string functions - BuiltinScalarFunction::CharacterLength => { - Arc::new(|args| match args[0].data_type() { - DataType::Utf8 => { - let func = invoke_if_unicode_expressions_feature_flag!( - character_length, - Int32Type, - "character_length" - ); - make_scalar_function_inner(func)(args) - } - DataType::LargeUtf8 => { - let func = invoke_if_unicode_expressions_feature_flag!( - character_length, - Int64Type, - "character_length" - ); - make_scalar_function_inner(func)(args) - } - other => exec_err!( - "Unsupported data type {other:?} for function character_length" - ), - }) - } BuiltinScalarFunction::Coalesce => Arc::new(conditional_expressions::coalesce), BuiltinScalarFunction::Concat => Arc::new(string_expressions::concat), BuiltinScalarFunction::ConcatWithSeparator => Arc::new(|args| { @@ -595,53 +572,6 @@ mod tests { #[test] fn test_functions() -> Result<()> { - #[cfg(feature = "unicode_expressions")] - test_function!( - CharacterLength, - &[lit("chars")], - Ok(Some(5)), - i32, - Int32, - Int32Array - ); - #[cfg(feature = "unicode_expressions")] - test_function!( - CharacterLength, - &[lit("josé")], - Ok(Some(4)), - i32, - Int32, - Int32Array - ); - #[cfg(feature = "unicode_expressions")] - test_function!( - CharacterLength, - &[lit("")], - Ok(Some(0)), - i32, - Int32, - Int32Array - ); - #[cfg(feature = "unicode_expressions")] - test_function!( - CharacterLength, - &[lit(ScalarValue::Utf8(None))], - Ok(None), - i32, - Int32, - Int32Array - ); - #[cfg(not(feature = "unicode_expressions"))] - test_function!( - CharacterLength, - &[lit("josé")], - internal_err!( - "function character_length requires compilation with feature flag: unicode_expressions." - ), - i32, - Int32, - Int32Array - ); test_function!( Concat, &[lit("aa"), lit("bb"), lit("cc"),], diff --git a/datafusion/physical-expr/src/unicode_expressions.rs b/datafusion/physical-expr/src/unicode_expressions.rs index 8ec9e062d9b7..c7e4b7d7c443 100644 --- a/datafusion/physical-expr/src/unicode_expressions.rs +++ b/datafusion/physical-expr/src/unicode_expressions.rs @@ -36,29 +36,6 @@ use datafusion_common::{ exec_err, Result, }; -/// Returns number of characters in the string. -/// character_length('josé') = 4 -/// The implementation counts UTF-8 code points to count the number of characters -pub fn character_length(args: &[ArrayRef]) -> Result -where - T::Native: OffsetSizeTrait, -{ - let string_array: &GenericStringArray = - as_generic_string_array::(&args[0])?; - - let result = string_array - .iter() - .map(|string| { - string.map(|string: &str| { - T::Native::from_usize(string.chars().count()) - .expect("should not fail as string.chars will always return integer") - }) - }) - .collect::>(); - - Ok(Arc::new(result) as ArrayRef) -} - /// Returns first n characters in the string, or when n is negative, returns all but last |n| characters. /// left('abcde', 2) = 'ab' /// The implementation uses UTF-8 code points as characters diff --git a/datafusion/proto/proto/datafusion.proto b/datafusion/proto/proto/datafusion.proto index f405ecf976be..766ca6633ee1 100644 --- a/datafusion/proto/proto/datafusion.proto +++ b/datafusion/proto/proto/datafusion.proto @@ -565,7 +565,7 @@ enum ScalarFunction { // RegexpMatch = 21; // 22 was BitLength // 23 was Btrim - CharacterLength = 24; + // 24 was CharacterLength // 25 was Chr Concat = 26; ConcatWithSeparator = 27; diff --git a/datafusion/proto/src/generated/pbjson.rs b/datafusion/proto/src/generated/pbjson.rs index 0d22ba5db773..f2814956ef1b 100644 --- a/datafusion/proto/src/generated/pbjson.rs +++ b/datafusion/proto/src/generated/pbjson.rs @@ -22928,7 +22928,6 @@ impl serde::Serialize for ScalarFunction { Self::Sin => "Sin", Self::Sqrt => "Sqrt", Self::Trunc => "Trunc", - Self::CharacterLength => "CharacterLength", Self::Concat => "Concat", Self::ConcatWithSeparator => "ConcatWithSeparator", Self::InitCap => "InitCap", @@ -22988,7 +22987,6 @@ impl<'de> serde::Deserialize<'de> for ScalarFunction { "Sin", "Sqrt", "Trunc", - "CharacterLength", "Concat", "ConcatWithSeparator", "InitCap", @@ -23077,7 +23075,6 @@ impl<'de> serde::Deserialize<'de> for ScalarFunction { "Sin" => Ok(ScalarFunction::Sin), "Sqrt" => Ok(ScalarFunction::Sqrt), "Trunc" => Ok(ScalarFunction::Trunc), - "CharacterLength" => Ok(ScalarFunction::CharacterLength), "Concat" => Ok(ScalarFunction::Concat), "ConcatWithSeparator" => Ok(ScalarFunction::ConcatWithSeparator), "InitCap" => Ok(ScalarFunction::InitCap), diff --git a/datafusion/proto/src/generated/prost.rs b/datafusion/proto/src/generated/prost.rs index 07c3fad15373..ecc94fcdaf99 100644 --- a/datafusion/proto/src/generated/prost.rs +++ b/datafusion/proto/src/generated/prost.rs @@ -2864,7 +2864,7 @@ pub enum ScalarFunction { /// RegexpMatch = 21; /// 22 was BitLength /// 23 was Btrim - CharacterLength = 24, + /// 24 was CharacterLength /// 25 was Chr Concat = 26, ConcatWithSeparator = 27, @@ -3001,7 +3001,6 @@ impl ScalarFunction { ScalarFunction::Sin => "Sin", ScalarFunction::Sqrt => "Sqrt", ScalarFunction::Trunc => "Trunc", - ScalarFunction::CharacterLength => "CharacterLength", ScalarFunction::Concat => "Concat", ScalarFunction::ConcatWithSeparator => "ConcatWithSeparator", ScalarFunction::InitCap => "InitCap", @@ -3055,7 +3054,6 @@ impl ScalarFunction { "Sin" => Some(Self::Sin), "Sqrt" => Some(Self::Sqrt), "Trunc" => Some(Self::Trunc), - "CharacterLength" => Some(Self::CharacterLength), "Concat" => Some(Self::Concat), "ConcatWithSeparator" => Some(Self::ConcatWithSeparator), "InitCap" => Some(Self::InitCap), diff --git a/datafusion/proto/src/logical_plan/from_proto.rs b/datafusion/proto/src/logical_plan/from_proto.rs index 4b9874bf8f65..19edd71a3a80 100644 --- a/datafusion/proto/src/logical_plan/from_proto.rs +++ b/datafusion/proto/src/logical_plan/from_proto.rs @@ -48,8 +48,8 @@ use datafusion_expr::expr::Unnest; use datafusion_expr::expr::{Alias, Placeholder}; use datafusion_expr::window_frame::{check_window_frame, regularize_window_order_by}; use datafusion_expr::{ - acosh, asinh, atan, atan2, atanh, cbrt, ceil, character_length, coalesce, - concat_expr, concat_ws_expr, cos, cosh, cot, degrees, ends_with, exp, + acosh, asinh, atan, atan2, atanh, cbrt, ceil, coalesce, concat_expr, concat_ws_expr, + cos, cosh, cot, degrees, ends_with, exp, expr::{self, InList, Sort, WindowFunction}, factorial, find_in_set, floor, gcd, initcap, iszero, lcm, left, ln, log, log10, log2, logical_plan::{PlanType, StringifiedPlan}, @@ -450,7 +450,6 @@ impl From<&protobuf::ScalarFunction> for BuiltinScalarFunction { ScalarFunction::Concat => Self::Concat, ScalarFunction::Log2 => Self::Log2, ScalarFunction::Signum => Self::Signum, - ScalarFunction::CharacterLength => Self::CharacterLength, ScalarFunction::ConcatWithSeparator => Self::ConcatWithSeparator, ScalarFunction::EndsWith => Self::EndsWith, ScalarFunction::InitCap => Self::InitCap, @@ -1372,9 +1371,6 @@ pub fn parse_expr( ScalarFunction::Signum => { Ok(signum(parse_expr(&args[0], registry, codec)?)) } - ScalarFunction::CharacterLength => { - Ok(character_length(parse_expr(&args[0], registry, codec)?)) - } ScalarFunction::InitCap => { Ok(initcap(parse_expr(&args[0], registry, codec)?)) } diff --git a/datafusion/proto/src/logical_plan/to_proto.rs b/datafusion/proto/src/logical_plan/to_proto.rs index 1335d511a0ea..11fc7362c75d 100644 --- a/datafusion/proto/src/logical_plan/to_proto.rs +++ b/datafusion/proto/src/logical_plan/to_proto.rs @@ -1442,7 +1442,6 @@ impl TryFrom<&BuiltinScalarFunction> for protobuf::ScalarFunction { BuiltinScalarFunction::Concat => Self::Concat, BuiltinScalarFunction::Log2 => Self::Log2, BuiltinScalarFunction::Signum => Self::Signum, - BuiltinScalarFunction::CharacterLength => Self::CharacterLength, BuiltinScalarFunction::ConcatWithSeparator => Self::ConcatWithSeparator, BuiltinScalarFunction::EndsWith => Self::EndsWith, BuiltinScalarFunction::InitCap => Self::InitCap, diff --git a/datafusion/sql/Cargo.toml b/datafusion/sql/Cargo.toml index ca2c1a240c21..b9f6dc259eb7 100644 --- a/datafusion/sql/Cargo.toml +++ b/datafusion/sql/Cargo.toml @@ -49,6 +49,7 @@ strum = { version = "0.26.1", features = ["derive"] } [dev-dependencies] ctor = { workspace = true } +datafusion-functions = { workspace = true, default-features = true } env_logger = { workspace = true } paste = "^1.0" rstest = { workspace = true } diff --git a/datafusion/sql/tests/sql_integration.rs b/datafusion/sql/tests/sql_integration.rs index 448a9c54202e..101c31039c7e 100644 --- a/datafusion/sql/tests/sql_integration.rs +++ b/datafusion/sql/tests/sql_integration.rs @@ -38,6 +38,7 @@ use datafusion_sql::{ planner::{ContextProvider, ParserOptions, PlannerContext, SqlToRel}, }; +use datafusion_functions::unicode; use rstest::rstest; use sqlparser::dialect::{Dialect, GenericDialect, HiveDialect, MySqlDialect}; use sqlparser::parser::Parser; @@ -88,7 +89,7 @@ fn parse_decimals() { fn parse_ident_normalization() { let test_data = [ ( - "SELECT LENGTH('str')", + "SELECT CHARACTER_LENGTH('str')", "Ok(Projection: character_length(Utf8(\"str\"))\n EmptyRelation)", false, ), @@ -2688,6 +2689,7 @@ fn logical_plan_with_dialect_and_options( options: ParserOptions, ) -> Result { let context = MockContextProvider::default() + .with_udf(unicode::character_length().as_ref().clone()) .with_udf(make_udf( "nullif", vec![DataType::Int32, DataType::Int32], @@ -4508,26 +4510,27 @@ fn test_field_not_found_window_function() { #[test] fn test_parse_escaped_string_literal_value() { - let sql = r"SELECT length('\r\n') AS len"; + let sql = r"SELECT character_length('\r\n') AS len"; let expected = "Projection: character_length(Utf8(\"\\r\\n\")) AS len\ \n EmptyRelation"; quick_test(sql, expected); - let sql = r"SELECT length(E'\r\n') AS len"; + let sql = r"SELECT character_length(E'\r\n') AS len"; let expected = "Projection: character_length(Utf8(\"\r\n\")) AS len\ \n EmptyRelation"; quick_test(sql, expected); - let sql = r"SELECT length(E'\445') AS len, E'\x4B' AS hex, E'\u0001' AS unicode"; + let sql = + r"SELECT character_length(E'\445') AS len, E'\x4B' AS hex, E'\u0001' AS unicode"; let expected = "Projection: character_length(Utf8(\"%\")) AS len, Utf8(\"\u{004b}\") AS hex, Utf8(\"\u{0001}\") AS unicode\ \n EmptyRelation"; quick_test(sql, expected); - let sql = r"SELECT length(E'\000') AS len"; + let sql = r"SELECT character_length(E'\000') AS len"; assert_eq!( logical_plan(sql).unwrap_err().strip_backtrace(), - "SQL error: TokenizerError(\"Unterminated encoded string literal at Line: 1, Column 15\")" + "SQL error: TokenizerError(\"Unterminated encoded string literal at Line: 1, Column 25\")" ) } From 47eac75b5ca4973c18846c0dd0bc38feac4eb1f0 Mon Sep 17 00:00:00 2001 From: Bruce Ritchie Date: Thu, 28 Mar 2024 10:43:15 -0400 Subject: [PATCH 5/7] move Left, Lpad, Reverse, Right, Rpad functions to datafusion_functions --- datafusion/expr/src/built_in_function.rs | 50 +- datafusion/expr/src/expr_fn.rs | 21 - datafusion/functions/src/unicode/left.rs | 245 +++++++ datafusion/functions/src/unicode/lpad.rs | 383 +++++++++++ datafusion/functions/src/unicode/mod.rs | 44 +- datafusion/functions/src/unicode/reverse.rs | 159 +++++ datafusion/functions/src/unicode/right.rs | 247 +++++++ datafusion/functions/src/unicode/rpad.rs | 375 +++++++++++ datafusion/physical-expr/src/functions.rs | 606 ------------------ datafusion/physical-expr/src/planner.rs | 4 +- .../physical-expr/src/unicode_expressions.rs | 263 +------- datafusion/proto/proto/datafusion.proto | 10 +- datafusion/proto/src/generated/pbjson.rs | 15 - datafusion/proto/src/generated/prost.rs | 20 +- .../proto/src/logical_plan/from_proto.rs | 53 +- datafusion/proto/src/logical_plan/to_proto.rs | 5 - 16 files changed, 1484 insertions(+), 1016 deletions(-) create mode 100644 datafusion/functions/src/unicode/left.rs create mode 100644 datafusion/functions/src/unicode/lpad.rs create mode 100644 datafusion/functions/src/unicode/reverse.rs create mode 100644 datafusion/functions/src/unicode/right.rs create mode 100644 datafusion/functions/src/unicode/rpad.rs diff --git a/datafusion/expr/src/built_in_function.rs b/datafusion/expr/src/built_in_function.rs index eefbc131a27b..196d278dc70e 100644 --- a/datafusion/expr/src/built_in_function.rs +++ b/datafusion/expr/src/built_in_function.rs @@ -111,18 +111,8 @@ pub enum BuiltinScalarFunction { EndsWith, /// initcap InitCap, - /// left - Left, - /// lpad - Lpad, /// random Random, - /// reverse - Reverse, - /// right - Right, - /// rpad - Rpad, /// strpos Strpos, /// substr @@ -220,12 +210,7 @@ impl BuiltinScalarFunction { BuiltinScalarFunction::ConcatWithSeparator => Volatility::Immutable, BuiltinScalarFunction::EndsWith => Volatility::Immutable, BuiltinScalarFunction::InitCap => Volatility::Immutable, - BuiltinScalarFunction::Left => Volatility::Immutable, - BuiltinScalarFunction::Lpad => Volatility::Immutable, BuiltinScalarFunction::Radians => Volatility::Immutable, - BuiltinScalarFunction::Reverse => Volatility::Immutable, - BuiltinScalarFunction::Right => Volatility::Immutable, - BuiltinScalarFunction::Rpad => Volatility::Immutable, BuiltinScalarFunction::Strpos => Volatility::Immutable, BuiltinScalarFunction::Substr => Volatility::Immutable, BuiltinScalarFunction::Translate => Volatility::Immutable, @@ -264,17 +249,8 @@ impl BuiltinScalarFunction { BuiltinScalarFunction::InitCap => { utf8_to_str_type(&input_expr_types[0], "initcap") } - BuiltinScalarFunction::Left => utf8_to_str_type(&input_expr_types[0], "left"), - BuiltinScalarFunction::Lpad => utf8_to_str_type(&input_expr_types[0], "lpad"), BuiltinScalarFunction::Pi => Ok(Float64), BuiltinScalarFunction::Random => Ok(Float64), - BuiltinScalarFunction::Reverse => { - utf8_to_str_type(&input_expr_types[0], "reverse") - } - BuiltinScalarFunction::Right => { - utf8_to_str_type(&input_expr_types[0], "right") - } - BuiltinScalarFunction::Rpad => utf8_to_str_type(&input_expr_types[0], "rpad"), BuiltinScalarFunction::EndsWith => Ok(Boolean), BuiltinScalarFunction::Strpos => { utf8_to_int_type(&input_expr_types[0], "strpos/instr/position") @@ -361,28 +337,9 @@ impl BuiltinScalarFunction { BuiltinScalarFunction::Coalesce => { Signature::variadic_equal(self.volatility()) } - BuiltinScalarFunction::InitCap | BuiltinScalarFunction::Reverse => { + BuiltinScalarFunction::InitCap => { Signature::uniform(1, vec![Utf8, LargeUtf8], self.volatility()) } - BuiltinScalarFunction::Lpad | BuiltinScalarFunction::Rpad => { - Signature::one_of( - vec![ - Exact(vec![Utf8, Int64]), - Exact(vec![LargeUtf8, Int64]), - Exact(vec![Utf8, Int64, Utf8]), - Exact(vec![LargeUtf8, Int64, Utf8]), - Exact(vec![Utf8, Int64, LargeUtf8]), - Exact(vec![LargeUtf8, Int64, LargeUtf8]), - ], - self.volatility(), - ) - } - BuiltinScalarFunction::Left | BuiltinScalarFunction::Right => { - Signature::one_of( - vec![Exact(vec![Utf8, Int64]), Exact(vec![LargeUtf8, Int64])], - self.volatility(), - ) - } BuiltinScalarFunction::EndsWith | BuiltinScalarFunction::Strpos => { Signature::one_of( @@ -580,11 +537,6 @@ impl BuiltinScalarFunction { BuiltinScalarFunction::ConcatWithSeparator => &["concat_ws"], BuiltinScalarFunction::EndsWith => &["ends_with"], BuiltinScalarFunction::InitCap => &["initcap"], - BuiltinScalarFunction::Left => &["left"], - BuiltinScalarFunction::Lpad => &["lpad"], - BuiltinScalarFunction::Reverse => &["reverse"], - BuiltinScalarFunction::Right => &["right"], - BuiltinScalarFunction::Rpad => &["rpad"], BuiltinScalarFunction::Strpos => &["strpos", "instr", "position"], BuiltinScalarFunction::Substr => &["substr"], BuiltinScalarFunction::Translate => &["translate"], diff --git a/datafusion/expr/src/expr_fn.rs b/datafusion/expr/src/expr_fn.rs index 654464798625..21dab72855e5 100644 --- a/datafusion/expr/src/expr_fn.rs +++ b/datafusion/expr/src/expr_fn.rs @@ -578,25 +578,11 @@ scalar_expr!(Atan2, atan2, y x, "inverse tangent of a division given in the argu scalar_expr!(Log, log, base x, "logarithm of a `x` for a particular `base`"); scalar_expr!(InitCap, initcap, string, "converts the first letter of each word in `string` in uppercase and the remaining characters in lowercase"); -scalar_expr!(Left, left, string n, "returns the first `n` characters in the `string`"); -scalar_expr!(Reverse, reverse, string, "reverses the `string`"); -scalar_expr!(Right, right, string n, "returns the last `n` characters in the `string`"); scalar_expr!(EndsWith, ends_with, string suffix, "whether the `string` ends with the `suffix`"); scalar_expr!(Strpos, strpos, string substring, "finds the position from where the `substring` matches the `string`"); scalar_expr!(Substr, substr, string position, "substring from the `position` to the end"); scalar_expr!(Substr, substring, string position length, "substring from the `position` with `length` characters"); scalar_expr!(Translate, translate, string from to, "replaces the characters in `from` with the counterpart in `to`"); -//use vec as parameter -nary_scalar_expr!( - Lpad, - lpad, - "fill up a string to the length by prepending the characters" -); -nary_scalar_expr!( - Rpad, - rpad, - "fill up a string to the length by appending the characters" -); nary_scalar_expr!(Coalesce, coalesce, "returns `coalesce(args...)`, which evaluates to the value of the first [Expr] which is not NULL"); //there is a func concat_ws before, so use concat_ws_expr as name.c nary_scalar_expr!( @@ -1028,13 +1014,6 @@ mod test { test_scalar_expr!(Gcd, gcd, arg_1, arg_2); test_scalar_expr!(Lcm, lcm, arg_1, arg_2); test_scalar_expr!(InitCap, initcap, string); - test_scalar_expr!(Left, left, string, count); - test_nary_scalar_expr!(Lpad, lpad, string, count); - test_nary_scalar_expr!(Lpad, lpad, string, count, characters); - test_scalar_expr!(Reverse, reverse, string); - test_scalar_expr!(Right, right, string, count); - test_nary_scalar_expr!(Rpad, rpad, string, count); - test_nary_scalar_expr!(Rpad, rpad, string, count, characters); test_scalar_expr!(EndsWith, ends_with, string, characters); test_scalar_expr!(Strpos, strpos, string, substring); test_scalar_expr!(Substr, substr, string, position); diff --git a/datafusion/functions/src/unicode/left.rs b/datafusion/functions/src/unicode/left.rs new file mode 100644 index 000000000000..76da56abc19e --- /dev/null +++ b/datafusion/functions/src/unicode/left.rs @@ -0,0 +1,245 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::any::Any; +use std::cmp::Ordering; +use std::sync::Arc; + +use arrow::array::{ArrayRef, GenericStringArray, OffsetSizeTrait}; +use arrow::datatypes::DataType; + +use datafusion_common::cast::{as_generic_string_array, as_int64_array}; +use datafusion_common::exec_err; +use datafusion_common::Result; +use datafusion_expr::TypeSignature::Exact; +use datafusion_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility}; + +use crate::utils::{make_scalar_function, utf8_to_str_type}; + +#[derive(Debug)] +pub(super) struct LeftFunc { + signature: Signature, +} + +impl LeftFunc { + pub fn new() -> Self { + use DataType::*; + Self { + signature: Signature::one_of( + vec![Exact(vec![Utf8, Int64]), Exact(vec![LargeUtf8, Int64])], + Volatility::Immutable, + ), + } + } +} + +impl ScalarUDFImpl for LeftFunc { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "left" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, arg_types: &[DataType]) -> Result { + utf8_to_str_type(&arg_types[0], "left") + } + + fn invoke(&self, args: &[ColumnarValue]) -> Result { + match args[0].data_type() { + DataType::Utf8 => make_scalar_function(left::, vec![])(args), + DataType::LargeUtf8 => make_scalar_function(left::, vec![])(args), + other => exec_err!("Unsupported data type {other:?} for function left"), + } + } +} + +/// Returns first n characters in the string, or when n is negative, returns all but last |n| characters. +/// left('abcde', 2) = 'ab' +/// The implementation uses UTF-8 code points as characters +pub fn left(args: &[ArrayRef]) -> Result { + let string_array = as_generic_string_array::(&args[0])?; + let n_array = as_int64_array(&args[1])?; + let result = string_array + .iter() + .zip(n_array.iter()) + .map(|(string, n)| match (string, n) { + (Some(string), Some(n)) => match n.cmp(&0) { + Ordering::Less => { + let len = string.chars().count() as i64; + Some(if n.abs() < len { + string.chars().take((len + n) as usize).collect::() + } else { + "".to_string() + }) + } + Ordering::Equal => Some("".to_string()), + Ordering::Greater => { + Some(string.chars().take(n as usize).collect::()) + } + }, + _ => None, + }) + .collect::>(); + + Ok(Arc::new(result) as ArrayRef) +} + +#[cfg(test)] +mod tests { + use arrow::array::{Array, StringArray}; + use arrow::datatypes::DataType::Utf8; + + use datafusion_common::{Result, ScalarValue}; + use datafusion_expr::{ColumnarValue, ScalarUDFImpl}; + + use crate::unicode::left::LeftFunc; + use crate::utils::test::test_function; + + #[test] + fn test_functions() -> Result<()> { + #[cfg(feature = "unicode_expressions")] + test_function!( + LeftFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("abcde")))), + ColumnarValue::Scalar(ScalarValue::Int64(Some(2))), + ], + Ok(Some("ab")), + &str, + Utf8, + StringArray + ); + #[cfg(feature = "unicode_expressions")] + test_function!( + LeftFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("abcde")))), + ColumnarValue::Scalar(ScalarValue::Int64(Some(200))), + ], + Ok(Some("abcde")), + &str, + Utf8, + StringArray + ); + #[cfg(feature = "unicode_expressions")] + test_function!( + LeftFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("abcde")))), + ColumnarValue::Scalar(ScalarValue::Int64(Some(-2))), + ], + Ok(Some("abc")), + &str, + Utf8, + StringArray + ); + #[cfg(feature = "unicode_expressions")] + test_function!( + LeftFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("abcde")))), + ColumnarValue::Scalar(ScalarValue::Int64(Some(-200))), + ], + Ok(Some("")), + &str, + Utf8, + StringArray + ); + #[cfg(feature = "unicode_expressions")] + test_function!( + LeftFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("abcde")))), + ColumnarValue::Scalar(ScalarValue::Int64(Some(0))), + ], + Ok(Some("")), + &str, + Utf8, + StringArray + ); + #[cfg(feature = "unicode_expressions")] + test_function!( + LeftFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::Utf8(None)), + ColumnarValue::Scalar(ScalarValue::Int64(Some(2))), + ], + Ok(None), + &str, + Utf8, + StringArray + ); + #[cfg(feature = "unicode_expressions")] + test_function!( + LeftFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("abcde")))), + ColumnarValue::Scalar(ScalarValue::Int64(None)), + ], + Ok(None), + &str, + Utf8, + StringArray + ); + #[cfg(feature = "unicode_expressions")] + test_function!( + LeftFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("joséésoj")))), + ColumnarValue::Scalar(ScalarValue::Int64(Some(5))), + ], + Ok(Some("joséé")), + &str, + Utf8, + StringArray + ); + #[cfg(feature = "unicode_expressions")] + test_function!( + LeftFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("joséésoj")))), + ColumnarValue::Scalar(ScalarValue::Int64(Some(-3))), + ], + Ok(Some("joséé")), + &str, + Utf8, + StringArray + ); + #[cfg(not(feature = "unicode_expressions"))] + test_function!( + LeftFunc::new90, + &[ + ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("abcde")))), + ColumnarValue::Scalar(ScalarValue::Int64(Some(2))), + ], + internal_err!( + "function left requires compilation with feature flag: unicode_expressions." + ), + &str, + Utf8, + StringArray + ); + + Ok(()) + } +} diff --git a/datafusion/functions/src/unicode/lpad.rs b/datafusion/functions/src/unicode/lpad.rs new file mode 100644 index 000000000000..a0968b36920f --- /dev/null +++ b/datafusion/functions/src/unicode/lpad.rs @@ -0,0 +1,383 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::any::Any; +use std::sync::Arc; + +use arrow::array::{ArrayRef, GenericStringArray, OffsetSizeTrait}; +use arrow::datatypes::DataType; +use datafusion_common::cast::{as_generic_string_array, as_int64_array}; +use unicode_segmentation::UnicodeSegmentation; + +use crate::utils::{make_scalar_function, utf8_to_str_type}; +use datafusion_common::{exec_err, Result}; +use datafusion_expr::TypeSignature::Exact; +use datafusion_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility}; + +#[derive(Debug)] +pub(super) struct LPadFunc { + signature: Signature, +} + +impl LPadFunc { + pub fn new() -> Self { + use DataType::*; + Self { + signature: Signature::one_of( + vec![ + Exact(vec![Utf8, Int64]), + Exact(vec![LargeUtf8, Int64]), + Exact(vec![Utf8, Int64, Utf8]), + Exact(vec![LargeUtf8, Int64, Utf8]), + Exact(vec![Utf8, Int64, LargeUtf8]), + Exact(vec![LargeUtf8, Int64, LargeUtf8]), + ], + Volatility::Immutable, + ), + } + } +} + +impl ScalarUDFImpl for LPadFunc { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "lpad" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, arg_types: &[DataType]) -> Result { + utf8_to_str_type(&arg_types[0], "lpad") + } + + fn invoke(&self, args: &[ColumnarValue]) -> Result { + match args[0].data_type() { + DataType::Utf8 => make_scalar_function(lpad::, vec![])(args), + DataType::LargeUtf8 => make_scalar_function(lpad::, vec![])(args), + other => exec_err!("Unsupported data type {other:?} for function lpad"), + } + } +} + +/// Extends the string to length 'length' by prepending the characters fill (a space by default). If the string is already longer than length then it is truncated (on the right). +/// lpad('hi', 5, 'xy') = 'xyxhi' +pub fn lpad(args: &[ArrayRef]) -> Result { + match args.len() { + 2 => { + let string_array = as_generic_string_array::(&args[0])?; + let length_array = as_int64_array(&args[1])?; + + let result = string_array + .iter() + .zip(length_array.iter()) + .map(|(string, length)| match (string, length) { + (Some(string), Some(length)) => { + if length > i32::MAX as i64 { + return exec_err!( + "lpad requested length {length} too large" + ); + } + + let length = if length < 0 { 0 } else { length as usize }; + if length == 0 { + Ok(Some("".to_string())) + } else { + let graphemes = string.graphemes(true).collect::>(); + if length < graphemes.len() { + Ok(Some(graphemes[..length].concat())) + } else { + let mut s: String = " ".repeat(length - graphemes.len()); + s.push_str(string); + Ok(Some(s)) + } + } + } + _ => Ok(None), + }) + .collect::>>()?; + + Ok(Arc::new(result) as ArrayRef) + } + 3 => { + let string_array = as_generic_string_array::(&args[0])?; + let length_array = as_int64_array(&args[1])?; + let fill_array = as_generic_string_array::(&args[2])?; + + let result = string_array + .iter() + .zip(length_array.iter()) + .zip(fill_array.iter()) + .map(|((string, length), fill)| match (string, length, fill) { + (Some(string), Some(length), Some(fill)) => { + if length > i32::MAX as i64 { + return exec_err!( + "lpad requested length {length} too large" + ); + } + + let length = if length < 0 { 0 } else { length as usize }; + if length == 0 { + Ok(Some("".to_string())) + } else { + let graphemes = string.graphemes(true).collect::>(); + let fill_chars = fill.chars().collect::>(); + + if length < graphemes.len() { + Ok(Some(graphemes[..length].concat())) + } else if fill_chars.is_empty() { + Ok(Some(string.to_string())) + } else { + let mut s = string.to_string(); + let mut char_vector = + Vec::::with_capacity(length - graphemes.len()); + for l in 0..length - graphemes.len() { + char_vector.push( + *fill_chars.get(l % fill_chars.len()).unwrap(), + ); + } + s.insert_str( + 0, + char_vector.iter().collect::().as_str(), + ); + Ok(Some(s)) + } + } + } + _ => Ok(None), + }) + .collect::>>()?; + + Ok(Arc::new(result) as ArrayRef) + } + other => exec_err!( + "lpad was called with {other} arguments. It requires at least 2 and at most 3." + ), + } +} + +#[cfg(test)] +mod tests { + use arrow::array::{Array, StringArray}; + use arrow::datatypes::DataType::Utf8; + + use datafusion_common::{Result, ScalarValue}; + use datafusion_expr::{ColumnarValue, ScalarUDFImpl}; + + use crate::unicode::lpad::LPadFunc; + use crate::utils::test::test_function; + + #[test] + fn test_functions() -> Result<()> { + #[cfg(feature = "unicode_expressions")] + test_function!( + LPadFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("josé")))), + ColumnarValue::Scalar(ScalarValue::Int64(Some(5))), + ], + Ok(Some(" josé")), + &str, + Utf8, + StringArray + ); + #[cfg(feature = "unicode_expressions")] + test_function!( + LPadFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("hi")))), + ColumnarValue::Scalar(ScalarValue::Int64(Some(5))), + ], + Ok(Some(" hi")), + &str, + Utf8, + StringArray + ); + #[cfg(feature = "unicode_expressions")] + test_function!( + LPadFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("hi")))), + ColumnarValue::Scalar(ScalarValue::Int64(Some(0))), + ], + Ok(Some("")), + &str, + Utf8, + StringArray + ); + #[cfg(feature = "unicode_expressions")] + test_function!( + LPadFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("hi")))), + ColumnarValue::Scalar(ScalarValue::Int64(None)), + ], + Ok(None), + &str, + Utf8, + StringArray + ); + #[cfg(feature = "unicode_expressions")] + test_function!( + LPadFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::Utf8(None)), + ColumnarValue::Scalar(ScalarValue::Int64(Some(5))), + ], + Ok(None), + &str, + Utf8, + StringArray + ); + #[cfg(feature = "unicode_expressions")] + test_function!( + LPadFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("hi")))), + ColumnarValue::Scalar(ScalarValue::Int64(Some(5))), + ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("xy")))), + ], + Ok(Some("xyxhi")), + &str, + Utf8, + StringArray + ); + #[cfg(feature = "unicode_expressions")] + test_function!( + LPadFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("hi")))), + ColumnarValue::Scalar(ScalarValue::Int64(Some(21))), + ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("abcdef")))), + ], + Ok(Some("abcdefabcdefabcdefahi")), + &str, + Utf8, + StringArray + ); + #[cfg(feature = "unicode_expressions")] + test_function!( + LPadFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("hi")))), + ColumnarValue::Scalar(ScalarValue::Int64(Some(5))), + ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from(" ")))), + ], + Ok(Some(" hi")), + &str, + Utf8, + StringArray + ); + #[cfg(feature = "unicode_expressions")] + test_function!( + LPadFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("hi")))), + ColumnarValue::Scalar(ScalarValue::Int64(Some(5))), + ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("")))), + ], + Ok(Some("hi")), + &str, + Utf8, + StringArray + ); + #[cfg(feature = "unicode_expressions")] + test_function!( + LPadFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::Utf8(None)), + ColumnarValue::Scalar(ScalarValue::Int64(Some(5))), + ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("xy")))), + ], + Ok(None), + &str, + Utf8, + StringArray + ); + #[cfg(feature = "unicode_expressions")] + test_function!( + LPadFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("hi")))), + ColumnarValue::Scalar(ScalarValue::Int64(None)), + ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("xy")))), + ], + Ok(None), + &str, + Utf8, + StringArray + ); + #[cfg(feature = "unicode_expressions")] + test_function!( + LPadFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("hi")))), + ColumnarValue::Scalar(ScalarValue::Int64(Some(5))), + ColumnarValue::Scalar(ScalarValue::Utf8(None)), + ], + Ok(None), + &str, + Utf8, + StringArray + ); + #[cfg(feature = "unicode_expressions")] + test_function!( + LPadFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("josé")))), + ColumnarValue::Scalar(ScalarValue::Int64(Some(10))), + ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("xy")))), + ], + Ok(Some("xyxyxyjosé")), + &str, + Utf8, + StringArray + ); + #[cfg(feature = "unicode_expressions")] + test_function!( + LPadFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("josé")))), + ColumnarValue::Scalar(ScalarValue::Int64(Some(10))), + ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("éñ")))), + ], + Ok(Some("éñéñéñjosé")), + &str, + Utf8, + StringArray + ); + #[cfg(not(feature = "unicode_expressions"))] + test_function!( + LPadFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("josé")))), + ColumnarValue::Scalar(ScalarValue::Int64(Some(5))), + ], + internal_err!( + "function lpad requires compilation with feature flag: unicode_expressions." + ), + &str, + Utf8, + StringArray + ); + Ok(()) + } +} diff --git a/datafusion/functions/src/unicode/mod.rs b/datafusion/functions/src/unicode/mod.rs index 291de3843903..ea4e70a92199 100644 --- a/datafusion/functions/src/unicode/mod.rs +++ b/datafusion/functions/src/unicode/mod.rs @@ -22,6 +22,11 @@ use std::sync::Arc; use datafusion_expr::ScalarUDF; mod character_length; +mod left; +mod lpad; +mod reverse; +mod right; +mod rpad; // create UDFs make_udf_function!( @@ -29,6 +34,11 @@ make_udf_function!( CHARACTER_LENGTH, character_length ); +make_udf_function!(left::LeftFunc, LEFT, left); +make_udf_function!(lpad::LPadFunc, LPAD, lpad); +make_udf_function!(right::RightFunc, RIGHT, right); +make_udf_function!(reverse::ReverseFunc, REVERSE, reverse); +make_udf_function!(rpad::RPadFunc, RPAD, rpad); pub mod expr_fn { use datafusion_expr::Expr; @@ -47,9 +57,41 @@ pub mod expr_fn { pub fn length(string: Expr) -> Expr { character_length(string) } + + #[doc = "returns the first `n` characters in the `string`"] + pub fn left(string: Expr, n: Expr) -> Expr { + super::left().call(vec![string, n]) + } + + #[doc = "fill up a string to the length by prepending the characters"] + pub fn lpad(args: Vec) -> Expr { + super::lpad().call(args) + } + + #[doc = "reverses the `string`"] + pub fn reverse(string: Expr) -> Expr { + super::reverse().call(vec![string]) + } + + #[doc = "returns the last `n` characters in the `string`"] + pub fn right(string: Expr, n: Expr) -> Expr { + super::right().call(vec![string, n]) + } + + #[doc = "fill up a string to the length by appending the characters"] + pub fn rpad(args: Vec) -> Expr { + super::rpad().call(args) + } } /// Return a list of all functions in this package pub fn functions() -> Vec> { - vec![character_length()] + vec![ + character_length(), + left(), + lpad(), + reverse(), + right(), + rpad(), + ] } diff --git a/datafusion/functions/src/unicode/reverse.rs b/datafusion/functions/src/unicode/reverse.rs new file mode 100644 index 000000000000..e1996fcb39c4 --- /dev/null +++ b/datafusion/functions/src/unicode/reverse.rs @@ -0,0 +1,159 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::any::Any; +use std::sync::Arc; + +use arrow::array::{ArrayRef, GenericStringArray, OffsetSizeTrait}; +use arrow::datatypes::DataType; + +use datafusion_common::cast::as_generic_string_array; +use datafusion_common::{exec_err, Result}; +use datafusion_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility}; + +use crate::utils::{make_scalar_function, utf8_to_str_type}; + +#[derive(Debug)] +pub(super) struct ReverseFunc { + signature: Signature, +} + +impl ReverseFunc { + pub fn new() -> Self { + use DataType::*; + Self { + signature: Signature::uniform( + 1, + vec![Utf8, LargeUtf8], + Volatility::Immutable, + ), + } + } +} + +impl ScalarUDFImpl for ReverseFunc { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "reverse" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, arg_types: &[DataType]) -> Result { + utf8_to_str_type(&arg_types[0], "reverse") + } + + fn invoke(&self, args: &[ColumnarValue]) -> Result { + match args[0].data_type() { + DataType::Utf8 => make_scalar_function(reverse::, vec![])(args), + DataType::LargeUtf8 => make_scalar_function(reverse::, vec![])(args), + other => { + exec_err!("Unsupported data type {other:?} for function reverse") + } + } + } +} + +/// Reverses the order of the characters in the string. +/// reverse('abcde') = 'edcba' +/// The implementation uses UTF-8 code points as characters +pub fn reverse(args: &[ArrayRef]) -> Result { + let string_array = as_generic_string_array::(&args[0])?; + + let result = string_array + .iter() + .map(|string| string.map(|string: &str| string.chars().rev().collect::())) + .collect::>(); + + Ok(Arc::new(result) as ArrayRef) +} + +#[cfg(test)] +mod tests { + use arrow::array::{Array, StringArray}; + use arrow::datatypes::DataType::Utf8; + + use datafusion_common::{Result, ScalarValue}; + use datafusion_expr::{ColumnarValue, ScalarUDFImpl}; + + use crate::unicode::reverse::ReverseFunc; + use crate::utils::test::test_function; + + #[test] + fn test_functions() -> Result<()> { + #[cfg(feature = "unicode_expressions")] + test_function!( + ReverseFunc::new(), + &[ColumnarValue::Scalar(ScalarValue::Utf8(Some( + String::from("abcde") + )))], + Ok(Some("edcba")), + &str, + Utf8, + StringArray + ); + #[cfg(feature = "unicode_expressions")] + test_function!( + ReverseFunc::new(), + &[ColumnarValue::Scalar(ScalarValue::Utf8(Some( + String::from("loẅks") + )))], + Ok(Some("sk̈wol")), + &str, + Utf8, + StringArray + ); + #[cfg(feature = "unicode_expressions")] + test_function!( + ReverseFunc::new(), + &[ColumnarValue::Scalar(ScalarValue::Utf8(Some( + String::from("loẅks") + )))], + Ok(Some("sk̈wol")), + &str, + Utf8, + StringArray + ); + #[cfg(feature = "unicode_expressions")] + test_function!( + ReverseFunc::new(), + &[ColumnarValue::Scalar(ScalarValue::Utf8(None))], + Ok(None), + &str, + Utf8, + StringArray + ); + #[cfg(not(feature = "unicode_expressions"))] + test_function!( + ReverseFunc::new(), + &[ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("abcde"))))], + internal_err!( + "function reverse requires compilation with feature flag: unicode_expressions." + ), + &str, + Utf8, + StringArray + ); + + Ok(()) + } +} diff --git a/datafusion/functions/src/unicode/right.rs b/datafusion/functions/src/unicode/right.rs new file mode 100644 index 000000000000..5eddf7b37bf0 --- /dev/null +++ b/datafusion/functions/src/unicode/right.rs @@ -0,0 +1,247 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::any::Any; +use std::cmp::{max, Ordering}; +use std::sync::Arc; + +use arrow::array::{ArrayRef, GenericStringArray, OffsetSizeTrait}; +use arrow::datatypes::DataType; + +use datafusion_common::cast::{as_generic_string_array, as_int64_array}; +use datafusion_common::exec_err; +use datafusion_common::Result; +use datafusion_expr::TypeSignature::Exact; +use datafusion_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility}; + +use crate::utils::{make_scalar_function, utf8_to_str_type}; + +#[derive(Debug)] +pub(super) struct RightFunc { + signature: Signature, +} + +impl RightFunc { + pub fn new() -> Self { + use DataType::*; + Self { + signature: Signature::one_of( + vec![Exact(vec![Utf8, Int64]), Exact(vec![LargeUtf8, Int64])], + Volatility::Immutable, + ), + } + } +} + +impl ScalarUDFImpl for RightFunc { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "right" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, arg_types: &[DataType]) -> Result { + utf8_to_str_type(&arg_types[0], "right") + } + + fn invoke(&self, args: &[ColumnarValue]) -> Result { + match args[0].data_type() { + DataType::Utf8 => make_scalar_function(right::, vec![])(args), + DataType::LargeUtf8 => make_scalar_function(right::, vec![])(args), + other => exec_err!("Unsupported data type {other:?} for function right"), + } + } +} + +/// Returns last n characters in the string, or when n is negative, returns all but first |n| characters. +/// right('abcde', 2) = 'de' +/// The implementation uses UTF-8 code points as characters +pub fn right(args: &[ArrayRef]) -> Result { + let string_array = as_generic_string_array::(&args[0])?; + let n_array = as_int64_array(&args[1])?; + + let result = string_array + .iter() + .zip(n_array.iter()) + .map(|(string, n)| match (string, n) { + (Some(string), Some(n)) => match n.cmp(&0) { + Ordering::Less => Some( + string + .chars() + .skip(n.unsigned_abs() as usize) + .collect::(), + ), + Ordering::Equal => Some("".to_string()), + Ordering::Greater => Some( + string + .chars() + .skip(max(string.chars().count() as i64 - n, 0) as usize) + .collect::(), + ), + }, + _ => None, + }) + .collect::>(); + + Ok(Arc::new(result) as ArrayRef) +} + +#[cfg(test)] +mod tests { + use arrow::array::{Array, StringArray}; + use arrow::datatypes::DataType::Utf8; + + use datafusion_common::{Result, ScalarValue}; + use datafusion_expr::{ColumnarValue, ScalarUDFImpl}; + + use crate::unicode::right::RightFunc; + use crate::utils::test::test_function; + + #[test] + fn test_functions() -> Result<()> { + #[cfg(feature = "unicode_expressions")] + test_function!( + RightFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("abcde")))), + ColumnarValue::Scalar(ScalarValue::Int64(Some(2))), + ], + Ok(Some("de")), + &str, + Utf8, + StringArray + ); + #[cfg(feature = "unicode_expressions")] + test_function!( + RightFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("abcde")))), + ColumnarValue::Scalar(ScalarValue::Int64(Some(200))), + ], + Ok(Some("abcde")), + &str, + Utf8, + StringArray + ); + #[cfg(feature = "unicode_expressions")] + test_function!( + RightFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("abcde")))), + ColumnarValue::Scalar(ScalarValue::Int64(Some(-2))), + ], + Ok(Some("cde")), + &str, + Utf8, + StringArray + ); + #[cfg(feature = "unicode_expressions")] + test_function!( + RightFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("abcde")))), + ColumnarValue::Scalar(ScalarValue::Int64(Some(-200))), + ], + Ok(Some("")), + &str, + Utf8, + StringArray + ); + #[cfg(feature = "unicode_expressions")] + test_function!( + RightFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("abcde")))), + ColumnarValue::Scalar(ScalarValue::Int64(Some(0))), + ], + Ok(Some("")), + &str, + Utf8, + StringArray + ); + #[cfg(feature = "unicode_expressions")] + test_function!( + RightFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::Utf8(None)), + ColumnarValue::Scalar(ScalarValue::Int64(Some(2))), + ], + Ok(None), + &str, + Utf8, + StringArray + ); + #[cfg(feature = "unicode_expressions")] + test_function!( + RightFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("abcde")))), + ColumnarValue::Scalar(ScalarValue::Int64(None)), + ], + Ok(None), + &str, + Utf8, + StringArray + ); + #[cfg(feature = "unicode_expressions")] + test_function!( + RightFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("joséésoj")))), + ColumnarValue::Scalar(ScalarValue::Int64(Some(5))), + ], + Ok(Some("éésoj")), + &str, + Utf8, + StringArray + ); + #[cfg(feature = "unicode_expressions")] + test_function!( + RightFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("joséésoj")))), + ColumnarValue::Scalar(ScalarValue::Int64(Some(-3))), + ], + Ok(Some("éésoj")), + &str, + Utf8, + StringArray + ); + #[cfg(not(feature = "unicode_expressions"))] + test_function!( + Right, + &[ + ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("abcde")))), + ColumnarValue::Scalar(ScalarValue::Int64(Some(2))), + ], + internal_err!( + "function right requires compilation with feature flag: unicode_expressions." + ), + &str, + Utf8, + StringArray + ); + + Ok(()) + } +} diff --git a/datafusion/functions/src/unicode/rpad.rs b/datafusion/functions/src/unicode/rpad.rs new file mode 100644 index 000000000000..352b2f823008 --- /dev/null +++ b/datafusion/functions/src/unicode/rpad.rs @@ -0,0 +1,375 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::any::Any; +use std::sync::Arc; + +use arrow::array::{ArrayRef, GenericStringArray, OffsetSizeTrait}; +use arrow::datatypes::DataType; +use datafusion_common::cast::{as_generic_string_array, as_int64_array}; +use unicode_segmentation::UnicodeSegmentation; + +use crate::utils::{make_scalar_function, utf8_to_str_type}; +use datafusion_common::{exec_err, Result}; +use datafusion_expr::TypeSignature::Exact; +use datafusion_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility}; + +#[derive(Debug)] +pub(super) struct RPadFunc { + signature: Signature, +} + +impl RPadFunc { + pub fn new() -> Self { + use DataType::*; + Self { + signature: Signature::one_of( + vec![ + Exact(vec![Utf8, Int64]), + Exact(vec![LargeUtf8, Int64]), + Exact(vec![Utf8, Int64, Utf8]), + Exact(vec![LargeUtf8, Int64, Utf8]), + Exact(vec![Utf8, Int64, LargeUtf8]), + Exact(vec![LargeUtf8, Int64, LargeUtf8]), + ], + Volatility::Immutable, + ), + } + } +} + +impl ScalarUDFImpl for RPadFunc { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "rpad" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, arg_types: &[DataType]) -> Result { + utf8_to_str_type(&arg_types[0], "rpad") + } + + fn invoke(&self, args: &[ColumnarValue]) -> Result { + match args[0].data_type() { + DataType::Utf8 => make_scalar_function(rpad::, vec![])(args), + DataType::LargeUtf8 => make_scalar_function(rpad::, vec![])(args), + other => exec_err!("Unsupported data type {other:?} for function rpad"), + } + } +} + +/// Extends the string to length 'length' by appending the characters fill (a space by default). If the string is already longer than length then it is truncated. +/// rpad('hi', 5, 'xy') = 'hixyx' +pub fn rpad(args: &[ArrayRef]) -> Result { + match args.len() { + 2 => { + let string_array = as_generic_string_array::(&args[0])?; + let length_array = as_int64_array(&args[1])?; + + let result = string_array + .iter() + .zip(length_array.iter()) + .map(|(string, length)| match (string, length) { + (Some(string), Some(length)) => { + if length > i32::MAX as i64 { + return exec_err!( + "rpad requested length {length} too large" + ); + } + + let length = if length < 0 { 0 } else { length as usize }; + if length == 0 { + Ok(Some("".to_string())) + } else { + let graphemes = string.graphemes(true).collect::>(); + if length < graphemes.len() { + Ok(Some(graphemes[..length].concat())) + } else { + let mut s = string.to_string(); + s.push_str(" ".repeat(length - graphemes.len()).as_str()); + Ok(Some(s)) + } + } + } + _ => Ok(None), + }) + .collect::>>()?; + Ok(Arc::new(result) as ArrayRef) + } + 3 => { + let string_array = as_generic_string_array::(&args[0])?; + let length_array = as_int64_array(&args[1])?; + let fill_array = as_generic_string_array::(&args[2])?; + + let result = string_array + .iter() + .zip(length_array.iter()) + .zip(fill_array.iter()) + .map(|((string, length), fill)| match (string, length, fill) { + (Some(string), Some(length), Some(fill)) => { + if length > i32::MAX as i64 { + return exec_err!( + "rpad requested length {length} too large" + ); + } + + let length = if length < 0 { 0 } else { length as usize }; + let graphemes = string.graphemes(true).collect::>(); + let fill_chars = fill.chars().collect::>(); + + if length < graphemes.len() { + Ok(Some(graphemes[..length].concat())) + } else if fill_chars.is_empty() { + Ok(Some(string.to_string())) + } else { + let mut s = string.to_string(); + let mut char_vector = + Vec::::with_capacity(length - graphemes.len()); + for l in 0..length - graphemes.len() { + char_vector + .push(*fill_chars.get(l % fill_chars.len()).unwrap()); + } + s.push_str(char_vector.iter().collect::().as_str()); + Ok(Some(s)) + } + } + _ => Ok(None), + }) + .collect::>>()?; + + Ok(Arc::new(result) as ArrayRef) + } + other => exec_err!( + "rpad was called with {other} arguments. It requires at least 2 and at most 3." + ), + } +} + +#[cfg(test)] +mod tests { + use arrow::array::{Array, StringArray}; + use arrow::datatypes::DataType::Utf8; + + use datafusion_common::{Result, ScalarValue}; + use datafusion_expr::{ColumnarValue, ScalarUDFImpl}; + + use crate::unicode::rpad::RPadFunc; + use crate::utils::test::test_function; + + #[test] + fn test_functions() -> Result<()> { + #[cfg(feature = "unicode_expressions")] + test_function!( + RPadFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("josé")))), + ColumnarValue::Scalar(ScalarValue::Int64(Some(5))), + ], + Ok(Some("josé ")), + &str, + Utf8, + StringArray + ); + #[cfg(feature = "unicode_expressions")] + test_function!( + RPadFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("hi")))), + ColumnarValue::Scalar(ScalarValue::Int64(Some(5))), + ], + Ok(Some("hi ")), + &str, + Utf8, + StringArray + ); + #[cfg(feature = "unicode_expressions")] + test_function!( + RPadFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("hi")))), + ColumnarValue::Scalar(ScalarValue::Int64(Some(0))), + ], + Ok(Some("")), + &str, + Utf8, + StringArray + ); + #[cfg(feature = "unicode_expressions")] + test_function!( + RPadFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("hi")))), + ColumnarValue::Scalar(ScalarValue::Int64(None)), + ], + Ok(None), + &str, + Utf8, + StringArray + ); + #[cfg(feature = "unicode_expressions")] + test_function!( + RPadFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::Utf8(None)), + ColumnarValue::Scalar(ScalarValue::Int64(Some(5))), + ], + Ok(None), + &str, + Utf8, + StringArray + ); + #[cfg(feature = "unicode_expressions")] + test_function!( + RPadFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("hi")))), + ColumnarValue::Scalar(ScalarValue::Int64(Some(5))), + ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("xy")))), + ], + Ok(Some("hixyx")), + &str, + Utf8, + StringArray + ); + #[cfg(feature = "unicode_expressions")] + test_function!( + RPadFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("hi")))), + ColumnarValue::Scalar(ScalarValue::Int64(Some(21))), + ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("abcdef")))), + ], + Ok(Some("hiabcdefabcdefabcdefa")), + &str, + Utf8, + StringArray + ); + #[cfg(feature = "unicode_expressions")] + test_function!( + RPadFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("hi")))), + ColumnarValue::Scalar(ScalarValue::Int64(Some(5))), + ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from(" ")))), + ], + Ok(Some("hi ")), + &str, + Utf8, + StringArray + ); + #[cfg(feature = "unicode_expressions")] + test_function!( + RPadFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("hi")))), + ColumnarValue::Scalar(ScalarValue::Int64(Some(5))), + ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("")))), + ], + Ok(Some("hi")), + &str, + Utf8, + StringArray + ); + #[cfg(feature = "unicode_expressions")] + test_function!( + RPadFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::Utf8(None)), + ColumnarValue::Scalar(ScalarValue::Int64(Some(5))), + ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("xy")))), + ], + Ok(None), + &str, + Utf8, + StringArray + ); + #[cfg(feature = "unicode_expressions")] + test_function!( + RPadFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("hi")))), + ColumnarValue::Scalar(ScalarValue::Int64(None)), + ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("xy")))), + ], + Ok(None), + &str, + Utf8, + StringArray + ); + #[cfg(feature = "unicode_expressions")] + test_function!( + RPadFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("hi")))), + ColumnarValue::Scalar(ScalarValue::Int64(Some(5))), + ColumnarValue::Scalar(ScalarValue::Utf8(None)), + ], + Ok(None), + &str, + Utf8, + StringArray + ); + #[cfg(feature = "unicode_expressions")] + test_function!( + RPadFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("josé")))), + ColumnarValue::Scalar(ScalarValue::Int64(Some(10))), + ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("xy")))), + ], + Ok(Some("joséxyxyxy")), + &str, + Utf8, + StringArray + ); + #[cfg(feature = "unicode_expressions")] + test_function!( + RPadFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("josé")))), + ColumnarValue::Scalar(ScalarValue::Int64(Some(10))), + ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("éñ")))), + ], + Ok(Some("josééñéñéñ")), + &str, + Utf8, + StringArray + ); + #[cfg(not(feature = "unicode_expressions"))] + test_function!( + RPadFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("josé")))), + ColumnarValue::Scalar(ScalarValue::Int64(Some(5))), + ], + internal_err!( + "function rpad requires compilation with feature flag: unicode_expressions." + ), + &str, + Utf8, + StringArray + ); + + Ok(()) + } +} diff --git a/datafusion/physical-expr/src/functions.rs b/datafusion/physical-expr/src/functions.rs index 9adc8536341d..c1b4900e399a 100644 --- a/datafusion/physical-expr/src/functions.rs +++ b/datafusion/physical-expr/src/functions.rs @@ -270,67 +270,6 @@ pub fn create_physical_fun( exec_err!("Unsupported data type {other:?} for function initcap") } }), - BuiltinScalarFunction::Left => Arc::new(|args| match args[0].data_type() { - DataType::Utf8 => { - let func = invoke_if_unicode_expressions_feature_flag!(left, i32, "left"); - make_scalar_function_inner(func)(args) - } - DataType::LargeUtf8 => { - let func = invoke_if_unicode_expressions_feature_flag!(left, i64, "left"); - make_scalar_function_inner(func)(args) - } - other => exec_err!("Unsupported data type {other:?} for function left"), - }), - BuiltinScalarFunction::Lpad => Arc::new(|args| match args[0].data_type() { - DataType::Utf8 => { - let func = invoke_if_unicode_expressions_feature_flag!(lpad, i32, "lpad"); - make_scalar_function_inner(func)(args) - } - DataType::LargeUtf8 => { - let func = invoke_if_unicode_expressions_feature_flag!(lpad, i64, "lpad"); - make_scalar_function_inner(func)(args) - } - other => exec_err!("Unsupported data type {other:?} for function lpad"), - }), - BuiltinScalarFunction::Reverse => Arc::new(|args| match args[0].data_type() { - DataType::Utf8 => { - let func = - invoke_if_unicode_expressions_feature_flag!(reverse, i32, "reverse"); - make_scalar_function_inner(func)(args) - } - DataType::LargeUtf8 => { - let func = - invoke_if_unicode_expressions_feature_flag!(reverse, i64, "reverse"); - make_scalar_function_inner(func)(args) - } - other => { - exec_err!("Unsupported data type {other:?} for function reverse") - } - }), - BuiltinScalarFunction::Right => Arc::new(|args| match args[0].data_type() { - DataType::Utf8 => { - let func = - invoke_if_unicode_expressions_feature_flag!(right, i32, "right"); - make_scalar_function_inner(func)(args) - } - DataType::LargeUtf8 => { - let func = - invoke_if_unicode_expressions_feature_flag!(right, i64, "right"); - make_scalar_function_inner(func)(args) - } - other => exec_err!("Unsupported data type {other:?} for function right"), - }), - BuiltinScalarFunction::Rpad => Arc::new(|args| match args[0].data_type() { - DataType::Utf8 => { - let func = invoke_if_unicode_expressions_feature_flag!(rpad, i32, "rpad"); - make_scalar_function_inner(func)(args) - } - DataType::LargeUtf8 => { - let func = invoke_if_unicode_expressions_feature_flag!(rpad, i64, "rpad"); - make_scalar_function_inner(func)(args) - } - other => exec_err!("Unsupported data type {other:?} for function rpad"), - }), BuiltinScalarFunction::EndsWith => Arc::new(|args| match args[0].data_type() { DataType::Utf8 => { make_scalar_function_inner(string_expressions::ends_with::)(args) @@ -691,551 +630,6 @@ mod tests { Utf8, StringArray ); - #[cfg(feature = "unicode_expressions")] - test_function!( - Left, - &[lit("abcde"), lit(ScalarValue::Int8(Some(2))),], - Ok(Some("ab")), - &str, - Utf8, - StringArray - ); - #[cfg(feature = "unicode_expressions")] - test_function!( - Left, - &[lit("abcde"), lit(ScalarValue::Int64(Some(200))),], - Ok(Some("abcde")), - &str, - Utf8, - StringArray - ); - #[cfg(feature = "unicode_expressions")] - test_function!( - Left, - &[lit("abcde"), lit(ScalarValue::Int64(Some(-2))),], - Ok(Some("abc")), - &str, - Utf8, - StringArray - ); - #[cfg(feature = "unicode_expressions")] - test_function!( - Left, - &[lit("abcde"), lit(ScalarValue::Int64(Some(-200))),], - Ok(Some("")), - &str, - Utf8, - StringArray - ); - #[cfg(feature = "unicode_expressions")] - test_function!( - Left, - &[lit("abcde"), lit(ScalarValue::Int64(Some(0))),], - Ok(Some("")), - &str, - Utf8, - StringArray - ); - #[cfg(feature = "unicode_expressions")] - test_function!( - Left, - &[ - lit(ScalarValue::Utf8(None)), - lit(ScalarValue::Int64(Some(2))), - ], - Ok(None), - &str, - Utf8, - StringArray - ); - #[cfg(feature = "unicode_expressions")] - test_function!( - Left, - &[lit("abcde"), lit(ScalarValue::Int64(None)),], - Ok(None), - &str, - Utf8, - StringArray - ); - #[cfg(feature = "unicode_expressions")] - test_function!( - Left, - &[lit("joséésoj"), lit(ScalarValue::Int64(Some(5))),], - Ok(Some("joséé")), - &str, - Utf8, - StringArray - ); - #[cfg(feature = "unicode_expressions")] - test_function!( - Left, - &[lit("joséésoj"), lit(ScalarValue::Int64(Some(-3))),], - Ok(Some("joséé")), - &str, - Utf8, - StringArray - ); - #[cfg(not(feature = "unicode_expressions"))] - test_function!( - Left, - &[ - lit("abcde"), - lit(ScalarValue::Int8(Some(2))), - ], - internal_err!( - "function left requires compilation with feature flag: unicode_expressions." - ), - &str, - Utf8, - StringArray - ); - #[cfg(feature = "unicode_expressions")] - test_function!( - Lpad, - &[lit("josé"), lit(ScalarValue::Int64(Some(5))),], - Ok(Some(" josé")), - &str, - Utf8, - StringArray - ); - #[cfg(feature = "unicode_expressions")] - test_function!( - Lpad, - &[lit("hi"), lit(ScalarValue::Int64(Some(5))),], - Ok(Some(" hi")), - &str, - Utf8, - StringArray - ); - #[cfg(feature = "unicode_expressions")] - test_function!( - Lpad, - &[lit("hi"), lit(ScalarValue::Int64(Some(0))),], - Ok(Some("")), - &str, - Utf8, - StringArray - ); - #[cfg(feature = "unicode_expressions")] - test_function!( - Lpad, - &[lit("hi"), lit(ScalarValue::Int64(None)),], - Ok(None), - &str, - Utf8, - StringArray - ); - #[cfg(feature = "unicode_expressions")] - test_function!( - Lpad, - &[ - lit(ScalarValue::Utf8(None)), - lit(ScalarValue::Int64(Some(5))), - ], - Ok(None), - &str, - Utf8, - StringArray - ); - #[cfg(feature = "unicode_expressions")] - test_function!( - Lpad, - &[lit("hi"), lit(ScalarValue::Int64(Some(5))), lit("xy"),], - Ok(Some("xyxhi")), - &str, - Utf8, - StringArray - ); - #[cfg(feature = "unicode_expressions")] - test_function!( - Lpad, - &[lit("hi"), lit(ScalarValue::Int64(Some(21))), lit("abcdef"),], - Ok(Some("abcdefabcdefabcdefahi")), - &str, - Utf8, - StringArray - ); - #[cfg(feature = "unicode_expressions")] - test_function!( - Lpad, - &[lit("hi"), lit(ScalarValue::Int64(Some(5))), lit(" "),], - Ok(Some(" hi")), - &str, - Utf8, - StringArray - ); - #[cfg(feature = "unicode_expressions")] - test_function!( - Lpad, - &[lit("hi"), lit(ScalarValue::Int64(Some(5))), lit(""),], - Ok(Some("hi")), - &str, - Utf8, - StringArray - ); - #[cfg(feature = "unicode_expressions")] - test_function!( - Lpad, - &[ - lit(ScalarValue::Utf8(None)), - lit(ScalarValue::Int64(Some(5))), - lit("xy"), - ], - Ok(None), - &str, - Utf8, - StringArray - ); - #[cfg(feature = "unicode_expressions")] - test_function!( - Lpad, - &[lit("hi"), lit(ScalarValue::Int64(None)), lit("xy"),], - Ok(None), - &str, - Utf8, - StringArray - ); - #[cfg(feature = "unicode_expressions")] - test_function!( - Lpad, - &[ - lit("hi"), - lit(ScalarValue::Int64(Some(5))), - lit(ScalarValue::Utf8(None)), - ], - Ok(None), - &str, - Utf8, - StringArray - ); - #[cfg(feature = "unicode_expressions")] - test_function!( - Lpad, - &[lit("josé"), lit(ScalarValue::Int64(Some(10))), lit("xy"),], - Ok(Some("xyxyxyjosé")), - &str, - Utf8, - StringArray - ); - #[cfg(feature = "unicode_expressions")] - test_function!( - Lpad, - &[lit("josé"), lit(ScalarValue::Int64(Some(10))), lit("éñ"),], - Ok(Some("éñéñéñjosé")), - &str, - Utf8, - StringArray - ); - #[cfg(not(feature = "unicode_expressions"))] - test_function!( - Lpad, - &[ - lit("josé"), - lit(ScalarValue::Int64(Some(5))), - ], - internal_err!( - "function lpad requires compilation with feature flag: unicode_expressions." - ), - &str, - Utf8, - StringArray - ); - #[cfg(feature = "unicode_expressions")] - test_function!( - Reverse, - &[lit("abcde")], - Ok(Some("edcba")), - &str, - Utf8, - StringArray - ); - #[cfg(feature = "unicode_expressions")] - test_function!( - Reverse, - &[lit("loẅks")], - Ok(Some("sk̈wol")), - &str, - Utf8, - StringArray - ); - #[cfg(feature = "unicode_expressions")] - test_function!( - Reverse, - &[lit("loẅks")], - Ok(Some("sk̈wol")), - &str, - Utf8, - StringArray - ); - #[cfg(feature = "unicode_expressions")] - test_function!( - Reverse, - &[lit(ScalarValue::Utf8(None))], - Ok(None), - &str, - Utf8, - StringArray - ); - #[cfg(not(feature = "unicode_expressions"))] - test_function!( - Reverse, - &[lit("abcde")], - internal_err!( - "function reverse requires compilation with feature flag: unicode_expressions." - ), - &str, - Utf8, - StringArray - ); - #[cfg(feature = "unicode_expressions")] - test_function!( - Right, - &[lit("abcde"), lit(ScalarValue::Int8(Some(2))),], - Ok(Some("de")), - &str, - Utf8, - StringArray - ); - #[cfg(feature = "unicode_expressions")] - test_function!( - Right, - &[lit("abcde"), lit(ScalarValue::Int64(Some(200))),], - Ok(Some("abcde")), - &str, - Utf8, - StringArray - ); - #[cfg(feature = "unicode_expressions")] - test_function!( - Right, - &[lit("abcde"), lit(ScalarValue::Int64(Some(-2))),], - Ok(Some("cde")), - &str, - Utf8, - StringArray - ); - #[cfg(feature = "unicode_expressions")] - test_function!( - Right, - &[lit("abcde"), lit(ScalarValue::Int64(Some(-200))),], - Ok(Some("")), - &str, - Utf8, - StringArray - ); - #[cfg(feature = "unicode_expressions")] - test_function!( - Right, - &[lit("abcde"), lit(ScalarValue::Int64(Some(0))),], - Ok(Some("")), - &str, - Utf8, - StringArray - ); - #[cfg(feature = "unicode_expressions")] - test_function!( - Right, - &[ - lit(ScalarValue::Utf8(None)), - lit(ScalarValue::Int64(Some(2))), - ], - Ok(None), - &str, - Utf8, - StringArray - ); - #[cfg(feature = "unicode_expressions")] - test_function!( - Right, - &[lit("abcde"), lit(ScalarValue::Int64(None)),], - Ok(None), - &str, - Utf8, - StringArray - ); - #[cfg(feature = "unicode_expressions")] - test_function!( - Right, - &[lit("joséésoj"), lit(ScalarValue::Int64(Some(5))),], - Ok(Some("éésoj")), - &str, - Utf8, - StringArray - ); - #[cfg(feature = "unicode_expressions")] - test_function!( - Right, - &[lit("joséésoj"), lit(ScalarValue::Int64(Some(-3))),], - Ok(Some("éésoj")), - &str, - Utf8, - StringArray - ); - #[cfg(not(feature = "unicode_expressions"))] - test_function!( - Right, - &[ - lit("abcde"), - lit(ScalarValue::Int8(Some(2))), - ], - internal_err!( - "function right requires compilation with feature flag: unicode_expressions." - ), - &str, - Utf8, - StringArray - ); - #[cfg(feature = "unicode_expressions")] - test_function!( - Rpad, - &[lit("josé"), lit(ScalarValue::Int64(Some(5))),], - Ok(Some("josé ")), - &str, - Utf8, - StringArray - ); - #[cfg(feature = "unicode_expressions")] - test_function!( - Rpad, - &[lit("hi"), lit(ScalarValue::Int64(Some(5))),], - Ok(Some("hi ")), - &str, - Utf8, - StringArray - ); - #[cfg(feature = "unicode_expressions")] - test_function!( - Rpad, - &[lit("hi"), lit(ScalarValue::Int64(Some(0))),], - Ok(Some("")), - &str, - Utf8, - StringArray - ); - #[cfg(feature = "unicode_expressions")] - test_function!( - Rpad, - &[lit("hi"), lit(ScalarValue::Int64(None)),], - Ok(None), - &str, - Utf8, - StringArray - ); - #[cfg(feature = "unicode_expressions")] - test_function!( - Rpad, - &[ - lit(ScalarValue::Utf8(None)), - lit(ScalarValue::Int64(Some(5))), - ], - Ok(None), - &str, - Utf8, - StringArray - ); - #[cfg(feature = "unicode_expressions")] - test_function!( - Rpad, - &[lit("hi"), lit(ScalarValue::Int64(Some(5))), lit("xy"),], - Ok(Some("hixyx")), - &str, - Utf8, - StringArray - ); - #[cfg(feature = "unicode_expressions")] - test_function!( - Rpad, - &[lit("hi"), lit(ScalarValue::Int64(Some(21))), lit("abcdef"),], - Ok(Some("hiabcdefabcdefabcdefa")), - &str, - Utf8, - StringArray - ); - #[cfg(feature = "unicode_expressions")] - test_function!( - Rpad, - &[lit("hi"), lit(ScalarValue::Int64(Some(5))), lit(" "),], - Ok(Some("hi ")), - &str, - Utf8, - StringArray - ); - #[cfg(feature = "unicode_expressions")] - test_function!( - Rpad, - &[lit("hi"), lit(ScalarValue::Int64(Some(5))), lit(""),], - Ok(Some("hi")), - &str, - Utf8, - StringArray - ); - #[cfg(feature = "unicode_expressions")] - test_function!( - Rpad, - &[ - lit(ScalarValue::Utf8(None)), - lit(ScalarValue::Int64(Some(5))), - lit("xy"), - ], - Ok(None), - &str, - Utf8, - StringArray - ); - #[cfg(feature = "unicode_expressions")] - test_function!( - Rpad, - &[lit("hi"), lit(ScalarValue::Int64(None)), lit("xy"),], - Ok(None), - &str, - Utf8, - StringArray - ); - #[cfg(feature = "unicode_expressions")] - test_function!( - Rpad, - &[ - lit("hi"), - lit(ScalarValue::Int64(Some(5))), - lit(ScalarValue::Utf8(None)), - ], - Ok(None), - &str, - Utf8, - StringArray - ); - #[cfg(feature = "unicode_expressions")] - test_function!( - Rpad, - &[lit("josé"), lit(ScalarValue::Int64(Some(10))), lit("xy"),], - Ok(Some("joséxyxyxy")), - &str, - Utf8, - StringArray - ); - #[cfg(feature = "unicode_expressions")] - test_function!( - Rpad, - &[lit("josé"), lit(ScalarValue::Int64(Some(10))), lit("éñ"),], - Ok(Some("josééñéñéñ")), - &str, - Utf8, - StringArray - ); - #[cfg(not(feature = "unicode_expressions"))] - test_function!( - Rpad, - &[ - lit("josé"), - lit(ScalarValue::Int64(Some(5))), - ], - internal_err!( - "function rpad requires compilation with feature flag: unicode_expressions." - ), - &str, - Utf8, - StringArray - ); test_function!( EndsWith, &[lit("alphabet"), lit("alph"),], diff --git a/datafusion/physical-expr/src/planner.rs b/datafusion/physical-expr/src/planner.rs index 319d9ca2269a..0dbea09ffb51 100644 --- a/datafusion/physical-expr/src/planner.rs +++ b/datafusion/physical-expr/src/planner.rs @@ -335,11 +335,11 @@ mod tests { use arrow_array::{ArrayRef, BooleanArray, RecordBatch, StringArray}; use arrow_schema::{DataType, Field, Schema}; use datafusion_common::{DFSchema, Result}; - use datafusion_expr::{col, left, Literal}; + use datafusion_expr::{col, lit}; #[test] fn test_create_physical_expr_scalar_input_output() -> Result<()> { - let expr = col("letter").eq(left("APACHE".lit(), 1i64.lit())); + let expr = col("letter").eq(lit("A")); let schema = Schema::new(vec![Field::new("letter", DataType::Utf8, false)]); let df_schema = DFSchema::try_from_qualified_schema("data", &schema)?; diff --git a/datafusion/physical-expr/src/unicode_expressions.rs b/datafusion/physical-expr/src/unicode_expressions.rs index c7e4b7d7c443..faff21111a61 100644 --- a/datafusion/physical-expr/src/unicode_expressions.rs +++ b/datafusion/physical-expr/src/unicode_expressions.rs @@ -21,7 +21,7 @@ //! Unicode expressions -use std::cmp::{max, Ordering}; +use std::cmp::max; use std::sync::Arc; use arrow::{ @@ -36,267 +36,6 @@ use datafusion_common::{ exec_err, Result, }; -/// Returns first n characters in the string, or when n is negative, returns all but last |n| characters. -/// left('abcde', 2) = 'ab' -/// The implementation uses UTF-8 code points as characters -pub fn left(args: &[ArrayRef]) -> Result { - let string_array = as_generic_string_array::(&args[0])?; - let n_array = as_int64_array(&args[1])?; - let result = string_array - .iter() - .zip(n_array.iter()) - .map(|(string, n)| match (string, n) { - (Some(string), Some(n)) => match n.cmp(&0) { - Ordering::Less => { - let len = string.chars().count() as i64; - Some(if n.abs() < len { - string.chars().take((len + n) as usize).collect::() - } else { - "".to_string() - }) - } - Ordering::Equal => Some("".to_string()), - Ordering::Greater => { - Some(string.chars().take(n as usize).collect::()) - } - }, - _ => None, - }) - .collect::>(); - - Ok(Arc::new(result) as ArrayRef) -} - -/// Extends the string to length 'length' by prepending the characters fill (a space by default). If the string is already longer than length then it is truncated (on the right). -/// lpad('hi', 5, 'xy') = 'xyxhi' -pub fn lpad(args: &[ArrayRef]) -> Result { - match args.len() { - 2 => { - let string_array = as_generic_string_array::(&args[0])?; - let length_array = as_int64_array(&args[1])?; - - let result = string_array - .iter() - .zip(length_array.iter()) - .map(|(string, length)| match (string, length) { - (Some(string), Some(length)) => { - if length > i32::MAX as i64 { - return exec_err!( - "lpad requested length {length} too large" - ); - } - - let length = if length < 0 { 0 } else { length as usize }; - if length == 0 { - Ok(Some("".to_string())) - } else { - let graphemes = string.graphemes(true).collect::>(); - if length < graphemes.len() { - Ok(Some(graphemes[..length].concat())) - } else { - let mut s: String = " ".repeat(length - graphemes.len()); - s.push_str(string); - Ok(Some(s)) - } - } - } - _ => Ok(None), - }) - .collect::>>()?; - - Ok(Arc::new(result) as ArrayRef) - } - 3 => { - let string_array = as_generic_string_array::(&args[0])?; - let length_array = as_int64_array(&args[1])?; - let fill_array = as_generic_string_array::(&args[2])?; - - let result = string_array - .iter() - .zip(length_array.iter()) - .zip(fill_array.iter()) - .map(|((string, length), fill)| match (string, length, fill) { - (Some(string), Some(length), Some(fill)) => { - if length > i32::MAX as i64 { - return exec_err!( - "lpad requested length {length} too large" - ); - } - - let length = if length < 0 { 0 } else { length as usize }; - if length == 0 { - Ok(Some("".to_string())) - } else { - let graphemes = string.graphemes(true).collect::>(); - let fill_chars = fill.chars().collect::>(); - - if length < graphemes.len() { - Ok(Some(graphemes[..length].concat())) - } else if fill_chars.is_empty() { - Ok(Some(string.to_string())) - } else { - let mut s = string.to_string(); - let mut char_vector = - Vec::::with_capacity(length - graphemes.len()); - for l in 0..length - graphemes.len() { - char_vector.push( - *fill_chars.get(l % fill_chars.len()).unwrap(), - ); - } - s.insert_str( - 0, - char_vector.iter().collect::().as_str(), - ); - Ok(Some(s)) - } - } - } - _ => Ok(None), - }) - .collect::>>()?; - - Ok(Arc::new(result) as ArrayRef) - } - other => exec_err!( - "lpad was called with {other} arguments. It requires at least 2 and at most 3." - ), - } -} - -/// Reverses the order of the characters in the string. -/// reverse('abcde') = 'edcba' -/// The implementation uses UTF-8 code points as characters -pub fn reverse(args: &[ArrayRef]) -> Result { - let string_array = as_generic_string_array::(&args[0])?; - - let result = string_array - .iter() - .map(|string| string.map(|string: &str| string.chars().rev().collect::())) - .collect::>(); - - Ok(Arc::new(result) as ArrayRef) -} - -/// Returns last n characters in the string, or when n is negative, returns all but first |n| characters. -/// right('abcde', 2) = 'de' -/// The implementation uses UTF-8 code points as characters -pub fn right(args: &[ArrayRef]) -> Result { - let string_array = as_generic_string_array::(&args[0])?; - let n_array = as_int64_array(&args[1])?; - - let result = string_array - .iter() - .zip(n_array.iter()) - .map(|(string, n)| match (string, n) { - (Some(string), Some(n)) => match n.cmp(&0) { - Ordering::Less => Some( - string - .chars() - .skip(n.unsigned_abs() as usize) - .collect::(), - ), - Ordering::Equal => Some("".to_string()), - Ordering::Greater => Some( - string - .chars() - .skip(max(string.chars().count() as i64 - n, 0) as usize) - .collect::(), - ), - }, - _ => None, - }) - .collect::>(); - - Ok(Arc::new(result) as ArrayRef) -} - -/// Extends the string to length 'length' by appending the characters fill (a space by default). If the string is already longer than length then it is truncated. -/// rpad('hi', 5, 'xy') = 'hixyx' -pub fn rpad(args: &[ArrayRef]) -> Result { - match args.len() { - 2 => { - let string_array = as_generic_string_array::(&args[0])?; - let length_array = as_int64_array(&args[1])?; - - let result = string_array - .iter() - .zip(length_array.iter()) - .map(|(string, length)| match (string, length) { - (Some(string), Some(length)) => { - if length > i32::MAX as i64 { - return exec_err!( - "rpad requested length {length} too large" - ); - } - - let length = if length < 0 { 0 } else { length as usize }; - if length == 0 { - Ok(Some("".to_string())) - } else { - let graphemes = string.graphemes(true).collect::>(); - if length < graphemes.len() { - Ok(Some(graphemes[..length].concat())) - } else { - let mut s = string.to_string(); - s.push_str(" ".repeat(length - graphemes.len()).as_str()); - Ok(Some(s)) - } - } - } - _ => Ok(None), - }) - .collect::>>()?; - Ok(Arc::new(result) as ArrayRef) - } - 3 => { - let string_array = as_generic_string_array::(&args[0])?; - let length_array = as_int64_array(&args[1])?; - let fill_array = as_generic_string_array::(&args[2])?; - - let result = string_array - .iter() - .zip(length_array.iter()) - .zip(fill_array.iter()) - .map(|((string, length), fill)| match (string, length, fill) { - (Some(string), Some(length), Some(fill)) => { - if length > i32::MAX as i64 { - return exec_err!( - "rpad requested length {length} too large" - ); - } - - let length = if length < 0 { 0 } else { length as usize }; - let graphemes = string.graphemes(true).collect::>(); - let fill_chars = fill.chars().collect::>(); - - if length < graphemes.len() { - Ok(Some(graphemes[..length].concat())) - } else if fill_chars.is_empty() { - Ok(Some(string.to_string())) - } else { - let mut s = string.to_string(); - let mut char_vector = - Vec::::with_capacity(length - graphemes.len()); - for l in 0..length - graphemes.len() { - char_vector - .push(*fill_chars.get(l % fill_chars.len()).unwrap()); - } - s.push_str(char_vector.iter().collect::().as_str()); - Ok(Some(s)) - } - } - _ => Ok(None), - }) - .collect::>>()?; - - Ok(Arc::new(result) as ArrayRef) - } - other => exec_err!( - "rpad was called with {other} arguments. It requires at least 2 and at most 3." - ), - } -} - /// Returns starting index of specified substring within string, or zero if it's not present. (Same as position(substring in string), but note the reversed argument order.) /// strpos('high', 'ig') = 2 /// The implementation uses UTF-8 code points as characters diff --git a/datafusion/proto/proto/datafusion.proto b/datafusion/proto/proto/datafusion.proto index 766ca6633ee1..6319372d98d2 100644 --- a/datafusion/proto/proto/datafusion.proto +++ b/datafusion/proto/proto/datafusion.proto @@ -572,8 +572,8 @@ enum ScalarFunction { // 28 was DatePart // 29 was DateTrunc InitCap = 30; - Left = 31; - Lpad = 32; + // 31 was Left + // 32 was Lpad // 33 was Lower // 34 was Ltrim // 35 was MD5 @@ -583,9 +583,9 @@ enum ScalarFunction { // 39 was RegexpReplace // 40 was Repeat // 41 was Replace - Reverse = 42; - Right = 43; - Rpad = 44; + // 42 was Reverse + // 43 was Right + // 44 was Rpad // 45 was Rtrim // 46 was SHA224 // 47 was SHA256 diff --git a/datafusion/proto/src/generated/pbjson.rs b/datafusion/proto/src/generated/pbjson.rs index f2814956ef1b..7281bc9dc263 100644 --- a/datafusion/proto/src/generated/pbjson.rs +++ b/datafusion/proto/src/generated/pbjson.rs @@ -22931,12 +22931,7 @@ impl serde::Serialize for ScalarFunction { Self::Concat => "Concat", Self::ConcatWithSeparator => "ConcatWithSeparator", Self::InitCap => "InitCap", - Self::Left => "Left", - Self::Lpad => "Lpad", Self::Random => "Random", - Self::Reverse => "Reverse", - Self::Right => "Right", - Self::Rpad => "Rpad", Self::Strpos => "Strpos", Self::Substr => "Substr", Self::Translate => "Translate", @@ -22990,12 +22985,7 @@ impl<'de> serde::Deserialize<'de> for ScalarFunction { "Concat", "ConcatWithSeparator", "InitCap", - "Left", - "Lpad", "Random", - "Reverse", - "Right", - "Rpad", "Strpos", "Substr", "Translate", @@ -23078,12 +23068,7 @@ impl<'de> serde::Deserialize<'de> for ScalarFunction { "Concat" => Ok(ScalarFunction::Concat), "ConcatWithSeparator" => Ok(ScalarFunction::ConcatWithSeparator), "InitCap" => Ok(ScalarFunction::InitCap), - "Left" => Ok(ScalarFunction::Left), - "Lpad" => Ok(ScalarFunction::Lpad), "Random" => Ok(ScalarFunction::Random), - "Reverse" => Ok(ScalarFunction::Reverse), - "Right" => Ok(ScalarFunction::Right), - "Rpad" => Ok(ScalarFunction::Rpad), "Strpos" => Ok(ScalarFunction::Strpos), "Substr" => Ok(ScalarFunction::Substr), "Translate" => Ok(ScalarFunction::Translate), diff --git a/datafusion/proto/src/generated/prost.rs b/datafusion/proto/src/generated/prost.rs index ecc94fcdaf99..2fe89efb9cea 100644 --- a/datafusion/proto/src/generated/prost.rs +++ b/datafusion/proto/src/generated/prost.rs @@ -2871,8 +2871,8 @@ pub enum ScalarFunction { /// 28 was DatePart /// 29 was DateTrunc InitCap = 30, - Left = 31, - Lpad = 32, + /// 31 was Left + /// 32 was Lpad /// 33 was Lower /// 34 was Ltrim /// 35 was MD5 @@ -2882,9 +2882,9 @@ pub enum ScalarFunction { /// 39 was RegexpReplace /// 40 was Repeat /// 41 was Replace - Reverse = 42, - Right = 43, - Rpad = 44, + /// 42 was Reverse + /// 43 was Right + /// 44 was Rpad /// 45 was Rtrim /// 46 was SHA224 /// 47 was SHA256 @@ -3004,12 +3004,7 @@ impl ScalarFunction { ScalarFunction::Concat => "Concat", ScalarFunction::ConcatWithSeparator => "ConcatWithSeparator", ScalarFunction::InitCap => "InitCap", - ScalarFunction::Left => "Left", - ScalarFunction::Lpad => "Lpad", ScalarFunction::Random => "Random", - ScalarFunction::Reverse => "Reverse", - ScalarFunction::Right => "Right", - ScalarFunction::Rpad => "Rpad", ScalarFunction::Strpos => "Strpos", ScalarFunction::Substr => "Substr", ScalarFunction::Translate => "Translate", @@ -3057,12 +3052,7 @@ impl ScalarFunction { "Concat" => Some(Self::Concat), "ConcatWithSeparator" => Some(Self::ConcatWithSeparator), "InitCap" => Some(Self::InitCap), - "Left" => Some(Self::Left), - "Lpad" => Some(Self::Lpad), "Random" => Some(Self::Random), - "Reverse" => Some(Self::Reverse), - "Right" => Some(Self::Right), - "Rpad" => Some(Self::Rpad), "Strpos" => Some(Self::Strpos), "Substr" => Some(Self::Substr), "Translate" => Some(Self::Translate), diff --git a/datafusion/proto/src/logical_plan/from_proto.rs b/datafusion/proto/src/logical_plan/from_proto.rs index 19edd71a3a80..2c6f2e479b24 100644 --- a/datafusion/proto/src/logical_plan/from_proto.rs +++ b/datafusion/proto/src/logical_plan/from_proto.rs @@ -17,18 +17,6 @@ use std::sync::Arc; -use crate::protobuf::{ - self, - plan_type::PlanTypeEnum::{ - AnalyzedLogicalPlan, FinalAnalyzedLogicalPlan, FinalLogicalPlan, - FinalPhysicalPlan, FinalPhysicalPlanWithStats, InitialLogicalPlan, - InitialPhysicalPlan, InitialPhysicalPlanWithStats, OptimizedLogicalPlan, - OptimizedPhysicalPlan, - }, - AnalyzedLogicalPlanType, CubeNode, GroupingSetNode, OptimizedLogicalPlanType, - OptimizedPhysicalPlanType, PlaceholderNode, RollupNode, -}; - use arrow::{ array::AsArray, buffer::Buffer, @@ -38,6 +26,7 @@ use arrow::{ }, ipc::{reader::read_record_batch, root_as_message}, }; + use datafusion::execution::registry::FunctionRegistry; use datafusion_common::{ arrow_datafusion_err, internal_err, plan_datafusion_err, Column, Constraint, @@ -51,17 +40,29 @@ use datafusion_expr::{ acosh, asinh, atan, atan2, atanh, cbrt, ceil, coalesce, concat_expr, concat_ws_expr, cos, cosh, cot, degrees, ends_with, exp, expr::{self, InList, Sort, WindowFunction}, - factorial, find_in_set, floor, gcd, initcap, iszero, lcm, left, ln, log, log10, log2, + factorial, find_in_set, floor, gcd, initcap, iszero, lcm, ln, log, log10, log2, logical_plan::{PlanType, StringifiedPlan}, - lpad, nanvl, pi, power, radians, random, reverse, right, round, rpad, signum, sin, - sinh, sqrt, strpos, substr, substr_index, substring, translate, trunc, - AggregateFunction, Between, BinaryExpr, BuiltInWindowFunction, BuiltinScalarFunction, - Case, Cast, Expr, GetFieldAccess, GetIndexedField, GroupingSet, + nanvl, pi, power, radians, random, round, signum, sin, sinh, sqrt, strpos, substr, + substr_index, substring, translate, trunc, AggregateFunction, Between, BinaryExpr, + BuiltInWindowFunction, BuiltinScalarFunction, Case, Cast, Expr, GetFieldAccess, + GetIndexedField, GroupingSet, GroupingSet::GroupingSets, JoinConstraint, JoinType, Like, Operator, TryCast, WindowFrame, WindowFrameBound, WindowFrameUnits, }; +use crate::protobuf::{ + self, + plan_type::PlanTypeEnum::{ + AnalyzedLogicalPlan, FinalAnalyzedLogicalPlan, FinalLogicalPlan, + FinalPhysicalPlan, FinalPhysicalPlanWithStats, InitialLogicalPlan, + InitialPhysicalPlan, InitialPhysicalPlanWithStats, OptimizedLogicalPlan, + OptimizedPhysicalPlan, + }, + AnalyzedLogicalPlanType, CubeNode, GroupingSetNode, OptimizedLogicalPlanType, + OptimizedPhysicalPlanType, PlaceholderNode, RollupNode, +}; + use super::LogicalExtensionCodec; #[derive(Debug)] @@ -453,12 +454,7 @@ impl From<&protobuf::ScalarFunction> for BuiltinScalarFunction { ScalarFunction::ConcatWithSeparator => Self::ConcatWithSeparator, ScalarFunction::EndsWith => Self::EndsWith, ScalarFunction::InitCap => Self::InitCap, - ScalarFunction::Left => Self::Left, - ScalarFunction::Lpad => Self::Lpad, ScalarFunction::Random => Self::Random, - ScalarFunction::Reverse => Self::Reverse, - ScalarFunction::Right => Self::Right, - ScalarFunction::Rpad => Self::Rpad, ScalarFunction::Strpos => Self::Strpos, ScalarFunction::Substr => Self::Substr, ScalarFunction::Translate => Self::Translate, @@ -1382,26 +1378,13 @@ pub fn parse_expr( parse_expr(&args[0], registry, codec)?, parse_expr(&args[1], registry, codec)?, )), - ScalarFunction::Left => Ok(left( - parse_expr(&args[0], registry, codec)?, - parse_expr(&args[1], registry, codec)?, - )), ScalarFunction::Random => Ok(random()), - ScalarFunction::Reverse => { - Ok(reverse(parse_expr(&args[0], registry, codec)?)) - } - ScalarFunction::Right => Ok(right( - parse_expr(&args[0], registry, codec)?, - parse_expr(&args[1], registry, codec)?, - )), ScalarFunction::Concat => { Ok(concat_expr(parse_exprs(args, registry, codec)?)) } ScalarFunction::ConcatWithSeparator => { Ok(concat_ws_expr(parse_exprs(args, registry, codec)?)) } - ScalarFunction::Lpad => Ok(lpad(parse_exprs(args, registry, codec)?)), - ScalarFunction::Rpad => Ok(rpad(parse_exprs(args, registry, codec)?)), ScalarFunction::EndsWith => Ok(ends_with( parse_expr(&args[0], registry, codec)?, parse_expr(&args[1], registry, codec)?, diff --git a/datafusion/proto/src/logical_plan/to_proto.rs b/datafusion/proto/src/logical_plan/to_proto.rs index 11fc7362c75d..ea682a5a22f8 100644 --- a/datafusion/proto/src/logical_plan/to_proto.rs +++ b/datafusion/proto/src/logical_plan/to_proto.rs @@ -1445,12 +1445,7 @@ impl TryFrom<&BuiltinScalarFunction> for protobuf::ScalarFunction { BuiltinScalarFunction::ConcatWithSeparator => Self::ConcatWithSeparator, BuiltinScalarFunction::EndsWith => Self::EndsWith, BuiltinScalarFunction::InitCap => Self::InitCap, - BuiltinScalarFunction::Left => Self::Left, - BuiltinScalarFunction::Lpad => Self::Lpad, BuiltinScalarFunction::Random => Self::Random, - BuiltinScalarFunction::Reverse => Self::Reverse, - BuiltinScalarFunction::Right => Self::Right, - BuiltinScalarFunction::Rpad => Self::Rpad, BuiltinScalarFunction::Strpos => Self::Strpos, BuiltinScalarFunction::Substr => Self::Substr, BuiltinScalarFunction::Translate => Self::Translate, From d3fac7bb49c13338019f1cc6ba5c9a77c3244372 Mon Sep 17 00:00:00 2001 From: Bruce Ritchie Date: Thu, 28 Mar 2024 12:56:07 -0400 Subject: [PATCH 6/7] move strpos, substr functions to datafusion_functions --- datafusion/expr/src/built_in_function.rs | 36 +- datafusion/expr/src/expr_fn.rs | 6 - datafusion/functions/src/unicode/mod.rs | 31 ++ datafusion/functions/src/unicode/strpos.rs | 121 ++++++ datafusion/functions/src/unicode/substr.rs | 411 ++++++++++++++++++ datafusion/physical-expr/src/functions.rs | 258 +---------- .../physical-expr/src/unicode_expressions.rs | 95 ---- datafusion/proto/Cargo.toml | 1 + datafusion/proto/proto/datafusion.proto | 4 +- datafusion/proto/src/generated/pbjson.rs | 6 - datafusion/proto/src/generated/prost.rs | 8 +- .../proto/src/logical_plan/from_proto.rs | 29 +- datafusion/proto/src/logical_plan/to_proto.rs | 2 - .../tests/cases/roundtrip_logical_plan.rs | 29 +- datafusion/proto/tests/cases/serialize.rs | 5 +- datafusion/sql/src/expr/mod.rs | 9 +- datafusion/sql/src/expr/substring.rs | 16 +- datafusion/sqllogictest/test_files/scalar.slt | 2 +- 18 files changed, 617 insertions(+), 452 deletions(-) create mode 100644 datafusion/functions/src/unicode/strpos.rs create mode 100644 datafusion/functions/src/unicode/substr.rs diff --git a/datafusion/expr/src/built_in_function.rs b/datafusion/expr/src/built_in_function.rs index 196d278dc70e..423fc11c1d8c 100644 --- a/datafusion/expr/src/built_in_function.rs +++ b/datafusion/expr/src/built_in_function.rs @@ -113,10 +113,6 @@ pub enum BuiltinScalarFunction { InitCap, /// random Random, - /// strpos - Strpos, - /// substr - Substr, /// translate Translate, /// substr_index @@ -211,8 +207,6 @@ impl BuiltinScalarFunction { BuiltinScalarFunction::EndsWith => Volatility::Immutable, BuiltinScalarFunction::InitCap => Volatility::Immutable, BuiltinScalarFunction::Radians => Volatility::Immutable, - BuiltinScalarFunction::Strpos => Volatility::Immutable, - BuiltinScalarFunction::Substr => Volatility::Immutable, BuiltinScalarFunction::Translate => Volatility::Immutable, BuiltinScalarFunction::SubstrIndex => Volatility::Immutable, BuiltinScalarFunction::FindInSet => Volatility::Immutable, @@ -252,12 +246,6 @@ impl BuiltinScalarFunction { BuiltinScalarFunction::Pi => Ok(Float64), BuiltinScalarFunction::Random => Ok(Float64), BuiltinScalarFunction::EndsWith => Ok(Boolean), - BuiltinScalarFunction::Strpos => { - utf8_to_int_type(&input_expr_types[0], "strpos/instr/position") - } - BuiltinScalarFunction::Substr => { - utf8_to_str_type(&input_expr_types[0], "substr") - } BuiltinScalarFunction::SubstrIndex => { utf8_to_str_type(&input_expr_types[0], "substr_index") } @@ -341,24 +329,12 @@ impl BuiltinScalarFunction { Signature::uniform(1, vec![Utf8, LargeUtf8], self.volatility()) } - BuiltinScalarFunction::EndsWith | BuiltinScalarFunction::Strpos => { - Signature::one_of( - vec![ - Exact(vec![Utf8, Utf8]), - Exact(vec![Utf8, LargeUtf8]), - Exact(vec![LargeUtf8, Utf8]), - Exact(vec![LargeUtf8, LargeUtf8]), - ], - self.volatility(), - ) - } - - BuiltinScalarFunction::Substr => Signature::one_of( + BuiltinScalarFunction::EndsWith => Signature::one_of( vec![ - Exact(vec![Utf8, Int64]), - Exact(vec![LargeUtf8, Int64]), - Exact(vec![Utf8, Int64, Int64]), - Exact(vec![LargeUtf8, Int64, Int64]), + Exact(vec![Utf8, Utf8]), + Exact(vec![Utf8, LargeUtf8]), + Exact(vec![LargeUtf8, Utf8]), + Exact(vec![LargeUtf8, LargeUtf8]), ], self.volatility(), ), @@ -537,8 +513,6 @@ impl BuiltinScalarFunction { BuiltinScalarFunction::ConcatWithSeparator => &["concat_ws"], BuiltinScalarFunction::EndsWith => &["ends_with"], BuiltinScalarFunction::InitCap => &["initcap"], - BuiltinScalarFunction::Strpos => &["strpos", "instr", "position"], - BuiltinScalarFunction::Substr => &["substr"], BuiltinScalarFunction::Translate => &["translate"], BuiltinScalarFunction::SubstrIndex => &["substr_index", "substring_index"], BuiltinScalarFunction::FindInSet => &["find_in_set"], diff --git a/datafusion/expr/src/expr_fn.rs b/datafusion/expr/src/expr_fn.rs index 21dab72855e5..09170ae639ff 100644 --- a/datafusion/expr/src/expr_fn.rs +++ b/datafusion/expr/src/expr_fn.rs @@ -579,9 +579,6 @@ scalar_expr!(Log, log, base x, "logarithm of a `x` for a particular `base`"); scalar_expr!(InitCap, initcap, string, "converts the first letter of each word in `string` in uppercase and the remaining characters in lowercase"); scalar_expr!(EndsWith, ends_with, string suffix, "whether the `string` ends with the `suffix`"); -scalar_expr!(Strpos, strpos, string substring, "finds the position from where the `substring` matches the `string`"); -scalar_expr!(Substr, substr, string position, "substring from the `position` to the end"); -scalar_expr!(Substr, substring, string position length, "substring from the `position` with `length` characters"); scalar_expr!(Translate, translate, string from to, "replaces the characters in `from` with the counterpart in `to`"); nary_scalar_expr!(Coalesce, coalesce, "returns `coalesce(args...)`, which evaluates to the value of the first [Expr] which is not NULL"); //there is a func concat_ws before, so use concat_ws_expr as name.c @@ -1015,9 +1012,6 @@ mod test { test_scalar_expr!(Lcm, lcm, arg_1, arg_2); test_scalar_expr!(InitCap, initcap, string); test_scalar_expr!(EndsWith, ends_with, string, characters); - test_scalar_expr!(Strpos, strpos, string, substring); - test_scalar_expr!(Substr, substr, string, position); - test_scalar_expr!(Substr, substring, string, position, count); test_scalar_expr!(Translate, translate, string, from, to); test_scalar_expr!(SubstrIndex, substr_index, string, delimiter, count); test_scalar_expr!(FindInSet, find_in_set, string, stringlist); diff --git a/datafusion/functions/src/unicode/mod.rs b/datafusion/functions/src/unicode/mod.rs index ea4e70a92199..ddab0d1e27c9 100644 --- a/datafusion/functions/src/unicode/mod.rs +++ b/datafusion/functions/src/unicode/mod.rs @@ -27,6 +27,8 @@ mod lpad; mod reverse; mod right; mod rpad; +mod strpos; +mod substr; // create UDFs make_udf_function!( @@ -39,6 +41,8 @@ make_udf_function!(lpad::LPadFunc, LPAD, lpad); make_udf_function!(right::RightFunc, RIGHT, right); make_udf_function!(reverse::ReverseFunc, REVERSE, reverse); make_udf_function!(rpad::RPadFunc, RPAD, rpad); +make_udf_function!(strpos::StrposFunc, STRPOS, strpos); +make_udf_function!(substr::SubstrFunc, SUBSTR, substr); pub mod expr_fn { use datafusion_expr::Expr; @@ -53,6 +57,11 @@ pub mod expr_fn { super::character_length().call(vec![string]) } + #[doc = "finds the position from where the `substring` matches the `string`"] + pub fn instr(string: Expr, substring: Expr) -> Expr { + strpos(string, substring) + } + #[doc = "the number of characters in the `string`"] pub fn length(string: Expr) -> Expr { character_length(string) @@ -68,6 +77,11 @@ pub mod expr_fn { super::lpad().call(args) } + #[doc = "finds the position from where the `substring` matches the `string`"] + pub fn position(string: Expr, substring: Expr) -> Expr { + strpos(string, substring) + } + #[doc = "reverses the `string`"] pub fn reverse(string: Expr) -> Expr { super::reverse().call(vec![string]) @@ -82,6 +96,21 @@ pub mod expr_fn { pub fn rpad(args: Vec) -> Expr { super::rpad().call(args) } + + #[doc = "finds the position from where the `substring` matches the `string`"] + pub fn strpos(string: Expr, substring: Expr) -> Expr { + super::strpos().call(vec![string, substring]) + } + + #[doc = "substring from the `position` to the end"] + pub fn substr(string: Expr, position: Expr) -> Expr { + super::substr().call(vec![string, position]) + } + + #[doc = "substring from the `position` with `length` characters"] + pub fn substring(string: Expr, position: Expr, length: Expr) -> Expr { + super::substr().call(vec![string, position, length]) + } } /// Return a list of all functions in this package @@ -93,5 +122,7 @@ pub fn functions() -> Vec> { reverse(), right(), rpad(), + strpos(), + substr(), ] } diff --git a/datafusion/functions/src/unicode/strpos.rs b/datafusion/functions/src/unicode/strpos.rs new file mode 100644 index 000000000000..1e8bfa37d40e --- /dev/null +++ b/datafusion/functions/src/unicode/strpos.rs @@ -0,0 +1,121 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::any::Any; +use std::sync::Arc; + +use arrow::array::{ + ArrayRef, ArrowPrimitiveType, GenericStringArray, OffsetSizeTrait, PrimitiveArray, +}; +use arrow::datatypes::{ArrowNativeType, DataType, Int32Type, Int64Type}; + +use datafusion_common::cast::as_generic_string_array; +use datafusion_common::{exec_err, Result}; +use datafusion_expr::TypeSignature::Exact; +use datafusion_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility}; + +use crate::utils::{make_scalar_function, utf8_to_int_type}; + +#[derive(Debug)] +pub(super) struct StrposFunc { + signature: Signature, + aliases: Vec, +} + +impl StrposFunc { + pub fn new() -> Self { + use DataType::*; + Self { + signature: Signature::one_of( + vec![ + Exact(vec![Utf8, Utf8]), + Exact(vec![Utf8, LargeUtf8]), + Exact(vec![LargeUtf8, Utf8]), + Exact(vec![LargeUtf8, LargeUtf8]), + ], + Volatility::Immutable, + ), + aliases: vec![String::from("instr"), String::from("position")], + } + } +} + +impl ScalarUDFImpl for StrposFunc { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "strpos" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, arg_types: &[DataType]) -> Result { + utf8_to_int_type(&arg_types[0], "strpos/instr/position") + } + + fn invoke(&self, args: &[ColumnarValue]) -> Result { + match args[0].data_type() { + DataType::Utf8 => make_scalar_function(strpos::, vec![])(args), + DataType::LargeUtf8 => { + make_scalar_function(strpos::, vec![])(args) + } + other => exec_err!("Unsupported data type {other:?} for function strpos"), + } + } + + fn aliases(&self) -> &[String] { + &self.aliases + } +} + +/// Returns starting index of specified substring within string, or zero if it's not present. (Same as position(substring in string), but note the reversed argument order.) +/// strpos('high', 'ig') = 2 +/// The implementation uses UTF-8 code points as characters +fn strpos(args: &[ArrayRef]) -> Result +where + T::Native: OffsetSizeTrait, +{ + let string_array: &GenericStringArray = + as_generic_string_array::(&args[0])?; + + let substring_array: &GenericStringArray = + as_generic_string_array::(&args[1])?; + + let result = string_array + .iter() + .zip(substring_array.iter()) + .map(|(string, substring)| match (string, substring) { + (Some(string), Some(substring)) => { + // the find method returns the byte index of the substring + // Next, we count the number of the chars until that byte + T::Native::from_usize( + string + .find(substring) + .map(|x| string[..x].chars().count() + 1) + .unwrap_or(0), + ) + } + _ => None, + }) + .collect::>(); + + Ok(Arc::new(result) as ArrayRef) +} diff --git a/datafusion/functions/src/unicode/substr.rs b/datafusion/functions/src/unicode/substr.rs new file mode 100644 index 000000000000..7afe8204768a --- /dev/null +++ b/datafusion/functions/src/unicode/substr.rs @@ -0,0 +1,411 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::any::Any; +use std::cmp::max; +use std::sync::Arc; + +use arrow::array::{ArrayRef, GenericStringArray, OffsetSizeTrait}; +use arrow::datatypes::DataType; + +use datafusion_common::cast::{as_generic_string_array, as_int64_array}; +use datafusion_common::{exec_err, Result}; +use datafusion_expr::TypeSignature::Exact; +use datafusion_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility}; + +use crate::utils::{make_scalar_function, utf8_to_str_type}; + +#[derive(Debug)] +pub(super) struct SubstrFunc { + signature: Signature, +} + +impl SubstrFunc { + pub fn new() -> Self { + use DataType::*; + Self { + signature: Signature::one_of( + vec![ + Exact(vec![Utf8, Int64]), + Exact(vec![LargeUtf8, Int64]), + Exact(vec![Utf8, Int64, Int64]), + Exact(vec![LargeUtf8, Int64, Int64]), + ], + Volatility::Immutable, + ), + } + } +} + +impl ScalarUDFImpl for SubstrFunc { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "substr" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, arg_types: &[DataType]) -> Result { + utf8_to_str_type(&arg_types[0], "substr") + } + + fn invoke(&self, args: &[ColumnarValue]) -> Result { + match args[0].data_type() { + DataType::Utf8 => make_scalar_function(substr::, vec![])(args), + DataType::LargeUtf8 => make_scalar_function(substr::, vec![])(args), + other => exec_err!("Unsupported data type {other:?} for function substr"), + } + } +} + +/// Extracts the substring of string starting at the start'th character, and extending for count characters if that is specified. (Same as substring(string from start for count).) +/// substr('alphabet', 3) = 'phabet' +/// substr('alphabet', 3, 2) = 'ph' +/// The implementation uses UTF-8 code points as characters +pub fn substr(args: &[ArrayRef]) -> Result { + match args.len() { + 2 => { + let string_array = as_generic_string_array::(&args[0])?; + let start_array = as_int64_array(&args[1])?; + + let result = string_array + .iter() + .zip(start_array.iter()) + .map(|(string, start)| match (string, start) { + (Some(string), Some(start)) => { + if start <= 0 { + Some(string.to_string()) + } else { + Some(string.chars().skip(start as usize - 1).collect()) + } + } + _ => None, + }) + .collect::>(); + + Ok(Arc::new(result) as ArrayRef) + } + 3 => { + let string_array = as_generic_string_array::(&args[0])?; + let start_array = as_int64_array(&args[1])?; + let count_array = as_int64_array(&args[2])?; + + let result = string_array + .iter() + .zip(start_array.iter()) + .zip(count_array.iter()) + .map(|((string, start), count)| match (string, start, count) { + (Some(string), Some(start), Some(count)) => { + if count < 0 { + exec_err!( + "negative substring length not allowed: substr(, {start}, {count})" + ) + } else { + let skip = max(0, start - 1); + let count = max(0, count + (if start < 1 {start - 1} else {0})); + Ok(Some(string.chars().skip(skip as usize).take(count as usize).collect::())) + } + } + _ => Ok(None), + }) + .collect::>>()?; + + Ok(Arc::new(result) as ArrayRef) + } + other => { + exec_err!("substr was called with {other} arguments. It requires 2 or 3.") + } + } +} + +#[cfg(test)] +mod tests { + use arrow::array::{Array, StringArray}; + use arrow::datatypes::DataType::Utf8; + + use datafusion_common::{exec_err, Result, ScalarValue}; + use datafusion_expr::{ColumnarValue, ScalarUDFImpl}; + + use crate::unicode::substr::SubstrFunc; + use crate::utils::test::test_function; + + #[test] + fn test_functions() -> Result<()> { + #[cfg(feature = "unicode_expressions")] + test_function!( + SubstrFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("alphabet")))), + ColumnarValue::Scalar(ScalarValue::Int64(Some(0))), + ], + Ok(Some("alphabet")), + &str, + Utf8, + StringArray + ); + #[cfg(feature = "unicode_expressions")] + test_function!( + SubstrFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("joséésoj")))), + ColumnarValue::Scalar(ScalarValue::Int64(Some(5))), + ], + Ok(Some("ésoj")), + &str, + Utf8, + StringArray + ); + #[cfg(feature = "unicode_expressions")] + test_function!( + SubstrFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("joséésoj")))), + ColumnarValue::Scalar(ScalarValue::Int64(Some(-5))), + ], + Ok(Some("joséésoj")), + &str, + Utf8, + StringArray + ); + #[cfg(feature = "unicode_expressions")] + test_function!( + SubstrFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("alphabet")))), + ColumnarValue::Scalar(ScalarValue::Int64(Some(1))), + ], + Ok(Some("alphabet")), + &str, + Utf8, + StringArray + ); + #[cfg(feature = "unicode_expressions")] + test_function!( + SubstrFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("alphabet")))), + ColumnarValue::Scalar(ScalarValue::Int64(Some(2))), + ], + Ok(Some("lphabet")), + &str, + Utf8, + StringArray + ); + #[cfg(feature = "unicode_expressions")] + test_function!( + SubstrFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("alphabet")))), + ColumnarValue::Scalar(ScalarValue::Int64(Some(3))), + ], + Ok(Some("phabet")), + &str, + Utf8, + StringArray + ); + #[cfg(feature = "unicode_expressions")] + test_function!( + SubstrFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("alphabet")))), + ColumnarValue::Scalar(ScalarValue::Int64(Some(-3))), + ], + Ok(Some("alphabet")), + &str, + Utf8, + StringArray + ); + #[cfg(feature = "unicode_expressions")] + test_function!( + SubstrFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("alphabet")))), + ColumnarValue::Scalar(ScalarValue::Int64(Some(30))), + ], + Ok(Some("")), + &str, + Utf8, + StringArray + ); + #[cfg(feature = "unicode_expressions")] + test_function!( + SubstrFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("alphabet")))), + ColumnarValue::Scalar(ScalarValue::Int64(None)), + ], + Ok(None), + &str, + Utf8, + StringArray + ); + #[cfg(feature = "unicode_expressions")] + test_function!( + SubstrFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("alphabet")))), + ColumnarValue::Scalar(ScalarValue::Int64(Some(3))), + ColumnarValue::Scalar(ScalarValue::Int64(Some(2))), + ], + Ok(Some("ph")), + &str, + Utf8, + StringArray + ); + #[cfg(feature = "unicode_expressions")] + test_function!( + SubstrFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("alphabet")))), + ColumnarValue::Scalar(ScalarValue::Int64(Some(3))), + ColumnarValue::Scalar(ScalarValue::Int64(Some(20))), + ], + Ok(Some("phabet")), + &str, + Utf8, + StringArray + ); + #[cfg(feature = "unicode_expressions")] + test_function!( + SubstrFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("alphabet")))), + ColumnarValue::Scalar(ScalarValue::Int64(Some(0))), + ColumnarValue::Scalar(ScalarValue::Int64(Some(5))), + ], + Ok(Some("alph")), + &str, + Utf8, + StringArray + ); + // starting from 5 (10 + -5) + #[cfg(feature = "unicode_expressions")] + test_function!( + SubstrFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("alphabet")))), + ColumnarValue::Scalar(ScalarValue::Int64(Some(-5))), + ColumnarValue::Scalar(ScalarValue::Int64(Some(10))), + ], + Ok(Some("alph")), + &str, + Utf8, + StringArray + ); + // starting from -1 (4 + -5) + #[cfg(feature = "unicode_expressions")] + test_function!( + SubstrFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("alphabet")))), + ColumnarValue::Scalar(ScalarValue::Int64(Some(-5))), + ColumnarValue::Scalar(ScalarValue::Int64(Some(4))), + ], + Ok(Some("")), + &str, + Utf8, + StringArray + ); + // starting from 0 (5 + -5) + #[cfg(feature = "unicode_expressions")] + test_function!( + SubstrFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("alphabet")))), + ColumnarValue::Scalar(ScalarValue::Int64(Some(-5))), + ColumnarValue::Scalar(ScalarValue::Int64(Some(5))), + ], + Ok(Some("")), + &str, + Utf8, + StringArray + ); + #[cfg(feature = "unicode_expressions")] + test_function!( + SubstrFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("alphabet")))), + ColumnarValue::Scalar(ScalarValue::Int64(None)), + ColumnarValue::Scalar(ScalarValue::Int64(Some(20))), + ], + Ok(None), + &str, + Utf8, + StringArray + ); + #[cfg(feature = "unicode_expressions")] + test_function!( + SubstrFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("alphabet")))), + ColumnarValue::Scalar(ScalarValue::Int64(Some(3))), + ColumnarValue::Scalar(ScalarValue::Int64(None)), + ], + Ok(None), + &str, + Utf8, + StringArray + ); + #[cfg(feature = "unicode_expressions")] + test_function!( + SubstrFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("alphabet")))), + ColumnarValue::Scalar(ScalarValue::Int64(Some(1))), + ColumnarValue::Scalar(ScalarValue::Int64(Some(-1))), + ], + exec_err!("negative substring length not allowed: substr(, 1, -1)"), + &str, + Utf8, + StringArray + ); + #[cfg(feature = "unicode_expressions")] + test_function!( + SubstrFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("joséésoj")))), + ColumnarValue::Scalar(ScalarValue::Int64(Some(5))), + ColumnarValue::Scalar(ScalarValue::Int64(Some(2))), + ], + Ok(Some("és")), + &str, + Utf8, + StringArray + ); + #[cfg(not(feature = "unicode_expressions"))] + test_function!( + SubstrFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("alphabet")))), + ColumnarValue::Scalar(ScalarValue::Int64(Some(0))), + ], + internal_err!( + "function substr requires compilation with feature flag: unicode_expressions." + ), + &str, + Utf8, + StringArray + ); + + Ok(()) + } +} diff --git a/datafusion/physical-expr/src/functions.rs b/datafusion/physical-expr/src/functions.rs index c1b4900e399a..513dd71d4074 100644 --- a/datafusion/physical-expr/src/functions.rs +++ b/datafusion/physical-expr/src/functions.rs @@ -281,34 +281,6 @@ pub fn create_physical_fun( exec_err!("Unsupported data type {other:?} for function ends_with") } }), - BuiltinScalarFunction::Strpos => Arc::new(|args| match args[0].data_type() { - DataType::Utf8 => { - let func = invoke_if_unicode_expressions_feature_flag!( - strpos, Int32Type, "strpos" - ); - make_scalar_function_inner(func)(args) - } - DataType::LargeUtf8 => { - let func = invoke_if_unicode_expressions_feature_flag!( - strpos, Int64Type, "strpos" - ); - make_scalar_function_inner(func)(args) - } - other => exec_err!("Unsupported data type {other:?} for function strpos"), - }), - BuiltinScalarFunction::Substr => Arc::new(|args| match args[0].data_type() { - DataType::Utf8 => { - let func = - invoke_if_unicode_expressions_feature_flag!(substr, i32, "substr"); - make_scalar_function_inner(func)(args) - } - DataType::LargeUtf8 => { - let func = - invoke_if_unicode_expressions_feature_flag!(substr, i64, "substr"); - make_scalar_function_inner(func)(args) - } - other => exec_err!("Unsupported data type {other:?} for function substr"), - }), BuiltinScalarFunction::Translate => Arc::new(|args| match args[0].data_type() { DataType::Utf8 => { let func = invoke_if_unicode_expressions_feature_flag!( @@ -450,7 +422,7 @@ mod tests { }; use datafusion_common::cast::as_uint64_array; - use datafusion_common::{exec_err, internal_err, plan_err}; + use datafusion_common::{internal_err, plan_err}; use datafusion_common::{DataFusionError, Result, ScalarValue}; use datafusion_expr::type_coercion::functions::data_types; use datafusion_expr::Signature; @@ -663,234 +635,6 @@ mod tests { BooleanArray ); #[cfg(feature = "unicode_expressions")] - test_function!( - Substr, - &[lit("alphabet"), lit(ScalarValue::Int64(Some(0))),], - Ok(Some("alphabet")), - &str, - Utf8, - StringArray - ); - #[cfg(feature = "unicode_expressions")] - test_function!( - Substr, - &[lit("joséésoj"), lit(ScalarValue::Int64(Some(5))),], - Ok(Some("ésoj")), - &str, - Utf8, - StringArray - ); - #[cfg(feature = "unicode_expressions")] - test_function!( - Substr, - &[lit("joséésoj"), lit(ScalarValue::Int64(Some(-5))),], - Ok(Some("joséésoj")), - &str, - Utf8, - StringArray - ); - #[cfg(feature = "unicode_expressions")] - test_function!( - Substr, - &[lit("alphabet"), lit(ScalarValue::Int64(Some(1))),], - Ok(Some("alphabet")), - &str, - Utf8, - StringArray - ); - #[cfg(feature = "unicode_expressions")] - test_function!( - Substr, - &[lit("alphabet"), lit(ScalarValue::Int64(Some(2))),], - Ok(Some("lphabet")), - &str, - Utf8, - StringArray - ); - #[cfg(feature = "unicode_expressions")] - test_function!( - Substr, - &[lit("alphabet"), lit(ScalarValue::Int64(Some(3))),], - Ok(Some("phabet")), - &str, - Utf8, - StringArray - ); - #[cfg(feature = "unicode_expressions")] - test_function!( - Substr, - &[lit("alphabet"), lit(ScalarValue::Int64(Some(-3))),], - Ok(Some("alphabet")), - &str, - Utf8, - StringArray - ); - #[cfg(feature = "unicode_expressions")] - test_function!( - Substr, - &[lit("alphabet"), lit(ScalarValue::Int64(Some(30))),], - Ok(Some("")), - &str, - Utf8, - StringArray - ); - #[cfg(feature = "unicode_expressions")] - test_function!( - Substr, - &[lit("alphabet"), lit(ScalarValue::Int64(None)),], - Ok(None), - &str, - Utf8, - StringArray - ); - #[cfg(feature = "unicode_expressions")] - test_function!( - Substr, - &[ - lit("alphabet"), - lit(ScalarValue::Int64(Some(3))), - lit(ScalarValue::Int64(Some(2))), - ], - Ok(Some("ph")), - &str, - Utf8, - StringArray - ); - #[cfg(feature = "unicode_expressions")] - test_function!( - Substr, - &[ - lit("alphabet"), - lit(ScalarValue::Int64(Some(3))), - lit(ScalarValue::Int64(Some(20))), - ], - Ok(Some("phabet")), - &str, - Utf8, - StringArray - ); - #[cfg(feature = "unicode_expressions")] - test_function!( - Substr, - &[ - lit("alphabet"), - lit(ScalarValue::Int64(Some(0))), - lit(ScalarValue::Int64(Some(5))), - ], - Ok(Some("alph")), - &str, - Utf8, - StringArray - ); - // starting from 5 (10 + -5) - #[cfg(feature = "unicode_expressions")] - test_function!( - Substr, - &[ - lit("alphabet"), - lit(ScalarValue::Int64(Some(-5))), - lit(ScalarValue::Int64(Some(10))), - ], - Ok(Some("alph")), - &str, - Utf8, - StringArray - ); - // starting from -1 (4 + -5) - #[cfg(feature = "unicode_expressions")] - test_function!( - Substr, - &[ - lit("alphabet"), - lit(ScalarValue::Int64(Some(-5))), - lit(ScalarValue::Int64(Some(4))), - ], - Ok(Some("")), - &str, - Utf8, - StringArray - ); - // starting from 0 (5 + -5) - #[cfg(feature = "unicode_expressions")] - test_function!( - Substr, - &[ - lit("alphabet"), - lit(ScalarValue::Int64(Some(-5))), - lit(ScalarValue::Int64(Some(5))), - ], - Ok(Some("")), - &str, - Utf8, - StringArray - ); - #[cfg(feature = "unicode_expressions")] - test_function!( - Substr, - &[ - lit("alphabet"), - lit(ScalarValue::Int64(None)), - lit(ScalarValue::Int64(Some(20))), - ], - Ok(None), - &str, - Utf8, - StringArray - ); - #[cfg(feature = "unicode_expressions")] - test_function!( - Substr, - &[ - lit("alphabet"), - lit(ScalarValue::Int64(Some(3))), - lit(ScalarValue::Int64(None)), - ], - Ok(None), - &str, - Utf8, - StringArray - ); - #[cfg(feature = "unicode_expressions")] - test_function!( - Substr, - &[ - lit("alphabet"), - lit(ScalarValue::Int64(Some(1))), - lit(ScalarValue::Int64(Some(-1))), - ], - exec_err!("negative substring length not allowed: substr(, 1, -1)"), - &str, - Utf8, - StringArray - ); - #[cfg(feature = "unicode_expressions")] - test_function!( - Substr, - &[ - lit("joséésoj"), - lit(ScalarValue::Int64(Some(5))), - lit(ScalarValue::Int64(Some(2))), - ], - Ok(Some("és")), - &str, - Utf8, - StringArray - ); - #[cfg(not(feature = "unicode_expressions"))] - test_function!( - Substr, - &[ - lit("alphabet"), - lit(ScalarValue::Int64(Some(0))), - ], - internal_err!( - "function substr requires compilation with feature flag: unicode_expressions." - ), - &str, - Utf8, - StringArray - ); - #[cfg(feature = "unicode_expressions")] test_function!( Translate, &[lit("12345"), lit("143"), lit("ax"),], diff --git a/datafusion/physical-expr/src/unicode_expressions.rs b/datafusion/physical-expr/src/unicode_expressions.rs index faff21111a61..ecbd1ea320d4 100644 --- a/datafusion/physical-expr/src/unicode_expressions.rs +++ b/datafusion/physical-expr/src/unicode_expressions.rs @@ -21,7 +21,6 @@ //! Unicode expressions -use std::cmp::max; use std::sync::Arc; use arrow::{ @@ -36,100 +35,6 @@ use datafusion_common::{ exec_err, Result, }; -/// Returns starting index of specified substring within string, or zero if it's not present. (Same as position(substring in string), but note the reversed argument order.) -/// strpos('high', 'ig') = 2 -/// The implementation uses UTF-8 code points as characters -pub fn strpos(args: &[ArrayRef]) -> Result -where - T::Native: OffsetSizeTrait, -{ - let string_array: &GenericStringArray = - as_generic_string_array::(&args[0])?; - - let substring_array: &GenericStringArray = - as_generic_string_array::(&args[1])?; - - let result = string_array - .iter() - .zip(substring_array.iter()) - .map(|(string, substring)| match (string, substring) { - (Some(string), Some(substring)) => { - // the find method returns the byte index of the substring - // Next, we count the number of the chars until that byte - T::Native::from_usize( - string - .find(substring) - .map(|x| string[..x].chars().count() + 1) - .unwrap_or(0), - ) - } - _ => None, - }) - .collect::>(); - - Ok(Arc::new(result) as ArrayRef) -} - -/// Extracts the substring of string starting at the start'th character, and extending for count characters if that is specified. (Same as substring(string from start for count).) -/// substr('alphabet', 3) = 'phabet' -/// substr('alphabet', 3, 2) = 'ph' -/// The implementation uses UTF-8 code points as characters -pub fn substr(args: &[ArrayRef]) -> Result { - match args.len() { - 2 => { - let string_array = as_generic_string_array::(&args[0])?; - let start_array = as_int64_array(&args[1])?; - - let result = string_array - .iter() - .zip(start_array.iter()) - .map(|(string, start)| match (string, start) { - (Some(string), Some(start)) => { - if start <= 0 { - Some(string.to_string()) - } else { - Some(string.chars().skip(start as usize - 1).collect()) - } - } - _ => None, - }) - .collect::>(); - - Ok(Arc::new(result) as ArrayRef) - } - 3 => { - let string_array = as_generic_string_array::(&args[0])?; - let start_array = as_int64_array(&args[1])?; - let count_array = as_int64_array(&args[2])?; - - let result = string_array - .iter() - .zip(start_array.iter()) - .zip(count_array.iter()) - .map(|((string, start), count)| match (string, start, count) { - (Some(string), Some(start), Some(count)) => { - if count < 0 { - exec_err!( - "negative substring length not allowed: substr(, {start}, {count})" - ) - } else { - let skip = max(0, start - 1); - let count = max(0, count + (if start < 1 {start - 1} else {0})); - Ok(Some(string.chars().skip(skip as usize).take(count as usize).collect::())) - } - } - _ => Ok(None), - }) - .collect::>>()?; - - Ok(Arc::new(result) as ArrayRef) - } - other => { - exec_err!("substr was called with {other} arguments. It requires 2 or 3.") - } - } -} - /// Replaces each character in string that matches a character in the from set with the corresponding character in the to set. If from is longer than to, occurrences of the extra characters in from are deleted. /// translate('12345', '143', 'ax') = 'a2x5' pub fn translate(args: &[ArrayRef]) -> Result { diff --git a/datafusion/proto/Cargo.toml b/datafusion/proto/Cargo.toml index f5297aefcd1c..bec2b8c53a7a 100644 --- a/datafusion/proto/Cargo.toml +++ b/datafusion/proto/Cargo.toml @@ -54,6 +54,7 @@ serde = { version = "1.0", optional = true } serde_json = { workspace = true, optional = true } [dev-dependencies] +datafusion-functions = { workspace = true, default-features = true } doc-comment = { workspace = true } strum = { version = "0.26.1", features = ["derive"] } tokio = { workspace = true, features = ["rt-multi-thread"] } diff --git a/datafusion/proto/proto/datafusion.proto b/datafusion/proto/proto/datafusion.proto index 6319372d98d2..3a187eabe836 100644 --- a/datafusion/proto/proto/datafusion.proto +++ b/datafusion/proto/proto/datafusion.proto @@ -593,8 +593,8 @@ enum ScalarFunction { // 49 was SHA512 // 50 was SplitPart // StartsWith = 51; - Strpos = 52; - Substr = 53; + // 52 was Strpos + // 53 was Substr // ToHex = 54; // 55 was ToTimestamp // 56 was ToTimestampMillis diff --git a/datafusion/proto/src/generated/pbjson.rs b/datafusion/proto/src/generated/pbjson.rs index 7281bc9dc263..07b91b26d60b 100644 --- a/datafusion/proto/src/generated/pbjson.rs +++ b/datafusion/proto/src/generated/pbjson.rs @@ -22932,8 +22932,6 @@ impl serde::Serialize for ScalarFunction { Self::ConcatWithSeparator => "ConcatWithSeparator", Self::InitCap => "InitCap", Self::Random => "Random", - Self::Strpos => "Strpos", - Self::Substr => "Substr", Self::Translate => "Translate", Self::Coalesce => "Coalesce", Self::Power => "Power", @@ -22986,8 +22984,6 @@ impl<'de> serde::Deserialize<'de> for ScalarFunction { "ConcatWithSeparator", "InitCap", "Random", - "Strpos", - "Substr", "Translate", "Coalesce", "Power", @@ -23069,8 +23065,6 @@ impl<'de> serde::Deserialize<'de> for ScalarFunction { "ConcatWithSeparator" => Ok(ScalarFunction::ConcatWithSeparator), "InitCap" => Ok(ScalarFunction::InitCap), "Random" => Ok(ScalarFunction::Random), - "Strpos" => Ok(ScalarFunction::Strpos), - "Substr" => Ok(ScalarFunction::Substr), "Translate" => Ok(ScalarFunction::Translate), "Coalesce" => Ok(ScalarFunction::Coalesce), "Power" => Ok(ScalarFunction::Power), diff --git a/datafusion/proto/src/generated/prost.rs b/datafusion/proto/src/generated/prost.rs index 2fe89efb9cea..babeccec595f 100644 --- a/datafusion/proto/src/generated/prost.rs +++ b/datafusion/proto/src/generated/prost.rs @@ -2892,8 +2892,8 @@ pub enum ScalarFunction { /// 49 was SHA512 /// 50 was SplitPart /// StartsWith = 51; - Strpos = 52, - Substr = 53, + /// 52 was Strpos + /// 53 was Substr /// ToHex = 54; /// 55 was ToTimestamp /// 56 was ToTimestampMillis @@ -3005,8 +3005,6 @@ impl ScalarFunction { ScalarFunction::ConcatWithSeparator => "ConcatWithSeparator", ScalarFunction::InitCap => "InitCap", ScalarFunction::Random => "Random", - ScalarFunction::Strpos => "Strpos", - ScalarFunction::Substr => "Substr", ScalarFunction::Translate => "Translate", ScalarFunction::Coalesce => "Coalesce", ScalarFunction::Power => "Power", @@ -3053,8 +3051,6 @@ impl ScalarFunction { "ConcatWithSeparator" => Some(Self::ConcatWithSeparator), "InitCap" => Some(Self::InitCap), "Random" => Some(Self::Random), - "Strpos" => Some(Self::Strpos), - "Substr" => Some(Self::Substr), "Translate" => Some(Self::Translate), "Coalesce" => Some(Self::Coalesce), "Power" => Some(Self::Power), diff --git a/datafusion/proto/src/logical_plan/from_proto.rs b/datafusion/proto/src/logical_plan/from_proto.rs index 2c6f2e479b24..ff3d6773d512 100644 --- a/datafusion/proto/src/logical_plan/from_proto.rs +++ b/datafusion/proto/src/logical_plan/from_proto.rs @@ -42,10 +42,10 @@ use datafusion_expr::{ expr::{self, InList, Sort, WindowFunction}, factorial, find_in_set, floor, gcd, initcap, iszero, lcm, ln, log, log10, log2, logical_plan::{PlanType, StringifiedPlan}, - nanvl, pi, power, radians, random, round, signum, sin, sinh, sqrt, strpos, substr, - substr_index, substring, translate, trunc, AggregateFunction, Between, BinaryExpr, - BuiltInWindowFunction, BuiltinScalarFunction, Case, Cast, Expr, GetFieldAccess, - GetIndexedField, GroupingSet, + nanvl, pi, power, radians, random, round, signum, sin, sinh, sqrt, substr_index, + translate, trunc, AggregateFunction, Between, BinaryExpr, BuiltInWindowFunction, + BuiltinScalarFunction, Case, Cast, Expr, GetFieldAccess, GetIndexedField, + GroupingSet, GroupingSet::GroupingSets, JoinConstraint, JoinType, Like, Operator, TryCast, WindowFrame, WindowFrameBound, WindowFrameUnits, @@ -455,8 +455,6 @@ impl From<&protobuf::ScalarFunction> for BuiltinScalarFunction { ScalarFunction::EndsWith => Self::EndsWith, ScalarFunction::InitCap => Self::InitCap, ScalarFunction::Random => Self::Random, - ScalarFunction::Strpos => Self::Strpos, - ScalarFunction::Substr => Self::Substr, ScalarFunction::Translate => Self::Translate, ScalarFunction::Coalesce => Self::Coalesce, ScalarFunction::Pi => Self::Pi, @@ -1389,25 +1387,6 @@ pub fn parse_expr( parse_expr(&args[0], registry, codec)?, parse_expr(&args[1], registry, codec)?, )), - ScalarFunction::Strpos => Ok(strpos( - parse_expr(&args[0], registry, codec)?, - parse_expr(&args[1], registry, codec)?, - )), - ScalarFunction::Substr => { - if args.len() > 2 { - assert_eq!(args.len(), 3); - Ok(substring( - parse_expr(&args[0], registry, codec)?, - parse_expr(&args[1], registry, codec)?, - parse_expr(&args[2], registry, codec)?, - )) - } else { - Ok(substr( - parse_expr(&args[0], registry, codec)?, - parse_expr(&args[1], registry, codec)?, - )) - } - } ScalarFunction::Translate => Ok(translate( parse_expr(&args[0], registry, codec)?, parse_expr(&args[1], registry, codec)?, diff --git a/datafusion/proto/src/logical_plan/to_proto.rs b/datafusion/proto/src/logical_plan/to_proto.rs index ea682a5a22f8..89d49c5658a2 100644 --- a/datafusion/proto/src/logical_plan/to_proto.rs +++ b/datafusion/proto/src/logical_plan/to_proto.rs @@ -1446,8 +1446,6 @@ impl TryFrom<&BuiltinScalarFunction> for protobuf::ScalarFunction { BuiltinScalarFunction::EndsWith => Self::EndsWith, BuiltinScalarFunction::InitCap => Self::InitCap, BuiltinScalarFunction::Random => Self::Random, - BuiltinScalarFunction::Strpos => Self::Strpos, - BuiltinScalarFunction::Substr => Self::Substr, BuiltinScalarFunction::Translate => Self::Translate, BuiltinScalarFunction::Coalesce => Self::Coalesce, BuiltinScalarFunction::Pi => Self::Pi, diff --git a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs index 3c43f100750f..3a47f556c0f3 100644 --- a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs +++ b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs @@ -34,8 +34,8 @@ use datafusion::test_util::{TestTableFactory, TestTableProvider}; use datafusion_common::config::{FormatOptions, TableOptions}; use datafusion_common::scalar::ScalarStructBuilder; use datafusion_common::{ - internal_err, not_impl_err, plan_err, DFField, DFSchema, DFSchemaRef, - DataFusionError, FileType, Result, ScalarValue, + internal_datafusion_err, internal_err, not_impl_err, plan_err, DFField, DFSchema, + DFSchemaRef, DataFusionError, FileType, Result, ScalarValue, }; use datafusion_expr::dml::CopyTo; use datafusion_expr::expr::{ @@ -44,8 +44,7 @@ use datafusion_expr::expr::{ }; use datafusion_expr::logical_plan::{Extension, UserDefinedLogicalNodeCore}; use datafusion_expr::{ - col, create_udaf, lit, Accumulator, AggregateFunction, - BuiltinScalarFunction::{Sqrt, Substr}, + col, create_udaf, lit, Accumulator, AggregateFunction, BuiltinScalarFunction::Sqrt, ColumnarValue, Expr, ExprSchemable, LogicalPlan, Operator, PartitionEvaluator, ScalarUDF, ScalarUDFImpl, Signature, TryCast, Volatility, WindowFrame, WindowFrameBound, WindowFrameUnits, WindowFunctionDefinition, WindowUDF, @@ -60,6 +59,7 @@ use datafusion_proto::logical_plan::LogicalExtensionCodec; use datafusion_proto::logical_plan::{from_proto, DefaultLogicalExtensionCodec}; use datafusion_proto::protobuf; +use datafusion::execution::FunctionRegistry; use prost::Message; #[cfg(feature = "json")] @@ -1863,17 +1863,28 @@ fn roundtrip_cube() { #[test] fn roundtrip_substr() { + let ctx = SessionContext::new(); + + let fun = ctx + .state() + .udf("substr") + .map_err(|e| { + internal_datafusion_err!("Unable to find expected 'substr' function: {e:?}") + }) + .unwrap(); + // substr(string, position) - let test_expr = - Expr::ScalarFunction(ScalarFunction::new(Substr, vec![col("col"), lit(1_i64)])); + let test_expr = Expr::ScalarFunction(ScalarFunction::new_udf( + fun.clone(), + vec![col("col"), lit(1_i64)], + )); // substr(string, position, count) - let test_expr_with_count = Expr::ScalarFunction(ScalarFunction::new( - Substr, + let test_expr_with_count = Expr::ScalarFunction(ScalarFunction::new_udf( + fun, vec![col("col"), lit(1_i64), lit(1_i64)], )); - let ctx = SessionContext::new(); roundtrip_expr_test(test_expr, ctx.clone()); roundtrip_expr_test(test_expr_with_count, ctx); } diff --git a/datafusion/proto/tests/cases/serialize.rs b/datafusion/proto/tests/cases/serialize.rs index d4a1ab44a6ea..972382b841d5 100644 --- a/datafusion/proto/tests/cases/serialize.rs +++ b/datafusion/proto/tests/cases/serialize.rs @@ -260,10 +260,7 @@ fn test_expression_serialization_roundtrip() { let lit = Expr::Literal(ScalarValue::Utf8(None)); for builtin_fun in BuiltinScalarFunction::iter() { // default to 4 args (though some exprs like substr have error checking) - let num_args = match builtin_fun { - BuiltinScalarFunction::Substr => 3, - _ => 4, - }; + let num_args = 4; let args: Vec<_> = std::iter::repeat(&lit).take(num_args).cloned().collect(); let expr = Expr::ScalarFunction(ScalarFunction::new(builtin_fun, args)); diff --git a/datafusion/sql/src/expr/mod.rs b/datafusion/sql/src/expr/mod.rs index d1fc03194997..43bf2d871564 100644 --- a/datafusion/sql/src/expr/mod.rs +++ b/datafusion/sql/src/expr/mod.rs @@ -823,12 +823,17 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { schema: &DFSchema, planner_context: &mut PlannerContext, ) -> Result { - let fun = BuiltinScalarFunction::Strpos; + let fun = self + .context_provider + .get_function_meta("strpos") + .ok_or_else(|| { + internal_datafusion_err!("Unable to find expected 'strpos' function") + })?; let substr = self.sql_expr_to_logical_expr(substr_expr, schema, planner_context)?; let fullstr = self.sql_expr_to_logical_expr(str_expr, schema, planner_context)?; let args = vec![fullstr, substr]; - Ok(Expr::ScalarFunction(ScalarFunction::new(fun, args))) + Ok(Expr::ScalarFunction(ScalarFunction::new_udf(fun, args))) } fn sql_agg_with_filter_to_expr( &self, diff --git a/datafusion/sql/src/expr/substring.rs b/datafusion/sql/src/expr/substring.rs index a5d1abf0f265..f58c6f3b94d0 100644 --- a/datafusion/sql/src/expr/substring.rs +++ b/datafusion/sql/src/expr/substring.rs @@ -16,10 +16,10 @@ // under the License. use crate::planner::{ContextProvider, PlannerContext, SqlToRel}; -use datafusion_common::plan_err; +use datafusion_common::{internal_datafusion_err, plan_err}; use datafusion_common::{DFSchema, Result, ScalarValue}; use datafusion_expr::expr::ScalarFunction; -use datafusion_expr::{BuiltinScalarFunction, Expr}; +use datafusion_expr::Expr; use sqlparser::ast::Expr as SQLExpr; impl<'a, S: ContextProvider> SqlToRel<'a, S> { @@ -68,9 +68,13 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { } }; - Ok(Expr::ScalarFunction(ScalarFunction::new( - BuiltinScalarFunction::Substr, - args, - ))) + let fun = self + .context_provider + .get_function_meta("substr") + .ok_or_else(|| { + internal_datafusion_err!("Unable to find expected 'substr' function") + })?; + + Ok(Expr::ScalarFunction(ScalarFunction::new_udf(fun, args))) } } diff --git a/datafusion/sqllogictest/test_files/scalar.slt b/datafusion/sqllogictest/test_files/scalar.slt index a77a2bf4059c..20c8b3d25fdd 100644 --- a/datafusion/sqllogictest/test_files/scalar.slt +++ b/datafusion/sqllogictest/test_files/scalar.slt @@ -2087,7 +2087,7 @@ select position('' in '') 1 -query error DataFusion error: Error during planning: The STRPOS/INSTR/POSITION function can only accept strings, but got Int64. +query error DataFusion error: Execution error: The STRPOS/INSTR/POSITION function can only accept strings, but got Int64. select position(1 in 1) From 9fce44309615971a892dbba77bf55f2f14e74954 Mon Sep 17 00:00:00 2001 From: Bruce Ritchie Date: Thu, 28 Mar 2024 21:21:45 -0400 Subject: [PATCH 7/7] Cleanup tests --- datafusion/functions/src/unicode/substr.rs | 113 +++++++++------------ 1 file changed, 47 insertions(+), 66 deletions(-) diff --git a/datafusion/functions/src/unicode/substr.rs b/datafusion/functions/src/unicode/substr.rs index 7afe8204768a..403157e2a85a 100644 --- a/datafusion/functions/src/unicode/substr.rs +++ b/datafusion/functions/src/unicode/substr.rs @@ -150,107 +150,98 @@ mod tests { #[test] fn test_functions() -> Result<()> { - #[cfg(feature = "unicode_expressions")] test_function!( SubstrFunc::new(), &[ - ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("alphabet")))), - ColumnarValue::Scalar(ScalarValue::Int64(Some(0))), + ColumnarValue::Scalar(ScalarValue::from("alphabet")), + ColumnarValue::Scalar(ScalarValue::from(0i64)), ], Ok(Some("alphabet")), &str, Utf8, StringArray ); - #[cfg(feature = "unicode_expressions")] test_function!( SubstrFunc::new(), &[ - ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("joséésoj")))), - ColumnarValue::Scalar(ScalarValue::Int64(Some(5))), + ColumnarValue::Scalar(ScalarValue::from("joséésoj")), + ColumnarValue::Scalar(ScalarValue::from(5i64)), ], Ok(Some("ésoj")), &str, Utf8, StringArray ); - #[cfg(feature = "unicode_expressions")] test_function!( SubstrFunc::new(), &[ - ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("joséésoj")))), - ColumnarValue::Scalar(ScalarValue::Int64(Some(-5))), + ColumnarValue::Scalar(ScalarValue::from("joséésoj")), + ColumnarValue::Scalar(ScalarValue::from(-5i64)), ], Ok(Some("joséésoj")), &str, Utf8, StringArray ); - #[cfg(feature = "unicode_expressions")] test_function!( SubstrFunc::new(), &[ - ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("alphabet")))), - ColumnarValue::Scalar(ScalarValue::Int64(Some(1))), + ColumnarValue::Scalar(ScalarValue::from("alphabet")), + ColumnarValue::Scalar(ScalarValue::from(1i64)), ], Ok(Some("alphabet")), &str, Utf8, StringArray ); - #[cfg(feature = "unicode_expressions")] test_function!( SubstrFunc::new(), &[ - ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("alphabet")))), - ColumnarValue::Scalar(ScalarValue::Int64(Some(2))), + ColumnarValue::Scalar(ScalarValue::from("alphabet")), + ColumnarValue::Scalar(ScalarValue::from(2i64)), ], Ok(Some("lphabet")), &str, Utf8, StringArray ); - #[cfg(feature = "unicode_expressions")] test_function!( SubstrFunc::new(), &[ - ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("alphabet")))), - ColumnarValue::Scalar(ScalarValue::Int64(Some(3))), + ColumnarValue::Scalar(ScalarValue::from("alphabet")), + ColumnarValue::Scalar(ScalarValue::from(3i64)), ], Ok(Some("phabet")), &str, Utf8, StringArray ); - #[cfg(feature = "unicode_expressions")] test_function!( SubstrFunc::new(), &[ - ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("alphabet")))), - ColumnarValue::Scalar(ScalarValue::Int64(Some(-3))), + ColumnarValue::Scalar(ScalarValue::from("alphabet")), + ColumnarValue::Scalar(ScalarValue::from(-3i64)), ], Ok(Some("alphabet")), &str, Utf8, StringArray ); - #[cfg(feature = "unicode_expressions")] test_function!( SubstrFunc::new(), &[ - ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("alphabet")))), - ColumnarValue::Scalar(ScalarValue::Int64(Some(30))), + ColumnarValue::Scalar(ScalarValue::from("alphabet")), + ColumnarValue::Scalar(ScalarValue::from(30i64)), ], Ok(Some("")), &str, Utf8, StringArray ); - #[cfg(feature = "unicode_expressions")] test_function!( SubstrFunc::new(), &[ - ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("alphabet")))), + ColumnarValue::Scalar(ScalarValue::from("alphabet")), ColumnarValue::Scalar(ScalarValue::Int64(None)), ], Ok(None), @@ -258,39 +249,36 @@ mod tests { Utf8, StringArray ); - #[cfg(feature = "unicode_expressions")] test_function!( SubstrFunc::new(), &[ - ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("alphabet")))), - ColumnarValue::Scalar(ScalarValue::Int64(Some(3))), - ColumnarValue::Scalar(ScalarValue::Int64(Some(2))), + ColumnarValue::Scalar(ScalarValue::from("alphabet")), + ColumnarValue::Scalar(ScalarValue::from(3i64)), + ColumnarValue::Scalar(ScalarValue::from(2i64)), ], Ok(Some("ph")), &str, Utf8, StringArray ); - #[cfg(feature = "unicode_expressions")] test_function!( SubstrFunc::new(), &[ - ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("alphabet")))), - ColumnarValue::Scalar(ScalarValue::Int64(Some(3))), - ColumnarValue::Scalar(ScalarValue::Int64(Some(20))), + ColumnarValue::Scalar(ScalarValue::from("alphabet")), + ColumnarValue::Scalar(ScalarValue::from(3i64)), + ColumnarValue::Scalar(ScalarValue::from(20i64)), ], Ok(Some("phabet")), &str, Utf8, StringArray ); - #[cfg(feature = "unicode_expressions")] test_function!( SubstrFunc::new(), &[ - ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("alphabet")))), - ColumnarValue::Scalar(ScalarValue::Int64(Some(0))), - ColumnarValue::Scalar(ScalarValue::Int64(Some(5))), + ColumnarValue::Scalar(ScalarValue::from("alphabet")), + ColumnarValue::Scalar(ScalarValue::from(0i64)), + ColumnarValue::Scalar(ScalarValue::from(5i64)), ], Ok(Some("alph")), &str, @@ -298,13 +286,12 @@ mod tests { StringArray ); // starting from 5 (10 + -5) - #[cfg(feature = "unicode_expressions")] test_function!( SubstrFunc::new(), &[ - ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("alphabet")))), - ColumnarValue::Scalar(ScalarValue::Int64(Some(-5))), - ColumnarValue::Scalar(ScalarValue::Int64(Some(10))), + ColumnarValue::Scalar(ScalarValue::from("alphabet")), + ColumnarValue::Scalar(ScalarValue::from(-5i64)), + ColumnarValue::Scalar(ScalarValue::from(10i64)), ], Ok(Some("alph")), &str, @@ -312,13 +299,12 @@ mod tests { StringArray ); // starting from -1 (4 + -5) - #[cfg(feature = "unicode_expressions")] test_function!( SubstrFunc::new(), &[ - ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("alphabet")))), - ColumnarValue::Scalar(ScalarValue::Int64(Some(-5))), - ColumnarValue::Scalar(ScalarValue::Int64(Some(4))), + ColumnarValue::Scalar(ScalarValue::from("alphabet")), + ColumnarValue::Scalar(ScalarValue::from(-5i64)), + ColumnarValue::Scalar(ScalarValue::from(4i64)), ], Ok(Some("")), &str, @@ -326,38 +312,35 @@ mod tests { StringArray ); // starting from 0 (5 + -5) - #[cfg(feature = "unicode_expressions")] test_function!( SubstrFunc::new(), &[ - ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("alphabet")))), - ColumnarValue::Scalar(ScalarValue::Int64(Some(-5))), - ColumnarValue::Scalar(ScalarValue::Int64(Some(5))), + ColumnarValue::Scalar(ScalarValue::from("alphabet")), + ColumnarValue::Scalar(ScalarValue::from(-5i64)), + ColumnarValue::Scalar(ScalarValue::from(5i64)), ], Ok(Some("")), &str, Utf8, StringArray ); - #[cfg(feature = "unicode_expressions")] test_function!( SubstrFunc::new(), &[ - ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("alphabet")))), + ColumnarValue::Scalar(ScalarValue::from("alphabet")), ColumnarValue::Scalar(ScalarValue::Int64(None)), - ColumnarValue::Scalar(ScalarValue::Int64(Some(20))), + ColumnarValue::Scalar(ScalarValue::from(20i64)), ], Ok(None), &str, Utf8, StringArray ); - #[cfg(feature = "unicode_expressions")] test_function!( SubstrFunc::new(), &[ - ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("alphabet")))), - ColumnarValue::Scalar(ScalarValue::Int64(Some(3))), + ColumnarValue::Scalar(ScalarValue::from("alphabet")), + ColumnarValue::Scalar(ScalarValue::from(3i64)), ColumnarValue::Scalar(ScalarValue::Int64(None)), ], Ok(None), @@ -365,26 +348,24 @@ mod tests { Utf8, StringArray ); - #[cfg(feature = "unicode_expressions")] test_function!( SubstrFunc::new(), &[ - ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("alphabet")))), - ColumnarValue::Scalar(ScalarValue::Int64(Some(1))), - ColumnarValue::Scalar(ScalarValue::Int64(Some(-1))), + ColumnarValue::Scalar(ScalarValue::from("alphabet")), + ColumnarValue::Scalar(ScalarValue::from(1i64)), + ColumnarValue::Scalar(ScalarValue::from(-1i64)), ], exec_err!("negative substring length not allowed: substr(, 1, -1)"), &str, Utf8, StringArray ); - #[cfg(feature = "unicode_expressions")] test_function!( SubstrFunc::new(), &[ - ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("joséésoj")))), - ColumnarValue::Scalar(ScalarValue::Int64(Some(5))), - ColumnarValue::Scalar(ScalarValue::Int64(Some(2))), + ColumnarValue::Scalar(ScalarValue::from("joséésoj")), + ColumnarValue::Scalar(ScalarValue::from(5i64)), + ColumnarValue::Scalar(ScalarValue::from(2i64)), ], Ok(Some("és")), &str, @@ -395,8 +376,8 @@ mod tests { test_function!( SubstrFunc::new(), &[ - ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("alphabet")))), - ColumnarValue::Scalar(ScalarValue::Int64(Some(0))), + ColumnarValue::Scalar(ScalarValue::from("alphabet")), + ColumnarValue::Scalar(ScalarValue::from(0i64)), ], internal_err!( "function substr requires compilation with feature flag: unicode_expressions."