Skip to content

Commit

Permalink
Cast integer to Long to avoid InterOverflow in the IntegralDiv operator.
Browse files Browse the repository at this point in the history
Handled cast in TypeCoercion
  • Loading branch information
sandeep-katta committed May 22, 2020
1 parent cac15d9 commit b2eea1f
Show file tree
Hide file tree
Showing 4 changed files with 30 additions and 28 deletions.
Expand Up @@ -61,6 +61,7 @@ object TypeCoercion {
IfCoercion ::
StackCoercion ::
Division ::
IntegralDivision ::
ImplicitTypeCasts ::
DateTimeOperations ::
WindowFrameCoercion ::
Expand Down Expand Up @@ -684,6 +685,24 @@ object TypeCoercion {
}
}

/**
* From SPARK-16323 division operator returns Long.
* So need to add the rule to cast to Long if there operands are of Byte, Short and Integer
*/
object IntegralDivision extends TypeCoercionRule {
override protected def coerceTypes(
plan: LogicalPlan): LogicalPlan = plan resolveExpressions {
case e if !e.childrenResolved => e
case d@IntegralDivide(left, right) =>
IntegralDivide(castToLong(left), castToLong(right))
}

def castToLong(expr: Expression): Expression = expr.dataType match {
case _: ByteType | _: ShortType | _: IntegerType => Cast(expr, LongType)
case _ => expr
}
}

/**
* Coerces the type of different branches of a CASE WHEN statement to a common type.
*/
Expand Down
Expand Up @@ -17,13 +17,10 @@

package org.apache.spark.sql.catalyst.expressions

import scala.math.Integral

import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.{FunctionRegistry, TypeCheckResult, TypeCoercion}
import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
import org.apache.spark.sql.catalyst.expressions.codegen.CodeGenerator.JAVA_LONG
import org.apache.spark.sql.catalyst.util.{IntervalUtils, TypeUtils}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
Expand Down Expand Up @@ -340,19 +337,11 @@ trait DivModLike extends BinaryArithmetic {
} else {
s"${eval2.value} == 0"
}
val isIntegralDiv = this.isInstanceOf[IntegralDivide]
// From SPARK-16323 IntegralDivision returns Long data type
val javaType = if (isIntegralDiv) JAVA_LONG else CodeGenerator.javaType(dataType)
val operandJavaType = if (isIntegralDiv) operandsDataType match {
case _: IntegerType => JAVA_LONG
case other => CodeGenerator.javaType(other)
} else {
CodeGenerator.javaType(operandsDataType)
}
val javaType = CodeGenerator.javaType(dataType)
val operation = if (operandsDataType.isInstanceOf[DecimalType]) {
decimalToDataTypeCodeGen(s"${eval1.value}.$decimalMethod(${eval2.value})")
} else {
s"($javaType)(((${operandJavaType})(${eval1.value})) $symbol ${eval2.value})"
s"($javaType)(${eval1.value} $symbol ${eval2.value})"
}
if (!left.nullable && !right.nullable) {
ev.copy(code = code"""
Expand Down Expand Up @@ -434,21 +423,13 @@ case class IntegralDivide(

private lazy val div: (Any, Any) => Any = {
val integral = left.dataType match {
// if it is of Integer type than cast it to Long
case _: IntegerType =>
implicitly[Integral[Long]].asInstanceOf[Integral[Any]]
case i: IntegralType =>
i.integral.asInstanceOf[Integral[Any]]
case d: DecimalType =>
d.asIntegral.asInstanceOf[Integral[Any]]
}
(x, y) => {
def convertToLong(data: Any): Any = data match {
case i: Integer => i.toLong
case other => other
}

val res = integral.quot(convertToLong(x), convertToLong(y))
val res = integral.quot(x, y)
if (res == null) {
null
} else {
Expand Down
Expand Up @@ -505,9 +505,4 @@ class ArithmeticExpressionSuite extends SparkFunSuite with ExpressionEvalHelper
checkEvaluation(e6, 0.toByte)
}
}

test("SPARK-31761: test integer overflow for (Divide) integral type ") {
checkEvaluation(IntegralDivide(Literal(Integer.MIN_VALUE), Literal(-1)), Integer
.MIN_VALUE.toLong * -1)
}
}
Expand Up @@ -3495,6 +3495,13 @@ class SQLQuerySuite extends QueryTest with SharedSparkSession with AdaptiveSpark
assert(df4.schema.head.name === "randn(1)")
checkIfSeedExistsInExplain(df2)
}
}

test("SPARK-31761: test byte, short, integer overflow for (Divide) integral type ") {
checkAnswer(sql("Select -2147483648 DIV -1"), Seq(Row(Integer.MIN_VALUE.toLong * -1)))
checkAnswer(sql("select CAST(-128 as Byte) DIV CAST (-1 as Byte)"),
Seq(Row(Byte.MinValue.toLong * -1)))
checkAnswer(sql("select CAST(-32768 as short) DIV CAST (-1 as short)"),
Seq(Row(Short.MinValue.toLong * -1)))
}
}
case class Foo(bar: Option[String])

0 comments on commit b2eea1f

Please sign in to comment.