Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import java.util.Properties

import scala.collection.mutable.ArrayBuffer

import org.apache.spark.internal.Logging
import org.apache.spark.Partition
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{DataFrame, Row, SaveMode, SparkSession, SQLContext}
Expand All @@ -36,7 +37,7 @@ private[sql] case class JDBCPartitioningInfo(
upperBound: Long,
numPartitions: Int)

private[sql] object JDBCRelation {
private[sql] object JDBCRelation extends Logging {
/**
* Given a partitioning schematic (a column of integral type, a number of
* partitions, and upper and lower bounds on the column's value), generate
Expand All @@ -52,29 +53,46 @@ private[sql] object JDBCRelation {
* @return an array of partitions with where clause for each partition
*/
def columnPartition(partitioning: JDBCPartitioningInfo): Array[Partition] = {
if (partitioning == null) return Array[Partition](JDBCPartition(null, 0))
if (partitioning == null || partitioning.numPartitions <= 1 ||
partitioning.lowerBound == partitioning.upperBound) {
return Array[Partition](JDBCPartition(null, 0))
}

val numPartitions = partitioning.numPartitions
val column = partitioning.column
if (numPartitions == 1) return Array[Partition](JDBCPartition(null, 0))
val lowerBound = partitioning.lowerBound
val upperBound = partitioning.upperBound
require (lowerBound <= upperBound,
"Operation not allowed: the lower bound of partitioning column is larger than the upper " +
s"bound. Lower bound: $lowerBound; Upper bound: $upperBound")

val numPartitions =
if ((upperBound - lowerBound) >= partitioning.numPartitions) {
partitioning.numPartitions
} else {
logWarning("The number of partitions is reduced because the specified number of " +
"partitions is less than the difference between upper bound and lower bound. " +
s"Updated number of partitions: ${upperBound - lowerBound}; Input number of " +
s"partitions: ${partitioning.numPartitions}; Lower bound: $lowerBound; " +
s"Upper bound: $upperBound.")
upperBound - lowerBound
}
// Overflow and silliness can happen if you subtract then divide.
// Here we get a little roundoff, but that's (hopefully) OK.
val stride: Long = (partitioning.upperBound / numPartitions
- partitioning.lowerBound / numPartitions)
val stride: Long = upperBound / numPartitions - lowerBound / numPartitions
val column = partitioning.column
var i: Int = 0
var currentValue: Long = partitioning.lowerBound
var currentValue: Long = lowerBound
var ans = new ArrayBuffer[Partition]()
while (i < numPartitions) {
val lowerBound = if (i != 0) s"$column >= $currentValue" else null
val lBound = if (i != 0) s"$column >= $currentValue" else null
currentValue += stride
val upperBound = if (i != numPartitions - 1) s"$column < $currentValue" else null
val uBound = if (i != numPartitions - 1) s"$column < $currentValue" else null
val whereClause =
if (upperBound == null) {
lowerBound
} else if (lowerBound == null) {
s"$upperBound or $column is null"
if (uBound == null) {
lBound
} else if (lBound == null) {
s"$uBound or $column is null"
} else {
s"$lowerBound AND $upperBound"
s"$lBound AND $uBound"
}
ans += JDBCPartition(whereClause, i)
i = i + 1
Expand Down
65 changes: 65 additions & 0 deletions sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,16 @@ class JDBCSuite extends SparkFunSuite
"insert into test.emp values ('kathy', null, null)").executeUpdate()
conn.commit()

conn.prepareStatement(
"create table test.seq(id INTEGER)").executeUpdate()
(0 to 6).foreach { value =>
conn.prepareStatement(
s"insert into test.seq values ($value)").executeUpdate()
}
conn.prepareStatement(
"insert into test.seq values (null)").executeUpdate()
conn.commit()

sql(
s"""
|CREATE TEMPORARY TABLE nullparts
Expand Down Expand Up @@ -373,6 +383,61 @@ class JDBCSuite extends SparkFunSuite
.collect().length === 4)
}

test("Partitioning on column where numPartitions is zero") {
val res = spark.read.jdbc(
url = urlWithUserAndPass,
table = "TEST.seq",
columnName = "id",
lowerBound = 0,
upperBound = 4,
numPartitions = 0,
connectionProperties = new Properties
)
assert(res.count() === 8)
}

test("Partitioning on column where numPartitions are more than the number of total rows") {
val res = spark.read.jdbc(
url = urlWithUserAndPass,
table = "TEST.seq",
columnName = "id",
lowerBound = 1,
upperBound = 5,
numPartitions = 10,
connectionProperties = new Properties
)
assert(res.count() === 8)
}

test("Partitioning on column where lowerBound is equal to upperBound") {
val res = spark.read.jdbc(
url = urlWithUserAndPass,
table = "TEST.seq",
columnName = "id",
lowerBound = 5,
upperBound = 5,
numPartitions = 4,
connectionProperties = new Properties
)
assert(res.count() === 8)
}

test("Partitioning on column where lowerBound is larger than upperBound") {
val e = intercept[IllegalArgumentException] {
spark.read.jdbc(
url = urlWithUserAndPass,
table = "TEST.seq",
columnName = "id",
lowerBound = 5,
upperBound = 1,
numPartitions = 3,
connectionProperties = new Properties
)
}.getMessage
assert(e.contains("Operation not allowed: the lower bound of partitioning column " +
"is larger than the upper bound. Lower bound: 5; Upper bound: 1"))
}

test("SELECT * on partitioned table with a nullable partition column") {
assert(sql("SELECT * FROM nullparts").collect().size == 4)
}
Expand Down