From 1a24082e22514d67ff387bbbf7b355b89ddc66a7 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Thu, 2 Jul 2015 00:03:26 -0700 Subject: [PATCH 1/4] Add Python API for hex and unhex --- python/pyspark/sql/functions.py | 28 ++++ .../catalyst/analysis/FunctionRegistry.scala | 2 +- .../spark/sql/catalyst/expressions/math.scala | 136 +++++++++--------- .../expressions/MathFunctionsSuite.scala | 7 +- .../org/apache/spark/sql/functions.scala | 2 +- 5 files changed, 102 insertions(+), 73 deletions(-) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index f9a15d4a66309..e551646d1fb6a 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -381,6 +381,34 @@ def randn(seed=None): return Column(jc) +@ignore_unicode_prefix +@since(1.5) +def hex(col): + """Computes hex value of the given column, which could be StringType, + BinaryType, IntegerType or LongType. + + >>> sqlContext.createDataFrame([('ABC', 3)], ['a', 'b']).select(hex('a'), hex('b')).collect() + [Row(hex(a)=u'414243', hex(b)=u'3')] + """ + sc = SparkContext._active_spark_context + jc = sc._jvm.functions.hex(_to_java_column(col)) + return Column(jc) + + +@ignore_unicode_prefix +@since(1.5) +def unhex(col): + """Inverse of hex. Interprets each pair of characters as a hexadecimal number + and converts to the byte representation of number. + + >>> sqlContext.createDataFrame([('414243',)], ['a']).select(unhex('a')).collect() + [Row(unhex(a)=u'ABC')] + """ + sc = SparkContext._active_spark_context + jc = sc._jvm.functions.unhex(_to_java_column(col)) + return Column(jc) + + @ignore_unicode_prefix @since(1.5) def sha1(col): diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index 6f04298d4711b..453f24f5202f0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -157,7 +157,7 @@ object FunctionRegistry { expression[Substring]("substr"), expression[Substring]("substring"), expression[Upper]("ucase"), - expression[UnHex]("unhex"), + expression[Unhex]("unhex"), expression[Upper]("upper") ) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala index 8633eb06ffee4..a45185bfe616e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala @@ -259,30 +259,22 @@ case class Hex(child: Expression) extends UnaryExpression with Serializable { case LongType => hex(num.asInstanceOf[Long]) case IntegerType => hex(num.asInstanceOf[Integer].toLong) case BinaryType => hex(num.asInstanceOf[Array[Byte]]) - case StringType => hex(num.asInstanceOf[UTF8String]) + case StringType => hex(num.asInstanceOf[UTF8String].getBytes) } } } - /** - * Converts every character in s to two hex digits. - */ - private def hex(str: UTF8String): UTF8String = { - hex(str.getBytes) - } + private[this] val hexDigits = Array[Char]( + '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', 'A', 'B', 'C', 'D', 'E', 'F' + ).map(_.toByte) - private def hex(bytes: Array[Byte]): UTF8String = { - doHex(bytes, bytes.length) - } - - private def doHex(bytes: Array[Byte], length: Int): UTF8String = { + private[this] def hex(bytes: Array[Byte]): UTF8String = { + val length = bytes.length val value = new Array[Byte](length * 2) var i = 0 while (i < length) { - value(i * 2) = Character.toUpperCase(Character.forDigit( - (bytes(i) & 0xF0) >>> 4, 16)).toByte - value(i * 2 + 1) = Character.toUpperCase(Character.forDigit( - bytes(i) & 0x0F, 16)).toByte + value(i * 2) = hexDigits((bytes(i) & 0xF0) >> 4) + value(i * 2 + 1) = hexDigits((bytes(i) & 0x0F)) i += 1 } UTF8String.fromBytes(value) @@ -303,6 +295,66 @@ case class Hex(child: Expression) extends UnaryExpression with Serializable { } } +/** + * Performs the inverse operation of HEX. + * Resulting characters are returned as a byte array. + */ +case class Unhex(child: Expression) + extends UnaryExpression with AutoCastInputTypes with Serializable { + + override def nullable: Boolean = true + override def dataType: DataType = BinaryType + override def inputTypes: Seq[DataType] = Seq(BinaryType) + + override def eval(input: InternalRow): Any = { + val num = child.eval(input) + if (num == null) { + null + } else { + unhex(num.asInstanceOf[UTF8String].getBytes) + } + } + + // lookup table to translate '0' -> 0 ... 'F'/'f' -> 15 + private[this] val unhexDigits = { + val array = Array.fill[Byte](128)(-1) + (0 to 9).foreach(i => array('0' + i) = i.toByte) + (0 to 5).foreach(i => array('A' + i) = (i + 10).toByte) + (0 to 5).foreach(i => array('a' + i) = (i + 10).toByte) + array + } + + private[this] def unhex(bytes: Array[Byte]): Array[Byte] = { + val out = new Array[Byte]((bytes.length + 1) >> 1) + var i = 0 + if ((bytes.length & 0x01) != 0) { + // padding with '0' + if (bytes(0) < 0) { + return null + } + val v = unhexDigits(bytes(0)) + if (v == -1) { + return null + } + out(0) = v + i += 1 + } + // two characters form the hex value. + while (i < bytes.length) { + if (bytes(i) < 0 || bytes(i + 1) < 0) { + return null + } + val first = unhexDigits(bytes(i)) + val second = unhexDigits(bytes(i + 1)) + if (first == -1 || second == -1) { + return null + } + out(i / 2) = (((first << 4) | second) & 0xFF).toByte + i += 2 + } + out + } +} //////////////////////////////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////////////////////////////// @@ -351,58 +403,6 @@ case class Pow(left: Expression, right: Expression) } } -/** - * Performs the inverse operation of HEX. - * Resulting characters are returned as a byte array. - */ -case class UnHex(child: Expression) extends UnaryExpression with Serializable { - - override def dataType: DataType = BinaryType - - override def checkInputDataTypes(): TypeCheckResult = { - if (child.dataType.isInstanceOf[StringType] || child.dataType == NullType) { - TypeCheckResult.TypeCheckSuccess - } else { - TypeCheckResult.TypeCheckFailure(s"unHex accepts String type, not ${child.dataType}") - } - } - - override def eval(input: InternalRow): Any = { - val num = child.eval(input) - if (num == null) { - null - } else { - unhex(num.asInstanceOf[UTF8String].getBytes) - } - } - - private val unhexDigits = { - val array = Array.fill[Byte](128)(-1) - (0 to 9).foreach(i => array('0' + i) = i.toByte) - (0 to 5).foreach(i => array('A' + i) = (i + 10).toByte) - (0 to 5).foreach(i => array('a' + i) = (i + 10).toByte) - array - } - - private def unhex(inputBytes: Array[Byte]): Array[Byte] = { - var bytes = inputBytes - if ((bytes.length & 0x01) != 0) { - bytes = '0'.toByte +: bytes - } - val out = new Array[Byte](bytes.length >> 1) - // two characters form the hex value. - var i = 0 - while (i < bytes.length) { - val first = unhexDigits(bytes(i)) - val second = unhexDigits(bytes(i + 1)) - if (first == -1 || second == -1) { return null} - out(i / 2) = (((first << 4) | second) & 0xFF).toByte - i += 2 - } - out - } -} - case class Hypot(left: Expression, right: Expression) extends BinaryMathExpression(math.hypot, "HYPOT") diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala index b3345d7069159..f37643dcc71bf 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala @@ -239,9 +239,10 @@ class MathFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { } test("unhex") { - checkEvaluation(UnHex(Literal("737472696E67")), "string".getBytes) - checkEvaluation(UnHex(Literal("")), new Array[Byte](0)) - checkEvaluation(UnHex(Literal("0")), Array[Byte](0)) + checkEvaluation(Unhex(Literal("737472696E67")), "string".getBytes) + checkEvaluation(Unhex(Literal("")), new Array[Byte](0)) + checkEvaluation(Unhex(Literal("F")), Array[Byte](15)) + checkEvaluation(Unhex(Literal("ff")), Array[Byte](-1)) } test("hypot") { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index e6f623bdf39eb..fda15943ce7dc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -1060,7 +1060,7 @@ object functions { * @group math_funcs * @since 1.5.0 */ - def unhex(column: Column): Column = UnHex(column.expr) + def unhex(column: Column): Column = Unhex(column.expr) /** * Inverse of hex. Interprets each pair of characters as a hexadecimal number From c3af78c1dad2390b6af4fe18159a65c5c9c1bd1a Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Thu, 2 Jul 2015 00:19:11 -0700 Subject: [PATCH 2/4] address commments --- .../apache/spark/sql/catalyst/expressions/math.scala | 2 +- .../sql/catalyst/expressions/MathFunctionsSuite.scala | 11 ++++++++++- 2 files changed, 11 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala index a45185bfe616e..2839a9a648911 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala @@ -300,7 +300,7 @@ case class Hex(child: Expression) extends UnaryExpression with Serializable { * Resulting characters are returned as a byte array. */ case class Unhex(child: Expression) - extends UnaryExpression with AutoCastInputTypes with Serializable { + extends UnaryExpression with ExpectsInputTypes with Serializable { override def nullable: Boolean = true override def dataType: DataType = BinaryType diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala index f37643dcc71bf..eb9f49b479496 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.dsl.expressions._ -import org.apache.spark.sql.types.{DataType, DoubleType, LongType} +import org.apache.spark.sql.types._ class MathFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { @@ -226,11 +226,15 @@ class MathFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { } test("hex") { + checkEvaluation(Hex(Literal.create(null, IntegerType)), null) checkEvaluation(Hex(Literal(28)), "1C") checkEvaluation(Hex(Literal(-28)), "FFFFFFFFFFFFFFE4") + checkEvaluation(Hex(Literal.create(null, LongType)), null) checkEvaluation(Hex(Literal(100800200404L)), "177828FED4") checkEvaluation(Hex(Literal(-100800200404L)), "FFFFFFE887D7012C") + checkEvaluation(Hex(Literal.create(null, StringType)), null) checkEvaluation(Hex(Literal("helloHex")), "68656C6C6F486578") + checkEvaluation(Hex(Literal.create(null, BinaryType)), null) checkEvaluation(Hex(Literal("helloHex".getBytes())), "68656C6C6F486578") // scalastyle:off // Turn off scala style for non-ascii chars @@ -239,10 +243,15 @@ class MathFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { } test("unhex") { + checkEvaluation(Unhex(Literal.create(null, StringType)), null) checkEvaluation(Unhex(Literal("737472696E67")), "string".getBytes) checkEvaluation(Unhex(Literal("")), new Array[Byte](0)) checkEvaluation(Unhex(Literal("F")), Array[Byte](15)) checkEvaluation(Unhex(Literal("ff")), Array[Byte](-1)) + // scalastyle:off + // Turn off scala style for non-ascii chars + checkEvaluation(Unhex(Literal("E4B889E9878DE79A84")), "δΈ‰ι‡ηš„".getBytes("UTF-8")) + // scalastyle:on } test("hypot") { From 25156b7a4911dc0ee9f41cabbfb529126288f68d Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Thu, 2 Jul 2015 10:31:30 -0700 Subject: [PATCH 3/4] address comments and fix test --- python/pyspark/sql/functions.py | 2 +- .../spark/sql/catalyst/expressions/math.scala | 40 +++++++++---------- 2 files changed, 21 insertions(+), 21 deletions(-) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index e551646d1fb6a..dbd14d91c04c3 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -402,7 +402,7 @@ def unhex(col): and converts to the byte representation of number. >>> sqlContext.createDataFrame([('414243',)], ['a']).select(unhex('a')).collect() - [Row(unhex(a)=u'ABC')] + [Row(unhex(a)=bytearray(b'ABC'))] """ sc = SparkContext._active_spark_context jc = sc._jvm.functions.unhex(_to_java_column(col)) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala index 2839a9a648911..bb1b97904f10e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala @@ -228,6 +228,20 @@ case class Bin(child: Expression) } } +object Hex { + val hexDigits = Array[Char]( + '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', 'A', 'B', 'C', 'D', 'E', 'F' + ).map(_.toByte) + + // lookup table to translate '0' -> 0 ... 'F'/'f' -> 15 + val unhexDigits = { + val array = Array.fill[Byte](128)(-1) + (0 to 9).foreach(i => array('0' + i) = i.toByte) + (0 to 5).foreach(i => array('A' + i) = (i + 10).toByte) + (0 to 5).foreach(i => array('a' + i) = (i + 10).toByte) + array + } +} /** * If the argument is an INT or binary, hex returns the number as a STRING in hexadecimal format. @@ -264,17 +278,13 @@ case class Hex(child: Expression) extends UnaryExpression with Serializable { } } - private[this] val hexDigits = Array[Char]( - '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', 'A', 'B', 'C', 'D', 'E', 'F' - ).map(_.toByte) - private[this] def hex(bytes: Array[Byte]): UTF8String = { val length = bytes.length val value = new Array[Byte](length * 2) var i = 0 while (i < length) { - value(i * 2) = hexDigits((bytes(i) & 0xF0) >> 4) - value(i * 2 + 1) = hexDigits((bytes(i) & 0x0F)) + value(i * 2) = Hex.hexDigits((bytes(i) & 0xF0) >> 4) + value(i * 2 + 1) = Hex.hexDigits(bytes(i) & 0x0F) i += 1 } UTF8String.fromBytes(value) @@ -287,8 +297,7 @@ case class Hex(child: Expression) extends UnaryExpression with Serializable { var len = 0 do { len += 1 - value(value.length - len) = Character.toUpperCase(Character - .forDigit((numBuf & 0xF).toInt, 16)).toByte + value(value.length - len) = Hex.hexDigits(numBuf & 0xF) numBuf >>>= 4 } while (numBuf != 0) UTF8String.fromBytes(Arrays.copyOfRange(value, value.length - len, value.length)) @@ -315,15 +324,6 @@ case class Unhex(child: Expression) } } - // lookup table to translate '0' -> 0 ... 'F'/'f' -> 15 - private[this] val unhexDigits = { - val array = Array.fill[Byte](128)(-1) - (0 to 9).foreach(i => array('0' + i) = i.toByte) - (0 to 5).foreach(i => array('A' + i) = (i + 10).toByte) - (0 to 5).foreach(i => array('a' + i) = (i + 10).toByte) - array - } - private[this] def unhex(bytes: Array[Byte]): Array[Byte] = { val out = new Array[Byte]((bytes.length + 1) >> 1) var i = 0 @@ -332,7 +332,7 @@ case class Unhex(child: Expression) if (bytes(0) < 0) { return null } - val v = unhexDigits(bytes(0)) + val v = Hex.unhexDigits(bytes(0)) if (v == -1) { return null } @@ -344,8 +344,8 @@ case class Unhex(child: Expression) if (bytes(i) < 0 || bytes(i + 1) < 0) { return null } - val first = unhexDigits(bytes(i)) - val second = unhexDigits(bytes(i + 1)) + val first = Hex.unhexDigits(bytes(i)) + val second = Hex.unhexDigits(bytes(i + 1)) if (first == -1 || second == -1) { return null } From b31fc9a76c299fdd5c09210576157b374ebcd245 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Thu, 2 Jul 2015 11:29:14 -0700 Subject: [PATCH 4/4] Update math.scala --- .../scala/org/apache/spark/sql/catalyst/expressions/math.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala index bb1b97904f10e..b4efc28b9e053 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala @@ -297,7 +297,7 @@ case class Hex(child: Expression) extends UnaryExpression with Serializable { var len = 0 do { len += 1 - value(value.length - len) = Hex.hexDigits(numBuf & 0xF) + value(value.length - len) = Hex.hexDigits((numBuf & 0xF).toInt) numBuf >>>= 4 } while (numBuf != 0) UTF8String.fromBytes(Arrays.copyOfRange(value, value.length - len, value.length))