diff --git a/datafusion/core/src/dataframe.rs b/datafusion/core/src/dataframe.rs index 06768b5631ca..a5caad176558 100644 --- a/datafusion/core/src/dataframe.rs +++ b/datafusion/core/src/dataframe.rs @@ -827,7 +827,6 @@ impl TableProvider for DataFrame { #[cfg(test)] mod tests { - use arrow::array::Int32Array; use std::vec; use super::*; diff --git a/datafusion/core/src/physical_plan/planner.rs b/datafusion/core/src/physical_plan/planner.rs index 86d005776353..8bb1d95a48a6 100644 --- a/datafusion/core/src/physical_plan/planner.rs +++ b/datafusion/core/src/physical_plan/planner.rs @@ -111,19 +111,15 @@ fn create_physical_name(e: &Expr, is_first_expr: bool) -> Result { let right = create_physical_name(right, false)?; Ok(format!("{} {} {}", left, op, right)) } - Expr::Case { - expr, - when_then_expr, - else_expr, - } => { + Expr::Case(case) => { let mut name = "CASE ".to_string(); - if let Some(e) = expr { + if let Some(e) = &case.expr { let _ = write!(name, "{:?} ", e); } - for (w, t) in when_then_expr { + for (w, t) in &case.when_then_expr { let _ = write!(name, "WHEN {:?} THEN {:?} ", w, t); } - if let Some(e) = else_expr { + if let Some(e) = &case.else_expr { let _ = write!(name, "ELSE {:?} ", e); } name += "END"; diff --git a/datafusion/expr/src/conditional_expressions.rs b/datafusion/expr/src/conditional_expressions.rs index 0c5104a4bfce..28ac6e8cd0c1 100644 --- a/datafusion/expr/src/conditional_expressions.rs +++ b/datafusion/expr/src/conditional_expressions.rs @@ -15,6 +15,7 @@ // specific language governing permissions and limitations // under the License. +use crate::expr::Case; ///! Conditional expressions use crate::{expr_schema::ExprSchemable, Expr}; use arrow::datatypes::DataType; @@ -108,16 +109,15 @@ impl CaseBuilder { } } - Ok(Expr::Case { - expr: self.expr.clone(), - when_then_expr: self - .when_expr + Ok(Expr::Case(Case::new( + self.expr.clone(), + self.when_expr .iter() .zip(self.then_expr.iter()) .map(|(w, t)| (Box::new(w.clone()), Box::new(t.clone()))) .collect(), - else_expr: self.else_expr.clone(), - }) + self.else_expr.clone(), + ))) } } diff --git a/datafusion/expr/src/expr.rs b/datafusion/expr/src/expr.rs index 008a2c454d83..c131682a8ef6 100644 --- a/datafusion/expr/src/expr.rs +++ b/datafusion/expr/src/expr.rs @@ -176,14 +176,7 @@ pub enum Expr { /// [WHEN ...] /// [ELSE result] /// END - Case { - /// Optional base expression that can be compared to literal values in the "when" expressions - expr: Option>, - /// One or more when/then expressions - when_then_expr: Vec<(Box, Box)>, - /// Optional "else" expression - else_expr: Option>, - }, + Case(Case), /// Casts the expression to a given type and will return a runtime error if the expression cannot be cast. /// This expression is guaranteed to have a fixed type. Cast { @@ -292,6 +285,32 @@ pub enum Expr { GroupingSet(GroupingSet), } +/// CASE expression +#[derive(Clone, PartialEq, Eq, Hash)] +pub struct Case { + /// Optional base expression that can be compared to literal values in the "when" expressions + pub expr: Option>, + /// One or more when/then expressions + pub when_then_expr: Vec<(Box, Box)>, + /// Optional "else" expression + pub else_expr: Option>, +} + +impl Case { + /// Create a new Case expression + pub fn new( + expr: Option>, + when_then_expr: Vec<(Box, Box)>, + else_expr: Option>, + ) -> Self { + Self { + expr, + when_then_expr, + else_expr, + } + } +} + /// Grouping sets /// See https://www.postgresql.org/docs/current/queries-table-expressions.html#QUERIES-GROUPING-SETS /// for Postgres definition. @@ -601,20 +620,15 @@ impl fmt::Debug for Expr { Expr::Column(c) => write!(f, "{}", c), Expr::ScalarVariable(_, var_names) => write!(f, "{}", var_names.join(".")), Expr::Literal(v) => write!(f, "{:?}", v), - Expr::Case { - expr, - when_then_expr, - else_expr, - .. - } => { + Expr::Case(case) => { write!(f, "CASE ")?; - if let Some(e) = expr { + if let Some(e) = &case.expr { write!(f, "{:?} ", e)?; } - for (w, t) in when_then_expr { + for (w, t) in &case.when_then_expr { write!(f, "WHEN {:?} THEN {:?} ", w, t)?; } - if let Some(e) = else_expr { + if let Some(e) = &case.else_expr { write!(f, "ELSE {:?} ", e)?; } write!(f, "END") @@ -957,22 +971,18 @@ fn create_name(e: &Expr) -> Result { ); Ok(s) } - Expr::Case { - expr, - when_then_expr, - else_expr, - } => { + Expr::Case(case) => { let mut name = "CASE ".to_string(); - if let Some(e) = expr { + if let Some(e) = &case.expr { let e = create_name(e)?; let _ = write!(name, "{} ", e); } - for (w, t) in when_then_expr { + for (w, t) in &case.when_then_expr { let when = create_name(w)?; let then = create_name(t)?; let _ = write!(name, "WHEN {} THEN {} ", when, then); } - if let Some(e) = else_expr { + if let Some(e) = &case.else_expr { let e = create_name(e)?; let _ = write!(name, "ELSE {} ", e); } diff --git a/datafusion/expr/src/expr_rewriter.rs b/datafusion/expr/src/expr_rewriter.rs index 6bdb54522c63..427fcf170174 100644 --- a/datafusion/expr/src/expr_rewriter.rs +++ b/datafusion/expr/src/expr_rewriter.rs @@ -17,7 +17,7 @@ //! Expression rewriter -use crate::expr::GroupingSet; +use crate::expr::{Case, GroupingSet}; use crate::logical_plan::{Aggregate, Projection}; use crate::utils::{from_plan, grouping_set_to_exprlist}; use crate::{Expr, ExprSchemable, LogicalPlan}; @@ -184,13 +184,10 @@ impl ExprRewritable for Expr { high: rewrite_boxed(high, rewriter)?, negated, }, - Expr::Case { - expr, - when_then_expr, - else_expr, - } => { - let expr = rewrite_option_box(expr, rewriter)?; - let when_then_expr = when_then_expr + Expr::Case(case) => { + let expr = rewrite_option_box(case.expr, rewriter)?; + let when_then_expr = case + .when_then_expr .into_iter() .map(|(when, then)| { Ok(( @@ -200,13 +197,9 @@ impl ExprRewritable for Expr { }) .collect::>>()?; - let else_expr = rewrite_option_box(else_expr, rewriter)?; + let else_expr = rewrite_option_box(case.else_expr, rewriter)?; - Expr::Case { - expr, - when_then_expr, - else_expr, - } + Expr::Case(Case::new(expr, when_then_expr, else_expr)) } Expr::Cast { expr, data_type } => Expr::Cast { expr: rewrite_boxed(expr, rewriter)?, diff --git a/datafusion/expr/src/expr_schema.rs b/datafusion/expr/src/expr_schema.rs index 88d767366cef..5442a24212d2 100644 --- a/datafusion/expr/src/expr_schema.rs +++ b/datafusion/expr/src/expr_schema.rs @@ -59,7 +59,7 @@ impl ExprSchemable for Expr { Expr::Column(c) => Ok(schema.data_type(c)?.clone()), Expr::ScalarVariable(ty, _) => Ok(ty.clone()), Expr::Literal(l) => Ok(l.get_datatype()), - Expr::Case { when_then_expr, .. } => when_then_expr[0].1.get_type(schema), + Expr::Case(case) => case.when_then_expr[0].1.get_type(schema), Expr::Cast { data_type, .. } | Expr::TryCast { data_type, .. } => { Ok(data_type.clone()) } @@ -164,19 +164,16 @@ impl ExprSchemable for Expr { | Expr::InList { expr, .. } => expr.nullable(input_schema), Expr::Column(c) => input_schema.nullable(c), Expr::Literal(value) => Ok(value.is_null()), - Expr::Case { - when_then_expr, - else_expr, - .. - } => { + Expr::Case(case) => { // this expression is nullable if any of the input expressions are nullable - let then_nullable = when_then_expr + let then_nullable = case + .when_then_expr .iter() .map(|(_, t)| t.nullable(input_schema)) .collect::>>()?; if then_nullable.contains(&true) { Ok(true) - } else if let Some(e) = else_expr { + } else if let Some(e) = &case.else_expr { e.nullable(input_schema) } else { // CASE produces NULL if there is no `else` expr diff --git a/datafusion/expr/src/expr_visitor.rs b/datafusion/expr/src/expr_visitor.rs index 3885456cc3a9..f362a759eb3a 100644 --- a/datafusion/expr/src/expr_visitor.rs +++ b/datafusion/expr/src/expr_visitor.rs @@ -153,24 +153,20 @@ impl ExprVisitable for Expr { let visitor = low.accept(visitor)?; high.accept(visitor) } - Expr::Case { - expr, - when_then_expr, - else_expr, - } => { - let visitor = if let Some(expr) = expr.as_ref() { + Expr::Case(case) => { + let visitor = if let Some(expr) = case.expr.as_ref() { expr.accept(visitor) } else { Ok(visitor) }?; - let visitor = when_then_expr.iter().try_fold( + let visitor = case.when_then_expr.iter().try_fold( visitor, |visitor, (when, then)| { let visitor = when.accept(visitor)?; then.accept(visitor) }, )?; - if let Some(else_expr) = else_expr.as_ref() { + if let Some(else_expr) = case.else_expr.as_ref() { else_expr.accept(visitor) } else { Ok(visitor) diff --git a/datafusion/optimizer/src/simplify_expressions.rs b/datafusion/optimizer/src/simplify_expressions.rs index bb17a9925cd2..c96f6eea7151 100644 --- a/datafusion/optimizer/src/simplify_expressions.rs +++ b/datafusion/optimizer/src/simplify_expressions.rs @@ -490,7 +490,7 @@ impl<'a> ConstEvaluator<'a> { | Expr::Like { .. } | Expr::ILike { .. } | Expr::SimilarTo { .. } - | Expr::Case { .. } + | Expr::Case(_) | Expr::Cast { .. } | Expr::TryCast { .. } | Expr::InList { .. } @@ -848,20 +848,17 @@ impl<'a, S: SimplifyInfo> ExprRewriter for Simplifier<'a, S> { // // Note: the rationale for this rewrite is that the expr can then be further // simplified using the existing rules for AND/OR - Case { - expr: None, - when_then_expr, - else_expr, - } if !when_then_expr.is_empty() - && when_then_expr.len() < 3 // The rewrite is O(n!) so limit to small number - && info.is_boolean_type(&when_then_expr[0].1)? => + Case(case) + if !case.when_then_expr.is_empty() + && case.when_then_expr.len() < 3 // The rewrite is O(n!) so limit to small number + && info.is_boolean_type(&case.when_then_expr[0].1)? => { // The disjunction of all the when predicates encountered so far let mut filter_expr = lit(false); // The disjunction of all the cases let mut out_expr = lit(false); - for (when, then) in when_then_expr { + for (when, then) in case.when_then_expr { let case_expr = when .as_ref() .clone() @@ -872,7 +869,7 @@ impl<'a, S: SimplifyInfo> ExprRewriter for Simplifier<'a, S> { filter_expr = filter_expr.or(*when); } - if let Some(else_expr) = else_expr { + if let Some(else_expr) = case.else_expr { let case_expr = filter_expr.not().and(*else_expr); out_expr = out_expr.or(case_expr); } @@ -974,6 +971,7 @@ mod tests { use arrow::array::{ArrayRef, Int32Array}; use chrono::{DateTime, TimeZone, Utc}; use datafusion_common::{DFField, ToDFSchema}; + use datafusion_expr::expr::Case; use datafusion_expr::logical_plan::table_scan; use datafusion_expr::{ and, binary_expr, call_fn, col, create_udf, lit, lit_timestamp_nano, @@ -1700,14 +1698,14 @@ mod tests { // --> // false assert_eq!( - simplify(Expr::Case { - expr: None, - when_then_expr: vec![( + simplify(Expr::Case(Case::new( + None, + vec![( Box::new(col("c2").not_eq(lit(false))), Box::new(lit("ok").eq(lit("not_ok"))), )], - else_expr: Some(Box::new(col("c2").eq(lit(true)))), - }), + Some(Box::new(col("c2").eq(lit(true)))), + ))), col("c2").not().and(col("c2")) // #1716 ); @@ -1720,14 +1718,14 @@ mod tests { // Need to call simplify 2x due to // https://github.com/apache/arrow-datafusion/issues/1160 assert_eq!( - simplify(simplify(Expr::Case { - expr: None, - when_then_expr: vec![( + simplify(simplify(Expr::Case(Case::new( + None, + vec![( Box::new(col("c2").not_eq(lit(false))), Box::new(lit("ok").eq(lit("ok"))), )], - else_expr: Some(Box::new(col("c2").eq(lit(true)))), - })), + Some(Box::new(col("c2").eq(lit(true)))), + )))), col("c2").or(col("c2").not().and(col("c2"))) // #1716 ); @@ -1738,14 +1736,11 @@ mod tests { // Need to call simplify 2x due to // https://github.com/apache/arrow-datafusion/issues/1160 assert_eq!( - simplify(simplify(Expr::Case { - expr: None, - when_then_expr: vec![( - Box::new(col("c2").is_null()), - Box::new(lit(true)), - )], - else_expr: Some(Box::new(col("c2"))), - })), + simplify(simplify(Expr::Case(Case::new( + None, + vec![(Box::new(col("c2").is_null()), Box::new(lit(true)),)], + Some(Box::new(col("c2"))), + )))), col("c2") .is_null() .or(col("c2").is_not_null().and(col("c2"))) @@ -1759,14 +1754,14 @@ mod tests { // Need to call simplify 2x due to // https://github.com/apache/arrow-datafusion/issues/1160 assert_eq!( - simplify(simplify(Expr::Case { - expr: None, - when_then_expr: vec![ + simplify(simplify(Expr::Case(Case::new( + None, + vec![ (Box::new(col("c1")), Box::new(lit(true)),), (Box::new(col("c2")), Box::new(lit(false)),), ], - else_expr: Some(Box::new(lit(true))), - })), + Some(Box::new(lit(true))), + )))), col("c1").or(col("c1").not().and(col("c2").not())) ); @@ -1778,14 +1773,14 @@ mod tests { // Need to call simplify 2x due to // https://github.com/apache/arrow-datafusion/issues/1160 assert_eq!( - simplify(simplify(Expr::Case { - expr: None, - when_then_expr: vec![ + simplify(simplify(Expr::Case(Case::new( + None, + vec![ (Box::new(col("c1")), Box::new(lit(true)),), (Box::new(col("c2")), Box::new(lit(false)),), ], - else_expr: Some(Box::new(lit(true))), - })), + Some(Box::new(lit(true))), + )))), col("c1").or(col("c1").not().and(col("c2").not())) ); } diff --git a/datafusion/optimizer/src/type_coercion.rs b/datafusion/optimizer/src/type_coercion.rs index f0470da8710c..fcfe6eaaa84d 100644 --- a/datafusion/optimizer/src/type_coercion.rs +++ b/datafusion/optimizer/src/type_coercion.rs @@ -20,6 +20,7 @@ use crate::{OptimizerConfig, OptimizerRule}; use arrow::datatypes::DataType; use datafusion_common::{DFSchema, DFSchemaRef, DataFusionError, Result}; +use datafusion_expr::expr::Case; use datafusion_expr::expr_rewriter::{ExprRewritable, ExprRewriter, RewriteRecursion}; use datafusion_expr::logical_plan::Subquery; use datafusion_expr::type_coercion::binary::{coerce_types, comparison_coercion}; @@ -357,18 +358,15 @@ impl ExprRewriter for TypeCoercionRewriter { } } } - Expr::Case { - expr, - when_then_expr, - else_expr, - } => { + Expr::Case(case) => { // all the result of then and else should be convert to a common data type, // if they can be coercible to a common data type, return error. - let then_types = when_then_expr + let then_types = case + .when_then_expr .iter() .map(|when_then| when_then.1.get_type(&self.schema)) .collect::>>()?; - let else_type = match &else_expr { + let else_type = match &case.else_expr { None => Ok(None), Some(expr) => expr.get_type(&self.schema).map(Some), }?; @@ -380,24 +378,20 @@ impl ExprRewriter for TypeCoercionRewriter { then_types, else_type ))), Some(data_type) => { - let left = when_then_expr + let left = case.when_then_expr .into_iter() .map(|(when, then)| { let then = then.cast_to(&data_type, &self.schema)?; Ok((when, Box::new(then))) }) .collect::>>()?; - let right = match else_expr { + let right = match &case.else_expr { None => None, Some(expr) => { - Some(Box::new(expr.cast_to(&data_type, &self.schema)?)) + Some(Box::new(expr.clone().cast_to(&data_type, &self.schema)?)) } }; - Ok(Expr::Case { - expr, - when_then_expr: left, - else_expr: right, - }) + Ok(Expr::Case(Case::new(case.expr,left,right))) } } } diff --git a/datafusion/physical-expr/src/planner.rs b/datafusion/physical-expr/src/planner.rs index 0964d64805b8..993891884589 100644 --- a/datafusion/physical-expr/src/planner.rs +++ b/datafusion/physical-expr/src/planner.rs @@ -221,13 +221,8 @@ pub fn create_physical_expr( binary_expr(expr.as_ref().clone(), op, pattern.as_ref().clone()); create_physical_expr(&bin_expr, input_dfschema, input_schema, execution_props) } - Expr::Case { - expr, - when_then_expr, - else_expr, - .. - } => { - let expr: Option> = if let Some(e) = expr { + Expr::Case(case) => { + let expr: Option> = if let Some(e) = &case.expr { Some(create_physical_expr( e.as_ref(), input_dfschema, @@ -237,7 +232,8 @@ pub fn create_physical_expr( } else { None }; - let when_expr = when_then_expr + let when_expr = case + .when_then_expr .iter() .map(|(w, _)| { create_physical_expr( @@ -248,7 +244,8 @@ pub fn create_physical_expr( ) }) .collect::>>()?; - let then_expr = when_then_expr + let then_expr = case + .when_then_expr .iter() .map(|(_, t)| { create_physical_expr( @@ -265,16 +262,17 @@ pub fn create_physical_expr( .zip(then_expr.iter()) .map(|(w, t)| (w.clone(), t.clone())) .collect(); - let else_expr: Option> = if let Some(e) = else_expr { - Some(create_physical_expr( - e.as_ref(), - input_dfschema, - input_schema, - execution_props, - )?) - } else { - None - }; + let else_expr: Option> = + if let Some(e) = &case.else_expr { + Some(create_physical_expr( + e.as_ref(), + input_dfschema, + input_schema, + execution_props, + )?) + } else { + None + }; Ok(expressions::case(expr, when_then_expr, else_expr)?) } Expr::Cast { expr, data_type } => expressions::cast( diff --git a/datafusion/proto/src/from_proto.rs b/datafusion/proto/src/from_proto.rs index 3eeb30edf649..208c24036262 100644 --- a/datafusion/proto/src/from_proto.rs +++ b/datafusion/proto/src/from_proto.rs @@ -31,8 +31,8 @@ use datafusion::logical_plan::FunctionRegistry; use datafusion_common::{ Column, DFField, DFSchema, DFSchemaRef, DataFusionError, ScalarValue, }; -use datafusion_expr::expr::GroupingSet; use datafusion_expr::expr::GroupingSet::GroupingSets; +use datafusion_expr::expr::{Case, GroupingSet}; use datafusion_expr::{ abs, acos, array, ascii, asin, atan, atan2, bit_length, btrim, ceil, character_length, chr, coalesce, concat_expr, concat_ws_expr, cos, date_bin, @@ -1023,11 +1023,11 @@ pub fn parse_expr( Ok((Box::new(when_expr), Box::new(then_expr))) }) .collect::, Box)>, Error>>()?; - Ok(Expr::Case { - expr: parse_optional_expr(&case.expr, registry)?.map(Box::new), + Ok(Expr::Case(Case::new( + parse_optional_expr(&case.expr, registry)?.map(Box::new), when_then_expr, - else_expr: parse_optional_expr(&case.else_expr, registry)?.map(Box::new), - }) + parse_optional_expr(&case.else_expr, registry)?.map(Box::new), + ))) } ExprType::Cast(cast) => { let expr = Box::new(parse_required_expr(&cast.expr, registry, "expr")?); diff --git a/datafusion/proto/src/lib.rs b/datafusion/proto/src/lib.rs index e3b6c848a2b1..7a495077f26f 100644 --- a/datafusion/proto/src/lib.rs +++ b/datafusion/proto/src/lib.rs @@ -63,7 +63,7 @@ mod roundtrip_tests { use datafusion::physical_plan::functions::make_scalar_function; use datafusion::prelude::{create_udf, CsvReadOptions, SessionContext}; use datafusion_common::{DFSchemaRef, DataFusionError, ScalarValue}; - use datafusion_expr::expr::GroupingSet; + use datafusion_expr::expr::{Case, GroupingSet}; use datafusion_expr::logical_plan::{Extension, UserDefinedLogicalNode}; use datafusion_expr::{ col, lit, Accumulator, AggregateFunction, AggregateState, @@ -970,11 +970,11 @@ mod roundtrip_tests { #[test] fn roundtrip_case() { - let test_expr = Expr::Case { - expr: Some(Box::new(lit(1.0_f32))), - when_then_expr: vec![(Box::new(lit(2.0_f32)), Box::new(lit(3.0_f32)))], - else_expr: Some(Box::new(lit(4.0_f32))), - }; + let test_expr = Expr::Case(Case::new( + Some(Box::new(lit(1.0_f32))), + vec![(Box::new(lit(2.0_f32)), Box::new(lit(3.0_f32)))], + Some(Box::new(lit(4.0_f32))), + )); let ctx = SessionContext::new(); roundtrip_expr_test(test_expr, ctx); @@ -982,11 +982,11 @@ mod roundtrip_tests { #[test] fn roundtrip_case_with_null() { - let test_expr = Expr::Case { - expr: Some(Box::new(lit(1.0_f32))), - when_then_expr: vec![(Box::new(lit(2.0_f32)), Box::new(lit(3.0_f32)))], - else_expr: Some(Box::new(Expr::Literal(ScalarValue::Null))), - }; + let test_expr = Expr::Case(Case::new( + Some(Box::new(lit(1.0_f32))), + vec![(Box::new(lit(2.0_f32)), Box::new(lit(3.0_f32)))], + Some(Box::new(Expr::Literal(ScalarValue::Null))), + )); let ctx = SessionContext::new(); roundtrip_expr_test(test_expr, ctx); diff --git a/datafusion/proto/src/to_proto.rs b/datafusion/proto/src/to_proto.rs index 47b779fffc74..7b70821ec724 100644 --- a/datafusion/proto/src/to_proto.rs +++ b/datafusion/proto/src/to_proto.rs @@ -771,12 +771,8 @@ impl TryFrom<&Expr> for protobuf::LogicalExprNode { expr_type: Some(ExprType::Between(expr)), } } - Expr::Case { - expr, - when_then_expr, - else_expr, - } => { - let when_then_expr = when_then_expr + Expr::Case(case) => { + let when_then_expr = case.when_then_expr .iter() .map(|(w, t)| { Ok(protobuf::WhenThen { @@ -786,12 +782,12 @@ impl TryFrom<&Expr> for protobuf::LogicalExprNode { }) .collect::, Error>>()?; let expr = Box::new(protobuf::CaseNode { - expr: match expr { + expr: match &case.expr { Some(e) => Some(Box::new(e.as_ref().try_into()?)), None => None, }, when_then_expr, - else_expr: match else_expr { + else_expr: match &case.else_expr { Some(e) => Some(Box::new(e.as_ref().try_into()?)), None => None, }, diff --git a/datafusion/sql/src/planner.rs b/datafusion/sql/src/planner.rs index 9a4e29228932..58b65af596ef 100644 --- a/datafusion/sql/src/planner.rs +++ b/datafusion/sql/src/planner.rs @@ -51,7 +51,7 @@ use crate::utils::{make_decimal_type, normalize_ident, resolve_columns}; use datafusion_common::{ field_not_found, Column, DFSchema, DFSchemaRef, DataFusionError, Result, ScalarValue, }; -use datafusion_expr::expr::GroupingSet; +use datafusion_expr::expr::{Case, GroupingSet}; use datafusion_expr::logical_plan::builder::project_with_alias; use datafusion_expr::logical_plan::{Filter, Subquery}; use datafusion_expr::Expr::Alias; @@ -1872,15 +1872,15 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { None }; - Ok(Expr::Case { + Ok(Expr::Case(Case::new( expr, - when_then_expr: when_expr + when_expr .iter() .zip(then_expr.iter()) .map(|(w, t)| (Box::new(w.to_owned()), Box::new(t.to_owned()))) .collect(), else_expr, - }) + ))) } SQLExpr::Cast { diff --git a/datafusion/sql/src/utils.rs b/datafusion/sql/src/utils.rs index eb58509d0960..952ef31106fd 100644 --- a/datafusion/sql/src/utils.rs +++ b/datafusion/sql/src/utils.rs @@ -21,7 +21,7 @@ use arrow::datatypes::{DataType, DECIMAL128_MAX_PRECISION, DECIMAL_DEFAULT_SCALE use sqlparser::ast::Ident; use datafusion_common::{DataFusionError, Result, ScalarValue}; -use datafusion_expr::expr::GroupingSet; +use datafusion_expr::expr::{Case, GroupingSet}; use datafusion_expr::utils::{expr_as_column_expr, find_column_exprs}; use datafusion_expr::{Expr, LogicalPlan}; use std::collections::HashMap; @@ -268,18 +268,14 @@ where pattern: Box::new(clone_with_replacement(pattern, replacement_fn)?), escape_char: *escape_char, }), - Expr::Case { - expr: case_expr_opt, - when_then_expr, - else_expr: else_expr_opt, - } => Ok(Expr::Case { - expr: match case_expr_opt { + Expr::Case(case) => Ok(Expr::Case(Case::new( + match &case.expr { Some(case_expr) => { Some(Box::new(clone_with_replacement(case_expr, replacement_fn)?)) } None => None, }, - when_then_expr: when_then_expr + case.when_then_expr .iter() .map(|(a, b)| { Ok(( @@ -288,13 +284,13 @@ where )) }) .collect::>>()?, - else_expr: match else_expr_opt { + match &case.else_expr { Some(else_expr) => { Some(Box::new(clone_with_replacement(else_expr, replacement_fn)?)) } None => None, }, - }), + ))), Expr::ScalarFunction { fun, args } => Ok(Expr::ScalarFunction { fun: fun.clone(), args: args