Skip to content

Commit

Permalink
extends HashingTF
Browse files Browse the repository at this point in the history
  • Loading branch information
hhbyyh committed Jun 30, 2015
1 parent 809fb59 commit 7ee1c31
Showing 1 changed file with 12 additions and 19 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -14,26 +14,24 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.ml.feature

import scala.collection.mutable

import org.apache.spark.annotation.Experimental
import org.apache.spark.ml.UnaryTransformer
import org.apache.spark.ml.param._
import org.apache.spark.ml.param.{IntParam, ParamMap, ParamValidators}
import org.apache.spark.ml.util.Identifiable
import org.apache.spark.mllib.linalg.{Vectors, VectorUDT, Vector}
import org.apache.spark.sql.types.{StringType, ArrayType, DataType}
import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.functions._

/**
* :: Experimental ::
* Converts a text document to a sparse vector of token counts.
* @param vocabulary An Array over terms. Only the terms in the vocabulary will be counted.
*/
@Experimental
class CountVectorizer (override val uid: String, vocabulary: Array[String])
extends UnaryTransformer[Seq[String], Vector, CountVectorizer] {
class CountVectorizer (override val uid: String, vocabulary: Array[String]) extends HashingTF{

def this(vocabulary: Array[String]) = this(Identifiable.randomUID("countVectorizer"), vocabulary)

Expand All @@ -43,36 +41,31 @@ class CountVectorizer (override val uid: String, vocabulary: Array[String])
* @group param
*/
val minTermCounts: IntParam = new IntParam(this, "minTermCounts",
"lower bound of effective term counts (>= 0)", ParamValidators.gtEq(1))
"lower bound of effective term counts (>= 1)", ParamValidators.gtEq(1))

/** @group setParam */
def setMinTermCounts(value: Int): this.type = set(minTermCounts, value)

/** @group getParam */
def getMinTermCounts: Int = $(minTermCounts)

setDefault(minTermCounts -> 1)
setDefault(minTermCounts -> 1, numFeatures -> vocabulary.size)

override protected def createTransformFunc: Seq[String] => Vector = {
override def transform(dataset: DataFrame): DataFrame = {
val dict = vocabulary.zipWithIndex.toMap
document =>
val t = udf { terms: Seq[String] =>
val termCounts = mutable.HashMap.empty[Int, Double]
document.foreach { term =>
terms.foreach { term =>
val index = dict.getOrElse(term, -1)
if (index >= 0) {
termCounts.put(index, termCounts.getOrElse(index, 0.0) + 1.0)
}
}
Vectors.sparse(dict.size, termCounts.filter(_._2 >= $(minTermCounts)).toSeq)
}
dataset.withColumn($(outputCol), t(col($(inputCol))))
}

override protected def validateInputType(inputType: DataType): Unit = {
require(inputType.sameType(ArrayType(StringType)),
s"Input type must be Array type but got $inputType.")
}

override protected def outputDataType: DataType = new VectorUDT()

override def copy(extra: ParamMap): CountVectorizer = {
val copied = new CountVectorizer(uid, vocabulary)
copyValues(copied, extra)
Expand Down

0 comments on commit 7ee1c31

Please sign in to comment.