From 3bfe7e3be0ebc57d81c033548a71650097ec9459 Mon Sep 17 00:00:00 2001 From: aljazerzen Date: Mon, 11 Apr 2022 13:15:46 +0200 Subject: [PATCH] Parse dialect & version --- Cargo.lock | 11 +- prql/Cargo.toml | 2 +- prql/src/ast.rs | 40 ++++++- prql/src/ast_fold.rs | 5 +- prql/src/parser.rs | 64 +++++++++- prql/src/prql.pest | 4 +- prql/src/semantic/context.rs | 2 +- prql/src/semantic/materializer.rs | 6 +- .../prql__parser__test__parse_query.snap | 2 + prql/src/translator.rs | 110 +++++++++++------- prql/tests/integration/compile.rs | 2 + 11 files changed, 197 insertions(+), 51 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 7037182e2c8c..d780d5df4975 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -912,7 +912,7 @@ dependencies = [ "similar", "sqlformat", "sqlparser", - "strum_macros", + "strum", ] [[package]] @@ -1176,6 +1176,15 @@ version = "0.10.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "73473c0e59e6d5812c5dfe2a064a6444949f089e20eec9a2e5506596494e4623" +[[package]] +name = "strum" +version = "0.24.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e96acfc1b70604b8b2f1ffa4c57e59176c7dbb05d556c71ecd2f5498a1dee7f8" +dependencies = [ + "strum_macros", +] + [[package]] name = "strum_macros" version = "0.24.0" diff --git a/prql/Cargo.toml b/prql/Cargo.toml index 61477331f4e1..3285e23c9ecd 100644 --- a/prql/Cargo.toml +++ b/prql/Cargo.toml @@ -27,7 +27,7 @@ pest = "^2.1" pest_derive = "^2.1" serde_yaml = "^0.8" sqlformat = "^0.1.8" -strum_macros = "^0.24" # for converting enum variants to string +strum = { version = "^0.24", features = ["std", "derive"] } # for converting enum variants to string [dependencies.clap] features = ["derive"] diff --git a/prql/src/ast.rs b/prql/src/ast.rs index d206257619dd..918203b89f84 100644 --- a/prql/src/ast.rs +++ b/prql/src/ast.rs @@ -1,7 +1,7 @@ use anyhow::{anyhow, bail, Result}; use enum_as_inner::EnumAsInner; use serde::{Deserialize, Serialize}; -use strum_macros::Display; +use strum::{self, Display, EnumString}; use crate::error::{Error, Reason, Span}; use crate::utils::*; @@ -44,10 +44,40 @@ pub enum Item { #[derive(Debug, PartialEq, Clone, Serialize, Deserialize)] pub struct Query { - // TODO: Add dialect & prql version onto Query. + pub version: Option, + #[serde(default)] + pub dialect: Dialect, pub nodes: Vec, } +#[derive(Debug, PartialEq, Clone, Serialize, Deserialize, EnumString)] +pub enum Dialect { + #[strum(serialize = "ansi")] + Ansi, + #[strum(serialize = "click_house")] + ClickHouse, + #[strum(serialize = "generic")] + Generic, + #[strum(serialize = "hive")] + Hive, + #[strum(serialize = "ms", serialize = "microsoft", serialize = "ms_sql_server")] + MsSql, + #[strum(serialize = "mysql")] + MySql, + #[strum(serialize = "postgresql", serialize = "pg")] + PostgreSql, + #[strum(serialize = "sqlite")] + SQLite, + #[strum(serialize = "snowflake")] + Snowflake, +} + +impl Default for Dialect { + fn default() -> Self { + Dialect::Generic + } +} + #[derive(Debug, PartialEq, Clone, Serialize, Deserialize)] pub struct ListItem(pub Node); @@ -273,3 +303,9 @@ impl From for anyhow::Error { anyhow!("Failed to convert {item:?}") } } + +impl Dialect { + pub fn use_top(&self) -> bool { + matches!(self, Dialect::MsSql) + } +} diff --git a/prql/src/ast_fold.rs b/prql/src/ast_fold.rs index 2b9ddb3493d3..b238c980ddd1 100644 --- a/prql/src/ast_fold.rs +++ b/prql/src/ast_fold.rs @@ -87,8 +87,9 @@ pub fn fold_item(fold: &mut T, item: Item) -> Result .map(|x| fold.fold_node(x.into_inner()).map(ListItem)) .try_collect()?, ), - Item::Query(Query { nodes: items }) => Item::Query(Query { - nodes: fold.fold_nodes(items)?, + Item::Query(query) => Item::Query(Query { + nodes: fold.fold_nodes(query.nodes)?, + ..query }), Item::InlinePipeline(InlinePipeline { value, functions }) => { Item::InlinePipeline(InlinePipeline { diff --git a/prql/src/parser.rs b/prql/src/parser.rs index c5f5083f06b7..e90ee8f22175 100644 --- a/prql/src/parser.rs +++ b/prql/src/parser.rs @@ -2,6 +2,8 @@ //! of pest pairs into a tree of AST Items. It has a small function to call into //! pest to get the parse tree / concrete syntaxt tree, and then a large //! function for turning that into PRQL AST. +use std::str::FromStr; + use anyhow::{anyhow, bail, Context, Result}; use itertools::Itertools; use pest::iterators::Pair; @@ -50,9 +52,55 @@ fn ast_of_parse_tree(pairs: Pairs) -> Result> { let span = pair.as_span(); let item = match pair.as_rule() { - Rule::query => Item::Query(Query { - nodes: ast_of_parse_tree(pair.into_inner())?, - }), + Rule::query => { + let mut parsed = ast_of_parse_tree(pair.into_inner())?; + + let has_def = parsed + .first() + .map(|p| matches!(p.item, Item::Query(_))) + .unwrap_or(false); + + Item::Query(if has_def { + let mut query = parsed.remove(0).item.into_query().unwrap(); + query.nodes = parsed; + query + } else { + Query { + dialect: Dialect::default(), + version: None, + nodes: parsed, + } + }) + } + Rule::query_def => { + let parsed = ast_of_parse_tree(pair.into_inner())?; + + let (_, [version, dialect]) = unpack_arguments(parsed, ["version", "dialect"]); + + let version = version + .map(|v| v.unwrap(|i| i.into_ident(), "string")) + .transpose()?; + + let dialect = if let Some(node) = dialect { + let span = node.span; + let dialect = node.unwrap(|i| i.into_ident(), "string")?; + Dialect::from_str(&dialect).map_err(|_| { + Error::new(Reason::NotFound { + name: dialect, + namespace: "dialect".to_string(), + }) + .with_span(span) + })? + } else { + Dialect::default() + }; + + Item::Query(Query { + nodes: vec![], + version, + dialect, + }) + } Rule::list => Item::List( ast_of_parse_tree(pair.into_inner())? .into_iter() @@ -912,6 +960,8 @@ take 20 assert_yaml_snapshot!(parse(r#"from mytable | select [a and b + c or d e and f]"#)?, @r###" --- + version: ~ + dialect: Generic nodes: - Pipeline: - From: @@ -1048,6 +1098,8 @@ take 20 assert_yaml_snapshot!(ast_of_string("func median x = (x | percentile 50)", Rule::query)?, @r###" --- Query: + version: ~ + dialect: Generic nodes: - FuncDef: name: median @@ -1125,6 +1177,8 @@ take 20 ] "#)?, @r###" --- + version: ~ + dialect: Generic nodes: - Pipeline: - From: @@ -1169,6 +1223,8 @@ select [ let result = parse(prql).unwrap(); assert_yaml_snapshot!(result, @r###" --- + version: ~ + dialect: Generic nodes: - Pipeline: - From: @@ -1204,6 +1260,8 @@ select [ let result = parse(prql).unwrap(); assert_yaml_snapshot!(result, @r###" --- + version: ~ + dialect: Generic nodes: - Pipeline: - From: diff --git a/prql/src/prql.pest b/prql/src/prql.pest index faa4bda5bd6c..eb7b28263e64 100644 --- a/prql/src/prql.pest +++ b/prql/src/prql.pest @@ -11,7 +11,9 @@ WHITESPACE = _{ " " | "\t" } // Need to exclude # in strings (and maybe confirm whether this the syntax we want) COMMENT = _{ "#" ~ (!NEWLINE ~ ANY) * } -query = { SOI ~ NEWLINE* ~ (( func_def | table | pipeline ) ~ ( NEWLINE+ | &EOI ))* ~ EOI } +query = { SOI ~ NEWLINE* ~ (query_def ~ NEWLINE+)? ~ (( func_def | table | pipeline ) ~ ( NEWLINE+ | &EOI ))* ~ EOI } + +query_def = { WHITESPACE* ~ "prql" ~ named_expr_simple* } func_def = { "func" ~ ident ~ func_def_params ~ "=" ~ expr } // TODO: we could force the named parameters to follow the positional ones here. diff --git a/prql/src/semantic/context.rs b/prql/src/semantic/context.rs index a55393e8e7c6..d95071500086 100644 --- a/prql/src/semantic/context.rs +++ b/prql/src/semantic/context.rs @@ -4,7 +4,7 @@ use serde::Deserialize; use serde::Serialize; use std::collections::HashMap; use std::collections::HashSet; -use strum_macros::Display; +use strum::Display; use crate::ast::*; use crate::error::Span; diff --git a/prql/src/semantic/materializer.rs b/prql/src/semantic/materializer.rs index ae548e0692bb..a27c15da6f0e 100644 --- a/prql/src/semantic/materializer.rs +++ b/prql/src/semantic/materializer.rs @@ -343,6 +343,8 @@ aggregate [ assert_yaml_snapshot!(ast, @r###" --- + version: ~ + dialect: Generic nodes: - FuncDef: name: count @@ -376,8 +378,10 @@ aggregate [ let diff = diff(&to_string(&ast)?, &to_string(&mat)?); assert!(!diff.is_empty()); assert_display_snapshot!(diff, @r###" - @@ -1,25 +1,8 @@ + @@ -1,27 +1,8 @@ --- + -version: ~ + -dialect: Generic -nodes: - - FuncDef: - name: count diff --git a/prql/src/snapshots/prql__parser__test__parse_query.snap b/prql/src/snapshots/prql__parser__test__parse_query.snap index 21e8766f6239..943810d30526 100644 --- a/prql/src/snapshots/prql__parser__test__parse_query.snap +++ b/prql/src/snapshots/prql__parser__test__parse_query.snap @@ -3,6 +3,8 @@ source: prql/src/parser.rs expression: "ast_of_string(r#\"\nfrom employees\nfilter country = \"USA\" # Each line transforms the previous result.\nderive [ # This adds columns / variables.\n gross_salary: salary + payroll_tax,\n gross_cost: gross_salary + benefits_cost # Variables can use other variables.\n]\nfilter gross_cost > 0\naggregate by:[title, country] [ # `by` are the columns to group by.\n average salary, # These are aggregation calcs run on each group.\n sum salary,\n average gross_salary,\n sum gross_salary,\n average gross_cost,\n sum_gross_cost: sum gross_cost,\n ct : count,\n]\nsort sum_gross_cost\nfilter ct > 200\ntake 20\n \"#.trim(),\n Rule::query)?" --- Query: + version: ~ + dialect: Generic nodes: - Pipeline: - From: diff --git a/prql/src/translator.rs b/prql/src/translator.rs index 973c69f19bcf..8f93cc23d0d9 100644 --- a/prql/src/translator.rs +++ b/prql/src/translator.rs @@ -11,11 +11,12 @@ use anyhow::{anyhow, bail, Result}; use itertools::Itertools; use sqlformat::{format, FormatOptions, QueryParams}; +use sqlparser::ast::Value; use sqlparser::ast::{ - self as sql_ast, Expr, FunctionArgExpr, Join, JoinConstraint, JoinOperator, ObjectName, - OrderByExpr, Select, SelectItem, SetExpr, TableAlias, TableFactor, TableWithJoins, Top, + self as sql_ast, Expr, Function, FunctionArg, FunctionArgExpr, Join, JoinConstraint, + JoinOperator, ObjectName, OrderByExpr, Select, SelectItem, SetExpr, TableAlias, TableFactor, + TableWithJoins, Top, }; -use sqlparser::ast::{Function, FunctionArg}; use std::collections::HashMap; use super::ast::*; @@ -83,10 +84,13 @@ pub fn translate_query(query: &Query) -> Result { let ctes = materialized; // convert each of the CTEs - let ctes: Vec<_> = ctes.into_iter().map(table_to_sql_cte).try_collect()?; + let ctes: Vec<_> = ctes + .into_iter() + .map(|t| table_to_sql_cte(t, &query.dialect)) + .try_collect()?; // convert main query - let mut main_query = sql_query_of_atomic_table(main_query)?; + let mut main_query = sql_query_of_atomic_table(main_query, &query.dialect)?; // attach CTEs if !ctes.is_empty() { @@ -145,14 +149,14 @@ fn separate_pipeline(query: &Query) -> Result<(Vec, Vec, Vec Result { +fn table_to_sql_cte(table: AtomicTable, dialect: &Dialect) -> Result { let alias = sql_ast::TableAlias { name: Item::Ident(table.name.clone()).try_into()?, columns: vec![], }; Ok(sql_ast::Cte { alias, - query: sql_query_of_atomic_table(table)?, + query: sql_query_of_atomic_table(table, dialect)?, from: None, }) } @@ -179,7 +183,7 @@ fn table_factor_of_table_ref(table_ref: &TableRef) -> TableFactor { } } -fn sql_query_of_atomic_table(table: AtomicTable) -> Result { +fn sql_query_of_atomic_table(table: AtomicTable, dialect: &Dialect) -> Result { // TODO: possibly do validation here? e.g. check there isn't more than one // `from`? Or do we rely on the caller for that? @@ -267,11 +271,11 @@ fn sql_query_of_atomic_table(table: AtomicTable) -> Result { .pipeline .iter() .filter_map(|t| match t { - Transform::Take(_) => Some(t.clone().try_into()), + Transform::Take(take) => Some(*take), _ => None, }) - .last() - .transpose()?; + .min() + .map(expr_of_i64); // Find the final sort (none of the others affect the result, and can be discarded). let order_by = table @@ -304,7 +308,11 @@ fn sql_query_of_atomic_table(table: AtomicTable) -> Result { Ok(sql_ast::Query { body: SetExpr::Select(Box::new(Select { distinct: false, - top: None, + top: if dialect.use_top() { + take.clone().map(top_of_expr) + } else { + None + }, projection: (table.select.0.into_iter()) .map(|n| n.item.try_into()) .try_collect()?, @@ -319,8 +327,7 @@ fn sql_query_of_atomic_table(table: AtomicTable) -> Result { })), order_by, with: None, - // TODO: when should this be `TOP` vs `LIMIT` (which is on the `Query` object?) - limit: take, + limit: if dialect.use_top() { None } else { take }, offset: None, fetch: None, }) @@ -455,33 +462,18 @@ impl TryFrom for SelectItem { } } -impl TryFrom for Expr { - type Error = anyhow::Error; - fn try_from(transformation: Transform) -> Result { - match transformation { - Transform::Take(take) => Ok( - // TODO: implement for number - Item::Raw(take.to_string()).try_into()?, - ), - _ => Err(anyhow!( - "Expr transformation currently only supported for Take" - )), - } - } +fn expr_of_i64(number: i64) -> Expr { + Expr::Value(Value::Number( + number.to_string(), + number.leading_zeros() < 32, + )) } -impl TryFrom for Top { - type Error = anyhow::Error; - fn try_from(transformation: Transform) -> Result { - match transformation { - Transform::Take(take) => Ok(Top { - // TODO: implement for number - quantity: Some(Item::Raw(take.to_string()).try_into()?), - with_ties: false, - percent: false, - }), - _ => Err(anyhow!("Top transformation only supported for Take")), - } +fn top_of_expr(take: Expr) -> Top { + Top { + quantity: Some(take), + with_ties: false, + percent: false, } } @@ -1162,4 +1154,44 @@ take 20 "###); Ok(()) } + + #[test] + fn test_dialects() -> Result<()> { + // Generic + let query: Query = parse( + r###" + prql dialect:generic + from Employees + select [FirstName] + take 3 + "###, + )?; + + assert_display_snapshot!((translate(&query)?), @r###" + SELECT + FirstName + FROM + Employees + LIMIT + 3 + "###); + + // SQL server + let query: Query = parse( + r###" + prql dialect:ms_sql_server + from Employees + select [FirstName] + take 3 + "###, + )?; + + assert_display_snapshot!((translate(&query)?), @r###" + SELECT + TOP (3) FirstName + FROM + Employees + "###); + Ok(()) + } } diff --git a/prql/tests/integration/compile.rs b/prql/tests/integration/compile.rs index 4d3547994ef6..2d309561ab57 100644 --- a/prql/tests/integration/compile.rs +++ b/prql/tests/integration/compile.rs @@ -6,6 +6,8 @@ fn parse_simple_string_to_ast() -> Result<()> { assert_eq!( serde_yaml::to_string(&parse("select 1")?)?, r#"--- +version: ~ +dialect: Generic nodes: - Pipeline: - Select: