Skip to content

Commit

Permalink
ARROW-4815: [Rust] [DataFusion] Add support for SQL wilcard operator
Browse files Browse the repository at this point in the history
Closes #6716 from andygrove/ARROW-4815

Authored-by: Andy Grove <andygrove73@gmail.com>
Signed-off-by: Andy Grove <andygrove73@gmail.com>
  • Loading branch information
andygrove committed Mar 26, 2020
1 parent 25e8c2b commit 76c6424
Show file tree
Hide file tree
Showing 5 changed files with 98 additions and 52 deletions.
61 changes: 41 additions & 20 deletions rust/datafusion/src/logicalplan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -232,45 +232,50 @@ pub enum Expr {
/// The `DataType` the expression will yield
return_type: DataType,
},
/// Wildcard
Wildcard,
}

impl Expr {
/// Find the `DataType` for the expression
pub fn get_type(&self, schema: &Schema) -> DataType {
pub fn get_type(&self, schema: &Schema) -> Result<DataType> {
match self {
Expr::Alias(expr, _) => expr.get_type(schema),
Expr::Column(n) => schema.field(*n).data_type().clone(),
Expr::Literal(l) => l.get_datatype(),
Expr::Cast { data_type, .. } => data_type.clone(),
Expr::ScalarFunction { return_type, .. } => return_type.clone(),
Expr::AggregateFunction { return_type, .. } => return_type.clone(),
Expr::Not(_) => DataType::Boolean,
Expr::IsNull(_) => DataType::Boolean,
Expr::IsNotNull(_) => DataType::Boolean,
Expr::Column(n) => Ok(schema.field(*n).data_type().clone()),
Expr::Literal(l) => Ok(l.get_datatype()),
Expr::Cast { data_type, .. } => Ok(data_type.clone()),
Expr::ScalarFunction { return_type, .. } => Ok(return_type.clone()),
Expr::AggregateFunction { return_type, .. } => Ok(return_type.clone()),
Expr::Not(_) => Ok(DataType::Boolean),
Expr::IsNull(_) => Ok(DataType::Boolean),
Expr::IsNotNull(_) => Ok(DataType::Boolean),
Expr::BinaryExpr {
ref left,
ref right,
ref op,
} => match op {
Operator::Eq | Operator::NotEq => DataType::Boolean,
Operator::Lt | Operator::LtEq => DataType::Boolean,
Operator::Gt | Operator::GtEq => DataType::Boolean,
Operator::And | Operator::Or => DataType::Boolean,
Operator::Eq | Operator::NotEq => Ok(DataType::Boolean),
Operator::Lt | Operator::LtEq => Ok(DataType::Boolean),
Operator::Gt | Operator::GtEq => Ok(DataType::Boolean),
Operator::And | Operator::Or => Ok(DataType::Boolean),
_ => {
let left_type = left.get_type(schema);
let right_type = right.get_type(schema);
utils::get_supertype(&left_type, &right_type).unwrap()
let left_type = left.get_type(schema)?;
let right_type = right.get_type(schema)?;
utils::get_supertype(&left_type, &right_type)
}
},
Expr::Sort { ref expr, .. } => expr.get_type(schema),
Expr::Wildcard => Err(ExecutionError::General(
"Wildcard expressions are not valid in a logical query plan".to_owned(),
)),
}
}

/// Perform a type cast on the expression value.
///
/// Will `Err` if the type cast cannot be performed.
pub fn cast_to(&self, cast_to_type: &DataType, schema: &Schema) -> Result<Expr> {
let this_type = self.get_type(schema);
let this_type = self.get_type(schema)?;
if this_type == *cast_to_type {
Ok(self.clone())
} else if can_coerce_from(cast_to_type, &this_type) {
Expand Down Expand Up @@ -414,6 +419,7 @@ impl fmt::Debug for Expr {

write!(f, ")")
}
Expr::Wildcard => write!(f, "*"),
}
}
}
Expand Down Expand Up @@ -698,12 +704,27 @@ impl LogicalPlanBuilder {
/// Apply a projection
pub fn project(&self, expr: &Vec<Expr>) -> Result<Self> {
let input_schema = self.plan.schema();
let projected_expr = if expr.contains(&Expr::Wildcard) {
let mut expr_vec = vec![];
(0..expr.len()).for_each(|i| match &expr[i] {
Expr::Wildcard => {
(0..input_schema.fields().len())
.for_each(|i| expr_vec.push(col(i).clone()));
}
_ => expr_vec.push(expr[i].clone()),
});
expr_vec
} else {
expr.clone()
};

let schema =
Schema::new(utils::exprlist_to_fields(&expr, input_schema.as_ref())?);
let schema = Schema::new(utils::exprlist_to_fields(
&projected_expr,
input_schema.as_ref(),
)?);

Ok(Self::from(&LogicalPlan::Projection {
expr: expr.clone(),
expr: projected_expr,
input: Arc::new(self.plan.clone()),
schema: Arc::new(schema),
}))
Expand Down
13 changes: 8 additions & 5 deletions rust/datafusion/src/optimizer/projection_push_down.rs
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ impl ProjectionPushDown {
schema,
} => {
// collect all columns referenced by projection expressions
utils::exprlist_to_column_indices(&expr, accum);
utils::exprlist_to_column_indices(&expr, accum)?;

// push projection down
let input = self.optimize_plan(&input, accum, mapping)?;
Expand All @@ -74,7 +74,7 @@ impl ProjectionPushDown {
}
LogicalPlan::Selection { expr, input } => {
// collect all columns referenced by filter expression
utils::expr_to_column_indices(expr, accum);
utils::expr_to_column_indices(expr, accum)?;

// push projection down
let input = self.optimize_plan(&input, accum, mapping)?;
Expand All @@ -94,8 +94,8 @@ impl ProjectionPushDown {
schema,
} => {
// collect all columns referenced by grouping and aggregate expressions
utils::exprlist_to_column_indices(&group_expr, accum);
utils::exprlist_to_column_indices(&aggr_expr, accum);
utils::exprlist_to_column_indices(&group_expr, accum)?;
utils::exprlist_to_column_indices(&aggr_expr, accum)?;

// push projection down
let input = self.optimize_plan(&input, accum, mapping)?;
Expand All @@ -117,7 +117,7 @@ impl ProjectionPushDown {
schema,
} => {
// collect all columns referenced by sort expressions
utils::exprlist_to_column_indices(&expr, accum);
utils::exprlist_to_column_indices(&expr, accum)?;

// push projection down
let input = self.optimize_plan(&input, accum, mapping)?;
Expand Down Expand Up @@ -271,6 +271,9 @@ impl ProjectionPushDown {
args: self.rewrite_exprs(args, mapping)?,
return_type: return_type.clone(),
}),
Expr::Wildcard => Err(ExecutionError::General(
"Wildcard expressions are not valid in a logical query plan".to_owned(),
)),
}
}

Expand Down
4 changes: 2 additions & 2 deletions rust/datafusion/src/optimizer/type_coercion.rs
Original file line number Diff line number Diff line change
Expand Up @@ -96,8 +96,8 @@ fn rewrite_expr(expr: &Expr, schema: &Schema) -> Result<Expr> {
Expr::BinaryExpr { left, op, right } => {
let left = rewrite_expr(left, schema)?;
let right = rewrite_expr(right, schema)?;
let left_type = left.get_type(schema);
let right_type = right.get_type(schema);
let left_type = left.get_type(schema)?;
let right_type = right.get_type(schema)?;
if left_type == right_type {
Ok(Expr::BinaryExpr {
left: Arc::new(left),
Expand Down
39 changes: 27 additions & 12 deletions rust/datafusion/src/optimizer/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,38 +26,52 @@ use crate::logicalplan::Expr;

/// Recursively walk a list of expression trees, collecting the unique set of column
/// indexes referenced in the expression
pub fn exprlist_to_column_indices(expr: &Vec<Expr>, accum: &mut HashSet<usize>) {
expr.iter().for_each(|e| expr_to_column_indices(e, accum));
pub fn exprlist_to_column_indices(
expr: &Vec<Expr>,
accum: &mut HashSet<usize>,
) -> Result<()> {
for e in expr {
expr_to_column_indices(e, accum)?;
}
Ok(())
}

/// Recursively walk an expression tree, collecting the unique set of column indexes
/// referenced in the expression
pub fn expr_to_column_indices(expr: &Expr, accum: &mut HashSet<usize>) {
pub fn expr_to_column_indices(expr: &Expr, accum: &mut HashSet<usize>) -> Result<()> {
match expr {
Expr::Alias(expr, _) => expr_to_column_indices(expr, accum),
Expr::Column(i) => {
accum.insert(*i);
Ok(())
}
Expr::Literal(_) => {
// not needed
Ok(())
}
Expr::Literal(_) => { /* not needed */ }
Expr::Not(e) => expr_to_column_indices(e, accum),
Expr::IsNull(e) => expr_to_column_indices(e, accum),
Expr::IsNotNull(e) => expr_to_column_indices(e, accum),
Expr::BinaryExpr { left, right, .. } => {
expr_to_column_indices(left, accum);
expr_to_column_indices(right, accum);
expr_to_column_indices(left, accum)?;
expr_to_column_indices(right, accum)?;
Ok(())
}
Expr::Cast { expr, .. } => expr_to_column_indices(expr, accum),
Expr::Sort { expr, .. } => expr_to_column_indices(expr, accum),
Expr::AggregateFunction { args, .. } => exprlist_to_column_indices(args, accum),
Expr::ScalarFunction { args, .. } => exprlist_to_column_indices(args, accum),
Expr::Wildcard => Err(ExecutionError::General(
"Wildcard expressions are not valid in a logical query plan".to_owned(),
)),
}
}

/// Create field meta-data from an expression, for use in a result set schema
pub fn expr_to_field(e: &Expr, input_schema: &Schema) -> Result<Field> {
match e {
Expr::Alias(expr, name) => {
Ok(Field::new(name, expr.get_type(input_schema), true))
Ok(Field::new(name, expr.get_type(input_schema)?, true))
}
Expr::Column(i) => {
let input_schema_field_count = input_schema.fields().len();
Expand Down Expand Up @@ -89,8 +103,8 @@ pub fn expr_to_field(e: &Expr, input_schema: &Schema) -> Result<Field> {
ref right,
..
} => {
let left_type = left.get_type(input_schema);
let right_type = right.get_type(input_schema);
let left_type = left.get_type(input_schema)?;
let right_type = right.get_type(input_schema)?;
Ok(Field::new(
"binary_expr",
get_supertype(&left_type, &right_type).unwrap(),
Expand Down Expand Up @@ -235,23 +249,24 @@ mod tests {
use std::sync::Arc;

#[test]
fn test_collect_expr() {
fn test_collect_expr() -> Result<()> {
let mut accum: HashSet<usize> = HashSet::new();
expr_to_column_indices(
&Expr::Cast {
expr: Arc::new(Expr::Column(3)),
data_type: DataType::Float64,
},
&mut accum,
);
)?;
expr_to_column_indices(
&Expr::Cast {
expr: Arc::new(Expr::Column(3)),
data_type: DataType::Float64,
},
&mut accum,
);
)?;
assert_eq!(1, accum.len());
assert!(accum.contains(&3));
Ok(())
}
}
33 changes: 20 additions & 13 deletions rust/datafusion/src/sql/planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -287,9 +287,7 @@ impl<S: SchemaProvider> SqlToRel<S> {
}
}

ASTNode::SQLWildcard => {
Err(ExecutionError::NotImplemented("SQL wildcard operator is not supported in projection - please use explicit column names".to_string()))
}
ASTNode::SQLWildcard => Ok(Expr::Wildcard),

ASTNode::SQLCast {
ref expr,
Expand All @@ -307,17 +305,17 @@ impl<S: SchemaProvider> SqlToRel<S> {
Ok(Expr::IsNotNull(Arc::new(self.sql_to_rex(expr, schema)?)))
}

ASTNode::SQLUnary{
ASTNode::SQLUnary {
ref operator,
ref expr,
} => {
match *operator {
SQLOperator::Not => Ok(Expr::Not(Arc::new(self.sql_to_rex(expr, schema)?))),
_ => Err(ExecutionError::InternalError(format!(
"SQL binary operator cannot be interpreted as a unary operator"
))),
} => match *operator {
SQLOperator::Not => {
Ok(Expr::Not(Arc::new(self.sql_to_rex(expr, schema)?)))
}
}
_ => Err(ExecutionError::InternalError(format!(
"SQL binary operator cannot be interpreted as a unary operator"
))),
},

ASTNode::SQLBinaryExpr {
ref left,
Expand Down Expand Up @@ -370,7 +368,7 @@ impl<S: SchemaProvider> SqlToRel<S> {

// return type is same as the argument type for these aggregate
// functions
let return_type = rex_args[0].get_type(schema).clone();
let return_type = rex_args[0].get_type(schema)?.clone();

Ok(Expr::AggregateFunction {
name: id.clone(),
Expand All @@ -387,7 +385,7 @@ impl<S: SchemaProvider> SqlToRel<S> {
}
ASTNode::SQLWildcard => {
Ok(Expr::Literal(ScalarValue::UInt8(1)))
},
}
_ => self.sql_to_rex(a, schema),
})
.collect::<Result<Vec<Expr>>>()?;
Expand Down Expand Up @@ -575,6 +573,15 @@ mod tests {
);
}

#[test]
fn test_wildcard() {
quick_test(
"SELECT * from person",
"Projection: #0, #1, #2, #3, #4, #5, #6\
\n TableScan: person projection=None",
);
}

#[test]
fn select_count_one() {
let sql = "SELECT COUNT(1) FROM person";
Expand Down

0 comments on commit 76c6424

Please sign in to comment.