Skip to content

Commit

Permalink
[SPARK-18538][SQL] Fix Concurrent Table Fetching Using DataFrameReade…
Browse files Browse the repository at this point in the history
…r JDBC APIs

### What changes were proposed in this pull request?
The following two `DataFrameReader` JDBC APIs ignore the user-specified parameters of parallelism degree.

```Scala
  def jdbc(
      url: String,
      table: String,
      columnName: String,
      lowerBound: Long,
      upperBound: Long,
      numPartitions: Int,
      connectionProperties: Properties): DataFrame
```

```Scala
  def jdbc(
      url: String,
      table: String,
      predicates: Array[String],
      connectionProperties: Properties): DataFrame
```

This PR is to fix the issues. To verify the behavior correctness, we improve the plan output of `EXPLAIN` command by adding `numPartitions` in the `JDBCRelation` node.

Before the fix,
```
== Physical Plan ==
*Scan JDBCRelation(TEST.PEOPLE) [NAME#1896,THEID#1897] ReadSchema: struct<NAME:string,THEID:int>
```

After the fix,
```
== Physical Plan ==
*Scan JDBCRelation(TEST.PEOPLE) [numPartitions=3] [NAME#1896,THEID#1897] ReadSchema: struct<NAME:string,THEID:int>
```
### How was this patch tested?
Added the verification logics on all the test cases for JDBC concurrent fetching.

Author: gatorsmile <gatorsmile@gmail.com>

Closes #15975 from gatorsmile/jdbc.
  • Loading branch information
gatorsmile authored and cloud-fan committed Dec 1, 2016
1 parent 2eb6764 commit b28fe4a
Show file tree
Hide file tree
Showing 6 changed files with 81 additions and 50 deletions.
37 changes: 19 additions & 18 deletions sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,11 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging {
* @since 1.4.0
*/
def jdbc(url: String, table: String, properties: Properties): DataFrame = {
jdbc(url, table, JDBCRelation.columnPartition(null), properties)
// properties should override settings in extraOptions.
this.extraOptions = this.extraOptions ++ properties.asScala
// explicit url and dbtable should override all
this.extraOptions += (JDBCOptions.JDBC_URL -> url, JDBCOptions.JDBC_TABLE_NAME -> table)
format("jdbc").load()
}

/**
Expand All @@ -177,7 +181,8 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging {
* @param upperBound the maximum value of `columnName` used to decide partition stride.
* @param numPartitions the number of partitions. This, along with `lowerBound` (inclusive),
* `upperBound` (exclusive), form partition strides for generated WHERE
* clause expressions used to split the column `columnName` evenly.
* clause expressions used to split the column `columnName` evenly. When
* the input is less than 1, the number is set to 1.
* @param connectionProperties JDBC database connection arguments, a list of arbitrary string
* tag/value. Normally at least a "user" and "password" property
* should be included. "fetchsize" can be used to control the
Expand All @@ -192,9 +197,13 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging {
upperBound: Long,
numPartitions: Int,
connectionProperties: Properties): DataFrame = {
val partitioning = JDBCPartitioningInfo(columnName, lowerBound, upperBound, numPartitions)
val parts = JDBCRelation.columnPartition(partitioning)
jdbc(url, table, parts, connectionProperties)
// columnName, lowerBound, upperBound and numPartitions override settings in extraOptions.
this.extraOptions ++= Map(
JDBCOptions.JDBC_PARTITION_COLUMN -> columnName,
JDBCOptions.JDBC_LOWER_BOUND -> lowerBound.toString,
JDBCOptions.JDBC_UPPER_BOUND -> upperBound.toString,
JDBCOptions.JDBC_NUM_PARTITIONS -> numPartitions.toString)
jdbc(url, table, connectionProperties)
}

/**
Expand All @@ -220,22 +229,14 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging {
table: String,
predicates: Array[String],
connectionProperties: Properties): DataFrame = {
// connectionProperties should override settings in extraOptions.
val params = extraOptions.toMap ++ connectionProperties.asScala.toMap
val options = new JDBCOptions(url, table, params)
val parts: Array[Partition] = predicates.zipWithIndex.map { case (part, i) =>
JDBCPartition(part, i) : Partition
}
jdbc(url, table, parts, connectionProperties)
}

private def jdbc(
url: String,
table: String,
parts: Array[Partition],
connectionProperties: Properties): DataFrame = {
// connectionProperties should override settings in extraOptions.
this.extraOptions = this.extraOptions ++ connectionProperties.asScala
// explicit url and dbtable should override all
this.extraOptions += ("url" -> url, "dbtable" -> table)
format("jdbc").load()
val relation = JDBCRelation(parts, options)(sparkSession)
sparkSession.baseRelationToDataFrame(relation)
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,9 +76,6 @@ class JDBCOptions(

// the number of partitions
val numPartitions = parameters.get(JDBC_NUM_PARTITIONS).map(_.toInt)
require(numPartitions.isEmpty || numPartitions.get > 0,
s"Invalid value `${numPartitions.get}` for parameter `$JDBC_NUM_PARTITIONS`. " +
"The minimum value is 1.")

// ------------------------------------------------------------
// Optional parameters only for reading
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,8 @@ private[sql] case class JDBCRelation(
}

override def toString: String = {
val partitioningInfo = if (parts.nonEmpty) s" [numPartitions=${parts.length}]" else ""
// credentials should not be included in the plan output, table information is sufficient.
s"JDBCRelation(${jdbcOptions.table})"
s"JDBCRelation(${jdbcOptions.table})" + partitioningInfo
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -657,7 +657,7 @@ object JdbcUtils extends Logging {
df: DataFrame,
url: String,
table: String,
options: JDBCOptions) {
options: JDBCOptions): Unit = {
val dialect = JdbcDialects.get(url)
val nullTypes: Array[Int] = df.schema.fields.map { field =>
getJdbcType(field.dataType, dialect).jdbcNullType
Expand All @@ -667,13 +667,13 @@ object JdbcUtils extends Logging {
val getConnection: () => Connection = createConnectionFactory(options)
val batchSize = options.batchSize
val isolationLevel = options.isolationLevel
val numPartitions = options.numPartitions
val repartitionedDF =
if (numPartitions.isDefined && numPartitions.get < df.rdd.getNumPartitions) {
df.coalesce(numPartitions.get)
} else {
df
}
val repartitionedDF = options.numPartitions match {
case Some(n) if n <= 0 => throw new IllegalArgumentException(
s"Invalid value `$n` for parameter `${JDBCOptions.JDBC_NUM_PARTITIONS}` in table writing " +
"via JDBC. The minimum value is 1.")
case Some(n) if n < df.rdd.getNumPartitions => df.coalesce(n)
case _ => df
}
repartitionedDF.foreachPartition(iterator => savePartition(
getConnection, table, iterator, rddSchema, nullTypes, batchSize, dialect, isolationLevel)
)
Expand Down
67 changes: 48 additions & 19 deletions sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,12 @@ import java.util.{Calendar, GregorianCalendar, Properties}
import org.h2.jdbc.JdbcSQLException
import org.scalatest.{BeforeAndAfter, PrivateMethodTester}

import org.apache.spark.{SparkException, SparkFunSuite}
import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.{DataFrame, Row}
import org.apache.spark.sql.execution.DataSourceScanExec
import org.apache.spark.sql.execution.command.ExplainCommand
import org.apache.spark.sql.execution.datasources.LogicalRelation
import org.apache.spark.sql.execution.datasources.jdbc.{JDBCOptions, JDBCRDD, JdbcUtils}
import org.apache.spark.sql.execution.datasources.jdbc.{JDBCOptions, JDBCRDD, JDBCRelation, JdbcUtils}
import org.apache.spark.sql.sources._
import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.sql.types._
Expand Down Expand Up @@ -209,6 +209,16 @@ class JDBCSuite extends SparkFunSuite
conn.close()
}

// Check whether the tables are fetched in the expected degree of parallelism
def checkNumPartitions(df: DataFrame, expectedNumPartitions: Int): Unit = {
val jdbcRelations = df.queryExecution.analyzed.collect {
case LogicalRelation(r: JDBCRelation, _, _) => r
}
assert(jdbcRelations.length == 1)
assert(jdbcRelations.head.parts.length == expectedNumPartitions,
s"Expecting a JDBCRelation with $expectedNumPartitions partitions, but got:`$jdbcRelations`")
}

test("SELECT *") {
assert(sql("SELECT * FROM foobar").collect().size === 3)
}
Expand Down Expand Up @@ -313,13 +323,23 @@ class JDBCSuite extends SparkFunSuite
}

test("SELECT * partitioned") {
assert(sql("SELECT * FROM parts").collect().size == 3)
val df = sql("SELECT * FROM parts")
checkNumPartitions(df, expectedNumPartitions = 3)
assert(df.collect().length == 3)
}

test("SELECT WHERE (simple predicates) partitioned") {
assert(sql("SELECT * FROM parts WHERE THEID < 1").collect().size === 0)
assert(sql("SELECT * FROM parts WHERE THEID != 2").collect().size === 2)
assert(sql("SELECT THEID FROM parts WHERE THEID = 1").collect().size === 1)
val df1 = sql("SELECT * FROM parts WHERE THEID < 1")
checkNumPartitions(df1, expectedNumPartitions = 3)
assert(df1.collect().length === 0)

val df2 = sql("SELECT * FROM parts WHERE THEID != 2")
checkNumPartitions(df2, expectedNumPartitions = 3)
assert(df2.collect().length === 2)

val df3 = sql("SELECT THEID FROM parts WHERE THEID = 1")
checkNumPartitions(df3, expectedNumPartitions = 3)
assert(df3.collect().length === 1)
}

test("SELECT second field partitioned") {
Expand Down Expand Up @@ -370,24 +390,27 @@ class JDBCSuite extends SparkFunSuite
}

test("Partitioning via JDBCPartitioningInfo API") {
assert(
spark.read.jdbc(urlWithUserAndPass, "TEST.PEOPLE", "THEID", 0, 4, 3, new Properties())
.collect().length === 3)
val df = spark.read.jdbc(urlWithUserAndPass, "TEST.PEOPLE", "THEID", 0, 4, 3, new Properties())
checkNumPartitions(df, expectedNumPartitions = 3)
assert(df.collect().length === 3)
}

test("Partitioning via list-of-where-clauses API") {
val parts = Array[String]("THEID < 2", "THEID >= 2")
assert(spark.read.jdbc(urlWithUserAndPass, "TEST.PEOPLE", parts, new Properties())
.collect().length === 3)
val df = spark.read.jdbc(urlWithUserAndPass, "TEST.PEOPLE", parts, new Properties())
checkNumPartitions(df, expectedNumPartitions = 2)
assert(df.collect().length === 3)
}

test("Partitioning on column that might have null values.") {
assert(
spark.read.jdbc(urlWithUserAndPass, "TEST.EMP", "theid", 0, 4, 3, new Properties())
.collect().length === 4)
assert(
spark.read.jdbc(urlWithUserAndPass, "TEST.EMP", "THEID", 0, 4, 3, new Properties())
.collect().length === 4)
val df = spark.read.jdbc(urlWithUserAndPass, "TEST.EMP", "theid", 0, 4, 3, new Properties())
checkNumPartitions(df, expectedNumPartitions = 3)
assert(df.collect().length === 4)

val df2 = spark.read.jdbc(urlWithUserAndPass, "TEST.EMP", "THEID", 0, 4, 3, new Properties())
checkNumPartitions(df2, expectedNumPartitions = 3)
assert(df2.collect().length === 4)

// partitioning on a nullable quoted column
assert(
spark.read.jdbc(urlWithUserAndPass, "TEST.EMP", """"Dept"""", 0, 4, 3, new Properties())
Expand All @@ -404,6 +427,7 @@ class JDBCSuite extends SparkFunSuite
numPartitions = 0,
connectionProperties = new Properties()
)
checkNumPartitions(res, expectedNumPartitions = 1)
assert(res.count() === 8)
}

Expand All @@ -417,6 +441,7 @@ class JDBCSuite extends SparkFunSuite
numPartitions = 10,
connectionProperties = new Properties()
)
checkNumPartitions(res, expectedNumPartitions = 4)
assert(res.count() === 8)
}

Expand All @@ -430,6 +455,7 @@ class JDBCSuite extends SparkFunSuite
numPartitions = 4,
connectionProperties = new Properties()
)
checkNumPartitions(res, expectedNumPartitions = 1)
assert(res.count() === 8)
}

Expand All @@ -450,7 +476,9 @@ class JDBCSuite extends SparkFunSuite
}

test("SELECT * on partitioned table with a nullable partition column") {
assert(sql("SELECT * FROM nullparts").collect().size == 4)
val df = sql("SELECT * FROM nullparts")
checkNumPartitions(df, expectedNumPartitions = 3)
assert(df.collect().length == 4)
}

test("H2 integral types") {
Expand Down Expand Up @@ -722,7 +750,8 @@ class JDBCSuite extends SparkFunSuite
}
// test the JdbcRelation toString output
df.queryExecution.analyzed.collect {
case r: LogicalRelation => assert(r.relation.toString == "JDBCRelation(TEST.PEOPLE)")
case r: LogicalRelation =>
assert(r.relation.toString == "JDBCRelation(TEST.PEOPLE) [numPartitions=3]")
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -319,9 +319,12 @@ class JDBCWriteSuite extends SharedSQLContext with BeforeAndAfter {
df.write.format("jdbc")
.option("dbtable", "TEST.SAVETEST")
.option("url", url1)
.option("user", "testUser")
.option("password", "testPass")
.option(s"${JDBCOptions.JDBC_NUM_PARTITIONS}", "0")
.save()
}.getMessage
assert(e.contains("Invalid value `0` for parameter `numPartitions`. The minimum value is 1"))
assert(e.contains("Invalid value `0` for parameter `numPartitions` in table writing " +
"via JDBC. The minimum value is 1."))
}
}

0 comments on commit b28fe4a

Please sign in to comment.