Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ARROW-5945: [Rust] [DataFusion] Table trait can now be used to build real queries #4875

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
3 changes: 2 additions & 1 deletion rust/datafusion/src/execution/context.rs
Expand Up @@ -66,6 +66,7 @@ impl ExecutionContext {
/// of RecordBatch instances)
pub fn sql(&mut self, sql: &str, batch_size: usize) -> Result<Rc<RefCell<Relation>>> {
let plan = self.create_logical_plan(sql)?;
let plan = self.optimize(&plan)?;
Ok(self.execute(&plan, batch_size)?)
}

Expand All @@ -86,7 +87,7 @@ impl ExecutionContext {
// plan the query (create a logical relational plan)
let plan = query_planner.sql_to_rel(&ansi)?;

Ok(self.optimize(&plan)?)
Ok(plan)
}
DFASTNode::CreateExternalTable {
name,
Expand Down
290 changes: 263 additions & 27 deletions rust/datafusion/src/execution/table_impl.rs
Expand Up @@ -19,7 +19,7 @@

use std::sync::Arc;

use crate::arrow::datatypes::{Field, Schema};
use crate::arrow::datatypes::{DataType, Field, Schema};
use crate::error::{ExecutionError, Result};
use crate::logicalplan::Expr::Literal;
use crate::logicalplan::ScalarValue;
Expand All @@ -41,34 +41,64 @@ impl TableImpl {
impl Table for TableImpl {
/// Apply a projection based on a list of column names
fn select_columns(&self, columns: Vec<&str>) -> Result<Arc<Table>> {
let schema = self.plan.schema();
let mut projection_index: Vec<usize> = Vec::with_capacity(columns.len());
let mut expr: Vec<Expr> = Vec::with_capacity(columns.len());
for column_name in columns {
let i = self.column_index(column_name)?;
expr.push(Expr::Column(i));
}
self.select(expr)
}

/// Create a projection based on arbitrary expressions
fn select(&self, expr_list: Vec<Expr>) -> Result<Arc<Table>> {
let schema = self.plan.schema();
let mut field: Vec<Field> = Vec::with_capacity(expr_list.len());

for column in columns {
match schema.column_with_name(column) {
Some((i, _)) => {
projection_index.push(i);
expr.push(Expr::Column(i));
for expr in &expr_list {
match expr {
Expr::Column(i) => {
field.push(schema.field(*i).clone());
}
_ => {
return Err(ExecutionError::InvalidColumn(format!(
"No column named '{}'",
column
)));
other => {
return Err(ExecutionError::NotImplemented(format!(
"Expr {:?} is not currently supported in this context",
other
)))
}
}
}

Ok(Arc::new(TableImpl::new(Arc::new(
LogicalPlan::Projection {
expr,
expr: expr_list.clone(),
input: self.plan.clone(),
schema: projection(&schema, &projection_index)?,
schema: Arc::new(Schema::new(field)),
},
))))
}

/// Create a selection based on a filter expression
fn filter(&self, expr: Expr) -> Result<Arc<Table>> {
Ok(Arc::new(TableImpl::new(Arc::new(LogicalPlan::Selection {
expr,
input: self.plan.clone(),
}))))
}

/// Perform an aggregate query
fn aggregate(
&self,
group_expr: Vec<Expr>,
aggr_expr: Vec<Expr>,
) -> Result<Arc<Table>> {
Ok(Arc::new(TableImpl::new(Arc::new(LogicalPlan::Aggregate {
input: self.plan.clone(),
group_expr,
aggr_expr,
schema: Arc::new(Schema::new(vec![])),
}))))
}

/// Limit the number of rows
fn limit(&self, n: usize) -> Result<Arc<Table>> {
Ok(Arc::new(TableImpl::new(Arc::new(LogicalPlan::Limit {
Expand All @@ -78,24 +108,230 @@ impl Table for TableImpl {
}))))
}

/// Return an expression representing a column within this table
fn col(&self, name: &str) -> Result<Expr> {
Ok(Expr::Column(self.column_index(name)?))
}

/// Return the index of a column within this table's schema
fn column_index(&self, name: &str) -> Result<usize> {
let schema = self.plan.schema();
match schema.column_with_name(name) {
Some((i, _)) => Ok(i),
_ => Err(ExecutionError::InvalidColumn(format!(
"No column named '{}'",
name
))),
}
}

/// Create an expression to represent the min() aggregate function
fn min(&self, expr: &Expr) -> Result<Expr> {
self.aggregate_expr("MIN", expr)
}

/// Create an expression to represent the max() aggregate function
fn max(&self, expr: &Expr) -> Result<Expr> {
self.aggregate_expr("MAX", expr)
}

/// Create an expression to represent the sum() aggregate function
fn sum(&self, expr: &Expr) -> Result<Expr> {
self.aggregate_expr("SUM", expr)
}

/// Create an expression to represent the avg() aggregate function
fn avg(&self, expr: &Expr) -> Result<Expr> {
self.aggregate_expr("AVG", expr)
}

/// Create an expression to represent the count() aggregate function
fn count(&self, expr: &Expr) -> Result<Expr> {
self.aggregate_expr("COUNT", expr)
}

/// Convert to logical plan
fn to_logical_plan(&self) -> Arc<LogicalPlan> {
self.plan.clone()
}
}

/// Create a new schema by applying a projection to this schema's fields
fn projection(schema: &Schema, projection: &Vec<usize>) -> Result<Arc<Schema>> {
let mut fields: Vec<Field> = Vec::with_capacity(projection.len());
for i in projection {
if *i < schema.fields().len() {
fields.push(schema.field(*i).clone());
} else {
return Err(ExecutionError::InvalidColumn(format!(
"Invalid column index {} in projection",
i
)));
impl TableImpl {
/// Determine the data type for a given expression
fn get_data_type(&self, expr: &Expr) -> Result<DataType> {
match expr {
Expr::Column(i) => Ok(self.plan.schema().field(*i).data_type().clone()),
_ => Err(ExecutionError::General(format!(
"Could not determine data type for expr {:?}",
expr
))),
}
}

/// Create an expression to represent a named aggregate function
fn aggregate_expr(&self, name: &str, expr: &Expr) -> Result<Expr> {
let return_type = self.get_data_type(expr)?;
Ok(Expr::AggregateFunction {
name: name.to_string(),
args: vec![expr.clone()],
return_type,
})
}
}

#[cfg(test)]
mod tests {
use super::*;
use crate::execution::context::ExecutionContext;
use std::env;

#[test]
fn column_index() {
let t = test_table();
assert_eq!(0, t.column_index("c1").unwrap());
assert_eq!(1, t.column_index("c2").unwrap());
assert_eq!(12, t.column_index("c13").unwrap());
}

#[test]
fn select_columns() -> Result<()> {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice to know we can return Result<()> here. We can avoid lots of unwraps in all the unit tests we have.

// build plan using Table API
let t = test_table();
let t2 = t.select_columns(vec!["c1", "c2", "c11"])?;
let plan = t2.to_logical_plan();

// build query using SQL
let sql_plan = create_plan("SELECT c1, c2, c11 FROM aggregate_test_100")?;

// the two plans should be identical
assert_same_plan(&plan, &sql_plan);

Ok(())
}

#[test]
fn select_expr() -> Result<()> {
// build plan using Table API
let t = test_table();
let t2 = t.select(vec![t.col("c1")?, t.col("c2")?, t.col("c11")?])?;
let plan = t2.to_logical_plan();

// build query using SQL
let sql_plan = create_plan("SELECT c1, c2, c11 FROM aggregate_test_100")?;

// the two plans should be identical
assert_same_plan(&plan, &sql_plan);

Ok(())
}

#[test]
fn select_invalid_column() -> Result<()> {
let t = test_table();

match t.col("invalid_column_name") {
Ok(_) => panic!(),
Err(e) => assert_eq!(
"InvalidColumn(\"No column named \\\'invalid_column_name\\\'\")",
format!("{:?}", e)
),
}

Ok(())
}
Ok(Arc::new(Schema::new(fields)))

#[test]
fn aggregate() -> Result<()> {
// build plan using Table API
let t = test_table();
let group_expr = vec![t.col("c1")?];
let c12 = t.col("c12")?;
let aggr_expr = vec![
t.min(&c12)?,
t.max(&c12)?,
t.avg(&c12)?,
t.sum(&c12)?,
t.count(&c12)?,
];

let t2 = t.aggregate(group_expr.clone(), aggr_expr.clone())?;

let plan = t2.to_logical_plan();

// build same plan using SQL API
let sql = "SELECT c1, MIN(c12), MAX(c12), AVG(c12), SUM(c12), COUNT(c12) \
FROM aggregate_test_100 \
GROUP BY c1";
let sql_plan = create_plan(sql)?;

// the two plans should be identical
assert_same_plan(&plan, &sql_plan);

Ok(())
}

#[test]
fn limit() -> Result<()> {
// build query using Table API
let t = test_table();
let t2 = t.select_columns(vec!["c1", "c2", "c11"])?.limit(10)?;
let plan = t2.to_logical_plan();

// build query using SQL
let sql_plan =
create_plan("SELECT c1, c2, c11 FROM aggregate_test_100 LIMIT 10")?;

// the two plans should be identical
assert_same_plan(&plan, &sql_plan);

Ok(())
}

/// Compare the formatted string representation of two plans for equality
fn assert_same_plan(plan1: &LogicalPlan, plan2: &LogicalPlan) {
assert_eq!(format!("{:?}", plan1), format!("{:?}", plan2));
}

/// Create a logical plan from a SQL query
fn create_plan(sql: &str) -> Result<Arc<LogicalPlan>> {
let mut ctx = ExecutionContext::new();
register_aggregate_csv(&mut ctx);
ctx.create_logical_plan(sql)
}

fn test_table() -> Arc<dyn Table + 'static> {
let mut ctx = ExecutionContext::new();
register_aggregate_csv(&mut ctx);
ctx.table("aggregate_test_100").unwrap()
}

fn register_aggregate_csv(ctx: &mut ExecutionContext) {
let schema = aggr_test_schema();
let testdata = env::var("ARROW_TEST_DATA").expect("ARROW_TEST_DATA not defined");
ctx.register_csv(
"aggregate_test_100",
&format!("{}/csv/aggregate_test_100.csv", testdata),
&schema,
true,
);
}

fn aggr_test_schema() -> Arc<Schema> {
Arc::new(Schema::new(vec![
Field::new("c1", DataType::Utf8, false),
Field::new("c2", DataType::UInt32, false),
Field::new("c3", DataType::Int8, false),
Field::new("c4", DataType::Int16, false),
Field::new("c5", DataType::Int32, false),
Field::new("c6", DataType::Int64, false),
Field::new("c7", DataType::UInt8, false),
Field::new("c8", DataType::UInt16, false),
Field::new("c9", DataType::UInt32, false),
Field::new("c10", DataType::UInt64, false),
Field::new("c11", DataType::Float32, false),
Field::new("c12", DataType::Float64, false),
Field::new("c13", DataType::Utf8, false),
]))
}

}
1 change: 1 addition & 0 deletions rust/datafusion/src/optimizer/type_coercion.rs
Expand Up @@ -74,6 +74,7 @@ impl OptimizerRule for TypeCoercionRule {
LogicalPlan::TableScan { .. } => Ok(Arc::new(plan.clone())),
LogicalPlan::EmptyRelation { .. } => Ok(Arc::new(plan.clone())),
LogicalPlan::Limit { .. } => Ok(Arc::new(plan.clone())),
LogicalPlan::CreateExternalTable { .. } => Ok(Arc::new(plan.clone())),
other => Err(ExecutionError::NotImplemented(format!(
"Type coercion optimizer rule does not support relation: {:?}",
other
Expand Down
13 changes: 11 additions & 2 deletions rust/datafusion/src/sql/planner.rs
Expand Up @@ -174,8 +174,17 @@ impl SqlToRel {
let limit_plan = match limit {
&Some(ref limit_expr) => {
let input_schema = order_by_plan.schema();
let limit_rex =
self.sql_to_rex(&limit_expr, &input_schema.clone())?;

let limit_rex = match self
.sql_to_rex(&limit_expr, &input_schema.clone())?
{
Expr::Literal(ScalarValue::Int64(n)) => {
Ok(Expr::Literal(ScalarValue::UInt32(n as u32)))
}
_ => Err(ExecutionError::General(
"Unexpected expression for LIMIT clause".to_string(),
)),
}?;

LogicalPlan::Limit {
expr: limit_rex,
Expand Down