Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-8999][MLlib]Support non-temporal sequence in PrefixSpan (Array[Int]) #7646

Closed
wants to merge 14 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
158 changes: 135 additions & 23 deletions mllib/src/main/scala/org/apache/spark/mllib/fpm/LocalPrefixSpan.scala
Original file line number Diff line number Diff line change
Expand Up @@ -41,54 +41,166 @@ private[fpm] object LocalPrefixSpan extends Logging with Serializable {
maxPatternLength: Int,
prefixes: List[Int],
database: Iterable[Array[Int]]): Iterator[(List[Int], Long)] = {
if (prefixes.length == maxPatternLength || database.isEmpty) return Iterator.empty
val frequentItemAndCounts = getFreqItemAndCounts(minCount, database)
val filteredDatabase = database.map(x => x.filter(frequentItemAndCounts.contains))
frequentItemAndCounts.iterator.flatMap { case (item, count) =>
val newPrefixes = item :: prefixes
val newProjected = project(filteredDatabase, item)
Iterator.single((newPrefixes, count)) ++
run(minCount, maxPatternLength, newPrefixes, newProjected)
if (prefixes.count(_ != -1) == maxPatternLength || database.isEmpty) return Iterator.empty
val frequentPrefixAndCounts = getFreqPrefixAndCounts(minCount, prefixes, database)
frequentPrefixAndCounts.iterator.flatMap { case (prefix, count) =>
val newProjected = project(database, prefix)
Iterator.single((prefix, count)) ++
run(minCount, maxPatternLength, prefix, newProjected)
}
}

/**
* Calculate suffix sequence immediately after the first occurrence of an item.
* @param item item to get suffix after
* Calculate suffix sequence immediately after the first occurrence of a prefix.
* @param prefix prefix to get suffix after
* @param sequence sequence to extract suffix from
* @return suffix sequence
*/
def getSuffix(item: Int, sequence: Array[Int]): Array[Int] = {
val index = sequence.indexOf(item)
def getSuffix(prefix: List[Int], sequence: Array[Int]): (Boolean, Array[Int]) = {
val element = getLastElement(prefix)
if (sequence.apply(0) != -3) {
if (element.length == 1) {
getSingleItemElementSuffix(element, sequence)
} else {
getMultiItemsElementSuffix(element, sequence)
}
} else {
if (element.length == 1) {
val firstElemPos = sequence.indexOf(-1)
if (firstElemPos == -1) {
(false, Array())
} else {
getSingleItemElementSuffix(element, sequence.drop(firstElemPos + 1))
}
} else {
val newSequence = element.take(element.length - 1) ++ sequence.drop(1)
getMultiItemsElementSuffix(element, newSequence)
}
}
}

private def getLastElement(prefix: List[Int]): Array[Int] = {
val pos = prefix.indexOf(-1)
if (pos == -1) {
prefix.reverse.toArray
} else {
prefix.take(pos).reverse.toArray
}
}

private def getSingleItemElementSuffix(
element: Array[Int],
sequence: Array[Int]): (Boolean, Array[Int]) = {
val index = sequence.indexOf(element.apply(0))
if (index == -1) {
Array()
(false, Array())
} else if (index == sequence.length - 1) {
(true, Array())
} else if (sequence.apply(index + 1) == -1) {
(true, sequence.drop(index + 2))
} else {
(true, -3 +: sequence.drop(index + 1))
}
}

private def getMultiItemsElementSuffix(
element: Array[Int],
sequence: Array[Int]): (Boolean, Array[Int]) = {
var seqPos = 0
var found = false
while (seqPos < sequence.length && !found) {
var elemPos = 0
while (!found && elemPos < element.length &&
seqPos < sequence.length && sequence.apply(seqPos) != -1 ) {
if (element.apply(elemPos) == sequence.apply(seqPos)) {
elemPos += 1
seqPos += 1
} else {
seqPos += 1
}
found = elemPos == element.length
}
if (!found) seqPos += 1
}
if (found) {
if (sequence.apply(seqPos) == -1) {
(true, sequence.drop(seqPos + 1))
} else {
(true, -3 +: sequence.drop(seqPos))
}
} else {
sequence.drop(index + 1)
(false, Array())
}
}

def project(database: Iterable[Array[Int]], prefix: Int): Iterable[Array[Int]] = {
def project(database: Iterable[Array[Int]], prefix: List[Int]): Iterable[Array[Int]] = {
database
.map(getSuffix(prefix, _))
.map(getSuffix(prefix, _)._2)
.filter(_.nonEmpty)
}

/**
* Generates frequent items by filtering the input data using minimal count level.
* Generates frequent prefix by filtering the input data using minimal count level.
* @param minCount the minimum count for an item to be frequent
* @param prefix the minimum count for an item to be frequent
* @param database database of sequences
* @return freq item to count map
*/
private def getFreqItemAndCounts(
private def getFreqPrefixAndCounts(
minCount: Long,
database: Iterable[Array[Int]]): mutable.Map[Int, Long] = {
prefix: List[Int],
database: Iterable[Array[Int]]): mutable.Map[List[Int], Long] = {
// TODO: use PrimitiveKeyOpenHashMap
val counts = mutable.Map[Int, Long]().withDefaultValue(0L)

// get frequent items
val freqItems = database
.flatMap(_.distinct.filter(x => x != -1 && x != -3))
.groupBy(x => x)
.mapValues(_.size)
.filter(_._2 >= minCount)
.map(_._1)
if (freqItems.isEmpty) return mutable.Map[List[Int], Long]()

// get prefixes and counts
val singleItemCounts = mutable.Map[Int, Long]().withDefaultValue(0L)
val multiItemsCounts = mutable.Map[Int, Long]().withDefaultValue(0L)
val prefixLastElement = getLastElement(prefix)
database.foreach { sequence =>
sequence.distinct.foreach { item =>
counts(item) += 1L
if (sequence.apply(0) != -3) {
freqItems.foreach { item =>
if (getSingleItemElementSuffix(Array(item), sequence)._1) {
singleItemCounts(item) += 1
}
if (prefixLastElement.nonEmpty &&
getMultiItemsElementSuffix(prefixLastElement :+ item, sequence)._1) {
multiItemsCounts(item) += 1
}
}
} else {
val firstElemPos = sequence.indexOf(-1)
if (firstElemPos != -1) {
val newSequence = sequence.drop(firstElemPos + 1)
freqItems.foreach { item =>
if (getSingleItemElementSuffix(Array(item), newSequence)._1) {
singleItemCounts(item) += 1
}
}
}
val newSequence = prefixLastElement ++ sequence.drop(1)
freqItems.foreach { item =>
if (prefixLastElement.nonEmpty &&
getMultiItemsElementSuffix(prefixLastElement :+ item, newSequence)._1) {
multiItemsCounts(item) += 1
}
}
}
}
counts.filter(_._2 >= minCount)

if (prefix.nonEmpty) {
singleItemCounts.filter(_._2 >= minCount).map(x => (x._1 :: (-1 :: prefix), x._2)) ++
multiItemsCounts.filter(_._2 >= minCount).map(x => (x._1 :: prefix, x._2))
} else {
singleItemCounts.filter(_._2 >= minCount).map(x => (List(x._1), x._2))
}
}
}
67 changes: 36 additions & 31 deletions mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,6 @@

package org.apache.spark.mllib.fpm

import scala.collection.mutable.ArrayBuffer

import org.apache.spark.Logging
import org.apache.spark.annotation.Experimental
import org.apache.spark.rdd.RDD
Expand Down Expand Up @@ -50,7 +48,7 @@ class PrefixSpan private (
* projected database exceeds this size, another iteration of distributed PrefixSpan is run.
*/
// TODO: make configurable with a better default value, 10000 may be too small
private val maxLocalProjDBSize: Long = 10000
private val maxLocalProjDBSize: Long = 32000000L

/**
* Constructs a default instance with default parameters
Expand Down Expand Up @@ -108,20 +106,20 @@ class PrefixSpan private (

// (Frequent items -> number of occurrences, all items here satisfy the `minSupport` threshold
val freqItemCounts = sequences
.flatMap(seq => seq.distinct.map(item => (item, 1L)))
.flatMap(seq => seq.distinct.filter(_ != -1).map(item => (item, 1L)))
.reduceByKey(_ + _)
.filter(_._2 >= minCount)
.collect()

// Pairs of (length 1 prefix, suffix consisting of frequent items)
val itemSuffixPairs = {
val prefixSuffixPairs = {
val freqItems = freqItemCounts.map(_._1).toSet
sequences.flatMap { seq =>
val filteredSeq = seq.filter(freqItems.contains(_))
val filteredSeq = seq.filter(item => freqItems.contains(item) || item == -1)
freqItems.flatMap { item =>
val candidateSuffix = LocalPrefixSpan.getSuffix(item, filteredSeq)
val candidateSuffix = LocalPrefixSpan.getSuffix(List(item), filteredSeq)._2
candidateSuffix match {
case suffix if !suffix.isEmpty => Some((List(item), suffix))
case suffix if suffix.nonEmpty => Some((List(item), suffix))
case _ => None
}
}
Expand All @@ -133,11 +131,13 @@ class PrefixSpan private (
var resultsAccumulator = freqItemCounts.map(x => (List(x._1), x._2))

// Remaining work to be locally and distributively processed respectfully
var (pairsForLocal, pairsForDistributed) = partitionByProjDBSize(itemSuffixPairs)
var (pairsForLocal, pairsForDistributed) = partitionByProjDBSize(prefixSuffixPairs)

// Continue processing until no pairs for distributed processing remain (i.e. all prefixes have
// projected database sizes <= `maxLocalProjDBSize`)
while (pairsForDistributed.count() != 0) {
var patternLength = 1
while (pairsForDistributed.count() != 0 && patternLength < maxPatternLength) {
patternLength += 1
val (nextPatternAndCounts, nextPrefixSuffixPairs) =
extendPrefixes(minCount, pairsForDistributed)
pairsForDistributed.unpersist()
Expand All @@ -153,7 +153,7 @@ class PrefixSpan private (
minCount, sc.parallelize(pairsForLocal, 1).groupByKey())

(sc.parallelize(resultsAccumulator, 1) ++ remainingResults)
.map { case (pattern, count) => (pattern.toArray, count) }
.map { case (pattern, count) => (pattern.reverse.toArray, count) }
}


Expand Down Expand Up @@ -195,36 +195,41 @@ class PrefixSpan private (
// (length N prefix, item from suffix) pairs and their corresponding number of occurrences
// Every (prefix :+ suffix) is guaranteed to have support exceeding `minSupport`
val prefixItemPairAndCounts = prefixSuffixPairs
.flatMap { case (prefix, suffix) => suffix.distinct.map(y => ((prefix, y), 1L)) }
.flatMap { case (prefix, suffix) =>
suffix.distinct.filter(item => item != -1 && item != -3).map(y => ((prefix, y), 1L)) }
.reduceByKey(_ + _)
.filter(_._2 >= minCount)

// Map from prefix to set of possible next items from suffix
val prefixToNextItems = prefixItemPairAndCounts
// Map from prefix to set of possible next prefix from suffix
val prefixToNextPrefixes = prefixItemPairAndCounts
.keys
.groupByKey()
.mapValues(_.toSet)
.map { case (prefix, items) =>
(prefix, items.flatMap(item => Array(item :: (-1 :: prefix), item :: prefix)).toSet) }
.collect()
.toMap


// Frequent patterns with length N+1 and their corresponding counts
val extendedPrefixAndCounts = prefixItemPairAndCounts
.map { case ((prefix, item), count) => (item :: prefix, count) }

// Remaining work, all prefixes will have length N+1
val extendedPrefixAndSuffix = prefixSuffixPairs
.filter(x => prefixToNextItems.contains(x._1))
val extendedPrefixAndSuffixWithFlags = prefixSuffixPairs
.flatMap { case (prefix, suffix) =>
val frequentNextItems = prefixToNextItems(prefix)
val filteredSuffix = suffix.filter(frequentNextItems.contains(_))
frequentNextItems.flatMap { item =>
LocalPrefixSpan.getSuffix(item, filteredSuffix) match {
case suffix if !suffix.isEmpty => Some(item :: prefix, suffix)
case _ => None
if (prefixToNextPrefixes.contains(prefix)) {
val frequentNextPrefixes = prefixToNextPrefixes(prefix)
frequentNextPrefixes.map { nextPrefix =>
val suffixWithFlag = LocalPrefixSpan.getSuffix(nextPrefix, suffix)
(nextPrefix, if (suffixWithFlag._1) 1L else 0L, suffixWithFlag._2)
}
} else {
None
}
}
}.persist(StorageLevel.MEMORY_AND_DISK)
val extendedPrefixAndCounts = extendedPrefixAndSuffixWithFlags
.map(x => (x._1, x._2))
.reduceByKey(_ + _)
.filter(_._2 >= minCount)
val extendedPrefixAndSuffix = extendedPrefixAndSuffixWithFlags
.map(x => (x._1, x._3))
.filter(_._2.nonEmpty)
extendedPrefixAndSuffixWithFlags.unpersist()

(extendedPrefixAndCounts, extendedPrefixAndSuffix)
}
Expand All @@ -240,9 +245,9 @@ class PrefixSpan private (
data: RDD[(List[Int], Iterable[Array[Int]])]): RDD[(List[Int], Long)] = {
data.flatMap {
case (prefix, projDB) =>
LocalPrefixSpan.run(minCount, maxPatternLength, prefix.toList.reverse, projDB)
LocalPrefixSpan.run(minCount, maxPatternLength, prefix, projDB)
.map { case (pattern: List[Int], count: Long) =>
(pattern.reverse, count)
(pattern, count)
}
}
}
Expand Down
Loading