diff --git a/datafusion/optimizer/src/simplify_expressions.rs b/datafusion/optimizer/src/simplify_expressions.rs index a9970abf6b64..d528bd00f587 100644 --- a/datafusion/optimizer/src/simplify_expressions.rs +++ b/datafusion/optimizer/src/simplify_expressions.rs @@ -788,12 +788,22 @@ impl<'a, S: SimplifyInfo> ExprRewriter for Simplifier<'a, S> { op: Divide, right, }) if is_null(&right) => *right, - // A / A --> 1 (if a is not nullable) + // 0 / 0 -> null Expr::BinaryExpr(BinaryExpr { left, op: Divide, right, - }) if !info.nullable(&left)? && left == right => lit(1), + }) if is_zero(&left) && is_zero(&right) => { + Expr::Literal(ScalarValue::Int32(None)) + } + // A / 0 -> DivideByZero Error + Expr::BinaryExpr(BinaryExpr { + left, + op: Divide, + right, + }) if !info.nullable(&left)? && is_zero(&right) => { + return Err(DataFusionError::ArrowError(ArrowError::DivideByZero)) + } // // Rules for Modulo @@ -1179,13 +1189,25 @@ mod tests { } #[test] - fn test_simplify_divide_by_same_non_null() { - let expr = binary_expr(col("c2_non_null"), Operator::Divide, col("c2_non_null")); - let expected = lit(1); + fn test_simplify_divide_zero_by_zero() { + // 0 / 0 -> null + let expr = binary_expr(lit(0), Operator::Divide, lit(0)); + let expected = Expr::Literal(ScalarValue::Int32(None)); assert_eq!(simplify(expr), expected); } + #[test] + #[should_panic( + expected = "called `Result::unwrap()` on an `Err` value: ArrowError(DivideByZero)" + )] + fn test_simplify_divide_by_zero() { + // A / 0 -> DivideByZeroError + let expr = binary_expr(col("c2_non_null"), Operator::Divide, lit(0)); + + simplify(expr); + } + #[test] fn test_simplify_modulo_by_null() { let null = Expr::Literal(ScalarValue::Null);