From c4008480b781379ac0451b9220300d83c054c60d Mon Sep 17 00:00:00 2001 From: Takeshi Yamamuro Date: Wed, 29 Mar 2017 12:37:49 -0700 Subject: [PATCH] [SPARK-20009][SQL] Support DDL strings for defining schema in functions.from_json ## What changes were proposed in this pull request? This pr added `StructType.fromDDL` to convert a DDL format string into `StructType` for defining schemas in `functions.from_json`. ## How was this patch tested? Added tests in `JsonFunctionsSuite`. Author: Takeshi Yamamuro Closes #17406 from maropu/SPARK-20009. --- .../apache/spark/sql/types/StructType.scala | 6 ++ .../spark/sql/types/DataTypeSuite.scala | 85 ++++++++++++++----- .../org/apache/spark/sql/functions.scala | 15 +++- .../apache/spark/sql/JsonFunctionsSuite.scala | 7 ++ .../sql/sources/SimpleTextRelation.scala | 2 +- 5 files changed, 90 insertions(+), 25 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala index 8d8b5b86d5aa1..54006e20a3eb6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala @@ -417,6 +417,12 @@ object StructType extends AbstractDataType { } } + /** + * Creates StructType for a given DDL-formatted string, which is a comma separated list of field + * definitions, e.g., a INT, b STRING. + */ + def fromDDL(ddl: String): StructType = CatalystSqlParser.parseTableSchema(ddl) + def apply(fields: Seq[StructField]): StructType = StructType(fields.toArray) def apply(fields: java.util.List[StructField]): StructType = { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala index 61e1ec7c7ab35..05cb999af6a50 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala @@ -169,30 +169,72 @@ class DataTypeSuite extends SparkFunSuite { assert(!arrayType.existsRecursively(_.isInstanceOf[IntegerType])) } - def checkDataTypeJsonRepr(dataType: DataType): Unit = { - test(s"JSON - $dataType") { + def checkDataTypeFromJson(dataType: DataType): Unit = { + test(s"from Json - $dataType") { assert(DataType.fromJson(dataType.json) === dataType) } } - checkDataTypeJsonRepr(NullType) - checkDataTypeJsonRepr(BooleanType) - checkDataTypeJsonRepr(ByteType) - checkDataTypeJsonRepr(ShortType) - checkDataTypeJsonRepr(IntegerType) - checkDataTypeJsonRepr(LongType) - checkDataTypeJsonRepr(FloatType) - checkDataTypeJsonRepr(DoubleType) - checkDataTypeJsonRepr(DecimalType(10, 5)) - checkDataTypeJsonRepr(DecimalType.SYSTEM_DEFAULT) - checkDataTypeJsonRepr(DateType) - checkDataTypeJsonRepr(TimestampType) - checkDataTypeJsonRepr(StringType) - checkDataTypeJsonRepr(BinaryType) - checkDataTypeJsonRepr(ArrayType(DoubleType, true)) - checkDataTypeJsonRepr(ArrayType(StringType, false)) - checkDataTypeJsonRepr(MapType(IntegerType, StringType, true)) - checkDataTypeJsonRepr(MapType(IntegerType, ArrayType(DoubleType), false)) + def checkDataTypeFromDDL(dataType: DataType): Unit = { + test(s"from DDL - $dataType") { + val parsed = StructType.fromDDL(s"a ${dataType.sql}") + val expected = new StructType().add("a", dataType) + assert(parsed.sameType(expected)) + } + } + + checkDataTypeFromJson(NullType) + + checkDataTypeFromJson(BooleanType) + checkDataTypeFromDDL(BooleanType) + + checkDataTypeFromJson(ByteType) + checkDataTypeFromDDL(ByteType) + + checkDataTypeFromJson(ShortType) + checkDataTypeFromDDL(ShortType) + + checkDataTypeFromJson(IntegerType) + checkDataTypeFromDDL(IntegerType) + + checkDataTypeFromJson(LongType) + checkDataTypeFromDDL(LongType) + + checkDataTypeFromJson(FloatType) + checkDataTypeFromDDL(FloatType) + + checkDataTypeFromJson(DoubleType) + checkDataTypeFromDDL(DoubleType) + + checkDataTypeFromJson(DecimalType(10, 5)) + checkDataTypeFromDDL(DecimalType(10, 5)) + + checkDataTypeFromJson(DecimalType.SYSTEM_DEFAULT) + checkDataTypeFromDDL(DecimalType.SYSTEM_DEFAULT) + + checkDataTypeFromJson(DateType) + checkDataTypeFromDDL(DateType) + + checkDataTypeFromJson(TimestampType) + checkDataTypeFromDDL(TimestampType) + + checkDataTypeFromJson(StringType) + checkDataTypeFromDDL(StringType) + + checkDataTypeFromJson(BinaryType) + checkDataTypeFromDDL(BinaryType) + + checkDataTypeFromJson(ArrayType(DoubleType, true)) + checkDataTypeFromDDL(ArrayType(DoubleType, true)) + + checkDataTypeFromJson(ArrayType(StringType, false)) + checkDataTypeFromDDL(ArrayType(StringType, false)) + + checkDataTypeFromJson(MapType(IntegerType, StringType, true)) + checkDataTypeFromDDL(MapType(IntegerType, StringType, true)) + + checkDataTypeFromJson(MapType(IntegerType, ArrayType(DoubleType), false)) + checkDataTypeFromDDL(MapType(IntegerType, ArrayType(DoubleType), false)) val metadata = new MetadataBuilder() .putString("name", "age") @@ -201,7 +243,8 @@ class DataTypeSuite extends SparkFunSuite { StructField("a", IntegerType, nullable = true), StructField("b", ArrayType(DoubleType), nullable = false), StructField("c", DoubleType, nullable = false, metadata))) - checkDataTypeJsonRepr(structType) + checkDataTypeFromJson(structType) + checkDataTypeFromDDL(structType) def checkDefaultSize(dataType: DataType, expectedDefaultSize: Int): Unit = { test(s"Check the default size of $dataType") { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index acdb8e2d3edc8..0f9203065ef05 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -21,6 +21,7 @@ import scala.collection.JavaConverters._ import scala.language.implicitConversions import scala.reflect.runtime.universe.{typeTag, TypeTag} import scala.util.Try +import scala.util.control.NonFatal import org.apache.spark.annotation.{Experimental, InterfaceStability} import org.apache.spark.sql.catalyst.ScalaReflection @@ -3055,13 +3056,21 @@ object functions { * with the specified schema. Returns `null`, in the case of an unparseable string. * * @param e a string column containing JSON data. - * @param schema the schema to use when parsing the json string as a json string + * @param schema the schema to use when parsing the json string as a json string. In Spark 2.1, + * the user-provided schema has to be in JSON format. Since Spark 2.2, the DDL + * format is also supported for the schema. * * @group collection_funcs * @since 2.1.0 */ - def from_json(e: Column, schema: String, options: java.util.Map[String, String]): Column = - from_json(e, DataType.fromJson(schema), options) + def from_json(e: Column, schema: String, options: java.util.Map[String, String]): Column = { + val dataType = try { + DataType.fromJson(schema) + } catch { + case NonFatal(_) => StructType.fromDDL(schema) + } + from_json(e, dataType, options) + } /** * (Scala-specific) Converts a column containing a `StructType` or `ArrayType` of `StructType`s diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala index 170c238c53438..8465e8d036a6d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala @@ -156,6 +156,13 @@ class JsonFunctionsSuite extends QueryTest with SharedSQLContext { Row(Seq(Row(1, "a"), Row(2, null), Row(null, null)))) } + test("from_json uses DDL strings for defining a schema") { + val df = Seq("""{"a": 1, "b": "haa"}""").toDS() + checkAnswer( + df.select(from_json($"value", "a INT, b STRING", new java.util.HashMap[String, String]())), + Row(Row(1, "haa")) :: Nil) + } + test("to_json - struct") { val df = Seq(Tuple1(Tuple1(1))).toDF("a") diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextRelation.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextRelation.scala index 1607c97cd6acb..9f4009bfe402a 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextRelation.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextRelation.scala @@ -21,7 +21,7 @@ import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.{FileStatus, Path} import org.apache.hadoop.mapreduce.{Job, TaskAttemptContext} -import org.apache.spark.sql.{sources, Row, SparkSession} +import org.apache.spark.sql.{sources, SparkSession} import org.apache.spark.sql.catalyst.{expressions, InternalRow} import org.apache.spark.sql.catalyst.expressions.{Cast, Expression, GenericInternalRow, InterpretedPredicate, InterpretedProjection, JoinedRow, Literal} import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection