Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: adds parse options for SQL parser #3193

Merged
merged 6 commits into from
Jan 19, 2024
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
6 changes: 3 additions & 3 deletions src/cmd/src/cli/repl.rs
Original file line number Diff line number Diff line change
Expand Up @@ -157,11 +157,11 @@ impl Repl {
let start = Instant::now();

let output = if let Some(query_engine) = &self.query_engine {
let stmt = QueryLanguageParser::parse_sql(&sql)
.with_context(|_| ParseSqlSnafu { sql: sql.clone() })?;

let query_ctx = QueryContext::with(self.database.catalog(), self.database.schema());

let stmt = QueryLanguageParser::parse_sql(&sql, &query_ctx)
.with_context(|_| ParseSqlSnafu { sql: sql.clone() })?;

let plan = query_engine
.planner()
.plan(stmt, query_ctx)
Expand Down
1 change: 1 addition & 0 deletions src/datanode/src/region_server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -596,6 +596,7 @@ impl RegionServerInner {
} = request;
let region_id = RegionId::from_u64(region_id);

// Build query context from gRPC header
let ctx: QueryContextRef = header
.as_ref()
.map(|h| Arc::new(h.into()))
Expand Down
10 changes: 6 additions & 4 deletions src/frontend/src/instance.rs
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ use servers::server::{start_server, ServerHandlers};
use session::context::QueryContextRef;
use snafu::prelude::*;
use sql::dialect::Dialect;
use sql::parser::ParserContext;
use sql::parser::{ParseOptions, ParserContext};
use sql::statements::copy::CopyTable;
use sql::statements::statement::Statement;
use sqlparser::ast::ObjectName;
Expand Down Expand Up @@ -253,7 +253,7 @@ impl FrontendInstance for Instance {
}

fn parse_stmt(sql: &str, dialect: &(dyn Dialect + Send + Sync)) -> Result<Vec<Statement>> {
ParserContext::create_with_dialect(sql, dialect).context(ParseSqlSnafu)
ParserContext::create_with_dialect(sql, dialect, ParseOptions::default()).context(ParseSqlSnafu)
}

impl Instance {
Expand Down Expand Up @@ -414,8 +414,10 @@ impl PrometheusHandler for Instance {
.check_permission(query_ctx.current_user(), PermissionReq::PromQuery)
.context(AuthSnafu)?;

let stmt = QueryLanguageParser::parse_promql(query).with_context(|_| ParsePromQLSnafu {
query: query.clone(),
let stmt = QueryLanguageParser::parse_promql(query, &query_ctx).with_context(|_| {
ParsePromQLSnafu {
query: query.clone(),
}
})?;

let output = self
Expand Down
11 changes: 6 additions & 5 deletions src/operator/src/expr_factory.rs
Original file line number Diff line number Diff line change
Expand Up @@ -411,18 +411,19 @@ pub(crate) fn to_alter_expr(
mod tests {
use session::context::QueryContext;
use sql::dialect::GreptimeDbDialect;
use sql::parser::ParserContext;
use sql::parser::{ParseOptions, ParserContext};
use sql::statements::statement::Statement;

use super::*;

#[test]
fn test_create_to_expr() {
let sql = "CREATE TABLE monitor (host STRING,ts TIMESTAMP,TIME INDEX (ts),PRIMARY KEY(host)) ENGINE=mito WITH(regions=1, ttl='3days', write_buffer_size='1024KB');";
let stmt = ParserContext::create_with_dialect(sql, &GreptimeDbDialect {})
.unwrap()
.pop()
.unwrap();
let stmt =
ParserContext::create_with_dialect(sql, &GreptimeDbDialect {}, ParseOptions::default())
.unwrap()
.pop()
.unwrap();

let Statement::CreateTable(create_table) = stmt else {
unreachable!()
Expand Down
9 changes: 7 additions & 2 deletions src/operator/src/statement/ddl.rs
Original file line number Diff line number Diff line change
Expand Up @@ -632,7 +632,7 @@ fn merge_options(mut table_opts: TableOptions, schema_opts: SchemaNameValue) ->
mod test {
use session::context::QueryContext;
use sql::dialect::GreptimeDbDialect;
use sql::parser::ParserContext;
use sql::parser::{ParseOptions, ParserContext};
use sql::statements::statement::Statement;

use super::*;
Expand Down Expand Up @@ -698,7 +698,12 @@ ENGINE=mito",
),
];
for (sql, expected) in cases {
let result = ParserContext::create_with_dialect(sql, &GreptimeDbDialect {}).unwrap();
let result = ParserContext::create_with_dialect(
sql,
&GreptimeDbDialect {},
ParseOptions::default(),
)
.unwrap();
match &result[0] {
Statement::CreateTable(c) => {
let expr = expr_factory::create_to_expr(c, QueryContext::arc()).unwrap();
Expand Down
6 changes: 3 additions & 3 deletions src/operator/src/statement/tql.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,15 +35,15 @@ impl StatementExecutor {
step: eval.step,
query: eval.query,
};
QueryLanguageParser::parse_promql(&promql).context(ParseQuerySnafu)?
QueryLanguageParser::parse_promql(&promql, &query_ctx).context(ParseQuerySnafu)?
}
Tql::Explain(explain) => {
let promql = PromQuery {
query: explain.query,
..PromQuery::default()
};
let params = HashMap::from([("name".to_string(), EXPLAIN_NODE_NAME.to_string())]);
QueryLanguageParser::parse_promql(&promql)
QueryLanguageParser::parse_promql(&promql, &query_ctx)
.context(ParseQuerySnafu)?
.post_process(params)
.unwrap()
Expand All @@ -56,7 +56,7 @@ impl StatementExecutor {
query: tql_analyze.query,
};
let params = HashMap::from([("name".to_string(), ANALYZE_NODE_NAME.to_string())]);
QueryLanguageParser::parse_promql(&promql)
QueryLanguageParser::parse_promql(&promql, &query_ctx)
.context(ParseQuerySnafu)?
.post_process(params)
.unwrap()
Expand Down
9 changes: 5 additions & 4 deletions src/query/src/datafusion.rs
Original file line number Diff line number Diff line change
Expand Up @@ -519,7 +519,7 @@ mod tests {
use datatypes::vectors::{Helper, StringVectorBuilder, UInt32Vector, UInt64Vector, VectorRef};
use session::context::QueryContext;
use sql::dialect::GreptimeDbDialect;
use sql::parser::ParserContext;
use sql::parser::{ParseOptions, ParserContext};
use sql::statements::show::{ShowKind, ShowTables};
use sql::statements::statement::Statement;
use table::table::numbers::{NumbersTable, NUMBERS_TABLE_NAME};
Expand Down Expand Up @@ -547,7 +547,7 @@ mod tests {
let engine = create_test_engine().await;
let sql = "select sum(number) from numbers limit 20";

let stmt = QueryLanguageParser::parse_sql(sql).unwrap();
let stmt = QueryLanguageParser::parse_sql(sql, &QueryContext::arc()).unwrap();
let plan = engine
.planner()
.plan(stmt, QueryContext::arc())
Expand All @@ -569,7 +569,7 @@ mod tests {
let engine = create_test_engine().await;
let sql = "select sum(number) from numbers limit 20";

let stmt = QueryLanguageParser::parse_sql(sql).unwrap();
let stmt = QueryLanguageParser::parse_sql(sql, &QueryContext::arc()).unwrap();
let plan = engine
.planner()
.plan(stmt, QueryContext::arc())
Expand Down Expand Up @@ -643,7 +643,7 @@ mod tests {
let engine = create_test_engine().await;
let sql = "select sum(number) from numbers limit 20";

let stmt = QueryLanguageParser::parse_sql(sql).unwrap();
let stmt = QueryLanguageParser::parse_sql(sql, &QueryContext::arc()).unwrap();

let plan = engine
.planner()
Expand Down Expand Up @@ -709,6 +709,7 @@ mod tests {
let statement = ParserContext::create_with_dialect(
"SHOW TABLES WHERE \"Tables\"='monitor'",
&GreptimeDbDialect {},
ParseOptions::default(),
)
.unwrap()[0]
.clone();
Expand Down
28 changes: 18 additions & 10 deletions src/query/src/parser.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,10 @@ use common_error::status_code::StatusCode;
use promql_parser::parser::ast::{Extension as NodeExtension, ExtensionExpr};
use promql_parser::parser::Expr::Extension;
use promql_parser::parser::{EvalStmt, Expr, ValueType};
use session::context::QueryContextRef;
use snafu::{OptionExt, ResultExt};
use sql::dialect::GreptimeDbDialect;
use sql::parser::ParserContext;
use sql::parser::{ParseOptions, ParserContext};
use sql::statements::statement::Statement;

use crate::error::{
Expand Down Expand Up @@ -101,16 +102,19 @@ impl Default for PromQuery {
}
}

/// Query language parser, supports parsing SQL and PromQL
pub struct QueryLanguageParser {}

impl QueryLanguageParser {
pub fn parse_sql(sql: &str) -> Result<QueryStatement> {
/// Try to parse SQL with GreptimeDB dialect, return the statement when success.
pub fn parse_sql(sql: &str, _query_ctx: &QueryContextRef) -> Result<QueryStatement> {
let _timer = METRIC_PARSE_SQL_ELAPSED.start_timer();
let mut statement = ParserContext::create_with_dialect(sql, &GreptimeDbDialect {})
.map_err(BoxedError::new)
.context(QueryParseSnafu {
query: sql.to_string(),
})?;
let mut statement =
ParserContext::create_with_dialect(sql, &GreptimeDbDialect {}, ParseOptions::default())
.map_err(BoxedError::new)
.context(QueryParseSnafu {
query: sql.to_string(),
})?;
killme2008 marked this conversation as resolved.
Show resolved Hide resolved
if statement.len() != 1 {
MultipleStatementsSnafu {
query: sql.to_string(),
Expand All @@ -121,7 +125,8 @@ impl QueryLanguageParser {
}
}

pub fn parse_promql(query: &PromQuery) -> Result<QueryStatement> {
/// Try to parse PromQL, return the statement when success.
pub fn parse_promql(query: &PromQuery, _query_ctx: &QueryContextRef) -> Result<QueryStatement> {
let _timer = METRIC_PARSE_PROMQL_ELAPSED.start_timer();

let expr = promql_parser::parser::parse(&query.query)
Expand Down Expand Up @@ -165,6 +170,7 @@ impl QueryLanguageParser {
}

fn parse_promql_timestamp(timestamp: &str) -> Result<SystemTime> {
// FIXME(dennis): aware of timezone
// try rfc3339 format
let rfc3339_result = DateTime::parse_from_rfc3339(timestamp)
.context(ParseTimestampSnafu { raw: timestamp })
Expand Down Expand Up @@ -236,13 +242,15 @@ define_node_ast_extension!(Explain, ExplainExpr, Expr, EXPLAIN_NODE_NAME);

#[cfg(test)]
mod test {
use session::context::QueryContext;

use super::*;

// Detailed logic tests are covered in the parser crate.
#[test]
fn parse_sql_simple() {
let sql = "select * from t1";
let stmt = QueryLanguageParser::parse_sql(sql).unwrap();
let stmt = QueryLanguageParser::parse_sql(sql, &QueryContext::arc()).unwrap();
let expected = String::from("Sql(Query(Query { \
inner: Query { \
with: None, body: Select(Select { \
Expand Down Expand Up @@ -360,7 +368,7 @@ mod test {
})",
);

let result = QueryLanguageParser::parse_promql(&promql).unwrap();
let result = QueryLanguageParser::parse_promql(&promql, &QueryContext::arc()).unwrap();
assert_eq!(format!("{result:?}"), expected);
}
}
2 changes: 1 addition & 1 deletion src/query/src/range_select/plan_rewrite.rs
Original file line number Diff line number Diff line change
Expand Up @@ -526,7 +526,7 @@ mod test {
}

async fn do_query(sql: &str) -> Result<crate::plan::LogicalPlan> {
let stmt = QueryLanguageParser::parse_sql(sql).unwrap();
let stmt = QueryLanguageParser::parse_sql(sql, &QueryContext::arc()).unwrap();
let engine = create_test_engine().await;
engine.planner().plan(stmt, QueryContext::arc()).await
}
Expand Down
2 changes: 1 addition & 1 deletion src/query/src/sql.rs
Original file line number Diff line number Diff line change
Expand Up @@ -657,7 +657,7 @@ mod test {
variable: ObjectName(vec![Ident::new(variable)]),
};
let ctx = QueryContextBuilder::default()
.timezone(Timezone::from_tz_string(tz).unwrap())
.timezone(Arc::new(Timezone::from_tz_string(tz).unwrap()))
.build();
match show_variable(stmt, ctx) {
Ok(Output::RecordBatches(record)) => {
Expand Down
5 changes: 3 additions & 2 deletions src/query/src/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,13 +36,14 @@ mod function;
mod pow;

async fn exec_selection(engine: QueryEngineRef, sql: &str) -> Vec<RecordBatch> {
let stmt = QueryLanguageParser::parse_sql(sql).unwrap();
let query_ctx = QueryContext::arc();
let stmt = QueryLanguageParser::parse_sql(sql, &query_ctx).unwrap();
let plan = engine
.planner()
.plan(stmt, QueryContext::arc())
.await
.unwrap();
let Output::Stream(stream) = engine.execute(plan, QueryContext::arc()).await.unwrap() else {
let Output::Stream(stream) = engine.execute(plan, query_ctx).await.unwrap() else {
unreachable!()
};
util::collect(stream).await.unwrap()
Expand Down
10 changes: 8 additions & 2 deletions src/query/src/tests/query_engine_test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -131,14 +131,20 @@ async fn test_query_validate() -> Result<()> {
let factory = QueryEngineFactory::new_with_plugins(catalog_list, None, None, false, plugins);
let engine = factory.query_engine();

let stmt = QueryLanguageParser::parse_sql("select number from public.numbers").unwrap();
let stmt =
QueryLanguageParser::parse_sql("select number from public.numbers", &QueryContext::arc())
.unwrap();
assert!(engine
.planner()
.plan(stmt, QueryContext::arc())
.await
.is_ok());

let stmt = QueryLanguageParser::parse_sql("select number from wrongschema.numbers").unwrap();
let stmt = QueryLanguageParser::parse_sql(
"select number from wrongschema.numbers",
&QueryContext::arc(),
)
.unwrap();
assert!(engine
.planner()
.plan(stmt, QueryContext::arc())
Expand Down
15 changes: 13 additions & 2 deletions src/script/src/engine.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ use std::collections::HashMap;
use async_trait::async_trait;
use common_error::ext::ErrorExt;
use common_query::Output;
use session::context::{QueryContext, QueryContextRef};

#[async_trait]
pub trait Script {
Expand Down Expand Up @@ -57,8 +58,18 @@ pub trait ScriptEngine {
}

/// Evaluate script context
#[derive(Debug, Default)]
pub struct EvalContext {}
#[derive(Debug)]
pub struct EvalContext {
pub query_ctx: QueryContextRef,
}

impl Default for EvalContext {
fn default() -> Self {
Self {
query_ctx: QueryContext::arc(),
}
}
}

/// Compile script context
#[derive(Debug, Default)]
Expand Down