From 186155c1579a2ba53017177be7f1d5bb5e5137bb Mon Sep 17 00:00:00 2001 From: Takeshi YAMAMURO Date: Wed, 25 Jan 2017 19:06:02 +0900 Subject: [PATCH] Fix a bug to wrongly use mllib vectors in some functions --- .../apache/spark/sql/hive/HivemallOps.scala | 6 +-- .../apache/spark/sql/hive/HivemallUtils.scala | 37 +++++++++---------- .../apache/spark/sql/hive/HiveUdfSuite.scala | 22 ++++------- 3 files changed, 29 insertions(+), 36 deletions(-) diff --git a/spark/spark-2.0/src/main/scala/org/apache/spark/sql/hive/HivemallOps.scala b/spark/spark-2.0/src/main/scala/org/apache/spark/sql/hive/HivemallOps.scala index 9bde84f8f..f233a2a1e 100644 --- a/spark/spark-2.0/src/main/scala/org/apache/spark/sql/hive/HivemallOps.scala +++ b/spark/spark-2.0/src/main/scala/org/apache/spark/sql/hive/HivemallOps.scala @@ -25,7 +25,7 @@ import scala.collection.JavaConverters._ import org.apache.spark.annotation.Experimental import org.apache.spark.internal.Logging import org.apache.spark.ml.feature.HivemallFeature -import org.apache.spark.ml.linalg.{DenseVector => SDV, SparseVector => SSV, VectorUDT} +import org.apache.spark.ml.linalg.{DenseVector, SparseVector, VectorUDT} import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute @@ -764,12 +764,12 @@ final class HivemallOps(df: DataFrame) extends Logging { StructField("feature", StringType) :: StructField("weight", DoubleType) :: Nil) val explodeFunc: Row => TraversableOnce[InternalRow] = (row: Row) => { row.get(0) match { - case dv: SDV => + case dv: DenseVector => dv.values.zipWithIndex.map { case (value, index) => InternalRow(UTF8String.fromString(s"$index"), value) } - case sv: SSV => + case sv: SparseVector => sv.values.zip(sv.indices).map { case (value, index) => InternalRow(UTF8String.fromString(s"$index"), value) diff --git a/spark/spark-2.0/src/main/scala/org/apache/spark/sql/hive/HivemallUtils.scala b/spark/spark-2.0/src/main/scala/org/apache/spark/sql/hive/HivemallUtils.scala index 692434773..b7b707142 100644 --- a/spark/spark-2.0/src/main/scala/org/apache/spark/sql/hive/HivemallUtils.scala +++ b/spark/spark-2.0/src/main/scala/org/apache/spark/sql/hive/HivemallUtils.scala @@ -18,8 +18,7 @@ */ package org.apache.spark.sql.hive -import org.apache.spark.ml.linalg.{DenseVector => SDV, SparseVector => SSV, Vector => SV} -import org.apache.spark.mllib.linalg.{BLAS, Vector, Vectors} +import org.apache.spark.ml.linalg.{BLAS, DenseVector, SparseVector, Vector, Vectors} import org.apache.spark.sql.{Column, DataFrame, Row} import org.apache.spark.sql.catalyst.expressions.Literal import org.apache.spark.sql.expressions.UserDefinedFunction @@ -46,17 +45,12 @@ object HivemallUtils { @inline implicit def toStringArrayLiteral(i: Seq[String]): Column = Column(Literal.create(i, ArrayType(StringType))) - /** - * Transforms `org.apache.spark.ml.linalg.Vector` into Hivemall features. - */ - def to_hivemall_features: UserDefinedFunction = udf(_to_hivemall_features) - - private[hive] def _to_hivemall_features = (v: SV) => v match { - case dv: SDV => + def to_hivemall_features_func(): Vector => Array[String] = { + case dv: DenseVector => dv.values.zipWithIndex.map { case (value, index) => s"$index:$value" } - case sv: SSV => + case sv: SparseVector => sv.values.zip(sv.indices).map { case (value, index) => s"$index:$value" } @@ -64,21 +58,15 @@ object HivemallUtils { throw new IllegalArgumentException(s"Do not support vector type ${v.getClass}") } - /** - * Returns a new vector with `1.0` (bias) appended to the input vector. - * @group ftvec - */ - def append_bias: UserDefinedFunction = udf(_append_bias) - - private[hive] def _append_bias = (v: SV) => v match { - case dv: SDV => + def append_bias_func(): Vector => Vector = { + case dv: DenseVector => val inputValues = dv.values val inputLength = inputValues.length val outputValues = Array.ofDim[Double](inputLength + 1) System.arraycopy(inputValues, 0, outputValues, 0, inputLength) outputValues(inputLength) = 1.0 Vectors.dense(outputValues) - case sv: SSV => + case sv: SparseVector => val inputValues = sv.values val inputIndices = sv.indices val inputValuesLength = inputValues.length @@ -94,6 +82,17 @@ object HivemallUtils { throw new IllegalArgumentException(s"Do not support vector type ${v.getClass}") } + /** + * Transforms `org.apache.spark.ml.linalg.Vector` into Hivemall features. + */ + def to_hivemall_features: UserDefinedFunction = udf(to_hivemall_features_func) + + /** + * Returns a new vector with `1.0` (bias) appended to the input vector. + * @group ftvec + */ + def append_bias: UserDefinedFunction = udf(append_bias_func) + /** * Make up a function object from a Hivemall model. */ diff --git a/spark/spark-2.0/src/test/scala/org/apache/spark/sql/hive/HiveUdfSuite.scala b/spark/spark-2.0/src/test/scala/org/apache/spark/sql/hive/HiveUdfSuite.scala index f8622c6b2..d53ef7352 100644 --- a/spark/spark-2.0/src/test/scala/org/apache/spark/sql/hive/HiveUdfSuite.scala +++ b/spark/spark-2.0/src/test/scala/org/apache/spark/sql/hive/HiveUdfSuite.scala @@ -118,7 +118,7 @@ final class HiveUdfWithVectorSuite extends VectorQueryTest { test("to_hivemall_features") { mllibTrainDf.createOrReplaceTempView("mllibTrainDf") - hiveContext.udf.register("to_hivemall_features", _to_hivemall_features) + hiveContext.udf.register("to_hivemall_features", to_hivemall_features_func) checkAnswer( sql( s""" @@ -134,16 +134,10 @@ final class HiveUdfWithVectorSuite extends VectorQueryTest { ) } - ignore("append_bias") { + test("append_bias") { mllibTrainDf.createOrReplaceTempView("mllibTrainDf") - hiveContext.udf.register("append_bias", _append_bias) - hiveContext.udf.register("to_hivemall_features", _to_hivemall_features) - /** - * TODO: This test throws an exception: - * Failed to analyze query: org.apache.spark.sql.AnalysisException: cannot resolve - * 'UDF(UDF(features))' due to data type mismatch: argument 1 requires vector type, - * however, 'UDF(features)' is of vector type.; line 2 pos 8 - */ + hiveContext.udf.register("append_bias", append_bias_func) + hiveContext.udf.register("to_hivemall_features", to_hivemall_features_func) checkAnswer( sql( s""" @@ -151,10 +145,10 @@ final class HiveUdfWithVectorSuite extends VectorQueryTest { | FROM mllibTrainDF """.stripMargin), Seq( - Row(Seq("0:1.0", "0:1.0", "2:2.0", "4:3.0")), - Row(Seq("0:1.0", "0:1.0", "3:1.5", "4:2.1", "6:1.2")), - Row(Seq("0:1.0", "0:1.1", "3:1.0", "4:2.3", "6:1.0")), - Row(Seq("0:1.0", "1:4.0", "3:5.0", "5:6.0")) + Row(Seq("0:1.0", "2:2.0", "4:3.0", "7:1.0")), + Row(Seq("0:1.0", "3:1.5", "4:2.1", "6:1.2", "7:1.0")), + Row(Seq("0:1.1", "3:1.0", "4:2.3", "6:1.0", "7:1.0")), + Row(Seq("1:4.0", "3:5.0", "5:6.0", "7:1.0")) ) ) }