From ddbff7974583832f4ed7f8a04fdf61cb6caf4961 Mon Sep 17 00:00:00 2001 From: Eugene Golovan Date: Tue, 20 Nov 2018 12:39:57 +0200 Subject: [PATCH] [SPARK-26077][SQL] Reserved SQL words are not escaped by JDBC writer for table name --- .../datasources/jdbc/JDBCOptions.scala | 4 ++- .../datasources/jdbc/JdbcUtils.scala | 16 +++++---- .../spark/sql/jdbc/AggregatedDialect.scala | 8 +++++ .../apache/spark/sql/jdbc/JdbcDialects.scala | 27 +++++++++++++-- .../apache/spark/sql/jdbc/MySQLDialect.scala | 4 +-- .../org/apache/spark/sql/jdbc/JDBCSuite.scala | 33 +++++++++++++++---- 6 files changed, 72 insertions(+), 20 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCOptions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCOptions.scala index 7dfbb9d8b5c05..06ee0793eb73c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCOptions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCOptions.scala @@ -21,6 +21,7 @@ import java.sql.{Connection, DriverManager} import java.util.{Locale, Properties} import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap +import org.apache.spark.sql.jdbc.JdbcDialects import org.apache.spark.sql.types.StructType /** @@ -81,7 +82,8 @@ class JDBCOptions( if (name.isEmpty) { throw new IllegalArgumentException(s"Option '$JDBC_TABLE_NAME' can not be empty.") } else { - name.trim + val dialect = JdbcDialects.get(url) + dialect.quoteIdentifier(name.trim) } case (None, Some(subquery)) => if (subquery.isEmpty) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala index edea549748b47..2e0f09365fb2e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala @@ -74,7 +74,8 @@ object JdbcUtils extends Logging { // SQL database systems using JDBC meta data calls, considering "table" could also include // the database name. Query used to find table exists can be overridden by the dialects. Try { - val statement = conn.prepareStatement(dialect.getTableExistsQuery(options.table)) + val table = dialect.quoteIdentifier(options.table) + val statement = conn.prepareStatement(dialect.getTableExistsQuery(table)) try { statement.setQueryTimeout(options.queryTimeout) statement.executeQuery() @@ -88,10 +89,11 @@ object JdbcUtils extends Logging { * Drops a table from the JDBC database. */ def dropTable(conn: Connection, table: String, options: JDBCOptions): Unit = { + val dialect = JdbcDialects.get(options.url) val statement = conn.createStatement try { statement.setQueryTimeout(options.queryTimeout) - statement.executeUpdate(s"DROP TABLE $table") + statement.executeUpdate(s"DROP TABLE ${dialect.quoteIdentifier(table)}") } finally { statement.close() } @@ -105,10 +107,11 @@ object JdbcUtils extends Logging { val statement = conn.createStatement try { statement.setQueryTimeout(options.queryTimeout) + val table = dialect.quoteIdentifier(options.table) val truncateQuery = if (options.isCascadeTruncate.isDefined) { - dialect.getTruncateQuery(options.table, options.isCascadeTruncate) + dialect.getTruncateQuery(table, options.isCascadeTruncate) } else { - dialect.getTruncateQuery(options.table) + dialect.getTruncateQuery(table) } statement.executeUpdate(truncateQuery) } finally { @@ -150,7 +153,7 @@ object JdbcUtils extends Logging { }.mkString(",") } val placeholders = rddSchema.fields.map(_ => "?").mkString(",") - s"INSERT INTO $table ($columns) VALUES ($placeholders)" + s"INSERT INTO ${dialect.quoteIdentifier(table)} ($columns) VALUES ($placeholders)" } /** @@ -848,11 +851,12 @@ object JdbcUtils extends Logging { df, options.url, options.createTableColumnTypes) val table = options.table val createTableOptions = options.createTableOptions + val dialect = JdbcDialects.get(options.url) // Create the table if the table does not exist. // To allow certain options to append when create a new table, which can be // table_options or partition_options. // E.g., "CREATE TABLE t (name string) ENGINE=InnoDB DEFAULT CHARSET=utf8" - val sql = s"CREATE TABLE $table ($strSchema) $createTableOptions" + val sql = s"CREATE TABLE ${dialect.quoteIdentifier(table)} ($strSchema) $createTableOptions" val statement = conn.createStatement try { statement.setQueryTimeout(options.queryTimeout) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/AggregatedDialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/AggregatedDialect.scala index 3a3246a1b1d13..575c845f953af 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/AggregatedDialect.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/AggregatedDialect.scala @@ -42,10 +42,18 @@ private class AggregatedDialect(dialects: List[JdbcDialect]) extends JdbcDialect dialects.flatMap(_.getJDBCType(dt)).headOption } + override def getIdentifierQuoteCharacter: String = { + dialects.head.getIdentifierQuoteCharacter + } + override def quoteIdentifier(colName: String): String = { dialects.head.quoteIdentifier(colName) } + override def quoteSingleIdentifier(colName: String): String = { + dialects.head.quoteSingleIdentifier(colName) + } + override def getTableExistsQuery(table: String): String = { dialects.head.getTableExistsQuery(table) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala index f76c1fae562c6..bd92accb90b25 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala @@ -87,11 +87,32 @@ abstract class JdbcDialect extends Serializable { def getJDBCType(dt: DataType): Option[JdbcType] = None /** - * Quotes the identifier. This is used to put quotes around the identifier in case the column - * name is a reserved keyword, or in case it contains characters that require quotes (e.g. space). + * Gets the character used for identifier quoting. + */ + def getIdentifierQuoteCharacter: String = """"""" + + /** + * Quotes the identifier. This is used to put quotes around the identifier in case + * the table or column name is a reserved keyword, or in case it contains characters + * that require quotes (e.g. space). */ def quoteIdentifier(colName: String): String = { - s""""$colName"""" + if (colName.startsWith("(")) { + // assuming this is a subquery and do nothing in this case + colName + } else if (colName.contains(".")) { + colName.split("\\.").map(quoteSingleIdentifier).mkString(".") + } else { + quoteSingleIdentifier(colName) + } + } + + /** + * Quotes a single identifier (no dot chain separation). + */ + def quoteSingleIdentifier(colName: String): String = { + val quoteChar = getIdentifierQuoteCharacter + quoteChar + colName.replace(quoteChar, "") + quoteChar } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MySQLDialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MySQLDialect.scala index b2cff7877d8b5..1f1b3541aa94b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MySQLDialect.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MySQLDialect.scala @@ -37,9 +37,7 @@ private case object MySQLDialect extends JdbcDialect { } else None } - override def quoteIdentifier(colName: String): String = { - s"`$colName`" - } + override def getIdentifierQuoteCharacter: String = "`" override def getTableExistsQuery(table: String): String = { s"SELECT 1 FROM $table LIMIT 1" diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala index 7fa0e7fc162ca..497b5812a788b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala @@ -32,6 +32,7 @@ import org.apache.spark.sql.execution.DataSourceScanExec import org.apache.spark.sql.execution.command.ExplainCommand import org.apache.spark.sql.execution.datasources.LogicalRelation import org.apache.spark.sql.execution.datasources.jdbc.{JDBCOptions, JDBCPartition, JDBCRDD, JDBCRelation, JdbcUtils} +import org.apache.spark.sql.execution.datasources.jdbc.JdbcUtils.getInsertStatement import org.apache.spark.sql.execution.metric.InputOutputMetricsHelper import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.sources._ @@ -506,7 +507,7 @@ class JDBCSuite extends QueryTest test("Partitioning on column where numPartitions is zero") { val res = spark.read.jdbc( url = urlWithUserAndPass, - table = "TEST.seq", + table = "TEST.SEQ", columnName = "id", lowerBound = 0, upperBound = 4, @@ -520,7 +521,7 @@ class JDBCSuite extends QueryTest test("Partitioning on column where numPartitions are more than the number of total rows") { val res = spark.read.jdbc( url = urlWithUserAndPass, - table = "TEST.seq", + table = "TEST.SEQ", columnName = "id", lowerBound = 1, upperBound = 5, @@ -534,7 +535,7 @@ class JDBCSuite extends QueryTest test("Partitioning on column where lowerBound is equal to upperBound") { val res = spark.read.jdbc( url = urlWithUserAndPass, - table = "TEST.seq", + table = "TEST.SEQ", columnName = "id", lowerBound = 5, upperBound = 5, @@ -549,7 +550,7 @@ class JDBCSuite extends QueryTest val e = intercept[IllegalArgumentException] { spark.read.jdbc( url = urlWithUserAndPass, - table = "TEST.seq", + table = "TEST.SEQ", columnName = "id", lowerBound = 5, upperBound = 1, @@ -941,7 +942,7 @@ class JDBCSuite extends QueryTest // test the JdbcRelation toString output df.queryExecution.analyzed.collect { case r: LogicalRelation => - assert(r.relation.toString == "JDBCRelation(TEST.PEOPLE) [numPartitions=3]") + assert(r.relation.toString == """JDBCRelation("TEST"."PEOPLE") [numPartitions=3]""") } } @@ -1040,6 +1041,24 @@ class JDBCSuite extends QueryTest assert(schema.contains("`order` TEXT")) } + test("SPARK-26077: Reserved SQL words are not escaped by JDBC writer for table name") { + val df = spark.createDataset(Seq("a", "b", "c")).toDF("order") + val url = "jdbc:mysql://localhost:3306/temp" + val table = "temp.test" + val rddSchema = df.schema + val dialect = JdbcDialects.get(url) + + val options1 = new JDBCOptions(Map("url" -> url, "dbtable" -> table)) + assert(options1.tableOrQuery == "`temp`.`test`") + + val t2 = "(SELECT 1) as t2" + val options2 = new JDBCOptions(Map("url" -> url, "dbtable" -> t2)) + assert(options2.tableOrQuery == t2) + + val insertStmt = getInsertStatement(table, rddSchema, None, isCaseSensitive = true, dialect) + assert(insertStmt.contains("`temp`.`test`")) + } + test("SPARK-18141: Predicates on quoted column names in the jdbc data source") { assert(sql("SELECT * FROM mixedCaseCols WHERE Id < 1").collect().size == 0) assert(sql("SELECT * FROM mixedCaseCols WHERE Id <= 1").collect().size == 1) @@ -1144,7 +1163,7 @@ class JDBCSuite extends QueryTest test("SPARK-19318: Connection properties keys should be case-sensitive.") { def testJdbcOptions(options: JDBCOptions): Unit = { // Spark JDBC data source options are case-insensitive - assert(options.tableOrQuery == "t1") + assert(options.tableOrQuery == """"t1"""") // When we convert it to properties, it should be case-sensitive. assert(options.asProperties.size == 3) assert(options.asProperties.get("customkey") == null) @@ -1454,7 +1473,7 @@ class JDBCSuite extends QueryTest } test("SPARK-24288: Enable preventing predicate pushdown") { - val table = "test.people" + val table = "TEST.PEOPLE" val df = spark.read.format("jdbc") .option("Url", urlWithUserAndPass)