Skip to content

Commit

Permalink
[SPARK-30185][SQL] Implement Dataset.tail API
Browse files Browse the repository at this point in the history
### What changes were proposed in this pull request?

This PR proposes a `tail` API.

Namely, as below:

```scala
scala> spark.range(10).head(5)
res1: Array[Long] = Array(0, 1, 2, 3, 4)
scala> spark.range(10).tail(5)
res2: Array[Long] = Array(5, 6, 7, 8, 9)
```

Implementation details will be similar with `head` but it will be reversed:

1. Run the job against the last partition and collect rows. If this is enough, return as is.
2. If this is not enough, calculate the number of partitions to select more based upon
 `spark.sql.limit.scaleUpFactor`
3. Run more jobs against more partitions (in a reversed order compared to head) as many as the number calculated from 2.
4. Go to 2.

**Note that**, we don't guarantee the natural order in DataFrame in general - there are cases when it's deterministic and when it's not. We probably should write down this as a caveat separately.

### Why are the changes needed?

Many other systems support the way to take data from the end, for instance, pandas[1] and
 Python[2][3]. Scala collections APIs also have head and tail

On the other hand, in Spark, we only provide a way to take data from the start
 (e.g., DataFrame.head).

This has been requested multiple times here and there in Spark user mailing list[4], StackOverFlow[5][6], JIRA[7] and other third party projects such as
 Koalas[8]. In addition, this missing API seems explicitly mentioned in comparison to another system[9] time to time.

It seems we're missing non-trivial use case in Spark and this motivated me to propose this API.

[1] https://pandas.pydata.org/pandas-docs/stable/reference/api/pandas.DataFrame.tail.html?highlight=tail#pandas.DataFrame.tail
[2] https://stackoverflow.com/questions/10532473/head-and-tail-in-one-line
[3] https://stackoverflow.com/questions/646644/how-to-get-last-items-of-a-list-in-python
[4] http://apache-spark-user-list.1001560.n3.nabble.com/RDD-tail-td4217.html
[5] https://stackoverflow.com/questions/39544796/how-to-select-last-row-and-also-how-to-access-pyspark-dataframe-by-index
[6] https://stackoverflow.com/questions/45406762/how-to-get-the-last-row-from-dataframe
[7] https://issues.apache.org/jira/browse/SPARK-26433
[8] databricks/koalas#343
[9] https://medium.com/chris_bour/6-differences-between-pandas-and-spark-dataframes-1380cec394d2

### Does this PR introduce any user-facing change?

No, (new API)

### How was this patch tested?

Unit tests were added and manually tested.

Closes #26809 from HyukjinKwon/wip-tail.

Authored-by: HyukjinKwon <gurwls223@apache.org>
Signed-off-by: HyukjinKwon <gurwls223@apache.org>
  • Loading branch information
HyukjinKwon committed Dec 30, 2019
1 parent f0fbbf0 commit 7079e87
Show file tree
Hide file tree
Showing 14 changed files with 180 additions and 33 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -66,20 +66,20 @@ trait CheckAnalysis extends PredicateHelper {
case _ => None
}

private def checkLimitClause(limitExpr: Expression): Unit = {
private def checkLimitLikeClause(name: String, limitExpr: Expression): Unit = {
limitExpr match {
case e if !e.foldable => failAnalysis(
"The limit expression must evaluate to a constant value, but got " +
s"The $name expression must evaluate to a constant value, but got " +
limitExpr.sql)
case e if e.dataType != IntegerType => failAnalysis(
s"The limit expression must be integer type, but got " +
s"The $name expression must be integer type, but got " +
e.dataType.catalogString)
case e =>
e.eval() match {
case null => failAnalysis(
s"The evaluated limit expression must not be null, but got ${limitExpr.sql}")
s"The evaluated $name expression must not be null, but got ${limitExpr.sql}")
case v: Int if v < 0 => failAnalysis(
s"The limit expression must be equal to or greater than 0, but got $v")
s"The $name expression must be equal to or greater than 0, but got $v")
case _ => // OK
}
}
Expand Down Expand Up @@ -324,9 +324,11 @@ trait CheckAnalysis extends PredicateHelper {
}
}

case GlobalLimit(limitExpr, _) => checkLimitClause(limitExpr)
case GlobalLimit(limitExpr, _) => checkLimitLikeClause("limit", limitExpr)

case LocalLimit(limitExpr, _) => checkLimitClause(limitExpr)
case LocalLimit(limitExpr, _) => checkLimitLikeClause("limit", limitExpr)

case Tail(limitExpr, _) => checkLimitLikeClause("tail", limitExpr)

case _: Union | _: SetOperation if operator.children.length > 1 =>
def dataTypes(plan: LogicalPlan): Seq[DataType] = plan.output.map(_.dataType)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -810,6 +810,26 @@ case class LocalLimit(limitExpr: Expression, child: LogicalPlan) extends OrderPr
}
}

/**
* This is similar with [[Limit]] except:
*
* - It does not have plans for global/local separately because currently there is only single
* implementation which initially mimics both global/local tails. See
* `org.apache.spark.sql.execution.CollectTailExec` and
* `org.apache.spark.sql.execution.CollectLimitExec`
*
* - Currently, this plan can only be a root node.
*/
case class Tail(limitExpr: Expression, child: LogicalPlan) extends OrderPreservingUnaryNode {
override def output: Seq[Attribute] = child.output
override def maxRows: Option[Long] = {
limitExpr match {
case IntegerLiteral(limit) => Some(limit)
case _ => None
}
}
}

/**
* Aliased subquery.
*
Expand Down
12 changes: 12 additions & 0 deletions sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2793,6 +2793,18 @@ class Dataset[T] private[sql](
*/
def take(n: Int): Array[T] = head(n)

/**
* Returns the last `n` rows in the Dataset.
*
* Running tail requires moving data into the application's driver process, and doing so with
* a very large `n` can crash the driver process with OutOfMemoryError.
*
* @group action
* @since 3.0.0
*/
def tail(n: Int): Array[T] = withAction(
"tail", withTypedPlan(Tail(Literal(n), logicalPlan)).queryExecution)(collectFromPlan)

/**
* Returns the first `n` rows in the Dataset as a list.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,12 @@ case class LocalTableScanExec(
taken
}

override def executeTail(limit: Int): Array[InternalRow] = {
val taken: Seq[InternalRow] = unsafeRows.takeRight(limit)
longMetric("numOutputRows").add(taken.size)
taken.toArray
}

// Input is already UnsafeRows.
override protected val createUnsafeProjection: Boolean = false

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ package org.apache.spark.sql.execution
import java.io.{ByteArrayInputStream, ByteArrayOutputStream, DataInputStream, DataOutputStream}
import java.util.concurrent.atomic.AtomicInteger

import scala.collection.mutable.ArrayBuffer
import scala.collection.mutable.{ArrayBuffer, ListBuffer}

import org.apache.spark.{broadcast, SparkEnv}
import org.apache.spark.internal.Logging
Expand Down Expand Up @@ -309,20 +309,38 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ
* UnsafeRow is highly compressible (at least 8 bytes for any column), the byte array is also
* compressed.
*/
private def getByteArrayRdd(n: Int = -1): RDD[(Long, Array[Byte])] = {
private def getByteArrayRdd(
n: Int = -1, takeFromEnd: Boolean = false): RDD[(Long, Array[Byte])] = {
execute().mapPartitionsInternal { iter =>
var count = 0
val buffer = new Array[Byte](4 << 10) // 4K
val codec = CompressionCodec.createCodec(SparkEnv.get.conf)
val bos = new ByteArrayOutputStream()
val out = new DataOutputStream(codec.compressedOutputStream(bos))
// `iter.hasNext` may produce one row and buffer it, we should only call it when the limit is
// not hit.
while ((n < 0 || count < n) && iter.hasNext) {
val row = iter.next().asInstanceOf[UnsafeRow]
out.writeInt(row.getSizeInBytes)
row.writeToStream(out, buffer)
count += 1

if (takeFromEnd && n > 0) {
// To collect n from the last, we should anyway read everything with keeping the n.
// Otherwise, we don't know where is the last from the iterator.
var last: Seq[UnsafeRow] = Seq.empty[UnsafeRow]
val slidingIter = iter.map(_.copy()).sliding(n)
while (slidingIter.hasNext) { last = slidingIter.next().asInstanceOf[Seq[UnsafeRow]] }
var i = 0
count = last.length
while (i < count) {
val row = last(i)
out.writeInt(row.getSizeInBytes)
row.writeToStream(out, buffer)
i += 1
}
} else {
// `iter.hasNext` may produce one row and buffer it, we should only call it when the
// limit is not hit.
while ((n < 0 || count < n) && iter.hasNext) {
val row = iter.next().asInstanceOf[UnsafeRow]
out.writeInt(row.getSizeInBytes)
row.writeToStream(out, buffer)
count += 1
}
}
out.writeInt(-1)
out.flush()
Expand Down Expand Up @@ -397,14 +415,23 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ
*
* This is modeled after `RDD.take` but never runs any job locally on the driver.
*/
def executeTake(n: Int): Array[InternalRow] = {
def executeTake(n: Int): Array[InternalRow] = executeTake(n, takeFromEnd = false)

/**
* Runs this query returning the last `n` rows as an array.
*
* This is modeled after `RDD.take` but never runs any job locally on the driver.
*/
def executeTail(n: Int): Array[InternalRow] = executeTake(n, takeFromEnd = true)

private def executeTake(n: Int, takeFromEnd: Boolean): Array[InternalRow] = {
if (n == 0) {
return new Array[InternalRow](0)
}

val childRDD = getByteArrayRdd(n)
val childRDD = getByteArrayRdd(n, takeFromEnd)

val buf = new ArrayBuffer[InternalRow]
val buf = if (takeFromEnd) new ListBuffer[InternalRow] else new ArrayBuffer[InternalRow]
val totalParts = childRDD.partitions.length
var partsScanned = 0
while (buf.length < n && partsScanned < totalParts) {
Expand All @@ -426,23 +453,46 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ
}
}

val p = partsScanned.until(math.min(partsScanned + numPartsToTry, totalParts).toInt)
val parts = partsScanned.until(math.min(partsScanned + numPartsToTry, totalParts).toInt)
val partsToScan = if (takeFromEnd) {
// Reverse partitions to scan. So, if parts was [1, 2, 3] in 200 partitions (0 to 199),
// it becomes [198, 197, 196].
parts.map(p => (totalParts - 1) - p)
} else {
parts
}
val sc = sqlContext.sparkContext
val res = sc.runJob(childRDD, (it: Iterator[(Long, Array[Byte])]) =>
if (it.hasNext) it.next() else (0L, Array.empty[Byte]), p)
if (it.hasNext) it.next() else (0L, Array.empty[Byte]), partsToScan)

var i = 0
while (buf.length < n && i < res.length) {
val rows = decodeUnsafeRows(res(i)._2)
val rowsToTake = if (n - buf.length >= res(i)._1) {
rows.toArray
} else {
rows.take(n - buf.length).toArray

if (takeFromEnd) {
while (buf.length < n && i < res.length) {
val rows = decodeUnsafeRows(res(i)._2)
if (n - buf.length >= res(i)._1) {
buf.prepend(rows.toArray[InternalRow]: _*)
} else {
val dropUntil = res(i)._1 - (n - buf.length)
// Same as Iterator.drop but this only takes a long.
var j: Long = 0L
while (j < dropUntil) { rows.next(); j += 1L}
buf.prepend(rows.toArray[InternalRow]: _*)
}
i += 1
}
} else {
while (buf.length < n && i < res.length) {
val rows = decodeUnsafeRows(res(i)._2)
if (n - buf.length >= res(i)._1) {
buf ++= rows.toArray[InternalRow]
} else {
buf ++= rows.take(n - buf.length).toArray[InternalRow]
}
i += 1
}
buf ++= rowsToTake
i += 1
}
partsScanned += p.size
partsScanned += partsToScan.size
}
buf.toArray
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,8 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
TakeOrderedAndProjectExec(limit, order, projectList, planLater(child)) :: Nil
case Limit(IntegerLiteral(limit), child) =>
CollectLimitExec(limit, planLater(child)) :: Nil
case Tail(IntegerLiteral(limit), child) =>
CollectTailExec(limit, planLater(child)) :: Nil
case other => planLater(other) :: Nil
}
case Limit(IntegerLiteral(limit), Sort(order, true, child))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,10 @@ case class AdaptiveSparkPlanExec(
getFinalPhysicalPlan().executeTake(n)
}

override def executeTail(n: Int): Array[InternalRow] = {
getFinalPhysicalPlan().executeTail(n)
}

override def doExecute(): RDD[InternalRow] = {
getFinalPhysicalPlan().execute()
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@ abstract class QueryStageExec extends LeafExecNode {
override def outputOrdering: Seq[SortOrder] = plan.outputOrdering
override def executeCollect(): Array[InternalRow] = plan.executeCollect()
override def executeTake(n: Int): Array[InternalRow] = plan.executeTake(n)
override def executeTail(n: Int): Array[InternalRow] = plan.executeTail(n)
override def executeToIterator(): Iterator[InternalRow] = plan.executeToIterator()

override def doPrepare(): Unit = plan.prepare()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,10 @@ case class ExecutedCommandExec(cmd: RunnableCommand) extends LeafExecNode {

override def executeTake(limit: Int): Array[InternalRow] = sideEffectResult.take(limit).toArray

override def executeTail(limit: Int): Array[InternalRow] = {
sideEffectResult.takeRight(limit).toArray
}

protected override def doExecute(): RDD[InternalRow] = {
sqlContext.sparkContext.parallelize(sideEffectResult, 1)
}
Expand Down Expand Up @@ -119,6 +123,10 @@ case class DataWritingCommandExec(cmd: DataWritingCommand, child: SparkPlan)

override def executeTake(limit: Int): Array[InternalRow] = sideEffectResult.take(limit).toArray

override def executeTail(limit: Int): Array[InternalRow] = {
sideEffectResult.takeRight(limit).toArray
}

protected override def doExecute(): RDD[InternalRow] = {
sqlContext.sparkContext.parallelize(sideEffectResult, 1)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,8 @@ abstract class V2CommandExec extends LeafExecNode {

override def executeTake(limit: Int): Array[InternalRow] = result.take(limit).toArray

override def executeTail(limit: Int): Array[InternalRow] = result.takeRight(limit).toArray

protected override def doExecute(): RDD[InternalRow] = {
sqlContext.sparkContext.parallelize(result, 1)
}
Expand Down
22 changes: 22 additions & 0 deletions sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,28 @@ case class CollectLimitExec(limit: Int, child: SparkPlan) extends LimitExec {
}
}

/**
* Take the last `limit` elements and collect them to a single partition.
*
* This operator will be used when a logical `Tail` operation is the final operator in an
* logical plan, which happens when the user is collecting results back to the driver.
*/
case class CollectTailExec(limit: Int, child: SparkPlan) extends LimitExec {
override def output: Seq[Attribute] = child.output
override def outputPartitioning: Partitioning = SinglePartition
override def executeCollect(): Array[InternalRow] = child.executeTail(limit)
protected override def doExecute(): RDD[InternalRow] = {
// This is a bit hacky way to avoid a shuffle and scanning all data when it performs
// at `Dataset.tail`.
// Since this execution plan and `execute` are currently called only when
// `Dataset.tail` is invoked, the jobs are always executed when they are supposed to be.

// If we use this execution plan separately like `Dataset.limit` without an actual
// job launch, we might just have to mimic the implementation of `CollectLimitExec`.
sparkContext.parallelize(executeCollect(), numSlices = 1)
}
}

object BaseLimitExec {
private val curId = new java.util.concurrent.atomic.AtomicInteger()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -92,9 +92,10 @@ class DataFrameSuite extends QueryTest with SharedSparkSession {
assert(spark.emptyDataFrame.count() === 0)
}

test("head and take") {
test("head, take and tail") {
assert(testData.take(2) === testData.collect().take(2))
assert(testData.head(2) === testData.collect().take(2))
assert(testData.tail(2) === testData.collect().takeRight(2))
assert(testData.head(2).head.schema === testData.schema)
}

Expand Down
20 changes: 18 additions & 2 deletions sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,11 @@ class DatasetSuite extends QueryTest with SharedSparkSession {
assert(ds.take(2) === Array(ClassData("a", 1), ClassData("b", 2)))
}

test("as case class - tail") {
val ds = Seq((1, "a"), (2, "b"), (3, "c")).toDF("b", "a").as[ClassData]
assert(ds.tail(2) === Array(ClassData("b", 2), ClassData("c", 3)))
}

test("as seq of case class - reorder fields by name") {
val df = spark.range(3).select(array(struct($"id".cast("int").as("b"), lit("a").as("a"))))
val ds = df.as[Seq[ClassData]]
Expand Down Expand Up @@ -1861,9 +1866,9 @@ class DatasetSuite extends QueryTest with SharedSparkSession {
}

test("groupBy.as") {
val df1 = Seq(DoubleData(1, "one"), DoubleData(2, "two"), DoubleData( 3, "three")).toDS()
val df1 = Seq(DoubleData(1, "one"), DoubleData(2, "two"), DoubleData(3, "three")).toDS()
.repartition($"id").sortWithinPartitions("id")
val df2 = Seq(DoubleData(5, "one"), DoubleData(1, "two"), DoubleData( 3, "three")).toDS()
val df2 = Seq(DoubleData(5, "one"), DoubleData(1, "two"), DoubleData(3, "three")).toDS()
.repartition($"id").sortWithinPartitions("id")

val df3 = df1.groupBy("id").as[Int, DoubleData]
Expand All @@ -1880,6 +1885,17 @@ class DatasetSuite extends QueryTest with SharedSparkSession {
}
assert(exchanges.size == 2)
}

test("tail with different numbers") {
Seq(0, 2, 5, 10, 50, 100, 1000).foreach { n =>
assert(spark.range(n).tail(6) === (math.max(n - 6, 0) until n))
}
}

test("tail should not accept minus value") {
val e = intercept[AnalysisException](spark.range(1).tail(-1))
e.getMessage.contains("tail expression must be equal to or greater than 0")
}
}

object AssertExecutionId {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ class SparkPlanSuite extends QueryTest with SharedSparkSession {
intercept[IllegalStateException] { plan.executeToIterator() }
intercept[IllegalStateException] { plan.executeBroadcast() }
intercept[IllegalStateException] { plan.executeTake(1) }
intercept[IllegalStateException] { plan.executeTail(1) }
}

test("SPARK-23731 plans should be canonicalizable after being (de)serialized") {
Expand Down

0 comments on commit 7079e87

Please sign in to comment.