Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Spark-8703] [ML] Add CountVectorizer as a ml transformer to convert document to words count vector #7084

Closed
wants to merge 9 commits into from
Closed
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.ml.feature

import scala.collection.mutable

import org.apache.spark.annotation.Experimental
import org.apache.spark.ml.UnaryTransformer
import org.apache.spark.ml.param.{ParamMap, ParamValidators, IntParam}
import org.apache.spark.ml.util.Identifiable
import org.apache.spark.mllib.linalg.{Vectors, VectorUDT, Vector}
import org.apache.spark.sql.types.{StringType, ArrayType, DataType}

/**
* :: Experimental ::
* Converts a text document to a sparse vector of token counts.
* @param vocabulary An Array over terms. Only the terms in the vocabulary will be counted.
*/
@Experimental
class CountVectorizer (override val uid: String, vocabulary: Array[String])
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we make vocabulary be a val? That will be good when we make an Estimator version to let users access the dictionary.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great point!

extends UnaryTransformer[Seq[String], Vector, CountVectorizer] {

def this(vocabulary: Array[String]) = this(Identifiable.randomUID("countVectorizer"), vocabulary)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is probably fine for now, but I had some thoughts about having an empty constructor for including every word encountered if no vocabulary is provided. If it requires significant modification, we should make a separate JIRA for it.


/**
* Corpus-specific stop words filter. Terms with count less than the given threshold are ignored.
* Default: 1
* @group param
*/
val minTermCounts: IntParam = new IntParam(this, "minTermCounts",
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should call this "minTermFreq". I agree that "count" is more accurate than "frequency," but when I discussed this previously with others, we decided to go with "frequency" since that is the common term in the literature and other libraries. I say "Freq" instead of "Frequency" to match IDF's param name.

Also, using "term frequency" in the doc should clarify that this threshold operates per-document.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the careful consideration.

"lower bound of effective term counts (>= 1)", ParamValidators.gtEq(1))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I prefer long built-in documentation. Please make this built-in doc have as much detail as the Scala doc.


/** @group setParam */
def setMinTermCounts(value: Int): this.type = set(minTermCounts, value)

/** @group getParam */
def getMinTermCounts: Int = $(minTermCounts)

setDefault(minTermCounts -> 1)

override protected def createTransformFunc: Seq[String] => Vector = {
val dict = vocabulary.zipWithIndex.toMap
document =>
val termCounts = mutable.HashMap.empty[Int, Double]
document.foreach { term =>
dict.get(term) match {
case Some(index) => termCounts.put(index, termCounts.getOrElse(index, 0.0) + 1.0)
case None => // ignore terms not in the vocabulary
}
}
Vectors.sparse(dict.size, termCounts.filter(_._2 >= $(minTermCounts)).toSeq)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perhaps do the filter before initializing Vectors.sparse to avoid trailing zeros if the filter result size is < dict.size?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure I get this. Do you mean we should alter the size of the sparse vector dynamically? The vectors after transformation should have the same size for many algorithms to work properly.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point, the trailing zeros should be pretty cheap with a sparse Vector anyways so this seems fine to me.

}

override protected def validateInputType(inputType: DataType): Unit = {
require(inputType.sameType(ArrayType(StringType)),
s"Input type must be Array type but got $inputType.")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please say precise type ArrayType(StringType), or better yet use SchemaUtils.checkColumnType.

}

override protected def outputDataType: DataType = new VectorUDT()

override def copy(extra: ParamMap): CountVectorizer = {
val copied = new CountVectorizer(uid, vocabulary)
copyValues(copied, extra)
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.ml.feature

import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.param.ParamsSuite
import org.apache.spark.mllib.linalg.{Vector, Vectors}
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.mllib.util.TestingUtils._

class CountVectorizerSuite extends SparkFunSuite with MLlibTestSparkContext {

test("params") {
ParamsSuite.checkParams(new CountVectorizer(Array("empty")))
}

test("CountVectorizer common cases") {
val df = sqlContext.createDataFrame(Seq(
(0, "a b c d".split(" ").toSeq),
(1, "a b b c d a".split(" ").toSeq),
(2, "a".split(" ").toSeq),
(3, "".split(" ").toSeq), // empty string
(3, "a notInDict d".split(" ").toSeq) // with words not in vocabulary
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

duplicate id

)).toDF("id", "words")
val cv = new CountVectorizer(Array("a", "b", "c", "d"))
.setInputCol("words")
.setOutputCol("features")
val output = cv.transform(df)
val features = output.select("features").collect()

val expected = Seq(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can this be matched with the doc ID to make sure the order isn't switched around? An easier way would be to include the expected results in the original DataFrame.

same for other test

Vectors.sparse(4, Seq((0, 1.0), (1, 1.0), (2, 1.0), (3, 1.0))),
Vectors.sparse(4, Seq((0, 2.0), (1, 2.0), (2, 1.0), (3, 1.0))),
Vectors.sparse(4, Seq((0, 1.0))),
Vectors.sparse(4, Seq()),
Vectors.sparse(4, Seq((0, 1.0), (3, 1.0))))

features.zip(expected).foreach(p =>
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

scala style: use braces {} instead of parentheses for multiline closures

same for other test

assert(p._1.getAs[Vector](0) ~== p._2 absTol 1e-14)
)
}

test("CountVectorizer with minTermCounts") {
val df = sqlContext.createDataFrame(Seq(
(0, "a a a b b c c c d ".split(" ").toSeq),
(1, "c c c c c c".split(" ").toSeq),
(2, "a".split(" ").toSeq),
(3, "e e e e e".split(" ").toSeq)
)).toDF("id", "words")
val cv = new CountVectorizer(Array("a", "b", "c", "d"))
.setInputCol("words")
.setOutputCol("features")
.setMinTermCounts(3)
val output = cv.transform(df)
val features = output.select("features").collect()

val expected = Seq(
Vectors.sparse(4, Seq((0, 3.0), (2, 3.0))),
Vectors.sparse(4, Seq((2, 6.0))),
Vectors.sparse(4, Seq()),
Vectors.sparse(4, Seq()))

features.zip(expected).foreach(p =>
assert(p._1.getAs[Vector](0) ~== p._2 absTol 1e-14)
)
}
}