diff --git a/mllib/src/main/scala/org/apache/spark/mllib/fpm/FPGrowth.scala b/mllib/src/main/scala/org/apache/spark/mllib/fpm/FPGrowth.scala index 2ae6e9c26d86b..6a0378e611215 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/fpm/FPGrowth.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/fpm/FPGrowth.scala @@ -36,7 +36,7 @@ import org.apache.spark.storage.StorageLevel * :: Experimental :: * * Model trained by [[FPGrowth]], which holds frequent itemsets. - * @param freqItemsets frequent itemset, which is an RDD of [[FreqItemset]] + * @param freqItemsets frequent itemsets, which is an RDD of [[FreqItemset]] * @tparam Item item type */ @Experimental @@ -63,11 +63,11 @@ class FPGrowthModel[Item: ClassTag](val freqItemsets: RDD[FreqItemset[Item]]) ex class FPGrowth private ( private var minSupport: Double, private var numPartitions: Int, - private var mineSequences: Boolean) extends Logging with Serializable { + private var ordered: Boolean) extends Logging with Serializable { /** * Constructs a default instance with default parameters {minSupport: `0.3`, numPartitions: same - * as the input data, mineSequences: `false`}. + * as the input data, ordered: `false`}. */ def this() = this(0.3, -1, false) @@ -88,10 +88,11 @@ class FPGrowth private ( } /** - * Indicates whether to mine item-sets or item-sequences (default: false, mine item-sets). + * Indicates whether to mine itemsets (unordered) or sequences (ordered) (default: false, mine + * itemsets). */ - def setMineSequences(value: Boolean): this.type = { - this.mineSequences = value + def setOrdered(ordered: Boolean): this.type = { + this.ordered = ordered this } @@ -164,7 +165,7 @@ class FPGrowth private ( .flatMap { case (part, tree) => tree.extract(minCount, x => partitioner.getPartition(x) == part) }.map { case (ranks, count) => - new FreqItemset(ranks.map(i => freqItems(i)).toArray, count) + new FreqItemset(ranks.map(i => freqItems(i)).toArray, count, ordered) } } @@ -182,7 +183,7 @@ class FPGrowth private ( val output = mutable.Map.empty[Int, Array[Int]] // Filter the basket by frequent items pattern val filtered = transaction.flatMap(itemToRank.get) - if (!this.mineSequences) { // Ignore ordering if not mining sequences + if (!this.ordered) { ju.Arrays.sort(filtered) } // Generate conditional transactions @@ -210,9 +211,11 @@ object FPGrowth { * Frequent itemset. * @param items items in this itemset. Java users should call [[FreqItemset#javaItems]] instead. * @param freq frequency + * @param ordered indicates if items represents an itemset (false) or sequence (true) * @tparam Item item type */ - class FreqItemset[Item](val items: Array[Item], val freq: Long) extends Serializable { + class FreqItemset[Item](val items: Array[Item], val freq: Long, val ordered: Boolean) + extends Serializable { /** * Returns items in a Java List. diff --git a/mllib/src/test/scala/org/apache/spark/mllib/fpm/FPGrowthSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/fpm/FPGrowthSuite.scala index d8887ea2c4ce4..36f381febf688 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/fpm/FPGrowthSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/fpm/FPGrowthSuite.scala @@ -22,7 +22,7 @@ import org.apache.spark.mllib.util.MLlibTestSparkContext class FPGrowthSuite extends SparkFunSuite with MLlibTestSparkContext { - test("FP-Growth frequent item-sets using String type") { + test("FP-Growth frequent itemsets using String type") { val transactions = Seq( "r z h k p", "z y x w v u t s", @@ -38,14 +38,14 @@ class FPGrowthSuite extends SparkFunSuite with MLlibTestSparkContext { val model6 = fpg .setMinSupport(0.9) .setNumPartitions(1) - .setMineSequences(false) + .setOrdered(false) .run(rdd) assert(model6.freqItemsets.count() === 0) val model3 = fpg .setMinSupport(0.5) .setNumPartitions(2) - .setMineSequences(false) + .setOrdered(false) .run(rdd) val freqItemsets3 = model3.freqItemsets.collect().map { itemset => (itemset.items.toSet, itemset.freq) @@ -63,19 +63,19 @@ class FPGrowthSuite extends SparkFunSuite with MLlibTestSparkContext { val model2 = fpg .setMinSupport(0.3) .setNumPartitions(4) - .setMineSequences(false) + .setOrdered(false) .run(rdd) assert(model2.freqItemsets.count() === 54) val model1 = fpg .setMinSupport(0.1) .setNumPartitions(8) - .setMineSequences(false) + .setOrdered(false) .run(rdd) assert(model1.freqItemsets.count() === 625) } - test("FP-Growth frequent item-sequences using String type"){ + test("FP-Growth frequent sequences using String type"){ val transactions = Seq( "r z h k p", "z y x w v u t s", @@ -91,21 +91,21 @@ class FPGrowthSuite extends SparkFunSuite with MLlibTestSparkContext { val model1 = fpg .setMinSupport(0.5) .setNumPartitions(2) - .setMineSequences(true) + .setOrdered(true) .run(rdd) val expected = Set( - (Set("r"), 3L), (Set("s"), 3L), (Set("t"), 3L), (Set("x"), 4L), (Set("y"), 3L), - (Set("z"), 5L), (Set("z", "y"), 3L), (Set("x", "t"), 3L), (Set("y", "t"), 3L), - (Set("z", "t"), 3L), (Set("z", "y", "t"), 3L) + (List("r"), 3L), (List("s"), 3L), (List("t"), 3L), (List("x"), 4L), (List("y"), 3L), + (List("z"), 5L), (List("z", "y"), 3L), (List("x", "t"), 3L), (List("y", "t"), 3L), + (List("z", "t"), 3L), (List("z", "y", "t"), 3L) ) - val freqItemsets1 = model1.freqItemsets.collect().map { itemset => - (itemset.items.toSet, itemset.freq) + val freqItemseqs1 = model1.freqItemsets.collect().map { itemset => + (itemset.items.toList, itemset.freq) }.toSet - assert(freqItemsets1 === expected) + assert(freqItemseqs1 === expected) } - test("FP-Growth frequent item-sets using Int type") { + test("FP-Growth frequent itemsets using Int type") { val transactions = Seq( "1 2 3", "1 2 3 4", @@ -122,14 +122,14 @@ class FPGrowthSuite extends SparkFunSuite with MLlibTestSparkContext { val model6 = fpg .setMinSupport(0.9) .setNumPartitions(1) - .setMineSequences(false) + .setOrdered(false) .run(rdd) assert(model6.freqItemsets.count() === 0) val model3 = fpg .setMinSupport(0.5) .setNumPartitions(2) - .setMineSequences(false) + .setOrdered(false) .run(rdd) assert(model3.freqItemsets.first().items.getClass === Array(1).getClass, "frequent itemsets should use primitive arrays") @@ -145,14 +145,14 @@ class FPGrowthSuite extends SparkFunSuite with MLlibTestSparkContext { val model2 = fpg .setMinSupport(0.3) .setNumPartitions(4) - .setMineSequences(false) + .setOrdered(false) .run(rdd) assert(model2.freqItemsets.count() === 15) val model1 = fpg .setMinSupport(0.1) .setNumPartitions(8) - .setMineSequences(false) + .setOrdered(false) .run(rdd) assert(model1.freqItemsets.count() === 65) }