From 74c826429416493a6d1d0efdf83b0e561dc33591 Mon Sep 17 00:00:00 2001 From: Rui Wang Date: Mon, 24 Oct 2022 10:50:55 +0800 Subject: [PATCH] [SPARK-40812][CONNECT][PYTHON][FOLLOW-UP] Improve Deduplicate in Python client ### What changes were proposed in this pull request? Following up on https://github.com/apache/spark/pull/38276, this PR improve both `distinct()` and `dropDuplicates` DataFrame API in Python client, which both depends on `Deduplicate` plan in the Connect proto. ### Why are the changes needed? Improve API coverage. ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? UT Closes #38327 from amaliujia/python_deduplicate. Authored-by: Rui Wang Signed-off-by: Wenchen Fan --- python/pyspark/sql/connect/dataframe.py | 41 +++++++++++++++++-- python/pyspark/sql/connect/plan.py | 39 ++++++++++++++++++ .../tests/connect/test_connect_plan_only.py | 19 +++++++++ 3 files changed, 95 insertions(+), 4 deletions(-) diff --git a/python/pyspark/sql/connect/dataframe.py b/python/pyspark/sql/connect/dataframe.py index eabcf433ae9bc..2b7e3d520391d 100644 --- a/python/pyspark/sql/connect/dataframe.py +++ b/python/pyspark/sql/connect/dataframe.py @@ -157,11 +157,44 @@ def coalesce(self, num_partitions: int) -> "DataFrame": def describe(self, cols: List[ColumnRef]) -> Any: ... + def dropDuplicates(self, subset: Optional[List[str]] = None) -> "DataFrame": + """Return a new :class:`DataFrame` with duplicate rows removed, + optionally only deduplicating based on certain columns. + + .. versionadded:: 3.4.0 + + Parameters + ---------- + subset : List of column names, optional + List of columns to use for duplicate comparison (default All columns). + + Returns + ------- + :class:`DataFrame` + DataFrame without duplicated rows. + """ + if subset is None: + return DataFrame.withPlan( + plan.Deduplicate(child=self._plan, all_columns_as_keys=True), session=self._session + ) + else: + return DataFrame.withPlan( + plan.Deduplicate(child=self._plan, column_names=subset), session=self._session + ) + def distinct(self) -> "DataFrame": - """Returns all distinct rows.""" - all_cols = self.columns - gf = self.groupBy(*all_cols) - return gf.agg() + """Returns a new :class:`DataFrame` containing the distinct rows in this :class:`DataFrame`. + + .. versionadded:: 3.4.0 + + Returns + ------- + :class:`DataFrame` + DataFrame with distinct rows. + """ + return DataFrame.withPlan( + plan.Deduplicate(child=self._plan, all_columns_as_keys=True), session=self._session + ) def drop(self, *cols: "ColumnOrString") -> "DataFrame": all_cols = self.columns diff --git a/python/pyspark/sql/connect/plan.py b/python/pyspark/sql/connect/plan.py index 297b15994d3bc..d6b6f9e3b67dd 100644 --- a/python/pyspark/sql/connect/plan.py +++ b/python/pyspark/sql/connect/plan.py @@ -327,6 +327,45 @@ def _repr_html_(self) -> str: """ +class Deduplicate(LogicalPlan): + def __init__( + self, + child: Optional["LogicalPlan"], + all_columns_as_keys: bool = False, + column_names: Optional[List[str]] = None, + ) -> None: + super().__init__(child) + self.all_columns_as_keys = all_columns_as_keys + self.column_names = column_names + + def plan(self, session: Optional["RemoteSparkSession"]) -> proto.Relation: + assert self._child is not None + plan = proto.Relation() + plan.deduplicate.all_columns_as_keys = self.all_columns_as_keys + if self.column_names is not None: + plan.deduplicate.column_names.extend(self.column_names) + return plan + + def print(self, indent: int = 0) -> str: + c_buf = self._child.print(indent + LogicalPlan.INDENT) if self._child else "" + return ( + f"{' ' * indent}\n{c_buf}" + ) + + def _repr_html_(self) -> str: + return f""" +
    +
  • + Deduplicate
    + all_columns_as_keys: {self.all_columns_as_keys}
    + column_names: {self.column_names}
    + {self._child_repr_()} +
  • +
+ """ + + class Sort(LogicalPlan): def __init__( self, child: Optional["LogicalPlan"], *columns: Union[SortOrder, ColumnRef, str] diff --git a/python/pyspark/sql/tests/connect/test_connect_plan_only.py b/python/pyspark/sql/tests/connect/test_connect_plan_only.py index 3b609db7a028d..450f5c70fabad 100644 --- a/python/pyspark/sql/tests/connect/test_connect_plan_only.py +++ b/python/pyspark/sql/tests/connect/test_connect_plan_only.py @@ -72,6 +72,25 @@ def test_sample(self): self.assertEqual(plan.root.sample.with_replacement, True) self.assertEqual(plan.root.sample.seed.seed, -1) + def test_deduplicate(self): + df = self.connect.readTable(table_name=self.tbl_name) + + distinct_plan = df.distinct()._plan.to_proto(self.connect) + self.assertEqual(distinct_plan.root.deduplicate.all_columns_as_keys, True) + self.assertEqual(len(distinct_plan.root.deduplicate.column_names), 0) + + deduplicate_on_all_columns_plan = df.dropDuplicates()._plan.to_proto(self.connect) + self.assertEqual(deduplicate_on_all_columns_plan.root.deduplicate.all_columns_as_keys, True) + self.assertEqual(len(deduplicate_on_all_columns_plan.root.deduplicate.column_names), 0) + + deduplicate_on_subset_columns_plan = df.dropDuplicates(["name", "height"])._plan.to_proto( + self.connect + ) + self.assertEqual( + deduplicate_on_subset_columns_plan.root.deduplicate.all_columns_as_keys, False + ) + self.assertEqual(len(deduplicate_on_subset_columns_plan.root.deduplicate.column_names), 2) + def test_relation_alias(self): df = self.connect.readTable(table_name=self.tbl_name) plan = df.alias("table_alias")._plan.to_proto(self.connect)