From d18fd7bcbdfe028e2e985ec6a8ec2f78bd5599c4 Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Sun, 3 Apr 2022 10:10:15 -0700 Subject: [PATCH] [SPARK-38776][MLLIB][TESTS] Disable ANSI_ENABLED explicitly in `ALSSuite` ### What changes were proposed in this pull request? This PR aims to disable `ANSI_ENABLED` explicitly in the following tests of `ALSSuite`. ``` test("ALS validate input dataset") { test("input type validation") { ``` ### Why are the changes needed? After SPARK-38490, this test became flaky in ANSI mode GitHub Action. ![Screen Shot 2022-04-03 at 12 07 29 AM](https://user-images.githubusercontent.com/9700541/161416006-7b76596f-c19a-4212-91d2-8602df569608.png) - https://github.com/apache/spark/runs/5800714463?check_suite_focus=true - https://github.com/apache/spark/runs/5803714260?check_suite_focus=true - https://github.com/apache/spark/runs/5803745768?check_suite_focus=true ``` [info] ALSSuite: ... [info] - ALS validate input dataset *** FAILED *** (2 seconds, 449 milliseconds) [info] Invalid Long: out of range "Job aborted due to stage failure: Task 0 in stage 100.0 failed 1 times, most recent failure: Lost task 0.0 in stage 100.0 (TID 348) (localhost executor driver): org.apache.spark.SparkArithmeticException: Casting 1231000000000 to int causes overflow. To return NULL instead, use 'try_cast'. If necessary set spark.sql.ansi.enabled to false to bypass this error. ``` ### Does this PR introduce _any_ user-facing change? No. This is a test-only bug and fix. ### How was this patch tested? Pass the CIs. Closes #36051 from dongjoon-hyun/SPARK-38776. Authored-by: Dongjoon Hyun Signed-off-by: Dongjoon Hyun --- .../spark/ml/recommendation/ALSSuite.scala | 71 ++++++++++--------- 1 file changed, 38 insertions(+), 33 deletions(-) diff --git a/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala index b5c1462096737..3ee66c95edb99 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala @@ -39,6 +39,7 @@ import org.apache.spark.scheduler.{SparkListener, SparkListenerStageCompleted} import org.apache.spark.sql.{DataFrame, Encoder, Row, SparkSession} import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder import org.apache.spark.sql.functions.col +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.streaming.StreamingQueryException import org.apache.spark.sql.types._ import org.apache.spark.storage.StorageLevel @@ -220,7 +221,9 @@ class ALSSuite extends MLTest with DefaultReadWriteTest with Logging { (1231L, 12L, 0.5), (1112L, 21L, 1.0) )).toDF("item", "user", "rating") - new ALS().setMaxIter(1).fit(df) + withSQLConf(SQLConf.ANSI_ENABLED.key -> "false") { + new ALS().setMaxIter(1).fit(df) + } } withClue("Valid Double Ids") { @@ -719,40 +722,42 @@ class ALSSuite extends MLTest with DefaultReadWriteTest with Logging { (1, 1L, 1d, 0, 0L, 0d, 5.0) ).toDF("user", "user_big", "user_small", "item", "item_big", "item_small", "rating") val msg = "ALS only supports non-Null values" - withClue("fit should fail when ids exceed integer range. ") { - assert(intercept[Exception] { - als.fit(df.select(df("user_big").as("user"), df("item"), df("rating"))) - }.getMessage.contains(msg)) - assert(intercept[Exception] { - als.fit(df.select(df("user_small").as("user"), df("item"), df("rating"))) - }.getMessage.contains(msg)) - assert(intercept[Exception] { - als.fit(df.select(df("item_big").as("item"), df("user"), df("rating"))) - }.getMessage.contains(msg)) - assert(intercept[Exception] { - als.fit(df.select(df("item_small").as("item"), df("user"), df("rating"))) - }.getMessage.contains(msg)) - } - withClue("transform should fail when ids exceed integer range. ") { - val model = als.fit(df) - def testTransformIdExceedsIntRange[A : Encoder](dataFrame: DataFrame): Unit = { - val e1 = intercept[Exception] { - model.transform(dataFrame).collect() - } - TestUtils.assertExceptionMsg(e1, msg) - val e2 = intercept[StreamingQueryException] { - testTransformer[A](dataFrame, model, "prediction") { _ => } + withSQLConf(SQLConf.ANSI_ENABLED.key -> "false") { + withClue("fit should fail when ids exceed integer range. ") { + assert(intercept[Exception] { + als.fit(df.select(df("user_big").as("user"), df("item"), df("rating"))) + }.getMessage.contains(msg)) + assert(intercept[Exception] { + als.fit(df.select(df("user_small").as("user"), df("item"), df("rating"))) + }.getMessage.contains(msg)) + assert(intercept[Exception] { + als.fit(df.select(df("item_big").as("item"), df("user"), df("rating"))) + }.getMessage.contains(msg)) + assert(intercept[Exception] { + als.fit(df.select(df("item_small").as("item"), df("user"), df("rating"))) + }.getMessage.contains(msg)) + } + withClue("transform should fail when ids exceed integer range. ") { + val model = als.fit(df) + def testTransformIdExceedsIntRange[A : Encoder](dataFrame: DataFrame): Unit = { + val e1 = intercept[Exception] { + model.transform(dataFrame).collect() + } + TestUtils.assertExceptionMsg(e1, msg) + val e2 = intercept[StreamingQueryException] { + testTransformer[A](dataFrame, model, "prediction") { _ => } + } + TestUtils.assertExceptionMsg(e2, msg) } - TestUtils.assertExceptionMsg(e2, msg) + testTransformIdExceedsIntRange[(Long, Int)](df.select(df("user_big").as("user"), + df("item"))) + testTransformIdExceedsIntRange[(Double, Int)](df.select(df("user_small").as("user"), + df("item"))) + testTransformIdExceedsIntRange[(Long, Int)](df.select(df("item_big").as("item"), + df("user"))) + testTransformIdExceedsIntRange[(Double, Int)](df.select(df("item_small").as("item"), + df("user"))) } - testTransformIdExceedsIntRange[(Long, Int)](df.select(df("user_big").as("user"), - df("item"))) - testTransformIdExceedsIntRange[(Double, Int)](df.select(df("user_small").as("user"), - df("item"))) - testTransformIdExceedsIntRange[(Long, Int)](df.select(df("item_big").as("item"), - df("user"))) - testTransformIdExceedsIntRange[(Double, Int)](df.select(df("item_small").as("item"), - df("user"))) } }