Skip to content

Commit

Permalink
Add math function round
Browse files Browse the repository at this point in the history
  • Loading branch information
yjshen committed Jul 14, 2015
1 parent 408b384 commit 653d047
Show file tree
Hide file tree
Showing 6 changed files with 183 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,7 @@ object FunctionRegistry {
expression[Tanh]("tanh"),
expression[ToDegrees]("degrees"),
expression[ToRadians]("radians"),
expression[Round]("round"),

// misc functions
expression[Md5]("md5"),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,11 @@ package org.apache.spark.sql.catalyst.expressions

import java.{lang => jl}

import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{TypeCheckSuccess, TypeCheckFailure}
import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.util.BigDecimalConverter
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String

Expand Down Expand Up @@ -520,3 +523,95 @@ case class Logarithm(left: Expression, right: Expression)
"""
}
}

case class Round(children: Seq[Expression]) extends Expression {

def nullable: Boolean = true

def dataType: DataType = {
children(0).dataType match {
case StringType | BinaryType => DoubleType
case t => t
}
}

override def checkInputDataTypes(): TypeCheckResult = {
if (children.size < 1 || children.size > 2) {
return TypeCheckFailure(s"ROUND require one or two arguments, got ${children.size}")
}
children(0).dataType match {
case _: NumericType | NullType | BinaryType | StringType => // satisfy requirement
case dt =>
return TypeCheckFailure(s"Only numeric, string or binary data types" +
s" are allowed for ROUND function, got $dt")
}
if (children.size == 2) {
children(1) match {
case Literal(value, LongType) =>
if (value.asInstanceOf[Long] < Int.MinValue || value.asInstanceOf[Long] > Int.MaxValue) {
return TypeCheckFailure("ROUND scale argument out of allowed range")
}
case Literal(_, _: IntegralType) | Literal(_, NullType) => // satisfy requirement
case child =>
if (child.find { case _: AttributeReference => true; case _ => false } != None) {
return TypeCheckFailure("Only Integral Literal or Null Literal " +
s"are allowed for ROUND scale arguments, got ${child.dataType}")
}
}
}
TypeCheckSuccess
}

def eval(input: InternalRow): Any = {
val evalE1 = children(0).eval(input)
if (evalE1 == null) {
return null
}

var _scale: Int = 0
if (children.size == 2) {
val evalE2 = children(1).eval(input)
if (evalE2 == null) {
return null
} else {
_scale = evalE2.asInstanceOf[Int]
}
}

children(0).dataType match {
case decimalType: DecimalType =>
// TODO: Support Decimal Round
case ByteType =>
round(evalE1.asInstanceOf[Byte], _scale)
case ShortType =>
round(evalE1.asInstanceOf[Short], _scale)
case IntegerType =>
round(evalE1.asInstanceOf[Int], _scale)
case LongType =>
round(evalE1.asInstanceOf[Long], _scale)
case FloatType =>
round(evalE1.asInstanceOf[Float], _scale)
case DoubleType =>
round(evalE1.asInstanceOf[Double], _scale)
case StringType =>
round(evalE1.asInstanceOf[UTF8String].toString, _scale)
case BinaryType =>
round(UTF8String.fromBytes(evalE1.asInstanceOf[Array[Byte]]).toString, _scale)
}
}

private def round[T](input: T, scale: Int)(implicit bdc: BigDecimalConverter[T]): T = {
input match {
case f: Float if (f.isNaN || f.isInfinite) => return input
case d: Double if (d.isNaN || d.isInfinite) => return input
case _ =>
}
bdc.fromBigDecimal(bdc.toBigDecimal(input).setScale(scale, BigDecimal.RoundingMode.HALF_UP))
}

private def round(input: String, scale: Int): Any = {
try round(input.toDouble, scale) catch {
case _ : NumberFormatException => null
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.catalyst.util

trait BigDecimalConverter[T] {
def toBigDecimal(in: T) : BigDecimal
def fromBigDecimal(bd: BigDecimal) : T
}

/**
* Helper type converters to work with BigDecimal
* from http://stackoverflow.com/a/30979266/1115193
*/
object BigDecimalConverter {

implicit object ByteConverter extends BigDecimalConverter[Byte] {
def toBigDecimal(in: Byte) = BigDecimal(in)
def fromBigDecimal(bd: BigDecimal) = bd.toByte
}

implicit object ShortConverter extends BigDecimalConverter[Short] {
def toBigDecimal(in: Short) = BigDecimal(in)
def fromBigDecimal(bd: BigDecimal) = bd.toShort
}

implicit object IntConverter extends BigDecimalConverter[Int] {
def toBigDecimal(in: Int) = BigDecimal(in)
def fromBigDecimal(bd: BigDecimal) = bd.toInt
}

implicit object LongConverter extends BigDecimalConverter[Long] {
def toBigDecimal(in: Long) = BigDecimal(in)
def fromBigDecimal(bd: BigDecimal) = bd.toLong
}

implicit object FloatConverter extends BigDecimalConverter[Float] {
def toBigDecimal(in: Float) = BigDecimal(in)
def fromBigDecimal(bd: BigDecimal) = bd.toFloat
}

implicit object DoubleConverter extends BigDecimalConverter[Double] {
def toBigDecimal(in: Double) = BigDecimal(in)
def fromBigDecimal(bd: BigDecimal) = bd.toDouble
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -171,4 +171,17 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite {
CreateNamedStruct(Seq('a.string.at(0), "a", "b", 2.0)),
"Odd position only allow foldable and not-null StringType expressions")
}

test("check types for ROUND") {
assertError(Round(Seq()), "ROUND require one or two arguments")
assertError(Round(Seq(Literal(null),'booleanField)),
"Only Integral Literal or Null Literal are allowed for ROUND scale argument")
assertError(Round(Seq(Literal(null), 'complexField)),
"Only Integral Literal or Null Literal are allowed for ROUND scale argument")
assertSuccess(Round(Seq(Literal(null), Literal(null))))
assertError(Round(Seq('booleanField, 'intField)),
"Only numeric, string or binary data types are allowed for ROUND function")
assertError(Round(Seq(Literal(null), Literal(1L + Int.MaxValue))),
"ROUND scale argument out of allowed range")
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -336,4 +336,16 @@ class MathFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper {
null,
create_row(null))
}

test("round test") {
val piRounds = Seq(
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 3.0,
3.1, 3.14, 3.142, 3.1416, 3.14159, 3.141593, 3.1415927, 3.14159265, 3.141592654,
3.1415926536, 3.14159265359, 3.14159265359, 3.1415926535898, 3.14159265358979,
3.141592653589793, 3.141592653589793)
(-16 to 16).zipWithIndex.foreach {
case (scale, i) =>
checkEvaluation(Round(Seq(3.141592653589793, scale)), piRounds(i), EmptyRow)
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -919,7 +919,7 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter {
"udf_repeat",
"udf_rlike",
"udf_round",
// "udf_round_3", TODO: FIX THIS failed due to cast exception
"udf_round_3",
"udf_rpad",
"udf_rtrim",
"udf_second",
Expand Down

0 comments on commit 653d047

Please sign in to comment.