Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions dev/sparktestsupport/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -612,6 +612,7 @@ def __hash__(self):
"pyspark.sql.tests.test_readwriter",
"pyspark.sql.tests.test_serde",
"pyspark.sql.tests.test_session",
"pyspark.sql.tests.test_nearest_by_join",
"pyspark.sql.tests.test_subquery",
"pyspark.sql.tests.test_types",
"pyspark.sql.tests.test_geographytype",
Expand Down Expand Up @@ -1174,6 +1175,7 @@ def __hash__(self):
"pyspark.sql.tests.connect.test_parity_observation",
"pyspark.sql.tests.connect.test_parity_repartition",
"pyspark.sql.tests.connect.test_parity_stat",
"pyspark.sql.tests.connect.test_parity_nearest_by_join",
"pyspark.sql.tests.connect.test_parity_subquery",
"pyspark.sql.tests.connect.test_parity_types",
"pyspark.sql.tests.connect.test_parity_column",
Expand Down
4 changes: 3 additions & 1 deletion project/MimaExcludes.scala
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,9 @@ object MimaExcludes {
// [SPARK-56330][CORE] Add TaskInterruptListener to TaskContext for interrupt notifications
ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.TaskContext.addTaskInterruptListener"),
// [SPARK-56700][SS] Make DataStreamReader.name public
ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.sql.streaming.DataStreamReader.name")
ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.sql.streaming.DataStreamReader.name"),
// [SPARK-56395][SQL] Add NEAREST BY top-K ranking join
ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.sql.Dataset.nearestByJoin")
)

// Exclude rules for 4.1.x from 4.0.0
Expand Down
1 change: 1 addition & 0 deletions python/docs/source/reference/pyspark.sql/dataframe.rst
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ DataFrame
DataFrame.metadataColumn
DataFrame.melt
DataFrame.na
DataFrame.nearestByJoin
DataFrame.observe
DataFrame.offset
DataFrame.orderBy
Expand Down
28 changes: 28 additions & 0 deletions python/pyspark/errors/error-conditions.json
Original file line number Diff line number Diff line change
Expand Up @@ -602,6 +602,34 @@
"Multiple pipeline spec files found in the directory `<dir_path>`. Please remove one or choose a particular one with the --spec argument."
]
},
"NEAREST_BY_JOIN": {
"message": [
"Invalid nearest-by join."
],
"sub_class": {
"NUM_RESULTS_OUT_OF_RANGE": {
"message": [
"The number of results <numResults> must be between <min> and <max>. Update the literal in `APPROX NEAREST <numResults> BY ...` (or `EXACT NEAREST <numResults> BY ...`) to fall within that range."
]
},
"UNSUPPORTED_DIRECTION": {
"message": [
"Unsupported nearest-by join direction '<direction>'. Supported nearest-by join directions include: <supported>."
]
},
"UNSUPPORTED_JOIN_TYPE": {
"message": [
"Unsupported nearest-by join type <joinType>. Supported types: <supported>."
]
},
"UNSUPPORTED_MODE": {
"message": [
"Unsupported nearest-by join mode '<mode>'. Supported modes include: <supported>."
]
}
},
"sqlState": "42604"
},
"NEGATIVE_VALUE": {
"message": [
"Value for `<arg_name>` must be greater than or equal to 0, got '<arg_value>'."
Expand Down
15 changes: 15 additions & 0 deletions python/pyspark/sql/classic/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -820,6 +820,21 @@ def lateralJoin(
jdf = self._jdf.lateralJoin(other._jdf, on._jc, how)
return DataFrame(jdf, self.sparkSession)

def nearestByJoin(
self,
other: ParentDataFrame,
rankingExpression: Column,
numResults: int,
mode: str,
direction: str,
*,
joinType: str = "inner",
) -> ParentDataFrame:
jdf = self._jdf.nearestByJoin(
other._jdf, rankingExpression._jc, int(numResults), mode, direction, joinType
)
return DataFrame(jdf, self.sparkSession)

# TODO(SPARK-22947): Fix the DataFrame API.
def _joinAsOf(
self,
Expand Down
24 changes: 24 additions & 0 deletions python/pyspark/sql/connect/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -726,6 +726,30 @@ def lateralJoin(
session=self._session,
)

def nearestByJoin(
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we need Spark connect tests for nearestByJoin - see lateralJoin tests in DataFrameSubquerySuite and PlanGenerationTestSuite

self,
other: ParentDataFrame,
rankingExpression: Column,
numResults: int,
mode: str,
direction: str,
*,
joinType: str = "inner",
) -> ParentDataFrame:
other = self._check_same_session(other)
return DataFrame(
plan.NearestByJoin(
left=self._plan,
right=other._plan,
ranking_expression=rankingExpression,
num_results=int(numResults),
join_type=joinType,
mode=mode,
direction=direction,
),
session=self._session,
)

def _joinAsOf(
self,
other: ParentDataFrame,
Expand Down
102 changes: 102 additions & 0 deletions python/pyspark/sql/connect/plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -1345,6 +1345,108 @@ def _repr_html_(self) -> str:
"""


# Acceptance lists for `nearestByJoin`. Must stay aligned with `NearestByJoinValidation` in
# `sql/api/.../catalyst/plans/NearestByJoinValidation.scala`.
_NEAREST_BY_JOIN_MAX_NUM_RESULTS = 100000
_NEAREST_BY_JOIN_SUPPORTED_JOIN_TYPES = frozenset({"inner", "leftouter", "left"})
_NEAREST_BY_JOIN_SUPPORTED_JOIN_TYPE_DISPLAY = "'INNER', 'LEFT OUTER'"
_NEAREST_BY_JOIN_SUPPORTED_MODES = ("approx", "exact")
_NEAREST_BY_JOIN_SUPPORTED_DIRECTIONS = ("distance", "similarity")


class NearestByJoin(LogicalPlan):
def __init__(
self,
left: Optional[LogicalPlan],
right: LogicalPlan,
ranking_expression: Column,
num_results: int,
join_type: str,
mode: str,
direction: str,
) -> None:
super().__init__(left, self._collect_references([ranking_expression]))
self.left = cast(LogicalPlan, left)
self.right = right
self.ranking_expression = ranking_expression
# Mirror of the Scala `Dataset.validateNearestByJoinArgs` validator -- raises the same
# `NEAREST_BY_JOIN.*` error classes the server would, so the user sees a consistent
# error regardless of where the check fires.
if num_results < 1 or num_results > _NEAREST_BY_JOIN_MAX_NUM_RESULTS:
raise AnalysisException(
errorClass="NEAREST_BY_JOIN.NUM_RESULTS_OUT_OF_RANGE",
messageParameters={
"numResults": str(num_results),
"min": "1",
"max": str(_NEAREST_BY_JOIN_MAX_NUM_RESULTS),
},
)
if join_type.lower().replace("_", "") not in _NEAREST_BY_JOIN_SUPPORTED_JOIN_TYPES:
raise AnalysisException(
errorClass="NEAREST_BY_JOIN.UNSUPPORTED_JOIN_TYPE",
messageParameters={
"joinType": join_type,
"supported": _NEAREST_BY_JOIN_SUPPORTED_JOIN_TYPE_DISPLAY,
},
)
if mode.lower() not in _NEAREST_BY_JOIN_SUPPORTED_MODES:
raise AnalysisException(
errorClass="NEAREST_BY_JOIN.UNSUPPORTED_MODE",
messageParameters={
"mode": mode,
"supported": "'" + "', '".join(_NEAREST_BY_JOIN_SUPPORTED_MODES) + "'",
},
)
if direction.lower() not in _NEAREST_BY_JOIN_SUPPORTED_DIRECTIONS:
raise AnalysisException(
errorClass="NEAREST_BY_JOIN.UNSUPPORTED_DIRECTION",
messageParameters={
"direction": direction,
"supported": "'" + "', '".join(_NEAREST_BY_JOIN_SUPPORTED_DIRECTIONS) + "'",
},
)
self.num_results = int(num_results)
self.join_type = join_type
self.mode = mode
self.direction = direction

def plan(self, session: "SparkConnectClient") -> proto.Relation:
plan = self._create_proto_relation()
plan.nearest_by_join.left.CopyFrom(self.left.plan(session))
plan.nearest_by_join.right.CopyFrom(self.right.plan(session))
plan.nearest_by_join.ranking_expression.CopyFrom(self.ranking_expression.to_plan(session))
plan.nearest_by_join.num_results = self.num_results
plan.nearest_by_join.join_type = self.join_type
plan.nearest_by_join.mode = self.mode
plan.nearest_by_join.direction = self.direction
return self._with_relations(plan, session)

@property
def observations(self) -> Dict[str, "Observation"]:
return {**super().observations, **self.right.observations}

def print(self, indent: int = 0) -> str:
i = " " * indent
o = " " * (indent + LogicalPlan.INDENT)
n = indent + LogicalPlan.INDENT * 2
return (
f"{i}<NearestByJoin numResults={self.num_results} joinType={self.join_type} "
f"mode={self.mode} direction={self.direction}>\n{o}"
f"left=\n{self.left.print(n)}\n{o}right=\n{self.right.print(n)}"
)

def _repr_html_(self) -> str:
return f"""
<ul>
<li>
<b>NearestByJoin</b><br />
Left: {self.left._repr_html_()}
Right: {self.right._repr_html_()}
</li>
</uL>
"""


class SetOperation(LogicalPlan):
def __init__(
self,
Expand Down
350 changes: 176 additions & 174 deletions python/pyspark/sql/connect/proto/relations_pb2.py

Large diffs are not rendered by default.

85 changes: 85 additions & 0 deletions python/pyspark/sql/connect/proto/relations_pb2.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,7 @@ class Relation(google.protobuf.message.Message):
LATERAL_JOIN_FIELD_NUMBER: builtins.int
CHUNKED_CACHED_LOCAL_RELATION_FIELD_NUMBER: builtins.int
RELATION_CHANGES_FIELD_NUMBER: builtins.int
NEAREST_BY_JOIN_FIELD_NUMBER: builtins.int
FILL_NA_FIELD_NUMBER: builtins.int
DROP_NA_FIELD_NUMBER: builtins.int
REPLACE_FIELD_NUMBER: builtins.int
Expand Down Expand Up @@ -223,6 +224,8 @@ class Relation(google.protobuf.message.Message):
@property
def relation_changes(self) -> global___RelationChanges: ...
@property
def nearest_by_join(self) -> global___NearestByJoin: ...
@property
def fill_na(self) -> global___NAFill:
"""NA functions"""
@property
Expand Down Expand Up @@ -310,6 +313,7 @@ class Relation(google.protobuf.message.Message):
lateral_join: global___LateralJoin | None = ...,
chunked_cached_local_relation: global___ChunkedCachedLocalRelation | None = ...,
relation_changes: global___RelationChanges | None = ...,
nearest_by_join: global___NearestByJoin | None = ...,
fill_na: global___NAFill | None = ...,
drop_na: global___NADrop | None = ...,
replace: global___NAReplace | None = ...,
Expand Down Expand Up @@ -395,6 +399,8 @@ class Relation(google.protobuf.message.Message):
b"map_partitions",
"ml_relation",
b"ml_relation",
"nearest_by_join",
b"nearest_by_join",
"offset",
b"offset",
"parse",
Expand Down Expand Up @@ -524,6 +530,8 @@ class Relation(google.protobuf.message.Message):
b"map_partitions",
"ml_relation",
b"ml_relation",
"nearest_by_join",
b"nearest_by_join",
"offset",
b"offset",
"parse",
Expand Down Expand Up @@ -633,6 +641,7 @@ class Relation(google.protobuf.message.Message):
"lateral_join",
"chunked_cached_local_relation",
"relation_changes",
"nearest_by_join",
"fill_na",
"drop_na",
"replace",
Expand Down Expand Up @@ -4657,3 +4666,79 @@ class LateralJoin(google.protobuf.message.Message):
) -> None: ...

global___LateralJoin = LateralJoin

class NearestByJoin(google.protobuf.message.Message):
"""Relation of type [[NearestByJoin]].

For each row on the left side, returns up to `num_results` rows from the right side ordered
by `ranking_expression`.
"""

DESCRIPTOR: google.protobuf.descriptor.Descriptor

LEFT_FIELD_NUMBER: builtins.int
RIGHT_FIELD_NUMBER: builtins.int
RANKING_EXPRESSION_FIELD_NUMBER: builtins.int
NUM_RESULTS_FIELD_NUMBER: builtins.int
JOIN_TYPE_FIELD_NUMBER: builtins.int
MODE_FIELD_NUMBER: builtins.int
DIRECTION_FIELD_NUMBER: builtins.int
@property
def left(self) -> global___Relation:
"""(Required) Left (query) input relation."""
@property
def right(self) -> global___Relation:
"""(Required) Right (base) input relation."""
@property
def ranking_expression(self) -> pyspark.sql.connect.proto.expressions_pb2.Expression:
"""(Required) Scalar expression used to rank candidate rows on the right side."""
num_results: builtins.int
"""(Required) Maximum number of matches per left row. Must be between 1 and 100000."""
join_type: builtins.str
"""The following three fields use `string` (not typed enums) for parity with `AsOfJoin`,
which models analogous fields the same way. Validation happens server-side at planning time.

(Required) The join type. Must be one of: "inner", "leftouter".
"""
mode: builtins.str
"""(Required) Search algorithm contract. Must be one of: "approx", "exact"."""
direction: builtins.str
"""(Required) Ranking direction. Must be one of: "distance", "similarity"."""
def __init__(
self,
*,
left: global___Relation | None = ...,
right: global___Relation | None = ...,
ranking_expression: pyspark.sql.connect.proto.expressions_pb2.Expression | None = ...,
num_results: builtins.int = ...,
join_type: builtins.str = ...,
mode: builtins.str = ...,
direction: builtins.str = ...,
) -> None: ...
def HasField(
self,
field_name: typing_extensions.Literal[
"left", b"left", "ranking_expression", b"ranking_expression", "right", b"right"
],
) -> builtins.bool: ...
def ClearField(
self,
field_name: typing_extensions.Literal[
"direction",
b"direction",
"join_type",
b"join_type",
"left",
b"left",
"mode",
b"mode",
"num_results",
b"num_results",
"ranking_expression",
b"ranking_expression",
"right",
b"right",
],
) -> None: ...

global___NearestByJoin = NearestByJoin
Loading