diff --git a/datafusion/expr/src/expr_schema.rs b/datafusion/expr/src/expr_schema.rs index 94f5b0480b65..7fee7c61fcd8 100644 --- a/datafusion/expr/src/expr_schema.rs +++ b/datafusion/expr/src/expr_schema.rs @@ -58,6 +58,10 @@ pub trait ExprSchemable { fn cast_to(self, cast_to_type: &DataType, schema: &dyn ExprSchema) -> Result; /// Given a schema, return the type and nullability of the expr + #[deprecated( + since = "51.0.0", + note = "Use `to_field().1.is_nullable` and `to_field().1.data_type()` directly instead" + )] fn data_type_and_nullable(&self, schema: &dyn ExprSchema) -> Result<(DataType, bool)>; } @@ -150,7 +154,7 @@ impl ExprSchemable for Expr { } } Expr::ScalarFunction(_func) => { - let (return_type, _) = self.data_type_and_nullable(schema)?; + let return_type = self.to_field(schema)?.1.data_type().clone(); Ok(return_type) } Expr::WindowFunction(window_function) => self @@ -350,7 +354,9 @@ impl ExprSchemable for Expr { } Expr::Cast(Cast { expr, .. }) => expr.nullable(input_schema), Expr::ScalarFunction(_func) => { - let (_, nullable) = self.data_type_and_nullable(input_schema)?; + let field = self.to_field(input_schema)?.1; + + let nullable = field.is_nullable(); Ok(nullable) } Expr::AggregateFunction(AggregateFunction { func, .. }) => { @@ -530,9 +536,14 @@ impl ExprSchemable for Expr { ref right, ref op, }) => { - let (lhs_type, lhs_nullable) = left.data_type_and_nullable(schema)?; - let (rhs_type, rhs_nullable) = right.data_type_and_nullable(schema)?; - let mut coercer = BinaryTypeCoercer::new(&lhs_type, op, &rhs_type); + let (left_field, right_field) = + (left.to_field(schema)?.1, right.to_field(schema)?.1); + + let (lhs_type, lhs_nullable) = + (left_field.data_type(), left_field.is_nullable()); + let (rhs_type, rhs_nullable) = + (right_field.data_type(), right_field.is_nullable()); + let mut coercer = BinaryTypeCoercer::new(lhs_type, op, rhs_type); coercer.set_lhs_spans(left.spans().cloned().unwrap_or_default()); coercer.set_rhs_spans(right.spans().cloned().unwrap_or_default()); Ok(Arc::new(Field::new( @@ -1127,16 +1138,18 @@ mod tests { ), )); + let field = expr.to_field(&schema).unwrap().1; assert_eq!( - expr.data_type_and_nullable(&schema).unwrap(), - (DataType::Utf8, true) + (field.data_type(), field.is_nullable()), + (&DataType::Utf8, true) ); assert_eq!(placeholder_meta, expr.metadata(&schema).unwrap()); let expr_alias = expr.alias("a placeholder by any other name"); + let expr_alias_field = expr_alias.to_field(&schema).unwrap().1; assert_eq!( - expr_alias.data_type_and_nullable(&schema).unwrap(), - (DataType::Utf8, true) + (expr_alias_field.data_type(), expr_alias_field.is_nullable()), + (&DataType::Utf8, true) ); assert_eq!(placeholder_meta, expr_alias.metadata(&schema).unwrap()); @@ -1145,14 +1158,17 @@ mod tests { "".to_string(), Some(Field::new("", DataType::Utf8, false).into()), )); + let expr_field = expr.to_field(&schema).unwrap().1; assert_eq!( - expr.data_type_and_nullable(&schema).unwrap(), - (DataType::Utf8, false) + (expr_field.data_type(), expr_field.is_nullable()), + (&DataType::Utf8, false) ); + let expr_alias = expr.alias("a placeholder by any other name"); + let expr_alias_field = expr_alias.to_field(&schema).unwrap().1; assert_eq!( - expr_alias.data_type_and_nullable(&schema).unwrap(), - (DataType::Utf8, false) + (expr_alias_field.data_type(), expr_alias_field.is_nullable()), + (&DataType::Utf8, false) ); } diff --git a/datafusion/sql/src/expr/unary_op.rs b/datafusion/sql/src/expr/unary_op.rs index e0c94543f601..f63140230b60 100644 --- a/datafusion/sql/src/expr/unary_op.rs +++ b/datafusion/sql/src/expr/unary_op.rs @@ -38,10 +38,11 @@ impl SqlToRel<'_, S> { UnaryOperator::Plus => { let operand = self.sql_expr_to_logical_expr(expr, schema, planner_context)?; - let (data_type, _) = operand.data_type_and_nullable(schema)?; + let field = operand.to_field(schema)?.1; + let data_type = field.data_type(); if data_type.is_numeric() - || is_interval(&data_type) - || is_timestamp(&data_type) + || is_interval(data_type) + || is_timestamp(data_type) { Ok(operand) } else { diff --git a/datafusion/sql/src/utils.rs b/datafusion/sql/src/utils.rs index 91ab2e003c87..7a677ef8a8b2 100644 --- a/datafusion/sql/src/utils.rs +++ b/datafusion/sql/src/utils.rs @@ -417,8 +417,8 @@ impl RecursiveUnnestRewriter<'_> { // This is due to the fact that unnest transformation should keep the original // column name as is, to comply with group by and order by let placeholder_column = Column::from_name(placeholder_name.clone()); - - let (data_type, _) = expr_in_unnest.data_type_and_nullable(self.input_schema)?; + let field = expr_in_unnest.to_field(self.input_schema)?.1; + let data_type = field.data_type(); match data_type { DataType::Struct(inner_fields) => { @@ -432,12 +432,10 @@ impl RecursiveUnnestRewriter<'_> { ); self.columns_unnestings .insert(Column::from_name(placeholder_name.clone()), None); - Ok( - get_struct_unnested_columns(&placeholder_name, &inner_fields) - .into_iter() - .map(Expr::Column) - .collect(), - ) + Ok(get_struct_unnested_columns(&placeholder_name, inner_fields) + .into_iter() + .map(Expr::Column) + .collect()) } DataType::List(_) | DataType::FixedSizeList(_, _) @@ -478,8 +476,8 @@ impl TreeNodeRewriter for RecursiveUnnestRewriter<'_> { /// is used to detect if some recursive unnest expr exists (e.g **unnest(unnest(unnest(3d column))))** fn f_down(&mut self, expr: Expr) -> Result> { if let Expr::Unnest(ref unnest_expr) = expr { - let (data_type, _) = - unnest_expr.expr.data_type_and_nullable(self.input_schema)?; + let field = unnest_expr.expr.to_field(self.input_schema)?.1; + let data_type = field.data_type(); self.consecutive_unnest.push(Some(unnest_expr.clone())); // if expr inside unnest is a struct, do not consider // the next unnest as consecutive unnest (if any) diff --git a/datafusion/substrait/src/logical_plan/consumer/expr/mod.rs b/datafusion/substrait/src/logical_plan/consumer/expr/mod.rs index ba892259852a..b1dc5ab70f81 100644 --- a/datafusion/substrait/src/logical_plan/consumer/expr/mod.rs +++ b/datafusion/substrait/src/logical_plan/consumer/expr/mod.rs @@ -143,9 +143,7 @@ pub async fn from_substrait_extended_expr( let expr = consumer .consume_expression(scalar_expr, &input_schema) .await?; - let (output_type, expected_nullability) = - expr.data_type_and_nullable(&input_schema)?; - let output_field = Field::new("", output_type, expected_nullability); + let output_field = expr.to_field(&input_schema)?.1; let mut names_idx = 0; let output_field = rename_field( &output_field,