Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-8245][SQL] FormatNumber/Length Support for Expression #7034

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 20 additions & 5 deletions python/pyspark/sql/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@
'coalesce',
'countDistinct',
'explode',
'format_number',
'length',
'log2',
'md5',
'monotonicallyIncreasingId',
Expand All @@ -47,7 +49,6 @@
'sha1',
'sha2',
'sparkPartitionId',
'strlen',
'struct',
'udf',
'when']
Expand Down Expand Up @@ -506,14 +507,28 @@ def sparkPartitionId():

@ignore_unicode_prefix
@since(1.5)
def strlen(col):
"""Calculates the length of a string expression.
def length(col):
"""Calculates the length of a string or binary expression.

>>> sqlContext.createDataFrame([('ABC',)], ['a']).select(strlen('a').alias('length')).collect()
>>> sqlContext.createDataFrame([('ABC',)], ['a']).select(length('a').alias('length')).collect()
[Row(length=3)]
"""
sc = SparkContext._active_spark_context
return Column(sc._jvm.functions.strlen(_to_java_column(col)))
return Column(sc._jvm.functions.length(_to_java_column(col)))


@ignore_unicode_prefix
@since(1.5)
def format_number(col, d):
"""Formats the number X to a format like '#,###,###.##', rounded to d decimal places,
and returns the result as a string.
:param col: the column name of the numeric value to be formatted
:param d: the N decimal places
>>> sqlContext.createDataFrame([(5,)], ['a']).select(format_number('a', 4).alias('v')).collect()
[Row(v=u'5.0000')]
"""
sc = SparkContext._active_spark_context
return Column(sc._jvm.functions.format_number(_to_java_column(col), d))


@ignore_unicode_prefix
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -152,11 +152,12 @@ object FunctionRegistry {
expression[Base64]("base64"),
expression[Encode]("encode"),
expression[Decode]("decode"),
expression[StringInstr]("instr"),
expression[FormatNumber]("format_number"),
expression[Lower]("lcase"),
expression[Lower]("lower"),
expression[StringLength]("length"),
expression[Length]("length"),
expression[Levenshtein]("levenshtein"),
expression[StringInstr]("instr"),
expression[StringLocate]("locate"),
expression[StringLPad]("lpad"),
expression[StringTrimLeft]("ltrim"),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,10 @@

package org.apache.spark.sql.catalyst.expressions

import java.text.DecimalFormat
import java.util.Locale
import java.util.regex.Pattern

import org.apache.commons.lang3.StringUtils

import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.UnresolvedException
import org.apache.spark.sql.catalyst.expressions.codegen._
Expand Down Expand Up @@ -553,17 +552,22 @@ case class Substring(str: Expression, pos: Expression, len: Expression)
}

/**
* A function that return the length of the given string expression.
* A function that return the length of the given string or binary expression.
*/
case class StringLength(child: Expression) extends UnaryExpression with ImplicitCastInputTypes {
case class Length(child: Expression) extends UnaryExpression with ExpectsInputTypes {
override def dataType: DataType = IntegerType
override def inputTypes: Seq[DataType] = Seq(StringType)
override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection(StringType, BinaryType))

protected override def nullSafeEval(string: Any): Any =
string.asInstanceOf[UTF8String].numChars
protected override def nullSafeEval(value: Any): Any = child.dataType match {
case StringType => value.asInstanceOf[UTF8String].numChars
case BinaryType => value.asInstanceOf[Array[Byte]].length
}

override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
defineCodeGen(ctx, ev, c => s"($c).numChars()")
child.dataType match {
case StringType => defineCodeGen(ctx, ev, c => s"($c).numChars()")
case BinaryType => defineCodeGen(ctx, ev, c => s"($c).length")
}
}

override def prettyName: String = "length"
Expand Down Expand Up @@ -668,3 +672,77 @@ case class Encode(value: Expression, charset: Expression)
}
}

/**
* Formats the number X to a format like '#,###,###.##', rounded to D decimal places,
* and returns the result as a string. If D is 0, the result has no decimal point or
* fractional part.
*/
case class FormatNumber(x: Expression, d: Expression)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

override prettyName

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

note: this is done

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sorry, yes, it's done, but in the end of this class code.

extends BinaryExpression with ExpectsInputTypes {

override def left: Expression = x
override def right: Expression = d
override def dataType: DataType = StringType
override def inputTypes: Seq[AbstractDataType] = Seq(NumericType, IntegerType)

// Associated with the pattern, for the last d value, and we will update the
// pattern (DecimalFormat) once the new coming d value differ with the last one.
@transient
private var lastDValue: Int = -100

// A cached DecimalFormat, for performance concern, we will change it
// only if the d value changed.
@transient
private val pattern: StringBuffer = new StringBuffer()

@transient
private val numberFormat: DecimalFormat = new DecimalFormat("")

override def eval(input: InternalRow): Any = {
val xObject = x.eval(input)
if (xObject == null) {
return null
}

val dObject = d.eval(input)

if (dObject == null || dObject.asInstanceOf[Int] < 0) {
return null
}
val dValue = dObject.asInstanceOf[Int]

if (dValue != lastDValue) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it'd be great to document what's happening here. from what i can tell we are caching the last pattern in order to avoid constant allocating lots of objects.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've added some comments in the description of the lastDValue. hopefully people will not get confused.

// construct a new DecimalFormat only if a new dValue
pattern.delete(0, pattern.length())
pattern.append("#,###,###,###,###,###,##0")

// decimal place
if (dValue > 0) {
pattern.append(".")

var i = 0
while (i < dValue) {
i += 1
pattern.append("0")
}
}
val dFormat = new DecimalFormat(pattern.toString())
lastDValue = dValue;
numberFormat.applyPattern(dFormat.toPattern())
}

x.dataType match {
case ByteType => UTF8String.fromString(numberFormat.format(xObject.asInstanceOf[Byte]))
case ShortType => UTF8String.fromString(numberFormat.format(xObject.asInstanceOf[Short]))
case FloatType => UTF8String.fromString(numberFormat.format(xObject.asInstanceOf[Float]))
case IntegerType => UTF8String.fromString(numberFormat.format(xObject.asInstanceOf[Int]))
case LongType => UTF8String.fromString(numberFormat.format(xObject.asInstanceOf[Long]))
case DoubleType => UTF8String.fromString(numberFormat.format(xObject.asInstanceOf[Double]))
case _: DecimalType =>
UTF8String.fromString(numberFormat.format(xObject.asInstanceOf[Decimal].toJavaBigDecimal))
}
}

override def prettyName: String = "format_number"
}

Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.expressions

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.types.{BinaryType, IntegerType, StringType}
import org.apache.spark.sql.types._


class StringFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper {
Expand Down Expand Up @@ -216,15 +216,6 @@ class StringFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper {
}
}

test("length for string") {
val a = 'a.string.at(0)
checkEvaluation(StringLength(Literal("abc")), 3, create_row("abdef"))
checkEvaluation(StringLength(a), 5, create_row("abdef"))
checkEvaluation(StringLength(a), 0, create_row(""))
checkEvaluation(StringLength(a), null, create_row(null))
checkEvaluation(StringLength(Literal.create(null, StringType)), null, create_row("abdef"))
}

test("ascii for string") {
val a = 'a.string.at(0)
checkEvaluation(Ascii(Literal("efg")), 101, create_row("abdef"))
Expand Down Expand Up @@ -426,4 +417,46 @@ class StringFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper {
checkEvaluation(
StringSplit(s1, s2), Seq("aa", "bb", "cc"), row1)
}

test("length for string / binary") {
val a = 'a.string.at(0)
val b = 'b.binary.at(0)
val bytes = Array[Byte](1, 2, 3, 1, 2)
val string = "abdef"

// scalastyle:off
// non ascii characters are not allowed in the source code, so we disable the scalastyle.
checkEvaluation(Length(Literal("a花花c")), 4, create_row(string))
// scalastyle:on
checkEvaluation(Length(Literal(bytes)), 5, create_row(Array[Byte]()))

checkEvaluation(Length(a), 5, create_row(string))
checkEvaluation(Length(b), 5, create_row(bytes))

checkEvaluation(Length(a), 0, create_row(""))
checkEvaluation(Length(b), 0, create_row(Array[Byte]()))

checkEvaluation(Length(a), null, create_row(null))
checkEvaluation(Length(b), null, create_row(null))

checkEvaluation(Length(Literal.create(null, StringType)), null, create_row(string))
checkEvaluation(Length(Literal.create(null, BinaryType)), null, create_row(bytes))
}

test("number format") {
checkEvaluation(FormatNumber(Literal(4.asInstanceOf[Byte]), Literal(3)), "4.000")
checkEvaluation(FormatNumber(Literal(4.asInstanceOf[Short]), Literal(3)), "4.000")
checkEvaluation(FormatNumber(Literal(4.0f), Literal(3)), "4.000")
checkEvaluation(FormatNumber(Literal(4), Literal(3)), "4.000")
checkEvaluation(FormatNumber(Literal(12831273.23481d), Literal(3)), "12,831,273.235")
checkEvaluation(FormatNumber(Literal(12831273.83421d), Literal(0)), "12,831,274")
checkEvaluation(FormatNumber(Literal(123123324123L), Literal(3)), "123,123,324,123.000")
checkEvaluation(FormatNumber(Literal(123123324123L), Literal(-1)), null)
checkEvaluation(
FormatNumber(
Literal(Decimal(123123324123L) * Decimal(123123.21234d)), Literal(4)),
"15,159,339,180,002,773.2778")
checkEvaluation(FormatNumber(Literal.create(null, IntegerType), Literal(3)), null)
checkEvaluation(FormatNumber(Literal.create(null, NullType), Literal(3)), null)
}
}
32 changes: 28 additions & 4 deletions sql/core/src/main/scala/org/apache/spark/sql/functions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1685,20 +1685,44 @@ object functions {
//////////////////////////////////////////////////////////////////////////////////////////////

/**
* Computes the length of a given string value.
* Computes the length of a given string / binary value.
*
* @group string_funcs
* @since 1.5.0
*/
def strlen(e: Column): Column = StringLength(e.expr)
def length(e: Column): Column = Length(e.expr)

/**
* Computes the length of a given string column.
* Computes the length of a given string / binary column.
*
* @group string_funcs
* @since 1.5.0
*/
def strlen(columnName: String): Column = strlen(Column(columnName))
def length(columnName: String): Column = length(Column(columnName))

/**
* Formats the number X to a format like '#,###,###.##', rounded to d decimal places,
* and returns the result as a string.
* If d is 0, the result has no decimal point or fractional part.
* If d < 0, the result will be null.
*
* @group string_funcs
* @since 1.5.0
*/
def format_number(x: Column, d: Int): Column = FormatNumber(x.expr, lit(d).expr)

/**
* Formats the number X to a format like '#,###,###.##', rounded to d decimal places,
* and returns the result as a string.
* If d is 0, the result has no decimal point or fractional part.
* If d < 0, the result will be null.
*
* @group string_funcs
* @since 1.5.0
*/
def format_number(columnXName: String, d: Int): Column = {
format_number(Column(columnXName), d)
}

/**
* Computes the Levenshtein distance of the two given strings.
Expand Down
Loading