diff --git a/datafusion/expr/src/expr_fn.rs b/datafusion/expr/src/expr_fn.rs index e0954963fade..9d8b05f2b155 100644 --- a/datafusion/expr/src/expr_fn.rs +++ b/datafusion/expr/src/expr_fn.rs @@ -403,6 +403,7 @@ scalar_expr!(SplitPart, split_part, expr, delimiter, index); scalar_expr!(StartsWith, starts_with, string, characters); scalar_expr!(Strpos, strpos, string, substring); scalar_expr!(Substr, substr, string, position); +scalar_expr!(Substr, substring, string, position, count); scalar_expr!(ToHex, to_hex, string); scalar_expr!(Translate, translate, string, from, to); scalar_expr!(Trim, trim, string); @@ -656,6 +657,7 @@ mod test { test_scalar_expr!(StartsWith, starts_with, string, characters); test_scalar_expr!(Strpos, strpos, string, substring); test_scalar_expr!(Substr, substr, string, position); + test_scalar_expr!(Substr, substring, string, position, count); test_scalar_expr!(ToHex, to_hex, string); test_scalar_expr!(Translate, translate, string, from, to); test_scalar_expr!(Trim, trim, string); diff --git a/datafusion/proto/src/from_proto.rs b/datafusion/proto/src/from_proto.rs index b29891746d09..79b477b3e1ef 100644 --- a/datafusion/proto/src/from_proto.rs +++ b/datafusion/proto/src/from_proto.rs @@ -39,10 +39,11 @@ use datafusion_expr::{ logical_plan::{PlanType, StringifiedPlan}, lower, lpad, ltrim, md5, now, nullif, octet_length, power, random, regexp_match, regexp_replace, repeat, replace, reverse, right, round, rpad, rtrim, sha224, sha256, - sha384, sha512, signum, sin, split_part, sqrt, starts_with, strpos, substr, tan, - to_hex, to_timestamp_micros, to_timestamp_millis, to_timestamp_seconds, translate, - trim, trunc, upper, AggregateFunction, Between, BuiltInWindowFunction, - BuiltinScalarFunction, Case, Expr, GetIndexedField, GroupingSet, + sha384, sha512, signum, sin, split_part, sqrt, starts_with, strpos, substr, + substring, tan, to_hex, to_timestamp_micros, to_timestamp_millis, + to_timestamp_seconds, translate, trim, trunc, upper, AggregateFunction, Between, + BuiltInWindowFunction, BuiltinScalarFunction, Case, Expr, GetIndexedField, + GroupingSet, GroupingSet::GroupingSets, Like, Operator, WindowFrame, WindowFrameBound, WindowFrameUnits, }; @@ -1137,10 +1138,21 @@ pub fn parse_expr( parse_expr(&args[0], registry)?, parse_expr(&args[1], registry)?, )), - ScalarFunction::Substr => Ok(substr( - parse_expr(&args[0], registry)?, - parse_expr(&args[1], registry)?, - )), + ScalarFunction::Substr => { + if args.len() > 2 { + assert_eq!(args.len(), 3); + Ok(substring( + parse_expr(&args[0], registry)?, + parse_expr(&args[1], registry)?, + parse_expr(&args[2], registry)?, + )) + } else { + Ok(substr( + parse_expr(&args[0], registry)?, + parse_expr(&args[1], registry)?, + )) + } + } ScalarFunction::ToHex => Ok(to_hex(parse_expr(&args[0], registry)?)), ScalarFunction::ToTimestampMillis => { Ok(to_timestamp_millis(parse_expr(&args[0], registry)?)) diff --git a/datafusion/proto/src/lib.rs b/datafusion/proto/src/lib.rs index 5f665c9f25a4..7feae7965215 100644 --- a/datafusion/proto/src/lib.rs +++ b/datafusion/proto/src/lib.rs @@ -67,7 +67,8 @@ mod roundtrip_tests { use datafusion_expr::logical_plan::{Extension, UserDefinedLogicalNode}; use datafusion_expr::{ col, lit, Accumulator, AggregateFunction, AggregateState, - BuiltinScalarFunction::Sqrt, Expr, LogicalPlan, Operator, Volatility, + BuiltinScalarFunction::{Sqrt, Substr}, + Expr, LogicalPlan, Operator, Volatility, }; use prost::Message; use std::any::Any; @@ -1149,4 +1150,23 @@ mod roundtrip_tests { let ctx = SessionContext::new(); roundtrip_expr_test(test_expr, ctx); } + + #[test] + fn roundtrip_substr() { + // substr(string, position) + let test_expr = Expr::ScalarFunction { + fun: Substr, + args: vec![col("col"), lit(1_i64)], + }; + + // substr(string, position, count) + let test_expr_with_count = Expr::ScalarFunction { + fun: Substr, + args: vec![col("col"), lit(1_i64), lit(1_i64)], + }; + + let ctx = SessionContext::new(); + roundtrip_expr_test(test_expr, ctx.clone()); + roundtrip_expr_test(test_expr_with_count, ctx); + } }