diff --git a/datafusion/optimizer/src/common_subexpr_eliminate.rs b/datafusion/optimizer/src/common_subexpr_eliminate.rs index 4eeda1f4c7d0..0e7b4f1e0870 100644 --- a/datafusion/optimizer/src/common_subexpr_eliminate.rs +++ b/datafusion/optimizer/src/common_subexpr_eliminate.rs @@ -19,7 +19,7 @@ use crate::{OptimizerConfig, OptimizerRule}; use arrow::datatypes::DataType; -use datafusion_common::{DFField, DFSchema, Result}; +use datafusion_common::{DFField, DFSchema, DataFusionError, Result}; use datafusion_expr::{ col, expr::GroupingSet, @@ -107,7 +107,7 @@ fn optimize( )?; Ok(LogicalPlan::Projection(Projection::try_new_with_schema( - new_expr.pop().unwrap(), + pop_expr(&mut new_expr)?, Arc::new(new_input), schema.clone(), alias.clone(), @@ -139,10 +139,16 @@ fn optimize( optimizer_config, )?; - Ok(LogicalPlan::Filter(Filter { - predicate: new_expr.pop().unwrap().pop().unwrap(), - input: Arc::new(new_input), - })) + if let Some(predicate) = pop_expr(&mut new_expr)?.pop() { + Ok(LogicalPlan::Filter(Filter { + predicate, + input: Arc::new(new_input), + })) + } else { + Err(DataFusionError::Internal( + "Failed to pop predicate expr".to_string(), + )) + } } LogicalPlan::Window(Window { input, @@ -161,7 +167,7 @@ fn optimize( Ok(LogicalPlan::Window(Window { input: Arc::new(new_input), - window_expr: new_expr.pop().unwrap(), + window_expr: pop_expr(&mut new_expr)?, schema: schema.clone(), })) } @@ -182,8 +188,8 @@ fn optimize( optimizer_config, )?; // note the reversed pop order. - let new_aggr_expr = new_expr.pop().unwrap(); - let new_group_expr = new_expr.pop().unwrap(); + let new_aggr_expr = pop_expr(&mut new_expr)?; + let new_group_expr = pop_expr(&mut new_expr)?; Ok(LogicalPlan::Aggregate(Aggregate::try_new( Arc::new(new_input), @@ -204,7 +210,7 @@ fn optimize( )?; Ok(LogicalPlan::Sort(Sort { - expr: new_expr.pop().unwrap(), + expr: pop_expr(&mut new_expr)?, input: Arc::new(new_input), })) } @@ -241,6 +247,12 @@ fn optimize( } } +fn pop_expr(new_expr: &mut Vec>) -> Result> { + new_expr + .pop() + .ok_or_else(|| DataFusionError::Internal("Failed to pop expression".to_string())) +} + fn to_arrays( expr: &[Expr], input: &LogicalPlan, @@ -268,12 +280,20 @@ fn build_project_plan( let mut fields_set = HashSet::new(); for id in affected_id { - let (expr, _, data_type) = expr_set.get(&id).unwrap(); - // todo: check `nullable` - let field = DFField::new(None, &id, data_type.clone(), true); - fields_set.insert(field.name().to_owned()); - fields.push(field); - project_exprs.push(expr.clone().alias(&id)); + match expr_set.get(&id) { + Some((expr, _, data_type)) => { + // todo: check `nullable` + let field = DFField::new(None, &id, data_type.clone(), true); + fields_set.insert(field.name().to_owned()); + fields.push(field); + project_exprs.push(expr.clone().alias(&id)); + } + _ => { + return Err(DataFusionError::Internal( + "expr_set invalid state".to_string(), + )) + } + } } for field in input.schema().fields() { @@ -639,13 +659,19 @@ impl ExprRewriter for CommonSubexprRewriter<'_> { self.curr_index += 1; return Ok(RewriteRecursion::Skip); } - let (_, counter, _) = self.expr_set.get(curr_id).unwrap(); - if *counter > 1 { - self.affected_id.insert(curr_id.clone()); - Ok(RewriteRecursion::Mutate) - } else { - self.curr_index += 1; - Ok(RewriteRecursion::Skip) + match self.expr_set.get(curr_id) { + Some((_, counter, _)) => { + if *counter > 1 { + self.affected_id.insert(curr_id.clone()); + Ok(RewriteRecursion::Mutate) + } else { + self.curr_index += 1; + Ok(RewriteRecursion::Skip) + } + } + _ => Err(DataFusionError::Internal( + "expr_set invalid state".to_string(), + )), } } @@ -658,9 +684,12 @@ impl ExprRewriter for CommonSubexprRewriter<'_> { let (series_number, id) = &self.id_array[self.curr_index]; self.curr_index += 1; // Skip sub-node of a replaced tree, or without identifier, or is not repeated expr. + let expr_set_item = self.expr_set.get(id).ok_or_else(|| { + DataFusionError::Internal("expr_set invalid state".to_string()) + })?; if *series_number < self.max_series_number || id.is_empty() - || self.expr_set.get(id).unwrap().1 <= 1 + || expr_set_item.1 <= 1 { return Ok(expr); }