Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add 3.5.1-SNAPSHOT Shim #9962

Merged
merged 17 commits into from
Dec 25, 2023
Merged

Conversation

razajafri
Copy link
Collaborator

@razajafri razajafri commented Dec 5, 2023

This PR adds shims for Spark 3.5.1-SNAPSHOT.

Changes Made:

  • The following Shimplify command was run
mvn generate-sources -Dshimplify=true -Dshimplify.move=true -Dshimplify.overwrite=true -Dshimplify.add.shim=351 -Dshimplify.add.base=350

The only files that were manually changed were pom.xml and ShimServiceProvider.scala to add the SNAPSHOT version to the VERSIONNAMES. Also removed some empty lines as a result of the above Shimplify command

  • Added a DecimalUtilShims.scala which calls the respective multiplication method depending on the Spark version. In Spark 3.5.1 and other versions, the multiplication doesn't perform an interim cast and as part of spark-rapids-jni PR another method called mul128 was added which skips the interim cast.
  • Added ComputeSequenceSize.scala to provide a shim for the new method to calculate sequence size and to make sure it's within limit.
  • Made relevant changes to GpuBatchScanExec to match the changes in Spark

Tests:
All integration tests were run locally

fixes #9258
fixes #9859
fixes #9875
fixes #9743

jlowe
jlowe previously approved these changes Dec 5, 2023
andygrove
andygrove previously approved these changes Dec 5, 2023
@sameerz sameerz added the feature request New feature or request label Dec 5, 2023
@razajafri razajafri dismissed stale reviews from andygrove and jlowe via 175dc8d December 7, 2023 01:13
@tgravescs
Copy link
Collaborator

I'm assuming your decimal multiple is related to #9859... If so pleas emake sure it fixes it all the way or we comment on that issue. the shim is very hard to read, one calls mul128 the other calls multiply128. I haven't went and looked at those but one its hard to even see that diff so you should in the very least add a comment or point to issue and explain.

@razajafri
Copy link
Collaborator Author

the shim is very hard to read, one calls mul128 the other calls multiply128. I haven't went and looked at those but one its hard to even see that diff so you should in the very least add a comment or point to issue and explain.

I will go ahead and put in some comments to highlight the change

@razajafri
Copy link
Collaborator Author

I'm assuming your decimal multiple is related to #9859... If so pleas emake sure it fixes it all the way or we comment on that issue

Discussed this offline. I missed the division bit of the puzzle. Will verify division and post an update here

@razajafri
Copy link
Collaborator Author

I'm assuming your decimal multiple is related to #9859... If so pleas emake sure it fixes it all the way or we comment on that issue

Discussed this offline. I missed the division bit of the puzzle. Will verify division and post an update here

I have verified the Decimal division and we match Spark 3.5.1 output.

It turns out that we were always doing the right thing on the GPU for decimal division. So to match Spark bug for bug we should "fix" the versions Databricks 330+ and Spark versions 340+ by returning the bad answer. I have created an issue for that here

@razajafri
Copy link
Collaborator Author

build

@razajafri
Copy link
Collaborator Author

premerge failing due to an unrelated change

[2023-12-18T20:02:45.179Z] [ERROR] /home/ubuntu/spark-rapids/tests/src/test/scala/org/apache/spark/sql/rapids/metrics/source/MockTaskContext.scala:69: overriding method getKillReason in class TaskContext of type ()Option[String];
[2023-12-18T20:02:45.180Z]  method getKillReason has weaker access privileges; it should be public
[2023-12-18T20:02:45.180Z] [ERROR]   override private[spark] def getKillReason() = None
[2023-12-18T20:02:45.180Z] [ERROR]                               ^
[2023-12-18T20:02:45.180Z] [ERROR] one error found

@razajafri
Copy link
Collaborator Author

@NvTimLiu @pxLi why is the premerge picking up databricks 13.3

@razajafri
Copy link
Collaborator Author

build

@razajafri
Copy link
Collaborator Author

build

@razajafri
Copy link
Collaborator Author

@NvTimLiu @pxLi why is the premerge picking up databricks 13.3

Nevermind, I saw that there is a function that checks if there are any changes in the Databricks shims in addition to the word databricks in the title of the PR

@razajafri
Copy link
Collaborator Author

build

@razajafri razajafri changed the title Added 3.5.1-SNAPSHOT Shim Add 3.5.1-SNAPSHOT Shim Dec 20, 2023
@razajafri
Copy link
Collaborator Author

build

@razajafri razajafri marked this pull request as ready for review December 20, 2023 16:46
@razajafri
Copy link
Collaborator Author

build

@razajafri
Copy link
Collaborator Author

I have reverted the tests for versions that we don't support yet. They will be added in other shims

@razajafri
Copy link
Collaborator Author

build

@razajafri
Copy link
Collaborator Author

@andygrove can you PTAL?

andygrove
andygrove previously approved these changes Dec 20, 2023
Comment on lines 57 to 58
throw RapidsErrorUtils.
arithmeticOverflowError("One or more rows overflow for Add operation.")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let us leave formatting-only changes to dedicated PRs

Comment on lines 85 to 111
withResource(actualSize) { _ =>
val mergedEquals = withResource(start.equalTo(stop)) { equals =>
if (step.hasNulls) {
// Also set the row to null where step is null.
equals.mergeAndSetValidity(BinaryOp.BITWISE_AND, equals, step)
} else {
equals.incRefCount()
}
}
withResource(mergedEquals) { _ =>
mergedEquals.ifElse(one, actualSize)
}
}
}

withResource(sizeAsLong) { _ =>
// check max size
withResource(Scalar.fromInt(MAX_ROUNDED_ARRAY_LENGTH)) { maxLen =>
withResource(sizeAsLong.lessOrEqualTo(maxLen)) { allValid =>
require(isAllValidTrue(allValid),
s"Too long sequence found. Should be <= $MAX_ROUNDED_ARRAY_LENGTH")
}
}
// cast to int and return
sizeAsLong.castTo(DType.INT32)
}
}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The bottom portion L85-L111 in 311 and L98-L126 in 351 differ only in the require message let us refactor to minimize shimming

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This file should be dropped thanks to #9902

@razajafri
Copy link
Collaborator Author

Thanks for taking a look @gerashegalov

PTAL

@razajafri
Copy link
Collaborator Author

build

Copy link
Collaborator

@NvTimLiu NvTimLiu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@NvTimLiu
Copy link
Collaborator

Will we plan to run nightly integration tests against spark-3.5.1-SNAPSHOT?

@razajafri razajafri merged commit 11a91d4 into NVIDIA:branch-24.02 Dec 25, 2023
39 checks passed
@razajafri razajafri deleted the SP-9258-351-shim branch December 25, 2023 07:18
@razajafri
Copy link
Collaborator Author

Will we plan to run nightly integration tests against spark-3.5.1-SNAPSHOT?

Yes, we do

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
feature request New feature or request
Projects
None yet
7 participants