Skip to content

Commit

Permalink
[SPARK-40970][CONNECT][PYTHON][COLUMN] Support List for Join's on arg…
Browse files Browse the repository at this point in the history
…ument

### What changes were proposed in this pull request?

This PR adds the support for Join on a list of columns.

### Why are the changes needed?

API coverage

### Does this PR introduce _any_ user-facing change?

NO

### How was this patch tested?

UT

Closes #38866 from amaliujia/join_condition_list_final.

Authored-by: Rui Wang <rui.wang@databricks.com>
Signed-off-by: Hyukjin Kwon <gurwls223@apache.org>
  • Loading branch information
amaliujia authored and HyukjinKwon committed Dec 2, 2022
1 parent 3fc8a90 commit 5d90f98
Show file tree
Hide file tree
Showing 4 changed files with 46 additions and 4 deletions.
2 changes: 1 addition & 1 deletion python/pyspark/sql/connect/dataframe.py
Expand Up @@ -468,7 +468,7 @@ def take(self, num: int) -> List[Row]:
def join(
self,
other: "DataFrame",
on: Optional[Union[str, List[str], Column]] = None,
on: Optional[Union[str, List[str], Column, List[Column]]] = None,
how: Optional[str] = None,
) -> "DataFrame":
if self._plan is None:
Expand Down
11 changes: 8 additions & 3 deletions python/pyspark/sql/connect/plan.py
Expand Up @@ -16,6 +16,7 @@
#

from typing import Any, List, Optional, Sequence, Union, cast, TYPE_CHECKING, Mapping, Dict
import functools
import pandas
import pyarrow as pa
import pyspark.sql.connect.proto as proto
Expand Down Expand Up @@ -675,7 +676,7 @@ def __init__(
self,
left: Optional["LogicalPlan"],
right: "LogicalPlan",
on: Optional[Union[str, List[str], Column]],
on: Optional[Union[str, List[str], Column, List[Column]]],
how: Optional[str],
) -> None:
super().__init__(left)
Expand Down Expand Up @@ -721,8 +722,12 @@ def plan(self, session: "SparkConnectClient") -> proto.Relation:
rel.join.using_columns.append(self.on)
else:
rel.join.join_condition.CopyFrom(self.to_attr_or_expression(self.on, session))
else:
rel.join.using_columns.extend(self.on)
elif len(self.on) > 0:
if isinstance(self.on[0], str):
rel.join.using_columns.extend(cast(str, self.on))
else:
merge_column = functools.reduce(lambda c1, c2: c1 & c2, self.on)
rel.join.join_condition.CopyFrom(cast(Column, merge_column).to_plan(session))
rel.join.join_type = self.how
return rel

Expand Down
32 changes: 32 additions & 0 deletions python/pyspark/sql/tests/connect/test_connect_basic.py
Expand Up @@ -70,6 +70,7 @@ def setUpClass(cls: Any):
cls.df_text = cls.sc.parallelize(cls.testDataStr).toDF()

cls.tbl_name = "test_connect_basic_table_1"
cls.tbl_name2 = "test_connect_basic_table_2"
cls.tbl_name_empty = "test_connect_basic_table_empty"

# Cleanup test data
Expand All @@ -90,6 +91,8 @@ def spark_connect_load_test_data(cls: Any):
# Since we might create multiple Spark sessions, we need to create global temporary view
# that is specifically maintained in the "global_temp" schema.
df.write.saveAsTable(cls.tbl_name)
df2 = cls.spark.createDataFrame([(x, f"{x}") for x in range(100)], ["col1", "col2"])
df2.write.saveAsTable(cls.tbl_name2)
empty_table_schema = StructType(
[
StructField("firstname", StringType(), True),
Expand All @@ -104,6 +107,7 @@ def spark_connect_load_test_data(cls: Any):
@classmethod
def spark_connect_clean_up_test_data(cls: Any) -> None:
cls.spark.sql("DROP TABLE IF EXISTS {}".format(cls.tbl_name))
cls.spark.sql("DROP TABLE IF EXISTS {}".format(cls.tbl_name2))
cls.spark.sql("DROP TABLE IF EXISTS {}".format(cls.tbl_name_empty))


Expand All @@ -114,6 +118,34 @@ def test_simple_read(self):
# Check that the limit is applied
self.assertEqual(len(data.index), 10)

def test_join_condition_column_list_columns(self):
left_connect_df = self.connect.read.table(self.tbl_name)
right_connect_df = self.connect.read.table(self.tbl_name2)
left_spark_df = self.spark.read.table(self.tbl_name)
right_spark_df = self.spark.read.table(self.tbl_name2)
joined_plan = left_connect_df.join(
other=right_connect_df, on=left_connect_df.id == right_connect_df.col1, how="inner"
)
joined_plan2 = left_spark_df.join(
other=right_spark_df, on=left_spark_df.id == right_spark_df.col1, how="inner"
)
self.assert_eq(joined_plan.toPandas(), joined_plan2.toPandas())

joined_plan3 = left_connect_df.join(
other=right_connect_df,
on=[
left_connect_df.id == right_connect_df.col1,
left_connect_df.name == right_connect_df.col2,
],
how="inner",
)
joined_plan4 = left_spark_df.join(
other=right_spark_df,
on=[left_spark_df.id == right_spark_df.col1, left_spark_df.name == right_spark_df.col2],
how="inner",
)
self.assert_eq(joined_plan3.toPandas(), joined_plan4.toPandas())

def test_columns(self):
# SPARK-41036: test `columns` API for python client.
df = self.connect.read.table(self.tbl_name)
Expand Down
5 changes: 5 additions & 0 deletions python/pyspark/sql/tests/connect/test_connect_plan_only.py
Expand Up @@ -59,6 +59,11 @@ def test_join_condition(self):
other=right_input, on=left_input.name == right_input.name
)._plan.to_proto(self.connect)
self.assertIsNotNone(plan.root.join.join_condition)
plan = left_input.join(
other=right_input,
on=[left_input.name == right_input.name, left_input.age == right_input.age],
)._plan.to_proto(self.connect)
self.assertIsNotNone(plan.root.join.join_condition)

def test_crossjoin(self):
# SPARK-41227: Test CrossJoin
Expand Down

0 comments on commit 5d90f98

Please sign in to comment.