From 87504855cf0fa3772007bb05fc95bdaec3c312c3 Mon Sep 17 00:00:00 2001 From: uncleGen Date: Mon, 6 Feb 2017 09:36:29 +0800 Subject: [PATCH 1/4] change the type of index from Int to Long --- .../catalyst/expressions/windowExpressions.scala | 8 ++++---- .../sql/execution/window/BoundOrdering.scala | 16 ++++++++++------ .../spark/sql/execution/window/WindowExec.scala | 6 +++--- .../execution/window/WindowFunctionFrame.scala | 16 ++++++++-------- .../spark/sql/expressions/WindowSpec.scala | 8 ++++---- 5 files changed, 29 insertions(+), 25 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala index 07d294b108548..beeb0542cb3e5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala @@ -140,8 +140,8 @@ sealed trait FrameBoundary { * Extractor for making working with frame boundaries easier. */ object FrameBoundary { - def apply(boundary: FrameBoundary): Option[Int] = unapply(boundary) - def unapply(boundary: FrameBoundary): Option[Int] = boundary match { + def apply(boundary: FrameBoundary): Option[Long] = unapply(boundary) + def unapply(boundary: FrameBoundary): Option[Long] = boundary match { case CurrentRow => Some(0) case ValuePreceding(offset) => Some(-offset) case ValueFollowing(offset) => Some(offset) @@ -163,7 +163,7 @@ case object UnboundedPreceding extends FrameBoundary { } /** PRECEDING boundary. */ -case class ValuePreceding(value: Int) extends FrameBoundary { +case class ValuePreceding(value: Long) extends FrameBoundary { def notFollows(other: FrameBoundary): Boolean = other match { case UnboundedPreceding => false case ValuePreceding(anotherValue) => value >= anotherValue @@ -189,7 +189,7 @@ case object CurrentRow extends FrameBoundary { } /** FOLLOWING boundary. */ -case class ValueFollowing(value: Int) extends FrameBoundary { +case class ValueFollowing(value: Long) extends FrameBoundary { def notFollows(other: FrameBoundary): Boolean = other match { case UnboundedPreceding => false case vp: ValuePreceding => false diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/window/BoundOrdering.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/window/BoundOrdering.scala index d6a801954c1ac..9a0ec8ec4270f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/window/BoundOrdering.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/window/BoundOrdering.scala @@ -25,18 +25,22 @@ import org.apache.spark.sql.catalyst.expressions.Projection * Function for comparing boundary values. */ private[window] abstract class BoundOrdering { - def compare(inputRow: InternalRow, inputIndex: Int, outputRow: InternalRow, outputIndex: Int): Int + def compare( + inputRow: InternalRow, + inputIndex: Long, + outputRow: InternalRow, + outputIndex: Long): Long } /** * Compare the input index to the bound of the output index. */ -private[window] final case class RowBoundOrdering(offset: Int) extends BoundOrdering { +private[window] final case class RowBoundOrdering(offset: Long) extends BoundOrdering { override def compare( inputRow: InternalRow, - inputIndex: Int, + inputIndex: Long, outputRow: InternalRow, - outputIndex: Int): Int = + outputIndex: Long): Long = inputIndex - (outputIndex + offset) } @@ -51,8 +55,8 @@ private[window] final case class RangeBoundOrdering( override def compare( inputRow: InternalRow, - inputIndex: Int, + inputIndex: Long, outputRow: InternalRow, - outputIndex: Int): Int = + outputIndex: Long): Long = ordering.compare(current(inputRow), bound(outputRow)) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowExec.scala index 80b87d5ffa797..ff6a7ab6a88df 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowExec.scala @@ -115,7 +115,7 @@ case class WindowExec( * @param offset with respect to the row. * @return a bound ordering object. */ - private[this] def createBoundOrdering(frameType: FrameType, offset: Int): BoundOrdering = { + private[this] def createBoundOrdering(frameType: FrameType, offset: Long): BoundOrdering = { frameType match { case RangeFrame => val (exprs, current, bound) = if (offset == 0) { @@ -159,7 +159,7 @@ case class WindowExec( * WindowExpressions and factory function for the WindowFrameFunction. */ private[this] lazy val windowFrameExpressionFactoryPairs = { - type FrameKey = (String, FrameType, Option[Int], Option[Int]) + type FrameKey = (String, FrameType, Option[Long], Option[Long]) type ExpressionBuffer = mutable.Buffer[Expression] val framedFunctions = mutable.Map.empty[FrameKey, (ExpressionBuffer, ExpressionBuffer)] @@ -379,7 +379,7 @@ case class WindowExec( } // Iteration - var rowIndex = 0 + var rowIndex = 0L var rowsSize = 0L override final def hasNext: Boolean = rowIndex < rowsSize || nextRowAvailable diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowFunctionFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowFunctionFrame.scala index 70efc0f78ddb0..a9be025833516 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowFunctionFrame.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowFunctionFrame.scala @@ -40,7 +40,7 @@ private[window] abstract class WindowFunctionFrame { /** * Write the current results to the target row. */ - def write(index: Int, current: InternalRow): Unit + def write(index: Long, current: InternalRow): Unit } /** @@ -61,14 +61,14 @@ private[window] final class OffsetWindowFunctionFrame( expressions: Array[OffsetWindowFunction], inputSchema: Seq[Attribute], newMutableProjection: (Seq[Expression], Seq[Attribute]) => MutableProjection, - offset: Int) + offset: Long) extends WindowFunctionFrame { /** Rows of the partition currently being processed. */ private[this] var input: RowBuffer = null /** Index of the input row currently used for output. */ - private[this] var inputIndex = 0 + private[this] var inputIndex = 0L /** * Create the projection used when the offset row exists. @@ -114,7 +114,7 @@ private[window] final class OffsetWindowFunctionFrame( inputIndex = offset } - override def write(index: Int, current: InternalRow): Unit = { + override def write(index: Long, current: InternalRow): Unit = { if (inputIndex >= 0 && inputIndex < input.size) { val r = input.next() projection(r) @@ -173,7 +173,7 @@ private[window] final class SlidingWindowFunctionFrame( } /** Write the frame columns for the current row to the given target row. */ - override def write(index: Int, current: InternalRow): Unit = { + override def write(index: Long, current: InternalRow): Unit = { var bufferUpdated = index == 0 // Add all rows to the buffer for which the input row value is equal to or less than @@ -233,7 +233,7 @@ private[window] final class UnboundedWindowFunctionFrame( } /** Write the frame columns for the current row to the given target row. */ - override def write(index: Int, current: InternalRow): Unit = { + override def write(index: Long, current: InternalRow): Unit = { // Unfortunately we cannot assume that evaluation is deterministic. So we need to re-evaluate // for each row. processor.evaluate(target) @@ -281,7 +281,7 @@ private[window] final class UnboundedPrecedingWindowFunctionFrame( } /** Write the frame columns for the current row to the given target row. */ - override def write(index: Int, current: InternalRow): Unit = { + override def write(index: Long, current: InternalRow): Unit = { var bufferUpdated = index == 0 // Add all rows to the aggregates for which the input row value is equal to or less than @@ -338,7 +338,7 @@ private[window] final class UnboundedFollowingWindowFunctionFrame( } /** Write the frame columns for the current row to the given target row. */ - override def write(index: Int, current: InternalRow): Unit = { + override def write(index: Long, current: InternalRow): Unit = { var bufferUpdated = index == 0 // Duplicate the input to have a new iterator diff --git a/sql/core/src/main/scala/org/apache/spark/sql/expressions/WindowSpec.scala b/sql/core/src/main/scala/org/apache/spark/sql/expressions/WindowSpec.scala index de7d7a1772753..3453e44a07ff1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/expressions/WindowSpec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/expressions/WindowSpec.scala @@ -181,15 +181,15 @@ class WindowSpec private[sql]( val boundaryStart = start match { case 0 => CurrentRow case Long.MinValue => UnboundedPreceding - case x if x < 0 => ValuePreceding(-start.toInt) - case x if x > 0 => ValueFollowing(start.toInt) + case x if x < 0 => ValuePreceding(-start) + case x if x > 0 => ValueFollowing(start) } val boundaryEnd = end match { case 0 => CurrentRow case Long.MaxValue => UnboundedFollowing - case x if x < 0 => ValuePreceding(-end.toInt) - case x if x > 0 => ValueFollowing(end.toInt) + case x if x < 0 => ValuePreceding(-end) + case x if x > 0 => ValueFollowing(end) } new WindowSpec( From 942d8f87acb5be76a480f30b99c2dafc4bcdbaa9 Mon Sep 17 00:00:00 2001 From: uncleGen Date: Wed, 8 Feb 2017 15:54:32 +0800 Subject: [PATCH 2/4] fix and add unit test --- .../sql/execution/window/WindowExec.scala | 4 +- .../sql/DataFrameWindowFunctionsSuite.scala | 94 +++++++++++++++++++ 2 files changed, 96 insertions(+), 2 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowExec.scala index ff6a7ab6a88df..de312354b5e5f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowExec.scala @@ -27,7 +27,7 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.execution.{SparkPlan, UnaryExecNode} -import org.apache.spark.sql.types.IntegerType +import org.apache.spark.sql.types.LongType import org.apache.spark.util.collection.unsafe.sort.UnsafeExternalSorter /** @@ -135,7 +135,7 @@ case class WindowExec( case Ascending => offset } // Create the projection which returns the current 'value' modified by adding the offset. - val boundExpr = Add(expr, Cast(Literal.create(boundOffset, IntegerType), expr.dataType)) + val boundExpr = Add(expr, Cast(Literal.create(boundOffset, LongType), expr.dataType)) val bound = newMutableProjection(boundExpr :: Nil, child.output) (sortExpr :: Nil, current, bound) } else { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFunctionsSuite.scala index 1255c49104718..1d19cbe007c12 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFunctionsSuite.scala @@ -423,4 +423,98 @@ class DataFrameWindowFunctionsSuite extends QueryTest with SharedSQLContext { df.select(selectList: _*).where($"value" < 2), Seq(Row(3, "1", null, 3.0, 4.0, 3.0), Row(5, "1", false, 4.0, 5.0, 5.0))) } + + test("SPARK-19451: Underlying integer overflow in Window function") { + val df = Seq((1L, "a"), (1L, "a"), (2L, "a"), (1L, "b"), (2L, "b"), (3L, "b")) + .toDF("id", "category") + df.createOrReplaceTempView("window_table") + + // range frames + checkAnswer( + df.select('id, 'category, sum("id").over(Window.partitionBy('category).orderBy('id) + .rangeBetween(-2160000000L, -1))), + Seq( + Row(1, "b", null), Row(2, "b", 1), Row(3, "b", 3), + Row(1, "a", null), Row(1, "a", null), Row(2, "a", 2))) + checkAnswer( + df.select('id, 'category, sum("id").over(Window.partitionBy('category).orderBy('id) + .rangeBetween(-2160000000L, 0))), + Seq( + Row(1, "b", 1), Row(2, "b", 3), Row(3, "b", 6), + Row(1, "a", 2), Row(1, "a", 2), Row(2, "a", 4))) + checkAnswer( + df.select('id, 'category, sum("id").over(Window.partitionBy('category).orderBy('id) + .rangeBetween(-2160000000L, 2))), + Seq( + Row(1, "b", 6), Row(2, "b", 6), Row(3, "b", 6), + Row(1, "a", 4), Row(1, "a", 4), Row(2, "a", 4))) + checkAnswer( + df.select('id, 'category, sum("id").over(Window.partitionBy('category).orderBy('id) + .rangeBetween(-2160000000L, 2160000000L))), + Seq( + Row(1, "b", 6), Row(2, "b", 6), Row(3, "b", 6), + Row(1, "a", 4), Row(1, "a", 4), Row(2, "a", 4))) + checkAnswer( + df.select('id, 'category, sum("id").over(Window.partitionBy('category).orderBy('id) + .rangeBetween(-1, 2160000000L))), + Seq( + Row(1, "b", 6), Row(2, "b", 6), Row(3, "b", 5), + Row(1, "a", 4), Row(1, "a", 4), Row(2, "a", 4))) + checkAnswer( + df.select('id, 'category, sum("id").over(Window.partitionBy('category).orderBy('id) + .rangeBetween(0, 2160000000L))), + Seq( + Row(1, "b", 6), Row(2, "b", 5), Row(3, "b", 3), + Row(1, "a", 4), Row(1, "a", 4), Row(2, "a", 2))) + checkAnswer( + df.select('id, 'category, sum("id").over(Window.partitionBy('category).orderBy('id) + .rangeBetween(2, 2160000000L))), + Seq( + Row(1, "b", 3), Row(2, "b", null), Row(3, "b", null), + Row(1, "a", null), Row(1, "a", null), Row(2, "a", null))) + + // row frames + checkAnswer( + df.select('id, 'category, sum("id").over(Window.partitionBy('category).orderBy('id) + .rowsBetween(-2160000000L, -1))), + Seq( + Row(1, "b", null), Row(2, "b", 1), Row(3, "b", 3), + Row(1, "a", null), Row(1, "a", 1), Row(2, "a", 2))) + checkAnswer( + df.select('id, 'category, sum("id").over(Window.partitionBy('category).orderBy('id) + .rowsBetween(-2160000000L, 0))), + Seq( + Row(1, "b", 1), Row(2, "b", 3), Row(3, "b", 6), + Row(1, "a", 1), Row(1, "a", 2), Row(2, "a", 4))) + checkAnswer( + df.select('id, 'category, sum("id").over(Window.partitionBy('category).orderBy('id) + .rowsBetween(-2160000000L, 2))), + Seq( + Row(1, "b", 6), Row(2, "b", 6), Row(3, "b", 6), + Row(1, "a", 4), Row(1, "a", 4), Row(2, "a", 4))) + checkAnswer( + df.select('id, 'category, sum("id").over(Window.partitionBy('category).orderBy('id) + .rowsBetween(-2160000000L, 2160000000L))), + Seq( + Row(1, "b", 6), Row(2, "b", 6), Row(3, "b", 6), + Row(1, "a", 4), Row(1, "a", 4), Row(2, "a", 4))) + checkAnswer( + df.select('id, 'category, sum("id").over(Window.partitionBy('category).orderBy('id) + .rowsBetween(-1, 2160000000L))), + Seq( + Row(1, "b", 6), Row(2, "b", 6), Row(3, "b", 5), + Row(1, "a", 4), Row(1, "a", 4), Row(2, "a", 3))) + checkAnswer( + df.select('id, 'category, sum("id").over(Window.partitionBy('category).orderBy('id) + .rowsBetween(0, 2160000000L))), + Seq( + Row(1, "b", 6), Row(2, "b", 5), Row(3, "b", 3), + Row(1, "a", 4), Row(1, "a", 3), Row(2, "a", 2))) + checkAnswer( + df.select('id, 'category, sum("id").over(Window.partitionBy('category).orderBy('id) + .rowsBetween(2, 2160000000L))), + Seq( + Row(1, "b", 3), Row(2, "b", null), Row(3, "b", null), + Row(1, "a", 2), Row(1, "a", null), Row(2, "a", null))) + } } From 7ae4e4845b5049ed5df68b57c340cf4c347f9d5e Mon Sep 17 00:00:00 2001 From: uncleGen Date: Tue, 14 Feb 2017 10:53:33 +0800 Subject: [PATCH 3/4] address the comment from hvanhovell --- .../expressions/windowExpressions.scala | 8 ++++---- .../sql/execution/window/BoundOrdering.scala | 16 ++++++--------- .../sql/execution/window/WindowExec.scala | 10 +++++----- .../window/WindowFunctionFrame.scala | 16 +++++++-------- .../spark/sql/expressions/WindowSpec.scala | 16 +++++++++------ .../sql/DataFrameWindowFunctionsSuite.scala | 20 +++++++++++++++++++ 6 files changed, 53 insertions(+), 33 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala index beeb0542cb3e5..07d294b108548 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala @@ -140,8 +140,8 @@ sealed trait FrameBoundary { * Extractor for making working with frame boundaries easier. */ object FrameBoundary { - def apply(boundary: FrameBoundary): Option[Long] = unapply(boundary) - def unapply(boundary: FrameBoundary): Option[Long] = boundary match { + def apply(boundary: FrameBoundary): Option[Int] = unapply(boundary) + def unapply(boundary: FrameBoundary): Option[Int] = boundary match { case CurrentRow => Some(0) case ValuePreceding(offset) => Some(-offset) case ValueFollowing(offset) => Some(offset) @@ -163,7 +163,7 @@ case object UnboundedPreceding extends FrameBoundary { } /** PRECEDING boundary. */ -case class ValuePreceding(value: Long) extends FrameBoundary { +case class ValuePreceding(value: Int) extends FrameBoundary { def notFollows(other: FrameBoundary): Boolean = other match { case UnboundedPreceding => false case ValuePreceding(anotherValue) => value >= anotherValue @@ -189,7 +189,7 @@ case object CurrentRow extends FrameBoundary { } /** FOLLOWING boundary. */ -case class ValueFollowing(value: Long) extends FrameBoundary { +case class ValueFollowing(value: Int) extends FrameBoundary { def notFollows(other: FrameBoundary): Boolean = other match { case UnboundedPreceding => false case vp: ValuePreceding => false diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/window/BoundOrdering.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/window/BoundOrdering.scala index 9a0ec8ec4270f..d6a801954c1ac 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/window/BoundOrdering.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/window/BoundOrdering.scala @@ -25,22 +25,18 @@ import org.apache.spark.sql.catalyst.expressions.Projection * Function for comparing boundary values. */ private[window] abstract class BoundOrdering { - def compare( - inputRow: InternalRow, - inputIndex: Long, - outputRow: InternalRow, - outputIndex: Long): Long + def compare(inputRow: InternalRow, inputIndex: Int, outputRow: InternalRow, outputIndex: Int): Int } /** * Compare the input index to the bound of the output index. */ -private[window] final case class RowBoundOrdering(offset: Long) extends BoundOrdering { +private[window] final case class RowBoundOrdering(offset: Int) extends BoundOrdering { override def compare( inputRow: InternalRow, - inputIndex: Long, + inputIndex: Int, outputRow: InternalRow, - outputIndex: Long): Long = + outputIndex: Int): Int = inputIndex - (outputIndex + offset) } @@ -55,8 +51,8 @@ private[window] final case class RangeBoundOrdering( override def compare( inputRow: InternalRow, - inputIndex: Long, + inputIndex: Int, outputRow: InternalRow, - outputIndex: Long): Long = + outputIndex: Int): Int = ordering.compare(current(inputRow), bound(outputRow)) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowExec.scala index de312354b5e5f..80b87d5ffa797 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowExec.scala @@ -27,7 +27,7 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.execution.{SparkPlan, UnaryExecNode} -import org.apache.spark.sql.types.LongType +import org.apache.spark.sql.types.IntegerType import org.apache.spark.util.collection.unsafe.sort.UnsafeExternalSorter /** @@ -115,7 +115,7 @@ case class WindowExec( * @param offset with respect to the row. * @return a bound ordering object. */ - private[this] def createBoundOrdering(frameType: FrameType, offset: Long): BoundOrdering = { + private[this] def createBoundOrdering(frameType: FrameType, offset: Int): BoundOrdering = { frameType match { case RangeFrame => val (exprs, current, bound) = if (offset == 0) { @@ -135,7 +135,7 @@ case class WindowExec( case Ascending => offset } // Create the projection which returns the current 'value' modified by adding the offset. - val boundExpr = Add(expr, Cast(Literal.create(boundOffset, LongType), expr.dataType)) + val boundExpr = Add(expr, Cast(Literal.create(boundOffset, IntegerType), expr.dataType)) val bound = newMutableProjection(boundExpr :: Nil, child.output) (sortExpr :: Nil, current, bound) } else { @@ -159,7 +159,7 @@ case class WindowExec( * WindowExpressions and factory function for the WindowFrameFunction. */ private[this] lazy val windowFrameExpressionFactoryPairs = { - type FrameKey = (String, FrameType, Option[Long], Option[Long]) + type FrameKey = (String, FrameType, Option[Int], Option[Int]) type ExpressionBuffer = mutable.Buffer[Expression] val framedFunctions = mutable.Map.empty[FrameKey, (ExpressionBuffer, ExpressionBuffer)] @@ -379,7 +379,7 @@ case class WindowExec( } // Iteration - var rowIndex = 0L + var rowIndex = 0 var rowsSize = 0L override final def hasNext: Boolean = rowIndex < rowsSize || nextRowAvailable diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowFunctionFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowFunctionFrame.scala index a9be025833516..70efc0f78ddb0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowFunctionFrame.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowFunctionFrame.scala @@ -40,7 +40,7 @@ private[window] abstract class WindowFunctionFrame { /** * Write the current results to the target row. */ - def write(index: Long, current: InternalRow): Unit + def write(index: Int, current: InternalRow): Unit } /** @@ -61,14 +61,14 @@ private[window] final class OffsetWindowFunctionFrame( expressions: Array[OffsetWindowFunction], inputSchema: Seq[Attribute], newMutableProjection: (Seq[Expression], Seq[Attribute]) => MutableProjection, - offset: Long) + offset: Int) extends WindowFunctionFrame { /** Rows of the partition currently being processed. */ private[this] var input: RowBuffer = null /** Index of the input row currently used for output. */ - private[this] var inputIndex = 0L + private[this] var inputIndex = 0 /** * Create the projection used when the offset row exists. @@ -114,7 +114,7 @@ private[window] final class OffsetWindowFunctionFrame( inputIndex = offset } - override def write(index: Long, current: InternalRow): Unit = { + override def write(index: Int, current: InternalRow): Unit = { if (inputIndex >= 0 && inputIndex < input.size) { val r = input.next() projection(r) @@ -173,7 +173,7 @@ private[window] final class SlidingWindowFunctionFrame( } /** Write the frame columns for the current row to the given target row. */ - override def write(index: Long, current: InternalRow): Unit = { + override def write(index: Int, current: InternalRow): Unit = { var bufferUpdated = index == 0 // Add all rows to the buffer for which the input row value is equal to or less than @@ -233,7 +233,7 @@ private[window] final class UnboundedWindowFunctionFrame( } /** Write the frame columns for the current row to the given target row. */ - override def write(index: Long, current: InternalRow): Unit = { + override def write(index: Int, current: InternalRow): Unit = { // Unfortunately we cannot assume that evaluation is deterministic. So we need to re-evaluate // for each row. processor.evaluate(target) @@ -281,7 +281,7 @@ private[window] final class UnboundedPrecedingWindowFunctionFrame( } /** Write the frame columns for the current row to the given target row. */ - override def write(index: Long, current: InternalRow): Unit = { + override def write(index: Int, current: InternalRow): Unit = { var bufferUpdated = index == 0 // Add all rows to the aggregates for which the input row value is equal to or less than @@ -338,7 +338,7 @@ private[window] final class UnboundedFollowingWindowFunctionFrame( } /** Write the frame columns for the current row to the given target row. */ - override def write(index: Long, current: InternalRow): Unit = { + override def write(index: Int, current: InternalRow): Unit = { var bufferUpdated = index == 0 // Duplicate the input to have a new iterator diff --git a/sql/core/src/main/scala/org/apache/spark/sql/expressions/WindowSpec.scala b/sql/core/src/main/scala/org/apache/spark/sql/expressions/WindowSpec.scala index 3453e44a07ff1..e44a192ae3aea 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/expressions/WindowSpec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/expressions/WindowSpec.scala @@ -180,16 +180,20 @@ class WindowSpec private[sql]( private def between(typ: FrameType, start: Long, end: Long): WindowSpec = { val boundaryStart = start match { case 0 => CurrentRow - case Long.MinValue => UnboundedPreceding - case x if x < 0 => ValuePreceding(-start) - case x if x > 0 => ValueFollowing(start) + case x if x < Int.MinValue => UnboundedPreceding + case x if x < 0 && x >= Int.MinValue => ValuePreceding(-start.toInt) + case x if x > 0 && x <= Int.MaxValue => ValueFollowing(start.toInt) + case _ => throw new IllegalArgumentException(s"Boundary start($start) should not be " + + s"larger than Int.MaxValue(${Int.MaxValue}).") } val boundaryEnd = end match { case 0 => CurrentRow - case Long.MaxValue => UnboundedFollowing - case x if x < 0 => ValuePreceding(-end) - case x if x > 0 => ValueFollowing(end) + case x if x > Int.MaxValue => UnboundedFollowing + case x if x < 0 && x >= Int.MinValue => ValuePreceding(-end.toInt) + case x if x > 0 && x <= Int.MaxValue => ValueFollowing(end.toInt) + case _ => throw new IllegalArgumentException(s"Boundary end($end) should not be " + + s"smaller than Int.MinValue(${Int.MinValue}).") } new WindowSpec( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFunctionsSuite.scala index 1d19cbe007c12..44208ad4ce561 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFunctionsSuite.scala @@ -516,5 +516,25 @@ class DataFrameWindowFunctionsSuite extends QueryTest with SharedSQLContext { Seq( Row(1, "b", 3), Row(2, "b", null), Row(3, "b", null), Row(1, "a", 2), Row(1, "a", null), Row(2, "a", null))) + try { + checkAnswer( + df.select('id, 'category, sum("id").over(Window.partitionBy('category).orderBy('id) + .rowsBetween(-3160000000L, -2160000000L))), + Seq()) + assert(false, "Boundary end should not be smaller than Int.MinValue(-2147483648).") + } catch { + case e: IllegalArgumentException => + // expected + } + try { + checkAnswer( + df.select('id, 'category, sum("id").over(Window.partitionBy('category).orderBy('id) + .rowsBetween(2160000000L, 3160000000L))), + Seq()) + assert(false, "Boundary start should not be larger than Int.MaxValue(2147483647).") + } catch { + case e: IllegalArgumentException => + // expected + } } } From c65de9a629f508cf5ee3304f5e7515606fbb1355 Mon Sep 17 00:00:00 2001 From: uncleGen Date: Wed, 15 Feb 2017 09:42:49 +0800 Subject: [PATCH 4/4] update doc --- .../org/apache/spark/sql/expressions/Window.scala | 12 ++++++------ .../apache/spark/sql/expressions/WindowSpec.scala | 14 +++++++------- 2 files changed, 13 insertions(+), 13 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/expressions/Window.scala b/sql/core/src/main/scala/org/apache/spark/sql/expressions/Window.scala index f3cf3052ea3ea..55023e6d09191 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/expressions/Window.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/expressions/Window.scala @@ -148,9 +148,9 @@ object Window { * }}} * * @param start boundary start, inclusive. The frame is unbounded if this is - * the minimum long value (`Window.unboundedPreceding`). - * @param end boundary end, inclusive. The frame is unbounded if this is the - * maximum long value (`Window.unboundedFollowing`). + * less than minimum int value (`Window.unboundedPreceding`). + * @param end boundary end, inclusive. The frame is unbounded if this is larger + * than maximum int value (`Window.unboundedFollowing`). * @since 2.1.0 */ // Note: when updating the doc for this method, also update WindowSpec.rowsBetween. @@ -200,9 +200,9 @@ object Window { * }}} * * @param start boundary start, inclusive. The frame is unbounded if this is - * the minimum long value (`Window.unboundedPreceding`). - * @param end boundary end, inclusive. The frame is unbounded if this is the - * maximum long value (`Window.unboundedFollowing`). + * less than minimum int value (`Window.unboundedPreceding`). + * @param end boundary end, inclusive. The frame is unbounded if this is larger + * than maximum int value (`Window.unboundedFollowing`). * @since 2.1.0 */ // Note: when updating the doc for this method, also update WindowSpec.rangeBetween. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/expressions/WindowSpec.scala b/sql/core/src/main/scala/org/apache/spark/sql/expressions/WindowSpec.scala index e44a192ae3aea..3ba7c74d7e52e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/expressions/WindowSpec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/expressions/WindowSpec.scala @@ -116,9 +116,9 @@ class WindowSpec private[sql]( * }}} * * @param start boundary start, inclusive. The frame is unbounded if this is - * the minimum long value (`Window.unboundedPreceding`). - * @param end boundary end, inclusive. The frame is unbounded if this is the - * maximum long value (`Window.unboundedFollowing`). + * less than minimum int value (`Window.unboundedPreceding`). + * @param end boundary end, inclusive. The frame is unbounded if this is larger + * than maximum int value (`Window.unboundedFollowing`). * @since 1.4.0 */ // Note: when updating the doc for this method, also update Window.rowsBetween. @@ -167,9 +167,9 @@ class WindowSpec private[sql]( * }}} * * @param start boundary start, inclusive. The frame is unbounded if this is - * the minimum long value (`Window.unboundedPreceding`). - * @param end boundary end, inclusive. The frame is unbounded if this is the - * maximum long value (`Window.unboundedFollowing`). + * less than minimum int value (`Window.unboundedPreceding`). + * @param end boundary end, inclusive. The frame is unbounded if this is larger + * than maximum int value (`Window.unboundedFollowing`). * @since 1.4.0 */ // Note: when updating the doc for this method, also update Window.rangeBetween. @@ -193,7 +193,7 @@ class WindowSpec private[sql]( case x if x < 0 && x >= Int.MinValue => ValuePreceding(-end.toInt) case x if x > 0 && x <= Int.MaxValue => ValueFollowing(end.toInt) case _ => throw new IllegalArgumentException(s"Boundary end($end) should not be " + - s"smaller than Int.MinValue(${Int.MinValue}).") + s"less than Int.MinValue(${Int.MinValue}).") } new WindowSpec(