diff --git a/spark/src/main/spark-4.0/org/apache/comet/shims/CometExprShim.scala b/spark/src/main/spark-4.0/org/apache/comet/shims/CometExprShim.scala index 0b75907670..06e9d2278a 100644 --- a/spark/src/main/spark-4.0/org/apache/comet/shims/CometExprShim.scala +++ b/spark/src/main/spark-4.0/org/apache/comet/shims/CometExprShim.scala @@ -20,7 +20,8 @@ package org.apache.comet.shims import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.expressions.objects.StaticInvoke +import org.apache.spark.sql.catalyst.expressions.json.StructsToJsonEvaluator +import org.apache.spark.sql.catalyst.expressions.objects.{Invoke, StaticInvoke} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.internal.types.StringTypeWithCollation import org.apache.spark.sql.types.{BinaryType, BooleanType, DataTypes, StringType} @@ -114,6 +115,19 @@ trait CometExprShim extends CommonStringExprs { case k: KnownNotContainsNull => exprToProtoInternal(k.child, inputs, binding) + // In Spark 4.0, StructsToJson is a RuntimeReplaceable whose replacement is + // Invoke(Literal(StructsToJsonEvaluator), "evaluate", ...). Reconstruct the + // original StructsToJson and recurse so support-level checks apply. + case i: Invoke => + (i.targetObject, i.functionName, i.arguments) match { + case (Literal(evaluator: StructsToJsonEvaluator, _), "evaluate", Seq(child)) => + exprToProtoInternal( + StructsToJson(evaluator.options, child, evaluator.timeZoneId), + inputs, + binding) + case _ => None + } + case _ => None } } diff --git a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala index c3d66340c4..9091b871c5 100644 --- a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala @@ -2271,8 +2271,6 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper { } test("to_json") { - // TODO fix for Spark 4.0.0 - assume(!isSpark40Plus) withSQLConf(CometConf.getExprAllowIncompatConfigKey(classOf[StructsToJson]) -> "true") { Seq(true, false).foreach { dictionaryEnabled => withParquetTable( @@ -2298,8 +2296,6 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper { } test("to_json escaping of field names and string values") { - // TODO fix for Spark 4.0.0 - assume(!isSpark40Plus) withSQLConf(CometConf.getExprAllowIncompatConfigKey(classOf[StructsToJson]) -> "true") { val gen = new DataGenerator(new Random(42)) val chars = "\\'\"abc\t\r\n\f\b" @@ -2329,8 +2325,6 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper { } test("to_json unicode") { - // TODO fix for Spark 4.0.0 - assume(!isSpark40Plus) withSQLConf(CometConf.getExprAllowIncompatConfigKey(classOf[StructsToJson]) -> "true") { Seq(true, false).foreach { dictionaryEnabled => withParquetTable( diff --git a/spark/src/test/scala/org/apache/comet/CometJsonExpressionSuite.scala b/spark/src/test/scala/org/apache/comet/CometJsonExpressionSuite.scala index 64c330dbdd..c2c1ce3663 100644 --- a/spark/src/test/scala/org/apache/comet/CometJsonExpressionSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometJsonExpressionSuite.scala @@ -30,7 +30,6 @@ import org.apache.spark.sql.catalyst.expressions.{JsonToStructs, StructsToJson} import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper import org.apache.spark.sql.functions._ -import org.apache.comet.CometSparkSessionExtensions.isSpark40Plus import org.apache.comet.serde.CometStructsToJson import org.apache.comet.testing.{DataGenOptions, ParquetGenerator, SchemaGenOptions} @@ -48,7 +47,6 @@ class CometJsonExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelpe } test("to_json - all supported types") { - assume(!isSpark40Plus) withTempDir { dir => val path = new Path(dir.toURI.toString, "test.parquet") val filename = path.toString