Skip to content

Commit

Permalink
[SPARK-43046][SS][CONNECT] Implemented Python API dropDuplicatesWithi…
Browse files Browse the repository at this point in the history
…nWatermark for Spark Connect

### What changes were proposed in this pull request?

Implemented `dropDuplicatesWithinWatermark` Python API for Spark Connect. This change is based on a previous [commit](0e9e34c) that introduced `dropDuplicatesWithinWatermark` API in Spark.

### Why are the changes needed?

We recently introduced dropDuplicatesWithinWatermark API in Spark ([commit link](0e9e34c)). We want to bring parity to the Spark Connect.

### Does this PR introduce _any_ user-facing change?

Yes, this introduces a new public API, dropDuplicatesWithinWatermark in Spark Connect.

### How was this patch tested?

Added new test cases in test suites.

Closes #40834 from bogao007/drop-dup-watermark.

Authored-by: bogao007 <bo.gao@databricks.com>
Signed-off-by: Hyukjin Kwon <gurwls223@apache.org>
  • Loading branch information
bogao007 authored and HyukjinKwon committed Apr 22, 2023
1 parent 069b48e commit 4d76511
Show file tree
Hide file tree
Showing 10 changed files with 183 additions and 83 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -362,6 +362,9 @@ message Deduplicate {
//
// This field does not co-use with `column_names`.
optional bool all_columns_as_keys = 3;

// (Optional) Deduplicate within the time range of watermark.
optional bool within_watermark = 4;
}

// A relation that does not need to be qualified by name.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -596,6 +596,17 @@ package object dsl {
.addAllColumnNames(colNames.asJava))
.build()

def deduplicateWithinWatermark(colNames: Seq[String]): Relation =
Relation
.newBuilder()
.setDeduplicate(
Deduplicate
.newBuilder()
.setInput(logicalPlan)
.addAllColumnNames(colNames.asJava)
.setWithinWatermark(true))
.build()

def distinct(): Relation =
Relation
.newBuilder()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.parser.{CatalystSqlParser, ParseException, ParserUtils}
import org.apache.spark.sql.catalyst.plans.{Cross, FullOuter, Inner, JoinType, LeftAnti, LeftOuter, LeftSemi, RightOuter, UsingJoin}
import org.apache.spark.sql.catalyst.plans.logical
import org.apache.spark.sql.catalyst.plans.logical.{CollectMetrics, CommandResult, Deduplicate, DeserializeToObject, Except, Intersect, LocalRelation, LogicalPlan, MapPartitions, Project, Sample, SerializeFromObject, Sort, SubqueryAlias, TypedFilter, Union, Unpivot, UnresolvedHint}
import org.apache.spark.sql.catalyst.plans.logical.{CollectMetrics, CommandResult, Deduplicate, DeduplicateWithinWatermark, DeserializeToObject, Except, Intersect, LocalRelation, LogicalPlan, MapPartitions, Project, Sample, SerializeFromObject, Sort, SubqueryAlias, TypedFilter, Union, Unpivot, UnresolvedHint}
import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, CharVarcharUtils}
import org.apache.spark.sql.connect.artifact.SparkConnectArtifactManager
import org.apache.spark.sql.connect.common.{DataTypeProtoConverter, InvalidPlanInput, LiteralValueProtoConverter, StorageLevelProtoConverter, UdfPacket}
Expand Down Expand Up @@ -738,7 +738,8 @@ class SparkConnectPlanner(val session: SparkSession) {
val resolver = session.sessionState.analyzer.resolver
val allColumns = queryExecution.analyzed.output
if (rel.getAllColumnsAsKeys) {
Deduplicate(allColumns, queryExecution.analyzed)
if (rel.getWithinWatermark) DeduplicateWithinWatermark(allColumns, queryExecution.analyzed)
else Deduplicate(allColumns, queryExecution.analyzed)
} else {
val toGroupColumnNames = rel.getColumnNamesList.asScala.toSeq
val groupCols = toGroupColumnNames.flatMap { (colName: String) =>
Expand All @@ -750,7 +751,8 @@ class SparkConnectPlanner(val session: SparkSession) {
}
cols
}
Deduplicate(groupCols, queryExecution.analyzed)
if (rel.getWithinWatermark) DeduplicateWithinWatermark(groupCols, queryExecution.analyzed)
else Deduplicate(groupCols, queryExecution.analyzed)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -381,6 +381,36 @@ class SparkConnectPlannerSuite extends SparkFunSuite with SparkConnectPlanTest {
assert(e2.getMessage.contains("either deduplicate on all columns or a subset of columns"))
}

test("Test invalid deduplicateWithinWatermark") {
val deduplicateWithinWatermark = proto.Deduplicate
.newBuilder()
.setInput(readRel)
.setAllColumnsAsKeys(true)
.addColumnNames("test")
.setWithinWatermark(true)

val e = intercept[InvalidPlanInput] {
transform(
proto.Relation.newBuilder
.setDeduplicate(deduplicateWithinWatermark)
.build())
}
assert(
e.getMessage.contains("Cannot deduplicate on both all columns and a subset of columns"))

val deduplicateWithinWatermark2 = proto.Deduplicate
.newBuilder()
.setInput(readRel)
.setWithinWatermark(true)
val e2 = intercept[InvalidPlanInput] {
transform(
proto.Relation.newBuilder
.setDeduplicate(deduplicateWithinWatermark2)
.build())
}
assert(e2.getMessage.contains("either deduplicate on all columns or a subset of columns"))
}

test("Test invalid intersect, except") {
// Except with union_by_name=true
val except = proto.SetOperation
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -348,6 +348,16 @@ class SparkConnectProtoSuite extends PlanTest with SparkConnectPlanTest {
comparePlans(connectPlan2, sparkPlan2)
}

test("Test basic deduplicateWithinWatermark") {
val connectPlan = connectTestRelation.distinct()
val sparkPlan = sparkTestRelation.distinct()
comparePlans(connectPlan, sparkPlan)

val connectPlan2 = connectTestRelation.deduplicateWithinWatermark(Seq("id", "name"))
val sparkPlan2 = sparkTestRelation.dropDuplicatesWithinWatermark(Seq("id", "name"))
comparePlans(connectPlan2, sparkPlan2)
}

test("Test union, except, intersect") {
val connectPlan1 = connectTestRelation.except(connectTestRelation, isAll = false)
val sparkPlan1 = sparkTestRelation.except(sparkTestRelation)
Expand Down
22 changes: 20 additions & 2 deletions python/pyspark/sql/connect/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -379,7 +379,26 @@ def dropDuplicates(self, subset: Optional[List[str]] = None) -> "DataFrame":
drop_duplicates = dropDuplicates

def dropDuplicatesWithinWatermark(self, subset: Optional[List[str]] = None) -> "DataFrame":
raise NotImplementedError("dropDuplicatesWithinWatermark() is not implemented.")
if subset is not None and not isinstance(subset, (list, tuple)):
raise PySparkTypeError(
error_class="NOT_LIST_OR_TUPLE",
message_parameters={"arg_name": "subset", "arg_type": type(subset).__name__},
)

if subset is None:
return DataFrame.withPlan(
plan.Deduplicate(child=self._plan, all_columns_as_keys=True, within_watermark=True),
session=self._session,
)
else:
return DataFrame.withPlan(
plan.Deduplicate(child=self._plan, column_names=subset, within_watermark=True),
session=self._session,
)

dropDuplicatesWithinWatermark.__doc__ = PySparkDataFrame.dropDuplicatesWithinWatermark.__doc__

drop_duplicates_within_watermark = dropDuplicatesWithinWatermark

def distinct(self) -> "DataFrame":
return DataFrame.withPlan(
Expand Down Expand Up @@ -595,7 +614,6 @@ def sample(
fraction: Optional[Union[int, float]] = None,
seed: Optional[int] = None,
) -> "DataFrame":

# For the cases below:
# sample(True, 0.5 [, seed])
# sample(True, fraction=0.5 [, seed])
Expand Down
3 changes: 3 additions & 0 deletions python/pyspark/sql/connect/plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -609,16 +609,19 @@ def __init__(
child: Optional["LogicalPlan"],
all_columns_as_keys: bool = False,
column_names: Optional[List[str]] = None,
within_watermark: bool = False,
) -> None:
super().__init__(child)
self.all_columns_as_keys = all_columns_as_keys
self.column_names = column_names
self.within_watermark = within_watermark

def plan(self, session: "SparkConnectClient") -> proto.Relation:
assert self._child is not None
plan = self._create_proto_relation()
plan.deduplicate.input.CopyFrom(self._child.plan(session))
plan.deduplicate.all_columns_as_keys = self.all_columns_as_keys
plan.deduplicate.within_watermark = self.within_watermark
if self.column_names is not None:
plan.deduplicate.column_names.extend(self.column_names)
return plan
Expand Down
152 changes: 76 additions & 76 deletions python/pyspark/sql/connect/proto/relations_pb2.py

Large diffs are not rendered by default.

17 changes: 17 additions & 0 deletions python/pyspark/sql/connect/proto/relations_pb2.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -1436,6 +1436,7 @@ class Deduplicate(google.protobuf.message.Message):
INPUT_FIELD_NUMBER: builtins.int
COLUMN_NAMES_FIELD_NUMBER: builtins.int
ALL_COLUMNS_AS_KEYS_FIELD_NUMBER: builtins.int
WITHIN_WATERMARK_FIELD_NUMBER: builtins.int
@property
def input(self) -> global___Relation:
"""(Required) Input relation for a Deduplicate."""
Expand All @@ -1452,41 +1453,57 @@ class Deduplicate(google.protobuf.message.Message):
This field does not co-use with `column_names`.
"""
within_watermark: builtins.bool
"""(Optional) Deduplicate within the time range of watermark."""
def __init__(
self,
*,
input: global___Relation | None = ...,
column_names: collections.abc.Iterable[builtins.str] | None = ...,
all_columns_as_keys: builtins.bool | None = ...,
within_watermark: builtins.bool | None = ...,
) -> None: ...
def HasField(
self,
field_name: typing_extensions.Literal[
"_all_columns_as_keys",
b"_all_columns_as_keys",
"_within_watermark",
b"_within_watermark",
"all_columns_as_keys",
b"all_columns_as_keys",
"input",
b"input",
"within_watermark",
b"within_watermark",
],
) -> builtins.bool: ...
def ClearField(
self,
field_name: typing_extensions.Literal[
"_all_columns_as_keys",
b"_all_columns_as_keys",
"_within_watermark",
b"_within_watermark",
"all_columns_as_keys",
b"all_columns_as_keys",
"column_names",
b"column_names",
"input",
b"input",
"within_watermark",
b"within_watermark",
],
) -> None: ...
@typing.overload
def WhichOneof(
self,
oneof_group: typing_extensions.Literal["_all_columns_as_keys", b"_all_columns_as_keys"],
) -> typing_extensions.Literal["all_columns_as_keys"] | None: ...
@typing.overload
def WhichOneof(
self, oneof_group: typing_extensions.Literal["_within_watermark", b"_within_watermark"]
) -> typing_extensions.Literal["within_watermark"] | None: ...

global___Deduplicate = Deduplicate

Expand Down
10 changes: 8 additions & 2 deletions python/pyspark/sql/tests/connect/test_connect_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -1213,6 +1213,14 @@ def test_deduplicate(self):
df.dropDuplicates(["name"]).toPandas(), df2.dropDuplicates(["name"]).toPandas()
)

def test_deduplicate_within_watermark_in_batch(self):
df = self.connect.read.table(self.tbl_name)
with self.assertRaisesRegex(
AnalysisException,
"dropDuplicatesWithinWatermark is not supported with batch DataFrames/DataSets",
):
df.dropDuplicatesWithinWatermark().toPandas()

def test_first(self):
# SPARK-41002: test `first` API in Python Client
df = self.connect.read.table(self.tbl_name)
Expand Down Expand Up @@ -1761,7 +1769,6 @@ def test_hint(self):
self.connect.read.table(self.tbl_name).hint("REPARTITION", "id", 3).toPandas()

def test_join_hint(self):

cdf1 = self.connect.createDataFrame([(2, "Alice"), (5, "Bob")], schema=["age", "name"])
cdf2 = self.connect.createDataFrame(
[Row(height=80, name="Tom"), Row(height=85, name="Bob")]
Expand Down Expand Up @@ -2284,7 +2291,6 @@ def test_crossjoin(self):
)

def test_grouped_data(self):

query = """
SELECT * FROM VALUES
('James', 'Sales', 3000, 2020),
Expand Down

0 comments on commit 4d76511

Please sign in to comment.