diff --git a/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala index b5c1462096737..3ee66c95edb99 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala @@ -39,6 +39,7 @@ import org.apache.spark.scheduler.{SparkListener, SparkListenerStageCompleted} import org.apache.spark.sql.{DataFrame, Encoder, Row, SparkSession} import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder import org.apache.spark.sql.functions.col +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.streaming.StreamingQueryException import org.apache.spark.sql.types._ import org.apache.spark.storage.StorageLevel @@ -220,7 +221,9 @@ class ALSSuite extends MLTest with DefaultReadWriteTest with Logging { (1231L, 12L, 0.5), (1112L, 21L, 1.0) )).toDF("item", "user", "rating") - new ALS().setMaxIter(1).fit(df) + withSQLConf(SQLConf.ANSI_ENABLED.key -> "false") { + new ALS().setMaxIter(1).fit(df) + } } withClue("Valid Double Ids") { @@ -719,40 +722,42 @@ class ALSSuite extends MLTest with DefaultReadWriteTest with Logging { (1, 1L, 1d, 0, 0L, 0d, 5.0) ).toDF("user", "user_big", "user_small", "item", "item_big", "item_small", "rating") val msg = "ALS only supports non-Null values" - withClue("fit should fail when ids exceed integer range. ") { - assert(intercept[Exception] { - als.fit(df.select(df("user_big").as("user"), df("item"), df("rating"))) - }.getMessage.contains(msg)) - assert(intercept[Exception] { - als.fit(df.select(df("user_small").as("user"), df("item"), df("rating"))) - }.getMessage.contains(msg)) - assert(intercept[Exception] { - als.fit(df.select(df("item_big").as("item"), df("user"), df("rating"))) - }.getMessage.contains(msg)) - assert(intercept[Exception] { - als.fit(df.select(df("item_small").as("item"), df("user"), df("rating"))) - }.getMessage.contains(msg)) - } - withClue("transform should fail when ids exceed integer range. ") { - val model = als.fit(df) - def testTransformIdExceedsIntRange[A : Encoder](dataFrame: DataFrame): Unit = { - val e1 = intercept[Exception] { - model.transform(dataFrame).collect() - } - TestUtils.assertExceptionMsg(e1, msg) - val e2 = intercept[StreamingQueryException] { - testTransformer[A](dataFrame, model, "prediction") { _ => } + withSQLConf(SQLConf.ANSI_ENABLED.key -> "false") { + withClue("fit should fail when ids exceed integer range. ") { + assert(intercept[Exception] { + als.fit(df.select(df("user_big").as("user"), df("item"), df("rating"))) + }.getMessage.contains(msg)) + assert(intercept[Exception] { + als.fit(df.select(df("user_small").as("user"), df("item"), df("rating"))) + }.getMessage.contains(msg)) + assert(intercept[Exception] { + als.fit(df.select(df("item_big").as("item"), df("user"), df("rating"))) + }.getMessage.contains(msg)) + assert(intercept[Exception] { + als.fit(df.select(df("item_small").as("item"), df("user"), df("rating"))) + }.getMessage.contains(msg)) + } + withClue("transform should fail when ids exceed integer range. ") { + val model = als.fit(df) + def testTransformIdExceedsIntRange[A : Encoder](dataFrame: DataFrame): Unit = { + val e1 = intercept[Exception] { + model.transform(dataFrame).collect() + } + TestUtils.assertExceptionMsg(e1, msg) + val e2 = intercept[StreamingQueryException] { + testTransformer[A](dataFrame, model, "prediction") { _ => } + } + TestUtils.assertExceptionMsg(e2, msg) } - TestUtils.assertExceptionMsg(e2, msg) + testTransformIdExceedsIntRange[(Long, Int)](df.select(df("user_big").as("user"), + df("item"))) + testTransformIdExceedsIntRange[(Double, Int)](df.select(df("user_small").as("user"), + df("item"))) + testTransformIdExceedsIntRange[(Long, Int)](df.select(df("item_big").as("item"), + df("user"))) + testTransformIdExceedsIntRange[(Double, Int)](df.select(df("item_small").as("item"), + df("user"))) } - testTransformIdExceedsIntRange[(Long, Int)](df.select(df("user_big").as("user"), - df("item"))) - testTransformIdExceedsIntRange[(Double, Int)](df.select(df("user_small").as("user"), - df("item"))) - testTransformIdExceedsIntRange[(Long, Int)](df.select(df("item_big").as("item"), - df("user"))) - testTransformIdExceedsIntRange[(Double, Int)](df.select(df("item_small").as("item"), - df("user"))) } }