diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala index f9fd04c02aaef..20b1eaab8e303 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala @@ -17,16 +17,18 @@ package org.apache.spark.sql.catalyst.expressions +import scala.collection.mutable + import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.TypeCheckResult -import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback +import org.apache.spark.sql.catalyst.expressions.codegen.{GeneratedExpressionCode, CodeGenContext} import org.apache.spark.sql.catalyst.util.TypeUtils import org.apache.spark.sql.types._ /** * Returns an Array containing the evaluation of all children expressions. */ -case class CreateArray(children: Seq[Expression]) extends Expression with CodegenFallback { +case class CreateArray(children: Seq[Expression]) extends Expression { override def foldable: Boolean = children.forall(_.foldable) @@ -45,14 +47,31 @@ case class CreateArray(children: Seq[Expression]) extends Expression with Codege children.map(_.eval(input)) } + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + val arraySeqClass = classOf[mutable.ArraySeq[Any]].getName + s""" + boolean ${ev.isNull} = false; + $arraySeqClass ${ev.primitive} = new $arraySeqClass(${children.size}); + """ + + children.zipWithIndex.map { case (e, i) => + val eval = e.gen(ctx) + eval.code + s""" + if (${eval.isNull}) { + ${ev.primitive}.update($i, null); + } else { + ${ev.primitive}.update($i, ${eval.primitive}); + } + """ + }.mkString("\n") + } + override def prettyName: String = "array" } /** * Returns a Row containing the evaluation of all children expressions. - * TODO: [[CreateStruct]] does not support codegen. */ -case class CreateStruct(children: Seq[Expression]) extends Expression with CodegenFallback { +case class CreateStruct(children: Seq[Expression]) extends Expression { override def foldable: Boolean = children.forall(_.foldable) @@ -76,6 +95,24 @@ case class CreateStruct(children: Seq[Expression]) extends Expression with Codeg InternalRow(children.map(_.eval(input)): _*) } + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + val rowClass = classOf[GenericMutableRow].getName + s""" + boolean ${ev.isNull} = false; + final $rowClass ${ev.primitive} = new $rowClass(${children.size}); + """ + + children.zipWithIndex.map { case (e, i) => + val eval = e.gen(ctx) + eval.code + s""" + if (${eval.isNull}) { + ${ev.primitive}.update($i, null); + } else { + ${ev.primitive}.update($i, ${eval.primitive}); + } + """ + }.mkString("\n") + } + override def prettyName: String = "struct" } @@ -84,7 +121,7 @@ case class CreateStruct(children: Seq[Expression]) extends Expression with Codeg * * @param children Seq(name1, val1, name2, val2, ...) */ -case class CreateNamedStruct(children: Seq[Expression]) extends Expression with CodegenFallback { +case class CreateNamedStruct(children: Seq[Expression]) extends Expression { private lazy val (nameExprs, valExprs) = children.grouped(2).map { case Seq(name, value) => (name, value) }.toList.unzip @@ -122,5 +159,23 @@ case class CreateNamedStruct(children: Seq[Expression]) extends Expression with InternalRow(valExprs.map(_.eval(input)): _*) } + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + val rowClass = classOf[GenericMutableRow].getName + s""" + boolean ${ev.isNull} = false; + final $rowClass ${ev.primitive} = new $rowClass(${valExprs.size}); + """ + + valExprs.zipWithIndex.map { case (e, i) => + val eval = e.gen(ctx) + eval.code + s""" + if (${eval.isNull}) { + ${ev.primitive}.update($i, null); + } else { + ${ev.primitive}.update($i, ${eval.primitive}); + } + """ + }.mkString("\n") + } + override def prettyName: String = "named_struct" } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala index e3042143632aa..a8aee8f634e03 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala @@ -117,6 +117,22 @@ class ComplexTypeSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(getArrayStructFields(nullArrayStruct, "a"), null) } + test("CreateArray") { + val intSeq = Seq(5, 10, 15, 20, 25) + val longSeq = intSeq.map(_.toLong) + val strSeq = intSeq.map(_.toString) + checkEvaluation(CreateArray(intSeq.map(Literal(_))), intSeq, EmptyRow) + checkEvaluation(CreateArray(longSeq.map(Literal(_))), longSeq, EmptyRow) + checkEvaluation(CreateArray(strSeq.map(Literal(_))), strSeq, EmptyRow) + + val intWithNull = intSeq.map(Literal(_)) :+ Literal.create(null, IntegerType) + val longWithNull = longSeq.map(Literal(_)) :+ Literal.create(null, LongType) + val strWithNull = strSeq.map(Literal(_)) :+ Literal.create(null, StringType) + checkEvaluation(CreateArray(intWithNull), intSeq :+ null, EmptyRow) + checkEvaluation(CreateArray(longWithNull), longSeq :+ null, EmptyRow) + checkEvaluation(CreateArray(strWithNull), strSeq :+ null, EmptyRow) + } + test("CreateStruct") { val row = create_row(1, 2, 3) val c1 = 'a.int.at(0)