Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 58 additions & 17 deletions src/expression/range_detacher.rs
Original file line number Diff line number Diff line change
Expand Up @@ -221,16 +221,11 @@ impl<'a> RangeDetacher<'a> {

None
}
(Some(binary), None) | (None, Some(binary)) => self.check_or(op, binary),
(Some(binary), None) | (None, Some(binary)) => self.check_and(op, binary),
},
ScalarExpression::Alias { expr, .. }
| ScalarExpression::TypeCast { expr, .. }
| ScalarExpression::Unary { expr, .. }
| ScalarExpression::In { expr, .. }
| ScalarExpression::Between { expr, .. }
| ScalarExpression::SubString { expr, .. } => self.detach(expr)?,
ScalarExpression::Position { expr, .. } => self.detach(expr)?,
ScalarExpression::Trim { expr, .. } => self.detach(expr)?,
ScalarExpression::Alias { expr, .. } | ScalarExpression::TypeCast { expr, .. } => {
self.detach(expr)?
}
ScalarExpression::IsNull { expr, negated, .. } => match expr.as_ref() {
ScalarExpression::ColumnRef { column, .. } => {
if let (Some(col_id), Some(col_table)) = (column.id(), column.table_name()) {
Expand Down Expand Up @@ -263,14 +258,20 @@ impl<'a> RangeDetacher<'a> {
| ScalarExpression::IfNull { .. }
| ScalarExpression::NullIf { .. }
| ScalarExpression::Coalesce { .. }
| ScalarExpression::CaseWhen { .. } => self.detach(expr)?,
| ScalarExpression::CaseWhen { .. } => None,
ScalarExpression::Tuple(_)
| ScalarExpression::TableFunction(_)
| ScalarExpression::Empty => unreachable!(),
},
ScalarExpression::Constant(_) | ScalarExpression::ColumnRef { .. } => None,
// FIXME: support [RangeDetacher::_detach]
ScalarExpression::Tuple(_)
ScalarExpression::Unary { .. }
| ScalarExpression::In { .. }
| ScalarExpression::Between { .. }
| ScalarExpression::SubString { .. }
| ScalarExpression::Position { .. }
| ScalarExpression::Trim { .. }
| ScalarExpression::Tuple(_)
| ScalarExpression::AggCall { .. }
| ScalarExpression::ScalaFunction(_)
| ScalarExpression::If { .. }
Expand Down Expand Up @@ -759,14 +760,13 @@ impl<'a> RangeDetacher<'a> {
})
}

/// check if: `c1 > c2 or c1 > 1` or `c2 > 1 or c1 > 1`
/// this case it makes no sense to just extract c1 > 1
fn check_or(&mut self, op: &BinaryOperator, binary: Range) -> Option<Range> {
if matches!(op, BinaryOperator::Or) {
return None;
/// Only conjunction can safely keep a range detached from one side of a binary expression.
fn check_and(&mut self, op: &BinaryOperator, binary: Range) -> Option<Range> {
if matches!(op, BinaryOperator::And) {
return Some(binary);
}

Some(binary)
None
}
}

Expand Down Expand Up @@ -1312,6 +1312,47 @@ mod test {
Ok(())
}

#[test]
fn test_detach_only_conjunction_can_keep_partial_range() -> Result<(), DatabaseError> {
let table_state = build_t1_table()?;
let detach_c1 = |sql: &str| -> Result<Option<Range>, DatabaseError> {
let plan = table_state.plan(sql)?;
let op = plan_filter(plan)?.unwrap();
RangeDetacher::new("t1", table_state.column_id_by_name("c1")).detach(&op.predicate)
};

assert_eq!(
detach_c1("select * from t1 where c2 = 1 and c1 > 10")?,
Some(Range::Scope {
min: Bound::Excluded(DataValue::Int32(10)),
max: Bound::Unbounded,
})
);
let negated_range = Some(Range::Scope {
min: Bound::Unbounded,
max: Bound::Included(DataValue::Int32(10)),
});
assert_eq!(
detach_c1("select * from t1 where (c1 > 10) = false")?,
negated_range.clone()
);
assert_eq!(
detach_c1("select * from t1 where (c1 > 10) != true")?,
negated_range.clone()
);
assert_eq!(
detach_c1("select * from t1 where not (c1 > 10)")?,
negated_range
);
assert_eq!(
detach_c1("select * from t1 where (c1 > 10) = (c2 > 0)")?,
None
);
assert_eq!(detach_c1("select * from t1 where (c1 > 10) is null")?, None);

Ok(())
}

// Tips: `null` should be First
#[test]
fn test_detach_null_cases() -> Result<(), DatabaseError> {
Expand Down
113 changes: 105 additions & 8 deletions src/expression/simplify.rs
Original file line number Diff line number Diff line change
Expand Up @@ -121,20 +121,35 @@ impl VisitorMut<'_> for Simplify {
ScalarExpression::Unary {
op,
expr: arg_expr,
evaluator,
ty,
..
} => {
let op = *op;
let ty = ty.clone();
let arg_expr = arg_expr.as_ref().clone();
if let Some(value) = expr.unpack_val() {
let child_expr = arg_expr.as_ref().clone();
let value = if let Some(value) = arg_expr.unpack_val() {
Some(if let Some(evaluator) = evaluator {
evaluator.0.unary_eval(&value)
} else {
unary_create(Cow::Borrowed(&ty), op)?.0.unary_eval(&value)
})
} else {
None
};

if let Some(value) = value {
let _ = mem::replace(expr, ScalarExpression::Constant(value));
} else if matches!(op, UnaryOperator::Not) {
if let Some(new_expr) = Self::take_negated_range_comparison(arg_expr) {
let _ = mem::replace(expr, new_expr);
self.visit(expr)?;
} else {
self.replaces
.push(Replace::Unary(ReplaceUnary { child_expr, op, ty }));
}
} else {
self.replaces.push(Replace::Unary(ReplaceUnary {
child_expr: arg_expr,
op,
ty,
}));
self.replaces
.push(Replace::Unary(ReplaceUnary { child_expr, op, ty }));
}
}
ScalarExpression::Binary {
Expand All @@ -149,6 +164,14 @@ impl VisitorMut<'_> for Simplify {
// `(c1 - 1) and (c1 + 2)` cannot fix!
self.fix_expr(right_expr, left_expr, op)?;

if let Some(new_expr) =
Self::take_bool_normalized_range_comparison(*op, left_expr, right_expr)
{
let _ = mem::replace(expr, new_expr);
self.visit(expr)?;
return Ok(());
}

if Self::is_arithmetic(op) {
match (
left_expr.unpack_bound_col(false),
Expand Down Expand Up @@ -317,6 +340,80 @@ impl Simplify {
)
}

fn negate_range_comparison(op: BinaryOperator) -> Option<BinaryOperator> {
match op {
BinaryOperator::Gt => Some(BinaryOperator::LtEq),
BinaryOperator::GtEq => Some(BinaryOperator::Lt),
BinaryOperator::Lt => Some(BinaryOperator::GtEq),
BinaryOperator::LtEq => Some(BinaryOperator::Gt),
_ => None,
}
}

fn take_range_comparison(expr: &mut Box<ScalarExpression>) -> Option<ScalarExpression> {
match expr.as_ref() {
ScalarExpression::Binary { op, .. } if Self::negate_range_comparison(*op).is_some() => {
Some(mem::replace(expr.as_mut(), ScalarExpression::Empty))
}
_ => None,
}
}

fn take_negated_range_comparison(expr: &mut Box<ScalarExpression>) -> Option<ScalarExpression> {
match expr.as_mut() {
ScalarExpression::Binary { op, .. } => {
*op = Self::negate_range_comparison(*op)?;
Some(mem::replace(expr.as_mut(), ScalarExpression::Empty))
}
_ => None,
}
}

fn boolean_constant(expr: &ScalarExpression) -> Option<bool> {
match expr {
ScalarExpression::Constant(DataValue::Boolean(value)) => Some(*value),
_ => None,
}
}

fn take_range_comparison_with_polarity(
expr: &mut Box<ScalarExpression>,
positive: bool,
) -> Option<ScalarExpression> {
if positive {
Self::take_range_comparison(expr)
} else {
Self::take_negated_range_comparison(expr)
}
}

fn take_bool_normalized_range_comparison(
op: BinaryOperator,
left_expr: &mut Box<ScalarExpression>,
right_expr: &mut Box<ScalarExpression>,
) -> Option<ScalarExpression> {
let is_eq = matches!(op, BinaryOperator::Eq);
let is_not_eq = matches!(op, BinaryOperator::NotEq);
if !is_eq && !is_not_eq {
return None;
}

if let Some(value) = Self::boolean_constant(right_expr) {
return Self::take_range_comparison_with_polarity(
left_expr,
if is_eq { value } else { !value },
);
}
if let Some(value) = Self::boolean_constant(left_expr) {
return Self::take_range_comparison_with_polarity(
right_expr,
if is_eq { value } else { !value },
);
}

None
}

fn fix_expr(
&mut self,
left_expr: &mut Box<ScalarExpression>,
Expand Down
23 changes: 23 additions & 0 deletions src/optimizer/rule/normalization/simplification.rs
Original file line number Diff line number Diff line change
Expand Up @@ -306,6 +306,29 @@ mod test {
Ok(())
}

#[test]
fn test_simplify_filter_boolean_wrapped_range_comparison() -> Result<(), DatabaseError> {
let table_state = build_t1_table()?;
let expected = Some(Range::Scope {
min: Bound::Unbounded,
max: Bound::Included(DataValue::Int32(10)),
});

for sql in [
"select * from t1 where (c1 > 10) = false",
"select * from t1 where (c1 > 10) != true",
"select * from t1 where not (c1 > 10)",
] {
let plan = table_state.plan(sql)?;
assert_eq!(
plan_filter(&plan, table_state.column_id_by_name("c1"))?,
expected
);
}

Ok(())
}

#[test]
fn test_simplify_filter_repeating_column() -> Result<(), DatabaseError> {
let table_state = build_t1_table()?;
Expand Down
24 changes: 24 additions & 0 deletions tests/slt/where_by_index.slt
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,30 @@ query IIT
select * from t1 where id <= 0 and id >= 3;
----

query IIT
select * from t1 where (id > 10) = false;
----
0 1 2
3 4 5
6 7 8
9 10 11

query IIT
select * from t1 where (id > 10) != true;
----
0 1 2
3 4 5
6 7 8
9 10 11

query IIT
select * from t1 where not (id > 10);
----
0 1 2
3 4 5
6 7 8
9 10 11

query IIT
select * from t1 where id >= 3 or id <= 9 limit 10;
----
Expand Down
15 changes: 15 additions & 0 deletions tests/slt/where_by_index_explain.slt
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,21 @@ explain select * from t1 where id <= 0 and id >= 3;
----
Projection [t1.id, t1.c1, t1.c2] [Project => (Sort Option: Follow)] Filter ((t1.id <= 0) && (t1.id >= 3)), Is Having: false [Filter => (Sort Option: Follow)] TableScan t1 -> [id, c1, c2] [IndexScan By pk_index => Dummy => (Sort Option: OrderBy: (t1.id Asc Nulls Last) ignore_prefix_len: 0)]

query T
explain select * from t1 where (id > 10) = false;
----
Projection [t1.id, t1.c1, t1.c2] [Project => (Sort Option: Follow)] Filter (t1.id <= 10), Is Having: false [Filter => (Sort Option: Follow)] TableScan t1 -> [id, c1, c2] [IndexScan By pk_index => (-inf, 10] => (Sort Option: OrderBy: (t1.id Asc Nulls Last) ignore_prefix_len: 0)]

query T
explain select * from t1 where (id > 10) != true;
----
Projection [t1.id, t1.c1, t1.c2] [Project => (Sort Option: Follow)] Filter (t1.id <= 10), Is Having: false [Filter => (Sort Option: Follow)] TableScan t1 -> [id, c1, c2] [IndexScan By pk_index => (-inf, 10] => (Sort Option: OrderBy: (t1.id Asc Nulls Last) ignore_prefix_len: 0)]

query T
explain select * from t1 where not (id > 10);
----
Projection [t1.id, t1.c1, t1.c2] [Project => (Sort Option: Follow)] Filter (t1.id <= 10), Is Having: false [Filter => (Sort Option: Follow)] TableScan t1 -> [id, c1, c2] [IndexScan By pk_index => (-inf, 10] => (Sort Option: OrderBy: (t1.id Asc Nulls Last) ignore_prefix_len: 0)]

query T
explain select * from t1 where id >= 3 or id <= 9 limit 10;
----
Expand Down
Loading