diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/json/JsonExpressionUtils.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/json/JsonExpressionUtils.java index 07e13610aa950..c930d5aa77094 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/json/JsonExpressionUtils.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/json/JsonExpressionUtils.java @@ -21,12 +21,22 @@ import java.util.ArrayList; import java.util.List; +import scala.Option; + +import com.fasterxml.jackson.core.JsonFactory; import com.fasterxml.jackson.core.JsonParser; import com.fasterxml.jackson.core.JsonToken; import org.apache.spark.sql.catalyst.expressions.SharedFactory; import org.apache.spark.sql.catalyst.json.CreateJacksonParser; +import org.apache.spark.sql.catalyst.json.JSONOptions; +import org.apache.spark.sql.catalyst.json.JsonInferSchema; import org.apache.spark.sql.catalyst.util.GenericArrayData; +import org.apache.spark.sql.internal.SQLConf; +import org.apache.spark.sql.types.ArrayType; +import org.apache.spark.sql.types.DataType; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; import org.apache.spark.unsafe.types.UTF8String; public class JsonExpressionUtils { @@ -86,4 +96,35 @@ public static GenericArrayData jsonObjectKeys(UTF8String json) { return null; } } + + public static UTF8String schemaOfJson( + JsonFactory jsonFactory, + JSONOptions jsonOptions, + JsonInferSchema jsonInferSchema, + UTF8String json) { + DataType schema; + try (JsonParser jsonParser = CreateJacksonParser.utf8String(jsonFactory, json)) { + jsonParser.nextToken(); + // To match with schema inference from JSON datasource. + DataType inferSchema = jsonInferSchema.inferField(jsonParser); + if (inferSchema instanceof StructType) { + Option canonicalType = jsonInferSchema.canonicalizeType(inferSchema, jsonOptions); + schema = canonicalType.isDefined() ? + canonicalType.get() : new StructType(new StructField[0]); + } else if (inferSchema instanceof ArrayType at && at.elementType() instanceof StructType et) { + Option canonicalType = jsonInferSchema.canonicalizeType(et, jsonOptions) + .map(dt -> ArrayType.apply(dt, at.containsNull())); + schema = canonicalType.isDefined() ? canonicalType.get() : + ArrayType.apply(new StructType(new StructField[0]), at.containsNull()); + } else { + Option canonicalType = jsonInferSchema.canonicalizeType(inferSchema, jsonOptions); + schema = canonicalType.isDefined() ? + canonicalType.get() : SQLConf.get().defaultStringType(); + } + } catch (IOException e) { + throw new RuntimeException(e); + } + + return UTF8String.fromString(schema.sql()); + } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExprUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExprUtils.scala index 08cb03edb78b6..38b927f5bbf38 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExprUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExprUtils.scala @@ -32,11 +32,11 @@ import org.apache.spark.sql.internal.types.{AbstractMapType, StringTypeWithCaseA import org.apache.spark.sql.types.{DataType, MapType, StringType, StructType, VariantType} import org.apache.spark.unsafe.types.UTF8String -object ExprUtils extends QueryErrorsBase { +object ExprUtils extends EvalHelper with QueryErrorsBase { def evalTypeExpr(exp: Expression): DataType = { if (exp.foldable) { - exp.eval() match { + prepareForEval(exp).eval() match { case s: UTF8String if s != null => val dataType = DataType.parseTypeWithFallback( s.toString, diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala index e01531cc821c9..2b882cfd1419f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala @@ -878,7 +878,9 @@ case class StructsToJson( case class SchemaOfJson( child: Expression, options: Map[String, String]) - extends UnaryExpression with CodegenFallback with QueryErrorsBase { + extends UnaryExpression + with RuntimeReplaceable + with QueryErrorsBase { def this(child: Expression) = this(child, Map.empty[String, String]) @@ -919,26 +921,20 @@ case class SchemaOfJson( } } - override def eval(v: InternalRow): Any = { - val dt = Utils.tryWithResource(CreateJacksonParser.utf8String(jsonFactory, json)) { parser => - parser.nextToken() - // To match with schema inference from JSON datasource. - jsonInferSchema.inferField(parser) match { - case st: StructType => - jsonInferSchema.canonicalizeType(st, jsonOptions).getOrElse(StructType(Nil)) - case at: ArrayType if at.elementType.isInstanceOf[StructType] => - jsonInferSchema - .canonicalizeType(at.elementType, jsonOptions) - .map(ArrayType(_, containsNull = at.containsNull)) - .getOrElse(ArrayType(StructType(Nil), containsNull = at.containsNull)) - case other: DataType => - jsonInferSchema.canonicalizeType(other, jsonOptions).getOrElse( - SQLConf.get.defaultStringType) - } - } + @transient private lazy val jsonFactoryObjectType = ObjectType(classOf[JsonFactory]) + @transient private lazy val jsonOptionsObjectType = ObjectType(classOf[JSONOptions]) + @transient private lazy val jsonInferSchemaObjectType = ObjectType(classOf[JsonInferSchema]) - UTF8String.fromString(dt.sql) - } + override def replacement: Expression = StaticInvoke( + classOf[JsonExpressionUtils], + dataType, + "schemaOfJson", + Seq(Literal(jsonFactory, jsonFactoryObjectType), + Literal(jsonOptions, jsonOptionsObjectType), + Literal(jsonInferSchema, jsonInferSchemaObjectType), + child), + Seq(jsonFactoryObjectType, jsonOptionsObjectType, jsonInferSchemaObjectType, child.dataType) + ) override def prettyName: String = "schema_of_json" diff --git a/sql/connect/common/src/test/resources/query-tests/explain-results/function_schema_of_json.explain b/sql/connect/common/src/test/resources/query-tests/explain-results/function_schema_of_json.explain index 8ec799bc58084..96f2da0cee0dd 100644 --- a/sql/connect/common/src/test/resources/query-tests/explain-results/function_schema_of_json.explain +++ b/sql/connect/common/src/test/resources/query-tests/explain-results/function_schema_of_json.explain @@ -1,2 +1,2 @@ -Project [schema_of_json([{"col":01}]) AS schema_of_json([{"col":01}])#0] +Project [static_invoke(JsonExpressionUtils.schemaOfJson(com.fasterxml.jackson.core.JsonFactory, org.apache.spark.sql.catalyst.json.JSONOptions, org.apache.spark.sql.catalyst.json.JsonInferSchema, [{"col":01}])) AS schema_of_json([{"col":01}])#0] +- LocalRelation , [id#0L, a#0, b#0, d#0, e#0, f#0, g#0] diff --git a/sql/connect/common/src/test/resources/query-tests/explain-results/function_schema_of_json_with_options.explain b/sql/connect/common/src/test/resources/query-tests/explain-results/function_schema_of_json_with_options.explain index 13867949177a4..96f2da0cee0dd 100644 --- a/sql/connect/common/src/test/resources/query-tests/explain-results/function_schema_of_json_with_options.explain +++ b/sql/connect/common/src/test/resources/query-tests/explain-results/function_schema_of_json_with_options.explain @@ -1,2 +1,2 @@ -Project [schema_of_json([{"col":01}], (allowNumericLeadingZeros,true)) AS schema_of_json([{"col":01}])#0] +Project [static_invoke(JsonExpressionUtils.schemaOfJson(com.fasterxml.jackson.core.JsonFactory, org.apache.spark.sql.catalyst.json.JSONOptions, org.apache.spark.sql.catalyst.json.JsonInferSchema, [{"col":01}])) AS schema_of_json([{"col":01}])#0] +- LocalRelation , [id#0L, a#0, b#0, d#0, e#0, f#0, g#0]