diff --git a/.github/workflows/rust.yml b/.github/workflows/rust.yml index f492b2e31d0b..933b51353d06 100644 --- a/.github/workflows/rust.yml +++ b/.github/workflows/rust.yml @@ -155,6 +155,8 @@ jobs: --health-retries 5 steps: - uses: actions/checkout@v2 + with: + submodules: true - uses: actions/setup-python@v2 with: python-version: "3.8" @@ -167,7 +169,22 @@ jobs: # make sure psql can access the server echo "$POSTGRES_HOST:$POSTGRES_PORT:$POSTGRES_DB:$POSTGRES_USER:$POSTGRES_PASSWORD" | tee ~/.pgpass chmod 0600 ~/.pgpass - psql -d "$POSTGRES_DB" -h "$POSTGRES_HOST" -p "$POSTGRES_PORT" -U "$POSTGRES_USER" -c 'select now() as now' + psql -d "$POSTGRES_DB" -h "$POSTGRES_HOST" -p "$POSTGRES_PORT" -U "$POSTGRES_USER" -c 'CREATE TABLE IF NOT EXISTS test ( + c1 character varying NOT NULL, + c2 integer NOT NULL, + c3 smallint NOT NULL, + c4 smallint NOT NULL, + c5 integer NOT NULL, + c6 bigint NOT NULL, + c7 smallint NOT NULL, + c8 integer NOT NULL, + c9 bigint NOT NULL, + c10 character varying NOT NULL, + c11 double precision NOT NULL, + c12 double precision NOT NULL, + c13 character varying NOT NULL + );' + psql -d "$POSTGRES_DB" -h "$POSTGRES_HOST" -p "$POSTGRES_PORT" -U "$POSTGRES_USER" -c "\copy test FROM '$(pwd)/testing/data/csv/aggregate_test_100.csv' WITH (FORMAT csv, HEADER true);" env: POSTGRES_HOST: localhost POSTGRES_PORT: ${{ job.services.postgres.ports[5432] }} diff --git a/datafusion-cli/src/main.rs b/datafusion-cli/src/main.rs index 5b35880580b2..083710f6dd19 100644 --- a/datafusion-cli/src/main.rs +++ b/datafusion-cli/src/main.rs @@ -58,9 +58,10 @@ pub async fn main() { ) .arg( Arg::with_name("file") - .help("Execute commands from file, then exit") + .help("Execute commands from file(s), then exit") .short("f") .long("file") + .multiple(true) .validator(is_valid_file) .takes_value(true), ) @@ -112,22 +113,25 @@ pub async fn main() { let quiet = matches.is_present("quiet"); let print_options = PrintOptions { format, quiet }; - if let Some(file_path) = matches.value_of("file") { - let file = File::open(file_path) - .unwrap_or_else(|err| panic!("cannot open file '{}': {}", file_path, err)); - let mut reader = BufReader::new(file); - exec_from_lines(&mut reader, execution_config, print_options).await; + if let Some(file_paths) = matches.values_of("file") { + let files = file_paths + .map(|file_path| File::open(file_path).unwrap()) + .collect::>(); + let mut ctx = ExecutionContext::with_config(execution_config); + for file in files { + let mut reader = BufReader::new(file); + exec_from_lines(&mut ctx, &mut reader, print_options.clone()).await; + } } else { exec_from_repl(execution_config, print_options).await; } } async fn exec_from_lines( + ctx: &mut ExecutionContext, reader: &mut BufReader, - execution_config: ExecutionConfig, print_options: PrintOptions, ) { - let mut ctx = ExecutionContext::with_config(execution_config); let mut query = "".to_owned(); for line in reader.lines() { @@ -139,7 +143,7 @@ async fn exec_from_lines( let line = line.trim_end(); query.push_str(line); if line.ends_with(';') { - match exec_and_print(&mut ctx, print_options.clone(), query).await { + match exec_and_print(ctx, print_options.clone(), query).await { Ok(_) => {} Err(err) => println!("{:?}", err), } @@ -156,7 +160,7 @@ async fn exec_from_lines( // run the left over query if the last statement doesn't contain ‘;’ if !query.is_empty() { - match exec_and_print(&mut ctx, print_options, query).await { + match exec_and_print(ctx, print_options, query).await { Ok(_) => {} Err(err) => println!("{:?}", err), } diff --git a/integration-tests/create_test_table.sql b/integration-tests/create_test_table.sql new file mode 100644 index 000000000000..89b08611d1c0 --- /dev/null +++ b/integration-tests/create_test_table.sql @@ -0,0 +1,34 @@ +-- Licensed to the Apache Software Foundation (ASF) under one +-- or more contributor license agreements. See the NOTICE file +-- distributed with this work for additional information +-- regarding copyright ownership. The ASF licenses this file +-- to you under the Apache License, Version 2.0 (the +-- "License"); you may not use this file except in compliance +-- with the License. You may obtain a copy of the License at + +-- http://www.apache.org/licenses/LICENSE-2.0 + +-- Unless required by applicable law or agreed to in writing, software +-- distributed under the License is distributed on an "AS IS" BASIS, +-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +-- See the License for the specific language governing permissions and +-- limitations under the License. + +CREATE EXTERNAL TABLE test ( + c1 VARCHAR NOT NULL, + c2 INT NOT NULL, + c3 SMALLINT NOT NULL, + c4 SMALLINT NOT NULL, + c5 INT NOT NULL, + c6 BIGINT NOT NULL, + c7 SMALLINT NOT NULL, + c8 INT NOT NULL, + c9 BIGINT NOT NULL, + c10 VARCHAR NOT NULL, + c11 FLOAT NOT NULL, + c12 DOUBLE NOT NULL, + c13 VARCHAR NOT NULL +) +STORED AS CSV +WITH HEADER ROW +LOCATION 'testing/data/csv/aggregate_test_100.csv'; diff --git a/integration-tests/sqls/simple_aggregation.sql b/integration-tests/sqls/simple_aggregation.sql new file mode 100644 index 000000000000..cbe37ed4ba31 --- /dev/null +++ b/integration-tests/sqls/simple_aggregation.sql @@ -0,0 +1,24 @@ +-- Licensed to the Apache Software Foundation (ASF) under one +-- or more contributor license agreements. See the NOTICE file +-- distributed with this work for additional information +-- regarding copyright ownership. The ASF licenses this file +-- to you under the Apache License, Version 2.0 (the +-- "License"); you may not use this file except in compliance +-- with the License. You may obtain a copy of the License at + +-- http://www.apache.org/licenses/LICENSE-2.0 + +-- Unless required by applicable law or agreed to in writing, software +-- distributed under the License is distributed on an "AS IS" BASIS, +-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +-- See the License for the specific language governing permissions and +-- limitations under the License. + +SELECT + count(*) AS count_all, + count(c3) AS count_c3, + avg(c3) AS avg, + sum(c3) AS sum, + max(c3) AS max, + min(c3) AS min +FROM test; diff --git a/integration-tests/sqls/simple_group_by.sql b/integration-tests/sqls/simple_group_by.sql new file mode 100644 index 000000000000..11fe1cce406f --- /dev/null +++ b/integration-tests/sqls/simple_group_by.sql @@ -0,0 +1,27 @@ +-- Licensed to the Apache Software Foundation (ASF) under one +-- or more contributor license agreements. See the NOTICE file +-- distributed with this work for additional information +-- regarding copyright ownership. The ASF licenses this file +-- to you under the Apache License, Version 2.0 (the +-- "License"); you may not use this file except in compliance +-- with the License. You may obtain a copy of the License at + +-- http://www.apache.org/licenses/LICENSE-2.0 + +-- Unless required by applicable law or agreed to in writing, software +-- distributed under the License is distributed on an "AS IS" BASIS, +-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +-- See the License for the specific language governing permissions and +-- limitations under the License. + + +select + c2, + sum(c3) sum_c3, + avg(c3) avg_c3, + max(c3) max_c3, + min(c3) min_c3, + count(c3) count_c3 +from test +group by c2 +order by c2; diff --git a/integration-tests/test_psql_parity.py b/integration-tests/test_psql_parity.py index 204f9063297e..f4967b8457e4 100644 --- a/integration-tests/test_psql_parity.py +++ b/integration-tests/test_psql_parity.py @@ -32,12 +32,16 @@ ) ] +CREATE_TABLE_SQL_FILE = "integration-tests/create_test_table.sql" + def generate_csv_from_datafusion(fname: str): return subprocess.check_output( [ "./target/debug/datafusion-cli", "-f", + CREATE_TABLE_SQL_FILE, + "-f", fname, "--format", "csv", @@ -70,7 +74,7 @@ class PsqlParityTest(unittest.TestCase): def test_parity(self): root = Path(os.path.dirname(__file__)) / "sqls" files = set(root.glob("*.sql")) - self.assertEqual(len(files), 2, msg="tests are missed") + self.assertEqual(len(files), 4, msg="tests are missed") for fname in files: with self.subTest(fname=fname): datafusion_output = pd.read_csv(