Skip to content

Commit

Permalink
Add spark.sql.binary.comparison.compatible.with.hive conf.
Browse files Browse the repository at this point in the history
  • Loading branch information
wangyum committed Sep 18, 2017
1 parent 3bec6a2 commit 844aec7
Show file tree
Hide file tree
Showing 4 changed files with 80 additions and 13 deletions.
7 changes: 7 additions & 0 deletions docs/sql-programming-guide.md
Original file line number Diff line number Diff line change
Expand Up @@ -968,6 +968,13 @@ Configuration of Parquet can be done using the `setConf` method on `SparkSession
</p>
</td>
</tr>
<tr>
<td><code>spark.sql.binary.comparison.compatible.with.hive</code></td>
<td>true</td>
<td>
Whether compatible with Hive when binary comparison.
</td>
</tr>
</table>

## JSON Datasets
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -116,9 +116,29 @@ object TypeCoercion {
* other is a Timestamp by making the target type to be String.
*/
val findCommonTypeForBinaryComparison: (DataType, DataType) => Option[DataType] = {
// We should follow hive:
// https://github.com/apache/hive/blob/rel/storage-release-2.4.0/ql/src/java/
// org/apache/hadoop/hive/ql/exec/FunctionRegistry.java#L781
// We should cast all relative timestamp/date/string comparison into string comparisons
// This behaves as a user would expect because timestamp strings sort lexicographically.
// i.e. TimeStamp(2013-01-01 00:00 ...) < "2014" = true
case (StringType, DateType) => Some(StringType)
case (DateType, StringType) => Some(StringType)
case (StringType, TimestampType) => Some(StringType)
case (TimestampType, StringType) => Some(StringType)
case (TimestampType, DateType) => Some(StringType)
case (DateType, TimestampType) => Some(StringType)
case (StringType, NullType) => Some(StringType)
case (NullType, StringType) => Some(StringType)
case (l: StringType, r: AtomicType) if r != StringType => Some(r)
case (l: AtomicType, r: StringType) if (l != StringType) => Some(l)
case (l, r) => None
}

/**
* Follow hive's binary comparison action:
* https://github.com/apache/hive/blob/rel/storage-release-2.4.0/ql/src/java/
* org/apache/hadoop/hive/ql/exec/FunctionRegistry.java#L781
*/
val findCommonTypeCompatibleWithHive: (DataType, DataType) =>
Option[DataType] = {
case (StringType, DateType) => Some(DateType)
case (DateType, StringType) => Some(DateType)
case (StringType, TimestampType) => Some(TimestampType)
Expand Down Expand Up @@ -355,9 +375,15 @@ object TypeCoercion {
case p @ Equality(left @ TimestampType(), right @ StringType()) =>
p.makeCopy(Array(left, Cast(right, TimestampType)))
case p @ BinaryComparison(left, right)
if findCommonTypeForBinaryComparison(left.dataType, right.dataType).isDefined =>
if !plan.conf.binaryComparisonCompatibleWithHive &&
findCommonTypeForBinaryComparison(left.dataType, right.dataType).isDefined =>
val commonType = findCommonTypeForBinaryComparison(left.dataType, right.dataType).get
p.makeCopy(Array(castExpr(left, commonType), castExpr(right, commonType)))
case p @ BinaryComparison(left, right)
if plan.conf.binaryComparisonCompatibleWithHive &&
findCommonTypeCompatibleWithHive(left.dataType, right.dataType).isDefined =>
val commonType = findCommonTypeCompatibleWithHive(left.dataType, right.dataType).get
p.makeCopy(Array(castExpr(left, commonType), castExpr(right, commonType)))

case Abs(e @ StringType()) => Abs(Cast(e, DoubleType))
case Sum(e @ StringType()) => Sum(Cast(e, DoubleType))
Expand Down Expand Up @@ -412,8 +438,13 @@ object TypeCoercion {
val rhs = sub.output

val commonTypes = lhs.zip(rhs).flatMap { case (l, r) =>
findCommonTypeForBinaryComparison(l.dataType, r.dataType)
.orElse(findTightestCommonType(l.dataType, r.dataType))
if (plan.conf.binaryComparisonCompatibleWithHive) {
findCommonTypeCompatibleWithHive(l.dataType, r.dataType)
.orElse(findTightestCommonType(l.dataType, r.dataType))
} else {
findCommonTypeForBinaryComparison(l.dataType, r.dataType)
.orElse(findTightestCommonType(l.dataType, r.dataType))
}
}

// The number of columns/expressions must match between LHS and RHS of an
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -925,6 +925,12 @@ object SQLConf {
.intConf
.createWithDefault(10000)

val BINARY_COMPARISON_COMPATIBLE_WITH_HIVE =
buildConf("spark.sql.binary.comparison.compatible.with.hive")
.doc("Whether compatible with Hive when binary comparison.")
.booleanConf
.createWithDefault(true)

object Deprecated {
val MAPRED_REDUCE_TASKS = "mapred.reduce.tasks"
}
Expand Down Expand Up @@ -1203,6 +1209,9 @@ class SQLConf extends Serializable with Logging {

def arrowMaxRecordsPerBatch: Int = getConf(ARROW_EXECUTION_MAX_RECORDS_PER_BATCH)

def binaryComparisonCompatibleWithHive: Boolean =
getConf(SQLConf.BINARY_COMPARISON_COMPATIBLE_WITH_HIVE)

/** ********************** SQLConf functionality methods ************ */

/** Set Spark SQL configuration properties. */
Expand Down
34 changes: 27 additions & 7 deletions sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2684,18 +2684,38 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext {
val str2 = Int.MaxValue.toString + "1"
val str3 = "10"
Seq(str1, str2, str3).toDF("c1").createOrReplaceTempView("v")
checkAnswer(sql("SELECT c1 from v where c1 > 0"), Row(str1) :: Row(str2) :: Row(str3) :: Nil)
checkAnswer(sql("SELECT c1 from v where c1 > 0L"), Row(str1) :: Row(str2) :: Row(str3) :: Nil)
withSQLConf(SQLConf.BINARY_COMPARISON_COMPATIBLE_WITH_HIVE.key -> "true") {
checkAnswer(sql("SELECT c1 from v where c1 > 0"),
Row(str1) :: Row(str2) :: Row(str3) :: Nil)
checkAnswer(sql("SELECT c1 from v where c1 > 0L"),
Row(str1) :: Row(str2) :: Row(str3) :: Nil)
}

withSQLConf(SQLConf.BINARY_COMPARISON_COMPATIBLE_WITH_HIVE.key -> "false") {
checkAnswer(sql("SELECT c1 from v where c1 > 0"), Row(str3) :: Nil)
checkAnswer(sql("SELECT c1 from v where c1 > 0L"), Row(str2) :: Row(str3) :: Nil)
}
}
}

test("SPARK-21646: CommonTypeForBinaryComparison: DoubleType vs IntegerType") {
withTempView("v") {
Seq(("0", 1), ("-0.4", 2)).toDF("a", "b").createOrReplaceTempView("v")
checkAnswer(sql("SELECT a FROM v WHERE a=0"), Seq(Row("0")))
checkAnswer(sql("SELECT a FROM v WHERE a=0L"), Seq(Row("0")))
checkAnswer(sql("SELECT a FROM v WHERE a=0.0"), Seq(Row("0")))
checkAnswer(sql("SELECT a FROM v WHERE a=-0.4"), Seq(Row("-0.4")))
Seq(("0", 1), ("-0.4", 2), ("0.6", 3)).toDF("a", "b").createOrReplaceTempView("v")
withSQLConf(SQLConf.BINARY_COMPARISON_COMPATIBLE_WITH_HIVE.key -> "true") {
checkAnswer(sql("SELECT a FROM v WHERE a = 0"), Seq(Row("0")))
checkAnswer(sql("SELECT a FROM v WHERE a = 0L"), Seq(Row("0")))
checkAnswer(sql("SELECT a FROM v WHERE a = 0.0"), Seq(Row("0")))
checkAnswer(sql("SELECT a FROM v WHERE a = -0.4"), Seq(Row("-0.4")))
checkAnswer(sql("SELECT count(*) FROM v WHERE a > 0"), Row(1) :: Nil)
}

withSQLConf(SQLConf.BINARY_COMPARISON_COMPATIBLE_WITH_HIVE.key -> "false") {
checkAnswer(sql("SELECT a FROM v WHERE a = 0"), Seq(Row("0"), Row("-0.4"), Row("0.6")))
checkAnswer(sql("SELECT a FROM v WHERE a = 0L"), Seq(Row("0"), Row("-0.4"), Row("0.6")))
checkAnswer(sql("SELECT a FROM v WHERE a = 0.0"), Seq(Row("0")))
checkAnswer(sql("SELECT a FROM v WHERE a = -0.4"), Seq(Row("-0.4")))
checkAnswer(sql("SELECT count(*) FROM v WHERE a > 0"), Row(0) :: Nil)
}
}
}

Expand Down

0 comments on commit 844aec7

Please sign in to comment.