From 6a25da2cc96a744aaf047280ac414e5ff4515434 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Yan=20Facai=20=28=E9=A2=9C=E5=8F=91=E6=89=8D=29?= Date: Sat, 19 Aug 2017 10:24:14 +0800 Subject: [PATCH] ENH: implement HashingTF in ml --- .../apache/spark/ml/feature/HashingTF.scala | 24 +++++++++++++++---- 1 file changed, 19 insertions(+), 5 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala index db432b6fefaff..795bc0fc6f5b4 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala @@ -20,13 +20,15 @@ package org.apache.spark.ml.feature import org.apache.spark.annotation.Since import org.apache.spark.ml.Transformer import org.apache.spark.ml.attribute.AttributeGroup +import org.apache.spark.ml.linalg.Vectors import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol} import org.apache.spark.ml.util._ -import org.apache.spark.mllib.feature +import org.apache.spark.mllib.feature.HashingTF.murmur3Hash import org.apache.spark.sql.{DataFrame, Dataset} import org.apache.spark.sql.functions.{col, udf} import org.apache.spark.sql.types.{ArrayType, StructType} +import org.apache.spark.util.Utils /** * Maps a sequence of terms to their term frequencies using the hashing trick. @@ -40,6 +42,8 @@ import org.apache.spark.sql.types.{ArrayType, StructType} class HashingTF @Since("1.4.0") (@Since("1.4.0") override val uid: String) extends Transformer with HasInputCol with HasOutputCol with DefaultParamsWritable { + private[this] val hashFunc: Any => Int = murmur3Hash + @Since("1.2.0") def this() = this(Identifiable.randomUID("hashingTF")) @@ -93,11 +97,21 @@ class HashingTF @Since("1.4.0") (@Since("1.4.0") override val uid: String) @Since("2.0.0") override def transform(dataset: Dataset[_]): DataFrame = { val outputSchema = transformSchema(dataset.schema) - val hashingTF = new feature.HashingTF($(numFeatures)).setBinary($(binary)) - // TODO: Make the hashingTF.transform natively in ml framework to avoid extra conversion. - val t = udf { terms: Seq[_] => hashingTF.transform(terms).asML } + val hashUDF = udf { (terms: Seq[_]) => + val ids = terms.map { term => + Utils.nonNegativeMod(hashFunc(term), $(numFeatures)) + } + + val termFrequencies: Seq[(Int, Double)] = if ($(binary)) { + ids.distinct.map(x => x -> 1.0) + } else { + ids.groupBy(identity).mapValues(_.size.toDouble).toSeq + } + + Vectors.sparse($(numFeatures), termFrequencies) + } val metadata = outputSchema($(outputCol)).metadata - dataset.select(col("*"), t(col($(inputCol))).as($(outputCol), metadata)) + dataset.select(col("*"), hashUDF(col($(inputCol))).as($(outputCol), metadata)) } @Since("1.4.0")