Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions datafusion/functions/src/unicode/substr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ use crate::utils::{make_scalar_function, utf8_to_str_type};
#[derive(Debug)]
pub struct SubstrFunc {
signature: Signature,
aliases: Vec<String>,
}

impl Default for SubstrFunc {
Expand All @@ -53,6 +54,7 @@ impl SubstrFunc {
],
Volatility::Immutable,
),
aliases: vec![String::from("substring")],
}
}
}
Expand Down Expand Up @@ -81,6 +83,10 @@ impl ScalarUDFImpl for SubstrFunc {
other => exec_err!("Unsupported data type {other:?} for function substr"),
}
}

fn aliases(&self) -> &[String] {
&self.aliases
}
}

/// Extracts the substring of string starting at the start'th character, and extending for count characters if that is specified. (Same as substring(string from start for count).)
Expand Down
60 changes: 30 additions & 30 deletions datafusion/substrait/src/logical_plan/consumer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -88,36 +88,36 @@ use substrait::proto::{
};
use substrait::proto::{FunctionArgument, SortField};

pub fn name_to_op(name: &str) -> Result<Operator> {
pub fn name_to_op(name: &str) -> Option<Operator> {
match name {
"equal" => Ok(Operator::Eq),
"not_equal" => Ok(Operator::NotEq),
"lt" => Ok(Operator::Lt),
"lte" => Ok(Operator::LtEq),
"gt" => Ok(Operator::Gt),
"gte" => Ok(Operator::GtEq),
"add" => Ok(Operator::Plus),
"subtract" => Ok(Operator::Minus),
"multiply" => Ok(Operator::Multiply),
"divide" => Ok(Operator::Divide),
"mod" => Ok(Operator::Modulo),
"and" => Ok(Operator::And),
"or" => Ok(Operator::Or),
"is_distinct_from" => Ok(Operator::IsDistinctFrom),
"is_not_distinct_from" => Ok(Operator::IsNotDistinctFrom),
"regex_match" => Ok(Operator::RegexMatch),
"regex_imatch" => Ok(Operator::RegexIMatch),
"regex_not_match" => Ok(Operator::RegexNotMatch),
"regex_not_imatch" => Ok(Operator::RegexNotIMatch),
"bitwise_and" => Ok(Operator::BitwiseAnd),
"bitwise_or" => Ok(Operator::BitwiseOr),
"str_concat" => Ok(Operator::StringConcat),
"at_arrow" => Ok(Operator::AtArrow),
"arrow_at" => Ok(Operator::ArrowAt),
"bitwise_xor" => Ok(Operator::BitwiseXor),
"bitwise_shift_right" => Ok(Operator::BitwiseShiftRight),
"bitwise_shift_left" => Ok(Operator::BitwiseShiftLeft),
_ => not_impl_err!("Unsupported function name: {name:?}"),
"equal" => Some(Operator::Eq),
"not_equal" => Some(Operator::NotEq),
"lt" => Some(Operator::Lt),
"lte" => Some(Operator::LtEq),
"gt" => Some(Operator::Gt),
"gte" => Some(Operator::GtEq),
"add" => Some(Operator::Plus),
"subtract" => Some(Operator::Minus),
"multiply" => Some(Operator::Multiply),
"divide" => Some(Operator::Divide),
"mod" => Some(Operator::Modulo),
"and" => Some(Operator::And),
"or" => Some(Operator::Or),
"is_distinct_from" => Some(Operator::IsDistinctFrom),
"is_not_distinct_from" => Some(Operator::IsNotDistinctFrom),
"regex_match" => Some(Operator::RegexMatch),
"regex_imatch" => Some(Operator::RegexIMatch),
"regex_not_match" => Some(Operator::RegexNotMatch),
"regex_not_imatch" => Some(Operator::RegexNotIMatch),
"bitwise_and" => Some(Operator::BitwiseAnd),
"bitwise_or" => Some(Operator::BitwiseOr),
"str_concat" => Some(Operator::StringConcat),
"at_arrow" => Some(Operator::AtArrow),
"arrow_at" => Some(Operator::ArrowAt),
"bitwise_xor" => Some(Operator::BitwiseXor),
"bitwise_shift_right" => Some(Operator::BitwiseShiftRight),
"bitwise_shift_left" => Some(Operator::BitwiseShiftLeft),
_ => None,
}
}

Expand Down Expand Up @@ -1124,7 +1124,7 @@ pub async fn from_substrait_rex(
Ok(Arc::new(Expr::ScalarFunction(
expr::ScalarFunction::new_udf(func.to_owned(), args),
)))
} else if let Ok(op) = name_to_op(fn_name) {
} else if let Some(op) = name_to_op(fn_name) {
if f.arguments.len() < 2 {
return not_impl_err!(
"Expect at least two arguments for binary operator {op:?}, the provided number of operators is {:?}",
Expand Down
32 changes: 20 additions & 12 deletions datafusion/substrait/src/logical_plan/producer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -818,7 +818,7 @@ pub fn to_substrait_agg_measure(
for arg in args {
arguments.push(FunctionArgument { arg_type: Some(ArgType::Value(to_substrait_rex(ctx, arg, schema, 0, extension_info)?)) });
}
let function_anchor = _register_function(fun.to_string(), extension_info);
let function_anchor = register_function(fun.to_string(), extension_info);
Ok(Measure {
measure: Some(AggregateFunction {
function_reference: function_anchor,
Expand Down Expand Up @@ -849,7 +849,7 @@ pub fn to_substrait_agg_measure(
for arg in args {
arguments.push(FunctionArgument { arg_type: Some(ArgType::Value(to_substrait_rex(ctx, arg, schema, 0, extension_info)?)) });
}
let function_anchor = _register_function(fun.name().to_string(), extension_info);
let function_anchor = register_function(fun.name().to_string(), extension_info);
Ok(Measure {
measure: Some(AggregateFunction {
function_reference: function_anchor,
Expand Down Expand Up @@ -917,7 +917,7 @@ fn to_substrait_sort_field(
}
}

fn _register_function(
fn register_function(
function_name: String,
extension_info: &mut (
Vec<extensions::SimpleExtensionDeclaration>,
Expand All @@ -926,6 +926,14 @@ fn _register_function(
) -> u32 {
let (function_extensions, function_set) = extension_info;
let function_name = function_name.to_lowercase();

// Some functions are named differently in Substrait default extensions than in DF
// Rename those to match the Substrait extensions for interoperability
let function_name = match function_name.as_str() {
"substr" => "substring".to_string(),
_ => function_name,
};

// To prevent ambiguous references between ScalarFunctions and AggregateFunctions,
// a plan-relative identifier starting from 0 is used as the function_anchor.
// The consumer is responsible for correctly registering <function_anchor,function_name>
Expand Down Expand Up @@ -969,7 +977,7 @@ pub fn make_binary_op_scalar_func(
),
) -> Expression {
let function_anchor =
_register_function(operator_to_name(op).to_string(), extension_info);
register_function(operator_to_name(op).to_string(), extension_info);
Expression {
rex_type: Some(RexType::ScalarFunction(ScalarFunction {
function_reference: function_anchor,
Expand Down Expand Up @@ -1044,7 +1052,7 @@ pub fn to_substrait_rex(

if *negated {
let function_anchor =
_register_function("not".to_string(), extension_info);
register_function("not".to_string(), extension_info);

Ok(Expression {
rex_type: Some(RexType::ScalarFunction(ScalarFunction {
Expand Down Expand Up @@ -1076,7 +1084,7 @@ pub fn to_substrait_rex(
}

let function_anchor =
_register_function(fun.name().to_string(), extension_info);
register_function(fun.name().to_string(), extension_info);
Ok(Expression {
rex_type: Some(RexType::ScalarFunction(ScalarFunction {
function_reference: function_anchor,
Expand Down Expand Up @@ -1252,7 +1260,7 @@ pub fn to_substrait_rex(
null_treatment: _,
}) => {
// function reference
let function_anchor = _register_function(fun.to_string(), extension_info);
let function_anchor = register_function(fun.to_string(), extension_info);
// arguments
let mut arguments: Vec<FunctionArgument> = vec![];
for arg in args {
Expand Down Expand Up @@ -1330,7 +1338,7 @@ pub fn to_substrait_rex(
};
if *negated {
let function_anchor =
_register_function("not".to_string(), extension_info);
register_function("not".to_string(), extension_info);

Ok(Expression {
rex_type: Some(RexType::ScalarFunction(ScalarFunction {
Expand Down Expand Up @@ -1727,9 +1735,9 @@ fn make_substrait_like_expr(
),
) -> Result<Expression> {
let function_anchor = if ignore_case {
_register_function("ilike".to_string(), extension_info)
register_function("ilike".to_string(), extension_info)
} else {
_register_function("like".to_string(), extension_info)
register_function("like".to_string(), extension_info)
};
let expr = to_substrait_rex(ctx, expr, schema, col_ref_offset, extension_info)?;
let pattern = to_substrait_rex(ctx, pattern, schema, col_ref_offset, extension_info)?;
Expand Down Expand Up @@ -1759,7 +1767,7 @@ fn make_substrait_like_expr(
};

if negated {
let function_anchor = _register_function("not".to_string(), extension_info);
let function_anchor = register_function("not".to_string(), extension_info);

Ok(Expression {
rex_type: Some(RexType::ScalarFunction(ScalarFunction {
Expand Down Expand Up @@ -2128,7 +2136,7 @@ fn to_substrait_unary_scalar_fn(
HashMap<String, u32>,
),
) -> Result<Expression> {
let function_anchor = _register_function(fn_name.to_string(), extension_info);
let function_anchor = register_function(fn_name.to_string(), extension_info);
let substrait_expr =
to_substrait_rex(ctx, arg, schema, col_ref_offset, extension_info)?;

Expand Down
2 changes: 1 addition & 1 deletion datafusion/substrait/tests/cases/roundtrip_logical_plan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -327,7 +327,7 @@ async fn simple_scalar_function_pow() -> Result<()> {

#[tokio::test]
async fn simple_scalar_function_substr() -> Result<()> {
roundtrip("SELECT * FROM data WHERE a = SUBSTR('datafusion', 0, 3)").await
roundtrip("SELECT SUBSTR(f, 1, 3) FROM data").await
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is there a reason to change the test?

Maybe we could add this particular query as an additioanl query (to show the existing behavior is not changed). Something like

Suggested change
roundtrip("SELECT SUBSTR(f, 1, 3) FROM data").await
roundtrip("SELECT * FROM data WHERE a = SUBSTR('datafusion', 0, 3)").await
roundtrip("SELECT SUBSTR(f, 1, 3) FROM data").await

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is there a reason to change the test?

Yes - the original query gets optimized by DF into SELECT * FROM data WHERE a = "dat" before being converted into Substrait, i.e. the whole SUBSTRING call is optimized away:

Filter: CAST(data.a AS Utf8) = Utf8("da")
  TableScan: data projection=[a, b, c, d, e, f, g], partial_filters=[CAST(data.a AS Utf8) = Utf8("da")]

Maybe we could add this particular query as an additioanl query (to show the existing behavior is not changed).

I can add it back, but it doesn't really test what it tries to test 😅 given that would you still like to have it?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for the explanation

No need to change the PR

The issue seems to be that DataFusion partially evaluates the expression

}

#[tokio::test]
Expand Down
8 changes: 8 additions & 0 deletions docs/source/user-guide/sql/scalar_functions.md
Original file line number Diff line number Diff line change
Expand Up @@ -1132,6 +1132,14 @@ substr(str, start_pos[, length])
- **length**: Number of characters to extract.
If not specified, returns the rest of the string after the start position.

#### Aliases

- substring

### `substring`

_Alias of [substr](#substr)._

### `translate`

Translates characters in a string to specified translation characters.
Expand Down