Skip to content

Commit

Permalink
add shiftrightunsigned
Browse files Browse the repository at this point in the history
  • Loading branch information
zhichao-li committed Jul 3, 2015
1 parent a59d14f commit d85ae0b
Show file tree
Hide file tree
Showing 6 changed files with 111 additions and 0 deletions.
11 changes: 11 additions & 0 deletions python/pyspark/sql/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -435,6 +435,17 @@ def shiftRight(col, numBits):
jc = sc._jvm.functions.shiftRight(_to_java_column(col), numBits)
return Column(jc)

@since(1.5)
def shiftRightUnsigned(col, numBits):
"""Unsigned shift the the given value numBits right.
>>> sqlContext.createDataFrame([(-42,)], ['a']).select(shiftRightUnsigned('a', 1).alias('r'))\
.collect()
[Row(r=9223372036854775787)]
"""
sc = SparkContext._active_spark_context
jc = sc._jvm.functions.shiftRightUnsigned(_to_java_column(col), numBits)
return Column(jc)

@since(1.4)
def sparkPartitionId():
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,7 @@ object FunctionRegistry {
expression[Rint]("rint"),
expression[ShiftLeft]("shiftleft"),
expression[ShiftRight]("shiftright"),
expression[ShiftRightUnsigned]("shiftrightunsigned"),
expression[Signum]("sign"),
expression[Signum]("signum"),
expression[Sin]("sin"),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -521,6 +521,55 @@ case class ShiftRight(left: Expression, right: Expression) extends BinaryExpress
}
}

case class ShiftRightUnsigned(left: Expression, right: Expression) extends BinaryExpression {

override def checkInputDataTypes(): TypeCheckResult = {
(left.dataType, right.dataType) match {
case (NullType, _) | (_, NullType) => return TypeCheckResult.TypeCheckSuccess
case (_, IntegerType) => left.dataType match {
case LongType | IntegerType | ShortType | ByteType =>
return TypeCheckResult.TypeCheckSuccess
case _ => // failed
}
case _ => // failed
}
TypeCheckResult.TypeCheckFailure(
s"ShiftRightUnsigned expects long, integer, short or byte value as first argument and an " +
s"integer value as second argument, not (${left.dataType}, ${right.dataType})")
}

override def eval(input: InternalRow): Any = {
val valueLeft = left.eval(input)
if (valueLeft != null) {
val valueRight = right.eval(input)
if (valueRight != null) {
left.dataType match {
case LongType => valueLeft.asInstanceOf[Long] >>> valueRight.asInstanceOf[Int]
case IntegerType => valueLeft.asInstanceOf[Int] >>> valueRight.asInstanceOf[Int]
case ShortType => valueLeft.asInstanceOf[Short] >>> valueRight.asInstanceOf[Int]
case ByteType => valueLeft.asInstanceOf[Byte] >>> valueRight.asInstanceOf[Int]
}
} else {
null
}
} else {
null
}
}

override def dataType: DataType = {
left.dataType match {
case LongType => LongType
case IntegerType | ShortType | ByteType => IntegerType
case _ => NullType
}
}

override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
nullSafeCodeGen(ctx, ev, (result, left, right) => s"$result = $left >>> $right;")
}
}

/**
* Performs the inverse operation of HEX.
* Resulting characters are returned as a byte array.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -264,6 +264,19 @@ class MathFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper {
checkEvaluation(ShiftRight(Literal(-42.toLong), Literal(1)), -21.toLong)
}

test("shift right unsigned") {
checkEvaluation(ShiftRightUnsigned(Literal.create(null, IntegerType), Literal(1)), null)
checkEvaluation(ShiftRightUnsigned(Literal(42), Literal.create(null, IntegerType)), null)
checkEvaluation(
ShiftRight(Literal.create(null, IntegerType), Literal.create(null, IntegerType)), null)
checkEvaluation(ShiftRightUnsigned(Literal(42), Literal(1)), 21)
checkEvaluation(ShiftRightUnsigned(Literal(42.toByte), Literal(1)), 21)
checkEvaluation(ShiftRightUnsigned(Literal(42.toShort), Literal(1)), 21)
checkEvaluation(ShiftRightUnsigned(Literal(42.toLong), Literal(1)), 21.toLong)

checkEvaluation(ShiftRightUnsigned(Literal(-42.toLong), Literal(1)), 9223372036854775787L)
}

test("hex") {
checkEvaluation(Hex(Literal(28)), "1C")
checkEvaluation(Hex(Literal(-28)), "FFFFFFFFFFFFFFE4")
Expand Down
20 changes: 20 additions & 0 deletions sql/core/src/main/scala/org/apache/spark/sql/functions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1343,6 +1343,26 @@ object functions {
*/
def shiftRight(e: Column, numBits: Int): Column = ShiftRight(e.expr, lit(numBits).expr)

/**
* Unsigned shift the the given value numBits right. If the given value is a long value,
* it will return a long value else it will return an integer value.
*
* @group math_funcs
* @since 1.5.0
*/
def shiftRightUnsigned(columnName: String, numBits: Int): Column =
shiftRightUnsigned(Column(columnName), numBits)

/**
* Unsigned shift the the given value numBits right. If the given value is a long value,
* it will return a long value else it will return an integer value.
*
* @group math_funcs
* @since 1.5.0
*/
def shiftRightUnsigned(e: Column, numBits: Int): Column =
ShiftRightUnsigned(e.expr, lit(numBits).expr)

/**
* Shift the the given value numBits right. If the given value is a long value, it will return
* a long value else it will return an integer value.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -304,6 +304,23 @@ class MathExpressionsSuite extends QueryTest {
Row(21.toLong, 21, 21.toShort, 21.toByte, null))
}

test("shift right unsigned") {
val df = Seq[(Long, Integer, Short, Byte, Integer, Integer)]((-42, 42, 42, 42, 42, null))
.toDF("a", "b", "c", "d", "e", "f")

checkAnswer(
df.select(
shiftRightUnsigned('a, 1), shiftRightUnsigned('b, 1), shiftRightUnsigned('c, 1),
shiftRightUnsigned('d, 1), shiftRightUnsigned('f, 1)),
Row(9223372036854775787L, 21, 21.toShort, 21.toByte, null))

checkAnswer(
df.selectExpr(
"shiftRightUnsigned(a, 1)", "shiftRightUnsigned(b, 1)", "shiftRightUnsigned(c, 1)",
"shiftRightUnsigned(d, 1)", "shiftRightUnsigned(f, 1)"),
Row(9223372036854775787L, 21, 21.toShort, 21.toByte, null))
}

test("binary log") {
val df = Seq[(Integer, Integer)]((123, null)).toDF("a", "b")
checkAnswer(
Expand Down

0 comments on commit d85ae0b

Please sign in to comment.