diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index dca39fa833435..e0816b3e654bc 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -39,6 +39,8 @@ 'coalesce', 'countDistinct', 'explode', + 'format_number', + 'length', 'log2', 'md5', 'monotonicallyIncreasingId', @@ -47,7 +49,6 @@ 'sha1', 'sha2', 'sparkPartitionId', - 'strlen', 'struct', 'udf', 'when'] @@ -506,14 +507,28 @@ def sparkPartitionId(): @ignore_unicode_prefix @since(1.5) -def strlen(col): - """Calculates the length of a string expression. +def length(col): + """Calculates the length of a string or binary expression. - >>> sqlContext.createDataFrame([('ABC',)], ['a']).select(strlen('a').alias('length')).collect() + >>> sqlContext.createDataFrame([('ABC',)], ['a']).select(length('a').alias('length')).collect() [Row(length=3)] """ sc = SparkContext._active_spark_context - return Column(sc._jvm.functions.strlen(_to_java_column(col))) + return Column(sc._jvm.functions.length(_to_java_column(col))) + + +@ignore_unicode_prefix +@since(1.5) +def format_number(col, d): + """Formats the number X to a format like '#,###,###.##', rounded to d decimal places, + and returns the result as a string. + :param col: the column name of the numeric value to be formatted + :param d: the N decimal places + >>> sqlContext.createDataFrame([(5,)], ['a']).select(format_number('a', 4).alias('v')).collect() + [Row(v=u'5.0000')] + """ + sc = SparkContext._active_spark_context + return Column(sc._jvm.functions.format_number(_to_java_column(col), d)) @ignore_unicode_prefix diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index d2678ce860701..e0beafe710079 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -152,11 +152,12 @@ object FunctionRegistry { expression[Base64]("base64"), expression[Encode]("encode"), expression[Decode]("decode"), - expression[StringInstr]("instr"), + expression[FormatNumber]("format_number"), expression[Lower]("lcase"), expression[Lower]("lower"), - expression[StringLength]("length"), + expression[Length]("length"), expression[Levenshtein]("levenshtein"), + expression[StringInstr]("instr"), expression[StringLocate]("locate"), expression[StringLPad]("lpad"), expression[StringTrimLeft]("ltrim"), diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala index 03b55ce5fe7cc..c64afe7b3f19a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala @@ -17,11 +17,10 @@ package org.apache.spark.sql.catalyst.expressions +import java.text.DecimalFormat import java.util.Locale import java.util.regex.Pattern -import org.apache.commons.lang3.StringUtils - import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.UnresolvedException import org.apache.spark.sql.catalyst.expressions.codegen._ @@ -553,17 +552,22 @@ case class Substring(str: Expression, pos: Expression, len: Expression) } /** - * A function that return the length of the given string expression. + * A function that return the length of the given string or binary expression. */ -case class StringLength(child: Expression) extends UnaryExpression with ImplicitCastInputTypes { +case class Length(child: Expression) extends UnaryExpression with ExpectsInputTypes { override def dataType: DataType = IntegerType - override def inputTypes: Seq[DataType] = Seq(StringType) + override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection(StringType, BinaryType)) - protected override def nullSafeEval(string: Any): Any = - string.asInstanceOf[UTF8String].numChars + protected override def nullSafeEval(value: Any): Any = child.dataType match { + case StringType => value.asInstanceOf[UTF8String].numChars + case BinaryType => value.asInstanceOf[Array[Byte]].length + } override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { - defineCodeGen(ctx, ev, c => s"($c).numChars()") + child.dataType match { + case StringType => defineCodeGen(ctx, ev, c => s"($c).numChars()") + case BinaryType => defineCodeGen(ctx, ev, c => s"($c).length") + } } override def prettyName: String = "length" @@ -668,3 +672,77 @@ case class Encode(value: Expression, charset: Expression) } } +/** + * Formats the number X to a format like '#,###,###.##', rounded to D decimal places, + * and returns the result as a string. If D is 0, the result has no decimal point or + * fractional part. + */ +case class FormatNumber(x: Expression, d: Expression) + extends BinaryExpression with ExpectsInputTypes { + + override def left: Expression = x + override def right: Expression = d + override def dataType: DataType = StringType + override def inputTypes: Seq[AbstractDataType] = Seq(NumericType, IntegerType) + + // Associated with the pattern, for the last d value, and we will update the + // pattern (DecimalFormat) once the new coming d value differ with the last one. + @transient + private var lastDValue: Int = -100 + + // A cached DecimalFormat, for performance concern, we will change it + // only if the d value changed. + @transient + private val pattern: StringBuffer = new StringBuffer() + + @transient + private val numberFormat: DecimalFormat = new DecimalFormat("") + + override def eval(input: InternalRow): Any = { + val xObject = x.eval(input) + if (xObject == null) { + return null + } + + val dObject = d.eval(input) + + if (dObject == null || dObject.asInstanceOf[Int] < 0) { + return null + } + val dValue = dObject.asInstanceOf[Int] + + if (dValue != lastDValue) { + // construct a new DecimalFormat only if a new dValue + pattern.delete(0, pattern.length()) + pattern.append("#,###,###,###,###,###,##0") + + // decimal place + if (dValue > 0) { + pattern.append(".") + + var i = 0 + while (i < dValue) { + i += 1 + pattern.append("0") + } + } + val dFormat = new DecimalFormat(pattern.toString()) + lastDValue = dValue; + numberFormat.applyPattern(dFormat.toPattern()) + } + + x.dataType match { + case ByteType => UTF8String.fromString(numberFormat.format(xObject.asInstanceOf[Byte])) + case ShortType => UTF8String.fromString(numberFormat.format(xObject.asInstanceOf[Short])) + case FloatType => UTF8String.fromString(numberFormat.format(xObject.asInstanceOf[Float])) + case IntegerType => UTF8String.fromString(numberFormat.format(xObject.asInstanceOf[Int])) + case LongType => UTF8String.fromString(numberFormat.format(xObject.asInstanceOf[Long])) + case DoubleType => UTF8String.fromString(numberFormat.format(xObject.asInstanceOf[Double])) + case _: DecimalType => + UTF8String.fromString(numberFormat.format(xObject.asInstanceOf[Decimal].toJavaBigDecimal)) + } + } + + override def prettyName: String = "format_number" +} + diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringFunctionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringFunctionsSuite.scala index b19f4ee37a109..5d7763bedf6bd 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringFunctionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringFunctionsSuite.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.dsl.expressions._ -import org.apache.spark.sql.types.{BinaryType, IntegerType, StringType} +import org.apache.spark.sql.types._ class StringFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { @@ -216,15 +216,6 @@ class StringFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { } } - test("length for string") { - val a = 'a.string.at(0) - checkEvaluation(StringLength(Literal("abc")), 3, create_row("abdef")) - checkEvaluation(StringLength(a), 5, create_row("abdef")) - checkEvaluation(StringLength(a), 0, create_row("")) - checkEvaluation(StringLength(a), null, create_row(null)) - checkEvaluation(StringLength(Literal.create(null, StringType)), null, create_row("abdef")) - } - test("ascii for string") { val a = 'a.string.at(0) checkEvaluation(Ascii(Literal("efg")), 101, create_row("abdef")) @@ -426,4 +417,46 @@ class StringFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation( StringSplit(s1, s2), Seq("aa", "bb", "cc"), row1) } + + test("length for string / binary") { + val a = 'a.string.at(0) + val b = 'b.binary.at(0) + val bytes = Array[Byte](1, 2, 3, 1, 2) + val string = "abdef" + + // scalastyle:off + // non ascii characters are not allowed in the source code, so we disable the scalastyle. + checkEvaluation(Length(Literal("a花花c")), 4, create_row(string)) + // scalastyle:on + checkEvaluation(Length(Literal(bytes)), 5, create_row(Array[Byte]())) + + checkEvaluation(Length(a), 5, create_row(string)) + checkEvaluation(Length(b), 5, create_row(bytes)) + + checkEvaluation(Length(a), 0, create_row("")) + checkEvaluation(Length(b), 0, create_row(Array[Byte]())) + + checkEvaluation(Length(a), null, create_row(null)) + checkEvaluation(Length(b), null, create_row(null)) + + checkEvaluation(Length(Literal.create(null, StringType)), null, create_row(string)) + checkEvaluation(Length(Literal.create(null, BinaryType)), null, create_row(bytes)) + } + + test("number format") { + checkEvaluation(FormatNumber(Literal(4.asInstanceOf[Byte]), Literal(3)), "4.000") + checkEvaluation(FormatNumber(Literal(4.asInstanceOf[Short]), Literal(3)), "4.000") + checkEvaluation(FormatNumber(Literal(4.0f), Literal(3)), "4.000") + checkEvaluation(FormatNumber(Literal(4), Literal(3)), "4.000") + checkEvaluation(FormatNumber(Literal(12831273.23481d), Literal(3)), "12,831,273.235") + checkEvaluation(FormatNumber(Literal(12831273.83421d), Literal(0)), "12,831,274") + checkEvaluation(FormatNumber(Literal(123123324123L), Literal(3)), "123,123,324,123.000") + checkEvaluation(FormatNumber(Literal(123123324123L), Literal(-1)), null) + checkEvaluation( + FormatNumber( + Literal(Decimal(123123324123L) * Decimal(123123.21234d)), Literal(4)), + "15,159,339,180,002,773.2778") + checkEvaluation(FormatNumber(Literal.create(null, IntegerType), Literal(3)), null) + checkEvaluation(FormatNumber(Literal.create(null, NullType), Literal(3)), null) + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index c7deaca8437a1..d6da284a4c788 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -1685,20 +1685,44 @@ object functions { ////////////////////////////////////////////////////////////////////////////////////////////// /** - * Computes the length of a given string value. + * Computes the length of a given string / binary value. * * @group string_funcs * @since 1.5.0 */ - def strlen(e: Column): Column = StringLength(e.expr) + def length(e: Column): Column = Length(e.expr) /** - * Computes the length of a given string column. + * Computes the length of a given string / binary column. * * @group string_funcs * @since 1.5.0 */ - def strlen(columnName: String): Column = strlen(Column(columnName)) + def length(columnName: String): Column = length(Column(columnName)) + + /** + * Formats the number X to a format like '#,###,###.##', rounded to d decimal places, + * and returns the result as a string. + * If d is 0, the result has no decimal point or fractional part. + * If d < 0, the result will be null. + * + * @group string_funcs + * @since 1.5.0 + */ + def format_number(x: Column, d: Int): Column = FormatNumber(x.expr, lit(d).expr) + + /** + * Formats the number X to a format like '#,###,###.##', rounded to d decimal places, + * and returns the result as a string. + * If d is 0, the result has no decimal point or fractional part. + * If d < 0, the result will be null. + * + * @group string_funcs + * @since 1.5.0 + */ + def format_number(columnXName: String, d: Int): Column = { + format_number(Column(columnXName), d) + } /** * Computes the Levenshtein distance of the two given strings. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index 70bd78737f69c..6dccdd857b453 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -208,17 +208,6 @@ class DataFrameFunctionsSuite extends QueryTest { Row(2743272264L, 2180413220L)) } - test("string length function") { - val df = Seq(("abc", "")).toDF("a", "b") - checkAnswer( - df.select(strlen($"a"), strlen("b")), - Row(3, 0)) - - checkAnswer( - df.selectExpr("length(a)", "length(b)"), - Row(3, 0)) - } - test("Levenshtein distance") { val df = Seq(("kitten", "sitting"), ("frog", "fog")).toDF("l", "r") checkAnswer(df.select(levenshtein("l", "r")), Seq(Row(3), Row(1))) @@ -433,11 +422,91 @@ class DataFrameFunctionsSuite extends QueryTest { val doubleData = Seq((7.2, 4.1)).toDF("a", "b") checkAnswer( doubleData.select(pmod('a, 'b)), - Seq(Row(3.1000000000000005)) // same as hive + Seq(Row(3.1000000000000005)) // same as hive ) checkAnswer( doubleData.select(pmod(lit(2), lit(Int.MaxValue))), Seq(Row(2)) ) } + + test("string / binary length function") { + val df = Seq(("123", Array[Byte](1, 2, 3, 4), 123)).toDF("a", "b", "c") + checkAnswer( + df.select(length($"a"), length("a"), length($"b"), length("b")), + Row(3, 3, 4, 4)) + + checkAnswer( + df.selectExpr("length(a)", "length(b)"), + Row(3, 4)) + + intercept[AnalysisException] { + checkAnswer( + df.selectExpr("length(c)"), // int type of the argument is unacceptable + Row("5.0000")) + } + } + + test("number format function") { + val tuple = + ("aa", 1.asInstanceOf[Byte], 2.asInstanceOf[Short], + 3.13223f, 4, 5L, 6.48173d, Decimal(7.128381)) + val df = + Seq(tuple) + .toDF( + "a", // string "aa" + "b", // byte 1 + "c", // short 2 + "d", // float 3.13223f + "e", // integer 4 + "f", // long 5L + "g", // double 6.48173d + "h") // decimal 7.128381 + + checkAnswer( + df.select( + format_number($"f", 4), + format_number("f", 4)), + Row("5.0000", "5.0000")) + + checkAnswer( + df.selectExpr("format_number(b, e)"), // convert the 1st argument to integer + Row("1.0000")) + + checkAnswer( + df.selectExpr("format_number(c, e)"), // convert the 1st argument to integer + Row("2.0000")) + + checkAnswer( + df.selectExpr("format_number(d, e)"), // convert the 1st argument to double + Row("3.1322")) + + checkAnswer( + df.selectExpr("format_number(e, e)"), // not convert anything + Row("4.0000")) + + checkAnswer( + df.selectExpr("format_number(f, e)"), // not convert anything + Row("5.0000")) + + checkAnswer( + df.selectExpr("format_number(g, e)"), // not convert anything + Row("6.4817")) + + checkAnswer( + df.selectExpr("format_number(h, e)"), // not convert anything + Row("7.1284")) + + intercept[AnalysisException] { + checkAnswer( + df.selectExpr("format_number(a, e)"), // string type of the 1st argument is unacceptable + Row("5.0000")) + } + + intercept[AnalysisException] { + checkAnswer( + df.selectExpr("format_number(e, g)"), // decimal type of the 2nd argument is unacceptable + Row("5.0000")) + } + } }