Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
Original file line number Diff line number Diff line change
Expand Up @@ -705,6 +705,7 @@ class Dataset[T] private[sql](
LogicalRDD(
logicalPlan.output,
internalRdd,
None,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why not set the logical plan here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That is a silly missing point. Thanks for finding this out!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

outputPartitioning,
physicalPlan.outputOrdering,
isStreaming
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -83,10 +83,16 @@ case class ExternalRDDScanExec[T](
}
}

/** Logical plan node for scanning data from an RDD of InternalRow. */
/**
* Logical plan node for scanning data from an RDD of InternalRow.
*
* It is advised to set the field `originLogicalPlan` if the RDD is directly built from DataFrame,
* as the stat can be inherited from `originLogicalPlan`.
*/
case class LogicalRDD(
output: Seq[Attribute],
rdd: RDD[InternalRow],
originLogicalPlan: Option[LogicalPlan] = None,
outputPartitioning: Partitioning = UnknownPartitioning(0),
override val outputOrdering: Seq[SortOrder] = Nil,
override val isStreaming: Boolean = false)(session: SparkSession)
Expand All @@ -113,6 +119,7 @@ case class LogicalRDD(
LogicalRDD(
output.map(rewrite),
rdd,
originLogicalPlan,
rewrittenPartitioning,
rewrittenOrdering,
isStreaming
Expand All @@ -121,11 +128,15 @@ case class LogicalRDD(

override protected def stringArgs: Iterator[Any] = Iterator(output, isStreaming)

override def computeStats(): Statistics = Statistics(
// TODO: Instead of returning a default value here, find a way to return a meaningful size
// estimate for RDDs. See PR 1238 for more discussions.
sizeInBytes = BigInt(session.sessionState.conf.defaultSizeInBytes)
)
override def computeStats(): Statistics = {
originLogicalPlan.map(_.stats).getOrElse {
Statistics(
// TODO: Instead of returning a default value here, find a way to return a meaningful size
// estimate for RDDs. See PR 1238 for more discussions.
sizeInBytes = BigInt(session.sessionState.conf.defaultSizeInBytes)
)
}
}
}

/** Physical plan node for scanning data from an RDD of InternalRow. */
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ package org.apache.spark.sql.execution.streaming.sources

import org.apache.spark.sql._
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.execution.LogicalRDD
import org.apache.spark.sql.execution.streaming.Sink
import org.apache.spark.sql.streaming.DataStreamWriter

Expand All @@ -27,11 +29,29 @@ class ForeachBatchSink[T](batchWriter: (Dataset[T], Long) => Unit, encoder: Expr

override def addBatch(batchId: Long, data: DataFrame): Unit = {
val rdd = data.queryExecution.toRdd
val executedPlan = data.queryExecution.executedPlan
val node = LogicalRDD(
data.schema.toAttributes,
rdd,
Some(eliminateWriteMarkerNode(data.queryExecution.analyzed)),
executedPlan.outputPartitioning,
executedPlan.outputOrdering)(data.sparkSession)
implicit val enc = encoder
val ds = data.sparkSession.internalCreateDataFrame(rdd, data.schema).as[T]
val ds = Dataset.ofRows(data.sparkSession, node).as[T]
batchWriter(ds, batchId)
}

/**
* ForEachBatchSink implementation reuses the logical plan of `data` which breaks the contract
* of Sink.addBatch, which `data` should be just used to "collect" the output data.
* We have to deal with eliminating marker node here which we do this in streaming specific
* optimization rule.
*/
private def eliminateWriteMarkerNode(plan: LogicalPlan): LogicalPlan = plan match {
case node: WriteToMicroBatchDataSourceV1 => node.child
case node => node
}

override def toString(): String = "ForeachBatchSink"
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@ import scala.language.implicitConversions

import org.apache.spark.sql._
import org.apache.spark.sql.catalyst.encoders.RowEncoder
import org.apache.spark.sql.execution.SerializeFromObjectExec
import org.apache.spark.sql.execution.{LogicalRDD, SerializeFromObjectExec}
import org.apache.spark.sql.execution.datasources.v2.StreamingDataSourceV2Relation
import org.apache.spark.sql.execution.streaming.MemoryStream
import org.apache.spark.sql.functions._
import org.apache.spark.sql.streaming._
Expand Down Expand Up @@ -185,6 +186,48 @@ class ForeachBatchSinkSuite extends StreamTest {
assertPlan(mem2, dsUntyped)
}

test("Leaf node of Dataset in foreachBatch should carry over origin logical plan") {
def assertPlan[T](stream: MemoryStream[Int], ds: Dataset[T]): Unit = {
var planAsserted = false

val writer: (Dataset[T], Long) => Unit = { case (df, _) =>
df.logicalPlan.collectLeaves().head match {
case l: LogicalRDD =>
assert(l.originLogicalPlan.nonEmpty, "Origin logical plan should be available in " +
"LogicalRDD")
l.originLogicalPlan.get.collectLeaves().head match {
case _: StreamingDataSourceV2Relation => // pass
case p =>
fail("Expect StreamingDataSourceV2Relation in the leaf node of origin " +
s"logical plan! Actual: $p")
}

case p =>
fail(s"Expect LogicalRDD in the leaf node of Dataset! Actual: $p")
}
planAsserted = true
}

stream.addData(1, 2, 3, 4, 5)

val query = ds.writeStream.trigger(Trigger.Once()).foreachBatch(writer).start()
query.awaitTermination()

assert(planAsserted, "ForeachBatch writer should be called!")
}

// typed
val mem = MemoryStream[Int]
val ds = mem.toDS.map(_ + 1)
assertPlan(mem, ds)

// untyped
val mem2 = MemoryStream[Int]
val dsUntyped = mem2.toDF().selectExpr("value + 1 as value")
assertPlan(mem2, dsUntyped)
}


// ============== Helper classes and methods =================

private class ForeachBatchTester[T: Encoder](memoryStream: MemoryStream[Int]) {
Expand Down