Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,16 @@ case class EnsureRequirements(
case (child, distribution) =>
val numPartitions = distribution.requiredNumPartitions
.getOrElse(conf.numShufflePartitions)
ShuffleExchangeExec(distribution.createPartitioning(numPartitions), child, shuffleOrigin)
distribution match {
case _: StatefulOpClusteredDistribution =>
ShuffleExchangeExec(
distribution.createPartitioning(numPartitions), child,
REQUIRED_BY_STATEFUL_OPERATOR)

case _ =>
ShuffleExchangeExec(
distribution.createPartitioning(numPartitions), child, shuffleOrigin)
}
}

// Get the indexes of children which have specified distribution requirements and need to be
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,11 @@ case object REBALANCE_PARTITIONS_BY_NONE extends ShuffleOrigin
// the output needs to be partitioned by the given columns.
case object REBALANCE_PARTITIONS_BY_COL extends ShuffleOrigin

// Indicates that the shuffle operator was added by the internal `EnsureRequirements` rule, but
// was required by a stateful operator. The physical partitioning is static and Spark shouldn't
// change it.
case object REQUIRED_BY_STATEFUL_OPERATOR extends ShuffleOrigin

/**
* Performs a shuffle that will result in the desired partitioning.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ import org.apache.spark.sql.catalyst.streaming.InternalOutputModes.Complete
import org.apache.spark.sql.catalyst.types.DataTypeUtils.toAttributes
import org.apache.spark.sql.connector.read.InputPartition
import org.apache.spark.sql.connector.read.streaming.{Offset => OffsetV2, ReadLimit}
import org.apache.spark.sql.execution.exchange.ReusedExchangeExec
import org.apache.spark.sql.execution.exchange.{REQUIRED_BY_STATEFUL_OPERATOR, ReusedExchangeExec, ShuffleExchangeExec}
import org.apache.spark.sql.execution.streaming._
import org.apache.spark.sql.execution.streaming.sources.{MemorySink, TestForeachWriter}
import org.apache.spark.sql.functions._
Expand Down Expand Up @@ -1448,6 +1448,28 @@ class StreamingQuerySuite extends StreamTest with BeforeAndAfter with Logging wi
}
}

test("SPARK-49905 shuffle added by stateful operator should use the shuffle origin " +
"`REQUIRED_BY_STATEFUL_OPERATOR`") {
val inputData = MemoryStream[Int]

// Use the streaming aggregation as an example - all stateful operators are using the same
// distribution, named `StatefulOpClusteredDistribution`.
val df = inputData.toDF().groupBy("value").count()

testStream(df, OutputMode.Update())(
AddData(inputData, 1, 2, 3, 1, 2, 3),
CheckAnswer((1, 2), (2, 2), (3, 2)),
Execute { qe =>
val shuffleOpt = qe.lastExecution.executedPlan.collect {
case s: ShuffleExchangeExec => s
}

assert(shuffleOpt.nonEmpty, "No shuffle exchange found in the query plan")
assert(shuffleOpt.head.shuffleOrigin === REQUIRED_BY_STATEFUL_OPERATOR)
}
)
}

private def checkAppendOutputModeException(df: DataFrame): Unit = {
withTempDir { outputDir =>
withTempDir { checkpointDir =>
Expand Down