Skip to content

Commit

Permalink
wip
Browse files Browse the repository at this point in the history
  • Loading branch information
sunchao committed Jun 11, 2021
1 parent 4180692 commit 3f3ad6c
Show file tree
Hide file tree
Showing 23 changed files with 2,090 additions and 1,959 deletions.
1 change: 1 addition & 0 deletions python/pyspark/sql/tests/test_pandas_cogrouped_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,6 +225,7 @@ def right_assign_key(key, l, r):
.groupby('id') \
.cogroup(right.groupby('id')) \
.applyInPandas(right_assign_key, 'id long, k int, v int, key long') \
.sort(['id']) \
.toPandas()

expected = left.toPandas() if isLeft else right.toPandas()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@

package org.apache.spark.sql.catalyst.plans.physical

import scala.collection.mutable

import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.types.{DataType, IntegerType}

Expand Down Expand Up @@ -87,31 +89,6 @@ case class ClusteredDistribution(
}
}

/**
* Represents data where tuples have been clustered according to the hash of the given
* `expressions`. The hash function is defined as `HashPartitioning.partitionIdExpression`, so only
* [[HashPartitioning]] can satisfy this distribution.
*
* This is a strictly stronger guarantee than [[ClusteredDistribution]]. Given a tuple and the
* number of partitions, this distribution strictly requires which partition the tuple should be in.
*/
case class HashClusteredDistribution(
expressions: Seq[Expression],
requiredNumPartitions: Option[Int] = None) extends Distribution {
require(
expressions != Nil,
"The expressions for hash of a HashClusteredDistribution should not be Nil. " +
"An AllTuples should be used to represent a distribution that only has " +
"a single partition.")

override def createPartitioning(numPartitions: Int): Partitioning = {
assert(requiredNumPartitions.isEmpty || requiredNumPartitions.get == numPartitions,
s"This HashClusteredDistribution requires ${requiredNumPartitions.get} partitions, but " +
s"the actual number of partitions is $numPartitions.")
HashPartitioning(expressions, numPartitions)
}
}

/**
* Represents data where tuples have been ordered according to the `ordering`
* [[Expression Expressions]]. Its requirement is defined as the following:
Expand Down Expand Up @@ -171,6 +148,24 @@ trait Partitioning {
required.requiredNumPartitions.forall(_ == numPartitions) && satisfies0(required)
}

/**
* Returns true iff this partitioning is compatible with `other`. If two [[Partitioning]]s can
* satisfy their respective required distribution (via [[satisfies]]), and are compatible with
* each other, then their partitions are considered to be co-partitioned, which will allow Spark
* to eliminate data shuffle whenever necessary.
*
* Note: implementor should make sure the method satisfies the equivalence relation, that is,
* the implementation should be reflexive, symmetric and transitive.
*/
final def isCompatibleWith(
distribution: Distribution,
other: Partitioning,
otherDistribution: Distribution): Boolean = other match {
case PartitioningCollection(others) =>
others.exists(_.isCompatibleWith(otherDistribution, this, distribution))
case _ => isCompatibleWith0(distribution, other, otherDistribution)
}

/**
* The actual method that defines whether this [[Partitioning]] can satisfy the given
* [[Distribution]], after the `numPartitions` check.
Expand All @@ -184,6 +179,15 @@ trait Partitioning {
case AllTuples => numPartitions == 1
case _ => false
}

/**
* The actual method that defines whether this [[Partitioning]] is compatible with `other`. In
* default this always return false.
*/
protected def isCompatibleWith0(
distribution: Distribution,
other: Partitioning,
otherDistribution: Distribution): Boolean = false
}

case class UnknownPartitioning(numPartitions: Int) extends Partitioning
Expand All @@ -202,6 +206,14 @@ case object SinglePartition extends Partitioning {
case _: BroadcastDistribution => false
case _ => true
}

override def isCompatibleWith0(
distribution: Distribution,
other: Partitioning,
otherDistribution: Distribution): Boolean = other match {
case SinglePartition => true
case _ => false
}
}

/**
Expand All @@ -219,17 +231,52 @@ case class HashPartitioning(expressions: Seq[Expression], numPartitions: Int)
override def satisfies0(required: Distribution): Boolean = {
super.satisfies0(required) || {
required match {
case h: HashClusteredDistribution =>
expressions.length == h.expressions.length && expressions.zip(h.expressions).forall {
case (l, r) => l.semanticEquals(r)
}
case ClusteredDistribution(requiredClustering, _) =>
expressions.forall(x => requiredClustering.exists(_.semanticEquals(x)))
case _ => false
}
}
}

override def isCompatibleWith0(
distribution: Distribution,
other: Partitioning,
otherDistribution: Distribution): Boolean = (distribution, otherDistribution) match {
case (thisDist: ClusteredDistribution, thatDist: ClusteredDistribution) =>
// For each expression in the `HashPartitioning` that has occurrences in
// `ClusteredDistribution`, returns a mapping from its index in the partitioning to the
// indexes where it appears in the distribution.
// For instance, if `partitioning` is `[a, b]` and `distribution is `[a, a, b]`, then the
// result mapping could be `{ 0 -> (0, 1), 1 -> (2) }`.
def indexMap(
distribution: ClusteredDistribution,
partitioning: HashPartitioning): mutable.Map[Int, mutable.BitSet] = {
val result = mutable.Map.empty[Int, mutable.BitSet]
val expressionToIndex = partitioning.expressions.zipWithIndex.toMap
distribution.clustering.zipWithIndex.foreach { case (distKey, distKeyIdx) =>
expressionToIndex.find { case (partKey, _) => partKey.semanticEquals(distKey) }.forall {
case (_, partIdx) =>
result.getOrElseUpdate(partIdx, mutable.BitSet.empty).add(distKeyIdx)
}
}
result
}

other match {
case that @ HashPartitioning(_, _) =>
// we need to check:
// 1. both partitioning have the same number of expressions
// 2. each corresponding expression in both partitioning is used in the same positions
// of the corresponding distribution.
this.expressions.length == that.expressions.length &&
indexMap(thisDist, this) == indexMap(thatDist, that)
case _ =>
false
}
case _ =>
false
}

/**
* Returns an expression that will produce a valid partition ID(i.e. non-negative and is less
* than numPartitions) based on hashing expressions.
Expand Down Expand Up @@ -330,6 +377,12 @@ case class PartitioningCollection(partitionings: Seq[Partitioning])
override def satisfies0(required: Distribution): Boolean =
partitionings.exists(_.satisfies(required))

override def isCompatibleWith0(
distribution: Distribution,
other: Partitioning,
otherDistribution: Distribution): Boolean =
partitionings.exists(_.isCompatibleWith(distribution, other, otherDistribution))

override def toString: String = {
partitionings.map(_.toString).mkString("(", " or ", ")")
}
Expand Down

0 comments on commit 3f3ad6c

Please sign in to comment.