Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 10 additions & 1 deletion Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion prql/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ pest = "^2.1"
pest_derive = "^2.1"
serde_yaml = "^0.8"
sqlformat = "^0.1.8"
strum_macros = "^0.24" # for converting enum variants to string
strum = { version = "^0.24", features = ["std", "derive"] } # for converting enum variants to string

[dependencies.clap]
features = ["derive"]
Expand Down
40 changes: 38 additions & 2 deletions prql/src/ast.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use anyhow::{anyhow, bail, Result};
use enum_as_inner::EnumAsInner;
use serde::{Deserialize, Serialize};
use strum_macros::Display;
use strum::{self, Display, EnumString};

use crate::error::{Error, Reason, Span};
use crate::utils::*;
Expand Down Expand Up @@ -44,10 +44,40 @@ pub enum Item {

#[derive(Debug, PartialEq, Clone, Serialize, Deserialize)]
pub struct Query {
// TODO: Add dialect & prql version onto Query.
pub version: Option<String>,
#[serde(default)]
pub dialect: Dialect,
pub nodes: Vec<Node>,
}

#[derive(Debug, PartialEq, Clone, Serialize, Deserialize, EnumString)]
pub enum Dialect {
#[strum(serialize = "ansi")]
Ansi,
#[strum(serialize = "click_house")]
ClickHouse,
#[strum(serialize = "generic")]
Generic,
#[strum(serialize = "hive")]
Hive,
#[strum(serialize = "ms", serialize = "microsoft", serialize = "ms_sql_server")]
MsSql,
#[strum(serialize = "mysql")]
MySql,
#[strum(serialize = "postgresql", serialize = "pg")]
PostgreSql,
#[strum(serialize = "sqlite")]
SQLite,
#[strum(serialize = "snowflake")]
Snowflake,
}

impl Default for Dialect {
fn default() -> Self {
Dialect::Generic
}
}

#[derive(Debug, PartialEq, Clone, Serialize, Deserialize)]
pub struct ListItem(pub Node);

Expand Down Expand Up @@ -273,3 +303,9 @@ impl From<Item> for anyhow::Error {
anyhow!("Failed to convert {item:?}")
}
}

impl Dialect {
pub fn use_top(&self) -> bool {
matches!(self, Dialect::MsSql)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Possibly in the future we want to have each Dialect be a Struct that we implement a Trait for — for the moment this is excellent.

}
}
5 changes: 3 additions & 2 deletions prql/src/ast_fold.rs
Original file line number Diff line number Diff line change
Expand Up @@ -87,8 +87,9 @@ pub fn fold_item<T: ?Sized + AstFold>(fold: &mut T, item: Item) -> Result<Item>
.map(|x| fold.fold_node(x.into_inner()).map(ListItem))
.try_collect()?,
),
Item::Query(Query { nodes: items }) => Item::Query(Query {
nodes: fold.fold_nodes(items)?,
Item::Query(query) => Item::Query(Query {
nodes: fold.fold_nodes(query.nodes)?,
..query
}),
Item::InlinePipeline(InlinePipeline { value, functions }) => {
Item::InlinePipeline(InlinePipeline {
Expand Down
64 changes: 61 additions & 3 deletions prql/src/parser.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
//! of pest pairs into a tree of AST Items. It has a small function to call into
//! pest to get the parse tree / concrete syntaxt tree, and then a large
//! function for turning that into PRQL AST.
use std::str::FromStr;

use anyhow::{anyhow, bail, Context, Result};
use itertools::Itertools;
use pest::iterators::Pair;
Expand Down Expand Up @@ -50,9 +52,55 @@ fn ast_of_parse_tree(pairs: Pairs<Rule>) -> Result<Vec<Node>> {
let span = pair.as_span();

let item = match pair.as_rule() {
Rule::query => Item::Query(Query {
nodes: ast_of_parse_tree(pair.into_inner())?,
}),
Rule::query => {
let mut parsed = ast_of_parse_tree(pair.into_inner())?;

let has_def = parsed
.first()
.map(|p| matches!(p.item, Item::Query(_)))
.unwrap_or(false);

Item::Query(if has_def {
let mut query = parsed.remove(0).item.into_query().unwrap();
query.nodes = parsed;
query
} else {
Query {
dialect: Dialect::default(),
version: None,
nodes: parsed,
}
})
}
Rule::query_def => {
let parsed = ast_of_parse_tree(pair.into_inner())?;

let (_, [version, dialect]) = unpack_arguments(parsed, ["version", "dialect"]);

let version = version
.map(|v| v.unwrap(|i| i.into_ident(), "string"))
.transpose()?;

let dialect = if let Some(node) = dialect {
let span = node.span;
let dialect = node.unwrap(|i| i.into_ident(), "string")?;
Dialect::from_str(&dialect).map_err(|_| {
Error::new(Reason::NotFound {
name: dialect,
namespace: "dialect".to_string(),
})
.with_span(span)
})?
} else {
Dialect::default()
};

Item::Query(Query {
nodes: vec![],
version,
dialect,
})
}
Rule::list => Item::List(
ast_of_parse_tree(pair.into_inner())?
.into_iter()
Expand Down Expand Up @@ -912,6 +960,8 @@ take 20

assert_yaml_snapshot!(parse(r#"from mytable | select [a and b + c or d e and f]"#)?, @r###"
---
version: ~
dialect: Generic
nodes:
- Pipeline:
- From:
Expand Down Expand Up @@ -1048,6 +1098,8 @@ take 20
assert_yaml_snapshot!(ast_of_string("func median x = (x | percentile 50)", Rule::query)?, @r###"
---
Query:
version: ~
dialect: Generic
nodes:
- FuncDef:
name: median
Expand Down Expand Up @@ -1125,6 +1177,8 @@ take 20
]
"#)?, @r###"
---
version: ~
dialect: Generic
nodes:
- Pipeline:
- From:
Expand Down Expand Up @@ -1169,6 +1223,8 @@ select [
let result = parse(prql).unwrap();
assert_yaml_snapshot!(result, @r###"
---
version: ~
dialect: Generic
nodes:
- Pipeline:
- From:
Expand Down Expand Up @@ -1204,6 +1260,8 @@ select [
let result = parse(prql).unwrap();
assert_yaml_snapshot!(result, @r###"
---
version: ~
dialect: Generic
nodes:
- Pipeline:
- From:
Expand Down
4 changes: 3 additions & 1 deletion prql/src/prql.pest
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,9 @@ WHITESPACE = _{ " " | "\t" }
// Need to exclude # in strings (and maybe confirm whether this the syntax we want)
COMMENT = _{ "#" ~ (!NEWLINE ~ ANY) * }

query = { SOI ~ NEWLINE* ~ (( func_def | table | pipeline ) ~ ( NEWLINE+ | &EOI ))* ~ EOI }
query = { SOI ~ NEWLINE* ~ (query_def ~ NEWLINE+)? ~ (( func_def | table | pipeline ) ~ ( NEWLINE+ | &EOI ))* ~ EOI }

query_def = { WHITESPACE* ~ "prql" ~ named_expr_simple* }

func_def = { "func" ~ ident ~ func_def_params ~ "=" ~ expr }
// TODO: we could force the named parameters to follow the positional ones here.
Expand Down
2 changes: 1 addition & 1 deletion prql/src/semantic/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use serde::Deserialize;
use serde::Serialize;
use std::collections::HashMap;
use std::collections::HashSet;
use strum_macros::Display;
use strum::Display;

use crate::ast::*;
use crate::error::Span;
Expand Down
6 changes: 5 additions & 1 deletion prql/src/semantic/materializer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -343,6 +343,8 @@ aggregate [

assert_yaml_snapshot!(ast, @r###"
---
version: ~
dialect: Generic
nodes:
- FuncDef:
name: count
Expand Down Expand Up @@ -376,8 +378,10 @@ aggregate [
let diff = diff(&to_string(&ast)?, &to_string(&mat)?);
assert!(!diff.is_empty());
assert_display_snapshot!(diff, @r###"
@@ -1,25 +1,8 @@
@@ -1,27 +1,8 @@
---
-version: ~
-dialect: Generic
-nodes:
- - FuncDef:
- name: count
Expand Down
2 changes: 2 additions & 0 deletions prql/src/snapshots/prql__parser__test__parse_query.snap
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@ source: prql/src/parser.rs
expression: "ast_of_string(r#\"\nfrom employees\nfilter country = \"USA\" # Each line transforms the previous result.\nderive [ # This adds columns / variables.\n gross_salary: salary + payroll_tax,\n gross_cost: gross_salary + benefits_cost # Variables can use other variables.\n]\nfilter gross_cost > 0\naggregate by:[title, country] [ # `by` are the columns to group by.\n average salary, # These are aggregation calcs run on each group.\n sum salary,\n average gross_salary,\n sum gross_salary,\n average gross_cost,\n sum_gross_cost: sum gross_cost,\n ct : count,\n]\nsort sum_gross_cost\nfilter ct > 200\ntake 20\n \"#.trim(),\n Rule::query)?"
---
Query:
version: ~
dialect: Generic
nodes:
- Pipeline:
- From:
Expand Down
Loading