diff --git a/dev/release/00-prepare.sh b/dev/release/00-prepare.sh index bfcfc83825499..90d318145ca06 100755 --- a/dev/release/00-prepare.sh +++ b/dev/release/00-prepare.sh @@ -100,9 +100,9 @@ update_versions() { cd "${SOURCE_DIR}/../../rust" sed -i.bak -E -e \ "s/^version = \".+\"/version = \"${version}\"/g" \ - arrow/Cargo.toml parquet/Cargo.toml - rm -f arrow/Cargo.toml.bak parquet/Cargo.toml.bak - git add arrow/Cargo.toml parquet/Cargo.toml + arrow/Cargo.toml parquet/Cargo.toml datafusion/Cargo.toml + rm -f arrow/Cargo.toml.bak parquet/Cargo.toml.bak datafusion/Cargo.toml.bak + git add arrow/Cargo.toml parquet/Cargo.toml datafusion/Cargo.toml # Update version number for parquet README sed -i.bak -E -e \ diff --git a/rust/Cargo.toml b/rust/Cargo.toml index abfb71ada7951..c415ab7c8d7d1 100644 --- a/rust/Cargo.toml +++ b/rust/Cargo.toml @@ -19,4 +19,5 @@ members = [ "arrow", "parquet", + "datafusion", ] \ No newline at end of file diff --git a/rust/datafusion/Cargo.toml b/rust/datafusion/Cargo.toml new file mode 100644 index 0000000000000..864243bce4cc1 --- /dev/null +++ b/rust/datafusion/Cargo.toml @@ -0,0 +1,50 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +[package] +name = "datafusion" +description = "DataFusion is an in-memory query engine that uses Apache Arrow as the memory model" +version = "0.13.0-SNAPSHOT" +homepage = "https://github.com/apache/arrow" +repository = "https://github.com/apache/arrow" +authors = ["Apache Arrow "] +license = "Apache-2.0" +keywords = [ "arrow", "query", "sql" ] +include = [ + "src/**/*.rs", + "Cargo.toml", +] +edition = "2018" + +[lib] +name = "datafusion" +path = "src/lib.rs" + +[dependencies] +clap = "2.31.2" +fnv = "1.0.3" +arrow = { path = "../arrow" } +parquet = { path = "../parquet" } +datafusion-rustyline = "2.0.0-alpha-20180628" +serde = { version = "1.0.80", features = ["alloc", "rc"] } +serde_derive = "1.0.80" +serde_json = "1.0.33" +sqlparser = "0.2.0" + +[dev-dependencies] +criterion = "0.2.0" + diff --git a/rust/datafusion/README.md b/rust/datafusion/README.md new file mode 100644 index 0000000000000..eade95e063536 --- /dev/null +++ b/rust/datafusion/README.md @@ -0,0 +1,94 @@ + + +# DataFusion + +DataFusion is an in-memory query engine that uses Apache Arrow as the memory model + +# Status + +The current code supports single-threaded execution of limited SQL queries (projection, selection, and aggregates) against CSV files. Parquet files will be supported shortly. + +Here is a brief example for running a SQL query against a CSV file. See the [examples](examples) directory for full examples. + +```rust +fn main() { + // create local execution context + let mut ctx = ExecutionContext::new(); + + // define schema for data source (csv file) + let schema = Arc::new(Schema::new(vec![ + Field::new("city", DataType::Utf8, false), + Field::new("lat", DataType::Float64, false), + Field::new("lng", DataType::Float64, false), + ])); + + // register csv file with the execution context + let csv_datasource = CsvDataSource::new("../../testing/data/csv/uk_cities.csv", schema.clone(), 1024); + ctx.register_datasource("cities", Rc::new(RefCell::new(csv_datasource))); + + // simple projection and selection + let sql = "SELECT city, lat, lng FROM cities WHERE lat > 51.0 AND lat < 53"; + + // execute the query + let relation = ctx.sql(&sql).unwrap(); + + // display the relation + let mut results = relation.borrow_mut(); + + while let Some(batch) = results.next().unwrap() { + + println!( + "RecordBatch has {} rows and {} columns", + batch.num_rows(), + batch.num_columns() + ); + + let city = batch + .column(0) + .as_any() + .downcast_ref::() + .unwrap(); + + let lat = batch + .column(1) + .as_any() + .downcast_ref::() + .unwrap(); + + let lng = batch + .column(2) + .as_any() + .downcast_ref::() + .unwrap(); + + for i in 0..batch.num_rows() { + let city_name: String = String::from_utf8(city.get_value(i).to_vec()).unwrap(); + + println!( + "City: {}, Latitude: {}, Longitude: {}", + city_name, + lat.value(i), + lng.value(i), + ); + } + } +} +``` + diff --git a/rust/datafusion/examples/csv_sql.rs b/rust/datafusion/examples/csv_sql.rs new file mode 100644 index 0000000000000..40959cbb624f6 --- /dev/null +++ b/rust/datafusion/examples/csv_sql.rs @@ -0,0 +1,101 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::cell::RefCell; +use std::rc::Rc; +use std::sync::Arc; + +extern crate arrow; +extern crate datafusion; + +use arrow::array::{BinaryArray, Float64Array}; +use arrow::datatypes::{DataType, Field, Schema}; + +use datafusion::execution::context::ExecutionContext; +use datafusion::execution::datasource::CsvDataSource; + +/// This example demonstrates executing a simple query against an Arrow data source and fetching results +fn main() { + // create local execution context + let mut ctx = ExecutionContext::new(); + + // define schema for data source (csv file) + let schema = Arc::new(Schema::new(vec![ + Field::new("c1", DataType::Utf8, false), + Field::new("c2", DataType::UInt32, false), + Field::new("c3", DataType::Int8, false), + Field::new("c4", DataType::Int16, false), + Field::new("c5", DataType::Int32, false), + Field::new("c6", DataType::Int64, false), + Field::new("c7", DataType::UInt8, false), + Field::new("c8", DataType::UInt16, false), + Field::new("c9", DataType::UInt32, false), + Field::new("c10", DataType::UInt64, false), + Field::new("c11", DataType::Float32, false), + Field::new("c12", DataType::Float64, false), + Field::new("c13", DataType::Utf8, false), + ])); + + // register csv file with the execution context + let csv_datasource = CsvDataSource::new( + "../../testing/data/csv/aggregate_test_100.csv", + schema.clone(), + 1024, + ); + ctx.register_datasource("aggregate_test_100", Rc::new(RefCell::new(csv_datasource))); + + // simple projection and selection + let sql = "SELECT c1, MIN(c12), MAX(c12) FROM aggregate_test_100 WHERE c11 > 0.1 AND c11 < 0.9 GROUP BY c1"; + + // execute the query + let relation = ctx.sql(&sql).unwrap(); + + // display the relation + let mut results = relation.borrow_mut(); + + while let Some(batch) = results.next().unwrap() { + println!( + "RecordBatch has {} rows and {} columns", + batch.num_rows(), + batch.num_columns() + ); + + let c1 = batch + .column(0) + .as_any() + .downcast_ref::() + .unwrap(); + + let min = batch + .column(1) + .as_any() + .downcast_ref::() + .unwrap(); + + let max = batch + .column(2) + .as_any() + .downcast_ref::() + .unwrap(); + + for i in 0..batch.num_rows() { + let c1_value: String = String::from_utf8(c1.value(i).to_vec()).unwrap(); + + println!("{}, Min: {}, Max: {}", c1_value, min.value(i), max.value(i),); + } + } +} diff --git a/rust/datafusion/src/dfparser.rs b/rust/datafusion/src/dfparser.rs new file mode 100644 index 0000000000000..7abbd8dcf10bd --- /dev/null +++ b/rust/datafusion/src/dfparser.rs @@ -0,0 +1,220 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! SQL Parser +//! +//! Note that most SQL parsing is now delegated to the sqlparser crate, which handles ANSI SQL but +//! this module contains DataFusion-specific SQL extensions. + +use sqlparser::dialect::*; +use sqlparser::sqlast::*; +use sqlparser::sqlparser::*; +use sqlparser::sqltokenizer::*; + +macro_rules! parser_err { + ($MSG:expr) => { + Err(ParserError::ParserError($MSG.to_string())) + }; +} + +#[derive(Debug, Clone)] +pub enum FileType { + NdJson, + Parquet, + CSV, +} + +#[derive(Debug, Clone)] +pub enum DFASTNode { + /// ANSI SQL AST node + ANSI(ASTNode), + /// DDL for creating an external table in DataFusion + CreateExternalTable { + /// Table name + name: String, + /// Optional schema + columns: Vec, + /// File type (Parquet, NDJSON, CSV) + file_type: FileType, + /// Header row? + header_row: bool, + /// Path to file + location: String, + }, +} + +/// SQL Parser +pub struct DFParser { + parser: Parser, +} + +impl DFParser { + /// Parse the specified tokens + pub fn new(sql: String) -> Result { + let dialect = GenericSqlDialect {}; + let mut tokenizer = Tokenizer::new(&dialect, &sql); + let tokens = tokenizer.tokenize()?; + Ok(DFParser { + parser: Parser::new(tokens), + }) + } + + /// Parse a SQL statement and produce an Abstract Syntax Tree (AST) + pub fn parse_sql(sql: String) -> Result { + let mut parser = DFParser::new(sql)?; + parser.parse() + } + + /// Parse a new expression + pub fn parse(&mut self) -> Result { + self.parse_expr(0) + } + + /// Parse tokens until the precedence changes + fn parse_expr(&mut self, precedence: u8) -> Result { + let mut expr = self.parse_prefix()?; + loop { + let next_precedence = self.parser.get_next_precedence()?; + if precedence >= next_precedence { + break; + } + + if let Some(infix_expr) = self.parse_infix(expr.clone(), next_precedence)? { + expr = infix_expr; + } + } + Ok(expr) + } + + /// Parse an expression prefix + fn parse_prefix(&mut self) -> Result { + if self + .parser + .parse_keywords(vec!["CREATE", "EXTERNAL", "TABLE"]) + { + match self.parser.next_token() { + Some(Token::Identifier(id)) => { + // parse optional column list (schema) + let mut columns = vec![]; + if self.parser.consume_token(&Token::LParen) { + loop { + if let Some(Token::Identifier(column_name)) = + self.parser.next_token() + { + if let Ok(data_type) = self.parser.parse_data_type() { + let allow_null = if self + .parser + .parse_keywords(vec!["NOT", "NULL"]) + { + false + } else if self.parser.parse_keyword("NULL") { + true + } else { + true + }; + + match self.parser.peek_token() { + Some(Token::Comma) => { + self.parser.next_token(); + columns.push(SQLColumnDef { + name: column_name, + data_type: data_type, + allow_null, + default: None, + is_primary: false, + is_unique: false, + }); + } + Some(Token::RParen) => break, + _ => { + return parser_err!( + "Expected ',' or ')' after column definition" + ); + } + } + } else { + return parser_err!( + "Error parsing data type in column definition" + ); + } + } else { + return parser_err!("Error parsing column name"); + } + } + } + + //println!("Parsed {} column defs", columns.len()); + + let mut headers = true; + let file_type: FileType = if self + .parser + .parse_keywords(vec!["STORED", "AS", "CSV"]) + { + if self.parser.parse_keywords(vec!["WITH", "HEADER", "ROW"]) { + headers = true; + } else if self + .parser + .parse_keywords(vec!["WITHOUT", "HEADER", "ROW"]) + { + headers = false; + } + FileType::CSV + } else if self.parser.parse_keywords(vec!["STORED", "AS", "NDJSON"]) { + FileType::NdJson + } else if self.parser.parse_keywords(vec!["STORED", "AS", "PARQUET"]) + { + FileType::Parquet + } else { + return parser_err!(format!( + "Expected 'STORED AS' clause, found {:?}", + self.parser.peek_token() + )); + }; + + let location: String = if self.parser.parse_keywords(vec!["LOCATION"]) + { + self.parser.parse_literal_string()? + } else { + return parser_err!("Missing 'LOCATION' clause"); + }; + + Ok(DFASTNode::CreateExternalTable { + name: id, + columns, + file_type, + header_row: headers, + location, + }) + } + _ => parser_err!(format!( + "Unexpected token after CREATE EXTERNAL TABLE: {:?}", + self.parser.peek_token() + )), + } + } else { + Ok(DFASTNode::ANSI(self.parser.parse_prefix()?)) + } + } + + pub fn parse_infix( + &mut self, + _expr: DFASTNode, + _precedence: u8, + ) -> Result, ParserError> { + unimplemented!() + } +} diff --git a/rust/datafusion/src/execution/aggregate.rs b/rust/datafusion/src/execution/aggregate.rs new file mode 100644 index 0000000000000..5acf5fb63a188 --- /dev/null +++ b/rust/datafusion/src/execution/aggregate.rs @@ -0,0 +1,1214 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Execution of a simple aggregate relation containing MIN, MAX, COUNT, SUM aggregate functions +//! with optional GROUP BY columns + +use std::cell::RefCell; +use std::rc::Rc; +use std::str; +use std::sync::Arc; + +use arrow::array::*; +use arrow::array_ops; +use arrow::builder::*; +use arrow::datatypes::{DataType, Schema}; +use arrow::record_batch::RecordBatch; + +use super::error::{ExecutionError, Result}; +use super::expression::{AggregateType, RuntimeExpr}; +use super::relation::Relation; +use crate::logicalplan::ScalarValue; + +use fnv::FnvHashMap; + +/// An aggregate relation is made up of zero or more grouping expressions and one +/// or more aggregate expressions +pub struct AggregateRelation { + schema: Arc, + input: Rc>, + group_expr: Vec, + aggr_expr: Vec, + end_of_results: bool, +} + +impl AggregateRelation { + pub fn new( + schema: Arc, + input: Rc>, + group_expr: Vec, + aggr_expr: Vec, + ) -> Self { + AggregateRelation { + schema, + input, + group_expr, + aggr_expr, + end_of_results: false, + } + } +} + +/// Enumeration of types that can be used in a GROUP BY expression (all primitives except for +/// floating point numerics) +#[derive(Debug, PartialEq, Eq, Hash, Clone)] +enum GroupByScalar { + UInt8(u8), + UInt16(u16), + UInt32(u32), + UInt64(u64), + Int8(i8), + Int16(i16), + Int32(i32), + Int64(i64), + Utf8(String), +} + +/// Common trait for all aggregation functions +trait AggregateFunction { + /// Get the function name (used for debugging) + fn name(&self) -> &str; + fn accumulate_scalar(&mut self, value: &Option); + fn result(&self) -> &Option; + fn data_type(&self) -> &DataType; +} + +#[derive(Debug)] +struct MinFunction { + data_type: DataType, + value: Option, +} + +impl MinFunction { + fn new(data_type: &DataType) -> Self { + Self { + data_type: data_type.clone(), + value: None, + } + } +} + +impl AggregateFunction for MinFunction { + fn name(&self) -> &str { + "min" + } + + fn accumulate_scalar(&mut self, value: &Option) { + if self.value.is_none() { + self.value = value.clone(); + } else if value.is_some() { + self.value = match (&self.value, value) { + (Some(ScalarValue::UInt8(a)), Some(ScalarValue::UInt8(b))) => { + Some(ScalarValue::UInt8(*a.min(b))) + } + (Some(ScalarValue::UInt16(a)), Some(ScalarValue::UInt16(b))) => { + Some(ScalarValue::UInt16(*a.min(b))) + } + (Some(ScalarValue::UInt32(a)), Some(ScalarValue::UInt32(b))) => { + Some(ScalarValue::UInt32(*a.min(b))) + } + (Some(ScalarValue::UInt64(a)), Some(ScalarValue::UInt64(b))) => { + Some(ScalarValue::UInt64(*a.min(b))) + } + (Some(ScalarValue::Int8(a)), Some(ScalarValue::Int8(b))) => { + Some(ScalarValue::Int8(*a.min(b))) + } + (Some(ScalarValue::Int16(a)), Some(ScalarValue::Int16(b))) => { + Some(ScalarValue::Int16(*a.min(b))) + } + (Some(ScalarValue::Int32(a)), Some(ScalarValue::Int32(b))) => { + Some(ScalarValue::Int32(*a.min(b))) + } + (Some(ScalarValue::Int64(a)), Some(ScalarValue::Int64(b))) => { + Some(ScalarValue::Int64(*a.min(b))) + } + (Some(ScalarValue::Float32(a)), Some(ScalarValue::Float32(b))) => { + Some(ScalarValue::Float32(a.min(*b))) + } + (Some(ScalarValue::Float64(a)), Some(ScalarValue::Float64(b))) => { + Some(ScalarValue::Float64(a.min(*b))) + } + _ => panic!("unsupported data type for MIN"), + } + } + } + + fn result(&self) -> &Option { + &self.value + } + + fn data_type(&self) -> &DataType { + &self.data_type + } +} + +#[derive(Debug)] +struct MaxFunction { + data_type: DataType, + value: Option, +} + +impl MaxFunction { + fn new(data_type: &DataType) -> Self { + Self { + data_type: data_type.clone(), + value: None, + } + } +} + +impl AggregateFunction for MaxFunction { + fn name(&self) -> &str { + "max" + } + + fn accumulate_scalar(&mut self, value: &Option) { + if self.value.is_none() { + self.value = value.clone(); + } else if value.is_some() { + self.value = match (&self.value, value) { + (Some(ScalarValue::UInt8(a)), Some(ScalarValue::UInt8(b))) => { + Some(ScalarValue::UInt8(*a.max(b))) + } + (Some(ScalarValue::UInt16(a)), Some(ScalarValue::UInt16(b))) => { + Some(ScalarValue::UInt16(*a.max(b))) + } + (Some(ScalarValue::UInt32(a)), Some(ScalarValue::UInt32(b))) => { + Some(ScalarValue::UInt32(*a.max(b))) + } + (Some(ScalarValue::UInt64(a)), Some(ScalarValue::UInt64(b))) => { + Some(ScalarValue::UInt64(*a.max(b))) + } + (Some(ScalarValue::Int8(a)), Some(ScalarValue::Int8(b))) => { + Some(ScalarValue::Int8(*a.max(b))) + } + (Some(ScalarValue::Int16(a)), Some(ScalarValue::Int16(b))) => { + Some(ScalarValue::Int16(*a.max(b))) + } + (Some(ScalarValue::Int32(a)), Some(ScalarValue::Int32(b))) => { + Some(ScalarValue::Int32(*a.max(b))) + } + (Some(ScalarValue::Int64(a)), Some(ScalarValue::Int64(b))) => { + Some(ScalarValue::Int64(*a.max(b))) + } + (Some(ScalarValue::Float32(a)), Some(ScalarValue::Float32(b))) => { + Some(ScalarValue::Float32(a.max(*b))) + } + (Some(ScalarValue::Float64(a)), Some(ScalarValue::Float64(b))) => { + Some(ScalarValue::Float64(a.max(*b))) + } + _ => panic!("unsupported data type for MAX"), + } + } + } + + fn result(&self) -> &Option { + &self.value + } + + fn data_type(&self) -> &DataType { + &self.data_type + } +} + +#[derive(Debug)] +struct SumFunction { + data_type: DataType, + value: Option, +} + +impl SumFunction { + fn new(data_type: &DataType) -> Self { + Self { + data_type: data_type.clone(), + value: None, + } + } +} + +impl AggregateFunction for SumFunction { + fn name(&self) -> &str { + "sum" + } + + fn accumulate_scalar(&mut self, value: &Option) { + if self.value.is_none() { + self.value = value.clone(); + } else if value.is_some() { + self.value = match (&self.value, value) { + (Some(ScalarValue::UInt8(a)), Some(ScalarValue::UInt8(b))) => { + Some(ScalarValue::UInt8(*a + b)) + } + (Some(ScalarValue::UInt16(a)), Some(ScalarValue::UInt16(b))) => { + Some(ScalarValue::UInt16(*a + b)) + } + (Some(ScalarValue::UInt32(a)), Some(ScalarValue::UInt32(b))) => { + Some(ScalarValue::UInt32(*a + b)) + } + (Some(ScalarValue::UInt64(a)), Some(ScalarValue::UInt64(b))) => { + Some(ScalarValue::UInt64(*a + b)) + } + (Some(ScalarValue::Int8(a)), Some(ScalarValue::Int8(b))) => { + Some(ScalarValue::Int8(*a + b)) + } + (Some(ScalarValue::Int16(a)), Some(ScalarValue::Int16(b))) => { + Some(ScalarValue::Int16(*a + b)) + } + (Some(ScalarValue::Int32(a)), Some(ScalarValue::Int32(b))) => { + Some(ScalarValue::Int32(*a + b)) + } + (Some(ScalarValue::Int64(a)), Some(ScalarValue::Int64(b))) => { + Some(ScalarValue::Int64(*a + b)) + } + (Some(ScalarValue::Float32(a)), Some(ScalarValue::Float32(b))) => { + Some(ScalarValue::Float32(a + *b)) + } + (Some(ScalarValue::Float64(a)), Some(ScalarValue::Float64(b))) => { + Some(ScalarValue::Float64(a + *b)) + } + _ => panic!("unsupported data type for SUM"), + } + } + } + + fn result(&self) -> &Option { + &self.value + } + + fn data_type(&self) -> &DataType { + &self.data_type + } +} + +struct AccumulatorSet { + aggr_values: Vec>>, +} + +impl AccumulatorSet { + fn accumulate_scalar(&mut self, i: usize, value: Option) { + let mut accumulator = self.aggr_values[i].borrow_mut(); + accumulator.accumulate_scalar(&value); + } + + fn values(&self) -> Vec> { + self.aggr_values + .iter() + .map(|x| x.borrow().result().clone()) + .collect() + } +} + +#[derive(Debug)] +struct MapEntry { + k: Vec, + v: Vec>, +} + +/// Create an initial aggregate entry +fn create_accumulators(aggr_expr: &Vec) -> Result { + let aggr_values: Vec>> = aggr_expr + .iter() + .map(|e| match e { + RuntimeExpr::AggregateFunction { ref f, ref t, .. } => match f { + AggregateType::Min => Ok(Rc::new(RefCell::new(MinFunction::new(t))) + as Rc>), + AggregateType::Max => Ok(Rc::new(RefCell::new(MaxFunction::new(t))) + as Rc>), + AggregateType::Sum => Ok(Rc::new(RefCell::new(SumFunction::new(t))) + as Rc>), + _ => Err(ExecutionError::ExecutionError( + "unsupported aggregate function".to_string(), + )), + }, + _ => Err(ExecutionError::ExecutionError( + "invalid aggregate expression".to_string(), + )), + }) + .collect::>>>>()?; + + Ok(AccumulatorSet { aggr_values }) +} + +fn array_min(array: ArrayRef, dt: &DataType) -> Result> { + match dt { + DataType::UInt8 => { + match array_ops::min(array.as_any().downcast_ref::().unwrap()) { + Some(n) => Ok(Some(ScalarValue::UInt8(n))), + None => Ok(None), + } + } + DataType::UInt16 => { + match array_ops::min(array.as_any().downcast_ref::().unwrap()) { + Some(n) => Ok(Some(ScalarValue::UInt16(n))), + None => Ok(None), + } + } + DataType::UInt32 => { + match array_ops::min(array.as_any().downcast_ref::().unwrap()) { + Some(n) => Ok(Some(ScalarValue::UInt32(n))), + None => Ok(None), + } + } + DataType::UInt64 => { + match array_ops::min(array.as_any().downcast_ref::().unwrap()) { + Some(n) => Ok(Some(ScalarValue::UInt64(n))), + None => Ok(None), + } + } + DataType::Int8 => { + match array_ops::min(array.as_any().downcast_ref::().unwrap()) { + Some(n) => Ok(Some(ScalarValue::Int8(n))), + None => Ok(None), + } + } + DataType::Int16 => { + match array_ops::min(array.as_any().downcast_ref::().unwrap()) { + Some(n) => Ok(Some(ScalarValue::Int16(n))), + None => Ok(None), + } + } + DataType::Int32 => { + match array_ops::min(array.as_any().downcast_ref::().unwrap()) { + Some(n) => Ok(Some(ScalarValue::Int32(n))), + None => Ok(None), + } + } + DataType::Int64 => { + match array_ops::min(array.as_any().downcast_ref::().unwrap()) { + Some(n) => Ok(Some(ScalarValue::Int64(n))), + None => Ok(None), + } + } + DataType::Float32 => { + match array_ops::min(array.as_any().downcast_ref::().unwrap()) { + Some(n) => Ok(Some(ScalarValue::Float32(n))), + None => Ok(None), + } + } + DataType::Float64 => { + match array_ops::min(array.as_any().downcast_ref::().unwrap()) { + Some(n) => Ok(Some(ScalarValue::Float64(n))), + None => Ok(None), + } + } + _ => Err(ExecutionError::ExecutionError( + "Unsupported data type for MIN".to_string(), + )), + } +} + +fn array_max(array: ArrayRef, dt: &DataType) -> Result> { + match dt { + DataType::UInt8 => { + match array_ops::max(array.as_any().downcast_ref::().unwrap()) { + Some(n) => Ok(Some(ScalarValue::UInt8(n))), + None => Ok(None), + } + } + DataType::UInt16 => { + match array_ops::max(array.as_any().downcast_ref::().unwrap()) { + Some(n) => Ok(Some(ScalarValue::UInt16(n))), + None => Ok(None), + } + } + DataType::UInt32 => { + match array_ops::max(array.as_any().downcast_ref::().unwrap()) { + Some(n) => Ok(Some(ScalarValue::UInt32(n))), + None => Ok(None), + } + } + DataType::UInt64 => { + match array_ops::max(array.as_any().downcast_ref::().unwrap()) { + Some(n) => Ok(Some(ScalarValue::UInt64(n))), + None => Ok(None), + } + } + DataType::Int8 => { + match array_ops::max(array.as_any().downcast_ref::().unwrap()) { + Some(n) => Ok(Some(ScalarValue::Int8(n))), + None => Ok(None), + } + } + DataType::Int16 => { + match array_ops::max(array.as_any().downcast_ref::().unwrap()) { + Some(n) => Ok(Some(ScalarValue::Int16(n))), + None => Ok(None), + } + } + DataType::Int32 => { + match array_ops::max(array.as_any().downcast_ref::().unwrap()) { + Some(n) => Ok(Some(ScalarValue::Int32(n))), + None => Ok(None), + } + } + DataType::Int64 => { + match array_ops::max(array.as_any().downcast_ref::().unwrap()) { + Some(n) => Ok(Some(ScalarValue::Int64(n))), + None => Ok(None), + } + } + DataType::Float32 => { + match array_ops::max(array.as_any().downcast_ref::().unwrap()) { + Some(n) => Ok(Some(ScalarValue::Float32(n))), + None => Ok(None), + } + } + DataType::Float64 => { + match array_ops::max(array.as_any().downcast_ref::().unwrap()) { + Some(n) => Ok(Some(ScalarValue::Float64(n))), + None => Ok(None), + } + } + _ => Err(ExecutionError::ExecutionError( + "Unsupported data type for MAX".to_string(), + )), + } +} + +fn array_sum(array: ArrayRef, dt: &DataType) -> Result> { + match dt { + DataType::UInt8 => { + match array_ops::sum(array.as_any().downcast_ref::().unwrap()) { + Some(n) => Ok(Some(ScalarValue::UInt8(n))), + None => Ok(None), + } + } + DataType::UInt16 => { + match array_ops::sum(array.as_any().downcast_ref::().unwrap()) { + Some(n) => Ok(Some(ScalarValue::UInt16(n))), + None => Ok(None), + } + } + DataType::UInt32 => { + match array_ops::sum(array.as_any().downcast_ref::().unwrap()) { + Some(n) => Ok(Some(ScalarValue::UInt32(n))), + None => Ok(None), + } + } + DataType::UInt64 => { + match array_ops::sum(array.as_any().downcast_ref::().unwrap()) { + Some(n) => Ok(Some(ScalarValue::UInt64(n))), + None => Ok(None), + } + } + DataType::Int8 => { + match array_ops::sum(array.as_any().downcast_ref::().unwrap()) { + Some(n) => Ok(Some(ScalarValue::Int8(n))), + None => Ok(None), + } + } + DataType::Int16 => { + match array_ops::sum(array.as_any().downcast_ref::().unwrap()) { + Some(n) => Ok(Some(ScalarValue::Int16(n))), + None => Ok(None), + } + } + DataType::Int32 => { + match array_ops::sum(array.as_any().downcast_ref::().unwrap()) { + Some(n) => Ok(Some(ScalarValue::Int32(n))), + None => Ok(None), + } + } + DataType::Int64 => { + match array_ops::sum(array.as_any().downcast_ref::().unwrap()) { + Some(n) => Ok(Some(ScalarValue::Int64(n))), + None => Ok(None), + } + } + DataType::Float32 => { + match array_ops::sum(array.as_any().downcast_ref::().unwrap()) { + Some(n) => Ok(Some(ScalarValue::Float32(n))), + None => Ok(None), + } + } + DataType::Float64 => { + match array_ops::sum(array.as_any().downcast_ref::().unwrap()) { + Some(n) => Ok(Some(ScalarValue::Float64(n))), + None => Ok(None), + } + } + _ => Err(ExecutionError::ExecutionError( + "Unsupported data type for SUM".to_string(), + )), + } +} + +fn update_accumulators( + batch: &RecordBatch, + row: usize, + accumulator_set: &mut AccumulatorSet, + aggr_expr: &Vec, +) { + // update the accumulators + for j in 0..accumulator_set.aggr_values.len() { + match &aggr_expr[j] { + RuntimeExpr::AggregateFunction { args, t, .. } => { + // evaluate argument to aggregate function + match args[0](&batch) { + Ok(array) => { + let value: Option = match t { + DataType::UInt8 => { + let z = + array.as_any().downcast_ref::().unwrap(); + Some(ScalarValue::UInt8(z.value(row))) + } + DataType::UInt16 => { + let z = + array.as_any().downcast_ref::().unwrap(); + Some(ScalarValue::UInt16(z.value(row))) + } + DataType::UInt32 => { + let z = + array.as_any().downcast_ref::().unwrap(); + Some(ScalarValue::UInt32(z.value(row))) + } + DataType::UInt64 => { + let z = + array.as_any().downcast_ref::().unwrap(); + Some(ScalarValue::UInt64(z.value(row))) + } + DataType::Int8 => { + let z = + array.as_any().downcast_ref::().unwrap(); + Some(ScalarValue::Int8(z.value(row))) + } + DataType::Int16 => { + let z = + array.as_any().downcast_ref::().unwrap(); + Some(ScalarValue::Int16(z.value(row))) + } + DataType::Int32 => { + let z = + array.as_any().downcast_ref::().unwrap(); + Some(ScalarValue::Int32(z.value(row))) + } + DataType::Int64 => { + let z = + array.as_any().downcast_ref::().unwrap(); + Some(ScalarValue::Int64(z.value(row))) + } + DataType::Float32 => { + let z = array + .as_any() + .downcast_ref::() + .unwrap(); + Some(ScalarValue::Float32(z.value(row))) + } + DataType::Float64 => { + let z = array + .as_any() + .downcast_ref::() + .unwrap(); + Some(ScalarValue::Float64(z.value(row))) + } + _ => panic!(), + }; + accumulator_set.accumulate_scalar(j, value); + } + _ => panic!(), + } + } + _ => panic!(), + } + } +} + +impl Relation for AggregateRelation { + fn next(&mut self) -> Result> { + if self.end_of_results { + Ok(None) + } else { + self.end_of_results = true; + if self.group_expr.is_empty() { + self.without_group_by() + } else { + self.with_group_by() + } + } + } + + fn schema(&self) -> &Arc { + &self.schema + } +} + +macro_rules! array_from_scalar { + ($BUILDER:ident, $TY:ident, $ACCUM:expr) => {{ + let mut b = $BUILDER::new(1); + let mut err = false; + match $ACCUM.result() { + Some(ScalarValue::$TY(n)) => { + b.append_value(*n)?; + } + None => { + b.append_null()?; + } + Some(_) => { + err = true; + } + }; + if err { + Err(ExecutionError::ExecutionError( + "unexpected type when creating array from scalar value".to_string(), + )) + } else { + Ok(Arc::new(b.finish()) as ArrayRef) + } + }}; +} + +/// Create array from `key` attribute in map entry (representing a grouping scalar value) +macro_rules! group_array_from_map_entries { + ($BUILDER:ident, $TY:ident, $ENTRIES:expr, $COL_INDEX:expr) => {{ + let mut builder = $BUILDER::new($ENTRIES.len()); + let mut err = false; + for j in 0..$ENTRIES.len() { + match $ENTRIES[j].k[$COL_INDEX] { + GroupByScalar::$TY(n) => builder.append_value(n).unwrap(), + _ => err = true, + } + } + if err { + Err(ExecutionError::ExecutionError( + "unexpected type when creating array from aggregate map".to_string(), + )) + } else { + Ok(Arc::new(builder.finish()) as ArrayRef) + } + }}; +} + +/// Create array from `value` attribute in map entry (representing an aggregate scalar value) +macro_rules! aggr_array_from_map_entries { + ($BUILDER:ident, $TY:ident, $ENTRIES:expr, $COL_INDEX:expr) => {{ + let mut builder = $BUILDER::new($ENTRIES.len()); + let mut err = false; + for j in 0..$ENTRIES.len() { + match $ENTRIES[j].v[$COL_INDEX] { + Some(ScalarValue::$TY(n)) => builder.append_value(n).unwrap(), + None => builder.append_null().unwrap(), + _ => err = true, + } + } + if err { + Err(ExecutionError::ExecutionError( + "unexpected type when creating array from aggregate map".to_string(), + )) + } else { + Ok(Arc::new(builder.finish()) as ArrayRef) + } + }}; +} + +impl AggregateRelation { + /// perform simple aggregate on entire columns without grouping logic + fn without_group_by(&mut self) -> Result> { + let aggr_expr_count = self.aggr_expr.len(); + let mut accumulator_set = create_accumulators(&self.aggr_expr)?; + + while let Some(batch) = self.input.borrow_mut().next()? { + for i in 0..aggr_expr_count { + match &self.aggr_expr[i] { + RuntimeExpr::AggregateFunction { f, args, t, .. } => { + // evaluate argument to aggregate function + match args[0](&batch) { + Ok(array) => match f { + AggregateType::Min => accumulator_set + .accumulate_scalar(i, array_min(array, &t)?), + AggregateType::Max => accumulator_set + .accumulate_scalar(i, array_max(array, &t)?), + AggregateType::Sum => accumulator_set + .accumulate_scalar(i, array_sum(array, &t)?), + _ => { + return Err(ExecutionError::NotImplemented( + "Unsupported aggregate function".to_string(), + )); + } + }, + Err(_) => { + return Err(ExecutionError::ExecutionError( + "Failed to evaluate argument to aggregate function" + .to_string(), + )); + } + } + } + _ => { + return Err(ExecutionError::General( + "Invalid aggregate expression".to_string(), + )); + } + } + } + } + + let mut result_columns: Vec = vec![]; + + for i in 0..aggr_expr_count { + let accum = accumulator_set.aggr_values[i].borrow(); + match accum.data_type() { + DataType::UInt8 => { + result_columns.push(array_from_scalar!(UInt8Builder, UInt8, accum)?) + } + DataType::UInt16 => { + result_columns.push(array_from_scalar!(UInt16Builder, UInt16, accum)?) + } + DataType::UInt32 => { + result_columns.push(array_from_scalar!(UInt32Builder, UInt32, accum)?) + } + DataType::UInt64 => { + result_columns.push(array_from_scalar!(UInt64Builder, UInt64, accum)?) + } + DataType::Int8 => { + result_columns.push(array_from_scalar!(Int8Builder, Int8, accum)?) + } + DataType::Int16 => { + result_columns.push(array_from_scalar!(Int16Builder, Int16, accum)?) + } + DataType::Int32 => { + result_columns.push(array_from_scalar!(Int32Builder, Int32, accum)?) + } + DataType::Int64 => { + result_columns.push(array_from_scalar!(Int64Builder, Int64, accum)?) + } + DataType::Float32 => result_columns.push(array_from_scalar!( + Float32Builder, + Float32, + accum + )?), + DataType::Float64 => result_columns.push(array_from_scalar!( + Float64Builder, + Float64, + accum + )?), + _ => return Err(ExecutionError::NotImplemented("tbd".to_string())), + } + } + + Ok(Some(RecordBatch::new(self.schema.clone(), result_columns))) + } + + fn with_group_by(&mut self) -> Result> { + //NOTE this whole method is currently very inefficient with too many per-row operations + // involving pattern matching and downcasting ... I'm sure this can be re-implemented in + // a much more efficient way that takes better advantage of Arrow + + // create map to store aggregate results + let mut map: FnvHashMap, Rc>> = + FnvHashMap::default(); + + while let Some(batch) = self.input.borrow_mut().next()? { + // evaulate the group by expressions on this batch + let group_by_keys: Vec = self + .group_expr + .iter() + .map(|e| e.get_func()(&batch)) + .collect::>>()?; + + // iterate over each row in the batch + for row in 0..batch.num_rows() { + // create key + let key: Vec = group_by_keys + .iter() + .map(|col| match col.data_type() { + DataType::UInt8 => { + let array = + col.as_any().downcast_ref::().unwrap(); + Ok(GroupByScalar::UInt8(array.value(row))) + } + DataType::UInt16 => { + let array = + col.as_any().downcast_ref::().unwrap(); + Ok(GroupByScalar::UInt16(array.value(row))) + } + DataType::UInt32 => { + let array = + col.as_any().downcast_ref::().unwrap(); + Ok(GroupByScalar::UInt32(array.value(row))) + } + DataType::UInt64 => { + let array = + col.as_any().downcast_ref::().unwrap(); + Ok(GroupByScalar::UInt64(array.value(row))) + } + DataType::Int8 => { + let array = col.as_any().downcast_ref::().unwrap(); + Ok(GroupByScalar::Int8(array.value(row))) + } + DataType::Int16 => { + let array = + col.as_any().downcast_ref::().unwrap(); + Ok(GroupByScalar::Int16(array.value(row))) + } + DataType::Int32 => { + let array = + col.as_any().downcast_ref::().unwrap(); + Ok(GroupByScalar::Int32(array.value(row))) + } + DataType::Int64 => { + let array = + col.as_any().downcast_ref::().unwrap(); + Ok(GroupByScalar::Int64(array.value(row))) + } + DataType::Utf8 => { + let array = + col.as_any().downcast_ref::().unwrap(); + Ok(GroupByScalar::Utf8(String::from( + str::from_utf8(array.value(row)).unwrap(), + ))) + } + _ => Err(ExecutionError::ExecutionError( + "Unsupported GROUP BY data type".to_string(), + )), + }) + .collect::>>()?; + + //TODO: find more elegant way to write this instead of hacking around ownership issues + + let updated = match map.get(&key) { + Some(entry) => { + let mut accumulator_set = entry.borrow_mut(); + update_accumulators( + &batch, + row, + &mut accumulator_set, + &self.aggr_expr, + ); + true + } + None => false, + }; + + if !updated { + let accumulator_set = + Rc::new(RefCell::new(create_accumulators(&self.aggr_expr)?)); + { + let mut entry_mut = accumulator_set.borrow_mut(); + update_accumulators(&batch, row, &mut entry_mut, &self.aggr_expr); + } + map.insert(key.clone(), accumulator_set); + } + } + } + + // convert the map to a vec to make it easier to build arrays + let entries: Vec = map + .iter() + .map(|(k, v)| { + let x = v.borrow(); + MapEntry { + k: k.clone(), + v: x.values(), + } + }) + .collect(); + + // build the result arrays + let mut result_arrays: Vec = + Vec::with_capacity(self.group_expr.len() + self.aggr_expr.len()); + + // grouping values + for i in 0..self.group_expr.len() { + let array: Result = match self.group_expr[i].get_type() { + DataType::UInt8 => { + group_array_from_map_entries!(UInt8Builder, UInt8, entries, i) + } + DataType::UInt16 => { + group_array_from_map_entries!(UInt16Builder, UInt16, entries, i) + } + DataType::UInt32 => { + group_array_from_map_entries!(UInt32Builder, UInt32, entries, i) + } + DataType::UInt64 => { + group_array_from_map_entries!(UInt64Builder, UInt64, entries, i) + } + DataType::Int8 => { + group_array_from_map_entries!(Int8Builder, Int8, entries, i) + } + DataType::Int16 => { + group_array_from_map_entries!(Int16Builder, Int16, entries, i) + } + DataType::Int32 => { + group_array_from_map_entries!(Int32Builder, Int32, entries, i) + } + DataType::Int64 => { + group_array_from_map_entries!(Int64Builder, Int64, entries, i) + } + DataType::Utf8 => { + let mut builder = BinaryBuilder::new(1); + for j in 0..entries.len() { + match &entries[j].k[i] { + GroupByScalar::Utf8(s) => builder.append_string(&s).unwrap(), + _ => {} + } + } + Ok(Arc::new(builder.finish()) as ArrayRef) + } + _ => Err(ExecutionError::ExecutionError( + "Unsupported group by expr".to_string(), + )), + }; + result_arrays.push(array?); + } + + // aggregate values + for i in 0..self.aggr_expr.len() { + let array = match self.aggr_expr[i].get_type() { + DataType::UInt8 => { + aggr_array_from_map_entries!(UInt8Builder, UInt8, entries, i) + } + DataType::UInt16 => { + aggr_array_from_map_entries!(UInt16Builder, UInt16, entries, i) + } + DataType::UInt32 => { + aggr_array_from_map_entries!(UInt32Builder, UInt32, entries, i) + } + DataType::UInt64 => { + aggr_array_from_map_entries!(UInt64Builder, UInt64, entries, i) + } + DataType::Int8 => { + group_array_from_map_entries!(Int8Builder, Int8, entries, i) + } + DataType::Int16 => { + aggr_array_from_map_entries!(Int16Builder, Int16, entries, i) + } + DataType::Int32 => { + aggr_array_from_map_entries!(Int32Builder, Int32, entries, i) + } + DataType::Int64 => { + aggr_array_from_map_entries!(Int64Builder, Int64, entries, i) + } + DataType::Float32 => { + aggr_array_from_map_entries!(Float32Builder, Float32, entries, i) + } + DataType::Float64 => { + aggr_array_from_map_entries!(Float64Builder, Float64, entries, i) + } + _ => Err(ExecutionError::ExecutionError( + "Unsupported aggregate expr".to_string(), + )), + }; + result_arrays.push(array?); + } + + Ok(Some(RecordBatch::new(self.schema.clone(), result_arrays))) + } +} + +#[cfg(test)] +mod tests { + use super::super::super::logicalplan::Expr; + use super::super::context::ExecutionContext; + use super::super::datasource::CsvDataSource; + use super::super::expression; + use super::super::relation::DataSourceRelation; + use super::*; + use arrow::datatypes::{DataType, Field, Schema}; + + #[test] + fn min_f64_group_by_string() { + let schema = aggr_test_schema(); + let relation = load_csv("../../testing/data/csv/aggregate_test_100.csv", &schema); + let context = ExecutionContext::new(); + + let aggr_expr = vec![expression::compile_expr( + &context, + &Expr::AggregateFunction { + name: String::from("min"), + args: vec![Expr::Column(11)], + return_type: DataType::Float64, + }, + &schema, + ) + .unwrap()]; + + let aggr_schema = Arc::new(Schema::new(vec![Field::new( + "min_lat", + DataType::Float64, + false, + )])); + + let mut projection = + AggregateRelation::new(aggr_schema, relation, vec![], aggr_expr); + let batch = projection.next().unwrap().unwrap(); + assert_eq!(1, batch.num_columns()); + let min_lat = batch + .column(0) + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!(0.01479305307777301, min_lat.value(0)); + } + + #[test] + fn max_f64_group_by_string() { + let schema = aggr_test_schema(); + let relation = load_csv("../../testing/data/csv/aggregate_test_100.csv", &schema); + let context = ExecutionContext::new(); + + let aggr_expr = vec![expression::compile_expr( + &context, + &Expr::AggregateFunction { + name: String::from("max"), + args: vec![Expr::Column(11)], + return_type: DataType::Float64, + }, + &schema, + ) + .unwrap()]; + + let aggr_schema = Arc::new(Schema::new(vec![Field::new( + "max_lat", + DataType::Float64, + false, + )])); + + let mut projection = + AggregateRelation::new(aggr_schema, relation, vec![], aggr_expr); + let batch = projection.next().unwrap().unwrap(); + assert_eq!(1, batch.num_columns()); + let max_lat = batch + .column(0) + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!(0.9965400387585364, max_lat.value(0)); + } + + #[test] + fn test_min_max_sum_f64_group_by_uint32() { + let schema = aggr_test_schema(); + let relation = load_csv("../../testing/data/csv/aggregate_test_100.csv", &schema); + + let context = ExecutionContext::new(); + + let group_by_expr = + expression::compile_expr(&context, &Expr::Column(1), &schema).unwrap(); + + let min_expr = expression::compile_expr( + &context, + &Expr::AggregateFunction { + name: String::from("min"), + args: vec![Expr::Column(11)], + return_type: DataType::Float64, + }, + &schema, + ) + .unwrap(); + + let max_expr = expression::compile_expr( + &context, + &Expr::AggregateFunction { + name: String::from("max"), + args: vec![Expr::Column(11)], + return_type: DataType::Float64, + }, + &schema, + ) + .unwrap(); + + let sum_expr = expression::compile_expr( + &context, + &Expr::AggregateFunction { + name: String::from("sum"), + args: vec![Expr::Column(11)], + return_type: DataType::Float64, + }, + &schema, + ) + .unwrap(); + + let aggr_schema = Arc::new(Schema::new(vec![ + Field::new("c2", DataType::Int32, false), + Field::new("min", DataType::Float64, false), + Field::new("max", DataType::Float64, false), + Field::new("sum", DataType::Float64, false), + ])); + + let mut projection = AggregateRelation::new( + aggr_schema, + relation, + vec![group_by_expr], + vec![min_expr, max_expr, sum_expr], + ); + let batch = projection.next().unwrap().unwrap(); + assert_eq!(4, batch.num_columns()); + assert_eq!(5, batch.num_rows()); + + let a = batch + .column(0) + .as_any() + .downcast_ref::() + .unwrap(); + let min = batch + .column(1) + .as_any() + .downcast_ref::() + .unwrap(); + let max = batch + .column(2) + .as_any() + .downcast_ref::() + .unwrap(); + let sum = batch + .column(3) + .as_any() + .downcast_ref::() + .unwrap(); + + assert_eq!(4, a.value(0)); + assert_eq!(0.02182578039211991, min.value(0)); + assert_eq!(0.9237877978193884, max.value(0)); + assert_eq!(9.253864188402662, sum.value(0)); + + assert_eq!(2, a.value(1)); + assert_eq!(0.16301110515739792, min.value(1)); + assert_eq!(0.991517828651004, max.value(1)); + assert_eq!(14.400412325480858, sum.value(1)); + + assert_eq!(5, a.value(2)); + assert_eq!(0.01479305307777301, min.value(2)); + assert_eq!(0.9723580396501548, max.value(2)); + assert_eq!(6.037181692266781, sum.value(2)); + } + + fn aggr_test_schema() -> Arc { + Arc::new(Schema::new(vec![ + Field::new("c1", DataType::Utf8, false), + Field::new("c2", DataType::UInt32, false), + Field::new("c3", DataType::Int8, false), + Field::new("c3", DataType::Int16, false), + Field::new("c4", DataType::Int32, false), + Field::new("c5", DataType::Int64, false), + Field::new("c6", DataType::UInt8, false), + Field::new("c7", DataType::UInt16, false), + Field::new("c8", DataType::UInt32, false), + Field::new("c9", DataType::UInt64, false), + Field::new("c10", DataType::Float32, false), + Field::new("c11", DataType::Float64, false), + Field::new("c12", DataType::Utf8, false), + ])) + } + + fn load_csv(filename: &str, schema: &Arc) -> Rc> { + let ds = CsvDataSource::new(filename, schema.clone(), 1024); + Rc::new(RefCell::new(DataSourceRelation::new(Rc::new( + RefCell::new(ds), + )))) + } + +} diff --git a/rust/datafusion/src/execution/context.rs b/rust/datafusion/src/execution/context.rs new file mode 100644 index 0000000000000..86d7c99c21900 --- /dev/null +++ b/rust/datafusion/src/execution/context.rs @@ -0,0 +1,228 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::cell::RefCell; +use std::collections::HashMap; +use std::rc::Rc; +use std::sync::Arc; + +use arrow::datatypes::{Field, Schema}; + +use super::super::dfparser::{DFASTNode, DFParser}; +use super::super::logicalplan::*; +use super::super::sqlplanner::{SchemaProvider, SqlToRel}; +use super::aggregate::AggregateRelation; +use super::datasource::DataSource; +use super::error::{ExecutionError, Result}; +use super::expression::*; +use super::filter::FilterRelation; +use super::projection::ProjectRelation; +use super::relation::{DataSourceRelation, Relation}; + +pub struct ExecutionContext { + datasources: Rc>>>>, +} + +impl ExecutionContext { + pub fn new() -> Self { + Self { + datasources: Rc::new(RefCell::new(HashMap::new())), + } + } + + pub fn sql(&mut self, sql: &str) -> Result>> { + let ast = DFParser::parse_sql(String::from(sql))?; + + match ast { + DFASTNode::ANSI(ansi) => { + let schema_provider: Rc = + Rc::new(ExecutionContextSchemaProvider { + datasources: self.datasources.clone(), + }); + + // create a query planner + let query_planner = SqlToRel::new(schema_provider); + + // plan the query (create a logical relational plan) + let plan = query_planner.sql_to_rel(&ansi)?; + //println!("Logical plan: {:?}", plan); + + let optimized_plan = plan; //push_down_projection(&plan, &HashSet::new()); + //println!("Optimized logical plan: {:?}", new_plan); + + let relation = self.execute(&optimized_plan)?; + + Ok(relation) + } + _ => unimplemented!(), + } + } + + pub fn register_datasource(&mut self, name: &str, ds: Rc>) { + self.datasources.borrow_mut().insert(name.to_string(), ds); + } + + pub fn execute(&mut self, plan: &LogicalPlan) -> Result>> { + println!("Logical plan: {:?}", plan); + + match *plan { + LogicalPlan::TableScan { ref table_name, .. } => { + match self.datasources.borrow().get(table_name) { + Some(ds) => { + //TODO: projection + Ok(Rc::new(RefCell::new(DataSourceRelation::new(ds.clone())))) + } + _ => Err(ExecutionError::General(format!( + "No table registered as '{}'", + table_name + ))), + } + } + LogicalPlan::Selection { + ref expr, + ref input, + } => { + let input_rel = self.execute(input)?; + let input_schema = input_rel.as_ref().borrow().schema().clone(); + let runtime_expr = compile_scalar_expr(&self, expr, &input_schema)?; + let rel = FilterRelation::new( + input_rel, + runtime_expr, /* .get_func().clone() */ + input_schema, + ); + Ok(Rc::new(RefCell::new(rel))) + } + LogicalPlan::Projection { + ref expr, + ref input, + .. + } => { + let input_rel = self.execute(input)?; + + let input_schema = input_rel.as_ref().borrow().schema().clone(); + + let project_columns: Vec = + exprlist_to_fields(&expr, &input_schema); + + let project_schema = Arc::new(Schema::new(project_columns)); + + let compiled_expr: Result> = expr + .iter() + .map(|e| compile_scalar_expr(&self, e, &input_schema)) + .collect(); + + let rel = ProjectRelation::new(input_rel, compiled_expr?, project_schema); + + Ok(Rc::new(RefCell::new(rel))) + } + LogicalPlan::Aggregate { + ref input, + ref group_expr, + ref aggr_expr, + .. + } => { + let input_rel = self.execute(&input)?; + + let input_schema = input_rel.as_ref().borrow().schema().clone(); + + let compiled_group_expr_result: Result> = group_expr + .iter() + .map(|e| compile_scalar_expr(&self, e, &input_schema)) + .collect(); + let compiled_group_expr = compiled_group_expr_result?; + + let compiled_aggr_expr_result: Result> = aggr_expr + .iter() + .map(|e| compile_expr(&self, e, &input_schema)) + .collect(); + let compiled_aggr_expr = compiled_aggr_expr_result?; + + let rel = AggregateRelation::new( + Arc::new(Schema::empty()), //(expr_to_field(&compiled_group_expr, &input_schema))), + input_rel, + compiled_group_expr, + compiled_aggr_expr, + ); + + Ok(Rc::new(RefCell::new(rel))) + } + + _ => unimplemented!(), + } + } +} + +#[derive(Debug, Clone)] +pub enum ExecutionResult { + Unit, + Count(usize), + Str(String), +} + +pub fn expr_to_field(e: &Expr, input_schema: &Schema) -> Field { + match e { + Expr::Column(i) => input_schema.fields()[*i].clone(), + Expr::Literal(ref lit) => Field::new("lit", lit.get_datatype(), true), + Expr::ScalarFunction { + ref name, + ref return_type, + .. + } => Field::new(&name, return_type.clone(), true), + Expr::AggregateFunction { + ref name, + ref return_type, + .. + } => Field::new(&name, return_type.clone(), true), + Expr::Cast { ref data_type, .. } => Field::new("cast", data_type.clone(), true), + Expr::BinaryExpr { + ref left, + ref right, + .. + } => { + let left_type = left.get_type(input_schema); + let right_type = right.get_type(input_schema); + Field::new( + "binary_expr", + get_supertype(&left_type, &right_type).unwrap(), + true, + ) + } + _ => unimplemented!("Cannot determine schema type for expression {:?}", e), + } +} + +pub fn exprlist_to_fields(expr: &Vec, input_schema: &Schema) -> Vec { + expr.iter() + .map(|e| expr_to_field(e, input_schema)) + .collect() +} + +struct ExecutionContextSchemaProvider { + datasources: Rc>>>>, +} +impl SchemaProvider for ExecutionContextSchemaProvider { + fn get_table_meta(&self, name: &str) -> Option> { + match self.datasources.borrow().get(name) { + Some(ds) => Some(ds.borrow().schema().clone()), + None => None, + } + } + + fn get_function_meta(&self, _name: &str) -> Option> { + unimplemented!() + } +} diff --git a/rust/datafusion/src/execution/datasource.rs b/rust/datafusion/src/execution/datasource.rs new file mode 100644 index 0000000000000..379632b95a37f --- /dev/null +++ b/rust/datafusion/src/execution/datasource.rs @@ -0,0 +1,74 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Data sources + +use std::fs::File; +use std::rc::Rc; +use std::sync::Arc; + +use arrow::csv; +use arrow::datatypes::Schema; +use arrow::record_batch::RecordBatch; + +use super::error::Result; + +pub trait DataSource { + fn schema(&self) -> &Arc; + fn next(&mut self) -> Result>; +} + +/// CSV data source +pub struct CsvDataSource { + schema: Arc, + reader: csv::Reader, +} + +impl CsvDataSource { + pub fn new(filename: &str, schema: Arc, batch_size: usize) -> Self { + let file = File::open(filename).unwrap(); + let reader = csv::Reader::new(file, schema.clone(), true, batch_size, None); + Self { schema, reader } + } +} + +impl DataSource for CsvDataSource { + fn schema(&self) -> &Arc { + &self.schema + } + + fn next(&mut self) -> Result> { + Ok(self.reader.next()?) + } +} + +#[derive(Serialize, Deserialize, Clone)] +pub enum DataSourceMeta { + /// Represents a CSV file with a provided schema + CsvFile { + filename: String, + schema: Rc, + has_header: bool, + projection: Option>, + }, + /// Represents a Parquet file that contains schema information + ParquetFile { + filename: String, + schema: Rc, + projection: Option>, + }, +} diff --git a/rust/datafusion/src/execution/error.rs b/rust/datafusion/src/execution/error.rs new file mode 100644 index 0000000000000..5b8d04d3dca34 --- /dev/null +++ b/rust/datafusion/src/execution/error.rs @@ -0,0 +1,69 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Error types + +use std::io::Error; +use std::result; + +use arrow::error::ArrowError; + +use sqlparser::sqlparser::ParserError; + +pub type Result = result::Result; + +#[derive(Debug)] +pub enum ExecutionError { + IoError(Error), + ParserError(ParserError), + General(String), + InvalidColumn(String), + NotImplemented(String), + InternalError(String), + ArrowError(ArrowError), + ExecutionError(String), +} + +impl From for ExecutionError { + fn from(e: Error) -> Self { + ExecutionError::IoError(e) + } +} + +impl From for ExecutionError { + fn from(e: String) -> Self { + ExecutionError::General(e) + } +} + +impl From<&'static str> for ExecutionError { + fn from(e: &'static str) -> Self { + ExecutionError::General(e.to_string()) + } +} + +impl From for ExecutionError { + fn from(e: ArrowError) -> Self { + ExecutionError::ArrowError(e) + } +} + +impl From for ExecutionError { + fn from(e: ParserError) -> Self { + ExecutionError::ParserError(e) + } +} diff --git a/rust/datafusion/src/execution/expression.rs b/rust/datafusion/src/execution/expression.rs new file mode 100644 index 0000000000000..108a8558b9adf --- /dev/null +++ b/rust/datafusion/src/execution/expression.rs @@ -0,0 +1,516 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::rc::Rc; +use std::sync::Arc; + +use arrow::array::*; +use arrow::array_ops; +use arrow::datatypes::{DataType, Schema}; +use arrow::record_batch::RecordBatch; + +use super::super::logicalplan::{Expr, Operator, ScalarValue}; +use super::context::ExecutionContext; +use super::error::{ExecutionError, Result}; + +/// Compiled Expression (basically just a closure to evaluate the expression at runtime) +pub type CompiledExpr = Rc Result>; + +pub type CompiledCastFunction = Rc Result>; + +pub enum AggregateType { + Min, + Max, + Sum, + Count, + CountDistinct, + Avg, +} + +/// Runtime expression +pub enum RuntimeExpr { + Compiled { + name: String, + f: CompiledExpr, + t: DataType, + }, + AggregateFunction { + name: String, + f: AggregateType, + args: Vec, + t: DataType, + }, +} + +impl RuntimeExpr { + pub fn get_func(&self) -> CompiledExpr { + match self { + &RuntimeExpr::Compiled { ref f, .. } => f.clone(), + _ => panic!(), + } + } + + pub fn get_name(&self) -> &String { + match self { + &RuntimeExpr::Compiled { ref name, .. } => name, + &RuntimeExpr::AggregateFunction { ref name, .. } => name, + } + } + + pub fn get_type(&self) -> DataType { + match self { + &RuntimeExpr::Compiled { ref t, .. } => t.clone(), + &RuntimeExpr::AggregateFunction { ref t, .. } => t.clone(), + } + } +} + +/// Compiles a scalar expression into a closure +pub fn compile_expr( + ctx: &ExecutionContext, + expr: &Expr, + input_schema: &Schema, +) -> Result { + match *expr { + Expr::AggregateFunction { + ref name, + ref args, + ref return_type, + } => { + assert_eq!(1, args.len()); + + let compiled_args: Result> = args + .iter() + .map(|e| compile_scalar_expr(&ctx, e, input_schema)) + .collect(); + + let func = match name.to_lowercase().as_ref() { + "min" => Ok(AggregateType::Min), + "max" => Ok(AggregateType::Max), + "count" => Ok(AggregateType::Count), + "sum" => Ok(AggregateType::Sum), + _ => Err(ExecutionError::General(format!( + "Unsupported aggregate function '{}'", + name + ))), + }; + + Ok(RuntimeExpr::AggregateFunction { + name: name.to_string(), + f: func?, + args: compiled_args? + .iter() + .map(|e| e.get_func().clone()) + .collect(), + t: return_type.clone(), + }) + } + _ => Ok(compile_scalar_expr(&ctx, expr, input_schema)?), + } +} + +macro_rules! binary_op { + ($LEFT:expr, $RIGHT:expr, $OP:ident, $DT:ident) => {{ + let ll = $LEFT.as_any().downcast_ref::<$DT>().unwrap(); + let rr = $RIGHT.as_any().downcast_ref::<$DT>().unwrap(); + Ok(Arc::new(array_ops::$OP(&ll, &rr)?)) + }}; +} + +macro_rules! math_ops { + ($LEFT:expr, $RIGHT:expr, $BATCH:expr, $OP:ident) => {{ + let left_values = $LEFT.get_func()($BATCH)?; + let right_values = $RIGHT.get_func()($BATCH)?; + match (left_values.data_type(), right_values.data_type()) { + (DataType::Int8, DataType::Int8) => { + binary_op!(left_values, right_values, $OP, Int8Array) + } + (DataType::Int16, DataType::Int16) => { + binary_op!(left_values, right_values, $OP, Int16Array) + } + (DataType::Int32, DataType::Int32) => { + binary_op!(left_values, right_values, $OP, Int32Array) + } + (DataType::Int64, DataType::Int64) => { + binary_op!(left_values, right_values, $OP, Int64Array) + } + (DataType::UInt8, DataType::UInt8) => { + binary_op!(left_values, right_values, $OP, UInt8Array) + } + (DataType::UInt16, DataType::UInt16) => { + binary_op!(left_values, right_values, $OP, UInt16Array) + } + (DataType::UInt32, DataType::UInt32) => { + binary_op!(left_values, right_values, $OP, UInt32Array) + } + (DataType::UInt64, DataType::UInt64) => { + binary_op!(left_values, right_values, $OP, UInt64Array) + } + (DataType::Float32, DataType::Float32) => { + binary_op!(left_values, right_values, $OP, Float32Array) + } + (DataType::Float64, DataType::Float64) => { + binary_op!(left_values, right_values, $OP, Float64Array) + } + _ => Err(ExecutionError::ExecutionError(format!("math_ops"))), + } + }}; +} + +macro_rules! comparison_ops { + ($LEFT:expr, $RIGHT:expr, $BATCH:expr, $OP:ident) => {{ + let left_values = $LEFT.get_func()($BATCH)?; + let right_values = $RIGHT.get_func()($BATCH)?; + match (left_values.data_type(), right_values.data_type()) { + (DataType::Int8, DataType::Int8) => { + binary_op!(left_values, right_values, $OP, Int8Array) + } + (DataType::Int16, DataType::Int16) => { + binary_op!(left_values, right_values, $OP, Int16Array) + } + (DataType::Int32, DataType::Int32) => { + binary_op!(left_values, right_values, $OP, Int32Array) + } + (DataType::Int64, DataType::Int64) => { + binary_op!(left_values, right_values, $OP, Int64Array) + } + (DataType::UInt8, DataType::UInt8) => { + binary_op!(left_values, right_values, $OP, UInt8Array) + } + (DataType::UInt16, DataType::UInt16) => { + binary_op!(left_values, right_values, $OP, UInt16Array) + } + (DataType::UInt32, DataType::UInt32) => { + binary_op!(left_values, right_values, $OP, UInt32Array) + } + (DataType::UInt64, DataType::UInt64) => { + binary_op!(left_values, right_values, $OP, UInt64Array) + } + (DataType::Float32, DataType::Float32) => { + binary_op!(left_values, right_values, $OP, Float32Array) + } + (DataType::Float64, DataType::Float64) => { + binary_op!(left_values, right_values, $OP, Float64Array) + } + //TODO other types + _ => Err(ExecutionError::ExecutionError(format!("comparison_ops"))), + } + }}; +} + +macro_rules! boolean_ops { + ($LEFT:expr, $RIGHT:expr, $BATCH:expr, $OP:ident) => {{ + let left_values = $LEFT.get_func()($BATCH)?; + let right_values = $RIGHT.get_func()($BATCH)?; + Ok(Arc::new(array_ops::$OP( + left_values.as_any().downcast_ref::().unwrap(), + right_values + .as_any() + .downcast_ref::() + .unwrap(), + )?)) + }}; +} + +macro_rules! literal_array { + ($VALUE:expr, $ARRAY_TYPE:ident, $TY:ident) => {{ + let nn = *$VALUE; + Ok(RuntimeExpr::Compiled { + name: format!("{}", nn), + f: Rc::new(move |batch: &RecordBatch| { + let capacity = batch.num_rows(); + let mut builder = $ARRAY_TYPE::builder(capacity); + for _ in 0..capacity { + builder.append_value(nn)?; + } + let array = builder.finish(); + Ok(Arc::new(array) as ArrayRef) + }), + t: DataType::$TY, + }) + }}; +} + +/// Casts a column to an array with a different data type +macro_rules! cast_column { + ($INDEX:expr, $FROM_TYPE:ty, $TO_TYPE:ident, $DT:ty) => {{ + Rc::new(move |batch: &RecordBatch| { + // get data and cast to known type + match batch.column($INDEX).as_any().downcast_ref::<$FROM_TYPE>() { + Some(array) => { + // create builder for desired type + let mut builder = $TO_TYPE::builder(batch.num_rows()); + for i in 0..batch.num_rows() { + if array.is_null(i) { + builder.append_null()?; + } else { + builder.append_value(array.value(i) as $DT)?; + } + } + Ok(Arc::new(builder.finish()) as ArrayRef) + } + None => Err(ExecutionError::InternalError(format!( + "Column at index {} is not of expected type", + $INDEX + ))), + } + }) + }}; +} + +macro_rules! cast_column_outer { + ($INDEX:expr, $FROM_TYPE:ty, $TO_TYPE:expr) => {{ + match $TO_TYPE { + DataType::UInt8 => cast_column!($INDEX, $FROM_TYPE, UInt8Array, u8), + DataType::UInt16 => cast_column!($INDEX, $FROM_TYPE, UInt16Array, u16), + DataType::UInt32 => cast_column!($INDEX, $FROM_TYPE, UInt32Array, u32), + DataType::UInt64 => cast_column!($INDEX, $FROM_TYPE, UInt64Array, u64), + DataType::Int8 => cast_column!($INDEX, $FROM_TYPE, Int8Array, i8), + DataType::Int16 => cast_column!($INDEX, $FROM_TYPE, Int16Array, i16), + DataType::Int32 => cast_column!($INDEX, $FROM_TYPE, Int32Array, i32), + DataType::Int64 => cast_column!($INDEX, $FROM_TYPE, Int64Array, i64), + DataType::Float32 => cast_column!($INDEX, $FROM_TYPE, Float32Array, f32), + DataType::Float64 => cast_column!($INDEX, $FROM_TYPE, Float64Array, f64), + _ => unimplemented!(), + } + }}; +} + +/// Compiles a scalar expression into a closure +pub fn compile_scalar_expr( + ctx: &ExecutionContext, + expr: &Expr, + input_schema: &Schema, +) -> Result { + match expr { + &Expr::Literal(ref value) => match value { + //NOTE: this is a temporary hack .. due to the way expressions like 'a > 1' are + // evaluated, currently the left and right are evaluated separately and must result + // in arrays and then the '>' operator is evaluated against the two arrays. This works + // but is dumb ... I intend to optimize this soon to add special handling for + // binary expressions that involve literal values to avoid creating arrays of literals + // filed as https://github.com/andygrove/datafusion/issues/191 + ScalarValue::Int8(n) => literal_array!(n, Int8Array, Int8), + ScalarValue::Int16(n) => literal_array!(n, Int16Array, Int16), + ScalarValue::Int32(n) => literal_array!(n, Int32Array, Int32), + ScalarValue::Int64(n) => literal_array!(n, Int64Array, Int64), + ScalarValue::UInt8(n) => literal_array!(n, UInt8Array, UInt8), + ScalarValue::UInt16(n) => literal_array!(n, UInt16Array, UInt16), + ScalarValue::UInt32(n) => literal_array!(n, UInt32Array, UInt32), + ScalarValue::UInt64(n) => literal_array!(n, UInt64Array, UInt64), + ScalarValue::Float32(n) => literal_array!(n, Float32Array, Float32), + ScalarValue::Float64(n) => literal_array!(n, Float64Array, Float64), + other => Err(ExecutionError::ExecutionError(format!( + "No support for literal type {:?}", + other + ))), + }, + &Expr::Column(index) => Ok(RuntimeExpr::Compiled { + name: input_schema.field(index).name().clone(), + f: Rc::new(move |batch: &RecordBatch| Ok((*batch.column(index)).clone())), + t: input_schema.field(index).data_type().clone(), + }), + &Expr::Cast { + ref expr, + ref data_type, + } => match expr.as_ref() { + &Expr::Column(index) => { + let col = input_schema.field(index); + Ok(RuntimeExpr::Compiled { + name: col.name().clone(), + t: col.data_type().clone(), + f: match col.data_type() { + DataType::Int8 => { + cast_column_outer!(index, Int8Array, &data_type) + } + DataType::Int16 => { + cast_column_outer!(index, Int16Array, &data_type) + } + DataType::Int32 => { + cast_column_outer!(index, Int32Array, &data_type) + } + DataType::Int64 => { + cast_column_outer!(index, Int64Array, &data_type) + } + DataType::UInt8 => { + cast_column_outer!(index, UInt8Array, &data_type) + } + DataType::UInt16 => { + cast_column_outer!(index, UInt16Array, &data_type) + } + DataType::UInt32 => { + cast_column_outer!(index, UInt32Array, &data_type) + } + DataType::UInt64 => { + cast_column_outer!(index, UInt64Array, &data_type) + } + DataType::Float32 => { + cast_column_outer!(index, Float32Array, &data_type) + } + DataType::Float64 => { + cast_column_outer!(index, Float64Array, &data_type) + } + _ => panic!("unsupported CAST operation"), /*TODO */ + /*Err(ExecutionError::NotImplemented(format!( + "CAST column from {:?} to {:?}", + col.data_type(), + data_type + )))*/ + }, + }) + } + &Expr::Literal(ref value) => { + //NOTE this is all very inefficient and needs to be optimized - tracking + // issue is https://github.com/andygrove/datafusion/issues/191 + match value { + ScalarValue::Int64(n) => { + let nn = *n; + match data_type { + DataType::Float64 => Ok(RuntimeExpr::Compiled { + name: "lit".to_string(), + f: Rc::new(move |batch: &RecordBatch| { + let mut b = Float64Array::builder(batch.num_rows()); + for _ in 0..batch.num_rows() { + b.append_value(nn as f64)?; + } + Ok(Arc::new(b.finish()) as ArrayRef) + }), + t: data_type.clone(), + }), + other => Err(ExecutionError::NotImplemented(format!( + "CAST from Int64 to {:?}", + other + ))), + } + } + other => Err(ExecutionError::NotImplemented(format!( + "CAST from {:?} to {:?}", + other, data_type + ))), + } + } + other => Err(ExecutionError::General(format!( + "CAST not implemented for expression {:?}", + other + ))), + }, + &Expr::BinaryExpr { + ref left, + ref op, + ref right, + } => { + let left_expr = compile_scalar_expr(ctx, left, input_schema)?; + let right_expr = compile_scalar_expr(ctx, right, input_schema)?; + let name = format!("{:?} {:?} {:?}", left, op, right); + let op_type = left_expr.get_type().clone(); + match op { + &Operator::Eq => Ok(RuntimeExpr::Compiled { + name, + f: Rc::new(move |batch: &RecordBatch| { + comparison_ops!(left_expr, right_expr, batch, eq) + }), + t: DataType::Boolean, + }), + &Operator::NotEq => Ok(RuntimeExpr::Compiled { + name, + f: Rc::new(move |batch: &RecordBatch| { + comparison_ops!(left_expr, right_expr, batch, neq) + }), + t: DataType::Boolean, + }), + &Operator::Lt => Ok(RuntimeExpr::Compiled { + name, + f: Rc::new(move |batch: &RecordBatch| { + comparison_ops!(left_expr, right_expr, batch, lt) + }), + t: DataType::Boolean, + }), + &Operator::LtEq => Ok(RuntimeExpr::Compiled { + name, + f: Rc::new(move |batch: &RecordBatch| { + comparison_ops!(left_expr, right_expr, batch, lt_eq) + }), + t: DataType::Boolean, + }), + &Operator::Gt => Ok(RuntimeExpr::Compiled { + name, + f: Rc::new(move |batch: &RecordBatch| { + comparison_ops!(left_expr, right_expr, batch, gt) + }), + t: DataType::Boolean, + }), + &Operator::GtEq => Ok(RuntimeExpr::Compiled { + name, + f: Rc::new(move |batch: &RecordBatch| { + comparison_ops!(left_expr, right_expr, batch, gt_eq) + }), + t: DataType::Boolean, + }), + &Operator::And => Ok(RuntimeExpr::Compiled { + name, + f: Rc::new(move |batch: &RecordBatch| { + boolean_ops!(left_expr, right_expr, batch, and) + }), + t: DataType::Boolean, + }), + &Operator::Or => Ok(RuntimeExpr::Compiled { + name, + f: Rc::new(move |batch: &RecordBatch| { + boolean_ops!(left_expr, right_expr, batch, or) + }), + t: DataType::Boolean, + }), + &Operator::Plus => Ok(RuntimeExpr::Compiled { + name, + f: Rc::new(move |batch: &RecordBatch| { + math_ops!(left_expr, right_expr, batch, add) + }), + t: op_type, + }), + &Operator::Minus => Ok(RuntimeExpr::Compiled { + name, + f: Rc::new(move |batch: &RecordBatch| { + math_ops!(left_expr, right_expr, batch, subtract) + }), + t: op_type, + }), + &Operator::Multiply => Ok(RuntimeExpr::Compiled { + name, + f: Rc::new(move |batch: &RecordBatch| { + math_ops!(left_expr, right_expr, batch, multiply) + }), + t: op_type, + }), + &Operator::Divide => Ok(RuntimeExpr::Compiled { + name, + f: Rc::new(move |batch: &RecordBatch| { + math_ops!(left_expr, right_expr, batch, divide) + }), + t: op_type, + }), + other => Err(ExecutionError::ExecutionError(format!( + "operator: {:?}", + other + ))), + } + } + other => Err(ExecutionError::ExecutionError(format!( + "expression {:?}", + other + ))), + } +} diff --git a/rust/datafusion/src/execution/filter.rs b/rust/datafusion/src/execution/filter.rs new file mode 100644 index 0000000000000..ba20dca036e5b --- /dev/null +++ b/rust/datafusion/src/execution/filter.rs @@ -0,0 +1,211 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Execution of a filter (predicate) + +use std::cell::RefCell; +use std::rc::Rc; +use std::sync::Arc; + +use arrow::array::*; +use arrow::datatypes::{DataType, Schema}; +use arrow::record_batch::RecordBatch; + +use super::error::{ExecutionError, Result}; +use super::expression::RuntimeExpr; +use super::relation::Relation; + +pub struct FilterRelation { + schema: Arc, + input: Rc>, + expr: RuntimeExpr, +} + +impl FilterRelation { + pub fn new( + input: Rc>, + expr: RuntimeExpr, + schema: Arc, + ) -> Self { + Self { + schema, + input, + expr, + } + } +} + +impl Relation for FilterRelation { + fn next(&mut self) -> Result> { + match self.input.borrow_mut().next()? { + Some(batch) => { + // evaluate the filter expression against the batch + match self.expr.get_func()(&batch)? + .as_any() + .downcast_ref::() + { + Some(filter_bools) => { + let filtered_columns: Result> = (0..batch + .num_columns()) + .map(|i| filter(batch.column(i), &filter_bools)) + .collect(); + + let filtered_batch: RecordBatch = RecordBatch::new( + Arc::new(Schema::empty()), + filtered_columns?, + ); + + Ok(Some(filtered_batch)) + } + _ => Err(ExecutionError::ExecutionError( + "Filter expression did not evaluate to boolean".to_string(), + )), + } + } + None => Ok(None), + } + } + + fn schema(&self) -> &Arc { + &self.schema + } +} + +//TODO: move into Arrow array_ops +fn filter(array: &Arc, filter: &BooleanArray) -> Result { + let a = array.as_ref(); + + //TODO use macros + match a.data_type() { + DataType::UInt8 => { + let b = a.as_any().downcast_ref::().unwrap(); + let mut builder = UInt8Array::builder(b.len()); + for i in 0..b.len() { + if filter.value(i) { + builder.append_value(b.value(i))?; + } + } + Ok(Arc::new(builder.finish())) + } + DataType::UInt16 => { + let b = a.as_any().downcast_ref::().unwrap(); + let mut builder = UInt16Array::builder(b.len()); + for i in 0..b.len() { + if filter.value(i) { + builder.append_value(b.value(i))?; + } + } + Ok(Arc::new(builder.finish())) + } + DataType::UInt32 => { + let b = a.as_any().downcast_ref::().unwrap(); + let mut builder = UInt32Array::builder(b.len()); + for i in 0..b.len() { + if filter.value(i) { + builder.append_value(b.value(i))?; + } + } + Ok(Arc::new(builder.finish())) + } + DataType::UInt64 => { + let b = a.as_any().downcast_ref::().unwrap(); + let mut builder = UInt64Array::builder(b.len()); + for i in 0..b.len() { + if filter.value(i) { + builder.append_value(b.value(i))?; + } + } + Ok(Arc::new(builder.finish())) + } + DataType::Int8 => { + let b = a.as_any().downcast_ref::().unwrap(); + let mut builder = Int8Array::builder(b.len()); + for i in 0..b.len() { + if filter.value(i) { + builder.append_value(b.value(i))?; + } + } + Ok(Arc::new(builder.finish())) + } + DataType::Int16 => { + let b = a.as_any().downcast_ref::().unwrap(); + let mut builder = Int16Array::builder(b.len()); + for i in 0..b.len() { + if filter.value(i) { + builder.append_value(b.value(i))?; + } + } + Ok(Arc::new(builder.finish())) + } + DataType::Int32 => { + let b = a.as_any().downcast_ref::().unwrap(); + let mut builder = Int32Array::builder(b.len()); + for i in 0..b.len() { + if filter.value(i) { + builder.append_value(b.value(i))?; + } + } + Ok(Arc::new(builder.finish())) + } + DataType::Int64 => { + let b = a.as_any().downcast_ref::().unwrap(); + let mut builder = Int64Array::builder(b.len()); + for i in 0..b.len() { + if filter.value(i) { + builder.append_value(b.value(i))?; + } + } + Ok(Arc::new(builder.finish())) + } + DataType::Float32 => { + let b = a.as_any().downcast_ref::().unwrap(); + let mut builder = Float32Array::builder(b.len()); + for i in 0..b.len() { + if filter.value(i) { + builder.append_value(b.value(i))?; + } + } + Ok(Arc::new(builder.finish())) + } + DataType::Float64 => { + let b = a.as_any().downcast_ref::().unwrap(); + let mut builder = Float64Array::builder(b.len()); + for i in 0..b.len() { + if filter.value(i) { + builder.append_value(b.value(i))?; + } + } + Ok(Arc::new(builder.finish())) + } + DataType::Utf8 => { + //TODO: this is inefficient and we should improve the Arrow impl to help make this more concise + let b = a.as_any().downcast_ref::().unwrap(); + let mut values: Vec = Vec::with_capacity(b.len()); + for i in 0..b.len() { + if filter.value(i) { + values.push(b.get_string(i)); + } + } + let tmp: Vec<&str> = values.iter().map(|s| s.as_str()).collect(); + Ok(Arc::new(BinaryArray::from(tmp))) + } + other => Err(ExecutionError::ExecutionError(format!( + "filter not supported for {:?}", + other + ))), + } +} diff --git a/rust/datafusion/src/execution/mod.rs b/rust/datafusion/src/execution/mod.rs new file mode 100644 index 0000000000000..23144bb5173ca --- /dev/null +++ b/rust/datafusion/src/execution/mod.rs @@ -0,0 +1,26 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +pub mod aggregate; +pub mod context; +pub mod datasource; +pub mod error; +pub mod expression; +pub mod filter; +pub mod physicalplan; +pub mod projection; +pub mod relation; diff --git a/rust/datafusion/src/execution/physicalplan.rs b/rust/datafusion/src/execution/physicalplan.rs new file mode 100644 index 0000000000000..23aa4312bfbf2 --- /dev/null +++ b/rust/datafusion/src/execution/physicalplan.rs @@ -0,0 +1,37 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use super::super::logicalplan::LogicalPlan; +use std::rc::Rc; + +#[derive(Serialize, Deserialize, Debug, Clone)] +pub enum PhysicalPlan { + /// Run a query and return the results to the client + Interactive { + plan: Rc, + }, + /// Execute a logical plan and write the output to a file + Write { + plan: Rc, + filename: String, + kind: String, + }, + Show { + plan: Rc, + count: usize, + }, +} diff --git a/rust/datafusion/src/execution/projection.rs b/rust/datafusion/src/execution/projection.rs new file mode 100644 index 0000000000000..6de5818ab1983 --- /dev/null +++ b/rust/datafusion/src/execution/projection.rs @@ -0,0 +1,130 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Execution of a projection + +use std::cell::RefCell; +use std::rc::Rc; +use std::sync::Arc; + +use arrow::array::ArrayRef; +use arrow::datatypes::{Field, Schema}; +use arrow::record_batch::RecordBatch; + +use super::error::Result; +use super::expression::RuntimeExpr; +use super::relation::Relation; + +pub struct ProjectRelation { + schema: Arc, + input: Rc>, + expr: Vec, +} + +impl ProjectRelation { + pub fn new( + input: Rc>, + expr: Vec, + schema: Arc, + ) -> Self { + ProjectRelation { + input, + expr, + schema, + } + } +} + +impl Relation for ProjectRelation { + fn next(&mut self) -> Result> { + match self.input.borrow_mut().next()? { + Some(batch) => { + let projected_columns: Result> = + self.expr.iter().map(|e| e.get_func()(&batch)).collect(); + + let schema = Schema::new( + self.expr + .iter() + .map(|e| Field::new(&e.get_name(), e.get_type(), true)) + .collect(), + ); + + let projected_batch: RecordBatch = + RecordBatch::new(Arc::new(schema), projected_columns?); + + Ok(Some(projected_batch)) + } + None => Ok(None), + } + } + + fn schema(&self) -> &Arc { + &self.schema + } +} + +#[cfg(test)] +mod tests { + use super::super::super::logicalplan::Expr; + use super::super::context::ExecutionContext; + use super::super::datasource::CsvDataSource; + use super::super::expression; + use super::super::relation::DataSourceRelation; + use super::*; + use arrow::datatypes::{DataType, Field, Schema}; + + #[test] + fn project_first_column() { + let schema = Arc::new(Schema::new(vec![ + Field::new("c1", DataType::Utf8, false), + Field::new("c2", DataType::UInt32, false), + Field::new("c3", DataType::Int8, false), + Field::new("c3", DataType::Int16, false), + Field::new("c4", DataType::Int32, false), + Field::new("c5", DataType::Int64, false), + Field::new("c6", DataType::UInt8, false), + Field::new("c7", DataType::UInt16, false), + Field::new("c8", DataType::UInt32, false), + Field::new("c9", DataType::UInt64, false), + Field::new("c10", DataType::Float32, false), + Field::new("c11", DataType::Float64, false), + Field::new("c12", DataType::Utf8, false), + ])); + let ds = CsvDataSource::new( + "../../testing/data/csv/aggregate_test_100.csv", + schema.clone(), + 1024, + ); + let relation = Rc::new(RefCell::new(DataSourceRelation::new(Rc::new( + RefCell::new(ds), + )))); + let context = ExecutionContext::new(); + + let projection_expr = + vec![ + expression::compile_expr(&context, &Expr::Column(0), schema.as_ref()) + .unwrap(), + ]; + + let mut projection = ProjectRelation::new(relation, projection_expr, schema); + let batch = projection.next().unwrap().unwrap(); + assert_eq!(1, batch.num_columns()); + + assert_eq!("c1", batch.schema().field(0).name()); + } + +} diff --git a/rust/datafusion/src/execution/relation.rs b/rust/datafusion/src/execution/relation.rs new file mode 100644 index 0000000000000..88eefabba67a8 --- /dev/null +++ b/rust/datafusion/src/execution/relation.rs @@ -0,0 +1,57 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::cell::RefCell; +use std::rc::Rc; +use std::sync::Arc; + +use arrow::datatypes::Schema; +use arrow::record_batch::RecordBatch; + +use super::datasource::DataSource; +use super::error::Result; + +/// trait for all relations (a relation is essentially just an iterator over rows with +/// a known schema) +pub trait Relation { + fn next(&mut self) -> Result>; + + /// get the schema for this relation + fn schema(&self) -> &Arc; +} + +pub struct DataSourceRelation { + schema: Arc, + ds: Rc>, +} + +impl DataSourceRelation { + pub fn new(ds: Rc>) -> Self { + let schema = ds.borrow().schema().clone(); + Self { ds, schema } + } +} + +impl Relation for DataSourceRelation { + fn next(&mut self) -> Result> { + self.ds.borrow_mut().next() + } + + fn schema(&self) -> &Arc { + &self.schema + } +} diff --git a/rust/datafusion/src/lib.rs b/rust/datafusion/src/lib.rs new file mode 100644 index 0000000000000..efcad5d886925 --- /dev/null +++ b/rust/datafusion/src/lib.rs @@ -0,0 +1,30 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! DataFusion is a modern distributed compute platform implemented in Rust that uses Apache Arrow +//! as the memory model + +extern crate arrow; +#[macro_use] +extern crate serde_derive; +extern crate serde_json; +extern crate sqlparser; + +pub mod dfparser; +pub mod execution; +pub mod logicalplan; +pub mod sqlplanner; diff --git a/rust/datafusion/src/logicalplan.rs b/rust/datafusion/src/logicalplan.rs new file mode 100644 index 0000000000000..b3e6bda545996 --- /dev/null +++ b/rust/datafusion/src/logicalplan.rs @@ -0,0 +1,650 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Logical query plan + +use std::fmt; +use std::fmt::{Error, Formatter}; +use std::rc::Rc; +use std::sync::Arc; + +use arrow::datatypes::*; + +#[derive(Serialize, Deserialize, Debug, Clone)] +pub enum FunctionType { + Scalar, + Aggregate, +} + +#[derive(Debug, Clone)] +pub struct FunctionMeta { + name: String, + args: Vec, + return_type: DataType, + function_type: FunctionType, +} + +impl FunctionMeta { + pub fn new( + name: String, + args: Vec, + return_type: DataType, + function_type: FunctionType, + ) -> Self { + FunctionMeta { + name, + args, + return_type, + function_type, + } + } + pub fn name(&self) -> &String { + &self.name + } + pub fn args(&self) -> &Vec { + &self.args + } + pub fn return_type(&self) -> &DataType { + &self.return_type + } + pub fn function_type(&self) -> &FunctionType { + &self.function_type + } +} + +#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq)] +pub enum Operator { + Eq, + NotEq, + Lt, + LtEq, + Gt, + GtEq, + Plus, + Minus, + Multiply, + Divide, + Modulus, + And, + Or, + Not, + Like, + NotLike, +} + +impl Operator { + /// Get the result type of applying this operation to its left and right inputs + pub fn get_datatype(&self, l: &Expr, _r: &Expr, schema: &Schema) -> DataType { + //TODO: implement correctly, just go with left side for now + l.get_type(schema).clone() + } +} + +/// ScalarValue enumeration +#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)] +pub enum ScalarValue { + Null, + Boolean(bool), + Float32(f32), + Float64(f64), + Int8(i8), + Int16(i16), + Int32(i32), + Int64(i64), + UInt8(u8), + UInt16(u16), + UInt32(u32), + UInt64(u64), + Utf8(Rc), + Struct(Vec), +} + +impl ScalarValue { + pub fn get_datatype(&self) -> DataType { + match *self { + ScalarValue::Boolean(_) => DataType::Boolean, + ScalarValue::UInt8(_) => DataType::UInt8, + ScalarValue::UInt16(_) => DataType::UInt16, + ScalarValue::UInt32(_) => DataType::UInt32, + ScalarValue::UInt64(_) => DataType::UInt64, + ScalarValue::Int8(_) => DataType::Int8, + ScalarValue::Int16(_) => DataType::Int16, + ScalarValue::Int32(_) => DataType::Int32, + ScalarValue::Int64(_) => DataType::Int64, + ScalarValue::Float32(_) => DataType::Float32, + ScalarValue::Float64(_) => DataType::Float64, + ScalarValue::Utf8(_) => DataType::Utf8, + ScalarValue::Struct(_) => unimplemented!(), + ScalarValue::Null => unimplemented!(), + } + } +} + +/// Relation Expression +#[derive(Serialize, Deserialize, Clone, PartialEq)] +pub enum Expr { + /// index into a value within the row or complex value + Column(usize), + /// literal value + Literal(ScalarValue), + /// binary expression e.g. "age > 21" + BinaryExpr { + left: Rc, + op: Operator, + right: Rc, + }, + /// unary IS NOT NULL + IsNotNull(Rc), + /// unary IS NULL + IsNull(Rc), + /// cast a value to a different type + Cast { expr: Rc, data_type: DataType }, + /// sort expression + Sort { expr: Rc, asc: bool }, + /// scalar function + ScalarFunction { + name: String, + args: Vec, + return_type: DataType, + }, + /// aggregate function + AggregateFunction { + name: String, + args: Vec, + return_type: DataType, + }, +} + +impl Expr { + pub fn get_type(&self, schema: &Schema) -> DataType { + match self { + Expr::Column(n) => schema.field(*n).data_type().clone(), + Expr::Literal(l) => l.get_datatype(), + Expr::Cast { data_type, .. } => data_type.clone(), + Expr::ScalarFunction { return_type, .. } => return_type.clone(), + Expr::AggregateFunction { return_type, .. } => return_type.clone(), + Expr::IsNull(_) => DataType::Boolean, + Expr::IsNotNull(_) => DataType::Boolean, + Expr::BinaryExpr { + ref left, + ref right, + ref op, + } => { + match op { + Operator::Eq | Operator::NotEq => DataType::Boolean, + Operator::Lt | Operator::LtEq => DataType::Boolean, + Operator::Gt | Operator::GtEq => DataType::Boolean, + Operator::And | Operator::Or => DataType::Boolean, + _ => { + let left_type = left.get_type(schema); + let right_type = right.get_type(schema); + get_supertype(&left_type, &right_type).unwrap_or(DataType::Utf8) //TODO ??? + } + } + } + Expr::Sort { ref expr, .. } => expr.get_type(schema), + } + } + + pub fn cast_to( + &self, + cast_to_type: &DataType, + schema: &Schema, + ) -> Result { + let this_type = self.get_type(schema); + if this_type == *cast_to_type { + Ok(self.clone()) + } else if can_coerce_from(cast_to_type, &this_type) { + Ok(Expr::Cast { + expr: Rc::new(self.clone()), + data_type: cast_to_type.clone(), + }) + } else { + Err(format!( + "Cannot automatically convert {:?} to {:?}", + this_type, cast_to_type + )) + } + } + + pub fn eq(&self, other: &Expr) -> Expr { + Expr::BinaryExpr { + left: Rc::new(self.clone()), + op: Operator::Eq, + right: Rc::new(other.clone()), + } + } + + pub fn not_eq(&self, other: &Expr) -> Expr { + Expr::BinaryExpr { + left: Rc::new(self.clone()), + op: Operator::NotEq, + right: Rc::new(other.clone()), + } + } + + pub fn gt(&self, other: &Expr) -> Expr { + Expr::BinaryExpr { + left: Rc::new(self.clone()), + op: Operator::Gt, + right: Rc::new(other.clone()), + } + } + + pub fn gt_eq(&self, other: &Expr) -> Expr { + Expr::BinaryExpr { + left: Rc::new(self.clone()), + op: Operator::GtEq, + right: Rc::new(other.clone()), + } + } + + pub fn lt(&self, other: &Expr) -> Expr { + Expr::BinaryExpr { + left: Rc::new(self.clone()), + op: Operator::Lt, + right: Rc::new(other.clone()), + } + } + + pub fn lt_eq(&self, other: &Expr) -> Expr { + Expr::BinaryExpr { + left: Rc::new(self.clone()), + op: Operator::LtEq, + right: Rc::new(other.clone()), + } + } +} + +impl fmt::Debug for Expr { + fn fmt(&self, f: &mut Formatter) -> Result<(), Error> { + match self { + Expr::Column(i) => write!(f, "#{}", i), + Expr::Literal(v) => write!(f, "{:?}", v), + Expr::Cast { expr, data_type } => { + write!(f, "CAST({:?} AS {:?})", expr, data_type) + } + Expr::IsNull(expr) => write!(f, "{:?} IS NULL", expr), + Expr::IsNotNull(expr) => write!(f, "{:?} IS NOT NULL", expr), + Expr::BinaryExpr { left, op, right } => { + write!(f, "{:?} {:?} {:?}", left, op, right) + } + Expr::Sort { expr, asc } => { + if *asc { + write!(f, "{:?} ASC", expr) + } else { + write!(f, "{:?} DESC", expr) + } + } + Expr::ScalarFunction { name, ref args, .. } => { + write!(f, "{}(", name)?; + for i in 0..args.len() { + if i > 0 { + write!(f, ", ")?; + } + write!(f, "{:?}", args[i])?; + } + + write!(f, ")") + } + Expr::AggregateFunction { name, ref args, .. } => { + write!(f, "{}(", name)?; + for i in 0..args.len() { + if i > 0 { + write!(f, ", ")?; + } + write!(f, "{:?}", args[i])?; + } + + write!(f, ")") + } + } + } +} + +/// The LogicalPlan represents different types of relations (such as Projection, Selection, etc) and +/// can be created by the SQL query planner and the DataFrame API. +#[derive(Serialize, Deserialize, Clone)] +pub enum LogicalPlan { + /// A Projection (essentially a SELECT with an expression list) + Projection { + expr: Vec, + input: Rc, + schema: Arc, + }, + /// A Selection (essentially a WHERE clause with a predicate expression) + Selection { expr: Expr, input: Rc }, + /// Represents a list of aggregate expressions with optional grouping expressions + Aggregate { + input: Rc, + group_expr: Vec, + aggr_expr: Vec, + schema: Arc, + }, + /// Represents a list of sort expressions to be applied to a relation + Sort { + expr: Vec, + input: Rc, + schema: Arc, + }, + /// A table scan against a table that has been registered on a context + TableScan { + schema_name: String, + table_name: String, + schema: Arc, + projection: Option>, + }, + /// An empty relation with an empty schema + EmptyRelation { schema: Arc }, +} + +impl LogicalPlan { + /// Get a reference to the logical plan's schema + pub fn schema(&self) -> &Arc { + match self { + LogicalPlan::EmptyRelation { schema } => &schema, + LogicalPlan::TableScan { schema, .. } => &schema, + LogicalPlan::Projection { schema, .. } => &schema, + LogicalPlan::Selection { input, .. } => input.schema(), + LogicalPlan::Aggregate { schema, .. } => &schema, + LogicalPlan::Sort { schema, .. } => &schema, + } + } +} + +impl LogicalPlan { + fn fmt_with_indent(&self, f: &mut Formatter, indent: usize) -> Result<(), Error> { + if indent > 0 { + writeln!(f)?; + for _ in 0..indent { + write!(f, " ")?; + } + } + match *self { + LogicalPlan::EmptyRelation { .. } => write!(f, "EmptyRelation"), + LogicalPlan::TableScan { + ref table_name, + ref projection, + .. + } => write!(f, "TableScan: {} projection={:?}", table_name, projection), + LogicalPlan::Projection { + ref expr, + ref input, + .. + } => { + write!(f, "Projection: ")?; + for i in 0..expr.len() { + if i > 0 { + write!(f, ", ")?; + } + write!(f, "{:?}", expr[i])?; + } + input.fmt_with_indent(f, indent + 1) + } + LogicalPlan::Selection { + ref expr, + ref input, + .. + } => { + write!(f, "Selection: {:?}", expr)?; + input.fmt_with_indent(f, indent + 1) + } + LogicalPlan::Aggregate { + ref input, + ref group_expr, + ref aggr_expr, + .. + } => { + write!( + f, + "Aggregate: groupBy=[{:?}], aggr=[{:?}]", + group_expr, aggr_expr + )?; + input.fmt_with_indent(f, indent + 1) + } + LogicalPlan::Sort { + ref input, + ref expr, + .. + } => { + write!(f, "Sort: ")?; + for i in 0..expr.len() { + if i > 0 { + write!(f, ", ")?; + } + write!(f, "{:?}", expr[i])?; + } + input.fmt_with_indent(f, indent + 1) + } + } + } +} + +impl fmt::Debug for LogicalPlan { + fn fmt(&self, f: &mut Formatter) -> Result<(), Error> { + self.fmt_with_indent(f, 0) + } +} + +//TODO move to Arrow DataType impl? +pub fn get_supertype(l: &DataType, r: &DataType) -> Option { + match _get_supertype(l, r) { + Some(dt) => Some(dt), + None => match _get_supertype(r, l) { + Some(dt) => Some(dt), + None => None, + }, + } +} + +fn _get_supertype(l: &DataType, r: &DataType) -> Option { + use self::DataType::*; + match (l, r) { + (UInt8, Int8) => Some(Int8), + (UInt8, Int16) => Some(Int16), + (UInt8, Int32) => Some(Int32), + (UInt8, Int64) => Some(Int64), + + (UInt16, Int16) => Some(Int16), + (UInt16, Int32) => Some(Int32), + (UInt16, Int64) => Some(Int64), + + (UInt32, Int32) => Some(Int32), + (UInt32, Int64) => Some(Int64), + + (UInt64, Int64) => Some(Int64), + + (Int8, UInt8) => Some(Int8), + + (Int16, UInt8) => Some(Int16), + (Int16, UInt16) => Some(Int16), + + (Int32, UInt8) => Some(Int32), + (Int32, UInt16) => Some(Int32), + (Int32, UInt32) => Some(Int32), + + (Int64, UInt8) => Some(Int64), + (Int64, UInt16) => Some(Int64), + (Int64, UInt32) => Some(Int64), + (Int64, UInt64) => Some(Int64), + + (UInt8, UInt8) => Some(UInt8), + (UInt8, UInt16) => Some(UInt16), + (UInt8, UInt32) => Some(UInt32), + (UInt8, UInt64) => Some(UInt64), + (UInt8, Float32) => Some(Float32), + (UInt8, Float64) => Some(Float64), + + (UInt16, UInt8) => Some(UInt16), + (UInt16, UInt16) => Some(UInt16), + (UInt16, UInt32) => Some(UInt32), + (UInt16, UInt64) => Some(UInt64), + (UInt16, Float32) => Some(Float32), + (UInt16, Float64) => Some(Float64), + + (UInt32, UInt8) => Some(UInt32), + (UInt32, UInt16) => Some(UInt32), + (UInt32, UInt32) => Some(UInt32), + (UInt32, UInt64) => Some(UInt64), + (UInt32, Float32) => Some(Float32), + (UInt32, Float64) => Some(Float64), + + (UInt64, UInt8) => Some(UInt64), + (UInt64, UInt16) => Some(UInt64), + (UInt64, UInt32) => Some(UInt64), + (UInt64, UInt64) => Some(UInt64), + (UInt64, Float32) => Some(Float32), + (UInt64, Float64) => Some(Float64), + + (Int8, Int8) => Some(Int8), + (Int8, Int16) => Some(Int16), + (Int8, Int32) => Some(Int32), + (Int8, Int64) => Some(Int64), + (Int8, Float32) => Some(Float32), + (Int8, Float64) => Some(Float64), + + (Int16, Int8) => Some(Int16), + (Int16, Int16) => Some(Int16), + (Int16, Int32) => Some(Int32), + (Int16, Int64) => Some(Int64), + (Int16, Float32) => Some(Float32), + (Int16, Float64) => Some(Float64), + + (Int32, Int8) => Some(Int32), + (Int32, Int16) => Some(Int32), + (Int32, Int32) => Some(Int32), + (Int32, Int64) => Some(Int64), + (Int32, Float32) => Some(Float32), + (Int32, Float64) => Some(Float64), + + (Int64, Int8) => Some(Int64), + (Int64, Int16) => Some(Int64), + (Int64, Int32) => Some(Int64), + (Int64, Int64) => Some(Int64), + (Int64, Float32) => Some(Float32), + (Int64, Float64) => Some(Float64), + + (Float32, Float32) => Some(Float32), + (Float32, Float64) => Some(Float64), + (Float64, Float32) => Some(Float64), + (Float64, Float64) => Some(Float64), + + (Utf8, Utf8) => Some(Utf8), + + (Boolean, Boolean) => Some(Boolean), + + _ => None, + } +} + +pub fn can_coerce_from(left: &DataType, other: &DataType) -> bool { + use self::DataType::*; + match left { + Int8 => match other { + Int8 => true, + _ => false, + }, + Int16 => match other { + Int8 | Int16 => true, + _ => false, + }, + Int32 => match other { + Int8 | Int16 | Int32 => true, + _ => false, + }, + Int64 => match other { + Int8 | Int16 | Int32 | Int64 => true, + _ => false, + }, + UInt8 => match other { + UInt8 => true, + _ => false, + }, + UInt16 => match other { + UInt8 | UInt16 => true, + _ => false, + }, + UInt32 => match other { + UInt8 | UInt16 | UInt32 => true, + _ => false, + }, + UInt64 => match other { + UInt8 | UInt16 | UInt32 | UInt64 => true, + _ => false, + }, + Float32 => match other { + Int8 | Int16 | Int32 | Int64 => true, + UInt8 | UInt16 | UInt32 | UInt64 => true, + Float32 => true, + _ => false, + }, + Float64 => match other { + Int8 | Int16 | Int32 | Int64 => true, + UInt8 | UInt16 | UInt32 | UInt64 => true, + Float32 | Float64 => true, + _ => false, + }, + _ => false, + } +} + +#[cfg(test)] +mod tests { + use super::*; + use serde_json; + + #[test] + fn serialize_plan() { + let schema = Schema::new(vec![ + Field::new("first_name", DataType::Utf8, false), + Field::new("last_name", DataType::Utf8, false), + Field::new( + "address", + DataType::Struct(vec![ + Field::new("street", DataType::Utf8, false), + Field::new("zip", DataType::UInt16, false), + ]), + false, + ), + ]); + + let plan = LogicalPlan::TableScan { + schema_name: "".to_string(), + table_name: "people".to_string(), + schema: Arc::new(schema), + projection: Some(vec![0, 1, 4]), + }; + + let serialized = serde_json::to_string(&plan).unwrap(); + + assert_eq!( + "{\"TableScan\":{\ + \"schema_name\":\"\",\ + \"table_name\":\"people\",\ + \"schema\":{\"fields\":[\ + {\"name\":\"first_name\",\"data_type\":\"Utf8\",\"nullable\":false},\ + {\"name\":\"last_name\",\"data_type\":\"Utf8\",\"nullable\":false},\ + {\"name\":\"address\",\"data_type\":{\"Struct\":\ + [\ + {\"name\":\"street\",\"data_type\":\"Utf8\",\"nullable\":false},\ + {\"name\":\"zip\",\"data_type\":\"UInt16\",\"nullable\":false}]},\"nullable\":false}\ + ]},\ + \"projection\":[0,1,4]}}", + serialized + ); + } +} diff --git a/rust/datafusion/src/sqlplanner.rs b/rust/datafusion/src/sqlplanner.rs new file mode 100644 index 0000000000000..dcb69ebc33f0a --- /dev/null +++ b/rust/datafusion/src/sqlplanner.rs @@ -0,0 +1,687 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! SQL Query Planner (produces logical plan from SQL AST) + +use std::collections::HashSet; +use std::rc::Rc; +use std::string::String; +use std::sync::Arc; + +use super::execution::error::*; +use super::logicalplan::*; + +use arrow::datatypes::*; + +use sqlparser::sqlast::*; + +pub trait SchemaProvider { + fn get_table_meta(&self, name: &str) -> Option>; + fn get_function_meta(&self, name: &str) -> Option>; +} + +/// SQL query planner +pub struct SqlToRel { + schema_provider: Rc, +} + +impl SqlToRel { + /// Create a new query planner + pub fn new(schema_provider: Rc) -> Self { + SqlToRel { schema_provider } + } + + /// Generate a logic plan from a SQL AST node + pub fn sql_to_rel(&self, sql: &ASTNode) -> Result> { + match sql { + &ASTNode::SQLSelect { + ref projection, + ref relation, + ref selection, + ref order_by, + ref group_by, + ref having, + .. + } => { + // parse the input relation so we have access to the row type + let input = match relation { + &Some(ref r) => self.sql_to_rel(r)?, + &None => Rc::new(LogicalPlan::EmptyRelation { + schema: Arc::new(Schema::empty()), + }), + }; + + let input_schema = input.schema(); + + // selection first + let selection_plan = match selection { + &Some(ref filter_expr) => Some(LogicalPlan::Selection { + expr: self.sql_to_rex(&filter_expr, &input_schema.clone())?, + input: input.clone(), + }), + _ => None, + }; + + let expr: Vec = projection + .iter() + .map(|e| self.sql_to_rex(&e, &input_schema)) + .collect::>>()?; + + // collect aggregate expressions + let aggr_expr: Vec = expr + .iter() + .filter(|e| match e { + Expr::AggregateFunction { .. } => true, + _ => false, + }) + .map(|e| e.clone()) + .collect(); + + if aggr_expr.len() > 0 { + let aggregate_input: Rc = match selection_plan { + Some(s) => Rc::new(s), + _ => input.clone(), + }; + + let group_expr: Vec = match group_by { + Some(gbe) => gbe + .iter() + .map(|e| self.sql_to_rex(&e, &input_schema)) + .collect::>>()?, + None => vec![], + }; + //println!("GROUP BY: {:?}", group_expr); + + let mut all_fields: Vec = group_expr.clone(); + aggr_expr.iter().for_each(|x| all_fields.push(x.clone())); + + let aggr_schema = + Schema::new(exprlist_to_fields(&all_fields, input_schema)); + + //TODO: selection, projection, everything else + Ok(Rc::new(LogicalPlan::Aggregate { + input: aggregate_input, + group_expr, + aggr_expr, + schema: Arc::new(aggr_schema), + })) + } else { + let projection_input: Rc = match selection_plan { + Some(s) => Rc::new(s), + _ => input.clone(), + }; + + let projection_schema = Arc::new(Schema::new(exprlist_to_fields( + &expr, + input_schema.as_ref(), + ))); + + let projection = LogicalPlan::Projection { + expr: expr, + input: projection_input, + schema: projection_schema.clone(), + }; + + if let &Some(_) = having { + return Err(ExecutionError::General( + "HAVING is not implemented yet".to_string(), + )); + } + + let order_by_plan = match order_by { + &Some(ref order_by_expr) => { + let input_schema = projection.schema(); + let order_by_rex: Result> = order_by_expr + .iter() + .map(|e| { + Ok(Expr::Sort { + expr: Rc::new( + self.sql_to_rex(&e.expr, &input_schema) + .unwrap(), + ), + asc: e.asc, + }) + }) + .collect(); + + LogicalPlan::Sort { + expr: order_by_rex?, + input: Rc::new(projection.clone()), + schema: input_schema.clone(), + } + } + _ => projection, + }; + + Ok(Rc::new(order_by_plan)) + } + } + + &ASTNode::SQLIdentifier(ref id) => { + match self.schema_provider.get_table_meta(id.as_ref()) { + Some(schema) => Ok(Rc::new(LogicalPlan::TableScan { + schema_name: String::from("default"), + table_name: id.clone(), + schema: schema.clone(), + projection: None, + })), + None => Err(ExecutionError::General(format!( + "no schema found for table {}", + id + ))), + } + } + + _ => Err(ExecutionError::ExecutionError(format!( + "sql_to_rel does not support this relation: {:?}", + sql + ))), + } + } + + /// Generate a relational expression from a SQL expression + pub fn sql_to_rex(&self, sql: &ASTNode, schema: &Schema) -> Result { + match sql { + &ASTNode::SQLValue(sqlparser::sqlast::Value::Long(n)) => { + Ok(Expr::Literal(ScalarValue::Int64(n))) + } + &ASTNode::SQLValue(sqlparser::sqlast::Value::Double(n)) => { + Ok(Expr::Literal(ScalarValue::Float64(n))) + } + &ASTNode::SQLValue(sqlparser::sqlast::Value::SingleQuotedString(ref s)) => { + Ok(Expr::Literal(ScalarValue::Utf8(Rc::new(s.clone())))) + } + + &ASTNode::SQLIdentifier(ref id) => { + match schema.fields().iter().position(|c| c.name().eq(id)) { + Some(index) => Ok(Expr::Column(index)), + None => Err(ExecutionError::ExecutionError(format!( + "Invalid identifier '{}' for schema {}", + id, + schema.to_string() + ))), + } + } + + &ASTNode::SQLWildcard => { + // schema.columns().iter().enumerate() + // .map(|(i,c)| Ok(Expr::Column(i))).collect() + unimplemented!("SQL wildcard operator is not supported in projection - please use explicit column names") + } + + &ASTNode::SQLCast { + ref expr, + ref data_type, + } => Ok(Expr::Cast { + expr: Rc::new(self.sql_to_rex(&expr, schema)?), + data_type: convert_data_type(data_type)?, + }), + + &ASTNode::SQLIsNull(ref expr) => { + Ok(Expr::IsNull(Rc::new(self.sql_to_rex(expr, schema)?))) + } + + &ASTNode::SQLIsNotNull(ref expr) => { + Ok(Expr::IsNotNull(Rc::new(self.sql_to_rex(expr, schema)?))) + } + + &ASTNode::SQLBinaryExpr { + ref left, + ref op, + ref right, + } => { + let operator = match op { + &SQLOperator::Gt => Operator::Gt, + &SQLOperator::GtEq => Operator::GtEq, + &SQLOperator::Lt => Operator::Lt, + &SQLOperator::LtEq => Operator::LtEq, + &SQLOperator::Eq => Operator::Eq, + &SQLOperator::NotEq => Operator::NotEq, + &SQLOperator::Plus => Operator::Plus, + &SQLOperator::Minus => Operator::Minus, + &SQLOperator::Multiply => Operator::Multiply, + &SQLOperator::Divide => Operator::Divide, + &SQLOperator::Modulus => Operator::Modulus, + &SQLOperator::And => Operator::And, + &SQLOperator::Or => Operator::Or, + &SQLOperator::Not => Operator::Not, + &SQLOperator::Like => Operator::Like, + &SQLOperator::NotLike => Operator::NotLike, + }; + + let left_expr = self.sql_to_rex(&left, &schema)?; + let right_expr = self.sql_to_rex(&right, &schema)?; + let left_type = left_expr.get_type(schema); + let right_type = right_expr.get_type(schema); + + match get_supertype(&left_type, &right_type) { + Some(supertype) => Ok(Expr::BinaryExpr { + left: Rc::new(left_expr.cast_to(&supertype, schema)?), + op: operator, + right: Rc::new(right_expr.cast_to(&supertype, schema)?), + }), + None => { + return Err(ExecutionError::General(format!( + "No common supertype found for binary operator {:?} \ + with input types {:?} and {:?}", + operator, left_type, right_type + ))); + } + } + } + + // &ASTNode::SQLOrderBy { ref expr, asc } => Ok(Expr::Sort { + // expr: Rc::new(self.sql_to_rex(&expr, &schema)?), + // asc, + // }), + &ASTNode::SQLFunction { ref id, ref args } => { + //TODO: fix this hack + match id.to_lowercase().as_ref() { + "min" | "max" | "sum" | "avg" => { + let rex_args = args + .iter() + .map(|a| self.sql_to_rex(a, schema)) + .collect::>>()?; + + // return type is same as the argument type for these aggregate functions + let return_type = rex_args[0].get_type(schema).clone(); + + Ok(Expr::AggregateFunction { + name: id.clone(), + args: rex_args, + return_type, + }) + } + "count" => { + let rex_args = args + .iter() + .map(|a| match a { + // this feels hacky but translate COUNT(1)/COUNT(*) to COUNT(first_column) + ASTNode::SQLValue(sqlparser::sqlast::Value::Long(1)) => { + Ok(Expr::Column(0)) + } + ASTNode::SQLWildcard => Ok(Expr::Column(0)), + _ => self.sql_to_rex(a, schema), + }) + .collect::>>()?; + + Ok(Expr::AggregateFunction { + name: id.clone(), + args: rex_args, + return_type: DataType::UInt64, + }) + } + _ => match self.schema_provider.get_function_meta(id) { + Some(fm) => { + let rex_args = args + .iter() + .map(|a| self.sql_to_rex(a, schema)) + .collect::>>()?; + + let mut safe_args: Vec = vec![]; + for i in 0..rex_args.len() { + safe_args.push( + rex_args[i] + .cast_to(fm.args()[i].data_type(), schema)?, + ); + } + + Ok(Expr::ScalarFunction { + name: id.clone(), + args: safe_args, + return_type: fm.return_type().clone(), + }) + } + _ => Err(ExecutionError::General(format!( + "Invalid function '{}'", + id + ))), + }, + } + } + + _ => Err(ExecutionError::General(format!( + "Unsupported ast node {:?} in sqltorel", + sql + ))), + } + } +} + +/// Convert SQL data type to relational representation of data type +pub fn convert_data_type(sql: &SQLType) -> Result { + match sql { + SQLType::Boolean => Ok(DataType::Boolean), + SQLType::SmallInt => Ok(DataType::Int16), + SQLType::Int => Ok(DataType::Int32), + SQLType::BigInt => Ok(DataType::Int64), + SQLType::Float(_) | SQLType::Real => Ok(DataType::Float64), + SQLType::Double => Ok(DataType::Float64), + SQLType::Char(_) | SQLType::Varchar(_) => Ok(DataType::Utf8), + other => Err(ExecutionError::NotImplemented(format!( + "Unsupported SQL type {:?}", + other + ))), + } +} + +pub fn expr_to_field(e: &Expr, input_schema: &Schema) -> Field { + match e { + Expr::Column(i) => input_schema.fields()[*i].clone(), + Expr::Literal(ref lit) => Field::new("lit", lit.get_datatype(), true), + Expr::ScalarFunction { + ref name, + ref return_type, + .. + } => Field::new(name, return_type.clone(), true), + Expr::AggregateFunction { + ref name, + ref return_type, + .. + } => Field::new(name, return_type.clone(), true), + Expr::Cast { ref data_type, .. } => Field::new("cast", data_type.clone(), true), + Expr::BinaryExpr { + ref left, + ref right, + .. + } => { + let left_type = left.get_type(input_schema); + let right_type = right.get_type(input_schema); + Field::new( + "binary_expr", + get_supertype(&left_type, &right_type).unwrap(), + true, + ) + } + _ => unimplemented!("Cannot determine schema type for expression {:?}", e), + } +} + +pub fn exprlist_to_fields(expr: &Vec, input_schema: &Schema) -> Vec { + expr.iter() + .map(|e| expr_to_field(e, input_schema)) + .collect() +} + +fn collect_expr(e: &Expr, accum: &mut HashSet) { + match e { + Expr::Column(i) => { + accum.insert(*i); + } + Expr::Cast { ref expr, .. } => collect_expr(expr, accum), + Expr::Literal(_) => {} + Expr::IsNotNull(ref expr) => collect_expr(expr, accum), + Expr::IsNull(ref expr) => collect_expr(expr, accum), + Expr::BinaryExpr { + ref left, + ref right, + .. + } => { + collect_expr(left, accum); + collect_expr(right, accum); + } + Expr::AggregateFunction { ref args, .. } => { + args.iter().for_each(|e| collect_expr(e, accum)); + } + Expr::ScalarFunction { ref args, .. } => { + args.iter().for_each(|e| collect_expr(e, accum)); + } + Expr::Sort { ref expr, .. } => collect_expr(expr, accum), + } +} + +pub fn push_down_projection( + plan: &Rc, + projection: &HashSet, +) -> Rc { + //println!("push_down_projection() projection={:?}", projection); + match plan.as_ref() { + LogicalPlan::Aggregate { + ref input, + ref group_expr, + ref aggr_expr, + ref schema, + } => { + //TODO: apply projection first + let mut accum: HashSet = HashSet::new(); + group_expr.iter().for_each(|e| collect_expr(e, &mut accum)); + aggr_expr.iter().for_each(|e| collect_expr(e, &mut accum)); + Rc::new(LogicalPlan::Aggregate { + input: push_down_projection(&input, &accum), + group_expr: group_expr.clone(), + aggr_expr: aggr_expr.clone(), + schema: schema.clone(), + }) + } + LogicalPlan::Selection { + ref expr, + ref input, + } => { + let mut accum: HashSet = projection.clone(); + collect_expr(expr, &mut accum); + Rc::new(LogicalPlan::Selection { + expr: expr.clone(), + input: push_down_projection(&input, &accum), + }) + } + LogicalPlan::TableScan { + ref schema_name, + ref table_name, + ref schema, + .. + } => Rc::new(LogicalPlan::TableScan { + schema_name: schema_name.to_string(), + table_name: table_name.to_string(), + schema: schema.clone(), + projection: Some(projection.iter().cloned().collect()), + }), + LogicalPlan::Projection { .. } => plan.clone(), + LogicalPlan::Sort { .. } => plan.clone(), + LogicalPlan::EmptyRelation { .. } => plan.clone(), + } +} + +#[cfg(test)] +mod tests { + + use super::*; + use sqlparser::sqlparser::*; + + #[test] + fn select_no_relation() { + quick_test( + "SELECT 1", + "Projection: Int64(1)\ + \n EmptyRelation", + ); + } + + #[test] + fn select_scalar_func_with_literal_no_relation() { + quick_test( + "SELECT sqrt(9)", + "Projection: sqrt(CAST(Int64(9) AS Float64))\ + \n EmptyRelation", + ); + } + + #[test] + fn select_simple_selection() { + let sql = "SELECT id, first_name, last_name \ + FROM person WHERE state = 'CO'"; + let expected = "Projection: #0, #1, #2\ + \n Selection: #4 Eq Utf8(\"CO\")\ + \n TableScan: person projection=None"; + quick_test(sql, expected); + } + + #[test] + fn select_compound_selection() { + let sql = "SELECT id, first_name, last_name \ + FROM person WHERE state = 'CO' AND age >= 21 AND age <= 65"; + let expected = + "Projection: #0, #1, #2\ + \n Selection: #4 Eq Utf8(\"CO\") And CAST(#3 AS Int64) GtEq Int64(21) And CAST(#3 AS Int64) LtEq Int64(65)\ + \n TableScan: person projection=None"; + quick_test(sql, expected); + } + + #[test] + fn select_all_boolean_operators() { + let sql = "SELECT age, first_name, last_name \ + FROM person \ + WHERE age = 21 \ + AND age != 21 \ + AND age > 21 \ + AND age >= 21 \ + AND age < 65 \ + AND age <= 65"; + let expected = "Projection: #3, #1, #2\ + \n Selection: CAST(#3 AS Int64) Eq Int64(21) \ + And CAST(#3 AS Int64) NotEq Int64(21) \ + And CAST(#3 AS Int64) Gt Int64(21) \ + And CAST(#3 AS Int64) GtEq Int64(21) \ + And CAST(#3 AS Int64) Lt Int64(65) \ + And CAST(#3 AS Int64) LtEq Int64(65)\ + \n TableScan: person projection=None"; + quick_test(sql, expected); + } + + #[test] + fn select_simple_aggregate() { + quick_test( + "SELECT MIN(age) FROM person", + "Aggregate: groupBy=[[]], aggr=[[MIN(#3)]]\ + \n TableScan: person projection=None", + ); + } + + #[test] + fn test_sum_aggregate() { + quick_test( + "SELECT SUM(age) from person", + "Aggregate: groupBy=[[]], aggr=[[SUM(#3)]]\ + \n TableScan: person projection=None", + ); + } + + #[test] + fn select_simple_aggregate_with_groupby() { + quick_test( + "SELECT state, MIN(age), MAX(age) FROM person GROUP BY state", + "Aggregate: groupBy=[[#4]], aggr=[[MIN(#3), MAX(#3)]]\ + \n TableScan: person projection=None", + ); + } + + #[test] + fn select_count_one() { + let sql = "SELECT COUNT(1) FROM person"; + let expected = "Aggregate: groupBy=[[]], aggr=[[COUNT(#0)]]\ + \n TableScan: person projection=None"; + quick_test(sql, expected); + } + + #[test] + fn select_scalar_func() { + let sql = "SELECT sqrt(age) FROM person"; + let expected = "Projection: sqrt(CAST(#3 AS Float64))\ + \n TableScan: person projection=None"; + quick_test(sql, expected); + } + + #[test] + fn select_order_by() { + let sql = "SELECT id FROM person ORDER BY id"; + let expected = "Sort: #0 ASC\ + \n Projection: #0\ + \n TableScan: person projection=None"; + quick_test(sql, expected); + } + + #[test] + fn select_order_by_desc() { + let sql = "SELECT id FROM person ORDER BY id DESC"; + let expected = "Sort: #0 DESC\ + \n Projection: #0\ + \n TableScan: person projection=None"; + quick_test(sql, expected); + } + + #[test] + fn test_collect_expr() { + let mut accum: HashSet = HashSet::new(); + collect_expr( + &Expr::Cast { + expr: Rc::new(Expr::Column(3)), + data_type: DataType::Float64, + }, + &mut accum, + ); + collect_expr( + &Expr::Cast { + expr: Rc::new(Expr::Column(3)), + data_type: DataType::Float64, + }, + &mut accum, + ); + println!("accum: {:?}", accum); + assert_eq!(1, accum.len()); + assert!(accum.contains(&3)); + } + + /// Create logical plan, write with formatter, compare to expected output + fn quick_test(sql: &str, expected: &str) { + use sqlparser::dialect::*; + let dialect = GenericSqlDialect {}; + let planner = SqlToRel::new(Rc::new(MockSchemaProvider {})); + let ast = Parser::parse_sql(&dialect, sql.to_string()).unwrap(); + let plan = planner.sql_to_rel(&ast).unwrap(); + assert_eq!(expected, format!("{:?}", plan)); + } + + struct MockSchemaProvider {} + + impl SchemaProvider for MockSchemaProvider { + fn get_table_meta(&self, name: &str) -> Option> { + match name { + "person" => Some(Arc::new(Schema::new(vec![ + Field::new("id", DataType::UInt32, false), + Field::new("first_name", DataType::Utf8, false), + Field::new("last_name", DataType::Utf8, false), + Field::new("age", DataType::Int32, false), + Field::new("state", DataType::Utf8, false), + Field::new("salary", DataType::Float64, false), + ]))), + _ => None, + } + } + + fn get_function_meta(&self, name: &str) -> Option> { + match name { + "sqrt" => Some(Arc::new(FunctionMeta::new( + "sqrt".to_string(), + vec![Field::new("n", DataType::Float64, false)], + DataType::Float64, + FunctionType::Scalar, + ))), + _ => None, + } + } + } + +} diff --git a/rust/datafusion/tests/sql.rs b/rust/datafusion/tests/sql.rs new file mode 100644 index 0000000000000..bd228087dcc6e --- /dev/null +++ b/rust/datafusion/tests/sql.rs @@ -0,0 +1,191 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::cell::RefCell; +use std::rc::Rc; +use std::sync::Arc; + +extern crate arrow; +extern crate datafusion; + +use arrow::array::*; +use arrow::datatypes::{DataType, Field, Schema}; + +use datafusion::execution::context::ExecutionContext; +use datafusion::execution::datasource::CsvDataSource; +use datafusion::execution::relation::Relation; + +#[test] +fn csv_query_with_predicate() { + let mut ctx = ExecutionContext::new(); + register_aggregate_csv(&mut ctx); + let sql = "SELECT c1, c12 FROM aggregate_test_100 WHERE c12 > 0.376 AND c12 < 0.4"; + let actual = execute(&mut ctx, sql); + let expected = "\"e\"\t0.39144436569161134\n\"d\"\t0.38870280983958583\n".to_string(); + assert_eq!(expected, actual); +} + +#[test] +fn csv_query_group_by_int_min_max() { + let mut ctx = ExecutionContext::new(); + register_aggregate_csv(&mut ctx); + //TODO add ORDER BY once supported, to make this test determistic + let sql = "SELECT c2, MIN(c12), MAX(c12) FROM aggregate_test_100 GROUP BY c2"; + let actual = execute(&mut ctx, sql); + let expected = "4\t0.02182578039211991\t0.9237877978193884\n2\t0.16301110515739792\t0.991517828651004\n5\t0.01479305307777301\t0.9723580396501548\n3\t0.047343434291126085\t0.9293883502480845\n1\t0.05636955101974106\t0.9965400387585364\n".to_string(); + assert_eq!(expected, actual); +} + +#[test] +fn csv_query_group_by_string_min_max() { + let mut ctx = ExecutionContext::new(); + register_aggregate_csv(&mut ctx); + //TODO add ORDER BY once supported, to make this test determistic + let sql = "SELECT c2, MIN(c12), MAX(c12) FROM aggregate_test_100 GROUP BY c1"; + let actual = execute(&mut ctx, sql); + let expected = + "\"d\"\t0.061029375346466685\t0.9748360509016578\n\"c\"\t0.0494924465469434\t0.991517828651004\n\"b\"\t0.04893135681998029\t0.9185813970744787\n\"a\"\t0.02182578039211991\t0.9800193410444061\n\"e\"\t0.01479305307777301\t0.9965400387585364\n".to_string(); + assert_eq!(expected, actual); +} + +#[test] +fn csv_query_cast() { + let mut ctx = ExecutionContext::new(); + register_aggregate_csv(&mut ctx); + let sql = "SELECT CAST(c12 AS float) FROM aggregate_test_100 WHERE c12 > 0.376 AND c12 < 0.4"; + let actual = execute(&mut ctx, sql); + let expected = "0.39144436569161134\n0.38870280983958583\n".to_string(); + assert_eq!(expected, actual); +} + +fn aggr_test_schema() -> Arc { + Arc::new(Schema::new(vec![ + Field::new("c1", DataType::Utf8, false), + Field::new("c2", DataType::UInt32, false), + Field::new("c3", DataType::Int8, false), + Field::new("c4", DataType::Int16, false), + Field::new("c5", DataType::Int32, false), + Field::new("c6", DataType::Int64, false), + Field::new("c7", DataType::UInt8, false), + Field::new("c8", DataType::UInt16, false), + Field::new("c9", DataType::UInt32, false), + Field::new("c10", DataType::UInt64, false), + Field::new("c11", DataType::Float32, false), + Field::new("c12", DataType::Float64, false), + Field::new("c13", DataType::Utf8, false), + ])) +} + +fn register_aggregate_csv(ctx: &mut ExecutionContext) { + let schema = aggr_test_schema(); + register_csv( + ctx, + "aggregate_test_100", + "../../testing/data/csv/aggregate_test_100.csv", + &schema, + ); +} + +fn register_csv( + ctx: &mut ExecutionContext, + name: &str, + filename: &str, + schema: &Arc, +) { + let csv_datasource = CsvDataSource::new(filename, schema.clone(), 1024); + ctx.register_datasource(name, Rc::new(RefCell::new(csv_datasource))); +} + +/// Execute query and return result set as tab delimited string +fn execute(ctx: &mut ExecutionContext, sql: &str) -> String { + let results = ctx.sql(&sql).unwrap(); + result_str(&results) +} + +fn result_str(results: &Rc>) -> String { + let mut relation = results.borrow_mut(); + let mut str = String::new(); + while let Some(batch) = relation.next().unwrap() { + for row_index in 0..batch.num_rows() { + for column_index in 0..batch.num_columns() { + if column_index > 0 { + str.push_str("\t"); + } + let column = batch.column(column_index); + + match column.data_type() { + DataType::Int8 => { + let array = column.as_any().downcast_ref::().unwrap(); + str.push_str(&format!("{:?}", array.value(row_index))); + } + DataType::Int16 => { + let array = column.as_any().downcast_ref::().unwrap(); + str.push_str(&format!("{:?}", array.value(row_index))); + } + DataType::Int32 => { + let array = column.as_any().downcast_ref::().unwrap(); + str.push_str(&format!("{:?}", array.value(row_index))); + } + DataType::Int64 => { + let array = column.as_any().downcast_ref::().unwrap(); + str.push_str(&format!("{:?}", array.value(row_index))); + } + DataType::UInt8 => { + let array = column.as_any().downcast_ref::().unwrap(); + str.push_str(&format!("{:?}", array.value(row_index))); + } + DataType::UInt16 => { + let array = + column.as_any().downcast_ref::().unwrap(); + str.push_str(&format!("{:?}", array.value(row_index))); + } + DataType::UInt32 => { + let array = + column.as_any().downcast_ref::().unwrap(); + str.push_str(&format!("{:?}", array.value(row_index))); + } + DataType::UInt64 => { + let array = + column.as_any().downcast_ref::().unwrap(); + str.push_str(&format!("{:?}", array.value(row_index))); + } + DataType::Float32 => { + let array = + column.as_any().downcast_ref::().unwrap(); + str.push_str(&format!("{:?}", array.value(row_index))); + } + DataType::Float64 => { + let array = + column.as_any().downcast_ref::().unwrap(); + str.push_str(&format!("{:?}", array.value(row_index))); + } + DataType::Utf8 => { + let array = + column.as_any().downcast_ref::().unwrap(); + let s = + String::from_utf8(array.value(row_index).to_vec()).unwrap(); + + str.push_str(&format!("{:?}", s)); + } + _ => str.push_str("???"), + } + } + str.push_str("\n"); + } + } + str +} diff --git a/testing b/testing index 6ee39a9d17b09..bf0abe442bf7e 160000 --- a/testing +++ b/testing @@ -1 +1 @@ -Subproject commit 6ee39a9d17b0902b7372c22b9ff823304c69d709 +Subproject commit bf0abe442bf7e313380452c8972692940f4e56b6