Skip to content

Commit

Permalink
Update code
Browse files Browse the repository at this point in the history
  • Loading branch information
beliefer committed Oct 14, 2021
1 parent 7df29e5 commit d202787
Showing 1 changed file with 15 additions and 23 deletions.
Expand Up @@ -599,7 +599,13 @@ trait IntervalDivide {
}
}

def divideByZeroCheckCodegen(expr: Expression, value: String): String = expr.dataType match {
def divideByZeroCheck(dataType: DataType, num: Any): Unit = dataType match {
case _: DecimalType =>
if (num.asInstanceOf[Decimal].isZero) throw QueryExecutionErrors.divideByZeroError()
case _ => if (num == 0) throw QueryExecutionErrors.divideByZeroError()
}

def divideByZeroCheckCodegen(dataType: DataType, value: String): String = dataType match {
case _: DecimalType => s"if ($value.isZero()) throw QueryExecutionErrors.divideByZeroError();"
case _ => s"if ($value == 0) throw QueryExecutionErrors.divideByZeroError();"
}
Expand All @@ -617,13 +623,6 @@ case class DivideYMInterval(
override def inputTypes: Seq[AbstractDataType] = Seq(YearMonthIntervalType, NumericType)
override def dataType: DataType = YearMonthIntervalType()

@transient
private lazy val divideByZeroCheck: Any => Unit = right.dataType match {
case _: DecimalType => (num) =>
if (num.asInstanceOf[Decimal].isZero) throw QueryExecutionErrors.divideByZeroError()
case _ => (num) => if (num == 0) throw QueryExecutionErrors.divideByZeroError()
}

@transient
private lazy val evalFunc: (Int, Any) => Any = right.dataType match {
case LongType => (months: Int, num) =>
Expand All @@ -641,7 +640,7 @@ case class DivideYMInterval(

override def nullSafeEval(interval: Any, num: Any): Any = {
checkDivideOverflow(interval.asInstanceOf[Int], Int.MinValue, right, num)
divideByZeroCheck(num)
divideByZeroCheck(right.dataType, num)
evalFunc(interval.asInstanceOf[Int], num)
}

Expand All @@ -663,22 +662,22 @@ case class DivideYMInterval(
// Similarly to non-codegen code. The result of `divide(Int, Long, ...)` must fit to `Int`.
// Casting to `Int` is safe here.
s"""
|${divideByZeroCheckCodegen(right, n)}
|${divideByZeroCheckCodegen(right.dataType, n)}
|$checkIntegralDivideOverflow
|${ev.value} = ($javaType)$math.divide($m, $n, java.math.RoundingMode.HALF_UP);
""".stripMargin)
case _: DecimalType =>
nullSafeCodeGen(ctx, ev, (m, n) =>
s"""
|${divideByZeroCheckCodegen(right, n)}
|${divideByZeroCheckCodegen(right.dataType, n)}
|${ev.value} = ((new Decimal()).set($m).$$div($n)).toJavaBigDecimal()
| .setScale(0, java.math.RoundingMode.HALF_UP).intValueExact();
""".stripMargin)
case _: FractionalType =>
val math = classOf[DoubleMath].getName
nullSafeCodeGen(ctx, ev, (m, n) =>
s"""
|${divideByZeroCheckCodegen(right, n)}
|${divideByZeroCheckCodegen(right.dataType, n)}
|${ev.value} = $math.roundToInt($m / (double)$n, java.math.RoundingMode.HALF_UP);
""".stripMargin)
}
Expand All @@ -703,13 +702,6 @@ case class DivideDTInterval(
override def inputTypes: Seq[AbstractDataType] = Seq(DayTimeIntervalType, NumericType)
override def dataType: DataType = DayTimeIntervalType()

@transient
private lazy val divideByZeroCheck: Any => Unit = right.dataType match {
case _: DecimalType => (num) =>
if (num.asInstanceOf[Decimal].isZero) throw QueryExecutionErrors.divideByZeroError()
case _ => (num) => if (num == 0) throw QueryExecutionErrors.divideByZeroError()
}

@transient
private lazy val evalFunc: (Long, Any) => Any = right.dataType match {
case _: IntegralType => (micros: Long, num) =>
Expand All @@ -723,7 +715,7 @@ case class DivideDTInterval(

override def nullSafeEval(interval: Any, num: Any): Any = {
checkDivideOverflow(interval.asInstanceOf[Long], Long.MinValue, right, num)
divideByZeroCheck(num)
divideByZeroCheck(right.dataType, num)
evalFunc(interval.asInstanceOf[Long], num)
}

Expand All @@ -739,22 +731,22 @@ case class DivideDTInterval(
|""".stripMargin
nullSafeCodeGen(ctx, ev, (m, n) =>
s"""
|${divideByZeroCheckCodegen(right, n)}
|${divideByZeroCheckCodegen(right.dataType, n)}
|$checkIntegralDivideOverflow
|${ev.value} = $math.divide($m, $n, java.math.RoundingMode.HALF_UP);
""".stripMargin)
case _: DecimalType =>
nullSafeCodeGen(ctx, ev, (m, n) =>
s"""
|${divideByZeroCheckCodegen(right, n)}
|${divideByZeroCheckCodegen(right.dataType, n)}
|${ev.value} = ((new Decimal()).set($m).$$div($n)).toJavaBigDecimal()
| .setScale(0, java.math.RoundingMode.HALF_UP).longValueExact();
""".stripMargin)
case _: FractionalType =>
val math = classOf[DoubleMath].getName
nullSafeCodeGen(ctx, ev, (m, n) =>
s"""
|${divideByZeroCheckCodegen(right, n)}
|${divideByZeroCheckCodegen(right.dataType, n)}
|${ev.value} = $math.roundToLong($m / (double)$n, java.math.RoundingMode.HALF_UP);
""".stripMargin)
}
Expand Down

0 comments on commit d202787

Please sign in to comment.