Skip to content

Commit

Permalink
Add minimumOccurence filtering to IDF
Browse files Browse the repository at this point in the history
  • Loading branch information
rnowling committed Sep 22, 2014
1 parent 56dae30 commit c0cc643
Show file tree
Hide file tree
Showing 4 changed files with 101 additions and 4 deletions.
15 changes: 15 additions & 0 deletions docs/mllib-feature-extraction.md
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,21 @@ tf.cache()
val idf = new IDF().fit(tf)
val tfidf: RDD[Vector] = idf.transform(tf)
{% endhighlight %}

MLLib's IDF implementation provides an option for ignoring terms which occur in less than a
minimum number of documents. In such cases, the IDF for these terms is set to 0. This feature
can be used by passing the `minimumOccurence` value to the IDF constructor.

{% highlight scala %}
import org.apache.spark.mllib.feature.IDF

// ... continue from the previous example
tf.cache()
val idf = new IDF(minimumOccurence=2).fit(tf)
val tfidf: RDD[Vector] = idf.transform(tf)
{% endhighlight %}


</div>
</div>

Expand Down
34 changes: 30 additions & 4 deletions mllib/src/main/scala/org/apache/spark/mllib/feature/IDF.scala
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,20 @@ import org.apache.spark.rdd.RDD
* Inverse document frequency (IDF).
* The standard formulation is used: `idf = log((m + 1) / (d(t) + 1))`, where `m` is the total
* number of documents and `d(t)` is the number of documents that contain term `t`.
*
* This implementation supports filtering out terms which do not appear in a minimum number
* of documents (controlled by the variable minimumOccurence). For terms that are not in
* at least `minimumOccurence` documents, the IDF is found as 0, resulting in TF-IDFs of 0.
*
* @param minimumOccurence minimum of documents in which a term
* should appear for filtering
*
*
*/
@Experimental
class IDF {
class IDF(minimumOccurence: Long) {

def this() = this(0L)

// TODO: Allow different IDF formulations.

Expand All @@ -41,7 +52,7 @@ class IDF {
* @param dataset an RDD of term frequency vectors
*/
def fit(dataset: RDD[Vector]): IDFModel = {
val idf = dataset.treeAggregate(new IDF.DocumentFrequencyAggregator)(
val idf = dataset.treeAggregate(new IDF.DocumentFrequencyAggregator(minimumOccurence=minimumOccurence))(
seqOp = (df, v) => df.add(v),
combOp = (df1, df2) => df1.merge(df2)
).idf()
Expand All @@ -60,7 +71,7 @@ class IDF {
private object IDF {

/** Document frequency aggregator. */
class DocumentFrequencyAggregator extends Serializable {
class DocumentFrequencyAggregator(minimumOccurence: Long) extends Serializable {

/** number of documents */
private var m = 0L
Expand Down Expand Up @@ -123,7 +134,17 @@ private object IDF {
val inv = new Array[Double](n)
var j = 0
while (j < n) {
inv(j) = math.log((m + 1.0)/ (df(j) + 1.0))
/*
* If the term is not present in the minimum
* number of documents, set IDF to 0. This
* will cause multiplication in IDFModel to
* set TF-IDF to 0.
*/
if(df(j) >= minimumOccurence) {
inv(j) = math.log((m + 1.0)/ (df(j) + 1.0))
} else {
inv(j) = 0.0
}
j += 1
}
Vectors.dense(inv)
Expand All @@ -140,6 +161,11 @@ class IDFModel private[mllib] (val idf: Vector) extends Serializable {

/**
* Transforms term frequency (TF) vectors to TF-IDF vectors.
*
* If minimumOccurence was set for the IDF calculation,
* the terms which occur in fewer than minimumOccurence
* documents will have an entry of 0.
*
* @param dataset an RDD of term frequency vectors
* @return an RDD of TF-IDF vectors
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@
import org.junit.Test;
import com.google.common.collect.Lists;

import java.lang.reflect.Method;

import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.mllib.linalg.Vector;
Expand Down Expand Up @@ -63,4 +65,24 @@ public void tfIdf() {
Assert.assertEquals(0.0, v.apply(indexOfThis), 1e-15);
}
}

@Test
public void tfIdfMinimumOccurence() {
// The tests are to check Java compatibility.
HashingTF tf = new HashingTF();
JavaRDD<ArrayList<String>> documents = sc.parallelize(Lists.newArrayList(
Lists.newArrayList("this is a sentence".split(" ")),
Lists.newArrayList("this is another sentence".split(" ")),
Lists.newArrayList("this is still a sentence".split(" "))), 2);
JavaRDD<Vector> termFreqs = tf.transform(documents);
termFreqs.collect();
IDF idf = new IDF(2);
JavaRDD<Vector> tfIdfs = idf.fit(termFreqs).transform(termFreqs);
List<Vector> localTfIdfs = tfIdfs.collect();
int indexOfThis = tf.indexOf("this");
for (Vector v: localTfIdfs) {
Assert.assertEquals(0.0, v.apply(indexOfThis), 1e-15);
}
}

}
34 changes: 34 additions & 0 deletions mllib/src/test/scala/org/apache/spark/mllib/feature/IDFSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -54,4 +54,38 @@ class IDFSuite extends FunSuite with LocalSparkContext {
assert(tfidf2.indices === Array(1))
assert(tfidf2.values(0) ~== (1.0 * expected(1)) absTol 1e-12)
}

test("idf minimum occurence filtering") {
val n = 4
val localTermFrequencies = Seq(
Vectors.sparse(n, Array(1, 3), Array(1.0, 2.0)),
Vectors.dense(0.0, 1.0, 2.0, 3.0),
Vectors.sparse(n, Array(1), Array(1.0))
)
val m = localTermFrequencies.size
val termFrequencies = sc.parallelize(localTermFrequencies, 2)
val idf = new IDF(minimumOccurence=1L)
val model = idf.fit(termFrequencies)
val expected = Vectors.dense(Array(0, 3, 1, 2).map { x =>
if(x > 0) {
math.log((m.toDouble + 1.0) / (x + 1.0))
} else {
0
}
})
assert(model.idf ~== expected absTol 1e-12)
val tfidf = model.transform(termFrequencies).cache().zipWithIndex().map(_.swap).collectAsMap()
assert(tfidf.size === 3)
val tfidf0 = tfidf(0L).asInstanceOf[SparseVector]
assert(tfidf0.indices === Array(1, 3))
assert(Vectors.dense(tfidf0.values) ~==
Vectors.dense(1.0 * expected(1), 2.0 * expected(3)) absTol 1e-12)
val tfidf1 = tfidf(1L).asInstanceOf[DenseVector]
assert(Vectors.dense(tfidf1.values) ~==
Vectors.dense(0.0, 1.0 * expected(1), 2.0 * expected(2), 3.0 * expected(3)) absTol 1e-12)
val tfidf2 = tfidf(2L).asInstanceOf[SparseVector]
assert(tfidf2.indices === Array(1))
assert(tfidf2.values(0) ~== (1.0 * expected(1)) absTol 1e-12)
}

}

0 comments on commit c0cc643

Please sign in to comment.