Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-30185][SQL] Implement Dataset.tail API #26809

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -66,20 +66,20 @@ trait CheckAnalysis extends PredicateHelper {
case _ => None
}

private def checkLimitClause(limitExpr: Expression): Unit = {
private def checkLimitLikeClause(name: String, limitExpr: Expression): Unit = {
limitExpr match {
case e if !e.foldable => failAnalysis(
"The limit expression must evaluate to a constant value, but got " +
s"The $name expression must evaluate to a constant value, but got " +
limitExpr.sql)
case e if e.dataType != IntegerType => failAnalysis(
s"The limit expression must be integer type, but got " +
s"The $name expression must be integer type, but got " +
e.dataType.catalogString)
case e =>
e.eval() match {
case null => failAnalysis(
s"The evaluated limit expression must not be null, but got ${limitExpr.sql}")
s"The evaluated $name expression must not be null, but got ${limitExpr.sql}")
case v: Int if v < 0 => failAnalysis(
s"The limit expression must be equal to or greater than 0, but got $v")
s"The $name expression must be equal to or greater than 0, but got $v")
case _ => // OK
}
}
Expand Down Expand Up @@ -324,9 +324,11 @@ trait CheckAnalysis extends PredicateHelper {
}
}

case GlobalLimit(limitExpr, _) => checkLimitClause(limitExpr)
case GlobalLimit(limitExpr, _) => checkLimitLikeClause("limit", limitExpr)

case LocalLimit(limitExpr, _) => checkLimitClause(limitExpr)
case LocalLimit(limitExpr, _) => checkLimitLikeClause("limit", limitExpr)

case Tail(limitExpr, _) => checkLimitLikeClause("tail", limitExpr)

case _: Union | _: SetOperation if operator.children.length > 1 =>
def dataTypes(plan: LogicalPlan): Seq[DataType] = plan.output.map(_.dataType)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -810,6 +810,26 @@ case class LocalLimit(limitExpr: Expression, child: LogicalPlan) extends OrderPr
}
}

/**
* This is similar with [[Limit]] except:
*
* - It does not have plans for global/local separately because currently there is only single
* implementation which initially mimics both global/local tails. See
* `org.apache.spark.sql.execution.CollectTailExec` and
* `org.apache.spark.sql.execution.CollectLimitExec`
*
* - Currently, this plan can only be a root node.
*/
case class Tail(limitExpr: Expression, child: LogicalPlan) extends OrderPreservingUnaryNode {
override def output: Seq[Attribute] = child.output
override def maxRows: Option[Long] = {
limitExpr match {
case IntegerLiteral(limit) => Some(limit)
case _ => None
}
}
}

/**
* Aliased subquery.
*
Expand Down
12 changes: 12 additions & 0 deletions sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2793,6 +2793,18 @@ class Dataset[T] private[sql](
*/
def take(n: Int): Array[T] = head(n)

/**
* Returns the last `n` rows in the Dataset.
HyukjinKwon marked this conversation as resolved.
Show resolved Hide resolved
*
* Running tail requires moving data into the application's driver process, and doing so with
* a very large `n` can crash the driver process with OutOfMemoryError.
*
* @group action
* @since 3.0.0
*/
def tail(n: Int): Array[T] = withAction(
"tail", withTypedPlan(Tail(Literal(n), logicalPlan)).queryExecution)(collectFromPlan)

/**
* Returns the first `n` rows in the Dataset as a list.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,12 @@ case class LocalTableScanExec(
taken
}

override def executeTail(limit: Int): Array[InternalRow] = {
val taken: Seq[InternalRow] = unsafeRows.takeRight(limit)
longMetric("numOutputRows").add(taken.size)
taken.toArray
}

// Input is already UnsafeRows.
override protected val createUnsafeProjection: Boolean = false

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ package org.apache.spark.sql.execution
import java.io.{ByteArrayInputStream, ByteArrayOutputStream, DataInputStream, DataOutputStream}
import java.util.concurrent.atomic.AtomicInteger

import scala.collection.mutable.ArrayBuffer
import scala.collection.mutable.{ArrayBuffer, ListBuffer}

import org.apache.spark.{broadcast, SparkEnv}
import org.apache.spark.internal.Logging
Expand Down Expand Up @@ -309,20 +309,38 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ
* UnsafeRow is highly compressible (at least 8 bytes for any column), the byte array is also
* compressed.
*/
private def getByteArrayRdd(n: Int = -1): RDD[(Long, Array[Byte])] = {
private def getByteArrayRdd(
n: Int = -1, takeFromEnd: Boolean = false): RDD[(Long, Array[Byte])] = {
execute().mapPartitionsInternal { iter =>
var count = 0
val buffer = new Array[Byte](4 << 10) // 4K
val codec = CompressionCodec.createCodec(SparkEnv.get.conf)
val bos = new ByteArrayOutputStream()
val out = new DataOutputStream(codec.compressedOutputStream(bos))
// `iter.hasNext` may produce one row and buffer it, we should only call it when the limit is
// not hit.
while ((n < 0 || count < n) && iter.hasNext) {
val row = iter.next().asInstanceOf[UnsafeRow]
out.writeInt(row.getSizeInBytes)
row.writeToStream(out, buffer)
count += 1

if (takeFromEnd && n > 0) {
// To collect n from the last, we should anyway read everything with keeping the n.
// Otherwise, we don't know where is the last from the iterator.
var last: Seq[UnsafeRow] = Seq.empty[UnsafeRow]
val slidingIter = iter.map(_.copy()).sliding(n)
while (slidingIter.hasNext) { last = slidingIter.next().asInstanceOf[Seq[UnsafeRow]] }
var i = 0
count = last.length
while (i < count) {
val row = last(i)
out.writeInt(row.getSizeInBytes)
row.writeToStream(out, buffer)
i += 1
}
} else {
// `iter.hasNext` may produce one row and buffer it, we should only call it when the
// limit is not hit.
while ((n < 0 || count < n) && iter.hasNext) {
val row = iter.next().asInstanceOf[UnsafeRow]
out.writeInt(row.getSizeInBytes)
row.writeToStream(out, buffer)
count += 1
}
}
out.writeInt(-1)
out.flush()
Expand Down Expand Up @@ -397,14 +415,23 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ
*
* This is modeled after `RDD.take` but never runs any job locally on the driver.
*/
def executeTake(n: Int): Array[InternalRow] = {
def executeTake(n: Int): Array[InternalRow] = executeTake(n, takeFromEnd = false)

/**
* Runs this query returning the last `n` rows as an array.
*
* This is modeled after `RDD.take` but never runs any job locally on the driver.
*/
def executeTail(n: Int): Array[InternalRow] = executeTake(n, takeFromEnd = true)

private def executeTake(n: Int, takeFromEnd: Boolean): Array[InternalRow] = {
if (n == 0) {
return new Array[InternalRow](0)
}

val childRDD = getByteArrayRdd(n)
val childRDD = getByteArrayRdd(n, takeFromEnd)

val buf = new ArrayBuffer[InternalRow]
val buf = if (takeFromEnd) new ListBuffer[InternalRow] else new ArrayBuffer[InternalRow]
val totalParts = childRDD.partitions.length
var partsScanned = 0
while (buf.length < n && partsScanned < totalParts) {
Expand All @@ -426,23 +453,46 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ
}
}

val p = partsScanned.until(math.min(partsScanned + numPartsToTry, totalParts).toInt)
val parts = partsScanned.until(math.min(partsScanned + numPartsToTry, totalParts).toInt)
val partsToScan = if (takeFromEnd) {
// Reverse partitions to scan. So, if parts was [1, 2, 3] in 200 partitions (0 to 199),
// it becomes [198, 197, 196].
parts.map(p => (totalParts - 1) - p)
} else {
parts
}
val sc = sqlContext.sparkContext
val res = sc.runJob(childRDD, (it: Iterator[(Long, Array[Byte])]) =>
if (it.hasNext) it.next() else (0L, Array.empty[Byte]), p)
if (it.hasNext) it.next() else (0L, Array.empty[Byte]), partsToScan)

var i = 0
while (buf.length < n && i < res.length) {
val rows = decodeUnsafeRows(res(i)._2)
val rowsToTake = if (n - buf.length >= res(i)._1) {
rows.toArray
} else {
rows.take(n - buf.length).toArray

if (takeFromEnd) {
while (buf.length < n && i < res.length) {
val rows = decodeUnsafeRows(res(i)._2)
if (n - buf.length >= res(i)._1) {
buf.prepend(rows.toArray[InternalRow]: _*)
} else {
val dropUntil = res(i)._1 - (n - buf.length)
// Same as Iterator.drop but this only takes a long.
var j: Long = 0L
while (j < dropUntil) { rows.next(); j += 1L}
buf.prepend(rows.toArray[InternalRow]: _*)
}
i += 1
}
} else {
while (buf.length < n && i < res.length) {
val rows = decodeUnsafeRows(res(i)._2)
if (n - buf.length >= res(i)._1) {
buf ++= rows.toArray[InternalRow]
} else {
buf ++= rows.take(n - buf.length).toArray[InternalRow]
}
i += 1
}
buf ++= rowsToTake
i += 1
}
partsScanned += p.size
partsScanned += partsToScan.size
}
buf.toArray
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,8 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
TakeOrderedAndProjectExec(limit, order, projectList, planLater(child)) :: Nil
case Limit(IntegerLiteral(limit), child) =>
CollectLimitExec(limit, planLater(child)) :: Nil
case Tail(IntegerLiteral(limit), child) =>
CollectTailExec(limit, planLater(child)) :: Nil
case other => planLater(other) :: Nil
}
case Limit(IntegerLiteral(limit), Sort(order, true, child))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,10 @@ case class AdaptiveSparkPlanExec(
getFinalPhysicalPlan().executeTake(n)
}

override def executeTail(n: Int): Array[InternalRow] = {
getFinalPhysicalPlan().executeTail(n)
}

override def doExecute(): RDD[InternalRow] = {
getFinalPhysicalPlan().execute()
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@ abstract class QueryStageExec extends LeafExecNode {
override def outputOrdering: Seq[SortOrder] = plan.outputOrdering
override def executeCollect(): Array[InternalRow] = plan.executeCollect()
override def executeTake(n: Int): Array[InternalRow] = plan.executeTake(n)
override def executeTail(n: Int): Array[InternalRow] = plan.executeTail(n)
override def executeToIterator(): Iterator[InternalRow] = plan.executeToIterator()

override def doPrepare(): Unit = plan.prepare()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,10 @@ case class ExecutedCommandExec(cmd: RunnableCommand) extends LeafExecNode {

override def executeTake(limit: Int): Array[InternalRow] = sideEffectResult.take(limit).toArray

override def executeTail(limit: Int): Array[InternalRow] = {
sideEffectResult.takeRight(limit).toArray
}

protected override def doExecute(): RDD[InternalRow] = {
sqlContext.sparkContext.parallelize(sideEffectResult, 1)
}
Expand Down Expand Up @@ -119,6 +123,10 @@ case class DataWritingCommandExec(cmd: DataWritingCommand, child: SparkPlan)

override def executeTake(limit: Int): Array[InternalRow] = sideEffectResult.take(limit).toArray

override def executeTail(limit: Int): Array[InternalRow] = {
sideEffectResult.takeRight(limit).toArray
}

protected override def doExecute(): RDD[InternalRow] = {
sqlContext.sparkContext.parallelize(sideEffectResult, 1)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,8 @@ abstract class V2CommandExec extends LeafExecNode {

override def executeTake(limit: Int): Array[InternalRow] = result.take(limit).toArray

override def executeTail(limit: Int): Array[InternalRow] = result.takeRight(limit).toArray

protected override def doExecute(): RDD[InternalRow] = {
sqlContext.sparkContext.parallelize(result, 1)
}
Expand Down
22 changes: 22 additions & 0 deletions sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,28 @@ case class CollectLimitExec(limit: Int, child: SparkPlan) extends LimitExec {
}
}

/**
* Take the last `limit` elements and collect them to a single partition.
*
* This operator will be used when a logical `Tail` operation is the final operator in an
* logical plan, which happens when the user is collecting results back to the driver.
*/
case class CollectTailExec(limit: Int, child: SparkPlan) extends LimitExec {
override def output: Seq[Attribute] = child.output
override def outputPartitioning: Partitioning = SinglePartition
override def executeCollect(): Array[InternalRow] = child.executeTail(limit)
protected override def doExecute(): RDD[InternalRow] = {
HyukjinKwon marked this conversation as resolved.
Show resolved Hide resolved
HyukjinKwon marked this conversation as resolved.
Show resolved Hide resolved
// This is a bit hacky way to avoid a shuffle and scanning all data when it performs
// at `Dataset.tail`.
// Since this execution plan and `execute` are currently called only when
// `Dataset.tail` is invoked, the jobs are always executed when they are supposed to be.

// If we use this execution plan separately like `Dataset.limit` without an actual
// job launch, we might just have to mimic the implementation of `CollectLimitExec`.
sparkContext.parallelize(executeCollect(), numSlices = 1)
}
}

object BaseLimitExec {
private val curId = new java.util.concurrent.atomic.AtomicInteger()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -92,9 +92,10 @@ class DataFrameSuite extends QueryTest with SharedSparkSession {
assert(spark.emptyDataFrame.count() === 0)
}

test("head and take") {
test("head, take and tail") {
assert(testData.take(2) === testData.collect().take(2))
assert(testData.head(2) === testData.collect().take(2))
assert(testData.tail(2) === testData.collect().takeRight(2))
assert(testData.head(2).head.schema === testData.schema)
}

Expand Down
20 changes: 18 additions & 2 deletions sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,11 @@ class DatasetSuite extends QueryTest with SharedSparkSession {
assert(ds.take(2) === Array(ClassData("a", 1), ClassData("b", 2)))
}

test("as case class - tail") {
val ds = Seq((1, "a"), (2, "b"), (3, "c")).toDF("b", "a").as[ClassData]
assert(ds.tail(2) === Array(ClassData("b", 2), ClassData("c", 3)))
}

test("as seq of case class - reorder fields by name") {
val df = spark.range(3).select(array(struct($"id".cast("int").as("b"), lit("a").as("a"))))
val ds = df.as[Seq[ClassData]]
Expand Down Expand Up @@ -1861,9 +1866,9 @@ class DatasetSuite extends QueryTest with SharedSparkSession {
}

test("groupBy.as") {
val df1 = Seq(DoubleData(1, "one"), DoubleData(2, "two"), DoubleData( 3, "three")).toDS()
val df1 = Seq(DoubleData(1, "one"), DoubleData(2, "two"), DoubleData(3, "three")).toDS()
.repartition($"id").sortWithinPartitions("id")
val df2 = Seq(DoubleData(5, "one"), DoubleData(1, "two"), DoubleData( 3, "three")).toDS()
val df2 = Seq(DoubleData(5, "one"), DoubleData(1, "two"), DoubleData(3, "three")).toDS()
.repartition($"id").sortWithinPartitions("id")

val df3 = df1.groupBy("id").as[Int, DoubleData]
Expand All @@ -1880,6 +1885,17 @@ class DatasetSuite extends QueryTest with SharedSparkSession {
}
assert(exchanges.size == 2)
}

test("tail with different numbers") {
Seq(0, 2, 5, 10, 50, 100, 1000).foreach { n =>
assert(spark.range(n).tail(6) === (math.max(n - 6, 0) until n))
}
}

test("tail should not accept minus value") {
cloud-fan marked this conversation as resolved.
Show resolved Hide resolved
val e = intercept[AnalysisException](spark.range(1).tail(-1))
e.getMessage.contains("tail expression must be equal to or greater than 0")
}
}

object AssertExecutionId {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ class SparkPlanSuite extends QueryTest with SharedSparkSession {
intercept[IllegalStateException] { plan.executeToIterator() }
intercept[IllegalStateException] { plan.executeBroadcast() }
intercept[IllegalStateException] { plan.executeTake(1) }
intercept[IllegalStateException] { plan.executeTail(1) }
}

test("SPARK-23731 plans should be canonicalizable after being (de)serialized") {
Expand Down