Skip to content

Commit

Permalink
address comments.
Browse files Browse the repository at this point in the history
  • Loading branch information
gatorsmile committed Jul 8, 2016
1 parent 1abdbb9 commit 3036847
Show file tree
Hide file tree
Showing 4 changed files with 52 additions and 11 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,8 @@ class Analyzer(
TimeWindowing ::
TypeCoercion.typeCoercionRules ++
extendedResolutionRules : _*),
Batch("LIMIT", Once,
ResolveLimits),
Batch("Nondeterministic", Once,
PullOutNondeterministic),
Batch("UDF", Once,
Expand Down Expand Up @@ -2044,6 +2046,21 @@ object EliminateUnions extends Rule[LogicalPlan] {
}
}

/**
* Converts foldable numeric expressions to integers of [[GlobalLimit]] and [[LocalLimit]] operators
*/
object ResolveLimits extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
case g @ GlobalLimit(limitExpr, _) if limitExpr.foldable && isNumeric(limitExpr.eval()) =>
g.copy(limitExpr = Literal(limitExpr.eval().asInstanceOf[Number].intValue(), IntegerType))
case l @ LocalLimit(limitExpr, _) if limitExpr.foldable && isNumeric(limitExpr.eval()) =>
l.copy(limitExpr = Literal(limitExpr.eval().asInstanceOf[Number].intValue(), IntegerType))
}

private def isNumeric(value: Any): Boolean =
scala.util.Try(value.asInstanceOf[Number].intValue()).isSuccess
}

/**
* Cleans up unnecessary Aliases inside the plan. Basically we only need Alias as a top level
* expression in Project(project list) or Aggregate(aggregate expressions) or
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,14 +50,13 @@ trait CheckAnalysis extends PredicateHelper {
if (!limitExpr.foldable) {
failAnalysis(
"The argument to the LIMIT clause must evaluate to a constant value. " +
s"Limit:${limitExpr.sql}")
s"Limit:${limitExpr.sql}")
}
limitExpr.eval() match {
case o: Int if o >= 0 => // OK
case o: Int => failAnalysis(
s"number_rows in limit clause must be equal to or greater than 0. number_rows:$o")
case o => failAnalysis(
s"number_rows in limit clause cannot be cast to integer:$o")
limitExpr match {
case IntegerLiteral(limit) if limit >= 0 => // OK
case IntegerLiteral(limit) => failAnalysis(
s"number_rows in limit clause must be equal to or greater than 0. number_rows:$limit")
case o => failAnalysis(s"""number_rows in limit clause cannot be cast to integer:"$o".""")
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -659,7 +659,7 @@ case class GlobalLimit(limitExpr: Expression, child: LogicalPlan) extends UnaryN
}
}
override lazy val statistics: Statistics = {
val limit = limitExpr.eval().asInstanceOf[Int]
val limit = limitExpr.eval().asInstanceOf[Number].intValue()
val sizeInBytes = if (limit == 0) {
// sizeInBytes can't be zero, or sizeInBytes of BinaryNode will also be zero
// (product of children).
Expand All @@ -680,7 +680,7 @@ case class LocalLimit(limitExpr: Expression, child: LogicalPlan) extends UnaryNo
}
}
override lazy val statistics: Statistics = {
val limit = limitExpr.eval().asInstanceOf[Int]
val limit = limitExpr.eval().asInstanceOf[Number].intValue()
val sizeInBytes = if (limit == 0) {
// sizeInBytes can't be zero, or sizeInBytes of BinaryNode will also be zero
// (product of children).
Expand Down
29 changes: 27 additions & 2 deletions sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -670,6 +670,26 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext {
checkAnswer(
sql("SELECT * FROM mapData LIMIT 1"),
mapData.collect().take(1).map(Row.fromTuple).toSeq)

checkAnswer(
sql("SELECT * FROM mapData LIMIT CAST(1 AS Double)"),
mapData.collect().take(1).map(Row.fromTuple).toSeq)

checkAnswer(
sql("SELECT * FROM mapData LIMIT CAST(1 AS BYTE)"),
mapData.collect().take(1).map(Row.fromTuple).toSeq)

checkAnswer(
sql("SELECT * FROM mapData LIMIT CAST(1 AS LONG)"),
mapData.collect().take(1).map(Row.fromTuple).toSeq)

checkAnswer(
sql("SELECT * FROM mapData LIMIT CAST(1 AS SHORT)"),
mapData.collect().take(1).map(Row.fromTuple).toSeq)

checkAnswer(
sql("SELECT * FROM mapData LIMIT CAST(1 AS FLOAT)"),
mapData.collect().take(1).map(Row.fromTuple).toSeq)
}

test("non-foldable expressions in LIMIT") {
Expand All @@ -681,10 +701,15 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext {
}

test("Limit: unable to evaluate and cast expressions in limit clauses to Int") {
val e = intercept[AnalysisException] {
var e = intercept[AnalysisException] {
sql("SELECT * FROM testData LIMIT true")
}.getMessage
assert(e.contains("number_rows in limit clause cannot be cast to integer:true"))
assert(e.contains("number_rows in limit clause cannot be cast to integer:\"true\""))

e = intercept[AnalysisException] {
sql("SELECT * FROM testData LIMIT 'a'")
}.getMessage
assert(e.contains("number_rows in limit clause cannot be cast to integer:\"a\""))
}

test("negative in LIMIT or TABLESAMPLE") {
Expand Down

0 comments on commit 3036847

Please sign in to comment.