Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,7 @@ object FunctionRegistry {
expression[Explode]("explode"),
expression[Greatest]("greatest"),
expression[If]("if"),
expression[Inline]("inline"),
expression[IsNaN]("isnan"),
expression[IfNull]("ifnull"),
expression[IsNull]("isnull"),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -195,3 +195,38 @@ case class Explode(child: Expression) extends ExplodeBase(child, position = fals
extended = "> SELECT _FUNC_(array(10,20));\n 0\t10\n 1\t20")
// scalastyle:on line.size.limit
case class PosExplode(child: Expression) extends ExplodeBase(child, position = true)

/**
* Explodes an array of structs into a table.
*/
@ExpressionDescription(
usage = "_FUNC_(a) - Explodes an array of structs into a table.",
extended = "> SELECT _FUNC_(array(struct(1, 'a'), struct(2, 'b')));\n [1,a]\n [2,b]")
case class Inline(child: Expression) extends UnaryExpression with Generator with CodegenFallback {

override def children: Seq[Expression] = child :: Nil

override def checkInputDataTypes(): TypeCheckResult = child.dataType match {
case ArrayType(et, _) if et.isInstanceOf[StructType] =>
TypeCheckResult.TypeCheckSuccess
case _ =>
TypeCheckResult.TypeCheckFailure(
s"input to function $prettyName should be array of struct type, not ${child.dataType}")
}

override def elementSchema: StructType = child.dataType match {
case ArrayType(et : StructType, _) => et
}

private lazy val numFields = elementSchema.fields.length

override def eval(input: InternalRow): TraversableOnce[InternalRow] = {
val inputArray = child.eval(input).asInstanceOf[ArrayData]
if (inputArray == null) {
Nil
} else {
for (i <- 0 until inputArray.numElements())
yield inputArray.getStruct(i, numFields)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure how is the performance of for-yield, maybe it's safe to create an array manually and use while loop here?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you, @cloud-fan . By the way, for about this, @rxin gave me an advice at the first commit of this PR.

we don't need to materialize the array, do we? We can create an iterator to return the results.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah i see, for-yield returns an iterator.

}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,53 +19,48 @@ package org.apache.spark.sql.catalyst.expressions

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.unsafe.types.UTF8String
import org.apache.spark.sql.types._

class GeneratorExpressionSuite extends SparkFunSuite with ExpressionEvalHelper {
private def checkTuple(actual: ExplodeBase, expected: Seq[InternalRow]): Unit = {
assert(actual.eval(null).toSeq === expected)
private def checkTuple(actual: Expression, expected: Seq[InternalRow]): Unit = {
assert(actual.eval(null).asInstanceOf[TraversableOnce[InternalRow]].toSeq === expected)
}

private final val int_array = Seq(1, 2, 3)
private final val str_array = Seq("a", "b", "c")
private final val empty_array = CreateArray(Seq.empty)
private final val int_array = CreateArray(Seq(1, 2, 3).map(Literal(_)))
private final val str_array = CreateArray(Seq("a", "b", "c").map(Literal(_)))

test("explode") {
val int_correct_answer = Seq(Seq(1), Seq(2), Seq(3))
val str_correct_answer = Seq(
Seq(UTF8String.fromString("a")),
Seq(UTF8String.fromString("b")),
Seq(UTF8String.fromString("c")))
val int_correct_answer = Seq(create_row(1), create_row(2), create_row(3))
val str_correct_answer = Seq(create_row("a"), create_row("b"), create_row("c"))

checkTuple(
Explode(CreateArray(Seq.empty)),
Seq.empty)
checkTuple(Explode(empty_array), Seq.empty)
checkTuple(Explode(int_array), int_correct_answer)
checkTuple(Explode(str_array), str_correct_answer)
}

checkTuple(
Explode(CreateArray(int_array.map(Literal(_)))),
int_correct_answer.map(InternalRow.fromSeq(_)))
test("posexplode") {
val int_correct_answer = Seq(create_row(0, 1), create_row(1, 2), create_row(2, 3))
val str_correct_answer = Seq(create_row(0, "a"), create_row(1, "b"), create_row(2, "c"))

checkTuple(
Explode(CreateArray(str_array.map(Literal(_)))),
str_correct_answer.map(InternalRow.fromSeq(_)))
checkTuple(PosExplode(CreateArray(Seq.empty)), Seq.empty)
checkTuple(PosExplode(int_array), int_correct_answer)
checkTuple(PosExplode(str_array), str_correct_answer)
}

test("posexplode") {
val int_correct_answer = Seq(Seq(0, 1), Seq(1, 2), Seq(2, 3))
val str_correct_answer = Seq(
Seq(0, UTF8String.fromString("a")),
Seq(1, UTF8String.fromString("b")),
Seq(2, UTF8String.fromString("c")))
test("inline") {
val correct_answer = Seq(create_row(0, "a"), create_row(1, "b"), create_row(2, "c"))

checkTuple(
PosExplode(CreateArray(Seq.empty)),
Inline(Literal.create(Array(), ArrayType(new StructType().add("id", LongType)))),
Seq.empty)

checkTuple(
PosExplode(CreateArray(int_array.map(Literal(_)))),
int_correct_answer.map(InternalRow.fromSeq(_)))

checkTuple(
PosExplode(CreateArray(str_array.map(Literal(_)))),
str_correct_answer.map(InternalRow.fromSeq(_)))
Inline(CreateArray(Seq(
CreateStruct(Seq(Literal(0), Literal("a"))),
CreateStruct(Seq(Literal(1), Literal("b"))),
CreateStruct(Seq(Literal(2), Literal("c")))
))),
correct_answer)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -89,4 +89,64 @@ class GeneratorFunctionSuite extends QueryTest with SharedSQLContext {
exploded.join(exploded, exploded("i") === exploded("i")).agg(count("*")),
Row(3) :: Nil)
}

test("inline raises exception on array of null type") {
val m = intercept[AnalysisException] {
spark.range(2).selectExpr("inline(array())")
}.getMessage
assert(m.contains("data type mismatch"))
}

test("inline with empty table") {
checkAnswer(
spark.range(0).selectExpr("inline(array(struct(10, 100)))"),
Nil)
}

test("inline on literal") {
checkAnswer(
spark.range(2).selectExpr("inline(array(struct(10, 100), struct(20, 200), struct(30, 300)))"),
Row(10, 100) :: Row(20, 200) :: Row(30, 300) ::
Row(10, 100) :: Row(20, 200) :: Row(30, 300) :: Nil)
}

test("inline on column") {
val df = Seq((1, 2)).toDF("a", "b")

checkAnswer(
df.selectExpr("inline(array(struct(a), struct(a)))"),
Row(1) :: Row(1) :: Nil)

checkAnswer(
df.selectExpr("inline(array(struct(a, b), struct(a, b)))"),
Row(1, 2) :: Row(1, 2) :: Nil)

// Spark think [struct<a:int>, struct<b:int>] is heterogeneous due to name difference.
val m = intercept[AnalysisException] {
df.selectExpr("inline(array(struct(a), struct(b)))")
}.getMessage
assert(m.contains("data type mismatch"))

checkAnswer(
df.selectExpr("inline(array(struct(a), named_struct('a', b)))"),
Row(1) :: Row(2) :: Nil)

// Spark think [struct<a:int>, struct<col1:int>] is heterogeneous due to name difference.
val m2 = intercept[AnalysisException] {
df.selectExpr("inline(array(struct(a), struct(2)))")
}.getMessage
assert(m2.contains("data type mismatch"))

checkAnswer(
df.selectExpr("inline(array(struct(a), named_struct('a', 2)))"),
Row(1) :: Row(2) :: Nil)

checkAnswer(
df.selectExpr("struct(a)").selectExpr("inline(array(*))"),
Row(1) :: Nil)

checkAnswer(
df.selectExpr("array(struct(a), named_struct('a', b))").selectExpr("inline(*)"),
Row(1) :: Row(2) :: Nil)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -242,9 +242,6 @@ private[sql] class HiveSessionCatalog(
"map_keys", "map_values",
"parse_url", "percentile", "percentile_approx", "reflect", "sentences", "stack", "str_to_map",
"xpath", "xpath_double", "xpath_float", "xpath_int", "xpath_long",
"xpath_number", "xpath_short", "xpath_string",

// table generating function
"inline"
"xpath_number", "xpath_short", "xpath_string"
)
}