Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-16289][SQL] Implement posexplode table generating function #13971

Closed
wants to merge 8 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions R/pkg/NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -234,6 +234,7 @@ exportMethods("%in%",
"over",
"percent_rank",
"pmod",
"posexplode",
"quarter",
"rand",
"randn",
Expand Down
17 changes: 17 additions & 0 deletions R/pkg/R/functions.R
Original file line number Diff line number Diff line change
Expand Up @@ -2934,3 +2934,20 @@ setMethod("sort_array",
jc <- callJStatic("org.apache.spark.sql.functions", "sort_array", x@jc, asc)
column(jc)
})

#' posexplode
#'
#' Creates a new row for each element with position in the given array or map column.
#'
#' @rdname posexplode
#' @name posexplode
#' @family collection_funcs
#' @export
#' @examples \dontrun{posexplode(df$c)}
#' @note posexplode since 2.1.0
setMethod("posexplode",
signature(x = "Column"),
function(x) {
jc <- callJStatic("org.apache.spark.sql.functions", "posexplode", x@jc)
column(jc)
})
4 changes: 4 additions & 0 deletions R/pkg/R/generics.R
Original file line number Diff line number Diff line change
Expand Up @@ -1050,6 +1050,10 @@ setGeneric("percent_rank", function(x) { standardGeneric("percent_rank") })
#' @export
setGeneric("pmod", function(y, x) { standardGeneric("pmod") })

#' @rdname posexplode
#' @export
setGeneric("posexplode", function(x) { standardGeneric("posexplode") })

#' @rdname quarter
#' @export
setGeneric("quarter", function(x) { standardGeneric("quarter") })
Expand Down
2 changes: 1 addition & 1 deletion R/pkg/inst/tests/testthat/test_sparkSQL.R
Original file line number Diff line number Diff line change
Expand Up @@ -1065,7 +1065,7 @@ test_that("column functions", {
c4 <- explode(c) + expm1(c) + factorial(c) + first(c) + floor(c) + hex(c)
c5 <- hour(c) + initcap(c) + last(c) + last_day(c) + length(c)
c6 <- log(c) + (c) + log1p(c) + log2(c) + lower(c) + ltrim(c) + max(c) + md5(c)
c7 <- mean(c) + min(c) + month(c) + negate(c) + quarter(c)
c7 <- mean(c) + min(c) + month(c) + negate(c) + posexplode(c) + quarter(c)
c8 <- reverse(c) + rint(c) + round(c) + rtrim(c) + sha1(c) + monotonically_increasing_id()
c9 <- signum(c) + sin(c) + sinh(c) + size(c) + stddev(c) + soundex(c) + sqrt(c) + sum(c)
c10 <- sumDistinct(c) + tan(c) + tanh(c) + toDegrees(c) + toRadians(c)
Expand Down
21 changes: 21 additions & 0 deletions python/pyspark/sql/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -1637,6 +1637,27 @@ def explode(col):
return Column(jc)


@since(2.1)
def posexplode(col):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cc @rxin , is posexplode a special hive fallback function that we need to register? other ones don't get registered in functions

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For this one, I thought the reason is explode is already registered. posexplode is a pair of that.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yea this one is probably fine.

i wouldn't register the other ones.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for reconfirming!

"""Returns a new row for each element with position in the given array or map.

>>> from pyspark.sql import Row
>>> eDF = spark.createDataFrame([Row(a=1, intlist=[1,2,3], mapfield={"a": "b"})])
>>> eDF.select(posexplode(eDF.intlist)).collect()
[Row(pos=0, col=1), Row(pos=1, col=2), Row(pos=2, col=3)]

>>> eDF.select(posexplode(eDF.mapfield)).show()
+---+---+-----+
|pos|key|value|
+---+---+-----+
| 0| a| b|
+---+---+-----+
"""
sc = SparkContext._active_spark_context
jc = sc._jvm.functions.posexplode(_to_java_column(col))
return Column(jc)


@ignore_unicode_prefix
@since(1.6)
def get_json_object(col, path):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,7 @@ object FunctionRegistry {
expression[NullIf]("nullif"),
expression[Nvl]("nvl"),
expression[Nvl2]("nvl2"),
expression[PosExplode]("posexplode"),
expression[Rand]("rand"),
expression[Randn]("randn"),
expression[CreateStruct]("struct"),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -94,13 +94,10 @@ case class UserDefinedGenerator(
}

/**
* Given an input array produces a sequence of rows for each value in the array.
* A base class for Explode and PosExplode
*/
// scalastyle:off line.size.limit
@ExpressionDescription(
usage = "_FUNC_(a) - Separates the elements of array a into multiple rows, or the elements of a map into multiple rows and columns.")
// scalastyle:on line.size.limit
case class Explode(child: Expression) extends UnaryExpression with Generator with CodegenFallback {
abstract class ExplodeBase(child: Expression, position: Boolean)
extends UnaryExpression with Generator with CodegenFallback with Serializable {

override def children: Seq[Expression] = child :: Nil

Expand All @@ -115,9 +112,26 @@ case class Explode(child: Expression) extends UnaryExpression with Generator wit

// hive-compatible default alias for explode function ("col" for array, "key", "value" for map)
override def elementSchema: StructType = child.dataType match {
case ArrayType(et, containsNull) => new StructType().add("col", et, containsNull)
case ArrayType(et, containsNull) =>
if (position) {
new StructType()
.add("pos", IntegerType, false)
.add("col", et, containsNull)
} else {
new StructType()
.add("col", et, containsNull)
}
case MapType(kt, vt, valueContainsNull) =>
new StructType().add("key", kt, false).add("value", vt, valueContainsNull)
if (position) {
new StructType()
.add("pos", IntegerType, false)
.add("key", kt, false)
.add("value", vt, valueContainsNull)
} else {
new StructType()
.add("key", kt, false)
.add("value", vt, valueContainsNull)
}
}

override def eval(input: InternalRow): TraversableOnce[InternalRow] = {
Expand All @@ -129,7 +143,7 @@ case class Explode(child: Expression) extends UnaryExpression with Generator wit
} else {
val rows = new Array[InternalRow](inputArray.numElements())
inputArray.foreach(et, (i, e) => {
rows(i) = InternalRow(e)
rows(i) = if (position) InternalRow(i, e) else InternalRow(e)
})
rows
}
Expand All @@ -141,11 +155,43 @@ case class Explode(child: Expression) extends UnaryExpression with Generator wit
val rows = new Array[InternalRow](inputMap.numElements())
var i = 0
inputMap.foreach(kt, vt, (k, v) => {
rows(i) = InternalRow(k, v)
rows(i) = if (position) InternalRow(i, k, v) else InternalRow(k, v)
i += 1
})
rows
}
}
}
}

/**
* Given an input array produces a sequence of rows for each value in the array.
*
* {{{
* SELECT explode(array(10,20)) ->
* 10
* 20
* }}}
*/
// scalastyle:off line.size.limit
@ExpressionDescription(
usage = "_FUNC_(a) - Separates the elements of array a into multiple rows, or the elements of map a into multiple rows and columns.",
extended = "> SELECT _FUNC_(array(10,20));\n 10\n 20")
// scalastyle:on line.size.limit
case class Explode(child: Expression) extends ExplodeBase(child, position = false)

/**
* Given an input array produces a sequence of rows for each position and value in the array.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

btw since the expression description might be difficult to see without line wrapping, it'd also be better to put an example here.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

also you don't need the { }

*
* {{{
* SELECT posexplode(array(10,20)) ->
* 0 10
* 1 20
* }}}
*/
// scalastyle:off line.size.limit
@ExpressionDescription(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

an example would be useful.

usage = "_FUNC_(a) - Separates the elements of array a into multiple rows with positions, or the elements of a map into multiple rows and columns with positions.",
extended = "> SELECT _FUNC_(array(10,20));\n 0\t10\n 1\t20")
// scalastyle:on line.size.limit
case class PosExplode(child: Expression) extends ExplodeBase(child, position = true)
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,8 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite {
assertError(new Murmur3Hash(Nil), "function hash requires at least one argument")
assertError(Explode('intField),
"input to function explode should be array or map type")
assertError(PosExplode('intField),
"input to function explode should be array or map type")
}

test("check types for CreateNamedStruct") {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.catalyst.expressions

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.unsafe.types.UTF8String

class GeneratorExpressionSuite extends SparkFunSuite with ExpressionEvalHelper {
private def checkTuple(actual: ExplodeBase, expected: Seq[InternalRow]): Unit = {
assert(actual.eval(null).toSeq === expected)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We have checkEvaluation for this purpose, how about using that?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, thank you for review, @cloud-fan , too.
Do we have an example of checkEvaluation to check the generator, multiple InternalRows?
I just thought checkEvaluation is just for a single row, e.g., values, arrays, maps.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

checkEvaluation takes Any as expected result, so I don't think checkEvaluation is only used for a single row.
Have you tried to pass a Seq[Row] to checkEvaluation? If it doesn't work, is it possible to improve checkEvaluation so that it can work for this case? thanks

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure. @cloud-fan . In fact, I try everything you told me in many ways because I trust you. :)

Copy link
Member Author

@dongjoon-hyun dongjoon-hyun Jun 30, 2016

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As a evidence, let me write the results of the most simplest case.

checkEvaluation(Explode(CreateArray(Seq.empty)), Seq.empty[Row])
checkEvaluation(Explode(CreateArray(Seq.empty)), Seq.empty[InternalRow])
checkEvaluation(Explode(CreateArray(Seq.empty)), Seq.empty)

All the above returns the followings.

Incorrect evaluation (codegen off): explode(array()), actual: InternalRow;(), expected: []

Here is the body of checkEvaluation. The following comments are the limitation I found.

// 1. This makes `Seq[Any]` into `GenericArrayData` generally.
val catalystValue = CatalystTypeConverters.convertToCatalyst(expected)
checkEvaluationWithoutCodegen(expression, catalystValue, inputRow)

// 2. Here, `val actual = plan(inputRow).get(0, expression.dataType)` is called to try casting to `expression.dataType`.
checkEvaluationWithGeneratedMutableProjection(expression, catalystValue, inputRow)

if (GenerateUnsafeProjection.canSupport(expression.dataType)) {
      // 3. Here, `val unsafeRow = plan(inputRow)` with one row assumption.
      checkEvalutionWithUnsafeProjection(expression, catalystValue, inputRow)
}

// 4. Here, `checkResult` fails at `result == expected`.
checkEvaluationWithOptimization(expression, catalystValue, inputRow)

In short, every steps of the checkEvaluation seem to depend on the single row assumption heavily. If we wan to change this. We should do in a separate issue since it's not trivial.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If I didn't misunderstand, it's definitely valuable issue to investigate more. If we can upgrade checkEvaluation later, we can unify the testcases of this PR with checkEvaluation.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's not change it for now. We also don't want test code to become so complicated that is is no longer obvious what's going on.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep. Thank you. I'll investigate it later.

}

private final val int_array = Seq(1, 2, 3)
private final val str_array = Seq("a", "b", "c")

test("explode") {
val int_correct_answer = Seq(Seq(1), Seq(2), Seq(3))
val str_correct_answer = Seq(
Seq(UTF8String.fromString("a")),
Seq(UTF8String.fromString("b")),
Seq(UTF8String.fromString("c")))

checkTuple(
Explode(CreateArray(Seq.empty)),
Seq.empty)

checkTuple(
Explode(CreateArray(int_array.map(Literal(_)))),
int_correct_answer.map(InternalRow.fromSeq(_)))

checkTuple(
Explode(CreateArray(str_array.map(Literal(_)))),
str_correct_answer.map(InternalRow.fromSeq(_)))
}

test("posexplode") {
val int_correct_answer = Seq(Seq(0, 1), Seq(1, 2), Seq(2, 3))
val str_correct_answer = Seq(
Seq(0, UTF8String.fromString("a")),
Seq(1, UTF8String.fromString("b")),
Seq(2, UTF8String.fromString("c")))

checkTuple(
PosExplode(CreateArray(Seq.empty)),
Seq.empty)

checkTuple(
PosExplode(CreateArray(int_array.map(Literal(_)))),
int_correct_answer.map(InternalRow.fromSeq(_)))

checkTuple(
PosExplode(CreateArray(str_array.map(Literal(_)))),
str_correct_answer.map(InternalRow.fromSeq(_)))
}
}
1 change: 1 addition & 0 deletions sql/core/src/main/scala/org/apache/spark/sql/Column.scala
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,7 @@ class Column(protected[sql] val expr: Expression) extends Logging {
// Leave an unaliased generator with an empty list of names since the analyzer will generate
// the correct defaults after the nested expression's type has been resolved.
case explode: Explode => MultiAlias(explode, Nil)
case explode: PosExplode => MultiAlias(explode, Nil)

case jt: JsonTuple => MultiAlias(jt, Nil)

Expand Down
8 changes: 8 additions & 0 deletions sql/core/src/main/scala/org/apache/spark/sql/functions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2721,6 +2721,14 @@ object functions {
*/
def explode(e: Column): Column = withExpr { Explode(e.expr) }

/**
* Creates a new row for each element with position in the given array or map column.
*
* @group collection_funcs
* @since 2.1.0
*/
def posexplode(e: Column): Column = withExpr { PosExplode(e.expr) }
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

need to add this to python too


/**
* Extracts json object from a json string based on json path specified, and returns json string
* of the extracted json object. It will return null if the input json string is invalid.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -122,66 +122,6 @@ class ColumnExpressionSuite extends QueryTest with SharedSQLContext {
assert(newCol.expr.asInstanceOf[NamedExpression].metadata.getString("key") === "value")
}

test("single explode") {
val df = Seq((1, Seq(1, 2, 3))).toDF("a", "intList")
checkAnswer(
df.select(explode('intList)),
Row(1) :: Row(2) :: Row(3) :: Nil)
}

test("explode and other columns") {
val df = Seq((1, Seq(1, 2, 3))).toDF("a", "intList")

checkAnswer(
df.select($"a", explode('intList)),
Row(1, 1) ::
Row(1, 2) ::
Row(1, 3) :: Nil)

checkAnswer(
df.select($"*", explode('intList)),
Row(1, Seq(1, 2, 3), 1) ::
Row(1, Seq(1, 2, 3), 2) ::
Row(1, Seq(1, 2, 3), 3) :: Nil)
}

test("aliased explode") {
val df = Seq((1, Seq(1, 2, 3))).toDF("a", "intList")

checkAnswer(
df.select(explode('intList).as('int)).select('int),
Row(1) :: Row(2) :: Row(3) :: Nil)

checkAnswer(
df.select(explode('intList).as('int)).select(sum('int)),
Row(6) :: Nil)
}

test("explode on map") {
val df = Seq((1, Map("a" -> "b"))).toDF("a", "map")

checkAnswer(
df.select(explode('map)),
Row("a", "b"))
}

test("explode on map with aliases") {
val df = Seq((1, Map("a" -> "b"))).toDF("a", "map")

checkAnswer(
df.select(explode('map).as("key1" :: "value1" :: Nil)).select("key1", "value1"),
Row("a", "b"))
}

test("self join explode") {
val df = Seq((1, Seq(1, 2, 3))).toDF("a", "intList")
val exploded = df.select(explode('intList).as('i))

checkAnswer(
exploded.join(exploded, exploded("i") === exploded("i")).agg(count("*")),
Row(3) :: Nil)
}

test("collect on column produced by a binary operator") {
val df = Seq((1, 2, 3)).toDF("a", "b", "c")
checkAnswer(df.select(df("a") + df("b")), Seq(Row(3)))
Expand Down
Loading