Skip to content

Commit

Permalink
[SPARK-23736][SQL] Adding more tests
Browse files Browse the repository at this point in the history
  • Loading branch information
mn-mikke authored and mn-mikke committed Apr 13, 2018
1 parent 944e0c9 commit 7f5124b
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -368,7 +368,7 @@ case class Concat(children: Seq[Expression]) extends Expression {
val arrayConcatClass = if (CodeGenerator.isPrimitiveType(elementType)) {
genCodeForPrimitiveArrayConcat(ctx, elementType)
} else {
genCodeForComplexArrayConcat(ctx)
genCodeForComplexArrayConcat(ctx, elementType)
}
(arrayConcatClass, s"ArrayData[] $args = new ArrayData[${evals.length}];")
}
Expand Down Expand Up @@ -451,7 +451,7 @@ case class Concat(children: Seq[Expression]) extends Expression {
| } else {
| $arrayData.set$primitiveValueTypeName(
| $counter[0],
| args[$idx].get$primitiveValueTypeName(z)
| ${CodeGenerator.getValue(s"args[$idx]", elementType, "z")}
| );
| }
| $counter[0]++;
Expand Down Expand Up @@ -482,7 +482,7 @@ case class Concat(children: Seq[Expression]) extends Expression {
|}""".stripMargin
}

private def genCodeForComplexArrayConcat(ctx: CodegenContext): String = {
private def genCodeForComplexArrayConcat(ctx: CodegenContext, elementType: DataType): String = {
val genericArrayClass = classOf[GenericArrayData].getName
val arrayData = ctx.freshName("arrayObjects")
val counter = ctx.freshName("counter")
Expand All @@ -492,7 +492,7 @@ case class Concat(children: Seq[Expression]) extends Expression {
val assignments = (0 until children.length).map { idx =>
s"""
|for (int z = 0; z < args[$idx].numElements(); z++) {
| $arrayData[$counter[0]] = args[$idx].array()[z];
| $arrayData[$counter[0]] = ${CodeGenerator.getValue(s"args[$idx]", elementType, "z")};
| $counter[0]++;
|}
""".stripMargin
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -422,6 +422,8 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext {
(Seq(1, 0), Seq.empty[Int], Seq(2L), nseqi, Seq("a"), Seq.empty[String], Seq(null), nseqs)
).toDF("i1", "i2", "i3", "in", "s1", "s2", "s3", "sn")

val dummyFilter = (c: Column) => c.isNull || c.isNotNull // switch codeGen on

// Simple test cases
checkAnswer(
df.selectExpr("array(1, 2, 3L)"),
Expand All @@ -436,6 +438,10 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext {
df.select(concat($"i1", $"i2", $"i3")),
Seq(Row(Seq(1, 2, 3, 5, 6)), Row(Seq(1, 0, 2)))
)
checkAnswer(
df.filter(dummyFilter($"i1")).select(concat($"i1", $"i2", $"i3")),
Seq(Row(Seq(1, 2, 3, 5, 6)), Row(Seq(1, 0, 2)))
)
checkAnswer(
df.selectExpr("concat(array(1, null), i2, i3)"),
Seq(Row(Seq(1, null, 2, 3, 5, 6)), Row(Seq(1, null, 2)))
Expand All @@ -448,6 +454,10 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext {
df.selectExpr("concat(s1, s2, s3)"),
Seq(Row(Seq("a", "b", "c", "d", "e", "f")), Row(Seq("a", null)))
)
checkAnswer(
df.filter(dummyFilter($"s1"))select(concat($"s1", $"s2", $"s3")),
Seq(Row(Seq("a", "b", "c", "d", "e", "f")), Row(Seq("a", null)))
)

// Null test cases
checkAnswer(
Expand Down

0 comments on commit 7f5124b

Please sign in to comment.