Skip to content

Commit

Permalink
ARROW-7941: [Rust] [DataFusion] Add support for named columns in logi…
Browse files Browse the repository at this point in the history
…cal plan

This PR adds support for unresolved columns in the logical plan so that users can add columns by name rather than index. There is a new optimizer rule that will resolve these columns and replace them with indices in the plan.

This PR also:

- Removes pointless `Arc`s from the optimizer rules
- Optimizer rules now leverage `LogicalPlanBuilder` for much more concise and readable code

Closes #6730 from andygrove/ARROW-7941

Authored-by: Andy Grove <andygrove73@gmail.com>
Signed-off-by: Andy Grove <andygrove73@gmail.com>
  • Loading branch information
andygrove committed Mar 28, 2020
1 parent 4e680c4 commit 8e40170
Show file tree
Hide file tree
Showing 10 changed files with 386 additions and 267 deletions.
42 changes: 21 additions & 21 deletions rust/datafusion/src/execution/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ use crate::execution::table_impl::TableImpl;
use crate::logicalplan::*;
use crate::optimizer::optimizer::OptimizerRule;
use crate::optimizer::projection_push_down::ProjectionPushDown;
use crate::optimizer::resolve_columns::ResolveColumnsRule;
use crate::optimizer::type_coercion::TypeCoercionRule;
use crate::sql::parser::{DFASTNode, DFParser, FileType};
use crate::sql::planner::{SchemaProvider, SqlToRel};
Expand Down Expand Up @@ -231,12 +232,13 @@ impl ExecutionContext {
}

/// Optimize the logical plan by applying optimizer rules
pub fn optimize(&self, plan: &LogicalPlan) -> Result<Arc<LogicalPlan>> {
pub fn optimize(&self, plan: &LogicalPlan) -> Result<LogicalPlan> {
let rules: Vec<Box<dyn OptimizerRule>> = vec![
Box::new(ResolveColumnsRule::new()),
Box::new(ProjectionPushDown::new()),
Box::new(TypeCoercionRule::new()),
];
let mut plan = Arc::new(plan.clone());
let mut plan = plan.clone();
for mut rule in rules {
plan = rule.optimize(&plan)?;
}
Expand All @@ -246,10 +248,10 @@ impl ExecutionContext {
/// Create a physical plan from a logical plan
pub fn create_physical_plan(
&mut self,
logical_plan: &Arc<LogicalPlan>,
logical_plan: &LogicalPlan,
batch_size: usize,
) -> Result<Arc<dyn ExecutionPlan>> {
match logical_plan.as_ref() {
match logical_plan {
LogicalPlan::TableScan {
table_name,
projection,
Expand Down Expand Up @@ -435,9 +437,10 @@ impl ExecutionContext {
))),
}
}
_ => Err(ExecutionError::NotImplemented(
"Unsupported aggregate expression".to_string(),
)),
other => Err(ExecutionError::General(format!(
"Invalid aggregate expression '{:?}'",
other
))),
}
}

Expand Down Expand Up @@ -731,22 +734,19 @@ mod tests {
let mut ctx = create_ctx(&tmp_dir, 1)?;

let schema = Arc::new(Schema::new(vec![
Field::new("c1", DataType::UInt32, false),
Field::new("c2", DataType::UInt64, false),
Field::new("state", DataType::Utf8, false),
Field::new("salary", DataType::UInt32, false),
]));

let plan = LogicalPlanBuilder::scan(
"default",
"test",
schema.as_ref(),
Some(vec![0, 1]),
)?
.aggregate(
vec![col(0)],
vec![aggregate_expr("SUM", col(1), DataType::Int32)],
)?
.project(vec![col(0), col(1).alias("total_salary")])?
.build()?;
let plan = LogicalPlanBuilder::scan("default", "test", schema.as_ref(), None)?
.aggregate(
vec![col("state")],
vec![aggregate_expr("SUM", col("salary"), DataType::UInt32)],
)?
.project(vec![col("state"), col_index(1).alias("total_salary")])?
.build()?;

let plan = ctx.optimize(&plan)?;

let physical_plan = ctx.create_physical_plan(&Arc::new(plan), 1024)?;
assert_eq!("c1", physical_plan.schema().field(0).name().as_str());
Expand Down
44 changes: 28 additions & 16 deletions rust/datafusion/src/logicalplan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,8 @@ pub enum Expr {
Alias(Arc<Expr>, String),
/// index into a value within the row or complex value
Column(usize),
/// Reference to column by name
UnresolvedColumn(String),
/// literal value
Literal(ScalarValue),
/// binary expression e.g. "age > 21"
Expand Down Expand Up @@ -242,6 +244,9 @@ impl Expr {
match self {
Expr::Alias(expr, _) => expr.get_type(schema),
Expr::Column(n) => Ok(schema.field(*n).data_type().clone()),
Expr::UnresolvedColumn(name) => {
Ok(schema.field_with_name(&name)?.data_type().clone())
}
Expr::Literal(l) => Ok(l.get_datatype()),
Expr::Cast { data_type, .. } => Ok(data_type.clone()),
Expr::ScalarFunction { return_type, .. } => Ok(return_type.clone()),
Expand Down Expand Up @@ -356,11 +361,16 @@ impl Expr {
}
}

/// Create a column expression
pub fn col(index: usize) -> Expr {
/// Create a column expression based on a column index
pub fn col_index(index: usize) -> Expr {
Expr::Column(index)
}

/// Create a column expression based on a column name
pub fn col(name: &str) -> Expr {
Expr::UnresolvedColumn(name.to_owned())
}

/// Create a literal string expression
pub fn lit_str(str: &str) -> Expr {
Expr::Literal(ScalarValue::Utf8(str.to_owned()))
Expand All @@ -380,6 +390,7 @@ impl fmt::Debug for Expr {
match self {
Expr::Alias(expr, alias) => write!(f, "{:?} AS {}", expr, alias),
Expr::Column(i) => write!(f, "#{}", i),
Expr::UnresolvedColumn(name) => write!(f, "#{}", name),
Expr::Literal(v) => write!(f, "{:?}", v),
Expr::Cast { expr, data_type } => {
write!(f, "CAST({:?} AS {:?})", expr, data_type)
Expand Down Expand Up @@ -709,7 +720,7 @@ impl LogicalPlanBuilder {
(0..expr.len()).for_each(|i| match &expr[i] {
Expr::Wildcard => {
(0..input_schema.fields().len())
.for_each(|i| expr_vec.push(col(i).clone()));
.for_each(|i| expr_vec.push(col_index(i).clone()));
}
_ => expr_vec.push(expr[i].clone()),
});
Expand Down Expand Up @@ -791,8 +802,8 @@ mod tests {
&employee_schema(),
Some(vec![0, 3]),
)?
.filter(col(1).eq(&lit_str("CO")))?
.project(vec![col(0)])?
.filter(col("id").eq(&lit_str("CO")))?
.project(vec![col("id")])?
.build()?;

// prove that a plan can be passed to a thread
Expand All @@ -812,13 +823,13 @@ mod tests {
&employee_schema(),
Some(vec![0, 3]),
)?
.filter(col(1).eq(&lit_str("CO")))?
.project(vec![col(0)])?
.filter(col("state").eq(&lit_str("CO")))?
.project(vec![col("id")])?
.build()?;

let expected = "Projection: #0\n \
Selection: #1 Eq Utf8(\"CO\")\n \
TableScan: employee.csv projection=Some([0, 3])";
let expected = "Projection: #id\
\n Selection: #state Eq Utf8(\"CO\")\
\n TableScan: employee.csv projection=Some([0, 3])";

assert_eq!(expected, format!("{:?}", plan));

Expand All @@ -834,15 +845,16 @@ mod tests {
Some(vec![3, 4]),
)?
.aggregate(
vec![col(0)],
vec![aggregate_expr("SUM", col(1), DataType::Int32)],
vec![col("state")],
vec![aggregate_expr("SUM", col("salary"), DataType::Int32)
.alias("total_salary")],
)?
.project(vec![col(0), col(1).alias("total_salary")])?
.project(vec![col("state"), col("total_salary")])?
.build()?;

let expected = "Projection: #0, #1 AS total_salary\
\n Aggregate: groupBy=[[#0]], aggr=[[SUM(#1)]]\
\n TableScan: employee.csv projection=Some([3, 4])";
let expected = "Projection: #state, #total_salary\
\n Aggregate: groupBy=[[#state]], aggr=[[SUM(#salary) AS total_salary]]\
\n TableScan: employee.csv projection=Some([3, 4])";

assert_eq!(expected, format!("{:?}", plan));

Expand Down
1 change: 1 addition & 0 deletions rust/datafusion/src/optimizer/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,5 +20,6 @@

pub mod optimizer;
pub mod projection_push_down;
pub mod resolve_columns;
pub mod type_coercion;
pub mod utils;
3 changes: 1 addition & 2 deletions rust/datafusion/src/optimizer/optimizer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,10 @@

use crate::error::Result;
use crate::logicalplan::LogicalPlan;
use std::sync::Arc;

/// An optimizer rules performs a transformation on a logical plan to produce an optimized
/// logical plan.
pub trait OptimizerRule {
/// Perform optimizations on the plan
fn optimize(&mut self, plan: &LogicalPlan) -> Result<Arc<LogicalPlan>>;
fn optimize(&mut self, plan: &LogicalPlan) -> Result<LogicalPlan>;
}
Loading

0 comments on commit 8e40170

Please sign in to comment.