Skip to content

Commit

Permalink
[SPARK-20114][ML] spark.ml parity for sequential pattern mining - Pre…
Browse files Browse the repository at this point in the history
…fixSpan

## What changes were proposed in this pull request?

PrefixSpan API for spark.ml. New implementation instead of #20810

## How was this patch tested?

TestSuite added.

Author: WeichenXu <weichen.xu@databricks.com>

Closes #20973 from WeichenXu123/prefixSpan2.
  • Loading branch information
WeichenXu123 authored and jkbradley committed May 7, 2018
1 parent f48bd6b commit 76ecd09
Show file tree
Hide file tree
Showing 3 changed files with 233 additions and 2 deletions.
96 changes: 96 additions & 0 deletions mllib/src/main/scala/org/apache/spark/ml/fpm/PrefixSpan.scala
@@ -0,0 +1,96 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.ml.fpm

import org.apache.spark.annotation.{Experimental, Since}
import org.apache.spark.mllib.fpm.{PrefixSpan => mllibPrefixSpan}
import org.apache.spark.sql.{DataFrame, Dataset, Row}
import org.apache.spark.sql.functions.col
import org.apache.spark.sql.types.{ArrayType, LongType, StructField, StructType}

/**
* :: Experimental ::
* A parallel PrefixSpan algorithm to mine frequent sequential patterns.
* The PrefixSpan algorithm is described in J. Pei, et al., PrefixSpan: Mining Sequential Patterns
* Efficiently by Prefix-Projected Pattern Growth
* (see <a href="http://doi.org/10.1109/ICDE.2001.914830">here</a>).
*
* @see <a href="https://en.wikipedia.org/wiki/Sequential_Pattern_Mining">Sequential Pattern Mining
* (Wikipedia)</a>
*/
@Since("2.4.0")
@Experimental
object PrefixSpan {

/**
* :: Experimental ::
* Finds the complete set of frequent sequential patterns in the input sequences of itemsets.
*
* @param dataset A dataset or a dataframe containing a sequence column which is
* {{{Seq[Seq[_]]}}} type
* @param sequenceCol the name of the sequence column in dataset, rows with nulls in this column
* are ignored
* @param minSupport the minimal support level of the sequential pattern, any pattern that
* appears more than (minSupport * size-of-the-dataset) times will be output
* (recommended value: `0.1`).
* @param maxPatternLength the maximal length of the sequential pattern
* (recommended value: `10`).
* @param maxLocalProjDBSize The maximum number of items (including delimiters used in the
* internal storage format) allowed in a projected database before
* local processing. If a projected database exceeds this size, another
* iteration of distributed prefix growth is run
* (recommended value: `32000000`).
* @return A `DataFrame` that contains columns of sequence and corresponding frequency.
* The schema of it will be:
* - `sequence: Seq[Seq[T]]` (T is the item type)
* - `freq: Long`
*/
@Since("2.4.0")
def findFrequentSequentialPatterns(
dataset: Dataset[_],
sequenceCol: String,
minSupport: Double,
maxPatternLength: Int,
maxLocalProjDBSize: Long): DataFrame = {

val inputType = dataset.schema(sequenceCol).dataType
require(inputType.isInstanceOf[ArrayType] &&
inputType.asInstanceOf[ArrayType].elementType.isInstanceOf[ArrayType],
s"The input column must be ArrayType and the array element type must also be ArrayType, " +
s"but got $inputType.")


val data = dataset.select(sequenceCol)
val sequences = data.where(col(sequenceCol).isNotNull).rdd
.map(r => r.getAs[Seq[Seq[Any]]](0).map(_.toArray).toArray)

val mllibPrefixSpan = new mllibPrefixSpan()
.setMinSupport(minSupport)
.setMaxPatternLength(maxPatternLength)
.setMaxLocalProjDBSize(maxLocalProjDBSize)

val rows = mllibPrefixSpan.run(sequences).freqSequences.map(f => Row(f.sequence, f.freq))
val schema = StructType(Seq(
StructField("sequence", dataset.schema(sequenceCol).dataType, nullable = false),
StructField("freq", LongType, nullable = false)))
val freqSequences = dataset.sparkSession.createDataFrame(rows, schema)

freqSequences
}

}
Expand Up @@ -49,8 +49,7 @@ import org.apache.spark.storage.StorageLevel
*
* @param minSupport the minimal support level of the sequential pattern, any pattern that appears
* more than (minSupport * size-of-the-dataset) times will be output
* @param maxPatternLength the maximal length of the sequential pattern, any pattern that appears
* less than maxPatternLength will be output
* @param maxPatternLength the maximal length of the sequential pattern
* @param maxLocalProjDBSize The maximum number of items (including delimiters used in the internal
* storage format) allowed in a projected database before local
* processing. If a projected database exceeds this size, another
Expand Down
136 changes: 136 additions & 0 deletions mllib/src/test/scala/org/apache/spark/ml/fpm/PrefixSpanSuite.scala
@@ -0,0 +1,136 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.ml.fpm

import org.apache.spark.ml.util.MLTest
import org.apache.spark.sql.DataFrame

class PrefixSpanSuite extends MLTest {

import testImplicits._

override def beforeAll(): Unit = {
super.beforeAll()
}

test("PrefixSpan projections with multiple partial starts") {
val smallDataset = Seq(Seq(Seq(1, 2), Seq(1, 2, 3))).toDF("sequence")
val result = PrefixSpan.findFrequentSequentialPatterns(smallDataset, "sequence",
minSupport = 1.0, maxPatternLength = 2, maxLocalProjDBSize = 32000000)
.as[(Seq[Seq[Int]], Long)].collect()
val expected = Array(
(Seq(Seq(1)), 1L),
(Seq(Seq(1, 2)), 1L),
(Seq(Seq(1), Seq(1)), 1L),
(Seq(Seq(1), Seq(2)), 1L),
(Seq(Seq(1), Seq(3)), 1L),
(Seq(Seq(1, 3)), 1L),
(Seq(Seq(2)), 1L),
(Seq(Seq(2, 3)), 1L),
(Seq(Seq(2), Seq(1)), 1L),
(Seq(Seq(2), Seq(2)), 1L),
(Seq(Seq(2), Seq(3)), 1L),
(Seq(Seq(3)), 1L))
compareResults[Int](expected, result)
}

/*
To verify expected results for `smallTestData`, create file "prefixSpanSeqs2" with content
(format = (transactionID, idxInTransaction, numItemsinItemset, itemset)):
1 1 2 1 2
1 2 1 3
2 1 1 1
2 2 2 3 2
2 3 2 1 2
3 1 2 1 2
3 2 1 5
4 1 1 6
In R, run:
library("arulesSequences")
prefixSpanSeqs = read_baskets("prefixSpanSeqs", info = c("sequenceID","eventID","SIZE"))
freqItemSeq = cspade(prefixSpanSeqs,
parameter = 0.5, maxlen = 5 ))
resSeq = as(freqItemSeq, "data.frame")
resSeq
sequence support
1 <{1}> 0.75
2 <{2}> 0.75
3 <{3}> 0.50
4 <{1},{3}> 0.50
5 <{1,2}> 0.75
*/
val smallTestData = Seq(
Seq(Seq(1, 2), Seq(3)),
Seq(Seq(1), Seq(3, 2), Seq(1, 2)),
Seq(Seq(1, 2), Seq(5)),
Seq(Seq(6)))

val smallTestDataExpectedResult = Array(
(Seq(Seq(1)), 3L),
(Seq(Seq(2)), 3L),
(Seq(Seq(3)), 2L),
(Seq(Seq(1), Seq(3)), 2L),
(Seq(Seq(1, 2)), 3L)
)

test("PrefixSpan Integer type, variable-size itemsets") {
val df = smallTestData.toDF("sequence")
val result = PrefixSpan.findFrequentSequentialPatterns(df, "sequence",
minSupport = 0.5, maxPatternLength = 5, maxLocalProjDBSize = 32000000)
.as[(Seq[Seq[Int]], Long)].collect()

compareResults[Int](smallTestDataExpectedResult, result)
}

test("PrefixSpan input row with nulls") {
val df = (smallTestData :+ null).toDF("sequence")
val result = PrefixSpan.findFrequentSequentialPatterns(df, "sequence",
minSupport = 0.5, maxPatternLength = 5, maxLocalProjDBSize = 32000000)
.as[(Seq[Seq[Int]], Long)].collect()

compareResults[Int](smallTestDataExpectedResult, result)
}

test("PrefixSpan String type, variable-size itemsets") {
val intToString = (1 to 6).zip(Seq("a", "b", "c", "d", "e", "f")).toMap
val df = smallTestData
.map(seq => seq.map(itemSet => itemSet.map(intToString)))
.toDF("sequence")
val result = PrefixSpan.findFrequentSequentialPatterns(df, "sequence",
minSupport = 0.5, maxPatternLength = 5, maxLocalProjDBSize = 32000000)
.as[(Seq[Seq[String]], Long)].collect()

val expected = smallTestDataExpectedResult.map { case (seq, freq) =>
(seq.map(itemSet => itemSet.map(intToString)), freq)
}
compareResults[String](expected, result)
}

private def compareResults[Item](
expectedValue: Array[(Seq[Seq[Item]], Long)],
actualValue: Array[(Seq[Seq[Item]], Long)]): Unit = {
val expectedSet = expectedValue.map { x =>
(x._1.map(_.toSet), x._2)
}.toSet
val actualSet = actualValue.map { x =>
(x._1.map(_.toSet), x._2)
}.toSet
assert(expectedSet === actualSet)
}
}

0 comments on commit 76ecd09

Please sign in to comment.