Skip to content

Commit

Permalink
[SPARK-8226] [SQL] Add function shiftrightunsigned
Browse files Browse the repository at this point in the history
Author: zhichao.li <zhichao.li@intel.com>

Closes #7035 from zhichao-li/shiftRightUnsigned and squashes the following commits:

6bcca5a [zhichao.li] change coding style
3e9f5ae [zhichao.li] python style
d85ae0b [zhichao.li] add shiftrightunsigned
  • Loading branch information
zhichao-li authored and davies committed Jul 3, 2015
1 parent 2848f4d commit ab535b9
Show file tree
Hide file tree
Showing 6 changed files with 113 additions and 0 deletions.
13 changes: 13 additions & 0 deletions python/pyspark/sql/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -436,6 +436,19 @@ def shiftRight(col, numBits):
return Column(jc)


@since(1.5)
def shiftRightUnsigned(col, numBits):
"""Unsigned shift the the given value numBits right.
>>> sqlContext.createDataFrame([(-42,)], ['a']).select(shiftRightUnsigned('a', 1).alias('r'))\
.collect()
[Row(r=9223372036854775787)]
"""
sc = SparkContext._active_spark_context
jc = sc._jvm.functions.shiftRightUnsigned(_to_java_column(col), numBits)
return Column(jc)


@since(1.4)
def sparkPartitionId():
"""A column for partition ID of the Spark task.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,7 @@ object FunctionRegistry {
expression[Rint]("rint"),
expression[ShiftLeft]("shiftleft"),
expression[ShiftRight]("shiftright"),
expression[ShiftRightUnsigned]("shiftrightunsigned"),
expression[Signum]("sign"),
expression[Signum]("signum"),
expression[Sin]("sin"),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -521,6 +521,55 @@ case class ShiftRight(left: Expression, right: Expression) extends BinaryExpress
}
}

case class ShiftRightUnsigned(left: Expression, right: Expression) extends BinaryExpression {

override def checkInputDataTypes(): TypeCheckResult = {
(left.dataType, right.dataType) match {
case (NullType, _) | (_, NullType) => return TypeCheckResult.TypeCheckSuccess
case (_, IntegerType) => left.dataType match {
case LongType | IntegerType | ShortType | ByteType =>
return TypeCheckResult.TypeCheckSuccess
case _ => // failed
}
case _ => // failed
}
TypeCheckResult.TypeCheckFailure(
s"ShiftRightUnsigned expects long, integer, short or byte value as first argument and an " +
s"integer value as second argument, not (${left.dataType}, ${right.dataType})")
}

override def eval(input: InternalRow): Any = {
val valueLeft = left.eval(input)
if (valueLeft != null) {
val valueRight = right.eval(input)
if (valueRight != null) {
valueLeft match {
case l: Long => l >>> valueRight.asInstanceOf[Integer]
case i: Integer => i >>> valueRight.asInstanceOf[Integer]
case s: Short => s >>> valueRight.asInstanceOf[Integer]
case b: Byte => b >>> valueRight.asInstanceOf[Integer]
}
} else {
null
}
} else {
null
}
}

override def dataType: DataType = {
left.dataType match {
case LongType => LongType
case IntegerType | ShortType | ByteType => IntegerType
case _ => NullType
}
}

override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
nullSafeCodeGen(ctx, ev, (result, left, right) => s"$result = $left >>> $right;")
}
}

/**
* Performs the inverse operation of HEX.
* Resulting characters are returned as a byte array.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -264,6 +264,19 @@ class MathFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper {
checkEvaluation(ShiftRight(Literal(-42.toLong), Literal(1)), -21.toLong)
}

test("shift right unsigned") {
checkEvaluation(ShiftRightUnsigned(Literal.create(null, IntegerType), Literal(1)), null)
checkEvaluation(ShiftRightUnsigned(Literal(42), Literal.create(null, IntegerType)), null)
checkEvaluation(
ShiftRight(Literal.create(null, IntegerType), Literal.create(null, IntegerType)), null)
checkEvaluation(ShiftRightUnsigned(Literal(42), Literal(1)), 21)
checkEvaluation(ShiftRightUnsigned(Literal(42.toByte), Literal(1)), 21)
checkEvaluation(ShiftRightUnsigned(Literal(42.toShort), Literal(1)), 21)
checkEvaluation(ShiftRightUnsigned(Literal(42.toLong), Literal(1)), 21.toLong)

checkEvaluation(ShiftRightUnsigned(Literal(-42.toLong), Literal(1)), 9223372036854775787L)
}

test("hex") {
checkEvaluation(Hex(Literal(28)), "1C")
checkEvaluation(Hex(Literal(-28)), "FFFFFFFFFFFFFFE4")
Expand Down
20 changes: 20 additions & 0 deletions sql/core/src/main/scala/org/apache/spark/sql/functions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1343,6 +1343,26 @@ object functions {
*/
def shiftRight(e: Column, numBits: Int): Column = ShiftRight(e.expr, lit(numBits).expr)

/**
* Unsigned shift the the given value numBits right. If the given value is a long value,
* it will return a long value else it will return an integer value.
*
* @group math_funcs
* @since 1.5.0
*/
def shiftRightUnsigned(columnName: String, numBits: Int): Column =
shiftRightUnsigned(Column(columnName), numBits)

/**
* Unsigned shift the the given value numBits right. If the given value is a long value,
* it will return a long value else it will return an integer value.
*
* @group math_funcs
* @since 1.5.0
*/
def shiftRightUnsigned(e: Column, numBits: Int): Column =
ShiftRightUnsigned(e.expr, lit(numBits).expr)

/**
* Shift the the given value numBits right. If the given value is a long value, it will return
* a long value else it will return an integer value.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -304,6 +304,23 @@ class MathExpressionsSuite extends QueryTest {
Row(21.toLong, 21, 21.toShort, 21.toByte, null))
}

test("shift right unsigned") {
val df = Seq[(Long, Integer, Short, Byte, Integer, Integer)]((-42, 42, 42, 42, 42, null))
.toDF("a", "b", "c", "d", "e", "f")

checkAnswer(
df.select(
shiftRightUnsigned('a, 1), shiftRightUnsigned('b, 1), shiftRightUnsigned('c, 1),
shiftRightUnsigned('d, 1), shiftRightUnsigned('f, 1)),
Row(9223372036854775787L, 21, 21.toShort, 21.toByte, null))

checkAnswer(
df.selectExpr(
"shiftRightUnsigned(a, 1)", "shiftRightUnsigned(b, 1)", "shiftRightUnsigned(c, 1)",
"shiftRightUnsigned(d, 1)", "shiftRightUnsigned(f, 1)"),
Row(9223372036854775787L, 21, 21.toShort, 21.toByte, null))
}

test("binary log") {
val df = Seq[(Integer, Integer)]((123, null)).toDF("a", "b")
checkAnswer(
Expand Down

0 comments on commit ab535b9

Please sign in to comment.