Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -428,8 +428,11 @@ case class KeyGroupedPartitioning(
}

lazy val uniquePartitionValues: Seq[InternalRow] = {
val internalRowComparableFactory =
InternalRowComparableWrapper.getInternalRowComparableWrapperFactory(
expressions.map(_.dataType))
partitionValues
.map(InternalRowComparableWrapper(_, expressions))
.map(internalRowComparableFactory)
.distinct
.map(_.row)
}
Expand All @@ -448,11 +451,14 @@ object KeyGroupedPartitioning {
val projectedPartitionValues = partitionValues.map(project(expressions, projectionPositions, _))
val projectedOriginalPartitionValues =
originalPartitionValues.map(project(expressions, projectionPositions, _))
val internalRowComparableFactory =
InternalRowComparableWrapper.getInternalRowComparableWrapperFactory(
projectedExpressions.map(_.dataType))

val finalPartitionValues = projectedPartitionValues
.map(InternalRowComparableWrapper(_, projectedExpressions))
.distinct
.map(_.row)
.map(internalRowComparableFactory)
.distinct
.map(_.row)

KeyGroupedPartitioning(projectedExpressions, finalPartitionValues.length,
finalPartitionValues, projectedOriginalPartitionValues)
Expand Down Expand Up @@ -867,12 +873,14 @@ case class KeyGroupedShuffleSpec(
// transform functions.
// 4. the partition values from both sides are following the same order.
case otherSpec @ KeyGroupedShuffleSpec(otherPartitioning, otherDistribution, _) =>
lazy val internalRowComparableFactory =
InternalRowComparableWrapper.getInternalRowComparableWrapperFactory(
partitioning.expressions.map(_.dataType))
distribution.clustering.length == otherDistribution.clustering.length &&
numPartitions == other.numPartitions && areKeysCompatible(otherSpec) &&
partitioning.partitionValues.zip(otherPartitioning.partitionValues).forall {
case (left, right) =>
InternalRowComparableWrapper(left, partitioning.expressions)
.equals(InternalRowComparableWrapper(right, partitioning.expressions))
internalRowComparableFactory(left).equals(internalRowComparableFactory(right))
}
case ShuffleSpecCollection(specs) =>
specs.exists(isCompatibleWith)
Expand Down Expand Up @@ -957,15 +965,16 @@ case class KeyGroupedShuffleSpec(
object KeyGroupedShuffleSpec {
def reducePartitionValue(
row: InternalRow,
expressions: Seq[Expression],
reducers: Seq[Option[Reducer[_, _]]]):
InternalRowComparableWrapper = {
val partitionVals = row.toSeq(expressions.map(_.dataType))
reducers: Seq[Option[Reducer[_, _]]],
dataTypes: Seq[DataType],
internalRowComparableWrapperFactory: InternalRow => InternalRowComparableWrapper
): InternalRowComparableWrapper = {
val partitionVals = row.toSeq(dataTypes)
val reducedRow = partitionVals.zip(reducers).map{
case (v, Some(reducer: Reducer[Any, Any])) => reducer.reduce(v)
case (v, _) => v
}.toArray
InternalRowComparableWrapper(new GenericInternalRow(reducedRow), expressions)
internalRowComparableWrapperFactory(new GenericInternalRow(reducedRow))
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst.util
import scala.collection.mutable

import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{Expression, Murmur3HashFunction, RowOrdering}
import org.apache.spark.sql.catalyst.expressions.{BaseOrdering, Expression, Murmur3HashFunction, RowOrdering}
import org.apache.spark.sql.connector.read.{HasPartitionKey, InputPartition}
import org.apache.spark.sql.types.{DataType, StructField, StructType}
import org.apache.spark.util.NonFateSharingCache
Expand All @@ -33,11 +33,23 @@ import org.apache.spark.util.NonFateSharingCache
*
* @param dataTypes the data types for the row
*/
class InternalRowComparableWrapper(val row: InternalRow, val dataTypes: Seq[DataType]) {
import InternalRowComparableWrapper._
class InternalRowComparableWrapper private (
val row: InternalRow,
val dataTypes: Seq[DataType],
val structType: StructType,
val ordering: BaseOrdering) {

private val structType = structTypeCache.get(dataTypes)
private val ordering = orderingCache.get(dataTypes)
/**
* Previous constructor for binary compatibility. Prefer using
* `getInternalRowComparableWrapperFactory` for the creation of InternalRowComparableWrapper's in
* hot paths to avoid excessive cache lookups.
*/
@deprecated
def this(row: InternalRow, dataTypes: Seq[DataType]) = this(
row,
dataTypes,
InternalRowComparableWrapper.structTypeCache.get(dataTypes),
InternalRowComparableWrapper.orderingCache.get(dataTypes))

override def hashCode(): Int = Murmur3HashFunction.hash(
row,
Expand Down Expand Up @@ -96,12 +108,14 @@ object InternalRowComparableWrapper {
intersect: Boolean = false): Seq[InternalRowComparableWrapper] = {
val partitionDataTypes = partitionExpression.map(_.dataType)
val leftPartitionSet = new mutable.HashSet[InternalRowComparableWrapper]
val internalRowComparableWrapperFactory =
getInternalRowComparableWrapperFactory(partitionDataTypes)
leftPartitioning
.map(new InternalRowComparableWrapper(_, partitionDataTypes))
.map(internalRowComparableWrapperFactory)
.foreach(partition => leftPartitionSet.add(partition))
val rightPartitionSet = new mutable.HashSet[InternalRowComparableWrapper]
rightPartitioning
.map(new InternalRowComparableWrapper(_, partitionDataTypes))
.map(internalRowComparableWrapperFactory)
.foreach(partition => rightPartitionSet.add(partition))

val result = if (intersect) {
Expand All @@ -111,4 +125,12 @@ object InternalRowComparableWrapper {
}
result.toSeq
}

/** Creates a shared factory method for a given row schema to avoid excessive cache lookups. */
def getInternalRowComparableWrapperFactory(
dataTypes: Seq[DataType]): InternalRow => InternalRowComparableWrapper = {
val structType = structTypeCache.get(dataTypes)
val ordering = orderingCache.get(dataTypes)
row: InternalRow => new InternalRowComparableWrapper(row, dataTypes, structType, ordering)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,11 @@ object InternalRowComparableWrapperBenchmark extends BenchmarkBase {
val benchmark = new Benchmark("internal row comparable wrapper", partitionNum, output = output)

benchmark.addCase("toSet") { _ =>
val internalRowComparableWrapperFactory =
InternalRowComparableWrapper.getInternalRowComparableWrapperFactory(
Seq(IntegerType, IntegerType))
val distinct = partitions
.map(new InternalRowComparableWrapper(_, Seq(IntegerType, IntegerType)))
.map(internalRowComparableWrapperFactory)
.toSet
assert(distinct.size == bucketNum)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,10 +49,14 @@ trait KeyGroupedPartitionedScan[T] {
}
case None =>
spjParams.joinKeyPositions match {
case Some(projectionPositions) => basePartitioning.partitionValues.map { r =>
case Some(projectionPositions) =>
val internalRowComparableWrapperFactory =
InternalRowComparableWrapper.getInternalRowComparableWrapperFactory(
expressions.map(_.dataType))
basePartitioning.partitionValues.map { r =>
val projectedRow = KeyGroupedPartitioning.project(expressions,
projectionPositions, r)
InternalRowComparableWrapper(projectedRow, expressions)
internalRowComparableWrapperFactory(projectedRow)
}.distinct.map(_.row)
case _ => basePartitioning.partitionValues
}
Expand Down Expand Up @@ -83,11 +87,14 @@ trait KeyGroupedPartitionedScan[T] {
val (groupedPartitions, partExpressions) = spjParams.joinKeyPositions match {
case Some(projectPositions) =>
val projectedExpressions = projectPositions.map(i => expressions(i))
val internalRowComparableWrapperFactory =
InternalRowComparableWrapper.getInternalRowComparableWrapperFactory(
projectedExpressions.map(_.dataType))
val parts = filteredPartitions.flatten.groupBy(part => {
val row = partitionValueAccessor(part)
val projectedRow = KeyGroupedPartitioning.project(
expressions, projectPositions, row)
InternalRowComparableWrapper(projectedRow, projectedExpressions)
internalRowComparableWrapperFactory(projectedRow)
}).map { case (wrapper, splits) => (wrapper.row, splits) }.toSeq
(parts, projectedExpressions)
case _ =>
Expand All @@ -99,10 +106,14 @@ trait KeyGroupedPartitionedScan[T] {
}

// Also re-group the partitions if we are reducing compatible partition expressions
val partitionDataTypes = partExpressions.map(_.dataType)
val internalRowComparableWrapperFactory =
InternalRowComparableWrapper.getInternalRowComparableWrapperFactory(partitionDataTypes)
val finalGroupedPartitions = spjParams.reducers match {
case Some(reducers) =>
val result = groupedPartitions.groupBy { case (row, _) =>
KeyGroupedShuffleSpec.reducePartitionValue(row, partExpressions, reducers)
KeyGroupedShuffleSpec.reducePartitionValue(
row, reducers, partitionDataTypes, internalRowComparableWrapperFactory)
}.map { case (wrapper, splits) => (wrapper.row, splits.flatMap(_._2)) }.toSeq
val rowOrdering = RowOrdering.createNaturalAscendingOrdering(
partExpressions.map(_.dataType))
Expand All @@ -118,17 +129,15 @@ trait KeyGroupedPartitionedScan[T] {
// should contain.
val commonPartValuesMap = spjParams.commonPartitionValues
.get
.map(t => (InternalRowComparableWrapper(t._1, partExpressions), t._2))
.map(t => (internalRowComparableWrapperFactory(t._1), t._2))
.toMap
val filteredGroupedPartitions = finalGroupedPartitions.filter {
case (partValues, _) =>
commonPartValuesMap.keySet.contains(
InternalRowComparableWrapper(partValues, partExpressions))
commonPartValuesMap.keySet.contains(internalRowComparableWrapperFactory(partValues))
}
val nestGroupedPartitions = filteredGroupedPartitions.map { case (partValue, splits) =>
// `commonPartValuesMap` should contain the part value since it's the super set.
val numSplits = commonPartValuesMap
.get(InternalRowComparableWrapper(partValue, partExpressions))
val numSplits = commonPartValuesMap.get(internalRowComparableWrapperFactory(partValue))
assert(numSplits.isDefined, s"Partition value $partValue does not exist in " +
"common partition values from Spark plan")

Expand All @@ -143,7 +152,7 @@ trait KeyGroupedPartitionedScan[T] {
// sides of a join will have the same number of partitions & splits.
splits.map(Seq(_)).padTo(numSplits.get, Seq.empty)
}
(InternalRowComparableWrapper(partValue, partExpressions), newSplits)
(internalRowComparableWrapperFactory(partValue), newSplits)
}

// Now fill missing partition keys with empty partitions
Expand All @@ -152,14 +161,14 @@ trait KeyGroupedPartitionedScan[T] {
case (partValue, numSplits) =>
// Use empty partition for those partition values that are not present.
partitionMapping.getOrElse(
InternalRowComparableWrapper(partValue, partExpressions),
internalRowComparableWrapperFactory(partValue),
Seq.fill(numSplits)(Seq.empty))
}
} else {
// either `commonPartitionValues` is not defined, or it is defined but
// `applyPartialClustering` is false.
val partitionMapping = finalGroupedPartitions.map { case (partValue, splits) =>
InternalRowComparableWrapper(partValue, partExpressions) -> splits
internalRowComparableWrapperFactory(partValue) -> splits
}.toMap

// In case `commonPartitionValues` is not defined (e.g., SPJ is not used), there
Expand All @@ -168,8 +177,7 @@ trait KeyGroupedPartitionedScan[T] {
// partition values here so that grouped partitions won't get duplicated.
p.uniquePartitionValues.map { partValue =>
// Use empty partition for those partition values that are not present
partitionMapping.getOrElse(
InternalRowComparableWrapper(partValue, partExpressions), Seq.empty)
partitionMapping.getOrElse(internalRowComparableWrapperFactory(partValue), Seq.empty)
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -579,15 +579,18 @@ case class EnsureRequirements(
// In partially clustered distribution, we should use un-grouped partition values
val spec = if (replicateLeftSide) rightSpec else leftSpec
val partValues = spec.partitioning.originalPartitionValues
val internalRowComparableWrapperFactory =
InternalRowComparableWrapper.getInternalRowComparableWrapperFactory(
partitionExprs.map(_.dataType))

val numExpectedPartitions = partValues
.map(InternalRowComparableWrapper(_, partitionExprs))
.map(internalRowComparableWrapperFactory)
.groupBy(identity)
.transform((_, v) => v.size)

mergedPartValues = mergedPartValues.map { case (partVal, numParts) =>
(partVal, numExpectedPartitions.getOrElse(
InternalRowComparableWrapper(partVal, partitionExprs), numParts))
internalRowComparableWrapperFactory(partVal), numParts))
}

logInfo(log"After applying partially clustered distribution, there are " +
Expand Down Expand Up @@ -679,9 +682,15 @@ case class EnsureRequirements(
expressions: Seq[Expression],
reducers: Option[Seq[Option[Reducer[_, _]]]]) = {
reducers match {
case Some(reducers) => partValues.map { row =>
KeyGroupedShuffleSpec.reducePartitionValue(row, expressions, reducers)
}.distinct.map(_.row)
case Some(reducers) =>
val partitionDataTypes = expressions.map(_.dataType)
val internalRowComparableWrapperFactory =
InternalRowComparableWrapper.getInternalRowComparableWrapperFactory(
partitionDataTypes)
partValues.map { row =>
KeyGroupedShuffleSpec.reducePartitionValue(
row, reducers, partitionDataTypes, internalRowComparableWrapperFactory)
}.distinct.map(_.row)
case _ => partValues
}
}
Expand Down Expand Up @@ -737,15 +746,16 @@ case class EnsureRequirements(
rightPartitioning: Seq[InternalRow],
partitionExpression: Seq[Expression],
joinType: JoinType): Seq[InternalRow] = {
val internalRowComparableWrapperFactory =
InternalRowComparableWrapper.getInternalRowComparableWrapperFactory(
partitionExpression.map(_.dataType))

val merged = if (SQLConf.get.getConf(SQLConf.V2_BUCKETING_PARTITION_FILTER_ENABLED)) {
joinType match {
case Inner => InternalRowComparableWrapper.mergePartitions(
leftPartitioning, rightPartitioning, partitionExpression, intersect = true)
case LeftOuter => leftPartitioning.map(
InternalRowComparableWrapper(_, partitionExpression))
case RightOuter => rightPartitioning.map(
InternalRowComparableWrapper(_, partitionExpression))
case LeftOuter => leftPartitioning.map(internalRowComparableWrapperFactory)
case RightOuter => rightPartitioning.map(internalRowComparableWrapperFactory)
case _ => InternalRowComparableWrapper.mergePartitions(leftPartitioning,
rightPartitioning, partitionExpression)
}
Expand Down