Skip to content
This repository has been archived by the owner on Sep 20, 2022. It is now read-only.

Commit

Permalink
Close #25: [HIVEMALL-34] Fix a bug to wrongly use mllib vectors in so…
Browse files Browse the repository at this point in the history
…me functions
  • Loading branch information
maropu committed Jan 26, 2017
1 parent 4688494 commit b909996
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 36 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ import scala.collection.JavaConverters._
import org.apache.spark.annotation.Experimental
import org.apache.spark.internal.Logging
import org.apache.spark.ml.feature.HivemallFeature
import org.apache.spark.ml.linalg.{DenseVector => SDV, SparseVector => SSV, VectorUDT}
import org.apache.spark.ml.linalg.{DenseVector, SparseVector, VectorUDT}
import org.apache.spark.sql._
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute
Expand Down Expand Up @@ -764,12 +764,12 @@ final class HivemallOps(df: DataFrame) extends Logging {
StructField("feature", StringType) :: StructField("weight", DoubleType) :: Nil)
val explodeFunc: Row => TraversableOnce[InternalRow] = (row: Row) => {
row.get(0) match {
case dv: SDV =>
case dv: DenseVector =>
dv.values.zipWithIndex.map {
case (value, index) =>
InternalRow(UTF8String.fromString(s"$index"), value)
}
case sv: SSV =>
case sv: SparseVector =>
sv.values.zip(sv.indices).map {
case (value, index) =>
InternalRow(UTF8String.fromString(s"$index"), value)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,7 @@
*/
package org.apache.spark.sql.hive

import org.apache.spark.ml.linalg.{DenseVector => SDV, SparseVector => SSV, Vector => SV}
import org.apache.spark.mllib.linalg.{BLAS, Vector, Vectors}
import org.apache.spark.ml.linalg.{BLAS, DenseVector, SparseVector, Vector, Vectors}
import org.apache.spark.sql.{Column, DataFrame, Row}
import org.apache.spark.sql.catalyst.expressions.Literal
import org.apache.spark.sql.expressions.UserDefinedFunction
Expand All @@ -46,39 +45,28 @@ object HivemallUtils {
@inline implicit def toStringArrayLiteral(i: Seq[String]): Column =
Column(Literal.create(i, ArrayType(StringType)))

/**
* Transforms `org.apache.spark.ml.linalg.Vector` into Hivemall features.
*/
def to_hivemall_features: UserDefinedFunction = udf(_to_hivemall_features)

private[hive] def _to_hivemall_features = (v: SV) => v match {
case dv: SDV =>
def to_hivemall_features_func(): Vector => Array[String] = {
case dv: DenseVector =>
dv.values.zipWithIndex.map {
case (value, index) => s"$index:$value"
}
case sv: SSV =>
case sv: SparseVector =>
sv.values.zip(sv.indices).map {
case (value, index) => s"$index:$value"
}
case v =>
throw new IllegalArgumentException(s"Do not support vector type ${v.getClass}")
}

/**
* Returns a new vector with `1.0` (bias) appended to the input vector.
* @group ftvec
*/
def append_bias: UserDefinedFunction = udf(_append_bias)

private[hive] def _append_bias = (v: SV) => v match {
case dv: SDV =>
def append_bias_func(): Vector => Vector = {
case dv: DenseVector =>
val inputValues = dv.values
val inputLength = inputValues.length
val outputValues = Array.ofDim[Double](inputLength + 1)
System.arraycopy(inputValues, 0, outputValues, 0, inputLength)
outputValues(inputLength) = 1.0
Vectors.dense(outputValues)
case sv: SSV =>
case sv: SparseVector =>
val inputValues = sv.values
val inputIndices = sv.indices
val inputValuesLength = inputValues.length
Expand All @@ -94,6 +82,17 @@ object HivemallUtils {
throw new IllegalArgumentException(s"Do not support vector type ${v.getClass}")
}

/**
* Transforms `org.apache.spark.ml.linalg.Vector` into Hivemall features.
*/
def to_hivemall_features: UserDefinedFunction = udf(to_hivemall_features_func)

/**
* Returns a new vector with `1.0` (bias) appended to the input vector.
* @group ftvec
*/
def append_bias: UserDefinedFunction = udf(append_bias_func)

/**
* Make up a function object from a Hivemall model.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ final class HiveUdfWithVectorSuite extends VectorQueryTest {

test("to_hivemall_features") {
mllibTrainDf.createOrReplaceTempView("mllibTrainDf")
hiveContext.udf.register("to_hivemall_features", _to_hivemall_features)
hiveContext.udf.register("to_hivemall_features", to_hivemall_features_func)
checkAnswer(
sql(
s"""
Expand All @@ -134,27 +134,21 @@ final class HiveUdfWithVectorSuite extends VectorQueryTest {
)
}

ignore("append_bias") {
test("append_bias") {
mllibTrainDf.createOrReplaceTempView("mllibTrainDf")
hiveContext.udf.register("append_bias", _append_bias)
hiveContext.udf.register("to_hivemall_features", _to_hivemall_features)
/**
* TODO: This test throws an exception:
* Failed to analyze query: org.apache.spark.sql.AnalysisException: cannot resolve
* 'UDF(UDF(features))' due to data type mismatch: argument 1 requires vector type,
* however, 'UDF(features)' is of vector type.; line 2 pos 8
*/
hiveContext.udf.register("append_bias", append_bias_func)
hiveContext.udf.register("to_hivemall_features", to_hivemall_features_func)
checkAnswer(
sql(
s"""
| SELECT to_hivemall_features(append_bias(features))
| FROM mllibTrainDF
""".stripMargin),
Seq(
Row(Seq("0:1.0", "0:1.0", "2:2.0", "4:3.0")),
Row(Seq("0:1.0", "0:1.0", "3:1.5", "4:2.1", "6:1.2")),
Row(Seq("0:1.0", "0:1.1", "3:1.0", "4:2.3", "6:1.0")),
Row(Seq("0:1.0", "1:4.0", "3:5.0", "5:6.0"))
Row(Seq("0:1.0", "2:2.0", "4:3.0", "7:1.0")),
Row(Seq("0:1.0", "3:1.5", "4:2.1", "6:1.2", "7:1.0")),
Row(Seq("0:1.1", "3:1.0", "4:2.3", "6:1.0", "7:1.0")),
Row(Seq("1:4.0", "3:5.0", "5:6.0", "7:1.0"))
)
)
}
Expand Down

0 comments on commit b909996

Please sign in to comment.