From 647ddce6f1295e52087c3e468da1fa4548b1a846 Mon Sep 17 00:00:00 2001 From: Boyd Johnson Date: Thu, 4 Jan 2024 18:38:55 +0000 Subject: [PATCH] Add support for PostgreSQL Insert table aliases (#1069) --- src/ast/mod.rs | 10 ++ src/parser/mod.rs | 9 ++ tests/sqlparser_postgres.rs | 198 ++++++++++++++++++++++++++++++++++++ 3 files changed, 217 insertions(+) diff --git a/src/ast/mod.rs b/src/ast/mod.rs index 1112236a1c..9ce331ec15 100644 --- a/src/ast/mod.rs +++ b/src/ast/mod.rs @@ -14,6 +14,7 @@ #[cfg(not(feature = "std"))] use alloc::{ boxed::Box, + format, string::{String, ToString}, vec::Vec, }; @@ -1428,6 +1429,8 @@ pub enum Statement { /// TABLE #[cfg_attr(feature = "visitor", visit(with = "visit_relation"))] table_name: ObjectName, + /// table_name as foo (for PostgreSQL) + table_alias: Option, /// COLUMNS columns: Vec, /// Overwrite (Hive) @@ -2400,6 +2403,7 @@ impl fmt::Display for Statement { ignore, into, table_name, + table_alias, overwrite, partitioned, columns, @@ -2411,6 +2415,12 @@ impl fmt::Display for Statement { replace_into, priority, } => { + let table_name = if let Some(alias) = table_alias { + format!("{table_name} AS {alias}") + } else { + table_name.to_string() + }; + if let Some(action) = or { write!(f, "INSERT OR {action} INTO {table_name} ")?; } else { diff --git a/src/parser/mod.rs b/src/parser/mod.rs index 853ab3d176..45ac12de2a 100644 --- a/src/parser/mod.rs +++ b/src/parser/mod.rs @@ -7456,6 +7456,14 @@ impl<'a> Parser<'a> { // Hive lets you put table here regardless let table = self.parse_keyword(Keyword::TABLE); let table_name = self.parse_object_name()?; + + let table_alias = + if dialect_of!(self is PostgreSqlDialect) && self.parse_keyword(Keyword::AS) { + Some(self.parse_identifier()?) + } else { + None + }; + let is_mysql = dialect_of!(self is MySqlDialect); let (columns, partitioned, after_columns, source) = @@ -7530,6 +7538,7 @@ impl<'a> Parser<'a> { Ok(Statement::Insert { or, table_name, + table_alias, ignore, into, overwrite, diff --git a/tests/sqlparser_postgres.rs b/tests/sqlparser_postgres.rs index b075a9b4d8..6dd6688d50 100644 --- a/tests/sqlparser_postgres.rs +++ b/tests/sqlparser_postgres.rs @@ -3500,3 +3500,201 @@ fn parse_join_constraint_unnest_alias() { }] ); } + +#[test] +fn test_complex_postgres_insert_with_alias() { + let sql1 = "WITH existing AS (SELECT test_table.id FROM test_tables AS test_table WHERE (a = 12) AND (b = 34)), inserted AS (INSERT INTO test_tables AS test_table (id, a, b, c) VALUES (DEFAULT, 56, 78, 90) ON CONFLICT(a, b) DO UPDATE SET c = EXCLUDED.c WHERE (test_table.c <> EXCLUDED.c)) SELECT c FROM existing"; + + pg().verified_stmt(sql1); +} + +#[cfg(not(feature = "bigdecimal"))] +#[test] +fn test_simple_postgres_insert_with_alias() { + let sql2 = "INSERT INTO test_tables AS test_table (id, a) VALUES (DEFAULT, 123)"; + + let statement = pg().verified_stmt(sql2); + + assert_eq!( + statement, + Statement::Insert { + or: None, + ignore: false, + into: true, + table_name: ObjectName(vec![Ident { + value: "test_tables".to_string(), + quote_style: None + }]), + table_alias: Some(Ident { + value: "test_table".to_string(), + quote_style: None + }), + columns: vec![ + Ident { + value: "id".to_string(), + quote_style: None + }, + Ident { + value: "a".to_string(), + quote_style: None + } + ], + overwrite: false, + source: Some(Box::new(Query { + with: None, + body: Box::new(SetExpr::Values(Values { + explicit_row: false, + rows: vec![vec![ + Expr::Identifier(Ident { + value: "DEFAULT".to_string(), + quote_style: None + }), + Expr::Value(Value::Number("123".to_string(), false)) + ]] + })), + order_by: vec![], + limit: None, + limit_by: vec![], + offset: None, + fetch: None, + locks: vec![], + for_clause: None + })), + partitioned: None, + after_columns: vec![], + table: false, + on: None, + returning: None, + replace_into: false, + priority: None + } + ) +} + +#[cfg(feature = "bigdecimal")] +#[test] +fn test_simple_postgres_insert_with_alias() { + let sql2 = "INSERT INTO test_tables AS test_table (id, a) VALUES (DEFAULT, 123)"; + + let statement = pg().verified_stmt(sql2); + + assert_eq!( + statement, + Statement::Insert { + or: None, + ignore: false, + into: true, + table_name: ObjectName(vec![Ident { + value: "test_tables".to_string(), + quote_style: None + }]), + table_alias: Some(Ident { + value: "test_table".to_string(), + quote_style: None + }), + columns: vec![ + Ident { + value: "id".to_string(), + quote_style: None + }, + Ident { + value: "a".to_string(), + quote_style: None + } + ], + overwrite: false, + source: Some(Box::new(Query { + with: None, + body: Box::new(SetExpr::Values(Values { + explicit_row: false, + rows: vec![vec![ + Expr::Identifier(Ident { + value: "DEFAULT".to_string(), + quote_style: None + }), + Expr::Value(Value::Number( + bigdecimal::BigDecimal::new(123.into(), 0), + false + )) + ]] + })), + order_by: vec![], + limit: None, + limit_by: vec![], + offset: None, + fetch: None, + locks: vec![], + for_clause: None + })), + partitioned: None, + after_columns: vec![], + table: false, + on: None, + returning: None, + replace_into: false, + priority: None + } + ) +} + +#[test] +fn test_simple_insert_with_quoted_alias() { + let sql = r#"INSERT INTO test_tables AS "Test_Table" (id, a) VALUES (DEFAULT, '0123')"#; + + let statement = pg().verified_stmt(sql); + + assert_eq!( + statement, + Statement::Insert { + or: None, + ignore: false, + into: true, + table_name: ObjectName(vec![Ident { + value: "test_tables".to_string(), + quote_style: None + }]), + table_alias: Some(Ident { + value: "Test_Table".to_string(), + quote_style: Some('"') + }), + columns: vec![ + Ident { + value: "id".to_string(), + quote_style: None + }, + Ident { + value: "a".to_string(), + quote_style: None + } + ], + overwrite: false, + source: Some(Box::new(Query { + with: None, + body: Box::new(SetExpr::Values(Values { + explicit_row: false, + rows: vec![vec![ + Expr::Identifier(Ident { + value: "DEFAULT".to_string(), + quote_style: None + }), + Expr::Value(Value::SingleQuotedString("0123".to_string())) + ]] + })), + order_by: vec![], + limit: None, + limit_by: vec![], + offset: None, + fetch: None, + locks: vec![], + for_clause: None + })), + partitioned: None, + after_columns: vec![], + table: false, + on: None, + returning: None, + replace_into: false, + priority: None + } + ) +}