Skip to content

Commit

Permalink
[SPARK-8223][SPARK-8224] right and left bit shift
Browse files Browse the repository at this point in the history
  • Loading branch information
tarekbecker committed Jul 2, 2015
1 parent 365c140 commit ac7fe9d
Show file tree
Hide file tree
Showing 6 changed files with 264 additions and 1 deletion.
27 changes: 27 additions & 0 deletions python/pyspark/sql/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -412,6 +412,33 @@ def sha2(col, numBits):
return Column(jc)


@since(1.5)
def shiftLeft(col, numBits):
"""Shift the the given value numBits left. Returns int for tinyint, smallint and int and
bigint for bigint a.
>>> sqlContext.createDataFrame([(21,)], ['a']).select(shiftLeft('a', 1).alias('r')).collect()
[Row(r=42)]
"""
sc = SparkContext._active_spark_context
jc = sc._jvm.functions.shiftLeft(_to_java_column(col), numBits)
return Column(jc)


@since(1.5)
def shiftRight(col, numBits):
"""Shift the the given value numBits right. Returns int for tinyint, smallint and int and
bigint for bigint a.
>>> sqlContext.createDataFrame([(42,)], ['a']).select(shiftRight('a', 1).alias('r')).collect()
[Row(r=21)]
"""
sc = SparkContext._active_spark_context
jc = sc._jvm.functions.shiftRight(_to_java_column(col), numBits)
return Column(jc)


@since(1.4)
def sparkPartitionId():
"""A column for partition ID of the Spark task.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,8 @@ object FunctionRegistry {
expression[Pow]("power"),
expression[UnaryPositive]("positive"),
expression[Rint]("rint"),
expression[ShiftLeft]("shiftleft"),
expression[ShiftRight]("shiftright"),
expression[Signum]("sign"),
expression[Signum]("signum"),
expression[Sin]("sin"),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -351,6 +351,142 @@ case class Pow(left: Expression, right: Expression)
}
}

case class ShiftLeft(left: Expression, right: Expression) extends Expression {

override def nullable: Boolean = true

override def children: Seq[Expression] = Seq(left, right)

override def checkInputDataTypes(): TypeCheckResult = {
(left.dataType, right.dataType) match {
case (NullType, _) | (_, NullType) => return TypeCheckResult.TypeCheckSuccess
case (_, IntegerType) => left.dataType match {
case LongType | IntegerType | ShortType | ByteType => TypeCheckResult.TypeCheckSuccess
case _ => // failed
}
case _ => // failed
}
TypeCheckResult.TypeCheckFailure(
s"ShiftLeft expects long, integer, short or byte value as first argument and an " +
s"integer value as second argument, not (${left.dataType}, ${right.dataType})")
}

override def eval(input: InternalRow): Any = {
val valueLeft = left.eval(input)
if (valueLeft != null) {
val valueRight = right.eval(input)
if (valueRight != null) {
valueLeft match {
case l: Long => l << valueRight.asInstanceOf[Integer]
case i: Integer => i << valueRight.asInstanceOf[Integer]
case s: Short => s << valueRight.asInstanceOf[Integer]
case b: Byte => b << valueRight.asInstanceOf[Integer]
}
} else {
null
}
} else {
null
}
}

override def dataType: DataType = {
left.dataType match {
case LongType => LongType
case IntegerType | ShortType | ByteType => IntegerType
case _ => NullType
}
}

override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
val eval1 = left.gen(ctx)
val eval2 = right.gen(ctx)
s"""
${eval1.code}
boolean ${ev.isNull} = ${eval1.isNull};
${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)};
if (!${ev.isNull}) {
${eval2.code}
if (!${eval2.isNull}) {
${ev.primitive} = ${eval1.primitive} << ${eval2.primitive};
} else {
${ev.isNull} = true;
}
}
"""
}

override def toString: String = s"ShiftLeft($left, $right)"
}

case class ShiftRight(left: Expression, right: Expression) extends Expression {

override def nullable: Boolean = true

override def children: Seq[Expression] = Seq(left, right)

override def checkInputDataTypes(): TypeCheckResult = {
(left.dataType, right.dataType) match {
case (NullType, _) | (_, NullType) => return TypeCheckResult.TypeCheckSuccess
case (_, IntegerType) => left.dataType match {
case LongType | IntegerType | ShortType | ByteType => return TypeCheckResult.TypeCheckSuccess
case _ => // failed
}
case _ => // failed
}
TypeCheckResult.TypeCheckFailure(
s"ShiftRight expects long, integer, short or byte value as first argument and an " +
s"integer value as second argument, not (${left.dataType}, ${right.dataType})")
}

override def eval(input: InternalRow): Any = {
val valueLeft = left.eval(input)
if (valueLeft != null) {
val valueRight = right.eval(input)
if (valueRight != null) {
valueLeft match {
case l: Long => l >> valueRight.asInstanceOf[Integer]
case i: Integer => i >> valueRight.asInstanceOf[Integer]
case s: Short => s >> valueRight.asInstanceOf[Integer]
case b: Byte => b >> valueRight.asInstanceOf[Integer]
}
} else {
null
}
} else {
null
}
}

override def dataType: DataType = {
left.dataType match {
case LongType => LongType
case IntegerType | ShortType | ByteType => IntegerType
case _ => NullType
}
}

override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
val eval1 = left.gen(ctx)
val eval2 = right.gen(ctx)
s"""
${eval1.code}
boolean ${ev.isNull} = ${eval1.isNull};
${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)};
if (!${ev.isNull}) {
${eval2.code}
if (!${eval2.isNull}) {
${ev.primitive} = ${eval1.primitive} >> ${eval2.primitive};
} else {
${ev.isNull} = true;
}
}
"""
}

override def toString: String = s"ShiftRight($left, $right)"
}

case class Hypot(left: Expression, right: Expression)
extends BinaryMathExpression(math.hypot, "HYPOT")

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.expressions

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.types.{DataType, DoubleType, LongType}
import org.apache.spark.sql.types.{IntegerType, DataType, DoubleType, LongType}

class MathFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper {

Expand Down Expand Up @@ -225,6 +225,32 @@ class MathFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper {
testBinary(Pow, math.pow, Seq((-1.0, 0.9), (-2.2, 1.7), (-2.2, -1.7)), expectNull = true)
}

test("shift left") {
checkEvaluation(ShiftLeft(Literal.create(null, IntegerType), Literal(1)), null)
checkEvaluation(ShiftLeft(Literal(21), Literal.create(null, IntegerType)), null)
checkEvaluation(
ShiftLeft(Literal.create(null, IntegerType), Literal.create(null, IntegerType)), null)
checkEvaluation(ShiftLeft(Literal(21), Literal(1)), 42)
checkEvaluation(ShiftLeft(Literal(21.toByte), Literal(1)), 42)
checkEvaluation(ShiftLeft(Literal(21.toShort), Literal(1)), 42)
checkEvaluation(ShiftLeft(Literal(21.toLong), Literal(1)), 42.toLong)

checkEvaluation(ShiftLeft(Literal(-21.toLong), Literal(1)), -42.toLong)
}

test("shift right") {
checkEvaluation(ShiftRight(Literal.create(null, IntegerType), Literal(1)), null)
checkEvaluation(ShiftRight(Literal(42), Literal.create(null, IntegerType)), null)
checkEvaluation(
ShiftRight(Literal.create(null, IntegerType), Literal.create(null, IntegerType)), null)
checkEvaluation(ShiftRight(Literal(42), Literal(1)), 21)
checkEvaluation(ShiftRight(Literal(42.toByte), Literal(1)), 21)
checkEvaluation(ShiftRight(Literal(42.toShort), Literal(1)), 21)
checkEvaluation(ShiftRight(Literal(42.toLong), Literal(1)), 21.toLong)

checkEvaluation(ShiftRight(Literal(-42.toLong), Literal(1)), -21.toLong)
}

test("hex") {
checkEvaluation(Hex(Literal(28)), "1C")
checkEvaluation(Hex(Literal(-28)), "FFFFFFFFFFFFFFE4")
Expand Down
38 changes: 38 additions & 0 deletions sql/core/src/main/scala/org/apache/spark/sql/functions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1280,6 +1280,44 @@ object functions {
*/
def rint(columnName: String): Column = rint(Column(columnName))

/**
* Shift the the given value numBits left. Returns int for tinyint, smallint and int and
* bigint for bigint a.
*
* @group math_funcs
* @since 1.5.0
*/
def shiftLeft(e: Column, numBits: Integer): Column = ShiftLeft(e.expr, lit(numBits).expr)

/**
* Shift the the given value numBits left. Returns int for tinyint, smallint and int and
* bigint for bigint a.
*
* @group math_funcs
* @since 1.5.0
*/
def shiftLeft(columnName: String, numBits: Integer): Column =
shiftLeft(Column(columnName), numBits)

/**
* Bitwise right shift of the given value. Returns int for tinyint, smallint and int and
* bigint for bigint a.
*
* @group math_funcs
* @since 1.5.0
*/
def shiftRight(e: Column, numBits: Integer): Column = ShiftRight(e.expr, lit(numBits).expr)

/**
* Shift the the given value numBits right. Returns int for tinyint, smallint and int and
* bigint for bigint a.
*
* @group math_funcs
* @since 1.5.0
*/
def shiftRight(columnName: String, numBits: Integer): Column =
shiftRight(Column(columnName), numBits)

/**
* Computes the signum of the given value.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -248,6 +248,40 @@ class MathExpressionsSuite extends QueryTest {
test("log1p") {
testOneToOneNonNegativeMathFunction(log1p, math.log1p)
}

test("shift left") {
val df = Seq[(Long, Integer, Short, Byte, Integer, Integer)]((21, 21, 21, 21, 21, null))
.toDF("a", "b", "c", "d", "e", "f")

checkAnswer(
df.select(
shiftLeft('a, 1), shiftLeft('b, 1), shiftLeft('c, 1), shiftLeft('d, 1),
shiftLeft('e, null), shiftLeft('f, 1)),
Row(42.toLong, 42, 42.toShort, 42.toByte, null, null))

checkAnswer(
df.selectExpr(
"shiftLeft(a, 1)", "shiftLeft(b, 1)", "shiftLeft(b, 1)", "shiftLeft(d, 1)",
"shiftLeft(e, null)", "shiftLeft(f, 1)"),
Row(42.toLong, 42, 42.toShort, 42.toByte, null, null))
}

test("shift right") {
val df = Seq[(Long, Integer, Short, Byte, Integer, Integer)]((42, 42, 42, 42, 42, null))
.toDF("a", "b", "c", "d", "e", "f")

checkAnswer(
df.select(
shiftRight('a, 1), shiftRight('b, 1), shiftRight('c, 1), shiftRight('d, 1),
shiftRight('e, null), shiftRight('f, 1)),
Row(21.toLong, 21, 21.toShort, 21.toByte, null, null))

checkAnswer(
df.selectExpr(
"shiftRight(a, 1)", "shiftRight(b, 1)", "shiftRight(c, 1)", "shiftRight(d, 1)",
"shiftRight(e, null)", "shiftRight(f, 1)"),
Row(21.toLong, 21, 21.toShort, 21.toByte, null, null))
}

test("binary log") {
val df = Seq[(Integer, Integer)]((123, null)).toDF("a", "b")
Expand Down

0 comments on commit ac7fe9d

Please sign in to comment.