diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala index 1cbb49c7a1f7..5a8505dc6992 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala @@ -428,8 +428,11 @@ case class KeyGroupedPartitioning( } lazy val uniquePartitionValues: Seq[InternalRow] = { + val internalRowComparableFactory = + InternalRowComparableWrapper.getInternalRowComparableWrapperFactory( + expressions.map(_.dataType)) partitionValues - .map(InternalRowComparableWrapper(_, expressions)) + .map(internalRowComparableFactory) .distinct .map(_.row) } @@ -448,11 +451,14 @@ object KeyGroupedPartitioning { val projectedPartitionValues = partitionValues.map(project(expressions, projectionPositions, _)) val projectedOriginalPartitionValues = originalPartitionValues.map(project(expressions, projectionPositions, _)) + val internalRowComparableFactory = + InternalRowComparableWrapper.getInternalRowComparableWrapperFactory( + projectedExpressions.map(_.dataType)) val finalPartitionValues = projectedPartitionValues - .map(InternalRowComparableWrapper(_, projectedExpressions)) - .distinct - .map(_.row) + .map(internalRowComparableFactory) + .distinct + .map(_.row) KeyGroupedPartitioning(projectedExpressions, finalPartitionValues.length, finalPartitionValues, projectedOriginalPartitionValues) @@ -867,12 +873,14 @@ case class KeyGroupedShuffleSpec( // transform functions. // 4. the partition values from both sides are following the same order. case otherSpec @ KeyGroupedShuffleSpec(otherPartitioning, otherDistribution, _) => + lazy val internalRowComparableFactory = + InternalRowComparableWrapper.getInternalRowComparableWrapperFactory( + partitioning.expressions.map(_.dataType)) distribution.clustering.length == otherDistribution.clustering.length && numPartitions == other.numPartitions && areKeysCompatible(otherSpec) && partitioning.partitionValues.zip(otherPartitioning.partitionValues).forall { case (left, right) => - InternalRowComparableWrapper(left, partitioning.expressions) - .equals(InternalRowComparableWrapper(right, partitioning.expressions)) + internalRowComparableFactory(left).equals(internalRowComparableFactory(right)) } case ShuffleSpecCollection(specs) => specs.exists(isCompatibleWith) @@ -957,15 +965,16 @@ case class KeyGroupedShuffleSpec( object KeyGroupedShuffleSpec { def reducePartitionValue( row: InternalRow, - expressions: Seq[Expression], - reducers: Seq[Option[Reducer[_, _]]]): - InternalRowComparableWrapper = { - val partitionVals = row.toSeq(expressions.map(_.dataType)) + reducers: Seq[Option[Reducer[_, _]]], + dataTypes: Seq[DataType], + internalRowComparableWrapperFactory: InternalRow => InternalRowComparableWrapper + ): InternalRowComparableWrapper = { + val partitionVals = row.toSeq(dataTypes) val reducedRow = partitionVals.zip(reducers).map{ case (v, Some(reducer: Reducer[Any, Any])) => reducer.reduce(v) case (v, _) => v }.toArray - InternalRowComparableWrapper(new GenericInternalRow(reducedRow), expressions) + internalRowComparableWrapperFactory(new GenericInternalRow(reducedRow)) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/InternalRowComparableWrapper.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/InternalRowComparableWrapper.scala index ba3d65fea027..b9935d40ed98 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/InternalRowComparableWrapper.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/InternalRowComparableWrapper.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst.util import scala.collection.mutable import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.{Expression, Murmur3HashFunction, RowOrdering} +import org.apache.spark.sql.catalyst.expressions.{BaseOrdering, Expression, Murmur3HashFunction, RowOrdering} import org.apache.spark.sql.connector.read.{HasPartitionKey, InputPartition} import org.apache.spark.sql.types.{DataType, StructField, StructType} import org.apache.spark.util.NonFateSharingCache @@ -33,11 +33,23 @@ import org.apache.spark.util.NonFateSharingCache * * @param dataTypes the data types for the row */ -class InternalRowComparableWrapper(val row: InternalRow, val dataTypes: Seq[DataType]) { - import InternalRowComparableWrapper._ +class InternalRowComparableWrapper private ( + val row: InternalRow, + val dataTypes: Seq[DataType], + val structType: StructType, + val ordering: BaseOrdering) { - private val structType = structTypeCache.get(dataTypes) - private val ordering = orderingCache.get(dataTypes) + /** + * Previous constructor for binary compatibility. Prefer using + * `getInternalRowComparableWrapperFactory` for the creation of InternalRowComparableWrapper's in + * hot paths to avoid excessive cache lookups. + */ + @deprecated + def this(row: InternalRow, dataTypes: Seq[DataType]) = this( + row, + dataTypes, + InternalRowComparableWrapper.structTypeCache.get(dataTypes), + InternalRowComparableWrapper.orderingCache.get(dataTypes)) override def hashCode(): Int = Murmur3HashFunction.hash( row, @@ -96,12 +108,14 @@ object InternalRowComparableWrapper { intersect: Boolean = false): Seq[InternalRowComparableWrapper] = { val partitionDataTypes = partitionExpression.map(_.dataType) val leftPartitionSet = new mutable.HashSet[InternalRowComparableWrapper] + val internalRowComparableWrapperFactory = + getInternalRowComparableWrapperFactory(partitionDataTypes) leftPartitioning - .map(new InternalRowComparableWrapper(_, partitionDataTypes)) + .map(internalRowComparableWrapperFactory) .foreach(partition => leftPartitionSet.add(partition)) val rightPartitionSet = new mutable.HashSet[InternalRowComparableWrapper] rightPartitioning - .map(new InternalRowComparableWrapper(_, partitionDataTypes)) + .map(internalRowComparableWrapperFactory) .foreach(partition => rightPartitionSet.add(partition)) val result = if (intersect) { @@ -111,4 +125,12 @@ object InternalRowComparableWrapper { } result.toSeq } + + /** Creates a shared factory method for a given row schema to avoid excessive cache lookups. */ + def getInternalRowComparableWrapperFactory( + dataTypes: Seq[DataType]): InternalRow => InternalRowComparableWrapper = { + val structType = structTypeCache.get(dataTypes) + val ordering = orderingCache.get(dataTypes) + row: InternalRow => new InternalRowComparableWrapper(row, dataTypes, structType, ordering) + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/InternalRowComparableWrapperBenchmark.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/InternalRowComparableWrapperBenchmark.scala index f3dd232129e8..764dac35f673 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/InternalRowComparableWrapperBenchmark.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/InternalRowComparableWrapperBenchmark.scala @@ -48,8 +48,11 @@ object InternalRowComparableWrapperBenchmark extends BenchmarkBase { val benchmark = new Benchmark("internal row comparable wrapper", partitionNum, output = output) benchmark.addCase("toSet") { _ => + val internalRowComparableWrapperFactory = + InternalRowComparableWrapper.getInternalRowComparableWrapperFactory( + Seq(IntegerType, IntegerType)) val distinct = partitions - .map(new InternalRowComparableWrapper(_, Seq(IntegerType, IntegerType))) + .map(internalRowComparableWrapperFactory) .toSet assert(distinct.size == bucketNum) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/KeyGroupedPartitionedScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/KeyGroupedPartitionedScan.scala index 5a789179219a..10a6aaa2e185 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/KeyGroupedPartitionedScan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/KeyGroupedPartitionedScan.scala @@ -49,10 +49,14 @@ trait KeyGroupedPartitionedScan[T] { } case None => spjParams.joinKeyPositions match { - case Some(projectionPositions) => basePartitioning.partitionValues.map { r => + case Some(projectionPositions) => + val internalRowComparableWrapperFactory = + InternalRowComparableWrapper.getInternalRowComparableWrapperFactory( + expressions.map(_.dataType)) + basePartitioning.partitionValues.map { r => val projectedRow = KeyGroupedPartitioning.project(expressions, projectionPositions, r) - InternalRowComparableWrapper(projectedRow, expressions) + internalRowComparableWrapperFactory(projectedRow) }.distinct.map(_.row) case _ => basePartitioning.partitionValues } @@ -83,11 +87,14 @@ trait KeyGroupedPartitionedScan[T] { val (groupedPartitions, partExpressions) = spjParams.joinKeyPositions match { case Some(projectPositions) => val projectedExpressions = projectPositions.map(i => expressions(i)) + val internalRowComparableWrapperFactory = + InternalRowComparableWrapper.getInternalRowComparableWrapperFactory( + projectedExpressions.map(_.dataType)) val parts = filteredPartitions.flatten.groupBy(part => { val row = partitionValueAccessor(part) val projectedRow = KeyGroupedPartitioning.project( expressions, projectPositions, row) - InternalRowComparableWrapper(projectedRow, projectedExpressions) + internalRowComparableWrapperFactory(projectedRow) }).map { case (wrapper, splits) => (wrapper.row, splits) }.toSeq (parts, projectedExpressions) case _ => @@ -99,10 +106,14 @@ trait KeyGroupedPartitionedScan[T] { } // Also re-group the partitions if we are reducing compatible partition expressions + val partitionDataTypes = partExpressions.map(_.dataType) + val internalRowComparableWrapperFactory = + InternalRowComparableWrapper.getInternalRowComparableWrapperFactory(partitionDataTypes) val finalGroupedPartitions = spjParams.reducers match { case Some(reducers) => val result = groupedPartitions.groupBy { case (row, _) => - KeyGroupedShuffleSpec.reducePartitionValue(row, partExpressions, reducers) + KeyGroupedShuffleSpec.reducePartitionValue( + row, reducers, partitionDataTypes, internalRowComparableWrapperFactory) }.map { case (wrapper, splits) => (wrapper.row, splits.flatMap(_._2)) }.toSeq val rowOrdering = RowOrdering.createNaturalAscendingOrdering( partExpressions.map(_.dataType)) @@ -118,17 +129,15 @@ trait KeyGroupedPartitionedScan[T] { // should contain. val commonPartValuesMap = spjParams.commonPartitionValues .get - .map(t => (InternalRowComparableWrapper(t._1, partExpressions), t._2)) + .map(t => (internalRowComparableWrapperFactory(t._1), t._2)) .toMap val filteredGroupedPartitions = finalGroupedPartitions.filter { case (partValues, _) => - commonPartValuesMap.keySet.contains( - InternalRowComparableWrapper(partValues, partExpressions)) + commonPartValuesMap.keySet.contains(internalRowComparableWrapperFactory(partValues)) } val nestGroupedPartitions = filteredGroupedPartitions.map { case (partValue, splits) => // `commonPartValuesMap` should contain the part value since it's the super set. - val numSplits = commonPartValuesMap - .get(InternalRowComparableWrapper(partValue, partExpressions)) + val numSplits = commonPartValuesMap.get(internalRowComparableWrapperFactory(partValue)) assert(numSplits.isDefined, s"Partition value $partValue does not exist in " + "common partition values from Spark plan") @@ -143,7 +152,7 @@ trait KeyGroupedPartitionedScan[T] { // sides of a join will have the same number of partitions & splits. splits.map(Seq(_)).padTo(numSplits.get, Seq.empty) } - (InternalRowComparableWrapper(partValue, partExpressions), newSplits) + (internalRowComparableWrapperFactory(partValue), newSplits) } // Now fill missing partition keys with empty partitions @@ -152,14 +161,14 @@ trait KeyGroupedPartitionedScan[T] { case (partValue, numSplits) => // Use empty partition for those partition values that are not present. partitionMapping.getOrElse( - InternalRowComparableWrapper(partValue, partExpressions), + internalRowComparableWrapperFactory(partValue), Seq.fill(numSplits)(Seq.empty)) } } else { // either `commonPartitionValues` is not defined, or it is defined but // `applyPartialClustering` is false. val partitionMapping = finalGroupedPartitions.map { case (partValue, splits) => - InternalRowComparableWrapper(partValue, partExpressions) -> splits + internalRowComparableWrapperFactory(partValue) -> splits }.toMap // In case `commonPartitionValues` is not defined (e.g., SPJ is not used), there @@ -168,8 +177,7 @@ trait KeyGroupedPartitionedScan[T] { // partition values here so that grouped partitions won't get duplicated. p.uniquePartitionValues.map { partValue => // Use empty partition for those partition values that are not present - partitionMapping.getOrElse( - InternalRowComparableWrapper(partValue, partExpressions), Seq.empty) + partitionMapping.getOrElse(internalRowComparableWrapperFactory(partValue), Seq.empty) } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala index b97d765afcf7..82b0d786f283 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala @@ -579,15 +579,18 @@ case class EnsureRequirements( // In partially clustered distribution, we should use un-grouped partition values val spec = if (replicateLeftSide) rightSpec else leftSpec val partValues = spec.partitioning.originalPartitionValues + val internalRowComparableWrapperFactory = + InternalRowComparableWrapper.getInternalRowComparableWrapperFactory( + partitionExprs.map(_.dataType)) val numExpectedPartitions = partValues - .map(InternalRowComparableWrapper(_, partitionExprs)) + .map(internalRowComparableWrapperFactory) .groupBy(identity) .transform((_, v) => v.size) mergedPartValues = mergedPartValues.map { case (partVal, numParts) => (partVal, numExpectedPartitions.getOrElse( - InternalRowComparableWrapper(partVal, partitionExprs), numParts)) + internalRowComparableWrapperFactory(partVal), numParts)) } logInfo(log"After applying partially clustered distribution, there are " + @@ -679,9 +682,15 @@ case class EnsureRequirements( expressions: Seq[Expression], reducers: Option[Seq[Option[Reducer[_, _]]]]) = { reducers match { - case Some(reducers) => partValues.map { row => - KeyGroupedShuffleSpec.reducePartitionValue(row, expressions, reducers) - }.distinct.map(_.row) + case Some(reducers) => + val partitionDataTypes = expressions.map(_.dataType) + val internalRowComparableWrapperFactory = + InternalRowComparableWrapper.getInternalRowComparableWrapperFactory( + partitionDataTypes) + partValues.map { row => + KeyGroupedShuffleSpec.reducePartitionValue( + row, reducers, partitionDataTypes, internalRowComparableWrapperFactory) + }.distinct.map(_.row) case _ => partValues } } @@ -737,15 +746,16 @@ case class EnsureRequirements( rightPartitioning: Seq[InternalRow], partitionExpression: Seq[Expression], joinType: JoinType): Seq[InternalRow] = { + val internalRowComparableWrapperFactory = + InternalRowComparableWrapper.getInternalRowComparableWrapperFactory( + partitionExpression.map(_.dataType)) val merged = if (SQLConf.get.getConf(SQLConf.V2_BUCKETING_PARTITION_FILTER_ENABLED)) { joinType match { case Inner => InternalRowComparableWrapper.mergePartitions( leftPartitioning, rightPartitioning, partitionExpression, intersect = true) - case LeftOuter => leftPartitioning.map( - InternalRowComparableWrapper(_, partitionExpression)) - case RightOuter => rightPartitioning.map( - InternalRowComparableWrapper(_, partitionExpression)) + case LeftOuter => leftPartitioning.map(internalRowComparableWrapperFactory) + case RightOuter => rightPartitioning.map(internalRowComparableWrapperFactory) case _ => InternalRowComparableWrapper.mergePartitions(leftPartitioning, rightPartitioning, partitionExpression) }