Skip to content

Commit

Permalink
[SPARK-9165] [SQL] codegen for CreateArray, CreateStruct and CreateNa…
Browse files Browse the repository at this point in the history
…medStruct

JIRA: https://issues.apache.org/jira/browse/SPARK-9165

Author: Yijie Shen <henry.yijieshen@gmail.com>

Closes #7537 from yjshen/array_struct_codegen and squashes the following commits:

3a6dce6 [Yijie Shen] use infix notion in createArray test
5e90f0a [Yijie Shen] resolve comments: classOf
39cefb8 [Yijie Shen] codegen for createArray createStruct & createNamedStruct
  • Loading branch information
yjshen authored and marmbrus committed Jul 22, 2015
1 parent 7652095 commit 86f80e2
Show file tree
Hide file tree
Showing 2 changed files with 76 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,16 +17,18 @@

package org.apache.spark.sql.catalyst.expressions

import scala.collection.mutable

import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback
import org.apache.spark.sql.catalyst.expressions.codegen.{GeneratedExpressionCode, CodeGenContext}
import org.apache.spark.sql.catalyst.util.TypeUtils
import org.apache.spark.sql.types._

/**
* Returns an Array containing the evaluation of all children expressions.
*/
case class CreateArray(children: Seq[Expression]) extends Expression with CodegenFallback {
case class CreateArray(children: Seq[Expression]) extends Expression {

override def foldable: Boolean = children.forall(_.foldable)

Expand All @@ -45,14 +47,31 @@ case class CreateArray(children: Seq[Expression]) extends Expression with Codege
children.map(_.eval(input))
}

override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
val arraySeqClass = classOf[mutable.ArraySeq[Any]].getName
s"""
boolean ${ev.isNull} = false;
$arraySeqClass<Object> ${ev.primitive} = new $arraySeqClass<Object>(${children.size});
""" +
children.zipWithIndex.map { case (e, i) =>
val eval = e.gen(ctx)
eval.code + s"""
if (${eval.isNull}) {
${ev.primitive}.update($i, null);
} else {
${ev.primitive}.update($i, ${eval.primitive});
}
"""
}.mkString("\n")
}

override def prettyName: String = "array"
}

/**
* Returns a Row containing the evaluation of all children expressions.
* TODO: [[CreateStruct]] does not support codegen.
*/
case class CreateStruct(children: Seq[Expression]) extends Expression with CodegenFallback {
case class CreateStruct(children: Seq[Expression]) extends Expression {

override def foldable: Boolean = children.forall(_.foldable)

Expand All @@ -76,6 +95,24 @@ case class CreateStruct(children: Seq[Expression]) extends Expression with Codeg
InternalRow(children.map(_.eval(input)): _*)
}

override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
val rowClass = classOf[GenericMutableRow].getName
s"""
boolean ${ev.isNull} = false;
final $rowClass ${ev.primitive} = new $rowClass(${children.size});
""" +
children.zipWithIndex.map { case (e, i) =>
val eval = e.gen(ctx)
eval.code + s"""
if (${eval.isNull}) {
${ev.primitive}.update($i, null);
} else {
${ev.primitive}.update($i, ${eval.primitive});
}
"""
}.mkString("\n")
}

override def prettyName: String = "struct"
}

Expand All @@ -84,7 +121,7 @@ case class CreateStruct(children: Seq[Expression]) extends Expression with Codeg
*
* @param children Seq(name1, val1, name2, val2, ...)
*/
case class CreateNamedStruct(children: Seq[Expression]) extends Expression with CodegenFallback {
case class CreateNamedStruct(children: Seq[Expression]) extends Expression {

private lazy val (nameExprs, valExprs) =
children.grouped(2).map { case Seq(name, value) => (name, value) }.toList.unzip
Expand Down Expand Up @@ -122,5 +159,23 @@ case class CreateNamedStruct(children: Seq[Expression]) extends Expression with
InternalRow(valExprs.map(_.eval(input)): _*)
}

override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
val rowClass = classOf[GenericMutableRow].getName
s"""
boolean ${ev.isNull} = false;
final $rowClass ${ev.primitive} = new $rowClass(${valExprs.size});
""" +
valExprs.zipWithIndex.map { case (e, i) =>
val eval = e.gen(ctx)
eval.code + s"""
if (${eval.isNull}) {
${ev.primitive}.update($i, null);
} else {
${ev.primitive}.update($i, ${eval.primitive});
}
"""
}.mkString("\n")
}

override def prettyName: String = "named_struct"
}
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,22 @@ class ComplexTypeSuite extends SparkFunSuite with ExpressionEvalHelper {
checkEvaluation(getArrayStructFields(nullArrayStruct, "a"), null)
}

test("CreateArray") {
val intSeq = Seq(5, 10, 15, 20, 25)
val longSeq = intSeq.map(_.toLong)
val strSeq = intSeq.map(_.toString)
checkEvaluation(CreateArray(intSeq.map(Literal(_))), intSeq, EmptyRow)
checkEvaluation(CreateArray(longSeq.map(Literal(_))), longSeq, EmptyRow)
checkEvaluation(CreateArray(strSeq.map(Literal(_))), strSeq, EmptyRow)

val intWithNull = intSeq.map(Literal(_)) :+ Literal.create(null, IntegerType)
val longWithNull = longSeq.map(Literal(_)) :+ Literal.create(null, LongType)
val strWithNull = strSeq.map(Literal(_)) :+ Literal.create(null, StringType)
checkEvaluation(CreateArray(intWithNull), intSeq :+ null, EmptyRow)
checkEvaluation(CreateArray(longWithNull), longSeq :+ null, EmptyRow)
checkEvaluation(CreateArray(strWithNull), strSeq :+ null, EmptyRow)
}

test("CreateStruct") {
val row = create_row(1, 2, 3)
val c1 = 'a.int.at(0)
Expand Down

0 comments on commit 86f80e2

Please sign in to comment.