diff --git a/python/pyspark/sql/connect/plan.py b/python/pyspark/sql/connect/plan.py index 3b9d9c7e6044e..7b894f7498a21 100644 --- a/python/pyspark/sql/connect/plan.py +++ b/python/pyspark/sql/connect/plan.py @@ -466,6 +466,7 @@ def __init__( def plan(self, session: "SparkConnectClient") -> proto.Relation: assert self._child is not None plan = proto.Relation() + plan.deduplicate.input.CopyFrom(self._child.plan(session)) plan.deduplicate.all_columns_as_keys = self.all_columns_as_keys if self.column_names is not None: plan.deduplicate.column_names.extend(self.column_names) diff --git a/python/pyspark/sql/tests/connect/test_connect_plan_only.py b/python/pyspark/sql/tests/connect/test_connect_plan_only.py index e7a14c4bb05f0..fc2a2a73f985b 100644 --- a/python/pyspark/sql/tests/connect/test_connect_plan_only.py +++ b/python/pyspark/sql/tests/connect/test_connect_plan_only.py @@ -226,6 +226,8 @@ def test_deduplicate(self): df = self.connect.readTable(table_name=self.tbl_name) distinct_plan = df.distinct()._plan.to_proto(self.connect) + self.assertTrue(distinct_plan.root.deduplicate.HasField("input"), "input must be set") + self.assertEqual(distinct_plan.root.deduplicate.all_columns_as_keys, True) self.assertEqual(len(distinct_plan.root.deduplicate.column_names), 0)