From 9b4b47157144a8e5d006975ecd4a33a552ecf824 Mon Sep 17 00:00:00 2001 From: Renat Valiullin Date: Sun, 3 Feb 2019 20:08:59 -0600 Subject: [PATCH 01/21] ARROW-4454: [C++] fix unused parameter warnings Author: Renat Valiullin Closes #3544 from rip-nsk/ARROW-4454 and squashes the following commits: 984fc623f fix unused parameter warnings --- cpp/src/arrow/array/builder_primitive.h | 2 +- cpp/src/parquet/exception.h | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/cpp/src/arrow/array/builder_primitive.h b/cpp/src/arrow/array/builder_primitive.h index d17a13013ceae..bf3ec914bec5a 100644 --- a/cpp/src/arrow/array/builder_primitive.h +++ b/cpp/src/arrow/array/builder_primitive.h @@ -37,7 +37,7 @@ class ARROW_EXPORT NullBuilder : public ArrayBuilder { return Status::OK(); } - Status Append(std::nullptr_t value) { return AppendNull(); } + Status Append(std::nullptr_t) { return AppendNull(); } Status FinishInternal(std::shared_ptr* out) override; }; diff --git a/cpp/src/parquet/exception.h b/cpp/src/parquet/exception.h index 65e12af47a7c2..90f6d0302d311 100644 --- a/cpp/src/parquet/exception.h +++ b/cpp/src/parquet/exception.h @@ -77,7 +77,7 @@ class ParquetException : public std::exception { explicit ParquetException(const std::string& msg) : msg_(msg) {} - explicit ParquetException(const char* msg, std::exception& e) : msg_(msg) {} + explicit ParquetException(const char* msg, std::exception&) : msg_(msg) {} ~ParquetException() throw() override {} From 9bd5294277e516c0f7a9caaa9ff2f1346f7297b3 Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Mon, 4 Feb 2019 16:03:31 +0100 Subject: [PATCH 02/21] ARROW-4459: [Testing] Add arrow-testing repo as submodule This PR adds the arrow-testing repo as a submodule in the directory `testing`. Author: Andy Grove Closes #3547 from andygrove/ARROW-4459 and squashes the following commits: 522ac269 Add arrow-testing submodule --- .gitmodules | 3 +++ testing | 1 + 2 files changed, 4 insertions(+) create mode 160000 testing diff --git a/.gitmodules b/.gitmodules index 71722b21777e6..6efc4871542cb 100644 --- a/.gitmodules +++ b/.gitmodules @@ -1,3 +1,6 @@ [submodule "cpp/submodules/parquet-testing"] path = cpp/submodules/parquet-testing url = https://github.com/apache/parquet-testing.git +[submodule "testing"] + path = testing + url = https://github.com/apache/arrow-testing diff --git a/testing b/testing new file mode 160000 index 0000000000000..6ee39a9d17b09 --- /dev/null +++ b/testing @@ -0,0 +1 @@ +Subproject commit 6ee39a9d17b0902b7372c22b9ff823304c69d709 From 9f34a417a37f6d1f7061cf793292123d9749f9a9 Mon Sep 17 00:00:00 2001 From: Micah Kornfield Date: Mon, 4 Feb 2019 16:28:58 +0100 Subject: [PATCH 03/21] ARROW-4436: [Documentation] Update building.rst to reflect pyarrow req Make note of the fact that pyarrow has to be installed for sphinx to successfully build the documentation. Author: Micah Kornfield Closes #3539 from emkornfield/update_build_doc and squashes the following commits: 8a3e5418 Fix install instructions 9468abe5 convert to note --- docs/source/building.rst | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/docs/source/building.rst b/docs/source/building.rst index c6ff97424fcfb..2239a197e1fc8 100644 --- a/docs/source/building.rst +++ b/docs/source/building.rst @@ -59,7 +59,15 @@ These two steps are mandatory and must be executed in order. doxygen popd -#. Build the complete documentation using Sphinx +#. Build the complete documentation using Sphinx. + + .. note:: + + This step requires the the pyarrow library is installed + in your python environment. One way to accomplish + this is to follow the build instructions at :ref:`development` + and then run `python setup.py install` in arrow/python + (it is best to do this in a dedicated conda/virtual environment). .. code-block:: shell From 07ab9cf29ac65e119975d285e5d68be09e28562e Mon Sep 17 00:00:00 2001 From: Eric Erhardt Date: Mon, 4 Feb 2019 16:34:57 +0100 Subject: [PATCH 04/21] ARROW-4435: Minor fixups to csharp .sln and .csproj file 1. Add .sln file to the repo. 2. Minor clean up of .csproj to move properties out of the Configuration/Platform PropertyGroups, which is the standard way of setting properties in SDK style projects. Author: Eric Erhardt Closes #3529 from eerhardt/SolutionProjectFixUp and squashes the following commits: 88a7c6ef Add RAT exclude for csharp sln file e9036b38 Minor fixups to csharp .sln and .csproj file --- csharp/.gitignore | 5 +++- csharp/Apache.Arrow.sln | 31 +++++++++++++++++++++ csharp/src/Apache.Arrow/Apache.Arrow.csproj | 14 ++-------- dev/release/rat_exclude_files.txt | 1 + 4 files changed, 39 insertions(+), 12 deletions(-) create mode 100644 csharp/Apache.Arrow.sln diff --git a/csharp/.gitignore b/csharp/.gitignore index e5b411f791302..42835683969ae 100644 --- a/csharp/.gitignore +++ b/csharp/.gitignore @@ -261,4 +261,7 @@ __pycache__/ *.pyc # Project-specific -artifacts/ \ No newline at end of file +artifacts/ + +# add .sln files back because they are ignored by the root .gitignore file +!*.sln \ No newline at end of file diff --git a/csharp/Apache.Arrow.sln b/csharp/Apache.Arrow.sln new file mode 100644 index 0000000000000..53b463c27136a --- /dev/null +++ b/csharp/Apache.Arrow.sln @@ -0,0 +1,31 @@ + +Microsoft Visual Studio Solution File, Format Version 12.00 +# Visual Studio 15 +VisualStudioVersion = 15.0.28307.357 +MinimumVisualStudioVersion = 10.0.40219.1 +Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Apache.Arrow", "src\Apache.Arrow\Apache.Arrow.csproj", "{BA6B2B0D-EAAE-4183-8A39-1B9CF571F71F}" +EndProject +Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Apache.Arrow.Tests", "test\Apache.Arrow.Tests\Apache.Arrow.Tests.csproj", "{9CCEC01B-E67A-4726-BE72-7B514F76163F}" +EndProject +Global + GlobalSection(SolutionConfigurationPlatforms) = preSolution + Debug|Any CPU = Debug|Any CPU + Release|Any CPU = Release|Any CPU + EndGlobalSection + GlobalSection(ProjectConfigurationPlatforms) = postSolution + {BA6B2B0D-EAAE-4183-8A39-1B9CF571F71F}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {BA6B2B0D-EAAE-4183-8A39-1B9CF571F71F}.Debug|Any CPU.Build.0 = Debug|Any CPU + {BA6B2B0D-EAAE-4183-8A39-1B9CF571F71F}.Release|Any CPU.ActiveCfg = Release|Any CPU + {BA6B2B0D-EAAE-4183-8A39-1B9CF571F71F}.Release|Any CPU.Build.0 = Release|Any CPU + {9CCEC01B-E67A-4726-BE72-7B514F76163F}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {9CCEC01B-E67A-4726-BE72-7B514F76163F}.Debug|Any CPU.Build.0 = Debug|Any CPU + {9CCEC01B-E67A-4726-BE72-7B514F76163F}.Release|Any CPU.ActiveCfg = Release|Any CPU + {9CCEC01B-E67A-4726-BE72-7B514F76163F}.Release|Any CPU.Build.0 = Release|Any CPU + EndGlobalSection + GlobalSection(SolutionProperties) = preSolution + HideSolutionNode = FALSE + EndGlobalSection + GlobalSection(ExtensibilityGlobals) = postSolution + SolutionGuid = {FD0BB617-6031-4844-B99D-B331E335B572} + EndGlobalSection +EndGlobal diff --git a/csharp/src/Apache.Arrow/Apache.Arrow.csproj b/csharp/src/Apache.Arrow/Apache.Arrow.csproj index adc21c9edc07b..c2d73ec3b1709 100644 --- a/csharp/src/Apache.Arrow/Apache.Arrow.csproj +++ b/csharp/src/Apache.Arrow/Apache.Arrow.csproj @@ -1,9 +1,11 @@ - + netstandard1.3 + 7.2 + true Apache Apache Arrow library 2018 Apache Software Foundation @@ -15,16 +17,6 @@ 0.0.1 - - 7.2 - true - - - - true - 7.2 - - diff --git a/dev/release/rat_exclude_files.txt b/dev/release/rat_exclude_files.txt index 4866ec2aa3c30..6bd62c417c703 100644 --- a/dev/release/rat_exclude_files.txt +++ b/dev/release/rat_exclude_files.txt @@ -164,6 +164,7 @@ c_glib/gtk-doc.make csharp/.gitattributes csharp/src/Apache.Arrow/Flatbuf/* csharp/build/Common.props +csharp/Apache.Arrow.sln csharp/src/Apache.Arrow/Apache.Arrow.csproj csharp/src/Apache.Arrow/Properties/Resources.Designer.cs csharp/src/Apache.Arrow/Properties/Resources.resx From 7f96b6feb1d12d2bccba70b7c1aadc52ff4e337e Mon Sep 17 00:00:00 2001 From: Yuhong Guo Date: Mon, 4 Feb 2019 16:39:40 +0100 Subject: [PATCH 05/21] ARROW-3481: [Java] Fix java building failure with Maven 3.5.4 This problem happens in some rare case. When there are arrow lib build by old maven, but now current maven is updated to 3.5.4. This problem could happen and raise a building failure. Author: Yuhong Guo Closes #2738 from guoyuhong/fixJavaBuild and squashes the following commits: 74665109 Fix java build with Maven 3.5.4 --- java/pom.xml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/java/pom.xml b/java/pom.xml index 9093bfa46d7db..57005b9622d64 100644 --- a/java/pom.xml +++ b/java/pom.xml @@ -368,7 +368,7 @@ maven-enforcer-plugin - 3.0.0-M1 + 3.0.0-M2 maven-surefire-plugin From 29f14cac4ed37afc139295565fe51533f6d2fde9 Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Mon, 4 Feb 2019 12:12:13 -0600 Subject: [PATCH 06/21] ARROW-4263: [Rust] Donate DataFusion This PR is to donate the DataFusion source code (assuming that the vote passes!) Author: Andy Grove Closes #3399 from andygrove/ARROW-4263 and squashes the following commits: 990d06f09 formatting 6603091a3 update path again, update testing submodule 38fa63b22 remove test csv file, update tests to use test data from new testing submodule 16e4cffea remove test csv file, update tests to use test data from new testing submodule 91f6e90d4 update example to use new data file 4ebeee59f formatting ae88a90ba convert tests to use new test data file that was randomly generated d7bea8e19 update test to use uk_cities.csv and remove people.csv 061d788b5 remove unused test data files f60e50d9f remove unused test data files, manually recreate uk_cities.csv because I can't trace where the original data came from 28d914a2e Update 00-prepare.sh to handle datafusion versioning c4e1a2614 DataFusion Donation --- dev/release/00-prepare.sh | 6 +- rust/Cargo.toml | 1 + rust/datafusion/Cargo.toml | 50 + rust/datafusion/README.md | 94 ++ rust/datafusion/examples/csv_sql.rs | 101 ++ rust/datafusion/src/dfparser.rs | 220 +++ rust/datafusion/src/execution/aggregate.rs | 1214 +++++++++++++++++ rust/datafusion/src/execution/context.rs | 228 ++++ rust/datafusion/src/execution/datasource.rs | 74 + rust/datafusion/src/execution/error.rs | 69 + rust/datafusion/src/execution/expression.rs | 516 +++++++ rust/datafusion/src/execution/filter.rs | 211 +++ rust/datafusion/src/execution/mod.rs | 26 + rust/datafusion/src/execution/physicalplan.rs | 37 + rust/datafusion/src/execution/projection.rs | 130 ++ rust/datafusion/src/execution/relation.rs | 57 + rust/datafusion/src/lib.rs | 30 + rust/datafusion/src/logicalplan.rs | 650 +++++++++ rust/datafusion/src/sqlplanner.rs | 687 ++++++++++ rust/datafusion/tests/sql.rs | 191 +++ testing | 2 +- 21 files changed, 4590 insertions(+), 4 deletions(-) create mode 100644 rust/datafusion/Cargo.toml create mode 100644 rust/datafusion/README.md create mode 100644 rust/datafusion/examples/csv_sql.rs create mode 100644 rust/datafusion/src/dfparser.rs create mode 100644 rust/datafusion/src/execution/aggregate.rs create mode 100644 rust/datafusion/src/execution/context.rs create mode 100644 rust/datafusion/src/execution/datasource.rs create mode 100644 rust/datafusion/src/execution/error.rs create mode 100644 rust/datafusion/src/execution/expression.rs create mode 100644 rust/datafusion/src/execution/filter.rs create mode 100644 rust/datafusion/src/execution/mod.rs create mode 100644 rust/datafusion/src/execution/physicalplan.rs create mode 100644 rust/datafusion/src/execution/projection.rs create mode 100644 rust/datafusion/src/execution/relation.rs create mode 100644 rust/datafusion/src/lib.rs create mode 100644 rust/datafusion/src/logicalplan.rs create mode 100644 rust/datafusion/src/sqlplanner.rs create mode 100644 rust/datafusion/tests/sql.rs diff --git a/dev/release/00-prepare.sh b/dev/release/00-prepare.sh index bfcfc83825499..90d318145ca06 100755 --- a/dev/release/00-prepare.sh +++ b/dev/release/00-prepare.sh @@ -100,9 +100,9 @@ update_versions() { cd "${SOURCE_DIR}/../../rust" sed -i.bak -E -e \ "s/^version = \".+\"/version = \"${version}\"/g" \ - arrow/Cargo.toml parquet/Cargo.toml - rm -f arrow/Cargo.toml.bak parquet/Cargo.toml.bak - git add arrow/Cargo.toml parquet/Cargo.toml + arrow/Cargo.toml parquet/Cargo.toml datafusion/Cargo.toml + rm -f arrow/Cargo.toml.bak parquet/Cargo.toml.bak datafusion/Cargo.toml.bak + git add arrow/Cargo.toml parquet/Cargo.toml datafusion/Cargo.toml # Update version number for parquet README sed -i.bak -E -e \ diff --git a/rust/Cargo.toml b/rust/Cargo.toml index abfb71ada7951..c415ab7c8d7d1 100644 --- a/rust/Cargo.toml +++ b/rust/Cargo.toml @@ -19,4 +19,5 @@ members = [ "arrow", "parquet", + "datafusion", ] \ No newline at end of file diff --git a/rust/datafusion/Cargo.toml b/rust/datafusion/Cargo.toml new file mode 100644 index 0000000000000..864243bce4cc1 --- /dev/null +++ b/rust/datafusion/Cargo.toml @@ -0,0 +1,50 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +[package] +name = "datafusion" +description = "DataFusion is an in-memory query engine that uses Apache Arrow as the memory model" +version = "0.13.0-SNAPSHOT" +homepage = "https://github.com/apache/arrow" +repository = "https://github.com/apache/arrow" +authors = ["Apache Arrow "] +license = "Apache-2.0" +keywords = [ "arrow", "query", "sql" ] +include = [ + "src/**/*.rs", + "Cargo.toml", +] +edition = "2018" + +[lib] +name = "datafusion" +path = "src/lib.rs" + +[dependencies] +clap = "2.31.2" +fnv = "1.0.3" +arrow = { path = "../arrow" } +parquet = { path = "../parquet" } +datafusion-rustyline = "2.0.0-alpha-20180628" +serde = { version = "1.0.80", features = ["alloc", "rc"] } +serde_derive = "1.0.80" +serde_json = "1.0.33" +sqlparser = "0.2.0" + +[dev-dependencies] +criterion = "0.2.0" + diff --git a/rust/datafusion/README.md b/rust/datafusion/README.md new file mode 100644 index 0000000000000..eade95e063536 --- /dev/null +++ b/rust/datafusion/README.md @@ -0,0 +1,94 @@ + + +# DataFusion + +DataFusion is an in-memory query engine that uses Apache Arrow as the memory model + +# Status + +The current code supports single-threaded execution of limited SQL queries (projection, selection, and aggregates) against CSV files. Parquet files will be supported shortly. + +Here is a brief example for running a SQL query against a CSV file. See the [examples](examples) directory for full examples. + +```rust +fn main() { + // create local execution context + let mut ctx = ExecutionContext::new(); + + // define schema for data source (csv file) + let schema = Arc::new(Schema::new(vec![ + Field::new("city", DataType::Utf8, false), + Field::new("lat", DataType::Float64, false), + Field::new("lng", DataType::Float64, false), + ])); + + // register csv file with the execution context + let csv_datasource = CsvDataSource::new("../../testing/data/csv/uk_cities.csv", schema.clone(), 1024); + ctx.register_datasource("cities", Rc::new(RefCell::new(csv_datasource))); + + // simple projection and selection + let sql = "SELECT city, lat, lng FROM cities WHERE lat > 51.0 AND lat < 53"; + + // execute the query + let relation = ctx.sql(&sql).unwrap(); + + // display the relation + let mut results = relation.borrow_mut(); + + while let Some(batch) = results.next().unwrap() { + + println!( + "RecordBatch has {} rows and {} columns", + batch.num_rows(), + batch.num_columns() + ); + + let city = batch + .column(0) + .as_any() + .downcast_ref::() + .unwrap(); + + let lat = batch + .column(1) + .as_any() + .downcast_ref::() + .unwrap(); + + let lng = batch + .column(2) + .as_any() + .downcast_ref::() + .unwrap(); + + for i in 0..batch.num_rows() { + let city_name: String = String::from_utf8(city.get_value(i).to_vec()).unwrap(); + + println!( + "City: {}, Latitude: {}, Longitude: {}", + city_name, + lat.value(i), + lng.value(i), + ); + } + } +} +``` + diff --git a/rust/datafusion/examples/csv_sql.rs b/rust/datafusion/examples/csv_sql.rs new file mode 100644 index 0000000000000..40959cbb624f6 --- /dev/null +++ b/rust/datafusion/examples/csv_sql.rs @@ -0,0 +1,101 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::cell::RefCell; +use std::rc::Rc; +use std::sync::Arc; + +extern crate arrow; +extern crate datafusion; + +use arrow::array::{BinaryArray, Float64Array}; +use arrow::datatypes::{DataType, Field, Schema}; + +use datafusion::execution::context::ExecutionContext; +use datafusion::execution::datasource::CsvDataSource; + +/// This example demonstrates executing a simple query against an Arrow data source and fetching results +fn main() { + // create local execution context + let mut ctx = ExecutionContext::new(); + + // define schema for data source (csv file) + let schema = Arc::new(Schema::new(vec![ + Field::new("c1", DataType::Utf8, false), + Field::new("c2", DataType::UInt32, false), + Field::new("c3", DataType::Int8, false), + Field::new("c4", DataType::Int16, false), + Field::new("c5", DataType::Int32, false), + Field::new("c6", DataType::Int64, false), + Field::new("c7", DataType::UInt8, false), + Field::new("c8", DataType::UInt16, false), + Field::new("c9", DataType::UInt32, false), + Field::new("c10", DataType::UInt64, false), + Field::new("c11", DataType::Float32, false), + Field::new("c12", DataType::Float64, false), + Field::new("c13", DataType::Utf8, false), + ])); + + // register csv file with the execution context + let csv_datasource = CsvDataSource::new( + "../../testing/data/csv/aggregate_test_100.csv", + schema.clone(), + 1024, + ); + ctx.register_datasource("aggregate_test_100", Rc::new(RefCell::new(csv_datasource))); + + // simple projection and selection + let sql = "SELECT c1, MIN(c12), MAX(c12) FROM aggregate_test_100 WHERE c11 > 0.1 AND c11 < 0.9 GROUP BY c1"; + + // execute the query + let relation = ctx.sql(&sql).unwrap(); + + // display the relation + let mut results = relation.borrow_mut(); + + while let Some(batch) = results.next().unwrap() { + println!( + "RecordBatch has {} rows and {} columns", + batch.num_rows(), + batch.num_columns() + ); + + let c1 = batch + .column(0) + .as_any() + .downcast_ref::() + .unwrap(); + + let min = batch + .column(1) + .as_any() + .downcast_ref::() + .unwrap(); + + let max = batch + .column(2) + .as_any() + .downcast_ref::() + .unwrap(); + + for i in 0..batch.num_rows() { + let c1_value: String = String::from_utf8(c1.value(i).to_vec()).unwrap(); + + println!("{}, Min: {}, Max: {}", c1_value, min.value(i), max.value(i),); + } + } +} diff --git a/rust/datafusion/src/dfparser.rs b/rust/datafusion/src/dfparser.rs new file mode 100644 index 0000000000000..7abbd8dcf10bd --- /dev/null +++ b/rust/datafusion/src/dfparser.rs @@ -0,0 +1,220 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! SQL Parser +//! +//! Note that most SQL parsing is now delegated to the sqlparser crate, which handles ANSI SQL but +//! this module contains DataFusion-specific SQL extensions. + +use sqlparser::dialect::*; +use sqlparser::sqlast::*; +use sqlparser::sqlparser::*; +use sqlparser::sqltokenizer::*; + +macro_rules! parser_err { + ($MSG:expr) => { + Err(ParserError::ParserError($MSG.to_string())) + }; +} + +#[derive(Debug, Clone)] +pub enum FileType { + NdJson, + Parquet, + CSV, +} + +#[derive(Debug, Clone)] +pub enum DFASTNode { + /// ANSI SQL AST node + ANSI(ASTNode), + /// DDL for creating an external table in DataFusion + CreateExternalTable { + /// Table name + name: String, + /// Optional schema + columns: Vec, + /// File type (Parquet, NDJSON, CSV) + file_type: FileType, + /// Header row? + header_row: bool, + /// Path to file + location: String, + }, +} + +/// SQL Parser +pub struct DFParser { + parser: Parser, +} + +impl DFParser { + /// Parse the specified tokens + pub fn new(sql: String) -> Result { + let dialect = GenericSqlDialect {}; + let mut tokenizer = Tokenizer::new(&dialect, &sql); + let tokens = tokenizer.tokenize()?; + Ok(DFParser { + parser: Parser::new(tokens), + }) + } + + /// Parse a SQL statement and produce an Abstract Syntax Tree (AST) + pub fn parse_sql(sql: String) -> Result { + let mut parser = DFParser::new(sql)?; + parser.parse() + } + + /// Parse a new expression + pub fn parse(&mut self) -> Result { + self.parse_expr(0) + } + + /// Parse tokens until the precedence changes + fn parse_expr(&mut self, precedence: u8) -> Result { + let mut expr = self.parse_prefix()?; + loop { + let next_precedence = self.parser.get_next_precedence()?; + if precedence >= next_precedence { + break; + } + + if let Some(infix_expr) = self.parse_infix(expr.clone(), next_precedence)? { + expr = infix_expr; + } + } + Ok(expr) + } + + /// Parse an expression prefix + fn parse_prefix(&mut self) -> Result { + if self + .parser + .parse_keywords(vec!["CREATE", "EXTERNAL", "TABLE"]) + { + match self.parser.next_token() { + Some(Token::Identifier(id)) => { + // parse optional column list (schema) + let mut columns = vec![]; + if self.parser.consume_token(&Token::LParen) { + loop { + if let Some(Token::Identifier(column_name)) = + self.parser.next_token() + { + if let Ok(data_type) = self.parser.parse_data_type() { + let allow_null = if self + .parser + .parse_keywords(vec!["NOT", "NULL"]) + { + false + } else if self.parser.parse_keyword("NULL") { + true + } else { + true + }; + + match self.parser.peek_token() { + Some(Token::Comma) => { + self.parser.next_token(); + columns.push(SQLColumnDef { + name: column_name, + data_type: data_type, + allow_null, + default: None, + is_primary: false, + is_unique: false, + }); + } + Some(Token::RParen) => break, + _ => { + return parser_err!( + "Expected ',' or ')' after column definition" + ); + } + } + } else { + return parser_err!( + "Error parsing data type in column definition" + ); + } + } else { + return parser_err!("Error parsing column name"); + } + } + } + + //println!("Parsed {} column defs", columns.len()); + + let mut headers = true; + let file_type: FileType = if self + .parser + .parse_keywords(vec!["STORED", "AS", "CSV"]) + { + if self.parser.parse_keywords(vec!["WITH", "HEADER", "ROW"]) { + headers = true; + } else if self + .parser + .parse_keywords(vec!["WITHOUT", "HEADER", "ROW"]) + { + headers = false; + } + FileType::CSV + } else if self.parser.parse_keywords(vec!["STORED", "AS", "NDJSON"]) { + FileType::NdJson + } else if self.parser.parse_keywords(vec!["STORED", "AS", "PARQUET"]) + { + FileType::Parquet + } else { + return parser_err!(format!( + "Expected 'STORED AS' clause, found {:?}", + self.parser.peek_token() + )); + }; + + let location: String = if self.parser.parse_keywords(vec!["LOCATION"]) + { + self.parser.parse_literal_string()? + } else { + return parser_err!("Missing 'LOCATION' clause"); + }; + + Ok(DFASTNode::CreateExternalTable { + name: id, + columns, + file_type, + header_row: headers, + location, + }) + } + _ => parser_err!(format!( + "Unexpected token after CREATE EXTERNAL TABLE: {:?}", + self.parser.peek_token() + )), + } + } else { + Ok(DFASTNode::ANSI(self.parser.parse_prefix()?)) + } + } + + pub fn parse_infix( + &mut self, + _expr: DFASTNode, + _precedence: u8, + ) -> Result, ParserError> { + unimplemented!() + } +} diff --git a/rust/datafusion/src/execution/aggregate.rs b/rust/datafusion/src/execution/aggregate.rs new file mode 100644 index 0000000000000..5acf5fb63a188 --- /dev/null +++ b/rust/datafusion/src/execution/aggregate.rs @@ -0,0 +1,1214 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Execution of a simple aggregate relation containing MIN, MAX, COUNT, SUM aggregate functions +//! with optional GROUP BY columns + +use std::cell::RefCell; +use std::rc::Rc; +use std::str; +use std::sync::Arc; + +use arrow::array::*; +use arrow::array_ops; +use arrow::builder::*; +use arrow::datatypes::{DataType, Schema}; +use arrow::record_batch::RecordBatch; + +use super::error::{ExecutionError, Result}; +use super::expression::{AggregateType, RuntimeExpr}; +use super::relation::Relation; +use crate::logicalplan::ScalarValue; + +use fnv::FnvHashMap; + +/// An aggregate relation is made up of zero or more grouping expressions and one +/// or more aggregate expressions +pub struct AggregateRelation { + schema: Arc, + input: Rc>, + group_expr: Vec, + aggr_expr: Vec, + end_of_results: bool, +} + +impl AggregateRelation { + pub fn new( + schema: Arc, + input: Rc>, + group_expr: Vec, + aggr_expr: Vec, + ) -> Self { + AggregateRelation { + schema, + input, + group_expr, + aggr_expr, + end_of_results: false, + } + } +} + +/// Enumeration of types that can be used in a GROUP BY expression (all primitives except for +/// floating point numerics) +#[derive(Debug, PartialEq, Eq, Hash, Clone)] +enum GroupByScalar { + UInt8(u8), + UInt16(u16), + UInt32(u32), + UInt64(u64), + Int8(i8), + Int16(i16), + Int32(i32), + Int64(i64), + Utf8(String), +} + +/// Common trait for all aggregation functions +trait AggregateFunction { + /// Get the function name (used for debugging) + fn name(&self) -> &str; + fn accumulate_scalar(&mut self, value: &Option); + fn result(&self) -> &Option; + fn data_type(&self) -> &DataType; +} + +#[derive(Debug)] +struct MinFunction { + data_type: DataType, + value: Option, +} + +impl MinFunction { + fn new(data_type: &DataType) -> Self { + Self { + data_type: data_type.clone(), + value: None, + } + } +} + +impl AggregateFunction for MinFunction { + fn name(&self) -> &str { + "min" + } + + fn accumulate_scalar(&mut self, value: &Option) { + if self.value.is_none() { + self.value = value.clone(); + } else if value.is_some() { + self.value = match (&self.value, value) { + (Some(ScalarValue::UInt8(a)), Some(ScalarValue::UInt8(b))) => { + Some(ScalarValue::UInt8(*a.min(b))) + } + (Some(ScalarValue::UInt16(a)), Some(ScalarValue::UInt16(b))) => { + Some(ScalarValue::UInt16(*a.min(b))) + } + (Some(ScalarValue::UInt32(a)), Some(ScalarValue::UInt32(b))) => { + Some(ScalarValue::UInt32(*a.min(b))) + } + (Some(ScalarValue::UInt64(a)), Some(ScalarValue::UInt64(b))) => { + Some(ScalarValue::UInt64(*a.min(b))) + } + (Some(ScalarValue::Int8(a)), Some(ScalarValue::Int8(b))) => { + Some(ScalarValue::Int8(*a.min(b))) + } + (Some(ScalarValue::Int16(a)), Some(ScalarValue::Int16(b))) => { + Some(ScalarValue::Int16(*a.min(b))) + } + (Some(ScalarValue::Int32(a)), Some(ScalarValue::Int32(b))) => { + Some(ScalarValue::Int32(*a.min(b))) + } + (Some(ScalarValue::Int64(a)), Some(ScalarValue::Int64(b))) => { + Some(ScalarValue::Int64(*a.min(b))) + } + (Some(ScalarValue::Float32(a)), Some(ScalarValue::Float32(b))) => { + Some(ScalarValue::Float32(a.min(*b))) + } + (Some(ScalarValue::Float64(a)), Some(ScalarValue::Float64(b))) => { + Some(ScalarValue::Float64(a.min(*b))) + } + _ => panic!("unsupported data type for MIN"), + } + } + } + + fn result(&self) -> &Option { + &self.value + } + + fn data_type(&self) -> &DataType { + &self.data_type + } +} + +#[derive(Debug)] +struct MaxFunction { + data_type: DataType, + value: Option, +} + +impl MaxFunction { + fn new(data_type: &DataType) -> Self { + Self { + data_type: data_type.clone(), + value: None, + } + } +} + +impl AggregateFunction for MaxFunction { + fn name(&self) -> &str { + "max" + } + + fn accumulate_scalar(&mut self, value: &Option) { + if self.value.is_none() { + self.value = value.clone(); + } else if value.is_some() { + self.value = match (&self.value, value) { + (Some(ScalarValue::UInt8(a)), Some(ScalarValue::UInt8(b))) => { + Some(ScalarValue::UInt8(*a.max(b))) + } + (Some(ScalarValue::UInt16(a)), Some(ScalarValue::UInt16(b))) => { + Some(ScalarValue::UInt16(*a.max(b))) + } + (Some(ScalarValue::UInt32(a)), Some(ScalarValue::UInt32(b))) => { + Some(ScalarValue::UInt32(*a.max(b))) + } + (Some(ScalarValue::UInt64(a)), Some(ScalarValue::UInt64(b))) => { + Some(ScalarValue::UInt64(*a.max(b))) + } + (Some(ScalarValue::Int8(a)), Some(ScalarValue::Int8(b))) => { + Some(ScalarValue::Int8(*a.max(b))) + } + (Some(ScalarValue::Int16(a)), Some(ScalarValue::Int16(b))) => { + Some(ScalarValue::Int16(*a.max(b))) + } + (Some(ScalarValue::Int32(a)), Some(ScalarValue::Int32(b))) => { + Some(ScalarValue::Int32(*a.max(b))) + } + (Some(ScalarValue::Int64(a)), Some(ScalarValue::Int64(b))) => { + Some(ScalarValue::Int64(*a.max(b))) + } + (Some(ScalarValue::Float32(a)), Some(ScalarValue::Float32(b))) => { + Some(ScalarValue::Float32(a.max(*b))) + } + (Some(ScalarValue::Float64(a)), Some(ScalarValue::Float64(b))) => { + Some(ScalarValue::Float64(a.max(*b))) + } + _ => panic!("unsupported data type for MAX"), + } + } + } + + fn result(&self) -> &Option { + &self.value + } + + fn data_type(&self) -> &DataType { + &self.data_type + } +} + +#[derive(Debug)] +struct SumFunction { + data_type: DataType, + value: Option, +} + +impl SumFunction { + fn new(data_type: &DataType) -> Self { + Self { + data_type: data_type.clone(), + value: None, + } + } +} + +impl AggregateFunction for SumFunction { + fn name(&self) -> &str { + "sum" + } + + fn accumulate_scalar(&mut self, value: &Option) { + if self.value.is_none() { + self.value = value.clone(); + } else if value.is_some() { + self.value = match (&self.value, value) { + (Some(ScalarValue::UInt8(a)), Some(ScalarValue::UInt8(b))) => { + Some(ScalarValue::UInt8(*a + b)) + } + (Some(ScalarValue::UInt16(a)), Some(ScalarValue::UInt16(b))) => { + Some(ScalarValue::UInt16(*a + b)) + } + (Some(ScalarValue::UInt32(a)), Some(ScalarValue::UInt32(b))) => { + Some(ScalarValue::UInt32(*a + b)) + } + (Some(ScalarValue::UInt64(a)), Some(ScalarValue::UInt64(b))) => { + Some(ScalarValue::UInt64(*a + b)) + } + (Some(ScalarValue::Int8(a)), Some(ScalarValue::Int8(b))) => { + Some(ScalarValue::Int8(*a + b)) + } + (Some(ScalarValue::Int16(a)), Some(ScalarValue::Int16(b))) => { + Some(ScalarValue::Int16(*a + b)) + } + (Some(ScalarValue::Int32(a)), Some(ScalarValue::Int32(b))) => { + Some(ScalarValue::Int32(*a + b)) + } + (Some(ScalarValue::Int64(a)), Some(ScalarValue::Int64(b))) => { + Some(ScalarValue::Int64(*a + b)) + } + (Some(ScalarValue::Float32(a)), Some(ScalarValue::Float32(b))) => { + Some(ScalarValue::Float32(a + *b)) + } + (Some(ScalarValue::Float64(a)), Some(ScalarValue::Float64(b))) => { + Some(ScalarValue::Float64(a + *b)) + } + _ => panic!("unsupported data type for SUM"), + } + } + } + + fn result(&self) -> &Option { + &self.value + } + + fn data_type(&self) -> &DataType { + &self.data_type + } +} + +struct AccumulatorSet { + aggr_values: Vec>>, +} + +impl AccumulatorSet { + fn accumulate_scalar(&mut self, i: usize, value: Option) { + let mut accumulator = self.aggr_values[i].borrow_mut(); + accumulator.accumulate_scalar(&value); + } + + fn values(&self) -> Vec> { + self.aggr_values + .iter() + .map(|x| x.borrow().result().clone()) + .collect() + } +} + +#[derive(Debug)] +struct MapEntry { + k: Vec, + v: Vec>, +} + +/// Create an initial aggregate entry +fn create_accumulators(aggr_expr: &Vec) -> Result { + let aggr_values: Vec>> = aggr_expr + .iter() + .map(|e| match e { + RuntimeExpr::AggregateFunction { ref f, ref t, .. } => match f { + AggregateType::Min => Ok(Rc::new(RefCell::new(MinFunction::new(t))) + as Rc>), + AggregateType::Max => Ok(Rc::new(RefCell::new(MaxFunction::new(t))) + as Rc>), + AggregateType::Sum => Ok(Rc::new(RefCell::new(SumFunction::new(t))) + as Rc>), + _ => Err(ExecutionError::ExecutionError( + "unsupported aggregate function".to_string(), + )), + }, + _ => Err(ExecutionError::ExecutionError( + "invalid aggregate expression".to_string(), + )), + }) + .collect::>>>>()?; + + Ok(AccumulatorSet { aggr_values }) +} + +fn array_min(array: ArrayRef, dt: &DataType) -> Result> { + match dt { + DataType::UInt8 => { + match array_ops::min(array.as_any().downcast_ref::().unwrap()) { + Some(n) => Ok(Some(ScalarValue::UInt8(n))), + None => Ok(None), + } + } + DataType::UInt16 => { + match array_ops::min(array.as_any().downcast_ref::().unwrap()) { + Some(n) => Ok(Some(ScalarValue::UInt16(n))), + None => Ok(None), + } + } + DataType::UInt32 => { + match array_ops::min(array.as_any().downcast_ref::().unwrap()) { + Some(n) => Ok(Some(ScalarValue::UInt32(n))), + None => Ok(None), + } + } + DataType::UInt64 => { + match array_ops::min(array.as_any().downcast_ref::().unwrap()) { + Some(n) => Ok(Some(ScalarValue::UInt64(n))), + None => Ok(None), + } + } + DataType::Int8 => { + match array_ops::min(array.as_any().downcast_ref::().unwrap()) { + Some(n) => Ok(Some(ScalarValue::Int8(n))), + None => Ok(None), + } + } + DataType::Int16 => { + match array_ops::min(array.as_any().downcast_ref::().unwrap()) { + Some(n) => Ok(Some(ScalarValue::Int16(n))), + None => Ok(None), + } + } + DataType::Int32 => { + match array_ops::min(array.as_any().downcast_ref::().unwrap()) { + Some(n) => Ok(Some(ScalarValue::Int32(n))), + None => Ok(None), + } + } + DataType::Int64 => { + match array_ops::min(array.as_any().downcast_ref::().unwrap()) { + Some(n) => Ok(Some(ScalarValue::Int64(n))), + None => Ok(None), + } + } + DataType::Float32 => { + match array_ops::min(array.as_any().downcast_ref::().unwrap()) { + Some(n) => Ok(Some(ScalarValue::Float32(n))), + None => Ok(None), + } + } + DataType::Float64 => { + match array_ops::min(array.as_any().downcast_ref::().unwrap()) { + Some(n) => Ok(Some(ScalarValue::Float64(n))), + None => Ok(None), + } + } + _ => Err(ExecutionError::ExecutionError( + "Unsupported data type for MIN".to_string(), + )), + } +} + +fn array_max(array: ArrayRef, dt: &DataType) -> Result> { + match dt { + DataType::UInt8 => { + match array_ops::max(array.as_any().downcast_ref::().unwrap()) { + Some(n) => Ok(Some(ScalarValue::UInt8(n))), + None => Ok(None), + } + } + DataType::UInt16 => { + match array_ops::max(array.as_any().downcast_ref::().unwrap()) { + Some(n) => Ok(Some(ScalarValue::UInt16(n))), + None => Ok(None), + } + } + DataType::UInt32 => { + match array_ops::max(array.as_any().downcast_ref::().unwrap()) { + Some(n) => Ok(Some(ScalarValue::UInt32(n))), + None => Ok(None), + } + } + DataType::UInt64 => { + match array_ops::max(array.as_any().downcast_ref::().unwrap()) { + Some(n) => Ok(Some(ScalarValue::UInt64(n))), + None => Ok(None), + } + } + DataType::Int8 => { + match array_ops::max(array.as_any().downcast_ref::().unwrap()) { + Some(n) => Ok(Some(ScalarValue::Int8(n))), + None => Ok(None), + } + } + DataType::Int16 => { + match array_ops::max(array.as_any().downcast_ref::().unwrap()) { + Some(n) => Ok(Some(ScalarValue::Int16(n))), + None => Ok(None), + } + } + DataType::Int32 => { + match array_ops::max(array.as_any().downcast_ref::().unwrap()) { + Some(n) => Ok(Some(ScalarValue::Int32(n))), + None => Ok(None), + } + } + DataType::Int64 => { + match array_ops::max(array.as_any().downcast_ref::().unwrap()) { + Some(n) => Ok(Some(ScalarValue::Int64(n))), + None => Ok(None), + } + } + DataType::Float32 => { + match array_ops::max(array.as_any().downcast_ref::().unwrap()) { + Some(n) => Ok(Some(ScalarValue::Float32(n))), + None => Ok(None), + } + } + DataType::Float64 => { + match array_ops::max(array.as_any().downcast_ref::().unwrap()) { + Some(n) => Ok(Some(ScalarValue::Float64(n))), + None => Ok(None), + } + } + _ => Err(ExecutionError::ExecutionError( + "Unsupported data type for MAX".to_string(), + )), + } +} + +fn array_sum(array: ArrayRef, dt: &DataType) -> Result> { + match dt { + DataType::UInt8 => { + match array_ops::sum(array.as_any().downcast_ref::().unwrap()) { + Some(n) => Ok(Some(ScalarValue::UInt8(n))), + None => Ok(None), + } + } + DataType::UInt16 => { + match array_ops::sum(array.as_any().downcast_ref::().unwrap()) { + Some(n) => Ok(Some(ScalarValue::UInt16(n))), + None => Ok(None), + } + } + DataType::UInt32 => { + match array_ops::sum(array.as_any().downcast_ref::().unwrap()) { + Some(n) => Ok(Some(ScalarValue::UInt32(n))), + None => Ok(None), + } + } + DataType::UInt64 => { + match array_ops::sum(array.as_any().downcast_ref::().unwrap()) { + Some(n) => Ok(Some(ScalarValue::UInt64(n))), + None => Ok(None), + } + } + DataType::Int8 => { + match array_ops::sum(array.as_any().downcast_ref::().unwrap()) { + Some(n) => Ok(Some(ScalarValue::Int8(n))), + None => Ok(None), + } + } + DataType::Int16 => { + match array_ops::sum(array.as_any().downcast_ref::().unwrap()) { + Some(n) => Ok(Some(ScalarValue::Int16(n))), + None => Ok(None), + } + } + DataType::Int32 => { + match array_ops::sum(array.as_any().downcast_ref::().unwrap()) { + Some(n) => Ok(Some(ScalarValue::Int32(n))), + None => Ok(None), + } + } + DataType::Int64 => { + match array_ops::sum(array.as_any().downcast_ref::().unwrap()) { + Some(n) => Ok(Some(ScalarValue::Int64(n))), + None => Ok(None), + } + } + DataType::Float32 => { + match array_ops::sum(array.as_any().downcast_ref::().unwrap()) { + Some(n) => Ok(Some(ScalarValue::Float32(n))), + None => Ok(None), + } + } + DataType::Float64 => { + match array_ops::sum(array.as_any().downcast_ref::().unwrap()) { + Some(n) => Ok(Some(ScalarValue::Float64(n))), + None => Ok(None), + } + } + _ => Err(ExecutionError::ExecutionError( + "Unsupported data type for SUM".to_string(), + )), + } +} + +fn update_accumulators( + batch: &RecordBatch, + row: usize, + accumulator_set: &mut AccumulatorSet, + aggr_expr: &Vec, +) { + // update the accumulators + for j in 0..accumulator_set.aggr_values.len() { + match &aggr_expr[j] { + RuntimeExpr::AggregateFunction { args, t, .. } => { + // evaluate argument to aggregate function + match args[0](&batch) { + Ok(array) => { + let value: Option = match t { + DataType::UInt8 => { + let z = + array.as_any().downcast_ref::().unwrap(); + Some(ScalarValue::UInt8(z.value(row))) + } + DataType::UInt16 => { + let z = + array.as_any().downcast_ref::().unwrap(); + Some(ScalarValue::UInt16(z.value(row))) + } + DataType::UInt32 => { + let z = + array.as_any().downcast_ref::().unwrap(); + Some(ScalarValue::UInt32(z.value(row))) + } + DataType::UInt64 => { + let z = + array.as_any().downcast_ref::().unwrap(); + Some(ScalarValue::UInt64(z.value(row))) + } + DataType::Int8 => { + let z = + array.as_any().downcast_ref::().unwrap(); + Some(ScalarValue::Int8(z.value(row))) + } + DataType::Int16 => { + let z = + array.as_any().downcast_ref::().unwrap(); + Some(ScalarValue::Int16(z.value(row))) + } + DataType::Int32 => { + let z = + array.as_any().downcast_ref::().unwrap(); + Some(ScalarValue::Int32(z.value(row))) + } + DataType::Int64 => { + let z = + array.as_any().downcast_ref::().unwrap(); + Some(ScalarValue::Int64(z.value(row))) + } + DataType::Float32 => { + let z = array + .as_any() + .downcast_ref::() + .unwrap(); + Some(ScalarValue::Float32(z.value(row))) + } + DataType::Float64 => { + let z = array + .as_any() + .downcast_ref::() + .unwrap(); + Some(ScalarValue::Float64(z.value(row))) + } + _ => panic!(), + }; + accumulator_set.accumulate_scalar(j, value); + } + _ => panic!(), + } + } + _ => panic!(), + } + } +} + +impl Relation for AggregateRelation { + fn next(&mut self) -> Result> { + if self.end_of_results { + Ok(None) + } else { + self.end_of_results = true; + if self.group_expr.is_empty() { + self.without_group_by() + } else { + self.with_group_by() + } + } + } + + fn schema(&self) -> &Arc { + &self.schema + } +} + +macro_rules! array_from_scalar { + ($BUILDER:ident, $TY:ident, $ACCUM:expr) => {{ + let mut b = $BUILDER::new(1); + let mut err = false; + match $ACCUM.result() { + Some(ScalarValue::$TY(n)) => { + b.append_value(*n)?; + } + None => { + b.append_null()?; + } + Some(_) => { + err = true; + } + }; + if err { + Err(ExecutionError::ExecutionError( + "unexpected type when creating array from scalar value".to_string(), + )) + } else { + Ok(Arc::new(b.finish()) as ArrayRef) + } + }}; +} + +/// Create array from `key` attribute in map entry (representing a grouping scalar value) +macro_rules! group_array_from_map_entries { + ($BUILDER:ident, $TY:ident, $ENTRIES:expr, $COL_INDEX:expr) => {{ + let mut builder = $BUILDER::new($ENTRIES.len()); + let mut err = false; + for j in 0..$ENTRIES.len() { + match $ENTRIES[j].k[$COL_INDEX] { + GroupByScalar::$TY(n) => builder.append_value(n).unwrap(), + _ => err = true, + } + } + if err { + Err(ExecutionError::ExecutionError( + "unexpected type when creating array from aggregate map".to_string(), + )) + } else { + Ok(Arc::new(builder.finish()) as ArrayRef) + } + }}; +} + +/// Create array from `value` attribute in map entry (representing an aggregate scalar value) +macro_rules! aggr_array_from_map_entries { + ($BUILDER:ident, $TY:ident, $ENTRIES:expr, $COL_INDEX:expr) => {{ + let mut builder = $BUILDER::new($ENTRIES.len()); + let mut err = false; + for j in 0..$ENTRIES.len() { + match $ENTRIES[j].v[$COL_INDEX] { + Some(ScalarValue::$TY(n)) => builder.append_value(n).unwrap(), + None => builder.append_null().unwrap(), + _ => err = true, + } + } + if err { + Err(ExecutionError::ExecutionError( + "unexpected type when creating array from aggregate map".to_string(), + )) + } else { + Ok(Arc::new(builder.finish()) as ArrayRef) + } + }}; +} + +impl AggregateRelation { + /// perform simple aggregate on entire columns without grouping logic + fn without_group_by(&mut self) -> Result> { + let aggr_expr_count = self.aggr_expr.len(); + let mut accumulator_set = create_accumulators(&self.aggr_expr)?; + + while let Some(batch) = self.input.borrow_mut().next()? { + for i in 0..aggr_expr_count { + match &self.aggr_expr[i] { + RuntimeExpr::AggregateFunction { f, args, t, .. } => { + // evaluate argument to aggregate function + match args[0](&batch) { + Ok(array) => match f { + AggregateType::Min => accumulator_set + .accumulate_scalar(i, array_min(array, &t)?), + AggregateType::Max => accumulator_set + .accumulate_scalar(i, array_max(array, &t)?), + AggregateType::Sum => accumulator_set + .accumulate_scalar(i, array_sum(array, &t)?), + _ => { + return Err(ExecutionError::NotImplemented( + "Unsupported aggregate function".to_string(), + )); + } + }, + Err(_) => { + return Err(ExecutionError::ExecutionError( + "Failed to evaluate argument to aggregate function" + .to_string(), + )); + } + } + } + _ => { + return Err(ExecutionError::General( + "Invalid aggregate expression".to_string(), + )); + } + } + } + } + + let mut result_columns: Vec = vec![]; + + for i in 0..aggr_expr_count { + let accum = accumulator_set.aggr_values[i].borrow(); + match accum.data_type() { + DataType::UInt8 => { + result_columns.push(array_from_scalar!(UInt8Builder, UInt8, accum)?) + } + DataType::UInt16 => { + result_columns.push(array_from_scalar!(UInt16Builder, UInt16, accum)?) + } + DataType::UInt32 => { + result_columns.push(array_from_scalar!(UInt32Builder, UInt32, accum)?) + } + DataType::UInt64 => { + result_columns.push(array_from_scalar!(UInt64Builder, UInt64, accum)?) + } + DataType::Int8 => { + result_columns.push(array_from_scalar!(Int8Builder, Int8, accum)?) + } + DataType::Int16 => { + result_columns.push(array_from_scalar!(Int16Builder, Int16, accum)?) + } + DataType::Int32 => { + result_columns.push(array_from_scalar!(Int32Builder, Int32, accum)?) + } + DataType::Int64 => { + result_columns.push(array_from_scalar!(Int64Builder, Int64, accum)?) + } + DataType::Float32 => result_columns.push(array_from_scalar!( + Float32Builder, + Float32, + accum + )?), + DataType::Float64 => result_columns.push(array_from_scalar!( + Float64Builder, + Float64, + accum + )?), + _ => return Err(ExecutionError::NotImplemented("tbd".to_string())), + } + } + + Ok(Some(RecordBatch::new(self.schema.clone(), result_columns))) + } + + fn with_group_by(&mut self) -> Result> { + //NOTE this whole method is currently very inefficient with too many per-row operations + // involving pattern matching and downcasting ... I'm sure this can be re-implemented in + // a much more efficient way that takes better advantage of Arrow + + // create map to store aggregate results + let mut map: FnvHashMap, Rc>> = + FnvHashMap::default(); + + while let Some(batch) = self.input.borrow_mut().next()? { + // evaulate the group by expressions on this batch + let group_by_keys: Vec = self + .group_expr + .iter() + .map(|e| e.get_func()(&batch)) + .collect::>>()?; + + // iterate over each row in the batch + for row in 0..batch.num_rows() { + // create key + let key: Vec = group_by_keys + .iter() + .map(|col| match col.data_type() { + DataType::UInt8 => { + let array = + col.as_any().downcast_ref::().unwrap(); + Ok(GroupByScalar::UInt8(array.value(row))) + } + DataType::UInt16 => { + let array = + col.as_any().downcast_ref::().unwrap(); + Ok(GroupByScalar::UInt16(array.value(row))) + } + DataType::UInt32 => { + let array = + col.as_any().downcast_ref::().unwrap(); + Ok(GroupByScalar::UInt32(array.value(row))) + } + DataType::UInt64 => { + let array = + col.as_any().downcast_ref::().unwrap(); + Ok(GroupByScalar::UInt64(array.value(row))) + } + DataType::Int8 => { + let array = col.as_any().downcast_ref::().unwrap(); + Ok(GroupByScalar::Int8(array.value(row))) + } + DataType::Int16 => { + let array = + col.as_any().downcast_ref::().unwrap(); + Ok(GroupByScalar::Int16(array.value(row))) + } + DataType::Int32 => { + let array = + col.as_any().downcast_ref::().unwrap(); + Ok(GroupByScalar::Int32(array.value(row))) + } + DataType::Int64 => { + let array = + col.as_any().downcast_ref::().unwrap(); + Ok(GroupByScalar::Int64(array.value(row))) + } + DataType::Utf8 => { + let array = + col.as_any().downcast_ref::().unwrap(); + Ok(GroupByScalar::Utf8(String::from( + str::from_utf8(array.value(row)).unwrap(), + ))) + } + _ => Err(ExecutionError::ExecutionError( + "Unsupported GROUP BY data type".to_string(), + )), + }) + .collect::>>()?; + + //TODO: find more elegant way to write this instead of hacking around ownership issues + + let updated = match map.get(&key) { + Some(entry) => { + let mut accumulator_set = entry.borrow_mut(); + update_accumulators( + &batch, + row, + &mut accumulator_set, + &self.aggr_expr, + ); + true + } + None => false, + }; + + if !updated { + let accumulator_set = + Rc::new(RefCell::new(create_accumulators(&self.aggr_expr)?)); + { + let mut entry_mut = accumulator_set.borrow_mut(); + update_accumulators(&batch, row, &mut entry_mut, &self.aggr_expr); + } + map.insert(key.clone(), accumulator_set); + } + } + } + + // convert the map to a vec to make it easier to build arrays + let entries: Vec = map + .iter() + .map(|(k, v)| { + let x = v.borrow(); + MapEntry { + k: k.clone(), + v: x.values(), + } + }) + .collect(); + + // build the result arrays + let mut result_arrays: Vec = + Vec::with_capacity(self.group_expr.len() + self.aggr_expr.len()); + + // grouping values + for i in 0..self.group_expr.len() { + let array: Result = match self.group_expr[i].get_type() { + DataType::UInt8 => { + group_array_from_map_entries!(UInt8Builder, UInt8, entries, i) + } + DataType::UInt16 => { + group_array_from_map_entries!(UInt16Builder, UInt16, entries, i) + } + DataType::UInt32 => { + group_array_from_map_entries!(UInt32Builder, UInt32, entries, i) + } + DataType::UInt64 => { + group_array_from_map_entries!(UInt64Builder, UInt64, entries, i) + } + DataType::Int8 => { + group_array_from_map_entries!(Int8Builder, Int8, entries, i) + } + DataType::Int16 => { + group_array_from_map_entries!(Int16Builder, Int16, entries, i) + } + DataType::Int32 => { + group_array_from_map_entries!(Int32Builder, Int32, entries, i) + } + DataType::Int64 => { + group_array_from_map_entries!(Int64Builder, Int64, entries, i) + } + DataType::Utf8 => { + let mut builder = BinaryBuilder::new(1); + for j in 0..entries.len() { + match &entries[j].k[i] { + GroupByScalar::Utf8(s) => builder.append_string(&s).unwrap(), + _ => {} + } + } + Ok(Arc::new(builder.finish()) as ArrayRef) + } + _ => Err(ExecutionError::ExecutionError( + "Unsupported group by expr".to_string(), + )), + }; + result_arrays.push(array?); + } + + // aggregate values + for i in 0..self.aggr_expr.len() { + let array = match self.aggr_expr[i].get_type() { + DataType::UInt8 => { + aggr_array_from_map_entries!(UInt8Builder, UInt8, entries, i) + } + DataType::UInt16 => { + aggr_array_from_map_entries!(UInt16Builder, UInt16, entries, i) + } + DataType::UInt32 => { + aggr_array_from_map_entries!(UInt32Builder, UInt32, entries, i) + } + DataType::UInt64 => { + aggr_array_from_map_entries!(UInt64Builder, UInt64, entries, i) + } + DataType::Int8 => { + group_array_from_map_entries!(Int8Builder, Int8, entries, i) + } + DataType::Int16 => { + aggr_array_from_map_entries!(Int16Builder, Int16, entries, i) + } + DataType::Int32 => { + aggr_array_from_map_entries!(Int32Builder, Int32, entries, i) + } + DataType::Int64 => { + aggr_array_from_map_entries!(Int64Builder, Int64, entries, i) + } + DataType::Float32 => { + aggr_array_from_map_entries!(Float32Builder, Float32, entries, i) + } + DataType::Float64 => { + aggr_array_from_map_entries!(Float64Builder, Float64, entries, i) + } + _ => Err(ExecutionError::ExecutionError( + "Unsupported aggregate expr".to_string(), + )), + }; + result_arrays.push(array?); + } + + Ok(Some(RecordBatch::new(self.schema.clone(), result_arrays))) + } +} + +#[cfg(test)] +mod tests { + use super::super::super::logicalplan::Expr; + use super::super::context::ExecutionContext; + use super::super::datasource::CsvDataSource; + use super::super::expression; + use super::super::relation::DataSourceRelation; + use super::*; + use arrow::datatypes::{DataType, Field, Schema}; + + #[test] + fn min_f64_group_by_string() { + let schema = aggr_test_schema(); + let relation = load_csv("../../testing/data/csv/aggregate_test_100.csv", &schema); + let context = ExecutionContext::new(); + + let aggr_expr = vec![expression::compile_expr( + &context, + &Expr::AggregateFunction { + name: String::from("min"), + args: vec![Expr::Column(11)], + return_type: DataType::Float64, + }, + &schema, + ) + .unwrap()]; + + let aggr_schema = Arc::new(Schema::new(vec![Field::new( + "min_lat", + DataType::Float64, + false, + )])); + + let mut projection = + AggregateRelation::new(aggr_schema, relation, vec![], aggr_expr); + let batch = projection.next().unwrap().unwrap(); + assert_eq!(1, batch.num_columns()); + let min_lat = batch + .column(0) + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!(0.01479305307777301, min_lat.value(0)); + } + + #[test] + fn max_f64_group_by_string() { + let schema = aggr_test_schema(); + let relation = load_csv("../../testing/data/csv/aggregate_test_100.csv", &schema); + let context = ExecutionContext::new(); + + let aggr_expr = vec![expression::compile_expr( + &context, + &Expr::AggregateFunction { + name: String::from("max"), + args: vec![Expr::Column(11)], + return_type: DataType::Float64, + }, + &schema, + ) + .unwrap()]; + + let aggr_schema = Arc::new(Schema::new(vec![Field::new( + "max_lat", + DataType::Float64, + false, + )])); + + let mut projection = + AggregateRelation::new(aggr_schema, relation, vec![], aggr_expr); + let batch = projection.next().unwrap().unwrap(); + assert_eq!(1, batch.num_columns()); + let max_lat = batch + .column(0) + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!(0.9965400387585364, max_lat.value(0)); + } + + #[test] + fn test_min_max_sum_f64_group_by_uint32() { + let schema = aggr_test_schema(); + let relation = load_csv("../../testing/data/csv/aggregate_test_100.csv", &schema); + + let context = ExecutionContext::new(); + + let group_by_expr = + expression::compile_expr(&context, &Expr::Column(1), &schema).unwrap(); + + let min_expr = expression::compile_expr( + &context, + &Expr::AggregateFunction { + name: String::from("min"), + args: vec![Expr::Column(11)], + return_type: DataType::Float64, + }, + &schema, + ) + .unwrap(); + + let max_expr = expression::compile_expr( + &context, + &Expr::AggregateFunction { + name: String::from("max"), + args: vec![Expr::Column(11)], + return_type: DataType::Float64, + }, + &schema, + ) + .unwrap(); + + let sum_expr = expression::compile_expr( + &context, + &Expr::AggregateFunction { + name: String::from("sum"), + args: vec![Expr::Column(11)], + return_type: DataType::Float64, + }, + &schema, + ) + .unwrap(); + + let aggr_schema = Arc::new(Schema::new(vec![ + Field::new("c2", DataType::Int32, false), + Field::new("min", DataType::Float64, false), + Field::new("max", DataType::Float64, false), + Field::new("sum", DataType::Float64, false), + ])); + + let mut projection = AggregateRelation::new( + aggr_schema, + relation, + vec![group_by_expr], + vec![min_expr, max_expr, sum_expr], + ); + let batch = projection.next().unwrap().unwrap(); + assert_eq!(4, batch.num_columns()); + assert_eq!(5, batch.num_rows()); + + let a = batch + .column(0) + .as_any() + .downcast_ref::() + .unwrap(); + let min = batch + .column(1) + .as_any() + .downcast_ref::() + .unwrap(); + let max = batch + .column(2) + .as_any() + .downcast_ref::() + .unwrap(); + let sum = batch + .column(3) + .as_any() + .downcast_ref::() + .unwrap(); + + assert_eq!(4, a.value(0)); + assert_eq!(0.02182578039211991, min.value(0)); + assert_eq!(0.9237877978193884, max.value(0)); + assert_eq!(9.253864188402662, sum.value(0)); + + assert_eq!(2, a.value(1)); + assert_eq!(0.16301110515739792, min.value(1)); + assert_eq!(0.991517828651004, max.value(1)); + assert_eq!(14.400412325480858, sum.value(1)); + + assert_eq!(5, a.value(2)); + assert_eq!(0.01479305307777301, min.value(2)); + assert_eq!(0.9723580396501548, max.value(2)); + assert_eq!(6.037181692266781, sum.value(2)); + } + + fn aggr_test_schema() -> Arc { + Arc::new(Schema::new(vec![ + Field::new("c1", DataType::Utf8, false), + Field::new("c2", DataType::UInt32, false), + Field::new("c3", DataType::Int8, false), + Field::new("c3", DataType::Int16, false), + Field::new("c4", DataType::Int32, false), + Field::new("c5", DataType::Int64, false), + Field::new("c6", DataType::UInt8, false), + Field::new("c7", DataType::UInt16, false), + Field::new("c8", DataType::UInt32, false), + Field::new("c9", DataType::UInt64, false), + Field::new("c10", DataType::Float32, false), + Field::new("c11", DataType::Float64, false), + Field::new("c12", DataType::Utf8, false), + ])) + } + + fn load_csv(filename: &str, schema: &Arc) -> Rc> { + let ds = CsvDataSource::new(filename, schema.clone(), 1024); + Rc::new(RefCell::new(DataSourceRelation::new(Rc::new( + RefCell::new(ds), + )))) + } + +} diff --git a/rust/datafusion/src/execution/context.rs b/rust/datafusion/src/execution/context.rs new file mode 100644 index 0000000000000..86d7c99c21900 --- /dev/null +++ b/rust/datafusion/src/execution/context.rs @@ -0,0 +1,228 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::cell::RefCell; +use std::collections::HashMap; +use std::rc::Rc; +use std::sync::Arc; + +use arrow::datatypes::{Field, Schema}; + +use super::super::dfparser::{DFASTNode, DFParser}; +use super::super::logicalplan::*; +use super::super::sqlplanner::{SchemaProvider, SqlToRel}; +use super::aggregate::AggregateRelation; +use super::datasource::DataSource; +use super::error::{ExecutionError, Result}; +use super::expression::*; +use super::filter::FilterRelation; +use super::projection::ProjectRelation; +use super::relation::{DataSourceRelation, Relation}; + +pub struct ExecutionContext { + datasources: Rc>>>>, +} + +impl ExecutionContext { + pub fn new() -> Self { + Self { + datasources: Rc::new(RefCell::new(HashMap::new())), + } + } + + pub fn sql(&mut self, sql: &str) -> Result>> { + let ast = DFParser::parse_sql(String::from(sql))?; + + match ast { + DFASTNode::ANSI(ansi) => { + let schema_provider: Rc = + Rc::new(ExecutionContextSchemaProvider { + datasources: self.datasources.clone(), + }); + + // create a query planner + let query_planner = SqlToRel::new(schema_provider); + + // plan the query (create a logical relational plan) + let plan = query_planner.sql_to_rel(&ansi)?; + //println!("Logical plan: {:?}", plan); + + let optimized_plan = plan; //push_down_projection(&plan, &HashSet::new()); + //println!("Optimized logical plan: {:?}", new_plan); + + let relation = self.execute(&optimized_plan)?; + + Ok(relation) + } + _ => unimplemented!(), + } + } + + pub fn register_datasource(&mut self, name: &str, ds: Rc>) { + self.datasources.borrow_mut().insert(name.to_string(), ds); + } + + pub fn execute(&mut self, plan: &LogicalPlan) -> Result>> { + println!("Logical plan: {:?}", plan); + + match *plan { + LogicalPlan::TableScan { ref table_name, .. } => { + match self.datasources.borrow().get(table_name) { + Some(ds) => { + //TODO: projection + Ok(Rc::new(RefCell::new(DataSourceRelation::new(ds.clone())))) + } + _ => Err(ExecutionError::General(format!( + "No table registered as '{}'", + table_name + ))), + } + } + LogicalPlan::Selection { + ref expr, + ref input, + } => { + let input_rel = self.execute(input)?; + let input_schema = input_rel.as_ref().borrow().schema().clone(); + let runtime_expr = compile_scalar_expr(&self, expr, &input_schema)?; + let rel = FilterRelation::new( + input_rel, + runtime_expr, /* .get_func().clone() */ + input_schema, + ); + Ok(Rc::new(RefCell::new(rel))) + } + LogicalPlan::Projection { + ref expr, + ref input, + .. + } => { + let input_rel = self.execute(input)?; + + let input_schema = input_rel.as_ref().borrow().schema().clone(); + + let project_columns: Vec = + exprlist_to_fields(&expr, &input_schema); + + let project_schema = Arc::new(Schema::new(project_columns)); + + let compiled_expr: Result> = expr + .iter() + .map(|e| compile_scalar_expr(&self, e, &input_schema)) + .collect(); + + let rel = ProjectRelation::new(input_rel, compiled_expr?, project_schema); + + Ok(Rc::new(RefCell::new(rel))) + } + LogicalPlan::Aggregate { + ref input, + ref group_expr, + ref aggr_expr, + .. + } => { + let input_rel = self.execute(&input)?; + + let input_schema = input_rel.as_ref().borrow().schema().clone(); + + let compiled_group_expr_result: Result> = group_expr + .iter() + .map(|e| compile_scalar_expr(&self, e, &input_schema)) + .collect(); + let compiled_group_expr = compiled_group_expr_result?; + + let compiled_aggr_expr_result: Result> = aggr_expr + .iter() + .map(|e| compile_expr(&self, e, &input_schema)) + .collect(); + let compiled_aggr_expr = compiled_aggr_expr_result?; + + let rel = AggregateRelation::new( + Arc::new(Schema::empty()), //(expr_to_field(&compiled_group_expr, &input_schema))), + input_rel, + compiled_group_expr, + compiled_aggr_expr, + ); + + Ok(Rc::new(RefCell::new(rel))) + } + + _ => unimplemented!(), + } + } +} + +#[derive(Debug, Clone)] +pub enum ExecutionResult { + Unit, + Count(usize), + Str(String), +} + +pub fn expr_to_field(e: &Expr, input_schema: &Schema) -> Field { + match e { + Expr::Column(i) => input_schema.fields()[*i].clone(), + Expr::Literal(ref lit) => Field::new("lit", lit.get_datatype(), true), + Expr::ScalarFunction { + ref name, + ref return_type, + .. + } => Field::new(&name, return_type.clone(), true), + Expr::AggregateFunction { + ref name, + ref return_type, + .. + } => Field::new(&name, return_type.clone(), true), + Expr::Cast { ref data_type, .. } => Field::new("cast", data_type.clone(), true), + Expr::BinaryExpr { + ref left, + ref right, + .. + } => { + let left_type = left.get_type(input_schema); + let right_type = right.get_type(input_schema); + Field::new( + "binary_expr", + get_supertype(&left_type, &right_type).unwrap(), + true, + ) + } + _ => unimplemented!("Cannot determine schema type for expression {:?}", e), + } +} + +pub fn exprlist_to_fields(expr: &Vec, input_schema: &Schema) -> Vec { + expr.iter() + .map(|e| expr_to_field(e, input_schema)) + .collect() +} + +struct ExecutionContextSchemaProvider { + datasources: Rc>>>>, +} +impl SchemaProvider for ExecutionContextSchemaProvider { + fn get_table_meta(&self, name: &str) -> Option> { + match self.datasources.borrow().get(name) { + Some(ds) => Some(ds.borrow().schema().clone()), + None => None, + } + } + + fn get_function_meta(&self, _name: &str) -> Option> { + unimplemented!() + } +} diff --git a/rust/datafusion/src/execution/datasource.rs b/rust/datafusion/src/execution/datasource.rs new file mode 100644 index 0000000000000..379632b95a37f --- /dev/null +++ b/rust/datafusion/src/execution/datasource.rs @@ -0,0 +1,74 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Data sources + +use std::fs::File; +use std::rc::Rc; +use std::sync::Arc; + +use arrow::csv; +use arrow::datatypes::Schema; +use arrow::record_batch::RecordBatch; + +use super::error::Result; + +pub trait DataSource { + fn schema(&self) -> &Arc; + fn next(&mut self) -> Result>; +} + +/// CSV data source +pub struct CsvDataSource { + schema: Arc, + reader: csv::Reader, +} + +impl CsvDataSource { + pub fn new(filename: &str, schema: Arc, batch_size: usize) -> Self { + let file = File::open(filename).unwrap(); + let reader = csv::Reader::new(file, schema.clone(), true, batch_size, None); + Self { schema, reader } + } +} + +impl DataSource for CsvDataSource { + fn schema(&self) -> &Arc { + &self.schema + } + + fn next(&mut self) -> Result> { + Ok(self.reader.next()?) + } +} + +#[derive(Serialize, Deserialize, Clone)] +pub enum DataSourceMeta { + /// Represents a CSV file with a provided schema + CsvFile { + filename: String, + schema: Rc, + has_header: bool, + projection: Option>, + }, + /// Represents a Parquet file that contains schema information + ParquetFile { + filename: String, + schema: Rc, + projection: Option>, + }, +} diff --git a/rust/datafusion/src/execution/error.rs b/rust/datafusion/src/execution/error.rs new file mode 100644 index 0000000000000..5b8d04d3dca34 --- /dev/null +++ b/rust/datafusion/src/execution/error.rs @@ -0,0 +1,69 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Error types + +use std::io::Error; +use std::result; + +use arrow::error::ArrowError; + +use sqlparser::sqlparser::ParserError; + +pub type Result = result::Result; + +#[derive(Debug)] +pub enum ExecutionError { + IoError(Error), + ParserError(ParserError), + General(String), + InvalidColumn(String), + NotImplemented(String), + InternalError(String), + ArrowError(ArrowError), + ExecutionError(String), +} + +impl From for ExecutionError { + fn from(e: Error) -> Self { + ExecutionError::IoError(e) + } +} + +impl From for ExecutionError { + fn from(e: String) -> Self { + ExecutionError::General(e) + } +} + +impl From<&'static str> for ExecutionError { + fn from(e: &'static str) -> Self { + ExecutionError::General(e.to_string()) + } +} + +impl From for ExecutionError { + fn from(e: ArrowError) -> Self { + ExecutionError::ArrowError(e) + } +} + +impl From for ExecutionError { + fn from(e: ParserError) -> Self { + ExecutionError::ParserError(e) + } +} diff --git a/rust/datafusion/src/execution/expression.rs b/rust/datafusion/src/execution/expression.rs new file mode 100644 index 0000000000000..108a8558b9adf --- /dev/null +++ b/rust/datafusion/src/execution/expression.rs @@ -0,0 +1,516 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::rc::Rc; +use std::sync::Arc; + +use arrow::array::*; +use arrow::array_ops; +use arrow::datatypes::{DataType, Schema}; +use arrow::record_batch::RecordBatch; + +use super::super::logicalplan::{Expr, Operator, ScalarValue}; +use super::context::ExecutionContext; +use super::error::{ExecutionError, Result}; + +/// Compiled Expression (basically just a closure to evaluate the expression at runtime) +pub type CompiledExpr = Rc Result>; + +pub type CompiledCastFunction = Rc Result>; + +pub enum AggregateType { + Min, + Max, + Sum, + Count, + CountDistinct, + Avg, +} + +/// Runtime expression +pub enum RuntimeExpr { + Compiled { + name: String, + f: CompiledExpr, + t: DataType, + }, + AggregateFunction { + name: String, + f: AggregateType, + args: Vec, + t: DataType, + }, +} + +impl RuntimeExpr { + pub fn get_func(&self) -> CompiledExpr { + match self { + &RuntimeExpr::Compiled { ref f, .. } => f.clone(), + _ => panic!(), + } + } + + pub fn get_name(&self) -> &String { + match self { + &RuntimeExpr::Compiled { ref name, .. } => name, + &RuntimeExpr::AggregateFunction { ref name, .. } => name, + } + } + + pub fn get_type(&self) -> DataType { + match self { + &RuntimeExpr::Compiled { ref t, .. } => t.clone(), + &RuntimeExpr::AggregateFunction { ref t, .. } => t.clone(), + } + } +} + +/// Compiles a scalar expression into a closure +pub fn compile_expr( + ctx: &ExecutionContext, + expr: &Expr, + input_schema: &Schema, +) -> Result { + match *expr { + Expr::AggregateFunction { + ref name, + ref args, + ref return_type, + } => { + assert_eq!(1, args.len()); + + let compiled_args: Result> = args + .iter() + .map(|e| compile_scalar_expr(&ctx, e, input_schema)) + .collect(); + + let func = match name.to_lowercase().as_ref() { + "min" => Ok(AggregateType::Min), + "max" => Ok(AggregateType::Max), + "count" => Ok(AggregateType::Count), + "sum" => Ok(AggregateType::Sum), + _ => Err(ExecutionError::General(format!( + "Unsupported aggregate function '{}'", + name + ))), + }; + + Ok(RuntimeExpr::AggregateFunction { + name: name.to_string(), + f: func?, + args: compiled_args? + .iter() + .map(|e| e.get_func().clone()) + .collect(), + t: return_type.clone(), + }) + } + _ => Ok(compile_scalar_expr(&ctx, expr, input_schema)?), + } +} + +macro_rules! binary_op { + ($LEFT:expr, $RIGHT:expr, $OP:ident, $DT:ident) => {{ + let ll = $LEFT.as_any().downcast_ref::<$DT>().unwrap(); + let rr = $RIGHT.as_any().downcast_ref::<$DT>().unwrap(); + Ok(Arc::new(array_ops::$OP(&ll, &rr)?)) + }}; +} + +macro_rules! math_ops { + ($LEFT:expr, $RIGHT:expr, $BATCH:expr, $OP:ident) => {{ + let left_values = $LEFT.get_func()($BATCH)?; + let right_values = $RIGHT.get_func()($BATCH)?; + match (left_values.data_type(), right_values.data_type()) { + (DataType::Int8, DataType::Int8) => { + binary_op!(left_values, right_values, $OP, Int8Array) + } + (DataType::Int16, DataType::Int16) => { + binary_op!(left_values, right_values, $OP, Int16Array) + } + (DataType::Int32, DataType::Int32) => { + binary_op!(left_values, right_values, $OP, Int32Array) + } + (DataType::Int64, DataType::Int64) => { + binary_op!(left_values, right_values, $OP, Int64Array) + } + (DataType::UInt8, DataType::UInt8) => { + binary_op!(left_values, right_values, $OP, UInt8Array) + } + (DataType::UInt16, DataType::UInt16) => { + binary_op!(left_values, right_values, $OP, UInt16Array) + } + (DataType::UInt32, DataType::UInt32) => { + binary_op!(left_values, right_values, $OP, UInt32Array) + } + (DataType::UInt64, DataType::UInt64) => { + binary_op!(left_values, right_values, $OP, UInt64Array) + } + (DataType::Float32, DataType::Float32) => { + binary_op!(left_values, right_values, $OP, Float32Array) + } + (DataType::Float64, DataType::Float64) => { + binary_op!(left_values, right_values, $OP, Float64Array) + } + _ => Err(ExecutionError::ExecutionError(format!("math_ops"))), + } + }}; +} + +macro_rules! comparison_ops { + ($LEFT:expr, $RIGHT:expr, $BATCH:expr, $OP:ident) => {{ + let left_values = $LEFT.get_func()($BATCH)?; + let right_values = $RIGHT.get_func()($BATCH)?; + match (left_values.data_type(), right_values.data_type()) { + (DataType::Int8, DataType::Int8) => { + binary_op!(left_values, right_values, $OP, Int8Array) + } + (DataType::Int16, DataType::Int16) => { + binary_op!(left_values, right_values, $OP, Int16Array) + } + (DataType::Int32, DataType::Int32) => { + binary_op!(left_values, right_values, $OP, Int32Array) + } + (DataType::Int64, DataType::Int64) => { + binary_op!(left_values, right_values, $OP, Int64Array) + } + (DataType::UInt8, DataType::UInt8) => { + binary_op!(left_values, right_values, $OP, UInt8Array) + } + (DataType::UInt16, DataType::UInt16) => { + binary_op!(left_values, right_values, $OP, UInt16Array) + } + (DataType::UInt32, DataType::UInt32) => { + binary_op!(left_values, right_values, $OP, UInt32Array) + } + (DataType::UInt64, DataType::UInt64) => { + binary_op!(left_values, right_values, $OP, UInt64Array) + } + (DataType::Float32, DataType::Float32) => { + binary_op!(left_values, right_values, $OP, Float32Array) + } + (DataType::Float64, DataType::Float64) => { + binary_op!(left_values, right_values, $OP, Float64Array) + } + //TODO other types + _ => Err(ExecutionError::ExecutionError(format!("comparison_ops"))), + } + }}; +} + +macro_rules! boolean_ops { + ($LEFT:expr, $RIGHT:expr, $BATCH:expr, $OP:ident) => {{ + let left_values = $LEFT.get_func()($BATCH)?; + let right_values = $RIGHT.get_func()($BATCH)?; + Ok(Arc::new(array_ops::$OP( + left_values.as_any().downcast_ref::().unwrap(), + right_values + .as_any() + .downcast_ref::() + .unwrap(), + )?)) + }}; +} + +macro_rules! literal_array { + ($VALUE:expr, $ARRAY_TYPE:ident, $TY:ident) => {{ + let nn = *$VALUE; + Ok(RuntimeExpr::Compiled { + name: format!("{}", nn), + f: Rc::new(move |batch: &RecordBatch| { + let capacity = batch.num_rows(); + let mut builder = $ARRAY_TYPE::builder(capacity); + for _ in 0..capacity { + builder.append_value(nn)?; + } + let array = builder.finish(); + Ok(Arc::new(array) as ArrayRef) + }), + t: DataType::$TY, + }) + }}; +} + +/// Casts a column to an array with a different data type +macro_rules! cast_column { + ($INDEX:expr, $FROM_TYPE:ty, $TO_TYPE:ident, $DT:ty) => {{ + Rc::new(move |batch: &RecordBatch| { + // get data and cast to known type + match batch.column($INDEX).as_any().downcast_ref::<$FROM_TYPE>() { + Some(array) => { + // create builder for desired type + let mut builder = $TO_TYPE::builder(batch.num_rows()); + for i in 0..batch.num_rows() { + if array.is_null(i) { + builder.append_null()?; + } else { + builder.append_value(array.value(i) as $DT)?; + } + } + Ok(Arc::new(builder.finish()) as ArrayRef) + } + None => Err(ExecutionError::InternalError(format!( + "Column at index {} is not of expected type", + $INDEX + ))), + } + }) + }}; +} + +macro_rules! cast_column_outer { + ($INDEX:expr, $FROM_TYPE:ty, $TO_TYPE:expr) => {{ + match $TO_TYPE { + DataType::UInt8 => cast_column!($INDEX, $FROM_TYPE, UInt8Array, u8), + DataType::UInt16 => cast_column!($INDEX, $FROM_TYPE, UInt16Array, u16), + DataType::UInt32 => cast_column!($INDEX, $FROM_TYPE, UInt32Array, u32), + DataType::UInt64 => cast_column!($INDEX, $FROM_TYPE, UInt64Array, u64), + DataType::Int8 => cast_column!($INDEX, $FROM_TYPE, Int8Array, i8), + DataType::Int16 => cast_column!($INDEX, $FROM_TYPE, Int16Array, i16), + DataType::Int32 => cast_column!($INDEX, $FROM_TYPE, Int32Array, i32), + DataType::Int64 => cast_column!($INDEX, $FROM_TYPE, Int64Array, i64), + DataType::Float32 => cast_column!($INDEX, $FROM_TYPE, Float32Array, f32), + DataType::Float64 => cast_column!($INDEX, $FROM_TYPE, Float64Array, f64), + _ => unimplemented!(), + } + }}; +} + +/// Compiles a scalar expression into a closure +pub fn compile_scalar_expr( + ctx: &ExecutionContext, + expr: &Expr, + input_schema: &Schema, +) -> Result { + match expr { + &Expr::Literal(ref value) => match value { + //NOTE: this is a temporary hack .. due to the way expressions like 'a > 1' are + // evaluated, currently the left and right are evaluated separately and must result + // in arrays and then the '>' operator is evaluated against the two arrays. This works + // but is dumb ... I intend to optimize this soon to add special handling for + // binary expressions that involve literal values to avoid creating arrays of literals + // filed as https://github.com/andygrove/datafusion/issues/191 + ScalarValue::Int8(n) => literal_array!(n, Int8Array, Int8), + ScalarValue::Int16(n) => literal_array!(n, Int16Array, Int16), + ScalarValue::Int32(n) => literal_array!(n, Int32Array, Int32), + ScalarValue::Int64(n) => literal_array!(n, Int64Array, Int64), + ScalarValue::UInt8(n) => literal_array!(n, UInt8Array, UInt8), + ScalarValue::UInt16(n) => literal_array!(n, UInt16Array, UInt16), + ScalarValue::UInt32(n) => literal_array!(n, UInt32Array, UInt32), + ScalarValue::UInt64(n) => literal_array!(n, UInt64Array, UInt64), + ScalarValue::Float32(n) => literal_array!(n, Float32Array, Float32), + ScalarValue::Float64(n) => literal_array!(n, Float64Array, Float64), + other => Err(ExecutionError::ExecutionError(format!( + "No support for literal type {:?}", + other + ))), + }, + &Expr::Column(index) => Ok(RuntimeExpr::Compiled { + name: input_schema.field(index).name().clone(), + f: Rc::new(move |batch: &RecordBatch| Ok((*batch.column(index)).clone())), + t: input_schema.field(index).data_type().clone(), + }), + &Expr::Cast { + ref expr, + ref data_type, + } => match expr.as_ref() { + &Expr::Column(index) => { + let col = input_schema.field(index); + Ok(RuntimeExpr::Compiled { + name: col.name().clone(), + t: col.data_type().clone(), + f: match col.data_type() { + DataType::Int8 => { + cast_column_outer!(index, Int8Array, &data_type) + } + DataType::Int16 => { + cast_column_outer!(index, Int16Array, &data_type) + } + DataType::Int32 => { + cast_column_outer!(index, Int32Array, &data_type) + } + DataType::Int64 => { + cast_column_outer!(index, Int64Array, &data_type) + } + DataType::UInt8 => { + cast_column_outer!(index, UInt8Array, &data_type) + } + DataType::UInt16 => { + cast_column_outer!(index, UInt16Array, &data_type) + } + DataType::UInt32 => { + cast_column_outer!(index, UInt32Array, &data_type) + } + DataType::UInt64 => { + cast_column_outer!(index, UInt64Array, &data_type) + } + DataType::Float32 => { + cast_column_outer!(index, Float32Array, &data_type) + } + DataType::Float64 => { + cast_column_outer!(index, Float64Array, &data_type) + } + _ => panic!("unsupported CAST operation"), /*TODO */ + /*Err(ExecutionError::NotImplemented(format!( + "CAST column from {:?} to {:?}", + col.data_type(), + data_type + )))*/ + }, + }) + } + &Expr::Literal(ref value) => { + //NOTE this is all very inefficient and needs to be optimized - tracking + // issue is https://github.com/andygrove/datafusion/issues/191 + match value { + ScalarValue::Int64(n) => { + let nn = *n; + match data_type { + DataType::Float64 => Ok(RuntimeExpr::Compiled { + name: "lit".to_string(), + f: Rc::new(move |batch: &RecordBatch| { + let mut b = Float64Array::builder(batch.num_rows()); + for _ in 0..batch.num_rows() { + b.append_value(nn as f64)?; + } + Ok(Arc::new(b.finish()) as ArrayRef) + }), + t: data_type.clone(), + }), + other => Err(ExecutionError::NotImplemented(format!( + "CAST from Int64 to {:?}", + other + ))), + } + } + other => Err(ExecutionError::NotImplemented(format!( + "CAST from {:?} to {:?}", + other, data_type + ))), + } + } + other => Err(ExecutionError::General(format!( + "CAST not implemented for expression {:?}", + other + ))), + }, + &Expr::BinaryExpr { + ref left, + ref op, + ref right, + } => { + let left_expr = compile_scalar_expr(ctx, left, input_schema)?; + let right_expr = compile_scalar_expr(ctx, right, input_schema)?; + let name = format!("{:?} {:?} {:?}", left, op, right); + let op_type = left_expr.get_type().clone(); + match op { + &Operator::Eq => Ok(RuntimeExpr::Compiled { + name, + f: Rc::new(move |batch: &RecordBatch| { + comparison_ops!(left_expr, right_expr, batch, eq) + }), + t: DataType::Boolean, + }), + &Operator::NotEq => Ok(RuntimeExpr::Compiled { + name, + f: Rc::new(move |batch: &RecordBatch| { + comparison_ops!(left_expr, right_expr, batch, neq) + }), + t: DataType::Boolean, + }), + &Operator::Lt => Ok(RuntimeExpr::Compiled { + name, + f: Rc::new(move |batch: &RecordBatch| { + comparison_ops!(left_expr, right_expr, batch, lt) + }), + t: DataType::Boolean, + }), + &Operator::LtEq => Ok(RuntimeExpr::Compiled { + name, + f: Rc::new(move |batch: &RecordBatch| { + comparison_ops!(left_expr, right_expr, batch, lt_eq) + }), + t: DataType::Boolean, + }), + &Operator::Gt => Ok(RuntimeExpr::Compiled { + name, + f: Rc::new(move |batch: &RecordBatch| { + comparison_ops!(left_expr, right_expr, batch, gt) + }), + t: DataType::Boolean, + }), + &Operator::GtEq => Ok(RuntimeExpr::Compiled { + name, + f: Rc::new(move |batch: &RecordBatch| { + comparison_ops!(left_expr, right_expr, batch, gt_eq) + }), + t: DataType::Boolean, + }), + &Operator::And => Ok(RuntimeExpr::Compiled { + name, + f: Rc::new(move |batch: &RecordBatch| { + boolean_ops!(left_expr, right_expr, batch, and) + }), + t: DataType::Boolean, + }), + &Operator::Or => Ok(RuntimeExpr::Compiled { + name, + f: Rc::new(move |batch: &RecordBatch| { + boolean_ops!(left_expr, right_expr, batch, or) + }), + t: DataType::Boolean, + }), + &Operator::Plus => Ok(RuntimeExpr::Compiled { + name, + f: Rc::new(move |batch: &RecordBatch| { + math_ops!(left_expr, right_expr, batch, add) + }), + t: op_type, + }), + &Operator::Minus => Ok(RuntimeExpr::Compiled { + name, + f: Rc::new(move |batch: &RecordBatch| { + math_ops!(left_expr, right_expr, batch, subtract) + }), + t: op_type, + }), + &Operator::Multiply => Ok(RuntimeExpr::Compiled { + name, + f: Rc::new(move |batch: &RecordBatch| { + math_ops!(left_expr, right_expr, batch, multiply) + }), + t: op_type, + }), + &Operator::Divide => Ok(RuntimeExpr::Compiled { + name, + f: Rc::new(move |batch: &RecordBatch| { + math_ops!(left_expr, right_expr, batch, divide) + }), + t: op_type, + }), + other => Err(ExecutionError::ExecutionError(format!( + "operator: {:?}", + other + ))), + } + } + other => Err(ExecutionError::ExecutionError(format!( + "expression {:?}", + other + ))), + } +} diff --git a/rust/datafusion/src/execution/filter.rs b/rust/datafusion/src/execution/filter.rs new file mode 100644 index 0000000000000..ba20dca036e5b --- /dev/null +++ b/rust/datafusion/src/execution/filter.rs @@ -0,0 +1,211 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Execution of a filter (predicate) + +use std::cell::RefCell; +use std::rc::Rc; +use std::sync::Arc; + +use arrow::array::*; +use arrow::datatypes::{DataType, Schema}; +use arrow::record_batch::RecordBatch; + +use super::error::{ExecutionError, Result}; +use super::expression::RuntimeExpr; +use super::relation::Relation; + +pub struct FilterRelation { + schema: Arc, + input: Rc>, + expr: RuntimeExpr, +} + +impl FilterRelation { + pub fn new( + input: Rc>, + expr: RuntimeExpr, + schema: Arc, + ) -> Self { + Self { + schema, + input, + expr, + } + } +} + +impl Relation for FilterRelation { + fn next(&mut self) -> Result> { + match self.input.borrow_mut().next()? { + Some(batch) => { + // evaluate the filter expression against the batch + match self.expr.get_func()(&batch)? + .as_any() + .downcast_ref::() + { + Some(filter_bools) => { + let filtered_columns: Result> = (0..batch + .num_columns()) + .map(|i| filter(batch.column(i), &filter_bools)) + .collect(); + + let filtered_batch: RecordBatch = RecordBatch::new( + Arc::new(Schema::empty()), + filtered_columns?, + ); + + Ok(Some(filtered_batch)) + } + _ => Err(ExecutionError::ExecutionError( + "Filter expression did not evaluate to boolean".to_string(), + )), + } + } + None => Ok(None), + } + } + + fn schema(&self) -> &Arc { + &self.schema + } +} + +//TODO: move into Arrow array_ops +fn filter(array: &Arc, filter: &BooleanArray) -> Result { + let a = array.as_ref(); + + //TODO use macros + match a.data_type() { + DataType::UInt8 => { + let b = a.as_any().downcast_ref::().unwrap(); + let mut builder = UInt8Array::builder(b.len()); + for i in 0..b.len() { + if filter.value(i) { + builder.append_value(b.value(i))?; + } + } + Ok(Arc::new(builder.finish())) + } + DataType::UInt16 => { + let b = a.as_any().downcast_ref::().unwrap(); + let mut builder = UInt16Array::builder(b.len()); + for i in 0..b.len() { + if filter.value(i) { + builder.append_value(b.value(i))?; + } + } + Ok(Arc::new(builder.finish())) + } + DataType::UInt32 => { + let b = a.as_any().downcast_ref::().unwrap(); + let mut builder = UInt32Array::builder(b.len()); + for i in 0..b.len() { + if filter.value(i) { + builder.append_value(b.value(i))?; + } + } + Ok(Arc::new(builder.finish())) + } + DataType::UInt64 => { + let b = a.as_any().downcast_ref::().unwrap(); + let mut builder = UInt64Array::builder(b.len()); + for i in 0..b.len() { + if filter.value(i) { + builder.append_value(b.value(i))?; + } + } + Ok(Arc::new(builder.finish())) + } + DataType::Int8 => { + let b = a.as_any().downcast_ref::().unwrap(); + let mut builder = Int8Array::builder(b.len()); + for i in 0..b.len() { + if filter.value(i) { + builder.append_value(b.value(i))?; + } + } + Ok(Arc::new(builder.finish())) + } + DataType::Int16 => { + let b = a.as_any().downcast_ref::().unwrap(); + let mut builder = Int16Array::builder(b.len()); + for i in 0..b.len() { + if filter.value(i) { + builder.append_value(b.value(i))?; + } + } + Ok(Arc::new(builder.finish())) + } + DataType::Int32 => { + let b = a.as_any().downcast_ref::().unwrap(); + let mut builder = Int32Array::builder(b.len()); + for i in 0..b.len() { + if filter.value(i) { + builder.append_value(b.value(i))?; + } + } + Ok(Arc::new(builder.finish())) + } + DataType::Int64 => { + let b = a.as_any().downcast_ref::().unwrap(); + let mut builder = Int64Array::builder(b.len()); + for i in 0..b.len() { + if filter.value(i) { + builder.append_value(b.value(i))?; + } + } + Ok(Arc::new(builder.finish())) + } + DataType::Float32 => { + let b = a.as_any().downcast_ref::().unwrap(); + let mut builder = Float32Array::builder(b.len()); + for i in 0..b.len() { + if filter.value(i) { + builder.append_value(b.value(i))?; + } + } + Ok(Arc::new(builder.finish())) + } + DataType::Float64 => { + let b = a.as_any().downcast_ref::().unwrap(); + let mut builder = Float64Array::builder(b.len()); + for i in 0..b.len() { + if filter.value(i) { + builder.append_value(b.value(i))?; + } + } + Ok(Arc::new(builder.finish())) + } + DataType::Utf8 => { + //TODO: this is inefficient and we should improve the Arrow impl to help make this more concise + let b = a.as_any().downcast_ref::().unwrap(); + let mut values: Vec = Vec::with_capacity(b.len()); + for i in 0..b.len() { + if filter.value(i) { + values.push(b.get_string(i)); + } + } + let tmp: Vec<&str> = values.iter().map(|s| s.as_str()).collect(); + Ok(Arc::new(BinaryArray::from(tmp))) + } + other => Err(ExecutionError::ExecutionError(format!( + "filter not supported for {:?}", + other + ))), + } +} diff --git a/rust/datafusion/src/execution/mod.rs b/rust/datafusion/src/execution/mod.rs new file mode 100644 index 0000000000000..23144bb5173ca --- /dev/null +++ b/rust/datafusion/src/execution/mod.rs @@ -0,0 +1,26 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +pub mod aggregate; +pub mod context; +pub mod datasource; +pub mod error; +pub mod expression; +pub mod filter; +pub mod physicalplan; +pub mod projection; +pub mod relation; diff --git a/rust/datafusion/src/execution/physicalplan.rs b/rust/datafusion/src/execution/physicalplan.rs new file mode 100644 index 0000000000000..23aa4312bfbf2 --- /dev/null +++ b/rust/datafusion/src/execution/physicalplan.rs @@ -0,0 +1,37 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use super::super::logicalplan::LogicalPlan; +use std::rc::Rc; + +#[derive(Serialize, Deserialize, Debug, Clone)] +pub enum PhysicalPlan { + /// Run a query and return the results to the client + Interactive { + plan: Rc, + }, + /// Execute a logical plan and write the output to a file + Write { + plan: Rc, + filename: String, + kind: String, + }, + Show { + plan: Rc, + count: usize, + }, +} diff --git a/rust/datafusion/src/execution/projection.rs b/rust/datafusion/src/execution/projection.rs new file mode 100644 index 0000000000000..6de5818ab1983 --- /dev/null +++ b/rust/datafusion/src/execution/projection.rs @@ -0,0 +1,130 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Execution of a projection + +use std::cell::RefCell; +use std::rc::Rc; +use std::sync::Arc; + +use arrow::array::ArrayRef; +use arrow::datatypes::{Field, Schema}; +use arrow::record_batch::RecordBatch; + +use super::error::Result; +use super::expression::RuntimeExpr; +use super::relation::Relation; + +pub struct ProjectRelation { + schema: Arc, + input: Rc>, + expr: Vec, +} + +impl ProjectRelation { + pub fn new( + input: Rc>, + expr: Vec, + schema: Arc, + ) -> Self { + ProjectRelation { + input, + expr, + schema, + } + } +} + +impl Relation for ProjectRelation { + fn next(&mut self) -> Result> { + match self.input.borrow_mut().next()? { + Some(batch) => { + let projected_columns: Result> = + self.expr.iter().map(|e| e.get_func()(&batch)).collect(); + + let schema = Schema::new( + self.expr + .iter() + .map(|e| Field::new(&e.get_name(), e.get_type(), true)) + .collect(), + ); + + let projected_batch: RecordBatch = + RecordBatch::new(Arc::new(schema), projected_columns?); + + Ok(Some(projected_batch)) + } + None => Ok(None), + } + } + + fn schema(&self) -> &Arc { + &self.schema + } +} + +#[cfg(test)] +mod tests { + use super::super::super::logicalplan::Expr; + use super::super::context::ExecutionContext; + use super::super::datasource::CsvDataSource; + use super::super::expression; + use super::super::relation::DataSourceRelation; + use super::*; + use arrow::datatypes::{DataType, Field, Schema}; + + #[test] + fn project_first_column() { + let schema = Arc::new(Schema::new(vec![ + Field::new("c1", DataType::Utf8, false), + Field::new("c2", DataType::UInt32, false), + Field::new("c3", DataType::Int8, false), + Field::new("c3", DataType::Int16, false), + Field::new("c4", DataType::Int32, false), + Field::new("c5", DataType::Int64, false), + Field::new("c6", DataType::UInt8, false), + Field::new("c7", DataType::UInt16, false), + Field::new("c8", DataType::UInt32, false), + Field::new("c9", DataType::UInt64, false), + Field::new("c10", DataType::Float32, false), + Field::new("c11", DataType::Float64, false), + Field::new("c12", DataType::Utf8, false), + ])); + let ds = CsvDataSource::new( + "../../testing/data/csv/aggregate_test_100.csv", + schema.clone(), + 1024, + ); + let relation = Rc::new(RefCell::new(DataSourceRelation::new(Rc::new( + RefCell::new(ds), + )))); + let context = ExecutionContext::new(); + + let projection_expr = + vec![ + expression::compile_expr(&context, &Expr::Column(0), schema.as_ref()) + .unwrap(), + ]; + + let mut projection = ProjectRelation::new(relation, projection_expr, schema); + let batch = projection.next().unwrap().unwrap(); + assert_eq!(1, batch.num_columns()); + + assert_eq!("c1", batch.schema().field(0).name()); + } + +} diff --git a/rust/datafusion/src/execution/relation.rs b/rust/datafusion/src/execution/relation.rs new file mode 100644 index 0000000000000..88eefabba67a8 --- /dev/null +++ b/rust/datafusion/src/execution/relation.rs @@ -0,0 +1,57 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::cell::RefCell; +use std::rc::Rc; +use std::sync::Arc; + +use arrow::datatypes::Schema; +use arrow::record_batch::RecordBatch; + +use super::datasource::DataSource; +use super::error::Result; + +/// trait for all relations (a relation is essentially just an iterator over rows with +/// a known schema) +pub trait Relation { + fn next(&mut self) -> Result>; + + /// get the schema for this relation + fn schema(&self) -> &Arc; +} + +pub struct DataSourceRelation { + schema: Arc, + ds: Rc>, +} + +impl DataSourceRelation { + pub fn new(ds: Rc>) -> Self { + let schema = ds.borrow().schema().clone(); + Self { ds, schema } + } +} + +impl Relation for DataSourceRelation { + fn next(&mut self) -> Result> { + self.ds.borrow_mut().next() + } + + fn schema(&self) -> &Arc { + &self.schema + } +} diff --git a/rust/datafusion/src/lib.rs b/rust/datafusion/src/lib.rs new file mode 100644 index 0000000000000..efcad5d886925 --- /dev/null +++ b/rust/datafusion/src/lib.rs @@ -0,0 +1,30 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! DataFusion is a modern distributed compute platform implemented in Rust that uses Apache Arrow +//! as the memory model + +extern crate arrow; +#[macro_use] +extern crate serde_derive; +extern crate serde_json; +extern crate sqlparser; + +pub mod dfparser; +pub mod execution; +pub mod logicalplan; +pub mod sqlplanner; diff --git a/rust/datafusion/src/logicalplan.rs b/rust/datafusion/src/logicalplan.rs new file mode 100644 index 0000000000000..b3e6bda545996 --- /dev/null +++ b/rust/datafusion/src/logicalplan.rs @@ -0,0 +1,650 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Logical query plan + +use std::fmt; +use std::fmt::{Error, Formatter}; +use std::rc::Rc; +use std::sync::Arc; + +use arrow::datatypes::*; + +#[derive(Serialize, Deserialize, Debug, Clone)] +pub enum FunctionType { + Scalar, + Aggregate, +} + +#[derive(Debug, Clone)] +pub struct FunctionMeta { + name: String, + args: Vec, + return_type: DataType, + function_type: FunctionType, +} + +impl FunctionMeta { + pub fn new( + name: String, + args: Vec, + return_type: DataType, + function_type: FunctionType, + ) -> Self { + FunctionMeta { + name, + args, + return_type, + function_type, + } + } + pub fn name(&self) -> &String { + &self.name + } + pub fn args(&self) -> &Vec { + &self.args + } + pub fn return_type(&self) -> &DataType { + &self.return_type + } + pub fn function_type(&self) -> &FunctionType { + &self.function_type + } +} + +#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq)] +pub enum Operator { + Eq, + NotEq, + Lt, + LtEq, + Gt, + GtEq, + Plus, + Minus, + Multiply, + Divide, + Modulus, + And, + Or, + Not, + Like, + NotLike, +} + +impl Operator { + /// Get the result type of applying this operation to its left and right inputs + pub fn get_datatype(&self, l: &Expr, _r: &Expr, schema: &Schema) -> DataType { + //TODO: implement correctly, just go with left side for now + l.get_type(schema).clone() + } +} + +/// ScalarValue enumeration +#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)] +pub enum ScalarValue { + Null, + Boolean(bool), + Float32(f32), + Float64(f64), + Int8(i8), + Int16(i16), + Int32(i32), + Int64(i64), + UInt8(u8), + UInt16(u16), + UInt32(u32), + UInt64(u64), + Utf8(Rc), + Struct(Vec), +} + +impl ScalarValue { + pub fn get_datatype(&self) -> DataType { + match *self { + ScalarValue::Boolean(_) => DataType::Boolean, + ScalarValue::UInt8(_) => DataType::UInt8, + ScalarValue::UInt16(_) => DataType::UInt16, + ScalarValue::UInt32(_) => DataType::UInt32, + ScalarValue::UInt64(_) => DataType::UInt64, + ScalarValue::Int8(_) => DataType::Int8, + ScalarValue::Int16(_) => DataType::Int16, + ScalarValue::Int32(_) => DataType::Int32, + ScalarValue::Int64(_) => DataType::Int64, + ScalarValue::Float32(_) => DataType::Float32, + ScalarValue::Float64(_) => DataType::Float64, + ScalarValue::Utf8(_) => DataType::Utf8, + ScalarValue::Struct(_) => unimplemented!(), + ScalarValue::Null => unimplemented!(), + } + } +} + +/// Relation Expression +#[derive(Serialize, Deserialize, Clone, PartialEq)] +pub enum Expr { + /// index into a value within the row or complex value + Column(usize), + /// literal value + Literal(ScalarValue), + /// binary expression e.g. "age > 21" + BinaryExpr { + left: Rc, + op: Operator, + right: Rc, + }, + /// unary IS NOT NULL + IsNotNull(Rc), + /// unary IS NULL + IsNull(Rc), + /// cast a value to a different type + Cast { expr: Rc, data_type: DataType }, + /// sort expression + Sort { expr: Rc, asc: bool }, + /// scalar function + ScalarFunction { + name: String, + args: Vec, + return_type: DataType, + }, + /// aggregate function + AggregateFunction { + name: String, + args: Vec, + return_type: DataType, + }, +} + +impl Expr { + pub fn get_type(&self, schema: &Schema) -> DataType { + match self { + Expr::Column(n) => schema.field(*n).data_type().clone(), + Expr::Literal(l) => l.get_datatype(), + Expr::Cast { data_type, .. } => data_type.clone(), + Expr::ScalarFunction { return_type, .. } => return_type.clone(), + Expr::AggregateFunction { return_type, .. } => return_type.clone(), + Expr::IsNull(_) => DataType::Boolean, + Expr::IsNotNull(_) => DataType::Boolean, + Expr::BinaryExpr { + ref left, + ref right, + ref op, + } => { + match op { + Operator::Eq | Operator::NotEq => DataType::Boolean, + Operator::Lt | Operator::LtEq => DataType::Boolean, + Operator::Gt | Operator::GtEq => DataType::Boolean, + Operator::And | Operator::Or => DataType::Boolean, + _ => { + let left_type = left.get_type(schema); + let right_type = right.get_type(schema); + get_supertype(&left_type, &right_type).unwrap_or(DataType::Utf8) //TODO ??? + } + } + } + Expr::Sort { ref expr, .. } => expr.get_type(schema), + } + } + + pub fn cast_to( + &self, + cast_to_type: &DataType, + schema: &Schema, + ) -> Result { + let this_type = self.get_type(schema); + if this_type == *cast_to_type { + Ok(self.clone()) + } else if can_coerce_from(cast_to_type, &this_type) { + Ok(Expr::Cast { + expr: Rc::new(self.clone()), + data_type: cast_to_type.clone(), + }) + } else { + Err(format!( + "Cannot automatically convert {:?} to {:?}", + this_type, cast_to_type + )) + } + } + + pub fn eq(&self, other: &Expr) -> Expr { + Expr::BinaryExpr { + left: Rc::new(self.clone()), + op: Operator::Eq, + right: Rc::new(other.clone()), + } + } + + pub fn not_eq(&self, other: &Expr) -> Expr { + Expr::BinaryExpr { + left: Rc::new(self.clone()), + op: Operator::NotEq, + right: Rc::new(other.clone()), + } + } + + pub fn gt(&self, other: &Expr) -> Expr { + Expr::BinaryExpr { + left: Rc::new(self.clone()), + op: Operator::Gt, + right: Rc::new(other.clone()), + } + } + + pub fn gt_eq(&self, other: &Expr) -> Expr { + Expr::BinaryExpr { + left: Rc::new(self.clone()), + op: Operator::GtEq, + right: Rc::new(other.clone()), + } + } + + pub fn lt(&self, other: &Expr) -> Expr { + Expr::BinaryExpr { + left: Rc::new(self.clone()), + op: Operator::Lt, + right: Rc::new(other.clone()), + } + } + + pub fn lt_eq(&self, other: &Expr) -> Expr { + Expr::BinaryExpr { + left: Rc::new(self.clone()), + op: Operator::LtEq, + right: Rc::new(other.clone()), + } + } +} + +impl fmt::Debug for Expr { + fn fmt(&self, f: &mut Formatter) -> Result<(), Error> { + match self { + Expr::Column(i) => write!(f, "#{}", i), + Expr::Literal(v) => write!(f, "{:?}", v), + Expr::Cast { expr, data_type } => { + write!(f, "CAST({:?} AS {:?})", expr, data_type) + } + Expr::IsNull(expr) => write!(f, "{:?} IS NULL", expr), + Expr::IsNotNull(expr) => write!(f, "{:?} IS NOT NULL", expr), + Expr::BinaryExpr { left, op, right } => { + write!(f, "{:?} {:?} {:?}", left, op, right) + } + Expr::Sort { expr, asc } => { + if *asc { + write!(f, "{:?} ASC", expr) + } else { + write!(f, "{:?} DESC", expr) + } + } + Expr::ScalarFunction { name, ref args, .. } => { + write!(f, "{}(", name)?; + for i in 0..args.len() { + if i > 0 { + write!(f, ", ")?; + } + write!(f, "{:?}", args[i])?; + } + + write!(f, ")") + } + Expr::AggregateFunction { name, ref args, .. } => { + write!(f, "{}(", name)?; + for i in 0..args.len() { + if i > 0 { + write!(f, ", ")?; + } + write!(f, "{:?}", args[i])?; + } + + write!(f, ")") + } + } + } +} + +/// The LogicalPlan represents different types of relations (such as Projection, Selection, etc) and +/// can be created by the SQL query planner and the DataFrame API. +#[derive(Serialize, Deserialize, Clone)] +pub enum LogicalPlan { + /// A Projection (essentially a SELECT with an expression list) + Projection { + expr: Vec, + input: Rc, + schema: Arc, + }, + /// A Selection (essentially a WHERE clause with a predicate expression) + Selection { expr: Expr, input: Rc }, + /// Represents a list of aggregate expressions with optional grouping expressions + Aggregate { + input: Rc, + group_expr: Vec, + aggr_expr: Vec, + schema: Arc, + }, + /// Represents a list of sort expressions to be applied to a relation + Sort { + expr: Vec, + input: Rc, + schema: Arc, + }, + /// A table scan against a table that has been registered on a context + TableScan { + schema_name: String, + table_name: String, + schema: Arc, + projection: Option>, + }, + /// An empty relation with an empty schema + EmptyRelation { schema: Arc }, +} + +impl LogicalPlan { + /// Get a reference to the logical plan's schema + pub fn schema(&self) -> &Arc { + match self { + LogicalPlan::EmptyRelation { schema } => &schema, + LogicalPlan::TableScan { schema, .. } => &schema, + LogicalPlan::Projection { schema, .. } => &schema, + LogicalPlan::Selection { input, .. } => input.schema(), + LogicalPlan::Aggregate { schema, .. } => &schema, + LogicalPlan::Sort { schema, .. } => &schema, + } + } +} + +impl LogicalPlan { + fn fmt_with_indent(&self, f: &mut Formatter, indent: usize) -> Result<(), Error> { + if indent > 0 { + writeln!(f)?; + for _ in 0..indent { + write!(f, " ")?; + } + } + match *self { + LogicalPlan::EmptyRelation { .. } => write!(f, "EmptyRelation"), + LogicalPlan::TableScan { + ref table_name, + ref projection, + .. + } => write!(f, "TableScan: {} projection={:?}", table_name, projection), + LogicalPlan::Projection { + ref expr, + ref input, + .. + } => { + write!(f, "Projection: ")?; + for i in 0..expr.len() { + if i > 0 { + write!(f, ", ")?; + } + write!(f, "{:?}", expr[i])?; + } + input.fmt_with_indent(f, indent + 1) + } + LogicalPlan::Selection { + ref expr, + ref input, + .. + } => { + write!(f, "Selection: {:?}", expr)?; + input.fmt_with_indent(f, indent + 1) + } + LogicalPlan::Aggregate { + ref input, + ref group_expr, + ref aggr_expr, + .. + } => { + write!( + f, + "Aggregate: groupBy=[{:?}], aggr=[{:?}]", + group_expr, aggr_expr + )?; + input.fmt_with_indent(f, indent + 1) + } + LogicalPlan::Sort { + ref input, + ref expr, + .. + } => { + write!(f, "Sort: ")?; + for i in 0..expr.len() { + if i > 0 { + write!(f, ", ")?; + } + write!(f, "{:?}", expr[i])?; + } + input.fmt_with_indent(f, indent + 1) + } + } + } +} + +impl fmt::Debug for LogicalPlan { + fn fmt(&self, f: &mut Formatter) -> Result<(), Error> { + self.fmt_with_indent(f, 0) + } +} + +//TODO move to Arrow DataType impl? +pub fn get_supertype(l: &DataType, r: &DataType) -> Option { + match _get_supertype(l, r) { + Some(dt) => Some(dt), + None => match _get_supertype(r, l) { + Some(dt) => Some(dt), + None => None, + }, + } +} + +fn _get_supertype(l: &DataType, r: &DataType) -> Option { + use self::DataType::*; + match (l, r) { + (UInt8, Int8) => Some(Int8), + (UInt8, Int16) => Some(Int16), + (UInt8, Int32) => Some(Int32), + (UInt8, Int64) => Some(Int64), + + (UInt16, Int16) => Some(Int16), + (UInt16, Int32) => Some(Int32), + (UInt16, Int64) => Some(Int64), + + (UInt32, Int32) => Some(Int32), + (UInt32, Int64) => Some(Int64), + + (UInt64, Int64) => Some(Int64), + + (Int8, UInt8) => Some(Int8), + + (Int16, UInt8) => Some(Int16), + (Int16, UInt16) => Some(Int16), + + (Int32, UInt8) => Some(Int32), + (Int32, UInt16) => Some(Int32), + (Int32, UInt32) => Some(Int32), + + (Int64, UInt8) => Some(Int64), + (Int64, UInt16) => Some(Int64), + (Int64, UInt32) => Some(Int64), + (Int64, UInt64) => Some(Int64), + + (UInt8, UInt8) => Some(UInt8), + (UInt8, UInt16) => Some(UInt16), + (UInt8, UInt32) => Some(UInt32), + (UInt8, UInt64) => Some(UInt64), + (UInt8, Float32) => Some(Float32), + (UInt8, Float64) => Some(Float64), + + (UInt16, UInt8) => Some(UInt16), + (UInt16, UInt16) => Some(UInt16), + (UInt16, UInt32) => Some(UInt32), + (UInt16, UInt64) => Some(UInt64), + (UInt16, Float32) => Some(Float32), + (UInt16, Float64) => Some(Float64), + + (UInt32, UInt8) => Some(UInt32), + (UInt32, UInt16) => Some(UInt32), + (UInt32, UInt32) => Some(UInt32), + (UInt32, UInt64) => Some(UInt64), + (UInt32, Float32) => Some(Float32), + (UInt32, Float64) => Some(Float64), + + (UInt64, UInt8) => Some(UInt64), + (UInt64, UInt16) => Some(UInt64), + (UInt64, UInt32) => Some(UInt64), + (UInt64, UInt64) => Some(UInt64), + (UInt64, Float32) => Some(Float32), + (UInt64, Float64) => Some(Float64), + + (Int8, Int8) => Some(Int8), + (Int8, Int16) => Some(Int16), + (Int8, Int32) => Some(Int32), + (Int8, Int64) => Some(Int64), + (Int8, Float32) => Some(Float32), + (Int8, Float64) => Some(Float64), + + (Int16, Int8) => Some(Int16), + (Int16, Int16) => Some(Int16), + (Int16, Int32) => Some(Int32), + (Int16, Int64) => Some(Int64), + (Int16, Float32) => Some(Float32), + (Int16, Float64) => Some(Float64), + + (Int32, Int8) => Some(Int32), + (Int32, Int16) => Some(Int32), + (Int32, Int32) => Some(Int32), + (Int32, Int64) => Some(Int64), + (Int32, Float32) => Some(Float32), + (Int32, Float64) => Some(Float64), + + (Int64, Int8) => Some(Int64), + (Int64, Int16) => Some(Int64), + (Int64, Int32) => Some(Int64), + (Int64, Int64) => Some(Int64), + (Int64, Float32) => Some(Float32), + (Int64, Float64) => Some(Float64), + + (Float32, Float32) => Some(Float32), + (Float32, Float64) => Some(Float64), + (Float64, Float32) => Some(Float64), + (Float64, Float64) => Some(Float64), + + (Utf8, Utf8) => Some(Utf8), + + (Boolean, Boolean) => Some(Boolean), + + _ => None, + } +} + +pub fn can_coerce_from(left: &DataType, other: &DataType) -> bool { + use self::DataType::*; + match left { + Int8 => match other { + Int8 => true, + _ => false, + }, + Int16 => match other { + Int8 | Int16 => true, + _ => false, + }, + Int32 => match other { + Int8 | Int16 | Int32 => true, + _ => false, + }, + Int64 => match other { + Int8 | Int16 | Int32 | Int64 => true, + _ => false, + }, + UInt8 => match other { + UInt8 => true, + _ => false, + }, + UInt16 => match other { + UInt8 | UInt16 => true, + _ => false, + }, + UInt32 => match other { + UInt8 | UInt16 | UInt32 => true, + _ => false, + }, + UInt64 => match other { + UInt8 | UInt16 | UInt32 | UInt64 => true, + _ => false, + }, + Float32 => match other { + Int8 | Int16 | Int32 | Int64 => true, + UInt8 | UInt16 | UInt32 | UInt64 => true, + Float32 => true, + _ => false, + }, + Float64 => match other { + Int8 | Int16 | Int32 | Int64 => true, + UInt8 | UInt16 | UInt32 | UInt64 => true, + Float32 | Float64 => true, + _ => false, + }, + _ => false, + } +} + +#[cfg(test)] +mod tests { + use super::*; + use serde_json; + + #[test] + fn serialize_plan() { + let schema = Schema::new(vec![ + Field::new("first_name", DataType::Utf8, false), + Field::new("last_name", DataType::Utf8, false), + Field::new( + "address", + DataType::Struct(vec![ + Field::new("street", DataType::Utf8, false), + Field::new("zip", DataType::UInt16, false), + ]), + false, + ), + ]); + + let plan = LogicalPlan::TableScan { + schema_name: "".to_string(), + table_name: "people".to_string(), + schema: Arc::new(schema), + projection: Some(vec![0, 1, 4]), + }; + + let serialized = serde_json::to_string(&plan).unwrap(); + + assert_eq!( + "{\"TableScan\":{\ + \"schema_name\":\"\",\ + \"table_name\":\"people\",\ + \"schema\":{\"fields\":[\ + {\"name\":\"first_name\",\"data_type\":\"Utf8\",\"nullable\":false},\ + {\"name\":\"last_name\",\"data_type\":\"Utf8\",\"nullable\":false},\ + {\"name\":\"address\",\"data_type\":{\"Struct\":\ + [\ + {\"name\":\"street\",\"data_type\":\"Utf8\",\"nullable\":false},\ + {\"name\":\"zip\",\"data_type\":\"UInt16\",\"nullable\":false}]},\"nullable\":false}\ + ]},\ + \"projection\":[0,1,4]}}", + serialized + ); + } +} diff --git a/rust/datafusion/src/sqlplanner.rs b/rust/datafusion/src/sqlplanner.rs new file mode 100644 index 0000000000000..dcb69ebc33f0a --- /dev/null +++ b/rust/datafusion/src/sqlplanner.rs @@ -0,0 +1,687 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! SQL Query Planner (produces logical plan from SQL AST) + +use std::collections::HashSet; +use std::rc::Rc; +use std::string::String; +use std::sync::Arc; + +use super::execution::error::*; +use super::logicalplan::*; + +use arrow::datatypes::*; + +use sqlparser::sqlast::*; + +pub trait SchemaProvider { + fn get_table_meta(&self, name: &str) -> Option>; + fn get_function_meta(&self, name: &str) -> Option>; +} + +/// SQL query planner +pub struct SqlToRel { + schema_provider: Rc, +} + +impl SqlToRel { + /// Create a new query planner + pub fn new(schema_provider: Rc) -> Self { + SqlToRel { schema_provider } + } + + /// Generate a logic plan from a SQL AST node + pub fn sql_to_rel(&self, sql: &ASTNode) -> Result> { + match sql { + &ASTNode::SQLSelect { + ref projection, + ref relation, + ref selection, + ref order_by, + ref group_by, + ref having, + .. + } => { + // parse the input relation so we have access to the row type + let input = match relation { + &Some(ref r) => self.sql_to_rel(r)?, + &None => Rc::new(LogicalPlan::EmptyRelation { + schema: Arc::new(Schema::empty()), + }), + }; + + let input_schema = input.schema(); + + // selection first + let selection_plan = match selection { + &Some(ref filter_expr) => Some(LogicalPlan::Selection { + expr: self.sql_to_rex(&filter_expr, &input_schema.clone())?, + input: input.clone(), + }), + _ => None, + }; + + let expr: Vec = projection + .iter() + .map(|e| self.sql_to_rex(&e, &input_schema)) + .collect::>>()?; + + // collect aggregate expressions + let aggr_expr: Vec = expr + .iter() + .filter(|e| match e { + Expr::AggregateFunction { .. } => true, + _ => false, + }) + .map(|e| e.clone()) + .collect(); + + if aggr_expr.len() > 0 { + let aggregate_input: Rc = match selection_plan { + Some(s) => Rc::new(s), + _ => input.clone(), + }; + + let group_expr: Vec = match group_by { + Some(gbe) => gbe + .iter() + .map(|e| self.sql_to_rex(&e, &input_schema)) + .collect::>>()?, + None => vec![], + }; + //println!("GROUP BY: {:?}", group_expr); + + let mut all_fields: Vec = group_expr.clone(); + aggr_expr.iter().for_each(|x| all_fields.push(x.clone())); + + let aggr_schema = + Schema::new(exprlist_to_fields(&all_fields, input_schema)); + + //TODO: selection, projection, everything else + Ok(Rc::new(LogicalPlan::Aggregate { + input: aggregate_input, + group_expr, + aggr_expr, + schema: Arc::new(aggr_schema), + })) + } else { + let projection_input: Rc = match selection_plan { + Some(s) => Rc::new(s), + _ => input.clone(), + }; + + let projection_schema = Arc::new(Schema::new(exprlist_to_fields( + &expr, + input_schema.as_ref(), + ))); + + let projection = LogicalPlan::Projection { + expr: expr, + input: projection_input, + schema: projection_schema.clone(), + }; + + if let &Some(_) = having { + return Err(ExecutionError::General( + "HAVING is not implemented yet".to_string(), + )); + } + + let order_by_plan = match order_by { + &Some(ref order_by_expr) => { + let input_schema = projection.schema(); + let order_by_rex: Result> = order_by_expr + .iter() + .map(|e| { + Ok(Expr::Sort { + expr: Rc::new( + self.sql_to_rex(&e.expr, &input_schema) + .unwrap(), + ), + asc: e.asc, + }) + }) + .collect(); + + LogicalPlan::Sort { + expr: order_by_rex?, + input: Rc::new(projection.clone()), + schema: input_schema.clone(), + } + } + _ => projection, + }; + + Ok(Rc::new(order_by_plan)) + } + } + + &ASTNode::SQLIdentifier(ref id) => { + match self.schema_provider.get_table_meta(id.as_ref()) { + Some(schema) => Ok(Rc::new(LogicalPlan::TableScan { + schema_name: String::from("default"), + table_name: id.clone(), + schema: schema.clone(), + projection: None, + })), + None => Err(ExecutionError::General(format!( + "no schema found for table {}", + id + ))), + } + } + + _ => Err(ExecutionError::ExecutionError(format!( + "sql_to_rel does not support this relation: {:?}", + sql + ))), + } + } + + /// Generate a relational expression from a SQL expression + pub fn sql_to_rex(&self, sql: &ASTNode, schema: &Schema) -> Result { + match sql { + &ASTNode::SQLValue(sqlparser::sqlast::Value::Long(n)) => { + Ok(Expr::Literal(ScalarValue::Int64(n))) + } + &ASTNode::SQLValue(sqlparser::sqlast::Value::Double(n)) => { + Ok(Expr::Literal(ScalarValue::Float64(n))) + } + &ASTNode::SQLValue(sqlparser::sqlast::Value::SingleQuotedString(ref s)) => { + Ok(Expr::Literal(ScalarValue::Utf8(Rc::new(s.clone())))) + } + + &ASTNode::SQLIdentifier(ref id) => { + match schema.fields().iter().position(|c| c.name().eq(id)) { + Some(index) => Ok(Expr::Column(index)), + None => Err(ExecutionError::ExecutionError(format!( + "Invalid identifier '{}' for schema {}", + id, + schema.to_string() + ))), + } + } + + &ASTNode::SQLWildcard => { + // schema.columns().iter().enumerate() + // .map(|(i,c)| Ok(Expr::Column(i))).collect() + unimplemented!("SQL wildcard operator is not supported in projection - please use explicit column names") + } + + &ASTNode::SQLCast { + ref expr, + ref data_type, + } => Ok(Expr::Cast { + expr: Rc::new(self.sql_to_rex(&expr, schema)?), + data_type: convert_data_type(data_type)?, + }), + + &ASTNode::SQLIsNull(ref expr) => { + Ok(Expr::IsNull(Rc::new(self.sql_to_rex(expr, schema)?))) + } + + &ASTNode::SQLIsNotNull(ref expr) => { + Ok(Expr::IsNotNull(Rc::new(self.sql_to_rex(expr, schema)?))) + } + + &ASTNode::SQLBinaryExpr { + ref left, + ref op, + ref right, + } => { + let operator = match op { + &SQLOperator::Gt => Operator::Gt, + &SQLOperator::GtEq => Operator::GtEq, + &SQLOperator::Lt => Operator::Lt, + &SQLOperator::LtEq => Operator::LtEq, + &SQLOperator::Eq => Operator::Eq, + &SQLOperator::NotEq => Operator::NotEq, + &SQLOperator::Plus => Operator::Plus, + &SQLOperator::Minus => Operator::Minus, + &SQLOperator::Multiply => Operator::Multiply, + &SQLOperator::Divide => Operator::Divide, + &SQLOperator::Modulus => Operator::Modulus, + &SQLOperator::And => Operator::And, + &SQLOperator::Or => Operator::Or, + &SQLOperator::Not => Operator::Not, + &SQLOperator::Like => Operator::Like, + &SQLOperator::NotLike => Operator::NotLike, + }; + + let left_expr = self.sql_to_rex(&left, &schema)?; + let right_expr = self.sql_to_rex(&right, &schema)?; + let left_type = left_expr.get_type(schema); + let right_type = right_expr.get_type(schema); + + match get_supertype(&left_type, &right_type) { + Some(supertype) => Ok(Expr::BinaryExpr { + left: Rc::new(left_expr.cast_to(&supertype, schema)?), + op: operator, + right: Rc::new(right_expr.cast_to(&supertype, schema)?), + }), + None => { + return Err(ExecutionError::General(format!( + "No common supertype found for binary operator {:?} \ + with input types {:?} and {:?}", + operator, left_type, right_type + ))); + } + } + } + + // &ASTNode::SQLOrderBy { ref expr, asc } => Ok(Expr::Sort { + // expr: Rc::new(self.sql_to_rex(&expr, &schema)?), + // asc, + // }), + &ASTNode::SQLFunction { ref id, ref args } => { + //TODO: fix this hack + match id.to_lowercase().as_ref() { + "min" | "max" | "sum" | "avg" => { + let rex_args = args + .iter() + .map(|a| self.sql_to_rex(a, schema)) + .collect::>>()?; + + // return type is same as the argument type for these aggregate functions + let return_type = rex_args[0].get_type(schema).clone(); + + Ok(Expr::AggregateFunction { + name: id.clone(), + args: rex_args, + return_type, + }) + } + "count" => { + let rex_args = args + .iter() + .map(|a| match a { + // this feels hacky but translate COUNT(1)/COUNT(*) to COUNT(first_column) + ASTNode::SQLValue(sqlparser::sqlast::Value::Long(1)) => { + Ok(Expr::Column(0)) + } + ASTNode::SQLWildcard => Ok(Expr::Column(0)), + _ => self.sql_to_rex(a, schema), + }) + .collect::>>()?; + + Ok(Expr::AggregateFunction { + name: id.clone(), + args: rex_args, + return_type: DataType::UInt64, + }) + } + _ => match self.schema_provider.get_function_meta(id) { + Some(fm) => { + let rex_args = args + .iter() + .map(|a| self.sql_to_rex(a, schema)) + .collect::>>()?; + + let mut safe_args: Vec = vec![]; + for i in 0..rex_args.len() { + safe_args.push( + rex_args[i] + .cast_to(fm.args()[i].data_type(), schema)?, + ); + } + + Ok(Expr::ScalarFunction { + name: id.clone(), + args: safe_args, + return_type: fm.return_type().clone(), + }) + } + _ => Err(ExecutionError::General(format!( + "Invalid function '{}'", + id + ))), + }, + } + } + + _ => Err(ExecutionError::General(format!( + "Unsupported ast node {:?} in sqltorel", + sql + ))), + } + } +} + +/// Convert SQL data type to relational representation of data type +pub fn convert_data_type(sql: &SQLType) -> Result { + match sql { + SQLType::Boolean => Ok(DataType::Boolean), + SQLType::SmallInt => Ok(DataType::Int16), + SQLType::Int => Ok(DataType::Int32), + SQLType::BigInt => Ok(DataType::Int64), + SQLType::Float(_) | SQLType::Real => Ok(DataType::Float64), + SQLType::Double => Ok(DataType::Float64), + SQLType::Char(_) | SQLType::Varchar(_) => Ok(DataType::Utf8), + other => Err(ExecutionError::NotImplemented(format!( + "Unsupported SQL type {:?}", + other + ))), + } +} + +pub fn expr_to_field(e: &Expr, input_schema: &Schema) -> Field { + match e { + Expr::Column(i) => input_schema.fields()[*i].clone(), + Expr::Literal(ref lit) => Field::new("lit", lit.get_datatype(), true), + Expr::ScalarFunction { + ref name, + ref return_type, + .. + } => Field::new(name, return_type.clone(), true), + Expr::AggregateFunction { + ref name, + ref return_type, + .. + } => Field::new(name, return_type.clone(), true), + Expr::Cast { ref data_type, .. } => Field::new("cast", data_type.clone(), true), + Expr::BinaryExpr { + ref left, + ref right, + .. + } => { + let left_type = left.get_type(input_schema); + let right_type = right.get_type(input_schema); + Field::new( + "binary_expr", + get_supertype(&left_type, &right_type).unwrap(), + true, + ) + } + _ => unimplemented!("Cannot determine schema type for expression {:?}", e), + } +} + +pub fn exprlist_to_fields(expr: &Vec, input_schema: &Schema) -> Vec { + expr.iter() + .map(|e| expr_to_field(e, input_schema)) + .collect() +} + +fn collect_expr(e: &Expr, accum: &mut HashSet) { + match e { + Expr::Column(i) => { + accum.insert(*i); + } + Expr::Cast { ref expr, .. } => collect_expr(expr, accum), + Expr::Literal(_) => {} + Expr::IsNotNull(ref expr) => collect_expr(expr, accum), + Expr::IsNull(ref expr) => collect_expr(expr, accum), + Expr::BinaryExpr { + ref left, + ref right, + .. + } => { + collect_expr(left, accum); + collect_expr(right, accum); + } + Expr::AggregateFunction { ref args, .. } => { + args.iter().for_each(|e| collect_expr(e, accum)); + } + Expr::ScalarFunction { ref args, .. } => { + args.iter().for_each(|e| collect_expr(e, accum)); + } + Expr::Sort { ref expr, .. } => collect_expr(expr, accum), + } +} + +pub fn push_down_projection( + plan: &Rc, + projection: &HashSet, +) -> Rc { + //println!("push_down_projection() projection={:?}", projection); + match plan.as_ref() { + LogicalPlan::Aggregate { + ref input, + ref group_expr, + ref aggr_expr, + ref schema, + } => { + //TODO: apply projection first + let mut accum: HashSet = HashSet::new(); + group_expr.iter().for_each(|e| collect_expr(e, &mut accum)); + aggr_expr.iter().for_each(|e| collect_expr(e, &mut accum)); + Rc::new(LogicalPlan::Aggregate { + input: push_down_projection(&input, &accum), + group_expr: group_expr.clone(), + aggr_expr: aggr_expr.clone(), + schema: schema.clone(), + }) + } + LogicalPlan::Selection { + ref expr, + ref input, + } => { + let mut accum: HashSet = projection.clone(); + collect_expr(expr, &mut accum); + Rc::new(LogicalPlan::Selection { + expr: expr.clone(), + input: push_down_projection(&input, &accum), + }) + } + LogicalPlan::TableScan { + ref schema_name, + ref table_name, + ref schema, + .. + } => Rc::new(LogicalPlan::TableScan { + schema_name: schema_name.to_string(), + table_name: table_name.to_string(), + schema: schema.clone(), + projection: Some(projection.iter().cloned().collect()), + }), + LogicalPlan::Projection { .. } => plan.clone(), + LogicalPlan::Sort { .. } => plan.clone(), + LogicalPlan::EmptyRelation { .. } => plan.clone(), + } +} + +#[cfg(test)] +mod tests { + + use super::*; + use sqlparser::sqlparser::*; + + #[test] + fn select_no_relation() { + quick_test( + "SELECT 1", + "Projection: Int64(1)\ + \n EmptyRelation", + ); + } + + #[test] + fn select_scalar_func_with_literal_no_relation() { + quick_test( + "SELECT sqrt(9)", + "Projection: sqrt(CAST(Int64(9) AS Float64))\ + \n EmptyRelation", + ); + } + + #[test] + fn select_simple_selection() { + let sql = "SELECT id, first_name, last_name \ + FROM person WHERE state = 'CO'"; + let expected = "Projection: #0, #1, #2\ + \n Selection: #4 Eq Utf8(\"CO\")\ + \n TableScan: person projection=None"; + quick_test(sql, expected); + } + + #[test] + fn select_compound_selection() { + let sql = "SELECT id, first_name, last_name \ + FROM person WHERE state = 'CO' AND age >= 21 AND age <= 65"; + let expected = + "Projection: #0, #1, #2\ + \n Selection: #4 Eq Utf8(\"CO\") And CAST(#3 AS Int64) GtEq Int64(21) And CAST(#3 AS Int64) LtEq Int64(65)\ + \n TableScan: person projection=None"; + quick_test(sql, expected); + } + + #[test] + fn select_all_boolean_operators() { + let sql = "SELECT age, first_name, last_name \ + FROM person \ + WHERE age = 21 \ + AND age != 21 \ + AND age > 21 \ + AND age >= 21 \ + AND age < 65 \ + AND age <= 65"; + let expected = "Projection: #3, #1, #2\ + \n Selection: CAST(#3 AS Int64) Eq Int64(21) \ + And CAST(#3 AS Int64) NotEq Int64(21) \ + And CAST(#3 AS Int64) Gt Int64(21) \ + And CAST(#3 AS Int64) GtEq Int64(21) \ + And CAST(#3 AS Int64) Lt Int64(65) \ + And CAST(#3 AS Int64) LtEq Int64(65)\ + \n TableScan: person projection=None"; + quick_test(sql, expected); + } + + #[test] + fn select_simple_aggregate() { + quick_test( + "SELECT MIN(age) FROM person", + "Aggregate: groupBy=[[]], aggr=[[MIN(#3)]]\ + \n TableScan: person projection=None", + ); + } + + #[test] + fn test_sum_aggregate() { + quick_test( + "SELECT SUM(age) from person", + "Aggregate: groupBy=[[]], aggr=[[SUM(#3)]]\ + \n TableScan: person projection=None", + ); + } + + #[test] + fn select_simple_aggregate_with_groupby() { + quick_test( + "SELECT state, MIN(age), MAX(age) FROM person GROUP BY state", + "Aggregate: groupBy=[[#4]], aggr=[[MIN(#3), MAX(#3)]]\ + \n TableScan: person projection=None", + ); + } + + #[test] + fn select_count_one() { + let sql = "SELECT COUNT(1) FROM person"; + let expected = "Aggregate: groupBy=[[]], aggr=[[COUNT(#0)]]\ + \n TableScan: person projection=None"; + quick_test(sql, expected); + } + + #[test] + fn select_scalar_func() { + let sql = "SELECT sqrt(age) FROM person"; + let expected = "Projection: sqrt(CAST(#3 AS Float64))\ + \n TableScan: person projection=None"; + quick_test(sql, expected); + } + + #[test] + fn select_order_by() { + let sql = "SELECT id FROM person ORDER BY id"; + let expected = "Sort: #0 ASC\ + \n Projection: #0\ + \n TableScan: person projection=None"; + quick_test(sql, expected); + } + + #[test] + fn select_order_by_desc() { + let sql = "SELECT id FROM person ORDER BY id DESC"; + let expected = "Sort: #0 DESC\ + \n Projection: #0\ + \n TableScan: person projection=None"; + quick_test(sql, expected); + } + + #[test] + fn test_collect_expr() { + let mut accum: HashSet = HashSet::new(); + collect_expr( + &Expr::Cast { + expr: Rc::new(Expr::Column(3)), + data_type: DataType::Float64, + }, + &mut accum, + ); + collect_expr( + &Expr::Cast { + expr: Rc::new(Expr::Column(3)), + data_type: DataType::Float64, + }, + &mut accum, + ); + println!("accum: {:?}", accum); + assert_eq!(1, accum.len()); + assert!(accum.contains(&3)); + } + + /// Create logical plan, write with formatter, compare to expected output + fn quick_test(sql: &str, expected: &str) { + use sqlparser::dialect::*; + let dialect = GenericSqlDialect {}; + let planner = SqlToRel::new(Rc::new(MockSchemaProvider {})); + let ast = Parser::parse_sql(&dialect, sql.to_string()).unwrap(); + let plan = planner.sql_to_rel(&ast).unwrap(); + assert_eq!(expected, format!("{:?}", plan)); + } + + struct MockSchemaProvider {} + + impl SchemaProvider for MockSchemaProvider { + fn get_table_meta(&self, name: &str) -> Option> { + match name { + "person" => Some(Arc::new(Schema::new(vec![ + Field::new("id", DataType::UInt32, false), + Field::new("first_name", DataType::Utf8, false), + Field::new("last_name", DataType::Utf8, false), + Field::new("age", DataType::Int32, false), + Field::new("state", DataType::Utf8, false), + Field::new("salary", DataType::Float64, false), + ]))), + _ => None, + } + } + + fn get_function_meta(&self, name: &str) -> Option> { + match name { + "sqrt" => Some(Arc::new(FunctionMeta::new( + "sqrt".to_string(), + vec![Field::new("n", DataType::Float64, false)], + DataType::Float64, + FunctionType::Scalar, + ))), + _ => None, + } + } + } + +} diff --git a/rust/datafusion/tests/sql.rs b/rust/datafusion/tests/sql.rs new file mode 100644 index 0000000000000..bd228087dcc6e --- /dev/null +++ b/rust/datafusion/tests/sql.rs @@ -0,0 +1,191 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::cell::RefCell; +use std::rc::Rc; +use std::sync::Arc; + +extern crate arrow; +extern crate datafusion; + +use arrow::array::*; +use arrow::datatypes::{DataType, Field, Schema}; + +use datafusion::execution::context::ExecutionContext; +use datafusion::execution::datasource::CsvDataSource; +use datafusion::execution::relation::Relation; + +#[test] +fn csv_query_with_predicate() { + let mut ctx = ExecutionContext::new(); + register_aggregate_csv(&mut ctx); + let sql = "SELECT c1, c12 FROM aggregate_test_100 WHERE c12 > 0.376 AND c12 < 0.4"; + let actual = execute(&mut ctx, sql); + let expected = "\"e\"\t0.39144436569161134\n\"d\"\t0.38870280983958583\n".to_string(); + assert_eq!(expected, actual); +} + +#[test] +fn csv_query_group_by_int_min_max() { + let mut ctx = ExecutionContext::new(); + register_aggregate_csv(&mut ctx); + //TODO add ORDER BY once supported, to make this test determistic + let sql = "SELECT c2, MIN(c12), MAX(c12) FROM aggregate_test_100 GROUP BY c2"; + let actual = execute(&mut ctx, sql); + let expected = "4\t0.02182578039211991\t0.9237877978193884\n2\t0.16301110515739792\t0.991517828651004\n5\t0.01479305307777301\t0.9723580396501548\n3\t0.047343434291126085\t0.9293883502480845\n1\t0.05636955101974106\t0.9965400387585364\n".to_string(); + assert_eq!(expected, actual); +} + +#[test] +fn csv_query_group_by_string_min_max() { + let mut ctx = ExecutionContext::new(); + register_aggregate_csv(&mut ctx); + //TODO add ORDER BY once supported, to make this test determistic + let sql = "SELECT c2, MIN(c12), MAX(c12) FROM aggregate_test_100 GROUP BY c1"; + let actual = execute(&mut ctx, sql); + let expected = + "\"d\"\t0.061029375346466685\t0.9748360509016578\n\"c\"\t0.0494924465469434\t0.991517828651004\n\"b\"\t0.04893135681998029\t0.9185813970744787\n\"a\"\t0.02182578039211991\t0.9800193410444061\n\"e\"\t0.01479305307777301\t0.9965400387585364\n".to_string(); + assert_eq!(expected, actual); +} + +#[test] +fn csv_query_cast() { + let mut ctx = ExecutionContext::new(); + register_aggregate_csv(&mut ctx); + let sql = "SELECT CAST(c12 AS float) FROM aggregate_test_100 WHERE c12 > 0.376 AND c12 < 0.4"; + let actual = execute(&mut ctx, sql); + let expected = "0.39144436569161134\n0.38870280983958583\n".to_string(); + assert_eq!(expected, actual); +} + +fn aggr_test_schema() -> Arc { + Arc::new(Schema::new(vec![ + Field::new("c1", DataType::Utf8, false), + Field::new("c2", DataType::UInt32, false), + Field::new("c3", DataType::Int8, false), + Field::new("c4", DataType::Int16, false), + Field::new("c5", DataType::Int32, false), + Field::new("c6", DataType::Int64, false), + Field::new("c7", DataType::UInt8, false), + Field::new("c8", DataType::UInt16, false), + Field::new("c9", DataType::UInt32, false), + Field::new("c10", DataType::UInt64, false), + Field::new("c11", DataType::Float32, false), + Field::new("c12", DataType::Float64, false), + Field::new("c13", DataType::Utf8, false), + ])) +} + +fn register_aggregate_csv(ctx: &mut ExecutionContext) { + let schema = aggr_test_schema(); + register_csv( + ctx, + "aggregate_test_100", + "../../testing/data/csv/aggregate_test_100.csv", + &schema, + ); +} + +fn register_csv( + ctx: &mut ExecutionContext, + name: &str, + filename: &str, + schema: &Arc, +) { + let csv_datasource = CsvDataSource::new(filename, schema.clone(), 1024); + ctx.register_datasource(name, Rc::new(RefCell::new(csv_datasource))); +} + +/// Execute query and return result set as tab delimited string +fn execute(ctx: &mut ExecutionContext, sql: &str) -> String { + let results = ctx.sql(&sql).unwrap(); + result_str(&results) +} + +fn result_str(results: &Rc>) -> String { + let mut relation = results.borrow_mut(); + let mut str = String::new(); + while let Some(batch) = relation.next().unwrap() { + for row_index in 0..batch.num_rows() { + for column_index in 0..batch.num_columns() { + if column_index > 0 { + str.push_str("\t"); + } + let column = batch.column(column_index); + + match column.data_type() { + DataType::Int8 => { + let array = column.as_any().downcast_ref::().unwrap(); + str.push_str(&format!("{:?}", array.value(row_index))); + } + DataType::Int16 => { + let array = column.as_any().downcast_ref::().unwrap(); + str.push_str(&format!("{:?}", array.value(row_index))); + } + DataType::Int32 => { + let array = column.as_any().downcast_ref::().unwrap(); + str.push_str(&format!("{:?}", array.value(row_index))); + } + DataType::Int64 => { + let array = column.as_any().downcast_ref::().unwrap(); + str.push_str(&format!("{:?}", array.value(row_index))); + } + DataType::UInt8 => { + let array = column.as_any().downcast_ref::().unwrap(); + str.push_str(&format!("{:?}", array.value(row_index))); + } + DataType::UInt16 => { + let array = + column.as_any().downcast_ref::().unwrap(); + str.push_str(&format!("{:?}", array.value(row_index))); + } + DataType::UInt32 => { + let array = + column.as_any().downcast_ref::().unwrap(); + str.push_str(&format!("{:?}", array.value(row_index))); + } + DataType::UInt64 => { + let array = + column.as_any().downcast_ref::().unwrap(); + str.push_str(&format!("{:?}", array.value(row_index))); + } + DataType::Float32 => { + let array = + column.as_any().downcast_ref::().unwrap(); + str.push_str(&format!("{:?}", array.value(row_index))); + } + DataType::Float64 => { + let array = + column.as_any().downcast_ref::().unwrap(); + str.push_str(&format!("{:?}", array.value(row_index))); + } + DataType::Utf8 => { + let array = + column.as_any().downcast_ref::().unwrap(); + let s = + String::from_utf8(array.value(row_index).to_vec()).unwrap(); + + str.push_str(&format!("{:?}", s)); + } + _ => str.push_str("???"), + } + } + str.push_str("\n"); + } + } + str +} diff --git a/testing b/testing index 6ee39a9d17b09..bf0abe442bf7e 160000 --- a/testing +++ b/testing @@ -1 +1 @@ -Subproject commit 6ee39a9d17b0902b7372c22b9ff823304c69d709 +Subproject commit bf0abe442bf7e313380452c8972692940f4e56b6 From 170a7ff9dff8909311eb7849db1cf6a64ac8c4e3 Mon Sep 17 00:00:00 2001 From: "Korn, Uwe" Date: Mon, 4 Feb 2019 13:20:34 -0600 Subject: [PATCH 07/21] ARROW-4471: [C++] Pass AR and RANLIB to all external projects With the latest updates I have problems linking to gbenchmark on OSX, this fixes it. Author: Korn, Uwe Closes #3550 from xhochy/ARROW-4471 and squashes the following commits: b0608b0a7 Use linebreaks to separate arguments 9c2071446 ARROW-4471: Pass AR and RANLIB to all external projects --- cpp/CMakeLists.txt | 2 ++ cpp/cmake_modules/ThirdpartyToolchain.cmake | 10 ++++++++++ 2 files changed, 12 insertions(+) diff --git a/cpp/CMakeLists.txt b/cpp/CMakeLists.txt index e0dbcd305e92e..707514a34cb84 100644 --- a/cpp/CMakeLists.txt +++ b/cpp/CMakeLists.txt @@ -58,6 +58,8 @@ include(FindPkgConfig) include(GNUInstallDirs) +cmake_policy(SET CMP0025 NEW) + # Compatibility with CMake 3.1 if(POLICY CMP0054) # http://www.cmake.org/cmake/help/v3.1/policy/CMP0054.html diff --git a/cpp/cmake_modules/ThirdpartyToolchain.cmake b/cpp/cmake_modules/ThirdpartyToolchain.cmake index ff2252528fdf3..5f3b54ca3d92f 100644 --- a/cpp/cmake_modules/ThirdpartyToolchain.cmake +++ b/cpp/cmake_modules/ThirdpartyToolchain.cmake @@ -355,6 +355,16 @@ set(EP_COMMON_CMAKE_ARGS ${EP_COMMON_CMAKE_ARGS} -DCMAKE_CXX_FLAGS=${EP_CXX_FLAGS} -DCMAKE_CXX_FLAGS_${UPPERCASE_BUILD_TYPE}=${EP_CXX_FLAGS}) +if (CMAKE_AR) + set(EP_COMMON_CMAKE_ARGS ${EP_COMMON_CMAKE_ARGS} + -DCMAKE_AR=${CMAKE_AR}) +endif() + +if (CMAKE_RANLIB) + set(EP_COMMON_CMAKE_ARGS ${EP_COMMON_CMAKE_ARGS} + -DCMAKE_RANLIB=${CMAKE_RANLIB}) +endif() + if (NOT ARROW_VERBOSE_THIRDPARTY_BUILD) set(EP_LOG_OPTIONS LOG_CONFIGURE 1 From 2e161228bcbc77ad5395c51e8fdbb7f410c199a2 Mon Sep 17 00:00:00 2001 From: Wes McKinney Date: Mon, 4 Feb 2019 13:53:48 -0600 Subject: [PATCH 08/21] ARROW-4469: [CI] Pin conda-forge binutils version to 2.31 for now Seeing if this fixes the failures we are seeing. The conda-forge binutils package was just updated 2 days ago Author: Wes McKinney Closes #3554 from wesm/ARROW-4469 and squashes the following commits: dfa1b3fde Pin conda-forge binutils version to 2.31 for now --- ci/travis_env_common.sh | 1 + ci/travis_install_toolchain.sh | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/ci/travis_env_common.sh b/ci/travis_env_common.sh index 5f70535b42c6c..90d956b4d82ef 100755 --- a/ci/travis_env_common.sh +++ b/ci/travis_env_common.sh @@ -21,6 +21,7 @@ export NODE_NO_WARNINGS=1 export MINICONDA=$HOME/miniconda export CONDA_PKGS_DIRS=$HOME/.conda_packages +export CONDA_BINUTILS_VERSION=2.31 export ARROW_CPP_DIR=$TRAVIS_BUILD_DIR/cpp export ARROW_PYTHON_DIR=$TRAVIS_BUILD_DIR/python diff --git a/ci/travis_install_toolchain.sh b/ci/travis_install_toolchain.sh index 7ba1f79e009b2..506a04fa9d24a 100755 --- a/ci/travis_install_toolchain.sh +++ b/ci/travis_install_toolchain.sh @@ -34,7 +34,7 @@ if [ ! -e $CPP_TOOLCHAIN ]; then CONDA_LABEL=" -c conda-forge/label/cf201901" else # Use newer binutils when linking against conda-provided libraries - CONDA_PACKAGES="$CONDA_PACKAGES binutils" + CONDA_PACKAGES="$CONDA_PACKAGES binutils=$CONDA_BINUTILS_VERSION" fi fi From b0ac2342bfeda278da5caba6c50d1b4f7a5ae432 Mon Sep 17 00:00:00 2001 From: ptaylor Date: Mon, 4 Feb 2019 15:13:19 -0600 Subject: [PATCH 09/21] ARROW-4442: [JS] Add explicit type annotation to Chunked typeId getter Closes https://issues.apache.org/jira/browse/ARROW-4442 Typescript is generating an overly broad type for the `typeId` property of the ChunkedVector class, leading to a type mismatch and failure to infer `Chunked` is a `Vector`: ```ts let col: Vector; col = new Chunked(new Utf8()); ^ /* Argument of type 'Chunked' is not assignable to parameter of type 'Vector'. Type 'Chunked' is not assignable to type 'Vector'. Types of property 'typeId' are incompatible. Type 'Type' is not assignable to type 'Type.Utf8'. */ ``` The type being generated currently is: ```ts readonly typeId: import("../enum").Type; ``` but it should be: ```ts readonly typeId: T['TType']; ``` The fix is to add an explicit return annotation to the Chunked `typeId` getter. Unfortunately this only affects the generated typings (`.d.ts` files) and not the library source, so it's difficult to test. We can look into whether there are any flags to trigger stricter type checking of the compiled code in the unit tests, but I don't know any off the top of my head. Author: ptaylor Closes #3538 from trxcllnt/js/add-chunked-typeId-annoation and squashes the following commits: 077f38343 add explicit type annotation to Chunked typeId getter --- js/src/vector/chunked.ts | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/js/src/vector/chunked.ts b/js/src/vector/chunked.ts index 0d26dec8f8ad1..0882f7f225dd4 100644 --- a/js/src/vector/chunked.ts +++ b/js/src/vector/chunked.ts @@ -70,7 +70,7 @@ export class Chunked public get type() { return this._type; } public get length() { return this._length; } public get chunks() { return this._chunks; } - public get typeId() { return this._type.typeId; } + public get typeId(): T['TType'] { return this._type.typeId; } public get data(): Data { return this._chunks[0] ? this._chunks[0].data : null; } From 53bf5bfe85cde60eaf0b1f276ac7e11e87967c8a Mon Sep 17 00:00:00 2001 From: Wes McKinney Date: Mon, 4 Feb 2019 15:14:24 -0600 Subject: [PATCH 10/21] ARROW-4440: [C++] Revert recent changes to flatbuffers EP causing flakiness Author: Wes McKinney Closes #3552 from wesm/ARROW-4440 and squashes the following commits: f3d5d2a28 Fix brotli header copy step 908a9e7b8 Revert use of common thirdparty install directory 371ee0c44 Use EP_COMMON_CMAKE_ARGS 282e16489 Revert recent changes to flatbuffers EP causing flakiness --- cpp/cmake_modules/ThirdpartyToolchain.cmake | 84 ++++++++++----------- cpp/src/arrow/memory_pool.cc | 2 +- 2 files changed, 41 insertions(+), 45 deletions(-) diff --git a/cpp/cmake_modules/ThirdpartyToolchain.cmake b/cpp/cmake_modules/ThirdpartyToolchain.cmake index 5f3b54ca3d92f..5ee0ddfd55914 100644 --- a/cpp/cmake_modules/ThirdpartyToolchain.cmake +++ b/cpp/cmake_modules/ThirdpartyToolchain.cmake @@ -17,8 +17,6 @@ add_custom_target(toolchain) -set(THIRDPARTY_PREFIX "${CMAKE_CURRENT_BINARY_DIR}/arrow_thirdparty") - # ---------------------------------------------------------------------- # Toolchain linkage options @@ -555,17 +553,18 @@ include_directories(SYSTEM ${Boost_INCLUDE_DIR}) # Google double-conversion if("${DOUBLE_CONVERSION_HOME}" STREQUAL "") - set(DOUBLE_CONVERSION_HOME "${THIRDPARTY_PREFIX}") - set(DOUBLE_CONVERSION_INCLUDE_DIR "${THIRDPARTY_PREFIX}/include") - set(DOUBLE_CONVERSION_STATIC_LIB "${THIRDPARTY_PREFIX}/lib/${CMAKE_STATIC_LIBRARY_PREFIX}double-conversion${CMAKE_STATIC_LIBRARY_SUFFIX}") + set(DOUBLE_CONVERSION_PREFIX "${CMAKE_CURRENT_BINARY_DIR}/double-conversion_ep/src/double-conversion_ep") + set(DOUBLE_CONVERSION_HOME "${DOUBLE_CONVERSION_PREFIX}") + set(DOUBLE_CONVERSION_INCLUDE_DIR "${DOUBLE_CONVERSION_PREFIX}/include") + set(DOUBLE_CONVERSION_STATIC_LIB "${DOUBLE_CONVERSION_PREFIX}/lib/${CMAKE_STATIC_LIBRARY_PREFIX}double-conversion${CMAKE_STATIC_LIBRARY_SUFFIX}") set(DOUBLE_CONVERSION_CMAKE_ARGS ${EP_COMMON_CMAKE_ARGS} - "-DCMAKE_INSTALL_PREFIX=${THIRDPARTY_PREFIX}") + "-DCMAKE_INSTALL_PREFIX=${DOUBLE_CONVERSION_PREFIX}") ExternalProject_Add(double-conversion_ep ${EP_LOG_OPTIONS} - INSTALL_DIR ${THIRDPARTY_PREFIX} + INSTALL_DIR ${DOUBLE_CONVERSION_PREFIX} URL ${DOUBLE_CONVERSION_SOURCE_URL} CMAKE_ARGS ${DOUBLE_CONVERSION_CMAKE_ARGS} BUILD_BYPRODUCTS "${DOUBLE_CONVERSION_STATIC_LIB}") @@ -608,7 +607,8 @@ if(ARROW_NEED_GFLAGS) # gflags (formerly Googleflags) command line parsing if("${GFLAGS_HOME}" STREQUAL "") set(GFLAGS_CMAKE_CXX_FLAGS ${EP_CXX_FLAGS}) - set(GFLAGS_PREFIX "${THIRDPARTY_PREFIX}") + + set(GFLAGS_PREFIX "${CMAKE_CURRENT_BINARY_DIR}/gflags_ep-prefix/src/gflags_ep") set(GFLAGS_HOME "${GFLAGS_PREFIX}") set(GFLAGS_INCLUDE_DIR "${GFLAGS_PREFIX}/include") if(MSVC) @@ -666,7 +666,7 @@ if(ARROW_BUILD_TESTS OR ARROW_BUILD_BENCHMARKS) -Wno-ignored-attributes) endif() - set(GTEST_PREFIX "${THIRDPARTY_PREFIX}") + set(GTEST_PREFIX "${CMAKE_CURRENT_BINARY_DIR}/googletest_ep-prefix/src/googletest_ep") set(GTEST_INCLUDE_DIR "${GTEST_PREFIX}/include") set(GTEST_STATIC_LIB "${GTEST_PREFIX}/lib/${CMAKE_STATIC_LIBRARY_PREFIX}gtest${CMAKE_STATIC_LIBRARY_SUFFIX}") @@ -758,7 +758,7 @@ if(ARROW_BUILD_BENCHMARKS) set(GBENCHMARK_CMAKE_CXX_FLAGS "${GBENCHMARK_CMAKE_CXX_FLAGS} -stdlib=libc++") endif() - set(GBENCHMARK_PREFIX "${THIRDPARTY_PREFIX}") + set(GBENCHMARK_PREFIX "${CMAKE_CURRENT_BINARY_DIR}/gbenchmark_ep/src/gbenchmark_ep-install") set(GBENCHMARK_INCLUDE_DIR "${GBENCHMARK_PREFIX}/include") set(GBENCHMARK_STATIC_LIB "${GBENCHMARK_PREFIX}/lib/${CMAKE_STATIC_LIBRARY_PREFIX}benchmark${CMAKE_STATIC_LIBRARY_SUFFIX}") set(GBENCHMARK_VENDORED 1) @@ -794,12 +794,13 @@ endif() if (ARROW_WITH_RAPIDJSON) # RapidJSON, header only dependency if("${RAPIDJSON_HOME}" STREQUAL "") - set(RAPIDJSON_HOME "${THIRDPARTY_PREFIX}") + set(RAPIDJSON_PREFIX "${CMAKE_CURRENT_BINARY_DIR}/rapidjson_ep/src/rapidjson_ep-install") + set(RAPIDJSON_HOME "${RAPIDJSON_PREFIX}") set(RAPIDJSON_CMAKE_ARGS -DRAPIDJSON_BUILD_DOC=OFF -DRAPIDJSON_BUILD_EXAMPLES=OFF -DRAPIDJSON_BUILD_TESTS=OFF - "-DCMAKE_INSTALL_PREFIX=${THIRDPARTY_PREFIX}") + "-DCMAKE_INSTALL_PREFIX=${RAPIDJSON_PREFIX}") ExternalProject_Add(rapidjson_ep ${EP_LOG_OPTIONS} @@ -820,7 +821,7 @@ if (ARROW_WITH_RAPIDJSON) ## Flatbuffers if("${FLATBUFFERS_HOME}" STREQUAL "") - set(FLATBUFFERS_PREFIX "${THIRDPARTY_PREFIX}") + set(FLATBUFFERS_PREFIX "${CMAKE_CURRENT_BINARY_DIR}/flatbuffers_ep-prefix/src/flatbuffers_ep-install") if (MSVC) set(FLATBUFFERS_CMAKE_CXX_FLAGS /EHsc) else() @@ -830,12 +831,11 @@ if (ARROW_WITH_RAPIDJSON) ExternalProject_Add(flatbuffers_ep URL ${FLATBUFFERS_SOURCE_URL} CMAKE_ARGS + ${EP_COMMON_CMAKE_ARGS} + -DCMAKE_BUILD_TYPE=RELEASE "-DCMAKE_CXX_FLAGS=${FLATBUFFERS_CMAKE_CXX_FLAGS}" "-DCMAKE_INSTALL_PREFIX:PATH=${FLATBUFFERS_PREFIX}" "-DFLATBUFFERS_BUILD_TESTS=OFF" - "-DCMAKE_BUILD_TYPE=RELEASE" - "-DCMAKE_CXX_FLAGS_${UPPERCASE_BUILD_TYPE}=${EP_CXX_FLAGS}" - "-DCMAKE_C_FLAGS_${UPPERCASE_BUILD_TYPE}=${EP_C_FLAGS}" ${EP_LOG_OPTIONS}) set(FLATBUFFERS_INCLUDE_DIR "${FLATBUFFERS_PREFIX}/include") @@ -866,7 +866,7 @@ if (ARROW_JEMALLOC) # find_package(jemalloc) set(ARROW_JEMALLOC_USE_SHARED OFF) - set(JEMALLOC_PREFIX "${THIRDPARTY_PREFIX}") + set(JEMALLOC_PREFIX "${CMAKE_CURRENT_BINARY_DIR}/jemalloc_ep-prefix/src/jemalloc_ep/dist/") set(JEMALLOC_HOME "${JEMALLOC_PREFIX}") set(JEMALLOC_INCLUDE_DIR "${JEMALLOC_PREFIX}/include") set(JEMALLOC_SHARED_LIB "${JEMALLOC_PREFIX}/lib/libjemalloc${CMAKE_SHARED_LIBRARY_SUFFIX}") @@ -885,7 +885,7 @@ if (ARROW_JEMALLOC) # Don't use the include directory directly so that we can point to a path # that is unique to our codebase. - include_directories(SYSTEM "${CMAKE_CURRENT_BINARY_DIR}") + include_directories(SYSTEM "${CMAKE_CURRENT_BINARY_DIR}/jemalloc_ep-prefix/src/") ADD_THIRDPARTY_LIB(jemalloc STATIC_LIB ${JEMALLOC_STATIC_LIB} @@ -947,7 +947,7 @@ if (ARROW_WITH_ZLIB) ADD_THIRDPARTY_LIB(zlib SHARED_LIB ${ZLIB_SHARED_LIB}) set(ZLIB_LIBRARY zlib_shared) else() - set(ZLIB_PREFIX "${THIRDPARTY_PREFIX}") + set(ZLIB_PREFIX "${CMAKE_CURRENT_BINARY_DIR}/zlib_ep/src/zlib_ep-install") set(ZLIB_HOME "${ZLIB_PREFIX}") set(ZLIB_INCLUDE_DIR "${ZLIB_PREFIX}/include") if (MSVC) @@ -983,7 +983,7 @@ if (ARROW_WITH_SNAPPY) # Snappy if("${SNAPPY_HOME}" STREQUAL "") - set(SNAPPY_PREFIX "${THIRDPARTY_PREFIX}") + set(SNAPPY_PREFIX "${CMAKE_CURRENT_BINARY_DIR}/snappy_ep/src/snappy_ep-install") set(SNAPPY_HOME "${SNAPPY_PREFIX}") set(SNAPPY_INCLUDE_DIR "${SNAPPY_PREFIX}/include") if (MSVC) @@ -1051,7 +1051,7 @@ if (ARROW_WITH_BROTLI) # Brotli if("${BROTLI_HOME}" STREQUAL "") - set(BROTLI_PREFIX "${THIRDPARTY_PREFIX}") + set(BROTLI_PREFIX "${CMAKE_CURRENT_BINARY_DIR}/brotli_ep/src/brotli_ep-install") set(BROTLI_HOME "${BROTLI_PREFIX}") set(BROTLI_INCLUDE_DIR "${BROTLI_PREFIX}/include") if (MSVC) @@ -1078,7 +1078,7 @@ if (ARROW_WITH_BROTLI) ExternalProject_Get_Property(brotli_ep SOURCE_DIR) ExternalProject_Add_Step(brotli_ep headers_copy - COMMAND xcopy /E /I include ..\\..\\..\\arrow_thirdparty\\include /Y + COMMAND xcopy /E /I include ..\\..\\..\\brotli_ep\\src\\brotli_ep-install\\include /Y DEPENDEES build WORKING_DIRECTORY ${SOURCE_DIR}) endif() @@ -1176,7 +1176,7 @@ if (ARROW_WITH_ZSTD) # ZSTD if("${ZSTD_HOME}" STREQUAL "") - set(ZSTD_PREFIX "${THIRDPARTY_PREFIX}") + set(ZSTD_PREFIX "${CMAKE_CURRENT_BINARY_DIR}/zstd_ep-install") set(ZSTD_INCLUDE_DIR "${ZSTD_PREFIX}/include") set(ZSTD_CMAKE_ARGS -DCMAKE_BUILD_TYPE=${CMAKE_BUILD_TYPE} @@ -1236,7 +1236,7 @@ endif() if (ARROW_GANDIVA) # re2 if ("${RE2_HOME}" STREQUAL "") - set (RE2_PREFIX "${THIRDPARTY_PREFIX}") + set (RE2_PREFIX "${CMAKE_CURRENT_BINARY_DIR}/re2_ep-install") set (RE2_HOME "${RE2_PREFIX}") set (RE2_INCLUDE_DIR "${RE2_PREFIX}/include") set (RE2_STATIC_LIB "${RE2_PREFIX}/lib/${CMAKE_STATIC_LIBRARY_PREFIX}re2${CMAKE_STATIC_LIBRARY_SUFFIX}") @@ -1277,7 +1277,7 @@ endif () if (ARROW_WITH_PROTOBUF) # protobuf if ("${PROTOBUF_HOME}" STREQUAL "") - set (PROTOBUF_PREFIX "${THIRDPARTY_PREFIX}") + set (PROTOBUF_PREFIX "${THIRDPARTY_DIR}/protobuf_ep-install") set (PROTOBUF_HOME "${PROTOBUF_PREFIX}") set (PROTOBUF_INCLUDE_DIR "${PROTOBUF_PREFIX}/include") set (PROTOBUF_STATIC_LIB "${PROTOBUF_PREFIX}/lib/${CMAKE_STATIC_LIBRARY_PREFIX}protobuf${CMAKE_STATIC_LIBRARY_SUFFIX}") @@ -1325,7 +1325,7 @@ endif() if (ARROW_WITH_GRPC) if ("${CARES_HOME}" STREQUAL "") set(CARES_VENDORED 1) - set(CARES_PREFIX "${THIRDPARTY_PREFIX}") + set(CARES_PREFIX "${THIRDPARTY_DIR}/cares_ep-install") set(CARES_HOME "${CARES_PREFIX}") set(CARES_INCLUDE_DIR "${CARES_PREFIX}/include") @@ -1361,7 +1361,7 @@ if (ARROW_WITH_GRPC) if ("${GRPC_HOME}" STREQUAL "") set(GRPC_VENDORED 1) set(GRPC_BUILD_DIR "${CMAKE_CURRENT_BINARY_DIR}/grpc_ep-prefix/src/grpc_ep-build") - set(GRPC_PREFIX "${THIRDPARTY_PREFIX}") + set(GRPC_PREFIX "${THIRDPARTY_DIR}/grpc_ep-install") set(GRPC_HOME "${GRPC_PREFIX}") set(GRPC_INCLUDE_DIR "${GRPC_PREFIX}/include") set(GRPC_CMAKE_ARGS ${EP_COMMON_CMAKE_ARGS} @@ -1374,32 +1374,28 @@ if (ARROW_WITH_GRPC) set(GRPC_STATIC_LIBRARY_ADDRESS_SORTING "${GRPC_PREFIX}/lib/${CMAKE_STATIC_LIBRARY_PREFIX}address_sorting${CMAKE_STATIC_LIBRARY_SUFFIX}") set(GRPC_CPP_PLUGIN "${GRPC_PREFIX}/bin/grpc_cpp_plugin") - set(GRPC_CMAKE_PREFIX "${THIRDPARTY_PREFIX}") + set(GRPC_CMAKE_PREFIX) add_custom_target(grpc_dependencies) if (CARES_VENDORED) add_dependencies(grpc_dependencies cares_ep) - else() - set(GRPC_CMAKE_PREFIX "${GRPC_CMAKE_PREFIX};${CARES_HOME}") endif() if (GFLAGS_VENDORED) add_dependencies(grpc_dependencies gflags_ep) - else() - set(GRPC_CMAKE_PREFIX "${GRPC_CMAKE_PREFIX};${GFLAGS_HOME}") endif() if (PROTOBUF_VENDORED) add_dependencies(grpc_dependencies protobuf_ep) - else() - set(GRPC_CMAKE_PREFIX "${GRPC_CMAKE_PREFIX};${PROTOBUF_HOME}") endif() + set(GRPC_CMAKE_PREFIX "${GRPC_CMAKE_PREFIX};${CARES_HOME}") + set(GRPC_CMAKE_PREFIX "${GRPC_CMAKE_PREFIX};${GFLAGS_HOME}") + set(GRPC_CMAKE_PREFIX "${GRPC_CMAKE_PREFIX};${PROTOBUF_HOME}") + # ZLIB is never vendored - if(NOT "${ZLIB_HOME}" STREQUAL "") - set(GRPC_CMAKE_PREFIX "${GRPC_CMAKE_PREFIX};${ZLIB_HOME}") - endif() + set(GRPC_CMAKE_PREFIX "${GRPC_CMAKE_PREFIX};${ZLIB_HOME}") if (RAPIDJSON_VENDORED) add_dependencies(grpc_dependencies rapidjson_ep) @@ -1479,7 +1475,7 @@ endif() if (ARROW_ORC) # orc if ("${ORC_HOME}" STREQUAL "") - set(ORC_PREFIX "${THIRDPARTY_PREFIX}") + set(ORC_PREFIX "${CMAKE_CURRENT_BINARY_DIR}/orc_ep-install") set(ORC_HOME "${ORC_PREFIX}") set(ORC_INCLUDE_DIR "${ORC_PREFIX}/include") set(ORC_STATIC_LIB "${ORC_PREFIX}/lib/${CMAKE_STATIC_LIBRARY_PREFIX}orc${CMAKE_STATIC_LIBRARY_SUFFIX}") @@ -1554,7 +1550,7 @@ if (ARROW_WITH_THRIFT) find_package(Thrift) if (NOT THRIFT_FOUND) - set(THRIFT_PREFIX "${THIRDPARTY_PREFIX}") + set(THRIFT_PREFIX "${CMAKE_CURRENT_BINARY_DIR}/thrift_ep/src/thrift_ep-install") set(THRIFT_HOME "${THRIFT_PREFIX}") set(THRIFT_INCLUDE_DIR "${THRIFT_PREFIX}/include") set(THRIFT_COMPILER "${THRIFT_PREFIX}/bin/thrift") @@ -1608,7 +1604,7 @@ if (NOT THRIFT_FOUND) if (MSVC) set(WINFLEXBISON_VERSION 2.4.9) - set(WINFLEXBISON_PREFIX "${THIRDPARTY_PREFIX}") + set(WINFLEXBISON_PREFIX "${CMAKE_CURRENT_BINARY_DIR}/winflexbison_ep/src/winflexbison_ep-install") ExternalProject_Add(winflexbison_ep URL https://github.com/lexxmark/winflexbison/releases/download/v.${WINFLEXBISON_VERSION}/win_flex_bison-${WINFLEXBISON_VERSION}.zip URL_HASH MD5=a2e979ea9928fbf8567e995e9c0df765 @@ -1688,9 +1684,9 @@ endif() # ARROW_HIVESERVER2 if (ARROW_USE_GLOG) if("${GLOG_HOME}" STREQUAL "") - set(GLOG_PREFIX "${THIRDPARTY_PREFIX}") - set(GLOG_INCLUDE_DIR "${GLOG_PREFIX}/include") - set(GLOG_STATIC_LIB "${GLOG_PREFIX}/lib/${CMAKE_STATIC_LIBRARY_PREFIX}glog${CMAKE_STATIC_LIBRARY_SUFFIX}") + set(GLOG_BUILD_DIR "${CMAKE_CURRENT_BINARY_DIR}/glog_ep-prefix/src/glog_ep") + set(GLOG_INCLUDE_DIR "${GLOG_BUILD_DIR}/include") + set(GLOG_STATIC_LIB "${GLOG_BUILD_DIR}/lib/${CMAKE_STATIC_LIBRARY_PREFIX}glog${CMAKE_STATIC_LIBRARY_SUFFIX}") set(GLOG_CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fPIC") set(GLOG_CMAKE_C_FLAGS "${EP_C_FLAGS} -fPIC") if (Threads::Threads) @@ -1706,7 +1702,7 @@ if (ARROW_USE_GLOG) endif() set(GLOG_CMAKE_ARGS ${EP_COMMON_CMAKE_ARGS} - "-DCMAKE_INSTALL_PREFIX=${GLOG_PREFIX}" + "-DCMAKE_INSTALL_PREFIX=${GLOG_BUILD_DIR}" -DBUILD_SHARED_LIBS=OFF -DBUILD_TESTING=OFF -DWITH_GFLAGS=OFF diff --git a/cpp/src/arrow/memory_pool.cc b/cpp/src/arrow/memory_pool.cc index 103771bf527a7..3e0366a19da41 100644 --- a/cpp/src/arrow/memory_pool.cc +++ b/cpp/src/arrow/memory_pool.cc @@ -32,7 +32,7 @@ // Needed to support jemalloc 3 and 4 #define JEMALLOC_MANGLE // Explicitly link to our version of jemalloc -#include "arrow_thirdparty/include/jemalloc/jemalloc.h" +#include "jemalloc_ep/dist/include/jemalloc/jemalloc.h" #endif namespace arrow { From a6ae3486ed7f4022b663e077757d6364fc701a3e Mon Sep 17 00:00:00 2001 From: Philipp Moritz Date: Mon, 4 Feb 2019 16:47:22 -0800 Subject: [PATCH 11/21] ARROW-4475: [Python] Fix recursive serialization of self-containing objects This is a regression from https://github.com/apache/arrow/pull/3423, the recursion depth was incremented for arrays and dicts but not for lists, tuples and sets. I also added a regression test for this. Author: Philipp Moritz Closes #3556 from pcmoritz/recursive-serialization and squashes the following commits: f83cff99 fix 903788ea fix serialization of objects that contain themselves --- cpp/src/arrow/python/serialize.cc | 6 ++-- python/pyarrow/tests/test_serialization.py | 36 ++++++++++++++++++++++ 2 files changed, 39 insertions(+), 3 deletions(-) diff --git a/cpp/src/arrow/python/serialize.cc b/cpp/src/arrow/python/serialize.cc index 4dd4c04a6ccb5..3ccdfc8eee5b2 100644 --- a/cpp/src/arrow/python/serialize.cc +++ b/cpp/src/arrow/python/serialize.cc @@ -203,19 +203,19 @@ class SequenceBuilder { Status AppendList(PyObject* context, PyObject* list, int32_t recursion_depth, SerializedPyObject* blobs_out) { return AppendSequence(context, list, PythonType::LIST, lists_, list_values_, - recursion_depth, blobs_out); + recursion_depth + 1, blobs_out); } Status AppendTuple(PyObject* context, PyObject* tuple, int32_t recursion_depth, SerializedPyObject* blobs_out) { return AppendSequence(context, tuple, PythonType::TUPLE, tuples_, tuple_values_, - recursion_depth, blobs_out); + recursion_depth + 1, blobs_out); } Status AppendSet(PyObject* context, PyObject* set, int32_t recursion_depth, SerializedPyObject* blobs_out) { return AppendSequence(context, set, PythonType::SET, sets_, set_values_, - recursion_depth, blobs_out); + recursion_depth + 1, blobs_out); } Status AppendDict(PyObject* context, PyObject* dict, int32_t recursion_depth, diff --git a/python/pyarrow/tests/test_serialization.py b/python/pyarrow/tests/test_serialization.py index 5f76ff62d200d..c37e8b6a7815a 100644 --- a/python/pyarrow/tests/test_serialization.py +++ b/python/pyarrow/tests/test_serialization.py @@ -814,3 +814,39 @@ def test_serialization_determinism(): buf1 = pa.serialize(obj).to_buffer() buf2 = pa.serialize(obj).to_buffer() assert buf1.to_pybytes() == buf2.to_pybytes() + + +def test_serialize_recursive_objects(): + class ClassA(object): + pass + + # Make a list that contains itself. + lst = [] + lst.append(lst) + + # Make an object that contains itself as a field. + a1 = ClassA() + a1.field = a1 + + # Make two objects that contain each other as fields. + a2 = ClassA() + a3 = ClassA() + a2.field = a3 + a3.field = a2 + + # Make a dictionary that contains itself. + d1 = {} + d1["key"] = d1 + + # Make a numpy array that contains itself. + arr = np.array([None], dtype=object) + arr[0] = arr + + # Create a list of recursive objects. + recursive_objects = [lst, a1, a2, a3, d1, arr] + + # Check that exceptions are thrown when we serialize the recursive + # objects. + for obj in recursive_objects: + with pytest.raises(Exception): + pa.serialize(obj).deserialize() From 0c55b25c84119af59320eab0b0625da9ce987294 Mon Sep 17 00:00:00 2001 From: Anurag Khandelwal Date: Mon, 4 Feb 2019 18:24:13 -0800 Subject: [PATCH 12/21] ARROW-4294: [C++] [Plasma] Add support for evicting Plasma objects to external store https://issues.apache.org/jira/browse/ARROW-4294 Note: this PR was previously at https://github.com/apache/arrow/pull/3432, which was closed since its commit history was broken. Currently, when Plasma needs storage space for additional objects, it evicts objects by deleting them from the Plasma store. This is a problem when it isn't possible to reconstruct the object or reconstructing it is expensive. This patch adds support for a pluggable external store that Plasma can evict objects to when it runs out of memory. Author: Anurag Khandelwal Author: Philipp Moritz Closes #3482 from anuragkh/plasma_evict_to_external_store and squashes the following commits: 631671561 remove external store worker, simplify interface 6fbc55b08 Revert "Add an eviction buffer to allow asynchronous evictions" 4f2c02ce3 Revert "Minor fix" 1bc1dbed4 Revert "format fix" 7b662bee4 Revert "Remove timeout for external store test tearDown" 25663df30 Remove timeout for external store test tearDown 7945cc951 format fix 0d7263936 Minor fix 957efb5f0 Add an eviction buffer to allow asynchronous evictions 896d895bd Fixes 7ae486794 Merge branch 'master' into plasma_evict_to_external_store 1af2f8bce Fix cpplint issues 04e173085 Merge branch 'master' into plasma_evict_to_external_store 301e575ea Fix uses of ARROW_CHECK_OK/ARROW_CHECK 69a56abcc Fix documentation errrors c19c5767d Add documentation for notify flag f3fad8086 Fix external store worker intialization 9081596c4 Clean up formatting issues f5cc95c72 Add lint exclusion for external_store_worker, since it uses mutex ffd1f0e6c Extend plasma eviction changes to python module 8afc9fb2f Kill only the plasma_store_server that we started be315677b Add test for testing evictions/unevictions a43445aee Update serialization test 58a995318 Add support for evicting/un-evicting Plasma objects to/from external store --- cpp/src/plasma/CMakeLists.txt | 7 +- cpp/src/plasma/common.h | 4 +- cpp/src/plasma/external_store.cc | 63 ++++++ cpp/src/plasma/external_store.h | 123 ++++++++++++ cpp/src/plasma/hash_table_store.cc | 58 ++++++ cpp/src/plasma/hash_table_store.h | 53 +++++ cpp/src/plasma/store.cc | 210 +++++++++++++++----- cpp/src/plasma/store.h | 23 ++- cpp/src/plasma/test/client_tests.cc | 15 +- cpp/src/plasma/test/external_store_tests.cc | 139 +++++++++++++ python/pyarrow/_plasma.pyx | 1 + python/pyarrow/plasma.py | 6 +- python/pyarrow/tests/test_plasma.py | 58 ++++++ 13 files changed, 692 insertions(+), 68 deletions(-) create mode 100644 cpp/src/plasma/external_store.cc create mode 100644 cpp/src/plasma/external_store.h create mode 100644 cpp/src/plasma/hash_table_store.cc create mode 100644 cpp/src/plasma/hash_table_store.h create mode 100644 cpp/src/plasma/test/external_store_tests.cc diff --git a/cpp/src/plasma/CMakeLists.txt b/cpp/src/plasma/CMakeLists.txt index 53af8c531aad8..fd25aef11297d 100644 --- a/cpp/src/plasma/CMakeLists.txt +++ b/cpp/src/plasma/CMakeLists.txt @@ -125,9 +125,11 @@ if ("${COMPILER_FAMILY}" STREQUAL "gcc") " -Wno-conversion") endif() +list(APPEND PLASMA_EXTERNAL_STORE_SOURCES "external_store.cc" "hash_table_store.cc") + # We use static libraries for the plasma_store_server executable so that it can # be copied around and used in different locations. -add_executable(plasma_store_server store.cc) +add_executable(plasma_store_server ${PLASMA_EXTERNAL_STORE_SOURCES} store.cc) target_link_libraries(plasma_store_server plasma_static ${PLASMA_STATIC_LINK_LIBS}) add_dependencies(plasma plasma_store_server) @@ -214,3 +216,6 @@ ADD_PLASMA_TEST(test/serialization_tests ADD_PLASMA_TEST(test/client_tests EXTRA_LINK_LIBS plasma_shared ${PLASMA_LINK_LIBS} EXTRA_DEPENDENCIES plasma_store_server) +ADD_PLASMA_TEST(test/external_store_tests + EXTRA_LINK_LIBS plasma_shared ${PLASMA_LINK_LIBS} + EXTRA_DEPENDENCIES plasma_store_server) diff --git a/cpp/src/plasma/common.h b/cpp/src/plasma/common.h index dfbd90c3aa553..6f4cef5becb62 100644 --- a/cpp/src/plasma/common.h +++ b/cpp/src/plasma/common.h @@ -69,7 +69,9 @@ enum class ObjectState : int { /// Object was created but not sealed in the local Plasma Store. PLASMA_CREATED = 1, /// Object is sealed and stored in the local Plasma Store. - PLASMA_SEALED + PLASMA_SEALED = 2, + /// Object is evicted to external store. + PLASMA_EVICTED = 3, }; namespace internal { diff --git a/cpp/src/plasma/external_store.cc b/cpp/src/plasma/external_store.cc new file mode 100644 index 0000000000000..8cfbad179ba61 --- /dev/null +++ b/cpp/src/plasma/external_store.cc @@ -0,0 +1,63 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include +#include + +#include "arrow/util/memory.h" + +#include "plasma/external_store.h" + +namespace plasma { + +Status ExternalStores::ExtractStoreName(const std::string& endpoint, + std::string* store_name) { + size_t off = endpoint.find_first_of(':'); + if (off == std::string::npos) { + return Status::Invalid("Malformed endpoint " + endpoint); + } + *store_name = endpoint.substr(0, off); + return Status::OK(); +} + +void ExternalStores::RegisterStore(const std::string& store_name, + std::shared_ptr store) { + Stores().insert({store_name, store}); +} + +void ExternalStores::DeregisterStore(const std::string& store_name) { + auto it = Stores().find(store_name); + if (it == Stores().end()) { + return; + } + Stores().erase(it); +} + +std::shared_ptr ExternalStores::GetStore(const std::string& store_name) { + auto it = Stores().find(store_name); + if (it == Stores().end()) { + return nullptr; + } + return it->second; +} + +ExternalStores::StoreMap& ExternalStores::Stores() { + static auto* external_stores = new StoreMap(); + return *external_stores; +} + +} // namespace plasma diff --git a/cpp/src/plasma/external_store.h b/cpp/src/plasma/external_store.h new file mode 100644 index 0000000000000..feca466587a20 --- /dev/null +++ b/cpp/src/plasma/external_store.h @@ -0,0 +1,123 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#ifndef EXTERNAL_STORE_H +#define EXTERNAL_STORE_H + +#include +#include +#include +#include + +#include "plasma/client.h" + +namespace plasma { + +// ==== The external store ==== +// +// This file contains declaration for all functions that need to be implemented +// for an external storage service so that objects evicted from Plasma store +// can be written to it. + +class ExternalStore { + public: + /// Default constructor. + ExternalStore() = default; + + /// Virtual destructor. + virtual ~ExternalStore() = default; + + /// Connect to the local plasma store. Return the resulting connection. + /// + /// \param endpoint The name of the endpoint to connect to the external + /// storage service. While the formatting of the endpoint name is + /// specific to the implementation of the external store, it always + /// starts with {store-name}://, where {store-name} is the name of the + /// external store. + /// + /// \return The return status. + virtual Status Connect(const std::string& endpoint) = 0; + + /// This method will be called whenever an object in the Plasma store needs + /// to be evicted to the external store. + /// + /// This API is experimental and might change in the future. + /// + /// \param ids The IDs of the objects to put. + /// \param data The object data to put. + /// \return The return status. + virtual Status Put(const std::vector& ids, + const std::vector>& data) = 0; + + /// This method will be called whenever an evicted object in the external + /// store store needs to be accessed. + /// + /// This API is experimental and might change in the future. + /// + /// \param ids The IDs of the objects to get. + /// \param buffers List of buffers the data should be written to. + /// \return The return status. + virtual Status Get(const std::vector& ids, + std::vector> buffers) = 0; +}; + +class ExternalStores { + public: + typedef std::unordered_map> StoreMap; + /// Extracts the external store name from the external store endpoint. + /// + /// \param endpoint The endpoint for the external store. + /// \param[out] store_name The name of the external store. + /// \return The return status. + static Status ExtractStoreName(const std::string& endpoint, std::string* store_name); + + /// Register a new external store. + /// + /// \param store_name Name of the new external store. + /// \param store The new external store object. + static void RegisterStore(const std::string& store_name, + std::shared_ptr store); + + /// Remove an external store from the registry. + /// + /// \param store_name Name of the external store to remove. + static void DeregisterStore(const std::string& store_name); + + /// Obtain the external store given its name. + /// + /// \param store_name Name of the external store. + /// \return The external store object. + static std::shared_ptr GetStore(const std::string& store_name); + + private: + /// Obtain mapping between external store names and store instances. + /// + /// \return Mapping between external store names and store instances. + static StoreMap& Stores(); +}; + +#define REGISTER_EXTERNAL_STORE(name, store) \ + class store##Class { \ + public: \ + store##Class() { ExternalStores::RegisterStore(name, std::make_shared()); } \ + ~store##Class() { ExternalStores::DeregisterStore(name); } \ + }; \ + store##Class singleton_##store = store##Class() + +} // namespace plasma + +#endif // EXTERNAL_STORE_H diff --git a/cpp/src/plasma/hash_table_store.cc b/cpp/src/plasma/hash_table_store.cc new file mode 100644 index 0000000000000..b77d3693fb206 --- /dev/null +++ b/cpp/src/plasma/hash_table_store.cc @@ -0,0 +1,58 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include +#include + +#include "arrow/util/logging.h" + +#include "plasma/hash_table_store.h" + +namespace plasma { + +Status HashTableStore::Connect(const std::string& endpoint) { return Status::OK(); } + +Status HashTableStore::Put(const std::vector& ids, + const std::vector>& data) { + for (size_t i = 0; i < ids.size(); ++i) { + table_[ids[i]] = data[i]->ToString(); + } + return Status::OK(); +} + +Status HashTableStore::Get(const std::vector& ids, + std::vector> buffers) { + ARROW_CHECK(ids.size() == buffers.size()); + for (size_t i = 0; i < ids.size(); ++i) { + bool valid; + HashTable::iterator result; + { + result = table_.find(ids[i]); + valid = result != table_.end(); + } + if (valid) { + ARROW_CHECK(buffers[i]->size() == static_cast(result->second.size())); + std::memcpy(buffers[i]->mutable_data(), result->second.data(), + result->second.size()); + } + } + return Status::OK(); +} + +REGISTER_EXTERNAL_STORE("hashtable", HashTableStore); + +} // namespace plasma diff --git a/cpp/src/plasma/hash_table_store.h b/cpp/src/plasma/hash_table_store.h new file mode 100644 index 0000000000000..766088bd4a659 --- /dev/null +++ b/cpp/src/plasma/hash_table_store.h @@ -0,0 +1,53 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#ifndef HASH_TABLE_STORE_H +#define HASH_TABLE_STORE_H + +#include +#include +#include +#include + +#include "plasma/external_store.h" + +namespace plasma { + +// This is a sample implementation for an external store, for illustration +// purposes only. + +class HashTableStore : public ExternalStore { + public: + HashTableStore() = default; + + Status Connect(const std::string& endpoint) override; + + Status Get(const std::vector& ids, + std::vector> buffers) override; + + Status Put(const std::vector& ids, + const std::vector>& data) override; + + private: + typedef std::unordered_map HashTable; + + HashTable table_; +}; + +} // namespace plasma + +#endif // HASH_TABLE_STORE_H diff --git a/cpp/src/plasma/store.cc b/cpp/src/plasma/store.cc index 745e336049e8b..05495b7571d07 100644 --- a/cpp/src/plasma/store.cc +++ b/cpp/src/plasma/store.cc @@ -108,8 +108,10 @@ GetRequest::GetRequest(Client* client, const std::vector& object_ids) Client::Client(int fd) : fd(fd), notification_fd(-1) {} -PlasmaStore::PlasmaStore(EventLoop* loop, std::string directory, bool hugepages_enabled) - : loop_(loop), eviction_policy_(&store_info_) { +PlasmaStore::PlasmaStore(EventLoop* loop, std::string directory, bool hugepages_enabled, + const std::string& socket_name, + std::shared_ptr external_store) + : loop_(loop), eviction_policy_(&store_info_), external_store_(external_store) { store_info_.directory = directory; store_info_.hugepages_enabled = hugepages_enabled; #ifdef PLASMA_CUDA @@ -136,7 +138,7 @@ void PlasmaStore::AddToClientObjectIds(const ObjectID& object_id, ObjectTableEnt // Tell the eviction policy that this object is being used. std::vector objects_to_evict; eviction_policy_.BeginObjectAccess(object_id, &objects_to_evict); - DeleteObjects(objects_to_evict); + EvictObjects(objects_to_evict); } // Increase reference count. entry->ref_count++; @@ -145,16 +147,9 @@ void PlasmaStore::AddToClientObjectIds(const ObjectID& object_id, ObjectTableEnt client->object_ids.insert(object_id); } -// Create a new object buffer in the hash table. -PlasmaError PlasmaStore::CreateObject(const ObjectID& object_id, int64_t data_size, - int64_t metadata_size, int device_num, - Client* client, PlasmaObject* result) { - ARROW_LOG(DEBUG) << "creating object " << object_id.hex(); - if (store_info_.objects.count(object_id) != 0) { - // There is already an object with the same ID in the Plasma Store, so - // ignore this requst. - return PlasmaError::ObjectExists; - } +// Allocate memory +uint8_t* PlasmaStore::AllocateMemory(int device_num, size_t size, int* fd, + int64_t* map_size, ptrdiff_t* offset) { // Try to evict objects until there is enough space. uint8_t* pointer = nullptr; #ifdef PLASMA_CUDA @@ -173,18 +168,16 @@ PlasmaError PlasmaStore::CreateObject(const ObjectID& object_id, int64_t data_si // it is not guaranteed that the corresponding pointer in the client will be // 64-byte aligned, but in practice it often will be. if (device_num == 0) { - pointer = reinterpret_cast( - PlasmaAllocator::Memalign(kBlockSize, data_size + metadata_size)); - if (pointer == nullptr) { + pointer = reinterpret_cast(PlasmaAllocator::Memalign(kBlockSize, size)); + if (!pointer) { // Tell the eviction policy how much space we need to create this object. std::vector objects_to_evict; - bool success = - eviction_policy_.RequireSpace(data_size + metadata_size, &objects_to_evict); - DeleteObjects(objects_to_evict); + bool success = eviction_policy_.RequireSpace(size, &objects_to_evict); + EvictObjects(objects_to_evict); // Return an error to the client if not enough space could be freed to // create the object. if (!success) { - return PlasmaError::OutOfMemory; + return nullptr; } } else { break; @@ -196,16 +189,39 @@ PlasmaError PlasmaStore::CreateObject(const ObjectID& object_id, int64_t data_si #endif } } + if (device_num == 0) { + GetMallocMapinfo(pointer, fd, map_size, offset); + ARROW_CHECK(*fd != -1); + } + return pointer; +} + +// Create a new object buffer in the hash table. +PlasmaError PlasmaStore::CreateObject(const ObjectID& object_id, int64_t data_size, + int64_t metadata_size, int device_num, + Client* client, PlasmaObject* result) { + ARROW_LOG(DEBUG) << "creating object " << object_id.hex(); + auto entry = GetObjectTableEntry(&store_info_, object_id); + if (entry != nullptr) { + // There is already an object with the same ID in the Plasma Store, so + // ignore this requst. + return PlasmaError::ObjectExists; + } + int fd = -1; int64_t map_size = 0; ptrdiff_t offset = 0; - if (device_num == 0) { - GetMallocMapinfo(pointer, &fd, &map_size, &offset); - assert(fd != -1); + uint8_t* pointer = + AllocateMemory(device_num, data_size + metadata_size, &fd, &map_size, &offset); + if (!pointer) { + return PlasmaError::OutOfMemory; + } + if (!entry) { + auto ptr = std::unique_ptr(new ObjectTableEntry()); + entry = store_info_.objects.emplace(object_id, std::move(ptr)).first->second.get(); + entry->data_size = data_size; + entry->metadata_size = metadata_size; } - auto entry = std::unique_ptr(new ObjectTableEntry()); - entry->data_size = data_size; - entry->metadata_size = metadata_size; entry->pointer = pointer; // TODO(pcm): Set the other fields. entry->fd = fd; @@ -221,7 +237,7 @@ PlasmaError PlasmaStore::CreateObject(const ObjectID& object_id, int64_t data_si result->ipc_handle = entry->ipc_handle; } #endif - store_info_.objects[object_id] = std::move(entry); + result->store_fd = fd; result->data_offset = offset; result->metadata_offset = offset + data_size; @@ -386,7 +402,8 @@ void PlasmaStore::ProcessGetRequest(Client* client, int64_t timeout_ms) { // Create a get request for this object. auto get_req = new GetRequest(client, object_ids); - + std::vector evicted_ids; + std::vector evicted_entries; for (auto object_id : object_ids) { // Check if this object is already present locally. If so, record that the // object is being used and mark it as accounted for. @@ -398,6 +415,26 @@ void PlasmaStore::ProcessGetRequest(Client* client, // If necessary, record that this client is using this object. In the case // where entry == NULL, this will be called from SealObject. AddToClientObjectIds(object_id, entry, client); + } else if (entry && entry->state == ObjectState::PLASMA_EVICTED) { + // Make sure the object pointer is not already allocated + ARROW_CHECK(!entry->pointer); + + entry->pointer = AllocateMemory(0, /* Only support device_num = 0 */ + entry->data_size + entry->metadata_size, &entry->fd, + &entry->map_size, &entry->offset); + if (entry->pointer) { + entry->state = ObjectState::PLASMA_CREATED; + entry->create_time = std::time(nullptr); + eviction_policy_.ObjectCreated(object_id); + AddToClientObjectIds(object_id, store_info_.objects[object_id].get(), client); + evicted_ids.push_back(object_id); + evicted_entries.push_back(entry); + } else { + // We are out of memory an cannot allocate memory for this object. + // Change the state of the object back to PLASMA_EVICTED so some + // other request can try again. + entry->state = ObjectState::PLASMA_EVICTED; + } } else { // Add a placeholder plasma object to the get request to indicate that the // object is not present. This will be parsed by the client. We set the @@ -408,6 +445,33 @@ void PlasmaStore::ProcessGetRequest(Client* client, } } + if (!evicted_ids.empty()) { + unsigned char digest[kDigestSize]; + std::vector> buffers; + for (size_t i = 0; i < evicted_ids.size(); ++i) { + ARROW_CHECK(evicted_entries[i]->pointer != nullptr); + buffers.emplace_back(new arrow::MutableBuffer(evicted_entries[i]->pointer, + evicted_entries[i]->data_size)); + } + if (external_store_->Get(evicted_ids, buffers).ok()) { + for (size_t i = 0; i < evicted_ids.size(); ++i) { + evicted_entries[i]->state = ObjectState::PLASMA_SEALED; + std::memcpy(&evicted_entries[i]->digest[0], &digest[0], kDigestSize); + evicted_entries[i]->construct_duration = + std::time(nullptr) - evicted_entries[i]->create_time; + PlasmaObject_init(&get_req->objects[evicted_ids[i]], evicted_entries[i]); + get_req->num_satisfied += 1; + } + } else { + // We tried to get the objects from the external store, but could not get them. + // Set the state of these objects back to PLASMA_EVICTED so some other request + // can try again. + for (size_t i = 0; i < evicted_ids.size(); ++i) { + evicted_entries[i]->state = ObjectState::PLASMA_EVICTED; + } + } + } + // If all of the objects are present already or if the timeout is 0, return to // the client. if (get_req->num_satisfied == get_req->num_objects_to_wait_for || timeout_ms == 0) { @@ -437,12 +501,12 @@ int PlasmaStore::RemoveFromClientObjectIds(const ObjectID& object_id, // Tell the eviction policy that this object is no longer being used. std::vector objects_to_evict; eviction_policy_.EndObjectAccess(object_id, &objects_to_evict); - DeleteObjects(objects_to_evict); + EvictObjects(objects_to_evict); } else { // Above code does not really delete an object. Instead, it just put an // object to LRU cache which will be cleaned when the memory is not enough. deletion_cache_.erase(object_id); - DeleteObjects({object_id}); + EvictObjects({object_id}); } } // Return 1 to indicate that the client was removed. @@ -463,7 +527,8 @@ void PlasmaStore::ReleaseObject(const ObjectID& object_id, Client* client) { // Check if an object is present. ObjectStatus PlasmaStore::ContainsObject(const ObjectID& object_id) { auto entry = GetObjectTableEntry(&store_info_, object_id); - return entry && (entry->state == ObjectState::PLASMA_SEALED) + return entry && (entry->state == ObjectState::PLASMA_SEALED || + entry->state == ObjectState::PLASMA_EVICTED) ? ObjectStatus::OBJECT_FOUND : ObjectStatus::OBJECT_NOT_FOUND; } @@ -480,6 +545,7 @@ void PlasmaStore::SealObject(const ObjectID& object_id, unsigned char digest[]) std::memcpy(&entry->digest[0], &digest[0], kDigestSize); // Set object construction duration. entry->construct_duration = std::time(nullptr) - entry->create_time; + // Inform all subscribers that a new object has been sealed. ObjectInfoT info; info.object_id = object_id.binary(); @@ -545,25 +611,47 @@ PlasmaError PlasmaStore::DeleteObject(ObjectID& object_id) { return PlasmaError::OK; } -void PlasmaStore::DeleteObjects(const std::vector& object_ids) { +void PlasmaStore::EvictObjects(const std::vector& object_ids) { + std::vector> evicted_object_data; + std::vector evicted_entries; for (const auto& object_id : object_ids) { - ARROW_LOG(DEBUG) << "deleting object " << object_id.hex(); + ARROW_LOG(DEBUG) << "evicting object " << object_id.hex(); auto entry = GetObjectTableEntry(&store_info_, object_id); // TODO(rkn): This should probably not fail, but should instead throw an // error. Maybe we should also support deleting objects that have been // created but not sealed. - ARROW_CHECK(entry != nullptr) - << "To delete an object it must be in the object table."; + ARROW_CHECK(entry != nullptr) << "To evict an object it must be in the object table."; ARROW_CHECK(entry->state == ObjectState::PLASMA_SEALED) - << "To delete an object it must have been sealed."; + << "To evict an object it must have been sealed."; ARROW_CHECK(entry->ref_count == 0) - << "To delete an object, there must be no clients currently using it."; - store_info_.objects.erase(object_id); - // Inform all subscribers that the object has been deleted. - fb::ObjectInfoT notification; - notification.object_id = object_id.binary(); - notification.is_deletion = true; - PushNotification(¬ification); + << "To evict an object, there must be no clients currently using it."; + + // If there is a backing external store, then mark object for eviction to + // external store, free the object data pointer and keep a placeholder + // entry in ObjectTable + if (external_store_) { + evicted_object_data.push_back(std::make_shared( + entry->pointer, entry->data_size + entry->metadata_size)); + evicted_entries.push_back(entry); + } else { + // If there is no backing external store, just erase the object entry + // and send a deletion notification. + store_info_.objects.erase(object_id); + // Inform all subscribers that the object has been deleted. + fb::ObjectInfoT notification; + notification.object_id = object_id.binary(); + notification.is_deletion = true; + PushNotification(¬ification); + } + } + + if (external_store_ && !object_ids.empty()) { + ARROW_CHECK_OK(external_store_->Put(object_ids, evicted_object_data)); + for (auto entry : evicted_entries) { + PlasmaAllocator::Free(entry->pointer, entry->data_size + entry->metadata_size); + entry->pointer = nullptr; + entry->state = ObjectState::PLASMA_EVICTED; + } } } @@ -869,7 +957,7 @@ Status PlasmaStore::ProcessMessage(Client* client) { std::vector objects_to_evict; int64_t num_bytes_evicted = eviction_policy_.ChooseObjectsToEvict(num_bytes, &objects_to_evict); - DeleteObjects(objects_to_evict); + EvictObjects(objects_to_evict); HANDLE_SIGPIPE(SendEvictReply(client->fd, num_bytes_evicted), client->fd); } break; case fb::MessageType::PlasmaSubscribeRequest: @@ -894,10 +982,12 @@ class PlasmaStoreRunner { public: PlasmaStoreRunner() {} - void Start(char* socket_name, std::string directory, bool hugepages_enabled) { + void Start(char* socket_name, std::string directory, bool hugepages_enabled, + std::shared_ptr external_store) { // Create the event loop. loop_.reset(new EventLoop); - store_.reset(new PlasmaStore(loop_.get(), directory, hugepages_enabled)); + store_.reset(new PlasmaStore(loop_.get(), directory, hugepages_enabled, socket_name, + external_store)); plasma_config = store_->GetPlasmaStoreInfo(); // We are using a single memory-mapped file by mallocing and freeing a single @@ -945,15 +1035,15 @@ void HandleSignal(int signal) { } } -void StartServer(char* socket_name, std::string plasma_directory, - bool hugepages_enabled) { +void StartServer(char* socket_name, std::string plasma_directory, bool hugepages_enabled, + std::shared_ptr external_store) { // Ignore SIGPIPE signals. If we don't do this, then when we attempt to write // to a client that has already died, the store could die. signal(SIGPIPE, SIG_IGN); g_runner.reset(new PlasmaStoreRunner()); signal(SIGTERM, HandleSignal); - g_runner->Start(socket_name, plasma_directory, hugepages_enabled); + g_runner->Start(socket_name, plasma_directory, hugepages_enabled, external_store); } } // namespace plasma @@ -964,14 +1054,18 @@ int main(int argc, char* argv[]) { char* socket_name = nullptr; // Directory where plasma memory mapped files are stored. std::string plasma_directory; + std::string external_store_endpoint; bool hugepages_enabled = false; int64_t system_memory = -1; int c; - while ((c = getopt(argc, argv, "s:m:d:h")) != -1) { + while ((c = getopt(argc, argv, "s:m:d:e:h")) != -1) { switch (c) { case 'd': plasma_directory = std::string(optarg); break; + case 'e': + external_store_endpoint = std::string(optarg); + break; case 'h': hugepages_enabled = true; break; @@ -1038,8 +1132,22 @@ int main(int argc, char* argv[]) { SetMallocGranularity(1024 * 1024 * 1024); // 1 GB } #endif + // Get external store + std::shared_ptr external_store{nullptr}; + if (!external_store_endpoint.empty()) { + std::string name; + ARROW_CHECK_OK( + plasma::ExternalStores::ExtractStoreName(external_store_endpoint, &name)); + external_store = plasma::ExternalStores::GetStore(name); + if (external_store == nullptr) { + ARROW_LOG(FATAL) << "No such external store \"" << name << "\""; + return -1; + } + ARROW_LOG(DEBUG) << "connecting to external store..."; + ARROW_CHECK_OK(external_store->Connect(external_store_endpoint)); + } ARROW_LOG(DEBUG) << "starting server listening on " << socket_name; - plasma::StartServer(socket_name, plasma_directory, hugepages_enabled); + plasma::StartServer(socket_name, plasma_directory, hugepages_enabled, external_store); plasma::g_runner->Shutdown(); plasma::g_runner = nullptr; diff --git a/cpp/src/plasma/store.h b/cpp/src/plasma/store.h index a5c586b7f53f0..7105c513ebb2e 100644 --- a/cpp/src/plasma/store.h +++ b/cpp/src/plasma/store.h @@ -28,7 +28,9 @@ #include "plasma/common.h" #include "plasma/events.h" #include "plasma/eviction_policy.h" +#include "plasma/external_store.h" #include "plasma/plasma.h" +#include "plasma/protocol.h" namespace arrow { class Status; @@ -75,7 +77,9 @@ class PlasmaStore { using NotificationMap = std::unordered_map; // TODO: PascalCase PlasmaStore methods. - PlasmaStore(EventLoop* loop, std::string directory, bool hugetlbfs_enabled); + PlasmaStore(EventLoop* loop, std::string directory, bool hugepages_enabled, + const std::string& socket_name, + std::shared_ptr external_store); ~PlasmaStore(); @@ -125,11 +129,10 @@ class PlasmaStore { /// - PlasmaError::ObjectInUse, if the object is in use. PlasmaError DeleteObject(ObjectID& object_id); - /// Delete objects that have been created in the hash table. This should only - /// be called on objects that are returned by the eviction policy to evict. + /// Evict objects returned by the eviction policy. /// - /// @param object_ids Object IDs of the objects to be deleted. - void DeleteObjects(const std::vector& object_ids); + /// @param object_ids Object IDs of the objects to be evicted. + void EvictObjects(const std::vector& object_ids); /// Process a get request from a client. This method assumes that we will /// eventually have these objects sealed. If one of the objects has not yet @@ -149,8 +152,7 @@ class PlasmaStore { /// /// @param object_id Object ID of the object to be sealed. /// @param digest The digest of the object. This is used to tell if two - /// objects - /// with the same object ID are the same. + /// objects with the same object ID are the same. void SealObject(const ObjectID& object_id, unsigned char digest[]); /// Check if the plasma store contains an object: @@ -210,6 +212,9 @@ class PlasmaStore { int RemoveFromClientObjectIds(const ObjectID& object_id, ObjectTableEntry* entry, Client* client); + uint8_t* AllocateMemory(int device_num, size_t size, int* fd, int64_t* map_size, + ptrdiff_t* offset); + /// Event loop of the plasma store. EventLoop* loop_; /// The plasma store information, including the object tables, that is exposed @@ -233,6 +238,10 @@ class PlasmaStore { std::unordered_map> connected_clients_; std::unordered_set deletion_cache_; + + /// Manages worker threads for handling asynchronous/multi-threaded requests + /// for reading/writing data to/from external store. + std::shared_ptr external_store_; #ifdef PLASMA_CUDA arrow::cuda::CudaDeviceManager* manager_; #endif diff --git a/cpp/src/plasma/test/client_tests.cc b/cpp/src/plasma/test/client_tests.cc index 1678e27f90f58..90ab70fe09c36 100644 --- a/cpp/src/plasma/test/client_tests.cc +++ b/cpp/src/plasma/test/client_tests.cc @@ -59,9 +59,9 @@ class TestPlasmaStore : public ::testing::Test { std::string plasma_directory = test_executable.substr(0, test_executable.find_last_of("/")); - std::string plasma_command = plasma_directory + - "/plasma_store_server -m 10000000 -s " + - store_socket_name_ + " 1> /dev/null 2> /dev/null &"; + std::string plasma_command = + plasma_directory + "/plasma_store_server -m 10000000 -s " + store_socket_name_ + + " 1> /dev/null 2> /dev/null & " + "echo $! > " + store_socket_name_ + ".pid"; system(plasma_command.c_str()); ARROW_CHECK_OK(client_.Connect(store_socket_name_, "")); ARROW_CHECK_OK(client2_.Connect(store_socket_name_, "")); @@ -69,15 +69,16 @@ class TestPlasmaStore : public ::testing::Test { virtual void TearDown() { ARROW_CHECK_OK(client_.Disconnect()); ARROW_CHECK_OK(client2_.Disconnect()); - // Kill all plasma_store processes - // TODO should only kill the processes we launched + // Kill plasma_store process that we started #ifdef COVERAGE_BUILD // Ask plasma_store to exit gracefully and give it time to write out // coverage files - system("killall -TERM plasma_store_server"); + std::string plasma_term_command = "kill -TERM `cat " + store_socket_name_ + ".pid`"; + system(plasma_term_command.c_str()); std::this_thread::sleep_for(std::chrono::milliseconds(200)); #endif - system("killall -KILL plasma_store_server"); + std::string plasma_kill_command = "kill -KILL `cat " + store_socket_name_ + ".pid`"; + system(plasma_kill_command.c_str()); } void CreateObject(PlasmaClient& client, const ObjectID& object_id, diff --git a/cpp/src/plasma/test/external_store_tests.cc b/cpp/src/plasma/test/external_store_tests.cc new file mode 100644 index 0000000000000..33d3bd1dca9b0 --- /dev/null +++ b/cpp/src/plasma/test/external_store_tests.cc @@ -0,0 +1,139 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include +#include +#include +#include +#include +#include + +#include +#include +#include + +#include + +#include "arrow/test-util.h" + +#include "plasma/client.h" +#include "plasma/common.h" +#include "plasma/external_store.h" +#include "plasma/plasma.h" +#include "plasma/protocol.h" +#include "plasma/test-util.h" + +namespace plasma { + +std::string external_test_executable; // NOLINT + +void AssertObjectBufferEqual(const ObjectBuffer& object_buffer, + const std::string& metadata, const std::string& data) { + arrow::AssertBufferEqual(*object_buffer.metadata, metadata); + arrow::AssertBufferEqual(*object_buffer.data, data); +} + +class TestPlasmaStoreWithExternal : public ::testing::Test { + public: + // TODO(pcm): At the moment, stdout of the test gets mixed up with + // stdout of the object store. Consider changing that. + void SetUp() override { + uint64_t seed = std::chrono::high_resolution_clock::now().time_since_epoch().count(); + std::mt19937 rng(static_cast(seed)); + std::string store_index = std::to_string(rng()); + store_socket_name_ = "/tmp/store_with_external" + store_index; + + std::string plasma_directory = + external_test_executable.substr(0, external_test_executable.find_last_of('/')); + std::string plasma_command = plasma_directory + + "/plasma_store_server -m 1024000 -e " + + "hashtable://test -s " + store_socket_name_ + + " 1> /tmp/log.stdout 2> /tmp/log.stderr & " + + "echo $! > " + store_socket_name_ + ".pid"; + system(plasma_command.c_str()); + ARROW_CHECK_OK(client_.Connect(store_socket_name_, "")); + } + void TearDown() override { + ARROW_CHECK_OK(client_.Disconnect()); + // Kill plasma_store process that we started +#ifdef COVERAGE_BUILD + // Ask plasma_store to exit gracefully and give it time to write out + // coverage files + std::string plasma_term_command = "kill -TERM `cat " + store_socket_name_ + ".pid`"; + system(plasma_term_command.c_str()); + std::this_thread::sleep_for(std::chrono::milliseconds(200)); +#endif + std::string plasma_kill_command = "kill -KILL `cat " + store_socket_name_ + ".pid`"; + system(plasma_kill_command.c_str()); + } + + protected: + PlasmaClient client_; + std::string store_socket_name_; +}; + +TEST_F(TestPlasmaStoreWithExternal, EvictionTest) { + std::vector object_ids; + std::string data(100 * 1024, 'x'); + std::string metadata; + for (int i = 0; i < 20; i++) { + ObjectID object_id = random_object_id(); + object_ids.push_back(object_id); + + // Test for object non-existence. + bool has_object; + ARROW_CHECK_OK(client_.Contains(object_id, &has_object)); + ASSERT_FALSE(has_object); + + // Test for the object being in local Plasma store. + // Create and seal the object. + ARROW_CHECK_OK(client_.CreateAndSeal(object_id, data, metadata)); + // Test that the client can get the object. + ARROW_CHECK_OK(client_.Contains(object_id, &has_object)); + ASSERT_TRUE(has_object); + } + + for (int i = 0; i < 20; i++) { + // Since we are accessing objects sequentially, every object we + // access would be a cache "miss" owing to LRU eviction. + // Try and access the object from the plasma store first, and then try + // external store on failure. This should succeed to fetch the object. + // However, it may evict the next few objects. + std::vector object_buffers; + ARROW_CHECK_OK(client_.Get({object_ids[i]}, -1, &object_buffers)); + ASSERT_EQ(object_buffers.size(), 1); + ASSERT_EQ(object_buffers[0].device_num, 0); + ASSERT_TRUE(object_buffers[0].data); + AssertObjectBufferEqual(object_buffers[0], metadata, data); + } + + // Make sure we still cannot fetch objects that do not exist + std::vector object_buffers; + ARROW_CHECK_OK(client_.Get({random_object_id()}, 100, &object_buffers)); + ASSERT_EQ(object_buffers.size(), 1); + ASSERT_EQ(object_buffers[0].device_num, 0); + ASSERT_EQ(object_buffers[0].data, nullptr); + ASSERT_EQ(object_buffers[0].metadata, nullptr); +} + +} // namespace plasma + +int main(int argc, char** argv) { + ::testing::InitGoogleTest(&argc, argv); + plasma::external_test_executable = std::string(argv[0]); + return RUN_ALL_TESTS(); +} diff --git a/python/pyarrow/_plasma.pyx b/python/pyarrow/_plasma.pyx index 4f64f202cef67..b3868dcac82e3 100644 --- a/python/pyarrow/_plasma.pyx +++ b/python/pyarrow/_plasma.pyx @@ -373,6 +373,7 @@ cdef class PlasmaClient: The number of milliseconds that the get call should block before timing out and returning. Pass -1 if the call should block and 0 if the call should return immediately. + with_meta : bool Returns ------- diff --git a/python/pyarrow/plasma.py b/python/pyarrow/plasma.py index a6ab362536d00..13b3eec39db79 100644 --- a/python/pyarrow/plasma.py +++ b/python/pyarrow/plasma.py @@ -78,7 +78,8 @@ def build_plasma_tensorflow_op(): @contextlib.contextmanager def start_plasma_store(plasma_store_memory, use_valgrind=False, use_profiler=False, - plasma_directory=None, use_hugepages=False): + plasma_directory=None, use_hugepages=False, + external_store=None): """Start a plasma store process. Args: plasma_store_memory (int): Capacity of the plasma store in bytes. @@ -89,6 +90,7 @@ def start_plasma_store(plasma_store_memory, plasma_directory (str): Directory where plasma memory mapped files will be stored. use_hugepages (bool): True if the plasma store should use huge pages. + external_store (str): External store to use for evicted objects. Return: A tuple of the name of the plasma store socket and the process ID of the plasma store process. @@ -108,6 +110,8 @@ def start_plasma_store(plasma_store_memory, command += ["-d", plasma_directory] if use_hugepages: command += ["-h"] + if external_store is not None: + command += ["-e", external_store] stdout_file = None stderr_file = None if use_valgrind: diff --git a/python/pyarrow/tests/test_plasma.py b/python/pyarrow/tests/test_plasma.py index bcb467aab8e3e..ef53bab539cc1 100644 --- a/python/pyarrow/tests/test_plasma.py +++ b/python/pyarrow/tests/test_plasma.py @@ -37,6 +37,7 @@ DEFAULT_PLASMA_STORE_MEMORY = 10 ** 8 USE_VALGRIND = os.getenv("PLASMA_VALGRIND") == "1" +EXTERNAL_STORE = "hashtable://test" SMALL_OBJECT_SIZE = 9000 @@ -919,6 +920,63 @@ def client_get_multiple(plasma_store_name): break +@pytest.mark.plasma +class TestEvictionToExternalStore(object): + + def setup_method(self, test_method): + import pyarrow.plasma as plasma + # Start Plasma store. + self.plasma_store_ctx = plasma.start_plasma_store( + plasma_store_memory=1000 * 1024, + use_valgrind=USE_VALGRIND, + external_store=EXTERNAL_STORE) + self.plasma_store_name, self.p = self.plasma_store_ctx.__enter__() + # Connect to Plasma. + self.plasma_client = plasma.connect(self.plasma_store_name) + + def teardown_method(self, test_method): + try: + # Check that the Plasma store is still alive. + assert self.p.poll() is None + self.p.send_signal(signal.SIGTERM) + if sys.version_info >= (3, 3): + self.p.wait(timeout=5) + else: + self.p.wait() + finally: + self.plasma_store_ctx.__exit__(None, None, None) + + def test_eviction(self): + client = self.plasma_client + + object_ids = [random_object_id() for _ in range(0, 20)] + data = b'x' * 100 * 1024 + metadata = b'' + + for i in range(0, 20): + # Test for object non-existence. + assert not client.contains(object_ids[i]) + + # Create and seal the object. + client.create_and_seal(object_ids[i], data, metadata) + + # Test that the client can get the object. + assert client.contains(object_ids[i]) + + for i in range(0, 20): + # Since we are accessing objects sequentially, every object we + # access would be a cache "miss" owing to LRU eviction. + # Try and access the object from the plasma store first, and then + # try external store on failure. This should succeed to fetch the + # object. However, it may evict the next few objects. + [result] = client.get_buffers([object_ids[i]]) + assert result.to_pybytes() == data + + # Make sure we still cannot fetch objects that do not exist + [result] = client.get_buffers([random_object_id()], timeout_ms=100) + assert result is None + + @pytest.mark.plasma def test_object_id_size(): import pyarrow.plasma as plasma From c93e2ae14f4d35cf367817a2d2c65024c50dbecc Mon Sep 17 00:00:00 2001 From: Michael Pigott Date: Tue, 5 Feb 2019 10:18:41 +0100 Subject: [PATCH 13/21] ARROW-3923: [Java] JDBC Time Fetches Without Timezone https://issues.apache.org/jira/browse/ARROW-3923 Hello! I was reading through the JDBC source code and I noticed that a java.util.Calendar was required for creating an Arrow Schema and Arrow Vectors from a JDBC ResultSet, when none is required. This change makes the Calendar optional. Unit Tests: The existing SureFire plugin configuration uses a UTC calendar for the database, which is the default Calendar in the existing code. Likewise, no changes to the unit tests are required to provide adequate coverage for the change. Author: Michael Pigott Author: Mike Pigott Closes #3066 from mikepigott/jdbc-timestamp-no-calendar and squashes the following commits: 4d95da0a ARROW-3923: Supporting a null Calendar in the config, and reverting the breaking change. cd9a2306 Merge branch 'master' into jdbc-timestamp-no-calendar 509a1cc5 Merge pull request #5 from apache/master 789c8c84 Merge pull request #4 from apache/master e5b19eee Merge pull request #3 from apache/master 3b17c297 Merge pull request #2 from apache/master 881c6c83 Merge pull request #1 from apache/master 089cff4d Format fixes a58a4a5f Fixing calendar usage. e12832a3 Allowing for timestamps without a time zone. --- .../arrow/adapter/jdbc/JdbcToArrow.java | 6 +-- .../arrow/adapter/jdbc/JdbcToArrowConfig.java | 4 +- .../jdbc/JdbcToArrowConfigBuilder.java | 11 ++---- .../arrow/adapter/jdbc/JdbcToArrowUtils.java | 38 ++++++++++++++----- .../adapter/jdbc/JdbcToArrowConfigTest.java | 15 +++++--- 5 files changed, 46 insertions(+), 28 deletions(-) diff --git a/java/adapter/jdbc/src/main/java/org/apache/arrow/adapter/jdbc/JdbcToArrow.java b/java/adapter/jdbc/src/main/java/org/apache/arrow/adapter/jdbc/JdbcToArrow.java index fd320367f77b6..ddf70df9ad2ce 100644 --- a/java/adapter/jdbc/src/main/java/org/apache/arrow/adapter/jdbc/JdbcToArrow.java +++ b/java/adapter/jdbc/src/main/java/org/apache/arrow/adapter/jdbc/JdbcToArrow.java @@ -179,13 +179,12 @@ public static VectorSchemaRoot sqlToArrow(ResultSet resultSet, BaseAllocator all * For the given JDBC {@link ResultSet}, fetch the data from Relational DB and convert it to Arrow objects. * * @param resultSet ResultSet to use to fetch the data from underlying database - * @param calendar Calendar instance to use for Date, Time and Timestamp datasets. + * @param calendar Calendar instance to use for Date, Time and Timestamp datasets, or null if none. * @return Arrow Data Objects {@link VectorSchemaRoot} * @throws SQLException on error */ public static VectorSchemaRoot sqlToArrow(ResultSet resultSet, Calendar calendar) throws SQLException, IOException { Preconditions.checkNotNull(resultSet, "JDBC ResultSet object can not be null"); - Preconditions.checkNotNull(calendar, "Calendar object can not be null"); return sqlToArrow(resultSet, new JdbcToArrowConfig(new RootAllocator(Integer.MAX_VALUE), calendar)); } @@ -195,7 +194,7 @@ public static VectorSchemaRoot sqlToArrow(ResultSet resultSet, Calendar calendar * * @param resultSet ResultSet to use to fetch the data from underlying database * @param allocator Memory allocator to use. - * @param calendar Calendar instance to use for Date, Time and Timestamp datasets. + * @param calendar Calendar instance to use for Date, Time and Timestamp datasets, or null if none. * @return Arrow Data Objects {@link VectorSchemaRoot} * @throws SQLException on error */ @@ -203,7 +202,6 @@ public static VectorSchemaRoot sqlToArrow(ResultSet resultSet, BaseAllocator all throws SQLException, IOException { Preconditions.checkNotNull(resultSet, "JDBC ResultSet object can not be null"); Preconditions.checkNotNull(allocator, "Memory Allocator object can not be null"); - Preconditions.checkNotNull(calendar, "Calendar object can not be null"); return sqlToArrow(resultSet, new JdbcToArrowConfig(allocator, calendar)); } diff --git a/java/adapter/jdbc/src/main/java/org/apache/arrow/adapter/jdbc/JdbcToArrowConfig.java b/java/adapter/jdbc/src/main/java/org/apache/arrow/adapter/jdbc/JdbcToArrowConfig.java index 59813a830cbed..e9fcffb36b666 100644 --- a/java/adapter/jdbc/src/main/java/org/apache/arrow/adapter/jdbc/JdbcToArrowConfig.java +++ b/java/adapter/jdbc/src/main/java/org/apache/arrow/adapter/jdbc/JdbcToArrowConfig.java @@ -48,7 +48,6 @@ public final class JdbcToArrowConfig { */ JdbcToArrowConfig(BaseAllocator allocator, Calendar calendar) { Preconditions.checkNotNull(allocator, "Memory allocator cannot be null"); - Preconditions.checkNotNull(calendar, "Calendar object can not be null"); this.allocator = allocator; this.calendar = calendar; @@ -56,7 +55,8 @@ public final class JdbcToArrowConfig { /** * The calendar to use when defining Arrow Timestamp fields - * and retrieving time-based fields from the database. + * and retrieving {@link Date}, {@link Time}, or {@link Timestamp} + * data types from the {@link ResultSet}, or null if not converting. * @return the calendar. */ public Calendar getCalendar() { diff --git a/java/adapter/jdbc/src/main/java/org/apache/arrow/adapter/jdbc/JdbcToArrowConfigBuilder.java b/java/adapter/jdbc/src/main/java/org/apache/arrow/adapter/jdbc/JdbcToArrowConfigBuilder.java index df97c3a975196..9ba69639905ce 100644 --- a/java/adapter/jdbc/src/main/java/org/apache/arrow/adapter/jdbc/JdbcToArrowConfigBuilder.java +++ b/java/adapter/jdbc/src/main/java/org/apache/arrow/adapter/jdbc/JdbcToArrowConfigBuilder.java @@ -32,7 +32,7 @@ public class JdbcToArrowConfigBuilder { /** * Default constructor for the JdbcToArrowConfigBuilder}. - * Use the setter methods for the allocator and calendar; both must be + * Use the setter methods for the allocator and calendar; the allocator must be * set. Otherwise, {@link #build()} will throw a {@link NullPointerException}. */ public JdbcToArrowConfigBuilder() { @@ -41,9 +41,9 @@ public JdbcToArrowConfigBuilder() { } /** - * Constructor for the JdbcToArrowConfigBuilder. Both the - * allocator and calendar are required. A {@link NullPointerException} - * will be thrown if one of the arguments is null. + * Constructor for the JdbcToArrowConfigBuilder. The + * allocator is required, and a {@link NullPointerException} + * will be thrown if it is null. *

* The allocator is used to construct Arrow vectors from the JDBC ResultSet. * The calendar is used to determine the time zone of {@link java.sql.Timestamp} @@ -59,7 +59,6 @@ public JdbcToArrowConfigBuilder(BaseAllocator allocator, Calendar calendar) { this(); Preconditions.checkNotNull(allocator, "Memory allocator cannot be null"); - Preconditions.checkNotNull(calendar, "Calendar object can not be null"); this.allocator = allocator; this.calendar = calendar; @@ -82,10 +81,8 @@ public JdbcToArrowConfigBuilder setAllocator(BaseAllocator allocator) { * Arrow schema, and reading time-based fields from the JDBC ResultSet. * * @param calendar the calendar to set. - * @exception NullPointerExeption if calendar is null. */ public JdbcToArrowConfigBuilder setCalendar(Calendar calendar) { - Preconditions.checkNotNull(calendar, "Calendar object can not be null"); this.calendar = calendar; return this; } diff --git a/java/adapter/jdbc/src/main/java/org/apache/arrow/adapter/jdbc/JdbcToArrowUtils.java b/java/adapter/jdbc/src/main/java/org/apache/arrow/adapter/jdbc/JdbcToArrowUtils.java index d48cfe2197b0c..b6adbbc7334a4 100644 --- a/java/adapter/jdbc/src/main/java/org/apache/arrow/adapter/jdbc/JdbcToArrowUtils.java +++ b/java/adapter/jdbc/src/main/java/org/apache/arrow/adapter/jdbc/JdbcToArrowUtils.java @@ -240,15 +240,15 @@ private static void allocateVectors(VectorSchemaRoot root, int size) { * * @param rs ResultSet to use to fetch the data from underlying database * @param root Arrow {@link VectorSchemaRoot} object to populate - * @param calendar The calendar to use when reading time-based data. + * @param calendar The calendar to use when reading {@link Date}, {@link Time}, or {@link Timestamp} + * data types from the {@link ResultSet}, or null if not converting. * @throws SQLException on error */ public static void jdbcToArrowVectors(ResultSet rs, VectorSchemaRoot root, Calendar calendar) throws SQLException, IOException { Preconditions.checkNotNull(rs, "JDBC ResultSet object can't be null"); - Preconditions.checkNotNull(root, "Vector Schema cannot be null"); - Preconditions.checkNotNull(calendar, "Calendar object can't be null"); + Preconditions.checkNotNull(root, "JDBC ResultSet object can't be null"); jdbcToArrowVectors(rs, root, new JdbcToArrowConfig(new RootAllocator(0), calendar)); } @@ -274,6 +274,8 @@ public static void jdbcToArrowVectors(ResultSet rs, VectorSchemaRoot root, JdbcT allocateVectors(root, DEFAULT_BUFFER_SIZE); + final Calendar calendar = config.getCalendar(); + int rowCount = 0; while (rs.next()) { for (int i = 1; i <= columnCount; i++) { @@ -324,17 +326,35 @@ public static void jdbcToArrowVectors(ResultSet rs, VectorSchemaRoot root, JdbcT rs.getString(i), !rs.wasNull(), rowCount); break; case Types.DATE: - updateVector((DateMilliVector) root.getVector(columnName), - rs.getDate(i, config.getCalendar()), !rs.wasNull(), rowCount); + final Date date; + if (calendar != null) { + date = rs.getDate(i, calendar); + } else { + date = rs.getDate(i); + } + + updateVector((DateMilliVector) root.getVector(columnName), date, !rs.wasNull(), rowCount); break; case Types.TIME: - updateVector((TimeMilliVector) root.getVector(columnName), - rs.getTime(i, config.getCalendar()), !rs.wasNull(), rowCount); + final Time time; + if (calendar != null) { + time = rs.getTime(i, calendar); + } else { + time = rs.getTime(i); + } + + updateVector((TimeMilliVector) root.getVector(columnName), time, !rs.wasNull(), rowCount); break; case Types.TIMESTAMP: + final Timestamp ts; + if (calendar != null) { + ts = rs.getTimestamp(i, calendar); + } else { + ts = rs.getTimestamp(i); + } + // TODO: Need to handle precision such as milli, micro, nano - updateVector((TimeStampVector) root.getVector(columnName), - rs.getTimestamp(i, config.getCalendar()), !rs.wasNull(), rowCount); + updateVector((TimeStampVector) root.getVector(columnName), ts, !rs.wasNull(), rowCount); break; case Types.BINARY: case Types.VARBINARY: diff --git a/java/adapter/jdbc/src/test/java/org/apache/arrow/adapter/jdbc/JdbcToArrowConfigTest.java b/java/adapter/jdbc/src/test/java/org/apache/arrow/adapter/jdbc/JdbcToArrowConfigTest.java index b4f92fa417026..1d02c888f8537 100644 --- a/java/adapter/jdbc/src/test/java/org/apache/arrow/adapter/jdbc/JdbcToArrowConfigTest.java +++ b/java/adapter/jdbc/src/test/java/org/apache/arrow/adapter/jdbc/JdbcToArrowConfigTest.java @@ -42,14 +42,16 @@ public void testBuilderNullArguments() { new JdbcToArrowConfigBuilder(null, null); } - @Test(expected = NullPointerException.class) public void testConfigNullCalendar() { - new JdbcToArrowConfig(allocator, null); + JdbcToArrowConfig config = new JdbcToArrowConfig(allocator, null); + assertNull(config.getCalendar()); } - @Test(expected = NullPointerException.class) + @Test public void testBuilderNullCalendar() { - new JdbcToArrowConfigBuilder(allocator, null); + JdbcToArrowConfigBuilder builder = new JdbcToArrowConfigBuilder(allocator, null); + JdbcToArrowConfig config = builder.build(); + assertNull(config.getCalendar()); } @Test(expected = NullPointerException.class) @@ -68,10 +70,11 @@ public void testSetNullAllocator() { builder.setAllocator(null); } - @Test(expected = NullPointerException.class) + @Test public void testSetNullCalendar() { JdbcToArrowConfigBuilder builder = new JdbcToArrowConfigBuilder(allocator, calendar); - builder.setCalendar(null); + JdbcToArrowConfig config = builder.setCalendar(null).build(); + assertNull(config.getCalendar()); } @Test From 927cfeff875e557e28649891ea20ca38cb9d1536 Mon Sep 17 00:00:00 2001 From: ptaylor Date: Tue, 5 Feb 2019 11:00:10 +0100 Subject: [PATCH 14/21] ARROW-4477: [JS] remove constructor override in the bignum mixins Removes the constructor override leftover from the original `Object.create()` implementation of the bignum mixins. Closes https://issues.apache.org/jira/browse/ARROW-4477 Author: ptaylor Closes #3557 from trxcllnt/js/fix-bn-constructor and squashes the following commits: 3a468594 remove constructor override in the bignum mixins --- js/src/util/bn.ts | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/js/src/util/bn.ts b/js/src/util/bn.ts index 74e3e82cea7e9..c49c342044200 100644 --- a/js/src/util/bn.ts +++ b/js/src/util/bn.ts @@ -38,9 +38,9 @@ const BigNumNMixin = { }; /** @ignore */ -const SignedBigNumNMixin: any = Object.assign({}, BigNumNMixin, { signed: true, constructor: undefined }); +const SignedBigNumNMixin: any = Object.assign({}, BigNumNMixin, { signed: true }); /** @ignore */ -const UnsignedBigNumNMixin: any = Object.assign({}, BigNumNMixin, { signed: false, constructor: undefined }); +const UnsignedBigNumNMixin: any = Object.assign({}, BigNumNMixin, { signed: false }); /** @ignore */ export class BN { From 5f0ff7fcce18cc2afed8cf1165696946adb517a4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fran=C3=A7ois=20Saint-Jacques?= Date: Tue, 5 Feb 2019 14:37:08 +0100 Subject: [PATCH 15/21] ARROW-3239: [C++] Implement simple random array generation MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This implement the following API. ``` random::RandomArrayGenerator rand(seed); auto bool_array = rand.Boolean(num_rows, 0.75, null_prob); auto u8_array = rand.Int8(num_rows, 0, 100, null_prob); ``` Author: François Saint-Jacques Closes #3533 from fsaintjacques/ARROW-3239-random-utils and squashes the following commits: a806b1ff Add ARROW_EXPORT to RandomArrayGenerator 63d9103b Fix GenerateOptions seed type 59c3a3bb Add undef to macro 22eca801 Handle special case with MSVC 728aadcd Fix downcasting issues 4840ac0e ARROW-3239: Implement simple random array generation --- cpp/src/arrow/CMakeLists.txt | 2 +- cpp/src/arrow/ipc/json-test.cc | 43 +++--- cpp/src/arrow/ipc/read-write-benchmark.cc | 33 +---- cpp/src/arrow/ipc/test-common.h | 21 +-- cpp/src/arrow/test-random.cc | 149 +++++++++++++++++++ cpp/src/arrow/test-random.h | 169 ++++++++++++++++++++++ cpp/src/arrow/test-util.h | 5 +- 7 files changed, 354 insertions(+), 68 deletions(-) create mode 100644 cpp/src/arrow/test-random.cc create mode 100644 cpp/src/arrow/test-random.h diff --git a/cpp/src/arrow/CMakeLists.txt b/cpp/src/arrow/CMakeLists.txt index 1dba5898c0a7a..c65824f5385be 100644 --- a/cpp/src/arrow/CMakeLists.txt +++ b/cpp/src/arrow/CMakeLists.txt @@ -243,7 +243,7 @@ endif() if (ARROW_BUILD_TESTS OR ARROW_BUILD_BENCHMARKS) # that depend on gtest ADD_ARROW_LIB(arrow_testing - SOURCES test-util.cc + SOURCES test-util.cc test-random.cc OUTPUTS ARROW_TESTING_LIBRARIES DEPENDENCIES ${GTEST_LIBRARY} SHARED_LINK_LIBS arrow_shared ${GTEST_LIBRARY} diff --git a/cpp/src/arrow/ipc/json-test.cc b/cpp/src/arrow/ipc/json-test.cc index 47a0a29a14e79..bea6fbbb112f0 100644 --- a/cpp/src/arrow/ipc/json-test.cc +++ b/cpp/src/arrow/ipc/json-test.cc @@ -32,6 +32,7 @@ #include "arrow/ipc/test-common.h" #include "arrow/memory_pool.h" #include "arrow/record_batch.h" +#include "arrow/test-random.h" #include "arrow/test-util.h" #include "arrow/type.h" #include "arrow/type_traits.h" @@ -216,48 +217,38 @@ TEST(TestJsonArrayWriter, Unions) { // Data generation for test case below void MakeBatchArrays(const std::shared_ptr& schema, const int num_rows, std::vector>* arrays) { - std::vector is_valid; - random_is_valid(num_rows, 0.25, &is_valid); + const float null_prob = 0.25f; + random::RandomArrayGenerator rand(0x564a3bf0); - std::vector v1_values; - std::vector v2_values; - - randint(num_rows, 0, 100, &v1_values); - randint(num_rows, 0, 100, &v2_values); - - std::shared_ptr v1; - ArrayFromVector(is_valid, v1_values, &v1); - - std::shared_ptr v2; - ArrayFromVector(is_valid, v2_values, &v2); + *arrays = {rand.Boolean(num_rows, 0.75, null_prob), + rand.Int8(num_rows, 0, 100, null_prob), + rand.Int32(num_rows, -1000, 1000, null_prob), + rand.UInt64(num_rows, 0, 1UL << 16, null_prob)}; static const int kBufferSize = 10; static uint8_t buffer[kBufferSize]; static uint32_t seed = 0; StringBuilder string_builder; for (int i = 0; i < num_rows; ++i) { - if (!is_valid[i]) { - ASSERT_OK(string_builder.AppendNull()); - } else { - random_ascii(kBufferSize, seed++, buffer); - ASSERT_OK(string_builder.Append(buffer, kBufferSize)); - } + random_ascii(kBufferSize, seed++, buffer); + ASSERT_OK(string_builder.Append(buffer, kBufferSize)); } std::shared_ptr v3; ASSERT_OK(string_builder.Finish(&v3)); - arrays->emplace_back(v1); - arrays->emplace_back(v2); arrays->emplace_back(v3); } TEST(TestJsonFileReadWrite, BasicRoundTrip) { - auto v1_type = int8(); - auto v2_type = int32(); - auto v3_type = utf8(); + auto v1_type = boolean(); + auto v2_type = int8(); + auto v3_type = int32(); + auto v4_type = uint64(); + auto v5_type = utf8(); auto schema = - ::arrow::schema({field("f1", v1_type), field("f2", v2_type), field("f3", v3_type)}); + ::arrow::schema({field("f1", v1_type), field("f2", v2_type), field("f3", v3_type), + field("f4", v4_type), field("f5", v5_type)}); std::unique_ptr writer; ASSERT_OK(JsonWriter::Open(schema, &writer)); @@ -289,7 +280,7 @@ TEST(TestJsonFileReadWrite, BasicRoundTrip) { for (int i = 0; i < nbatches; ++i) { std::shared_ptr batch; ASSERT_OK(reader->ReadRecordBatch(i, &batch)); - ASSERT_TRUE(batch->Equals(*batches[i])); + ASSERT_RECORD_BATCHES_EQUAL(*batch, *batches[i]); } } diff --git a/cpp/src/arrow/ipc/read-write-benchmark.cc b/cpp/src/arrow/ipc/read-write-benchmark.cc index ace2965b9001c..359cd0eb6ba8f 100644 --- a/cpp/src/arrow/ipc/read-write-benchmark.cc +++ b/cpp/src/arrow/ipc/read-write-benchmark.cc @@ -24,34 +24,15 @@ #include "arrow/api.h" #include "arrow/io/memory.h" #include "arrow/ipc/api.h" +#include "arrow/test-random.h" #include "arrow/test-util.h" namespace arrow { -template std::shared_ptr MakeRecordBatch(int64_t total_size, int64_t num_fields) { - using T = typename TYPE::c_type; - size_t itemsize = sizeof(T); - int64_t length = total_size / num_fields / itemsize; - - auto type = TypeTraits::type_singleton(); - - std::vector is_valid; - random_is_valid(length, 0.1, &is_valid); - - std::vector values; - randint(length, 0, 100, &values); - - typename TypeTraits::BuilderType builder(type, default_memory_pool()); - for (size_t i = 0; i < values.size(); ++i) { - if (is_valid[i]) { - ABORT_NOT_OK(builder.Append(values[i])); - } else { - ABORT_NOT_OK(builder.AppendNull()); - } - } - std::shared_ptr array; - ABORT_NOT_OK(builder.Finish(&array)); + int64_t length = total_size / num_fields / sizeof(int64_t); + random::RandomArrayGenerator rand(0x4f32a908); + auto type = arrow::int64(); ArrayVector arrays; std::vector> fields; @@ -59,7 +40,7 @@ std::shared_ptr MakeRecordBatch(int64_t total_size, int64_t num_fie std::stringstream ss; ss << "f" << i; fields.push_back(field(ss.str(), type)); - arrays.push_back(array); + arrays.push_back(rand.Int64(length, 0, 100, 0.1)); } auto schema = std::make_shared(fields); @@ -72,7 +53,7 @@ static void BM_WriteRecordBatch(benchmark::State& state) { // NOLINT non-const std::shared_ptr buffer; ABORT_NOT_OK(AllocateResizableBuffer(kTotalSize & 2, &buffer)); - auto record_batch = MakeRecordBatch(kTotalSize, state.range(0)); + auto record_batch = MakeRecordBatch(kTotalSize, state.range(0)); while (state.KeepRunning()) { io::BufferOutputStream stream(buffer); @@ -93,7 +74,7 @@ static void BM_ReadRecordBatch(benchmark::State& state) { // NOLINT non-const r std::shared_ptr buffer; ABORT_NOT_OK(AllocateResizableBuffer(kTotalSize & 2, &buffer)); - auto record_batch = MakeRecordBatch(kTotalSize, state.range(0)); + auto record_batch = MakeRecordBatch(kTotalSize, state.range(0)); io::BufferOutputStream stream(buffer); diff --git a/cpp/src/arrow/ipc/test-common.h b/cpp/src/arrow/ipc/test-common.h index 4f7de26e35e16..c9f577d60ed04 100644 --- a/cpp/src/arrow/ipc/test-common.h +++ b/cpp/src/arrow/ipc/test-common.h @@ -32,6 +32,7 @@ #include "arrow/pretty_print.h" #include "arrow/record_batch.h" #include "arrow/status.h" +#include "arrow/test-random.h" #include "arrow/test-util.h" #include "arrow/type.h" #include "arrow/util/bit-util.h" @@ -67,20 +68,12 @@ const auto kListListInt32 = list(kListInt32); Status MakeRandomInt32Array(int64_t length, bool include_nulls, MemoryPool* pool, std::shared_ptr* out, uint32_t seed = 0) { - std::shared_ptr data; - RETURN_NOT_OK(MakeRandomBuffer(length, pool, &data, seed)); - Int32Builder builder(int32(), pool); - RETURN_NOT_OK(builder.Resize(length)); - if (include_nulls) { - std::shared_ptr valid_bytes; - RETURN_NOT_OK(MakeRandomByteBuffer(length, pool, &valid_bytes)); - RETURN_NOT_OK(builder.AppendValues(reinterpret_cast(data->data()), - length, valid_bytes->data())); - return builder.Finish(out); - } - RETURN_NOT_OK( - builder.AppendValues(reinterpret_cast(data->data()), length)); - return builder.Finish(out); + random::RandomArrayGenerator rand(seed); + const double null_probability = include_nulls ? 0.5 : 0.0; + + *out = rand.Int32(length, 0, 1000, null_probability); + + return Status::OK(); } Status MakeRandomListArray(const std::shared_ptr& child_array, int num_lists, diff --git a/cpp/src/arrow/test-random.cc b/cpp/src/arrow/test-random.cc new file mode 100644 index 0000000000000..cb35bfd4282f8 --- /dev/null +++ b/cpp/src/arrow/test-random.cc @@ -0,0 +1,149 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "arrow/test-random.h" + +#include +#include +#include +#include + +#include + +#include "arrow/array.h" +#include "arrow/buffer.h" +#include "arrow/test-util.h" +#include "arrow/type.h" +#include "arrow/type_fwd.h" +#include "arrow/type_traits.h" +#include "arrow/util/bit-util.h" + +namespace arrow { +namespace random { + +template +struct GenerateOptions { + GenerateOptions(SeedType seed, ValueType min, ValueType max, double probability) + : min_(min), max_(max), seed_(seed), probability_(probability) {} + + void GenerateData(uint8_t* buffer, size_t n) { + std::default_random_engine rng(seed_++); + DistributionType dist(min_, max_); + + ValueType* data = reinterpret_cast(buffer); + + // A static cast is required due to the int16 -> int8 handling. + std::generate(data, data + n, + [&dist, &rng] { return static_cast(dist(rng)); }); + } + + void GenerateBitmap(uint8_t* buffer, size_t n, int64_t* null_count) { + int64_t count = 0; + std::default_random_engine rng(seed_++); + std::bernoulli_distribution dist(1.0 - probability_); + + for (size_t i = 0; i < n; i++) { + if (dist(rng)) { + BitUtil::SetBit(buffer, i); + } else { + count++; + } + } + + if (null_count != nullptr) *null_count = count; + } + + ValueType min_; + ValueType max_; + SeedType seed_; + double probability_; +}; + +std::shared_ptr RandomArrayGenerator::Boolean(int64_t size, double probability, + double null_probability) { + // The boolean generator does not care about the value distribution since it + // only calls the GenerateBitmap method. + using GenOpt = GenerateOptions>; + + std::vector> buffers{2}; + // Need 2 distinct generators such that probabilities are not shared. + GenOpt value_gen(seed(), 0, 1, probability); + GenOpt null_gen(seed(), 0, 1, null_probability); + + int64_t null_count = 0; + ABORT_NOT_OK(AllocateEmptyBitmap(size, &buffers[0])); + null_gen.GenerateBitmap(buffers[0]->mutable_data(), size, &null_count); + + ABORT_NOT_OK(AllocateEmptyBitmap(size, &buffers[1])); + value_gen.GenerateBitmap(buffers[1]->mutable_data(), size, nullptr); + + auto array_data = ArrayData::Make(arrow::boolean(), size, buffers, null_count); + return std::make_shared(array_data); +} + +template +static std::shared_ptr> GenerateNumericArray(int64_t size, + OptionType options) { + using CType = typename ArrowType::c_type; + auto type = TypeTraits::type_singleton(); + std::vector> buffers{2}; + + int64_t null_count = 0; + ABORT_NOT_OK(AllocateEmptyBitmap(size, &buffers[0])); + options.GenerateBitmap(buffers[0]->mutable_data(), size, &null_count); + + ABORT_NOT_OK(AllocateBuffer(sizeof(CType) * size, &buffers[1])) + options.GenerateData(buffers[1]->mutable_data(), size); + + auto array_data = ArrayData::Make(type, size, buffers, null_count); + return std::make_shared>(array_data); +} + +#define PRIMITIVE_RAND_IMPL(Name, CType, ArrowType, Distribution) \ + std::shared_ptr RandomArrayGenerator::Name(int64_t size, CType min, CType max, \ + double probability) { \ + using OptionType = GenerateOptions; \ + OptionType options(seed(), min, max, probability); \ + return GenerateNumericArray(size, options); \ + } + +#define PRIMITIVE_RAND_INTEGER_IMPL(Name, CType, ArrowType) \ + PRIMITIVE_RAND_IMPL(Name, CType, ArrowType, std::uniform_int_distribution) + +// Visual Studio does not implement uniform_int_distribution for char types. +PRIMITIVE_RAND_IMPL(UInt8, uint8_t, UInt8Type, std::uniform_int_distribution) +PRIMITIVE_RAND_IMPL(Int8, int8_t, Int8Type, std::uniform_int_distribution) + +PRIMITIVE_RAND_INTEGER_IMPL(UInt16, uint16_t, UInt16Type) +PRIMITIVE_RAND_INTEGER_IMPL(Int16, int16_t, Int16Type) +PRIMITIVE_RAND_INTEGER_IMPL(UInt32, uint32_t, UInt32Type) +PRIMITIVE_RAND_INTEGER_IMPL(Int32, int32_t, Int32Type) +PRIMITIVE_RAND_INTEGER_IMPL(UInt64, uint64_t, UInt64Type) +PRIMITIVE_RAND_INTEGER_IMPL(Int64, int64_t, Int64Type) + +#define PRIMITIVE_RAND_FLOAT_IMPL(Name, CType, ArrowType) \ + PRIMITIVE_RAND_IMPL(Name, CType, ArrowType, std::uniform_real_distribution) + +PRIMITIVE_RAND_FLOAT_IMPL(Float32, float, FloatType) +PRIMITIVE_RAND_FLOAT_IMPL(Float64, double, DoubleType) + +#undef PRIMITIVE_RAND_INTEGER_IMPL +#undef PRIMITIVE_RAND_FLOAT_IMPL +#undef PRIMITIVE_RAND_IMPL + +} // namespace random +} // namespace arrow diff --git a/cpp/src/arrow/test-random.h b/cpp/src/arrow/test-random.h new file mode 100644 index 0000000000000..dc57dcab0251f --- /dev/null +++ b/cpp/src/arrow/test-random.h @@ -0,0 +1,169 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include +#include +#include +#include + +#include "arrow/util/visibility.h" + +namespace arrow { + +class Array; + +namespace random { + +using SeedType = std::random_device::result_type; +constexpr SeedType kSeedMax = std::numeric_limits::max(); + +class ARROW_EXPORT RandomArrayGenerator { + public: + explicit RandomArrayGenerator(SeedType seed) + : seed_distribution_(static_cast(1), kSeedMax), seed_rng_(seed) {} + + /// \brief Generates a random BooleanArray + /// + /// \param[in] size the size of the array to generate + /// \param[in] probability the estimated number of active bits + /// \param[in] null_probability the probability of a row being null + /// + /// \return a generated Array + std::shared_ptr Boolean(int64_t size, double probability, + double null_probability); + + /// \brief Generates a random UInt8Array + /// + /// \param[in] size the size of the array to generate + /// \param[in] min the lower bound of the uniform distribution + /// \param[in] max the upper bound of the uniform distribution + /// \param[in] null_probability the probability of a row being null + /// + /// \return a generated Array + std::shared_ptr UInt8(int64_t size, uint8_t min, uint8_t max, + double null_probability); + + /// \brief Generates a random Int8Array + /// + /// \param[in] size the size of the array to generate + /// \param[in] min the lower bound of the uniform distribution + /// \param[in] max the upper bound of the uniform distribution + /// \param[in] null_probability the probability of a row being null + /// + /// \return a generated Array + std::shared_ptr Int8(int64_t size, int8_t min, int8_t max, + double null_probability); + + /// \brief Generates a random UInt16Array + /// + /// \param[in] size the size of the array to generate + /// \param[in] min the lower bound of the uniform distribution + /// \param[in] max the upper bound of the uniform distribution + /// \param[in] null_probability the probability of a row being null + /// + /// \return a generated Array + std::shared_ptr UInt16(int64_t size, uint16_t min, uint16_t max, + double null_probability); + + /// \brief Generates a random Int16Array + /// + /// \param[in] size the size of the array to generate + /// \param[in] min the lower bound of the uniform distribution + /// \param[in] max the upper bound of the uniform distribution + /// \param[in] null_probability the probability of a row being null + /// + /// \return a generated Array + std::shared_ptr Int16(int64_t size, int16_t min, int16_t max, + double null_probability); + + /// \brief Generates a random UInt32Array + /// + /// \param[in] size the size of the array to generate + /// \param[in] min the lower bound of the uniform distribution + /// \param[in] max the upper bound of the uniform distribution + /// \param[in] null_probability the probability of a row being null + /// + /// \return a generated Array + std::shared_ptr UInt32(int64_t size, uint32_t min, uint32_t max, + double null_probability); + + /// \brief Generates a random Int32Array + /// + /// \param[in] size the size of the array to generate + /// \param[in] min the lower bound of the uniform distribution + /// \param[in] max the upper bound of the uniform distribution + /// \param[in] null_probability the probability of a row being null + /// + /// \return a generated Array + std::shared_ptr Int32(int64_t size, int32_t min, int32_t max, + double null_probability); + + /// \brief Generates a random UInt64Array + /// + /// \param[in] size the size of the array to generate + /// \param[in] min the lower bound of the uniform distribution + /// \param[in] max the upper bound of the uniform distribution + /// \param[in] null_probability the probability of a row being null + /// + /// \return a generated Array + std::shared_ptr UInt64(int64_t size, uint64_t min, uint64_t max, + double null_probability); + + /// \brief Generates a random Int64Array + /// + /// \param[in] size the size of the array to generate + /// \param[in] min the lower bound of the uniform distribution + /// \param[in] max the upper bound of the uniform distribution + /// \param[in] null_probability the probability of a row being null + /// + /// \return a generated Array + std::shared_ptr Int64(int64_t size, int64_t min, int64_t max, + double null_probability); + + /// \brief Generates a random FloatArray + /// + /// \param[in] size the size of the array to generate + /// \param[in] min the lower bound of the uniform distribution + /// \param[in] max the upper bound of the uniform distribution + /// \param[in] null_probability the probability of a row being null + /// + /// \return a generated Array + std::shared_ptr Float32(int64_t size, float min, float max, + double null_probability); + + /// \brief Generates a random DoubleArray + /// + /// \param[in] size the size of the array to generate + /// \param[in] min the lower bound of the uniform distribution + /// \param[in] max the upper bound of the uniform distribution + /// \param[in] null_probability the probability of a row being null + /// + /// \return a generated Array + std::shared_ptr Float64(int64_t size, double min, double max, + double null_probability); + + private: + SeedType seed() { return seed_distribution_(seed_rng_); } + + std::uniform_int_distribution seed_distribution_; + std::default_random_engine seed_rng_; +}; + +} // namespace random +} // namespace arrow diff --git a/cpp/src/arrow/test-util.h b/cpp/src/arrow/test-util.h index 713ff38ca5283..546cc4e261ee8 100644 --- a/cpp/src/arrow/test-util.h +++ b/cpp/src/arrow/test-util.h @@ -100,7 +100,7 @@ class Table; using ArrayVector = std::vector>; -#define ASSERT_ARRAYS_EQUAL(LEFT, RIGHT) \ +#define ASSERT_PP_EQUAL(LEFT, RIGHT) \ do { \ if (!(LEFT).Equals((RIGHT))) { \ std::stringstream pp_result; \ @@ -112,6 +112,9 @@ using ArrayVector = std::vector>; } \ } while (false) +#define ASSERT_ARRAYS_EQUAL(lhs, rhs) ASSERT_PP_EQUAL(lhs, rhs) +#define ASSERT_RECORD_BATCHES_EQUAL(lhs, rhs) ASSERT_PP_EQUAL(lhs, rhs) + template void randint(int64_t N, T lower, T upper, std::vector* out) { const int random_seed = 0; From 7ce26553b8ce78085751e4a4ae603d4043abf337 Mon Sep 17 00:00:00 2001 From: Wes McKinney Date: Tue, 5 Feb 2019 15:15:43 +0100 Subject: [PATCH 16/21] PARQUET-1521: [C++] Use pure virtual interfaces for parquet::TypedColumnWriter, remove use of 'extern template class' This follows corresponding work in TypedColumnReader. The public API is unchanged as can be verified by lack of changes to the unit tests Author: Wes McKinney Closes #3551 from wesm/PARQUET-1521 and squashes the following commits: aa6687a9 Fix clang warnings 33555044 Print build warning level b657ac93 Fix parquet-column-io-benchmark 61204dec Refactor TypedColumnWriter implementation to be based on pure virtual interface, remove use of extern template class --- cpp/cmake_modules/SetupCxxFlags.cmake | 2 + cpp/src/parquet/column-io-benchmark.cc | 26 +- cpp/src/parquet/column_writer.cc | 450 ++++++++++++++++--------- cpp/src/parquet/column_writer.h | 208 ++---------- 4 files changed, 350 insertions(+), 336 deletions(-) diff --git a/cpp/cmake_modules/SetupCxxFlags.cmake b/cpp/cmake_modules/SetupCxxFlags.cmake index 44ca22f5dacb2..43dab02bb613c 100644 --- a/cpp/cmake_modules/SetupCxxFlags.cmake +++ b/cpp/cmake_modules/SetupCxxFlags.cmake @@ -111,6 +111,8 @@ if (NOT BUILD_WARNING_LEVEL) endif(NOT BUILD_WARNING_LEVEL) string(TOUPPER ${BUILD_WARNING_LEVEL} BUILD_WARNING_LEVEL) +message(STATUS "Arrow build warning level: ${BUILD_WARNING_LEVEL}") + if ("${BUILD_WARNING_LEVEL}" STREQUAL "CHECKIN") # Pre-checkin builds if ("${COMPILER_FAMILY}" STREQUAL "msvc") diff --git a/cpp/src/parquet/column-io-benchmark.cc b/cpp/src/parquet/column-io-benchmark.cc index c648d562649d1..762bcb74c01bc 100644 --- a/cpp/src/parquet/column-io-benchmark.cc +++ b/cpp/src/parquet/column-io-benchmark.cc @@ -30,14 +30,15 @@ using schema::PrimitiveNode; namespace benchmark { -std::unique_ptr BuildWriter(int64_t output_size, OutputStream* dst, +std::shared_ptr BuildWriter(int64_t output_size, OutputStream* dst, ColumnChunkMetaDataBuilder* metadata, ColumnDescriptor* schema, const WriterProperties* properties) { std::unique_ptr pager = PageWriter::Open(dst, Compression::UNCOMPRESSED, metadata); - return std::unique_ptr(new Int64Writer( - metadata, std::move(pager), false /*use_dictionary*/, Encoding::PLAIN, properties)); + std::shared_ptr writer = + ColumnWriter::Make(metadata, std::move(pager), properties); + return std::static_pointer_cast(writer); } std::shared_ptr Int64Schema(Repetition::type repetition) { @@ -65,14 +66,17 @@ static void BM_WriteInt64Column(::benchmark::State& state) { std::vector definition_levels(state.range(0), 1); std::vector repetition_levels(state.range(0), 0); std::shared_ptr schema = Int64Schema(repetition); - WriterProperties::Builder builder; - std::shared_ptr properties = builder.compression(codec)->build(); + std::shared_ptr properties = WriterProperties::Builder() + .compression(codec) + ->encoding(Encoding::PLAIN) + ->disable_dictionary() + ->build(); auto metadata = ColumnChunkMetaDataBuilder::Make( properties, schema.get(), reinterpret_cast(&thrift_metadata)); while (state.KeepRunning()) { InMemoryOutputStream stream; - std::unique_ptr writer = BuildWriter( + std::shared_ptr writer = BuildWriter( state.range(0), &stream, metadata.get(), schema.get(), properties.get()); writer->WriteBatch(values.size(), definition_levels.data(), repetition_levels.data(), values.data()); @@ -125,13 +129,17 @@ static void BM_ReadInt64Column(::benchmark::State& state) { std::vector definition_levels(state.range(0), 1); std::vector repetition_levels(state.range(0), 0); std::shared_ptr schema = Int64Schema(repetition); - WriterProperties::Builder builder; - std::shared_ptr properties = builder.compression(codec)->build(); + std::shared_ptr properties = WriterProperties::Builder() + .compression(codec) + ->encoding(Encoding::PLAIN) + ->disable_dictionary() + ->build(); + auto metadata = ColumnChunkMetaDataBuilder::Make( properties, schema.get(), reinterpret_cast(&thrift_metadata)); InMemoryOutputStream stream; - std::unique_ptr writer = BuildWriter( + std::shared_ptr writer = BuildWriter( state.range(0), &stream, metadata.get(), schema.get(), properties.get()); writer->WriteBatch(values.size(), definition_levels.data(), repetition_levels.data(), values.data()); diff --git a/cpp/src/parquet/column_writer.cc b/cpp/src/parquet/column_writer.cc index 0919a3f1d7a65..47a125620b738 100644 --- a/cpp/src/parquet/column_writer.cc +++ b/cpp/src/parquet/column_writer.cc @@ -353,57 +353,148 @@ std::shared_ptr default_writer_properties() { return default_writer_properties; } -ColumnWriter::ColumnWriter(ColumnChunkMetaDataBuilder* metadata, - std::unique_ptr pager, bool has_dictionary, - Encoding::type encoding, const WriterProperties* properties) - : metadata_(metadata), - descr_(metadata->descr()), - pager_(std::move(pager)), - has_dictionary_(has_dictionary), - encoding_(encoding), - properties_(properties), - allocator_(properties->memory_pool()), - num_buffered_values_(0), - num_buffered_encoded_values_(0), - rows_written_(0), - total_bytes_written_(0), - total_compressed_bytes_(0), - closed_(false), - fallback_(false) { - definition_levels_sink_.reset(new InMemoryOutputStream(allocator_)); - repetition_levels_sink_.reset(new InMemoryOutputStream(allocator_)); - definition_levels_rle_ = - std::static_pointer_cast(AllocateBuffer(allocator_, 0)); - repetition_levels_rle_ = - std::static_pointer_cast(AllocateBuffer(allocator_, 0)); - uncompressed_data_ = - std::static_pointer_cast(AllocateBuffer(allocator_, 0)); - if (pager_->has_compressor()) { - compressed_data_ = +class ColumnWriterImpl { + public: + ColumnWriterImpl(ColumnChunkMetaDataBuilder* metadata, + std::unique_ptr pager, const bool use_dictionary, + Encoding::type encoding, const WriterProperties* properties) + : metadata_(metadata), + descr_(metadata->descr()), + pager_(std::move(pager)), + has_dictionary_(use_dictionary), + encoding_(encoding), + properties_(properties), + allocator_(properties->memory_pool()), + num_buffered_values_(0), + num_buffered_encoded_values_(0), + rows_written_(0), + total_bytes_written_(0), + total_compressed_bytes_(0), + closed_(false), + fallback_(false) { + definition_levels_sink_.reset(new InMemoryOutputStream(allocator_)); + repetition_levels_sink_.reset(new InMemoryOutputStream(allocator_)); + definition_levels_rle_ = + std::static_pointer_cast(AllocateBuffer(allocator_, 0)); + repetition_levels_rle_ = std::static_pointer_cast(AllocateBuffer(allocator_, 0)); + uncompressed_data_ = + std::static_pointer_cast(AllocateBuffer(allocator_, 0)); + if (pager_->has_compressor()) { + compressed_data_ = + std::static_pointer_cast(AllocateBuffer(allocator_, 0)); + } } -} -void ColumnWriter::InitSinks() { - definition_levels_sink_->Clear(); - repetition_levels_sink_->Clear(); -} + virtual ~ColumnWriterImpl() = default; -void ColumnWriter::WriteDefinitionLevels(int64_t num_levels, const int16_t* levels) { - DCHECK(!closed_); - definition_levels_sink_->Write(reinterpret_cast(levels), - sizeof(int16_t) * num_levels); -} + int64_t Close(); -void ColumnWriter::WriteRepetitionLevels(int64_t num_levels, const int16_t* levels) { - DCHECK(!closed_); - repetition_levels_sink_->Write(reinterpret_cast(levels), - sizeof(int16_t) * num_levels); -} + protected: + virtual std::shared_ptr GetValuesBuffer() = 0; + + // Serializes Dictionary Page if enabled + virtual void WriteDictionaryPage() = 0; + + // Plain-encoded statistics of the current page + virtual EncodedStatistics GetPageStatistics() = 0; + + // Plain-encoded statistics of the whole chunk + virtual EncodedStatistics GetChunkStatistics() = 0; + + // Merges page statistics into chunk statistics, then resets the values + virtual void ResetPageStatistics() = 0; + + // Adds Data Pages to an in memory buffer in dictionary encoding mode + // Serializes the Data Pages in other encoding modes + void AddDataPage(); + + // Serializes Data Pages + void WriteDataPage(const CompressedDataPage& page); + + // Write multiple definition levels + void WriteDefinitionLevels(int64_t num_levels, const int16_t* levels) { + DCHECK(!closed_); + definition_levels_sink_->Write(reinterpret_cast(levels), + sizeof(int16_t) * num_levels); + } + + // Write multiple repetition levels + void WriteRepetitionLevels(int64_t num_levels, const int16_t* levels) { + DCHECK(!closed_); + repetition_levels_sink_->Write(reinterpret_cast(levels), + sizeof(int16_t) * num_levels); + } + + // RLE encode the src_buffer into dest_buffer and return the encoded size + int64_t RleEncodeLevels(const Buffer& src_buffer, ResizableBuffer* dest_buffer, + int16_t max_level); + + // Serialize the buffered Data Pages + void FlushBufferedDataPages(); + + ColumnChunkMetaDataBuilder* metadata_; + const ColumnDescriptor* descr_; + + std::unique_ptr pager_; + + bool has_dictionary_; + Encoding::type encoding_; + const WriterProperties* properties_; + + LevelEncoder level_encoder_; + + ::arrow::MemoryPool* allocator_; + + // The total number of values stored in the data page. This is the maximum of + // the number of encoded definition levels or encoded values. For + // non-repeated, required columns, this is equal to the number of encoded + // values. For repeated or optional values, there may be fewer data values + // than levels, and this tells you how many encoded levels there are in that + // case. + int64_t num_buffered_values_; + + // The total number of stored values. For repeated or optional values, this + // number may be lower than num_buffered_values_. + int64_t num_buffered_encoded_values_; + + // Total number of rows written with this ColumnWriter + int rows_written_; + + // Records the total number of bytes written by the serializer + int64_t total_bytes_written_; + + // Records the current number of compressed bytes in a column + int64_t total_compressed_bytes_; + + // Flag to check if the Writer has been closed + bool closed_; + + // Flag to infer if dictionary encoding has fallen back to PLAIN + bool fallback_; + + std::unique_ptr definition_levels_sink_; + std::unique_ptr repetition_levels_sink_; + + std::shared_ptr definition_levels_rle_; + std::shared_ptr repetition_levels_rle_; + + std::shared_ptr uncompressed_data_; + std::shared_ptr compressed_data_; + + std::vector data_pages_; + + private: + void InitSinks() { + definition_levels_sink_->Clear(); + repetition_levels_sink_->Clear(); + } +}; // return the size of the encoded buffer -int64_t ColumnWriter::RleEncodeLevels(const Buffer& src_buffer, - ResizableBuffer* dest_buffer, int16_t max_level) { +int64_t ColumnWriterImpl::RleEncodeLevels(const Buffer& src_buffer, + ResizableBuffer* dest_buffer, + int16_t max_level) { // TODO: This only works with due to some RLE specifics int64_t rle_size = LevelEncoder::MaxBufferSize(Encoding::RLE, max_level, static_cast(num_buffered_values_)) + @@ -425,7 +516,7 @@ int64_t ColumnWriter::RleEncodeLevels(const Buffer& src_buffer, return encoded_size; } -void ColumnWriter::AddDataPage() { +void ColumnWriterImpl::AddDataPage() { int64_t definition_levels_rle_size = 0; int64_t repetition_levels_rle_size = 0; @@ -493,11 +584,11 @@ void ColumnWriter::AddDataPage() { num_buffered_encoded_values_ = 0; } -void ColumnWriter::WriteDataPage(const CompressedDataPage& page) { +void ColumnWriterImpl::WriteDataPage(const CompressedDataPage& page) { total_bytes_written_ += pager_->WriteDataPage(page); } -int64_t ColumnWriter::Close() { +int64_t ColumnWriterImpl::Close() { if (!closed_) { closed_ = true; if (has_dictionary_ && !fallback_) { @@ -525,7 +616,7 @@ int64_t ColumnWriter::Close() { return total_bytes_written_; } -void ColumnWriter::FlushBufferedDataPages() { +void ColumnWriterImpl::FlushBufferedDataPages() { // Write all outstanding data to a new page if (num_buffered_values_ > 0) { AddDataPage(); @@ -540,47 +631,123 @@ void ColumnWriter::FlushBufferedDataPages() { // ---------------------------------------------------------------------- // TypedColumnWriter -template -TypedColumnWriter::TypedColumnWriter(ColumnChunkMetaDataBuilder* metadata, - std::unique_ptr pager, - const bool use_dictionary, - Encoding::type encoding, - const WriterProperties* properties) - : ColumnWriter(metadata, std::move(pager), use_dictionary, encoding, properties) { - current_encoder_ = MakeEncoder(Type::type_num, encoding, use_dictionary, descr_, - properties->memory_pool()); - - if (properties->statistics_enabled(descr_->path()) && - (SortOrder::UNKNOWN != descr_->sort_order())) { - page_statistics_ = std::unique_ptr(new TypedStats(descr_, allocator_)); - chunk_statistics_ = std::unique_ptr(new TypedStats(descr_, allocator_)); +template +class TypedColumnWriterImpl : public ColumnWriterImpl, public TypedColumnWriter { + public: + using T = typename DType::c_type; + + TypedColumnWriterImpl(ColumnChunkMetaDataBuilder* metadata, + std::unique_ptr pager, const bool use_dictionary, + Encoding::type encoding, const WriterProperties* properties) + : ColumnWriterImpl(metadata, std::move(pager), use_dictionary, encoding, + properties) { + current_encoder_ = MakeEncoder(DType::type_num, encoding, use_dictionary, descr_, + properties->memory_pool()); + + if (properties->statistics_enabled(descr_->path()) && + (SortOrder::UNKNOWN != descr_->sort_order())) { + page_statistics_ = std::unique_ptr(new TypedStats(descr_, allocator_)); + chunk_statistics_ = std::unique_ptr(new TypedStats(descr_, allocator_)); + } + } + + int64_t Close() override { return ColumnWriterImpl::Close(); } + + void WriteBatch(int64_t num_values, const int16_t* def_levels, + const int16_t* rep_levels, const T* values) override; + + void WriteBatchSpaced(int64_t num_values, const int16_t* def_levels, + const int16_t* rep_levels, const uint8_t* valid_bits, + int64_t valid_bits_offset, const T* values) override; + + int64_t EstimatedBufferedValueBytes() const override { + return current_encoder_->EstimatedDataEncodedSize(); } -} + + protected: + std::shared_ptr GetValuesBuffer() override { + return current_encoder_->FlushValues(); + } + void WriteDictionaryPage() override; + + // Checks if the Dictionary Page size limit is reached + // If the limit is reached, the Dictionary and Data Pages are serialized + // The encoding is switched to PLAIN + void CheckDictionarySizeLimit(); + + EncodedStatistics GetPageStatistics() override { + EncodedStatistics result; + if (page_statistics_) result = page_statistics_->Encode(); + return result; + } + + EncodedStatistics GetChunkStatistics() override { + EncodedStatistics result; + if (chunk_statistics_) result = chunk_statistics_->Encode(); + return result; + } + + void ResetPageStatistics() override; + + Type::type type() const override { return descr_->physical_type(); } + + const ColumnDescriptor* descr() const override { return descr_; } + + int64_t rows_written() const override { return rows_written_; } + + int64_t total_compressed_bytes() const override { return total_compressed_bytes_; } + + int64_t total_bytes_written() const override { return total_bytes_written_; } + + const WriterProperties* properties() override { return properties_; } + + private: + inline int64_t WriteMiniBatch(int64_t num_values, const int16_t* def_levels, + const int16_t* rep_levels, const T* values); + + inline int64_t WriteMiniBatchSpaced(int64_t num_values, const int16_t* def_levels, + const int16_t* rep_levels, + const uint8_t* valid_bits, + int64_t valid_bits_offset, const T* values, + int64_t* num_spaced_written); + + // Write values to a temporary buffer before they are encoded into pages + void WriteValues(int64_t num_values, const T* values); + void WriteValuesSpaced(int64_t num_values, const uint8_t* valid_bits, + int64_t valid_bits_offset, const T* values); + + using ValueEncoderType = typename EncodingTraits::Encoder; + std::unique_ptr current_encoder_; + + typedef TypedRowGroupStatistics TypedStats; + std::unique_ptr page_statistics_; + std::unique_ptr chunk_statistics_; +}; // Only one Dictionary Page is written. // Fallback to PLAIN if dictionary page limit is reached. -template -void TypedColumnWriter::CheckDictionarySizeLimit() { +template +void TypedColumnWriterImpl::CheckDictionarySizeLimit() { // We have to dynamic cast here because TypedEncoder as some compilers // don't want to cast through virtual inheritance - auto dict_encoder = dynamic_cast*>(current_encoder_.get()); + auto dict_encoder = dynamic_cast*>(current_encoder_.get()); if (dict_encoder->dict_encoded_size() >= properties_->dictionary_pagesize_limit()) { WriteDictionaryPage(); // Serialize the buffered Dictionary Indicies FlushBufferedDataPages(); fallback_ = true; // Only PLAIN encoding is supported for fallback in V1 - current_encoder_ = MakeEncoder(Type::type_num, Encoding::PLAIN, false, descr_, + current_encoder_ = MakeEncoder(DType::type_num, Encoding::PLAIN, false, descr_, properties_->memory_pool()); encoding_ = Encoding::PLAIN; } } -template -void TypedColumnWriter::WriteDictionaryPage() { +template +void TypedColumnWriterImpl::WriteDictionaryPage() { // We have to dynamic cast here because TypedEncoder as some compilers // don't want to cast through virtual inheritance - auto dict_encoder = dynamic_cast*>(current_encoder_.get()); + auto dict_encoder = dynamic_cast*>(current_encoder_.get()); DCHECK(dict_encoder); std::shared_ptr buffer = AllocateBuffer(properties_->memory_pool(), dict_encoder->dict_encoded_size()); @@ -591,81 +758,22 @@ void TypedColumnWriter::WriteDictionaryPage() { total_bytes_written_ += pager_->WriteDictionaryPage(page); } -template -EncodedStatistics TypedColumnWriter::GetPageStatistics() { - EncodedStatistics result; - if (page_statistics_) result = page_statistics_->Encode(); - return result; -} - -template -EncodedStatistics TypedColumnWriter::GetChunkStatistics() { - EncodedStatistics result; - if (chunk_statistics_) result = chunk_statistics_->Encode(); - return result; -} - -template -void TypedColumnWriter::ResetPageStatistics() { +template +void TypedColumnWriterImpl::ResetPageStatistics() { if (chunk_statistics_ != nullptr) { chunk_statistics_->Merge(*page_statistics_); page_statistics_->Reset(); } } -// ---------------------------------------------------------------------- -// Dynamic column writer constructor - -std::shared_ptr ColumnWriter::Make(ColumnChunkMetaDataBuilder* metadata, - std::unique_ptr pager, - const WriterProperties* properties) { - const ColumnDescriptor* descr = metadata->descr(); - const bool use_dictionary = properties->dictionary_enabled(descr->path()) && - descr->physical_type() != Type::BOOLEAN; - Encoding::type encoding = properties->encoding(descr->path()); - if (use_dictionary) { - encoding = properties->dictionary_index_encoding(); - } - switch (descr->physical_type()) { - case Type::BOOLEAN: - return std::make_shared(metadata, std::move(pager), use_dictionary, - encoding, properties); - case Type::INT32: - return std::make_shared(metadata, std::move(pager), use_dictionary, - encoding, properties); - case Type::INT64: - return std::make_shared(metadata, std::move(pager), use_dictionary, - encoding, properties); - case Type::INT96: - return std::make_shared(metadata, std::move(pager), use_dictionary, - encoding, properties); - case Type::FLOAT: - return std::make_shared(metadata, std::move(pager), use_dictionary, - encoding, properties); - case Type::DOUBLE: - return std::make_shared(metadata, std::move(pager), use_dictionary, - encoding, properties); - case Type::BYTE_ARRAY: - return std::make_shared(metadata, std::move(pager), use_dictionary, - encoding, properties); - case Type::FIXED_LEN_BYTE_ARRAY: - return std::make_shared( - metadata, std::move(pager), use_dictionary, encoding, properties); - default: - ParquetException::NYI("type reader not implemented"); - } - // Unreachable code, but supress compiler warning - return std::shared_ptr(nullptr); -} - // ---------------------------------------------------------------------- // Instantiate templated classes template -inline int64_t TypedColumnWriter::WriteMiniBatch(int64_t num_values, - const int16_t* def_levels, - const int16_t* rep_levels, - const T* values) { +int64_t TypedColumnWriterImpl::WriteMiniBatch(int64_t num_values, + const int16_t* def_levels, + const int16_t* rep_levels, + const T* values) { int64_t values_to_write = 0; // If the field is required and non-repeated, there are no definition levels if (descr_->max_definition_level() > 0) { @@ -722,7 +830,7 @@ inline int64_t TypedColumnWriter::WriteMiniBatch(int64_t num_values, } template -inline int64_t TypedColumnWriter::WriteMiniBatchSpaced( +int64_t TypedColumnWriterImpl::WriteMiniBatchSpaced( int64_t num_levels, const int16_t* def_levels, const int16_t* rep_levels, const uint8_t* valid_bits, int64_t valid_bits_offset, const T* values, int64_t* num_spaced_written) { @@ -793,8 +901,10 @@ inline int64_t TypedColumnWriter::WriteMiniBatchSpaced( } template -void TypedColumnWriter::WriteBatch(int64_t num_values, const int16_t* def_levels, - const int16_t* rep_levels, const T* values) { +void TypedColumnWriterImpl::WriteBatch(int64_t num_values, + const int16_t* def_levels, + const int16_t* rep_levels, + const T* values) { // We check for DataPage limits only after we have inserted the values. If a user // writes a large number of values, the DataPage size can be much above the limit. // The purpose of this chunking is to bound this. Even if a user writes large number @@ -817,7 +927,7 @@ void TypedColumnWriter::WriteBatch(int64_t num_values, const int16_t* def } template -void TypedColumnWriter::WriteBatchSpaced( +void TypedColumnWriterImpl::WriteBatchSpaced( int64_t num_values, const int16_t* def_levels, const int16_t* rep_levels, const uint8_t* valid_bits, int64_t valid_bits_offset, const T* values) { // We check for DataPage limits only after we have inserted the values. If a user @@ -845,27 +955,63 @@ void TypedColumnWriter::WriteBatchSpaced( } template -void TypedColumnWriter::WriteValues(int64_t num_values, const T* values) { +void TypedColumnWriterImpl::WriteValues(int64_t num_values, const T* values) { dynamic_cast(current_encoder_.get()) ->Put(values, static_cast(num_values)); } template -void TypedColumnWriter::WriteValuesSpaced(int64_t num_values, - const uint8_t* valid_bits, - int64_t valid_bits_offset, - const T* values) { +void TypedColumnWriterImpl::WriteValuesSpaced(int64_t num_values, + const uint8_t* valid_bits, + int64_t valid_bits_offset, + const T* values) { dynamic_cast(current_encoder_.get()) ->PutSpaced(values, static_cast(num_values), valid_bits, valid_bits_offset); } -template class PARQUET_TEMPLATE_EXPORT TypedColumnWriter; -template class PARQUET_TEMPLATE_EXPORT TypedColumnWriter; -template class PARQUET_TEMPLATE_EXPORT TypedColumnWriter; -template class PARQUET_TEMPLATE_EXPORT TypedColumnWriter; -template class PARQUET_TEMPLATE_EXPORT TypedColumnWriter; -template class PARQUET_TEMPLATE_EXPORT TypedColumnWriter; -template class PARQUET_TEMPLATE_EXPORT TypedColumnWriter; -template class PARQUET_TEMPLATE_EXPORT TypedColumnWriter; +// ---------------------------------------------------------------------- +// Dynamic column writer constructor + +std::shared_ptr ColumnWriter::Make(ColumnChunkMetaDataBuilder* metadata, + std::unique_ptr pager, + const WriterProperties* properties) { + const ColumnDescriptor* descr = metadata->descr(); + const bool use_dictionary = properties->dictionary_enabled(descr->path()) && + descr->physical_type() != Type::BOOLEAN; + Encoding::type encoding = properties->encoding(descr->path()); + if (use_dictionary) { + encoding = properties->dictionary_index_encoding(); + } + switch (descr->physical_type()) { + case Type::BOOLEAN: + return std::make_shared>( + metadata, std::move(pager), use_dictionary, encoding, properties); + case Type::INT32: + return std::make_shared>( + metadata, std::move(pager), use_dictionary, encoding, properties); + case Type::INT64: + return std::make_shared>( + metadata, std::move(pager), use_dictionary, encoding, properties); + case Type::INT96: + return std::make_shared>( + metadata, std::move(pager), use_dictionary, encoding, properties); + case Type::FLOAT: + return std::make_shared>( + metadata, std::move(pager), use_dictionary, encoding, properties); + case Type::DOUBLE: + return std::make_shared>( + metadata, std::move(pager), use_dictionary, encoding, properties); + case Type::BYTE_ARRAY: + return std::make_shared>( + metadata, std::move(pager), use_dictionary, encoding, properties); + case Type::FIXED_LEN_BYTE_ARRAY: + return std::make_shared>( + metadata, std::move(pager), use_dictionary, encoding, properties); + default: + ParquetException::NYI("type reader not implemented"); + } + // Unreachable code, but supress compiler warning + return std::shared_ptr(nullptr); +} } // namespace parquet diff --git a/cpp/src/parquet/column_writer.h b/cpp/src/parquet/column_writer.h index 254bf0dd02e50..5b9efb43ae07c 100644 --- a/cpp/src/parquet/column_writer.h +++ b/cpp/src/parquet/column_writer.h @@ -105,147 +105,47 @@ class PARQUET_EXPORT PageWriter { static constexpr int WRITE_BATCH_SIZE = 1000; class PARQUET_EXPORT ColumnWriter { public: - ColumnWriter(ColumnChunkMetaDataBuilder*, std::unique_ptr, - bool has_dictionary, Encoding::type encoding, - const WriterProperties* properties); - virtual ~ColumnWriter() = default; static std::shared_ptr Make(ColumnChunkMetaDataBuilder*, std::unique_ptr, const WriterProperties* properties); - Type::type type() const { return descr_->physical_type(); } - - const ColumnDescriptor* descr() const { return descr_; } - - /** - * Closes the ColumnWriter, commits any buffered values to pages. - * - * @return Total size of the column in bytes - */ - int64_t Close(); - - int64_t rows_written() const { return rows_written_; } - - // Only considers the size of the compressed pages + page header - // Some values might be still buffered an not written to a page yet - int64_t total_compressed_bytes() const { return total_compressed_bytes_; } - - int64_t total_bytes_written() const { return total_bytes_written_; } - - const WriterProperties* properties() { return properties_; } - - protected: - virtual std::shared_ptr GetValuesBuffer() = 0; - - // Serializes Dictionary Page if enabled - virtual void WriteDictionaryPage() = 0; - - // Checks if the Dictionary Page size limit is reached - // If the limit is reached, the Dictionary and Data Pages are serialized - // The encoding is switched to PLAIN - - virtual void CheckDictionarySizeLimit() = 0; + /// \brief Closes the ColumnWriter, commits any buffered values to pages. + /// \return Total size of the column in bytes + virtual int64_t Close() = 0; - // Plain-encoded statistics of the current page - virtual EncodedStatistics GetPageStatistics() = 0; + /// \brief The physical Parquet type of the column + virtual Type::type type() const = 0; - // Plain-encoded statistics of the whole chunk - virtual EncodedStatistics GetChunkStatistics() = 0; - - // Merges page statistics into chunk statistics, then resets the values - virtual void ResetPageStatistics() = 0; - - // Adds Data Pages to an in memory buffer in dictionary encoding mode - // Serializes the Data Pages in other encoding modes - void AddDataPage(); - - // Serializes Data Pages - void WriteDataPage(const CompressedDataPage& page); - - // Write multiple definition levels - void WriteDefinitionLevels(int64_t num_levels, const int16_t* levels); - - // Write multiple repetition levels - void WriteRepetitionLevels(int64_t num_levels, const int16_t* levels); - - // RLE encode the src_buffer into dest_buffer and return the encoded size - int64_t RleEncodeLevels(const Buffer& src_buffer, ResizableBuffer* dest_buffer, - int16_t max_level); - - // Serialize the buffered Data Pages - void FlushBufferedDataPages(); - - ColumnChunkMetaDataBuilder* metadata_; - const ColumnDescriptor* descr_; - - std::unique_ptr pager_; - - bool has_dictionary_; - Encoding::type encoding_; - const WriterProperties* properties_; + /// \brief The schema for the column + virtual const ColumnDescriptor* descr() const = 0; - LevelEncoder level_encoder_; + /// \brief The number of rows written so far + virtual int64_t rows_written() const = 0; - ::arrow::MemoryPool* allocator_; + /// \brief The total size of the compressed pages + page headers. Some values + /// might be still buffered an not written to a page yet + virtual int64_t total_compressed_bytes() const = 0; - // The total number of values stored in the data page. This is the maximum of - // the number of encoded definition levels or encoded values. For - // non-repeated, required columns, this is equal to the number of encoded - // values. For repeated or optional values, there may be fewer data values - // than levels, and this tells you how many encoded levels there are in that - // case. - int64_t num_buffered_values_; + /// \brief The total number of bytes written as serialized data and + /// dictionary pages to the ColumnChunk so far + virtual int64_t total_bytes_written() const = 0; - // The total number of stored values. For repeated or optional values, this - // number may be lower than num_buffered_values_. - int64_t num_buffered_encoded_values_; - - // Total number of rows written with this ColumnWriter - int rows_written_; - - // Records the total number of bytes written by the serializer - int64_t total_bytes_written_; - - // Records the current number of compressed bytes in a column - int64_t total_compressed_bytes_; - - // Flag to check if the Writer has been closed - bool closed_; - - // Flag to infer if dictionary encoding has fallen back to PLAIN - bool fallback_; - - std::unique_ptr definition_levels_sink_; - std::unique_ptr repetition_levels_sink_; - - std::shared_ptr definition_levels_rle_; - std::shared_ptr repetition_levels_rle_; - - std::shared_ptr uncompressed_data_; - std::shared_ptr compressed_data_; - - std::vector data_pages_; - - private: - void InitSinks(); + /// \brief The file-level writer properties + virtual const WriterProperties* properties() = 0; }; // API to write values to a single column. This is the main client facing API. template -class PARQUET_TEMPLATE_CLASS_EXPORT TypedColumnWriter : public ColumnWriter { +class TypedColumnWriter : public ColumnWriter { public: - typedef typename DType::c_type T; - - TypedColumnWriter(ColumnChunkMetaDataBuilder* metadata, - std::unique_ptr pager, const bool use_dictionary, - Encoding::type encoding, const WriterProperties* properties); + using T = typename DType::c_type; // Write a batch of repetition levels, definition levels, and values to the // column. - void WriteBatch(int64_t num_values, const int16_t* def_levels, - const int16_t* rep_levels, const T* values); + virtual void WriteBatch(int64_t num_values, const int16_t* def_levels, + const int16_t* rep_levels, const T* values) = 0; /// Write a batch of repetition levels, definition levels, and values to the /// column. @@ -273,63 +173,21 @@ class PARQUET_TEMPLATE_CLASS_EXPORT TypedColumnWriter : public ColumnWriter { /// @param values The values in the lowest nested level including /// spacing for nulls on the lowest levels; input has the length /// of the number of rows on the lowest nesting level. - void WriteBatchSpaced(int64_t num_values, const int16_t* def_levels, - const int16_t* rep_levels, const uint8_t* valid_bits, - int64_t valid_bits_offset, const T* values); + virtual void WriteBatchSpaced(int64_t num_values, const int16_t* def_levels, + const int16_t* rep_levels, const uint8_t* valid_bits, + int64_t valid_bits_offset, const T* values) = 0; // Estimated size of the values that are not written to a page yet - int64_t EstimatedBufferedValueBytes() const { - return current_encoder_->EstimatedDataEncodedSize(); - } - - protected: - std::shared_ptr GetValuesBuffer() override { - return current_encoder_->FlushValues(); - } - void WriteDictionaryPage() override; - void CheckDictionarySizeLimit() override; - EncodedStatistics GetPageStatistics() override; - EncodedStatistics GetChunkStatistics() override; - void ResetPageStatistics() override; - - private: - int64_t WriteMiniBatch(int64_t num_values, const int16_t* def_levels, - const int16_t* rep_levels, const T* values); - - int64_t WriteMiniBatchSpaced(int64_t num_values, const int16_t* def_levels, - const int16_t* rep_levels, const uint8_t* valid_bits, - int64_t valid_bits_offset, const T* values, - int64_t* num_spaced_written); - - // Write values to a temporary buffer before they are encoded into pages - void WriteValues(int64_t num_values, const T* values); - void WriteValuesSpaced(int64_t num_values, const uint8_t* valid_bits, - int64_t valid_bits_offset, const T* values); - - using ValueEncoderType = typename EncodingTraits::Encoder; - std::unique_ptr current_encoder_; - - typedef TypedRowGroupStatistics TypedStats; - std::unique_ptr page_statistics_; - std::unique_ptr chunk_statistics_; + virtual int64_t EstimatedBufferedValueBytes() const = 0; }; -typedef TypedColumnWriter BoolWriter; -typedef TypedColumnWriter Int32Writer; -typedef TypedColumnWriter Int64Writer; -typedef TypedColumnWriter Int96Writer; -typedef TypedColumnWriter FloatWriter; -typedef TypedColumnWriter DoubleWriter; -typedef TypedColumnWriter ByteArrayWriter; -typedef TypedColumnWriter FixedLenByteArrayWriter; - -PARQUET_EXTERN_TEMPLATE TypedColumnWriter; -PARQUET_EXTERN_TEMPLATE TypedColumnWriter; -PARQUET_EXTERN_TEMPLATE TypedColumnWriter; -PARQUET_EXTERN_TEMPLATE TypedColumnWriter; -PARQUET_EXTERN_TEMPLATE TypedColumnWriter; -PARQUET_EXTERN_TEMPLATE TypedColumnWriter; -PARQUET_EXTERN_TEMPLATE TypedColumnWriter; -PARQUET_EXTERN_TEMPLATE TypedColumnWriter; +using BoolWriter = TypedColumnWriter; +using Int32Writer = TypedColumnWriter; +using Int64Writer = TypedColumnWriter; +using Int96Writer = TypedColumnWriter; +using FloatWriter = TypedColumnWriter; +using DoubleWriter = TypedColumnWriter; +using ByteArrayWriter = TypedColumnWriter; +using FixedLenByteArrayWriter = TypedColumnWriter; } // namespace parquet From c693dded7182a64818754cba1811f4fc96d44f7e Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Tue, 5 Feb 2019 08:31:10 -0600 Subject: [PATCH 17/21] ARROW-4460: [Website] DataFusion Blog Post Author: Andy Grove Closes #3548 from andygrove/ARROW-4460 and squashes the following commits: 5b3a77091 minor edit b5570608d Remove section on Gandiva, change date to today, link to JIRA for adding Parquet support 29c3dbec8 update blog post with link to open Rust issues in confluence a6fd60683 Minor edits 067cb0431 Add example to blog post 2a735fdcc Draft of DataFusion announcement --- site/_posts/2019-02-04-datafusion-donation.md | 119 ++++++++++++++++++ 1 file changed, 119 insertions(+) create mode 100644 site/_posts/2019-02-04-datafusion-donation.md diff --git a/site/_posts/2019-02-04-datafusion-donation.md b/site/_posts/2019-02-04-datafusion-donation.md new file mode 100644 index 0000000000000..9a7806cf00ccf --- /dev/null +++ b/site/_posts/2019-02-04-datafusion-donation.md @@ -0,0 +1,119 @@ +--- +layout: post +title: "DataFusion: A Rust-native Query Engine for Apache Arrow" +date: "2019-02-04 00:00:00 -0600" +author: agrove +categories: [application] +--- + + +We are excited to announce that [DataFusion](https://github.com/apache/arrow/tree/master/rust/datafusion) has been donated to the Apache Arrow project. DataFusion is an in-memory query engine for the Rust implementation of Apache Arrow. + +Although DataFusion was started two years ago, it was recently re-implemented to be Arrow-native and currently has limited capabilities but does support SQL queries against iterators of RecordBatch and has support for CSV files. There are plans to [add support for Parquet files](https://issues.apache.org/jira/browse/ARROW-4466). + +SQL support is limited to projection (`SELECT`), selection (`WHERE`), and simple aggregates (`MIN`, `MAX`, `SUM`) with an optional `GROUP BY` clause. + +Supported expressions are identifiers, literals, simple math operations (`+`, `-`, `*`, `/`), binary expressions (`AND`, `OR`), equality and comparison operators (`=`, `!=`, `<`, `<=`, `>=`, `>`), and `CAST(expr AS type)`. + +## Example + +The following example demonstrates running a simple aggregate SQL query against a CSV file. + +```rust +// create execution context +let mut ctx = ExecutionContext::new(); + +// define schema for data source (csv file) +let schema = Arc::new(Schema::new(vec![ + Field::new("c1", DataType::Utf8, false), + Field::new("c2", DataType::UInt32, false), + Field::new("c3", DataType::Int8, false), + Field::new("c4", DataType::Int16, false), + Field::new("c5", DataType::Int32, false), + Field::new("c6", DataType::Int64, false), + Field::new("c7", DataType::UInt8, false), + Field::new("c8", DataType::UInt16, false), + Field::new("c9", DataType::UInt32, false), + Field::new("c10", DataType::UInt64, false), + Field::new("c11", DataType::Float32, false), + Field::new("c12", DataType::Float64, false), + Field::new("c13", DataType::Utf8, false), +])); + +// register csv file with the execution context +let csv_datasource = + CsvDataSource::new("test/data/aggregate_test_100.csv", schema.clone(), 1024); +ctx.register_datasource("aggregate_test_100", Rc::new(RefCell::new(csv_datasource))); + +let sql = "SELECT c1, MIN(c12), MAX(c12) FROM aggregate_test_100 WHERE c11 > 0.1 AND c11 < 0.9 GROUP BY c1"; + +// execute the query +let relation = ctx.sql(&sql).unwrap(); +let mut results = relation.borrow_mut(); + +// iterate over the results +while let Some(batch) = results.next().unwrap() { + println!( + "RecordBatch has {} rows and {} columns", + batch.num_rows(), + batch.num_columns() + ); + + let c1 = batch + .column(0) + .as_any() + .downcast_ref::() + .unwrap(); + + let min = batch + .column(1) + .as_any() + .downcast_ref::() + .unwrap(); + + let max = batch + .column(2) + .as_any() + .downcast_ref::() + .unwrap(); + + for i in 0..batch.num_rows() { + let c1_value: String = String::from_utf8(c1.value(i).to_vec()).unwrap(); + println!("{}, Min: {}, Max: {}", c1_value, min.value(i), max.value(i),); + } +} +``` + +## Roadmap + +The roadmap for DataFusion will depend on interest from the Rust community, but here are some of the short term items that are planned: + +- Extending test coverage of the existing functionality +- Adding support for Parquet data sources +- Implementing more SQL features such as `JOIN`, `ORDER BY` and `LIMIT` +- Implement a DataFrame API as an alternative to SQL +- Adding support for partitioning and parallel query execution using Rust's async and await functionality +- Creating a Docker image to make it easy to use DataFusion as a standalone query tool for interactive and batch queries + +## Contributors Welcome! + +If you are excited about being able to use Rust for data science and would like to contribute to this work then there are many ways to get involved. The simplest way to get started is to try out DataFusion against your own data sources and file bug reports for any issues that you find. You could also check out the current [list of issues](https://cwiki.apache.org/confluence/display/ARROW/Rust+JIRA+Dashboard) and have a go at fixing one. You can also join the [user mailing list](http://mail-archives.apache.org/mod_mbox/arrow-user/) to ask questions. + + From 9af5a707b4527dc146095d610ea01e6bc3d34ce3 Mon Sep 17 00:00:00 2001 From: Wes McKinney Date: Tue, 5 Feb 2019 08:41:30 -0600 Subject: [PATCH 18/21] ARROW-4472: [Website][Python] Blog post about string memory use work in Arrow 0.12 This blog shows how we were able to significant improve performance and memory use in common cases when converting from the Arrow string memory layout to pandas's native memory model based on NumPy arrays of Python objects. Author: Wes McKinney Closes #3553 from wesm/python-string-memory-0.12 and squashes the following commits: f0d684d7f Update publication date 2bbb92d42 Fix some base urls c624e5545 Draft blog post about string memory use work in Arrow 0.12 --- .../_posts/2019-01-25-r-spark-improvements.md | 4 +- .../2019-02-05-python-string-memory-0.12.md | 246 ++++++++++++++++++ site/img/20190205-arrow-string.png | Bin 0 -> 4127 bytes site/img/20190205-numpy-string.png | Bin 0 -> 13714 bytes 4 files changed, 248 insertions(+), 2 deletions(-) create mode 100644 site/_posts/2019-02-05-python-string-memory-0.12.md create mode 100644 site/img/20190205-arrow-string.png create mode 100644 site/img/20190205-numpy-string.png diff --git a/site/_posts/2019-01-25-r-spark-improvements.md b/site/_posts/2019-01-25-r-spark-improvements.md index 23fba426f4dc3..982514ab6aba5 100644 --- a/site/_posts/2019-01-25-r-spark-improvements.md +++ b/site/_posts/2019-01-25-r-spark-improvements.md @@ -95,7 +95,7 @@ microbenchmark::microbenchmark( ```

-Copying data with R into Spark with and without Arrow
@@ -131,7 +131,7 @@ Unit: seconds ```
-Collecting data with R from Spark with and without Arrow
diff --git a/site/_posts/2019-02-05-python-string-memory-0.12.md b/site/_posts/2019-02-05-python-string-memory-0.12.md new file mode 100644 index 0000000000000..842cc23f73e3a --- /dev/null +++ b/site/_posts/2019-02-05-python-string-memory-0.12.md @@ -0,0 +1,246 @@ +--- +layout: post +title: "Reducing Python String Memory Use in Apache Arrow 0.12" +date: "2019-02-05 07:00:00 -0600" +author: wesm +categories: [application] +--- + + +Python users who upgrade to recently released `pyarrow` 0.12 may find that +their applications use significantly less memory when converting Arrow string +data to pandas format. This includes using `pyarrow.parquet.read_table` and +`pandas.read_parquet`. This article details some of what is going on under the +hood, and why Python applications dealing with large amounts of strings are +prone to memory use problems. + +## Why Python strings can use a lot of memory + +Let's start with some possibly surprising facts. I'm going to create an empty +`bytes` object and an empty `str` (unicode) object in Python 3.7: + +``` +In [1]: val = b'' + +In [2]: unicode_val = u'' +``` + +The `sys.getsizeof` function accurately reports the number of bytes used by +built-in Python objects. You might be surprised to find that: + + +``` +In [4]: import sys +In [5]: sys.getsizeof(val) +Out[5]: 33 + +In [6]: sys.getsizeof(unicode_val) +Out[6]: 49 +``` + +Since strings in Python are nul-terminated, we can infer that a bytes object +has 32 bytes of overhead while unicode has 48 bytes. One must also account for +`PyObject*` pointer references to the objects, so the actual overhead is 40 and +56 bytes, respectively. With large strings and text, this overhead may not +matter much, but when you have a lot of small strings, such as those arising +from reading a CSV or Apache Parquet file, they can take up an unexpected +amount of memory. pandas represents strings in NumPy arrays of `PyObject*` +pointers, so the total memory used by a unique unicode string is + +``` +8 (PyObject*) + 48 (Python C struct) + string_length + 1 +``` + +Suppose that we read a CSV file with + +* 1 column +* 1 million rows +* Each value in the column is a string with 10 characters + +On disk this file would take approximately 10MB. Read into memory, however, it +could take up over 60MB, as a 10 character string object takes up 67 bytes in a +`pandas.Series`. + +## How Apache Arrow represents strings + +While a Python unicode string can have 57 bytes of overhead, a string in the +Arrow columnar format has only 4 (32 bits) or 4.125 (33 bits) bytes of +overhead. 32-bit integer offsets encodes the position and size of a string +value in a contiguous chunk of memory: + +
+Apache Arrow string memory layout +
+ +When you call `table.to_pandas()` or `array.to_pandas()` with `pyarrow`, we +have to convert this compact string representation back to pandas's +Python-based strings. This can use a huge amount of memory when we have a large +number of small strings. It is a quite common occurrence when working with web +analytics data, which compresses to a compact size when stored in the Parquet +columnar file format. + +Note that the Arrow string memory format has other benefits beyond memory +use. It is also much more efficient for analytics due to the guarantee of data +locality; all strings are next to each other in memory. In the case of pandas +and Python strings, the string data can be located anywhere in the process +heap. Arrow PMC member Uwe Korn did some work to [extend pandas with Arrow +string arrays][1] for improved performance and memory use. + +## Reducing pandas memory use when converting from Arrow + +For many years, the `pandas.read_csv` function has relied on a trick to limit +the amount of string memory allocated. Because pandas uses arrays of +`PyObject*` pointers to refer to objects in the Python heap, we can avoid +creating multiple strings with the same value, instead reusing existing objects +and incrementing their reference counts. + +Schematically, we have the following: + +
+pandas string memory optimization +
+ +In `pyarrow` 0.12, we have implemented this when calling `to_pandas`. It +requires using a hash table to deduplicate the Arrow string data as it's being +converted to pandas. Hashing data is not free, but counterintuitively it can be +faster in addition to being vastly more memory efficient in the common case in +analytics where we have table columns with many instances of the same string +values. + +## Memory and Performance Benchmarks + +We can use the `memory_profiler` Python package to easily get process memory +usage within a running Python application. + +``` +import memory_profiler +def mem(): + return memory_profiler.memory_usage()[0] +``` + +In a new application I have: + +``` +In [7]: mem() +Out[7]: 86.21875 +``` + +I will generate approximate 1 gigabyte of string data represented as Python +strings with length 10. The `pandas.util.testing` module has a handy `rands` +function for generating random strings. Here is the data generation function: + +```python +from pandas.util.testing import rands +def generate_strings(length, nunique, string_length=10): + unique_values = [rands(string_length) for i in range(nunique)] + values = unique_values * (length // nunique) + return values +``` + +This generates a certain number of unique strings, then duplicates then to +yield the desired number of total strings. So I'm going to create 100 million +strings with only 10000 unique values: + +``` +In [8]: values = generate_strings(100000000, 10000) + +In [9]: mem() +Out[9]: 852.140625 +``` + +100 million `PyObject*` values is only 745 MB, so this increase of a little +over 770 MB is consistent with what we know so far. Now I'm going to convert +this to Arrow format: + +``` +In [11]: arr = pa.array(values) + +In [12]: mem() +Out[12]: 2276.9609375 +``` + +Since `pyarrow` exactly accounts for all of its memory allocations, we also +check that + +``` +In [13]: pa.total_allocated_bytes() +Out[13]: 1416777280 +``` + +Since each string takes about 14 bytes (10 bytes plus 4 bytes of overhead), +this is what we expect. + +Now, converting `arr` back to pandas is where things get tricky. The _minimum_ +amount of memory that pandas can use is a little under 800 MB as above as we +need 100 million `PyObject*` values, which are 8 bytes each. + +``` +In [14]: arr_as_pandas = arr.to_pandas() + +In [15]: mem() +Out[15]: 3041.78125 +``` + +Doing the math, we used 765 MB which seems right. We can disable the string +deduplication logic by passing `deduplicate_objects=False` to `to_pandas`: + +``` +In [16]: arr_as_pandas_no_dedup = arr.to_pandas(deduplicate_objects=False) + +In [17]: mem() +Out[17]: 10006.95703125 +``` + +Without object deduplication, we use 6965 megabytes, or an average of 73 bytes +per value. This is a little bit higher than the theoretical size of 67 bytes +computed above. + +One of the more surprising results is that the new behavior is about twice as fast: + +``` +In [18]: %time arr_as_pandas_time = arr.to_pandas() +CPU times: user 2.94 s, sys: 213 ms, total: 3.15 s +Wall time: 3.14 s + +In [19]: %time arr_as_pandas_no_dedup_time = arr.to_pandas(deduplicate_objects=False) +CPU times: user 4.19 s, sys: 2.04 s, total: 6.23 s +Wall time: 6.21 s +``` + +The reason for this is that creating so many Python objects is more expensive +than hashing the 10 byte values and looking them up in a hash table. + +Note that when you convert Arrow data with mostly unique values back to pandas, +the memory use benefits here won't have as much of an impact. + +## Takeaways + +In Apache Arrow, our goal is to develop computational tools to operate natively +on the cache- and SIMD-friendly efficient Arrow columnar format. In the +meantime, though, we recognize that users have legacy applications using the +native memory layout of pandas or other analytics tools. We will do our best to +provide fast and memory-efficient interoperability with pandas and other +popular libraries. + +[1]: https://www.slideshare.net/xhochy/extending-pandas-using-apache-arrow-and-numba \ No newline at end of file diff --git a/site/img/20190205-arrow-string.png b/site/img/20190205-arrow-string.png new file mode 100644 index 0000000000000000000000000000000000000000..066386b861887f2426541e7cab43aa2e3c6535c5 GIT binary patch literal 4127 zcmb_fcU03&m$r-26d_1a6zK$zsubZ969lON0tpGd_g(}QML^?^79kWVzH|t@NJ0pm zh!}c#X;OsHlui(k9%1p_{dV_{J>Qsu(KX))j1`k;;@?Jb~;smRX zHpKYEiBrHa$1|Ni#tqcg`{U`zNY51d&`0l>{`=qrUHvCL!F8GS%;{4n|5N|}BdqL~ z_Bu0qXEu*R^73Clefrb|rgQ8p7dS7m3Gi^%r3UhFUfOI;%Y1GpEzDnpaaEHQiNa?5 zQC7d&gqv@`TI#C!q8~nUGH4lDglWnKO|P zi-ktTmpAo}5nn^^N@9G>Xh1&s-MD$2hiHJYfyoJq^*?t0 zxG}joOXvA(`^@vpR6N6LG)+82R^<&=mr%n5pJ#ELJuF`vZa8s*tw;xQ-!x=wW!x1t zh7dUCbLxtpK1iGC!hv*L`14}p5;EBhRgz;F@be%%SycJtg@njkU;Qg-JA98Sn%d(l zA_!L_ZZQ2*DE2HEQ6i~jVt2b^idIk74pUq7Tdf>M&RTKMS;3zKPXXmjK>6#Bf?&b_ z2^la@ua|(d0mvd?M@1c#x03qZugU6#c_b83e?3S*HGuGOrrm4DSG#%H#nJimhGgi6 zIyTl<0jGiTyAP(amrY^p0UWODr2{ozzt*j&LnT$Sfh%|V!dUAJsE_q__y|R!6M0q# zmS)Lf(K#?0y5_qsAY~OSfOfCDg9XVfetBA*U7_t*6+o<}glMv_pzHVuV#IgK%7nB( z&A0{NazVQW)|17&D~*`O(|$RWKcK$*2^V*y%3f>w;^J+c{Hkx8u&BW>5>4f^>la2! zrN)qsPuj1cz+9);I~)sXzzr}tM5mSBY@*RWv$bC7yspwe^B~@S#ZRwP$ku#|59~2X9PnN z;E{XDd~-Fd!b(PO)2GF88(h=@T=>a?xV)8$02b36NYt?S&&@H!kG?;*=(lUPp014l zEb{MkT@R55T{$!)=^^7D6$@W+eC(E`l)p^iwkW*}-xC{5!qGzQ3Ph5uTnIF#J)vEx zWXT!_n9#tzO1q9d%`#4O<~~?Sp~e=#6{N+IL+%GRzH$K!Nj8z=xrJX3rqPBzp&wMe zM>eWi0axRiIJY0j@$9*3+H_Ts3mTM|0r!Wj41s&dU5^r=GRfv>T%P&2As0Y4zqFkw zE@Hn-3<-L`>{egF9Q>-n_D-!ZB|E*dv+$)Sq7*y65xjspydS52Hv40spZ+?sG`?pZ z$^SkcY5D|?PA`d#yP{YiMq+Rg=^v=ShMXl9kBZug9i^N`z2mf;d$fzAAwuvGig21j zaPG9s>-r;H?}Uo8ckD=SM*TLb3gaP6mIj97H#uKA6V{w@fpbPzCir{Ji0S2Gmqus# zgk2zcq1`(9B3H1A*TV;)cwe>{d@l$Wqg60r2O2WdxwSEd$T21B=go;|a;gK}&c4d? zAbmJzEUY=h-v*DitW?7)Ngr7xo_QFvid4|a9mPRpE2`Fxa+0Y_y2-3`_N;)HV7WNl{dF-|VhY0~b=JFm& zTqY#R67PN0T2{iCSe{m`jg52POG?#otju%)g00rPcB^g5>cV}9K9NuMLKgYhqR8J` zJvIjl<^mX?Exe*MR&(NOm&C^iI;)yvP(}U{+hnP(HO$B z{d$FVQn-G4Tr~EJvh!44*8-eVucPkBL9s+SafG*OUmVA>12eC`V_$O*Q28q3j=_wc zWNd&+#Ld)a>ZzrJA{)KNTwNcht9tD^My8Gu$d@#*r@-;Wfbe4L5di@*)0dKegKh2uj#)cp&5yiGlM@Qzw_C* z*M3NJ*9!78C#3#h|Eu8mDVv&7AJe;gwpk)tr0FI&`-$&F)%u17_Fh3LnhB+u$l#zoa z_1BE&r#-yuK#Al#vYpboDMA)@j&?Lb_m&n=Lo1g0ZDpJ-xFA01*X=?tpHX)I5|DjI zh{}?b9G<^+_WOqJS4y4=?s^17EYZ9WK_l%^i_# z`f3smo*0Wkfd=Ii!2@+8V|I(Xq@?1PHohCOToX06u>^7K^pgy}ZW{nXsu;;pUCwqM z84r#aGr$)1I}rAuertN;EE&^t78za%X~BkVbKm_%M(Gy-9q)+3IrTZeETmFluFw~o zz62pm(W+EZ7b3vAQUc-CPEE(#v_U*Xq>V95V*M+>pyIGy(yVGvUn#Pff z@g~L;n74}x3+g0jqhC^EZog5zZvsMimkY)?xm!Gj4}f>%xG-%b$XWyZ%+dYLm(b3n zAu#xysS@6mcO4NM`1DOgvsk(W^ZmdSO;f7?o~NR@kKRr_y-;+WJ5yAP=3Xq}_~7lJ zaa&jZe*Gm*9Hd1qT`P1@xk5_c>|tlkBORO~{HID%C@n#;!~$u;*RWM@1~7D72PTaX z?)VqXPU$rR<-2vLLC>Q{t)YoyvX$cku6+lvX0brTPq32NK?C66_Y7u-aZu=yc~lyu z6_ySX-Vi;|DEc*gk=<70hAw*Ra1e+zxwkwhM!ZTmZke}PD)mdOF^t@nr5 z8g|}1q^vl&-`n%}@hb07U~5VPy9o8Bd;4ut`qJ6W#uy@@@)EYkQU61J7h?OnqQ6%h z(m|B)*JEnChw5&d2WzJhk<>tBo*^uKXVG(Tu8PF|cVpGYEexBYRUw{35##6?p&u0- zW~AM{fyKqvUq*n(W{MARU|y&(`3l8YTl%?tV|=Erq@?Pl5pW>jEORfaTRHgiTvJ`i zfZ_WGQfxGpbV}1qE%*ySNWIZwhv5P@E?;Nt_NAA7d;+2bvG>G>P8@v#Uk?grEL?xa zT(RJxp3}Tz3yKe&e-)s>MA*gBGAP*}ZjPLr4+weo*HdodUW5O{<}+`N&ok%eYv&sd zxe@3@DhY`n_OQ&`?FXHIYTi+GXj@KYzh~YtPmjM1Sk61G1sgVlw;QX!sYfFj!^Kt4 z`g?7FwFS2yy;DUyQdhs4SB3vLEZ1P7G5Vm4)hYegR1?2T;dFtfnRnnN*HJk)Pvt*l zyw{~ll;aME%SNjTE=K!?=z6@`BIDtIj+sf$Xi3Vmk|3rlA1*E&gs0}l$GYw@sS>Wf znmp@K!3Q)A{hE`XuU;T@$y%?YNzL!c9zJLXkehrK(Df-QnQr&9tc>alLq&ol&dcB9 zZzy?YWMXORxV^f;<|8{%3N!Ig+7RWAnnQh{TFEebt~Z*sS1@Py*SH5xzLO~ogXGW# zsx%~MGQ4l9knR3f6qoBS%RdKV;YI*D4k#six6fC%c~yu{P-~KvGqi)8Qpv~)AYCb- ziE6kX&nYO!P4(Ih&Qy%x+E&5s_d&~G>>ilFlsnpIf$@Q zh|A!;21i=UR+t7R{p4MROf(~}N^3=h=;06)?S}FbSDrrU?usA8ymgTA*~m4Bx*c=1D5CtQgLFI^j9l3 zDK(+sr1@2c4HQSZrPG*WE;F~GEB5-Ny)q@&(SgDa_YsC%btihNV;kZwuckU3)CuvH#gx$Ey~IR!lY)t`{ z-h4|VA$3JssjD&C#OgYFte#NVxaH1GCc{8l8R|Vp-}tmtMZ@h3v}a56bl3px+iL9{ zj?1Ix(0WL&mTS7Tcg0|f)6>~W^qHaF+{}EsVa#XDyIuZO>A=kuh(04D&UUG`SEDfZ zoUQz9MjS&-p$+x(OF)tn67$Bc@@?FYwNa)U1{>#ztVUK043t@=lweU7*hIJ&6 z_RFq|s->Xc46P2iMSJBDz;51sw@g+5W@SW1hEit_G0GDZ4Gve@Xfi6$M{( z78G2v>XKenxy*WA@YwE7u6)}Tvj@gsMfX6k{5mnAb8B+VLbJNBhgqtciTmi*lpvZR t6?5buXw8*%Nc#a+MMv@cAOGJPt=7uhZv*qd(um&?bf5;1GEJL^e*;VUv(W$m literal 0 HcmV?d00001 diff --git a/site/img/20190205-numpy-string.png b/site/img/20190205-numpy-string.png new file mode 100644 index 0000000000000000000000000000000000000000..ed048b02708674eccbfc7c6a918f5982e896068b GIT binary patch literal 13714 zcmc(`Wl&sA&@hTiu&|2-4=nCZaKhp&gy0Z@ySqyu$l|a#EE=3Z2<}N(BxrD1BrF7X zmjphZ_g1|>?yv91t-7a9o#~$Go|c*Eo|&oxeWR&N^o;r$8X6jrstWim8X6`P4Gmow z5A&&}MKb!&Q*ro4LswDRUE>LSdN43B$;n7*X(;gVai35uEKCv-Vj@C<|7ra1LqJIQ zzwrM8=;;3o{@((RXQLCS^C#*j5e#%Mh>3`bvcu)1h0IMLYcpSao3q0MYWy+dT$^p#u=DYt zaMW6tD&evJNb12=W)Mb^x@@7( z3pu&}A^qUlJ4q@26$&Rj?HUS{y&g>AW=V0apTvDX0uufV(|v6FcS!dg^LNb}4VpA% zK+U{&dAs?Q&~~c_TF^hnfkIzZ-7=Sl9Ec zZ<8&WFJ5Avr6M#kgl?*Uy}E4l`G!hHyYmL1JEdfcqXDK5Ahv>WG;_`CO`=waiO&&2 z3a@}qoX#%t#=l5LLcpB+#1`HnHEoGX5CKs7?13r5%4*_59xQt@Lvs|%R2zi~BhAX0 z9y{UeZ6jK6V20i6yxlccztkQG@A-GL@jt8B;e=*+?Bl<_ihHNk@avBDnHyIq{Q2l8 z-pXtTG}hO_&A$$!Y`gB7<$KO&fVS>i%*i!sd1&$PWE~Q#Q0-5Pi_4?^n{A9dTlsT1 z!^?AE7FW3(dF_`9v)!Fb?{8d8DGXx;-Hi{Uc@HD_HO)sm)~rP@ctat}@o7CnLOr=vafaZY!!v66{9djOi3%QKWwgJ*BhbBPJLx!c`PrgOC3~C<4uVFPGJhKD+uo`7(FXh>&uFP2H4tYjUp$EV7aRXGeCfs4&4SGYj|mvn zMVW*#F8;vOwzhsAy9^1BzXWwl4k@FpIa6Z54zChA%BiC+gtOOf<$R@-eRV3}RG9mHEKcMAnkMdH>69;3(pY-B zv*~yGa_%alAs}(^p=CJ2tfi>((iczcL$FwgXD) zCC5z}TN#WKA$?Gz@s@+FMd-3^ej45QE4K2ShTkhhp4>P`2Ya4t53|r>4jdx_pgwM% zNoHO)`bJx}ib>J^Zv+{o-+awe91s%{okQKOa=rfd^_IlF@(OUe80)y*U=?bDvE znO;mW!vBgy(8IvDP@wwbbB8-Vp3nzuO>3TClE_4~H&72hCa&KNI1>98{L*9P?LlbS zF-zsU-<|vFJ@wEVca^&9Eon=dVL^pi>-DhZFz><%SxGCPgUBEG3qA>LA}TMjA=jz3 zg@b&mu39!*)wVn7$w%*l6u4;k5|#|!Vg8QGKRLgAt|hJfo9P8R6W%IPMtR-x;gH|? z{tEq4z!xLiAJ?0p-f3?2t9H&B__v`NA8CMn{Gy7?uTHN?2g7X*1m@yOe(eR1LE91L z%+Aj?m8Yf02N1sz6VXzU&*dD{ zztU#(JbLSNTQ?lDq||=Z4fOG$TB=;CPkTm#`AM=;R=gqP*^i&+`a@v|Xf5_ZTGZ!} zje{KoH8imzV3UlBWPgAmmp%DS#}2ait>~9HQ}zbRJtxZNgObh@?ixi9;YXJ@qsz;i z%qFG=NLSgFW`;B~$;rMRIsp%yYM3wJQ+x{h0DXp$lPR2P!o8Gc@PLCNzE8h8%Y62w*CcO}dpGuOEW7CI z3;_a(pp#96x6&kk`BY2>wPJNd&-=!eeV*j$?5<1cbSyA~a0o^=-EVaMtXKQYrNqiJ z(Y@}L1pX$ADVi!6<-;$^R^71|cS28~K)sL79i3qBtK4wpPllwu%xw#w*r9z+r4O-O z1E?5!;xlv^5?$juD1TQtb~3{;LOg<^s8HaRv>HG;&=Qr1_PAQ3V8Yy!T52*IEi!Bez%7850})UK&{wq*@7jC&nSu1w=;D&&W-|c zfpk6_#wwGO**75(fnD!mf~!o{@6|Zwe12GRd)Bx+TxxtyxqsD3I&AKX9#onh1yjr@ z)u8t~{xbC@a3YNQ`4P1pI*3onA1W)c|uI-Ri5VqNIph*(eBg6gH z5?nOx<2qHL&;jiROA=7(=RiB8ir1%1n4z?Z5N3zl%X;88FFCz=>8*Sl9cmSL@roV(sW|;_s6) z5mOv5jCelg?OiDZ&yiI6s@oaFgud))JGcMxnJbjKTR2@LqpPvtbd&h0;&Nq^1 zrJlDYbs-zs1%3G9d(Mg@txK3&`(h4t_cwkFd?sn(=iD^L8YN^D!@ysQ#7Uu_ z-C`=SuP|4ei`p+`(tjr#ySyCOQ`^@m_PF`mqnE!!vr=Fo{yJ$-17-Gd5`$so0{f!n ztoq%c%lU?t0C$pW(4qfth%pgR>1LouR?j$Q;`k>MgSy|`puse)4te=a~=FQHyH2 zDhAplW4$Ksi|#&T{xmL{eOsIaTFx&)JZUt3sQim|#p{R{9mLY;jQRmSdsC7sL$Er0 zG^2%2@w*zGl@qo;74i%VO3(jHYfIof2Y&Rm6^E=g1o(QTBgi!rOpB8Qu1r9pwWP1e zHfO8O)PFNH#Q7EO?Bbyby{xB;PnseciD9~2SCa?Ou!d8SXN;yVH`#yWap2%4>d3CE zK;3z2n>96ckJK8{d(fPb^Yc%g>>e)63}+x)aa9te?K`GQ$yn|sXRF8_z8T>^&DreK zMDmJVsvfF}KGJr{=mOw1TGsbBOu~LH)VTjFiAR6=vov_-^%yz=g81X68@{=Qaa2AR z=}ru?EXbafo}_pnT&EObz8rPkyGWvVrIn0`H@>A_S!SDT{)Hx6qV{9Lw|Y8y$eE2_ z00Mwk32z_4lh+|d!`%-5Q#8B{{W$7f|_u?SYbz10oT<0YufSktYU z&yXD-S1Ut~jKs%t-FW(3S)Lc*seS#E=0XKWX!xn7OW5WjdfHcuwB-qlf+|Xvhl){BGVsb~sio6%&F&4Jk$(%gwQ4r5>S8XHG3)7SYg;Uo z*OfiKf5kLdlH3&xUW+!JL3SiUQ~0pMx1#;%G`f7GmoOA9yQ$`u z{L4L35P#ii&R8xIyRXe)h~C{TBsE$r^7EcI)OAZFG-4LBnc4~9;z-x3IvUv#Di`%- zbvFfhBVYREr>I6ip@H&V(GtCVnv<_Ut%b^?Jisen`PXRcwXJf2*qwR72}-d~Ru2nVS z-@kXu_yL;rsb=*kcTr2kvfckD*2&m_t^h%)ka#6<2W5l#v{0G^ucb%Qd#${}5r)Ft zT@ZpXJ(g=q-RNqQ!&Yf!`lpn3@oIWOGWw$ho(-E5vBm2)a=}DTKqeWX|VlX z-CCuVl@)&G>N~1mtBId!hBT`G?(TTMNWxjT)S;B+H;7R^*k_fEPk2@NRah)Tb@}C_ z{nMUDOqpRmXbEFtfyJY$z{I?THunXJIyXn1gn5_G+YcwynX1K6zTLX)= z7OtQI*-Aso)4GFLlp*?-$3|0nNMdBX@r0D89$J4<1XW}0>F5b1Bf{(0=@d+qZyO;G zogr8^{j@aHmUd zi!Vq(`Ye!l$Pza0rHmoytyBHk>tqZO(Nd%8E%=4{?d&U{jX_MBo$w+vDTy5s1GmPo z`n2=L-is5lZ}BVXhsPMXit{b`7oP_>yBPrHjE592{xMvQH`Qpf5p2)}mBYPeV4qf8 z?$pWN)8O(nS8U!8a2Da-tM5Sw+(CDwJq=$ z(I8nN`fZxPXEr`QVprqo43fT%9UKmPA^iN(RWhDA(OQy7^#_V&ev!X#=RS4(sL!_V zx>JE---6mQL}Ha;ucuZRgPvMXjoH<<^uEYsB0|_<7_zi+c{i-Ud#i1S-0knfy;BO@ z38h)>H-urdc}Q%!kkyW$#k7_b+`l`m4*Vj4L7IY=h}N@Y!GR&^g^aO zNzzn|gd^B48v}zM|2P0(`k^MzR)~+AYLkFRODG9>gE$tzObg|5`R@yLfg)=DkxXwpiskfZ>zWkY zXb*~zp@FU?lZ449VnB;l()?6)3m$&$&>gsy$G?(*?-&$i3b2AsqZm5_YqHawDm0RjexYsILpB zh(VODiTQf> znTCgh8BhPIbVV+UGLot+zk*&H%od_y{6~?Eh~(UCIplycAMt#3m#>-bx9!fS`7^T; zba)5dn*t=NGJ9=Z(Z^R0z_?WY(~FVL`)K7CT-?59B)p*&BI&ANA~lPMVudg6ev(5U5> zZq`)EnQY&2LKh-vUbCF#Ywn}3MZZIn5I(t|_g9}QbkCuPMYv{3Qc+~_2;|yhHzwLu zEMk}%pp)<^bKKl{*RcTcOF$wN_WpLcn05EGH>LRA$5m+~s;9&Quu@RZKM0gP?;&(7 zQC3Z#B5EqFmaGM@9G^E7Okd>TJMNc!<8jzA@~M6G>}vORXRAs9GFYz@L)|B8++zd+ zaHSwy+(7#Et_H9LL{fYXvVadK0D#?2F~a17vfKdE5%|O})h8;Vl)>&*PLxn4fDWMZ z)4*YuhG1jXp*gOqIK{Lj8KZF==JymR`vlVXiidUtYNRDLIC|(J z2R0iJaI+^6@q5ll_p%HiqeK@ zU?#%iwX~8Bz!VUGC*saPB&2TfP=V;;^RZokr`gBjUyda}X8I*Ag&|tKQ#x(xFNO1V zl_M{GMVKfLIR;v137$pDphIW-=JznHB5Ej;A#rjabw#!y)BOj^PO0G3qxTb?LDwrR zer_x*4TCXxnJJ#1q?pQ>ds;b`62Ms^o}==)>&Z<$A4EHXko`W_J?whFr`9?}NzKy3 z{bEpK!s*YhDME#2YJUI7v-N>@_q=H&W0(mu6>q@_9>LE{j z0s^OH*Oio*6q;%4yrsHHaiqek@}2%Mi#q7w4JPL^==Hv>bM_vBX6p1eEtEtZWQ$EE zqgTD&Ghk=PqQ-VYO6APGDvVk)TA_{6iB|@s zf~<$LAkCO4eU~MkEmM0&QeAsM>T@H~irB=TquHK0Yz=+%gh?Kir!o$7B-*1mUZHB1 zg^Q7Sp2NxNH^O#tZ(>Aknm*G_KDeo&V%Y%R8y>^FA;bVnqR$|0TfgX+r#gY1U4@Ql z(4rDpWqbl?hN+~^fhbRUeh6kC>k+YRD4YVJAC;@?k*45r zNfJnUUGe;oV1|2z1># z0s;wV(r`+B2qoA=7aH`xFd2=&GZ|=+ zC&WvVdJ2kI0I^F=>86kz6KVV>I}zffa}DSK`6oiWB{9IGsV*E7TBHSL!c)OQ|BwA< zQn7%;{rZ9M?taw%y_}wW9$thfDb|0f;)w<0|KWVX7#W``Lk0DoBCtn1sNk`tQX*s{ zHr-XtjscE9DV_VRtA6a7wu8kAgz_XDE21ncN^gmtwAoe(=1KUkD^QOBv_9fL1wSBK z*-roP|Br+=a)22g^HVDZG|$`jBp8ljk)~1lA3MyZ#=GB5ME=KNCTCfx6Q!9_gAuYf zn7*_0cYb*#2rk^+eFB*}{Znpr#yl1$=y#m4-($nTN}Tp?b7z0tnw_K$b&!P6HONR- zDM6GhGHha!M*sDbG&_g(ICvnfIp!3k6Kx~0RYSO`)wC3aQd5N9d-#lgx|j}sI=%h4 z7zPZS>!qA)Tz6c;R03ULKq^xFVFM#;-*?xLj1R;yQxjp^&d0W$rVT>5Wm?sI-4 zt3NP9Ha}nqe-C`LI)0Zoy_Li7JcPRi|AzKbRW;YDaE{ri1-+0Thfs8j_Utiiye5Ot z9ODy&KsDBPi9_To{Pa+L$&|y`s)+Fu;kxna^UewFJ}P>xfpYi2yOd6?3Y{uX(Yy&W zyEr_!HU0Um^A+!p|G%B== zYr2L$KFLjKpuJ~DmGrTFC@{$dldIMdbB*kDk?Z@Yw!@t04+wO}yMJhZwddh$eZ?-- zc*syBi_g)|kKo;V`h4V9cj2T;TahSGV~EH{Y1tq0>Y_Yvb?zxjP)@wA#?u=U@PdMZ zZO`3{>lstgd*-mzpN!tBkH0NFq%~`OpU;j(YS$9?hrvK!S0{hugA2)VU70J!_b|k` zr&3(-)Ty6U4-tFWaV?b2Pi8I6?qn2Q3eBKbm^Hm4i04FjH2DZU7ow z5Z{sjdqze>8~zS*X+3)6x5aRDN?^6C>%gS~X5P&nrj;eaW>fqI8IZ zuBbs|U`;nXmMc1IgAhJwZ&`j#;32ob`BlTHZiAI=GJ;tDs7OLjEm0^a_bYwp4)@#L z=`6I>B6{ybYU;3AqXbm_2x?_9dbvsjU9(b!^mjIB^3S^ywgU_EN^LQuRLV}NN3sHL zYIuxcU59$5;ho?B8X$*;#xMcF35;ZtuvNBR%c9%i91aA|Zg@=4+v>_o0gi*wDY4 zMQ{OUV=Mtcb0q90$cy4w4)to{B}><|=+KL=*GON;BC$SGN(zj7RQqCtE?#zfpZ$2P{>OeHMLa_ z6PE+ZaBbEGRmfzBp;mW4O5|~FS{ZuIuq57cM-OyQOI&NxSR~V>{;6u^(9GvQZc;zS z*A7!pX^R9KH7*k;6E1kw1AhyumP~+jQYVB&c?kN>V>p82tM?m$L;$%TurK|38E6Gv z!}dN!$>nk)D!6f@a>_)0YAjEFvZcUUl`Ka(!AWX#n~5gz>S!XkJ^3QH6_5b%B?^la zyL;?i_0>QE1QD8^Q(KCE4^ml{FNyQ#v&R~u2Ad|(1;D}61F1U9qNA4HgEBySaC(j# zP&XAqv^&oIEv2A}YbeQ)et^~#^-2MU5luv4ZuC@cr)s|iPE6bQ?Hmb7W9-GcjHYU`>N?fD;&Xzs`{Ckn65bkN+<(Czs5*tB8ki?`mP3@DIz&un;5Jk z+Yw(aUUF6Mf|RoCHP?J9-lf>phDWz>v~s`oGN*_vSh|3L(un^>x8%G$ybX|hXOhyU zR^q~z8JD&&VF#=KM8lg$nBG4s1Kmd6e-E~moq6tB{fxYvv>gvQ>)$WnA}5teM9t-V zwU%8k85Lk%hO^=x6?a6uTTvYa1a@H}@l$9f-89WCM(p0xFk-d9+mnhmg$owz+k*?t=uPm92n*DJ%wy=E$65qs*;|*I&>aLF z6#929m*JzO_iy@!=)UNR*<^gHNR0w=1@3eTnN< z20fOa;DQ?5^|%|_A!r^h&B&#*cTN|IKIQjIH zD*zc!?s(rw;ra~l_zG(3SHOrY6~VmVRtaW?u1$Fnmo@HlIj})HBD%V378?LgW7yEg zXhgB7@bzkpXlGFRQJd=D%$^4UGp|-G4;p9>g~*j=Nzc52(m}y#z+tTFC)zD90=7$( zW?&S8nP0wy@w6{B`}Z5p&5<-l?G$zuDRJm^G8e7d+L@7RN}CpQuS0(?%W2r4BPByD zBCKNT--e>APcV0e{``*uJQ3N$Z1w}8#3z>I>nY#oQTTu;x^RRKF$+f?CF8RvBopy2 z6s}N}&lhXQr9$Sb0`R)sM_yX}AwVx-)LwY`q)xfnO|ike9qQ+Qgbr`>oAO`{+NqLf*?*7v>* zwCeodB$dGfV?iSjktzPLm4RZk|1paUn%*}PW>V)H%JE-^01|l=FEM4_93JoqxmH6N z1>nYDfdBIcBd8#Q&+lnC;GeMRBEa&fo+m9c{UD3=&c@6dPcB-fgx0&W%l&Nwy??*0 z{%2;^6Xfet~ux8(X^TE7F#Uo z|0dR1>fgEt(U%X*n6Hq)0hq1xw>u<$AyGvAgdAOQ>hb=mk7qy6o>(oLl4Z~&aC@o! zomzyT6rpqjRV!DoDc&OnI6+GY`SvWiHCI}2{r!6KyiH;MlDcZyuYQ!*(WfnW*f?Xw z7OqH2OF!m36bwJK4gDU`=t+l|mU-3BQk$@(l1Pz@HQZlQ0Q`{D(CEBc@@B?{uQr9J z{XIvTzh@5ahyi~*l#^!a_%K zPJPW;Tc8^G0AuuU9P*j{VCoz|a1MP;CVt-|yH3Q`GaaW&xZ#FtPXQa;XVWJNbLLjf z=N)&EuF)!Y)+g(n%3f<8*t4IympP~A`dqyy#OLRD@%M{e3~=NI0E&Y+$pj5-9Y|&M zsM)CBj+$91j4f(ajNPL?HLF2Yxeq83i((FTf2`;c%++q5d zLVV%9gIk6Vn&Y{lpn5k)7A*+n+%x}}UHWXb3g51XHNOVHPH0yD4P9m&GBc?n!Vbl3 z|yH~5eM}Gb4Uj*@nlv<~b))C%RW_f=yU>>dk+i!dsj#Y%E&yDKI zSHzLpj54_fXnxCDLgY(8Thl(osy}0z6UCi(L|uLIi_5A_*oae-4e-elsIM3mB_pE72v2(1NGvW1ztw@RSa<1ucJ$e^L zjZt0CMmwz9uH5X@^Lz=qrHmt}*gH7CFxk_Y2~FhOZI~?6FRfwP$>d67`wW0nPwjoi z4-Rf(o%*@mYS#O%9{P!Ds3=ustIEgb&SnK+T<&gh#A>e#c2M~jHEuET5sKZ+A{sES zL6`xp_Ib_)v=h!8V{NkZaObyb6wzL(QdTo|qfgT7@cUIH^lfe9qEgujmDTrxM)|i3 zC1rQqc60mmWDm4cZY#5(h!oqcvBfU6&w1&$ZWE9T_`|5988QrZLe_e@KsLUCLn;h| zJCD7X3XR9qt_O-mk{1w2eF0%E$9k2zdN+k0p*@CdDz~gP)v=H|2eRGJYQafwVS0ia zUtgi;JUlB)Nn)&Z-i>jmC>qwY^h*w^jErf1;7dy{h=%`&J~7Taw=52J%UV}UO@d~@ zRHK=BtL_#%PH9gI%R0&8#g7dSqAHDi_gR^=pAau+teduL)Sw$X-9rV?RpAdUqAj4^J!7 zW<~wVRmxi^hh%+KB%{d>-u?zYohbS<^02ccJRL!knugG9=ZlGW(QkSRF&gWlNIlNm zRIc8qMwpsqyynm)J1?Q}Qg2e@*rAq6#}=&bBKdJZ*X9> zZI%tSbyap5@vvys@2+9BP>x+nwdM7aG%%UZx5%}1)GiiN>-^`S+*6Pg@7iMI5F1?u zk8zZ+U-D>dB0su<>l?LFY55#C>m`)cx+?ef=6Q}STcp;Ao@Kt=*P`vMZ!i6w4xVPOPGxZ4C{c3H{r1U zcuJM07h@*=zi5Gl5Bm+C58oh&Vv#XHQrar$b5hm_W`5ENw3G~zlkft|~t3x~eePomkML?(X& z(;8y}El{M{H%C}ZKfu{H(-5ie#Udngh^+12?tdrST-i6`Q{697ks_zNgKWnWt0!W1 z$A!UvF@Owy9b)2ltGhLn466^m4`-Xn^CfamUEqZtiJx9v$Iln8y$c0kpY@+Jh- zrB|&xYlf)h8yl~)KHgJs4oXUu4m(~ihzbQCqx?*vukgL#Cdp-v%6iBTPm0?)7xemH zkK}NsP?{C~3y0?Y1+P~3-_D%@>)%wIuBpXbZQ|jA!Q~0I))a>_fykeI=ZCKWhA4OQ^T!vCEeC1C`tAQk``5jEULk{V41yCm7!+WCs- zMsh6%LX$mzHrKz4o>}I6o&30BQDV}>6ovu(ewyq5SuV8W?`0!FFz`@_6Spkr4! z(?E|6YULou(9_IEq5IvyqM={rbYQ6E+Ngv_Kn&D*#^REY=xKN+5kwQ<4}Uknya5N# zHg~7vx5aD~eTFzC)ds#cR_0|#9esHdb|X<-KI-!M19FjdMirYQ&?22E!Qd;g4rlJWs^n4~PuT|Rwzm}2U)aemG0Vrl*i-&7%3v{Ow` zdWbCZHKmkg@s#%a=ocW~L_NY2cgPG}C-ioIv48h6cZf({Lr~2dyjSSC2G;S#>b&$m zR`t#}TAAKm<*`5+l8A*wHrJP{{E@0~J#nLe zx{CKA+RZHx8_PnB#oD5TQBDXb0UahE=tAB0x_mr3G-S4KNpI{eK`dUTu{KQYF|3us^w2{b@z zUS`pK_ZKUrKFTx94ixlSo%AXd+T&{pfy9spi+sTbgQ^cIML1(y3OQIzXm1`|Eo8sU zyO)~ss^Fp4Umc~rSV%(>%q4{k;0!FSnDoaYBugL$nZJw(f$};U-t2QU%?iy78Ohz4 zz{R&cHry|~To#pF4B_rfLRkvm-yVvL=`l2$BK;&b!5(dkez{GXCOroY#^u3(hHLgY zgG4l)U6w|P;vkL3_xz+%`asHNy|aQGxAy^MHc_+9taW|ap%i^iCQgS=GlecPJv$TC zVA*(Wp`4g! RC-}ccRYgs3-D~rR{{vLVGEo2k literal 0 HcmV?d00001 From e533a9e3f04eeb29dbd58bc0246be7d404a643da Mon Sep 17 00:00:00 2001 From: Wes McKinney Date: Tue, 5 Feb 2019 09:18:32 -0600 Subject: [PATCH 19/21] [Website] Edits to Python string blog post Change-Id: I4208e6d42fc6040313de7a01f897fc22db490c43 --- site/_posts/2019-02-05-python-string-memory-0.12.md | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/site/_posts/2019-02-05-python-string-memory-0.12.md b/site/_posts/2019-02-05-python-string-memory-0.12.md index 842cc23f73e3a..0979274461736 100644 --- a/site/_posts/2019-02-05-python-string-memory-0.12.md +++ b/site/_posts/2019-02-05-python-string-memory-0.12.md @@ -130,10 +130,10 @@ values. ## Memory and Performance Benchmarks -We can use the `memory_profiler` Python package to easily get process memory -usage within a running Python application. +We can use the [`memory_profiler`][2] Python package to easily get process +memory usage within a running Python application. -``` +```python import memory_profiler def mem(): return memory_profiler.memory_usage()[0] @@ -243,4 +243,5 @@ native memory layout of pandas or other analytics tools. We will do our best to provide fast and memory-efficient interoperability with pandas and other popular libraries. -[1]: https://www.slideshare.net/xhochy/extending-pandas-using-apache-arrow-and-numba \ No newline at end of file +[1]: https://www.slideshare.net/xhochy/extending-pandas-using-apache-arrow-and-numba +[2]: https://pypi.org/project/memory-profiler/ \ No newline at end of file From 623deef5347cddd2c9f9e8d39b66a60e2a59d89d Mon Sep 17 00:00:00 2001 From: Renat Valiullin Date: Tue, 5 Feb 2019 09:54:22 -0600 Subject: [PATCH 20/21] PARQUET-1525: [C++] remove dependency on getopt in parquet tools Author: Renat Valiullin Closes #3545 from rip-nsk/ARROW-4456 and squashes the following commits: 939a4fb3e Add parquet tools to ci/cpp-msvc-build-main 03422aa99 remove dependency on getopt in parquet tools --- ci/cpp-msvc-build-main.bat | 1 + cpp/tools/parquet/parquet-dump-schema.cc | 27 ++++++++---------------- 2 files changed, 10 insertions(+), 18 deletions(-) diff --git a/ci/cpp-msvc-build-main.bat b/ci/cpp-msvc-build-main.bat index 779af154bedb0..5b1842d2f943a 100644 --- a/ci/cpp-msvc-build-main.bat +++ b/ci/cpp-msvc-build-main.bat @@ -57,6 +57,7 @@ cmake -G "%GENERATOR%" %CMAKE_ARGS% ^ -DCMAKE_CXX_FLAGS_RELEASE="/MD %CMAKE_CXX_FLAGS_RELEASE%" ^ -DARROW_GANDIVA=%ARROW_BUILD_GANDIVA% ^ -DARROW_PARQUET=ON ^ + -DPARQUET_BUILD_EXECUTABLES=ON ^ -DARROW_PYTHON=ON ^ .. || exit /B cmake --build . --target install --config %CONFIGURATION% || exit /B diff --git a/cpp/tools/parquet/parquet-dump-schema.cc b/cpp/tools/parquet/parquet-dump-schema.cc index 7b6c1b160aacc..0d7c2428f449d 100644 --- a/cpp/tools/parquet/parquet-dump-schema.cc +++ b/cpp/tools/parquet/parquet-dump-schema.cc @@ -15,39 +15,30 @@ // specific language governing permissions and limitations // under the License. -#include #include #include "parquet/api/reader.h" #include "parquet/api/schema.h" int main(int argc, char** argv) { - static struct option options[] = { - {"help", no_argument, nullptr, 'h'} - }; bool help_flag = false; - int opt_index; - do { - opt_index = getopt_long(argc, argv, "h", options, nullptr); - switch (opt_index) { - case '?': - case 'h': + std::string filename; + + for (int i = 1; i < argc; i++) { + if (!std::strcmp(argv[i], "-?") || !std::strcmp(argv[i], "-h") || + !std::strcmp(argv[i], "--help")) { help_flag = true; - opt_index = -1; - break; + } else { + filename = argv[i]; } - } while (opt_index != -1); - argc -= optind; - argv += optind; + } - if (argc != 1 || help_flag) { + if (argc != 2 || help_flag) { std::cerr << "Usage: parquet-dump-schema [-h] [--help]" << " " << std::endl; return -1; } - std::string filename = argv[0]; - try { std::unique_ptr reader = parquet::ParquetFileReader::OpenFile(filename); From 4004b725c952187522dd77cdc3151dca664e2148 Mon Sep 17 00:00:00 2001 From: David Li Date: Tue, 5 Feb 2019 18:48:24 +0100 Subject: [PATCH 21/21] ARROW-3289: [C++] Implement Flight DoPut Implements server/client side DoPut in C++ and extends the integration tests to exercise this. We may want a different API for client-side DoPut that exposes any potential server response; I made it give the client a RecordBatchWriter to be symmetric with DoGet for now though. Author: David Li Closes #3524 from lihalite/arrow-3289 and squashes the following commits: 13fb29af Document why VectorUnloader must align batches in Flight f32c0b25 Indicate error to client in DoPut if no message sent cd567820 Warn about undefined behavior in Flight source 1f816e8c Move serialization helpers out of gRPC namespace 21b315a5 Hide FlightPutWriter from public interface for now 6edf2e2b Introduce FlightPutWriter 58d6936e Enable building with non-CMake c-ares cfa4ca5e Properly quote arguments to gRPC CMake build 302dd334 Explicitly link Protobuf for Flight 419ad688 Log (de)serialization failures in Flight fast-path 562b8613 Factor out FlightData->Message conversion 65d6ba2f Clean up C++ Flight integration client 3e185cb9 Add convenience to parse JSON from file 111b3e6b Fix style/lint issues 3cb51bad Test all returned locations in Flight integration tests 905ef38f Implement C++ Flight DoPut 138141fd Fix FromProto for FlightDescriptor a11a5acf Don't hang in Flight DoPut if server sends exception b3ac01ab Align RecordBatch on client side in Flight DoPut 846df730 Implement put in Java Flight integration server --- cpp/cmake_modules/Findc-ares.cmake | 108 ++++++ cpp/cmake_modules/ThirdpartyToolchain.cmake | 25 +- cpp/src/arrow/flight/CMakeLists.txt | 2 + cpp/src/arrow/flight/client.cc | 284 +++++++------- cpp/src/arrow/flight/client.h | 35 +- cpp/src/arrow/flight/internal.cc | 17 +- cpp/src/arrow/flight/internal.h | 2 + .../arrow/flight/serialization-internal.cc | 37 ++ cpp/src/arrow/flight/serialization-internal.h | 345 ++++++++++++++++++ cpp/src/arrow/flight/server.cc | 218 +++++------ cpp/src/arrow/flight/server.h | 18 +- .../arrow/flight/test-integration-client.cc | 124 +++++-- .../arrow/flight/test-integration-server.cc | 102 +++--- cpp/src/arrow/flight/types.h | 6 - cpp/src/arrow/ipc/json.cc | 13 + cpp/src/arrow/ipc/json.h | 13 + integration/integration_test.py | 41 +-- .../org/apache/arrow/flight/FlightClient.java | 7 +- .../arrow/flight/example/InMemoryStore.java | 1 + .../integration/IntegrationTestClient.java | 82 +++-- .../integration/IntegrationTestServer.java | 128 ++----- 21 files changed, 1039 insertions(+), 569 deletions(-) create mode 100644 cpp/cmake_modules/Findc-ares.cmake create mode 100644 cpp/src/arrow/flight/serialization-internal.cc create mode 100644 cpp/src/arrow/flight/serialization-internal.h diff --git a/cpp/cmake_modules/Findc-ares.cmake b/cpp/cmake_modules/Findc-ares.cmake new file mode 100644 index 0000000000000..1366ce33fa790 --- /dev/null +++ b/cpp/cmake_modules/Findc-ares.cmake @@ -0,0 +1,108 @@ +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Tries to find c-ares headers and libraries. +# +# Usage of this module as follows: +# +# find_package(c-ares) +# +# Variables used by this module, they can change the default behaviour and need +# to be set before calling find_package: +# +# CARES_HOME - When set, this path is inspected instead of standard library +# locations as the root of the c-ares installation. +# The environment variable CARES_HOME overrides this variable. +# +# - Find CARES +# This module defines +# CARES_INCLUDE_DIR, directory containing headers +# CARES_SHARED_LIB, path to c-ares's shared library +# CARES_FOUND, whether c-ares has been found + +if( NOT "${CARES_HOME}" STREQUAL "") + file( TO_CMAKE_PATH "${CARES_HOME}" _native_path ) + list( APPEND _cares_roots ${_native_path} ) +elseif ( CARES_HOME ) + list( APPEND _cares_roots ${CARES_HOME} ) +endif() + +if (MSVC) + set(CARES_LIB_NAME cares.lib) +else () + set(CARES_LIB_NAME + ${CMAKE_SHARED_LIBRARY_PREFIX}cares${CMAKE_SHARED_LIBRARY_SUFFIX}) + set(CARES_STATIC_LIB_NAME + ${CMAKE_STATIC_LIBRARY_PREFIX}cares${CMAKE_STATIC_LIBRARY_SUFFIX}) +endif () + +# Try the parameterized roots, if they exist +if (_cares_roots) + find_path(CARES_INCLUDE_DIR NAMES ares.h + PATHS ${_cares_roots} NO_DEFAULT_PATH + PATH_SUFFIXES "include") + find_library(CARES_SHARED_LIB + NAMES ${CARES_LIB_NAME} + PATHS ${_cares_roots} NO_DEFAULT_PATH + PATH_SUFFIXES "lib") + find_library(CARES_STATIC_LIB + NAMES ${CARES_STATIC_LIB_NAME} + PATHS ${_cares_roots} NO_DEFAULT_PATH + PATH_SUFFIXES "lib") +else () + pkg_check_modules(PKG_CARES cares) + if (PKG_CARES_FOUND) + set(CARES_INCLUDE_DIR ${PKG_CARES_INCLUDEDIR}) + find_library(CARES_SHARED_LIB + NAMES ${CARES_LIB_NAME} + PATHS ${PKG_CARES_LIBDIR} NO_DEFAULT_PATH) + else () + find_path(CARES_INCLUDE_DIR NAMES cares.h) + find_library(CARES_SHARED_LIB NAMES ${CARES_LIB_NAME}) + endif () +endif () + +if (CARES_INCLUDE_DIR AND CARES_SHARED_LIB) + set(CARES_FOUND TRUE) +else () + set(CARES_FOUND FALSE) +endif () + +if (CARES_FOUND) + if (NOT CARES_FIND_QUIETLY) + if (CARES_SHARED_LIB) + message(STATUS "Found the c-ares shared library: ${CARES_SHARED_LIB}") + endif () + endif () +else () + if (NOT CARES_FIND_QUIETLY) + set(CARES_ERR_MSG "Could not find the c-ares library. Looked in ") + if ( _cares_roots ) + set(CARES_ERR_MSG "${CARES_ERR_MSG} ${_cares_roots}.") + else () + set(CARES_ERR_MSG "${CARES_ERR_MSG} system search paths.") + endif () + if (CARES_FIND_REQUIRED) + message(FATAL_ERROR "${CARES_ERR_MSG}") + else (CARES_FIND_REQUIRED) + message(STATUS "${CARES_ERR_MSG}") + endif (CARES_FIND_REQUIRED) + endif () +endif () + +mark_as_advanced( + CARES_INCLUDE_DIR + CARES_LIBRARIES + CARES_SHARED_LIB + CARES_STATIC_LIB +) diff --git a/cpp/cmake_modules/ThirdpartyToolchain.cmake b/cpp/cmake_modules/ThirdpartyToolchain.cmake index 5ee0ddfd55914..9bd6cb68c70e7 100644 --- a/cpp/cmake_modules/ThirdpartyToolchain.cmake +++ b/cpp/cmake_modules/ThirdpartyToolchain.cmake @@ -1347,12 +1347,7 @@ if (ARROW_WITH_GRPC) BUILD_BYPRODUCTS "${CARES_STATIC_LIB}") else() set(CARES_VENDORED 0) - find_package(c-ares REQUIRED - PATHS ${CARES_HOME} - NO_DEFAULT_PATH) - if(TARGET c-ares::cares) - get_property(CARES_STATIC_LIB TARGET c-ares::cares_static PROPERTY LOCATION) - endif() + find_package(c-ares REQUIRED) endif() message(STATUS "c-ares library: ${CARES_STATIC_LIB}") @@ -1406,15 +1401,15 @@ if (ARROW_WITH_GRPC) set(GRPC_CMAKE_ARGS -DCMAKE_BUILD_TYPE=${CMAKE_BUILD_TYPE} - -DCMAKE_PREFIX_PATH="${GRPC_PREFIX_PATH_ALT_SEP}" - "-DgRPC_CARES_PROVIDER=package" - "-DgRPC_GFLAGS_PROVIDER=package" - "-DgRPC_PROTOBUF_PROVIDER=package" - "-DgRPC_SSL_PROVIDER=package" - "-DgRPC_ZLIB_PROVIDER=package" - "-DCMAKE_CXX_FLAGS=${EP_CXX_FLAGS}" - "-DCMAKE_C_FLAGS=${EP_C_FLAGS}" - "-DCMAKE_INSTALL_PREFIX=${GRPC_PREFIX}" + -DCMAKE_PREFIX_PATH='${GRPC_PREFIX_PATH_ALT_SEP}' + '-DgRPC_CARES_PROVIDER=package' + '-DgRPC_GFLAGS_PROVIDER=package' + '-DgRPC_PROTOBUF_PROVIDER=package' + '-DgRPC_SSL_PROVIDER=package' + '-DgRPC_ZLIB_PROVIDER=package' + '-DCMAKE_CXX_FLAGS=${EP_CXX_FLAGS}' + '-DCMAKE_C_FLAGS=${EP_C_FLAGS}' + '-DCMAKE_INSTALL_PREFIX=${GRPC_PREFIX}' -DCMAKE_INSTALL_LIBDIR=lib -DBUILD_SHARED_LIBS=OFF) diff --git a/cpp/src/arrow/flight/CMakeLists.txt b/cpp/src/arrow/flight/CMakeLists.txt index b8b4d8d336365..1cbef6cf81808 100644 --- a/cpp/src/arrow/flight/CMakeLists.txt +++ b/cpp/src/arrow/flight/CMakeLists.txt @@ -21,6 +21,7 @@ add_custom_target(arrow_flight) ARROW_INSTALL_ALL_HEADERS("arrow/flight") SET(ARROW_FLIGHT_STATIC_LINK_LIBS + protobuf_static grpc_grpcpp_static grpc_grpc_static grpc_gpr_static @@ -69,6 +70,7 @@ set(ARROW_FLIGHT_SRCS Flight.pb.cc Flight.grpc.pb.cc internal.cc + serialization-internal.cc server.cc types.cc ) diff --git a/cpp/src/arrow/flight/client.cc b/cpp/src/arrow/flight/client.cc index e25c1875d669f..a58c2b5933225 100644 --- a/cpp/src/arrow/flight/client.cc +++ b/cpp/src/arrow/flight/client.cc @@ -22,12 +22,12 @@ #include #include -#include "google/protobuf/io/coded_stream.h" -#include "google/protobuf/wire_format_lite.h" -#include "grpc/byte_buffer_reader.h" #include "grpcpp/grpcpp.h" +#include "arrow/ipc/dictionary.h" +#include "arrow/ipc/metadata-internal.h" #include "arrow/ipc/reader.h" +#include "arrow/ipc/writer.h" #include "arrow/record_batch.h" #include "arrow/status.h" #include "arrow/type.h" @@ -36,161 +36,11 @@ #include "arrow/flight/Flight.grpc.pb.h" #include "arrow/flight/Flight.pb.h" #include "arrow/flight/internal.h" +#include "arrow/flight/serialization-internal.h" -namespace pb = arrow::flight::protocol; - -namespace arrow { -namespace flight { - -/// Internal, not user-visible type used for memory-efficient reads from gRPC -/// stream -struct FlightData { - /// Used only for puts, may be null - std::unique_ptr descriptor; - - /// Non-length-prefixed Message header as described in format/Message.fbs - std::shared_ptr metadata; - - /// Message body - std::shared_ptr body; -}; - -} // namespace flight -} // namespace arrow - -namespace grpc { - -// Customizations to gRPC for more efficient deserialization of FlightData - -using google::protobuf::internal::WireFormatLite; -using google::protobuf::io::CodedInputStream; - -using arrow::flight::FlightData; - -bool ReadBytesZeroCopy(const std::shared_ptr& source_data, - CodedInputStream* input, std::shared_ptr* out) { - uint32_t length; - if (!input->ReadVarint32(&length)) { - return false; - } - *out = arrow::SliceBuffer(source_data, input->CurrentPosition(), - static_cast(length)); - return input->Skip(static_cast(length)); -} - -// Internal wrapper for gRPC ByteBuffer so its memory can be exposed to Arrow -// consumers with zero-copy -class GrpcBuffer : public arrow::MutableBuffer { - public: - GrpcBuffer(grpc_slice slice, bool incref) - : MutableBuffer(GRPC_SLICE_START_PTR(slice), - static_cast(GRPC_SLICE_LENGTH(slice))), - slice_(incref ? grpc_slice_ref(slice) : slice) {} - - ~GrpcBuffer() override { - // Decref slice - grpc_slice_unref(slice_); - } - - static arrow::Status Wrap(ByteBuffer* cpp_buf, std::shared_ptr* out) { - // These types are guaranteed by static assertions in gRPC to have the same - // in-memory representation - - auto buffer = *reinterpret_cast(cpp_buf); - - // This part below is based on the Flatbuffers gRPC SerializationTraits in - // flatbuffers/grpc.h - - // Check if this is a single uncompressed slice. - if ((buffer->type == GRPC_BB_RAW) && - (buffer->data.raw.compression == GRPC_COMPRESS_NONE) && - (buffer->data.raw.slice_buffer.count == 1)) { - // If it is, then we can reference the `grpc_slice` directly. - grpc_slice slice = buffer->data.raw.slice_buffer.slices[0]; - - // Increment reference count so this memory remains valid - *out = std::make_shared(slice, true); - } else { - // Otherwise, we need to use `grpc_byte_buffer_reader_readall` to read - // `buffer` into a single contiguous `grpc_slice`. The gRPC reader gives - // us back a new slice with the refcount already incremented. - grpc_byte_buffer_reader reader; - if (!grpc_byte_buffer_reader_init(&reader, buffer)) { - return arrow::Status::IOError("Internal gRPC error reading from ByteBuffer"); - } - grpc_slice slice = grpc_byte_buffer_reader_readall(&reader); - grpc_byte_buffer_reader_destroy(&reader); - - // Steal the slice reference - *out = std::make_shared(slice, false); - } - - return arrow::Status::OK(); - } - - private: - grpc_slice slice_; -}; - -// Read internal::FlightData from grpc::ByteBuffer containing FlightData -// protobuf without copying -template <> -class SerializationTraits { - public: - static Status Serialize(const FlightData& msg, ByteBuffer** buffer, bool* own_buffer) { - return Status(StatusCode::UNIMPLEMENTED, - "internal::FlightData serialization not implemented"); - } - - static Status Deserialize(ByteBuffer* buffer, FlightData* out) { - if (!buffer) { - return Status(StatusCode::INTERNAL, "No payload"); - } +using arrow::ipc::internal::IpcPayload; - std::shared_ptr wrapped_buffer; - GRPC_RETURN_NOT_OK(GrpcBuffer::Wrap(buffer, &wrapped_buffer)); - - auto buffer_length = static_cast(wrapped_buffer->size()); - CodedInputStream pb_stream(wrapped_buffer->data(), buffer_length); - - // TODO(wesm): The 2-parameter version of this function is deprecated - pb_stream.SetTotalBytesLimit(buffer_length, -1 /* no threshold */); - - // This is the bytes remaining when using CodedInputStream like this - while (pb_stream.BytesUntilTotalBytesLimit()) { - const uint32_t tag = pb_stream.ReadTag(); - const int field_number = WireFormatLite::GetTagFieldNumber(tag); - switch (field_number) { - case pb::FlightData::kFlightDescriptorFieldNumber: { - pb::FlightDescriptor pb_descriptor; - if (!pb_descriptor.ParseFromCodedStream(&pb_stream)) { - return Status(StatusCode::INTERNAL, "Unable to parse FlightDescriptor"); - } - } break; - case pb::FlightData::kDataHeaderFieldNumber: { - if (!ReadBytesZeroCopy(wrapped_buffer, &pb_stream, &out->metadata)) { - return Status(StatusCode::INTERNAL, "Unable to read FlightData metadata"); - } - } break; - case pb::FlightData::kDataBodyFieldNumber: { - if (!ReadBytesZeroCopy(wrapped_buffer, &pb_stream, &out->body)) { - return Status(StatusCode::INTERNAL, "Unable to read FlightData body"); - } - } break; - default: - DCHECK(false) << "cannot happen"; - } - } - buffer->Clear(); - - // TODO(wesm): Where and when should we verify that the FlightData is not - // malformed or missing components? - - return Status::OK; - } -}; - -} // namespace grpc +namespace pb = arrow::flight::protocol; namespace arrow { namespace flight { @@ -225,9 +75,12 @@ class FlightStreamReader : public RecordBatchReader { } // For customizing read path for better memory/serialization efficiency + // XXX this cast is undefined behavior auto custom_reader = reinterpret_cast*>(stream_.get()); - if (custom_reader->Read(&data)) { + // Explicitly specify the override to invoke - otherwise compiler + // may invoke through vtable (not updated by reinterpret_cast) + if (custom_reader->grpc::ClientReader::Read(&data)) { std::unique_ptr message; // Validate IPC message @@ -259,6 +112,82 @@ class FlightStreamReader : public RecordBatchReader { std::unique_ptr> stream_; }; +class FlightClient; + +/// \brief A RecordBatchWriter implementation that writes to a Flight +/// DoPut stream. +class FlightPutWriter::FlightPutWriterImpl : public ipc::RecordBatchWriter { + public: + explicit FlightPutWriterImpl(std::unique_ptr rpc, + const FlightDescriptor& descriptor, + const std::shared_ptr& schema, + MemoryPool* pool = default_memory_pool()) + : rpc_(std::move(rpc)), descriptor_(descriptor), schema_(schema), pool_(pool) {} + + Status WriteRecordBatch(const RecordBatch& batch, bool allow_64bit = false) override { + IpcPayload payload; + RETURN_NOT_OK(ipc::internal::GetRecordBatchPayload(batch, pool_, &payload)); + // XXX this cast is undefined behavior + auto custom_writer = reinterpret_cast*>(writer_.get()); + // Explicitly specify the override to invoke - otherwise compiler + // may invoke through vtable (not updated by reinterpret_cast) + if (!custom_writer->grpc::ClientWriter::Write(payload, + grpc::WriteOptions())) { + std::stringstream ss; + ss << "Could not write record batch to stream: " + << rpc_->context.debug_error_string(); + return Status::IOError(ss.str()); + } + return Status::OK(); + } + + Status Close() override { + bool finished_writes = writer_->WritesDone(); + RETURN_NOT_OK(internal::FromGrpcStatus(writer_->Finish())); + if (!finished_writes) { + return Status::UnknownError( + "Could not finish writing record batches before closing"); + } + return Status::OK(); + } + + void set_memory_pool(MemoryPool* pool) override { pool_ = pool; } + + private: + /// \brief Set the gRPC writer backing this Flight stream. + /// \param [in] writer the gRPC writer + void set_stream(std::unique_ptr> writer) { + writer_ = std::move(writer); + } + + // TODO: there isn't a way to access this as a user. + protocol::PutResult response; + std::unique_ptr rpc_; + FlightDescriptor descriptor_; + std::shared_ptr schema_; + std::unique_ptr> writer_; + MemoryPool* pool_; + + // We need to reference some fields + friend class FlightClient; +}; + +FlightPutWriter::~FlightPutWriter() {} + +FlightPutWriter::FlightPutWriter(std::unique_ptr impl) { + impl_ = std::move(impl); +} + +Status FlightPutWriter::WriteRecordBatch(const RecordBatch& batch, bool allow_64bit) { + return impl_->WriteRecordBatch(batch, allow_64bit); +} + +Status FlightPutWriter::Close() { return impl_->Close(); } + +void FlightPutWriter::set_memory_pool(MemoryPool* pool) { + return impl_->set_memory_pool(pool); +} + class FlightClient::FlightClientImpl { public: Status Connect(const std::string& host, int port) { @@ -364,8 +293,38 @@ class FlightClient::FlightClientImpl { return Status::OK(); } - Status DoPut(std::unique_ptr* stream) { - return Status::NotImplemented("DoPut"); + Status DoPut(const FlightDescriptor& descriptor, const std::shared_ptr& schema, + std::unique_ptr* stream) { + std::unique_ptr rpc(new ClientRpc); + std::unique_ptr out( + new FlightPutWriter::FlightPutWriterImpl(std::move(rpc), descriptor, schema)); + std::unique_ptr> write_stream( + stub_->DoPut(&out->rpc_->context, &out->response)); + + // First write the descriptor and schema to the stream. + pb::FlightData descriptor_message; + RETURN_NOT_OK( + internal::ToProto(descriptor, descriptor_message.mutable_flight_descriptor())); + + std::shared_ptr header_buf; + RETURN_NOT_OK(Buffer::FromString("", &header_buf)); + ipc::DictionaryMemo dictionary_memo; + RETURN_NOT_OK(ipc::SerializeSchema(*schema, out->pool_, &header_buf)); + RETURN_NOT_OK( + ipc::internal::WriteSchemaMessage(*schema, &dictionary_memo, &header_buf)); + descriptor_message.set_data_header(header_buf->ToString()); + + if (!write_stream->Write(descriptor_message, grpc::WriteOptions())) { + std::stringstream ss; + ss << "Could not write descriptor and schema to stream: " + << rpc->context.debug_error_string(); + return Status::IOError(ss.str()); + } + + out->set_stream(std::move(write_stream)); + *stream = + std::unique_ptr(new FlightPutWriter(std::move(out))); + return Status::OK(); } private: @@ -410,9 +369,10 @@ Status FlightClient::DoGet(const Ticket& ticket, const std::shared_ptr& return impl_->DoGet(ticket, schema, stream); } -Status FlightClient::DoPut(const Schema& schema, - std::unique_ptr* stream) { - return Status::NotImplemented("DoPut"); +Status FlightClient::DoPut(const FlightDescriptor& descriptor, + const std::shared_ptr& schema, + std::unique_ptr* stream) { + return impl_->DoPut(descriptor, schema, stream); } } // namespace flight diff --git a/cpp/src/arrow/flight/client.h b/cpp/src/arrow/flight/client.h index 53bb1755b2995..ef960417b024a 100644 --- a/cpp/src/arrow/flight/client.h +++ b/cpp/src/arrow/flight/client.h @@ -24,6 +24,7 @@ #include #include +#include "arrow/ipc/writer.h" #include "arrow/status.h" #include "arrow/util/visibility.h" @@ -37,6 +38,8 @@ class Schema; namespace flight { +class FlightPutWriter; + /// \brief Client class for Arrow Flight RPC services (gRPC-based). /// API experimental for now class ARROW_EXPORT FlightClient { @@ -86,7 +89,7 @@ class ARROW_EXPORT FlightClient { /// \brief Given a flight ticket and schema, request to be sent the /// stream. Returns record batch stream reader - /// \param[in] ticket + /// \param[in] ticket The flight ticket to use /// \param[in] schema the schema of the stream data as computed by /// GetFlightInfo /// \param[out] stream the returned RecordBatchReader @@ -94,12 +97,15 @@ class ARROW_EXPORT FlightClient { Status DoGet(const Ticket& ticket, const std::shared_ptr& schema, std::unique_ptr* stream); - /// \brief Initiate DoPut RPC, returns FlightPutWriter interface to - /// write. Not yet implemented - /// \param[in] schema the schema of the stream data - /// \param[out] stream the created stream to write record batches to + /// \brief Upload data to a Flight described by the given + /// descriptor. The caller must call Close() on the returned stream + /// once they are done writing. + /// \param[in] descriptor the descriptor of the stream + /// \param[in] schema the schema for the data to upload + /// \param[out] stream a writer to write record batches to /// \return Status - Status DoPut(const Schema& schema, std::unique_ptr* stream); + Status DoPut(const FlightDescriptor& descriptor, const std::shared_ptr& schema, + std::unique_ptr* stream); private: FlightClient(); @@ -107,5 +113,22 @@ class ARROW_EXPORT FlightClient { std::unique_ptr impl_; }; +/// \brief An interface to upload record batches to a Flight server +class ARROW_EXPORT FlightPutWriter : public ipc::RecordBatchWriter { + public: + ~FlightPutWriter(); + + Status WriteRecordBatch(const RecordBatch& batch, bool allow_64bit = false) override; + Status Close() override; + void set_memory_pool(MemoryPool* pool) override; + + private: + class FlightPutWriterImpl; + explicit FlightPutWriter(std::unique_ptr impl); + std::unique_ptr impl_; + + friend class FlightClient; +}; + } // namespace flight } // namespace arrow diff --git a/cpp/src/arrow/flight/internal.cc b/cpp/src/arrow/flight/internal.cc index b4c6b2addcc11..a614450e8d0a0 100644 --- a/cpp/src/arrow/flight/internal.cc +++ b/cpp/src/arrow/flight/internal.cc @@ -131,6 +131,21 @@ void ToProto(const Ticket& ticket, pb::Ticket* pb_ticket) { pb_ticket->set_ticket(ticket.ticket); } +// FlightData + +Status FromProto(const pb::FlightData& pb_data, FlightDescriptor* descriptor, + std::unique_ptr* message) { + RETURN_NOT_OK(internal::FromProto(pb_data.flight_descriptor(), descriptor)); + const std::string& header = pb_data.data_header(); + const std::string& body = pb_data.data_body(); + std::shared_ptr header_buf = Buffer::Wrap(header.data(), header.size()); + std::shared_ptr body_buf = Buffer::Wrap(body.data(), body.size()); + if (header_buf == nullptr || body_buf == nullptr) { + return Status::UnknownError("Could not create buffers from protobuf"); + } + return ipc::Message::Open(header_buf, body_buf, message); +} + // FlightEndpoint Status FromProto(const pb::FlightEndpoint& pb_endpoint, FlightEndpoint* endpoint) { @@ -156,7 +171,7 @@ Status FromProto(const pb::FlightDescriptor& pb_descriptor, FlightDescriptor* descriptor) { if (pb_descriptor.type() == pb::FlightDescriptor::PATH) { descriptor->type = FlightDescriptor::PATH; - descriptor->path.resize(pb_descriptor.path_size()); + descriptor->path.reserve(pb_descriptor.path_size()); for (int i = 0; i < pb_descriptor.path_size(); ++i) { descriptor->path.emplace_back(pb_descriptor.path(i)); } diff --git a/cpp/src/arrow/flight/internal.h b/cpp/src/arrow/flight/internal.h index bae1eedfa9c66..7f9bda138cbb1 100644 --- a/cpp/src/arrow/flight/internal.h +++ b/cpp/src/arrow/flight/internal.h @@ -57,6 +57,8 @@ Status FromProto(const pb::Result& pb_result, Result* result); Status FromProto(const pb::Criteria& pb_criteria, Criteria* criteria); Status FromProto(const pb::Location& pb_location, Location* location); Status FromProto(const pb::Ticket& pb_ticket, Ticket* ticket); +Status FromProto(const pb::FlightData& pb_data, FlightDescriptor* descriptor, + std::unique_ptr* message); Status FromProto(const pb::FlightDescriptor& pb_descr, FlightDescriptor* descr); Status FromProto(const pb::FlightEndpoint& pb_endpoint, FlightEndpoint* endpoint); Status FromProto(const pb::FlightGetInfo& pb_info, FlightInfo::Data* info); diff --git a/cpp/src/arrow/flight/serialization-internal.cc b/cpp/src/arrow/flight/serialization-internal.cc new file mode 100644 index 0000000000000..194a7b5bc0c30 --- /dev/null +++ b/cpp/src/arrow/flight/serialization-internal.cc @@ -0,0 +1,37 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "arrow/flight/serialization-internal.h" + +namespace arrow { +namespace flight { +namespace internal { + +bool ReadBytesZeroCopy(const std::shared_ptr& source_data, + CodedInputStream* input, std::shared_ptr* out) { + uint32_t length; + if (!input->ReadVarint32(&length)) { + return false; + } + *out = arrow::SliceBuffer(source_data, input->CurrentPosition(), + static_cast(length)); + return input->Skip(static_cast(length)); +} + +} // namespace internal +} // namespace flight +} // namespace arrow diff --git a/cpp/src/arrow/flight/serialization-internal.h b/cpp/src/arrow/flight/serialization-internal.h new file mode 100644 index 0000000000000..d4254d606d40f --- /dev/null +++ b/cpp/src/arrow/flight/serialization-internal.h @@ -0,0 +1,345 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +// (De)serialization utilities that hook into gRPC, efficiently +// handling Arrow-encoded data in a gRPC call. + +#pragma once + +#include +#include + +#include "google/protobuf/io/coded_stream.h" +#include "google/protobuf/io/zero_copy_stream.h" +#include "google/protobuf/wire_format_lite.h" +#include "grpc/byte_buffer_reader.h" +#include "grpcpp/grpcpp.h" + +#include "arrow/ipc/writer.h" +#include "arrow/record_batch.h" +#include "arrow/status.h" +#include "arrow/util/logging.h" + +#include "arrow/flight/Flight.grpc.pb.h" +#include "arrow/flight/Flight.pb.h" +#include "arrow/flight/internal.h" +#include "arrow/flight/types.h" + +namespace pb = arrow::flight::protocol; + +using arrow::ipc::internal::IpcPayload; + +constexpr int64_t kInt32Max = std::numeric_limits::max(); + +namespace arrow { +namespace flight { + +/// Internal, not user-visible type used for memory-efficient reads from gRPC +/// stream +struct FlightData { + /// Used only for puts, may be null + std::unique_ptr descriptor; + + /// Non-length-prefixed Message header as described in format/Message.fbs + std::shared_ptr metadata; + + /// Message body + std::shared_ptr body; +}; + +namespace internal { + +using google::protobuf::io::CodedInputStream; +using google::protobuf::io::CodedOutputStream; + +// More efficient writing of FlightData to gRPC output buffer +// Implementation of ZeroCopyOutputStream that writes to a fixed-size buffer +class FixedSizeProtoWriter : public ::google::protobuf::io::ZeroCopyOutputStream { + public: + explicit FixedSizeProtoWriter(grpc_slice slice) + : slice_(slice), + bytes_written_(0), + total_size_(static_cast(GRPC_SLICE_LENGTH(slice))) {} + + bool Next(void** data, int* size) override { + // Consume the whole slice + *data = GRPC_SLICE_START_PTR(slice_) + bytes_written_; + *size = total_size_ - bytes_written_; + bytes_written_ = total_size_; + return true; + } + + void BackUp(int count) override { bytes_written_ -= count; } + + int64_t ByteCount() const override { return bytes_written_; } + + private: + grpc_slice slice_; + int bytes_written_; + int total_size_; +}; + +bool ReadBytesZeroCopy(const std::shared_ptr& source_data, + CodedInputStream* input, std::shared_ptr* out); + +// Internal wrapper for gRPC ByteBuffer so its memory can be exposed to Arrow +// consumers with zero-copy +class GrpcBuffer : public arrow::MutableBuffer { + public: + GrpcBuffer(grpc_slice slice, bool incref) + : MutableBuffer(GRPC_SLICE_START_PTR(slice), + static_cast(GRPC_SLICE_LENGTH(slice))), + slice_(incref ? grpc_slice_ref(slice) : slice) {} + + ~GrpcBuffer() override { + // Decref slice + grpc_slice_unref(slice_); + } + + static arrow::Status Wrap(grpc::ByteBuffer* cpp_buf, + std::shared_ptr* out) { + // These types are guaranteed by static assertions in gRPC to have the same + // in-memory representation + + auto buffer = *reinterpret_cast(cpp_buf); + + // This part below is based on the Flatbuffers gRPC SerializationTraits in + // flatbuffers/grpc.h + + // Check if this is a single uncompressed slice. + if ((buffer->type == GRPC_BB_RAW) && + (buffer->data.raw.compression == GRPC_COMPRESS_NONE) && + (buffer->data.raw.slice_buffer.count == 1)) { + // If it is, then we can reference the `grpc_slice` directly. + grpc_slice slice = buffer->data.raw.slice_buffer.slices[0]; + + // Increment reference count so this memory remains valid + *out = std::make_shared(slice, true); + } else { + // Otherwise, we need to use `grpc_byte_buffer_reader_readall` to read + // `buffer` into a single contiguous `grpc_slice`. The gRPC reader gives + // us back a new slice with the refcount already incremented. + grpc_byte_buffer_reader reader; + if (!grpc_byte_buffer_reader_init(&reader, buffer)) { + return arrow::Status::IOError("Internal gRPC error reading from ByteBuffer"); + } + grpc_slice slice = grpc_byte_buffer_reader_readall(&reader); + grpc_byte_buffer_reader_destroy(&reader); + + // Steal the slice reference + *out = std::make_shared(slice, false); + } + + return arrow::Status::OK(); + } + + private: + grpc_slice slice_; +}; + +} // namespace internal + +} // namespace flight +} // namespace arrow + +namespace grpc { + +using arrow::flight::FlightData; +using arrow::flight::internal::FixedSizeProtoWriter; +using arrow::flight::internal::GrpcBuffer; +using arrow::flight::internal::ReadBytesZeroCopy; + +using google::protobuf::internal::WireFormatLite; +using google::protobuf::io::CodedInputStream; +using google::protobuf::io::CodedOutputStream; + +// Helper to log status code, as gRPC doesn't expose why +// (de)serialization fails +inline Status FailSerialization(Status status) { + if (!status.ok()) { + ARROW_LOG(WARNING) << "Error deserializing Flight message: " + << status.error_message(); + } + return status; +} + +inline arrow::Status FailSerialization(arrow::Status status) { + if (!status.ok()) { + ARROW_LOG(WARNING) << "Error deserializing Flight message: " << status.ToString(); + } + return status; +} + +// Read internal::FlightData from grpc::ByteBuffer containing FlightData +// protobuf without copying +template <> +class SerializationTraits { + public: + static Status Serialize(const FlightData& msg, ByteBuffer** buffer, bool* own_buffer) { + return FailSerialization(Status( + StatusCode::UNIMPLEMENTED, "internal::FlightData serialization not implemented")); + } + + static Status Deserialize(ByteBuffer* buffer, FlightData* out) { + if (!buffer) { + return FailSerialization(Status(StatusCode::INTERNAL, "No payload")); + } + + std::shared_ptr wrapped_buffer; + GRPC_RETURN_NOT_OK(FailSerialization(GrpcBuffer::Wrap(buffer, &wrapped_buffer))); + + auto buffer_length = static_cast(wrapped_buffer->size()); + CodedInputStream pb_stream(wrapped_buffer->data(), buffer_length); + + // TODO(wesm): The 2-parameter version of this function is deprecated + pb_stream.SetTotalBytesLimit(buffer_length, -1 /* no threshold */); + + // This is the bytes remaining when using CodedInputStream like this + while (pb_stream.BytesUntilTotalBytesLimit()) { + const uint32_t tag = pb_stream.ReadTag(); + const int field_number = WireFormatLite::GetTagFieldNumber(tag); + switch (field_number) { + case pb::FlightData::kFlightDescriptorFieldNumber: { + pb::FlightDescriptor pb_descriptor; + if (!pb_descriptor.ParseFromCodedStream(&pb_stream)) { + return FailSerialization( + Status(StatusCode::INTERNAL, "Unable to parse FlightDescriptor")); + } + } break; + case pb::FlightData::kDataHeaderFieldNumber: { + if (!ReadBytesZeroCopy(wrapped_buffer, &pb_stream, &out->metadata)) { + return FailSerialization( + Status(StatusCode::INTERNAL, "Unable to read FlightData metadata")); + } + } break; + case pb::FlightData::kDataBodyFieldNumber: { + if (!ReadBytesZeroCopy(wrapped_buffer, &pb_stream, &out->body)) { + return FailSerialization( + Status(StatusCode::INTERNAL, "Unable to read FlightData body")); + } + } break; + default: + DCHECK(false) << "cannot happen"; + } + } + buffer->Clear(); + + // TODO(wesm): Where and when should we verify that the FlightData is not + // malformed or missing components? + + return Status::OK; + } +}; + +// Write FlightData to a grpc::ByteBuffer without extra copying +template <> +class SerializationTraits { + public: + static grpc::Status Deserialize(ByteBuffer* buffer, IpcPayload* out) { + return FailSerialization(grpc::Status(grpc::StatusCode::UNIMPLEMENTED, + "IpcPayload deserialization not implemented")); + } + + static grpc::Status Serialize(const IpcPayload& msg, ByteBuffer* out, + bool* own_buffer) { + size_t total_size = 0; + + DCHECK_LT(msg.metadata->size(), kInt32Max); + const int32_t metadata_size = static_cast(msg.metadata->size()); + + // 1 byte for metadata tag + total_size += 1 + WireFormatLite::LengthDelimitedSize(metadata_size); + + int64_t body_size = 0; + for (const auto& buffer : msg.body_buffers) { + // Buffer may be null when the row length is zero, or when all + // entries are invalid. + if (!buffer) continue; + + body_size += buffer->size(); + + const int64_t remainder = buffer->size() % 8; + if (remainder) { + body_size += 8 - remainder; + } + } + + // 2 bytes for body tag + // Only written when there are body buffers + if (msg.body_length > 0) { + total_size += + 2 + WireFormatLite::LengthDelimitedSize(static_cast(body_size)); + } + + // TODO(wesm): messages over 2GB unlikely to be yet supported + if (total_size > kInt32Max) { + return FailSerialization( + grpc::Status(grpc::StatusCode::INVALID_ARGUMENT, + "Cannot send record batches exceeding 2GB yet")); + } + + // Allocate slice, assign to output buffer + grpc::Slice slice(total_size); + + // XXX(wesm): for debugging + // std::cout << "Writing record batch with total size " << total_size << std::endl; + + FixedSizeProtoWriter writer(*reinterpret_cast(&slice)); + CodedOutputStream pb_stream(&writer); + + // Write header + WireFormatLite::WriteTag(pb::FlightData::kDataHeaderFieldNumber, + WireFormatLite::WIRETYPE_LENGTH_DELIMITED, &pb_stream); + pb_stream.WriteVarint32(metadata_size); + pb_stream.WriteRawMaybeAliased(msg.metadata->data(), + static_cast(msg.metadata->size())); + + // Don't write tag if there are no body buffers + if (msg.body_length > 0) { + // Write body + WireFormatLite::WriteTag(pb::FlightData::kDataBodyFieldNumber, + WireFormatLite::WIRETYPE_LENGTH_DELIMITED, &pb_stream); + pb_stream.WriteVarint32(static_cast(body_size)); + + constexpr uint8_t kPaddingBytes[8] = {0}; + + for (const auto& buffer : msg.body_buffers) { + // Buffer may be null when the row length is zero, or when all + // entries are invalid. + if (!buffer) continue; + + pb_stream.WriteRawMaybeAliased(buffer->data(), static_cast(buffer->size())); + + // Write padding if not multiple of 8 + const int remainder = static_cast(buffer->size() % 8); + if (remainder) { + pb_stream.WriteRawMaybeAliased(kPaddingBytes, 8 - remainder); + } + } + } + + DCHECK_EQ(static_cast(total_size), pb_stream.ByteCount()); + + // Hand off the slice to the returned ByteBuffer + grpc::ByteBuffer tmp(&slice, 1); + out->Swap(&tmp); + *own_buffer = true; + return grpc::Status::OK; + } +}; + +} // namespace grpc diff --git a/cpp/src/arrow/flight/server.cc b/cpp/src/arrow/flight/server.cc index 018c079501f2f..ac5b53532866f 100644 --- a/cpp/src/arrow/flight/server.cc +++ b/cpp/src/arrow/flight/server.cc @@ -18,15 +18,12 @@ #include "arrow/flight/server.h" #include -#include #include #include -#include "google/protobuf/io/coded_stream.h" -#include "google/protobuf/io/zero_copy_stream.h" -#include "google/protobuf/wire_format_lite.h" #include "grpcpp/grpcpp.h" +#include "arrow/ipc/reader.h" #include "arrow/ipc/writer.h" #include "arrow/record_batch.h" #include "arrow/status.h" @@ -35,6 +32,7 @@ #include "arrow/flight/Flight.grpc.pb.h" #include "arrow/flight/Flight.pb.h" #include "arrow/flight/internal.h" +#include "arrow/flight/serialization-internal.h" #include "arrow/flight/types.h" using FlightService = arrow::flight::protocol::FlightService; @@ -47,145 +45,64 @@ using ServerWriter = grpc::ServerWriter; namespace pb = arrow::flight::protocol; -constexpr int64_t kInt32Max = std::numeric_limits::max(); - -namespace grpc { - -using google::protobuf::internal::WireFormatLite; -using google::protobuf::io::CodedOutputStream; +namespace arrow { +namespace flight { -// More efficient writing of FlightData to gRPC output buffer -// Implementation of ZeroCopyOutputStream that writes to a fixed-size buffer -class FixedSizeProtoWriter : public ::google::protobuf::io::ZeroCopyOutputStream { - public: - explicit FixedSizeProtoWriter(grpc_slice slice) - : slice_(slice), - bytes_written_(0), - total_size_(static_cast(GRPC_SLICE_LENGTH(slice))) {} - - bool Next(void** data, int* size) override { - // Consume the whole slice - *data = GRPC_SLICE_START_PTR(slice_) + bytes_written_; - *size = total_size_ - bytes_written_; - bytes_written_ = total_size_; - return true; +#define CHECK_ARG_NOT_NULL(VAL, MESSAGE) \ + if (VAL == nullptr) { \ + return grpc::Status(grpc::StatusCode::INVALID_ARGUMENT, MESSAGE); \ } - void BackUp(int count) override { bytes_written_ -= count; } - - int64_t ByteCount() const override { return bytes_written_; } - - private: - grpc_slice slice_; - int bytes_written_; - int total_size_; -}; - -// Write FlightData to a grpc::ByteBuffer without extra copying -template <> -class SerializationTraits { +class FlightMessageReaderImpl : public FlightMessageReader { public: - static grpc::Status Deserialize(ByteBuffer* buffer, IpcPayload* out) { - return grpc::Status(grpc::StatusCode::UNIMPLEMENTED, - "IpcPayload deserialization not implemented"); - } - - static grpc::Status Serialize(const IpcPayload& msg, ByteBuffer* out, - bool* own_buffer) { - size_t total_size = 0; - - DCHECK_LT(msg.metadata->size(), kInt32Max); - const int32_t metadata_size = static_cast(msg.metadata->size()); - - // 1 byte for metadata tag - total_size += 1 + WireFormatLite::LengthDelimitedSize(metadata_size); - - int64_t body_size = 0; - for (const auto& buffer : msg.body_buffers) { - // Buffer may be null when the row length is zero, or when all - // entries are invalid. - if (!buffer) continue; - - body_size += buffer->size(); - - const int64_t remainder = buffer->size() % 8; - if (remainder) { - body_size += 8 - remainder; - } - } - - // 2 bytes for body tag - // Only written when there are body buffers - if (msg.body_length > 0) { - total_size += - 2 + WireFormatLite::LengthDelimitedSize(static_cast(body_size)); + FlightMessageReaderImpl(const FlightDescriptor& descriptor, + std::shared_ptr schema, + grpc::ServerReader* reader) + : descriptor_(descriptor), + schema_(schema), + reader_(reader), + stream_finished_(false) {} + + const FlightDescriptor& descriptor() const override { return descriptor_; } + + std::shared_ptr schema() const override { return schema_; } + + Status ReadNext(std::shared_ptr* out) override { + if (stream_finished_) { + *out = nullptr; + return Status::OK(); } - // TODO(wesm): messages over 2GB unlikely to be yet supported - if (total_size > kInt32Max) { - return grpc::Status(grpc::StatusCode::INVALID_ARGUMENT, - "Cannot send record batches exceeding 2GB yet"); - } - - // Allocate slice, assign to output buffer - grpc::Slice slice(total_size); - - // XXX(wesm): for debugging - // std::cout << "Writing record batch with total size " << total_size << std::endl; - - FixedSizeProtoWriter writer(*reinterpret_cast(&slice)); - CodedOutputStream pb_stream(&writer); - - // Write header - WireFormatLite::WriteTag(pb::FlightData::kDataHeaderFieldNumber, - WireFormatLite::WIRETYPE_LENGTH_DELIMITED, &pb_stream); - pb_stream.WriteVarint32(metadata_size); - pb_stream.WriteRawMaybeAliased(msg.metadata->data(), - static_cast(msg.metadata->size())); - - // Don't write tag if there are no body buffers - if (msg.body_length > 0) { - // Write body - WireFormatLite::WriteTag(pb::FlightData::kDataBodyFieldNumber, - WireFormatLite::WIRETYPE_LENGTH_DELIMITED, &pb_stream); - pb_stream.WriteVarint32(static_cast(body_size)); - - constexpr uint8_t kPaddingBytes[8] = {0}; - - for (const auto& buffer : msg.body_buffers) { - // Buffer may be null when the row length is zero, or when all - // entries are invalid. - if (!buffer) continue; - - pb_stream.WriteRawMaybeAliased(buffer->data(), static_cast(buffer->size())); - - // Write padding if not multiple of 8 - const int remainder = static_cast(buffer->size() % 8); - if (remainder) { - pb_stream.WriteRawMaybeAliased(kPaddingBytes, 8 - remainder); - } + // XXX this cast is undefined behavior + auto custom_reader = reinterpret_cast*>(reader_); + + FlightData data; + // Explicitly specify the override to invoke - otherwise compiler + // may invoke through vtable (not updated by reinterpret_cast) + if (custom_reader->grpc::ServerReader::Read(&data)) { + std::unique_ptr message; + + // Validate IPC message + RETURN_NOT_OK(ipc::Message::Open(data.metadata, data.body, &message)); + if (message->type() == ipc::Message::Type::RECORD_BATCH) { + return ipc::ReadRecordBatch(*message, schema_, out); + } else { + return Status(StatusCode::Invalid, "Unrecognized message in Flight stream"); } + } else { + // Stream is completed + stream_finished_ = true; + *out = nullptr; + return Status::OK(); } - - DCHECK_EQ(static_cast(total_size), pb_stream.ByteCount()); - - // Hand off the slice to the returned ByteBuffer - grpc::ByteBuffer tmp(&slice, 1); - out->Swap(&tmp); - *own_buffer = true; - return grpc::Status::OK; } -}; - -} // namespace grpc - -namespace arrow { -namespace flight { -#define CHECK_ARG_NOT_NULL(VAL, MESSAGE) \ - if (VAL == nullptr) { \ - return grpc::Status(grpc::StatusCode::INVALID_ARGUMENT, MESSAGE); \ - } + private: + FlightDescriptor descriptor_; + std::shared_ptr schema_; + grpc::ServerReader* reader_; + bool stream_finished_; +}; // This class glues an implementation of FlightServerBase together with the // gRPC service definition, so the latter is not exposed in the public API @@ -268,6 +185,7 @@ class FlightServiceImpl : public FlightService::Service { GRPC_RETURN_NOT_OK(server_->DoGet(ticket, &data_stream)); // Requires ServerWriter customization in grpc_customizations.h + // XXX this cast is undefined behavior auto custom_writer = reinterpret_cast*>(writer); // Write the schema as the first message in the stream @@ -276,7 +194,10 @@ class FlightServiceImpl : public FlightService::Service { ipc::DictionaryMemo dictionary_memo; GRPC_RETURN_NOT_OK(ipc::internal::GetSchemaPayload( *data_stream->schema(), pool, &dictionary_memo, &schema_payload)); - custom_writer->Write(schema_payload, grpc::WriteOptions()); + // Explicitly specify the override to invoke - otherwise compiler + // may invoke through vtable (not updated by reinterpret_cast) + custom_writer->grpc::ServerWriter::Write(schema_payload, + grpc::WriteOptions()); while (true) { IpcPayload payload; @@ -293,7 +214,30 @@ class FlightServiceImpl : public FlightService::Service { grpc::Status DoPut(ServerContext* context, grpc::ServerReader* reader, pb::PutResult* response) { - return grpc::Status(grpc::StatusCode::UNIMPLEMENTED, ""); + // Get metadata + pb::FlightData data; + if (reader->Read(&data)) { + FlightDescriptor descriptor; + // Message only lives as long as data + std::unique_ptr message; + GRPC_RETURN_NOT_OK(internal::FromProto(data, &descriptor, &message)); + + if (!message || message->type() != ipc::Message::Type::SCHEMA) { + return internal::ToGrpcStatus( + Status(StatusCode::Invalid, "DoPut must start with schema/descriptor")); + } else { + std::shared_ptr schema; + GRPC_RETURN_NOT_OK(ipc::ReadSchema(*message, &schema)); + + auto message_reader = std::unique_ptr( + new FlightMessageReaderImpl(descriptor, schema, reader)); + return internal::ToGrpcStatus(server_->DoPut(std::move(message_reader))); + } + } else { + return internal::ToGrpcStatus( + Status(StatusCode::Invalid, + "Client provided malformed message or did not provide message")); + } } grpc::Status ListActions(ServerContext* context, const pb::Empty* request, @@ -376,6 +320,10 @@ Status FlightServerBase::DoGet(const Ticket& request, return Status::NotImplemented("NYI"); } +Status FlightServerBase::DoPut(std::unique_ptr reader) { + return Status::NotImplemented("NYI"); +} + Status FlightServerBase::DoAction(const Action& action, std::unique_ptr* result) { return Status::NotImplemented("NYI"); diff --git a/cpp/src/arrow/flight/server.h b/cpp/src/arrow/flight/server.h index b3b8239132b7a..b2e8b02be8e7d 100644 --- a/cpp/src/arrow/flight/server.h +++ b/cpp/src/arrow/flight/server.h @@ -29,6 +29,7 @@ #include "arrow/flight/types.h" #include "arrow/ipc/dictionary.h" +#include "arrow/record_batch.h" namespace arrow { @@ -68,9 +69,9 @@ class ARROW_EXPORT FlightDataStream { /// \brief A basic implementation of FlightDataStream that will provide /// a sequence of FlightData messages to be written to a gRPC stream -/// \param[in] reader produces a sequence of record batches class ARROW_EXPORT RecordBatchStream : public FlightDataStream { public: + /// \param[in] reader produces a sequence of record batches explicit RecordBatchStream(const std::shared_ptr& reader); std::shared_ptr schema() override; @@ -81,6 +82,13 @@ class ARROW_EXPORT RecordBatchStream : public FlightDataStream { std::shared_ptr reader_; }; +/// \brief A reader for IPC payloads uploaded by a client +class ARROW_EXPORT FlightMessageReader : public RecordBatchReader { + public: + /// \brief Get the descriptor for this upload. + virtual const FlightDescriptor& descriptor() const = 0; +}; + /// \brief Skeleton RPC server implementation which can be used to create /// custom servers by implementing its abstract methods class ARROW_EXPORT FlightServerBase { @@ -90,8 +98,7 @@ class ARROW_EXPORT FlightServerBase { /// \brief Run an insecure server on localhost at the indicated port. Block /// until server is shut down or otherwise terminates - /// \param[in] port - /// \return Status + /// \param[in] port the port to bind to void Run(int port); /// \brief Shut down the server. Can be called from signal handler or another @@ -125,7 +132,10 @@ class ARROW_EXPORT FlightServerBase { /// \return Status virtual Status DoGet(const Ticket& request, std::unique_ptr* stream); - // virtual Status DoPut(std::unique_ptr* reader) = 0; + /// \brief Process a stream of IPC payloads sent from a client + /// \param[in] reader a sequence of uploaded record batches + /// \return Status + virtual Status DoPut(std::unique_ptr reader); /// \brief Execute an action, return stream of zero or more results /// \param[in] action the action to execute, with type and body diff --git a/cpp/src/arrow/flight/test-integration-client.cc b/cpp/src/arrow/flight/test-integration-client.cc index 267025a451cc7..62522833f4ba3 100644 --- a/cpp/src/arrow/flight/test-integration-client.cc +++ b/cpp/src/arrow/flight/test-integration-client.cc @@ -15,12 +15,11 @@ // specific language governing permissions and limitations // under the License. -// Client implementation for Flight integration testing. Requests the given -// path from the Flight server, which reads that file and sends it as a stream -// to the client. The client writes the server stream to the IPC file format at -// the given output file path. The integration test script then uses the -// existing integration test tools to compare the output binary with the -// original JSON +// Client implementation for Flight integration testing. Loads +// RecordBatches from the given JSON file and uploads them to the +// Flight server, which stores the data and schema in memory. The +// client then requests the data from the server and compares it to +// the data originally uploaded. #include #include @@ -31,6 +30,7 @@ #include "arrow/io/test-common.h" #include "arrow/ipc/json.h" #include "arrow/record_batch.h" +#include "arrow/table.h" #include "arrow/flight/server.h" #include "arrow/flight/test-util.h" @@ -38,7 +38,60 @@ DEFINE_string(host, "localhost", "Server port to connect to"); DEFINE_int32(port, 31337, "Server port to connect to"); DEFINE_string(path, "", "Resource path to request"); -DEFINE_string(output, "", "Where to write requested resource"); + +/// \brief Helper to read a RecordBatchReader into a Table. +arrow::Status ReadToTable(std::unique_ptr& reader, + std::shared_ptr* retrieved_data) { + std::vector> retrieved_chunks; + std::shared_ptr chunk; + while (true) { + RETURN_NOT_OK(reader->ReadNext(&chunk)); + if (chunk == nullptr) break; + retrieved_chunks.push_back(chunk); + } + return arrow::Table::FromRecordBatches(reader->schema(), retrieved_chunks, + retrieved_data); +} + +/// \brief Helper to read a JsonReader into a Table. +arrow::Status ReadToTable(std::unique_ptr& reader, + std::shared_ptr* retrieved_data) { + std::vector> retrieved_chunks; + std::shared_ptr chunk; + for (int i = 0; i < reader->num_record_batches(); i++) { + RETURN_NOT_OK(reader->ReadRecordBatch(i, &chunk)); + retrieved_chunks.push_back(chunk); + } + return arrow::Table::FromRecordBatches(reader->schema(), retrieved_chunks, + retrieved_data); +} + +/// \brief Helper to copy a RecordBatchReader to a RecordBatchWriter. +arrow::Status CopyReaderToWriter(std::unique_ptr& reader, + arrow::ipc::RecordBatchWriter& writer) { + while (true) { + std::shared_ptr chunk; + RETURN_NOT_OK(reader->ReadNext(&chunk)); + if (chunk == nullptr) break; + RETURN_NOT_OK(writer.WriteRecordBatch(*chunk)); + } + return writer.Close(); +} + +/// \brief Helper to read a flight into a Table. +arrow::Status ConsumeFlightLocation(const arrow::flight::Location& location, + const arrow::flight::Ticket& ticket, + const std::shared_ptr& schema, + std::shared_ptr* retrieved_data) { + std::unique_ptr read_client; + RETURN_NOT_OK( + arrow::flight::FlightClient::Connect(location.host, location.port, &read_client)); + + std::unique_ptr stream; + RETURN_NOT_OK(read_client->DoGet(ticket, schema, &stream)); + + return ReadToTable(stream, retrieved_data); +} int main(int argc, char** argv) { gflags::SetUsageMessage("Integration testing client for Flight."); @@ -49,6 +102,25 @@ int main(int argc, char** argv) { arrow::flight::FlightDescriptor descr{ arrow::flight::FlightDescriptor::PATH, "", {FLAGS_path}}; + + // 1. Put the data to the server. + std::unique_ptr reader; + std::cout << "Opening JSON file '" << FLAGS_path << "'" << std::endl; + std::shared_ptr in_file; + ABORT_NOT_OK(arrow::io::ReadableFile::Open(FLAGS_path, &in_file)); + ABORT_NOT_OK(arrow::ipc::internal::json::JsonReader::Open(arrow::default_memory_pool(), + in_file, &reader)); + + std::shared_ptr original_data; + ABORT_NOT_OK(ReadToTable(reader, &original_data)); + + std::unique_ptr write_stream; + ABORT_NOT_OK(client->DoPut(descr, reader->schema(), &write_stream)); + std::unique_ptr table_reader( + new arrow::TableBatchReader(*original_data)); + ABORT_NOT_OK(CopyReaderToWriter(table_reader, *write_stream)); + + // 2. Get the ticket for the data. std::unique_ptr info; ABORT_NOT_OK(client->GetFlightInfo(descr, &info)); @@ -60,23 +132,27 @@ int main(int argc, char** argv) { return -1; } - arrow::flight::Ticket ticket = info->endpoints()[0].ticket; - std::unique_ptr stream; - ABORT_NOT_OK(client->DoGet(ticket, schema, &stream)); - - std::shared_ptr out_file; - ABORT_NOT_OK(arrow::io::FileOutputStream::Open(FLAGS_output, &out_file)); - std::shared_ptr writer; - ABORT_NOT_OK(arrow::ipc::RecordBatchFileWriter::Open(out_file.get(), schema, &writer)); - - std::shared_ptr chunk; - while (true) { - ABORT_NOT_OK(stream->ReadNext(&chunk)); - if (chunk == nullptr) break; - ABORT_NOT_OK(writer->WriteRecordBatch(*chunk)); + for (const arrow::flight::FlightEndpoint& endpoint : info->endpoints()) { + const auto& ticket = endpoint.ticket; + + auto locations = endpoint.locations; + if (locations.size() == 0) { + locations = {arrow::flight::Location{FLAGS_host, FLAGS_port}}; + } + + for (const auto location : locations) { + std::cout << "Verifying location " << location.host << ':' << location.port + << std::endl; + // 3. Download the data from the server. + std::shared_ptr retrieved_data; + ABORT_NOT_OK(ConsumeFlightLocation(location, ticket, schema, &retrieved_data)); + + // 4. Validate that the data is equal. + if (!original_data->Equals(*retrieved_data)) { + std::cerr << "Data does not match!" << std::endl; + return 1; + } + } } - - ABORT_NOT_OK(writer->Close()); - return 0; } diff --git a/cpp/src/arrow/flight/test-integration-server.cc b/cpp/src/arrow/flight/test-integration-server.cc index 80813e7f19a4c..7e201a031943d 100644 --- a/cpp/src/arrow/flight/test-integration-server.cc +++ b/cpp/src/arrow/flight/test-integration-server.cc @@ -27,6 +27,7 @@ #include "arrow/io/test-common.h" #include "arrow/ipc/json.h" #include "arrow/record_batch.h" +#include "arrow/table.h" #include "arrow/flight/server.h" #include "arrow/flight/test-util.h" @@ -36,57 +37,7 @@ DEFINE_int32(port, 31337, "Server port to listen on"); namespace arrow { namespace flight { -class JsonReaderRecordBatchStream : public FlightDataStream { - public: - explicit JsonReaderRecordBatchStream( - std::unique_ptr&& reader) - : index_(0), pool_(default_memory_pool()), reader_(std::move(reader)) {} - - std::shared_ptr schema() override { return reader_->schema(); } - - Status Next(ipc::internal::IpcPayload* payload) override { - if (index_ >= reader_->num_record_batches()) { - // Signal that iteration is over - payload->metadata = nullptr; - return Status::OK(); - } - - std::shared_ptr batch; - RETURN_NOT_OK(reader_->ReadRecordBatch(index_, &batch)); - index_++; - - if (!batch) { - // Signal that iteration is over - payload->metadata = nullptr; - return Status::OK(); - } else { - return ipc::internal::GetRecordBatchPayload(*batch, pool_, payload); - } - } - - private: - int index_; - MemoryPool* pool_; - std::unique_ptr reader_; -}; - class FlightIntegrationTestServer : public FlightServerBase { - Status ReadJson(const std::string& json_path, - std::unique_ptr* out) { - std::shared_ptr in_file; - std::cout << "Opening JSON file '" << json_path << "'" << std::endl; - RETURN_NOT_OK(io::ReadableFile::Open(json_path, &in_file)); - - int64_t file_size = 0; - RETURN_NOT_OK(in_file->GetSize(&file_size)); - - std::shared_ptr json_buffer; - RETURN_NOT_OK(in_file->Read(file_size, &json_buffer)); - - RETURN_NOT_OK(arrow::ipc::internal::json::JsonReader::Open(json_buffer, out)); - return Status::OK(); - } - Status GetFlightInfo(const FlightDescriptor& request, std::unique_ptr* info) override { if (request.type == FlightDescriptor::PATH) { @@ -94,16 +45,19 @@ class FlightIntegrationTestServer : public FlightServerBase { return Status::Invalid("Invalid path"); } - std::unique_ptr reader; - RETURN_NOT_OK(ReadJson(request.path.back(), &reader)); + auto data = uploaded_chunks.find(request.path[0]); + if (data == uploaded_chunks.end()) { + return Status::KeyError("Could not find flight.", request.path[0]); + } + auto flight = data->second; - FlightEndpoint endpoint1({{request.path.back()}, {}}); + FlightEndpoint endpoint1({{request.path[0]}, {}}); FlightInfo::Data flight_data; - RETURN_NOT_OK(internal::SchemaToString(*reader->schema(), &flight_data.schema)); + RETURN_NOT_OK(internal::SchemaToString(*flight->schema(), &flight_data.schema)); flight_data.descriptor = request; flight_data.endpoints = {endpoint1}; - flight_data.total_records = reader->num_record_batches(); + flight_data.total_records = flight->num_rows(); flight_data.total_bytes = -1; FlightInfo value(flight_data); @@ -116,14 +70,44 @@ class FlightIntegrationTestServer : public FlightServerBase { Status DoGet(const Ticket& request, std::unique_ptr* data_stream) override { - std::unique_ptr reader; - RETURN_NOT_OK(ReadJson(request.ticket, &reader)); + auto data = uploaded_chunks.find(request.ticket); + if (data == uploaded_chunks.end()) { + return Status::KeyError("Could not find flight.", request.ticket); + } + auto flight = data->second; - *data_stream = std::unique_ptr( - new JsonReaderRecordBatchStream(std::move(reader))); + *data_stream = std::unique_ptr(new RecordBatchStream( + std::shared_ptr(new TableBatchReader(*flight)))); return Status::OK(); } + + Status DoPut(std::unique_ptr reader) override { + const FlightDescriptor& descriptor = reader->descriptor(); + + if (descriptor.type != FlightDescriptor::DescriptorType::PATH) { + return Status::Invalid("Must specify a path"); + } else if (descriptor.path.size() < 1) { + return Status::Invalid("Must specify a path"); + } + + std::string key = descriptor.path[0]; + + std::vector> retrieved_chunks; + std::shared_ptr chunk; + while (true) { + RETURN_NOT_OK(reader->ReadNext(&chunk)); + if (chunk == nullptr) break; + retrieved_chunks.push_back(chunk); + } + std::shared_ptr retrieved_data; + RETURN_NOT_OK(arrow::Table::FromRecordBatches(reader->schema(), retrieved_chunks, + &retrieved_data)); + uploaded_chunks[key] = retrieved_data; + return Status::OK(); + } + + std::unordered_map> uploaded_chunks; }; } // namespace flight diff --git a/cpp/src/arrow/flight/types.h b/cpp/src/arrow/flight/types.h index 0362105bbc592..e4251bdd5d21b 100644 --- a/cpp/src/arrow/flight/types.h +++ b/cpp/src/arrow/flight/types.h @@ -152,12 +152,6 @@ class FlightInfo { mutable bool reconstructed_schema_; }; -// TODO(wesm): NYI -class ARROW_EXPORT FlightPutWriter { - public: - virtual ~FlightPutWriter() = default; -}; - /// \brief An iterator to FlightInfo instances returned by ListFlights class ARROW_EXPORT FlightListing { public: diff --git a/cpp/src/arrow/ipc/json.cc b/cpp/src/arrow/ipc/json.cc index 61c242ca2dbbb..56fe31cbf5691 100644 --- a/cpp/src/arrow/ipc/json.cc +++ b/cpp/src/arrow/ipc/json.cc @@ -22,6 +22,7 @@ #include #include "arrow/buffer.h" +#include "arrow/io/file.h" #include "arrow/ipc/json-internal.h" #include "arrow/memory_pool.h" #include "arrow/record_batch.h" @@ -157,6 +158,18 @@ Status JsonReader::Open(MemoryPool* pool, const std::shared_ptr& data, return (*reader)->impl_->ParseAndReadSchema(); } +Status JsonReader::Open(MemoryPool* pool, + const std::shared_ptr& in_file, + std::unique_ptr* reader) { + int64_t file_size = 0; + RETURN_NOT_OK(in_file->GetSize(&file_size)); + + std::shared_ptr json_buffer; + RETURN_NOT_OK(in_file->Read(file_size, &json_buffer)); + + return Open(pool, json_buffer, reader); +} + std::shared_ptr JsonReader::schema() const { return impl_->schema(); } int JsonReader::num_record_batches() const { return impl_->num_record_batches(); } diff --git a/cpp/src/arrow/ipc/json.h b/cpp/src/arrow/ipc/json.h index 5c00555de8ec0..aeed7070fe96e 100644 --- a/cpp/src/arrow/ipc/json.h +++ b/cpp/src/arrow/ipc/json.h @@ -33,6 +33,10 @@ class MemoryPool; class RecordBatch; class Schema; +namespace io { +class ReadableFile; +} // namespace io + namespace ipc { namespace internal { namespace json { @@ -95,6 +99,15 @@ class ARROW_EXPORT JsonReader { static Status Open(const std::shared_ptr& data, std::unique_ptr* reader); + /// \brief Create a new JSON reader from a file + /// + /// \param[in] pool a MemoryPool to use for buffer allocations + /// \param[in] in_file a ReadableFile containing JSON data + /// \param[out] reader the returned reader object + /// \return Status + static Status Open(MemoryPool* pool, const std::shared_ptr& in_file, + std::unique_ptr* reader); + /// \brief Return the schema read from the JSON std::shared_ptr schema() const; diff --git a/integration/integration_test.py b/integration/integration_test.py index 0bced26f15acd..e7e8edda6ddf1 100644 --- a/integration/integration_test.py +++ b/integration/integration_test.py @@ -1004,28 +1004,15 @@ def _compare_flight_implementations(self, producer, consumer): ) print('##########################################################') - for json_path in self.json_files: - print('==========================================================') - print('Testing file {0}'.format(json_path)) - print('==========================================================') - - name = os.path.splitext(os.path.basename(json_path))[0] + with producer.flight_server(): + for json_path in self.json_files: + print('=' * 58) + print('Testing file {0}'.format(json_path)) + print('=' * 58) - file_id = guid()[:8] - - with producer.flight_server(): - # Have the client request the file - consumer_file_path = os.path.join( - self.temp_dir, - file_id + '_' + name + '.consumer_requested_file') - consumer.flight_request(producer.FLIGHT_PORT, - json_path, consumer_file_path) - - # Validate the file - print('-- Validating file') - consumer.validate(json_path, consumer_file_path) - - # TODO: also have the client upload the file + # Have the client upload the file, then download and + # compare + consumer.flight_request(producer.FLIGHT_PORT, json_path) class Tester(object): @@ -1053,7 +1040,7 @@ def validate(self, json_path, arrow_path): def flight_server(self): raise NotImplementedError - def flight_request(self, port, json_path, arrow_path): + def flight_request(self, port, json_path): raise NotImplementedError @@ -1122,12 +1109,11 @@ def file_to_stream(self, file_path, stream_path): print(' '.join(cmd)) run_cmd(cmd) - def flight_request(self, port, json_path, arrow_path): + def flight_request(self, port, json_path): cmd = ['java', '-cp', self.ARROW_FLIGHT_JAR, self.ARROW_FLIGHT_CLIENT, '-port', str(port), - '-j', json_path, - '-a', arrow_path] + '-j', json_path] if self.debug: print(' '.join(cmd)) run_cmd(cmd) @@ -1230,15 +1216,14 @@ def flight_server(self): server.terminate() server.wait(5) - def flight_request(self, port, json_path, arrow_path): + def flight_request(self, port, json_path): cmd = self.FLIGHT_CLIENT_CMD + [ '-port=' + str(port), '-path=' + json_path, - '-output=' + arrow_path ] if self.debug: print(' '.join(cmd)) - subprocess.run(cmd) + run_cmd(cmd) class JSTester(Tester): diff --git a/java/flight/src/main/java/org/apache/arrow/flight/FlightClient.java b/java/flight/src/main/java/org/apache/arrow/flight/FlightClient.java index ad7c7e28da242..bd126b5ea203c 100644 --- a/java/flight/src/main/java/org/apache/arrow/flight/FlightClient.java +++ b/java/flight/src/main/java/org/apache/arrow/flight/FlightClient.java @@ -127,7 +127,9 @@ public ClientStreamListener startPut(FlightDescriptor descriptor, VectorSchemaRo // send the schema to start. ArrowMessage message = new ArrowMessage(descriptor.toProtocol(), root.getSchema()); observer.onNext(message); - return new PutObserver(new VectorUnloader(root, true, false), observer, resultObserver.getFuture()); + return new PutObserver(new VectorUnloader( + root, true /* include # of nulls in vectors */, true /* must align buffers to be C++-compatible */), + observer, resultObserver.getFuture()); } public FlightInfo getInfo(FlightDescriptor descriptor) { @@ -211,7 +213,8 @@ public PutObserver(VectorUnloader unloader, ClientCallStreamObserver listener) { listener.onNext(new ActionType("get", "pull a stream. Action must be done via standard get mechanism")); listener.onNext(new ActionType("put", "push a stream. Action must be done via standard get mechanism")); listener.onNext(new ActionType("drop", "delete a flight. Action body is a JSON encoded path.")); + listener.onCompleted(); } @Override diff --git a/java/flight/src/main/java/org/apache/arrow/flight/example/integration/IntegrationTestClient.java b/java/flight/src/main/java/org/apache/arrow/flight/example/integration/IntegrationTestClient.java index 803a56c6c1afe..ed450074a767a 100644 --- a/java/flight/src/main/java/org/apache/arrow/flight/example/integration/IntegrationTestClient.java +++ b/java/flight/src/main/java/org/apache/arrow/flight/example/integration/IntegrationTestClient.java @@ -18,8 +18,8 @@ package org.apache.arrow.flight.example.integration; import java.io.File; -import java.io.FileOutputStream; import java.io.IOException; +import java.util.Collections; import java.util.List; import org.apache.arrow.flight.FlightClient; @@ -30,9 +30,11 @@ import org.apache.arrow.flight.Location; import org.apache.arrow.memory.BufferAllocator; import org.apache.arrow.memory.RootAllocator; +import org.apache.arrow.vector.VectorLoader; import org.apache.arrow.vector.VectorSchemaRoot; -import org.apache.arrow.vector.dictionary.DictionaryProvider; -import org.apache.arrow.vector.ipc.ArrowFileWriter; +import org.apache.arrow.vector.VectorUnloader; +import org.apache.arrow.vector.ipc.JsonFileReader; +import org.apache.arrow.vector.util.Validator; import org.apache.commons.cli.CommandLine; import org.apache.commons.cli.CommandLineParser; import org.apache.commons.cli.DefaultParser; @@ -48,10 +50,9 @@ class IntegrationTestClient { private IntegrationTestClient() { options = new Options(); - options.addOption("a", "arrow", true, "arrow file"); options.addOption("j", "json", true, "json file"); options.addOption("host", true, "The host to connect to."); - options.addOption("port", true, "The port to connect to." ); + options.addOption("port", true, "The port to connect to."); } public static void main(String[] args) { @@ -64,7 +65,7 @@ public static void main(String[] args) { } } - static void fatalError(String message, Throwable e) { + private static void fatalError(String message, Throwable e) { System.err.println(message); System.err.println(e.getMessage()); LOGGER.error(message, e); @@ -72,36 +73,65 @@ static void fatalError(String message, Throwable e) { } private void run(String[] args) throws ParseException, IOException { - CommandLineParser parser = new DefaultParser(); - CommandLine cmd = parser.parse(options, args, false); - - String fileName = cmd.getOptionValue("arrow"); - if (fileName == null) { - throw new IllegalArgumentException("missing arrow file parameter"); - } - File arrowFile = new File(fileName); - if (arrowFile.exists()) { - throw new IllegalArgumentException("arrow file already exists: " + arrowFile.getAbsolutePath()); - } + final CommandLineParser parser = new DefaultParser(); + final CommandLine cmd = parser.parse(options, args, false); final String host = cmd.getOptionValue("host", "localhost"); final int port = Integer.parseInt(cmd.getOptionValue("port", "31337")); final BufferAllocator allocator = new RootAllocator(Integer.MAX_VALUE); - FlightClient client = new FlightClient(allocator, new Location(host, port)); - FlightInfo info = client.getInfo(FlightDescriptor.path(cmd.getOptionValue("json"))); + final FlightClient client = new FlightClient(allocator, new Location(host, port)); + + final String inputPath = cmd.getOptionValue("j"); + + // 1. Read data from JSON and upload to server. + FlightDescriptor descriptor = FlightDescriptor.path(inputPath); + VectorSchemaRoot jsonRoot; + try (JsonFileReader reader = new JsonFileReader(new File(inputPath), allocator); + VectorSchemaRoot root = VectorSchemaRoot.create(reader.start(), allocator)) { + jsonRoot = VectorSchemaRoot.create(root.getSchema(), allocator); + VectorUnloader unloader = new VectorUnloader(root); + VectorLoader jsonLoader = new VectorLoader(jsonRoot); + FlightClient.ClientStreamListener stream = client.startPut(descriptor, root); + while (reader.read(root)) { + stream.putNext(); + jsonLoader.load(unloader.getRecordBatch()); + root.clear(); + } + stream.completed(); + // Need to call this, or exceptions from the server get swallowed + stream.getResult(); + } + + // 2. Get the ticket for the data. + FlightInfo info = client.getInfo(descriptor); List endpoints = info.getEndpoints(); if (endpoints.isEmpty()) { throw new RuntimeException("No endpoints returned from Flight server."); } - FlightStream stream = client.getStream(info.getEndpoints().get(0).getTicket()); - try (VectorSchemaRoot root = stream.getRoot(); - FileOutputStream fileOutputStream = new FileOutputStream(arrowFile); - ArrowFileWriter arrowWriter = new ArrowFileWriter(root, new DictionaryProvider.MapDictionaryProvider(), - fileOutputStream.getChannel())) { - while (stream.next()) { - arrowWriter.writeBatch(); + for (FlightEndpoint endpoint : info.getEndpoints()) { + // 3. Download the data from the server. + List locations = endpoint.getLocations(); + if (locations.size() == 0) { + locations = Collections.singletonList(new Location(host, port)); + } + for (Location location : locations) { + System.out.println("Verifying location " + location.getHost() + ":" + location.getPort()); + FlightClient readClient = new FlightClient(allocator, location); + FlightStream stream = readClient.getStream(endpoint.getTicket()); + VectorSchemaRoot downloadedRoot; + try (VectorSchemaRoot root = stream.getRoot()) { + downloadedRoot = VectorSchemaRoot.create(root.getSchema(), allocator); + VectorLoader loader = new VectorLoader(downloadedRoot); + VectorUnloader unloader = new VectorUnloader(root); + while (stream.next()) { + loader.load(unloader.getRecordBatch()); + } + } + + // 4. Validate the data. + Validator.compareVectorSchemaRoot(jsonRoot, downloadedRoot); } } } diff --git a/java/flight/src/main/java/org/apache/arrow/flight/example/integration/IntegrationTestServer.java b/java/flight/src/main/java/org/apache/arrow/flight/example/integration/IntegrationTestServer.java index 7b45e53a149be..eff2f5d4126cc 100644 --- a/java/flight/src/main/java/org/apache/arrow/flight/example/integration/IntegrationTestServer.java +++ b/java/flight/src/main/java/org/apache/arrow/flight/example/integration/IntegrationTestServer.java @@ -17,30 +17,11 @@ package org.apache.arrow.flight.example.integration; -import java.io.File; -import java.nio.charset.StandardCharsets; -import java.util.Collections; -import java.util.concurrent.Callable; - -import org.apache.arrow.flight.Action; -import org.apache.arrow.flight.ActionType; -import org.apache.arrow.flight.Criteria; -import org.apache.arrow.flight.FlightDescriptor; -import org.apache.arrow.flight.FlightEndpoint; -import org.apache.arrow.flight.FlightInfo; -import org.apache.arrow.flight.FlightProducer; -import org.apache.arrow.flight.FlightServer; -import org.apache.arrow.flight.FlightStream; import org.apache.arrow.flight.Location; -import org.apache.arrow.flight.Result; -import org.apache.arrow.flight.Ticket; -import org.apache.arrow.flight.auth.ServerAuthHandler; -import org.apache.arrow.flight.impl.Flight; +import org.apache.arrow.flight.example.ExampleFlightServer; import org.apache.arrow.memory.BufferAllocator; import org.apache.arrow.memory.RootAllocator; -import org.apache.arrow.vector.VectorSchemaRoot; -import org.apache.arrow.vector.ipc.JsonFileReader; -import org.apache.arrow.vector.types.pojo.Schema; +import org.apache.arrow.util.AutoCloseables; import org.apache.commons.cli.CommandLine; import org.apache.commons.cli.CommandLineParser; import org.apache.commons.cli.DefaultParser; @@ -48,6 +29,7 @@ import org.apache.commons.cli.ParseException; class IntegrationTestServer { + private static final org.slf4j.Logger LOGGER = org.slf4j.LoggerFactory.getLogger(IntegrationTestServer.class); private final Options options; private IntegrationTestServer() { @@ -58,17 +40,25 @@ private IntegrationTestServer() { private void run(String[] args) throws Exception { CommandLineParser parser = new DefaultParser(); CommandLine cmd = parser.parse(options, args, false); + final int port = Integer.parseInt(cmd.getOptionValue("port", "31337")); final BufferAllocator allocator = new RootAllocator(Long.MAX_VALUE); - final int port = Integer.parseInt(cmd.getOptionValue("port", "31337")); - try (final IntegrationFlightProducer producer = new IntegrationFlightProducer(allocator); - final FlightServer server = new FlightServer(allocator, port, producer, ServerAuthHandler.NO_OP)) { - server.start(); - // Print out message for integration test script - System.out.println("Server listening on localhost:" + server.getPort()); - while (true) { - Thread.sleep(30000); + final ExampleFlightServer efs = new ExampleFlightServer(allocator, new Location("localhost", port)); + efs.start(); + // Print out message for integration test script + System.out.println("Server listening on localhost:" + port); + + Runtime.getRuntime().addShutdownHook(new Thread(() -> { + try { + System.out.println("\nExiting..."); + AutoCloseables.close(efs, allocator); + } catch (Exception e) { + e.printStackTrace(); } + })); + + while (true) { + Thread.sleep(30000); } } @@ -76,81 +66,17 @@ public static void main(String[] args) { try { new IntegrationTestServer().run(args); } catch (ParseException e) { - IntegrationTestClient.fatalError("Error parsing arguments", e); + fatalError("Error parsing arguments", e); } catch (Exception e) { - IntegrationTestClient.fatalError("Runtime error", e); + fatalError("Runtime error", e); } } - static class IntegrationFlightProducer implements FlightProducer, AutoCloseable { - private final BufferAllocator allocator; - - IntegrationFlightProducer(BufferAllocator allocator) { - this.allocator = allocator; - } - - @Override - public void close() { - allocator.close(); - } - - @Override - public void getStream(Ticket ticket, ServerStreamListener listener) { - String path = new String(ticket.getBytes(), StandardCharsets.UTF_8); - File inputFile = new File(path); - try (JsonFileReader reader = new JsonFileReader(inputFile, allocator)) { - Schema schema = reader.start(); - try (VectorSchemaRoot root = VectorSchemaRoot.create(schema, allocator)) { - listener.start(root); - while (reader.read(root)) { - listener.putNext(); - } - listener.completed(); - } - } catch (Exception e) { - throw new RuntimeException(e); - } - } - - @Override - public void listFlights(Criteria criteria, StreamListener listener) { - listener.onCompleted(); - } - - @Override - public FlightInfo getFlightInfo(FlightDescriptor descriptor) { - if (descriptor.isCommand()) { - throw new UnsupportedOperationException("Commands not supported."); - } - if (descriptor.getPath().size() < 1) { - throw new IllegalArgumentException("Must provide a path."); - } - String path = descriptor.getPath().get(0); - File inputFile = new File(path); - try (JsonFileReader reader = new JsonFileReader(inputFile, allocator)) { - Schema schema = reader.start(); - return new FlightInfo(schema, descriptor, - Collections.singletonList(new FlightEndpoint(new Ticket(path.getBytes()), - new Location("localhost", 31338))), - 0, 0); - } catch (Exception e) { - throw new RuntimeException(e); - } - } - - @Override - public Callable acceptPut(FlightStream flightStream) { - return null; - } - - @Override - public Result doAction(Action action) { - return null; - } - - @Override - public void listActions(StreamListener listener) { - listener.onCompleted(); - } + private static void fatalError(String message, Throwable e) { + System.err.println(message); + System.err.println(e.getMessage()); + LOGGER.error(message, e); + System.exit(1); } + }