Skip to content

Commit

Permalink
Remove redundant match and use create_row() in testsuites.
Browse files Browse the repository at this point in the history
  • Loading branch information
dongjoon-hyun committed Jul 1, 2016
1 parent 9382f64 commit c43a187
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 59 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow}
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback
import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData, MapData}
import org.apache.spark.sql.catalyst.util.{ArrayData, MapData}
import org.apache.spark.sql.types._

/**
Expand Down Expand Up @@ -221,16 +221,15 @@ case class Inline(child: Expression) extends UnaryExpression with Generator with
})
}

private lazy val ncol = elementSchema.fields.length
private lazy val numFields = elementSchema.fields.length

override def eval(input: InternalRow): TraversableOnce[InternalRow] = child.dataType match {
case ArrayType(et : StructType, _) =>
val inputArray = child.eval(input).asInstanceOf[ArrayData]
if (inputArray == null) {
Nil
} else {
for (i <- 0 until inputArray.numElements())
yield inputArray.getStruct(i, ncol)
}
override def eval(input: InternalRow): TraversableOnce[InternalRow] = {
val inputArray = child.eval(input).asInstanceOf[ArrayData]
if (inputArray == null) {
Nil
} else {
for (i <- 0 until inputArray.numElements())
yield inputArray.getStruct(i, numFields)
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -20,64 +20,39 @@ package org.apache.spark.sql.catalyst.expressions
import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String

class GeneratorExpressionSuite extends SparkFunSuite with ExpressionEvalHelper {
private def checkTuple(actual: Expression, expected: Seq[InternalRow]): Unit = {
assert(actual.eval(null).asInstanceOf[TraversableOnce[InternalRow]].toSeq === expected)
}

private final val int_array = Seq(1, 2, 3)
private final val str_array = Seq("a", "b", "c")
private final val empty_array = CreateArray(Seq.empty)
private final val int_array = CreateArray(Seq(1, 2, 3).map(Literal(_)))
private final val str_array = CreateArray(Seq("a", "b", "c").map(Literal(_)))

test("explode") {
val int_correct_answer = Seq(Seq(1), Seq(2), Seq(3))
val str_correct_answer = Seq(
Seq(UTF8String.fromString("a")),
Seq(UTF8String.fromString("b")),
Seq(UTF8String.fromString("c")))
val int_correct_answer = Seq(create_row(1), create_row(2), create_row(3))
val str_correct_answer = Seq(create_row("a"), create_row("b"), create_row("c"))

checkTuple(
Explode(CreateArray(Seq.empty)),
Seq.empty)

checkTuple(
Explode(CreateArray(int_array.map(Literal(_)))),
int_correct_answer.map(InternalRow.fromSeq(_)))

checkTuple(
Explode(CreateArray(str_array.map(Literal(_)))),
str_correct_answer.map(InternalRow.fromSeq(_)))
checkTuple(Explode(empty_array), Seq.empty)
checkTuple(Explode(int_array), int_correct_answer)
checkTuple(Explode(str_array), str_correct_answer)
}

test("posexplode") {
val int_correct_answer = Seq(Seq(0, 1), Seq(1, 2), Seq(2, 3))
val str_correct_answer = Seq(
Seq(0, UTF8String.fromString("a")),
Seq(1, UTF8String.fromString("b")),
Seq(2, UTF8String.fromString("c")))
val int_correct_answer = Seq(create_row(0, 1), create_row(1, 2), create_row(2, 3))
val str_correct_answer = Seq(create_row(0, "a"), create_row(1, "b"), create_row(2, "c"))

checkTuple(
PosExplode(CreateArray(Seq.empty)),
Seq.empty)

checkTuple(
PosExplode(CreateArray(int_array.map(Literal(_)))),
int_correct_answer.map(InternalRow.fromSeq(_)))

checkTuple(
PosExplode(CreateArray(str_array.map(Literal(_)))),
str_correct_answer.map(InternalRow.fromSeq(_)))
checkTuple(PosExplode(CreateArray(Seq.empty)), Seq.empty)
checkTuple(PosExplode(int_array), int_correct_answer)
checkTuple(PosExplode(str_array), str_correct_answer)
}

test("inline") {
val correct_answer = Seq(
Seq(0, UTF8String.fromString("a")),
Seq(1, UTF8String.fromString("b")),
Seq(2, UTF8String.fromString("c")))
val correct_answer = Seq(create_row(0, "a"), create_row(1, "b"), create_row(2, "c"))

checkTuple(
Inline(Literal.create(Array(), ArrayType(StructType(Seq(StructField("id1", LongType)))))),
Inline(Literal.create(Array(), ArrayType(new StructType().add("id", LongType)))),
Seq.empty)

checkTuple(
Expand All @@ -86,6 +61,6 @@ class GeneratorExpressionSuite extends SparkFunSuite with ExpressionEvalHelper {
CreateStruct(Seq(Literal(1), Literal("b"))),
CreateStruct(Seq(Literal(2), Literal("c")))
))),
correct_answer.map(InternalRow.fromSeq(_)))
correct_answer)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,8 @@

package org.apache.spark.sql

import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.functions._
import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.sql.types.{IntegerType, StringType}

class GeneratorFunctionSuite extends QueryTest with SharedSQLContext {
import testImplicits._
Expand Down Expand Up @@ -92,17 +90,19 @@ class GeneratorFunctionSuite extends QueryTest with SharedSQLContext {
Row(3) :: Nil)
}

test("inline with empty table or empty array") {
checkAnswer(
spark.range(0).selectExpr("inline(array(struct(10, 100)))"),
Nil)

test("inline raises exception on empty array") {
val m = intercept[AnalysisException] {
spark.range(2).selectExpr("inline(array())")
}.getMessage
assert(m.contains("data type mismatch"))
}

test("inline with empty table") {
checkAnswer(
spark.range(0).selectExpr("inline(array(struct(10, 100)))"),
Nil)
}

test("inline on literal") {
checkAnswer(
spark.range(2).selectExpr("inline(array(struct(10, 100), struct(20, 200), struct(30, 300)))"),
Expand Down

0 comments on commit c43a187

Please sign in to comment.