Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import org.apache.spark.sql.catalyst.expressions.aggregate._
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules._
import org.apache.spark.sql.internal.SQLConf

/**
* Replaces a logical [[NearestByJoin]] operator with a `Generate(Inline(...))` over an
Expand Down Expand Up @@ -71,7 +72,9 @@ import org.apache.spark.sql.catalyst.rules._
object RewriteNearestByJoin extends Rule[LogicalPlan] {
private lazy val random = new scala.util.Random()

def apply(plan: LogicalPlan): LogicalPlan = plan.transformUp {
def apply(plan: LogicalPlan): LogicalPlan = {
if (SQLConf.get.nearestByStreamingHeapEnabled) return plan
plan.transformUp {
case j @ NearestByJoin(left, right, joinType, _, numResults, rankingExpression, direction) =>
// 1. Tag each left row with a unique id so that rows from the same left row can later be
// grouped together after the cross-join with `right`.
Expand Down Expand Up @@ -151,4 +154,5 @@ object RewriteNearestByJoin extends Rule[LogicalPlan] {
// 5. Final `Project` pinning the output schema to `NearestByJoin.output`.
Project(j.output, generate)
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -2321,6 +2321,15 @@ object SQLConf {
.booleanConf
.createWithDefault(true)

val NEAREST_BY_STREAMING_HEAP_ENABLED =
buildConf("spark.sql.join.nearestBy.streamingHeap.enabled")
.internal()
.doc("When true, NearestByJoin uses a streaming heap operator instead of the " +
"cross-product + aggregate rewrite.")
.version("4.1.0")
.booleanConf
.createWithDefault(false)

val ORDER_BY_ORDINAL = buildConf("spark.sql.orderByOrdinal")
.doc("When true, the ordinal numbers are treated as the position in the select list. " +
"When false, the ordinal numbers in order/sort by clause are ignored.")
Expand Down Expand Up @@ -8157,6 +8166,9 @@ class SQLConf extends Serializable with Logging with SqlApiConf {

def crossJoinEnabled: Boolean = getConf(SQLConf.CROSS_JOINS_ENABLED)

def nearestByStreamingHeapEnabled: Boolean =
getConf(SQLConf.NEAREST_BY_STREAMING_HEAP_ENABLED)

override def sessionLocalTimeZone: String = getConf(SQLConf.SESSION_LOCAL_TIMEZONE)

def jsonGeneratorIgnoreNullFields: Boolean = getConf(SQLConf.JSON_GENERATOR_IGNORE_NULL_FIELDS)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ class SparkPlanner(val session: SparkSession, val experimentalMethods: Experimen
Window ::
WindowGroupLimit ::
JoinSelection ::
NearestByJoinSelection ::
InMemoryScans ::
SparkScripts ::
Pipelines ::
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,21 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
* Supports both equi-joins and non-equi-joins.
* Supports only inner like joins.
*/

object NearestByJoinSelection extends Strategy {
def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
case j: NearestByJoin if conf.nearestByStreamingHeapEnabled =>
joins.StreamingNearestByJoinExec(
planLater(j.left),
planLater(j.right),
j.joinType,
j.numResults,
j.rankingExpression,
j.direction) :: Nil
case _ => Nil
}
}

object JoinSelection extends Strategy with JoinSelectionHelper {
private val hintErrorHandler = conf.hintErrorHandler

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.execution.joins

import java.util.{PriorityQueue => JPriorityQueue}

import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.{JoinType, LeftOuter, NearestByDirection, NearestByDistance}
import org.apache.spark.sql.catalyst.plans.physical._
import org.apache.spark.sql.execution.{BinaryExecNode, SparkPlan}
import org.apache.spark.sql.execution.metric.SQLMetrics

/**
* Physical operator for NearestByJoin that avoids materializing the full cross product.
* For each left row, iterates all right rows maintaining a bounded priority queue of size k,
* then emits the top-k matches directly.
*/
case class StreamingNearestByJoinExec(
left: SparkPlan,
right: SparkPlan,
joinType: JoinType,
numResults: Int,
rankingExpression: Expression,
direction: NearestByDirection) extends BinaryExecNode {

override def output: Seq[Attribute] =
left.output.map(_.withNullability(true)) ++ right.output.map(_.withNullability(true))

override lazy val metrics = Map(
"numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"))

override def requiredChildDistribution: Seq[Distribution] =
UnspecifiedDistribution :: BroadcastDistribution(IdentityBroadcastMode) :: Nil

override def outputPartitioning: Partitioning = left.outputPartitioning

override def outputOrdering: Seq[SortOrder] = Nil

protected override def doExecute(): RDD[InternalRow] = {
val broadcastedRight = right.executeBroadcast[Array[InternalRow]]()
val numOutput = longMetric("numOutputRows")
val k = numResults
val isDistance = direction == NearestByDistance
val leftOutput = left.output
val rightOutput = right.output
val rankExpr = rankingExpression
val allOutput = output

left.execute().mapPartitionsInternal { leftIter =>
val rightRows = broadcastedRight.value
if (rightRows.isEmpty && joinType != LeftOuter) {
Iterator.empty
} else {
val joinedRow = new JoinedRow
val rankingProj = UnsafeProjection.create(
Seq(rankExpr), leftOutput ++ rightOutput)
val resultProj = UnsafeProjection.create(allOutput, allOutput)

// For DISTANCE: we want smallest values, so the heap should evict the largest
// (max-heap on ranking value, evict top when full).
// For SIMILARITY: we want largest values, so the heap should evict the smallest
// (min-heap on ranking value, evict top when full).
// Java PriorityQueue is a min-heap. BoundedPriorityQueue keeps top-k by evicting
// the smallest. So:
// - DISTANCE: we want to keep k smallest. Use a max-heap (reverse ordering) so
// the largest is at the top and gets evicted → keeps k smallest.
// Actually BoundedPriorityQueue keeps elements where ord.gt(new, head) is true.
// With natural ordering (ascending), head is the smallest, and we keep elements
// that are greater → keeps k largest. That's wrong for distance.
// For DISTANCE: we want k smallest. Use reverse ordering so head is largest,
// and gt(new, head) means new > head in reverse = new < head in natural.
// Actually let me just use a simple Java PriorityQueue approach.

leftIter.flatMap { leftRow =>
// For each left row, find top-k right rows
// Use a max-heap of size k for DISTANCE (keep k smallest)
// Use a min-heap of size k for SIMILARITY (keep k largest)
val heap = if (isDistance) {
// Max-heap: largest ranking at top, evict it to keep k smallest
new JPriorityQueue[(InternalRow, Double)](k + 1,
(a: (InternalRow, Double), b: (InternalRow, Double)) =>
java.lang.Double.compare(b._2, a._2))
} else {
// Min-heap: smallest ranking at top, evict it to keep k largest
new JPriorityQueue[(InternalRow, Double)](k + 1,
(a: (InternalRow, Double), b: (InternalRow, Double)) =>
java.lang.Double.compare(a._2, b._2))
}

var i = 0
while (i < rightRows.length) {
val rightRow = rightRows(i)
joinedRow(leftRow, rightRow)
val rankingValue = rankingProj(joinedRow).getDouble(0)
if (!java.lang.Double.isNaN(rankingValue)) {
heap.offer((rightRow.copy(), rankingValue))
if (heap.size() > k) {
heap.poll()
}
}
i += 1
}

if (heap.isEmpty && joinType == LeftOuter) {
// Emit left row with null right columns
val nullRight = new GenericInternalRow(rightOutput.size)
joinedRow(leftRow, nullRight)
numOutput += 1
Iterator.single(resultProj(joinedRow).copy())
} else {
// Drain heap in best-first order (reverse the heap order)
val results = new Array[(InternalRow, Double)](heap.size())
var idx = heap.size() - 1
while (!heap.isEmpty) {
results(idx) = heap.poll()
idx -= 1
}
results.iterator.map { case (rightRow, _) =>
joinedRow(leftRow, rightRow)
numOutput += 1
resultProj(joinedRow).copy()
}
}
}
}
}
}

override protected def withNewChildrenInternal(
newLeft: SparkPlan, newRight: SparkPlan): StreamingNearestByJoinExec =
copy(left = newLeft, right = newRight)
}
Loading