From f6968a49c5b7faba3b160180c36f37f7a59542b8 Mon Sep 17 00:00:00 2001 From: "zhichao.li" Date: Thu, 30 Jul 2015 10:34:31 +0800 Subject: [PATCH 1/3] give script a default serde --- .../org/apache/spark/sql/hive/HiveQl.scala | 3 +- .../hive/execution/ScriptTransformation.scala | 96 ++++++++----------- .../sql/hive/execution/SQLQuerySuite.scala | 12 +++ 3 files changed, 53 insertions(+), 58 deletions(-) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala index e6df64d2642b..8a0c79fd2cb0 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.hive import java.sql.Date import org.apache.hadoop.hive.conf.HiveConf +import org.apache.hadoop.hive.conf.HiveConf.ConfVars import org.apache.hadoop.hive.serde.serdeConstants import org.apache.hadoop.hive.ql.{ErrorMsg, Context} import org.apache.hadoop.hive.ql.exec.{FunctionRegistry, FunctionInfo} @@ -893,7 +894,7 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C } (Nil, Some(BaseSemanticAnalyzer.unescapeSQLString(serdeClass)), serdeProps) - case Nil => (Nil, None, Nil) + case Nil => (Nil, Option(hiveConf().getVar(ConfVars.HIVESCRIPTSERDE)), Nil) } val (inRowFormat, inSerdeClass, inSerdeProps) = matchSerDe(inputSerdeClause) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala index 741c705e2a25..edfd8102312f 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala @@ -27,11 +27,11 @@ import scala.util.control.NonFatal import org.apache.hadoop.hive.serde.serdeConstants import org.apache.hadoop.hive.serde2.AbstractSerDe import org.apache.hadoop.hive.serde2.objectinspector._ +import org.apache.hadoop.io.Writable import org.apache.spark.{TaskContext, Logging} import org.apache.spark.rdd.RDD -import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.CatalystTypeConverters +import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical.ScriptInputOutputSchema import org.apache.spark.sql.execution._ @@ -82,6 +82,7 @@ case class ScriptTransformation( // This nullability is a performance optimization in order to avoid an Option.foreach() call // inside of a loop + @Nullable val (inputSerde, inputSoi) = ioschema.initInputSerDe(input).getOrElse((null, null)) // This new thread will consume the ScriptTransformation's input rows and write them to the @@ -106,9 +107,16 @@ case class ScriptTransformation( val reader = new BufferedReader(new InputStreamReader(inputStream)) val outputIterator: Iterator[InternalRow] = new Iterator[InternalRow] with HiveInspectors { - var cacheRow: InternalRow = null var curLine: String = null - var eof: Boolean = false + var cacheRow: InternalRow = null + val scriptOutputStream = new DataInputStream(inputStream) + var scriptOutputWritable: Writable = null + val reusedWritableObject: Writable = if (null != outputSerde) { + outputSerde.getSerializedClass().newInstance + } else { + null + } + val mutableRow = new SpecificMutableRow(output.map(_.dataType)) override def hasNext: Boolean = { if (outputSerde == null) { @@ -125,45 +133,20 @@ case class ScriptTransformation( } else { true } - } else { - if (eof) { - if (writerThread.exception.isDefined) { - throw writerThread.exception.get - } - false - } else { + } else if (scriptOutputWritable == null) { + scriptOutputWritable = reusedWritableObject + try { + scriptOutputWritable.readFields(scriptOutputStream) true + } catch { + case _: EOFException => + if (writerThread.exception.isDefined) { + throw writerThread.exception.get + } + false } - } - } - - def deserialize(): InternalRow = { - if (cacheRow != null) return cacheRow - - val mutableRow = new SpecificMutableRow(output.map(_.dataType)) - try { - val dataInputStream = new DataInputStream(inputStream) - val writable = outputSerde.getSerializedClass().newInstance - writable.readFields(dataInputStream) - - val raw = outputSerde.deserialize(writable) - val dataList = outputSoi.getStructFieldsDataAsList(raw) - val fieldList = outputSoi.getAllStructFieldRefs() - - var i = 0 - dataList.foreach( element => { - if (element == null) { - mutableRow.setNullAt(i) - } else { - mutableRow(i) = unwrap(element, fieldList(i).getFieldObjectInspector) - } - i += 1 - }) - mutableRow - } catch { - case e: EOFException => - eof = true - null + } else { + true } } @@ -171,7 +154,6 @@ case class ScriptTransformation( if (!hasNext) { throw new NoSuchElementException } - if (outputSerde == null) { val prevLine = curLine curLine = reader.readLine() @@ -185,12 +167,20 @@ case class ScriptTransformation( .asInstanceOf[Array[Any]]) } } else { - val ret = deserialize() - if (!eof) { - cacheRow = null - cacheRow = deserialize() + val raw = outputSerde.deserialize(scriptOutputWritable) + scriptOutputWritable = null + val dataList = outputSoi.getStructFieldsDataAsList(raw) + val fieldList = outputSoi.getAllStructFieldRefs() + var i = 0 + while (i < dataList.size()) { + if (dataList(i) == null) { + mutableRow.setNullAt(i) + } else { + mutableRow(i) = unwrap(dataList(i), fieldList(i).getFieldObjectInspector) + } + i += 1 } - ret + mutableRow } } } @@ -320,17 +310,9 @@ case class HiveScriptIOSchema ( } private def parseAttrs(attrs: Seq[Expression]): (Seq[String], Seq[DataType]) = { - val columns = attrs.map { - case aref: AttributeReference => aref.name - case e: NamedExpression => e.name - case _ => null - } + val columns = attrs.zipWithIndex.map { e => s"${e._1.prettyName}_${e._2}" } - val columnTypes = attrs.map { - case aref: AttributeReference => aref.dataType - case e: NamedExpression => e.dataType - case _ => null - } + val columnTypes = attrs.map { _.dataType } (columns, columnTypes) } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala index c4923d83e48f..c8aae3f4e0da 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala @@ -725,6 +725,18 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils { .queryExecution.toRdd.count()) } + test("test script transform data type") { + val data = (1 to 5).map { i => (i, i) } + data.toDF("key", "value").registerTempTable("test") + checkAnswer( + sql( + """FROM + |(FROM test SELECT TRANSFORM(key, value) USING 'cat' AS (thing1 int, thing2 string)) t + |SELECT thing1 + 1 + """.stripMargin), + (2 to 6).map(i => Row(i))) + } + test("window function: udaf with aggregate expressin") { val data = Seq( WindowData(1, "a", 5), From b9252a8f14f2c174f2fab3f2d3d66caf232e3dec Mon Sep 17 00:00:00 2001 From: "zhichao.li" Date: Thu, 30 Jul 2015 12:13:32 +0800 Subject: [PATCH 2/3] delete cacheRow --- .../apache/spark/sql/hive/execution/ScriptTransformation.scala | 1 - 1 file changed, 1 deletion(-) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala index edfd8102312f..be73a3b1e27e 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala @@ -108,7 +108,6 @@ case class ScriptTransformation( val reader = new BufferedReader(new InputStreamReader(inputStream)) val outputIterator: Iterator[InternalRow] = new Iterator[InternalRow] with HiveInspectors { var curLine: String = null - var cacheRow: InternalRow = null val scriptOutputStream = new DataInputStream(inputStream) var scriptOutputWritable: Writable = null val reusedWritableObject: Writable = if (null != outputSerde) { From a36cc7cd24ab1495880957b8f718576c9dc6f990 Mon Sep 17 00:00:00 2001 From: "zhichao.li" Date: Tue, 4 Aug 2015 09:43:53 +0800 Subject: [PATCH 3/3] style --- .../spark/sql/hive/execution/ScriptTransformation.scala | 7 ++----- .../apache/spark/sql/hive/execution/SQLQuerySuite.scala | 6 ++---- 2 files changed, 4 insertions(+), 9 deletions(-) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala index be73a3b1e27e..a2a492f89261 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala @@ -82,7 +82,6 @@ case class ScriptTransformation( // This nullability is a performance optimization in order to avoid an Option.foreach() call // inside of a loop - @Nullable val (inputSerde, inputSoi) = ioschema.initInputSerDe(input).getOrElse((null, null)) // This new thread will consume the ScriptTransformation's input rows and write them to the @@ -309,10 +308,8 @@ case class HiveScriptIOSchema ( } private def parseAttrs(attrs: Seq[Expression]): (Seq[String], Seq[DataType]) = { - val columns = attrs.zipWithIndex.map { e => s"${e._1.prettyName}_${e._2}" } - - val columnTypes = attrs.map { _.dataType } - + val columns = attrs.zipWithIndex.map(e => s"${e._1.prettyName}_${e._2}") + val columnTypes = attrs.map(_.dataType) (columns, columnTypes) } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala index c8aae3f4e0da..0c168259a4e4 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala @@ -729,12 +729,10 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils { val data = (1 to 5).map { i => (i, i) } data.toDF("key", "value").registerTempTable("test") checkAnswer( - sql( - """FROM + sql("""FROM |(FROM test SELECT TRANSFORM(key, value) USING 'cat' AS (thing1 int, thing2 string)) t |SELECT thing1 + 1 - """.stripMargin), - (2 to 6).map(i => Row(i))) + """.stripMargin), (2 to 6).map(i => Row(i))) } test("window function: udaf with aggregate expressin") {