Skip to content

Commit

Permalink
[SPARK-10577] [PYSPARK] DataFrame hint for broadcast join
Browse files Browse the repository at this point in the history
https://issues.apache.org/jira/browse/SPARK-10577

Author: Jian Feng <jzhang.chs@gmail.com>

Closes #8801 from Jianfeng-chs/master.

(cherry picked from commit 0180b84)
Signed-off-by: Reynold Xin <rxin@databricks.com>

Conflicts:
	python/pyspark/sql/tests.py
  • Loading branch information
snowmoon-zhang authored and rxin committed Oct 14, 2015
1 parent f366249 commit 30eea40
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 0 deletions.
9 changes: 9 additions & 0 deletions python/pyspark/sql/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from pyspark.sql import since
from pyspark.sql.types import StringType
from pyspark.sql.column import Column, _to_java_column, _to_seq
from pyspark.sql.dataframe import DataFrame


def _create_function(name, doc=""):
Expand Down Expand Up @@ -190,6 +191,14 @@ def approxCountDistinct(col, rsd=None):
return Column(jc)


@since(1.6)
def broadcast(df):
"""Marks a DataFrame as small enough for use in broadcast joins."""

sc = SparkContext._active_spark_context
return DataFrame(sc._jvm.functions.broadcast(df._jdf), df.sql_ctx)


@since(1.4)
def coalesce(*cols):
"""Returns the first column that is not null.
Expand Down
27 changes: 27 additions & 0 deletions python/pyspark/sql/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -1056,6 +1056,33 @@ def test_with_column_with_existing_name(self):
keys = self.df.withColumn("key", self.df.key).select("key").collect()
self.assertEqual([r.key for r in keys], list(range(100)))

# regression test for SPARK-10417
def test_column_iterator(self):

def foo():
for x in self.df.key:
break

self.assertRaises(TypeError, foo)

# add test for SPARK-10577 (test broadcast join hint)
def test_functions_broadcast(self):
from pyspark.sql.functions import broadcast

df1 = self.sqlCtx.createDataFrame([(1, "1"), (2, "2")], ("key", "value"))
df2 = self.sqlCtx.createDataFrame([(1, "1"), (2, "2")], ("key", "value"))

# equijoin - should be converted into broadcast join
plan1 = df1.join(broadcast(df2), "key")._jdf.queryExecution().executedPlan()
self.assertEqual(1, plan1.toString().count("BroadcastHashJoin"))

# no join key -- should not be a broadcast join
plan2 = df1.join(broadcast(df2))._jdf.queryExecution().executedPlan()
self.assertEqual(0, plan2.toString().count("BroadcastHashJoin"))

# planner should not crash without a join
broadcast(df1)._jdf.queryExecution().executedPlan()


class HiveContextSQLTests(ReusedPySparkTestCase):

Expand Down

0 comments on commit 30eea40

Please sign in to comment.