From 1b9508031b0a38fa482759a94ca74f75f29cc60e Mon Sep 17 00:00:00 2001 From: Scott Donnelly Date: Thu, 4 Apr 2024 20:49:02 +0100 Subject: [PATCH] feat: add BoundPredicateVisitor. Add AlwaysTrue and AlwaysFalse to Predicate --- crates/iceberg/src/expr/mod.rs | 7 +- crates/iceberg/src/expr/predicate.rs | 48 +++ .../expr/visitors/bound_predicate_visitor.rs | 363 ++++++++++++++++++ crates/iceberg/src/expr/visitors/mod.rs | 18 + 4 files changed, 432 insertions(+), 4 deletions(-) create mode 100644 crates/iceberg/src/expr/visitors/bound_predicate_visitor.rs create mode 100644 crates/iceberg/src/expr/visitors/mod.rs diff --git a/crates/iceberg/src/expr/mod.rs b/crates/iceberg/src/expr/mod.rs index 0d329682..5fcc3b5e 100644 --- a/crates/iceberg/src/expr/mod.rs +++ b/crates/iceberg/src/expr/mod.rs @@ -18,14 +18,13 @@ //! This module contains expressions. mod term; - -use std::fmt::{Display, Formatter}; - pub use term::*; mod predicate; +pub(crate) mod visitors; +pub use predicate::*; use crate::spec::SchemaRef; -pub use predicate::*; +use std::fmt::{Display, Formatter}; /// Predicate operators used in expressions. /// diff --git a/crates/iceberg/src/expr/predicate.rs b/crates/iceberg/src/expr/predicate.rs index da8a863d..0e6c52a0 100644 --- a/crates/iceberg/src/expr/predicate.rs +++ b/crates/iceberg/src/expr/predicate.rs @@ -116,6 +116,14 @@ impl UnaryExpression { debug_assert!(op.is_unary()); Self { op, term } } + + pub(crate) fn term(&self) -> &T { + &self.term + } + + pub(crate) fn op(&self) -> &PredicateOperator { + &self.op + } } /// Binary predicate, for example, `a > 10`. @@ -144,6 +152,18 @@ impl BinaryExpression { debug_assert!(op.is_binary()); Self { op, term, literal } } + + pub(crate) fn term(&self) -> &T { + &self.term + } + + pub(crate) fn op(&self) -> &PredicateOperator { + &self.op + } + + pub(crate) fn literal(&self) -> &Datum { + &self.literal + } } impl Display for BinaryExpression { @@ -191,6 +211,18 @@ impl SetExpression { debug_assert!(op.is_set()); Self { op, term, literals } } + + pub(crate) fn term(&self) -> &T { + &self.term + } + + pub(crate) fn op(&self) -> &PredicateOperator { + &self.op + } + + pub(crate) fn literals(&self) -> &FnvHashSet { + &self.literals + } } impl Bind for SetExpression { @@ -217,6 +249,10 @@ impl Display for SetExpression { /// Unbound predicate expression before binding to a schema. #[derive(Debug, PartialEq)] pub enum Predicate { + /// AlwaysTrue predicate, for example, `TRUE`. + AlwaysTrue, + /// AlwaysFalse predicate, for example, `FALSE`. + AlwaysFalse, /// And predicate, for example, `a > 10 AND b < 20`. And(LogicalExpression), /// Or predicate, for example, `a > 10 OR b < 20`. @@ -367,6 +403,8 @@ impl Bind for Predicate { bound_literals, ))) } + Predicate::AlwaysTrue => Ok(BoundPredicate::AlwaysTrue), + Predicate::AlwaysFalse => Ok(BoundPredicate::AlwaysFalse), } } } @@ -374,6 +412,12 @@ impl Bind for Predicate { impl Display for Predicate { fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { match self { + Predicate::AlwaysTrue => { + write!(f, "TRUE") + } + Predicate::AlwaysFalse => { + write!(f, "FALSE") + } Predicate::And(expr) => { write!(f, "({}) AND ({})", expr.inputs()[0], expr.inputs()[1]) } @@ -461,6 +505,8 @@ impl Predicate { /// ``` pub fn negate(self) -> Predicate { match self { + Predicate::AlwaysTrue => Predicate::AlwaysFalse, + Predicate::AlwaysFalse => Predicate::AlwaysTrue, Predicate::And(expr) => Predicate::Or(LogicalExpression::new( expr.inputs.map(|expr| Box::new(expr.negate())), )), @@ -525,6 +571,8 @@ impl Predicate { Predicate::Unary(expr) => Predicate::Unary(expr), Predicate::Binary(expr) => Predicate::Binary(expr), Predicate::Set(expr) => Predicate::Set(expr), + Predicate::AlwaysTrue => Predicate::AlwaysTrue, + Predicate::AlwaysFalse => Predicate::AlwaysFalse, } } } diff --git a/crates/iceberg/src/expr/visitors/bound_predicate_visitor.rs b/crates/iceberg/src/expr/visitors/bound_predicate_visitor.rs new file mode 100644 index 00000000..f3c11c97 --- /dev/null +++ b/crates/iceberg/src/expr/visitors/bound_predicate_visitor.rs @@ -0,0 +1,363 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::expr::{BoundPredicate, BoundReference, PredicateOperator}; +use crate::spec::Datum; +use crate::Result; +use fnv::FnvHashSet; + +pub trait BoundPredicateVisitor { + type T; + + fn visit(&mut self, predicate: &BoundPredicate) -> Result { + match predicate { + BoundPredicate::AlwaysTrue => self.always_true(), + BoundPredicate::AlwaysFalse => self.always_false(), + BoundPredicate::And(expr) => { + let [left_pred, right_pred] = expr.inputs(); + + let left_result = self.visit(left_pred)?; + let right_result = self.visit(right_pred)?; + + self.and(left_result, right_result) + } + BoundPredicate::Or(expr) => { + let [left_pred, right_pred] = expr.inputs(); + + let left_result = self.visit(left_pred)?; + let right_result = self.visit(right_pred)?; + + self.or(left_result, right_result) + } + BoundPredicate::Not(expr) => { + let [inner_pred] = expr.inputs(); + + let inner_result = self.visit(inner_pred)?; + + self.not(inner_result) + } + BoundPredicate::Unary(expr) => match expr.op() { + PredicateOperator::IsNull => self.is_null(expr.term()), + PredicateOperator::NotNull => self.not_null(expr.term()), + PredicateOperator::IsNan => self.is_nan(expr.term()), + PredicateOperator::NotNan => self.not_nan(expr.term()), + op => { + panic!("Unexpected op for unary predicate: {}", &op) + } + }, + BoundPredicate::Binary(expr) => { + let reference = expr.term(); + let literal = expr.literal(); + match expr.op() { + PredicateOperator::LessThan => self.less_than(reference, literal), + PredicateOperator::LessThanOrEq => self.less_than_or_eq(reference, literal), + PredicateOperator::GreaterThan => self.greater_than(reference, literal), + PredicateOperator::GreaterThanOrEq => { + self.greater_than_or_eq(reference, literal) + } + PredicateOperator::Eq => self.eq(reference, literal), + PredicateOperator::NotEq => self.not_eq(reference, literal), + PredicateOperator::StartsWith => self.starts_with(reference, literal), + PredicateOperator::NotStartsWith => self.not_starts_with(reference, literal), + op => { + panic!("Unexpected op for binary predicate: {}", &op) + } + } + } + BoundPredicate::Set(expr) => { + let reference = expr.term(); + let literals = expr.literals(); + match expr.op() { + PredicateOperator::In => self.r#in(reference, literals), + PredicateOperator::NotIn => self.not_in(reference, literals), + op => { + panic!("Unexpected op for set predicate: {}", &op) + } + } + } + } + } + + fn always_true(&mut self) -> Result; + fn always_false(&mut self) -> Result; + + fn and(&mut self, lhs: Self::T, rhs: Self::T) -> Result; + fn or(&mut self, lhs: Self::T, rhs: Self::T) -> Result; + fn not(&mut self, inner: Self::T) -> Result; + + fn is_null(&mut self, reference: &BoundReference) -> Result; + fn not_null(&mut self, reference: &BoundReference) -> Result; + fn is_nan(&mut self, reference: &BoundReference) -> Result; + fn not_nan(&mut self, reference: &BoundReference) -> Result; + + fn less_than(&mut self, reference: &BoundReference, literal: &Datum) -> Result; + fn less_than_or_eq(&mut self, reference: &BoundReference, literal: &Datum) -> Result; + fn greater_than(&mut self, reference: &BoundReference, literal: &Datum) -> Result; + fn greater_than_or_eq( + &mut self, + reference: &BoundReference, + literal: &Datum, + ) -> Result; + fn eq(&mut self, reference: &BoundReference, literal: &Datum) -> Result; + fn not_eq(&mut self, reference: &BoundReference, literal: &Datum) -> Result; + fn starts_with(&mut self, reference: &BoundReference, literal: &Datum) -> Result; + fn not_starts_with(&mut self, reference: &BoundReference, literal: &Datum) -> Result; + + fn r#in(&mut self, reference: &BoundReference, literals: &FnvHashSet) + -> Result; + fn not_in( + &mut self, + reference: &BoundReference, + literals: &FnvHashSet, + ) -> Result; +} + +#[cfg(test)] +mod tests { + use crate::expr::visitors::bound_predicate_visitor::BoundPredicateVisitor; + use crate::expr::Predicate::{AlwaysFalse, AlwaysTrue}; + use crate::expr::{Bind, BoundReference, Predicate}; + use crate::spec::{Datum, Schema, SchemaRef}; + use fnv::FnvHashSet; + use std::ops::Not; + use std::sync::Arc; + + struct TestEvaluator {} + impl BoundPredicateVisitor for TestEvaluator { + type T = bool; + + fn always_true(&mut self) -> crate::Result { + Ok(true) + } + + fn always_false(&mut self) -> crate::Result { + Ok(false) + } + + fn and(&mut self, lhs: Self::T, rhs: Self::T) -> crate::Result { + Ok(lhs && rhs) + } + + fn or(&mut self, lhs: Self::T, rhs: Self::T) -> crate::Result { + Ok(lhs || rhs) + } + + fn not(&mut self, inner: Self::T) -> crate::Result { + Ok(!inner) + } + + fn is_null(&mut self, _reference: &BoundReference) -> crate::Result { + Ok(true) + } + + fn not_null(&mut self, _reference: &BoundReference) -> crate::Result { + Ok(false) + } + + fn is_nan(&mut self, _reference: &BoundReference) -> crate::Result { + Ok(true) + } + + fn not_nan(&mut self, _reference: &BoundReference) -> crate::Result { + Ok(false) + } + + fn less_than( + &mut self, + _reference: &BoundReference, + _literal: &Datum, + ) -> crate::Result { + Ok(true) + } + + fn less_than_or_eq( + &mut self, + _reference: &BoundReference, + _literal: &Datum, + ) -> crate::Result { + Ok(false) + } + + fn greater_than( + &mut self, + _reference: &BoundReference, + _literal: &Datum, + ) -> crate::Result { + Ok(true) + } + + fn greater_than_or_eq( + &mut self, + _reference: &BoundReference, + _literal: &Datum, + ) -> crate::Result { + Ok(false) + } + + fn eq(&mut self, _reference: &BoundReference, _literal: &Datum) -> crate::Result { + Ok(true) + } + + fn not_eq(&mut self, _reference: &BoundReference, _literal: &Datum) -> crate::Result { + Ok(false) + } + + fn starts_with( + &mut self, + _reference: &BoundReference, + _literal: &Datum, + ) -> crate::Result { + Ok(true) + } + + fn not_starts_with( + &mut self, + _reference: &BoundReference, + _literal: &Datum, + ) -> crate::Result { + Ok(false) + } + + fn r#in( + &mut self, + _reference: &BoundReference, + _literals: &FnvHashSet, + ) -> crate::Result { + Ok(true) + } + + fn not_in( + &mut self, + _reference: &BoundReference, + _literals: &FnvHashSet, + ) -> crate::Result { + Ok(false) + } + } + + fn create_test_schema() -> SchemaRef { + let schema = Schema::builder().build().unwrap(); + + let schema_arc = Arc::new(schema); + schema_arc.clone() + } + + #[test] + fn test_default_default_always_true() { + let predicate = Predicate::AlwaysTrue; + let bound_predicate = predicate.bind(create_test_schema(), false).unwrap(); + + let mut test_evaluator = TestEvaluator {}; + + let result = test_evaluator.visit(&bound_predicate); + + assert!(result.unwrap()); + } + + #[test] + fn test_default_default_always_false() { + let predicate = Predicate::AlwaysFalse; + let bound_predicate = predicate.bind(create_test_schema(), false).unwrap(); + + let mut test_evaluator = TestEvaluator {}; + + let result = test_evaluator.visit(&bound_predicate); + + assert!(!result.unwrap()); + } + + #[test] + fn test_default_default_logical_and() { + let predicate = AlwaysTrue.and(AlwaysFalse); + let bound_predicate = predicate.bind(create_test_schema(), false).unwrap(); + + let mut test_evaluator = TestEvaluator {}; + + let result = test_evaluator.visit(&bound_predicate); + + assert!(!result.unwrap()); + + let predicate = AlwaysFalse.and(AlwaysFalse); + let bound_predicate = predicate.bind(create_test_schema(), false).unwrap(); + + let mut test_evaluator = TestEvaluator {}; + + let result = test_evaluator.visit(&bound_predicate); + + assert!(!result.unwrap()); + + let predicate = AlwaysTrue.and(AlwaysTrue); + let bound_predicate = predicate.bind(create_test_schema(), false).unwrap(); + + let mut test_evaluator = TestEvaluator {}; + + let result = test_evaluator.visit(&bound_predicate); + + assert!(result.unwrap()); + } + + #[test] + fn test_default_default_logical_or() { + let predicate = AlwaysTrue.or(AlwaysFalse); + let bound_predicate = predicate.bind(create_test_schema(), false).unwrap(); + + let mut test_evaluator = TestEvaluator {}; + + let result = test_evaluator.visit(&bound_predicate); + + assert!(result.unwrap()); + + let predicate = AlwaysFalse.or(AlwaysFalse); + let bound_predicate = predicate.bind(create_test_schema(), false).unwrap(); + + let mut test_evaluator = TestEvaluator {}; + + let result = test_evaluator.visit(&bound_predicate); + + assert!(!result.unwrap()); + + let predicate = AlwaysTrue.or(AlwaysTrue); + let bound_predicate = predicate.bind(create_test_schema(), false).unwrap(); + + let mut test_evaluator = TestEvaluator {}; + + let result = test_evaluator.visit(&bound_predicate); + + assert!(result.unwrap()); + } + + #[test] + fn test_default_default_not() { + let predicate = AlwaysFalse.not(); + let bound_predicate = predicate.bind(create_test_schema(), false).unwrap(); + + let mut test_evaluator = TestEvaluator {}; + + let result = test_evaluator.visit(&bound_predicate); + + assert!(result.unwrap()); + + let predicate = AlwaysTrue.not(); + let bound_predicate = predicate.bind(create_test_schema(), false).unwrap(); + + let mut test_evaluator = TestEvaluator {}; + + let result = test_evaluator.visit(&bound_predicate); + + assert!(!result.unwrap()); + } +} diff --git a/crates/iceberg/src/expr/visitors/mod.rs b/crates/iceberg/src/expr/visitors/mod.rs new file mode 100644 index 00000000..242a55c4 --- /dev/null +++ b/crates/iceberg/src/expr/visitors/mod.rs @@ -0,0 +1,18 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +pub(crate) mod bound_predicate_visitor;