diff --git a/mllib/src/main/scala/org/apache/spark/mllib/fpm/LocalPrefixSpan.scala b/mllib/src/main/scala/org/apache/spark/mllib/fpm/LocalPrefixSpan.scala index 0ea792081086d..45c19f521eadc 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/fpm/LocalPrefixSpan.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/fpm/LocalPrefixSpan.scala @@ -41,54 +41,166 @@ private[fpm] object LocalPrefixSpan extends Logging with Serializable { maxPatternLength: Int, prefixes: List[Int], database: Iterable[Array[Int]]): Iterator[(List[Int], Long)] = { - if (prefixes.length == maxPatternLength || database.isEmpty) return Iterator.empty - val frequentItemAndCounts = getFreqItemAndCounts(minCount, database) - val filteredDatabase = database.map(x => x.filter(frequentItemAndCounts.contains)) - frequentItemAndCounts.iterator.flatMap { case (item, count) => - val newPrefixes = item :: prefixes - val newProjected = project(filteredDatabase, item) - Iterator.single((newPrefixes, count)) ++ - run(minCount, maxPatternLength, newPrefixes, newProjected) + if (prefixes.count(_ != -1) == maxPatternLength || database.isEmpty) return Iterator.empty + val frequentPrefixAndCounts = getFreqPrefixAndCounts(minCount, prefixes, database) + frequentPrefixAndCounts.iterator.flatMap { case (prefix, count) => + val newProjected = project(database, prefix) + Iterator.single((prefix, count)) ++ + run(minCount, maxPatternLength, prefix, newProjected) } } /** - * Calculate suffix sequence immediately after the first occurrence of an item. - * @param item item to get suffix after + * Calculate suffix sequence immediately after the first occurrence of a prefix. + * @param prefix prefix to get suffix after * @param sequence sequence to extract suffix from * @return suffix sequence */ - def getSuffix(item: Int, sequence: Array[Int]): Array[Int] = { - val index = sequence.indexOf(item) + def getSuffix(prefix: List[Int], sequence: Array[Int]): (Boolean, Array[Int]) = { + val element = getLastElement(prefix) + if (sequence.apply(0) != -3) { + if (element.length == 1) { + getSingleItemElementSuffix(element, sequence) + } else { + getMultiItemsElementSuffix(element, sequence) + } + } else { + if (element.length == 1) { + val firstElemPos = sequence.indexOf(-1) + if (firstElemPos == -1) { + (false, Array()) + } else { + getSingleItemElementSuffix(element, sequence.drop(firstElemPos + 1)) + } + } else { + val newSequence = element.take(element.length - 1) ++ sequence.drop(1) + getMultiItemsElementSuffix(element, newSequence) + } + } + } + + private def getLastElement(prefix: List[Int]): Array[Int] = { + val pos = prefix.indexOf(-1) + if (pos == -1) { + prefix.reverse.toArray + } else { + prefix.take(pos).reverse.toArray + } + } + + private def getSingleItemElementSuffix( + element: Array[Int], + sequence: Array[Int]): (Boolean, Array[Int]) = { + val index = sequence.indexOf(element.apply(0)) if (index == -1) { - Array() + (false, Array()) + } else if (index == sequence.length - 1) { + (true, Array()) + } else if (sequence.apply(index + 1) == -1) { + (true, sequence.drop(index + 2)) + } else { + (true, -3 +: sequence.drop(index + 1)) + } + } + + private def getMultiItemsElementSuffix( + element: Array[Int], + sequence: Array[Int]): (Boolean, Array[Int]) = { + var seqPos = 0 + var found = false + while (seqPos < sequence.length && !found) { + var elemPos = 0 + while (!found && elemPos < element.length && + seqPos < sequence.length && sequence.apply(seqPos) != -1 ) { + if (element.apply(elemPos) == sequence.apply(seqPos)) { + elemPos += 1 + seqPos += 1 + } else { + seqPos += 1 + } + found = elemPos == element.length + } + if (!found) seqPos += 1 + } + if (found) { + if (sequence.apply(seqPos) == -1) { + (true, sequence.drop(seqPos + 1)) + } else { + (true, -3 +: sequence.drop(seqPos)) + } } else { - sequence.drop(index + 1) + (false, Array()) } } - def project(database: Iterable[Array[Int]], prefix: Int): Iterable[Array[Int]] = { + def project(database: Iterable[Array[Int]], prefix: List[Int]): Iterable[Array[Int]] = { database - .map(getSuffix(prefix, _)) + .map(getSuffix(prefix, _)._2) .filter(_.nonEmpty) } /** - * Generates frequent items by filtering the input data using minimal count level. + * Generates frequent prefix by filtering the input data using minimal count level. * @param minCount the minimum count for an item to be frequent + * @param prefix the minimum count for an item to be frequent * @param database database of sequences * @return freq item to count map */ - private def getFreqItemAndCounts( + private def getFreqPrefixAndCounts( minCount: Long, - database: Iterable[Array[Int]]): mutable.Map[Int, Long] = { + prefix: List[Int], + database: Iterable[Array[Int]]): mutable.Map[List[Int], Long] = { // TODO: use PrimitiveKeyOpenHashMap - val counts = mutable.Map[Int, Long]().withDefaultValue(0L) + + // get frequent items + val freqItems = database + .flatMap(_.distinct.filter(x => x != -1 && x != -3)) + .groupBy(x => x) + .mapValues(_.size) + .filter(_._2 >= minCount) + .map(_._1) + if (freqItems.isEmpty) return mutable.Map[List[Int], Long]() + + // get prefixes and counts + val singleItemCounts = mutable.Map[Int, Long]().withDefaultValue(0L) + val multiItemsCounts = mutable.Map[Int, Long]().withDefaultValue(0L) + val prefixLastElement = getLastElement(prefix) database.foreach { sequence => - sequence.distinct.foreach { item => - counts(item) += 1L + if (sequence.apply(0) != -3) { + freqItems.foreach { item => + if (getSingleItemElementSuffix(Array(item), sequence)._1) { + singleItemCounts(item) += 1 + } + if (prefixLastElement.nonEmpty && + getMultiItemsElementSuffix(prefixLastElement :+ item, sequence)._1) { + multiItemsCounts(item) += 1 + } + } + } else { + val firstElemPos = sequence.indexOf(-1) + if (firstElemPos != -1) { + val newSequence = sequence.drop(firstElemPos + 1) + freqItems.foreach { item => + if (getSingleItemElementSuffix(Array(item), newSequence)._1) { + singleItemCounts(item) += 1 + } + } + } + val newSequence = prefixLastElement ++ sequence.drop(1) + freqItems.foreach { item => + if (prefixLastElement.nonEmpty && + getMultiItemsElementSuffix(prefixLastElement :+ item, newSequence)._1) { + multiItemsCounts(item) += 1 + } + } } } - counts.filter(_._2 >= minCount) + + if (prefix.nonEmpty) { + singleItemCounts.filter(_._2 >= minCount).map(x => (x._1 :: (-1 :: prefix), x._2)) ++ + multiItemsCounts.filter(_._2 >= minCount).map(x => (x._1 :: prefix, x._2)) + } else { + singleItemCounts.filter(_._2 >= minCount).map(x => (List(x._1), x._2)) + } } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala b/mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala index e6752332cdeeb..3c48c6a0e5be9 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala @@ -17,8 +17,6 @@ package org.apache.spark.mllib.fpm -import scala.collection.mutable.ArrayBuffer - import org.apache.spark.Logging import org.apache.spark.annotation.Experimental import org.apache.spark.rdd.RDD @@ -50,7 +48,7 @@ class PrefixSpan private ( * projected database exceeds this size, another iteration of distributed PrefixSpan is run. */ // TODO: make configurable with a better default value, 10000 may be too small - private val maxLocalProjDBSize: Long = 10000 + private val maxLocalProjDBSize: Long = 32000000L /** * Constructs a default instance with default parameters @@ -108,20 +106,20 @@ class PrefixSpan private ( // (Frequent items -> number of occurrences, all items here satisfy the `minSupport` threshold val freqItemCounts = sequences - .flatMap(seq => seq.distinct.map(item => (item, 1L))) + .flatMap(seq => seq.distinct.filter(_ != -1).map(item => (item, 1L))) .reduceByKey(_ + _) .filter(_._2 >= minCount) .collect() // Pairs of (length 1 prefix, suffix consisting of frequent items) - val itemSuffixPairs = { + val prefixSuffixPairs = { val freqItems = freqItemCounts.map(_._1).toSet sequences.flatMap { seq => - val filteredSeq = seq.filter(freqItems.contains(_)) + val filteredSeq = seq.filter(item => freqItems.contains(item) || item == -1) freqItems.flatMap { item => - val candidateSuffix = LocalPrefixSpan.getSuffix(item, filteredSeq) + val candidateSuffix = LocalPrefixSpan.getSuffix(List(item), filteredSeq)._2 candidateSuffix match { - case suffix if !suffix.isEmpty => Some((List(item), suffix)) + case suffix if suffix.nonEmpty => Some((List(item), suffix)) case _ => None } } @@ -133,11 +131,13 @@ class PrefixSpan private ( var resultsAccumulator = freqItemCounts.map(x => (List(x._1), x._2)) // Remaining work to be locally and distributively processed respectfully - var (pairsForLocal, pairsForDistributed) = partitionByProjDBSize(itemSuffixPairs) + var (pairsForLocal, pairsForDistributed) = partitionByProjDBSize(prefixSuffixPairs) // Continue processing until no pairs for distributed processing remain (i.e. all prefixes have // projected database sizes <= `maxLocalProjDBSize`) - while (pairsForDistributed.count() != 0) { + var patternLength = 1 + while (pairsForDistributed.count() != 0 && patternLength < maxPatternLength) { + patternLength += 1 val (nextPatternAndCounts, nextPrefixSuffixPairs) = extendPrefixes(minCount, pairsForDistributed) pairsForDistributed.unpersist() @@ -153,7 +153,7 @@ class PrefixSpan private ( minCount, sc.parallelize(pairsForLocal, 1).groupByKey()) (sc.parallelize(resultsAccumulator, 1) ++ remainingResults) - .map { case (pattern, count) => (pattern.toArray, count) } + .map { case (pattern, count) => (pattern.reverse.toArray, count) } } @@ -195,36 +195,41 @@ class PrefixSpan private ( // (length N prefix, item from suffix) pairs and their corresponding number of occurrences // Every (prefix :+ suffix) is guaranteed to have support exceeding `minSupport` val prefixItemPairAndCounts = prefixSuffixPairs - .flatMap { case (prefix, suffix) => suffix.distinct.map(y => ((prefix, y), 1L)) } + .flatMap { case (prefix, suffix) => + suffix.distinct.filter(item => item != -1 && item != -3).map(y => ((prefix, y), 1L)) } .reduceByKey(_ + _) .filter(_._2 >= minCount) - // Map from prefix to set of possible next items from suffix - val prefixToNextItems = prefixItemPairAndCounts + // Map from prefix to set of possible next prefix from suffix + val prefixToNextPrefixes = prefixItemPairAndCounts .keys .groupByKey() - .mapValues(_.toSet) + .map { case (prefix, items) => + (prefix, items.flatMap(item => Array(item :: (-1 :: prefix), item :: prefix)).toSet) } .collect() .toMap - - // Frequent patterns with length N+1 and their corresponding counts - val extendedPrefixAndCounts = prefixItemPairAndCounts - .map { case ((prefix, item), count) => (item :: prefix, count) } - // Remaining work, all prefixes will have length N+1 - val extendedPrefixAndSuffix = prefixSuffixPairs - .filter(x => prefixToNextItems.contains(x._1)) + val extendedPrefixAndSuffixWithFlags = prefixSuffixPairs .flatMap { case (prefix, suffix) => - val frequentNextItems = prefixToNextItems(prefix) - val filteredSuffix = suffix.filter(frequentNextItems.contains(_)) - frequentNextItems.flatMap { item => - LocalPrefixSpan.getSuffix(item, filteredSuffix) match { - case suffix if !suffix.isEmpty => Some(item :: prefix, suffix) - case _ => None + if (prefixToNextPrefixes.contains(prefix)) { + val frequentNextPrefixes = prefixToNextPrefixes(prefix) + frequentNextPrefixes.map { nextPrefix => + val suffixWithFlag = LocalPrefixSpan.getSuffix(nextPrefix, suffix) + (nextPrefix, if (suffixWithFlag._1) 1L else 0L, suffixWithFlag._2) } + } else { + None } - } + }.persist(StorageLevel.MEMORY_AND_DISK) + val extendedPrefixAndCounts = extendedPrefixAndSuffixWithFlags + .map(x => (x._1, x._2)) + .reduceByKey(_ + _) + .filter(_._2 >= minCount) + val extendedPrefixAndSuffix = extendedPrefixAndSuffixWithFlags + .map(x => (x._1, x._3)) + .filter(_._2.nonEmpty) + extendedPrefixAndSuffixWithFlags.unpersist() (extendedPrefixAndCounts, extendedPrefixAndSuffix) } @@ -240,9 +245,9 @@ class PrefixSpan private ( data: RDD[(List[Int], Iterable[Array[Int]])]): RDD[(List[Int], Long)] = { data.flatMap { case (prefix, projDB) => - LocalPrefixSpan.run(minCount, maxPatternLength, prefix.toList.reverse, projDB) + LocalPrefixSpan.run(minCount, maxPatternLength, prefix, projDB) .map { case (pattern: List[Int], count: Long) => - (pattern.reverse, count) + (pattern, count) } } } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/fpm/PrefixSpanSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/fpm/PrefixSpanSuite.scala index 6dd2dc926acc5..2b7d779bd5ddb 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/fpm/PrefixSpanSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/fpm/PrefixSpanSuite.scala @@ -21,7 +21,7 @@ import org.apache.spark.mllib.util.MLlibTestSparkContext class PrefixSpanSuite extends SparkFunSuite with MLlibTestSparkContext { - test("PrefixSpan using Integer type") { + test("PrefixSpan using Integer type (An element contains only one item)") { /* library("arulesSequences") @@ -35,12 +35,12 @@ class PrefixSpanSuite extends SparkFunSuite with MLlibTestSparkContext { */ val sequences = Array( - Array(1, 3, 4, 5), - Array(2, 3, 1), - Array(2, 4, 1), - Array(3, 1, 3, 4, 5), - Array(3, 4, 4, 3), - Array(6, 5, 3)) + Array(1, -1, 3, -1, 4, -1, 5), + Array(2, -1, 3, -1, 1), + Array(2, -1, 4, -1, 1), + Array(3, -1, 1, -1, 3, -1, 4, -1, 5), + Array(3, -1, 4, -1, 4, -1, 3), + Array(6, -1, 5, -1, 3)) val rdd = sc.parallelize(sequences, 2).cache() @@ -50,23 +50,23 @@ class PrefixSpanSuite extends SparkFunSuite with MLlibTestSparkContext { val result1 = prefixspan.run(rdd) val expectedValue1 = Array( (Array(1), 4L), - (Array(1, 3), 2L), - (Array(1, 3, 4), 2L), - (Array(1, 3, 4, 5), 2L), - (Array(1, 3, 5), 2L), - (Array(1, 4), 2L), - (Array(1, 4, 5), 2L), - (Array(1, 5), 2L), + (Array(1, -1, 3), 2L), + (Array(1, -1, 3, -1, 4), 2L), + (Array(1, -1, 3, -1, 4, -1, 5), 2L), + (Array(1, -1, 3, -1, 5), 2L), + (Array(1, -1, 4), 2L), + (Array(1, -1, 4, -1, 5), 2L), + (Array(1, -1, 5), 2L), (Array(2), 2L), - (Array(2, 1), 2L), + (Array(2, -1, 1), 2L), (Array(3), 5L), - (Array(3, 1), 2L), - (Array(3, 3), 2L), - (Array(3, 4), 3L), - (Array(3, 4, 5), 2L), - (Array(3, 5), 2L), + (Array(3, -1, 1), 2L), + (Array(3, -1, 3), 2L), + (Array(3, -1, 4), 3L), + (Array(3, -1, 4, -1, 5), 2L), + (Array(3, -1, 5), 2L), (Array(4), 4L), - (Array(4, 5), 2L), + (Array(4, -1, 5), 2L), (Array(5), 3L) ) assert(compareResults(expectedValue1, result1.collect())) @@ -76,7 +76,7 @@ class PrefixSpanSuite extends SparkFunSuite with MLlibTestSparkContext { val expectedValue2 = Array( (Array(1), 4L), (Array(3), 5L), - (Array(3, 4), 3L), + (Array(3, -1, 4), 3L), (Array(4), 4L), (Array(5), 3L) ) @@ -86,23 +86,89 @@ class PrefixSpanSuite extends SparkFunSuite with MLlibTestSparkContext { val result3 = prefixspan.run(rdd) val expectedValue3 = Array( (Array(1), 4L), - (Array(1, 3), 2L), - (Array(1, 4), 2L), - (Array(1, 5), 2L), - (Array(2, 1), 2L), + (Array(1, -1, 3), 2L), + (Array(1, -1, 4), 2L), + (Array(1, -1, 5), 2L), + (Array(2, -1, 1), 2L), (Array(2), 2L), (Array(3), 5L), - (Array(3, 1), 2L), - (Array(3, 3), 2L), - (Array(3, 4), 3L), - (Array(3, 5), 2L), + (Array(3, -1, 1), 2L), + (Array(3, -1, 3), 2L), + (Array(3, -1, 4), 3L), + (Array(3, -1, 5), 2L), (Array(4), 4L), - (Array(4, 5), 2L), + (Array(4, -1, 5), 2L), (Array(5), 3L) ) assert(compareResults(expectedValue3, result3.collect())) } + test("PrefixSpan using Integer type (An element contains multiple items)") { + val sequences = Array( + Array(1, -1, 1, 2, 3, -1, 1, 3, -1, 4, -1, 3, 6), + Array(1, 4, -1, 3, -1, 2, 3, -1, 1, 5), + Array(5, 6, -1, 1, 2, -1, 4, 6, -1, 3, -1, 2), + Array(5, -1, 7, -1, 1, 6, -1, 3, -1, 2, -1, 3)) + val rdd = sc.parallelize(sequences, 2).cache() + val prefixspan = new PrefixSpan().setMinSupport(0.5).setMaxPatternLength(5) + val result = prefixspan.run(rdd) + val expectedValue = Array( + (Array(1), 4L), + (Array(1, -1, 6), 2L), + (Array(1, -1, 1), 2L), + (Array(1, -1, 3), 4L), + (Array(1, -1, 3, -1, 1), 2L), + (Array(1, -1, 3, -1, 3), 3L), + (Array(1, -1, 3, -1, 2), 3L), + (Array(1, -1, 2), 4L), + (Array(1, -1, 2, -1, 1), 2L), + (Array(1, -1, 2, -1, 3), 2L), + (Array(1, -1, 2, 3), 2L), + (Array(1, -1, 2, 3, -1, 1), 2L), + (Array(1, 2), 2L), + (Array(1, 2, -1, 6), 2L), + (Array(1, 2, -1, 3), 2L), + (Array(1, 2, -1, 4), 2L), + (Array(1, 2, -1, 4, -1, 3), 2L), + (Array(1, -1, 4), 2L), + (Array(1, -1, 4, -1, 3), 2L), + (Array(2), 4L), + (Array(2, -1, 6), 2L), + (Array(2, -1, 1), 2L), + (Array(2, -1, 3), 3L), + (Array(2, 3), 2L), + (Array(2, 3, -1, 1), 2L), + (Array(2, -1, 4), 2L), + (Array(2, -1, 4, -1, 3), 2L), + (Array(3), 4L), + (Array(3, -1, 1), 2L), + (Array(3, -1, 3), 3L), + (Array(3, -1, 2), 3L), + (Array(4), 3L), + (Array(4, -1, 3), 3L), + (Array(4, -1, 3, -1, 2), 2L), + (Array(4, -1, 2), 2L), + (Array(5), 3L), + (Array(5, -1, 6), 2L), + (Array(5, -1, 6, -1, 3), 2L), + (Array(5, -1, 6, -1, 3, -1, 2), 2L), + (Array(5, -1, 6, -1, 2), 2L), + (Array(5, -1, 1), 2L), + (Array(5, -1, 1, -1, 3), 2L), + (Array(5, -1, 1, -1, 3, -1, 2), 2L), + (Array(5, -1, 1, -1, 2), 2L), + (Array(5, -1, 3), 2L), + (Array(5, -1, 3, -1, 2), 2L), + (Array(5, -1, 2), 2L), + (Array(5, -1, 2, -1, 3), 2L), + (Array(6), 3L), + (Array(6, -1, 3), 2L), + (Array(6, -1, 3, -1, 2), 2L), + (Array(6, -1, 2), 2L), + (Array(6, -1, 2, -1, 3), 2L)) + assert(compareResults(expectedValue, result.collect())) + } + private def compareResults( expectedValue: Array[(Array[Int], Long)], actualValue: Array[(Array[Int], Long)]): Boolean = {