Skip to content

Commit

Permalink
[SPARK-38776][MLLIB][TESTS] Disable ANSI_ENABLED explicitly in `ALSSu…
Browse files Browse the repository at this point in the history
…ite`

### What changes were proposed in this pull request?

This PR aims to disable `ANSI_ENABLED` explicitly in the following tests of `ALSSuite`.
```
test("ALS validate input dataset") {
test("input type validation") {
```

### Why are the changes needed?

After SPARK-38490, this test became flaky in ANSI mode GitHub Action.

![Screen Shot 2022-04-03 at 12 07 29 AM](https://user-images.githubusercontent.com/9700541/161416006-7b76596f-c19a-4212-91d2-8602df569608.png)

- https://github.com/apache/spark/runs/5800714463?check_suite_focus=true
- https://github.com/apache/spark/runs/5803714260?check_suite_focus=true
- https://github.com/apache/spark/runs/5803745768?check_suite_focus=true

```
[info] ALSSuite:
...
[info] - ALS validate input dataset *** FAILED *** (2 seconds, 449 milliseconds)
[info]   Invalid Long: out of range "Job aborted due to stage failure: Task 0 in stage 100.0 failed 1 times, most recent failure: Lost task 0.0 in stage 100.0 (TID 348) (localhost executor driver):
org.apache.spark.SparkArithmeticException:
Casting 1231000000000 to int causes overflow.
To return NULL instead, use 'try_cast'.
If necessary set spark.sql.ansi.enabled to false to bypass this error.
```

### Does this PR introduce _any_ user-facing change?

No. This is a test-only bug and fix.

### How was this patch tested?

Pass the CIs.

Closes #36051 from dongjoon-hyun/SPARK-38776.

Authored-by: Dongjoon Hyun <dongjoon@apache.org>
Signed-off-by: Dongjoon Hyun <dongjoon@apache.org>
  • Loading branch information
dongjoon-hyun committed Apr 3, 2022
1 parent acbfd03 commit d18fd7b
Showing 1 changed file with 38 additions and 33 deletions.
Expand Up @@ -39,6 +39,7 @@ import org.apache.spark.scheduler.{SparkListener, SparkListenerStageCompleted}
import org.apache.spark.sql.{DataFrame, Encoder, Row, SparkSession}
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
import org.apache.spark.sql.functions.col
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.streaming.StreamingQueryException
import org.apache.spark.sql.types._
import org.apache.spark.storage.StorageLevel
Expand Down Expand Up @@ -220,7 +221,9 @@ class ALSSuite extends MLTest with DefaultReadWriteTest with Logging {
(1231L, 12L, 0.5),
(1112L, 21L, 1.0)
)).toDF("item", "user", "rating")
new ALS().setMaxIter(1).fit(df)
withSQLConf(SQLConf.ANSI_ENABLED.key -> "false") {
new ALS().setMaxIter(1).fit(df)
}
}

withClue("Valid Double Ids") {
Expand Down Expand Up @@ -719,40 +722,42 @@ class ALSSuite extends MLTest with DefaultReadWriteTest with Logging {
(1, 1L, 1d, 0, 0L, 0d, 5.0)
).toDF("user", "user_big", "user_small", "item", "item_big", "item_small", "rating")
val msg = "ALS only supports non-Null values"
withClue("fit should fail when ids exceed integer range. ") {
assert(intercept[Exception] {
als.fit(df.select(df("user_big").as("user"), df("item"), df("rating")))
}.getMessage.contains(msg))
assert(intercept[Exception] {
als.fit(df.select(df("user_small").as("user"), df("item"), df("rating")))
}.getMessage.contains(msg))
assert(intercept[Exception] {
als.fit(df.select(df("item_big").as("item"), df("user"), df("rating")))
}.getMessage.contains(msg))
assert(intercept[Exception] {
als.fit(df.select(df("item_small").as("item"), df("user"), df("rating")))
}.getMessage.contains(msg))
}
withClue("transform should fail when ids exceed integer range. ") {
val model = als.fit(df)
def testTransformIdExceedsIntRange[A : Encoder](dataFrame: DataFrame): Unit = {
val e1 = intercept[Exception] {
model.transform(dataFrame).collect()
}
TestUtils.assertExceptionMsg(e1, msg)
val e2 = intercept[StreamingQueryException] {
testTransformer[A](dataFrame, model, "prediction") { _ => }
withSQLConf(SQLConf.ANSI_ENABLED.key -> "false") {
withClue("fit should fail when ids exceed integer range. ") {
assert(intercept[Exception] {
als.fit(df.select(df("user_big").as("user"), df("item"), df("rating")))
}.getMessage.contains(msg))
assert(intercept[Exception] {
als.fit(df.select(df("user_small").as("user"), df("item"), df("rating")))
}.getMessage.contains(msg))
assert(intercept[Exception] {
als.fit(df.select(df("item_big").as("item"), df("user"), df("rating")))
}.getMessage.contains(msg))
assert(intercept[Exception] {
als.fit(df.select(df("item_small").as("item"), df("user"), df("rating")))
}.getMessage.contains(msg))
}
withClue("transform should fail when ids exceed integer range. ") {
val model = als.fit(df)
def testTransformIdExceedsIntRange[A : Encoder](dataFrame: DataFrame): Unit = {
val e1 = intercept[Exception] {
model.transform(dataFrame).collect()
}
TestUtils.assertExceptionMsg(e1, msg)
val e2 = intercept[StreamingQueryException] {
testTransformer[A](dataFrame, model, "prediction") { _ => }
}
TestUtils.assertExceptionMsg(e2, msg)
}
TestUtils.assertExceptionMsg(e2, msg)
testTransformIdExceedsIntRange[(Long, Int)](df.select(df("user_big").as("user"),
df("item")))
testTransformIdExceedsIntRange[(Double, Int)](df.select(df("user_small").as("user"),
df("item")))
testTransformIdExceedsIntRange[(Long, Int)](df.select(df("item_big").as("item"),
df("user")))
testTransformIdExceedsIntRange[(Double, Int)](df.select(df("item_small").as("item"),
df("user")))
}
testTransformIdExceedsIntRange[(Long, Int)](df.select(df("user_big").as("user"),
df("item")))
testTransformIdExceedsIntRange[(Double, Int)](df.select(df("user_small").as("user"),
df("item")))
testTransformIdExceedsIntRange[(Long, Int)](df.select(df("item_big").as("item"),
df("user")))
testTransformIdExceedsIntRange[(Double, Int)](df.select(df("item_small").as("item"),
df("user")))
}
}

Expand Down

0 comments on commit d18fd7b

Please sign in to comment.