From afe626cccc158691ec500a0653a4907d281044bb Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Sat, 21 Nov 2015 16:28:55 +0800 Subject: [PATCH 1/4] Add misc function hash. --- .../catalyst/analysis/FunctionRegistry.scala | 1 + .../spark/sql/catalyst/expressions/misc.scala | 39 +++++++++++++++++++ .../expressions/MiscFunctionsSuite.scala | 36 ++++++++++++++++- .../org/apache/spark/sql/functions.scala | 9 +++++ .../spark/sql/DataFrameFunctionsSuite.scala | 17 ++++++++ 5 files changed, 101 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index f9c04d7ec0b0c..995161ebccb12 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -120,6 +120,7 @@ object FunctionRegistry { expression[Coalesce]("coalesce"), expression[Explode]("explode"), expression[Greatest]("greatest"), + expression[Hash]("hash"), expression[If]("if"), expression[IsNaN]("isnan"), expression[IsNull]("isnull"), diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala index 0f6d02f2e00c2..22c8d19cd78bd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala @@ -23,6 +23,7 @@ import java.util.zip.CRC32 import org.apache.commons.codec.digest.DigestUtils import org.apache.spark.sql.catalyst.expressions.codegen._ +import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String @@ -45,6 +46,44 @@ case class Md5(child: Expression) extends UnaryExpression with ImplicitCastInput } } +case class Hash(children: Expression*) extends Expression with ImplicitCastInputTypes { + + override def inputTypes: Seq[AbstractDataType] = children.map(_.dataType) + override def dataType: DataType = IntegerType + + override def nullable: Boolean = children.exists(_.nullable) + + @transient + private lazy val extractProjection = GenerateUnsafeProjection.generate(children) + + def getProjection: UnsafeProjection = extractProjection + + override def eval(input: InternalRow): Any = { + extractProjection(input).hashCode + } + + override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + val hashExpressionClassName = classOf[Hash].getName + val projectionClassName = classOf[UnsafeProjection].getName + + ctx.references += this + val hashExpressionTermIndex = ctx.references.size - 1 + + val projectionTerm = ctx.freshName("projection") + ctx.addMutableState(projectionClassName, projectionTerm, + s"this.$projectionTerm = ($projectionClassName)((($hashExpressionClassName)expressions" + + s"[$hashExpressionTermIndex]).getProjection());") + + s""" + boolean ${ev.isNull} = false; + Integer ${ev.value} = $projectionTerm.apply(${ctx.INPUT_ROW}).hashCode(); + if (${ev.value} == null) { + ${ev.isNull} = true; + } + """ + } +} + /** * A function that calculates the SHA-2 family of functions (SHA-224, SHA-256, SHA-384, and SHA-512) * and returns it as a hex string. The first argument is the string or binary to be hashed. The diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MiscFunctionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MiscFunctionsSuite.scala index 75d17417e5a02..dfd544e2bea20 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MiscFunctionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MiscFunctionsSuite.scala @@ -20,7 +20,9 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.commons.codec.digest.DigestUtils import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.types.{IntegerType, StringType, BinaryType} +import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.types._ class MiscFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { @@ -32,6 +34,38 @@ class MiscFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkConsistencyBetweenInterpretedAndCodegen(Md5, BinaryType) } + test("hash") { + def projection(exprs: Expression*): UnsafeProjection = + GenerateUnsafeProjection.generate(exprs) + def getHashCode(inputs: Expression*): Int = projection(inputs: _*)(null).hashCode + + checkEvaluation(Hash(Literal(3)), getHashCode(Literal(3))) + checkEvaluation(Hash(Literal(3L)), getHashCode(Literal(3L))) + checkEvaluation(Hash(Literal(3.7d)), getHashCode(Literal(3.7d))) + checkEvaluation(Hash(Literal(3.7f)), getHashCode(Literal(3.7f))) + val v1: Byte = 3 + val v2: Short = 3 + checkEvaluation(Hash(Literal(v1)), getHashCode(Literal(v1))) + checkEvaluation(Hash(Literal(v2)), getHashCode(Literal(v2))) + checkEvaluation(Hash(Literal(v1), Literal(v2)), getHashCode(Literal(v1), Literal(v2))) + checkEvaluation(Hash(Literal("ABC")), getHashCode(Literal("ABC"))) + checkEvaluation(Hash(Literal(true)), getHashCode(Literal(true))) + checkEvaluation(Hash(Literal.create(Decimal(3.7), DecimalType.Unlimited)), + getHashCode(Literal.create(Decimal(3.7), DecimalType.Unlimited))) + checkEvaluation(Hash(Literal.create(java.sql.Date.valueOf("1991-12-07"), DateType)), + getHashCode(Literal.create(java.sql.Date.valueOf("1991-12-07"), DateType))) + checkEvaluation( + Hash(Literal.create(java.sql.Timestamp.valueOf("1991-12-07 12:00:00"), TimestampType)), + getHashCode(Literal.create(java.sql.Timestamp.valueOf("1991-12-07 12:00:00"), TimestampType))) + checkEvaluation(Hash(Literal.create(Map[Int, Int](1 -> 2), MapType(IntegerType, IntegerType))), + getHashCode(Literal.create(Map[Int, Int](1 -> 2), MapType(IntegerType, IntegerType)))) + checkEvaluation(Hash(Literal.create(Seq[Byte](1, 2, 3, 4, 5, 6), ArrayType(ByteType))), + getHashCode(Literal.create(Seq[Byte](1, 2, 3, 4, 5, 6), ArrayType(ByteType)))) + checkEvaluation(Hash(Literal.create(Seq[Double](1.1, 2.2, 3.3, 4.4, 5.5, 6.6), + ArrayType(DoubleType))), getHashCode(Literal.create(Seq[Double](1.1, 2.2, 3.3, 4.4, 5.5, 6.6), + ArrayType(DoubleType)))) + } + test("sha1") { checkEvaluation(Sha1(Literal("ABC".getBytes)), "3c01bdbb26f358bab27f267924aa2c9a03fcfdb8") checkEvaluation(Sha1(Literal.create(Array[Byte](1, 2, 3, 4, 5, 6), BinaryType)), diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index b27b1340cce46..48b69d8316fdc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -1665,6 +1665,15 @@ object functions extends LegacyFunctions { // Misc functions ////////////////////////////////////////////////////////////////////////////////////////////// + /** + * Returns a hash value of the arguments. + * + * @group misc_funcs + * @since 1.6.0 + */ + @scala.annotation.varargs + def hash(exprs: Column*): Column = withExpr { Hash(exprs.map(_.expr): _*) } + /** * Calculates the MD5 digest of a binary column and returns the value * as a 32 character hex string. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index aff9efe4b2b16..63c060f3751fb 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -18,6 +18,8 @@ package org.apache.spark.sql import org.apache.spark.sql.functions._ +import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection +import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ @@ -166,6 +168,21 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { Row("902fbdd2b1df0c4f70b4a5d23525e932", "6ac1e56bc78f031059be7be854522c4c")) } + test("misc hash function") { + def projection(exprs: Expression*): UnsafeProjection = + GenerateUnsafeProjection.generate(exprs) + def getHashCode(inputs: Expression*): Int = projection(inputs: _*)(null).hashCode + + val df = Seq(("ABC", 3.7f)).toDF("a", "b") + checkAnswer( + df.select(hash($"a"), hash($"b")), + Row(getHashCode(Literal("ABC")), getHashCode(Literal(3.7f)))) + + checkAnswer( + df.selectExpr("hash(a)", "hash(b)"), + Row(getHashCode(Literal("ABC")), getHashCode(Literal(3.7f)))) + } + test("misc sha1 function") { val df = Seq(("ABC", "ABC".getBytes)).toDF("a", "b") checkAnswer( From 47072b8dcb0a6676a9d143bb4a46df014f3b9a10 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Sun, 22 Nov 2015 18:30:44 +0800 Subject: [PATCH 2/4] Try to be consistent with Hive's hash function. --- .../main/scala/org/apache/spark/sql/Row.scala | 38 ++++++++++++++----- .../sql/catalyst/encoders/RowEncoder.scala | 5 ++- .../codegen/GenerateProjection.scala | 1 + .../spark/sql/catalyst/expressions/misc.scala | 34 ++++++++++++++--- .../spark/sql/catalyst/expressions/rows.scala | 1 + .../expressions/MiscFunctionsSuite.scala | 38 +++++++------------ .../spark/sql/DataFrameFunctionsSuite.scala | 9 +---- 7 files changed, 78 insertions(+), 48 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala index b14c66cc5ac88..1a6317571abfe 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala @@ -424,16 +424,34 @@ trait Row extends Serializable { true } - override def hashCode: Int = { - // Using Scala's Seq hash code implementation. - var n = 0 - var h = MurmurHash3.seqSeed - val len = length - while (n < len) { - h = MurmurHash3.mix(h, apply(n).##) - n += 1 - } - MurmurHash3.finalizeHash(h, n) + override def hashCode: Int = hashCode(this) + + def hashCode(v: Any): Int = v match { + case null => 0 + case b: Boolean => if (b) 1 else 0 + case b: Byte => b.toInt + case s: Short => s.toInt + case i: Int => i + case l: Long => (l ^ (l >>> 32)).toInt + case f: Float => java.lang.Float.floatToIntBits(f) + case d: Double => + val b = java.lang.Double.doubleToLongBits(d) + (b ^ (b >>> 32)).toInt + case a: Array[Byte] => java.util.Arrays.hashCode(a) + case s: String => s.getBytes.foldLeft(0) { (acc, n) => acc * 31 + n } + case a: Array[_] => a.foldLeft(0) { (acc, n) => acc * 31 + hashCode(n) } + case s: Seq[_] => s.foldLeft(0) { (acc, n) => acc * 31 + hashCode(n) } + case m: Map[_, _] => + var r = 0 + m.foreach { case (k, v) => r += hashCode(k) ^ hashCode(v) } + r + case r: Row => + var res = 0 + for (i <- 0 until r.length) { + res = 31 * res + hashCode(r(i)) + } + res + case other => other.hashCode() } /* ---------------------- utility methods for Scala ---------------------- */ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala index 4cda4824acdc3..fa553e7c5324c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala @@ -48,7 +48,7 @@ object RowEncoder { private def extractorsFor( inputObject: Expression, inputType: DataType): Expression = inputType match { - case BooleanType | ByteType | ShortType | IntegerType | LongType | + case NullType | BooleanType | ByteType | ShortType | IntegerType | LongType | FloatType | DoubleType | BinaryType => inputObject case udt: UserDefinedType[_] => @@ -143,6 +143,7 @@ object RowEncoder { case _: MapType => ObjectType(classOf[scala.collection.Map[_, _]]) case _: StructType => ObjectType(classOf[Row]) case udt: UserDefinedType[_] => ObjectType(udt.userClass) + case _: NullType => ObjectType(classOf[java.lang.Object]) } private def constructorFor(schema: StructType): Expression = { @@ -158,7 +159,7 @@ object RowEncoder { } private def constructorFor(input: Expression): Expression = input.dataType match { - case BooleanType | ByteType | ShortType | IntegerType | LongType | + case NullType | BooleanType | ByteType | ShortType | IntegerType | LongType | FloatType | DoubleType | BinaryType => input case udt: UserDefinedType[_] => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala index f229f2000d8e1..87b34c608498f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala @@ -129,6 +129,7 @@ object GenerateProjection extends CodeGenerator[Seq[Expression], Projection] { case DoubleType => s"(int)(Double.doubleToLongBits($col) ^ (Double.doubleToLongBits($col) >>> 32))" case BinaryType => s"java.util.Arrays.hashCode($col)" + case StringType => s"$col.toString().hashCode()" case _ => s"$col.hashCode()" } s"isNullAt($i) ? 0 : ($nonNull)" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala index 22c8d19cd78bd..2a07bf42e9c2a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala @@ -22,8 +22,10 @@ import java.util.zip.CRC32 import org.apache.commons.codec.digest.DigestUtils +import org.apache.spark.sql.catalyst.encoders.{ExpressionEncoder, RowEncoder} import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.Row import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String @@ -54,17 +56,28 @@ case class Hash(children: Expression*) extends Expression with ImplicitCastInput override def nullable: Boolean = children.exists(_.nullable) @transient - private lazy val extractProjection = GenerateUnsafeProjection.generate(children) + private lazy val extractProjection = GenerateSafeProjection.generate(children) - def getProjection: UnsafeProjection = extractProjection + lazy val schema = StructType(children.zipWithIndex.map { case (e, idx) => + StructField(s"_c$idx", e.dataType) + }) + @transient + private lazy val encoder = RowEncoder(schema) + + def getProjection: Projection = extractProjection + def getEncoder: ExpressionEncoder[Row] = encoder override def eval(input: InternalRow): Any = { - extractProjection(input).hashCode + val internalRow: InternalRow = extractProjection(input) + encoder.fromRow(internalRow).hashCode } override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { val hashExpressionClassName = classOf[Hash].getName - val projectionClassName = classOf[UnsafeProjection].getName + val projectionClassName = classOf[Projection].getName + val encoderClassName = classOf[ExpressionEncoder[Row]].getName + val internalRowClassName = classOf[InternalRow].getName + val rowClassName = classOf[Row].getName ctx.references += this val hashExpressionTermIndex = ctx.references.size - 1 @@ -74,9 +87,20 @@ case class Hash(children: Expression*) extends Expression with ImplicitCastInput s"this.$projectionTerm = ($projectionClassName)((($hashExpressionClassName)expressions" + s"[$hashExpressionTermIndex]).getProjection());") + val encoderTerm = ctx.freshName("encoder") + ctx.addMutableState(encoderClassName, encoderTerm, + s"this.$encoderTerm = ($encoderClassName)((($hashExpressionClassName)expressions" + + s"[$hashExpressionTermIndex]).getEncoder());") + + val internalRowTerm = ctx.freshName("internalRow") + val rowTerm = ctx.freshName("row") + s""" boolean ${ev.isNull} = false; - Integer ${ev.value} = $projectionTerm.apply(${ctx.INPUT_ROW}).hashCode(); + ${internalRowClassName} ${internalRowTerm} = + (${internalRowClassName})${projectionTerm}.apply(${ctx.INPUT_ROW}); + ${rowClassName} ${rowTerm} = (${rowClassName})${encoderTerm}.fromRow(${internalRowTerm}); + Integer ${ev.value} = ${rowTerm}.hashCode(); if (${ev.value} == null) { ${ev.isNull} = true; } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala index cfc68fc00bea8..ec94bb106af6d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala @@ -147,6 +147,7 @@ trait BaseGenericInternalRow extends InternalRow { val b = java.lang.Double.doubleToLongBits(d) (b ^ (b >>> 32)).toInt case a: Array[Byte] => java.util.Arrays.hashCode(a) + case s: UTF8String => s.toString().hashCode() case other => other.hashCode() } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MiscFunctionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MiscFunctionsSuite.scala index dfd544e2bea20..c764a39ae66b9 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MiscFunctionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MiscFunctionsSuite.scala @@ -22,6 +22,7 @@ import org.apache.commons.codec.digest.DigestUtils import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.Row import org.apache.spark.sql.types._ class MiscFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { @@ -35,35 +36,24 @@ class MiscFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { } test("hash") { - def projection(exprs: Expression*): UnsafeProjection = - GenerateUnsafeProjection.generate(exprs) - def getHashCode(inputs: Expression*): Int = projection(inputs: _*)(null).hashCode - - checkEvaluation(Hash(Literal(3)), getHashCode(Literal(3))) - checkEvaluation(Hash(Literal(3L)), getHashCode(Literal(3L))) - checkEvaluation(Hash(Literal(3.7d)), getHashCode(Literal(3.7d))) - checkEvaluation(Hash(Literal(3.7f)), getHashCode(Literal(3.7f))) + checkEvaluation(Hash(Literal.create(null, NullType)), 0) + checkEvaluation(Hash(Literal(3)), 3) + checkEvaluation(Hash(Literal(3L)), 3) + checkEvaluation(Hash(Literal(3.7d)), -644612093) + checkEvaluation(Hash(Literal(3.7f)), 1080872141) val v1: Byte = 3 val v2: Short = 3 - checkEvaluation(Hash(Literal(v1)), getHashCode(Literal(v1))) - checkEvaluation(Hash(Literal(v2)), getHashCode(Literal(v2))) - checkEvaluation(Hash(Literal(v1), Literal(v2)), getHashCode(Literal(v1), Literal(v2))) - checkEvaluation(Hash(Literal("ABC")), getHashCode(Literal("ABC"))) - checkEvaluation(Hash(Literal(true)), getHashCode(Literal(true))) - checkEvaluation(Hash(Literal.create(Decimal(3.7), DecimalType.Unlimited)), - getHashCode(Literal.create(Decimal(3.7), DecimalType.Unlimited))) - checkEvaluation(Hash(Literal.create(java.sql.Date.valueOf("1991-12-07"), DateType)), - getHashCode(Literal.create(java.sql.Date.valueOf("1991-12-07"), DateType))) - checkEvaluation( - Hash(Literal.create(java.sql.Timestamp.valueOf("1991-12-07 12:00:00"), TimestampType)), - getHashCode(Literal.create(java.sql.Timestamp.valueOf("1991-12-07 12:00:00"), TimestampType))) + checkEvaluation(Hash(Literal(v1)), 3) + checkEvaluation(Hash(Literal(v2)), 3) + checkEvaluation(Hash(Literal(v1), Literal(v2)), 96) + checkEvaluation(Hash(Literal("ABC")), 64578) + checkEvaluation(Hash(Literal(true)), 1) checkEvaluation(Hash(Literal.create(Map[Int, Int](1 -> 2), MapType(IntegerType, IntegerType))), - getHashCode(Literal.create(Map[Int, Int](1 -> 2), MapType(IntegerType, IntegerType)))) + 3) checkEvaluation(Hash(Literal.create(Seq[Byte](1, 2, 3, 4, 5, 6), ArrayType(ByteType))), - getHashCode(Literal.create(Seq[Byte](1, 2, 3, 4, 5, 6), ArrayType(ByteType)))) + 30569571) checkEvaluation(Hash(Literal.create(Seq[Double](1.1, 2.2, 3.3, 4.4, 5.5, 6.6), - ArrayType(DoubleType))), getHashCode(Literal.create(Seq[Double](1.1, 2.2, 3.3, 4.4, 5.5, 6.6), - ArrayType(DoubleType)))) + ArrayType(DoubleType))), 540728227) } test("sha1") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index 63c060f3751fb..fae8797abc270 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -18,7 +18,6 @@ package org.apache.spark.sql import org.apache.spark.sql.functions._ -import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ @@ -169,18 +168,14 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { } test("misc hash function") { - def projection(exprs: Expression*): UnsafeProjection = - GenerateUnsafeProjection.generate(exprs) - def getHashCode(inputs: Expression*): Int = projection(inputs: _*)(null).hashCode - val df = Seq(("ABC", 3.7f)).toDF("a", "b") checkAnswer( df.select(hash($"a"), hash($"b")), - Row(getHashCode(Literal("ABC")), getHashCode(Literal(3.7f)))) + Row(64578, 1080872141)) checkAnswer( df.selectExpr("hash(a)", "hash(b)"), - Row(getHashCode(Literal("ABC")), getHashCode(Literal(3.7f)))) + Row(64578, 1080872141)) } test("misc sha1 function") { From d9d28a36a7e6da35f456a47cb35adc7720afd51a Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Mon, 23 Nov 2015 16:48:46 +0800 Subject: [PATCH 3/4] Fix python test. --- python/pyspark/sql/group.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/pyspark/sql/group.py b/python/pyspark/sql/group.py index 227f40bc3cf53..fbb96f464886d 100644 --- a/python/pyspark/sql/group.py +++ b/python/pyspark/sql/group.py @@ -74,11 +74,11 @@ def agg(self, *exprs): or a list of :class:`Column`. >>> gdf = df.groupBy(df.name) - >>> gdf.agg({"*": "count"}).collect() + >>> sorted(gdf.agg({"*": "count"}).collect(), key = lambda r: r.name) [Row(name=u'Alice', count(1)=1), Row(name=u'Bob', count(1)=1)] >>> from pyspark.sql import functions as F - >>> gdf.agg(F.min(df.age)).collect() + >>> sorted(gdf.agg(F.min(df.age)).collect(), key = lambda r: r.name) [Row(name=u'Alice', min(age)=2), Row(name=u'Bob', min(age)=5)] """ assert exprs, "exprs should not be empty" From e8d4b108c6d829841db045c6a626f1eb1bd80ee7 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Mon, 23 Nov 2015 18:52:25 +0800 Subject: [PATCH 4/4] No need to modify InternalRow's hashCode. --- python/pyspark/sql/group.py | 4 ++-- .../sql/catalyst/expressions/codegen/GenerateProjection.scala | 1 - .../org/apache/spark/sql/catalyst/expressions/rows.scala | 1 - 3 files changed, 2 insertions(+), 4 deletions(-) diff --git a/python/pyspark/sql/group.py b/python/pyspark/sql/group.py index fbb96f464886d..227f40bc3cf53 100644 --- a/python/pyspark/sql/group.py +++ b/python/pyspark/sql/group.py @@ -74,11 +74,11 @@ def agg(self, *exprs): or a list of :class:`Column`. >>> gdf = df.groupBy(df.name) - >>> sorted(gdf.agg({"*": "count"}).collect(), key = lambda r: r.name) + >>> gdf.agg({"*": "count"}).collect() [Row(name=u'Alice', count(1)=1), Row(name=u'Bob', count(1)=1)] >>> from pyspark.sql import functions as F - >>> sorted(gdf.agg(F.min(df.age)).collect(), key = lambda r: r.name) + >>> gdf.agg(F.min(df.age)).collect() [Row(name=u'Alice', min(age)=2), Row(name=u'Bob', min(age)=5)] """ assert exprs, "exprs should not be empty" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala index 87b34c608498f..f229f2000d8e1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala @@ -129,7 +129,6 @@ object GenerateProjection extends CodeGenerator[Seq[Expression], Projection] { case DoubleType => s"(int)(Double.doubleToLongBits($col) ^ (Double.doubleToLongBits($col) >>> 32))" case BinaryType => s"java.util.Arrays.hashCode($col)" - case StringType => s"$col.toString().hashCode()" case _ => s"$col.hashCode()" } s"isNullAt($i) ? 0 : ($nonNull)" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala index ec94bb106af6d..cfc68fc00bea8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala @@ -147,7 +147,6 @@ trait BaseGenericInternalRow extends InternalRow { val b = java.lang.Double.doubleToLongBits(d) (b ^ (b >>> 32)).toInt case a: Array[Byte] => java.util.Arrays.hashCode(a) - case s: UTF8String => s.toString().hashCode() case other => other.hashCode() } }