Skip to content

Commit

Permalink
[SPARK-20830][PYSPARK][SQL] Add posexplode and posexplode_outer
Browse files Browse the repository at this point in the history
## What changes were proposed in this pull request?

Add Python wrappers for `o.a.s.sql.functions.explode_outer` and `o.a.s.sql.functions.posexplode_outer`.

## How was this patch tested?

Unit tests, doctests.

Author: zero323 <zero323@users.noreply.github.com>

Closes #18049 from zero323/SPARK-20830.
  • Loading branch information
zero323 authored and ueshin committed Jun 21, 2017
1 parent ba78514 commit 215281d
Show file tree
Hide file tree
Showing 2 changed files with 83 additions and 2 deletions.
65 changes: 65 additions & 0 deletions python/pyspark/sql/functions.py
Expand Up @@ -1727,6 +1727,71 @@ def posexplode(col):
return Column(jc)


@since(2.3)
def explode_outer(col):
"""Returns a new row for each element in the given array or map.
Unlike explode, if the array/map is null or empty then null is produced.
>>> df = spark.createDataFrame(
... [(1, ["foo", "bar"], {"x": 1.0}), (2, [], {}), (3, None, None)],
... ("id", "an_array", "a_map")
... )
>>> df.select("id", "an_array", explode_outer("a_map")).show()
+---+----------+----+-----+
| id| an_array| key|value|
+---+----------+----+-----+
| 1|[foo, bar]| x| 1.0|
| 2| []|null| null|
| 3| null|null| null|
+---+----------+----+-----+
>>> df.select("id", "a_map", explode_outer("an_array")).show()
+---+-------------+----+
| id| a_map| col|
+---+-------------+----+
| 1|Map(x -> 1.0)| foo|
| 1|Map(x -> 1.0)| bar|
| 2| Map()|null|
| 3| null|null|
+---+-------------+----+
"""
sc = SparkContext._active_spark_context
jc = sc._jvm.functions.explode_outer(_to_java_column(col))
return Column(jc)


@since(2.3)
def posexplode_outer(col):
"""Returns a new row for each element with position in the given array or map.
Unlike posexplode, if the array/map is null or empty then the row (null, null) is produced.
>>> df = spark.createDataFrame(
... [(1, ["foo", "bar"], {"x": 1.0}), (2, [], {}), (3, None, None)],
... ("id", "an_array", "a_map")
... )
>>> df.select("id", "an_array", posexplode_outer("a_map")).show()
+---+----------+----+----+-----+
| id| an_array| pos| key|value|
+---+----------+----+----+-----+
| 1|[foo, bar]| 0| x| 1.0|
| 2| []|null|null| null|
| 3| null|null|null| null|
+---+----------+----+----+-----+
>>> df.select("id", "a_map", posexplode_outer("an_array")).show()
+---+-------------+----+----+
| id| a_map| pos| col|
+---+-------------+----+----+
| 1|Map(x -> 1.0)| 0| foo|
| 1|Map(x -> 1.0)| 1| bar|
| 2| Map()|null|null|
| 3| null|null|null|
+---+-------------+----+----+
"""
sc = SparkContext._active_spark_context
jc = sc._jvm.functions.posexplode_outer(_to_java_column(col))
return Column(jc)


@ignore_unicode_prefix
@since(1.6)
def get_json_object(col, path):
Expand Down
20 changes: 18 additions & 2 deletions python/pyspark/sql/tests.py
Expand Up @@ -258,8 +258,12 @@ def test_column_name_encoding(self):
self.assertTrue(isinstance(columns[1], str))

def test_explode(self):
from pyspark.sql.functions import explode
d = [Row(a=1, intlist=[1, 2, 3], mapfield={"a": "b"})]
from pyspark.sql.functions import explode, explode_outer, posexplode_outer
d = [
Row(a=1, intlist=[1, 2, 3], mapfield={"a": "b"}),
Row(a=1, intlist=[], mapfield={}),
Row(a=1, intlist=None, mapfield=None),
]
rdd = self.sc.parallelize(d)
data = self.spark.createDataFrame(rdd)

Expand All @@ -272,6 +276,18 @@ def test_explode(self):
self.assertEqual(result[0][0], "a")
self.assertEqual(result[0][1], "b")

result = [tuple(x) for x in data.select(posexplode_outer("intlist")).collect()]
self.assertEqual(result, [(0, 1), (1, 2), (2, 3), (None, None), (None, None)])

result = [tuple(x) for x in data.select(posexplode_outer("mapfield")).collect()]
self.assertEqual(result, [(0, 'a', 'b'), (None, None, None), (None, None, None)])

result = [x[0] for x in data.select(explode_outer("intlist")).collect()]
self.assertEqual(result, [1, 2, 3, None, None])

result = [tuple(x) for x in data.select(explode_outer("mapfield")).collect()]
self.assertEqual(result, [('a', 'b'), (None, None), (None, None)])

def test_and_in_expression(self):
self.assertEqual(4, self.df.filter((self.df.key <= 10) & (self.df.value <= "2")).count())
self.assertRaises(ValueError, lambda: (self.df.key <= 10) and (self.df.value <= "2"))
Expand Down

0 comments on commit 215281d

Please sign in to comment.