Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-9078] [SQL] Allow jdbc dialects to override the query used to check the table. #8676

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -255,7 +255,7 @@ final class DataFrameWriter private[sql](df: DataFrame) {
val conn = JdbcUtils.createConnection(url, props)

try {
var tableExists = JdbcUtils.tableExists(conn, table)
var tableExists = JdbcUtils.tableExists(conn, url, table)

if (mode == SaveMode.Ignore && tableExists) {
return
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,10 +42,13 @@ object JdbcUtils extends Logging {
/**
* Returns true if the table already exists in the JDBC database.
*/
def tableExists(conn: Connection, table: String): Boolean = {
def tableExists(conn: Connection, url: String, table: String): Boolean = {
val dialect = JdbcDialects.get(url)

// Somewhat hacky, but there isn't a good way to identify whether a table exists for all
// SQL database systems, considering "table" could also include the database name.
Try(conn.prepareStatement(s"SELECT 1 FROM $table LIMIT 1").executeQuery().next()).isSuccess
// SQL database systems using JDBC meta data calls, considering "table" could also include
// the database name. Query used to find table exists can be overriden by the dialects.
Try(conn.prepareStatement(dialect.getTableExistsQuery(table)).executeQuery()).isSuccess
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,17 @@ abstract class JdbcDialect {
def quoteIdentifier(colName: String): String = {
s""""$colName""""
}

/**
* Get the SQL query that should be used to find if the given table exists. Dialects can
* override this method to return a query that works best in a particular database.
* @param table The name of the table.
* @return The SQL query to use for checking the table.
*/
def getTableExistsQuery(table: String): String = {
s"SELECT * FROM $table WHERE 1=0"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe we should quote the table here actually

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

actually never mind we cannot quote it.

Copy link

@toddleo toddleo Sep 7, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@rxin What's the specific reason table name cannot be quoted? We happen to have a table with dots and parenthesis in its name, planning to add surrounding backticks before passing it to Spark.

}

}

/**
Expand Down Expand Up @@ -198,6 +209,11 @@ case object PostgresDialect extends JdbcDialect {
case BooleanType => Some(JdbcType("BOOLEAN", java.sql.Types.BOOLEAN))
case _ => None
}

override def getTableExistsQuery(table: String): String = {
s"SELECT 1 FROM $table LIMIT 1"
}

}

/**
Expand All @@ -222,6 +238,10 @@ case object MySQLDialect extends JdbcDialect {
override def quoteIdentifier(colName: String): String = {
s"`$colName`"
}

override def getTableExistsQuery(table: String): String = {
s"SELECT 1 FROM $table LIMIT 1"
}
}

/**
Expand Down
14 changes: 14 additions & 0 deletions sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -450,4 +450,18 @@ class JDBCSuite extends SparkFunSuite with BeforeAndAfter with SharedSQLContext
assert(db2Dialect.getJDBCType(StringType).map(_.databaseTypeDefinition).get == "CLOB")
assert(db2Dialect.getJDBCType(BooleanType).map(_.databaseTypeDefinition).get == "CHAR(1)")
}

test("table exists query by jdbc dialect") {
val MySQL = JdbcDialects.get("jdbc:mysql://127.0.0.1/db")
val Postgres = JdbcDialects.get("jdbc:postgresql://127.0.0.1/db")
val db2 = JdbcDialects.get("jdbc:db2://127.0.0.1/db")
val h2 = JdbcDialects.get(url)
val table = "weblogs"
val defaultQuery = s"SELECT * FROM $table WHERE 1=0"
val limitQuery = s"SELECT 1 FROM $table LIMIT 1"
assert(MySQL.getTableExistsQuery(table) == limitQuery)
assert(Postgres.getTableExistsQuery(table) == limitQuery)
assert(db2.getTableExistsQuery(table) == defaultQuery)
assert(h2.getTableExistsQuery(table) == defaultQuery)
}
}