Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adopt changes from JNI for casting from float to decimal #10917

Draft
wants to merge 3 commits into
base: branch-24.08
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 17 additions & 1 deletion integration_tests/src/main/python/cast_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,6 +261,22 @@ def test_cast_long_to_decimal_overflow():
lambda spark : unary_op_df(spark, long_gen).select(
f.col('a').cast(DecimalType(18, -1))))

@approximate_float
@pytest.mark.parametrize('to_type', [
DecimalType(7, 1),
DecimalType(9, 9),
DecimalType(15, 2),
DecimalType(15, 15),
DecimalType(30, 3),
DecimalType(5, -3),
DecimalType(3, 0)], ids=idfn)
def test_cast_float_to_decimal(to_type):
assert_gpu_and_cpu_are_equal_collect(
lambda spark : unary_op_df(spark, double_gen).select(
f.col('a'), f.col('a').cast(to_type).cast('double')),
conf = {'spark.rapids.sql.castFloatToDecimal.enabled': 'true',
'spark.rapids.sql.castDecimalToFloat.enabled': 'true'})

# casting these types to string should be passed
basic_gens_for_cast_to_string = [ByteGen, ShortGen, IntegerGen, LongGen, StringGen, BooleanGen, DateGen, TimestampGen]
basic_array_struct_gens_for_cast_to_string = [f() for f in basic_gens_for_cast_to_string] + [null_gen] + decimal_gens
Expand Down Expand Up @@ -306,7 +322,7 @@ def test_cast_array_to_string(data_gen, legacy):
_assert_cast_to_string_equal(
data_gen,
{"spark.sql.legacy.castComplexTypesToString.enabled": legacy})

def test_cast_float_to_string():
assert_gpu_and_cpu_are_equal_collect(
lambda spark: unary_op_df(spark, FloatGen()).selectExpr("cast(cast(a as string) as float)"),
Expand Down
143 changes: 54 additions & 89 deletions sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuCast.scala
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ import java.util.Optional

import scala.collection.mutable.ArrayBuffer

import ai.rapids.cudf.{BinaryOp, CaptureGroups, ColumnVector, ColumnView, DecimalUtils, DType, RegexProgram, Scalar}
import ai.rapids.cudf.{BinaryOp, CaptureGroups, ColumnVector, ColumnView, DType, RegexProgram, Scalar}
import ai.rapids.cudf
import com.nvidia.spark.rapids.Arm.{closeOnExcept, withResource}
import com.nvidia.spark.rapids.RapidsPluginImplicits._
Expand Down Expand Up @@ -192,7 +192,7 @@ object CastOptions {
val ARITH_ANSI_OPTIONS = new CastOptions(false, true, false)
val TO_PRETTY_STRING_OPTIONS = ToPrettyStringOptions

def getArithmeticCastOptions(failOnError: Boolean): CastOptions =
def getArithmeticCastOptions(failOnError: Boolean): CastOptions =
if (failOnError) ARITH_ANSI_OPTIONS else DEFAULT_CAST_OPTIONS

object ToPrettyStringOptions extends CastOptions(false, false, false,
Expand Down Expand Up @@ -628,7 +628,7 @@ object GpuCast {
case (TimestampType, DateType) if options.timeZoneId.isDefined =>
val zoneId = DateTimeUtils.getZoneId(options.timeZoneId.get)
withResource(GpuTimeZoneDB.fromUtcTimestampToTimestamp(input.asInstanceOf[ColumnVector],
zoneId.normalized())) {
zoneId.normalized())) {
shifted => shifted.castTo(GpuColumnVector.getNonNestedRapidsType(toDataType))
}
case _ =>
Expand Down Expand Up @@ -707,37 +707,37 @@ object GpuCast {
* @param inclusiveMin Whether the min value is included in the valid range or not
* @param inclusiveMax Whether the max value is included in the valid range or not
*/
private def replaceOutOfRangeValues(values: ColumnView,
minValue: => Scalar,
maxValue: => Scalar,
replaceValue: => Scalar,
inclusiveMin: Boolean,
inclusiveMax: Boolean): ColumnVector = {

withResource(minValue) { minValue =>
withResource(maxValue) { maxValue =>
val minPredicate = if (inclusiveMin) {
values.lessThan(minValue)
} else {
values.lessOrEqualTo(minValue)
}
withResource(minPredicate) { minPredicate =>
val maxPredicate = if (inclusiveMax) {
values.greaterThan(maxValue)
} else {
values.greaterOrEqualTo(maxValue)
}
withResource(maxPredicate) { maxPredicate =>
withResource(maxPredicate.or(minPredicate)) { rangePredicate =>
withResource(replaceValue) { nullScalar =>
rangePredicate.ifElse(nullScalar, values)
}
}
}
}
}
}
}
// private def replaceOutOfRangeValues(values: ColumnView,
// minValue: => Scalar,
// maxValue: => Scalar,
// replaceValue: => Scalar,
// inclusiveMin: Boolean,
// inclusiveMax: Boolean): ColumnVector = {
//
// withResource(minValue) { minValue =>
// withResource(maxValue) { maxValue =>
// val minPredicate = if (inclusiveMin) {
// values.lessThan(minValue)
// } else {
// values.lessOrEqualTo(minValue)
// }
// withResource(minPredicate) { minPredicate =>
// val maxPredicate = if (inclusiveMax) {
// values.greaterThan(maxValue)
// } else {
// values.greaterOrEqualTo(maxValue)
// }
// withResource(maxPredicate) { maxPredicate =>
// withResource(maxPredicate.or(minPredicate)) { rangePredicate =>
// withResource(replaceValue) { nullScalar =>
// rangePredicate.ifElse(nullScalar, values)
// }
// }
// }
// }
// }
// }
// }

def castToString(
input: ColumnView,
Expand Down Expand Up @@ -1639,64 +1639,29 @@ object GpuCast {
dt: DecimalType,
ansiMode: Boolean): ColumnVector = {

// Approach to minimize difference between CPUCast and GPUCast:
// step 1. cast input to FLOAT64 (if necessary)
// step 2. cast FLOAT64 to container DECIMAL (who keeps one more digit for rounding)
// step 3. perform HALF_UP rounding on container DECIMAL
val checkedInput = withResource(input.castTo(DType.FLOAT64)) { double =>
val roundedDouble = double.round(dt.scale, cudf.RoundMode.HALF_UP)
withResource(roundedDouble) { rounded =>
// We rely on containerDecimal to perform preciser rounding. So, we have to take extra
// space cost of container into consideration when we run bound check.
val containerScaleBound = DType.DECIMAL128_MAX_PRECISION - (dt.scale + 1)
val bound = math.pow(10, (dt.precision - dt.scale) min containerScaleBound)
if (ansiMode) {
assertValuesInRange[Double](rounded,
minValue = -bound,
maxValue = bound,
inclusiveMin = false,
inclusiveMax = false)
rounded.incRefCount()
} else {
replaceOutOfRangeValues(rounded,
minValue = Scalar.fromDouble(-bound),
maxValue = Scalar.fromDouble(bound),
inclusiveMin = false,
inclusiveMax = false,
replaceValue = Scalar.fromNull(DType.FLOAT64))
}
}
}

withResource(checkedInput) { checked =>
val targetType = DecimalUtil.createCudfDecimal(dt)
// If target scale reaches DECIMAL128_MAX_PRECISION, container DECIMAL can not
// be created because of precision overflow. In this case, we perform casting op directly.
val casted = if (DType.DECIMAL128_MAX_PRECISION == dt.scale) {
checked.castTo(targetType)

if (DType.DECIMAL128_MAX_PRECISION == dt.scale) {
if (ansiMode) {
val bound = math.pow(10, dt.precision - dt.scale)
assertValuesInRange[Double](input,
minValue = -bound,
maxValue = bound,
inclusiveMin = false,
inclusiveMax = false)
}
input.castTo(targetType)
} else {
// Increase precision by one along with scale in case of overflow, which may lead to
// the upcast of cuDF decimal type. If precision already hits the max precision, it is safe
// to increase the scale solely because we have checked and replaced out of range values.
val containerType = DecimalUtils.createDecimalType(
dt.precision + 1 min DType.DECIMAL128_MAX_PRECISION, dt.scale + 1)
withResource(checked.castTo(containerType)) { container =>
withResource(container.round(dt.scale, cudf.RoundMode.HALF_UP)) { rd =>
// The cast here is for cases that cuDF decimal type got promoted as precision + 1.
// Need to convert back to original cuDF type, to keep align with the precision.
rd.castTo(targetType)
}
}
}
// Cast NaN values to nulls
withResource(casted) { casted =>
withResource(input.isNan) { inputIsNan =>
withResource(Scalar.fromNull(targetType)) { nullScalar =>
inputIsNan.ifElse(nullScalar, casted)
}
val converted =
com.nvidia.spark.rapids.jni.DecimalUtils.floatingPointToDecimal(input, targetType,
dt.precision)
if (ansiMode && converted.hasInvalid) {
// converted.column.close()
throw RapidsErrorUtils.arithmeticOverflowError(OVERFLOW_MESSAGE)
}
converted.column
}
}

}

def fixDecimalBounds(input: ColumnView,
Expand Down Expand Up @@ -1901,4 +1866,4 @@ case class GpuCast(

override def doColumnar(input: GpuColumnVector): ColumnVector =
doCast(input.getBase, input.dataType(), dataType, options)
}
}
Loading