From 0d706fffb133c1d685e4aaa0b62758d0826bad62 Mon Sep 17 00:00:00 2001 From: jinxing Date: Fri, 3 Nov 2017 17:39:48 +0800 Subject: [PATCH] Support processing array and map type using script --- .../spark/sql/execution/SparkSqlParser.scala | 28 +++++--- .../execution/ScriptTransformationExec.scala | 41 +++++++++-- .../sql/hive/execution/SQLQuerySuite.scala | 70 +++++++++++++++++++ 3 files changed, 124 insertions(+), 15 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala index 6de9ea0efd2c6..990535e949de6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala @@ -1454,14 +1454,15 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder(conf) { def format( fmt: RowFormatContext, configKey: String, - defaultConfigValue: String): Format = fmt match { + defaultConfigValue: String, + isInFormat: Boolean): Format = fmt match { case c: RowFormatDelimitedContext => // TODO we should use the visitRowFormatDelimited function here. However HiveScriptIOSchema // expects a seq of pairs in which the old parsers' token names are used as keys. // Transforming the result of visitRowFormatDelimited would be quite a bit messier than // retrieving the key value pairs ourselves. def entry(key: String, value: Token): Seq[(String, String)] = { - Option(value).map(t => key -> t.getText).toSeq + Option(value).toSeq.map(x => key -> string(x)) } val entries = entry("TOK_TABLEROWFORMATFIELD", c.fieldsTerminatedBy) ++ entry("TOK_TABLEROWFORMATCOLLITEMS", c.collectionItemsTerminatedBy) ++ @@ -1469,7 +1470,8 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder(conf) { entry("TOK_TABLEROWFORMATLINES", c.linesSeparatedBy) ++ entry("TOK_TABLEROWFORMATNULL", c.nullDefinedAs) - (entries, None, Seq.empty, None) + val recordHandler = Option(conf.getConfString(configKey, defaultConfigValue)) + (entries, None, Seq.empty, recordHandler) case c: RowFormatSerdeContext => // Use a serde format. @@ -1485,21 +1487,27 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder(conf) { case null => // Use default (serde) format. - val name = conf.getConfString("hive.script.serde", - "org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe") + val name = if (isInFormat) { + conf.getConfString("hive.script.serde", + "org.apache.hadoop.hive.serde2.DelimitedJSONSerDe") + } else { + conf.getConfString("hive.script.serde", + "org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe") + } val props = Seq("field.delim" -> "\t") val recordHandler = Option(conf.getConfString(configKey, defaultConfigValue)) (Nil, Option(name), props, recordHandler) } - val (inFormat, inSerdeClass, inSerdeProps, reader) = + val (inFormat, inSerdeClass, inSerdeProps, writer) = format( - inRowFormat, "hive.script.recordreader", "org.apache.hadoop.hive.ql.exec.TextRecordReader") + inRowFormat, "hive.script.recordwriter", + "org.apache.hadoop.hive.ql.exec.TextRecordWriter", isInFormat = true) - val (outFormat, outSerdeClass, outSerdeProps, writer) = + val (outFormat, outSerdeClass, outSerdeProps, reader) = format( - outRowFormat, "hive.script.recordwriter", - "org.apache.hadoop.hive.ql.exec.TextRecordWriter") + outRowFormat, "hive.script.recordreader", + "org.apache.hadoop.hive.ql.exec.TextRecordReader", isInFormat = false) ScriptInputOutputSchema( inFormat, outFormat, diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformationExec.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformationExec.scala index d786a610f1535..34677e7281970 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformationExec.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformationExec.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.hive.execution import java.io._ import java.nio.charset.StandardCharsets +import java.util.Map.Entry import java.util.Properties import javax.annotation.Nullable @@ -267,6 +268,33 @@ private class ScriptTransformationWriterThread( /** Contains the exception thrown while writing the parent iterator to the external process. */ def exception: Option[Throwable] = Option(_exception) + private def buildJSONString(sb: StringBuilder, obj: Any): Unit = { + obj match { + case null => + sb.append(ioschema.inputRowFormatMap("TOK_TABLEROWFORMATNULL")); + case list: java.util.List[_] => + (0 until list.size()).foreach { i => + if (i > 0) { + sb.append(ioschema.inputRowFormatMap("TOK_TABLEROWFORMATCOLLITEMS")); + } + buildJSONString(sb, list.get(i)) + } + case map: java.util.Map[_, _] => + val entries = map.entrySet().toArray() + (0 until entries.size).foreach { i => + if (i > 0) { + sb.append(ioschema.inputRowFormatMap("TOK_TABLEROWFORMATCOLLITEMS")); + } + val entry = entries(i).asInstanceOf[Entry[_, _]] + buildJSONString(sb, entry.getKey) + sb.append(ioschema.inputRowFormatMap("TOK_TABLEROWFORMATMAPKEYS")); + buildJSONString(sb, entry.getValue) + } + case other => + sb.append(other.toString) + } + } + override def run(): Unit = Utils.logUncaughtExceptions { TaskContext.setTaskContext(taskContext) @@ -279,16 +307,19 @@ private class ScriptTransformationWriterThread( val len = inputSchema.length try { iter.map(outputProjection).foreach { row => + val values = row.asInstanceOf[GenericInternalRow].values.zip(ioschema.wrappers).map { + case (value, wrapper) => wrapper(value) + } if (inputSerde == null) { val data = if (len == 0) { ioschema.inputRowFormatMap("TOK_TABLEROWFORMATLINES") } else { val sb = new StringBuilder - sb.append(row.get(0, inputSchema(0))) + buildJSONString(sb, values(0)) var i = 1 while (i < len) { sb.append(ioschema.inputRowFormatMap("TOK_TABLEROWFORMATFIELD")) - sb.append(row.get(i, inputSchema(i))) + buildJSONString(sb, values(i)) i += 1 } sb.append(ioschema.inputRowFormatMap("TOK_TABLEROWFORMATLINES")) @@ -296,9 +327,7 @@ private class ScriptTransformationWriterThread( } outputStream.write(data.getBytes(StandardCharsets.UTF_8)) } else { - val writable = inputSerde.serialize( - row.asInstanceOf[GenericInternalRow].values, inputSoi) - + val writable = inputSerde.serialize(values, inputSoi) if (scriptInputWriter != null) { scriptInputWriter.write(writable) } else { @@ -370,8 +399,10 @@ case class HiveScriptIOSchema ( val inputRowFormatMap = inputRowFormat.toMap.withDefault((k) => defaultFormat(k)) val outputRowFormatMap = outputRowFormat.toMap.withDefault((k) => defaultFormat(k)) + var wrappers: Seq[Any => Any] = _ def initInputSerDe(input: Seq[Expression]): Option[(AbstractSerDe, ObjectInspector)] = { + wrappers = input.map(_.dataType).map(dt => (wrapperFor(toInspector(dt), dt))) inputSerdeClass.map { serdeClass => val (columns, columnTypes) = parseAttrs(input) val serde = initSerDe(serdeClass, columns, columnTypes, inputSerdeProps) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala index c11e37a516646..abd60f1bfd3fd 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala @@ -98,6 +98,76 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { checkAnswer(query1, Row("x1_y1") :: Row("x2_y2") :: Nil) } + test("script: processing map type") { + assume(TestUtils.testCommandAvailable("/bin/bash")) + assume(TestUtils.testCommandAvailable("echo | sed")) + val scriptFilePath = getTestResourcePath("test_script.sh") + val df = Seq((Map("x0" -> "y0", "x1" -> "y1"), "z1"), + (Map("x2" -> "y2"), "z2")).toDF("c1", "c2") + df.createOrReplaceTempView("script_table") + val query = sql( + s""" + |SELECT TRANSFORM(c1, c2) USING 'bash $scriptFilePath' AS data FROM script_table + """.stripMargin) + checkAnswer(query, Row("""{"x0":"y0","x1":"y1"}_z1""") :: Row("""{"x2":"y2"}_z2""") :: Nil) + } + + test("script: processing array type") { + assume(TestUtils.testCommandAvailable("/bin/bash")) + assume(TestUtils.testCommandAvailable("echo | sed")) + val scriptFilePath = getTestResourcePath("test_script.sh") + val df = Seq((Array(0, 1, 2), "x"), + (Array(3, 4, 5), "y")).toDF("c1", "c2") + df.createOrReplaceTempView("script_table") + val query = sql( + s""" + |SELECT TRANSFORM(c1, c2) USING 'bash $scriptFilePath' AS data FROM script_table + """.stripMargin) + checkAnswer(query, Row("""[0,1,2]_x""") :: Row("""[3,4,5]_y""") :: Nil) + } + + test("script: processing with row format") { + val df = Seq((Map("x0" -> "y0", "x1" -> "y1"), "z1"), + (Map("x2" -> "y2"), "z2")).toDF("c1", "c2") + df.createOrReplaceTempView("script_table") + val query = sql( + s""" + |SELECT TRANSFORM(c1, c2) + |ROW FORMAT DELIMITED + |FIELDS TERMINATED BY '\t' + |COLLECTION ITEMS TERMINATED BY '|' + |MAP KEYS TERMINATED BY '#' + |USING 'cat' + |AS (col1 STRING, col2 STRING) + |FROM script_table + """.stripMargin) + checkAnswer(query, Row("x0#y0|x1#y1", "z1") :: Row("x2#y2", "z2") :: Nil) + } + + test("script: processing struct type") { + val schema = StructType( + StructField("s", StructType( + StructField("a", ArrayType(IntegerType), true) :: + StructField("m", MapType(StringType, IntegerType)) :: Nil + )) :: Nil + ) + val rows = Row(Row(Array(1, 2, 3), Map("x1" -> 1))) :: + Row(Row(Array(4, 5, 6), Map("x2" -> 2))) :: Nil + + val rowRdd = sparkContext.parallelize(rows) + + spark.createDataFrame(rowRdd, schema).createOrReplaceTempView("script_table") + val query = sql( + """ + |SELECT TRANSFORM(s) + |USING 'cat' + |AS data + |FROM script_table + """.stripMargin) + checkAnswer(query, Row("""{"a":[1,2,3],"m":{"x1":1}}""") :: + Row("""{"a":[4,5,6],"m":{"x2":2}}""") :: Nil) + } + test("SPARK-6835: udtf in lateral view") { val df = Seq((1, 1)).toDF("c1", "c2") df.createOrReplaceTempView("table1")