From 7ec8f7fb79474f894f83b047974f8a4bdcde8e68 Mon Sep 17 00:00:00 2001 From: Burak Yavuz Date: Tue, 21 Apr 2015 11:36:39 -0700 Subject: [PATCH 01/12] Added math functions for DataFrames added rint toDeg and toRad --- .../sql/catalyst/expressions/Expression.scala | 6 + .../catalyst/expressions/mathfunctions.scala | 272 +++++++++++ .../org/apache/spark/sql/functions.scala | 2 +- .../org/apache/spark/sql/mathfunctions.scala | 455 ++++++++++++++++++ .../spark/sql/ColumnExpressionSuite.scala | 173 ++++++- .../scala/org/apache/spark/sql/TestData.scala | 5 + 6 files changed, 911 insertions(+), 2 deletions(-) create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathfunctions.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/mathfunctions.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index 4e3bbc06a5b4c..b56a1815b9037 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -89,6 +89,12 @@ abstract class BinaryExpression extends Expression with trees.BinaryNode[Express override def toString: String = s"($left $symbol $right)" } +abstract class BinaryFunctionExpression extends Expression with trees.BinaryNode[Expression] { + self: Product => + + override def foldable: Boolean = left.foldable && right.foldable +} + abstract class LeafExpression extends Expression with trees.LeafNode[Expression] { self: Product => } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathfunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathfunctions.scala new file mode 100644 index 0000000000000..a499bc4441f3d --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathfunctions.scala @@ -0,0 +1,272 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions + +import org.apache.spark.sql.catalyst.analysis.UnresolvedException +import org.apache.spark.sql.types._ + +trait MathematicalExpression extends UnaryExpression with Serializable { self: Product => + type EvaluatedType = Any + + override def dataType: DataType = DoubleType + override def foldable: Boolean = child.foldable + override def nullable: Boolean = true + + lazy val numeric = child.dataType match { + case n: NumericType => n.numeric.asInstanceOf[Numeric[Any]] + case other => sys.error(s"Type $other does not support numeric operations") + } +} + +abstract class MathematicalExpressionForDouble(f: Double => Double) + extends MathematicalExpression { self: Product => + + override def eval(input: Row): Any = { + val evalE = child.eval(input) + if (evalE == null) { + null + } else { + f(numeric.toDouble(evalE)) + } + } +} + +abstract class MathematicalExpressionForInt(f: Int => Int) + extends MathematicalExpression { self: Product => + + override def dataType: DataType = IntegerType + + override def eval(input: Row): Any = { + val evalE = child.eval(input) + if (evalE == null) { + null + } else { + f(numeric.toInt(evalE)) + } + } +} + +abstract class MathematicalExpressionForFloat(f: Float => Float) + extends MathematicalExpression { self: Product => + + override def dataType: DataType = FloatType + + override def eval(input: Row): Any = { + val evalE = child.eval(input) + if (evalE == null) { + null + } else { + f(numeric.toFloat(evalE)) + } + } +} + +abstract class MathematicalExpressionForLong(f: Long => Long) + extends MathematicalExpression { self: Product => + + override def dataType: DataType = LongType + + override def eval(input: Row): Any = { + val evalE = child.eval(input) + if (evalE == null) { + null + } else { + f(numeric.toLong(evalE)) + } + } +} + +case class Sin(child: Expression) extends MathematicalExpressionForDouble(math.sin) { + override def toString: String = s"SIN($child)" +} + +case class Asin(child: Expression) extends MathematicalExpressionForDouble(math.asin) { + override def toString: String = s"ASIN($child)" +} + +case class Sinh(child: Expression) extends MathematicalExpressionForDouble(math.sinh) { + override def toString: String = s"SINH($child)" +} + +case class Cos(child: Expression) extends MathematicalExpressionForDouble(math.cos) { + override def toString: String = s"COS($child)" +} + +case class Acos(child: Expression) extends MathematicalExpressionForDouble(math.acos) { + override def toString: String = s"ACOS($child)" +} + +case class Cosh(child: Expression) extends MathematicalExpressionForDouble(math.cosh) { + override def toString: String = s"COSH($child)" +} + +case class Tan(child: Expression) extends MathematicalExpressionForDouble(math.tan) { + override def toString: String = s"TAN($child)" +} + +case class Atan(child: Expression) extends MathematicalExpressionForDouble(math.atan) { + override def toString: String = s"ATAN($child)" +} + +case class Tanh(child: Expression) extends MathematicalExpressionForDouble(math.tanh) { + override def toString: String = s"TANH($child)" +} + +case class Ceil(child: Expression) extends MathematicalExpressionForDouble(math.ceil) { + override def toString: String = s"CEIL($child)" +} + +case class Floor(child: Expression) extends MathematicalExpressionForDouble(math.floor) { + override def toString: String = s"FLOOR($child)" +} + +case class Rint(child: Expression) extends MathematicalExpressionForDouble(math.rint) { + override def toString: String = s"RINT($child)" +} + +case class Cbrt(child: Expression) extends MathematicalExpressionForDouble(math.cbrt) { + override def toString: String = s"CBRT($child)" +} + +case class Signum(child: Expression) extends MathematicalExpressionForDouble(math.signum) { + override def toString: String = s"SIGNUM($child)" +} + +case class ISignum(child: Expression) extends MathematicalExpressionForInt(math.signum) { + override def toString: String = s"ISIGNUM($child)" +} + +case class FSignum(child: Expression) extends MathematicalExpressionForFloat(math.signum) { + override def toString: String = s"FSIGNUM($child)" +} + +case class LSignum(child: Expression) extends MathematicalExpressionForLong(math.signum) { + override def toString: String = s"LSIGNUM($child)" +} + +case class ToDegrees(child: Expression) extends MathematicalExpressionForDouble(math.toDegrees) { + override def toString: String = s"TODEG($child)" +} + +case class ToRadians(child: Expression) extends MathematicalExpressionForDouble(math.toRadians) { + override def toString: String = s"TORAD($child)" +} + +case class Log(child: Expression) extends MathematicalExpressionForDouble(math.log) { + override def toString: String = s"LOG($child)" + + override def eval(input: Row): Any = { + val evalE = child.eval(input) + if (evalE == null) { + null + } else { + val value = numeric.toDouble(evalE) + if (value < 0) null + else math.log(value) + } + } +} + +case class Log10(child: Expression) extends MathematicalExpressionForDouble(math.log10) { + override def toString: String = s"LOG10($child)" + + override def eval(input: Row): Any = { + val evalE = child.eval(input) + if (evalE == null) { + null + } else { + val value = numeric.toDouble(evalE) + if (value < 0) null + else math.log10(value) + } + } +} + +case class Log1p(child: Expression) extends MathematicalExpressionForDouble(math.log1p) { + override def toString: String = s"LOG1P($child)" + + override def eval(input: Row): Any = { + val evalE = child.eval(input) + if (evalE == null) { + null + } else { + val value = numeric.toDouble(evalE) + if (value < -1) null + else math.log1p(value) + } + } +} + +case class Exp(child: Expression) extends MathematicalExpressionForDouble(math.exp) { + override def toString: String = s"EXP($child)" +} + +case class Expm1(child: Expression) extends MathematicalExpressionForDouble(math.expm1) { + override def toString: String = s"EXPM1($child)" +} + +abstract class BinaryMathExpression(f: (Double, Double) => Double) + extends BinaryFunctionExpression with Serializable { self: Product => + type EvaluatedType = Any + + def nullable: Boolean = left.nullable || right.nullable + + override lazy val resolved = + left.resolved && right.resolved && + left.dataType == right.dataType && + !DecimalType.isFixed(left.dataType) + + def dataType: DataType = { + if (!resolved) { + throw new UnresolvedException(this, + s"datatype. Can not resolve due to differing types ${left.dataType}, ${right.dataType}") + } + left.dataType + } + + lazy val numeric = dataType match { + case n: NumericType => n.numeric.asInstanceOf[Numeric[Any]] + case other => sys.error(s"Type $other does not support numeric operations") + } + + override def eval(input: Row): Any = { + val evalE1 = left.eval(input) + if(evalE1 == null) { + null + } else { + val evalE2 = right.eval(input) + if (evalE2 == null) { + null + } else { + f(numeric.toDouble(evalE1), numeric.toDouble(evalE2)) + } + } + } +} + +case class Pow(left: Expression, right: Expression) extends BinaryMathExpression(math.pow) { + override def toString: String = s"POW($left, $right)" +} + +case class Hypot(left: Expression, right: Expression) extends BinaryMathExpression(math.hypot) { + override def toString: String = s"HYPOT($left, $right)" +} + +case class Atan2(left: Expression, right: Expression) extends BinaryMathExpression(math.atan2) { + override def toString: String = s"ATAN2($left, $right)" +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index ff91e1d74bc2c..4c1862dd65f6e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -338,7 +338,7 @@ object functions { def sqrt(e: Column): Column = Sqrt(e.expr) /** - * Computes the absolutle value. + * Computes the absolute value. * * @group normal_funcs */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/mathfunctions.scala b/sql/core/src/main/scala/org/apache/spark/sql/mathfunctions.scala new file mode 100644 index 0000000000000..6b1035bcfaeb1 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/mathfunctions.scala @@ -0,0 +1,455 @@ +package org.apache.spark.sql + +import scala.language.implicitConversions + +import org.apache.spark.annotation.Experimental +import org.apache.spark.sql.catalyst.expressions._ + +/** + * :: Experimental :: + * Mathematical Functions available for [[DataFrame]]. + * + * @groupname double_funcs Functions that require DoubleType as an input + * @groupname int_funcs Functions that require IntegerType as an input + * @groupname float_funcs Functions that require FloatType as an input + * @groupname long_funcs Functions that require LongType as an input + */ +@Experimental +// scalastyle:off +object mathfunctions { +// scalastyle:on + + private[this] implicit def toColumn(expr: Expression): Column = Column(expr) + + /** + * Computes the sine of the given value. + * + * @group double_funcs + */ + def sin(e: Column): Column = Sin(e.expr) + + /** + * Computes the sine of the given column. + * + * @group double_funcs + */ + def sin(columnName: String): Column = sin(Column(columnName)) + + /** + * Computes the sine inverse of the given value; the returned angle is in the range + * -pi/2 through pi/2. + * + * @group double_funcs + */ + def asin(e: Column): Column = Asin(e.expr) + + /** + * Computes the sine inverse of the given column; the returned angle is in the range + * -pi/2 through pi/2. + * + * @group double_funcs + */ + def asin(columnName: String): Column = asin(Column(columnName)) + + /** + * Computes the hyperbolic sine of the given value. + * + * @group double_funcs + */ + def sinh(e: Column): Column = Sinh(e.expr) + + /** + * Computes the hyperbolic sine of the given column. + * + * @group double_funcs + */ + def sinh(columnName: String): Column = sinh(Column(columnName)) + + /** + * Computes the cosine of the given value. + * + * @group double_funcs + */ + def cos(e: Column): Column = Cos(e.expr) + + /** + * Computes the cosine of the given column. + * + * @group double_funcs + */ + def cos(columnName: String): Column = cos(Column(columnName)) + + /** + * Computes the cosine inverse of the given value; the returned angle is in the range + * 0.0 through pi. + * + * @group double_funcs + */ + def acos(e: Column): Column = Acos(e.expr) + + /** + * Computes the cosine inverse of the given column; the returned angle is in the range + * 0.0 through pi. + * + * @group double_funcs + */ + def acos(columnName: String): Column = acos(Column(columnName)) + + /** + * Computes the hyperbolic cosine of the given value. + * + * @group double_funcs + */ + def cosh(e: Column): Column = Cosh(e.expr) + + /** + * Computes the hyperbolic cosine of the given column. + * + * @group double_funcs + */ + def cosh(columnName: String): Column = cosh(Column(columnName)) + + /** + * Computes the tangent of the given value. + * + * @group double_funcs + */ + def tan(e: Column): Column = Tan(e.expr) + + /** + * Computes the tangent of the given column. + * + * @group double_funcs + */ + def tan(columnName: String): Column = tan(Column(columnName)) + + /** + * Computes the tangent inverse of the given value. + * + * @group double_funcs + */ + def atan(e: Column): Column = Atan(e.expr) + + /** + * Computes the tangent inverse of the given column. + * + * @group double_funcs + */ + def atan(columnName: String): Column = atan(Column(columnName)) + + /** + * Computes the hyperbolic tangent of the given value. + * + * @group double_funcs + */ + def tanh(e: Column): Column = Tanh(e.expr) + + /** + * Computes the hyperbolic tangent of the given column. + * + * @group double_funcs + */ + def tanh(columnName: String): Column = tanh(Column(columnName)) + + /** + * Converts an angle measured in radians to an approximately equivalent angle measured in degrees. + * + * @group double_funcs + */ + def toDeg(e: Column): Column = ToDegrees(e.expr) + + /** + * Converts an angle measured in radians to an approximately equivalent angle measured in degrees. + * + * @group double_funcs + */ + def toDeg(columnName: String): Column = toDeg(Column(columnName)) + + /** + * Converts an angle measured in degrees to an approximately equivalent angle measured in radians. + * + * @group double_funcs + */ + def toRad(e: Column): Column = ToRadians(e.expr) + + /** + * Converts an angle measured in degrees to an approximately equivalent angle measured in radians. + * + * @group double_funcs + */ + def toRad(columnName: String): Column = toRad(Column(columnName)) + + /** + * Computes the ceiling of the given value. + * + * @group double_funcs + */ + def ceil(e: Column): Column = Ceil(e.expr) + + /** + * Computes the ceiling of the given column. + * + * @group double_funcs + */ + def ceil(columnName: String): Column = ceil(Column(columnName)) + + /** + * Computes the floor of the given value. + * + * @group double_funcs + */ + def floor(e: Column): Column = Floor(e.expr) + + /** + * Computes the floor of the given column. + * + * @group double_funcs + */ + def floor(columnName: String): Column = floor(Column(columnName)) + + /** + * Returns the double value that is closest in value to the argument and + * is equal to a mathematical integer. + * + * @group double_funcs + */ + def rint(e: Column): Column = Rint(e.expr) + + /** + * Returns the double value that is closest in value to the argument and + * is equal to a mathematical integer. + * + * @group double_funcs + */ + def rint(columnName: String): Column = rint(Column(columnName)) + + /** + * Computes the cube-root of the given value. + * + * @group double_funcs + */ + def cbrt(e: Column): Column = Cbrt(e.expr) + + /** + * Computes the cube-root of the given column. + * + * @group double_funcs + */ + def cbrt(columnName: String): Column = cbrt(Column(columnName)) + + /** + * Computes the signum of the given value. + * + * @group double_funcs + */ + def signum(e: Column): Column = Signum(e.expr) + + /** + * Computes the signum of the given column. + * + * @group double_funcs + */ + def signum(columnName: String): Column = signum(Column(columnName)) + + /** + * Computes the signum of the given value. For IntegerType. + * + * @group int_funcs + */ + def isignum(e: Column): Column = ISignum(e.expr) + + /** + * Computes the signum of the given column. For IntegerType. + * + * @group int_funcs + */ + def isignum(columnName: String): Column = isignum(Column(columnName)) + + /** + * Computes the signum of the given value. For FloatType. + * + * @group float_funcs + */ + def fsignum(e: Column): Column = FSignum(e.expr) + + /** + * Computes the signum of the given column. For FloatType. + * + * @group float_funcs + */ + def fsignum(columnName: String): Column = fsignum(Column(columnName)) + + /** + * Computes the signum of the given value. For LongType. + * + * @group long_funcs + */ + def lsignum(e: Column): Column = LSignum(e.expr) + + /** + * Computes the signum of the given column. For FloatType. + * + * @group long_funcs + */ + def lsignum(columnName: String): Column = lsignum(Column(columnName)) + + /** + * Computes the natural logarithm of the given value. + * + * @group double_funcs + */ + def log(e: Column): Column = Log(e.expr) + + /** + * Computes the natural logarithm of the given column. + * + * @group double_funcs + */ + def log(columnName: String): Column = log(Column(columnName)) + + /** + * Computes the logarithm of the given value in Base 10. + * + * @group double_funcs + */ + def log10(e: Column): Column = Log10(e.expr) + + /** + * Computes the logarithm of the given value in Base 10. + * + * @group double_funcs + */ + def log10(columnName: String): Column = log10(Column(columnName)) + + /** + * Computes the natural logarithm of the given value plus one. + * + * @group double_funcs + */ + def log1p(e: Column): Column = Log1p(e.expr) + + /** + * Computes the natural logarithm of the given column plus one. + * + * @group double_funcs + */ + def log1p(columnName: String): Column = log1p(Column(columnName)) + + /** + * Computes the exponential of the given value. + * + * @group double_funcs + */ + def exp(e: Column): Column = Exp(e.expr) + + /** + * Computes the exponential of the given column. + * + * @group double_funcs + */ + def exp(columnName: String): Column = exp(Column(columnName)) + + /** + * Computes the exponential of the given value minus one. + * + * @group double_funcs + */ + def expm1(e: Column): Column = Expm1(e.expr) + + /** + * Computes the exponential of the given column. + * + * @group double_funcs + */ + def expm1(columnName: String): Column = expm1(Column(columnName)) + + /** + * Returns the value of the first argument raised to the power of the second argument. + * + * @group double_funcs + */ + def pow(l: Column, r: Column): Column = Pow(l.expr, r.expr) + + /** + * Returns the value of the first argument raised to the power of the second argument. + * + * @group double_funcs + */ + def pow(l: Column, rightName: String): Column = pow(l, Column(rightName)) + + /** + * Returns the value of the first argument raised to the power of the second argument. + * + * @group double_funcs + */ + def pow(leftName: String, r: Column): Column = pow(Column(leftName), r) + + /** + * Returns the value of the first argument raised to the power of the second argument. + * + * @group double_funcs + */ + def pow(leftName: String, rightName: String): Column = pow(Column(leftName), Column(rightName)) + + /** + * Computes sqrt(a^2^ + b^2^) without intermediate overflow or underflow. + * + * @group double_funcs + */ + def hypot(l: Column, r: Column): Column = Hypot(l.expr, r.expr) + + /** + * Computes sqrt(a^2^ + b^2^) without intermediate overflow or underflow. + * + * @group double_funcs + */ + def hypot(l: Column, rightName: String): Column = hypot(l, Column(rightName)) + + /** + * Computes sqrt(a^2^ + b^2^) without intermediate overflow or underflow. + * + * @group double_funcs + */ + def hypot(leftName: String, r: Column): Column = hypot(Column(leftName), r) + + /** + * Computes sqrt(a^2^ + b^2^) without intermediate overflow or underflow. + * + * @group double_funcs + */ + def hypot(leftName: String, rightName: String): Column = + hypot(Column(leftName), Column(rightName)) + + /** + * Returns the angle theta from the conversion of rectangular coordinates (x, y) to + * polar coordinates (r, theta). + * + * @group double_funcs + */ + def atan2(l: Column, r: Column): Column = Atan2(l.expr, r.expr) + + /** + * Returns the angle theta from the conversion of rectangular coordinates (x, y) to + * polar coordinates (r, theta). + * + * @group double_funcs + */ + def atan2(l: Column, rightName: String): Column = atan2(l, Column(rightName)) + + /** + * Returns the angle theta from the conversion of rectangular coordinates (x, y) to + * polar coordinates (r, theta). + * + * @group double_funcs + */ + def atan2(leftName: String, r: Column): Column = atan2(Column(leftName), r) + + /** + * Returns the angle theta from the conversion of rectangular coordinates (x, y) to + * polar coordinates (r, theta). + * + * @group double_funcs + */ + def atan2(leftName: String, rightName: String): Column = + atan2(Column(leftName), Column(rightName)) +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala index bc8fae100db6a..85f619fc68bcd 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala @@ -18,11 +18,11 @@ package org.apache.spark.sql import org.apache.spark.sql.functions._ +import org.apache.spark.sql.mathfunctions._ import org.apache.spark.sql.test.TestSQLContext import org.apache.spark.sql.test.TestSQLContext.implicits._ import org.apache.spark.sql.types._ - class ColumnExpressionSuite extends QueryTest { import org.apache.spark.sql.TestData._ @@ -331,4 +331,175 @@ class ColumnExpressionSuite extends QueryTest { assert(schema("value").metadata === Metadata.empty) assert(schema("abc").metadata === metadata) } + + def testOneToOneMathFunction[@specialized(Int, Double, Float, Long) T] + (c: Column => Column, f: T => T): Unit = { + checkAnswer( + doubleData.select(c('a)).orderBy('a.asc), + (1 to 100).map(n => Row(f((n * 0.02 - 1).asInstanceOf[T]))) + ) + + checkAnswer( + doubleData.select(c('b)).orderBy('b.desc), + (1 to 100).map(n => Row(f((-n * 0.02 + 1).asInstanceOf[T]))) + ) + + checkAnswer( + doubleData.select(c(lit(null))), + (1 to 100).map(_ => Row(null)) + ) + } + + test("sin") { + testOneToOneMathFunction(sin, math.sin) + } + + test("asin") { + testOneToOneMathFunction(asin, math.asin) + } + + test("sinh") { + testOneToOneMathFunction(sinh, math.sinh) + } + + test("cos") { + testOneToOneMathFunction(cos, math.cos) + } + + test("acos") { + testOneToOneMathFunction(acos, math.acos) + } + + test("cosh") { + testOneToOneMathFunction(cosh, math.cosh) + } + + test("tan") { + testOneToOneMathFunction(tan, math.tan) + } + + test("atan") { + testOneToOneMathFunction(atan, math.atan) + } + + test("tanh") { + testOneToOneMathFunction(tanh, math.tanh) + } + + test("toDeg") { + testOneToOneMathFunction(toDeg, math.toDegrees) + } + + test("toRad") { + testOneToOneMathFunction(toRad, math.toRadians) + } + + test("cbrt") { + testOneToOneMathFunction(cbrt, math.cbrt) + } + + test("ceil") { + testOneToOneMathFunction(ceil, math.ceil) + } + + test("floor") { + testOneToOneMathFunction(floor, math.floor) + } + + test("rint") { + testOneToOneMathFunction(rint, math.rint) + } + + test("exp") { + testOneToOneMathFunction(exp, math.exp) + } + + test("expm1") { + testOneToOneMathFunction(expm1, math.expm1) + } + + test("signum") { + testOneToOneMathFunction[Double](signum, math.signum) + } + + test("isignum") { + testOneToOneMathFunction[Int](isignum, math.signum) + } + + test("fsignum") { + testOneToOneMathFunction[Float](fsignum, math.signum) + } + + test("lsignum") { + testOneToOneMathFunction[Long](lsignum, math.signum) + } + + def testTwoToOneMathFunction( + c: (Column, Column) => Column, + f: (Double, Double) => Double): Unit = { + checkAnswer( + testData2.select(c('a, 'a)).orderBy('a.asc), + testData2.collect().toSeq.map(r => Row(f(r.getInt(0), r.getInt(0)))) + ) + + checkAnswer( + testData2.select(c('a, 'b)).orderBy('a.asc), + testData2.collect().toSeq.map(r => Row(f(r.getInt(0), r.getInt(1)))) + ) + + val nonNull = nullInts.collect().toSeq.filter(r => r.get(0) != null) + + checkAnswer( + nullInts.select(c('a, 'a)).orderBy('a.asc), + Row(null) +: nonNull.map(r => Row(f(r.getInt(0), r.getInt(0)))) + ) + } + + test("pow") { + testTwoToOneMathFunction(pow, math.pow) + } + + test("hypot") { + testTwoToOneMathFunction(hypot, math.hypot) + } + + test("atan2") { + testTwoToOneMathFunction(atan2, math.atan2) + } + + def testOneToOneNonNegativeMathFunction(c: Column => Column, f: Double => Double): Unit = { + checkAnswer( + testData.select(c('key)).orderBy('key.asc), + (1 to 100).map(n => Row(f(n))) + ) + + if (f(-1) === math.log1p(-1)) { + checkAnswer( + negativeData.select(c('key)).orderBy('key.desc), + Row(Double.NegativeInfinity) +: (2 to 100).map(n => Row(null)) + ) + } else { + checkAnswer( + negativeData.select(c('key)).orderBy('key.desc), + (1 to 100).map(n => Row(null)) + ) + } + + checkAnswer( + testData.select(c(lit(null))), + (1 to 100).map(_ => Row(null)) + ) + } + + test("log") { + testOneToOneNonNegativeMathFunction(log, math.log) + } + + test("log10") { + testOneToOneNonNegativeMathFunction(log10, math.log10) + } + + test("log1p") { + testOneToOneNonNegativeMathFunction(log1p, math.log1p) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala b/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala index 225b51bd73d6c..04575d2b65f0f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala @@ -57,6 +57,11 @@ object TestData { TestData2(3, 2) :: Nil, 2).toDF() testData2.registerTempTable("testData2") + case class DoubleData(a: Double, b: Double) + val doubleData = TestSQLContext.sparkContext.parallelize( + (1 to 100).map(i => DoubleData(i * 0.02 - 1, i * -0.02 + 1))).toDF() + doubleData.registerTempTable("doubleData") + case class DecimalData(a: BigDecimal, b: BigDecimal) val decimalData = From 8e28fff0333e923c4f67bee26f5dd0a2fe444fed Mon Sep 17 00:00:00 2001 From: Burak Yavuz Date: Tue, 21 Apr 2015 12:07:27 -0700 Subject: [PATCH 02/12] Added apache header --- .../org/apache/spark/sql/mathfunctions.scala | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/mathfunctions.scala b/sql/core/src/main/scala/org/apache/spark/sql/mathfunctions.scala index 6b1035bcfaeb1..df94380935f83 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/mathfunctions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/mathfunctions.scala @@ -1,3 +1,20 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package org.apache.spark.sql import scala.language.implicitConversions From 937d5a5ad24c2787758be5c1e87c21bf482c2dc3 Mon Sep 17 00:00:00 2001 From: Burak Yavuz Date: Tue, 21 Apr 2015 12:50:28 -0700 Subject: [PATCH 03/12] use doubles instead of ints --- .../org/apache/spark/sql/ColumnExpressionSuite.scala | 8 ++++---- .../src/test/scala/org/apache/spark/sql/TestData.scala | 4 ++++ 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala index 85f619fc68bcd..4a11de7593813 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala @@ -438,13 +438,13 @@ class ColumnExpressionSuite extends QueryTest { c: (Column, Column) => Column, f: (Double, Double) => Double): Unit = { checkAnswer( - testData2.select(c('a, 'a)).orderBy('a.asc), - testData2.collect().toSeq.map(r => Row(f(r.getInt(0), r.getInt(0)))) + nnDoubleData.select(c('a, 'a)).orderBy('a.asc), + nnDoubleData.collect().toSeq.map(r => Row(f(r.getDouble(0), r.getDouble(0)))) ) checkAnswer( - testData2.select(c('a, 'b)).orderBy('a.asc), - testData2.collect().toSeq.map(r => Row(f(r.getInt(0), r.getInt(1)))) + nnDoubleData.select(c('a, 'b)).orderBy('a.asc), + nnDoubleData.collect().toSeq.map(r => Row(f(r.getDouble(0), r.getDouble(1)))) ) val nonNull = nullInts.collect().toSeq.filter(r => r.get(0) != null) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala b/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala index 04575d2b65f0f..0932633839d1b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala @@ -62,6 +62,10 @@ object TestData { (1 to 100).map(i => DoubleData(i * 0.02 - 1, i * -0.02 + 1))).toDF() doubleData.registerTempTable("doubleData") + val nnDoubleData = TestSQLContext.sparkContext.parallelize( + (1 to 100).map(i => DoubleData(i * 0.01, i * -0.01))).toDF() + nnDoubleData.registerTempTable("nnDoubleData") + case class DecimalData(a: BigDecimal, b: BigDecimal) val decimalData = From fa68dbeeda3de9c221e75a066a0e6bfc93f37db6 Mon Sep 17 00:00:00 2001 From: Burak Yavuz Date: Tue, 21 Apr 2015 13:30:39 -0700 Subject: [PATCH 04/12] added double specific test data --- .../apache/spark/sql/ColumnExpressionSuite.scala | 6 +++--- .../test/scala/org/apache/spark/sql/TestData.scala | 13 ++++++++++++- 2 files changed, 15 insertions(+), 4 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala index 4a11de7593813..88c3acf8fc8aa 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala @@ -447,11 +447,11 @@ class ColumnExpressionSuite extends QueryTest { nnDoubleData.collect().toSeq.map(r => Row(f(r.getDouble(0), r.getDouble(1)))) ) - val nonNull = nullInts.collect().toSeq.filter(r => r.get(0) != null) + val nonNull = nullDoubles.collect().toSeq.filter(r => r.get(0) != null) checkAnswer( - nullInts.select(c('a, 'a)).orderBy('a.asc), - Row(null) +: nonNull.map(r => Row(f(r.getInt(0), r.getInt(0)))) + nullDoubles.select(c('a, 'a)).orderBy('a.asc), + Row(null) +: nonNull.map(r => Row(f(r.getDouble(0), r.getDouble(0)))) ) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala b/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala index 0932633839d1b..45f4852db8ebd 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql +import java.lang.{Double => JavaDouble} import java.sql.Timestamp import org.apache.spark.sql.catalyst.plans.logical @@ -57,7 +58,7 @@ object TestData { TestData2(3, 2) :: Nil, 2).toDF() testData2.registerTempTable("testData2") - case class DoubleData(a: Double, b: Double) + case class DoubleData(a: JavaDouble, b: JavaDouble) val doubleData = TestSQLContext.sparkContext.parallelize( (1 to 100).map(i => DoubleData(i * 0.02 - 1, i * -0.02 + 1))).toDF() doubleData.registerTempTable("doubleData") @@ -155,6 +156,16 @@ object TestData { ).toDF() nullInts.registerTempTable("nullInts") + case class NullDoubles(a: JavaDouble) + val nullDoubles = + TestSQLContext.sparkContext.parallelize( + NullDoubles(1.0) :: + NullDoubles(2.0) :: + NullDoubles(3.0) :: + NullDoubles(null) :: Nil + ).toDF() + nullDoubles.registerTempTable("nullDoubles") + val allNulls = TestSQLContext.sparkContext.parallelize( NullInts(null) :: From 534cc11de8efcd8ea6d3ca6e5452f88d752bbba7 Mon Sep 17 00:00:00 2001 From: Burak Yavuz Date: Wed, 22 Apr 2015 13:21:18 -0700 Subject: [PATCH 05/12] added more tests, addressed comments --- .../catalyst/expressions/mathfunctions.scala | 186 ++++++------------ .../ExpressionEvaluationSuite.scala | 164 +++++++++++++++ .../org/apache/spark/sql/mathfunctions.scala | 89 +++++++++ .../apache/spark/sql/JavaDataFrameSuite.java | 9 + .../spark/sql/ColumnExpressionSuite.scala | 17 +- 5 files changed, 332 insertions(+), 133 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathfunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathfunctions.scala index a499bc4441f3d..539855f5482e9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathfunctions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathfunctions.scala @@ -20,12 +20,14 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.analysis.UnresolvedException import org.apache.spark.sql.types._ -trait MathematicalExpression extends UnaryExpression with Serializable { self: Product => +abstract class MathematicalExpression(name: String) extends UnaryExpression with Serializable { + self: Product => type EvaluatedType = Any override def dataType: DataType = DoubleType override def foldable: Boolean = child.foldable override def nullable: Boolean = true + override def toString: String = s"$name($child)" lazy val numeric = child.dataType match { case n: NumericType => n.numeric.asInstanceOf[Numeric[Any]] @@ -33,22 +35,22 @@ trait MathematicalExpression extends UnaryExpression with Serializable { self: P } } -abstract class MathematicalExpressionForDouble(f: Double => Double) - extends MathematicalExpression { self: Product => - +abstract class MathematicalExpressionForDouble(f: Double => Double, name: String) + extends MathematicalExpression(name) { self: Product => override def eval(input: Row): Any = { val evalE = child.eval(input) if (evalE == null) { null } else { - f(numeric.toDouble(evalE)) + val result = f(numeric.toDouble(evalE)) + if (result.isNaN) null + else result } } } -abstract class MathematicalExpressionForInt(f: Int => Int) - extends MathematicalExpression { self: Product => - +abstract class MathematicalExpressionForInt(f: Int => Int, name: String) + extends MathematicalExpression(name) { self: Product => override def dataType: DataType = IntegerType override def eval(input: Row): Any = { @@ -61,8 +63,8 @@ abstract class MathematicalExpressionForInt(f: Int => Int) } } -abstract class MathematicalExpressionForFloat(f: Float => Float) - extends MathematicalExpression { self: Product => +abstract class MathematicalExpressionForFloat(f: Float => Float, name: String) + extends MathematicalExpression(name) { self: Product => override def dataType: DataType = FloatType @@ -71,13 +73,15 @@ abstract class MathematicalExpressionForFloat(f: Float => Float) if (evalE == null) { null } else { - f(numeric.toFloat(evalE)) + val result = f(numeric.toFloat(evalE)) + if (result.isNaN) null + else result } } } -abstract class MathematicalExpressionForLong(f: Long => Long) - extends MathematicalExpression { self: Product => +abstract class MathematicalExpressionForLong(f: Long => Long, name: String) + extends MathematicalExpression(name) { self: Product => override def dataType: DataType = LongType @@ -91,140 +95,62 @@ abstract class MathematicalExpressionForLong(f: Long => Long) } } -case class Sin(child: Expression) extends MathematicalExpressionForDouble(math.sin) { - override def toString: String = s"SIN($child)" -} +case class Sin(child: Expression) extends MathematicalExpressionForDouble(math.sin, "SIN") -case class Asin(child: Expression) extends MathematicalExpressionForDouble(math.asin) { - override def toString: String = s"ASIN($child)" -} +case class Asin(child: Expression) extends MathematicalExpressionForDouble(math.asin, "ASIN") -case class Sinh(child: Expression) extends MathematicalExpressionForDouble(math.sinh) { - override def toString: String = s"SINH($child)" -} +case class Sinh(child: Expression) extends MathematicalExpressionForDouble(math.sinh, "SINH") -case class Cos(child: Expression) extends MathematicalExpressionForDouble(math.cos) { - override def toString: String = s"COS($child)" -} +case class Cos(child: Expression) extends MathematicalExpressionForDouble(math.cos, "COS") -case class Acos(child: Expression) extends MathematicalExpressionForDouble(math.acos) { - override def toString: String = s"ACOS($child)" -} +case class Acos(child: Expression) extends MathematicalExpressionForDouble(math.acos, "ACOS") -case class Cosh(child: Expression) extends MathematicalExpressionForDouble(math.cosh) { - override def toString: String = s"COSH($child)" -} +case class Cosh(child: Expression) extends MathematicalExpressionForDouble(math.cosh, "COSH") -case class Tan(child: Expression) extends MathematicalExpressionForDouble(math.tan) { - override def toString: String = s"TAN($child)" -} +case class Tan(child: Expression) extends MathematicalExpressionForDouble(math.tan, "TAN") -case class Atan(child: Expression) extends MathematicalExpressionForDouble(math.atan) { - override def toString: String = s"ATAN($child)" -} +case class Atan(child: Expression) extends MathematicalExpressionForDouble(math.atan, "ATAN") -case class Tanh(child: Expression) extends MathematicalExpressionForDouble(math.tanh) { - override def toString: String = s"TANH($child)" -} +case class Tanh(child: Expression) extends MathematicalExpressionForDouble(math.tanh, "TANH") -case class Ceil(child: Expression) extends MathematicalExpressionForDouble(math.ceil) { - override def toString: String = s"CEIL($child)" -} +case class Ceil(child: Expression) extends MathematicalExpressionForDouble(math.ceil, "CEIL") -case class Floor(child: Expression) extends MathematicalExpressionForDouble(math.floor) { - override def toString: String = s"FLOOR($child)" -} +case class Floor(child: Expression) extends MathematicalExpressionForDouble(math.floor, "FLOOR") -case class Rint(child: Expression) extends MathematicalExpressionForDouble(math.rint) { - override def toString: String = s"RINT($child)" -} +case class Rint(child: Expression) extends MathematicalExpressionForDouble(math.rint, "ROUND") -case class Cbrt(child: Expression) extends MathematicalExpressionForDouble(math.cbrt) { - override def toString: String = s"CBRT($child)" -} +case class Cbrt(child: Expression) extends MathematicalExpressionForDouble(math.cbrt, "CBRT") -case class Signum(child: Expression) extends MathematicalExpressionForDouble(math.signum) { - override def toString: String = s"SIGNUM($child)" -} +case class Signum(child: Expression) extends MathematicalExpressionForDouble(math.signum, "SIGNUM") -case class ISignum(child: Expression) extends MathematicalExpressionForInt(math.signum) { - override def toString: String = s"ISIGNUM($child)" -} +case class ISignum(child: Expression) extends MathematicalExpressionForInt(math.signum, "ISIGNUM") -case class FSignum(child: Expression) extends MathematicalExpressionForFloat(math.signum) { - override def toString: String = s"FSIGNUM($child)" -} +case class FSignum(child: Expression) extends MathematicalExpressionForFloat(math.signum, "FSIGNUM") -case class LSignum(child: Expression) extends MathematicalExpressionForLong(math.signum) { - override def toString: String = s"LSIGNUM($child)" -} +case class LSignum(child: Expression) extends MathematicalExpressionForLong(math.signum, "LSIGNUM") -case class ToDegrees(child: Expression) extends MathematicalExpressionForDouble(math.toDegrees) { - override def toString: String = s"TODEG($child)" -} +case class ToDegrees(child: Expression) + extends MathematicalExpressionForDouble(math.toDegrees, "DEGREES") -case class ToRadians(child: Expression) extends MathematicalExpressionForDouble(math.toRadians) { - override def toString: String = s"TORAD($child)" -} +case class ToRadians(child: Expression) + extends MathematicalExpressionForDouble(math.toRadians, "RADIANS") -case class Log(child: Expression) extends MathematicalExpressionForDouble(math.log) { - override def toString: String = s"LOG($child)" +case class Log(child: Expression) extends MathematicalExpressionForDouble(math.log, "LOG") - override def eval(input: Row): Any = { - val evalE = child.eval(input) - if (evalE == null) { - null - } else { - val value = numeric.toDouble(evalE) - if (value < 0) null - else math.log(value) - } - } -} - -case class Log10(child: Expression) extends MathematicalExpressionForDouble(math.log10) { - override def toString: String = s"LOG10($child)" - - override def eval(input: Row): Any = { - val evalE = child.eval(input) - if (evalE == null) { - null - } else { - val value = numeric.toDouble(evalE) - if (value < 0) null - else math.log10(value) - } - } -} - -case class Log1p(child: Expression) extends MathematicalExpressionForDouble(math.log1p) { - override def toString: String = s"LOG1P($child)" +case class Log10(child: Expression) extends MathematicalExpressionForDouble(math.log10, "LOG10") - override def eval(input: Row): Any = { - val evalE = child.eval(input) - if (evalE == null) { - null - } else { - val value = numeric.toDouble(evalE) - if (value < -1) null - else math.log1p(value) - } - } -} +case class Log1p(child: Expression) extends MathematicalExpressionForDouble(math.log1p, "LOG1P") -case class Exp(child: Expression) extends MathematicalExpressionForDouble(math.exp) { - override def toString: String = s"EXP($child)" -} +case class Exp(child: Expression) extends MathematicalExpressionForDouble(math.exp, "EXP") -case class Expm1(child: Expression) extends MathematicalExpressionForDouble(math.expm1) { - override def toString: String = s"EXPM1($child)" -} +case class Expm1(child: Expression) extends MathematicalExpressionForDouble(math.expm1, "EXPM1") -abstract class BinaryMathExpression(f: (Double, Double) => Double) +abstract class BinaryMathExpression(f: (Double, Double) => Double, name: String) extends BinaryFunctionExpression with Serializable { self: Product => type EvaluatedType = Any def nullable: Boolean = left.nullable || right.nullable + override def toString: String = s"$name($left, $right)" override lazy val resolved = left.resolved && right.resolved && @@ -246,27 +172,27 @@ abstract class BinaryMathExpression(f: (Double, Double) => Double) override def eval(input: Row): Any = { val evalE1 = left.eval(input) - if(evalE1 == null) { + if (evalE1 == null) { null } else { val evalE2 = right.eval(input) if (evalE2 == null) { null } else { - f(numeric.toDouble(evalE1), numeric.toDouble(evalE2)) + val result = f(numeric.toDouble(evalE1), numeric.toDouble(evalE2)) + if (result.isNaN) null + else result } } } } -case class Pow(left: Expression, right: Expression) extends BinaryMathExpression(math.pow) { - override def toString: String = s"POW($left, $right)" -} +case class Pow(left: Expression, right: Expression) extends BinaryMathExpression(math.pow, "POWER") -case class Hypot(left: Expression, right: Expression) extends BinaryMathExpression(math.hypot) { - override def toString: String = s"HYPOT($left, $right)" -} +case class Hypot( + left: Expression, + right: Expression) extends BinaryMathExpression(math.hypot, "HYPOT") -case class Atan2(left: Expression, right: Expression) extends BinaryMathExpression(math.atan2) { - override def toString: String = s"ATAN2($left, $right)" -} +case class Atan2( + left: Expression, + right: Expression) extends BinaryMathExpression(math.atan2, "ATAN2") diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala index 76298f03c94ae..7a30d49c44f96 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala @@ -1152,6 +1152,170 @@ class ExpressionEvaluationSuite extends ExpressionEvaluationBaseSuite { checkEvaluation(c1 ^ c2, 3, row) checkEvaluation(~c1, -2, row) } + + /** + * Used for testing math functions for DataFrames. + * @param c The DataFrame function + * @param f The functions in scala.math + * @param domain The set of values to run the function with + * @param expectNull Whether the given values should return null or not + * @tparam T Generic type for primitives + */ + def unaryMathFunctionEvaluation[@specialized(Int, Double, Float, Long) T]( + c: Expression => Expression, + f: T => T, + domain: Iterable[T] = (-20 to 20).map(_ * 0.1), + expectNull: Boolean = false): Unit = { + if (expectNull) { + domain.foreach { value => + checkEvaluation(c(Literal(value)), null, EmptyRow) + } + } else { + domain.foreach { value => + checkEvaluation(c(Literal(value)), f(value), EmptyRow) + } + } + checkEvaluation(c(Literal.create(null, DoubleType)), null, create_row(null)) + } + + test("sin") { + unaryMathFunctionEvaluation(Sin, math.sin) + } + + test("asin") { + unaryMathFunctionEvaluation(Asin, math.asin, (-10 to 10).map(_ * 0.1)) + unaryMathFunctionEvaluation(Asin, math.asin, (11 to 20).map(_ * 0.1), true) + } + + test("sinh") { + unaryMathFunctionEvaluation(Sinh, math.sinh) + } + + test("cos") { + unaryMathFunctionEvaluation(Cos, math.cos) + } + + test("acos") { + unaryMathFunctionEvaluation(Acos, math.acos, (-10 to 10).map(_ * 0.1)) + unaryMathFunctionEvaluation(Acos, math.acos, (11 to 20).map(_ * 0.1), true) + } + + test("cosh") { + unaryMathFunctionEvaluation(Cosh, math.cosh) + } + + test("tan") { + unaryMathFunctionEvaluation(Tan, math.tan) + } + + test("atan") { + unaryMathFunctionEvaluation(Atan, math.atan) + } + + test("tanh") { + unaryMathFunctionEvaluation(Tanh, math.tanh) + } + + test("toDeg") { + unaryMathFunctionEvaluation(ToDegrees, math.toDegrees) + } + + test("toRad") { + unaryMathFunctionEvaluation(ToRadians, math.toRadians) + } + + test("cbrt") { + unaryMathFunctionEvaluation(Cbrt, math.cbrt) + } + + test("ceil") { + unaryMathFunctionEvaluation(Ceil, math.ceil) + } + + test("floor") { + unaryMathFunctionEvaluation(Floor, math.floor) + } + + test("rint") { + unaryMathFunctionEvaluation(Rint, math.rint) + } + + test("exp") { + unaryMathFunctionEvaluation(Exp, math.exp) + } + + test("expm1") { + unaryMathFunctionEvaluation(Expm1, math.expm1) + } + + test("signum") { + unaryMathFunctionEvaluation[Double](Signum, math.signum) + } + + test("isignum") { + unaryMathFunctionEvaluation[Int](ISignum, math.signum, (-5 to 5)) + } + + test("fsignum") { + unaryMathFunctionEvaluation[Float](FSignum, math.signum, (-5 to 5).map(_.toFloat)) + } + + test("lsignum") { + unaryMathFunctionEvaluation[Long](LSignum, math.signum, (5 to 5).map(_.toLong)) + } + + test("log") { + unaryMathFunctionEvaluation(Log, math.log, (0 to 20).map(_ * 0.1)) + unaryMathFunctionEvaluation(Log, math.log, (-5 to -1).map(_ * 0.1), true) + } + + test("log10") { + unaryMathFunctionEvaluation(Log10, math.log10, (0 to 20).map(_ * 0.1)) + unaryMathFunctionEvaluation(Log10, math.log10, (-5 to -1).map(_ * 0.1), true) + } + + test("log1p") { + unaryMathFunctionEvaluation(Log1p, math.log1p, (-1 to 20).map(_ * 0.1)) + unaryMathFunctionEvaluation(Log1p, math.log1p, (-10 to -2).map(_ * 1.0), true) + } + + /** + * Used for testing math functions for DataFrames. + * @param c The DataFrame function + * @param f The functions in scala.math + * @param domain The set of values to run the function with + */ + def binaryMathFunctionEvaluation( + c: (Expression, Expression) => Expression, + f: (Double, Double) => Double, + domain: Iterable[(Double, Double)] = (-20 to 20).map(v => (v * 0.1, v * -0.1)), + expectNull: Boolean = false): Unit = { + if (expectNull) { + domain.foreach { case (v1, v2) => + checkEvaluation(c(v1, v2), null, create_row(null)) + } + } else { + domain.foreach { case (v1, v2) => + checkEvaluation(c(v1, v2), f(v1, v2), EmptyRow) + checkEvaluation(c(v2, v1), f(v2, v1), EmptyRow) + } + } + checkEvaluation(c(Literal.create(null, DoubleType), 1.0), null, create_row(null)) + checkEvaluation(c(1.0, Literal.create(null, DoubleType)), null, create_row(null)) + } + + test("pow") { + binaryMathFunctionEvaluation(Pow, math.pow, (-5 to 5).map(v => (v * 1.0, v * 1.0))) + binaryMathFunctionEvaluation(Pow, math.pow, Seq((-1.0, 0.9), (-2.2, 1.7), (-2.2, -1.7)), true) + } + + test("hypot") { + binaryMathFunctionEvaluation(Hypot, math.hypot) + } + + test("atan2") { + binaryMathFunctionEvaluation(Atan2, math.atan2) + } } // TODO: Make the tests work with codegen. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/mathfunctions.scala b/sql/core/src/main/scala/org/apache/spark/sql/mathfunctions.scala index df94380935f83..4f91db9d7dba0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/mathfunctions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/mathfunctions.scala @@ -21,6 +21,7 @@ import scala.language.implicitConversions import org.apache.spark.annotation.Experimental import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.functions.lit /** * :: Experimental :: @@ -408,6 +409,34 @@ object mathfunctions { */ def pow(leftName: String, rightName: String): Column = pow(Column(leftName), Column(rightName)) + /** + * Returns the value of the first argument raised to the power of the second argument. + * + * @group double_funcs + */ + def pow(l: Column, r: Double): Column = pow(l, lit(r).expr) + + /** + * Returns the value of the first argument raised to the power of the second argument. + * + * @group double_funcs + */ + def pow(leftName: String, r: Double): Column = pow(Column(leftName), r) + + /** + * Returns the value of the first argument raised to the power of the second argument. + * + * @group double_funcs + */ + def pow(l: Double, r: Column): Column = pow(lit(l).expr, r) + + /** + * Returns the value of the first argument raised to the power of the second argument. + * + * @group double_funcs + */ + def pow(l: Double, rightName: String): Column = pow(l, Column(rightName)) + /** * Computes sqrt(a^2^ + b^2^) without intermediate overflow or underflow. * @@ -437,6 +466,34 @@ object mathfunctions { def hypot(leftName: String, rightName: String): Column = hypot(Column(leftName), Column(rightName)) + /** + * Computes sqrt(a^2^ + b^2^) without intermediate overflow or underflow. + * + * @group double_funcs + */ + def hypot(l: Column, r: Double): Column = hypot(l, lit(r).expr) + + /** + * Computes sqrt(a^2^ + b^2^) without intermediate overflow or underflow. + * + * @group double_funcs + */ + def hypot(leftName: String, r: Double): Column = hypot(Column(leftName), r) + + /** + * Computes sqrt(a^2^ + b^2^) without intermediate overflow or underflow. + * + * @group double_funcs + */ + def hypot(l: Double, r: Column): Column = hypot(lit(l).expr, r) + + /** + * Computes sqrt(a^2^ + b^2^) without intermediate overflow or underflow. + * + * @group double_funcs + */ + def hypot(l: Double, rightName: String): Column = hypot(l, Column(rightName)) + /** * Returns the angle theta from the conversion of rectangular coordinates (x, y) to * polar coordinates (r, theta). @@ -469,4 +526,36 @@ object mathfunctions { */ def atan2(leftName: String, rightName: String): Column = atan2(Column(leftName), Column(rightName)) + + /** + * Returns the angle theta from the conversion of rectangular coordinates (x, y) to + * polar coordinates (r, theta). + * + * @group double_funcs + */ + def atan2(l: Column, r: Double): Column = atan2(l, lit(r).expr) + + /** + * Returns the angle theta from the conversion of rectangular coordinates (x, y) to + * polar coordinates (r, theta). + * + * @group double_funcs + */ + def atan2(leftName: String, r: Double): Column = atan2(Column(leftName), r) + + /** + * Returns the angle theta from the conversion of rectangular coordinates (x, y) to + * polar coordinates (r, theta). + * + * @group double_funcs + */ + def atan2(l: Double, r: Column): Column = atan2(lit(l).expr, r) + + /** + * Returns the angle theta from the conversion of rectangular coordinates (x, y) to + * polar coordinates (r, theta). + * + * @group double_funcs + */ + def atan2(l: Double, rightName: String): Column = atan2(l, Column(rightName)) } diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java index 6d0fbe83c2f36..4a2619ef7e926 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java @@ -36,6 +36,7 @@ import org.apache.spark.sql.types.*; import static org.apache.spark.sql.functions.*; +import static org.apache.spark.sql.mathfunctions.*; public class JavaDataFrameSuite { private transient JavaSparkContext jsc; @@ -93,6 +94,14 @@ public void testVarargMethods() { df.groupBy().agg(countDistinct("key", "value")); df.groupBy().agg(countDistinct(col("key"), col("value"))); df.select(coalesce(col("key"))); + + // Varargs with mathfunctions + DataFrame df2 = context.table("doubleData"); + df2.select(exp("a"), exp("b")); + df2.select(exp(log("a"))); + df2.select(pow("a", "a"), pow("b", 2.0)); + df2.select(pow(col("a"), col("b")), exp("b")); + df2.select(sin("a"), acos("b")); } @Ignore diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala index 88c3acf8fc8aa..b6b77092c6cf6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala @@ -436,6 +436,7 @@ class ColumnExpressionSuite extends QueryTest { def testTwoToOneMathFunction( c: (Column, Column) => Column, + d: (Column, Double) => Column, f: (Double, Double) => Double): Unit = { checkAnswer( nnDoubleData.select(c('a, 'a)).orderBy('a.asc), @@ -447,6 +448,16 @@ class ColumnExpressionSuite extends QueryTest { nnDoubleData.collect().toSeq.map(r => Row(f(r.getDouble(0), r.getDouble(1)))) ) + checkAnswer( + nnDoubleData.select(d('a, 2.0)).orderBy('a.asc), + nnDoubleData.collect().toSeq.map(r => Row(f(r.getDouble(0), 2.0))) + ) + + checkAnswer( + nnDoubleData.select(d('a, -0.5)).orderBy('a.asc), + nnDoubleData.collect().toSeq.map(r => Row(f(r.getDouble(0), -0.5))) + ) + val nonNull = nullDoubles.collect().toSeq.filter(r => r.get(0) != null) checkAnswer( @@ -456,15 +467,15 @@ class ColumnExpressionSuite extends QueryTest { } test("pow") { - testTwoToOneMathFunction(pow, math.pow) + testTwoToOneMathFunction(pow, pow, math.pow) } test("hypot") { - testTwoToOneMathFunction(hypot, math.hypot) + testTwoToOneMathFunction(hypot, hypot, math.hypot) } test("atan2") { - testTwoToOneMathFunction(atan2, math.atan2) + testTwoToOneMathFunction(atan2, atan2, math.atan2) } def testOneToOneNonNegativeMathFunction(c: Column => Column, f: Double => Double): Unit = { From 029e739cfad92e565c465e9e2655c59be2263c32 Mon Sep 17 00:00:00 2001 From: Burak Yavuz Date: Wed, 22 Apr 2015 14:34:45 -0700 Subject: [PATCH 06/12] fixed atan2 test --- .../catalyst/expressions/mathfunctions.scala | 19 ++++++++++++++++++- .../ExpressionEvaluationSuite.scala | 4 ++-- 2 files changed, 20 insertions(+), 3 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathfunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathfunctions.scala index 539855f5482e9..984027f762bac 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathfunctions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathfunctions.scala @@ -195,4 +195,21 @@ case class Hypot( case class Atan2( left: Expression, - right: Expression) extends BinaryMathExpression(math.atan2, "ATAN2") + right: Expression) extends BinaryMathExpression(math.atan2, "ATAN2") { + override def eval(input: Row): Any = { + val evalE1 = left.eval(input) + if (evalE1 == null) { + null + } else { + val evalE2 = right.eval(input) + if (evalE2 == null) { + null + } else { + // With codegen, the values returned by -0.0 and 0.0 are different. Handled with +0.0 + val result = math.atan2(numeric.toDouble(evalE1) + 0.0, numeric.toDouble(evalE2) + 0.0) + if (result.isNaN) null + else result + } + } + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala index 7a30d49c44f96..9f126e81a0091 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala @@ -1296,8 +1296,8 @@ class ExpressionEvaluationSuite extends ExpressionEvaluationBaseSuite { } } else { domain.foreach { case (v1, v2) => - checkEvaluation(c(v1, v2), f(v1, v2), EmptyRow) - checkEvaluation(c(v2, v1), f(v2, v1), EmptyRow) + checkEvaluation(c(v1, v2), f(v1 + 0.0, v2 + 0.0), EmptyRow) + checkEvaluation(c(v2, v1), f(v2 + 0.0, v1 + 0.0), EmptyRow) } } checkEvaluation(c(Literal.create(null, DoubleType), 1.0), null, create_row(null)) From b084e1055274f8fb87c794a95d56e5b7850200bd Mon Sep 17 00:00:00 2001 From: Burak Yavuz Date: Mon, 27 Apr 2015 17:53:03 -0700 Subject: [PATCH 07/12] Addressed code review --- .../catalyst/analysis/ExpectsInputTypes.scala | 13 ++ .../catalyst/analysis/HiveTypeCoercion.scala | 26 +++ .../sql/catalyst/expressions/Expression.scala | 5 + .../expressions/mathfuncs/binary.scala | 94 ++++++++ .../unary.scala} | 190 ++++++--------- .../ExpressionEvaluationSuite.scala | 1 + .../org/apache/spark/sql/mathfunctions.scala | 1 + .../spark/sql/ColumnExpressionSuite.scala | 183 --------------- .../spark/sql/MathExpressionsSuite.scala | 216 ++++++++++++++++++ .../scala/org/apache/spark/sql/TestData.scala | 20 -- 10 files changed, 429 insertions(+), 320 deletions(-) create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ExpectsInputTypes.scala create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathfuncs/binary.scala rename sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/{mathfunctions.scala => mathfuncs/unary.scala} (50%) create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ExpectsInputTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ExpectsInputTypes.scala new file mode 100644 index 0000000000000..b5fe031448004 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ExpectsInputTypes.scala @@ -0,0 +1,13 @@ +package org.apache.spark.sql.catalyst.analysis + +import org.apache.spark.sql.types.DataType + +/** + * For expressions that require a specific `DataType` as input should implement this trait + * so that the proper type conversions can be performed in the analyzer. + */ +trait ExpectsInputTypes { + + def expectedChildTypes: Seq[DataType] + +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala index 35c7f00d4e42a..c31832cb68996 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala @@ -79,6 +79,7 @@ trait HiveTypeCoercion { CaseWhenCoercion :: Division :: PropagateTypes :: + ExpectedInputConversion :: Nil /** @@ -643,4 +644,29 @@ trait HiveTypeCoercion { } } + /** + * Casts types according to the expected input types for Expressions that have the trait + * `ExpectsInputTypes`. + */ + object ExpectedInputConversion extends Rule[LogicalPlan] { + + def apply(plan: LogicalPlan): LogicalPlan = plan transform { + case q: LogicalPlan => q transformExpressions { + // Skip nodes who's children have not been resolved yet. + case e if !e.childrenResolved => e + + case e: ExpectsInputTypes if e.children.map(_.dataType) != e.expectedChildTypes => + val newC = (e.children, e.children.map(_.dataType), e.expectedChildTypes).zipped.map { + case (child, actual, expected) => + if (actual == expected) { + child + } else { + Cast(child, expected) + } + } + e.withNewChildren(newC) + } + } + } + } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index b56a1815b9037..3b0ca9f1a2085 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -89,6 +89,11 @@ abstract class BinaryExpression extends Expression with trees.BinaryNode[Express override def toString: String = s"($left $symbol $right)" } +/** + * This class is for expressions that use math functions like `pow`, `max`, `hypot`, and `atan2`, + * which can't be expressed as `$left $symbol $right`. They are expressed as + * `$func($left, $right)`. + */ abstract class BinaryFunctionExpression extends Expression with trees.BinaryNode[Expression] { self: Product => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathfuncs/binary.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathfuncs/binary.scala new file mode 100644 index 0000000000000..1febe56b2f0f7 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathfuncs/binary.scala @@ -0,0 +1,94 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions.mathfuncs + +import math._ + +import org.apache.spark.sql.catalyst.analysis.{ExpectsInputTypes, UnresolvedException} +import org.apache.spark.sql.catalyst.expressions.{BinaryFunctionExpression, Expression, Row} +import org.apache.spark.sql.types._ + +/** + * A binary expression specifically for math functions that take two `Double`s as input and returns + * a `Double`. + * @param f The math function. + * @param name The short name of the function + */ +abstract class BinaryMathExpression(f: (Double, Double) => Double, name: String) + extends BinaryFunctionExpression with Serializable with ExpectsInputTypes { self: Product => + type EvaluatedType = Any + + override def expectedChildTypes: Seq[DataType] = Seq(DoubleType, DoubleType) + + override def nullable: Boolean = left.nullable || right.nullable + override def toString: String = s"$name($left, $right)" + + override lazy val resolved = + left.resolved && right.resolved && + left.dataType == right.dataType && + !DecimalType.isFixed(left.dataType) + + override def dataType: DataType = { + if (!resolved) { + throw new UnresolvedException(this, + s"datatype. Can not resolve due to differing types ${left.dataType}, ${right.dataType}") + } + left.dataType + } + + override def eval(input: Row): Any = { + val evalE1 = left.eval(input) + if (evalE1 == null) { + null + } else { + val evalE2 = right.eval(input) + if (evalE2 == null) { + null + } else { + val result = f(evalE1.asInstanceOf[Double], evalE2.asInstanceOf[Double]) + if (result.isNaN) null else result + } + } + } +} + +case class Pow(left: Expression, right: Expression) extends BinaryMathExpression(pow, "POWER") + +case class Hypot( + left: Expression, + right: Expression) extends BinaryMathExpression(hypot, "HYPOT") + +case class Atan2( + left: Expression, + right: Expression) extends BinaryMathExpression(atan2, "ATAN2") { + override def eval(input: Row): Any = { + val evalE1 = left.eval(input) + if (evalE1 == null) { + null + } else { + val evalE2 = right.eval(input) + if (evalE2 == null) { + null + } else { + // With codegen, the values returned by -0.0 and 0.0 are different. Handled with +0.0 + val result = atan2(evalE1.asInstanceOf[Double] + 0.0, evalE2.asInstanceOf[Double] + 0.0) + if (result.isNaN) null else result + } + } + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathfunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathfuncs/unary.scala similarity index 50% rename from sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathfunctions.scala rename to sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathfuncs/unary.scala index 984027f762bac..ada478e20a026 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathfunctions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathfuncs/unary.scala @@ -15,12 +15,21 @@ * limitations under the License. */ -package org.apache.spark.sql.catalyst.expressions +package org.apache.spark.sql.catalyst.expressions.mathfuncs -import org.apache.spark.sql.catalyst.analysis.UnresolvedException +import math._ + +import org.apache.spark.sql.catalyst.analysis.ExpectsInputTypes +import org.apache.spark.sql.catalyst.expressions.{Expression, Row, UnaryExpression} import org.apache.spark.sql.types._ -abstract class MathematicalExpression(name: String) extends UnaryExpression with Serializable { +/** + * A unary expression specifically for math functions. Math Functions expect a specific type of + * input format, therefore these functions extend `ExpectsInputTypes`. + * @param name The short name of the function + */ +abstract class MathematicalExpression(name: String) + extends UnaryExpression with Serializable with ExpectsInputTypes { self: Product => type EvaluatedType = Any @@ -28,188 +37,135 @@ abstract class MathematicalExpression(name: String) extends UnaryExpression with override def foldable: Boolean = child.foldable override def nullable: Boolean = true override def toString: String = s"$name($child)" - - lazy val numeric = child.dataType match { - case n: NumericType => n.numeric.asInstanceOf[Numeric[Any]] - case other => sys.error(s"Type $other does not support numeric operations") - } } +/** + * A unary expression specifically for math functions that take a `Double` as input and return + * a `Double`. + * @param f The math function. + * @param name The short name of the function + */ abstract class MathematicalExpressionForDouble(f: Double => Double, name: String) extends MathematicalExpression(name) { self: Product => + + override def expectedChildTypes: Seq[DataType] = Seq(DoubleType) + override def eval(input: Row): Any = { val evalE = child.eval(input) if (evalE == null) { null } else { - val result = f(numeric.toDouble(evalE)) - if (result.isNaN) null - else result + val result = f(evalE.asInstanceOf[Double]) + if (result.isNaN) null else result } } } +/** + * A unary expression specifically for math functions that take an `Int` as input and return + * an `Int`. + * @param f The math function. + * @param name The short name of the function + */ abstract class MathematicalExpressionForInt(f: Int => Int, name: String) extends MathematicalExpression(name) { self: Product => + override def dataType: DataType = IntegerType + override def expectedChildTypes: Seq[DataType] = Seq(IntegerType) override def eval(input: Row): Any = { val evalE = child.eval(input) - if (evalE == null) { - null - } else { - f(numeric.toInt(evalE)) - } + if (evalE == null) null else f(evalE.asInstanceOf[Int]) } } +/** + * A unary expression specifically for math functions that take a `Float` as input and return + * a `Float`. + * @param f The math function. + * @param name The short name of the function + */ abstract class MathematicalExpressionForFloat(f: Float => Float, name: String) extends MathematicalExpression(name) { self: Product => override def dataType: DataType = FloatType + override def expectedChildTypes: Seq[DataType] = Seq(FloatType) override def eval(input: Row): Any = { val evalE = child.eval(input) if (evalE == null) { null } else { - val result = f(numeric.toFloat(evalE)) - if (result.isNaN) null - else result + val result = f(evalE.asInstanceOf[Float]) + if (result.isNaN) null else result } } } +/** + * A unary expression specifically for math functions that take a `Long` as input and return + * a `Long`. + * @param f The math function. + * @param name The short name of the function + */ abstract class MathematicalExpressionForLong(f: Long => Long, name: String) extends MathematicalExpression(name) { self: Product => override def dataType: DataType = LongType + override def expectedChildTypes: Seq[DataType] = Seq(LongType) override def eval(input: Row): Any = { val evalE = child.eval(input) - if (evalE == null) { - null - } else { - f(numeric.toLong(evalE)) - } + if (evalE == null) null else f(evalE.asInstanceOf[Long]) } } -case class Sin(child: Expression) extends MathematicalExpressionForDouble(math.sin, "SIN") +case class Sin(child: Expression) extends MathematicalExpressionForDouble(sin, "SIN") -case class Asin(child: Expression) extends MathematicalExpressionForDouble(math.asin, "ASIN") +case class Asin(child: Expression) extends MathematicalExpressionForDouble(asin, "ASIN") -case class Sinh(child: Expression) extends MathematicalExpressionForDouble(math.sinh, "SINH") +case class Sinh(child: Expression) extends MathematicalExpressionForDouble(sinh, "SINH") -case class Cos(child: Expression) extends MathematicalExpressionForDouble(math.cos, "COS") +case class Cos(child: Expression) extends MathematicalExpressionForDouble(cos, "COS") -case class Acos(child: Expression) extends MathematicalExpressionForDouble(math.acos, "ACOS") +case class Acos(child: Expression) extends MathematicalExpressionForDouble(acos, "ACOS") -case class Cosh(child: Expression) extends MathematicalExpressionForDouble(math.cosh, "COSH") +case class Cosh(child: Expression) extends MathematicalExpressionForDouble(cosh, "COSH") -case class Tan(child: Expression) extends MathematicalExpressionForDouble(math.tan, "TAN") +case class Tan(child: Expression) extends MathematicalExpressionForDouble(tan, "TAN") -case class Atan(child: Expression) extends MathematicalExpressionForDouble(math.atan, "ATAN") +case class Atan(child: Expression) extends MathematicalExpressionForDouble(atan, "ATAN") -case class Tanh(child: Expression) extends MathematicalExpressionForDouble(math.tanh, "TANH") +case class Tanh(child: Expression) extends MathematicalExpressionForDouble(tanh, "TANH") -case class Ceil(child: Expression) extends MathematicalExpressionForDouble(math.ceil, "CEIL") +case class Ceil(child: Expression) extends MathematicalExpressionForDouble(ceil, "CEIL") -case class Floor(child: Expression) extends MathematicalExpressionForDouble(math.floor, "FLOOR") +case class Floor(child: Expression) extends MathematicalExpressionForDouble(floor, "FLOOR") -case class Rint(child: Expression) extends MathematicalExpressionForDouble(math.rint, "ROUND") +case class Rint(child: Expression) extends MathematicalExpressionForDouble(rint, "ROUND") -case class Cbrt(child: Expression) extends MathematicalExpressionForDouble(math.cbrt, "CBRT") +case class Cbrt(child: Expression) extends MathematicalExpressionForDouble(cbrt, "CBRT") -case class Signum(child: Expression) extends MathematicalExpressionForDouble(math.signum, "SIGNUM") +case class Signum(child: Expression) extends MathematicalExpressionForDouble(signum, "SIGNUM") -case class ISignum(child: Expression) extends MathematicalExpressionForInt(math.signum, "ISIGNUM") +case class ISignum(child: Expression) extends MathematicalExpressionForInt(signum, "ISIGNUM") -case class FSignum(child: Expression) extends MathematicalExpressionForFloat(math.signum, "FSIGNUM") +case class FSignum(child: Expression) extends MathematicalExpressionForFloat(signum, "FSIGNUM") -case class LSignum(child: Expression) extends MathematicalExpressionForLong(math.signum, "LSIGNUM") +case class LSignum(child: Expression) extends MathematicalExpressionForLong(signum, "LSIGNUM") case class ToDegrees(child: Expression) - extends MathematicalExpressionForDouble(math.toDegrees, "DEGREES") + extends MathematicalExpressionForDouble(toDegrees, "DEGREES") case class ToRadians(child: Expression) - extends MathematicalExpressionForDouble(math.toRadians, "RADIANS") - -case class Log(child: Expression) extends MathematicalExpressionForDouble(math.log, "LOG") - -case class Log10(child: Expression) extends MathematicalExpressionForDouble(math.log10, "LOG10") + extends MathematicalExpressionForDouble(toRadians, "RADIANS") -case class Log1p(child: Expression) extends MathematicalExpressionForDouble(math.log1p, "LOG1P") +case class Log(child: Expression) extends MathematicalExpressionForDouble(log, "LOG") -case class Exp(child: Expression) extends MathematicalExpressionForDouble(math.exp, "EXP") +case class Log10(child: Expression) extends MathematicalExpressionForDouble(log10, "LOG10") -case class Expm1(child: Expression) extends MathematicalExpressionForDouble(math.expm1, "EXPM1") - -abstract class BinaryMathExpression(f: (Double, Double) => Double, name: String) - extends BinaryFunctionExpression with Serializable { self: Product => - type EvaluatedType = Any +case class Log1p(child: Expression) extends MathematicalExpressionForDouble(log1p, "LOG1P") - def nullable: Boolean = left.nullable || right.nullable - override def toString: String = s"$name($left, $right)" +case class Exp(child: Expression) extends MathematicalExpressionForDouble(exp, "EXP") - override lazy val resolved = - left.resolved && right.resolved && - left.dataType == right.dataType && - !DecimalType.isFixed(left.dataType) - - def dataType: DataType = { - if (!resolved) { - throw new UnresolvedException(this, - s"datatype. Can not resolve due to differing types ${left.dataType}, ${right.dataType}") - } - left.dataType - } - - lazy val numeric = dataType match { - case n: NumericType => n.numeric.asInstanceOf[Numeric[Any]] - case other => sys.error(s"Type $other does not support numeric operations") - } - - override def eval(input: Row): Any = { - val evalE1 = left.eval(input) - if (evalE1 == null) { - null - } else { - val evalE2 = right.eval(input) - if (evalE2 == null) { - null - } else { - val result = f(numeric.toDouble(evalE1), numeric.toDouble(evalE2)) - if (result.isNaN) null - else result - } - } - } -} - -case class Pow(left: Expression, right: Expression) extends BinaryMathExpression(math.pow, "POWER") - -case class Hypot( - left: Expression, - right: Expression) extends BinaryMathExpression(math.hypot, "HYPOT") - -case class Atan2( - left: Expression, - right: Expression) extends BinaryMathExpression(math.atan2, "ATAN2") { - override def eval(input: Row): Any = { - val evalE1 = left.eval(input) - if (evalE1 == null) { - null - } else { - val evalE2 = right.eval(input) - if (evalE2 == null) { - null - } else { - // With codegen, the values returned by -0.0 and 0.0 are different. Handled with +0.0 - val result = math.atan2(numeric.toDouble(evalE1) + 0.0, numeric.toDouble(evalE2) + 0.0) - if (result.isNaN) null - else result - } - } - } -} +case class Expm1(child: Expression) extends MathematicalExpressionForDouble(expm1, "EXPM1") diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala index 9f126e81a0091..5390ce43c6639 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala @@ -28,6 +28,7 @@ import org.scalatest.Matchers._ import org.apache.spark.sql.catalyst.CatalystTypeConverters import org.apache.spark.sql.catalyst.analysis.UnresolvedGetField import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.expressions.mathfuncs._ import org.apache.spark.sql.types._ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/mathfunctions.scala b/sql/core/src/main/scala/org/apache/spark/sql/mathfunctions.scala index 4f91db9d7dba0..952ea1d26801b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/mathfunctions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/mathfunctions.scala @@ -21,6 +21,7 @@ import scala.language.implicitConversions import org.apache.spark.annotation.Experimental import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.mathfuncs._ import org.apache.spark.sql.functions.lit /** diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala index b6b77092c6cf6..d99d428879063 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala @@ -18,7 +18,6 @@ package org.apache.spark.sql import org.apache.spark.sql.functions._ -import org.apache.spark.sql.mathfunctions._ import org.apache.spark.sql.test.TestSQLContext import org.apache.spark.sql.test.TestSQLContext.implicits._ import org.apache.spark.sql.types._ @@ -331,186 +330,4 @@ class ColumnExpressionSuite extends QueryTest { assert(schema("value").metadata === Metadata.empty) assert(schema("abc").metadata === metadata) } - - def testOneToOneMathFunction[@specialized(Int, Double, Float, Long) T] - (c: Column => Column, f: T => T): Unit = { - checkAnswer( - doubleData.select(c('a)).orderBy('a.asc), - (1 to 100).map(n => Row(f((n * 0.02 - 1).asInstanceOf[T]))) - ) - - checkAnswer( - doubleData.select(c('b)).orderBy('b.desc), - (1 to 100).map(n => Row(f((-n * 0.02 + 1).asInstanceOf[T]))) - ) - - checkAnswer( - doubleData.select(c(lit(null))), - (1 to 100).map(_ => Row(null)) - ) - } - - test("sin") { - testOneToOneMathFunction(sin, math.sin) - } - - test("asin") { - testOneToOneMathFunction(asin, math.asin) - } - - test("sinh") { - testOneToOneMathFunction(sinh, math.sinh) - } - - test("cos") { - testOneToOneMathFunction(cos, math.cos) - } - - test("acos") { - testOneToOneMathFunction(acos, math.acos) - } - - test("cosh") { - testOneToOneMathFunction(cosh, math.cosh) - } - - test("tan") { - testOneToOneMathFunction(tan, math.tan) - } - - test("atan") { - testOneToOneMathFunction(atan, math.atan) - } - - test("tanh") { - testOneToOneMathFunction(tanh, math.tanh) - } - - test("toDeg") { - testOneToOneMathFunction(toDeg, math.toDegrees) - } - - test("toRad") { - testOneToOneMathFunction(toRad, math.toRadians) - } - - test("cbrt") { - testOneToOneMathFunction(cbrt, math.cbrt) - } - - test("ceil") { - testOneToOneMathFunction(ceil, math.ceil) - } - - test("floor") { - testOneToOneMathFunction(floor, math.floor) - } - - test("rint") { - testOneToOneMathFunction(rint, math.rint) - } - - test("exp") { - testOneToOneMathFunction(exp, math.exp) - } - - test("expm1") { - testOneToOneMathFunction(expm1, math.expm1) - } - - test("signum") { - testOneToOneMathFunction[Double](signum, math.signum) - } - - test("isignum") { - testOneToOneMathFunction[Int](isignum, math.signum) - } - - test("fsignum") { - testOneToOneMathFunction[Float](fsignum, math.signum) - } - - test("lsignum") { - testOneToOneMathFunction[Long](lsignum, math.signum) - } - - def testTwoToOneMathFunction( - c: (Column, Column) => Column, - d: (Column, Double) => Column, - f: (Double, Double) => Double): Unit = { - checkAnswer( - nnDoubleData.select(c('a, 'a)).orderBy('a.asc), - nnDoubleData.collect().toSeq.map(r => Row(f(r.getDouble(0), r.getDouble(0)))) - ) - - checkAnswer( - nnDoubleData.select(c('a, 'b)).orderBy('a.asc), - nnDoubleData.collect().toSeq.map(r => Row(f(r.getDouble(0), r.getDouble(1)))) - ) - - checkAnswer( - nnDoubleData.select(d('a, 2.0)).orderBy('a.asc), - nnDoubleData.collect().toSeq.map(r => Row(f(r.getDouble(0), 2.0))) - ) - - checkAnswer( - nnDoubleData.select(d('a, -0.5)).orderBy('a.asc), - nnDoubleData.collect().toSeq.map(r => Row(f(r.getDouble(0), -0.5))) - ) - - val nonNull = nullDoubles.collect().toSeq.filter(r => r.get(0) != null) - - checkAnswer( - nullDoubles.select(c('a, 'a)).orderBy('a.asc), - Row(null) +: nonNull.map(r => Row(f(r.getDouble(0), r.getDouble(0)))) - ) - } - - test("pow") { - testTwoToOneMathFunction(pow, pow, math.pow) - } - - test("hypot") { - testTwoToOneMathFunction(hypot, hypot, math.hypot) - } - - test("atan2") { - testTwoToOneMathFunction(atan2, atan2, math.atan2) - } - - def testOneToOneNonNegativeMathFunction(c: Column => Column, f: Double => Double): Unit = { - checkAnswer( - testData.select(c('key)).orderBy('key.asc), - (1 to 100).map(n => Row(f(n))) - ) - - if (f(-1) === math.log1p(-1)) { - checkAnswer( - negativeData.select(c('key)).orderBy('key.desc), - Row(Double.NegativeInfinity) +: (2 to 100).map(n => Row(null)) - ) - } else { - checkAnswer( - negativeData.select(c('key)).orderBy('key.desc), - (1 to 100).map(n => Row(null)) - ) - } - - checkAnswer( - testData.select(c(lit(null))), - (1 to 100).map(_ => Row(null)) - ) - } - - test("log") { - testOneToOneNonNegativeMathFunction(log, math.log) - } - - test("log10") { - testOneToOneNonNegativeMathFunction(log10, math.log10) - } - - test("log1p") { - testOneToOneNonNegativeMathFunction(log1p, math.log1p) - } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala new file mode 100644 index 0000000000000..45f141316fce3 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala @@ -0,0 +1,216 @@ +package org.apache.spark.sql + +import java.lang.{Double => JavaDouble} + +import org.apache.spark.sql.functions.lit +import org.apache.spark.sql.mathfunctions._ +import org.apache.spark.sql.test.TestSQLContext +import org.apache.spark.sql.test.TestSQLContext.implicits._ + +private[this] object MathExpressionsTestData { + + case class DoubleData(a: JavaDouble, b: JavaDouble) + val doubleData = TestSQLContext.sparkContext.parallelize( + (1 to 10).map(i => DoubleData(i * 0.2 - 1, i * -0.2 + 1))).toDF() + + val nnDoubleData = TestSQLContext.sparkContext.parallelize( + (1 to 10).map(i => DoubleData(i * 0.1, i * -0.1))).toDF() + + case class NullDoubles(a: JavaDouble) + val nullDoubles = + TestSQLContext.sparkContext.parallelize( + NullDoubles(1.0) :: + NullDoubles(2.0) :: + NullDoubles(3.0) :: + NullDoubles(null) :: Nil + ).toDF() +} + +class MathExpressionsSuite extends QueryTest { + + import MathExpressionsTestData._ + + def testOneToOneMathFunction[@specialized(Int, Long, Float, Double) T]( + c: Column => Column, + f: T => T): Unit = { + checkAnswer( + doubleData.select(c('a)), + (1 to 10).map(n => Row(f((n * 0.2 - 1).asInstanceOf[T]))) + ) + + checkAnswer( + doubleData.select(c('b)), + (1 to 10).map(n => Row(f((-n * 0.2 + 1).asInstanceOf[T]))) + ) + + checkAnswer( + doubleData.select(c(lit(null))), + (1 to 10).map(_ => Row(null)) + ) + } + + def testOneToOneNonNegativeMathFunction(c: Column => Column, f: Double => Double): Unit = { + checkAnswer( + nnDoubleData.select(c('a)), + (1 to 10).map(n => Row(f(n * 0.1))) + ) + + if (f(-1) === math.log1p(-1)) { + checkAnswer( + nnDoubleData.select(c('b)), + (1 to 9).map(n => Row(f(n * -0.1))) :+ Row(Double.NegativeInfinity) + ) + } else { + checkAnswer( + nnDoubleData.select(c('b)), + (1 to 10).map(n => Row(null)) + ) + } + + checkAnswer( + nnDoubleData.select(c(lit(null))), + (1 to 10).map(_ => Row(null)) + ) + } + + def testTwoToOneMathFunction( + c: (Column, Column) => Column, + d: (Column, Double) => Column, + f: (Double, Double) => Double): Unit = { + checkAnswer( + nnDoubleData.select(c('a, 'a)), + nnDoubleData.collect().toSeq.map(r => Row(f(r.getDouble(0), r.getDouble(0)))) + ) + + checkAnswer( + nnDoubleData.select(c('a, 'b)), + nnDoubleData.collect().toSeq.map(r => Row(f(r.getDouble(0), r.getDouble(1)))) + ) + + checkAnswer( + nnDoubleData.select(d('a, 2.0)), + nnDoubleData.collect().toSeq.map(r => Row(f(r.getDouble(0), 2.0))) + ) + + checkAnswer( + nnDoubleData.select(d('a, -0.5)), + nnDoubleData.collect().toSeq.map(r => Row(f(r.getDouble(0), -0.5))) + ) + + val nonNull = nullDoubles.collect().toSeq.filter(r => r.get(0) != null) + + checkAnswer( + nullDoubles.select(c('a, 'a)).orderBy('a.asc), + Row(null) +: nonNull.map(r => Row(f(r.getDouble(0), r.getDouble(0)))) + ) + } + + test("sin") { + testOneToOneMathFunction(sin, math.sin) + } + + test("asin") { + testOneToOneMathFunction(asin, math.asin) + } + + test("sinh") { + testOneToOneMathFunction(sinh, math.sinh) + } + + test("cos") { + testOneToOneMathFunction(cos, math.cos) + } + + test("acos") { + testOneToOneMathFunction(acos, math.acos) + } + + test("cosh") { + testOneToOneMathFunction(cosh, math.cosh) + } + + test("tan") { + testOneToOneMathFunction(tan, math.tan) + } + + test("atan") { + testOneToOneMathFunction(atan, math.atan) + } + + test("tanh") { + testOneToOneMathFunction(tanh, math.tanh) + } + + test("toDeg") { + testOneToOneMathFunction(toDeg, math.toDegrees) + } + + test("toRad") { + testOneToOneMathFunction(toRad, math.toRadians) + } + + test("cbrt") { + testOneToOneMathFunction(cbrt, math.cbrt) + } + + test("ceil") { + testOneToOneMathFunction(ceil, math.ceil) + } + + test("floor") { + testOneToOneMathFunction(floor, math.floor) + } + + test("rint") { + testOneToOneMathFunction(rint, math.rint) + } + + test("exp") { + testOneToOneMathFunction(exp, math.exp) + } + + test("expm1") { + testOneToOneMathFunction(expm1, math.expm1) + } + + test("signum") { + testOneToOneMathFunction[Double](signum, math.signum) + } + + test("isignum") { + testOneToOneMathFunction[Int](isignum, math.signum) + } + + test("fsignum") { + testOneToOneMathFunction[Float](fsignum, math.signum) + } + + test("lsignum") { + testOneToOneMathFunction[Long](lsignum, math.signum) + } + + test("pow") { + testTwoToOneMathFunction(pow, pow, math.pow) + } + + test("hypot") { + testTwoToOneMathFunction(hypot, hypot, math.hypot) + } + + test("atan2") { + testTwoToOneMathFunction(atan2, atan2, math.atan2) + } + + test("log") { + testOneToOneNonNegativeMathFunction(log, math.log) + } + + test("log10") { + testOneToOneNonNegativeMathFunction(log10, math.log10) + } + + test("log1p") { + testOneToOneNonNegativeMathFunction(log1p, math.log1p) + } + +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala b/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala index 45f4852db8ebd..225b51bd73d6c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala @@ -17,7 +17,6 @@ package org.apache.spark.sql -import java.lang.{Double => JavaDouble} import java.sql.Timestamp import org.apache.spark.sql.catalyst.plans.logical @@ -58,15 +57,6 @@ object TestData { TestData2(3, 2) :: Nil, 2).toDF() testData2.registerTempTable("testData2") - case class DoubleData(a: JavaDouble, b: JavaDouble) - val doubleData = TestSQLContext.sparkContext.parallelize( - (1 to 100).map(i => DoubleData(i * 0.02 - 1, i * -0.02 + 1))).toDF() - doubleData.registerTempTable("doubleData") - - val nnDoubleData = TestSQLContext.sparkContext.parallelize( - (1 to 100).map(i => DoubleData(i * 0.01, i * -0.01))).toDF() - nnDoubleData.registerTempTable("nnDoubleData") - case class DecimalData(a: BigDecimal, b: BigDecimal) val decimalData = @@ -156,16 +146,6 @@ object TestData { ).toDF() nullInts.registerTempTable("nullInts") - case class NullDoubles(a: JavaDouble) - val nullDoubles = - TestSQLContext.sparkContext.parallelize( - NullDoubles(1.0) :: - NullDoubles(2.0) :: - NullDoubles(3.0) :: - NullDoubles(null) :: Nil - ).toDF() - nullDoubles.registerTempTable("nullDoubles") - val allNulls = TestSQLContext.sparkContext.parallelize( NullInts(null) :: From 2761f08fa301b962de53f978f772b2a68b27d863 Mon Sep 17 00:00:00 2001 From: Burak Yavuz Date: Mon, 27 Apr 2015 18:56:12 -0700 Subject: [PATCH 08/12] addressed review v2 --- .../catalyst/analysis/ExpectsInputTypes.scala | 13 ----------- .../sql/catalyst/expressions/Expression.scala | 21 ++++++++--------- .../sql/catalyst/expressions/arithmetic.scala | 12 +++++----- .../expressions/mathfuncs/binary.scala | 23 +++++++++---------- .../expressions/mathfuncs/unary.scala | 3 +-- 5 files changed, 28 insertions(+), 44 deletions(-) delete mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ExpectsInputTypes.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ExpectsInputTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ExpectsInputTypes.scala deleted file mode 100644 index b5fe031448004..0000000000000 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ExpectsInputTypes.scala +++ /dev/null @@ -1,13 +0,0 @@ -package org.apache.spark.sql.catalyst.analysis - -import org.apache.spark.sql.types.DataType - -/** - * For expressions that require a specific `DataType` as input should implement this trait - * so that the proper type conversions can be performed in the analyzer. - */ -trait ExpectsInputTypes { - - def expectedChildTypes: Seq[DataType] - -} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index 3b0ca9f1a2085..1d71c1b4b0c7c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -89,17 +89,6 @@ abstract class BinaryExpression extends Expression with trees.BinaryNode[Express override def toString: String = s"($left $symbol $right)" } -/** - * This class is for expressions that use math functions like `pow`, `max`, `hypot`, and `atan2`, - * which can't be expressed as `$left $symbol $right`. They are expressed as - * `$func($left, $right)`. - */ -abstract class BinaryFunctionExpression extends Expression with trees.BinaryNode[Expression] { - self: Product => - - override def foldable: Boolean = left.foldable && right.foldable -} - abstract class LeafExpression extends Expression with trees.LeafNode[Expression] { self: Product => } @@ -120,3 +109,13 @@ case class GroupExpression(children: Seq[Expression]) extends Expression { override def foldable: Boolean = false override def dataType: DataType = throw new UnsupportedOperationException } + +/** + * Expressions that require a specific `DataType` as input should implement this trait + * so that the proper type conversions can be performed in the analyzer. + */ +trait ExpectsInputTypes { + + def expectedChildTypes: Seq[DataType] + +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala index 140ccd8d3796f..b78ce77fb762c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala @@ -83,8 +83,8 @@ abstract class BinaryArithmetic extends BinaryExpression { def dataType: DataType = { if (!resolved) { - throw new UnresolvedException(this, - s"datatype. Can not resolve due to differing types ${left.dataType}, ${right.dataType}") + throw new UnresolvedException(this, "Unresolved datatype. Can not resolve due to " + + s"differing types ${left.dataType}, ${right.dataType}") } left.dataType } @@ -339,8 +339,8 @@ case class MaxOf(left: Expression, right: Expression) extends Expression { override def dataType: DataType = { if (!resolved) { - throw new UnresolvedException(this, - s"datatype. Can not resolve due to differing types ${left.dataType}, ${right.dataType}") + throw new UnresolvedException(this, "Unresolved datatype. Can not resolve due to " + + s"differing types ${left.dataType}, ${right.dataType}") } left.dataType } @@ -384,8 +384,8 @@ case class MinOf(left: Expression, right: Expression) extends Expression { override def dataType: DataType = { if (!resolved) { - throw new UnresolvedException(this, - s"datatype. Can not resolve due to differing types ${left.dataType}, ${right.dataType}") + throw new UnresolvedException(this, "Unresolved datatype. Can not resolve due to " + + s"differing types ${left.dataType}, ${right.dataType}") } left.dataType } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathfuncs/binary.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathfuncs/binary.scala index 1febe56b2f0f7..d8c6c18b151c0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathfuncs/binary.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathfuncs/binary.scala @@ -17,10 +17,8 @@ package org.apache.spark.sql.catalyst.expressions.mathfuncs -import math._ - -import org.apache.spark.sql.catalyst.analysis.{ExpectsInputTypes, UnresolvedException} -import org.apache.spark.sql.catalyst.expressions.{BinaryFunctionExpression, Expression, Row} +import org.apache.spark.sql.catalyst.analysis.UnresolvedException +import org.apache.spark.sql.catalyst.expressions.{ExpectsInputTypes, BinaryExpression, Expression, Row} import org.apache.spark.sql.types._ /** @@ -30,9 +28,9 @@ import org.apache.spark.sql.types._ * @param name The short name of the function */ abstract class BinaryMathExpression(f: (Double, Double) => Double, name: String) - extends BinaryFunctionExpression with Serializable with ExpectsInputTypes { self: Product => + extends BinaryExpression with Serializable with ExpectsInputTypes { self: Product => type EvaluatedType = Any - + override def symbol: String = null override def expectedChildTypes: Seq[DataType] = Seq(DoubleType, DoubleType) override def nullable: Boolean = left.nullable || right.nullable @@ -45,8 +43,8 @@ abstract class BinaryMathExpression(f: (Double, Double) => Double, name: String) override def dataType: DataType = { if (!resolved) { - throw new UnresolvedException(this, - s"datatype. Can not resolve due to differing types ${left.dataType}, ${right.dataType}") + throw new UnresolvedException(this, "Unresolved datatype. Can not resolve due to " + + s"differing types ${left.dataType}, ${right.dataType}") } left.dataType } @@ -67,15 +65,15 @@ abstract class BinaryMathExpression(f: (Double, Double) => Double, name: String) } } -case class Pow(left: Expression, right: Expression) extends BinaryMathExpression(pow, "POWER") +case class Pow(left: Expression, right: Expression) extends BinaryMathExpression(math.pow, "POWER") case class Hypot( left: Expression, - right: Expression) extends BinaryMathExpression(hypot, "HYPOT") + right: Expression) extends BinaryMathExpression(math.hypot, "HYPOT") case class Atan2( left: Expression, - right: Expression) extends BinaryMathExpression(atan2, "ATAN2") { + right: Expression) extends BinaryMathExpression(math.atan2, "ATAN2") { override def eval(input: Row): Any = { val evalE1 = left.eval(input) if (evalE1 == null) { @@ -86,7 +84,8 @@ case class Atan2( null } else { // With codegen, the values returned by -0.0 and 0.0 are different. Handled with +0.0 - val result = atan2(evalE1.asInstanceOf[Double] + 0.0, evalE2.asInstanceOf[Double] + 0.0) + val result = math.atan2(evalE1.asInstanceOf[Double] + 0.0, + evalE2.asInstanceOf[Double] + 0.0) if (result.isNaN) null else result } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathfuncs/unary.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathfuncs/unary.scala index ada478e20a026..05ea42251037d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathfuncs/unary.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathfuncs/unary.scala @@ -19,8 +19,7 @@ package org.apache.spark.sql.catalyst.expressions.mathfuncs import math._ -import org.apache.spark.sql.catalyst.analysis.ExpectsInputTypes -import org.apache.spark.sql.catalyst.expressions.{Expression, Row, UnaryExpression} +import org.apache.spark.sql.catalyst.expressions.{ExpectsInputTypes, Expression, Row, UnaryExpression} import org.apache.spark.sql.types._ /** From b26c5fbda5d07a536e621372fdfcbc972bc11b60 Mon Sep 17 00:00:00 2001 From: Burak Yavuz Date: Mon, 27 Apr 2015 19:04:09 -0700 Subject: [PATCH 09/12] addressed review v2.1 --- .../org/apache/spark/sql/mathfunctions.scala | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/mathfunctions.scala b/sql/core/src/main/scala/org/apache/spark/sql/mathfunctions.scala index 952ea1d26801b..84f62bf47f955 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/mathfunctions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/mathfunctions.scala @@ -439,28 +439,28 @@ object mathfunctions { def pow(l: Double, rightName: String): Column = pow(l, Column(rightName)) /** - * Computes sqrt(a^2^ + b^2^) without intermediate overflow or underflow. + * Computes `sqrt(a^2^ + b^2^)` without intermediate overflow or underflow. * * @group double_funcs */ def hypot(l: Column, r: Column): Column = Hypot(l.expr, r.expr) /** - * Computes sqrt(a^2^ + b^2^) without intermediate overflow or underflow. + * Computes `sqrt(a^2^ + b^2^)` without intermediate overflow or underflow. * * @group double_funcs */ def hypot(l: Column, rightName: String): Column = hypot(l, Column(rightName)) /** - * Computes sqrt(a^2^ + b^2^) without intermediate overflow or underflow. + * Computes `sqrt(a^2^ + b^2^)` without intermediate overflow or underflow. * * @group double_funcs */ def hypot(leftName: String, r: Column): Column = hypot(Column(leftName), r) /** - * Computes sqrt(a^2^ + b^2^) without intermediate overflow or underflow. + * Computes `sqrt(a^2^ + b^2^)` without intermediate overflow or underflow. * * @group double_funcs */ @@ -468,28 +468,28 @@ object mathfunctions { hypot(Column(leftName), Column(rightName)) /** - * Computes sqrt(a^2^ + b^2^) without intermediate overflow or underflow. + * Computes `sqrt(a^2^ + b^2^)` without intermediate overflow or underflow. * * @group double_funcs */ def hypot(l: Column, r: Double): Column = hypot(l, lit(r).expr) /** - * Computes sqrt(a^2^ + b^2^) without intermediate overflow or underflow. + * Computes `sqrt(a^2^ + b^2^)` without intermediate overflow or underflow. * * @group double_funcs */ def hypot(leftName: String, r: Double): Column = hypot(Column(leftName), r) /** - * Computes sqrt(a^2^ + b^2^) without intermediate overflow or underflow. + * Computes `sqrt(a^2^ + b^2^)` without intermediate overflow or underflow. * * @group double_funcs */ def hypot(l: Double, r: Column): Column = hypot(lit(l).expr, r) /** - * Computes sqrt(a^2^ + b^2^) without intermediate overflow or underflow. + * Computes `sqrt(a^2^ + b^2^)` without intermediate overflow or underflow. * * @group double_funcs */ From e5f0d139dba063cc07f8cee92156cdbb719978bb Mon Sep 17 00:00:00 2001 From: Burak Yavuz Date: Mon, 27 Apr 2015 20:34:43 -0700 Subject: [PATCH 10/12] addressed code review v2.2 --- .../catalyst/analysis/HiveTypeCoercion.scala | 29 +++++------ .../expressions/mathfuncs/unary.scala | 50 +++++++++---------- .../spark/sql/MathExpressionsSuite.scala | 17 +++++++ 3 files changed, 54 insertions(+), 42 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala index c31832cb68996..2fddaf4e620cc 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala @@ -650,23 +650,20 @@ trait HiveTypeCoercion { */ object ExpectedInputConversion extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan transform { - case q: LogicalPlan => q transformExpressions { - // Skip nodes who's children have not been resolved yet. - case e if !e.childrenResolved => e + def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { + // Skip nodes who's children have not been resolved yet. + case e if !e.childrenResolved => e - case e: ExpectsInputTypes if e.children.map(_.dataType) != e.expectedChildTypes => - val newC = (e.children, e.children.map(_.dataType), e.expectedChildTypes).zipped.map { - case (child, actual, expected) => - if (actual == expected) { - child - } else { - Cast(child, expected) - } - } - e.withNewChildren(newC) - } + case e: ExpectsInputTypes if e.children.map(_.dataType) != e.expectedChildTypes => + val newC = (e.children, e.children.map(_.dataType), e.expectedChildTypes).zipped.map { + case (child, actual, expected) => + if (actual == expected) { + child + } else { + Cast(child, expected) + } + } + e.withNewChildren(newC) } } - } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathfuncs/unary.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathfuncs/unary.scala index 05ea42251037d..96cb77d487529 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathfuncs/unary.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathfuncs/unary.scala @@ -17,8 +17,6 @@ package org.apache.spark.sql.catalyst.expressions.mathfuncs -import math._ - import org.apache.spark.sql.catalyst.expressions.{ExpectsInputTypes, Expression, Row, UnaryExpression} import org.apache.spark.sql.types._ @@ -119,52 +117,52 @@ abstract class MathematicalExpressionForLong(f: Long => Long, name: String) } } -case class Sin(child: Expression) extends MathematicalExpressionForDouble(sin, "SIN") +case class Sin(child: Expression) extends MathematicalExpressionForDouble(math.sin, "SIN") -case class Asin(child: Expression) extends MathematicalExpressionForDouble(asin, "ASIN") +case class Asin(child: Expression) extends MathematicalExpressionForDouble(math.asin, "ASIN") -case class Sinh(child: Expression) extends MathematicalExpressionForDouble(sinh, "SINH") +case class Sinh(child: Expression) extends MathematicalExpressionForDouble(math.sinh, "SINH") -case class Cos(child: Expression) extends MathematicalExpressionForDouble(cos, "COS") +case class Cos(child: Expression) extends MathematicalExpressionForDouble(math.cos, "COS") -case class Acos(child: Expression) extends MathematicalExpressionForDouble(acos, "ACOS") +case class Acos(child: Expression) extends MathematicalExpressionForDouble(math.acos, "ACOS") -case class Cosh(child: Expression) extends MathematicalExpressionForDouble(cosh, "COSH") +case class Cosh(child: Expression) extends MathematicalExpressionForDouble(math.cosh, "COSH") -case class Tan(child: Expression) extends MathematicalExpressionForDouble(tan, "TAN") +case class Tan(child: Expression) extends MathematicalExpressionForDouble(math.tan, "TAN") -case class Atan(child: Expression) extends MathematicalExpressionForDouble(atan, "ATAN") +case class Atan(child: Expression) extends MathematicalExpressionForDouble(math.atan, "ATAN") -case class Tanh(child: Expression) extends MathematicalExpressionForDouble(tanh, "TANH") +case class Tanh(child: Expression) extends MathematicalExpressionForDouble(math.tanh, "TANH") -case class Ceil(child: Expression) extends MathematicalExpressionForDouble(ceil, "CEIL") +case class Ceil(child: Expression) extends MathematicalExpressionForDouble(math.ceil, "CEIL") -case class Floor(child: Expression) extends MathematicalExpressionForDouble(floor, "FLOOR") +case class Floor(child: Expression) extends MathematicalExpressionForDouble(math.floor, "FLOOR") -case class Rint(child: Expression) extends MathematicalExpressionForDouble(rint, "ROUND") +case class Rint(child: Expression) extends MathematicalExpressionForDouble(math.rint, "ROUND") -case class Cbrt(child: Expression) extends MathematicalExpressionForDouble(cbrt, "CBRT") +case class Cbrt(child: Expression) extends MathematicalExpressionForDouble(math.cbrt, "CBRT") -case class Signum(child: Expression) extends MathematicalExpressionForDouble(signum, "SIGNUM") +case class Signum(child: Expression) extends MathematicalExpressionForDouble(math.signum, "SIGNUM") -case class ISignum(child: Expression) extends MathematicalExpressionForInt(signum, "ISIGNUM") +case class ISignum(child: Expression) extends MathematicalExpressionForInt(math.signum, "ISIGNUM") -case class FSignum(child: Expression) extends MathematicalExpressionForFloat(signum, "FSIGNUM") +case class FSignum(child: Expression) extends MathematicalExpressionForFloat(math.signum, "FSIGNUM") -case class LSignum(child: Expression) extends MathematicalExpressionForLong(signum, "LSIGNUM") +case class LSignum(child: Expression) extends MathematicalExpressionForLong(math.signum, "LSIGNUM") case class ToDegrees(child: Expression) - extends MathematicalExpressionForDouble(toDegrees, "DEGREES") + extends MathematicalExpressionForDouble(math.toDegrees, "DEGREES") case class ToRadians(child: Expression) - extends MathematicalExpressionForDouble(toRadians, "RADIANS") + extends MathematicalExpressionForDouble(math.toRadians, "RADIANS") -case class Log(child: Expression) extends MathematicalExpressionForDouble(log, "LOG") +case class Log(child: Expression) extends MathematicalExpressionForDouble(math.log, "LOG") -case class Log10(child: Expression) extends MathematicalExpressionForDouble(log10, "LOG10") +case class Log10(child: Expression) extends MathematicalExpressionForDouble(math.log10, "LOG10") -case class Log1p(child: Expression) extends MathematicalExpressionForDouble(log1p, "LOG1P") +case class Log1p(child: Expression) extends MathematicalExpressionForDouble(math.log1p, "LOG1P") -case class Exp(child: Expression) extends MathematicalExpressionForDouble(exp, "EXP") +case class Exp(child: Expression) extends MathematicalExpressionForDouble(math.exp, "EXP") -case class Expm1(child: Expression) extends MathematicalExpressionForDouble(expm1, "EXPM1") +case class Expm1(child: Expression) extends MathematicalExpressionForDouble(math.expm1, "EXPM1") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala index 45f141316fce3..561553cc925cb 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala @@ -1,3 +1,20 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package org.apache.spark.sql import java.lang.{Double => JavaDouble} From 836a098dbbfc2b671592847dff37a5cefaeaea2a Mon Sep 17 00:00:00 2001 From: Burak Yavuz Date: Mon, 27 Apr 2015 21:13:52 -0700 Subject: [PATCH 11/12] fixed test and addressed small comment --- .../spark/sql/catalyst/analysis/HiveTypeCoercion.scala | 6 +----- .../java/test/org/apache/spark/sql/JavaDataFrameSuite.java | 2 +- 2 files changed, 2 insertions(+), 6 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala index 2fddaf4e620cc..73c9a1c7afdad 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala @@ -657,11 +657,7 @@ trait HiveTypeCoercion { case e: ExpectsInputTypes if e.children.map(_.dataType) != e.expectedChildTypes => val newC = (e.children, e.children.map(_.dataType), e.expectedChildTypes).zipped.map { case (child, actual, expected) => - if (actual == expected) { - child - } else { - Cast(child, expected) - } + if (actual == expected) child else Cast(child, expected) } e.withNewChildren(newC) } diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java index b201dfcf4cae1..e5c9504d21042 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java @@ -101,7 +101,7 @@ public void testVarargMethods() { df.select(coalesce(col("key"))); // Varargs with mathfunctions - DataFrame df2 = context.table("doubleData"); + DataFrame df2 = context.table("testData2"); df2.select(exp("a"), exp("b")); df2.select(exp(log("a"))); df2.select(pow("a", "a"), pow("b", 2.0)); From fb271536a68cf3f7ff267953098ce305512c65d0 Mon Sep 17 00:00:00 2001 From: Burak Yavuz Date: Mon, 27 Apr 2015 21:19:04 -0700 Subject: [PATCH 12/12] reverted exception message --- .../spark/sql/catalyst/expressions/arithmetic.scala | 12 ++++++------ .../sql/catalyst/expressions/mathfuncs/binary.scala | 4 ++-- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala index b78ce77fb762c..140ccd8d3796f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala @@ -83,8 +83,8 @@ abstract class BinaryArithmetic extends BinaryExpression { def dataType: DataType = { if (!resolved) { - throw new UnresolvedException(this, "Unresolved datatype. Can not resolve due to " + - s"differing types ${left.dataType}, ${right.dataType}") + throw new UnresolvedException(this, + s"datatype. Can not resolve due to differing types ${left.dataType}, ${right.dataType}") } left.dataType } @@ -339,8 +339,8 @@ case class MaxOf(left: Expression, right: Expression) extends Expression { override def dataType: DataType = { if (!resolved) { - throw new UnresolvedException(this, "Unresolved datatype. Can not resolve due to " + - s"differing types ${left.dataType}, ${right.dataType}") + throw new UnresolvedException(this, + s"datatype. Can not resolve due to differing types ${left.dataType}, ${right.dataType}") } left.dataType } @@ -384,8 +384,8 @@ case class MinOf(left: Expression, right: Expression) extends Expression { override def dataType: DataType = { if (!resolved) { - throw new UnresolvedException(this, "Unresolved datatype. Can not resolve due to " + - s"differing types ${left.dataType}, ${right.dataType}") + throw new UnresolvedException(this, + s"datatype. Can not resolve due to differing types ${left.dataType}, ${right.dataType}") } left.dataType } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathfuncs/binary.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathfuncs/binary.scala index d8c6c18b151c0..5b4d912a64f71 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathfuncs/binary.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathfuncs/binary.scala @@ -43,8 +43,8 @@ abstract class BinaryMathExpression(f: (Double, Double) => Double, name: String) override def dataType: DataType = { if (!resolved) { - throw new UnresolvedException(this, "Unresolved datatype. Can not resolve due to " + - s"differing types ${left.dataType}, ${right.dataType}") + throw new UnresolvedException(this, + s"datatype. Can not resolve due to differing types ${left.dataType}, ${right.dataType}") } left.dataType }