Skip to content

Commit

Permalink
[SPARK-37475][SQL] Add scale parameter to floor and ceil functions
Browse files Browse the repository at this point in the history
### What changes were proposed in this pull request?

Adds `scale` parameter to `floor`/`ceil` functions in order to allow users to control the rounding position. This feature is proposed in the PR: #34593

### Why are the changes needed?

Currently we support Decimal RoundingModes : HALF_UP (round) and HALF_EVEN (bround). But we have use cases that needs RoundingMode.UP and RoundingMode.DOWN.

Floor and Ceil functions helps to do this but it doesn't support the position of the rounding. Adding scale parameter to the functions would help us control the rounding positions.

Snowflake supports `scale` parameter to `floor`/`ceil` :
` FLOOR( <input_expr> [, <scale_expr> ] )`

REF:
https://docs.snowflake.com/en/sql-reference/functions/floor.html

### Does this PR introduce _any_ user-facing change?

Now users can pass `scale` parameter to the `floor` and `ceil` functions.
 ```
     > SELECT floor(-0.1);
       -1.0
      > SELECT floor(5);
       5
      > SELECT floor(3.1411, 3);
       3.141
      > SELECT floor(3.1411, -3);
       1000.0

      > SELECT ceil(-0.1);
       0.0
      > SELECT ceil(5);
       5
      > SELECT ceil(3.1411, 3);
       3.142
      > SELECT ceil(3.1411, -3);
       1000.0

```
### How was this patch tested?

This patch was tested locally using unit test and git workflow.

Closes #34729 from sathiyapk/SPARK-37475-floor-ceil-scale.

Authored-by: Sathiya KUMAR <ext.sathiyaprabhu.kumar@sncf.fr>
Signed-off-by: Wenchen Fan <wenchen@databricks.com>
  • Loading branch information
sathiyapk authored and cloud-fan committed Feb 21, 2022
1 parent e2796d2 commit 6242145
Show file tree
Hide file tree
Showing 8 changed files with 611 additions and 50 deletions.
Expand Up @@ -363,8 +363,8 @@ object FunctionRegistry {
expression[Bin]("bin"),
expression[BRound]("bround"),
expression[Cbrt]("cbrt"),
expression[Ceil]("ceil"),
expression[Ceil]("ceiling", true),
expressionBuilder("ceil", CeilExpressionBuilder),
expressionBuilder("ceiling", CeilExpressionBuilder, true),
expression[Cos]("cos"),
expression[Sec]("sec"),
expression[Cosh]("cosh"),
Expand All @@ -373,7 +373,7 @@ object FunctionRegistry {
expression[EulerNumber]("e"),
expression[Exp]("exp"),
expression[Expm1]("expm1"),
expression[Floor]("floor"),
expressionBuilder("floor", FloorExpressionBuilder),
expression[Factorial]("factorial"),
expression[Hex]("hex"),
expression[Hypot]("hypot"),
Expand Down Expand Up @@ -806,11 +806,14 @@ object FunctionRegistry {
}

private def expressionBuilder[T <: ExpressionBuilder : ClassTag](
name: String, builder: T): (String, (ExpressionInfo, FunctionBuilder)) = {
name: String, builder: T, setAlias: Boolean = false)
: (String, (ExpressionInfo, FunctionBuilder)) = {
val info = FunctionRegistryBase.expressionInfo[T](name, None)
val funcBuilder = (expressions: Seq[Expression]) => {
assert(expressions.forall(_.resolved), "function arguments must be resolved.")
builder.build(expressions)
val expr = builder.build(expressions)
if (setAlias) expr.setTagValue(FUNC_ALIAS, name)
expr
}
(name, (info, funcBuilder))
}
Expand Down
Expand Up @@ -21,11 +21,12 @@ import java.{lang => jl}
import java.util.Locale

import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.{FunctionRegistry, TypeCheckResult}
import org.apache.spark.sql.catalyst.analysis.{ExpressionBuilder, FunctionRegistry, TypeCheckResult}
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{TypeCheckFailure, TypeCheckSuccess}
import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
import org.apache.spark.sql.catalyst.util.{NumberConverter, TypeUtils}
import org.apache.spark.sql.errors.QueryCompilationErrors
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String

Expand Down Expand Up @@ -238,17 +239,6 @@ case class Cbrt(child: Expression) extends UnaryMathExpression(math.cbrt, "CBRT"
override protected def withNewChildInternal(newChild: Expression): Cbrt = copy(child = newChild)
}

@ExpressionDescription(
usage = "_FUNC_(expr) - Returns the smallest integer not smaller than `expr`.",
examples = """
Examples:
> SELECT _FUNC_(-0.1);
0
> SELECT _FUNC_(5);
5
""",
since = "1.4.0",
group = "math_funcs")
case class Ceil(child: Expression) extends UnaryMathExpression(math.ceil, "CEIL") {
override def dataType: DataType = child.dataType match {
case dt @ DecimalType.Fixed(_, 0) => dt
Expand Down Expand Up @@ -279,6 +269,77 @@ case class Ceil(child: Expression) extends UnaryMathExpression(math.ceil, "CEIL"
override protected def withNewChildInternal(newChild: Expression): Ceil = copy(child = newChild)
}

trait CeilFloorExpressionBuilder extends ExpressionBuilder {
val functionName: String
def build(expressions: Seq[Expression]): Expression

def extractChildAndScaleParam(expressions: Seq[Expression]): (Expression, Expression) = {
val child = expressions(0)
val scale = expressions(1)
if (! (scale.foldable && scale.dataType == DataTypes.IntegerType)) {
throw QueryCompilationErrors.invalidScaleParameterRoundBase(functionName)
}
val scaleV = scale.eval(EmptyRow)
if (scaleV == null) {
throw QueryCompilationErrors.invalidScaleParameterRoundBase(functionName)
}
(child, scale)
}
}

@ExpressionDescription(
usage = """
_FUNC_(expr[, scale]) - Returns the smallest number after rounding up that is not smaller
than `expr`. A optional `scale` parameter can be specified to control the rounding behavior.""",
examples = """
Examples:
> SELECT _FUNC_(-0.1);
0
> SELECT _FUNC_(5);
5
> SELECT _FUNC_(3.1411, 3);
3.142
> SELECT _FUNC_(3.1411, -3);
1000
""",
since = "3.3.0",
group = "math_funcs")
object CeilExpressionBuilder extends CeilFloorExpressionBuilder {
val functionName: String = "ceil"

def build(expressions: Seq[Expression]): Expression = {
if (expressions.length == 1) {
Ceil(expressions.head)
} else if (expressions.length == 2) {
val (child, scale) = extractChildAndScaleParam(expressions)
RoundCeil(child, scale)
} else {
throw QueryCompilationErrors.invalidNumberOfFunctionParameters(functionName)
}
}
}

case class RoundCeil(child: Expression, scale: Expression)
extends RoundBase(child, scale, BigDecimal.RoundingMode.CEILING, "ROUND_CEILING")
with Serializable with ImplicitCastInputTypes {

override def inputTypes: Seq[AbstractDataType] = Seq(DecimalType, IntegerType)

override lazy val dataType: DataType = child.dataType match {
case DecimalType.Fixed(p, s) =>
if (_scale < 0) {
DecimalType(math.max(p, 1 - _scale), 0)
} else {
DecimalType(p, math.min(s, _scale))
}
case t => t
}

override protected def withNewChildrenInternal(newLeft: Expression, newRight: Expression)
: RoundCeil = copy(child = newLeft, scale = newRight)
override def nodeName: String = "ceil"
}

@ExpressionDescription(
usage = """
_FUNC_(expr) - Returns the cosine of `expr`, as if computed by
Expand Down Expand Up @@ -448,17 +509,6 @@ case class Expm1(child: Expression) extends UnaryMathExpression(StrictMath.expm1
override protected def withNewChildInternal(newChild: Expression): Expm1 = copy(child = newChild)
}

@ExpressionDescription(
usage = "_FUNC_(expr) - Returns the largest integer not greater than `expr`.",
examples = """
Examples:
> SELECT _FUNC_(-0.1);
-1
> SELECT _FUNC_(5);
5
""",
since = "1.4.0",
group = "math_funcs")
case class Floor(child: Expression) extends UnaryMathExpression(math.floor, "FLOOR") {
override def dataType: DataType = child.dataType match {
case dt @ DecimalType.Fixed(_, 0) => dt
Expand All @@ -484,9 +534,62 @@ case class Floor(child: Expression) extends UnaryMathExpression(math.floor, "FLO
case LongType => defineCodeGen(ctx, ev, c => s"$c")
case _ => defineCodeGen(ctx, ev, c => s"(long)(java.lang.Math.${funcName}($c))")
}
}
override protected def withNewChildInternal(newChild: Expression): Floor =
copy(child = newChild)
}

@ExpressionDescription(
usage = """
_FUNC_(expr[, scale]) - Returns the largest number after rounding down that is not greater
than `expr`. An optional `scale` parameter can be specified to control the rounding behavior.""",
examples = """
Examples:
> SELECT _FUNC_(-0.1);
-1
> SELECT _FUNC_(5);
5
> SELECT _FUNC_(3.1411, 3);
3.141
> SELECT _FUNC_(3.1411, -3);
0
""",
since = "3.3.0",
group = "math_funcs")
object FloorExpressionBuilder extends CeilFloorExpressionBuilder {
val functionName: String = "floor"

def build(expressions: Seq[Expression]): Expression = {
if (expressions.length == 1) {
Floor(expressions.head)
} else if (expressions.length == 2) {
val(child, scale) = extractChildAndScaleParam(expressions)
RoundFloor(child, scale)
} else {
throw QueryCompilationErrors.invalidNumberOfFunctionParameters(functionName)
}
}
}

override protected def withNewChildInternal(newChild: Expression): Floor = copy(child = newChild)
case class RoundFloor(child: Expression, scale: Expression)
extends RoundBase(child, scale, BigDecimal.RoundingMode.FLOOR, "ROUND_FLOOR")
with Serializable with ImplicitCastInputTypes {

override def inputTypes: Seq[AbstractDataType] = Seq(DecimalType, IntegerType)

override lazy val dataType: DataType = child.dataType match {
case DecimalType.Fixed(p, s) =>
if (_scale < 0) {
DecimalType(math.max(p, 1 - _scale), 0)
} else {
DecimalType(p, math.min(s, _scale))
}
case t => t
}

override protected def withNewChildrenInternal(newLeft: Expression, newRight: Expression)
: RoundFloor = copy(child = newLeft, scale = newRight)
override def nodeName: String = "floor"
}

object Factorial {
Expand Down Expand Up @@ -1375,7 +1478,7 @@ abstract class RoundBase(child: Expression, scale: Expression,
// avoid unnecessary `child` evaluation in both codegen and non-codegen eval
// by checking if scaleV == null as well.
private lazy val scaleV: Any = scale.eval(EmptyRow)
private lazy val _scale: Int = scaleV.asInstanceOf[Int]
protected lazy val _scale: Int = scaleV.asInstanceOf[Int]

override def eval(input: InternalRow): Any = {
if (scaleV == null) { // if scale is null, no need to eval its child at all
Expand All @@ -1393,10 +1496,14 @@ abstract class RoundBase(child: Expression, scale: Expression,
// not overriding since _scale is a constant int at runtime
def nullSafeEval(input1: Any): Any = {
dataType match {
case DecimalType.Fixed(_, s) =>
case DecimalType.Fixed(p, s) =>
val decimal = input1.asInstanceOf[Decimal]
// Overflow cannot happen, so no need to control nullOnOverflow
decimal.toPrecision(decimal.precision, s, mode)
if (_scale >= 0) {
// Overflow cannot happen, so no need to control nullOnOverflow
decimal.toPrecision(decimal.precision, s, mode)
} else {
Decimal(decimal.toBigDecimal.setScale(_scale, mode), p, s)
}
case ByteType =>
BigDecimal(input1.asInstanceOf[Byte]).setScale(_scale, mode).toByte
case ShortType =>
Expand Down Expand Up @@ -1426,12 +1533,18 @@ abstract class RoundBase(child: Expression, scale: Expression,
val ce = child.genCode(ctx)

val evaluationCode = dataType match {
case DecimalType.Fixed(_, s) =>
s"""
|${ev.value} = ${ce.value}.toPrecision(${ce.value}.precision(), $s,
| Decimal.$modeStr(), true);
|${ev.isNull} = ${ev.value} == null;
""".stripMargin
case DecimalType.Fixed(p, s) =>
if (_scale >= 0) {
s"""
${ev.value} = ${ce.value}.toPrecision(${ce.value}.precision(), $s,
Decimal.$modeStr(), true);
${ev.isNull} = ${ev.value} == null;"""
} else {
s"""
${ev.value} = new Decimal().set(${ce.value}.toBigDecimal()
.setScale(${_scale}, Decimal.$modeStr()), $p, $s);
${ev.isNull} = ${ev.value} == null;"""
}
case ByteType =>
if (_scale < 0) {
s"""
Expand Down
Expand Up @@ -2375,4 +2375,12 @@ object QueryCompilationErrors {
new AnalysisException(
"Sinks cannot request distribution and ordering in continuous execution mode")
}

def invalidScaleParameterRoundBase(function: String): Throwable = {
new AnalysisException(s"The 'scale' parameter of function '$function' must be an int constant.")
}

def invalidNumberOfFunctionParameters(function: String): Throwable = {
new AnalysisException(s"Invalid number of parameters to the function '$function'.")
}
}

0 comments on commit 6242145

Please sign in to comment.