diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteNearestByJoin.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteNearestByJoin.scala index 3d45855cd60da..e920bbfffc550 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteNearestByJoin.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteNearestByJoin.scala @@ -41,7 +41,7 @@ import org.apache.spark.sql.catalyst.rules._ * +- Aggregate [__qid], * [first(left.col0) AS left.col0, ..., first(left.colN-1) AS left.colN-1, * max_by(struct(right.*), expr, k) AS _matches] - * +- Join LeftOuter + * +- Join Inner // or LeftOuter for `LEFT OUTER NEAREST BY` * :- Project [left.*, uuid() AS __qid] * : +- left * +- right @@ -79,18 +79,18 @@ object RewriteNearestByJoin extends Rule[LogicalPlan] { val taggedLeft = Project(left.output :+ qidAlias, left) val qidAttr = qidAlias.toAttribute - // 2. LEFT OUTER-join the tagged left with right (no join condition). LEFT OUTER - // (rather than INNER) preserves left rows even when `right` is empty, so that a - // `LEFT OUTER NEAREST BY` query still returns those rows with `NULL` right-side - // columns after the aggregate + inline below. When `right` is non-empty every left - // row already has right-row pairings, so LEFT OUTER and INNER are equivalent. + // 2. Join the tagged left with right (no join condition), using the user's join type. + // For `LEFT OUTER`, left rows with no right-side match are preserved with `NULL` + // right-side columns through the aggregate + inline below; for `INNER`, such rows + // are dropped. When `right` is non-empty every left row already has right-row + // pairings, so `LEFT OUTER` and `INNER` are equivalent in that case. // // This synthetic join is an unconditioned cross-product, so `NEAREST BY` queries // are subject to `CheckCartesianProducts` and will be rejected when the user has // set `spark.sql.crossJoin.enabled = false`. That is intentional: if the user has // opted out of cross-products, the NEAREST BY rewrite -- which is itself a bounded // cross-product today -- should not silently bypass that choice. - val join = Join(taggedLeft, right, LeftOuter, None, JoinHint.NONE) + val join = Join(taggedLeft, right, joinType, None, JoinHint.NONE) val (aggInput, rankingForAgg) = if (!rankingExpression.deterministic) { val rankingAlias = Alias(rankingExpression, "__ranking__")() diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteNearestByJoinSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteNearestByJoinSuite.scala index 650bdc7a6c358..729b58394d4bc 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteNearestByJoinSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteNearestByJoinSuite.scala @@ -21,7 +21,7 @@ import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.catalyst.expressions.{Alias, AttributeReference, CreateStruct, Inline, Literal, Rand, Uuid} import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, First, MaxMinByK} -import org.apache.spark.sql.catalyst.plans.{Inner, LeftOuter, NearestByDistance, NearestBySimilarity, PlanTest} +import org.apache.spark.sql.catalyst.plans.{Inner, JoinType, LeftOuter, NearestByDistance, NearestBySimilarity, PlanTest} import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Generate, Join, JoinHint, LocalRelation, NearestByJoin, Project} import org.apache.spark.sql.types.IntegerType @@ -41,10 +41,10 @@ class RewriteNearestByJoinSuite extends PlanTest { numResults: Int, ranking: org.apache.spark.sql.catalyst.expressions.Expression, reverse: Boolean, - outer: Boolean) = { + joinType: JoinType) = { val qidAlias = Alias(Uuid(Some(0L)), "__qid")() val taggedLeft = Project(left.output :+ qidAlias, left) - val join = Join(taggedLeft, right, LeftOuter, None, JoinHint.NONE) + val join = Join(taggedLeft, right, joinType, None, JoinHint.NONE) val rightStruct = CreateStruct(right.output) val topKAgg = MaxMinByK( @@ -66,7 +66,7 @@ class RewriteNearestByJoinSuite extends PlanTest { val generate = Generate( Inline(matchesAlias.toAttribute), unrequiredChildIndex = Seq(aggregate.output.indexOf(matchesAlias.toAttribute)), - outer = outer, + outer = joinType == LeftOuter, qualifier = None, generatorOutput = generatorOutput, child = aggregate) @@ -89,7 +89,7 @@ class RewriteNearestByJoinSuite extends PlanTest { val expected = expectedRewrite( left, right, 5, ranking = left.output(0) + right.output(0), - reverse = false, outer = false) + reverse = false, joinType = Inner) comparePlans(normalizeUuidSeed(rewritten), expected, checkAnalysis = false) } @@ -106,7 +106,7 @@ class RewriteNearestByJoinSuite extends PlanTest { val expected = expectedRewrite( left, right, 3, ranking = left.output(0) - right.output(0), - reverse = true, outer = false) + reverse = true, joinType = Inner) comparePlans(normalizeUuidSeed(rewritten), expected, checkAnalysis = false) } @@ -123,7 +123,7 @@ class RewriteNearestByJoinSuite extends PlanTest { val expected = expectedRewrite( left, right, 1, ranking = left.output(0) + right.output(0), - reverse = false, outer = true) + reverse = false, joinType = LeftOuter) comparePlans(normalizeUuidSeed(rewritten), expected, checkAnalysis = false) } @@ -140,11 +140,38 @@ class RewriteNearestByJoinSuite extends PlanTest { val expected = expectedRewrite( left, right, 2, ranking = left.output(0) - right.output(0), - reverse = true, outer = true) + reverse = true, joinType = LeftOuter) comparePlans(normalizeUuidSeed(rewritten), expected, checkAnalysis = false) } + test("synthetic Join uses the user's joinType") { + // Locks in that the rewrite's synthetic Join carries the user's `joinType` + // (Inner or LeftOuter). + val left = LocalRelation($"a".int, $"b".int) + val right = LocalRelation($"x".int, $"y".int) + Seq(Inner, LeftOuter).foreach { joinType => + val query = NearestByJoin( + left, right, joinType, approx = true, numResults = 1, + rankingExpression = left.output(0) + right.output(0), + direction = NearestBySimilarity) + + val rewritten = RewriteNearestByJoin(query.analyze) + val syntheticJoin = rewritten.collect { case j: Join => j } + assert(syntheticJoin.size == 1, + s"expected exactly one synthetic Join in the rewritten plan, got ${syntheticJoin.size}") + assert(syntheticJoin.head.joinType == joinType, + s"expected synthetic Join to use $joinType, got ${syntheticJoin.head.joinType}") + + val generate = rewritten.collect { case g: Generate => g } + assert(generate.size == 1, + s"expected exactly one Generate in the rewritten plan, got ${generate.size}") + val expectedOuter = joinType == LeftOuter + assert(generate.head.outer == expectedOuter, + s"expected Generate.outer == $expectedOuter for $joinType, got ${generate.head.outer}") + } + } + test("EXACT (approx = false) produces the same rewrite as APPROX") { // Locks in the current invariant that APPROX and EXACT lower through the same // brute-force rewrite. If a future change diverges them (e.g. an APPROX-only @@ -160,7 +187,7 @@ class RewriteNearestByJoinSuite extends PlanTest { val expected = expectedRewrite( left, right, 5, ranking = left.output(0) + right.output(0), - reverse = false, outer = false) + reverse = false, joinType = Inner) comparePlans(normalizeUuidSeed(rewritten), expected, checkAnalysis = false) } @@ -177,7 +204,7 @@ class RewriteNearestByJoinSuite extends PlanTest { val expected = expectedRewrite( left, right, 1, ranking = left.output(0) + right.output(0), - reverse = false, outer = false) + reverse = false, joinType = Inner) comparePlans(normalizeUuidSeed(rewritten), expected, checkAnalysis = false) } @@ -194,7 +221,7 @@ class RewriteNearestByJoinSuite extends PlanTest { val expected = expectedRewrite( left, right, NearestByJoin.MaxNumResults, ranking = left.output(0) + right.output(0), - reverse = false, outer = false) + reverse = false, joinType = Inner) comparePlans(normalizeUuidSeed(rewritten), expected, checkAnalysis = false) } @@ -214,7 +241,7 @@ class RewriteNearestByJoinSuite extends PlanTest { val expected = expectedRewrite( t, tDup, 1, ranking = t.output(0) + tDup.output(0), - reverse = false, outer = false) + reverse = false, joinType = Inner) comparePlans(normalizeUuidSeed(rewritten), expected, checkAnalysis = false) } diff --git a/sql/core/src/test/resources/sql-tests/analyzer-results/join-nearest-by.sql.out b/sql/core/src/test/resources/sql-tests/analyzer-results/join-nearest-by.sql.out index 7a795123cdcc7..48819f1723109 100644 --- a/sql/core/src/test/resources/sql-tests/analyzer-results/join-nearest-by.sql.out +++ b/sql/core/src/test/resources/sql-tests/analyzer-results/join-nearest-by.sql.out @@ -129,6 +129,27 @@ Project [user_id#x, product#x] +- LocalRelation [col1#x, col2#x] +-- !query +SELECT u.user_id, p.product +FROM users u INNER JOIN (SELECT * FROM products WHERE false) p + APPROX NEAREST 1 BY SIMILARITY -abs(u.score - p.pscore) +-- !query analysis +Project [user_id#x, product#x] ++- NearestByJoin Inner, true, 1, -abs((score#x - pscore#x)), NearestBySimilarity + :- SubqueryAlias u + : +- SubqueryAlias spark_catalog.default.users + : +- View (`spark_catalog`.`default`.`users`, [user_id#x, score#x]) + : +- Project [cast(col1#x as int) AS user_id#x, cast(col2#x as decimal(3,1)) AS score#x] + : +- LocalRelation [col1#x, col2#x] + +- SubqueryAlias p + +- Project [product#x, pscore#x] + +- Filter false + +- SubqueryAlias spark_catalog.default.products + +- View (`spark_catalog`.`default`.`products`, [product#x, pscore#x]) + +- Project [cast(col1#x as string) AS product#x, cast(col2#x as decimal(3,1)) AS pscore#x] + +- LocalRelation [col1#x, col2#x] + + -- !query SELECT u.user_id, p.product FROM users u INNER JOIN products p diff --git a/sql/core/src/test/resources/sql-tests/inputs/join-nearest-by.sql b/sql/core/src/test/resources/sql-tests/inputs/join-nearest-by.sql index 20b9b2fb73169..6b3dc63d28e3c 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/join-nearest-by.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/join-nearest-by.sql @@ -36,6 +36,11 @@ SELECT u.user_id, p.product FROM users u LEFT OUTER JOIN (SELECT * FROM products WHERE false) p APPROX NEAREST 1 BY SIMILARITY -abs(u.score - p.pscore); +-- INNER JOIN with NEAREST BY, empty right side +SELECT u.user_id, p.product +FROM users u INNER JOIN (SELECT * FROM products WHERE false) p + APPROX NEAREST 1 BY SIMILARITY -abs(u.score - p.pscore); + -- Explicit INNER keyword SELECT u.user_id, p.product FROM users u INNER JOIN products p diff --git a/sql/core/src/test/resources/sql-tests/results/join-nearest-by.sql.out b/sql/core/src/test/resources/sql-tests/results/join-nearest-by.sql.out index 286c61723b280..d06fb53686e78 100644 --- a/sql/core/src/test/resources/sql-tests/results/join-nearest-by.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/join-nearest-by.sql.out @@ -90,6 +90,16 @@ struct 3 NULL +-- !query +SELECT u.user_id, p.product +FROM users u INNER JOIN (SELECT * FROM products WHERE false) p + APPROX NEAREST 1 BY SIMILARITY -abs(u.score - p.pscore) +-- !query schema +struct +-- !query output + + + -- !query SELECT u.user_id, p.product FROM users u INNER JOIN products p @@ -286,12 +296,12 @@ AdaptiveSparkPlan isFinalPlan=false +- SortAggregate(key=[__qid#x], functions=[partial_first(user_id#x, false), partial_max_by(named_struct(product, product#x, pscore, pscore#x), __ranking__#x, 1, false, 0, 0)]) +- Sort [__qid#x ASC NULLS FIRST], false, 0 +- Project [user_id#x, __qid#x, product#x, pscore#x, (rand(0) + cast(pscore#x as double)) AS __ranking__#x] - +- BroadcastNestedLoopJoin BuildRight, LeftOuter - :- Project [col1#x AS user_id#x, uuid(Some(x)) AS __qid#x] - : +- LocalTableScan [col1#x, col2#x] - +- BroadcastExchange IdentityBroadcastMode, [plan_id=x] - +- Project [col1#x AS product#x, col2#x AS pscore#x] - +- LocalTableScan [col1#x, col2#x] + +- BroadcastNestedLoopJoin BuildLeft, Inner + :- BroadcastExchange IdentityBroadcastMode, [plan_id=x] + : +- Project [col1#x AS user_id#x, uuid(Some(x)) AS __qid#x] + : +- LocalTableScan [col1#x, col2#x] + +- Project [col1#x AS product#x, col2#x AS pscore#x] + +- LocalTableScan [col1#x, col2#x] -- !query @@ -313,7 +323,7 @@ AdaptiveSparkPlan isFinalPlan=false +- Exchange hashpartitioning(__qid#x, 4), ENSURE_REQUIREMENTS, [plan_id=x] +- SortAggregate(key=[__qid#x], functions=[partial_first(user_id#x, false), partial_min_by(named_struct(product, product#x, pscore, pscore#x), abs((score#x - pscore#x)), 2, true, 0, 0)]) +- Sort [__qid#x ASC NULLS FIRST], false, 0 - +- BroadcastNestedLoopJoin BuildRight, LeftOuter + +- BroadcastNestedLoopJoin BuildRight, Inner :- Filter (user_id#x > 1) : +- Project [col1#x AS user_id#x, col2#x AS score#x, uuid(Some(x)) AS __qid#x] : +- LocalTableScan [col1#x, col2#x] @@ -342,7 +352,7 @@ AdaptiveSparkPlan isFinalPlan=false +- Exchange hashpartitioning(__qid#x, 4), ENSURE_REQUIREMENTS, [plan_id=x] +- SortAggregate(key=[__qid#x], functions=[partial_first(user_id#x, false), partial_min_by(named_struct(product, product#x, pscore, pscore#x), abs((score#x - pscore#x)), 2, true, 0, 0)]) +- Sort [__qid#x ASC NULLS FIRST], false, 0 - +- BroadcastNestedLoopJoin BuildRight, LeftOuter + +- BroadcastNestedLoopJoin BuildRight, Inner :- Project [col1#x AS user_id#x, col2#x AS score#x, uuid(Some(x)) AS __qid#x] : +- LocalTableScan [col1#x, col2#x] +- BroadcastExchange IdentityBroadcastMode, [plan_id=x]