Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[GLUTEN-5620][CORE] Remove check_overflow and refactor code #5654

Merged
merged 3 commits into from
May 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,7 @@ class CHTransformerApi extends TransformerApi with Logging {
args: java.lang.Object,
substraitExprName: String,
childNode: ExpressionNode,
childResultType: DataType,
dataType: DecimalType,
nullable: Boolean,
nullOnOverflow: Boolean): ExpressionNode = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -449,14 +449,12 @@ object VeloxBackendSettings extends BackendSettingsApi {
override def fallbackAggregateWithEmptyOutputChild(): Boolean = true

override def recreateJoinExecOnFallback(): Boolean = true
override def rescaleDecimalLiteral(): Boolean = true
override def rescaleDecimalArithmetic(): Boolean = true

/** Get the config prefix for each backend */
override def getBackendConfigPrefix(): String =
GlutenConfig.GLUTEN_CONFIG_PREFIX + VeloxBackend.BACKEND_NAME

override def rescaleDecimalIntegralExpression(): Boolean = true

override def shuffleSupportedCodec(): Set[String] = SHUFFLE_SUPPORTED_CODEC

override def resolveNativeConf(nativeConf: java.util.Map[String, String]): Unit = {}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -80,11 +80,16 @@ class VeloxTransformerApi extends TransformerApi with Logging {
args: java.lang.Object,
substraitExprName: String,
childNode: ExpressionNode,
childResultType: DataType,
dataType: DecimalType,
nullable: Boolean,
nullOnOverflow: Boolean): ExpressionNode = {
val typeNode = ConverterUtils.getTypeNode(dataType, nullable)
ExpressionBuilder.makeCast(typeNode, childNode, !nullOnOverflow)
if (childResultType.equals(dataType)) {
childNode
} else {
val typeNode = ConverterUtils.getTypeNode(dataType, nullable)
ExpressionBuilder.makeCast(typeNode, childNode, !nullOnOverflow)
}
}

override def getNativePlanString(substraitPlan: Array[Byte], details: Boolean): String = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ trait BackendSettingsApi {
def supportShuffleWithProject(outputPartitioning: Partitioning, child: SparkPlan): Boolean = false
def utilizeShuffledHashJoinHint(): Boolean = false
def excludeScanExecFromCollapsedStage(): Boolean = false
def rescaleDecimalLiteral: Boolean = false
def rescaleDecimalArithmetic: Boolean = false

/**
* Whether to replace sort agg with hash agg., e.g., sort agg will be used in spark's planning for
Expand All @@ -106,8 +106,6 @@ trait BackendSettingsApi {
*/
def transformCheckOverflow: Boolean = true

def rescaleDecimalIntegralExpression(): Boolean = false

def shuffleSupportedCodec(): Set[String]

def needOutputSchemaForPlan(): Boolean = false
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression}
import org.apache.spark.sql.connector.read.InputPartition
import org.apache.spark.sql.execution.SparkPlan
import org.apache.spark.sql.execution.datasources.{HadoopFsRelation, PartitionDirectory}
import org.apache.spark.sql.types.DecimalType
import org.apache.spark.sql.types.{DataType, DecimalType}
import org.apache.spark.util.collection.BitSet

import com.google.protobuf.{Any, Message}
Expand Down Expand Up @@ -69,6 +69,7 @@ trait TransformerApi {
args: java.lang.Object,
substraitExprName: String,
childNode: ExpressionNode,
childResultType: DataType,
dataType: DecimalType,
nullable: Boolean,
nullOnOverflow: Boolean): ExpressionNode
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,28 @@ object ExpressionConverter extends SQLConfHelper with Logging {
}
}

private def genRescaleDecimalTransformer(
substraitName: String,
b: BinaryArithmetic,
attributeSeq: Seq[Attribute],
expressionsMap: Map[Class[_], String]): DecimalArithmeticExpressionTransformer = {
val rescaleBinary = DecimalArithmeticUtil.rescaleLiteral(b)
val (left, right) = DecimalArithmeticUtil.rescaleCastForDecimal(
DecimalArithmeticUtil.removeCastForDecimal(rescaleBinary.left),
DecimalArithmeticUtil.removeCastForDecimal(rescaleBinary.right))
val resultType = DecimalArithmeticUtil.getResultType(
b,
left.dataType.asInstanceOf[DecimalType],
right.dataType.asInstanceOf[DecimalType]
)

val leftChild =
replaceWithExpressionTransformerInternal(left, attributeSeq, expressionsMap)
val rightChild =
replaceWithExpressionTransformerInternal(right, attributeSeq, expressionsMap)
DecimalArithmeticExpressionTransformer(substraitName, leftChild, rightChild, resultType, b)
}

private def replaceWithExpressionTransformerInternal(
expr: Expression,
attributeSeq: Seq[Attribute],
Expand Down Expand Up @@ -492,7 +514,6 @@ object ExpressionConverter extends SQLConfHelper with Logging {
expr.children.map(
replaceWithExpressionTransformerInternal(_, attributeSeq, expressionsMap)),
expr)

case CheckOverflow(b: BinaryArithmetic, decimalType, _)
if !BackendsApiManager.getSettings.transformCheckOverflow &&
DecimalArithmeticUtil.isDecimalArithmetic(b) =>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jinchengchenghh if you update codes at the following

      case b: BinaryArithmetic if DecimalArithmeticUtil.isDecimalArithmetic(b) =>
        DecimalArithmeticUtil.checkAllowDecimalArithmetic()
       ///...

The same logical should apply there

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Velox transformCheckOverflow is true, so it won't go here. And rescale and remove cast only needs in velox backend

Expand All @@ -507,55 +528,25 @@ object ExpressionConverter extends SQLConfHelper with Logging {
rightChild,
decimalType,
b)

case c: CheckOverflow =>
CheckOverflowTransformer(
substraitExprName,
replaceWithExpressionTransformerInternal(c.child, attributeSeq, expressionsMap),
c.child.dataType,
c)

case b: BinaryArithmetic if DecimalArithmeticUtil.isDecimalArithmetic(b) =>
DecimalArithmeticUtil.checkAllowDecimalArithmetic()
if (!BackendsApiManager.getSettings.transformCheckOverflow) {
val leftChild =
replaceWithExpressionTransformerInternal(b.left, attributeSeq, expressionsMap)
val rightChild =
replaceWithExpressionTransformerInternal(b.right, attributeSeq, expressionsMap)
DecimalArithmeticExpressionTransformer(
GenericExpressionTransformer(
substraitExprName,
leftChild,
rightChild,
b.dataType.asInstanceOf[DecimalType],
b)
} else {
val rescaleBinary = if (BackendsApiManager.getSettings.rescaleDecimalLiteral) {
DecimalArithmeticUtil.rescaleLiteral(b)
} else {
b
}
val (left, right) = DecimalArithmeticUtil.rescaleCastForDecimal(
DecimalArithmeticUtil.removeCastForDecimal(rescaleBinary.left),
DecimalArithmeticUtil.removeCastForDecimal(rescaleBinary.right))
val leftChild =
replaceWithExpressionTransformerInternal(left, attributeSeq, expressionsMap)
val rightChild =
replaceWithExpressionTransformerInternal(right, attributeSeq, expressionsMap)

val resultType = DecimalArithmeticUtil.getResultTypeForOperation(
DecimalArithmeticUtil.getOperationType(b),
DecimalArithmeticUtil
.getResultType(leftChild)
.getOrElse(left.dataType.asInstanceOf[DecimalType]),
DecimalArithmeticUtil
.getResultType(rightChild)
.getOrElse(right.dataType.asInstanceOf[DecimalType])
expr.children.map(
replaceWithExpressionTransformerInternal(_, attributeSeq, expressionsMap)),
expr
)
DecimalArithmeticExpressionTransformer(
substraitExprName,
leftChild,
rightChild,
resultType,
b)
} else {
// Without the rescale and remove cast, result is right for high version Spark,
// but performance regression in velox
genRescaleDecimalTransformer(substraitExprName, b, attributeSeq, expressionsMap)
}
case n: NaNvl =>
BackendsApiManager.getSparkPlanExecApiInstance.genNaNvlTransformer(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,7 @@ case class PosExplodeTransformer(
case class CheckOverflowTransformer(
substraitExprName: String,
child: ExpressionTransformer,
childResultType: DataType,
original: CheckOverflow)
extends ExpressionTransformer {

Expand All @@ -160,6 +161,7 @@ case class CheckOverflowTransformer(
args,
substraitExprName,
child.doTransform(args),
childResultType,
original.dataType,
original.nullable,
original.nullOnOverflow)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,69 +18,40 @@ package org.apache.gluten.utils

import org.apache.gluten.backendsapi.BackendsApiManager
import org.apache.gluten.exception.GlutenNotSupportException
import org.apache.gluten.expression.{CheckOverflowTransformer, ChildTransformer, DecimalArithmeticExpressionTransformer, ExpressionTransformer}
import org.apache.gluten.expression.ExpressionConverter.conf

import org.apache.spark.sql.catalyst.analysis.DecimalPrecision
import org.apache.spark.sql.catalyst.expressions.{Add, BinaryArithmetic, Cast, Divide, Expression, Literal, Multiply, Pmod, PromotePrecision, Remainder, Subtract}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.{ByteType, Decimal, DecimalType, IntegerType, LongType, ShortType}

import scala.annotation.tailrec
import org.apache.spark.sql.utils.DecimalTypeUtil

object DecimalArithmeticUtil {

object OperationType extends Enumeration {
type Config = Value
val ADD, SUBTRACT, MULTIPLY, DIVIDE, MOD = Value
}

private val MIN_ADJUSTED_SCALE = 6
val MAX_PRECISION = 38

// Returns the result decimal type of a decimal arithmetic computing.
def getResultTypeForOperation(
operationType: OperationType.Config,
type1: DecimalType,
type2: DecimalType): DecimalType = {
def getResultType(expr: BinaryArithmetic, type1: DecimalType, type2: DecimalType): DecimalType = {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we are going into here, that means there is no checkoverflow for decimal binary arithmetic. If there is no PromotePrecision, should the result decimal type be b.dataType ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, this may run into by Checkoverflow(child: BinaryArithmetic), if we match this one like this version https://github.com/apache/incubator-gluten/compare/3c88fd40395c94505d76b5c146cd9498fe1a33b6..8110523f356e7356d1ff0a590610699f016e52d0, the tests can be passed, but performance regression, so I change it to current minimal change version.

var resultScale = 0
var resultPrecision = 0
operationType match {
case OperationType.ADD =>
expr match {
case _: Add =>
resultScale = Math.max(type1.scale, type2.scale)
resultPrecision =
resultScale + Math.max(type1.precision - type1.scale, type2.precision - type2.scale) + 1
case OperationType.SUBTRACT =>
case _: Subtract =>
resultScale = Math.max(type1.scale, type2.scale)
resultPrecision =
resultScale + Math.max(type1.precision - type1.scale, type2.precision - type2.scale) + 1
case OperationType.MULTIPLY =>
case _: Multiply =>
resultScale = type1.scale + type2.scale
resultPrecision = type1.precision + type2.precision + 1
case OperationType.DIVIDE =>
resultScale = Math.max(MIN_ADJUSTED_SCALE, type1.scale + type2.precision + 1)
case _: Divide =>
resultScale =
Math.max(DecimalType.MINIMUM_ADJUSTED_SCALE, type1.scale + type2.precision + 1)
resultPrecision = type1.precision - type1.scale + type2.scale + resultScale
case OperationType.MOD =>
resultScale = Math.max(type1.scale, type2.scale)
resultPrecision =
Math.min(type1.precision - type1.scale, type2.precision - type2.scale + resultScale)
case other =>
throw new GlutenNotSupportException(s"$other is not supported.")
}
adjustScaleIfNeeded(resultPrecision, resultScale)
}

// Returns the adjusted decimal type when the precision is larger the maximum.
private def adjustScaleIfNeeded(precision: Int, scale: Int): DecimalType = {
var typePrecision = precision
var typeScale = scale
if (precision > MAX_PRECISION) {
val minScale = Math.min(scale, MIN_ADJUSTED_SCALE)
val delta = precision - MAX_PRECISION
typePrecision = MAX_PRECISION
typeScale = Math.max(scale - delta, minScale)
}
DecimalType(typePrecision, typeScale)
DecimalTypeUtil.adjustPrecisionScale(resultPrecision, resultScale)
}

// If casting between DecimalType, unnecessary cast is skipped to avoid data loss,
Expand All @@ -98,18 +69,6 @@ object DecimalArithmeticUtil {
} else false
}

// Returns the operation type of a binary arithmetic expression.
def getOperationType(b: BinaryArithmetic): OperationType.Config = {
b match {
case _: Add => OperationType.ADD
case _: Subtract => OperationType.SUBTRACT
case _: Multiply => OperationType.MULTIPLY
case _: Divide => OperationType.DIVIDE
case other =>
throw new GlutenNotSupportException(s"$other is not supported.")
}
}

// For decimal * 10 case, dec will be Decimal(38, 18), then the result precision is wrong,
// so here we will get the real precision and scale of the literal.
private def getNewPrecisionScale(dec: Decimal): (Integer, Integer) = {
Expand Down Expand Up @@ -179,9 +138,7 @@ object DecimalArithmeticUtil {
if (isWiderType) (e1, newE2) else (e1, e2)
}

if (!BackendsApiManager.getSettings.rescaleDecimalIntegralExpression()) {
(left, right)
} else if (!isPromoteCast(left) && isPromoteCastIntegral(right)) {
if (!isPromoteCast(left) && isPromoteCastIntegral(right)) {
// Have removed PromotePrecision(Cast(DecimalType)).
// Decimal * cast int.
doScale(left, right)
Expand All @@ -202,66 +159,32 @@ object DecimalArithmeticUtil {
* @return
* expression removed child PromotePrecision->Cast
*/
def removeCastForDecimal(arithmeticExpr: Expression): Expression = {
arithmeticExpr match {
case precision: PromotePrecision =>
precision.child match {
case cast: Cast
if cast.dataType.isInstanceOf[DecimalType]
&& cast.child.dataType.isInstanceOf[DecimalType] =>
cast.child
case _ => arithmeticExpr
}
case _ => arithmeticExpr
}
def removeCastForDecimal(arithmeticExpr: Expression): Expression = arithmeticExpr match {
case PromotePrecision(_ @Cast(child, _: DecimalType, _, _))
if child.dataType.isInstanceOf[DecimalType] =>
child
case _ => arithmeticExpr
}

@tailrec
def getResultType(transformer: ExpressionTransformer): Option[DecimalType] = {
transformer match {
case ChildTransformer(child) =>
getResultType(child)
case CheckOverflowTransformer(_, _, original) =>
Some(original.dataType)
case DecimalArithmeticExpressionTransformer(_, _, _, resultType, _) =>
Some(resultType)
case _ => None
}
}

private def isPromoteCastIntegral(expr: Expression): Boolean = {
expr match {
case precision: PromotePrecision =>
precision.child match {
case cast: Cast if cast.dataType.isInstanceOf[DecimalType] =>
cast.child.dataType match {
case IntegerType | ByteType | ShortType | LongType => true
case _ => false
}
case _ => false
}
case _ => false
}
private def isPromoteCastIntegral(expr: Expression): Boolean = expr match {
case PromotePrecision(_ @Cast(child, _: DecimalType, _, _)) =>
child.dataType match {
case IntegerType | ByteType | ShortType | LongType => true
case _ => false
}
case _ => false
}

private def rescaleCastForOneSide(expr: Expression): Expression = {
expr match {
case precision: PromotePrecision =>
precision.child match {
case castInt: Cast
if castInt.dataType.isInstanceOf[DecimalType] &&
BackendsApiManager.getSettings.rescaleDecimalIntegralExpression() =>
castInt.child.dataType match {
case IntegerType | ByteType | ShortType =>
precision.withNewChildren(Seq(Cast(castInt.child, DecimalType(10, 0))))
case LongType =>
precision.withNewChildren(Seq(Cast(castInt.child, DecimalType(20, 0))))
case _ => expr
}
case _ => expr
}
case _ => expr
}
private def rescaleCastForOneSide(expr: Expression): Expression = expr match {
case precision @ PromotePrecision(_ @Cast(child, _: DecimalType, _, _)) =>
child.dataType match {
case IntegerType | ByteType | ShortType =>
precision.withNewChildren(Seq(Cast(child, DecimalType(10, 0))))
case LongType =>
precision.withNewChildren(Seq(Cast(child, DecimalType(20, 0))))
case _ => expr
}
case _ => expr
}

private def checkIsWiderType(
Expand Down