From 2b1589628f3ac8c74e7adfd9dd360208e6a763ea Mon Sep 17 00:00:00 2001 From: Jungtaek Lim Date: Tue, 12 Jul 2022 11:49:03 +0900 Subject: [PATCH] [SPARK-39748][SQL][SS] Include the origin logical plan for LogicalRDD if it comes from DataFrame --- .../scala/org/apache/spark/sql/Dataset.scala | 1 + .../spark/sql/execution/ExistingRDD.scala | 23 +++++++--- .../streaming/sources/ForeachBatchSink.scala | 22 ++++++++- .../sources/ForeachBatchSinkSuite.scala | 45 ++++++++++++++++++- 4 files changed, 83 insertions(+), 8 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index 39d33d80261df..f45c27d300769 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -705,6 +705,7 @@ class Dataset[T] private[sql]( LogicalRDD( logicalPlan.output, internalRdd, + None, outputPartitioning, physicalPlan.outputOrdering, isStreaming diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala index 1ab183fe843ff..bf9ef6991e3e1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala @@ -83,10 +83,16 @@ case class ExternalRDDScanExec[T]( } } -/** Logical plan node for scanning data from an RDD of InternalRow. */ +/** + * Logical plan node for scanning data from an RDD of InternalRow. + * + * It is advised to set the field `originLogicalPlan` if the RDD is directly built from DataFrame, + * as the stat can be inherited from `originLogicalPlan`. + */ case class LogicalRDD( output: Seq[Attribute], rdd: RDD[InternalRow], + originLogicalPlan: Option[LogicalPlan] = None, outputPartitioning: Partitioning = UnknownPartitioning(0), override val outputOrdering: Seq[SortOrder] = Nil, override val isStreaming: Boolean = false)(session: SparkSession) @@ -113,6 +119,7 @@ case class LogicalRDD( LogicalRDD( output.map(rewrite), rdd, + originLogicalPlan, rewrittenPartitioning, rewrittenOrdering, isStreaming @@ -121,11 +128,15 @@ case class LogicalRDD( override protected def stringArgs: Iterator[Any] = Iterator(output, isStreaming) - override def computeStats(): Statistics = Statistics( - // TODO: Instead of returning a default value here, find a way to return a meaningful size - // estimate for RDDs. See PR 1238 for more discussions. - sizeInBytes = BigInt(session.sessionState.conf.defaultSizeInBytes) - ) + override def computeStats(): Statistics = { + originLogicalPlan.map(_.stats).getOrElse { + Statistics( + // TODO: Instead of returning a default value here, find a way to return a meaningful size + // estimate for RDDs. See PR 1238 for more discussions. + sizeInBytes = BigInt(session.sessionState.conf.defaultSizeInBytes) + ) + } + } } /** Physical plan node for scanning data from an RDD of InternalRow. */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ForeachBatchSink.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ForeachBatchSink.scala index 0893875aff5d5..1c6bca241af4b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ForeachBatchSink.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ForeachBatchSink.scala @@ -19,6 +19,8 @@ package org.apache.spark.sql.execution.streaming.sources import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.execution.LogicalRDD import org.apache.spark.sql.execution.streaming.Sink import org.apache.spark.sql.streaming.DataStreamWriter @@ -27,11 +29,29 @@ class ForeachBatchSink[T](batchWriter: (Dataset[T], Long) => Unit, encoder: Expr override def addBatch(batchId: Long, data: DataFrame): Unit = { val rdd = data.queryExecution.toRdd + val executedPlan = data.queryExecution.executedPlan + val node = LogicalRDD( + data.schema.toAttributes, + rdd, + Some(eliminateWriteMarkerNode(data.queryExecution.analyzed)), + executedPlan.outputPartitioning, + executedPlan.outputOrdering)(data.sparkSession) implicit val enc = encoder - val ds = data.sparkSession.internalCreateDataFrame(rdd, data.schema).as[T] + val ds = Dataset.ofRows(data.sparkSession, node).as[T] batchWriter(ds, batchId) } + /** + * ForEachBatchSink implementation reuses the logical plan of `data` which breaks the contract + * of Sink.addBatch, which `data` should be just used to "collect" the output data. + * We have to deal with eliminating marker node here which we do this in streaming specific + * optimization rule. + */ + private def eliminateWriteMarkerNode(plan: LogicalPlan): LogicalPlan = plan match { + case node: WriteToMicroBatchDataSourceV1 => node.child + case node => node + } + override def toString(): String = "ForeachBatchSink" } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/ForeachBatchSinkSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/ForeachBatchSinkSuite.scala index ce98e2e6a5bb6..dbac4af90c07f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/ForeachBatchSinkSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/ForeachBatchSinkSuite.scala @@ -22,7 +22,8 @@ import scala.language.implicitConversions import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.encoders.RowEncoder -import org.apache.spark.sql.execution.SerializeFromObjectExec +import org.apache.spark.sql.execution.{LogicalRDD, SerializeFromObjectExec} +import org.apache.spark.sql.execution.datasources.v2.StreamingDataSourceV2Relation import org.apache.spark.sql.execution.streaming.MemoryStream import org.apache.spark.sql.functions._ import org.apache.spark.sql.streaming._ @@ -185,6 +186,48 @@ class ForeachBatchSinkSuite extends StreamTest { assertPlan(mem2, dsUntyped) } + test("Leaf node of Dataset in foreachBatch should carry over origin logical plan") { + def assertPlan[T](stream: MemoryStream[Int], ds: Dataset[T]): Unit = { + var planAsserted = false + + val writer: (Dataset[T], Long) => Unit = { case (df, _) => + df.logicalPlan.collectLeaves().head match { + case l: LogicalRDD => + assert(l.originLogicalPlan.nonEmpty, "Origin logical plan should be available in " + + "LogicalRDD") + l.originLogicalPlan.get.collectLeaves().head match { + case _: StreamingDataSourceV2Relation => // pass + case p => + fail("Expect StreamingDataSourceV2Relation in the leaf node of origin " + + s"logical plan! Actual: $p") + } + + case p => + fail(s"Expect LogicalRDD in the leaf node of Dataset! Actual: $p") + } + planAsserted = true + } + + stream.addData(1, 2, 3, 4, 5) + + val query = ds.writeStream.trigger(Trigger.Once()).foreachBatch(writer).start() + query.awaitTermination() + + assert(planAsserted, "ForeachBatch writer should be called!") + } + + // typed + val mem = MemoryStream[Int] + val ds = mem.toDS.map(_ + 1) + assertPlan(mem, ds) + + // untyped + val mem2 = MemoryStream[Int] + val dsUntyped = mem2.toDF().selectExpr("value + 1 as value") + assertPlan(mem2, dsUntyped) + } + + // ============== Helper classes and methods ================= private class ForeachBatchTester[T: Encoder](memoryStream: MemoryStream[Int]) {