diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 000000000..b773e8e91 --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,10 @@ +repos: + - repo: https://github.com/doublify/pre-commit-rust + rev: v1.0 + hooks: + - id: cargo-check + args: [ "--workspace" ] + - id: fmt + args: [ "--", "--check" ] + - id: clippy + args: [ "--", "-D", "warnings" ] diff --git a/Cargo.toml b/Cargo.toml index a8bd8b86b..c25256644 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -35,7 +35,7 @@ datafusion = { version = "43" } snafu = { version = "0.8.5", features = ["futures"] } [patch.crates-io] -datafusion = { git="https://github.com/Embucket/datafusion.git", rev = "bc5a978a4102391b2b7427dfdf94dd4e2667be49" } +datafusion = { git="https://github.com/Embucket/datafusion.git", rev = "cc37c5920463b0cb0b224fc7f567fd4ae0368ffe" } [workspace.lints.clippy] diff --git a/crates/control_plane/Cargo.toml b/crates/control_plane/Cargo.toml index baf7af042..bd61d41da 100644 --- a/crates/control_plane/Cargo.toml +++ b/crates/control_plane/Cargo.toml @@ -18,13 +18,13 @@ flatbuffers = { version = "24.3.25" } #iceberg-rest-catalog = { git = "https://github.com/JanKaul/iceberg-rust.git", rev = "836f11f" } #datafusion_iceberg = { git = "https://github.com/JanKaul/iceberg-rust.git", rev = "836f11f" } -datafusion = { git="https://github.com/Embucket/datafusion.git", rev = "bc5a978a4102391b2b7427dfdf94dd4e2667be49" } -datafusion-common = { git="https://github.com/Embucket/datafusion.git", rev = "bc5a978a4102391b2b7427dfdf94dd4e2667be49" } -datafusion-expr = { git="https://github.com/Embucket/datafusion.git", rev = "bc5a978a4102391b2b7427dfdf94dd4e2667be49" } +datafusion = { git="https://github.com/Embucket/datafusion.git", rev = "cc37c5920463b0cb0b224fc7f567fd4ae0368ffe" } +datafusion-common = { git="https://github.com/Embucket/datafusion.git", rev = "cc37c5920463b0cb0b224fc7f567fd4ae0368ffe" } +datafusion-expr = { git="https://github.com/Embucket/datafusion.git", rev = "cc37c5920463b0cb0b224fc7f567fd4ae0368ffe" } -iceberg-rust = { git = "https://github.com/Embucket/iceberg-rust.git", rev = "5d4211521085722de35b14c444da087c52309771" } -iceberg-rest-catalog = { git = "https://github.com/Embucket/iceberg-rust.git", rev = "5d4211521085722de35b14c444da087c52309771" } -datafusion_iceberg = { git = "https://github.com/Embucket/iceberg-rust.git", rev = "5d4211521085722de35b14c444da087c52309771" } +iceberg-rust = { git = "https://github.com/Embucket/iceberg-rust.git", rev = "68577441273eda894f1eb6d87b1c3e87dee0fdf6" } +iceberg-rest-catalog = { git = "https://github.com/Embucket/iceberg-rust.git", rev = "68577441273eda894f1eb6d87b1c3e87dee0fdf6" } +datafusion_iceberg = { git = "https://github.com/Embucket/iceberg-rust.git", rev = "68577441273eda894f1eb6d87b1c3e87dee0fdf6" } arrow = { version = "53" } arrow-json = { version = "53" } diff --git a/crates/nexus/src/http/dbt/error.rs b/crates/nexus/src/http/dbt/error.rs index dfdce9475..9e4b881ab 100644 --- a/crates/nexus/src/http/dbt/error.rs +++ b/crates/nexus/src/http/dbt/error.rs @@ -57,9 +57,24 @@ impl IntoResponse for DbtError { Self::NotImplemented => http::StatusCode::NOT_IMPLEMENTED, }; + let message = match &self { + Self::GZipDecompress { source } => format!("failed to decompress GZip body: {source}"), + Self::LoginRequestParse { source } => { + format!("failed to parse login request: {source}") + } + Self::QueryBodyParse { source } => format!("failed to parse query body: {source}"), + Self::InvalidWarehouseIdFormat { source } => format!("invalid warehouse_id: {source}"), + Self::ControlService { source } => source.to_string(), + Self::RowParse { source } => format!("failed to parse row JSON: {source}"), + Self::MissingAuthToken | Self::MissingDbtSession | Self::InvalidAuthData => { + "session error".to_string() + } + Self::NotImplemented => "feature not implemented".to_string(), + }; + let body = Json(JsonResponse { success: false, - message: Some(self.to_string()), + message: Some(message), data: None, code: Some(status_code.as_u16().to_string()), }); diff --git a/crates/nexus/src/http/dbt/handlers.rs b/crates/nexus/src/http/dbt/handlers.rs index 425daa65c..e5b84d39c 100644 --- a/crates/nexus/src/http/dbt/handlers.rs +++ b/crates/nexus/src/http/dbt/handlers.rs @@ -32,7 +32,7 @@ pub async fn login( //println!("Received login request: {:?}", query); //println!("Body data parameters: {:?}", body_json); - let token = uuid::Uuid::new_v4().to_string(); + let token = Uuid::new_v4().to_string(); let warehouses = state .control_svc diff --git a/crates/nexus/src/main.rs b/crates/nexus/src/main.rs index 80d323292..142d19b52 100644 --- a/crates/nexus/src/main.rs +++ b/crates/nexus/src/main.rs @@ -199,7 +199,7 @@ async fn buffer_and_print( body: B, ) -> Result where - B: axum::body::HttpBody, + B: axum::body::HttpBody + Send, B::Error: std::fmt::Display, { let bytes = match body.collect().await { diff --git a/crates/runtime/Cargo.toml b/crates/runtime/Cargo.toml index 1367e355e..be1f8afcc 100644 --- a/crates/runtime/Cargo.toml +++ b/crates/runtime/Cargo.toml @@ -14,13 +14,13 @@ serde = { workspace = true } serde_json = { workspace = true } object_store = { workspace = true } -datafusion = { git="https://github.com/Embucket/datafusion.git", rev = "bc5a978a4102391b2b7427dfdf94dd4e2667be49" } -datafusion-common = { git="https://github.com/Embucket/datafusion.git", rev = "bc5a978a4102391b2b7427dfdf94dd4e2667be49" } -datafusion-expr = { git="https://github.com/Embucket/datafusion.git", rev = "bc5a978a4102391b2b7427dfdf94dd4e2667be49" } +datafusion = { git="https://github.com/Embucket/datafusion.git", rev = "cc37c5920463b0cb0b224fc7f567fd4ae0368ffe" } +datafusion-common = { git="https://github.com/Embucket/datafusion.git", rev = "cc37c5920463b0cb0b224fc7f567fd4ae0368ffe" } +datafusion-expr = { git="https://github.com/Embucket/datafusion.git", rev = "cc37c5920463b0cb0b224fc7f567fd4ae0368ffe" } -iceberg-rust = { git = "https://github.com/Embucket/iceberg-rust.git", rev = "5d4211521085722de35b14c444da087c52309771" } -iceberg-rest-catalog = { git = "https://github.com/Embucket/iceberg-rust.git", rev = "5d4211521085722de35b14c444da087c52309771" } -datafusion_iceberg = { git = "https://github.com/Embucket/iceberg-rust.git", rev = "5d4211521085722de35b14c444da087c52309771" } +iceberg-rust = { git = "https://github.com/Embucket/iceberg-rust.git", rev = "68577441273eda894f1eb6d87b1c3e87dee0fdf6" } +iceberg-rest-catalog = { git = "https://github.com/Embucket/iceberg-rust.git", rev = "68577441273eda894f1eb6d87b1c3e87dee0fdf6" } +datafusion_iceberg = { git = "https://github.com/Embucket/iceberg-rust.git", rev = "68577441273eda894f1eb6d87b1c3e87dee0fdf6" } arrow = { version = "53" } arrow-json = { version = "53" } diff --git a/crates/runtime/src/datafusion/execution.rs b/crates/runtime/src/datafusion/execution.rs index 71d6029fa..adf72cbfd 100644 --- a/crates/runtime/src/datafusion/execution.rs +++ b/crates/runtime/src/datafusion/execution.rs @@ -7,10 +7,6 @@ use crate::datafusion::functions::register_udfs; use crate::datafusion::planner::ExtendedSqlToRel; use arrow::array::{RecordBatch, UInt64Array}; use arrow::datatypes::{DataType, Field, Schema as ArrowSchema}; -use datafusion::catalog::SchemaProvider; -use datafusion::catalog_common::information_schema::InformationSchemaProvider; -use datafusion::catalog_common::{ResolvedTableReference, TableReference}; -use datafusion::common::plan_datafusion_err; use datafusion::common::tree_node::{TransformedResult, TreeNode}; use datafusion::datasource::default_table_source::provider_as_source; use datafusion::execution::context::SessionContext; @@ -32,6 +28,7 @@ use iceberg_rust::spec::namespace::Namespace; use iceberg_rust::spec::schema::Schema; use iceberg_rust::spec::types::StructType; use snafu::ResultExt; +use sqlparser::ast::{MergeAction, MergeClauseKind, MergeInsertKind, Query as AstQuery}; use std::collections::hash_map::Entry; use std::collections::HashMap; use std::sync::Arc; @@ -76,14 +73,21 @@ impl SqlExecutor { Statement::CreateSchema { schema_name, .. } => { return self.create_schema(schema_name, warehouse_name).await; } - Statement::ShowSchemas { .. } - | Statement::ShowVariable { .. } - | Statement::Query { .. } => { + Statement::AlterTable { .. } + | Statement::StartTransaction { .. } + | Statement::Commit { .. } + | Statement::Insert { .. } + | Statement::Query { .. } + | Statement::ShowSchemas { .. } + | Statement::ShowVariable { .. } => { return Box::pin(self.execute_with_custom_plan(&query, warehouse_name)).await; } Statement::Drop { .. } => { return Box::pin(self.drop_table_query(&query, warehouse_name)).await; } + Statement::Merge { .. } => { + return Box::pin(self.merge_query(*s, warehouse_name)).await; + } _ => {} } } @@ -106,7 +110,7 @@ impl SqlExecutor { pub fn preprocess_query(&self, query: &str) -> String { // Replace field[0].subfield -> json_get(json_get(field, 0), 'subfield') // TODO: This regex should be a static allocation - let re = regex::Regex::new(r"(\w+)\[(\d+)]\.(\w+)").unwrap(); + let re = regex::Regex::new(r"(\w+)\[(\d+)][:\.](\w+)").unwrap(); let date_add = regex::Regex::new(r"(date|time|timestamp)(_?add)\(\s*([a-zA-Z]+),").unwrap(); let query = re @@ -246,13 +250,107 @@ impl SqlExecutor { created_entity_response().context(ih_error::ArrowSnafu) } else { Err(super::error::IcehutSQLError::DataFusion { - source: datafusion::error::DataFusionError::NotImplemented( + source: DataFusionError::NotImplemented( "Only CREATE TABLE statements are supported".to_string(), ), }) } } + pub async fn merge_query( + &self, + statement: Statement, + warehouse_name: &str, + ) -> IcehutSQLResult> { + if let Statement::Merge { + mut table, + mut source, + on, + clauses, + .. + } = statement + { + self.update_tables_in_table_factor(&mut table, warehouse_name); + self.update_tables_in_table_factor(&mut source, warehouse_name); + + let (target_table, target_alias) = self.get_table_with_alias(table); + let (source_table, _source_alias) = self.get_table_with_alias(source.clone()); + + let source_query = if let TableFactor::Derived { + subquery, + lateral, + alias, + } = source + { + source = TableFactor::Derived { + lateral, + subquery, + alias: None, + }; + alias.map_or_else(|| source.to_string(), |alias| format!("{source} {alias}")) + } else { + source.to_string() + }; + + // Prepare WHERE clause to filter unmatched records + let where_clause = self + .get_expr_where_clause(*on.clone(), target_alias.as_str()) + .iter() + .map(|v| format!("{v} IS NULL")) + .collect::>(); + let where_clause_str = if where_clause.is_empty() { + String::new() + } else { + format!(" WHERE {}", where_clause.join(" AND ")) + }; + + // Check NOT MATCHED for records to INSERT + let select_query = + format!("SELECT * FROM {source_query} JOIN {target_table} {target_alias} ON {on}{where_clause_str}"); + self.execute_with_custom_plan(&select_query, warehouse_name) + .await?; + + // Extract columns and values from clauses + let mut columns = Vec::new(); + let mut values = Vec::new(); + for clause in clauses { + if clause.clause_kind == MergeClauseKind::NotMatched { + if let MergeAction::Insert(insert) = clause.action { + columns = insert.columns; + if let MergeInsertKind::Values(values_insert) = insert.kind { + values = values_insert.rows.into_iter().flatten().collect(); + } + } + } + } + // Construct the INSERT statement + let insert_query = format!( + "INSERT INTO {} ({}) SELECT {} FROM {}", + target_table, + columns + .iter() + .map(ToString::to_string) + .collect::>() + .join(", "), + values + .iter() + .map(ToString::to_string) + .collect::>() + .join(", "), + source_table + ); + + self.execute_with_custom_plan(&insert_query, warehouse_name) + .await + } else { + Err(super::error::IcehutSQLError::DataFusion { + source: DataFusionError::NotImplemented( + "Only MERGE statements are supported".to_string(), + ), + }) + } + } + pub async fn drop_table_query( &self, query: &str, @@ -315,7 +413,7 @@ impl SqlExecutor { } _ => { return Err(super::error::IcehutSQLError::DataFusion { - source: datafusion::error::DataFusionError::NotImplemented( + source: DataFusionError::NotImplemented( "Only simple schema names are supported".to_string(), ), }); @@ -348,9 +446,9 @@ impl SqlExecutor { .context(super::error::DataFusionSnafu)?; //println!("References: {:?}", references); for reference in references { - let resolved = self.resolve_table_ref(reference); + let resolved = state.resolve_table_ref(reference); if let Entry::Vacant(v) = ctx_provider.tables.entry(resolved.to_string()) { - if let Ok(schema) = self.schema_for_ref(resolved.clone()) { + if let Ok(schema) = state.schema_for_ref(resolved.clone()) { if let Some(table) = schema .table(&resolved.table) .await @@ -390,48 +488,12 @@ impl SqlExecutor { .context(super::error::DataFusionSnafu) } else { Err(super::error::IcehutSQLError::DataFusion { - source: datafusion::error::DataFusionError::NotImplemented( + source: DataFusionError::NotImplemented( "Only SQL statements are supported".to_string(), ), }) } } - - pub fn resolve_table_ref( - &self, - table_ref: impl Into, - ) -> ResolvedTableReference { - let catalog = &self.ctx.state().config_options().catalog.clone(); - table_ref - .into() - .resolve(&catalog.default_catalog, &catalog.default_schema) - } - - pub fn schema_for_ref( - &self, - table_ref: impl Into, - ) -> IcehutSQLResult> { - let state = self.ctx.state(); - let resolved_ref = self.resolve_table_ref(table_ref); - if state.config().information_schema() && *resolved_ref.schema == *"information_schema" { - return Ok(Arc::new(InformationSchemaProvider::new( - state.catalog_list().clone(), - ))); - } - - // Need better error handling here instead of just DF errors - state - .catalog_list() - .catalog(&resolved_ref.catalog) - .ok_or_else(|| super::error::IcehutSQLError::DataFusion { - source: plan_datafusion_err!("failed to resolve catalog: {}", resolved_ref.catalog), - })? - .schema(&resolved_ref.schema) - .ok_or_else(|| super::error::IcehutSQLError::DataFusion { - source: plan_datafusion_err!("failed to resolve schema: {}", resolved_ref.schema), - }) - } - pub async fn execute_with_custom_plan( &self, query: &str, @@ -447,6 +509,30 @@ impl SqlExecutor { .context(super::error::DataFusionSnafu) } + #[allow(clippy::only_used_in_recursion)] + fn get_expr_where_clause(&self, expr: Expr, target_alias: &str) -> Vec { + match expr { + Expr::CompoundIdentifier(ident) => { + if ident.len() > 1 && ident[0].value == target_alias { + let ident_str = ident + .iter() + .map(|v| v.value.clone()) + .collect::>() + .join("."); + return vec![ident_str]; + } + vec![] + } + Expr::BinaryOp { left, right, .. } => { + let mut left_expr = self.get_expr_where_clause(*left, target_alias); + let right_expr = self.get_expr_where_clause(*right, target_alias); + left_expr.extend(right_expr); + left_expr + } + _ => vec![], + } + } + #[must_use] pub fn update_statement_references( &self, @@ -464,11 +550,40 @@ impl SqlExecutor { DFStatement::CreateExternalTable(modified_statement) } DFStatement::Statement(s) => match *s { + Statement::AlterTable { + name, + if_exists, + only, + operations, + location, + on_cluster, + } => { + let name = self.compress_database_name(name.0, warehouse_name); + let modified_statement = Statement::AlterTable { + name: ObjectName(name), + if_exists, + only, + operations, + location, + on_cluster, + }; + DFStatement::Statement(Box::new(modified_statement)) + } Statement::Insert(insert_statement) => { let table_name = self.compress_database_name(insert_statement.table_name.0, warehouse_name); + + let source = insert_statement.source.map_or_else( + || None, + |mut query| { + self.update_tables_in_query(query.as_mut(), warehouse_name); + Some(Box::new(AstQuery { ..*query })) + }, + ); + let modified_statement = Insert { table_name: ObjectName(table_name), + source, ..insert_statement }; DFStatement::Statement(Box::new(Statement::Insert(modified_statement))) @@ -548,6 +663,28 @@ impl SqlExecutor { table_name } + #[allow(clippy::only_used_in_recursion)] + fn get_table_with_alias(&self, factor: TableFactor) -> (ObjectName, String) { + match factor { + TableFactor::Table { name, alias, .. } => { + let target_alias = alias.map_or_else(String::new, |alias| alias.to_string()); + (name, target_alias) + } + TableFactor::Derived { + subquery, alias, .. + } => { + let target_alias = alias.map_or_else(String::new, |alias| alias.to_string()); + if let sqlparser::ast::SetExpr::Select(select) = subquery.body.as_ref() { + if let Some(table_with_joins) = select.from.first() { + return self.get_table_with_alias(table_with_joins.relation.clone()); + } + } + (ObjectName(vec![]), target_alias) + } + _ => (ObjectName(vec![]), String::new()), + } + } + fn update_tables_in_query(&self, query: &mut Query, warehouse_name: &str) { if let Some(with) = query.with.as_mut() { for cte in &mut with.cte_tables { @@ -556,7 +693,7 @@ impl SqlExecutor { } match query.body.as_mut() { - datafusion::sql::sqlparser::ast::SetExpr::Select(select) => { + sqlparser::ast::SetExpr::Select(select) => { for table_with_joins in &mut select.from { self.update_tables_in_table_with_joins(table_with_joins, warehouse_name); } @@ -565,7 +702,7 @@ impl SqlExecutor { self.update_tables_in_expr(expr, warehouse_name); } } - datafusion::sql::sqlparser::ast::SetExpr::Query(q) => { + sqlparser::ast::SetExpr::Query(q) => { self.update_tables_in_query(q, warehouse_name); } _ => {} diff --git a/crates/runtime/src/datafusion/planner.rs b/crates/runtime/src/datafusion/planner.rs index 8a4622233..84b28c02f 100644 --- a/crates/runtime/src/datafusion/planner.rs +++ b/crates/runtime/src/datafusion/planner.rs @@ -91,6 +91,9 @@ where let planner_context: &mut PlannerContext = &mut PlannerContext::new(); // Example: Custom handling for a specific statement match statement.clone() { + Statement::AlterTable { .. } + | Statement::StartTransaction { .. } + | Statement::Commit { .. } => Ok(LogicalPlan::default()), Statement::ShowSchemas { .. } => self.show_variable_to_plan(&["schemas".into()]), Statement::ShowVariable { variable } => self.show_variable_to_plan(&variable), Statement::CreateTable(CreateTableStatement { @@ -305,7 +308,7 @@ where let data_type = self.convert_data_type(&field.field_type)?; let field_name = field.field_name.as_ref().map_or_else( || Ident::new(format!("c{idx}")), - std::clone::Clone::clone, + Clone::clone, ); Ok(Arc::new(Field::new( self.ident_normalizer.normalize(field_name),