From 815f8e7dcec50f3755571d950f4f40d49489e91b Mon Sep 17 00:00:00 2001 From: kould Date: Sun, 3 May 2026 22:13:25 +0800 Subject: [PATCH] Normalize boolean-wrapped range predicates --- src/expression/range_detacher.rs | 75 +++++++++--- src/expression/simplify.rs | 113 ++++++++++++++++-- .../rule/normalization/simplification.rs | 23 ++++ tests/slt/where_by_index.slt | 24 ++++ tests/slt/where_by_index_explain.slt | 15 +++ 5 files changed, 225 insertions(+), 25 deletions(-) diff --git a/src/expression/range_detacher.rs b/src/expression/range_detacher.rs index ca431c2a..3265e50b 100644 --- a/src/expression/range_detacher.rs +++ b/src/expression/range_detacher.rs @@ -221,16 +221,11 @@ impl<'a> RangeDetacher<'a> { None } - (Some(binary), None) | (None, Some(binary)) => self.check_or(op, binary), + (Some(binary), None) | (None, Some(binary)) => self.check_and(op, binary), }, - ScalarExpression::Alias { expr, .. } - | ScalarExpression::TypeCast { expr, .. } - | ScalarExpression::Unary { expr, .. } - | ScalarExpression::In { expr, .. } - | ScalarExpression::Between { expr, .. } - | ScalarExpression::SubString { expr, .. } => self.detach(expr)?, - ScalarExpression::Position { expr, .. } => self.detach(expr)?, - ScalarExpression::Trim { expr, .. } => self.detach(expr)?, + ScalarExpression::Alias { expr, .. } | ScalarExpression::TypeCast { expr, .. } => { + self.detach(expr)? + } ScalarExpression::IsNull { expr, negated, .. } => match expr.as_ref() { ScalarExpression::ColumnRef { column, .. } => { if let (Some(col_id), Some(col_table)) = (column.id(), column.table_name()) { @@ -263,14 +258,20 @@ impl<'a> RangeDetacher<'a> { | ScalarExpression::IfNull { .. } | ScalarExpression::NullIf { .. } | ScalarExpression::Coalesce { .. } - | ScalarExpression::CaseWhen { .. } => self.detach(expr)?, + | ScalarExpression::CaseWhen { .. } => None, ScalarExpression::Tuple(_) | ScalarExpression::TableFunction(_) | ScalarExpression::Empty => unreachable!(), }, ScalarExpression::Constant(_) | ScalarExpression::ColumnRef { .. } => None, // FIXME: support [RangeDetacher::_detach] - ScalarExpression::Tuple(_) + ScalarExpression::Unary { .. } + | ScalarExpression::In { .. } + | ScalarExpression::Between { .. } + | ScalarExpression::SubString { .. } + | ScalarExpression::Position { .. } + | ScalarExpression::Trim { .. } + | ScalarExpression::Tuple(_) | ScalarExpression::AggCall { .. } | ScalarExpression::ScalaFunction(_) | ScalarExpression::If { .. } @@ -759,14 +760,13 @@ impl<'a> RangeDetacher<'a> { }) } - /// check if: `c1 > c2 or c1 > 1` or `c2 > 1 or c1 > 1` - /// this case it makes no sense to just extract c1 > 1 - fn check_or(&mut self, op: &BinaryOperator, binary: Range) -> Option { - if matches!(op, BinaryOperator::Or) { - return None; + /// Only conjunction can safely keep a range detached from one side of a binary expression. + fn check_and(&mut self, op: &BinaryOperator, binary: Range) -> Option { + if matches!(op, BinaryOperator::And) { + return Some(binary); } - Some(binary) + None } } @@ -1312,6 +1312,47 @@ mod test { Ok(()) } + #[test] + fn test_detach_only_conjunction_can_keep_partial_range() -> Result<(), DatabaseError> { + let table_state = build_t1_table()?; + let detach_c1 = |sql: &str| -> Result, DatabaseError> { + let plan = table_state.plan(sql)?; + let op = plan_filter(plan)?.unwrap(); + RangeDetacher::new("t1", table_state.column_id_by_name("c1")).detach(&op.predicate) + }; + + assert_eq!( + detach_c1("select * from t1 where c2 = 1 and c1 > 10")?, + Some(Range::Scope { + min: Bound::Excluded(DataValue::Int32(10)), + max: Bound::Unbounded, + }) + ); + let negated_range = Some(Range::Scope { + min: Bound::Unbounded, + max: Bound::Included(DataValue::Int32(10)), + }); + assert_eq!( + detach_c1("select * from t1 where (c1 > 10) = false")?, + negated_range.clone() + ); + assert_eq!( + detach_c1("select * from t1 where (c1 > 10) != true")?, + negated_range.clone() + ); + assert_eq!( + detach_c1("select * from t1 where not (c1 > 10)")?, + negated_range + ); + assert_eq!( + detach_c1("select * from t1 where (c1 > 10) = (c2 > 0)")?, + None + ); + assert_eq!(detach_c1("select * from t1 where (c1 > 10) is null")?, None); + + Ok(()) + } + // Tips: `null` should be First #[test] fn test_detach_null_cases() -> Result<(), DatabaseError> { diff --git a/src/expression/simplify.rs b/src/expression/simplify.rs index b03476a6..78592181 100644 --- a/src/expression/simplify.rs +++ b/src/expression/simplify.rs @@ -121,20 +121,35 @@ impl VisitorMut<'_> for Simplify { ScalarExpression::Unary { op, expr: arg_expr, + evaluator, ty, - .. } => { let op = *op; let ty = ty.clone(); - let arg_expr = arg_expr.as_ref().clone(); - if let Some(value) = expr.unpack_val() { + let child_expr = arg_expr.as_ref().clone(); + let value = if let Some(value) = arg_expr.unpack_val() { + Some(if let Some(evaluator) = evaluator { + evaluator.0.unary_eval(&value) + } else { + unary_create(Cow::Borrowed(&ty), op)?.0.unary_eval(&value) + }) + } else { + None + }; + + if let Some(value) = value { let _ = mem::replace(expr, ScalarExpression::Constant(value)); + } else if matches!(op, UnaryOperator::Not) { + if let Some(new_expr) = Self::take_negated_range_comparison(arg_expr) { + let _ = mem::replace(expr, new_expr); + self.visit(expr)?; + } else { + self.replaces + .push(Replace::Unary(ReplaceUnary { child_expr, op, ty })); + } } else { - self.replaces.push(Replace::Unary(ReplaceUnary { - child_expr: arg_expr, - op, - ty, - })); + self.replaces + .push(Replace::Unary(ReplaceUnary { child_expr, op, ty })); } } ScalarExpression::Binary { @@ -149,6 +164,14 @@ impl VisitorMut<'_> for Simplify { // `(c1 - 1) and (c1 + 2)` cannot fix! self.fix_expr(right_expr, left_expr, op)?; + if let Some(new_expr) = + Self::take_bool_normalized_range_comparison(*op, left_expr, right_expr) + { + let _ = mem::replace(expr, new_expr); + self.visit(expr)?; + return Ok(()); + } + if Self::is_arithmetic(op) { match ( left_expr.unpack_bound_col(false), @@ -317,6 +340,80 @@ impl Simplify { ) } + fn negate_range_comparison(op: BinaryOperator) -> Option { + match op { + BinaryOperator::Gt => Some(BinaryOperator::LtEq), + BinaryOperator::GtEq => Some(BinaryOperator::Lt), + BinaryOperator::Lt => Some(BinaryOperator::GtEq), + BinaryOperator::LtEq => Some(BinaryOperator::Gt), + _ => None, + } + } + + fn take_range_comparison(expr: &mut Box) -> Option { + match expr.as_ref() { + ScalarExpression::Binary { op, .. } if Self::negate_range_comparison(*op).is_some() => { + Some(mem::replace(expr.as_mut(), ScalarExpression::Empty)) + } + _ => None, + } + } + + fn take_negated_range_comparison(expr: &mut Box) -> Option { + match expr.as_mut() { + ScalarExpression::Binary { op, .. } => { + *op = Self::negate_range_comparison(*op)?; + Some(mem::replace(expr.as_mut(), ScalarExpression::Empty)) + } + _ => None, + } + } + + fn boolean_constant(expr: &ScalarExpression) -> Option { + match expr { + ScalarExpression::Constant(DataValue::Boolean(value)) => Some(*value), + _ => None, + } + } + + fn take_range_comparison_with_polarity( + expr: &mut Box, + positive: bool, + ) -> Option { + if positive { + Self::take_range_comparison(expr) + } else { + Self::take_negated_range_comparison(expr) + } + } + + fn take_bool_normalized_range_comparison( + op: BinaryOperator, + left_expr: &mut Box, + right_expr: &mut Box, + ) -> Option { + let is_eq = matches!(op, BinaryOperator::Eq); + let is_not_eq = matches!(op, BinaryOperator::NotEq); + if !is_eq && !is_not_eq { + return None; + } + + if let Some(value) = Self::boolean_constant(right_expr) { + return Self::take_range_comparison_with_polarity( + left_expr, + if is_eq { value } else { !value }, + ); + } + if let Some(value) = Self::boolean_constant(left_expr) { + return Self::take_range_comparison_with_polarity( + right_expr, + if is_eq { value } else { !value }, + ); + } + + None + } + fn fix_expr( &mut self, left_expr: &mut Box, diff --git a/src/optimizer/rule/normalization/simplification.rs b/src/optimizer/rule/normalization/simplification.rs index fe8ad75e..720802c9 100644 --- a/src/optimizer/rule/normalization/simplification.rs +++ b/src/optimizer/rule/normalization/simplification.rs @@ -306,6 +306,29 @@ mod test { Ok(()) } + #[test] + fn test_simplify_filter_boolean_wrapped_range_comparison() -> Result<(), DatabaseError> { + let table_state = build_t1_table()?; + let expected = Some(Range::Scope { + min: Bound::Unbounded, + max: Bound::Included(DataValue::Int32(10)), + }); + + for sql in [ + "select * from t1 where (c1 > 10) = false", + "select * from t1 where (c1 > 10) != true", + "select * from t1 where not (c1 > 10)", + ] { + let plan = table_state.plan(sql)?; + assert_eq!( + plan_filter(&plan, table_state.column_id_by_name("c1"))?, + expected + ); + } + + Ok(()) + } + #[test] fn test_simplify_filter_repeating_column() -> Result<(), DatabaseError> { let table_state = build_t1_table()?; diff --git a/tests/slt/where_by_index.slt b/tests/slt/where_by_index.slt index a960bd9b..860b15d2 100644 --- a/tests/slt/where_by_index.slt +++ b/tests/slt/where_by_index.slt @@ -84,6 +84,30 @@ query IIT select * from t1 where id <= 0 and id >= 3; ---- +query IIT +select * from t1 where (id > 10) = false; +---- +0 1 2 +3 4 5 +6 7 8 +9 10 11 + +query IIT +select * from t1 where (id > 10) != true; +---- +0 1 2 +3 4 5 +6 7 8 +9 10 11 + +query IIT +select * from t1 where not (id > 10); +---- +0 1 2 +3 4 5 +6 7 8 +9 10 11 + query IIT select * from t1 where id >= 3 or id <= 9 limit 10; ---- diff --git a/tests/slt/where_by_index_explain.slt b/tests/slt/where_by_index_explain.slt index f31f5f6b..87785219 100644 --- a/tests/slt/where_by_index_explain.slt +++ b/tests/slt/where_by_index_explain.slt @@ -69,6 +69,21 @@ explain select * from t1 where id <= 0 and id >= 3; ---- Projection [t1.id, t1.c1, t1.c2] [Project => (Sort Option: Follow)] Filter ((t1.id <= 0) && (t1.id >= 3)), Is Having: false [Filter => (Sort Option: Follow)] TableScan t1 -> [id, c1, c2] [IndexScan By pk_index => Dummy => (Sort Option: OrderBy: (t1.id Asc Nulls Last) ignore_prefix_len: 0)] +query T +explain select * from t1 where (id > 10) = false; +---- +Projection [t1.id, t1.c1, t1.c2] [Project => (Sort Option: Follow)] Filter (t1.id <= 10), Is Having: false [Filter => (Sort Option: Follow)] TableScan t1 -> [id, c1, c2] [IndexScan By pk_index => (-inf, 10] => (Sort Option: OrderBy: (t1.id Asc Nulls Last) ignore_prefix_len: 0)] + +query T +explain select * from t1 where (id > 10) != true; +---- +Projection [t1.id, t1.c1, t1.c2] [Project => (Sort Option: Follow)] Filter (t1.id <= 10), Is Having: false [Filter => (Sort Option: Follow)] TableScan t1 -> [id, c1, c2] [IndexScan By pk_index => (-inf, 10] => (Sort Option: OrderBy: (t1.id Asc Nulls Last) ignore_prefix_len: 0)] + +query T +explain select * from t1 where not (id > 10); +---- +Projection [t1.id, t1.c1, t1.c2] [Project => (Sort Option: Follow)] Filter (t1.id <= 10), Is Having: false [Filter => (Sort Option: Follow)] TableScan t1 -> [id, c1, c2] [IndexScan By pk_index => (-inf, 10] => (Sort Option: OrderBy: (t1.id Asc Nulls Last) ignore_prefix_len: 0)] + query T explain select * from t1 where id >= 3 or id <= 9 limit 10; ----