From 3193c7791996873abb3ec02f1e46ae324515e049 Mon Sep 17 00:00:00 2001 From: Yuhao Yang Date: Mon, 13 Jul 2015 18:30:49 +0800 Subject: [PATCH 1/7] count vectorizer estimator --- .../ml/feature/CountVectorizerModel.scala | 149 ++++++++++++++---- 1 file changed, 120 insertions(+), 29 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizerModel.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizerModel.scala index 6b77de89a0330..55a64162b17f0 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizerModel.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizerModel.scala @@ -19,45 +19,136 @@ package org.apache.spark.ml.feature import scala.collection.mutable import org.apache.spark.annotation.Experimental -import org.apache.spark.ml.UnaryTransformer -import org.apache.spark.ml.param.{ParamMap, ParamValidators, IntParam} -import org.apache.spark.ml.util.Identifiable -import org.apache.spark.mllib.linalg.{Vectors, VectorUDT, Vector} -import org.apache.spark.sql.types.{StringType, ArrayType, DataType} +import org.apache.spark.ml.param._ +import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol} +import org.apache.spark.ml.util.{Identifiable, SchemaUtils} +import org.apache.spark.ml.{Estimator, Model} +import org.apache.spark.mllib.linalg.{VectorUDT, Vectors} +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.types._ +import org.apache.spark.sql.DataFrame + +/** + * Params for [[CountVectorizer]] and [[CountVectorizerModel]]. + */ +private[feature] trait CountVectorizerParams extends Params with HasInputCol with HasOutputCol { + + /** + * size of the vocabulary. + * If using Estimator, CountVectorizer will build a vocabulary that only consider the top + * vocabSize terms ordered by term frequency across the corpus. + * Default: 10000 + * @group param + */ + val vocabSize: IntParam = new IntParam(this, "vocabSize", "size of the vocabulary") + + /** @group getParam */ + def getVocabSize: Int = $(vocabSize) + + /** Validates and transforms the input schema. */ + protected def validateAndTransformSchema(schema: StructType): StructType = { + SchemaUtils.checkColumnType(schema, $(inputCol), new ArrayType(StringType, true)) + SchemaUtils.appendColumn(schema, $(outputCol), new VectorUDT) + } + + override def validateParams(): Unit = { + require($(vocabSize) > 0, s"The vocabulary size (${$(vocabSize)}) must be above 0.") + } +} /** * :: Experimental :: - * Converts a text document to a sparse vector of token counts. - * @param vocabulary An Array over terms. Only the terms in the vocabulary will be counted. + * Extracts a vocabulary from document collections and generates a CountVectorizerModel. */ -@Experimental -class CountVectorizerModel (override val uid: String, val vocabulary: Array[String]) - extends UnaryTransformer[Seq[String], Vector, CountVectorizerModel] { +class CountVectorizer(override val uid: String) + extends Estimator[CountVectorizerModel] with CountVectorizerParams { - def this(vocabulary: Array[String]) = - this(Identifiable.randomUID("cntVec"), vocabulary) + def this() = this(Identifiable.randomUID("cntVec")) /** - * Corpus-specific filter to ignore scarce words in a document. For each document, terms with - * frequency (count) less than the given threshold are ignored. + * The minimum number of times a token must appear in the corpus to be included in the vocabulary. * Default: 1 * @group param */ - val minTermFreq: IntParam = new IntParam(this, "minTermFreq", - "minimum frequency (count) filter used to neglect scarce words (>= 1). For each document, " + - "terms with frequency less than the given threshold are ignored.", ParamValidators.gtEq(1)) + val minCount: IntParam = new IntParam(this, "minCount", + "minimum number of times a token must appear in the corpus to be included in the vocabulary." + , ParamValidators.gtEq(1)) + + /** @group getParam */ + def getMinCount: Int = $(minCount) /** @group setParam */ - def setMinTermFreq(value: Int): this.type = set(minTermFreq, value) + def setInputCol(value: String): this.type = set(inputCol, value) - /** @group getParam */ - def getMinTermFreq: Int = $(minTermFreq) + /** @group setParam */ + def setOutputCol(value: String): this.type = set(outputCol, value) + + /** @group setParam */ + def setVocabSize(value: Int): this.type = set(vocabSize, value) + + /** @group setParam */ + def setMinCount(value: Int): this.type = set(minCount, value) - setDefault(minTermFreq -> 1) + setDefault(vocabSize -> 10000, minCount -> 1) - override protected def createTransformFunc: Seq[String] => Vector = { + override def fit(dataset: DataFrame): CountVectorizerModel = { + transformSchema(dataset.schema, logging = true) + val input = dataset.select($(inputCol)).map(_.getAs[Seq[String]](0)) + val min_count = $(minCount) + val vocab_size = $(vocabSize) + val wordCounts: RDD[(String, Long)] = input + .flatMap { case (tokens) => tokens.map(_ -> 1L) } + .reduceByKey(_ + _) + .filter(_._2 >= min_count) + wordCounts.cache() + val fullVocabSize = wordCounts.count() + val vocab: Array[String] = { + val tmpSortedWC: Array[(String, Long)] = if (fullVocabSize <= vocab_size) { + // Use all terms + wordCounts.collect().sortBy(_._2) + } else { + // Sort terms to select vocab + wordCounts.sortBy(_._2, ascending = false).take(vocab_size) + } + tmpSortedWC.map(_._1) + } + + require(vocab.length > 0, "The vocabulary size should be > 0. Adjust minCount as necessary.") + copyValues(new CountVectorizerModel(uid, vocab).setParent(this)) + } + + override def transformSchema(schema: StructType): StructType = { + validateAndTransformSchema(schema) + } + + override def copy(extra: ParamMap): CountVectorizer = defaultCopy(extra) +} + +/** + * :: Experimental :: + * Converts a text document to a sparse vector of token counts. + * @param vocabulary An Array over terms. Only the terms in the vocabulary will be counted. + */ +@Experimental +class CountVectorizerModel(override val uid: String, val vocabulary: Array[String]) + extends Model[CountVectorizerModel] with CountVectorizerParams { + + def this(vocabulary: Array[String]) = { + this(Identifiable.randomUID("cntVecModel"), vocabulary) + set(vocabSize, vocabulary.length) + } + + /** @group setParam */ + def setInputCol(value: String): this.type = set(inputCol, value) + + /** @group setParam */ + def setOutputCol(value: String): this.type = set(outputCol, value) + + override def transform(dataset: DataFrame): DataFrame = { val dict = vocabulary.zipWithIndex.toMap - document => + + val vectorizer = udf { (document: Seq[String]) => val termCounts = mutable.HashMap.empty[Int, Double] document.foreach { term => dict.get(term) match { @@ -65,15 +156,15 @@ class CountVectorizerModel (override val uid: String, val vocabulary: Array[Stri case None => // ignore terms not in the vocabulary } } - Vectors.sparse(dict.size, termCounts.filter(_._2 >= $(minTermFreq)).toSeq) - } + Vectors.sparse(dict.size, termCounts.toSeq) + } - override protected def validateInputType(inputType: DataType): Unit = { - require(inputType.sameType(ArrayType(StringType)), - s"Input type must be ArrayType(StringType) but got $inputType.") + dataset.withColumn($(outputCol), vectorizer(col($(inputCol)))) } - override protected def outputDataType: DataType = new VectorUDT() + override def transformSchema(schema: StructType): StructType = { + validateAndTransformSchema(schema) + } override def copy(extra: ParamMap): CountVectorizerModel = { val copied = new CountVectorizerModel(uid, vocabulary) From 93e1ad494b837a4b82b95ac1001e1b1adafe4df1 Mon Sep 17 00:00:00 2001 From: Yuhao Yang Date: Tue, 14 Jul 2015 11:13:32 +0800 Subject: [PATCH 2/7] add more ut for estimator --- .../ml/feature/CountVectorizerModel.scala | 6 +- .../ml/feature/CountVectorizorSuite.scala | 56 +++++++++++++++---- 2 files changed, 46 insertions(+), 16 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizerModel.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizerModel.scala index 55a64162b17f0..283cf2aafa6b3 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizerModel.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizerModel.scala @@ -67,7 +67,7 @@ class CountVectorizer(override val uid: String) def this() = this(Identifiable.randomUID("cntVec")) /** - * The minimum number of times a token must appear in the corpus to be included in the vocabulary. + * The minimum number of times a token must appear in the corpus to be included in the vocabulary * Default: 1 * @group param */ @@ -106,7 +106,7 @@ class CountVectorizer(override val uid: String) val vocab: Array[String] = { val tmpSortedWC: Array[(String, Long)] = if (fullVocabSize <= vocab_size) { // Use all terms - wordCounts.collect().sortBy(_._2) + wordCounts.collect().sortBy(-_._2) } else { // Sort terms to select vocab wordCounts.sortBy(_._2, ascending = false).take(vocab_size) @@ -147,7 +147,6 @@ class CountVectorizerModel(override val uid: String, val vocabulary: Array[Strin override def transform(dataset: DataFrame): DataFrame = { val dict = vocabulary.zipWithIndex.toMap - val vectorizer = udf { (document: Seq[String]) => val termCounts = mutable.HashMap.empty[Int, Double] document.foreach { term => @@ -158,7 +157,6 @@ class CountVectorizerModel(override val uid: String, val vocabulary: Array[Strin } Vectors.sparse(dict.size, termCounts.toSeq) } - dataset.withColumn($(outputCol), vectorizer(col($(inputCol)))) } diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/CountVectorizorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/CountVectorizorSuite.scala index e90d9d4ef21ff..9479279a03a12 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/CountVectorizorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/CountVectorizorSuite.scala @@ -30,13 +30,13 @@ class CountVectorizerSuite extends SparkFunSuite with MLlibTestSparkContext { test("CountVectorizerModel common cases") { val df = sqlContext.createDataFrame(Seq( - (0, "a b c d".split(" ").toSeq, + (0, "a b c d".split("\\s+").toSeq, Vectors.sparse(4, Seq((0, 1.0), (1, 1.0), (2, 1.0), (3, 1.0)))), - (1, "a b b c d a".split(" ").toSeq, + (1, "a b b c d a".split("\\s+").toSeq, Vectors.sparse(4, Seq((0, 2.0), (1, 2.0), (2, 1.0), (3, 1.0)))), - (2, "a".split(" ").toSeq, Vectors.sparse(4, Seq((0, 1.0)))), - (3, "".split(" ").toSeq, Vectors.sparse(4, Seq())), // empty string - (4, "a notInDict d".split(" ").toSeq, + (2, "a".split("\\s+").toSeq, Vectors.sparse(4, Seq((0, 1.0)))), + (3, "".split("\\s+").toSeq, Vectors.sparse(4, Seq())), // empty string + (4, "a notInDict d".split("\\s+").toSeq, Vectors.sparse(4, Seq((0, 1.0), (3, 1.0)))) // with words not in vocabulary )).toDF("id", "words", "expected") val cv = new CountVectorizerModel(Array("a", "b", "c", "d")) @@ -50,17 +50,20 @@ class CountVectorizerSuite extends SparkFunSuite with MLlibTestSparkContext { } } - test("CountVectorizerModel with minTermFreq") { + test("CountVectorizer common cases") { val df = sqlContext.createDataFrame(Seq( - (0, "a a a b b c c c d ".split(" ").toSeq, Vectors.sparse(4, Seq((0, 3.0), (2, 3.0)))), - (1, "c c c c c c".split(" ").toSeq, Vectors.sparse(4, Seq((2, 6.0)))), - (2, "a".split(" ").toSeq, Vectors.sparse(4, Seq())), - (3, "e e e e e".split(" ").toSeq, Vectors.sparse(4, Seq()))) + (0, "a b c d e".split("\\s+").toSeq, + Vectors.sparse(5, Seq((0, 1.0), (1, 1.0), (2, 1.0), (3, 1.0), (4, 1.0)))), + (1, "a a a a a a".split("\\s+").toSeq, Vectors.sparse(5, Seq((0, 6.0)))), + (2, "c".split("\\s+").toSeq, Vectors.sparse(5, Seq((2, 1.0)))), + (3, "b b b b b".split("\\s+").toSeq, Vectors.sparse(5, Seq((1, 5.0))))) ).toDF("id", "words", "expected") - val cv = new CountVectorizerModel(Array("a", "b", "c", "d")) + val cv = new CountVectorizer() .setInputCol("words") .setOutputCol("features") - .setMinTermFreq(3) + .fit(df) + assert(cv.vocabulary.deep == Array("a", "b", "c", "d", "e").deep) + val output = cv.transform(df).collect() output.foreach { p => val features = p.getAs[Vector]("features") @@ -68,6 +71,35 @@ class CountVectorizerSuite extends SparkFunSuite with MLlibTestSparkContext { assert(features ~== expected absTol 1e-14) } } + + test("CountVectorizer vocabSize and minCount") { + val df = sqlContext.createDataFrame(Seq( + (0, "a a a a a".split("\\s+").toSeq, Vectors.sparse(3, Seq((0, 5.0)))), + (1, "b b b b".split("\\s+").toSeq, Vectors.sparse(3, Seq((1, 4.0)))), + (2, "c c c".split("\\s+").toSeq, Vectors.sparse(3, Seq((2, 3.0)))), + (3, "d d".split("\\s+").toSeq, Vectors.sparse(3, Seq()))) + ).toDF("id", "words", "expected") + val cvModel = new CountVectorizer() + .setInputCol("words") + .setOutputCol("features") + .setVocabSize(3) // limit vocab size to 3 + .fit(df) + assert(cvModel.vocabulary.deep == Array("a", "b", "c").deep) + + val cvModel2 = new CountVectorizer() + .setInputCol("words") + .setOutputCol("features") + .setMinCount(3) // ignore terms with count less than 3 + .fit(df) + assert(cvModel2.vocabulary.deep == Array("a", "b", "c").deep) + + val output = cvModel2.transform(df).collect() + output.foreach { p => + val features = p.getAs[Vector]("features") + val expected = p.getAs[Vector]("expected") + assert(features ~== expected absTol 1e-14) + } + } } From 589e93d3267b8a3f9ef1f1aba3aec22463dddfa6 Mon Sep 17 00:00:00 2001 From: Yuhao Yang Date: Fri, 24 Jul 2015 22:27:48 -0400 Subject: [PATCH 3/7] minor fix --- .../spark/ml/feature/CountVectorizerModel.scala | 10 ++++------ ...izorSuite.scala => CountVectorizerSuite.scala} | 15 +++++++++++++++ 2 files changed, 19 insertions(+), 6 deletions(-) rename mllib/src/test/scala/org/apache/spark/ml/feature/{CountVectorizorSuite.scala => CountVectorizerSuite.scala} (89%) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizerModel.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizerModel.scala index 283cf2aafa6b3..bc79cfac983a7 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizerModel.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizerModel.scala @@ -59,7 +59,7 @@ private[feature] trait CountVectorizerParams extends Params with HasInputCol wit /** * :: Experimental :: - * Extracts a vocabulary from document collections and generates a CountVectorizerModel. + * Extracts a vocabulary from document collections and generates a [[CountVectorizerModel]]. */ class CountVectorizer(override val uid: String) extends Estimator[CountVectorizerModel] with CountVectorizerParams { @@ -95,21 +95,19 @@ class CountVectorizer(override val uid: String) override def fit(dataset: DataFrame): CountVectorizerModel = { transformSchema(dataset.schema, logging = true) val input = dataset.select($(inputCol)).map(_.getAs[Seq[String]](0)) - val min_count = $(minCount) - val vocab_size = $(vocabSize) val wordCounts: RDD[(String, Long)] = input .flatMap { case (tokens) => tokens.map(_ -> 1L) } .reduceByKey(_ + _) - .filter(_._2 >= min_count) + .filter(_._2 >= $(minCount)) wordCounts.cache() val fullVocabSize = wordCounts.count() val vocab: Array[String] = { - val tmpSortedWC: Array[(String, Long)] = if (fullVocabSize <= vocab_size) { + val tmpSortedWC: Array[(String, Long)] = if (fullVocabSize <= $(vocabSize)) { // Use all terms wordCounts.collect().sortBy(-_._2) } else { // Sort terms to select vocab - wordCounts.sortBy(_._2, ascending = false).take(vocab_size) + wordCounts.sortBy(_._2, ascending = false).take($(vocabSize)) } tmpSortedWC.map(_._1) } diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/CountVectorizorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/CountVectorizerSuite.scala similarity index 89% rename from mllib/src/test/scala/org/apache/spark/ml/feature/CountVectorizorSuite.scala rename to mllib/src/test/scala/org/apache/spark/ml/feature/CountVectorizerSuite.scala index 9479279a03a12..2078858e4145a 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/CountVectorizorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/CountVectorizerSuite.scala @@ -100,6 +100,21 @@ class CountVectorizerSuite extends SparkFunSuite with MLlibTestSparkContext { assert(features ~== expected absTol 1e-14) } } + + test("CountVectorizer throws exception when vocab is empty") { + intercept[IllegalArgumentException] { + val df = sqlContext.createDataFrame(Seq( + (0, "a a b b c c".split("\\s+").toSeq), + (1, "aa bb cc".split("\\s+").toSeq)) + ).toDF("id", "words") + val cvModel = new CountVectorizer() + .setInputCol("words") + .setOutputCol("features") + .setVocabSize(3) // limit vocab size to 3 + .setMinCount(3) + .fit(df) + } + } } From 0fe9f967412ea06b53015aa3ad28889497b85f87 Mon Sep 17 00:00:00 2001 From: "Joseph K. Bradley" Date: Thu, 13 Aug 2015 23:28:16 -0700 Subject: [PATCH 4/7] Updates: * Renamed "minCount" to "minTokenCount" * Added "minTermFreq" back, including unit test * Moved all Params to include in both Estimator and Model so that they can be viewed in either. --- ...rizerModel.scala => CountVectorizer.scala} | 77 +++++++++++++------ .../ml/feature/CountVectorizerSuite.scala | 27 +++++-- 2 files changed, 74 insertions(+), 30 deletions(-) rename mllib/src/main/scala/org/apache/spark/ml/feature/{CountVectorizerModel.scala => CountVectorizer.scala} (73%) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizerModel.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala similarity index 73% rename from mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizerModel.scala rename to mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala index bc79cfac983a7..544cad1c0622b 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizerModel.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala @@ -35,49 +35,66 @@ import org.apache.spark.sql.DataFrame private[feature] trait CountVectorizerParams extends Params with HasInputCol with HasOutputCol { /** - * size of the vocabulary. - * If using Estimator, CountVectorizer will build a vocabulary that only consider the top + * Max size of the vocabulary. + * CountVectorizer will build a vocabulary that only considers the top * vocabSize terms ordered by term frequency across the corpus. + * * Default: 10000 * @group param */ - val vocabSize: IntParam = new IntParam(this, "vocabSize", "size of the vocabulary") + val vocabSize: IntParam = + new IntParam(this, "vocabSize", "size of the vocabulary", ParamValidators.gt(0)) /** @group getParam */ def getVocabSize: Int = $(vocabSize) + /** + * The minimum number of times a token must appear in the corpus to be included in the vocabulary. + * Note that this is not the same as document frequency: [[minTokenCount]] counts tokens including + * duplicates of terms, whereas document frequency counts unique terms. Support for document + * frequency will be added in the future. + * + * Default: 1 + * @group param + */ + val minTokenCount: IntParam = new IntParam(this, "minTokenCount", + "minimum number of times a token must appear in the corpus to be included in the vocabulary." + , ParamValidators.gtEq(1)) + + /** @group getParam */ + def getMinTokenCount: Int = $(minTokenCount) + /** Validates and transforms the input schema. */ protected def validateAndTransformSchema(schema: StructType): StructType = { SchemaUtils.checkColumnType(schema, $(inputCol), new ArrayType(StringType, true)) SchemaUtils.appendColumn(schema, $(outputCol), new VectorUDT) } - override def validateParams(): Unit = { - require($(vocabSize) > 0, s"The vocabulary size (${$(vocabSize)}) must be above 0.") - } + /** + * Filter to ignore scarce words in a document. For each document, terms with + * frequency (count) less than the given threshold are ignored. + * Default: 1 + * @group param + */ + val minTermFreq: IntParam = new IntParam(this, "minTermFreq", + "minimum frequency (count) filter used to neglect scarce words (>= 1). For each document, " + + "terms with frequency less than the given threshold are ignored.", ParamValidators.gtEq(1)) + setDefault(minTermFreq -> 1) + + /** @group getParam */ + def getMinTermFreq: Int = $(minTermFreq) } /** * :: Experimental :: * Extracts a vocabulary from document collections and generates a [[CountVectorizerModel]]. */ +@Experimental class CountVectorizer(override val uid: String) extends Estimator[CountVectorizerModel] with CountVectorizerParams { def this() = this(Identifiable.randomUID("cntVec")) - /** - * The minimum number of times a token must appear in the corpus to be included in the vocabulary - * Default: 1 - * @group param - */ - val minCount: IntParam = new IntParam(this, "minCount", - "minimum number of times a token must appear in the corpus to be included in the vocabulary." - , ParamValidators.gtEq(1)) - - /** @group getParam */ - def getMinCount: Int = $(minCount) - /** @group setParam */ def setInputCol(value: String): this.type = set(inputCol, value) @@ -88,31 +105,37 @@ class CountVectorizer(override val uid: String) def setVocabSize(value: Int): this.type = set(vocabSize, value) /** @group setParam */ - def setMinCount(value: Int): this.type = set(minCount, value) + def setMinTokenCount(value: Int): this.type = set(minTokenCount, value) - setDefault(vocabSize -> 10000, minCount -> 1) + /** @group setParam */ + def setMinTermFreq(value: Int): this.type = set(minTermFreq, value) + + setDefault(vocabSize -> 10000, minTokenCount -> 1) override def fit(dataset: DataFrame): CountVectorizerModel = { transformSchema(dataset.schema, logging = true) + val minCnt = $(minTokenCount) + val vocSize = $(vocabSize) val input = dataset.select($(inputCol)).map(_.getAs[Seq[String]](0)) val wordCounts: RDD[(String, Long)] = input .flatMap { case (tokens) => tokens.map(_ -> 1L) } .reduceByKey(_ + _) - .filter(_._2 >= $(minCount)) + .filter(_._2 >= minCnt) wordCounts.cache() val fullVocabSize = wordCounts.count() val vocab: Array[String] = { - val tmpSortedWC: Array[(String, Long)] = if (fullVocabSize <= $(vocabSize)) { + val tmpSortedWC: Array[(String, Long)] = if (fullVocabSize <= vocSize) { // Use all terms wordCounts.collect().sortBy(-_._2) } else { // Sort terms to select vocab - wordCounts.sortBy(_._2, ascending = false).take($(vocabSize)) + wordCounts.sortBy(_._2, ascending = false).take(vocSize) } tmpSortedWC.map(_._1) } - require(vocab.length > 0, "The vocabulary size should be > 0. Adjust minCount as necessary.") + require(vocab.length > 0, + "The vocabulary size should be > 0. Lower minTokenCount as necessary.") copyValues(new CountVectorizerModel(uid, vocab).setParent(this)) } @@ -143,8 +166,12 @@ class CountVectorizerModel(override val uid: String, val vocabulary: Array[Strin /** @group setParam */ def setOutputCol(value: String): this.type = set(outputCol, value) + /** @group setParam */ + def setMinTermFreq(value: Int): this.type = set(minTermFreq, value) + override def transform(dataset: DataFrame): DataFrame = { val dict = vocabulary.zipWithIndex.toMap + val minTF = $(minTermFreq) val vectorizer = udf { (document: Seq[String]) => val termCounts = mutable.HashMap.empty[Int, Double] document.foreach { term => @@ -153,7 +180,7 @@ class CountVectorizerModel(override val uid: String, val vocabulary: Array[Strin case None => // ignore terms not in the vocabulary } } - Vectors.sparse(dict.size, termCounts.toSeq) + Vectors.sparse(dict.size, termCounts.filter(_._2 >= minTF).toSeq) } dataset.withColumn($(outputCol), vectorizer(col($(inputCol)))) } diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/CountVectorizerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/CountVectorizerSuite.scala index 2078858e4145a..642e9802ee725 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/CountVectorizerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/CountVectorizerSuite.scala @@ -72,7 +72,7 @@ class CountVectorizerSuite extends SparkFunSuite with MLlibTestSparkContext { } } - test("CountVectorizer vocabSize and minCount") { + test("CountVectorizer vocabSize and minTokenCount") { val df = sqlContext.createDataFrame(Seq( (0, "a a a a a".split("\\s+").toSeq, Vectors.sparse(3, Seq((0, 5.0)))), (1, "b b b b".split("\\s+").toSeq, Vectors.sparse(3, Seq((1, 4.0)))), @@ -89,7 +89,7 @@ class CountVectorizerSuite extends SparkFunSuite with MLlibTestSparkContext { val cvModel2 = new CountVectorizer() .setInputCol("words") .setOutputCol("features") - .setMinCount(3) // ignore terms with count less than 3 + .setMinTokenCount(3) // ignore terms with count less than 3 .fit(df) assert(cvModel2.vocabulary.deep == Array("a", "b", "c").deep) @@ -111,10 +111,27 @@ class CountVectorizerSuite extends SparkFunSuite with MLlibTestSparkContext { .setInputCol("words") .setOutputCol("features") .setVocabSize(3) // limit vocab size to 3 - .setMinCount(3) + .setMinTokenCount(3) .fit(df) } } -} - + test("CountVectorizerModel with minTermFreq") { + val df = sqlContext.createDataFrame(Seq( + (0, "a a a b b c c c d ".split(" ").toSeq, Vectors.sparse(4, Seq((0, 3.0), (2, 3.0)))), + (1, "c c c c c c".split(" ").toSeq, Vectors.sparse(4, Seq((2, 6.0)))), + (2, "a".split(" ").toSeq, Vectors.sparse(4, Seq())), + (3, "e e e e e".split(" ").toSeq, Vectors.sparse(4, Seq()))) + ).toDF("id", "words", "expected") + val cv = new CountVectorizerModel(Array("a", "b", "c", "d")) + .setInputCol("words") + .setOutputCol("features") + .setMinTermFreq(3) + val output = cv.transform(df).collect() + output.foreach { p => + val features = p.getAs[Vector]("features") + val expected = p.getAs[Vector]("expected") + assert(features ~== expected absTol 1e-14) + } + } +} From 17b30097d0a553fc74cd0aef03ff8ca1ead4dfb7 Mon Sep 17 00:00:00 2001 From: Yuhao Yang Date: Fri, 14 Aug 2015 21:59:39 -0400 Subject: [PATCH 5/7] replace minTokenCount with minDocFreq --- .../spark/ml/feature/CountVectorizer.scala | 47 +++++++++--------- .../ml/feature/CountVectorizerSuite.scala | 49 ++++++++----------- 2 files changed, 46 insertions(+), 50 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala index 544cad1c0622b..990188f2c163f 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala @@ -39,30 +39,28 @@ private[feature] trait CountVectorizerParams extends Params with HasInputCol wit * CountVectorizer will build a vocabulary that only considers the top * vocabSize terms ordered by term frequency across the corpus. * - * Default: 10000 + * Default: 2^18^ * @group param */ val vocabSize: IntParam = - new IntParam(this, "vocabSize", "size of the vocabulary", ParamValidators.gt(0)) + new IntParam(this, "vocabSize", "max size of the vocabulary", ParamValidators.gt(0)) /** @group getParam */ def getVocabSize: Int = $(vocabSize) /** - * The minimum number of times a token must appear in the corpus to be included in the vocabulary. - * Note that this is not the same as document frequency: [[minTokenCount]] counts tokens including - * duplicates of terms, whereas document frequency counts unique terms. Support for document - * frequency will be added in the future. + * The minimum number of different documents a token must appear in to be included in the + * vocabulary. * * Default: 1 * @group param */ - val minTokenCount: IntParam = new IntParam(this, "minTokenCount", - "minimum number of times a token must appear in the corpus to be included in the vocabulary." + val minDocFreq: IntParam = new IntParam(this, "minDocFreq", + "The minimum number of documents a token must appear in to be included in the vocabulary" , ParamValidators.gtEq(1)) /** @group getParam */ - def getMinTokenCount: Int = $(minTokenCount) + def getMinDocFreq: Int = $(minDocFreq) /** Validates and transforms the input schema. */ protected def validateAndTransformSchema(schema: StructType): StructType = { @@ -73,12 +71,15 @@ private[feature] trait CountVectorizerParams extends Params with HasInputCol wit /** * Filter to ignore scarce words in a document. For each document, terms with * frequency (count) less than the given threshold are ignored. + * Note that the parameter is only used in transform of [[CountVectorizerModel]] and does not + * affect fitting. * Default: 1 * @group param */ val minTermFreq: IntParam = new IntParam(this, "minTermFreq", "minimum frequency (count) filter used to neglect scarce words (>= 1). For each document, " + - "terms with frequency less than the given threshold are ignored.", ParamValidators.gtEq(1)) + "terms with count less than the given threshold are ignored.", ParamValidators.gtEq(1)) + setDefault(minTermFreq -> 1) /** @group getParam */ @@ -105,23 +106,24 @@ class CountVectorizer(override val uid: String) def setVocabSize(value: Int): this.type = set(vocabSize, value) /** @group setParam */ - def setMinTokenCount(value: Int): this.type = set(minTokenCount, value) - - /** @group setParam */ - def setMinTermFreq(value: Int): this.type = set(minTermFreq, value) + def setMinDocFreq(value: Int): this.type = set(minDocFreq, value) - setDefault(vocabSize -> 10000, minTokenCount -> 1) + setDefault(vocabSize -> (1 << 18), minDocFreq -> 1) override def fit(dataset: DataFrame): CountVectorizerModel = { transformSchema(dataset.schema, logging = true) - val minCnt = $(minTokenCount) + val minDf = $(minDocFreq) val vocSize = $(vocabSize) val input = dataset.select($(inputCol)).map(_.getAs[Seq[String]](0)) val wordCounts: RDD[(String, Long)] = input - .flatMap { case (tokens) => tokens.map(_ -> 1L) } - .reduceByKey(_ + _) - .filter(_._2 >= minCnt) - wordCounts.cache() + .flatMap { case (tokens) => + tokens.foldLeft(Map.empty[String, Long])( + (count, word) => count + (word -> (count.getOrElse(word, 0L) + 1))) + .map { case (word, count) => (word, (count, 1)) }} + .reduceByKey { case ((wc1, df1), (wc2, df2)) => (wc1 + wc2, df1 + df2) } + .filter { case (word, (wc, df)) => df >= minDf } + .map { case (word, (count, dfCount)) => (word, count) } + .cache() val fullVocabSize = wordCounts.count() val vocab: Array[String] = { val tmpSortedWC: Array[(String, Long)] = if (fullVocabSize <= vocSize) { @@ -135,7 +137,7 @@ class CountVectorizer(override val uid: String) } require(vocab.length > 0, - "The vocabulary size should be > 0. Lower minTokenCount as necessary.") + "The vocabulary size should be > 0. Lower minDocFreq as necessary.") copyValues(new CountVectorizerModel(uid, vocab).setParent(this)) } @@ -171,11 +173,12 @@ class CountVectorizerModel(override val uid: String, val vocabulary: Array[Strin override def transform(dataset: DataFrame): DataFrame = { val dict = vocabulary.zipWithIndex.toMap + val dictBr = dataset.sqlContext.sparkContext.broadcast(dict) val minTF = $(minTermFreq) val vectorizer = udf { (document: Seq[String]) => val termCounts = mutable.HashMap.empty[Int, Double] document.foreach { term => - dict.get(term) match { + dictBr.value.get(term) match { case Some(index) => termCounts.put(index, termCounts.getOrElse(index, 0.0) + 1.0) case None => // ignore terms not in the vocabulary } diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/CountVectorizerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/CountVectorizerSuite.scala index 642e9802ee725..c556d8fec28a4 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/CountVectorizerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/CountVectorizerSuite.scala @@ -21,6 +21,7 @@ import org.apache.spark.ml.param.ParamsSuite import org.apache.spark.mllib.linalg.{Vector, Vectors} import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.TestingUtils._ +import org.apache.spark.sql.Row class CountVectorizerSuite extends SparkFunSuite with MLlibTestSparkContext { @@ -42,11 +43,9 @@ class CountVectorizerSuite extends SparkFunSuite with MLlibTestSparkContext { val cv = new CountVectorizerModel(Array("a", "b", "c", "d")) .setInputCol("words") .setOutputCol("features") - val output = cv.transform(df).collect() - output.foreach { p => - val features = p.getAs[Vector]("features") - val expected = p.getAs[Vector]("expected") - assert(features ~== expected absTol 1e-14) + cv.transform(df).select("features", "expected").collect().foreach { + case Row(features: Vector, expected: Vector) => + assert(features ~== expected absTol 1e-14) } } @@ -64,20 +63,18 @@ class CountVectorizerSuite extends SparkFunSuite with MLlibTestSparkContext { .fit(df) assert(cv.vocabulary.deep == Array("a", "b", "c", "d", "e").deep) - val output = cv.transform(df).collect() - output.foreach { p => - val features = p.getAs[Vector]("features") - val expected = p.getAs[Vector]("expected") - assert(features ~== expected absTol 1e-14) + cv.transform(df).select("features", "expected").collect().foreach { + case Row(features: Vector, expected: Vector) => + assert(features ~== expected absTol 1e-14) } } - test("CountVectorizer vocabSize and minTokenCount") { + test("CountVectorizer vocabSize and minDocFreq") { val df = sqlContext.createDataFrame(Seq( - (0, "a a a a a".split("\\s+").toSeq, Vectors.sparse(3, Seq((0, 5.0)))), - (1, "b b b b".split("\\s+").toSeq, Vectors.sparse(3, Seq((1, 4.0)))), - (2, "c c c".split("\\s+").toSeq, Vectors.sparse(3, Seq((2, 3.0)))), - (3, "d d".split("\\s+").toSeq, Vectors.sparse(3, Seq()))) + (0, "a b c d".split("\\s+").toSeq, Vectors.sparse(3, Seq((0, 1.0), (1, 1.0)))), + (1, "a b c".split("\\s+").toSeq, Vectors.sparse(3, Seq((0, 1.0), (1, 1.0)))), + (2, "a b".split("\\s+").toSeq, Vectors.sparse(3, Seq((0, 1.0), (1, 1.0)))), + (3, "a".split("\\s+").toSeq, Vectors.sparse(3, Seq((0, 1.0))))) ).toDF("id", "words", "expected") val cvModel = new CountVectorizer() .setInputCol("words") @@ -89,15 +86,13 @@ class CountVectorizerSuite extends SparkFunSuite with MLlibTestSparkContext { val cvModel2 = new CountVectorizer() .setInputCol("words") .setOutputCol("features") - .setMinTokenCount(3) // ignore terms with count less than 3 + .setMinDocFreq(3) // ignore terms with count less than 3 .fit(df) - assert(cvModel2.vocabulary.deep == Array("a", "b", "c").deep) + assert(cvModel2.vocabulary.deep == Array("a", "b").deep) - val output = cvModel2.transform(df).collect() - output.foreach { p => - val features = p.getAs[Vector]("features") - val expected = p.getAs[Vector]("expected") - assert(features ~== expected absTol 1e-14) + cvModel2.transform(df).select("features", "expected").collect().foreach { + case Row(features: Vector, expected: Vector) => + assert(features ~== expected absTol 1e-14) } } @@ -111,7 +106,7 @@ class CountVectorizerSuite extends SparkFunSuite with MLlibTestSparkContext { .setInputCol("words") .setOutputCol("features") .setVocabSize(3) // limit vocab size to 3 - .setMinTokenCount(3) + .setMinDocFreq(3) .fit(df) } } @@ -127,11 +122,9 @@ class CountVectorizerSuite extends SparkFunSuite with MLlibTestSparkContext { .setInputCol("words") .setOutputCol("features") .setMinTermFreq(3) - val output = cv.transform(df).collect() - output.foreach { p => - val features = p.getAs[Vector]("features") - val expected = p.getAs[Vector]("expected") - assert(features ~== expected absTol 1e-14) + cv.transform(df).select("features", "expected").collect().foreach { + case Row(features: Vector, expected: Vector) => + assert(features ~== expected absTol 1e-14) } } } From a9a9485ace02d98c8b054ea25766e10c86de76cc Mon Sep 17 00:00:00 2001 From: "Joseph K. Bradley" Date: Mon, 17 Aug 2015 16:28:24 -0700 Subject: [PATCH 6/7] renamed docFreq to DF, termFreq to TF, and added fractional support. save broadcast as private var --- .../spark/ml/feature/CountVectorizer.scala | 110 ++++++++++++------ .../ml/feature/CountVectorizerSuite.scala | 93 ++++++++++----- 2 files changed, 138 insertions(+), 65 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala index 990188f2c163f..e6dffd2332d17 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala @@ -16,9 +16,8 @@ */ package org.apache.spark.ml.feature -import scala.collection.mutable - import org.apache.spark.annotation.Experimental +import org.apache.spark.broadcast.Broadcast import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol} import org.apache.spark.ml.util.{Identifiable, SchemaUtils} @@ -28,6 +27,7 @@ import org.apache.spark.rdd.RDD import org.apache.spark.sql.functions._ import org.apache.spark.sql.types._ import org.apache.spark.sql.DataFrame +import org.apache.spark.util.collection.OpenHashMap /** * Params for [[CountVectorizer]] and [[CountVectorizerModel]]. @@ -49,18 +49,22 @@ private[feature] trait CountVectorizerParams extends Params with HasInputCol wit def getVocabSize: Int = $(vocabSize) /** - * The minimum number of different documents a token must appear in to be included in the - * vocabulary. + * Specifies the minimum number of different documents a term must appear in to be included + * in the vocabulary. + * If this is an integer >= 1, this specifies the number of documents the term must appear in; + * if this is a double in [0,1), then this specifies the fraction of documents. * * Default: 1 * @group param */ - val minDocFreq: IntParam = new IntParam(this, "minDocFreq", - "The minimum number of documents a token must appear in to be included in the vocabulary" - , ParamValidators.gtEq(1)) + val minDF: DoubleParam = new DoubleParam(this, "minDF", "Specifies the minimum number of" + + " different documents a term must appear in to be included in the vocabulary." + + " If this is an integer >= 1, this specifies the number of documents the term must" + + " appear in; if this is a double in [0,1), then this specifies the fraction of documents.", + ParamValidators.gtEq(0.0)) /** @group getParam */ - def getMinDocFreq: Int = $(minDocFreq) + def getMinDF: Double = $(minDF) /** Validates and transforms the input schema. */ protected def validateAndTransformSchema(schema: StructType): StructType = { @@ -69,21 +73,30 @@ private[feature] trait CountVectorizerParams extends Params with HasInputCol wit } /** - * Filter to ignore scarce words in a document. For each document, terms with - * frequency (count) less than the given threshold are ignored. + * Filter to ignore rare words in a document. For each document, terms with + * frequency/count less than the given threshold are ignored. + * If this is an integer >= 1, then this specifies a count (of times the term must appear + * in the document); + * if this is a double in [0,1), then this specifies a fraction (out of the document's token + * count). + * * Note that the parameter is only used in transform of [[CountVectorizerModel]] and does not * affect fitting. + * * Default: 1 * @group param */ - val minTermFreq: IntParam = new IntParam(this, "minTermFreq", - "minimum frequency (count) filter used to neglect scarce words (>= 1). For each document, " + - "terms with count less than the given threshold are ignored.", ParamValidators.gtEq(1)) + val minTF: DoubleParam = new DoubleParam(this, "minTF", "Filter to ignore rare words in" + + " a document. For each document, terms with frequency/count less than the given threshold are" + + " ignored. If this is an integer >= 1, then this specifies a count (of times the term must" + + " appear in the document); if this is a double in [0,1), then this specifies a fraction (out" + + " of the document's token count). Note that the parameter is only used in transform of" + + " CountVectorizerModel and does not affect fitting.", ParamValidators.gtEq(0.0)) - setDefault(minTermFreq -> 1) + setDefault(minTF -> 1) /** @group getParam */ - def getMinTermFreq: Int = $(minTermFreq) + def getMinTF: Double = $(minTF) } /** @@ -106,24 +119,35 @@ class CountVectorizer(override val uid: String) def setVocabSize(value: Int): this.type = set(vocabSize, value) /** @group setParam */ - def setMinDocFreq(value: Int): this.type = set(minDocFreq, value) + def setMinDF(value: Double): this.type = set(minDF, value) - setDefault(vocabSize -> (1 << 18), minDocFreq -> 1) + /** @group setParam */ + def setMinTF(value: Double): this.type = set(minTF, value) + + setDefault(vocabSize -> (1 << 18), minDF -> 1) override def fit(dataset: DataFrame): CountVectorizerModel = { transformSchema(dataset.schema, logging = true) - val minDf = $(minDocFreq) val vocSize = $(vocabSize) val input = dataset.select($(inputCol)).map(_.getAs[Seq[String]](0)) - val wordCounts: RDD[(String, Long)] = input - .flatMap { case (tokens) => - tokens.foldLeft(Map.empty[String, Long])( - (count, word) => count + (word -> (count.getOrElse(word, 0L) + 1))) - .map { case (word, count) => (word, (count, 1)) }} - .reduceByKey { case ((wc1, df1), (wc2, df2)) => (wc1 + wc2, df1 + df2) } - .filter { case (word, (wc, df)) => df >= minDf } - .map { case (word, (count, dfCount)) => (word, count) } - .cache() + val minDf: Long = if ($(minDF) >= 1.0) { + $(minDF).toLong + } else { + math.ceil($(minDF) * input.cache().count()).toLong + } + val wordCounts: RDD[(String, Long)] = input.flatMap { case (tokens) => + val wc = new OpenHashMap[String, Long] + tokens.foreach { w => + wc.changeValue(w, 1L, _ + 1L) + } + wc.map { case (word, count) => (word, (count, 1)) } + }.reduceByKey { case ((wc1, df1), (wc2, df2)) => + (wc1 + wc2, df1 + df2) + }.filter { case (word, (wc, df)) => + df >= minDf + }.map { case (word, (count, dfCount)) => + (word, count) + }.cache() val fullVocabSize = wordCounts.count() val vocab: Array[String] = { val tmpSortedWC: Array[(String, Long)] = if (fullVocabSize <= vocSize) { @@ -136,8 +160,7 @@ class CountVectorizer(override val uid: String) tmpSortedWC.map(_._1) } - require(vocab.length > 0, - "The vocabulary size should be > 0. Lower minDocFreq as necessary.") + require(vocab.length > 0, "The vocabulary size should be > 0. Lower minDF as necessary.") copyValues(new CountVectorizerModel(uid, vocab).setParent(this)) } @@ -169,21 +192,34 @@ class CountVectorizerModel(override val uid: String, val vocabulary: Array[Strin def setOutputCol(value: String): this.type = set(outputCol, value) /** @group setParam */ - def setMinTermFreq(value: Int): this.type = set(minTermFreq, value) + def setMinTF(value: Double): this.type = set(minTF, value) + + /** Dictionary created from [[vocabulary]] and its indices, broadcast once for [[transform()]] */ + private var broadcastDict: Option[Broadcast[Map[String, Int]]] = None override def transform(dataset: DataFrame): DataFrame = { - val dict = vocabulary.zipWithIndex.toMap - val dictBr = dataset.sqlContext.sparkContext.broadcast(dict) - val minTF = $(minTermFreq) + if (broadcastDict.isEmpty) { + val dict = vocabulary.zipWithIndex.toMap + broadcastDict = Some(dataset.sqlContext.sparkContext.broadcast(dict)) + } + val dictBr = broadcastDict.get + val minTf = $(minTF) val vectorizer = udf { (document: Seq[String]) => - val termCounts = mutable.HashMap.empty[Int, Double] + val termCounts = new OpenHashMap[Int, Double] + var tokenCount = 0L document.foreach { term => dictBr.value.get(term) match { - case Some(index) => termCounts.put(index, termCounts.getOrElse(index, 0.0) + 1.0) + case Some(index) => termCounts.changeValue(index, 1.0, _ + 1.0) case None => // ignore terms not in the vocabulary } + tokenCount += 1 + } + val effectiveMinTF = if (minTf >= 1.0) { + minTf + } else { + tokenCount * minTf } - Vectors.sparse(dict.size, termCounts.filter(_._2 >= minTF).toSeq) + Vectors.sparse(dictBr.value.size, termCounts.filter(_._2 >= effectiveMinTF).toSeq) } dataset.withColumn($(outputCol), vectorizer(col($(inputCol)))) } @@ -193,7 +229,7 @@ class CountVectorizerModel(override val uid: String, val vocabulary: Array[Strin } override def copy(extra: ParamMap): CountVectorizerModel = { - val copied = new CountVectorizerModel(uid, vocabulary) + val copied = new CountVectorizerModel(uid, vocabulary).setParent(parent) copyValues(copied, extra) } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/CountVectorizerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/CountVectorizerSuite.scala index c556d8fec28a4..e192fa4850af0 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/CountVectorizerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/CountVectorizerSuite.scala @@ -29,15 +29,17 @@ class CountVectorizerSuite extends SparkFunSuite with MLlibTestSparkContext { ParamsSuite.checkParams(new CountVectorizerModel(Array("empty"))) } + private def split(s: String): Seq[String] = s.split("\\s+") + test("CountVectorizerModel common cases") { val df = sqlContext.createDataFrame(Seq( - (0, "a b c d".split("\\s+").toSeq, + (0, split("a b c d"), Vectors.sparse(4, Seq((0, 1.0), (1, 1.0), (2, 1.0), (3, 1.0)))), - (1, "a b b c d a".split("\\s+").toSeq, + (1, split("a b b c d a"), Vectors.sparse(4, Seq((0, 2.0), (1, 2.0), (2, 1.0), (3, 1.0)))), - (2, "a".split("\\s+").toSeq, Vectors.sparse(4, Seq((0, 1.0)))), - (3, "".split("\\s+").toSeq, Vectors.sparse(4, Seq())), // empty string - (4, "a notInDict d".split("\\s+").toSeq, + (2, split("a"), Vectors.sparse(4, Seq((0, 1.0)))), + (3, split(""), Vectors.sparse(4, Seq())), // empty string + (4, split("a notInDict d"), Vectors.sparse(4, Seq((0, 1.0), (3, 1.0)))) // with words not in vocabulary )).toDF("id", "words", "expected") val cv = new CountVectorizerModel(Array("a", "b", "c", "d")) @@ -51,17 +53,17 @@ class CountVectorizerSuite extends SparkFunSuite with MLlibTestSparkContext { test("CountVectorizer common cases") { val df = sqlContext.createDataFrame(Seq( - (0, "a b c d e".split("\\s+").toSeq, + (0, split("a b c d e"), Vectors.sparse(5, Seq((0, 1.0), (1, 1.0), (2, 1.0), (3, 1.0), (4, 1.0)))), - (1, "a a a a a a".split("\\s+").toSeq, Vectors.sparse(5, Seq((0, 6.0)))), - (2, "c".split("\\s+").toSeq, Vectors.sparse(5, Seq((2, 1.0)))), - (3, "b b b b b".split("\\s+").toSeq, Vectors.sparse(5, Seq((1, 5.0))))) + (1, split("a a a a a a"), Vectors.sparse(5, Seq((0, 6.0)))), + (2, split("c"), Vectors.sparse(5, Seq((2, 1.0)))), + (3, split("b b b b b"), Vectors.sparse(5, Seq((1, 5.0))))) ).toDF("id", "words", "expected") val cv = new CountVectorizer() .setInputCol("words") .setOutputCol("features") .fit(df) - assert(cv.vocabulary.deep == Array("a", "b", "c", "d", "e").deep) + assert(cv.vocabulary === Array("a", "b", "c", "d", "e")) cv.transform(df).select("features", "expected").collect().foreach { case Row(features: Vector, expected: Vector) => @@ -69,59 +71,94 @@ class CountVectorizerSuite extends SparkFunSuite with MLlibTestSparkContext { } } - test("CountVectorizer vocabSize and minDocFreq") { + test("CountVectorizer vocabSize and minDF") { val df = sqlContext.createDataFrame(Seq( - (0, "a b c d".split("\\s+").toSeq, Vectors.sparse(3, Seq((0, 1.0), (1, 1.0)))), - (1, "a b c".split("\\s+").toSeq, Vectors.sparse(3, Seq((0, 1.0), (1, 1.0)))), - (2, "a b".split("\\s+").toSeq, Vectors.sparse(3, Seq((0, 1.0), (1, 1.0)))), - (3, "a".split("\\s+").toSeq, Vectors.sparse(3, Seq((0, 1.0))))) + (0, split("a b c d"), Vectors.sparse(3, Seq((0, 1.0), (1, 1.0)))), + (1, split("a b c"), Vectors.sparse(3, Seq((0, 1.0), (1, 1.0)))), + (2, split("a b"), Vectors.sparse(3, Seq((0, 1.0), (1, 1.0)))), + (3, split("a"), Vectors.sparse(3, Seq((0, 1.0))))) ).toDF("id", "words", "expected") val cvModel = new CountVectorizer() .setInputCol("words") .setOutputCol("features") .setVocabSize(3) // limit vocab size to 3 .fit(df) - assert(cvModel.vocabulary.deep == Array("a", "b", "c").deep) + assert(cvModel.vocabulary === Array("a", "b", "c")) + // minDF: ignore terms with count less than 3 val cvModel2 = new CountVectorizer() .setInputCol("words") .setOutputCol("features") - .setMinDocFreq(3) // ignore terms with count less than 3 + .setMinDF(3) .fit(df) - assert(cvModel2.vocabulary.deep == Array("a", "b").deep) + assert(cvModel2.vocabulary === Array("a", "b")) cvModel2.transform(df).select("features", "expected").collect().foreach { case Row(features: Vector, expected: Vector) => assert(features ~== expected absTol 1e-14) } + + // minDF: ignore terms with freq < 0.75 + val cvModel3 = new CountVectorizer() + .setInputCol("words") + .setOutputCol("features") + .setMinDF(3.0 / df.count()) + .fit(df) + assert(cvModel3.vocabulary === Array("a", "b")) + + cvModel3.transform(df).select("features", "expected").collect().foreach { + case Row(features: Vector, expected: Vector) => + assert(features ~== expected absTol 1e-14) + } } test("CountVectorizer throws exception when vocab is empty") { intercept[IllegalArgumentException] { val df = sqlContext.createDataFrame(Seq( - (0, "a a b b c c".split("\\s+").toSeq), - (1, "aa bb cc".split("\\s+").toSeq)) + (0, split("a a b b c c")), + (1, split("aa bb cc"))) ).toDF("id", "words") val cvModel = new CountVectorizer() .setInputCol("words") .setOutputCol("features") - .setVocabSize(3) // limit vocab size to 3 - .setMinDocFreq(3) + .setVocabSize(3) // limit vocab size to 3 + .setMinDF(3) .fit(df) } } - test("CountVectorizerModel with minTermFreq") { + test("CountVectorizerModel with minTF count") { + val df = sqlContext.createDataFrame(Seq( + (0, split("a a a b b c c c d "), Vectors.sparse(4, Seq((0, 3.0), (2, 3.0)))), + (1, split("c c c c c c"), Vectors.sparse(4, Seq((2, 6.0)))), + (2, split("a"), Vectors.sparse(4, Seq())), + (3, split("e e e e e"), Vectors.sparse(4, Seq()))) + ).toDF("id", "words", "expected") + + // minTF: count + val cv = new CountVectorizerModel(Array("a", "b", "c", "d")) + .setInputCol("words") + .setOutputCol("features") + .setMinTF(3) + cv.transform(df).select("features", "expected").collect().foreach { + case Row(features: Vector, expected: Vector) => + assert(features ~== expected absTol 1e-14) + } + } + + test("CountVectorizerModel with minTF freq") { val df = sqlContext.createDataFrame(Seq( - (0, "a a a b b c c c d ".split(" ").toSeq, Vectors.sparse(4, Seq((0, 3.0), (2, 3.0)))), - (1, "c c c c c c".split(" ").toSeq, Vectors.sparse(4, Seq((2, 6.0)))), - (2, "a".split(" ").toSeq, Vectors.sparse(4, Seq())), - (3, "e e e e e".split(" ").toSeq, Vectors.sparse(4, Seq()))) + (0, split("a a a b b c c c d "), Vectors.sparse(4, Seq((0, 3.0), (2, 3.0)))), + (1, split("c c c c c c"), Vectors.sparse(4, Seq((2, 6.0)))), + (2, split("a"), Vectors.sparse(4, Seq((0, 1.0)))), + (3, split("e e e e e"), Vectors.sparse(4, Seq()))) ).toDF("id", "words", "expected") + + // minTF: count val cv = new CountVectorizerModel(Array("a", "b", "c", "d")) .setInputCol("words") .setOutputCol("features") - .setMinTermFreq(3) + .setMinTF(0.3) cv.transform(df).select("features", "expected").collect().foreach { case Row(features: Vector, expected: Vector) => assert(features ~== expected absTol 1e-14) From a37081636388546f7c84e429d38f206b014cc88e Mon Sep 17 00:00:00 2001 From: Yuhao Yang Date: Wed, 19 Aug 2015 00:56:49 +0800 Subject: [PATCH 7/7] use minDF as Double --- .../scala/org/apache/spark/ml/feature/CountVectorizer.scala | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala index e6dffd2332d17..49028e4b85064 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala @@ -130,10 +130,10 @@ class CountVectorizer(override val uid: String) transformSchema(dataset.schema, logging = true) val vocSize = $(vocabSize) val input = dataset.select($(inputCol)).map(_.getAs[Seq[String]](0)) - val minDf: Long = if ($(minDF) >= 1.0) { - $(minDF).toLong + val minDf = if ($(minDF) >= 1.0) { + $(minDF) } else { - math.ceil($(minDF) * input.cache().count()).toLong + $(minDF) * input.cache().count() } val wordCounts: RDD[(String, Long)] = input.flatMap { case (tokens) => val wc = new OpenHashMap[String, Long]