-
Notifications
You must be signed in to change notification settings - Fork 371
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[GLUTEN-5620][CORE] Remove check_overflow and refactor code #5654
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -18,69 +18,40 @@ package org.apache.gluten.utils | |
|
||
import org.apache.gluten.backendsapi.BackendsApiManager | ||
import org.apache.gluten.exception.GlutenNotSupportException | ||
import org.apache.gluten.expression.{CheckOverflowTransformer, ChildTransformer, DecimalArithmeticExpressionTransformer, ExpressionTransformer} | ||
import org.apache.gluten.expression.ExpressionConverter.conf | ||
|
||
import org.apache.spark.sql.catalyst.analysis.DecimalPrecision | ||
import org.apache.spark.sql.catalyst.expressions.{Add, BinaryArithmetic, Cast, Divide, Expression, Literal, Multiply, Pmod, PromotePrecision, Remainder, Subtract} | ||
import org.apache.spark.sql.internal.SQLConf | ||
import org.apache.spark.sql.types.{ByteType, Decimal, DecimalType, IntegerType, LongType, ShortType} | ||
|
||
import scala.annotation.tailrec | ||
import org.apache.spark.sql.utils.DecimalTypeUtil | ||
|
||
object DecimalArithmeticUtil { | ||
|
||
object OperationType extends Enumeration { | ||
type Config = Value | ||
val ADD, SUBTRACT, MULTIPLY, DIVIDE, MOD = Value | ||
} | ||
|
||
private val MIN_ADJUSTED_SCALE = 6 | ||
val MAX_PRECISION = 38 | ||
|
||
// Returns the result decimal type of a decimal arithmetic computing. | ||
def getResultTypeForOperation( | ||
operationType: OperationType.Config, | ||
type1: DecimalType, | ||
type2: DecimalType): DecimalType = { | ||
def getResultType(expr: BinaryArithmetic, type1: DecimalType, type2: DecimalType): DecimalType = { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If we are going into here, that means there is no checkoverflow for decimal binary arithmetic. If there is no There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. No, this may run into by Checkoverflow(child: BinaryArithmetic), if we match this one like this version https://github.com/apache/incubator-gluten/compare/3c88fd40395c94505d76b5c146cd9498fe1a33b6..8110523f356e7356d1ff0a590610699f016e52d0, the tests can be passed, but performance regression, so I change it to current minimal change version. |
||
var resultScale = 0 | ||
var resultPrecision = 0 | ||
operationType match { | ||
case OperationType.ADD => | ||
expr match { | ||
case _: Add => | ||
resultScale = Math.max(type1.scale, type2.scale) | ||
resultPrecision = | ||
resultScale + Math.max(type1.precision - type1.scale, type2.precision - type2.scale) + 1 | ||
case OperationType.SUBTRACT => | ||
case _: Subtract => | ||
resultScale = Math.max(type1.scale, type2.scale) | ||
resultPrecision = | ||
resultScale + Math.max(type1.precision - type1.scale, type2.precision - type2.scale) + 1 | ||
case OperationType.MULTIPLY => | ||
case _: Multiply => | ||
resultScale = type1.scale + type2.scale | ||
resultPrecision = type1.precision + type2.precision + 1 | ||
case OperationType.DIVIDE => | ||
resultScale = Math.max(MIN_ADJUSTED_SCALE, type1.scale + type2.precision + 1) | ||
case _: Divide => | ||
resultScale = | ||
Math.max(DecimalType.MINIMUM_ADJUSTED_SCALE, type1.scale + type2.precision + 1) | ||
resultPrecision = type1.precision - type1.scale + type2.scale + resultScale | ||
case OperationType.MOD => | ||
resultScale = Math.max(type1.scale, type2.scale) | ||
resultPrecision = | ||
Math.min(type1.precision - type1.scale, type2.precision - type2.scale + resultScale) | ||
case other => | ||
throw new GlutenNotSupportException(s"$other is not supported.") | ||
} | ||
adjustScaleIfNeeded(resultPrecision, resultScale) | ||
} | ||
|
||
// Returns the adjusted decimal type when the precision is larger the maximum. | ||
private def adjustScaleIfNeeded(precision: Int, scale: Int): DecimalType = { | ||
var typePrecision = precision | ||
var typeScale = scale | ||
if (precision > MAX_PRECISION) { | ||
val minScale = Math.min(scale, MIN_ADJUSTED_SCALE) | ||
val delta = precision - MAX_PRECISION | ||
typePrecision = MAX_PRECISION | ||
typeScale = Math.max(scale - delta, minScale) | ||
} | ||
DecimalType(typePrecision, typeScale) | ||
DecimalTypeUtil.adjustPrecisionScale(resultPrecision, resultScale) | ||
} | ||
|
||
// If casting between DecimalType, unnecessary cast is skipped to avoid data loss, | ||
|
@@ -98,18 +69,6 @@ object DecimalArithmeticUtil { | |
} else false | ||
} | ||
|
||
// Returns the operation type of a binary arithmetic expression. | ||
def getOperationType(b: BinaryArithmetic): OperationType.Config = { | ||
b match { | ||
case _: Add => OperationType.ADD | ||
case _: Subtract => OperationType.SUBTRACT | ||
case _: Multiply => OperationType.MULTIPLY | ||
case _: Divide => OperationType.DIVIDE | ||
case other => | ||
throw new GlutenNotSupportException(s"$other is not supported.") | ||
} | ||
} | ||
|
||
// For decimal * 10 case, dec will be Decimal(38, 18), then the result precision is wrong, | ||
// so here we will get the real precision and scale of the literal. | ||
private def getNewPrecisionScale(dec: Decimal): (Integer, Integer) = { | ||
|
@@ -179,9 +138,7 @@ object DecimalArithmeticUtil { | |
if (isWiderType) (e1, newE2) else (e1, e2) | ||
} | ||
|
||
if (!BackendsApiManager.getSettings.rescaleDecimalIntegralExpression()) { | ||
(left, right) | ||
} else if (!isPromoteCast(left) && isPromoteCastIntegral(right)) { | ||
if (!isPromoteCast(left) && isPromoteCastIntegral(right)) { | ||
// Have removed PromotePrecision(Cast(DecimalType)). | ||
// Decimal * cast int. | ||
doScale(left, right) | ||
|
@@ -202,66 +159,32 @@ object DecimalArithmeticUtil { | |
* @return | ||
* expression removed child PromotePrecision->Cast | ||
*/ | ||
def removeCastForDecimal(arithmeticExpr: Expression): Expression = { | ||
arithmeticExpr match { | ||
case precision: PromotePrecision => | ||
precision.child match { | ||
case cast: Cast | ||
if cast.dataType.isInstanceOf[DecimalType] | ||
&& cast.child.dataType.isInstanceOf[DecimalType] => | ||
cast.child | ||
case _ => arithmeticExpr | ||
} | ||
case _ => arithmeticExpr | ||
} | ||
def removeCastForDecimal(arithmeticExpr: Expression): Expression = arithmeticExpr match { | ||
case PromotePrecision(_ @Cast(child, _: DecimalType, _, _)) | ||
if child.dataType.isInstanceOf[DecimalType] => | ||
child | ||
case _ => arithmeticExpr | ||
} | ||
|
||
@tailrec | ||
def getResultType(transformer: ExpressionTransformer): Option[DecimalType] = { | ||
transformer match { | ||
case ChildTransformer(child) => | ||
getResultType(child) | ||
case CheckOverflowTransformer(_, _, original) => | ||
Some(original.dataType) | ||
case DecimalArithmeticExpressionTransformer(_, _, _, resultType, _) => | ||
Some(resultType) | ||
case _ => None | ||
} | ||
} | ||
|
||
private def isPromoteCastIntegral(expr: Expression): Boolean = { | ||
expr match { | ||
case precision: PromotePrecision => | ||
precision.child match { | ||
case cast: Cast if cast.dataType.isInstanceOf[DecimalType] => | ||
cast.child.dataType match { | ||
case IntegerType | ByteType | ShortType | LongType => true | ||
case _ => false | ||
} | ||
case _ => false | ||
} | ||
case _ => false | ||
} | ||
private def isPromoteCastIntegral(expr: Expression): Boolean = expr match { | ||
case PromotePrecision(_ @Cast(child, _: DecimalType, _, _)) => | ||
child.dataType match { | ||
case IntegerType | ByteType | ShortType | LongType => true | ||
case _ => false | ||
} | ||
case _ => false | ||
} | ||
|
||
private def rescaleCastForOneSide(expr: Expression): Expression = { | ||
expr match { | ||
case precision: PromotePrecision => | ||
precision.child match { | ||
case castInt: Cast | ||
if castInt.dataType.isInstanceOf[DecimalType] && | ||
BackendsApiManager.getSettings.rescaleDecimalIntegralExpression() => | ||
castInt.child.dataType match { | ||
case IntegerType | ByteType | ShortType => | ||
precision.withNewChildren(Seq(Cast(castInt.child, DecimalType(10, 0)))) | ||
case LongType => | ||
precision.withNewChildren(Seq(Cast(castInt.child, DecimalType(20, 0)))) | ||
case _ => expr | ||
} | ||
case _ => expr | ||
} | ||
case _ => expr | ||
} | ||
private def rescaleCastForOneSide(expr: Expression): Expression = expr match { | ||
case precision @ PromotePrecision(_ @Cast(child, _: DecimalType, _, _)) => | ||
child.dataType match { | ||
case IntegerType | ByteType | ShortType => | ||
precision.withNewChildren(Seq(Cast(child, DecimalType(10, 0)))) | ||
case LongType => | ||
precision.withNewChildren(Seq(Cast(child, DecimalType(20, 0)))) | ||
case _ => expr | ||
} | ||
case _ => expr | ||
} | ||
|
||
private def checkIsWiderType( | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@jinchengchenghh if you update codes at the following
The same logical should apply there
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Velox transformCheckOverflow is true, so it won't go here. And rescale and remove cast only needs in velox backend