Skip to content

Commit

Permalink
feat: adds parse options for SQL parser (#3193)
Browse files Browse the repository at this point in the history
* feat: adds parse options for parser and timezone to scan request

* chore: remove timezone in ScanRequest

* feat: remove timezone in parse options and adds type checking to parititon columns

* fix: comment

* chore: apply suggestions

Co-authored-by: Yingwen <realevenyag@gmail.com>

* fix: format

---------

Co-authored-by: Yingwen <realevenyag@gmail.com>
  • Loading branch information
killme2008 and evenyag committed Jan 19, 2024
1 parent 632edd0 commit 5e89472
Show file tree
Hide file tree
Showing 49 changed files with 617 additions and 248 deletions.
6 changes: 3 additions & 3 deletions src/cmd/src/cli/repl.rs
Original file line number Diff line number Diff line change
Expand Up @@ -157,11 +157,11 @@ impl Repl {
let start = Instant::now();

let output = if let Some(query_engine) = &self.query_engine {
let stmt = QueryLanguageParser::parse_sql(&sql)
.with_context(|_| ParseSqlSnafu { sql: sql.clone() })?;

let query_ctx = QueryContext::with(self.database.catalog(), self.database.schema());

let stmt = QueryLanguageParser::parse_sql(&sql, &query_ctx)
.with_context(|_| ParseSqlSnafu { sql: sql.clone() })?;

let plan = query_engine
.planner()
.plan(stmt, query_ctx)
Expand Down
1 change: 1 addition & 0 deletions src/datanode/src/region_server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -596,6 +596,7 @@ impl RegionServerInner {
} = request;
let region_id = RegionId::from_u64(region_id);

// Build query context from gRPC header
let ctx: QueryContextRef = header
.as_ref()
.map(|h| Arc::new(h.into()))
Expand Down
10 changes: 6 additions & 4 deletions src/frontend/src/instance.rs
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ use servers::server::{start_server, ServerHandlers};
use session::context::QueryContextRef;
use snafu::prelude::*;
use sql::dialect::Dialect;
use sql::parser::ParserContext;
use sql::parser::{ParseOptions, ParserContext};
use sql::statements::copy::CopyTable;
use sql::statements::statement::Statement;
use sqlparser::ast::ObjectName;
Expand Down Expand Up @@ -253,7 +253,7 @@ impl FrontendInstance for Instance {
}

fn parse_stmt(sql: &str, dialect: &(dyn Dialect + Send + Sync)) -> Result<Vec<Statement>> {
ParserContext::create_with_dialect(sql, dialect).context(ParseSqlSnafu)
ParserContext::create_with_dialect(sql, dialect, ParseOptions::default()).context(ParseSqlSnafu)
}

impl Instance {
Expand Down Expand Up @@ -414,8 +414,10 @@ impl PrometheusHandler for Instance {
.check_permission(query_ctx.current_user(), PermissionReq::PromQuery)
.context(AuthSnafu)?;

let stmt = QueryLanguageParser::parse_promql(query).with_context(|_| ParsePromQLSnafu {
query: query.clone(),
let stmt = QueryLanguageParser::parse_promql(query, &query_ctx).with_context(|_| {
ParsePromQLSnafu {
query: query.clone(),
}
})?;

let output = self
Expand Down
11 changes: 6 additions & 5 deletions src/operator/src/expr_factory.rs
Original file line number Diff line number Diff line change
Expand Up @@ -411,18 +411,19 @@ pub(crate) fn to_alter_expr(
mod tests {
use session::context::QueryContext;
use sql::dialect::GreptimeDbDialect;
use sql::parser::ParserContext;
use sql::parser::{ParseOptions, ParserContext};
use sql::statements::statement::Statement;

use super::*;

#[test]
fn test_create_to_expr() {
let sql = "CREATE TABLE monitor (host STRING,ts TIMESTAMP,TIME INDEX (ts),PRIMARY KEY(host)) ENGINE=mito WITH(regions=1, ttl='3days', write_buffer_size='1024KB');";
let stmt = ParserContext::create_with_dialect(sql, &GreptimeDbDialect {})
.unwrap()
.pop()
.unwrap();
let stmt =
ParserContext::create_with_dialect(sql, &GreptimeDbDialect {}, ParseOptions::default())
.unwrap()
.pop()
.unwrap();

let Statement::CreateTable(create_table) = stmt else {
unreachable!()
Expand Down
9 changes: 7 additions & 2 deletions src/operator/src/statement/ddl.rs
Original file line number Diff line number Diff line change
Expand Up @@ -632,7 +632,7 @@ fn merge_options(mut table_opts: TableOptions, schema_opts: SchemaNameValue) ->
mod test {
use session::context::QueryContext;
use sql::dialect::GreptimeDbDialect;
use sql::parser::ParserContext;
use sql::parser::{ParseOptions, ParserContext};
use sql::statements::statement::Statement;

use super::*;
Expand Down Expand Up @@ -698,7 +698,12 @@ ENGINE=mito",
),
];
for (sql, expected) in cases {
let result = ParserContext::create_with_dialect(sql, &GreptimeDbDialect {}).unwrap();
let result = ParserContext::create_with_dialect(
sql,
&GreptimeDbDialect {},
ParseOptions::default(),
)
.unwrap();
match &result[0] {
Statement::CreateTable(c) => {
let expr = expr_factory::create_to_expr(c, QueryContext::arc()).unwrap();
Expand Down
6 changes: 3 additions & 3 deletions src/operator/src/statement/tql.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,15 +35,15 @@ impl StatementExecutor {
step: eval.step,
query: eval.query,
};
QueryLanguageParser::parse_promql(&promql).context(ParseQuerySnafu)?
QueryLanguageParser::parse_promql(&promql, &query_ctx).context(ParseQuerySnafu)?
}
Tql::Explain(explain) => {
let promql = PromQuery {
query: explain.query,
..PromQuery::default()
};
let params = HashMap::from([("name".to_string(), EXPLAIN_NODE_NAME.to_string())]);
QueryLanguageParser::parse_promql(&promql)
QueryLanguageParser::parse_promql(&promql, &query_ctx)
.context(ParseQuerySnafu)?
.post_process(params)
.unwrap()
Expand All @@ -56,7 +56,7 @@ impl StatementExecutor {
query: tql_analyze.query,
};
let params = HashMap::from([("name".to_string(), ANALYZE_NODE_NAME.to_string())]);
QueryLanguageParser::parse_promql(&promql)
QueryLanguageParser::parse_promql(&promql, &query_ctx)
.context(ParseQuerySnafu)?
.post_process(params)
.unwrap()
Expand Down
9 changes: 5 additions & 4 deletions src/query/src/datafusion.rs
Original file line number Diff line number Diff line change
Expand Up @@ -519,7 +519,7 @@ mod tests {
use datatypes::vectors::{Helper, StringVectorBuilder, UInt32Vector, UInt64Vector, VectorRef};
use session::context::QueryContext;
use sql::dialect::GreptimeDbDialect;
use sql::parser::ParserContext;
use sql::parser::{ParseOptions, ParserContext};
use sql::statements::show::{ShowKind, ShowTables};
use sql::statements::statement::Statement;
use table::table::numbers::{NumbersTable, NUMBERS_TABLE_NAME};
Expand Down Expand Up @@ -547,7 +547,7 @@ mod tests {
let engine = create_test_engine().await;
let sql = "select sum(number) from numbers limit 20";

let stmt = QueryLanguageParser::parse_sql(sql).unwrap();
let stmt = QueryLanguageParser::parse_sql(sql, &QueryContext::arc()).unwrap();
let plan = engine
.planner()
.plan(stmt, QueryContext::arc())
Expand All @@ -569,7 +569,7 @@ mod tests {
let engine = create_test_engine().await;
let sql = "select sum(number) from numbers limit 20";

let stmt = QueryLanguageParser::parse_sql(sql).unwrap();
let stmt = QueryLanguageParser::parse_sql(sql, &QueryContext::arc()).unwrap();
let plan = engine
.planner()
.plan(stmt, QueryContext::arc())
Expand Down Expand Up @@ -643,7 +643,7 @@ mod tests {
let engine = create_test_engine().await;
let sql = "select sum(number) from numbers limit 20";

let stmt = QueryLanguageParser::parse_sql(sql).unwrap();
let stmt = QueryLanguageParser::parse_sql(sql, &QueryContext::arc()).unwrap();

let plan = engine
.planner()
Expand Down Expand Up @@ -709,6 +709,7 @@ mod tests {
let statement = ParserContext::create_with_dialect(
"SHOW TABLES WHERE \"Tables\"='monitor'",
&GreptimeDbDialect {},
ParseOptions::default(),
)
.unwrap()[0]
.clone();
Expand Down
26 changes: 16 additions & 10 deletions src/query/src/parser.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,10 @@ use common_error::status_code::StatusCode;
use promql_parser::parser::ast::{Extension as NodeExtension, ExtensionExpr};
use promql_parser::parser::Expr::Extension;
use promql_parser::parser::{EvalStmt, Expr, ValueType};
use session::context::QueryContextRef;
use snafu::{OptionExt, ResultExt};
use sql::dialect::GreptimeDbDialect;
use sql::parser::ParserContext;
use sql::parser::{ParseOptions, ParserContext};
use sql::statements::statement::Statement;

use crate::error::{
Expand Down Expand Up @@ -101,16 +102,17 @@ impl Default for PromQuery {
}
}

/// Query language parser, supports parsing SQL and PromQL
pub struct QueryLanguageParser {}

impl QueryLanguageParser {
pub fn parse_sql(sql: &str) -> Result<QueryStatement> {
/// Try to parse SQL with GreptimeDB dialect, return the statement when success.
pub fn parse_sql(sql: &str, _query_ctx: &QueryContextRef) -> Result<QueryStatement> {
let _timer = METRIC_PARSE_SQL_ELAPSED.start_timer();
let mut statement = ParserContext::create_with_dialect(sql, &GreptimeDbDialect {})
.map_err(BoxedError::new)
.context(QueryParseSnafu {
query: sql.to_string(),
})?;
let mut statement =
ParserContext::create_with_dialect(sql, &GreptimeDbDialect {}, ParseOptions::default())
.map_err(BoxedError::new)
.context(QueryParseSnafu { query: sql })?;
if statement.len() != 1 {
MultipleStatementsSnafu {
query: sql.to_string(),
Expand All @@ -121,7 +123,8 @@ impl QueryLanguageParser {
}
}

pub fn parse_promql(query: &PromQuery) -> Result<QueryStatement> {
/// Try to parse PromQL, return the statement when success.
pub fn parse_promql(query: &PromQuery, _query_ctx: &QueryContextRef) -> Result<QueryStatement> {
let _timer = METRIC_PARSE_PROMQL_ELAPSED.start_timer();

let expr = promql_parser::parser::parse(&query.query)
Expand Down Expand Up @@ -165,6 +168,7 @@ impl QueryLanguageParser {
}

fn parse_promql_timestamp(timestamp: &str) -> Result<SystemTime> {
// FIXME(dennis): aware of timezone
// try rfc3339 format
let rfc3339_result = DateTime::parse_from_rfc3339(timestamp)
.context(ParseTimestampSnafu { raw: timestamp })
Expand Down Expand Up @@ -236,13 +240,15 @@ define_node_ast_extension!(Explain, ExplainExpr, Expr, EXPLAIN_NODE_NAME);

#[cfg(test)]
mod test {
use session::context::QueryContext;

use super::*;

// Detailed logic tests are covered in the parser crate.
#[test]
fn parse_sql_simple() {
let sql = "select * from t1";
let stmt = QueryLanguageParser::parse_sql(sql).unwrap();
let stmt = QueryLanguageParser::parse_sql(sql, &QueryContext::arc()).unwrap();
let expected = String::from("Sql(Query(Query { \
inner: Query { \
with: None, body: Select(Select { \
Expand Down Expand Up @@ -360,7 +366,7 @@ mod test {
})",
);

let result = QueryLanguageParser::parse_promql(&promql).unwrap();
let result = QueryLanguageParser::parse_promql(&promql, &QueryContext::arc()).unwrap();
assert_eq!(format!("{result:?}"), expected);
}
}
2 changes: 1 addition & 1 deletion src/query/src/range_select/plan_rewrite.rs
Original file line number Diff line number Diff line change
Expand Up @@ -526,7 +526,7 @@ mod test {
}

async fn do_query(sql: &str) -> Result<crate::plan::LogicalPlan> {
let stmt = QueryLanguageParser::parse_sql(sql).unwrap();
let stmt = QueryLanguageParser::parse_sql(sql, &QueryContext::arc()).unwrap();
let engine = create_test_engine().await;
engine.planner().plan(stmt, QueryContext::arc()).await
}
Expand Down
2 changes: 1 addition & 1 deletion src/query/src/sql.rs
Original file line number Diff line number Diff line change
Expand Up @@ -657,7 +657,7 @@ mod test {
variable: ObjectName(vec![Ident::new(variable)]),
};
let ctx = QueryContextBuilder::default()
.timezone(Timezone::from_tz_string(tz).unwrap())
.timezone(Arc::new(Timezone::from_tz_string(tz).unwrap()))
.build();
match show_variable(stmt, ctx) {
Ok(Output::RecordBatches(record)) => {
Expand Down
5 changes: 3 additions & 2 deletions src/query/src/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,13 +36,14 @@ mod function;
mod pow;

async fn exec_selection(engine: QueryEngineRef, sql: &str) -> Vec<RecordBatch> {
let stmt = QueryLanguageParser::parse_sql(sql).unwrap();
let query_ctx = QueryContext::arc();
let stmt = QueryLanguageParser::parse_sql(sql, &query_ctx).unwrap();
let plan = engine
.planner()
.plan(stmt, QueryContext::arc())
.await
.unwrap();
let Output::Stream(stream) = engine.execute(plan, QueryContext::arc()).await.unwrap() else {
let Output::Stream(stream) = engine.execute(plan, query_ctx).await.unwrap() else {
unreachable!()
};
util::collect(stream).await.unwrap()
Expand Down
10 changes: 8 additions & 2 deletions src/query/src/tests/query_engine_test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -131,14 +131,20 @@ async fn test_query_validate() -> Result<()> {
let factory = QueryEngineFactory::new_with_plugins(catalog_list, None, None, false, plugins);
let engine = factory.query_engine();

let stmt = QueryLanguageParser::parse_sql("select number from public.numbers").unwrap();
let stmt =
QueryLanguageParser::parse_sql("select number from public.numbers", &QueryContext::arc())
.unwrap();
assert!(engine
.planner()
.plan(stmt, QueryContext::arc())
.await
.is_ok());

let stmt = QueryLanguageParser::parse_sql("select number from wrongschema.numbers").unwrap();
let stmt = QueryLanguageParser::parse_sql(
"select number from wrongschema.numbers",
&QueryContext::arc(),
)
.unwrap();
assert!(engine
.planner()
.plan(stmt, QueryContext::arc())
Expand Down
15 changes: 13 additions & 2 deletions src/script/src/engine.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ use std::collections::HashMap;
use async_trait::async_trait;
use common_error::ext::ErrorExt;
use common_query::Output;
use session::context::{QueryContext, QueryContextRef};

#[async_trait]
pub trait Script {
Expand Down Expand Up @@ -57,8 +58,18 @@ pub trait ScriptEngine {
}

/// Evaluate script context
#[derive(Debug, Default)]
pub struct EvalContext {}
#[derive(Debug)]
pub struct EvalContext {
pub query_ctx: QueryContextRef,
}

impl Default for EvalContext {
fn default() -> Self {
Self {
query_ctx: QueryContext::arc(),
}
}
}

/// Compile script context
#[derive(Debug, Default)]
Expand Down

0 comments on commit 5e89472

Please sign in to comment.