From 9f5bb111ac0c80c0d45ed4a0a67af82c3add2be4 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Tue, 25 Oct 2016 23:27:46 +0800 Subject: [PATCH] complex type creator should always output unsafe format --- .../sql/catalyst/analysis/Analyzer.scala | 6 - .../sql/catalyst/expressions/Projection.scala | 10 +- .../codegen/GenerateUnsafeProjection.scala | 4 +- .../expressions/complexTypeCreator.scala | 280 +++++++----------- .../sql/catalyst/analysis/AnalysisSuite.scala | 2 - .../expressions/CodeGenerationSuite.scala | 30 +- .../expressions/ComplexTypeSuite.scala | 2 - .../expressions/ExpressionEvalHelper.scala | 44 ++- .../org/apache/spark/sql/DataFrameSuite.scala | 18 ++ 9 files changed, 173 insertions(+), 223 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index f8f4799322b3b..46e076d827e42 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -2080,9 +2080,6 @@ object CleanupAliases extends Rule[LogicalPlan] { case c: CreateStruct if !stop => stop = true c.copy(children = c.children.map(trimNonTopLevelAliases)) - case c: CreateStructUnsafe if !stop => - stop = true - c.copy(children = c.children.map(trimNonTopLevelAliases)) case Alias(child, _) if !stop => child } } @@ -2121,9 +2118,6 @@ object CleanupAliases extends Rule[LogicalPlan] { case c: CreateStruct if !stop => stop = true c.copy(children = c.children.map(trimNonTopLevelAliases)) - case c: CreateStructUnsafe if !stop => - stop = true - c.copy(children = c.children.map(trimNonTopLevelAliases)) case Alias(child, _) if !stop => child } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala index a81fa1ce3adcc..76c372716ba1f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala @@ -118,11 +118,7 @@ object UnsafeProjection { * Returns an UnsafeProjection for given sequence of Expressions (bounded). */ def create(exprs: Seq[Expression]): UnsafeProjection = { - val unsafeExprs = exprs.map(_ transform { - case CreateStruct(children) => CreateStructUnsafe(children) - case CreateNamedStruct(children) => CreateNamedStructUnsafe(children) - }) - GenerateUnsafeProjection.generate(unsafeExprs) + GenerateUnsafeProjection.generate(exprs) } def create(expr: Expression): UnsafeProjection = create(Seq(expr)) @@ -144,10 +140,6 @@ object UnsafeProjection { inputSchema: Seq[Attribute], subexpressionEliminationEnabled: Boolean): UnsafeProjection = { val e = exprs.map(BindReferences.bindReference(_, inputSchema)) - .map(_ transform { - case CreateStruct(children) => CreateStructUnsafe(children) - case CreateNamedStruct(children) => CreateNamedStructUnsafe(children) - }) GenerateUnsafeProjection.generate(e, subexpressionEliminationEnabled) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala index 7cc45372daa5a..67f081946dfec 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala @@ -167,7 +167,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro } // TODO: if the nullability of array element is correct, we can use it to save null check. - private def writeArrayToBuffer( + def writeArrayToBuffer( ctx: CodegenContext, input: String, elementType: DataType, @@ -245,7 +245,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro } // TODO: if the nullability of value element is correct, we can use it to save null check. - private def writeMapToBuffer( + def writeMapToBuffer( ctx: CodegenContext, input: String, keyType: DataType, diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala index 917aa0873130b..e640139f26114 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala @@ -36,7 +36,7 @@ case class CreateArray(children: Seq[Expression]) extends Expression { override def checkInputDataTypes(): TypeCheckResult = TypeUtils.checkForSameTypeInputExpr(children.map(_.dataType), "function array") - override def dataType: DataType = { + override def dataType: ArrayType = { ArrayType( children.headOption.map(_.dataType).getOrElse(NullType), containsNull = children.exists(_.nullable)) @@ -44,34 +44,53 @@ case class CreateArray(children: Seq[Expression]) extends Expression { override def nullable: Boolean = false + @transient private lazy val unsafeProj = + UnsafeProjection.create(BoundReference(0, dataType, false)) + override def eval(input: InternalRow): Any = { - new GenericArrayData(children.map(_.eval(input)).toArray) + val safeArray = new GenericArrayData(children.map(_.eval(input)).toArray) + unsafeProj(InternalRow(safeArray)).getArray(0) } override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val arrayClass = classOf[GenericArrayData].getName val values = ctx.freshName("values") + val safeArray = ctx.freshName("safeArray") ctx.addMutableState("Object[]", values, s"this.$values = null;") - ev.copy(code = s""" - final boolean ${ev.isNull} = false; - this.$values = new Object[${children.size}];""" + - ctx.splitExpressions( - ctx.INPUT_ROW, - children.zipWithIndex.map { case (e, i) => - val eval = e.genCode(ctx) - eval.code + s""" - if (${eval.isNull}) { - $values[$i] = null; - } else { - $values[$i] = ${eval.value}; - } - """ - }) + + val holder = ctx.freshName("holder") + val holderClass = classOf[BufferHolder].getName + ctx.addMutableState(holderClass, holder, + s"this.$holder = new $holderClass(new UnsafeRow(0));") + + val setValues = ctx.splitExpressions(ctx.INPUT_ROW, children.zipWithIndex.map { case (e, i) => + val eval = e.genCode(ctx) + s""" + ${eval.code} + if (${eval.isNull}) { + $values[$i] = null; + } else { + $values[$i] = ${eval.value}; + } + """ + }) + + // TODO(cloud-fan): should optimize for primitive array. + val writeUnsafeArray = GenerateUnsafeProjection.writeArrayToBuffer( + ctx, safeArray, dataType.elementType, holder) + val code = s""" - final ArrayData ${ev.value} = new $arrayClass($values); - this.$values = null; - """) + $holder.reset(); + $values = new Object[${children.size}]; + $setValues + final ArrayData $safeArray = new $arrayClass($values); + $writeUnsafeArray + final UnsafeArrayData ${ev.value} = new UnsafeArrayData(); + ${ev.value}.pointTo($holder.buffer, Platform.BYTE_ARRAY_OFFSET, $holder.totalSize()); + $values = null; + """ + + ev.copy(code = code, isNull = "false") } override def prettyName: String = "array" @@ -103,7 +122,7 @@ case class CreateMap(children: Seq[Expression]) extends Expression { } } - override def dataType: DataType = { + override def dataType: MapType = { MapType( keyType = keys.headOption.map(_.dataType).getOrElse(NullType), valueType = values.headOption.map(_.dataType).getOrElse(NullType), @@ -112,13 +131,18 @@ case class CreateMap(children: Seq[Expression]) extends Expression { override def nullable: Boolean = false + @transient private lazy val unsafeProj = + UnsafeProjection.create(BoundReference(0, dataType, false)) + override def eval(input: InternalRow): Any = { val keyArray = keys.map(_.eval(input)).toArray if (keyArray.contains(null)) { throw new RuntimeException("Cannot use null as map key!") } val valueArray = values.map(_.eval(input)).toArray - new ArrayBasedMapData(new GenericArrayData(keyArray), new GenericArrayData(valueArray)) + val safeMap = + new ArrayBasedMapData(new GenericArrayData(keyArray), new GenericArrayData(valueArray)) + unsafeProj(InternalRow(safeMap)).getMap(0) } override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { @@ -126,46 +150,59 @@ case class CreateMap(children: Seq[Expression]) extends Expression { val mapClass = classOf[ArrayBasedMapData].getName val keyArray = ctx.freshName("keyArray") val valueArray = ctx.freshName("valueArray") + val safeMap = ctx.freshName("safeMap") ctx.addMutableState("Object[]", keyArray, s"this.$keyArray = null;") ctx.addMutableState("Object[]", valueArray, s"this.$valueArray = null;") + val holder = ctx.freshName("holder") + val holderClass = classOf[BufferHolder].getName + ctx.addMutableState(holderClass, holder, + s"this.$holder = new $holderClass(new UnsafeRow(0));") + + val setKeys = ctx.splitExpressions(ctx.INPUT_ROW, keys.zipWithIndex.map { case (e, i) => + val eval = e.genCode(ctx) + s""" + ${eval.code} + if (${eval.isNull}) { + throw new RuntimeException("Cannot use null as map key!"); + } else { + $keyArray[$i] = ${eval.value}; + } + """ + }) + + val setValues = ctx.splitExpressions(ctx.INPUT_ROW, values.zipWithIndex.map { case (e, i) => + val eval = e.genCode(ctx) + s""" + ${eval.code} + if (${eval.isNull}) { + $valueArray[$i] = null; + } else { + $valueArray[$i] = ${eval.value}; + } + """ + }) + val keyData = s"new $arrayClass($keyArray)" val valueData = s"new $arrayClass($valueArray)" - ev.copy(code = s""" - final boolean ${ev.isNull} = false; - $keyArray = new Object[${keys.size}]; - $valueArray = new Object[${values.size}];""" + - ctx.splitExpressions( - ctx.INPUT_ROW, - keys.zipWithIndex.map { case (key, i) => - val eval = key.genCode(ctx) - s""" - ${eval.code} - if (${eval.isNull}) { - throw new RuntimeException("Cannot use null as map key!"); - } else { - $keyArray[$i] = ${eval.value}; - } - """ - }) + - ctx.splitExpressions( - ctx.INPUT_ROW, - values.zipWithIndex.map { case (value, i) => - val eval = value.genCode(ctx) - s""" - ${eval.code} - if (${eval.isNull}) { - $valueArray[$i] = null; - } else { - $valueArray[$i] = ${eval.value}; - } - """ - }) + + val writeUnsafeMap = GenerateUnsafeProjection.writeMapToBuffer( + ctx, safeMap, dataType.keyType, dataType.valueType, holder) + val code = s""" - final MapData ${ev.value} = new $mapClass($keyData, $valueData); - this.$keyArray = null; - this.$valueArray = null; - """) + $holder.reset(); + $keyArray = new Object[${keys.size}]; + $valueArray = new Object[${values.size}]; + $setKeys + $setValues + final MapData $safeMap = new $mapClass($keyData, $valueData); + $writeUnsafeMap + final UnsafeMapData ${ev.value} = new UnsafeMapData(); + ${ev.value}.pointTo($holder.buffer, Platform.BYTE_ARRAY_OFFSET, $holder.totalSize()); + $keyArray = null; + $valueArray = null; + """ + + ev.copy(code = code, isNull = "false") } override def prettyName: String = "map" @@ -194,33 +231,14 @@ case class CreateStruct(children: Seq[Expression]) extends Expression { override def nullable: Boolean = false + @transient private lazy val unsafeProj = UnsafeProjection.create(dataType) + override def eval(input: InternalRow): Any = { - InternalRow(children.map(_.eval(input)): _*) + unsafeProj(InternalRow(children.map(_.eval(input)): _*)) } override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - val rowClass = classOf[GenericInternalRow].getName - val values = ctx.freshName("values") - ctx.addMutableState("Object[]", values, s"this.$values = null;") - - ev.copy(code = s""" - boolean ${ev.isNull} = false; - this.$values = new Object[${children.size}];""" + - ctx.splitExpressions( - ctx.INPUT_ROW, - children.zipWithIndex.map { case (e, i) => - val eval = e.genCode(ctx) - eval.code + s""" - if (${eval.isNull}) { - $values[$i] = null; - } else { - $values[$i] = ${eval.value}; - }""" - }) + - s""" - final InternalRow ${ev.value} = new $rowClass($values); - this.$values = null; - """) + GenerateUnsafeProjection.createCode(ctx, children) } override def prettyName: String = "struct" @@ -284,116 +302,20 @@ case class CreateNamedStruct(children: Seq[Expression]) extends Expression { } } - override def eval(input: InternalRow): Any = { - InternalRow(valExprs.map(_.eval(input)): _*) - } - - override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - val rowClass = classOf[GenericInternalRow].getName - val values = ctx.freshName("values") - ctx.addMutableState("Object[]", values, s"this.$values = null;") - - ev.copy(code = s""" - boolean ${ev.isNull} = false; - $values = new Object[${valExprs.size}];""" + - ctx.splitExpressions( - ctx.INPUT_ROW, - valExprs.zipWithIndex.map { case (e, i) => - val eval = e.genCode(ctx) - eval.code + s""" - if (${eval.isNull}) { - $values[$i] = null; - } else { - $values[$i] = ${eval.value}; - }""" - }) + - s""" - final InternalRow ${ev.value} = new $rowClass($values); - this.$values = null; - """) - } - - override def prettyName: String = "named_struct" -} - -/** - * Returns a Row containing the evaluation of all children expressions. This is a variant that - * returns UnsafeRow directly. The unsafe projection operator replaces [[CreateStruct]] with - * this expression automatically at runtime. - */ -case class CreateStructUnsafe(children: Seq[Expression]) extends Expression { - - override def foldable: Boolean = children.forall(_.foldable) - - override lazy val resolved: Boolean = childrenResolved - - override lazy val dataType: StructType = { - val fields = children.zipWithIndex.map { case (child, idx) => - child match { - case ne: NamedExpression => - StructField(ne.name, ne.dataType, ne.nullable, ne.metadata) - case _ => - StructField(s"col${idx + 1}", child.dataType, child.nullable, Metadata.empty) - } - } - StructType(fields) - } - - override def nullable: Boolean = false + @transient private lazy val unsafeProj = UnsafeProjection.create(dataType) override def eval(input: InternalRow): Any = { - InternalRow(children.map(_.eval(input)): _*) + unsafeProj(InternalRow(valExprs.map(_.eval(input)): _*)) } override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - val eval = GenerateUnsafeProjection.createCode(ctx, children) - ExprCode(code = eval.code, isNull = eval.isNull, value = eval.value) + GenerateUnsafeProjection.createCode(ctx, valExprs) } - override def prettyName: String = "struct_unsafe" + override def prettyName: String = "named_struct" } -/** - * Creates a struct with the given field names and values. This is a variant that returns - * UnsafeRow directly. The unsafe projection operator replaces [[CreateStruct]] with - * this expression automatically at runtime. - * - * @param children Seq(name1, val1, name2, val2, ...) - */ -case class CreateNamedStructUnsafe(children: Seq[Expression]) extends Expression { - - private lazy val (nameExprs, valExprs) = - children.grouped(2).map { case Seq(name, value) => (name, value) }.toList.unzip - - private lazy val names = nameExprs.map(_.eval(EmptyRow).toString) - - override lazy val dataType: StructType = { - val fields = names.zip(valExprs).map { - case (name, valExpr: NamedExpression) => - StructField(name, valExpr.dataType, valExpr.nullable, valExpr.metadata) - case (name, valExpr) => - StructField(name, valExpr.dataType, valExpr.nullable, Metadata.empty) - } - StructType(fields) - } - - override def foldable: Boolean = valExprs.forall(_.foldable) - - override def nullable: Boolean = false - - override def eval(input: InternalRow): Any = { - InternalRow(valExprs.map(_.eval(input)): _*) - } - - override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - val eval = GenerateUnsafeProjection.createCode(ctx, valExprs) - ExprCode(code = eval.code, isNull = eval.isNull, value = eval.value) - } - - override def prettyName: String = "named_struct_unsafe" -} - /** * Creates a map after splitting the input text into key/value pairs using delimiters */ diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala index 590774c043040..1769ffb789951 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala @@ -219,8 +219,6 @@ class AnalysisSuite extends AnalysisTest { // CreateStruct is a special case that we should not trim Alias for it. plan = testRelation.select(CreateStruct(Seq(a, (a + 1).as("a+1"))).as("col")) checkAnalysis(plan, plan) - plan = testRelation.select(CreateStructUnsafe(Seq(a, (a + 1).as("a+1"))).as("col")) - checkAnalysis(plan, plan) } test("SPARK-10534: resolve attribute references in order by clause") { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala index 0cb201e4dae3e..0cd556d5c2c4f 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala @@ -71,7 +71,7 @@ class CodeGenerationSuite extends SparkFunSuite with ExpressionEvalHelper { val actual = plan(new GenericInternalRow(length)).toSeq(expressions.map(_.dataType)) val expected = Seq.fill(length)(true) - if (!checkResult(actual, expected)) { + if (actual != expected) { fail(s"Incorrect Evaluation: expressions: $expressions, actual: $actual, expected: $expected") } } @@ -102,9 +102,10 @@ class CodeGenerationSuite extends SparkFunSuite with ExpressionEvalHelper { val expressions = Seq(CreateArray(List.fill(length)(EqualTo(Literal(1), Literal(1))))) val plan = GenerateMutableProjection.generate(expressions) val actual = plan(new GenericInternalRow(length)).toSeq(expressions.map(_.dataType)) - val expected = Seq(new GenericArrayData(Seq.fill(length)(true))) + assert(actual.length == 1) + val expected = new GenericArrayData(Seq.fill(length)(true)) - if (!checkResult(actual, expected)) { + if (!checkResult(actual.head, expected, expressions.head.dataType)) { fail(s"Incorrect Evaluation: expressions: $expressions, actual: $actual, expected: $expected") } } @@ -116,12 +117,11 @@ class CodeGenerationSuite extends SparkFunSuite with ExpressionEvalHelper { case (expr, i) => Seq(Literal(i), expr) })) val plan = GenerateMutableProjection.generate(expressions) - val actual = plan(new GenericInternalRow(length)).toSeq(expressions.map(_.dataType)).map { - case m: ArrayBasedMapData => ArrayBasedMapData.toScalaMap(m) - } - val expected = (0 until length).map((_, true)).toMap :: Nil + val actual = plan(new GenericInternalRow(length)).toSeq(expressions.map(_.dataType)) + assert(actual.length == 1) + val expected = ArrayBasedMapData((0 until length).toArray, Array.fill(length)(true)) - if (!checkResult(actual, expected)) { + if (!checkResult(actual.head, expected, expressions.head.dataType)) { fail(s"Incorrect Evaluation: expressions: $expressions, actual: $actual, expected: $expected") } } @@ -131,9 +131,10 @@ class CodeGenerationSuite extends SparkFunSuite with ExpressionEvalHelper { val expressions = Seq(CreateStruct(List.fill(length)(EqualTo(Literal(1), Literal(1))))) val plan = GenerateMutableProjection.generate(expressions) val actual = plan(new GenericInternalRow(length)).toSeq(expressions.map(_.dataType)) - val expected = Seq(InternalRow(Seq.fill(length)(true): _*)) + assert(actual.length == 1) + val expected = InternalRow(Seq.fill(length)(true): _*) - if (!checkResult(actual, expected)) { + if (!checkResult(actual.head, expected, expressions.head.dataType)) { fail(s"Incorrect Evaluation: expressions: $expressions, actual: $actual, expected: $expected") } } @@ -146,9 +147,10 @@ class CodeGenerationSuite extends SparkFunSuite with ExpressionEvalHelper { })) val plan = GenerateMutableProjection.generate(expressions) val actual = plan(new GenericInternalRow(length)).toSeq(expressions.map(_.dataType)) - val expected = Seq(InternalRow(Seq.fill(length)(true): _*)) + assert(actual.length == 1) + val expected = InternalRow(Seq.fill(length)(true): _*) - if (!checkResult(actual, expected)) { + if (!checkResult(actual.head, expected, expressions.head.dataType)) { fail(s"Incorrect Evaluation: expressions: $expressions, actual: $actual, expected: $expected") } } @@ -161,7 +163,7 @@ class CodeGenerationSuite extends SparkFunSuite with ExpressionEvalHelper { val actual = plan(new GenericInternalRow(length)).toSeq(expressions.map(_.dataType)) val expected = Seq(Row.fromSeq(Seq.fill(length)(1))) - if (!checkResult(actual, expected)) { + if (actual != expected) { fail(s"Incorrect Evaluation: expressions: $expressions, actual: $actual, expected: $expected") } } @@ -178,7 +180,7 @@ class CodeGenerationSuite extends SparkFunSuite with ExpressionEvalHelper { val expected = Seq.fill(length)( DateTimeUtils.fromJavaTimestamp(Timestamp.valueOf("2015-07-24 07:00:00"))) - if (!checkResult(actual, expected)) { + if (actual != expected) { fail(s"Incorrect Evaluation: expressions: $expressions, actual: $actual, expected: $expected") } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala index 0c307b2b8576b..923e471aa42ce 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala @@ -243,8 +243,6 @@ class ComplexTypeSuite extends SparkFunSuite with ExpressionEvalHelper { val b = AttributeReference("b", IntegerType)() checkMetadata(CreateStruct(Seq(a, b))) checkMetadata(CreateNamedStruct(Seq("a", a, "b", b))) - checkMetadata(CreateStructUnsafe(Seq(a, b))) - checkMetadata(CreateNamedStructUnsafe(Seq("a", a, "b", b))) } test("StringToMap") { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala index f0c149c02b9aa..44afd953046c1 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala @@ -27,8 +27,8 @@ import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.optimizer.SimpleTestOptimizer import org.apache.spark.sql.catalyst.plans.logical.{OneRowRelation, Project} -import org.apache.spark.sql.catalyst.util.MapData -import org.apache.spark.sql.types.DataType +import org.apache.spark.sql.catalyst.util.{ArrayData, MapData} +import org.apache.spark.sql.types.{ArrayType, DataType, MapType, StructType} import org.apache.spark.util.Utils /** @@ -56,18 +56,44 @@ trait ExpressionEvalHelper extends GeneratorDrivenPropertyChecks { * Check the equality between result of expression and expected value, it will handle * Array[Byte], Spread[Double], and MapData. */ - protected def checkResult(result: Any, expected: Any): Boolean = { + protected def checkResult(result: Any, expected: Any, dataType: DataType): Boolean = { (result, expected) match { case (result: Array[Byte], expected: Array[Byte]) => java.util.Arrays.equals(result, expected) case (result: Double, expected: Spread[Double @unchecked]) => expected.asInstanceOf[Spread[Double]].isWithin(result) - case (result: MapData, expected: MapData) => - result.keyArray() == expected.keyArray() && result.valueArray() == expected.valueArray() case (result: Double, expected: Double) => if (expected.isNaN) result.isNaN else expected == result case (result: Float, expected: Float) => if (expected.isNaN) result.isNaN else expected == result + case (result: ArrayData, expected: ArrayData) => + result.numElements == expected.numElements && { + val et = dataType.asInstanceOf[ArrayType].elementType + var isSame = true + var i = 0 + while (isSame && i < result.numElements()) { + isSame = checkResult(result.get(i, et), expected.get(i, et), et) + i += 1 + } + isSame + } + case (result: MapData, expected: MapData) => + val kt = dataType.asInstanceOf[MapType].keyType + val vt = dataType.asInstanceOf[MapType].valueType + checkResult(result.keyArray(), expected.keyArray(), ArrayType(kt)) && + checkResult(result.valueArray(), expected.valueArray(), ArrayType(vt)) + case (result: InternalRow, expected: InternalRow) => + result.numFields == expected.numFields && { + val types = dataType.asInstanceOf[StructType].map(_.dataType) + var isSame = true + var i = 0 + while (isSame && i < result.numFields) { + val dt = types(i) + isSame = checkResult(result.get(i, dt), expected.get(i, dt), dt) + i += 1 + } + isSame + } case _ => result == expected } @@ -105,7 +131,7 @@ trait ExpressionEvalHelper extends GeneratorDrivenPropertyChecks { val actual = try evaluate(expression, inputRow) catch { case e: Exception => fail(s"Exception evaluating $expression", e) } - if (!checkResult(actual, expected)) { + if (!checkResult(actual, expected, expression.dataType)) { val input = if (inputRow == EmptyRow) "" else s", input: $inputRow" fail(s"Incorrect evaluation (codegen off): $expression, " + s"actual: $actual, " + @@ -123,7 +149,7 @@ trait ExpressionEvalHelper extends GeneratorDrivenPropertyChecks { expression) val actual = plan(inputRow).get(0, expression.dataType) - if (!checkResult(actual, expected)) { + if (!checkResult(actual, expected, expression.dataType)) { val input = if (inputRow == EmptyRow) "" else s", input: $inputRow" fail(s"Incorrect evaluation: $expression, actual: $actual, expected: $expected$input") } @@ -183,14 +209,14 @@ trait ExpressionEvalHelper extends GeneratorDrivenPropertyChecks { GenerateMutableProjection.generate(Alias(expression, s"Optimized($expression)")() :: Nil), expression) var actual = plan(inputRow).get(0, expression.dataType) - assert(checkResult(actual, expected)) + assert(checkResult(actual, expected, expression.dataType)) plan = generateProject( GenerateUnsafeProjection.generate(Alias(expression, s"Optimized($expression)")() :: Nil), expression) actual = FromUnsafeProjection(expression.dataType :: Nil)( plan(inputRow)).get(0, expression.dataType) - assert(checkResult(actual, expected)) + assert(checkResult(actual, expected, expression.dataType)) } /** diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index 3fb7eeefba67f..b7867d7e1392a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -28,6 +28,7 @@ import org.scalatest.Matchers._ import org.apache.spark.SparkException import org.apache.spark.sql.catalyst.TableIdentifier +import org.apache.spark.sql.catalyst.expressions.{CreateNamedStruct, Literal} import org.apache.spark.sql.catalyst.plans.logical.{OneRowRelation, Project, Union} import org.apache.spark.sql.execution.QueryExecution import org.apache.spark.sql.execution.aggregate.HashAggregateExec @@ -1649,4 +1650,21 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { dates.except(widenTypedRows).collect() dates.intersect(widenTypedRows).collect() } + + test("complex type creator should always output unsafe format") { + // map is not tested here as it's not order-able. + val df = Seq((Array(1), 1 -> 1, 1)).toDF("array", "struct", "int") + + def namedStruct(field1: Column, field2: Column): Column = { + Column(CreateNamedStruct(Seq(Literal("_1"), field1.expr, Literal("_2"), field2.expr))) + } + + // test interpreted version, we will call `eval` when optimize foldable expressions. + assert(df.filter($"array" === array(lit(1))).count() == 1) + assert(df.filter($"struct" === namedStruct(lit(1), lit(1))).count() == 1) + + // test codegen version + assert(df.filter($"array" === array($"int")).count() == 1) + assert(df.filter($"struct" === namedStruct($"int", $"int")).count() == 1) + } }