Skip to content

Commit

Permalink
ARROW-8249: [Rust] [DataFusion] Table API now uses LogicalPlanBuilder
Browse files Browse the repository at this point in the history
Table API now uses LogicalPlanBuilder for more concise and consistent code.

Closes #6748 from andygrove/ARROW-8249

Authored-by: Andy Grove <andygrove73@gmail.com>
Signed-off-by: Andy Grove <andygrove73@gmail.com>
  • Loading branch information
andygrove committed Mar 28, 2020
1 parent 8e40170 commit c49b960
Show file tree
Hide file tree
Showing 4 changed files with 50 additions and 111 deletions.
25 changes: 10 additions & 15 deletions rust/datafusion/examples/memory_table_api.rs
Expand Up @@ -26,11 +26,12 @@ use arrow::datatypes::{DataType, Field, Schema};
use arrow::record_batch::RecordBatch;

use datafusion::datasource::MemTable;
use datafusion::error::Result;
use datafusion::execution::context::ExecutionContext;
use datafusion::logicalplan::{Expr, ScalarValue};

/// This example demonstrates basic uses of the Table API on an in-memory table
fn main() {
fn main() -> Result<()> {
// define a schema.
let schema = Arc::new(Schema::new(vec![
Field::new("a", DataType::Utf8, false),
Expand All @@ -44,31 +45,23 @@ fn main() {
Arc::new(StringArray::from(vec!["a", "b", "c", "d"])),
Arc::new(Int32Array::from(vec![1, 10, 10, 100])),
],
)
.unwrap();
)?;

// declare a new context. In spark API, this corresponds to a new spark SQLsession
let mut ctx = ExecutionContext::new();

// declare a table in memory. In spark API, this corresponds to createDataFrame(...).
let provider = MemTable::new(schema, vec![batch]).unwrap();
let provider = MemTable::new(schema, vec![batch])?;
ctx.register_table("t", Box::new(provider));
let t = ctx.table("t").unwrap();
let t = ctx.table("t")?;

// construct an expression corresponding to "SELECT a, b FROM t WHERE b = 10" in SQL
let filter = t
.col("b")
.unwrap()
.eq(&Expr::Literal(ScalarValue::Int32(10)));
let filter = t.col("b")?.eq(&Expr::Literal(ScalarValue::Int32(10)));

let t = t
.select_columns(vec!["a", "b"])
.unwrap()
.filter(filter)
.unwrap();
let t = t.select_columns(vec!["a", "b"])?.filter(filter)?;

// execute
let results = t.collect(&mut ctx, 10).unwrap();
let results = t.collect(&mut ctx, 10)?;

// print results
results.iter().for_each(|batch| {
Expand All @@ -94,4 +87,6 @@ fn main() {
println!("{}, {}", c1.value(i), c2.value(i),);
}
});

Ok(())
}
7 changes: 5 additions & 2 deletions rust/datafusion/src/execution/context.rs
Expand Up @@ -216,13 +216,16 @@ impl ExecutionContext {
pub fn table(&mut self, table_name: &str) -> Result<Arc<dyn Table>> {
match self.datasources.get(table_name) {
Some(provider) => {
Ok(Arc::new(TableImpl::new(Arc::new(LogicalPlan::TableScan {
let table_scan = LogicalPlan::TableScan {
schema_name: "".to_string(),
table_name: table_name.to_string(),
table_schema: provider.schema().clone(),
projected_schema: provider.schema().clone(),
projection: None,
}))))
};
Ok(Arc::new(TableImpl::new(
&LogicalPlanBuilder::from(&table_scan).build()?,
)))
}
_ => Err(ExecutionError::General(format!(
"No table named '{}'",
Expand Down
122 changes: 33 additions & 89 deletions rust/datafusion/src/execution/table_impl.rs
Expand Up @@ -19,72 +19,54 @@

use std::sync::Arc;

use crate::arrow::datatypes::{DataType, Field, Schema};
use crate::arrow::datatypes::DataType;
use crate::arrow::record_batch::RecordBatch;
use crate::error::{ExecutionError, Result};
use crate::execution::context::ExecutionContext;
use crate::logicalplan::Expr::Literal;
use crate::logicalplan::ScalarValue;
use crate::logicalplan::{Expr, LogicalPlan};
use crate::logicalplan::{LogicalPlanBuilder, ScalarValue};
use crate::table::*;

/// Implementation of Table API
pub struct TableImpl {
plan: Arc<LogicalPlan>,
plan: LogicalPlan,
}

impl TableImpl {
/// Create a new Table based on an existing logical plan
pub fn new(plan: Arc<LogicalPlan>) -> Self {
Self { plan }
pub fn new(plan: &LogicalPlan) -> Self {
Self { plan: plan.clone() }
}
}

impl Table for TableImpl {
/// Apply a projection based on a list of column names
fn select_columns(&self, columns: Vec<&str>) -> Result<Arc<dyn Table>> {
let mut expr: Vec<Expr> = Vec::with_capacity(columns.len());
for column_name in columns {
let i = self.column_index(column_name)?;
expr.push(Expr::Column(i));
}
self.select(expr)
let exprs = columns
.iter()
.map(|name| {
self.plan
.schema()
.index_of(name.to_owned())
.and_then(|i| Ok(Expr::Column(i)))
.map_err(|e| e.into())
})
.collect::<Result<Vec<_>>>()?;
self.select(exprs)
}

/// Create a projection based on arbitrary expressions
fn select(&self, expr_list: Vec<Expr>) -> Result<Arc<dyn Table>> {
let schema = self.plan.schema();
let mut field: Vec<Field> = Vec::with_capacity(expr_list.len());

for expr in &expr_list {
match expr {
Expr::Column(i) => {
field.push(schema.field(*i).clone());
}
other => {
return Err(ExecutionError::NotImplemented(format!(
"Expr {:?} is not currently supported in this context",
other
)))
}
}
}

Ok(Arc::new(TableImpl::new(Arc::new(
LogicalPlan::Projection {
expr: expr_list.clone(),
input: self.plan.clone(),
schema: Arc::new(Schema::new(field)),
},
))))
let plan = LogicalPlanBuilder::from(&self.plan)
.project(expr_list)?
.build()?;
Ok(Arc::new(TableImpl::new(&plan)))
}

/// Create a selection based on a filter expression
fn filter(&self, expr: Expr) -> Result<Arc<dyn Table>> {
Ok(Arc::new(TableImpl::new(Arc::new(LogicalPlan::Selection {
expr,
input: self.plan.clone(),
}))))
let plan = LogicalPlanBuilder::from(&self.plan).filter(expr)?.build()?;
Ok(Arc::new(TableImpl::new(&plan)))
}

/// Perform an aggregate query
Expand All @@ -93,38 +75,23 @@ impl Table for TableImpl {
group_expr: Vec<Expr>,
aggr_expr: Vec<Expr>,
) -> Result<Arc<dyn Table>> {
Ok(Arc::new(TableImpl::new(Arc::new(LogicalPlan::Aggregate {
input: self.plan.clone(),
group_expr,
aggr_expr,
schema: Arc::new(Schema::new(vec![])),
}))))
let plan = LogicalPlanBuilder::from(&self.plan)
.aggregate(group_expr, aggr_expr)?
.build()?;
Ok(Arc::new(TableImpl::new(&plan)))
}

/// Limit the number of rows
fn limit(&self, n: usize) -> Result<Arc<dyn Table>> {
Ok(Arc::new(TableImpl::new(Arc::new(LogicalPlan::Limit {
expr: Literal(ScalarValue::UInt32(n as u32)),
input: self.plan.clone(),
schema: self.plan.schema().clone(),
}))))
fn limit(&self, n: u32) -> Result<Arc<dyn Table>> {
let plan = LogicalPlanBuilder::from(&self.plan)
.limit(Expr::Literal(ScalarValue::UInt32(n)))?
.build()?;
Ok(Arc::new(TableImpl::new(&plan)))
}

/// Return an expression representing a column within this table
fn col(&self, name: &str) -> Result<Expr> {
Ok(Expr::Column(self.column_index(name)?))
}

/// Return the index of a column within this table's schema
fn column_index(&self, name: &str) -> Result<usize> {
let schema = self.plan.schema();
match schema.column_with_name(name) {
Some((i, _)) => Ok(i),
_ => Err(ExecutionError::InvalidColumn(format!(
"No column named '{}'",
name
))),
}
Ok(Expr::Column(self.plan.schema().index_of(name)?))
}

/// Create an expression to represent the min() aggregate function
Expand Down Expand Up @@ -153,7 +120,7 @@ impl Table for TableImpl {
}

/// Convert to logical plan
fn to_logical_plan(&self) -> Arc<LogicalPlan> {
fn to_logical_plan(&self) -> LogicalPlan {
self.plan.clone()
}

Expand Down Expand Up @@ -195,14 +162,6 @@ mod tests {
use crate::execution::context::ExecutionContext;
use crate::test;

#[test]
fn column_index() {
let t = test_table();
assert_eq!(0, t.column_index("c1").unwrap());
assert_eq!(1, t.column_index("c2").unwrap());
assert_eq!(12, t.column_index("c13").unwrap());
}

#[test]
fn select_columns() -> Result<()> {
// build plan using Table API
Expand Down Expand Up @@ -235,21 +194,6 @@ mod tests {
Ok(())
}

#[test]
fn select_invalid_column() -> Result<()> {
let t = test_table();

match t.col("invalid_column_name") {
Ok(_) => panic!(),
Err(e) => assert_eq!(
"InvalidColumn(\"No column named \\\'invalid_column_name\\\'\")",
format!("{:?}", e)
),
}

Ok(())
}

#[test]
fn aggregate() -> Result<()> {
// build plan using Table API
Expand Down
7 changes: 2 additions & 5 deletions rust/datafusion/src/table.rs
Expand Up @@ -43,10 +43,10 @@ pub trait Table {
) -> Result<Arc<dyn Table>>;

/// limit the number of rows
fn limit(&self, n: usize) -> Result<Arc<dyn Table>>;
fn limit(&self, n: u32) -> Result<Arc<dyn Table>>;

/// Return the logical plan
fn to_logical_plan(&self) -> Arc<LogicalPlan>;
fn to_logical_plan(&self) -> LogicalPlan;

/// Return an expression representing a column within this table
fn col(&self, name: &str) -> Result<Expr>;
Expand All @@ -66,9 +66,6 @@ pub trait Table {
/// Create an expression to represent the count() aggregate function
fn count(&self, expr: &Expr) -> Result<Expr>;

/// Return the index of a column within this table's schema
fn column_index(&self, name: &str) -> Result<usize>;

/// Collects the result as a vector of RecordBatch.
fn collect(
&self,
Expand Down

0 comments on commit c49b960

Please sign in to comment.