Skip to content

Commit

Permalink
feat: add BoundPredicateVisitor. Add AlwaysTrue and AlwaysFalse to Pr…
Browse files Browse the repository at this point in the history
…edicate
  • Loading branch information
sdd committed Apr 6, 2024
1 parent 301a0af commit 9ae2e80
Show file tree
Hide file tree
Showing 4 changed files with 329 additions and 4 deletions.
7 changes: 3 additions & 4 deletions crates/iceberg/src/expr/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,13 @@
//! This module contains expressions.

mod term;

use std::fmt::{Display, Formatter};

pub use term::*;
mod predicate;
pub(crate) mod visitors;
pub use predicate::*;

use crate::spec::SchemaRef;
pub use predicate::*;
use std::fmt::{Display, Formatter};

/// Predicate operators used in expressions.
///
Expand Down
48 changes: 48 additions & 0 deletions crates/iceberg/src/expr/predicate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,14 @@ impl<T> UnaryExpression<T> {
debug_assert!(op.is_unary());
Self { op, term }
}

pub(crate) fn term(&self) -> &T {
&self.term
}

pub(crate) fn op(&self) -> &PredicateOperator {
&self.op
}
}

/// Binary predicate, for example, `a > 10`.
Expand Down Expand Up @@ -144,6 +152,18 @@ impl<T> BinaryExpression<T> {
debug_assert!(op.is_binary());
Self { op, term, literal }
}

pub(crate) fn term(&self) -> &T {
&self.term
}

pub(crate) fn op(&self) -> &PredicateOperator {
&self.op
}

pub(crate) fn literal(&self) -> &Datum {
&self.literal
}
}

impl<T: Display> Display for BinaryExpression<T> {
Expand Down Expand Up @@ -191,6 +211,18 @@ impl<T> SetExpression<T> {
debug_assert!(op.is_set());
Self { op, term, literals }
}

pub(crate) fn term(&self) -> &T {
&self.term
}

pub(crate) fn op(&self) -> &PredicateOperator {
&self.op
}

pub(crate) fn literals(&self) -> &FnvHashSet<Datum> {
&self.literals
}
}

impl<T: Bind> Bind for SetExpression<T> {
Expand All @@ -217,6 +249,10 @@ impl<T: Display + Debug> Display for SetExpression<T> {
/// Unbound predicate expression before binding to a schema.
#[derive(Debug, PartialEq)]
pub enum Predicate {
/// AlwaysTrue predicate, for example, `TRUE`.
AlwaysTrue,
/// AlwaysFalse predicate, for example, `FALSE`.
AlwaysFalse,
/// And predicate, for example, `a > 10 AND b < 20`.
And(LogicalExpression<Predicate, 2>),
/// Or predicate, for example, `a > 10 OR b < 20`.
Expand Down Expand Up @@ -367,13 +403,21 @@ impl Bind for Predicate {
bound_literals,
)))
}
Predicate::AlwaysTrue => Ok(BoundPredicate::AlwaysTrue),
Predicate::AlwaysFalse => Ok(BoundPredicate::AlwaysFalse),
}
}
}

impl Display for Predicate {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
match self {
Predicate::AlwaysTrue => {
write!(f, "TRUE")
}
Predicate::AlwaysFalse => {
write!(f, "FALSE")
}
Predicate::And(expr) => {
write!(f, "({}) AND ({})", expr.inputs()[0], expr.inputs()[1])
}
Expand Down Expand Up @@ -461,6 +505,8 @@ impl Predicate {
/// ```
pub fn negate(self) -> Predicate {
match self {
Predicate::AlwaysTrue => Predicate::AlwaysFalse,
Predicate::AlwaysFalse => Predicate::AlwaysTrue,
Predicate::And(expr) => Predicate::Or(LogicalExpression::new(
expr.inputs.map(|expr| Box::new(expr.negate())),
)),
Expand Down Expand Up @@ -525,6 +571,8 @@ impl Predicate {
Predicate::Unary(expr) => Predicate::Unary(expr),
Predicate::Binary(expr) => Predicate::Binary(expr),
Predicate::Set(expr) => Predicate::Set(expr),
Predicate::AlwaysTrue => Predicate::AlwaysTrue,
Predicate::AlwaysFalse => Predicate::AlwaysFalse,
}
}
}
Expand Down
260 changes: 260 additions & 0 deletions crates/iceberg/src/expr/visitors/bound_predicate_visitor.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,260 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

use crate::expr::{BoundPredicate, BoundReference, PredicateOperator};
use crate::spec::Datum;
use crate::Result;
use fnv::FnvHashSet;

pub enum OpLiteral<'a> {
None,
Single(&'a Datum),
Set(&'a FnvHashSet<Datum>),
}

pub trait BoundPredicateVisitor {
type T;

fn always_true(&mut self) -> Result<Self::T>;
fn always_false(&mut self) -> Result<Self::T>;

fn and(&mut self, lhs: Self::T, rhs: Self::T) -> Result<Self::T>;
fn or(&mut self, lhs: Self::T, rhs: Self::T) -> Result<Self::T>;
fn not(&mut self, inner: Self::T) -> Result<Self::T>;

fn op(
&mut self,
op: &PredicateOperator,
reference: Option<&BoundReference>,
literal: OpLiteral,
) -> Result<Self::T>;
}

pub(crate) fn visit<V: BoundPredicateVisitor>(
visitor: &mut V,
predicate: &BoundPredicate,
) -> Result<V::T> {
match predicate {
BoundPredicate::AlwaysTrue => visitor.always_true(),
BoundPredicate::AlwaysFalse => visitor.always_false(),
BoundPredicate::And(expr) => {
let [left_pred, right_pred] = expr.inputs();

let left_result = visit(visitor, left_pred)?;
let right_result = visit(visitor, right_pred)?;

visitor.and(left_result, right_result)
}
BoundPredicate::Or(expr) => {
let [left_pred, right_pred] = expr.inputs();

let left_result = visit(visitor, left_pred)?;
let right_result = visit(visitor, right_pred)?;

visitor.or(left_result, right_result)
}
BoundPredicate::Not(expr) => {
let [inner_pred] = expr.inputs();

let inner_result = visit(visitor, inner_pred)?;

visitor.not(inner_result)
}
BoundPredicate::Unary(expr) => visitor.op(expr.op(), None, OpLiteral::None),
BoundPredicate::Binary(expr) => visitor.op(
expr.op(),
Some(expr.term()),
OpLiteral::Single(expr.literal()),
),
BoundPredicate::Set(expr) => visitor.op(
expr.op(),
Some(expr.term()),
OpLiteral::Set(expr.literals()),
),
}
}

#[cfg(test)]
mod tests {
use crate::expr::visitors::bound_predicate_visitor::{visit, BoundPredicateVisitor, OpLiteral};
use crate::expr::Predicate::{AlwaysFalse, AlwaysTrue};
use crate::expr::{Bind, BoundReference, Predicate, PredicateOperator};
use crate::spec::{Schema, SchemaRef};
use std::ops::Not;
use std::sync::Arc;

struct TestEvaluator {}
impl BoundPredicateVisitor for TestEvaluator {
type T = bool;

fn always_true(&mut self) -> crate::Result<Self::T> {
Ok(true)
}

fn always_false(&mut self) -> crate::Result<Self::T> {
Ok(false)
}

fn and(&mut self, lhs: Self::T, rhs: Self::T) -> crate::Result<Self::T> {
Ok(lhs && rhs)
}

fn or(&mut self, lhs: Self::T, rhs: Self::T) -> crate::Result<Self::T> {
Ok(lhs || rhs)
}

fn not(&mut self, inner: Self::T) -> crate::Result<Self::T> {
Ok(!inner)
}

fn op(
&mut self,
op: &PredicateOperator,
_reference: Option<&BoundReference>,
_literal: OpLiteral,
) -> crate::Result<Self::T> {
Ok(match op {
PredicateOperator::IsNull => true,
PredicateOperator::NotNull => false,
PredicateOperator::IsNan => true,
PredicateOperator::NotNan => false,
PredicateOperator::LessThan => true,
PredicateOperator::LessThanOrEq => false,
PredicateOperator::GreaterThan => true,
PredicateOperator::GreaterThanOrEq => false,
PredicateOperator::Eq => true,
PredicateOperator::NotEq => false,
PredicateOperator::StartsWith => true,
PredicateOperator::NotStartsWith => false,
PredicateOperator::In => false,
PredicateOperator::NotIn => true,
})
}
}

fn create_test_schema() -> SchemaRef {
let schema = Schema::builder().build().unwrap();

let schema_arc = Arc::new(schema);
schema_arc.clone()
}

#[test]
fn test_default_default_always_true() {
let predicate = Predicate::AlwaysTrue;
let bound_predicate = predicate.bind(create_test_schema(), false).unwrap();

let mut test_evaluator = TestEvaluator {};

let result = visit(&mut test_evaluator, &bound_predicate);

assert!(result.unwrap());
}

#[test]
fn test_default_default_always_false() {
let predicate = Predicate::AlwaysFalse;
let bound_predicate = predicate.bind(create_test_schema(), false).unwrap();

let mut test_evaluator = TestEvaluator {};

let result = visit(&mut test_evaluator, &bound_predicate);

assert!(!result.unwrap());
}

#[test]
fn test_default_default_logical_and() {
let predicate = AlwaysTrue.and(AlwaysFalse);
let bound_predicate = predicate.bind(create_test_schema(), false).unwrap();

let mut test_evaluator = TestEvaluator {};

let result = visit(&mut test_evaluator, &bound_predicate);

assert!(!result.unwrap());

let predicate = AlwaysFalse.and(AlwaysFalse);
let bound_predicate = predicate.bind(create_test_schema(), false).unwrap();

let mut test_evaluator = TestEvaluator {};

let result = visit(&mut test_evaluator, &bound_predicate);

assert!(!result.unwrap());

let predicate = AlwaysTrue.and(AlwaysTrue);
let bound_predicate = predicate.bind(create_test_schema(), false).unwrap();

let mut test_evaluator = TestEvaluator {};

let result = visit(&mut test_evaluator, &bound_predicate);

assert!(result.unwrap());
}

#[test]
fn test_default_default_logical_or() {
let predicate = AlwaysTrue.or(AlwaysFalse);
let bound_predicate = predicate.bind(create_test_schema(), false).unwrap();

let mut test_evaluator = TestEvaluator {};

let result = visit(&mut test_evaluator, &bound_predicate);

assert!(result.unwrap());

let predicate = AlwaysFalse.or(AlwaysFalse);
let bound_predicate = predicate.bind(create_test_schema(), false).unwrap();

let mut test_evaluator = TestEvaluator {};

let result = visit(&mut test_evaluator, &bound_predicate);

assert!(!result.unwrap());

let predicate = AlwaysTrue.or(AlwaysTrue);
let bound_predicate = predicate.bind(create_test_schema(), false).unwrap();

let mut test_evaluator = TestEvaluator {};

let result = visit(&mut test_evaluator, &bound_predicate);

assert!(result.unwrap());
}

#[test]
fn test_default_default_not() {
let predicate = AlwaysFalse.not();
let bound_predicate = predicate.bind(create_test_schema(), false).unwrap();

let mut test_evaluator = TestEvaluator {};

let result = visit(&mut test_evaluator, &bound_predicate);

assert!(result.unwrap());

let predicate = AlwaysTrue.not();
let bound_predicate = predicate.bind(create_test_schema(), false).unwrap();

let mut test_evaluator = TestEvaluator {};

let result = visit(&mut test_evaluator, &bound_predicate);

assert!(!result.unwrap());
}
}

0 comments on commit 9ae2e80

Please sign in to comment.