From 5c13418898bc4b88102ebb842ea110bd43062f6c Mon Sep 17 00:00:00 2001 From: Chirag Singh Date: Mon, 17 Nov 2025 08:56:51 -0800 Subject: [PATCH 1/4] fix --- .../plans/physical/partitioning.scala | 42 +++++++----- .../util/InternalRowComparableWrapper.scala | 64 ++++++++++++++++++- .../execution/KeyGroupedPartitionedScan.scala | 43 +++++++++---- .../exchange/EnsureRequirements.scala | 38 +++++++---- 4 files changed, 148 insertions(+), 39 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala index 1cbb49c7a1f7..4e80fa3199e6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala @@ -23,10 +23,10 @@ import scala.collection.mutable import org.apache.spark.{SparkException, SparkUnsupportedOperationException} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.util.InternalRowComparableWrapper +import org.apache.spark.sql.catalyst.util.BoundInternalRowComparableWrapper import org.apache.spark.sql.connector.catalog.functions.Reducer import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.types.{DataType, IntegerType} +import org.apache.spark.sql.types.{DataType, IntegerType, StructType} /** * Specifies how tuples that share common expressions will be distributed when a query is executed @@ -428,10 +428,13 @@ case class KeyGroupedPartitioning( } lazy val uniquePartitionValues: Seq[InternalRow] = { + val dataTypes = expressions.map(_.dataType) + val (structType, ordering) = + BoundInternalRowComparableWrapper.getStructTypeAndOrdering(dataTypes) partitionValues - .map(InternalRowComparableWrapper(_, expressions)) - .distinct - .map(_.row) + .map(new BoundInternalRowComparableWrapper(_, dataTypes, ordering, structType)) + .distinct + .map(_.row) } override protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): Expression = @@ -448,11 +451,14 @@ object KeyGroupedPartitioning { val projectedPartitionValues = partitionValues.map(project(expressions, projectionPositions, _)) val projectedOriginalPartitionValues = originalPartitionValues.map(project(expressions, projectionPositions, _)) + val dataTypes = projectedExpressions.map(_.dataType) + val (structType, ordering) = + BoundInternalRowComparableWrapper.getStructTypeAndOrdering(dataTypes) val finalPartitionValues = projectedPartitionValues - .map(InternalRowComparableWrapper(_, projectedExpressions)) - .distinct - .map(_.row) + .map(new BoundInternalRowComparableWrapper(_, dataTypes, ordering, structType)) + .distinct + .map(_.row) KeyGroupedPartitioning(projectedExpressions, finalPartitionValues.length, finalPartitionValues, projectedOriginalPartitionValues) @@ -867,12 +873,16 @@ case class KeyGroupedShuffleSpec( // transform functions. // 4. the partition values from both sides are following the same order. case otherSpec @ KeyGroupedShuffleSpec(otherPartitioning, otherDistribution, _) => + lazy val dataTypes = partitioning.expressions.map(_.dataType) + lazy val (structType, ordering) = + BoundInternalRowComparableWrapper.getStructTypeAndOrdering(dataTypes) distribution.clustering.length == otherDistribution.clustering.length && numPartitions == other.numPartitions && areKeysCompatible(otherSpec) && partitioning.partitionValues.zip(otherPartitioning.partitionValues).forall { case (left, right) => - InternalRowComparableWrapper(left, partitioning.expressions) - .equals(InternalRowComparableWrapper(right, partitioning.expressions)) + new BoundInternalRowComparableWrapper(left, dataTypes, ordering, structType) + .equals( + new BoundInternalRowComparableWrapper(right, dataTypes, ordering, structType)) } case ShuffleSpecCollection(specs) => specs.exists(isCompatibleWith) @@ -957,15 +967,17 @@ case class KeyGroupedShuffleSpec( object KeyGroupedShuffleSpec { def reducePartitionValue( row: InternalRow, - expressions: Seq[Expression], - reducers: Seq[Option[Reducer[_, _]]]): - InternalRowComparableWrapper = { - val partitionVals = row.toSeq(expressions.map(_.dataType)) + reducers: Seq[Option[Reducer[_, _]]], + dataTypes: Seq[DataType], + ordering: BaseOrdering, + structType: StructType): BoundInternalRowComparableWrapper = { + val partitionVals = row.toSeq(dataTypes) val reducedRow = partitionVals.zip(reducers).map{ case (v, Some(reducer: Reducer[Any, Any])) => reducer.reduce(v) case (v, _) => v }.toArray - InternalRowComparableWrapper(new GenericInternalRow(reducedRow), expressions) + new BoundInternalRowComparableWrapper( + new GenericInternalRow(reducedRow), dataTypes, ordering, structType) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/InternalRowComparableWrapper.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/InternalRowComparableWrapper.scala index ba3d65fea027..5e28f25c45c8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/InternalRowComparableWrapper.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/InternalRowComparableWrapper.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst.util import scala.collection.mutable import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.{Expression, Murmur3HashFunction, RowOrdering} +import org.apache.spark.sql.catalyst.expressions.{BaseOrdering, Expression, Murmur3HashFunction, RowOrdering} import org.apache.spark.sql.connector.read.{HasPartitionKey, InputPartition} import org.apache.spark.sql.types.{DataType, StructField, StructType} import org.apache.spark.util.NonFateSharingCache @@ -112,3 +112,65 @@ object InternalRowComparableWrapper { result.toSeq } } + +/** + * Effectively the same as [[InternalRowComparableWrapper]], but using a precomputed `ordering` + * and `structType` to avoid the cache lookup for each row. + */ +class BoundInternalRowComparableWrapper( + val row: InternalRow, + val dataTypes: Seq[DataType], + val ordering: BaseOrdering, + val structType: StructType) { + + override def hashCode(): Int = Murmur3HashFunction.hash( + row, + structType, + 42L, + isCollationAware = true, + // legacyCollationAwareHashing only matters when isCollationAware is false. + legacyCollationAwareHashing = false).toInt + + override def equals(other: Any): Boolean = { + if (!other.isInstanceOf[BoundInternalRowComparableWrapper]) { + return false + } + val otherWrapper = other.asInstanceOf[BoundInternalRowComparableWrapper] + if (!otherWrapper.dataTypes.equals(this.dataTypes)) { + return false + } + ordering.compare(row, otherWrapper.row) == 0 + } +} + +object BoundInternalRowComparableWrapper { + /** Compute the schema and row ordering for a given list of data types. */ + def getStructTypeAndOrdering(dataTypes: Seq[DataType]): (StructType, BaseOrdering) = + StructType(dataTypes.map(t => StructField("f", t))) -> + RowOrdering.createNaturalAscendingOrdering(dataTypes) + + def mergePartitions( + leftPartitioning: Seq[InternalRow], + rightPartitioning: Seq[InternalRow], + partitionExpression: Seq[Expression], + intersect: Boolean = false): Seq[BoundInternalRowComparableWrapper] = { + val partitionDataTypes = partitionExpression.map(_.dataType) + val (structType, ordering) = getStructTypeAndOrdering(partitionDataTypes) + + val leftPartitionSet = new mutable.HashSet[BoundInternalRowComparableWrapper] + leftPartitioning + .map(new BoundInternalRowComparableWrapper(_, partitionDataTypes, ordering, structType)) + .foreach(partition => leftPartitionSet.add(partition)) + val rightPartitionSet = new mutable.HashSet[BoundInternalRowComparableWrapper] + rightPartitioning + .map(new BoundInternalRowComparableWrapper(_, partitionDataTypes, ordering, structType)) + .foreach(partition => rightPartitionSet.add(partition)) + + val result = if (intersect) { + leftPartitionSet.intersect(rightPartitionSet) + } else { + leftPartitionSet.union(rightPartitionSet) + } + result.toSeq + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/KeyGroupedPartitionedScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/KeyGroupedPartitionedScan.scala index 5a789179219a..1237d2047d7a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/KeyGroupedPartitionedScan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/KeyGroupedPartitionedScan.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.execution import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.RowOrdering import org.apache.spark.sql.catalyst.plans.physical.{KeyGroupedPartitioning, KeyGroupedShuffleSpec} -import org.apache.spark.sql.catalyst.util.InternalRowComparableWrapper +import org.apache.spark.sql.catalyst.util.BoundInternalRowComparableWrapper import org.apache.spark.sql.execution.joins.StoragePartitionJoinParams /** Base trait for a data source scan capable of producing a key-grouped output. */ @@ -50,9 +50,12 @@ trait KeyGroupedPartitionedScan[T] { case None => spjParams.joinKeyPositions match { case Some(projectionPositions) => basePartitioning.partitionValues.map { r => + val dataTypes = expressions.map(_.dataType) + val (structType, ordering) = + BoundInternalRowComparableWrapper.getStructTypeAndOrdering(dataTypes) val projectedRow = KeyGroupedPartitioning.project(expressions, projectionPositions, r) - InternalRowComparableWrapper(projectedRow, expressions) + new BoundInternalRowComparableWrapper(projectedRow, dataTypes, ordering, structType) }.distinct.map(_.row) case _ => basePartitioning.partitionValues } @@ -83,11 +86,14 @@ trait KeyGroupedPartitionedScan[T] { val (groupedPartitions, partExpressions) = spjParams.joinKeyPositions match { case Some(projectPositions) => val projectedExpressions = projectPositions.map(i => expressions(i)) + val projectedTypes = projectedExpressions.map(_.dataType) + val (structType, ordering) = + BoundInternalRowComparableWrapper.getStructTypeAndOrdering(projectedTypes) val parts = filteredPartitions.flatten.groupBy(part => { val row = partitionValueAccessor(part) val projectedRow = KeyGroupedPartitioning.project( expressions, projectPositions, row) - InternalRowComparableWrapper(projectedRow, projectedExpressions) + new BoundInternalRowComparableWrapper(projectedRow, projectedTypes, ordering, structType) }).map { case (wrapper, splits) => (wrapper.row, splits) }.toSeq (parts, projectedExpressions) case _ => @@ -99,10 +105,14 @@ trait KeyGroupedPartitionedScan[T] { } // Also re-group the partitions if we are reducing compatible partition expressions + val partitionDataTypes = partExpressions.map(_.dataType) + val (structType, ordering) = + BoundInternalRowComparableWrapper.getStructTypeAndOrdering(partitionDataTypes) val finalGroupedPartitions = spjParams.reducers match { case Some(reducers) => val result = groupedPartitions.groupBy { case (row, _) => - KeyGroupedShuffleSpec.reducePartitionValue(row, partExpressions, reducers) + KeyGroupedShuffleSpec.reducePartitionValue( + row, reducers, partitionDataTypes, ordering, structType) }.map { case (wrapper, splits) => (wrapper.row, splits.flatMap(_._2)) }.toSeq val rowOrdering = RowOrdering.createNaturalAscendingOrdering( partExpressions.map(_.dataType)) @@ -118,17 +128,21 @@ trait KeyGroupedPartitionedScan[T] { // should contain. val commonPartValuesMap = spjParams.commonPartitionValues .get - .map(t => (InternalRowComparableWrapper(t._1, partExpressions), t._2)) + .map(t => (new BoundInternalRowComparableWrapper( + t._1, partitionDataTypes, ordering, structType), t._2)) .toMap val filteredGroupedPartitions = finalGroupedPartitions.filter { case (partValues, _) => commonPartValuesMap.keySet.contains( - InternalRowComparableWrapper(partValues, partExpressions)) + new BoundInternalRowComparableWrapper( + partValues, partitionDataTypes, ordering, structType)) } val nestGroupedPartitions = filteredGroupedPartitions.map { case (partValue, splits) => // `commonPartValuesMap` should contain the part value since it's the super set. val numSplits = commonPartValuesMap - .get(InternalRowComparableWrapper(partValue, partExpressions)) + .get( + new BoundInternalRowComparableWrapper( + partValue, partitionDataTypes, ordering, structType)) assert(numSplits.isDefined, s"Partition value $partValue does not exist in " + "common partition values from Spark plan") @@ -143,7 +157,11 @@ trait KeyGroupedPartitionedScan[T] { // sides of a join will have the same number of partitions & splits. splits.map(Seq(_)).padTo(numSplits.get, Seq.empty) } - (InternalRowComparableWrapper(partValue, partExpressions), newSplits) + ( + new BoundInternalRowComparableWrapper( + partValue, partitionDataTypes, ordering, structType), + newSplits + ) } // Now fill missing partition keys with empty partitions @@ -152,14 +170,16 @@ trait KeyGroupedPartitionedScan[T] { case (partValue, numSplits) => // Use empty partition for those partition values that are not present. partitionMapping.getOrElse( - InternalRowComparableWrapper(partValue, partExpressions), + new BoundInternalRowComparableWrapper( + partValue, partitionDataTypes, ordering, structType), Seq.fill(numSplits)(Seq.empty)) } } else { // either `commonPartitionValues` is not defined, or it is defined but // `applyPartialClustering` is false. val partitionMapping = finalGroupedPartitions.map { case (partValue, splits) => - InternalRowComparableWrapper(partValue, partExpressions) -> splits + new BoundInternalRowComparableWrapper( + partValue, partitionDataTypes, ordering, structType) -> splits }.toMap // In case `commonPartitionValues` is not defined (e.g., SPJ is not used), there @@ -169,7 +189,8 @@ trait KeyGroupedPartitionedScan[T] { p.uniquePartitionValues.map { partValue => // Use empty partition for those partition values that are not present partitionMapping.getOrElse( - InternalRowComparableWrapper(partValue, partExpressions), Seq.empty) + new BoundInternalRowComparableWrapper( + partValue, partitionDataTypes, ordering, structType), Seq.empty) } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala index b97d765afcf7..b493eb704bb3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala @@ -20,13 +20,13 @@ package org.apache.spark.sql.execution.exchange import scala.collection.mutable import scala.collection.mutable.ArrayBuffer -import org.apache.spark.internal.{LogKeys} +import org.apache.spark.internal.LogKeys import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.catalyst.rules.Rule -import org.apache.spark.sql.catalyst.util.InternalRowComparableWrapper +import org.apache.spark.sql.catalyst.util.{BoundInternalRowComparableWrapper} import org.apache.spark.sql.connector.catalog.functions.Reducer import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.datasources.v2.BatchScanExec @@ -579,15 +579,21 @@ case class EnsureRequirements( // In partially clustered distribution, we should use un-grouped partition values val spec = if (replicateLeftSide) rightSpec else leftSpec val partValues = spec.partitioning.originalPartitionValues + val partitionDataTypes = partitionExprs.map(_.dataType) + val (structType, ordering) = + BoundInternalRowComparableWrapper.getStructTypeAndOrdering(partitionDataTypes) val numExpectedPartitions = partValues - .map(InternalRowComparableWrapper(_, partitionExprs)) + .map( + new BoundInternalRowComparableWrapper( + _, partitionDataTypes, ordering, structType)) .groupBy(identity) .transform((_, v) => v.size) mergedPartValues = mergedPartValues.map { case (partVal, numParts) => (partVal, numExpectedPartitions.getOrElse( - InternalRowComparableWrapper(partVal, partitionExprs), numParts)) + new BoundInternalRowComparableWrapper( + partVal, partitionDataTypes, ordering, structType), numParts)) } logInfo(log"After applying partially clustered distribution, there are " + @@ -679,9 +685,14 @@ case class EnsureRequirements( expressions: Seq[Expression], reducers: Option[Seq[Option[Reducer[_, _]]]]) = { reducers match { - case Some(reducers) => partValues.map { row => - KeyGroupedShuffleSpec.reducePartitionValue(row, expressions, reducers) - }.distinct.map(_.row) + case Some(reducers) => + val partitionDataTypes = expressions.map(_.dataType) + val (structType, ordering) = + BoundInternalRowComparableWrapper.getStructTypeAndOrdering(partitionDataTypes) + partValues.map { row => + KeyGroupedShuffleSpec.reducePartitionValue( + row, reducers, partitionDataTypes, ordering, structType) + }.distinct.map(_.row) case _ => partValues } } @@ -737,20 +748,23 @@ case class EnsureRequirements( rightPartitioning: Seq[InternalRow], partitionExpression: Seq[Expression], joinType: JoinType): Seq[InternalRow] = { + val partitionDataTypes = partitionExpression.map(_.dataType) + val (structType, ordering) = + BoundInternalRowComparableWrapper.getStructTypeAndOrdering(partitionDataTypes) val merged = if (SQLConf.get.getConf(SQLConf.V2_BUCKETING_PARTITION_FILTER_ENABLED)) { joinType match { - case Inner => InternalRowComparableWrapper.mergePartitions( + case Inner => BoundInternalRowComparableWrapper.mergePartitions( leftPartitioning, rightPartitioning, partitionExpression, intersect = true) case LeftOuter => leftPartitioning.map( - InternalRowComparableWrapper(_, partitionExpression)) + new BoundInternalRowComparableWrapper(_, partitionDataTypes, ordering, structType)) case RightOuter => rightPartitioning.map( - InternalRowComparableWrapper(_, partitionExpression)) - case _ => InternalRowComparableWrapper.mergePartitions(leftPartitioning, + new BoundInternalRowComparableWrapper(_, partitionDataTypes, ordering, structType)) + case _ => BoundInternalRowComparableWrapper.mergePartitions(leftPartitioning, rightPartitioning, partitionExpression) } } else { - InternalRowComparableWrapper.mergePartitions(leftPartitioning, rightPartitioning, + BoundInternalRowComparableWrapper.mergePartitions(leftPartitioning, rightPartitioning, partitionExpression) } From aaef0846e9f164ceb4eda2dcbd482a488c5eea63 Mon Sep 17 00:00:00 2001 From: Chirag Singh Date: Mon, 17 Nov 2025 14:50:38 -0800 Subject: [PATCH 2/4] fix --- .../plans/physical/partitioning.scala | 31 +++---- .../util/InternalRowComparableWrapper.scala | 85 +++++-------------- ...nternalRowComparableWrapperBenchmark.scala | 5 +- .../execution/KeyGroupedPartitionedScan.scala | 51 +++++------ .../exchange/EnsureRequirements.scala | 39 ++++----- 5 files changed, 74 insertions(+), 137 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala index 4e80fa3199e6..f07a890eee65 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala @@ -23,10 +23,10 @@ import scala.collection.mutable import org.apache.spark.{SparkException, SparkUnsupportedOperationException} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.util.BoundInternalRowComparableWrapper +import org.apache.spark.sql.catalyst.util.InternalRowComparableWrapper import org.apache.spark.sql.connector.catalog.functions.Reducer import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.types.{DataType, IntegerType, StructType} +import org.apache.spark.sql.types.{DataType, IntegerType} /** * Specifies how tuples that share common expressions will be distributed when a query is executed @@ -429,10 +429,10 @@ case class KeyGroupedPartitioning( lazy val uniquePartitionValues: Seq[InternalRow] = { val dataTypes = expressions.map(_.dataType) - val (structType, ordering) = - BoundInternalRowComparableWrapper.getStructTypeAndOrdering(dataTypes) + val internalRowComparableFactory = + InternalRowComparableWrapper.getInternalRowComparableWrapperFactory(dataTypes) partitionValues - .map(new BoundInternalRowComparableWrapper(_, dataTypes, ordering, structType)) + .map(internalRowComparableFactory) .distinct .map(_.row) } @@ -452,11 +452,11 @@ object KeyGroupedPartitioning { val projectedOriginalPartitionValues = originalPartitionValues.map(project(expressions, projectionPositions, _)) val dataTypes = projectedExpressions.map(_.dataType) - val (structType, ordering) = - BoundInternalRowComparableWrapper.getStructTypeAndOrdering(dataTypes) + val internalRowComparableFactory = + InternalRowComparableWrapper.getInternalRowComparableWrapperFactory(dataTypes) val finalPartitionValues = projectedPartitionValues - .map(new BoundInternalRowComparableWrapper(_, dataTypes, ordering, structType)) + .map(internalRowComparableFactory) .distinct .map(_.row) @@ -874,15 +874,13 @@ case class KeyGroupedShuffleSpec( // 4. the partition values from both sides are following the same order. case otherSpec @ KeyGroupedShuffleSpec(otherPartitioning, otherDistribution, _) => lazy val dataTypes = partitioning.expressions.map(_.dataType) - lazy val (structType, ordering) = - BoundInternalRowComparableWrapper.getStructTypeAndOrdering(dataTypes) + lazy val internalRowComparableFactory = + InternalRowComparableWrapper.getInternalRowComparableWrapperFactory(dataTypes) distribution.clustering.length == otherDistribution.clustering.length && numPartitions == other.numPartitions && areKeysCompatible(otherSpec) && partitioning.partitionValues.zip(otherPartitioning.partitionValues).forall { case (left, right) => - new BoundInternalRowComparableWrapper(left, dataTypes, ordering, structType) - .equals( - new BoundInternalRowComparableWrapper(right, dataTypes, ordering, structType)) + internalRowComparableFactory(left).equals(internalRowComparableFactory(right)) } case ShuffleSpecCollection(specs) => specs.exists(isCompatibleWith) @@ -969,15 +967,14 @@ object KeyGroupedShuffleSpec { row: InternalRow, reducers: Seq[Option[Reducer[_, _]]], dataTypes: Seq[DataType], - ordering: BaseOrdering, - structType: StructType): BoundInternalRowComparableWrapper = { + internalRowComparableWrapperFactory: InternalRow => InternalRowComparableWrapper + ): InternalRowComparableWrapper = { val partitionVals = row.toSeq(dataTypes) val reducedRow = partitionVals.zip(reducers).map{ case (v, Some(reducer: Reducer[Any, Any])) => reducer.reduce(v) case (v, _) => v }.toArray - new BoundInternalRowComparableWrapper( - new GenericInternalRow(reducedRow), dataTypes, ordering, structType) + internalRowComparableWrapperFactory(new GenericInternalRow(reducedRow)) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/InternalRowComparableWrapper.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/InternalRowComparableWrapper.scala index 5e28f25c45c8..8a8ef66d5c76 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/InternalRowComparableWrapper.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/InternalRowComparableWrapper.scala @@ -33,11 +33,17 @@ import org.apache.spark.util.NonFateSharingCache * * @param dataTypes the data types for the row */ -class InternalRowComparableWrapper(val row: InternalRow, val dataTypes: Seq[DataType]) { - import InternalRowComparableWrapper._ +class InternalRowComparableWrapper private ( + val row: InternalRow, + val dataTypes: Seq[DataType], + val structType: StructType, + val ordering: BaseOrdering) { - private val structType = structTypeCache.get(dataTypes) - private val ordering = orderingCache.get(dataTypes) + def this(row: InternalRow, dataTypes: Seq[DataType]) = this( + row, + dataTypes, + InternalRowComparableWrapper.structTypeCache.get(dataTypes), + InternalRowComparableWrapper.orderingCache.get(dataTypes)) override def hashCode(): Int = Murmur3HashFunction.hash( row, @@ -96,12 +102,14 @@ object InternalRowComparableWrapper { intersect: Boolean = false): Seq[InternalRowComparableWrapper] = { val partitionDataTypes = partitionExpression.map(_.dataType) val leftPartitionSet = new mutable.HashSet[InternalRowComparableWrapper] + val internalRowComparableWrapperFactory = + getInternalRowComparableWrapperFactory(partitionDataTypes) leftPartitioning - .map(new InternalRowComparableWrapper(_, partitionDataTypes)) + .map(internalRowComparableWrapperFactory) .foreach(partition => leftPartitionSet.add(partition)) val rightPartitionSet = new mutable.HashSet[InternalRowComparableWrapper] rightPartitioning - .map(new InternalRowComparableWrapper(_, partitionDataTypes)) + .map(internalRowComparableWrapperFactory) .foreach(partition => rightPartitionSet.add(partition)) val result = if (intersect) { @@ -111,66 +119,11 @@ object InternalRowComparableWrapper { } result.toSeq } -} - -/** - * Effectively the same as [[InternalRowComparableWrapper]], but using a precomputed `ordering` - * and `structType` to avoid the cache lookup for each row. - */ -class BoundInternalRowComparableWrapper( - val row: InternalRow, - val dataTypes: Seq[DataType], - val ordering: BaseOrdering, - val structType: StructType) { - - override def hashCode(): Int = Murmur3HashFunction.hash( - row, - structType, - 42L, - isCollationAware = true, - // legacyCollationAwareHashing only matters when isCollationAware is false. - legacyCollationAwareHashing = false).toInt - - override def equals(other: Any): Boolean = { - if (!other.isInstanceOf[BoundInternalRowComparableWrapper]) { - return false - } - val otherWrapper = other.asInstanceOf[BoundInternalRowComparableWrapper] - if (!otherWrapper.dataTypes.equals(this.dataTypes)) { - return false - } - ordering.compare(row, otherWrapper.row) == 0 - } -} - -object BoundInternalRowComparableWrapper { - /** Compute the schema and row ordering for a given list of data types. */ - def getStructTypeAndOrdering(dataTypes: Seq[DataType]): (StructType, BaseOrdering) = - StructType(dataTypes.map(t => StructField("f", t))) -> - RowOrdering.createNaturalAscendingOrdering(dataTypes) - def mergePartitions( - leftPartitioning: Seq[InternalRow], - rightPartitioning: Seq[InternalRow], - partitionExpression: Seq[Expression], - intersect: Boolean = false): Seq[BoundInternalRowComparableWrapper] = { - val partitionDataTypes = partitionExpression.map(_.dataType) - val (structType, ordering) = getStructTypeAndOrdering(partitionDataTypes) - - val leftPartitionSet = new mutable.HashSet[BoundInternalRowComparableWrapper] - leftPartitioning - .map(new BoundInternalRowComparableWrapper(_, partitionDataTypes, ordering, structType)) - .foreach(partition => leftPartitionSet.add(partition)) - val rightPartitionSet = new mutable.HashSet[BoundInternalRowComparableWrapper] - rightPartitioning - .map(new BoundInternalRowComparableWrapper(_, partitionDataTypes, ordering, structType)) - .foreach(partition => rightPartitionSet.add(partition)) - - val result = if (intersect) { - leftPartitionSet.intersect(rightPartitionSet) - } else { - leftPartitionSet.union(rightPartitionSet) - } - result.toSeq + def getInternalRowComparableWrapperFactory( + dataTypes: Seq[DataType]): InternalRow => InternalRowComparableWrapper = { + val structType = structTypeCache.get(dataTypes) + val ordering = orderingCache.get(dataTypes) + row: InternalRow => new InternalRowComparableWrapper(row, dataTypes, structType, ordering) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/InternalRowComparableWrapperBenchmark.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/InternalRowComparableWrapperBenchmark.scala index f3dd232129e8..764dac35f673 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/InternalRowComparableWrapperBenchmark.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/InternalRowComparableWrapperBenchmark.scala @@ -48,8 +48,11 @@ object InternalRowComparableWrapperBenchmark extends BenchmarkBase { val benchmark = new Benchmark("internal row comparable wrapper", partitionNum, output = output) benchmark.addCase("toSet") { _ => + val internalRowComparableWrapperFactory = + InternalRowComparableWrapper.getInternalRowComparableWrapperFactory( + Seq(IntegerType, IntegerType)) val distinct = partitions - .map(new InternalRowComparableWrapper(_, Seq(IntegerType, IntegerType))) + .map(internalRowComparableWrapperFactory) .toSet assert(distinct.size == bucketNum) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/KeyGroupedPartitionedScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/KeyGroupedPartitionedScan.scala index 1237d2047d7a..31a6cb46b721 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/KeyGroupedPartitionedScan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/KeyGroupedPartitionedScan.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.execution import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.RowOrdering import org.apache.spark.sql.catalyst.plans.physical.{KeyGroupedPartitioning, KeyGroupedShuffleSpec} -import org.apache.spark.sql.catalyst.util.BoundInternalRowComparableWrapper +import org.apache.spark.sql.catalyst.util.InternalRowComparableWrapper import org.apache.spark.sql.execution.joins.StoragePartitionJoinParams /** Base trait for a data source scan capable of producing a key-grouped output. */ @@ -49,13 +49,14 @@ trait KeyGroupedPartitionedScan[T] { } case None => spjParams.joinKeyPositions match { - case Some(projectionPositions) => basePartitioning.partitionValues.map { r => + case Some(projectionPositions) => val dataTypes = expressions.map(_.dataType) - val (structType, ordering) = - BoundInternalRowComparableWrapper.getStructTypeAndOrdering(dataTypes) + val internalRowComparableWrapperFactory = + InternalRowComparableWrapper.getInternalRowComparableWrapperFactory(dataTypes) + basePartitioning.partitionValues.map { r => val projectedRow = KeyGroupedPartitioning.project(expressions, projectionPositions, r) - new BoundInternalRowComparableWrapper(projectedRow, dataTypes, ordering, structType) + internalRowComparableWrapperFactory(projectedRow) }.distinct.map(_.row) case _ => basePartitioning.partitionValues } @@ -87,13 +88,13 @@ trait KeyGroupedPartitionedScan[T] { case Some(projectPositions) => val projectedExpressions = projectPositions.map(i => expressions(i)) val projectedTypes = projectedExpressions.map(_.dataType) - val (structType, ordering) = - BoundInternalRowComparableWrapper.getStructTypeAndOrdering(projectedTypes) + val internalRowComparableWrapperFactory = + InternalRowComparableWrapper.getInternalRowComparableWrapperFactory(projectedTypes) val parts = filteredPartitions.flatten.groupBy(part => { val row = partitionValueAccessor(part) val projectedRow = KeyGroupedPartitioning.project( expressions, projectPositions, row) - new BoundInternalRowComparableWrapper(projectedRow, projectedTypes, ordering, structType) + internalRowComparableWrapperFactory(projectedRow) }).map { case (wrapper, splits) => (wrapper.row, splits) }.toSeq (parts, projectedExpressions) case _ => @@ -106,13 +107,13 @@ trait KeyGroupedPartitionedScan[T] { // Also re-group the partitions if we are reducing compatible partition expressions val partitionDataTypes = partExpressions.map(_.dataType) - val (structType, ordering) = - BoundInternalRowComparableWrapper.getStructTypeAndOrdering(partitionDataTypes) + val internalRowComparableWrapperFactory = + InternalRowComparableWrapper.getInternalRowComparableWrapperFactory(partitionDataTypes) val finalGroupedPartitions = spjParams.reducers match { case Some(reducers) => val result = groupedPartitions.groupBy { case (row, _) => KeyGroupedShuffleSpec.reducePartitionValue( - row, reducers, partitionDataTypes, ordering, structType) + row, reducers, partitionDataTypes, internalRowComparableWrapperFactory) }.map { case (wrapper, splits) => (wrapper.row, splits.flatMap(_._2)) }.toSeq val rowOrdering = RowOrdering.createNaturalAscendingOrdering( partExpressions.map(_.dataType)) @@ -128,21 +129,15 @@ trait KeyGroupedPartitionedScan[T] { // should contain. val commonPartValuesMap = spjParams.commonPartitionValues .get - .map(t => (new BoundInternalRowComparableWrapper( - t._1, partitionDataTypes, ordering, structType), t._2)) + .map(t => (internalRowComparableWrapperFactory(t._1), t._2)) .toMap val filteredGroupedPartitions = finalGroupedPartitions.filter { case (partValues, _) => - commonPartValuesMap.keySet.contains( - new BoundInternalRowComparableWrapper( - partValues, partitionDataTypes, ordering, structType)) + commonPartValuesMap.keySet.contains(internalRowComparableWrapperFactory(partValues)) } val nestGroupedPartitions = filteredGroupedPartitions.map { case (partValue, splits) => // `commonPartValuesMap` should contain the part value since it's the super set. - val numSplits = commonPartValuesMap - .get( - new BoundInternalRowComparableWrapper( - partValue, partitionDataTypes, ordering, structType)) + val numSplits = commonPartValuesMap.get(internalRowComparableWrapperFactory(partValue)) assert(numSplits.isDefined, s"Partition value $partValue does not exist in " + "common partition values from Spark plan") @@ -157,11 +152,7 @@ trait KeyGroupedPartitionedScan[T] { // sides of a join will have the same number of partitions & splits. splits.map(Seq(_)).padTo(numSplits.get, Seq.empty) } - ( - new BoundInternalRowComparableWrapper( - partValue, partitionDataTypes, ordering, structType), - newSplits - ) + (internalRowComparableWrapperFactory(partValue), newSplits) } // Now fill missing partition keys with empty partitions @@ -170,16 +161,14 @@ trait KeyGroupedPartitionedScan[T] { case (partValue, numSplits) => // Use empty partition for those partition values that are not present. partitionMapping.getOrElse( - new BoundInternalRowComparableWrapper( - partValue, partitionDataTypes, ordering, structType), + internalRowComparableWrapperFactory(partValue), Seq.fill(numSplits)(Seq.empty)) } } else { // either `commonPartitionValues` is not defined, or it is defined but // `applyPartialClustering` is false. val partitionMapping = finalGroupedPartitions.map { case (partValue, splits) => - new BoundInternalRowComparableWrapper( - partValue, partitionDataTypes, ordering, structType) -> splits + internalRowComparableWrapperFactory(partValue) -> splits }.toMap // In case `commonPartitionValues` is not defined (e.g., SPJ is not used), there @@ -188,9 +177,7 @@ trait KeyGroupedPartitionedScan[T] { // partition values here so that grouped partitions won't get duplicated. p.uniquePartitionValues.map { partValue => // Use empty partition for those partition values that are not present - partitionMapping.getOrElse( - new BoundInternalRowComparableWrapper( - partValue, partitionDataTypes, ordering, structType), Seq.empty) + partitionMapping.getOrElse(internalRowComparableWrapperFactory(partValue), Seq.empty) } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala index b493eb704bb3..9063cb17b852 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala @@ -26,7 +26,7 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.catalyst.rules.Rule -import org.apache.spark.sql.catalyst.util.{BoundInternalRowComparableWrapper} +import org.apache.spark.sql.catalyst.util.InternalRowComparableWrapper import org.apache.spark.sql.connector.catalog.functions.Reducer import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.datasources.v2.BatchScanExec @@ -580,20 +580,18 @@ case class EnsureRequirements( val spec = if (replicateLeftSide) rightSpec else leftSpec val partValues = spec.partitioning.originalPartitionValues val partitionDataTypes = partitionExprs.map(_.dataType) - val (structType, ordering) = - BoundInternalRowComparableWrapper.getStructTypeAndOrdering(partitionDataTypes) + val internalRowComparableWrapperFactory = + InternalRowComparableWrapper.getInternalRowComparableWrapperFactory( + partitionDataTypes) val numExpectedPartitions = partValues - .map( - new BoundInternalRowComparableWrapper( - _, partitionDataTypes, ordering, structType)) + .map(internalRowComparableWrapperFactory) .groupBy(identity) .transform((_, v) => v.size) mergedPartValues = mergedPartValues.map { case (partVal, numParts) => (partVal, numExpectedPartitions.getOrElse( - new BoundInternalRowComparableWrapper( - partVal, partitionDataTypes, ordering, structType), numParts)) + internalRowComparableWrapperFactory(partVal), numParts)) } logInfo(log"After applying partially clustered distribution, there are " + @@ -687,11 +685,12 @@ case class EnsureRequirements( reducers match { case Some(reducers) => val partitionDataTypes = expressions.map(_.dataType) - val (structType, ordering) = - BoundInternalRowComparableWrapper.getStructTypeAndOrdering(partitionDataTypes) + val internalRowComparableWrapperFactory = + InternalRowComparableWrapper.getInternalRowComparableWrapperFactory( + partitionDataTypes) partValues.map { row => KeyGroupedShuffleSpec.reducePartitionValue( - row, reducers, partitionDataTypes, ordering, structType) + row, reducers, partitionDataTypes, internalRowComparableWrapperFactory) }.distinct.map(_.row) case _ => partValues } @@ -749,22 +748,20 @@ case class EnsureRequirements( partitionExpression: Seq[Expression], joinType: JoinType): Seq[InternalRow] = { val partitionDataTypes = partitionExpression.map(_.dataType) - val (structType, ordering) = - BoundInternalRowComparableWrapper.getStructTypeAndOrdering(partitionDataTypes) + val internalRowComparableWrapperFactory = + InternalRowComparableWrapper.getInternalRowComparableWrapperFactory(partitionDataTypes) val merged = if (SQLConf.get.getConf(SQLConf.V2_BUCKETING_PARTITION_FILTER_ENABLED)) { joinType match { - case Inner => BoundInternalRowComparableWrapper.mergePartitions( - leftPartitioning, rightPartitioning, partitionExpression, intersect = true) - case LeftOuter => leftPartitioning.map( - new BoundInternalRowComparableWrapper(_, partitionDataTypes, ordering, structType)) - case RightOuter => rightPartitioning.map( - new BoundInternalRowComparableWrapper(_, partitionDataTypes, ordering, structType)) - case _ => BoundInternalRowComparableWrapper.mergePartitions(leftPartitioning, + case Inner => InternalRowComparableWrapper.mergePartitions( + leftPartitioning, rightPartitioning, partitionExpression) + case LeftOuter => leftPartitioning.map(internalRowComparableWrapperFactory) + case RightOuter => rightPartitioning.map(internalRowComparableWrapperFactory) + case _ => InternalRowComparableWrapper.mergePartitions(leftPartitioning, rightPartitioning, partitionExpression) } } else { - BoundInternalRowComparableWrapper.mergePartitions(leftPartitioning, rightPartitioning, + InternalRowComparableWrapper.mergePartitions(leftPartitioning, rightPartitioning, partitionExpression) } From e234fdc5a2cd218acb50da29c61f52e014f63c76 Mon Sep 17 00:00:00 2001 From: Chirag Singh Date: Mon, 17 Nov 2025 15:00:00 -0800 Subject: [PATCH 3/4] fixes --- .../catalyst/plans/physical/partitioning.scala | 18 +++++++++--------- .../util/InternalRowComparableWrapper.scala | 7 +++++++ .../execution/KeyGroupedPartitionedScan.scala | 8 ++++---- .../exchange/EnsureRequirements.scala | 9 ++++----- 4 files changed, 24 insertions(+), 18 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala index f07a890eee65..5a8505dc6992 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala @@ -428,13 +428,13 @@ case class KeyGroupedPartitioning( } lazy val uniquePartitionValues: Seq[InternalRow] = { - val dataTypes = expressions.map(_.dataType) val internalRowComparableFactory = - InternalRowComparableWrapper.getInternalRowComparableWrapperFactory(dataTypes) + InternalRowComparableWrapper.getInternalRowComparableWrapperFactory( + expressions.map(_.dataType)) partitionValues - .map(internalRowComparableFactory) - .distinct - .map(_.row) + .map(internalRowComparableFactory) + .distinct + .map(_.row) } override protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): Expression = @@ -451,9 +451,9 @@ object KeyGroupedPartitioning { val projectedPartitionValues = partitionValues.map(project(expressions, projectionPositions, _)) val projectedOriginalPartitionValues = originalPartitionValues.map(project(expressions, projectionPositions, _)) - val dataTypes = projectedExpressions.map(_.dataType) val internalRowComparableFactory = - InternalRowComparableWrapper.getInternalRowComparableWrapperFactory(dataTypes) + InternalRowComparableWrapper.getInternalRowComparableWrapperFactory( + projectedExpressions.map(_.dataType)) val finalPartitionValues = projectedPartitionValues .map(internalRowComparableFactory) @@ -873,9 +873,9 @@ case class KeyGroupedShuffleSpec( // transform functions. // 4. the partition values from both sides are following the same order. case otherSpec @ KeyGroupedShuffleSpec(otherPartitioning, otherDistribution, _) => - lazy val dataTypes = partitioning.expressions.map(_.dataType) lazy val internalRowComparableFactory = - InternalRowComparableWrapper.getInternalRowComparableWrapperFactory(dataTypes) + InternalRowComparableWrapper.getInternalRowComparableWrapperFactory( + partitioning.expressions.map(_.dataType)) distribution.clustering.length == otherDistribution.clustering.length && numPartitions == other.numPartitions && areKeysCompatible(otherSpec) && partitioning.partitionValues.zip(otherPartitioning.partitionValues).forall { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/InternalRowComparableWrapper.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/InternalRowComparableWrapper.scala index 8a8ef66d5c76..b9935d40ed98 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/InternalRowComparableWrapper.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/InternalRowComparableWrapper.scala @@ -39,6 +39,12 @@ class InternalRowComparableWrapper private ( val structType: StructType, val ordering: BaseOrdering) { + /** + * Previous constructor for binary compatibility. Prefer using + * `getInternalRowComparableWrapperFactory` for the creation of InternalRowComparableWrapper's in + * hot paths to avoid excessive cache lookups. + */ + @deprecated def this(row: InternalRow, dataTypes: Seq[DataType]) = this( row, dataTypes, @@ -120,6 +126,7 @@ object InternalRowComparableWrapper { result.toSeq } + /** Creates a shared factory method for a given row schema to avoid excessive cache lookups. */ def getInternalRowComparableWrapperFactory( dataTypes: Seq[DataType]): InternalRow => InternalRowComparableWrapper = { val structType = structTypeCache.get(dataTypes) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/KeyGroupedPartitionedScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/KeyGroupedPartitionedScan.scala index 31a6cb46b721..10a6aaa2e185 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/KeyGroupedPartitionedScan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/KeyGroupedPartitionedScan.scala @@ -50,9 +50,9 @@ trait KeyGroupedPartitionedScan[T] { case None => spjParams.joinKeyPositions match { case Some(projectionPositions) => - val dataTypes = expressions.map(_.dataType) val internalRowComparableWrapperFactory = - InternalRowComparableWrapper.getInternalRowComparableWrapperFactory(dataTypes) + InternalRowComparableWrapper.getInternalRowComparableWrapperFactory( + expressions.map(_.dataType)) basePartitioning.partitionValues.map { r => val projectedRow = KeyGroupedPartitioning.project(expressions, projectionPositions, r) @@ -87,9 +87,9 @@ trait KeyGroupedPartitionedScan[T] { val (groupedPartitions, partExpressions) = spjParams.joinKeyPositions match { case Some(projectPositions) => val projectedExpressions = projectPositions.map(i => expressions(i)) - val projectedTypes = projectedExpressions.map(_.dataType) val internalRowComparableWrapperFactory = - InternalRowComparableWrapper.getInternalRowComparableWrapperFactory(projectedTypes) + InternalRowComparableWrapper.getInternalRowComparableWrapperFactory( + projectedExpressions.map(_.dataType)) val parts = filteredPartitions.flatten.groupBy(part => { val row = partitionValueAccessor(part) val projectedRow = KeyGroupedPartitioning.project( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala index 9063cb17b852..9b96cb26d41d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.execution.exchange import scala.collection.mutable import scala.collection.mutable.ArrayBuffer -import org.apache.spark.internal.LogKeys +import org.apache.spark.internal.{LogKeys} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans._ @@ -579,10 +579,9 @@ case class EnsureRequirements( // In partially clustered distribution, we should use un-grouped partition values val spec = if (replicateLeftSide) rightSpec else leftSpec val partValues = spec.partitioning.originalPartitionValues - val partitionDataTypes = partitionExprs.map(_.dataType) val internalRowComparableWrapperFactory = InternalRowComparableWrapper.getInternalRowComparableWrapperFactory( - partitionDataTypes) + partitionExprs.map(_.dataType)) val numExpectedPartitions = partValues .map(internalRowComparableWrapperFactory) @@ -747,9 +746,9 @@ case class EnsureRequirements( rightPartitioning: Seq[InternalRow], partitionExpression: Seq[Expression], joinType: JoinType): Seq[InternalRow] = { - val partitionDataTypes = partitionExpression.map(_.dataType) val internalRowComparableWrapperFactory = - InternalRowComparableWrapper.getInternalRowComparableWrapperFactory(partitionDataTypes) + InternalRowComparableWrapper.getInternalRowComparableWrapperFactory( + partitionExpression.map(_.dataType)) val merged = if (SQLConf.get.getConf(SQLConf.V2_BUCKETING_PARTITION_FILTER_ENABLED)) { joinType match { From c5ae5b79cdf85fbb50e49160100305f94f0ad5e7 Mon Sep 17 00:00:00 2001 From: Chirag Singh Date: Tue, 18 Nov 2025 08:29:43 -0800 Subject: [PATCH 4/4] fix --- .../spark/sql/execution/exchange/EnsureRequirements.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala index 9b96cb26d41d..82b0d786f283 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala @@ -753,7 +753,7 @@ case class EnsureRequirements( val merged = if (SQLConf.get.getConf(SQLConf.V2_BUCKETING_PARTITION_FILTER_ENABLED)) { joinType match { case Inner => InternalRowComparableWrapper.mergePartitions( - leftPartitioning, rightPartitioning, partitionExpression) + leftPartitioning, rightPartitioning, partitionExpression, intersect = true) case LeftOuter => leftPartitioning.map(internalRowComparableWrapperFactory) case RightOuter => rightPartitioning.map(internalRowComparableWrapperFactory) case _ => InternalRowComparableWrapper.mergePartitions(leftPartitioning,