Skip to content

Commit

Permalink
[SPARK-29854]lpad and rpad built in function should show Error or thr…
Browse files Browse the repository at this point in the history
…ow Exception for invalid length value
  • Loading branch information
07ARB committed Dec 28, 2019
1 parent 7adf886 commit 2793ac6
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 0 deletions.
Expand Up @@ -26,6 +26,7 @@ import scala.collection.mutable.ArrayBuffer

import org.apache.commons.codec.binary.{Base64 => CommonsBase64}

import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.expressions.codegen._
Expand Down Expand Up @@ -1227,6 +1228,19 @@ case class StringLPad(str: Expression, len: Expression, pad: Expression = Litera
override def dataType: DataType = StringType
override def inputTypes: Seq[DataType] = Seq(StringType, IntegerType, StringType)

override def checkInputDataTypes(): TypeCheckResult = {
val inputTypeCheck = super.checkInputDataTypes()
if (inputTypeCheck.isFailure) {
try {
if (len != null && len.toString.toInt.isValidInt) inputTypeCheck
} catch {
case _: NumberFormatException =>
throw new AnalysisException(s"Invalid argument, $inputTypeCheck")
}
}
inputTypeCheck
}

override def nullSafeEval(str: Any, len: Any, pad: Any): Any = {
str.asInstanceOf[UTF8String].lpad(len.asInstanceOf[Int], pad.asInstanceOf[UTF8String])
}
Expand Down Expand Up @@ -1268,6 +1282,19 @@ case class StringRPad(str: Expression, len: Expression, pad: Expression = Litera
override def dataType: DataType = StringType
override def inputTypes: Seq[DataType] = Seq(StringType, IntegerType, StringType)

override def checkInputDataTypes(): TypeCheckResult = {
val inputTypeCheck = super.checkInputDataTypes()
if (inputTypeCheck.isFailure) {
try {
if (len != null && len.toString.toInt.isValidInt) inputTypeCheck
} catch {
case _: NumberFormatException =>
throw new AnalysisException(s"Invalid argument, $inputTypeCheck")
}
}
inputTypeCheck
}

override def nullSafeEval(str: Any, len: Any, pad: Any): Any = {
str.asInstanceOf[UTF8String].rpad(len.asInstanceOf[Int], pad.asInstanceOf[UTF8String])
}
Expand Down
Expand Up @@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.expressions
import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.expressions.codegen.BufferHolderSparkSubmitSuite.{assert, intercept}
import org.apache.spark.sql.types._

class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
Expand Down Expand Up @@ -720,6 +721,16 @@ class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
checkEvaluation(StringRPad(s1, s2, s3), null, row5)
checkEvaluation(StringRPad(Literal("hi"), Literal(5)), "hi ")
checkEvaluation(StringRPad(Literal("hi"), Literal(1)), "h")

assert(intercept[AnalysisException] {
checkEvaluation(StringRPad(Literal("hi"), Literal("invalidLength")), "Exception")
}.getMessage.contains("Invalid argument, TypeCheckFailure(argument 2 " +
"requires int type, however, ''invalidLength'' is of string type.);"))

assert(intercept[AnalysisException] {
checkEvaluation(StringLPad(Literal("hi"), Literal("invalidLength")), "Exception")
}.getMessage.contains("Invalid argument, TypeCheckFailure(argument 2 " +
"requires int type, however, ''invalidLength'' is of string type.);"))
}

test("REPEAT") {
Expand Down

0 comments on commit 2793ac6

Please sign in to comment.