Skip to content

Commit

Permalink
Add 3.5.1-SNAPSHOT Shim (#9962)
Browse files Browse the repository at this point in the history
* Added 351 snapshot shim

* Signing off

Signed-off-by: Raza Jafri <rjafri@nvidia.com>

* updated 2.13 pom.xml

* Fixed Decimal 128 Multiplication

* added comments to highlight the change in multiply128

* Removed the import alias

* Fixed Sequence size limit check

* changes to use the new multiply128. Changed the name of the shim to reflect the method name

* Handle empty partitions

* pulled in change from #10070

* Fixed the overflow check for addition and subtraction

* Updated test conditions

* Revert "Updated test conditions"

This reverts commit 533504f.

* addressed review comments

* renamed

---------

Signed-off-by: Raza Jafri <rjafri@nvidia.com>
  • Loading branch information
razajafri committed Dec 25, 2023
1 parent bb235c9 commit 11a91d4
Show file tree
Hide file tree
Showing 176 changed files with 641 additions and 83 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
{"spark": "341"}
{"spark": "341db"}
{"spark": "350"}
{"spark": "351"}
spark-rapids-shim-json-lines ***/
package org.apache.spark.sql.tests.datagen

Expand Down
5 changes: 3 additions & 2 deletions integration_tests/src/main/python/collection_ops_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from string_test import mk_str_gen
import pyspark.sql.functions as f
import pyspark.sql.utils
from spark_session import with_cpu_session, with_gpu_session
from spark_session import with_cpu_session, with_gpu_session, is_before_spark_351
from conftest import get_datagen_seed
from marks import allow_non_gpu

Expand Down Expand Up @@ -326,11 +326,12 @@ def test_sequence_illegal_boundaries(start_gen, stop_gen, step_gen):
@pytest.mark.parametrize('stop_gen', sequence_too_long_length_gens, ids=idfn)
@allow_non_gpu(*non_utc_allow)
def test_sequence_too_long_sequence(stop_gen):
msg = "Too long sequence" if is_before_spark_351() else "Unsuccessful try to create array with"
assert_gpu_and_cpu_error(
# To avoid OOM, reduce the row number to 1, it is enough to verify this case.
lambda spark:unary_op_df(spark, stop_gen, 1).selectExpr(
"sequence(0, a)").collect(),
conf = {}, error_message = "Too long sequence")
conf = {}, error_message = msg)

def get_sequence_cases_mixed_df(spark, length=2048):
# Generate the sequence data following the 3 rules mixed in a single dataset.
Expand Down
3 changes: 3 additions & 0 deletions integration_tests/src/main/python/spark_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,9 @@ def is_before_spark_341():
def is_before_spark_350():
return spark_version() < "3.5.0"

def is_before_spark_351():
return spark_version() < "3.5.1"

def is_spark_320_or_later():
return spark_version() >= "3.2.0"

Expand Down
22 changes: 22 additions & 0 deletions pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -554,6 +554,26 @@
<module>delta-lake/delta-stub</module>
</modules>
</profile>
<profile>
<id>release351</id>
<activation>
<property>
<name>buildver</name>
<value>351</value>
</property>
</activation>
<properties>
<buildver>351</buildver>
<spark.version>${spark351.version}</spark.version>
<spark.test.version>${spark351.version}</spark.test.version>
<parquet.hadoop.version>1.13.1</parquet.hadoop.version>
<iceberg.version>${spark330.iceberg.version}</iceberg.version>
<slf4j.version>2.0.7</slf4j.version>
</properties>
<modules>
<module>delta-lake/delta-stub</module>
</modules>
</profile>
<profile>
<id>source-javadoc</id>
<build>
Expand Down Expand Up @@ -718,6 +738,7 @@
<spark332db.version>3.3.2-databricks</spark332db.version>
<spark341db.version>3.4.1-databricks</spark341db.version>
<spark350.version>3.5.0</spark350.version>
<spark351.version>3.5.1-SNAPSHOT</spark351.version>
<mockito.version>3.12.4</mockito.version>
<scala.plugin.version>4.3.0</scala.plugin.version>
<maven.install.plugin.version>3.1.1</maven.install.plugin.version>
Expand Down Expand Up @@ -767,6 +788,7 @@
350
</noSnapshot.buildvers>
<snapshot.buildvers>
351
</snapshot.buildvers>
<databricks.buildvers>
321db,
Expand Down
22 changes: 22 additions & 0 deletions scala2.13/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -554,6 +554,26 @@
<module>delta-lake/delta-stub</module>
</modules>
</profile>
<profile>
<id>release351</id>
<activation>
<property>
<name>buildver</name>
<value>351</value>
</property>
</activation>
<properties>
<buildver>351</buildver>
<spark.version>${spark351.version}</spark.version>
<spark.test.version>${spark351.version}</spark.test.version>
<parquet.hadoop.version>1.13.1</parquet.hadoop.version>
<iceberg.version>${spark330.iceberg.version}</iceberg.version>
<slf4j.version>2.0.7</slf4j.version>
</properties>
<modules>
<module>delta-lake/delta-stub</module>
</modules>
</profile>
<profile>
<id>source-javadoc</id>
<build>
Expand Down Expand Up @@ -718,6 +738,7 @@
<spark332db.version>3.3.2-databricks</spark332db.version>
<spark341db.version>3.4.1-databricks</spark341db.version>
<spark350.version>3.5.0</spark350.version>
<spark351.version>3.5.1-SNAPSHOT</spark351.version>
<mockito.version>3.12.4</mockito.version>
<scala.plugin.version>4.3.0</scala.plugin.version>
<maven.install.plugin.version>3.1.1</maven.install.plugin.version>
Expand Down Expand Up @@ -767,6 +788,7 @@
350
</noSnapshot.buildvers>
<snapshot.buildvers>
351
</snapshot.buildvers>
<databricks.buildvers>
321db,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ import ai.rapids.cudf.ast.BinaryOperator
import com.nvidia.spark.rapids._
import com.nvidia.spark.rapids.Arm.withResource
import com.nvidia.spark.rapids.RapidsPluginImplicits._
import com.nvidia.spark.rapids.shims.{GpuTypeShims, ShimExpression, SparkShimImpl}
import com.nvidia.spark.rapids.shims.{DecimalMultiply128, GpuTypeShims, ShimExpression, SparkShimImpl}

import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, TypeCoercion}
import org.apache.spark.sql.catalyst.expressions.{ComplexTypeMergingExpression, ExpectsInputTypes, Expression, NullIntolerant}
Expand All @@ -38,7 +38,8 @@ object AddOverflowChecks {
lhs: BinaryOperable,
rhs: BinaryOperable,
ret: ColumnVector): Unit = {
// Check overflow. It is true when both arguments have the opposite sign of the result.
// Check overflow. It is true if the arguments have different signs and
// the sign of the result is different from the sign of x.
// Which is equal to "((x ^ r) & (y ^ r)) < 0" in the form of arithmetic.
val signCV = withResource(ret.bitXor(lhs)) { lXor =>
withResource(ret.bitXor(rhs)) { rXor =>
Expand All @@ -54,7 +55,7 @@ object AddOverflowChecks {
withResource(signDiff.any()) { any =>
if (any.isValid && any.getBoolean) {
throw RapidsErrorUtils.arithmeticOverflowError(
"One or more rows overflow for Add operation."
"One or more rows overflow for Add operation."
)
}
}
Expand Down Expand Up @@ -109,6 +110,35 @@ object AddOverflowChecks {
}
}

object SubtractOverflowChecks {
def basicOpOverflowCheck(
lhs: BinaryOperable,
rhs: BinaryOperable,
ret: ColumnVector): Unit = {
// Check overflow. It is true if the arguments have different signs and
// the sign of the result is different from the sign of x.
// Which is equal to "((x ^ y) & (x ^ r)) < 0" in the form of arithmetic.
val signCV = withResource(lhs.bitXor(rhs)) { xyXor =>
withResource(lhs.bitXor(ret)) { xrXor =>
xyXor.bitAnd(xrXor)
}
}
val signDiffCV = withResource(signCV) { sign =>
withResource(Scalar.fromInt(0)) { zero =>
sign.lessThan(zero)
}
}
withResource(signDiffCV) { signDiff =>
withResource(signDiff.any()) { any =>
if (any.isValid && any.getBoolean) {
throw RapidsErrorUtils.
arithmeticOverflowError("One or more rows overflow for Subtract operation.")
}
}
}
}
}

object GpuAnsi {
def needBasicOpOverflowCheck(dt: DataType): Boolean =
dt.isInstanceOf[IntegralType]
Expand Down Expand Up @@ -289,35 +319,6 @@ abstract class GpuSubtractBase extends CudfBinaryArithmetic with Serializable {
override def binaryOp: BinaryOp = BinaryOp.SUB
override def astOperator: Option[BinaryOperator] = Some(ast.BinaryOperator.SUB)

private[this] def basicOpOverflowCheck(
lhs: BinaryOperable,
rhs: BinaryOperable,
ret: ColumnVector): Unit = {
// Check overflow. It is true if the arguments have different signs and
// the sign of the result is different from the sign of x.
// Which is equal to "((x ^ y) & (x ^ r)) < 0" in the form of arithmetic.

val signCV = withResource(lhs.bitXor(rhs)) { xyXor =>
withResource(lhs.bitXor(ret)) { xrXor =>
xyXor.bitAnd(xrXor)
}
}
val signDiffCV = withResource(signCV) { sign =>
withResource(Scalar.fromInt(0)) { zero =>
sign.lessThan(zero)
}
}
withResource(signDiffCV) { signDiff =>
withResource(signDiff.any()) { any =>
if (any.isValid && any.getBoolean) {
throw RapidsErrorUtils.arithmeticOverflowError(
"One or more rows overflow for Subtract operation."
)
}
}
}
}

private[this] def decimalOpOverflowCheck(
lhs: BinaryOperable,
rhs: BinaryOperable,
Expand Down Expand Up @@ -367,7 +368,7 @@ abstract class GpuSubtractBase extends CudfBinaryArithmetic with Serializable {
GpuTypeShims.isSupportedYearMonthType(dataType)) {
// For day time interval, Spark throws an exception when overflow,
// regardless of whether `SQLConf.get.ansiEnabled` is true or false
basicOpOverflowCheck(lhs, rhs, ret)
SubtractOverflowChecks.basicOpOverflowCheck(lhs, rhs, ret)
}

if (dataType.isInstanceOf[DecimalType]) {
Expand Down Expand Up @@ -452,7 +453,7 @@ trait GpuDecimalMultiplyBase extends GpuExpression {
rhs.getBase.castTo(DType.create(DType.DTypeEnum.DECIMAL128, rhs.getBase.getType.getScale))
}
withResource(castRhs) { castRhs =>
com.nvidia.spark.rapids.jni.DecimalUtils.multiply128(castLhs, castRhs, -dataType.scale)
DecimalMultiply128(castLhs, castRhs, -dataType.scale)
}
}
val retCol = withResource(retTab) { retTab =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ import com.nvidia.spark.rapids.Arm._
import com.nvidia.spark.rapids.ArrayIndexUtils.firstIndexAndNumElementUnchecked
import com.nvidia.spark.rapids.BoolUtils.isAllValidTrue
import com.nvidia.spark.rapids.RapidsPluginImplicits._
import com.nvidia.spark.rapids.shims.ShimExpression
import com.nvidia.spark.rapids.shims.{GetSequenceSize, ShimExpression}

import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, TypeCoercion}
import org.apache.spark.sql.catalyst.expressions.{ElementAt, ExpectsInputTypes, Expression, ImplicitCastInputTypes, NamedExpression, NullIntolerant, RowOrdering, Sequence, TimeZoneAwareExpression}
Expand Down Expand Up @@ -1311,7 +1311,7 @@ class GpuSequenceMeta(

object GpuSequenceUtil {

private def checkSequenceInputs(
def checkSequenceInputs(
start: ColumnVector,
stop: ColumnVector,
step: ColumnVector): Unit = {
Expand Down Expand Up @@ -1371,62 +1371,38 @@ object GpuSequenceUtil {
*
* The returned column should be closed.
*/
def computeSequenceSizes(
def computeSequenceSize(
start: ColumnVector,
stop: ColumnVector,
step: ColumnVector): ColumnVector = {
checkSequenceInputs(start, stop, step)

// Spark's algorithm to get the length (aka size)
// ``` Scala
// size = if (start == stop) 1L else 1L + (stop.toLong - start.toLong) / estimatedStep.toLong
// require(size <= MAX_ROUNDED_ARRAY_LENGTH,
// s"Too long sequence: $len. Should be <= $MAX_ROUNDED_ARRAY_LENGTH")
// size.toInt
// ```
val sizeAsLong = withResource(Scalar.fromLong(1L)) { one =>
val diff = withResource(stop.castTo(DType.INT64)) { stopAsLong =>
withResource(start.castTo(DType.INT64)) { startAsLong =>
stopAsLong.sub(startAsLong)
}
}
val quotient = withResource(diff) { _ =>
withResource(step.castTo(DType.INT64)) { stepAsLong =>
diff.div(stepAsLong)
val actualSize = GetSequenceSize(start, stop, step)
val sizeAsLong = withResource(actualSize) { _ =>
val mergedEquals = withResource(start.equalTo(stop)) { equals =>
if (step.hasNulls) {
// Also set the row to null where step is null.
equals.mergeAndSetValidity(BinaryOp.BITWISE_AND, equals, step)
} else {
equals.incRefCount()
}
}
// actualSize = 1L + (stop.toLong - start.toLong) / estimatedStep.toLong
val actualSize = withResource(quotient) { quotient =>
quotient.add(one, DType.INT64)
}
withResource(actualSize) { _ =>
val mergedEquals = withResource(start.equalTo(stop)) { equals =>
if (step.hasNulls) {
// Also set the row to null where step is null.
equals.mergeAndSetValidity(BinaryOp.BITWISE_AND, equals, step)
} else {
equals.incRefCount()
}
}
withResource(mergedEquals) { _ =>
withResource(mergedEquals) { _ =>
withResource(Scalar.fromLong(1L)) { one =>
mergedEquals.ifElse(one, actualSize)
}
}
}

withResource(sizeAsLong) { _ =>
// check max size
withResource(Scalar.fromInt(MAX_ROUNDED_ARRAY_LENGTH)) { maxLen =>
withResource(sizeAsLong.lessOrEqualTo(maxLen)) { allValid =>
require(isAllValidTrue(allValid),
s"Too long sequence found. Should be <= $MAX_ROUNDED_ARRAY_LENGTH")
require(isAllValidTrue(allValid), GetSequenceSize.TOO_LONG_SEQUENCE)
}
}
// cast to int and return
sizeAsLong.castTo(DType.INT32)
}
}

}

case class GpuSequence(start: Expression, stop: Expression, stepOpt: Option[Expression],
Expand Down Expand Up @@ -1460,7 +1436,7 @@ case class GpuSequence(start: Expression, stop: Expression, stepOpt: Option[Expr
val steps = stepGpuColOpt.map(_.getBase.incRefCount())
.getOrElse(defaultStepsFunc(startCol, stopCol))
closeOnExcept(steps) { _ =>
(computeSequenceSizes(startCol, stopCol, steps), steps)
(computeSequenceSize(startCol, stopCol, steps), steps)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.rapids.execution.python

import scala.collection.JavaConverters.seqAsJavaListConverter
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
{"spark": "340"}
{"spark": "341"}
{"spark": "350"}
{"spark": "351"}
spark-rapids-shim-json-lines ***/
package com.nvidia.spark.rapids.shims

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
{"spark": "340"}
{"spark": "341"}
{"spark": "350"}
{"spark": "351"}
spark-rapids-shim-json-lines ***/
package com.nvidia.spark.rapids.shims

Expand Down
Loading

0 comments on commit 11a91d4

Please sign in to comment.