Skip to content

Commit

Permalink
[SPARK-5520][MLlib] Make FP-Growth implementation take generic item t…
Browse files Browse the repository at this point in the history
…ypes (WIP)

Make FPGrowth.run API take generic item types:
`def run[Item: ClassTag, Basket <: Iterable[Item]](data: RDD[Basket]): FPGrowthModel[Item]`
so that user can invoke it by run[String, Seq[String]], run[Int, Seq[Int]], run[Int, List[Int]], etc.

Scala part is done, while java part is still in progress

Author: Jacky Li <jacky.likun@huawei.com>
Author: Jacky Li <jackylk@users.noreply.github.com>
Author: Xiangrui Meng <meng@databricks.com>

Closes #4340 from jackylk/SPARK-5520-WIP and squashes the following commits:

f5acf84 [Jacky Li] Merge pull request #2 from mengxr/SPARK-5520
63073d0 [Xiangrui Meng] update to make generic FPGrowth Java-friendly
737d8bb [Jacky Li] fix scalastyle
793f85c [Jacky Li] add Java test case
7783351 [Jacky Li] add generic support in FPGrowth
  • Loading branch information
jackylk authored and mengxr committed Feb 4, 2015
1 parent 068c0e2 commit e380d2d
Show file tree
Hide file tree
Showing 3 changed files with 170 additions and 15 deletions.
50 changes: 36 additions & 14 deletions mllib/src/main/scala/org/apache/spark/mllib/fpm/FPGrowth.scala
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,31 @@
package org.apache.spark.mllib.fpm

import java.{util => ju}
import java.lang.{Iterable => JavaIterable}

import scala.collection.mutable
import scala.collection.JavaConverters._
import scala.reflect.ClassTag

import org.apache.spark.{SparkException, HashPartitioner, Logging, Partitioner}
import org.apache.spark.{HashPartitioner, Logging, Partitioner, SparkException}
import org.apache.spark.api.java.{JavaPairRDD, JavaRDD}
import org.apache.spark.api.java.JavaSparkContext.fakeClassTag
import org.apache.spark.rdd.RDD
import org.apache.spark.storage.StorageLevel

class FPGrowthModel(val freqItemsets: RDD[(Array[String], Long)]) extends Serializable
/**
* Model trained by [[FPGrowth]], which holds frequent itemsets.
* @param freqItemsets frequent itemset, which is an RDD of (itemset, frequency) pairs
* @tparam Item item type
*/
class FPGrowthModel[Item: ClassTag](
val freqItemsets: RDD[(Array[Item], Long)]) extends Serializable {

/** Returns frequent itemsets as a [[org.apache.spark.api.java.JavaPairRDD]]. */
def javaFreqItemsets(): JavaPairRDD[Array[Item], java.lang.Long] = {
JavaPairRDD.fromRDD(freqItemsets).asInstanceOf[JavaPairRDD[Array[Item], java.lang.Long]]
}
}

/**
* This class implements Parallel FP-growth algorithm to do frequent pattern matching on input data.
Expand Down Expand Up @@ -69,7 +86,7 @@ class FPGrowth private (
* @param data input data set, each element contains a transaction
* @return an [[FPGrowthModel]]
*/
def run(data: RDD[Array[String]]): FPGrowthModel = {
def run[Item: ClassTag](data: RDD[Array[Item]]): FPGrowthModel[Item] = {
if (data.getStorageLevel == StorageLevel.NONE) {
logWarning("Input data is not cached.")
}
Expand All @@ -82,19 +99,24 @@ class FPGrowth private (
new FPGrowthModel(freqItemsets)
}

def run[Item, Basket <: JavaIterable[Item]](data: JavaRDD[Basket]): FPGrowthModel[Item] = {
implicit val tag = fakeClassTag[Item]
run(data.rdd.map(_.asScala.toArray))
}

/**
* Generates frequent items by filtering the input data using minimal support level.
* @param minCount minimum count for frequent itemsets
* @param partitioner partitioner used to distribute items
* @return array of frequent pattern ordered by their frequencies
*/
private def genFreqItems(
data: RDD[Array[String]],
private def genFreqItems[Item: ClassTag](
data: RDD[Array[Item]],
minCount: Long,
partitioner: Partitioner): Array[String] = {
partitioner: Partitioner): Array[Item] = {
data.flatMap { t =>
val uniq = t.toSet
if (t.length != uniq.size) {
if (t.size != uniq.size) {
throw new SparkException(s"Items in a transaction must be unique but got ${t.toSeq}.")
}
t
Expand All @@ -114,11 +136,11 @@ class FPGrowth private (
* @param partitioner partitioner used to distribute transactions
* @return an RDD of (frequent itemset, count)
*/
private def genFreqItemsets(
data: RDD[Array[String]],
private def genFreqItemsets[Item: ClassTag](
data: RDD[Array[Item]],
minCount: Long,
freqItems: Array[String],
partitioner: Partitioner): RDD[(Array[String], Long)] = {
freqItems: Array[Item],
partitioner: Partitioner): RDD[(Array[Item], Long)] = {
val itemToRank = freqItems.zipWithIndex.toMap
data.flatMap { transaction =>
genCondTransactions(transaction, itemToRank, partitioner)
Expand All @@ -139,9 +161,9 @@ class FPGrowth private (
* @param partitioner partitioner used to distribute transactions
* @return a map of (target partition, conditional transaction)
*/
private def genCondTransactions(
transaction: Array[String],
itemToRank: Map[String, Int],
private def genCondTransactions[Item: ClassTag](
transaction: Array[Item],
itemToRank: Map[Item, Int],
partitioner: Partitioner): mutable.Map[Int, Array[Int]] = {
val output = mutable.Map.empty[Int, Array[Int]]
// Filter the basket by frequent items pattern and sort their ranks.
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.mllib.fpm;

import java.io.Serializable;
import java.util.ArrayList;

import org.junit.After;
import org.junit.Before;
import org.junit.Test;
import com.google.common.collect.Lists;
import static org.junit.Assert.*;

import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;

public class JavaFPGrowthSuite implements Serializable {
private transient JavaSparkContext sc;

@Before
public void setUp() {
sc = new JavaSparkContext("local", "JavaFPGrowth");
}

@After
public void tearDown() {
sc.stop();
sc = null;
}

@Test
public void runFPGrowth() {

@SuppressWarnings("unchecked")
JavaRDD<ArrayList<String>> rdd = sc.parallelize(Lists.newArrayList(
Lists.newArrayList("r z h k p".split(" ")),
Lists.newArrayList("z y x w v u t s".split(" ")),
Lists.newArrayList("s x o n r".split(" ")),
Lists.newArrayList("x z y m t s q e".split(" ")),
Lists.newArrayList("z".split(" ")),
Lists.newArrayList("x z y r q t p".split(" "))), 2);

FPGrowth fpg = new FPGrowth();

FPGrowthModel<String> model6 = fpg
.setMinSupport(0.9)
.setNumPartitions(1)
.run(rdd);
assertEquals(0, model6.javaFreqItemsets().count());

FPGrowthModel<String> model3 = fpg
.setMinSupport(0.5)
.setNumPartitions(2)
.run(rdd);
assertEquals(18, model3.javaFreqItemsets().count());

FPGrowthModel<String> model2 = fpg
.setMinSupport(0.3)
.setNumPartitions(4)
.run(rdd);
assertEquals(54, model2.javaFreqItemsets().count());

FPGrowthModel<String> model1 = fpg
.setMinSupport(0.1)
.setNumPartitions(8)
.run(rdd);
assertEquals(625, model1.javaFreqItemsets().count());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@ import org.apache.spark.mllib.util.MLlibTestSparkContext

class FPGrowthSuite extends FunSuite with MLlibTestSparkContext {

test("FP-Growth") {

test("FP-Growth using String type") {
val transactions = Seq(
"r z h k p",
"z y x w v u t s",
Expand Down Expand Up @@ -70,4 +71,52 @@ class FPGrowthSuite extends FunSuite with MLlibTestSparkContext {
.run(rdd)
assert(model1.freqItemsets.count() === 625)
}

test("FP-Growth using Int type") {
val transactions = Seq(
"1 2 3",
"1 2 3 4",
"5 4 3 2 1",
"6 5 4 3 2 1",
"2 4",
"1 3",
"1 7")
.map(_.split(" ").map(_.toInt).toArray)
val rdd = sc.parallelize(transactions, 2).cache()

val fpg = new FPGrowth()

val model6 = fpg
.setMinSupport(0.9)
.setNumPartitions(1)
.run(rdd)
assert(model6.freqItemsets.count() === 0)

val model3 = fpg
.setMinSupport(0.5)
.setNumPartitions(2)
.run(rdd)
assert(model3.freqItemsets.first()._1.getClass === Array(1).getClass,
"frequent itemsets should use primitive arrays")
val freqItemsets3 = model3.freqItemsets.collect().map { case (items, count) =>
(items.toSet, count)
}
val expected = Set(
(Set(1), 6L), (Set(2), 5L), (Set(3), 5L), (Set(4), 4L),
(Set(1, 2), 4L), (Set(1, 3), 5L), (Set(2, 3), 4L),
(Set(2, 4), 4L), (Set(1, 2, 3), 4L))
assert(freqItemsets3.toSet === expected)

val model2 = fpg
.setMinSupport(0.3)
.setNumPartitions(4)
.run(rdd)
assert(model2.freqItemsets.count() === 15)

val model1 = fpg
.setMinSupport(0.1)
.setNumPartitions(8)
.run(rdd)
assert(model1.freqItemsets.count() === 65)
}
}

0 comments on commit e380d2d

Please sign in to comment.