From 05f0626e8aeeec07776d92644eb2e8078320e436 Mon Sep 17 00:00:00 2001 From: Asif Shahid Date: Wed, 3 Feb 2016 12:43:31 -0800 Subject: [PATCH] [SPARK-13116][SQL] TungstenAggregate though it is supposedly capable of all processing unsafe & safe rows, fails if the input is safe rows. Allowing targeted mutable row to be set with field values if the target row is Unsafe. --- .../sql/catalyst/expressions/Projection.scala | 49 +++++++++++++++++-- 1 file changed, 46 insertions(+), 3 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala index 053e612f3ecb5..4b63789ca22e5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.codegen.{GenerateSafeProjection, GenerateUnsafeProjection} -import org.apache.spark.sql.types.{DataType, StructType} +import org.apache.spark.sql.types._ /** * A [[Projection]] that is calculated by calling the `eval` of each of the specified expressions. @@ -67,13 +67,52 @@ case class InterpretedMutableProjection(expressions: Seq[Expression]) extends Mu case n: Nondeterministic => n.setInitialValues() case _ => }) - + private var targetUnsafe = false + type UnsafeSetter = (UnsafeRow, Any) => Unit + private var setters : Array[UnsafeSetter] = _ private[this] val exprArray = expressions.toArray private[this] var mutableRow: MutableRow = new GenericMutableRow(exprArray.length) def currentValue: InternalRow = mutableRow override def target(row: MutableRow): MutableProjection = { mutableRow = row + targetUnsafe = row match { + case _ : UnsafeRow => { + if (setters == null) { + setters = Array.ofDim[UnsafeSetter](exprArray.length) + for(i <- 0 until exprArray.length) { + setters(i) = exprArray(i).dataType match { + case IntegerType => (target: UnsafeRow, value: Any ) => + target.setInt(i, value.asInstanceOf[Int]) + case LongType => (target: UnsafeRow, value: Any ) => + target.setLong(i, value.asInstanceOf[Long]) + case DoubleType => (target: UnsafeRow, value: Any ) => + target.setDouble(i, value.asInstanceOf[Double]) + case FloatType => (target: UnsafeRow, value: Any ) => + target.setFloat(i, value.asInstanceOf[Float]) + case NullType => (target: UnsafeRow, value: Any ) => + target.setNullAt(i) + case BooleanType => (target: UnsafeRow, value: Any ) => + target.setBoolean(i, value.asInstanceOf[Boolean]) + case ByteType => (target: UnsafeRow, value: Any ) => + target.setByte(i, value.asInstanceOf[Byte]) + case ShortType => (target: UnsafeRow, value: Any ) => + target.setShort(i, value.asInstanceOf[Short]) + case _: DecimalType => (target: UnsafeRow, value: Any) => + val decNum = value.asInstanceOf[Decimal] + target.setDecimal(i, decNum, decNum.precision) + case DateType => (target: UnsafeRow, value: Any) => + target.setInt(i, value.asInstanceOf[Int]) + case TimestampType => (target: UnsafeRow, value: Any) => + target.setLong(i, value.asInstanceOf[Long]) + } + } + } + true + } + case _ => false + } + this } @@ -86,7 +125,11 @@ case class InterpretedMutableProjection(expressions: Seq[Expression]) extends Mu } i = 0 while (i < exprArray.length) { - mutableRow(i) = buffer(i) + if (targetUnsafe) { + setters(i)(mutableRow.asInstanceOf[UnsafeRow], buffer(i)) + } else { + mutableRow(i) = buffer(i) + } i += 1 } mutableRow