diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala index 73d4ad0826c56..9d1e96572a26d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala @@ -244,11 +244,8 @@ case class Divide(left: Expression, right: Expression) extends BinaryArithmetic } else { s"${eval2.primitive} == 0" } - val method = if (left.dataType.isInstanceOf[DecimalType]) { - s".$decimalMethod" - } else { - s"$symbol" - } + val method = if (left.dataType.isInstanceOf[DecimalType]) s".$decimalMethod" else s" $symbol " + val javaType = ctx.javaType(left.dataType) eval1.code + eval2.code + s""" boolean ${ev.isNull} = false; @@ -256,8 +253,7 @@ case class Divide(left: Expression, right: Expression) extends BinaryArithmetic if (${eval1.isNull} || ${eval2.isNull} || $test) { ${ev.isNull} = true; } else { - ${ev.primitive} = - (${ctx.javaType(left.dataType)})(${eval1.primitive}$method(${eval2.primitive})); + ${ev.primitive} = ($javaType) (${eval1.primitive}$method(${eval2.primitive})); } """ } @@ -305,11 +301,8 @@ case class Remainder(left: Expression, right: Expression) extends BinaryArithmet } else { s"${eval2.primitive} == 0" } - val method = if (left.dataType.isInstanceOf[DecimalType]) { - s".$decimalMethod" - } else { - s"$symbol" - } + val method = if (left.dataType.isInstanceOf[DecimalType]) s".$decimalMethod" else s" $symbol " + val javaType = ctx.javaType(left.dataType) eval1.code + eval2.code + s""" boolean ${ev.isNull} = false; @@ -317,7 +310,7 @@ case class Remainder(left: Expression, right: Expression) extends BinaryArithmet if (${eval1.isNull} || ${eval2.isNull} || $test) { ${ev.isNull} = true; } else { - ${ev.primitive} = ${eval1.primitive}$method(${eval2.primitive}); + ${ev.primitive} = ($javaType) (${eval1.primitive}$method(${eval2.primitive})); } """ } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala index a56a5ae70a95f..3f4843259e80b 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala @@ -104,6 +104,17 @@ class ArithmeticExpressionSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation(Divide(Literal(Decimal(1.0)), Literal(Decimal(2.0))), Decimal(0.5)) } + test("% (Remainder)") { + testNumericDataTypes { convert => + val left = Literal(convert(1)) + val right = Literal(convert(2)) + checkEvaluation(Remainder(left, right), convert(1)) + checkEvaluation(Remainder(Literal.create(null, left.dataType), right), null) + checkEvaluation(Remainder(left, Literal.create(null, right.dataType)), null) + checkEvaluation(Remainder(left, Literal(convert(0))), null) // mod by 0 + } + } + test("Abs") { testNumericDataTypes { convert => checkEvaluation(Abs(Literal(convert(0))), convert(0)) @@ -112,38 +123,6 @@ class ArithmeticExpressionSuite extends SparkFunSuite with ExpressionEvalHelper } } - test("Divide") { - checkEvaluation(Divide(Literal(2), Literal(1)), 2) - checkEvaluation(Divide(Literal(1.0), Literal(2.0)), 0.5) - checkEvaluation(Divide(Literal(1), Literal(2)), 0) - checkEvaluation(Divide(Literal(1), Literal(0)), null) - checkEvaluation(Divide(Literal(1.0), Literal(0.0)), null) - checkEvaluation(Divide(Literal(0.0), Literal(0.0)), null) - checkEvaluation(Divide(Literal(0), Literal.create(null, IntegerType)), null) - checkEvaluation(Divide(Literal(1), Literal.create(null, IntegerType)), null) - checkEvaluation(Divide(Literal.create(null, IntegerType), Literal(0)), null) - checkEvaluation(Divide(Literal.create(null, DoubleType), Literal(0.0)), null) - checkEvaluation(Divide(Literal.create(null, IntegerType), Literal(1)), null) - checkEvaluation(Divide(Literal.create(null, IntegerType), Literal.create(null, IntegerType)), - null) - } - - test("Remainder") { - checkEvaluation(Remainder(Literal(2), Literal(1)), 0) - checkEvaluation(Remainder(Literal(1.0), Literal(2.0)), 1.0) - checkEvaluation(Remainder(Literal(1), Literal(2)), 1) - checkEvaluation(Remainder(Literal(1), Literal(0)), null) - checkEvaluation(Remainder(Literal(1.0), Literal(0.0)), null) - checkEvaluation(Remainder(Literal(0.0), Literal(0.0)), null) - checkEvaluation(Remainder(Literal(0), Literal.create(null, IntegerType)), null) - checkEvaluation(Remainder(Literal(1), Literal.create(null, IntegerType)), null) - checkEvaluation(Remainder(Literal.create(null, IntegerType), Literal(0)), null) - checkEvaluation(Remainder(Literal.create(null, DoubleType), Literal(0.0)), null) - checkEvaluation(Remainder(Literal.create(null, IntegerType), Literal(1)), null) - checkEvaluation(Remainder(Literal.create(null, IntegerType), Literal.create(null, IntegerType)), - null) - } - test("MaxOf") { checkEvaluation(MaxOf(1, 2), 2) checkEvaluation(MaxOf(2, 1), 2)