From 86f80e2b4759e574fe3eb91695f81b644db87242 Mon Sep 17 00:00:00 2001 From: Yijie Shen Date: Wed, 22 Jul 2015 12:19:59 -0700 Subject: [PATCH] [SPARK-9165] [SQL] codegen for CreateArray, CreateStruct and CreateNamedStruct JIRA: https://issues.apache.org/jira/browse/SPARK-9165 Author: Yijie Shen Closes #7537 from yjshen/array_struct_codegen and squashes the following commits: 3a6dce6 [Yijie Shen] use infix notion in createArray test 5e90f0a [Yijie Shen] resolve comments: classOf 39cefb8 [Yijie Shen] codegen for createArray createStruct & createNamedStruct --- .../expressions/complexTypeCreator.scala | 65 +++++++++++++++++-- .../expressions/ComplexTypeSuite.scala | 16 +++++ 2 files changed, 76 insertions(+), 5 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala index f9fd04c02aaef..20b1eaab8e303 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala @@ -17,16 +17,18 @@ package org.apache.spark.sql.catalyst.expressions +import scala.collection.mutable + import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.TypeCheckResult -import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback +import org.apache.spark.sql.catalyst.expressions.codegen.{GeneratedExpressionCode, CodeGenContext} import org.apache.spark.sql.catalyst.util.TypeUtils import org.apache.spark.sql.types._ /** * Returns an Array containing the evaluation of all children expressions. */ -case class CreateArray(children: Seq[Expression]) extends Expression with CodegenFallback { +case class CreateArray(children: Seq[Expression]) extends Expression { override def foldable: Boolean = children.forall(_.foldable) @@ -45,14 +47,31 @@ case class CreateArray(children: Seq[Expression]) extends Expression with Codege children.map(_.eval(input)) } + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + val arraySeqClass = classOf[mutable.ArraySeq[Any]].getName + s""" + boolean ${ev.isNull} = false; + $arraySeqClass ${ev.primitive} = new $arraySeqClass(${children.size}); + """ + + children.zipWithIndex.map { case (e, i) => + val eval = e.gen(ctx) + eval.code + s""" + if (${eval.isNull}) { + ${ev.primitive}.update($i, null); + } else { + ${ev.primitive}.update($i, ${eval.primitive}); + } + """ + }.mkString("\n") + } + override def prettyName: String = "array" } /** * Returns a Row containing the evaluation of all children expressions. - * TODO: [[CreateStruct]] does not support codegen. */ -case class CreateStruct(children: Seq[Expression]) extends Expression with CodegenFallback { +case class CreateStruct(children: Seq[Expression]) extends Expression { override def foldable: Boolean = children.forall(_.foldable) @@ -76,6 +95,24 @@ case class CreateStruct(children: Seq[Expression]) extends Expression with Codeg InternalRow(children.map(_.eval(input)): _*) } + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + val rowClass = classOf[GenericMutableRow].getName + s""" + boolean ${ev.isNull} = false; + final $rowClass ${ev.primitive} = new $rowClass(${children.size}); + """ + + children.zipWithIndex.map { case (e, i) => + val eval = e.gen(ctx) + eval.code + s""" + if (${eval.isNull}) { + ${ev.primitive}.update($i, null); + } else { + ${ev.primitive}.update($i, ${eval.primitive}); + } + """ + }.mkString("\n") + } + override def prettyName: String = "struct" } @@ -84,7 +121,7 @@ case class CreateStruct(children: Seq[Expression]) extends Expression with Codeg * * @param children Seq(name1, val1, name2, val2, ...) */ -case class CreateNamedStruct(children: Seq[Expression]) extends Expression with CodegenFallback { +case class CreateNamedStruct(children: Seq[Expression]) extends Expression { private lazy val (nameExprs, valExprs) = children.grouped(2).map { case Seq(name, value) => (name, value) }.toList.unzip @@ -122,5 +159,23 @@ case class CreateNamedStruct(children: Seq[Expression]) extends Expression with InternalRow(valExprs.map(_.eval(input)): _*) } + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + val rowClass = classOf[GenericMutableRow].getName + s""" + boolean ${ev.isNull} = false; + final $rowClass ${ev.primitive} = new $rowClass(${valExprs.size}); + """ + + valExprs.zipWithIndex.map { case (e, i) => + val eval = e.gen(ctx) + eval.code + s""" + if (${eval.isNull}) { + ${ev.primitive}.update($i, null); + } else { + ${ev.primitive}.update($i, ${eval.primitive}); + } + """ + }.mkString("\n") + } + override def prettyName: String = "named_struct" } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala index e3042143632aa..a8aee8f634e03 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala @@ -117,6 +117,22 @@ class ComplexTypeSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(getArrayStructFields(nullArrayStruct, "a"), null) } + test("CreateArray") { + val intSeq = Seq(5, 10, 15, 20, 25) + val longSeq = intSeq.map(_.toLong) + val strSeq = intSeq.map(_.toString) + checkEvaluation(CreateArray(intSeq.map(Literal(_))), intSeq, EmptyRow) + checkEvaluation(CreateArray(longSeq.map(Literal(_))), longSeq, EmptyRow) + checkEvaluation(CreateArray(strSeq.map(Literal(_))), strSeq, EmptyRow) + + val intWithNull = intSeq.map(Literal(_)) :+ Literal.create(null, IntegerType) + val longWithNull = longSeq.map(Literal(_)) :+ Literal.create(null, LongType) + val strWithNull = strSeq.map(Literal(_)) :+ Literal.create(null, StringType) + checkEvaluation(CreateArray(intWithNull), intSeq :+ null, EmptyRow) + checkEvaluation(CreateArray(longWithNull), longSeq :+ null, EmptyRow) + checkEvaluation(CreateArray(strWithNull), strSeq :+ null, EmptyRow) + } + test("CreateStruct") { val row = create_row(1, 2, 3) val c1 = 'a.int.at(0)