From 07327869fd0ea1a9dc4f32da8d10bde2ff231770 Mon Sep 17 00:00:00 2001 From: WeichenXu Date: Tue, 22 May 2018 19:52:26 +0800 Subject: [PATCH 1/3] init pr --- .../org/apache/spark/ml/fpm/PrefixSpan.scala | 109 +++++++++++++++--- .../apache/spark/ml/fpm/PrefixSpanSuite.scala | 28 +++-- 2 files changed, 110 insertions(+), 27 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/fpm/PrefixSpan.scala b/mllib/src/main/scala/org/apache/spark/ml/fpm/PrefixSpan.scala index 02168fee16dbf..669f39def69bf 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/fpm/PrefixSpan.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/fpm/PrefixSpan.scala @@ -18,6 +18,8 @@ package org.apache.spark.ml.fpm import org.apache.spark.annotation.{Experimental, Since} +import org.apache.spark.ml.param._ +import org.apache.spark.ml.util.Identifiable import org.apache.spark.mllib.fpm.{PrefixSpan => mllibPrefixSpan} import org.apache.spark.sql.{DataFrame, Dataset, Row} import org.apache.spark.sql.functions.col @@ -35,7 +37,87 @@ import org.apache.spark.sql.types.{ArrayType, LongType, StructField, StructType} */ @Since("2.4.0") @Experimental -object PrefixSpan { +final class PrefixSpan(@Since("2.4.0") override val uid: String) extends Params { + + @Since("2.4.0") + def this() = this(Identifiable.randomUID("prefixSpan")) + + /** + * the minimal support level of the sequential pattern, any pattern that + * appears more than (minSupport * size-of-the-dataset) times will be output + * (default value: `0.1`). + * @group param + */ + @Since("2.4.0") + val minSupport = new DoubleParam(this, "minSupport", "the minimal support level of the " + + "sequential pattern, any pattern that appears more than (minSupport * size-of-the-dataset) " + + "times will be output", ParamValidators.gt(0.0)) + + /** @group getParam */ + @Since("2.4.0") + def getMinSupport: Double = $(minSupport) + + /** + * Set the minSupport parameter. + * Default is 1.0. + * + * @group setParam + */ + @Since("1.3.0") + def setMinSupport(value: Double): this.type = set(minSupport, value) + + /** + * the maximal length of the sequential pattern + * (default value: `10`). + * @group param + */ + @Since("2.4.0") + val maxPatternLength = new IntParam(this, "maxPatternLength", + "the maximal length of the sequential pattern", + ParamValidators.gt(0)) + + /** @group getParam */ + @Since("2.4.0") + def getMaxPatternLength: Double = $(maxPatternLength) + + /** + * Set the maxPatternLength parameter. + * Default is 10. + * + * @group setParam + */ + @Since("2.4.0") + def setMaxPatternLength(value: Int): this.type = set(maxPatternLength, value) + + /** + * The maximum number of items (including delimiters used in the + * internal storage format) allowed in a projected database before + * local processing. If a projected database exceeds this size, another + * iteration of distributed prefix growth is run + * (default value: `32000000`). + * @group param + */ + @Since("2.4.0") + val maxLocalProjDBSize = new LongParam(this, "maxLocalProjDBSize", + "The maximum number of items (including delimiters used in the internal storage format) " + + "allowed in a projected database before local processing. If a projected database exceeds " + + "this size, another iteration of distributed prefix growth is run", + ParamValidators.gt(0)) + + /** @group getParam */ + @Since("2.4.0") + def getMaxLocalProjDBSize: Double = $(maxLocalProjDBSize) + + /** + * Set the maxLocalProjDBSize parameter. + * Default is 32000000. + * + * @group setParam + */ + @Since("2.4.0") + def setMaxLocalProjDBSize(value: Long): this.type = set(maxLocalProjDBSize, value) + + setDefault(minSupport -> 0.1, maxPatternLength -> 10, maxLocalProjDBSize -> 32000000) /** * :: Experimental :: @@ -45,16 +127,6 @@ object PrefixSpan { * {{{Seq[Seq[_]]}}} type * @param sequenceCol the name of the sequence column in dataset, rows with nulls in this column * are ignored - * @param minSupport the minimal support level of the sequential pattern, any pattern that - * appears more than (minSupport * size-of-the-dataset) times will be output - * (recommended value: `0.1`). - * @param maxPatternLength the maximal length of the sequential pattern - * (recommended value: `10`). - * @param maxLocalProjDBSize The maximum number of items (including delimiters used in the - * internal storage format) allowed in a projected database before - * local processing. If a projected database exceeds this size, another - * iteration of distributed prefix growth is run - * (recommended value: `32000000`). * @return A `DataFrame` that contains columns of sequence and corresponding frequency. * The schema of it will be: * - `sequence: Seq[Seq[T]]` (T is the item type) @@ -63,10 +135,7 @@ object PrefixSpan { @Since("2.4.0") def findFrequentSequentialPatterns( dataset: Dataset[_], - sequenceCol: String, - minSupport: Double, - maxPatternLength: Int, - maxLocalProjDBSize: Long): DataFrame = { + sequenceCol: String): DataFrame = { val inputType = dataset.schema(sequenceCol).dataType require(inputType.isInstanceOf[ArrayType] && @@ -74,15 +143,14 @@ object PrefixSpan { s"The input column must be ArrayType and the array element type must also be ArrayType, " + s"but got $inputType.") - val data = dataset.select(sequenceCol) val sequences = data.where(col(sequenceCol).isNotNull).rdd .map(r => r.getAs[Seq[Seq[Any]]](0).map(_.toArray).toArray) val mllibPrefixSpan = new mllibPrefixSpan() - .setMinSupport(minSupport) - .setMaxPatternLength(maxPatternLength) - .setMaxLocalProjDBSize(maxLocalProjDBSize) + .setMinSupport($(minSupport)) + .setMaxPatternLength($(maxPatternLength)) + .setMaxLocalProjDBSize($(maxLocalProjDBSize)) val rows = mllibPrefixSpan.run(sequences).freqSequences.map(f => Row(f.sequence, f.freq)) val schema = StructType(Seq( @@ -93,4 +161,7 @@ object PrefixSpan { freqSequences } + @Since("2.4.0") + override def copy(extra: ParamMap): PrefixSpan = defaultCopy(extra) + } diff --git a/mllib/src/test/scala/org/apache/spark/ml/fpm/PrefixSpanSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/fpm/PrefixSpanSuite.scala index 9e538696cbcf7..b7c8d2a697d11 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/fpm/PrefixSpanSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/fpm/PrefixSpanSuite.scala @@ -29,8 +29,11 @@ class PrefixSpanSuite extends MLTest { test("PrefixSpan projections with multiple partial starts") { val smallDataset = Seq(Seq(Seq(1, 2), Seq(1, 2, 3))).toDF("sequence") - val result = PrefixSpan.findFrequentSequentialPatterns(smallDataset, "sequence", - minSupport = 1.0, maxPatternLength = 2, maxLocalProjDBSize = 32000000) + val result = new PrefixSpan() + .setMinSupport(1.0) + .setMaxPatternLength(2) + .setMaxLocalProjDBSize(32000000) + .findFrequentSequentialPatterns(smallDataset, "sequence") .as[(Seq[Seq[Int]], Long)].collect() val expected = Array( (Seq(Seq(1)), 1L), @@ -90,8 +93,11 @@ class PrefixSpanSuite extends MLTest { test("PrefixSpan Integer type, variable-size itemsets") { val df = smallTestData.toDF("sequence") - val result = PrefixSpan.findFrequentSequentialPatterns(df, "sequence", - minSupport = 0.5, maxPatternLength = 5, maxLocalProjDBSize = 32000000) + val result = new PrefixSpan() + .setMinSupport(0.5) + .setMaxPatternLength(5) + .setMaxLocalProjDBSize(32000000) + .findFrequentSequentialPatterns(df, "sequence") .as[(Seq[Seq[Int]], Long)].collect() compareResults[Int](smallTestDataExpectedResult, result) @@ -99,8 +105,11 @@ class PrefixSpanSuite extends MLTest { test("PrefixSpan input row with nulls") { val df = (smallTestData :+ null).toDF("sequence") - val result = PrefixSpan.findFrequentSequentialPatterns(df, "sequence", - minSupport = 0.5, maxPatternLength = 5, maxLocalProjDBSize = 32000000) + val result = new PrefixSpan() + .setMinSupport(0.5) + .setMaxPatternLength(5) + .setMaxLocalProjDBSize(32000000) + .findFrequentSequentialPatterns(df, "sequence") .as[(Seq[Seq[Int]], Long)].collect() compareResults[Int](smallTestDataExpectedResult, result) @@ -111,8 +120,11 @@ class PrefixSpanSuite extends MLTest { val df = smallTestData .map(seq => seq.map(itemSet => itemSet.map(intToString))) .toDF("sequence") - val result = PrefixSpan.findFrequentSequentialPatterns(df, "sequence", - minSupport = 0.5, maxPatternLength = 5, maxLocalProjDBSize = 32000000) + val result = new PrefixSpan() + .setMinSupport(0.5) + .setMaxPatternLength(5) + .setMaxLocalProjDBSize(32000000) + .findFrequentSequentialPatterns(df, "sequence") .as[(Seq[Seq[String]], Long)].collect() val expected = smallTestDataExpectedResult.map { case (seq, freq) => From 90d71e84f36075aeaab19b496eee87792877c48b Mon Sep 17 00:00:00 2001 From: WeichenXu Date: Wed, 23 May 2018 17:46:28 +0800 Subject: [PATCH 2/3] address comments --- .../org/apache/spark/ml/fpm/PrefixSpan.scala | 92 +++++++++---------- .../apache/spark/ml/fpm/PrefixSpanSuite.scala | 8 +- 2 files changed, 50 insertions(+), 50 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/fpm/PrefixSpan.scala b/mllib/src/main/scala/org/apache/spark/ml/fpm/PrefixSpan.scala index 669f39def69bf..8529cc252263e 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/fpm/PrefixSpan.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/fpm/PrefixSpan.scala @@ -31,6 +31,8 @@ import org.apache.spark.sql.types.{ArrayType, LongType, StructField, StructType} * The PrefixSpan algorithm is described in J. Pei, et al., PrefixSpan: Mining Sequential Patterns * Efficiently by Prefix-Projected Pattern Growth * (see here). + * This class is not yet an Estimator/Transformer, use `findFrequentSequentialPatterns` method to + * run the PrefixSpan algorithm. * * @see Sequential Pattern Mining * (Wikipedia) @@ -43,81 +45,83 @@ final class PrefixSpan(@Since("2.4.0") override val uid: String) extends Params def this() = this(Identifiable.randomUID("prefixSpan")) /** - * the minimal support level of the sequential pattern, any pattern that - * appears more than (minSupport * size-of-the-dataset) times will be output - * (default value: `0.1`). + * Param for the minimal support level (default: `0.1`). + * Sequential patterns that appear more than (minSupport * size-of-the-dataset) times are + * identified as frequent sequential patterns. * @group param */ @Since("2.4.0") - val minSupport = new DoubleParam(this, "minSupport", "the minimal support level of the " + - "sequential pattern, any pattern that appears more than (minSupport * size-of-the-dataset) " + - "times will be output", ParamValidators.gt(0.0)) + val minSupport = new DoubleParam(this, "minSupport", "The minimal support level of the " + + "sequential pattern. Sequential pattern that appears more than " + + "(minSupport * size-of-the-dataset)." + + "times will be output.", ParamValidators.gtEq(0.0)) /** @group getParam */ @Since("2.4.0") def getMinSupport: Double = $(minSupport) - /** - * Set the minSupport parameter. - * Default is 1.0. - * - * @group setParam - */ - @Since("1.3.0") + /** @group setParam */ + @Since("2.4.0") def setMinSupport(value: Double): this.type = set(minSupport, value) /** - * the maximal length of the sequential pattern - * (default value: `10`). + * Param for the maximal pattern length (default: `10`). * @group param */ @Since("2.4.0") val maxPatternLength = new IntParam(this, "maxPatternLength", - "the maximal length of the sequential pattern", + "The maximal length of the sequential pattern.", ParamValidators.gt(0)) /** @group getParam */ @Since("2.4.0") - def getMaxPatternLength: Double = $(maxPatternLength) + def getMaxPatternLength: Int = $(maxPatternLength) - /** - * Set the maxPatternLength parameter. - * Default is 10. - * - * @group setParam - */ + /** @group setParam */ @Since("2.4.0") def setMaxPatternLength(value: Int): this.type = set(maxPatternLength, value) /** - * The maximum number of items (including delimiters used in the - * internal storage format) allowed in a projected database before - * local processing. If a projected database exceeds this size, another - * iteration of distributed prefix growth is run - * (default value: `32000000`). + * Param for the maximum number of items (including delimiters used in the internal storage + * format) allowed in a projected database before local processing (default: `32000000`). + * If a projected database exceeds this size, another iteration of distributed prefix growth + * is run. * @group param */ @Since("2.4.0") val maxLocalProjDBSize = new LongParam(this, "maxLocalProjDBSize", "The maximum number of items (including delimiters used in the internal storage format) " + "allowed in a projected database before local processing. If a projected database exceeds " + - "this size, another iteration of distributed prefix growth is run", + "this size, another iteration of distributed prefix growth is run.", ParamValidators.gt(0)) /** @group getParam */ @Since("2.4.0") - def getMaxLocalProjDBSize: Double = $(maxLocalProjDBSize) + def getMaxLocalProjDBSize: Long = $(maxLocalProjDBSize) + + /** @group setParam */ + @Since("2.4.0") + def setMaxLocalProjDBSize(value: Long): this.type = set(maxLocalProjDBSize, value) /** - * Set the maxLocalProjDBSize parameter. - * Default is 32000000. - * - * @group setParam + * Param for the name of the sequence column in dataset, rows with nulls in this column + * are ignored. + * @group param */ @Since("2.4.0") - def setMaxLocalProjDBSize(value: Long): this.type = set(maxLocalProjDBSize, value) + val sequenceCol = new Param[String](this, "sequenceCol", "The name of the sequence column in " + + "dataset, rows with nulls in this column are ignored.") - setDefault(minSupport -> 0.1, maxPatternLength -> 10, maxLocalProjDBSize -> 32000000) + /** @group getParam */ + @Since("2.4.0") + def getSequenceCol: String = $(sequenceCol) + + /** @group setParam */ + @Since("2.4.0") + def setSequenceCol(value: String): this.type = set(sequenceCol, value) + + setDefault(minSupport -> 0.1, maxPatternLength -> 10, maxLocalProjDBSize -> 32000000, + sequenceCol -> "sequence") /** * :: Experimental :: @@ -125,26 +129,22 @@ final class PrefixSpan(@Since("2.4.0") override val uid: String) extends Params * * @param dataset A dataset or a dataframe containing a sequence column which is * {{{Seq[Seq[_]]}}} type - * @param sequenceCol the name of the sequence column in dataset, rows with nulls in this column - * are ignored * @return A `DataFrame` that contains columns of sequence and corresponding frequency. * The schema of it will be: * - `sequence: Seq[Seq[T]]` (T is the item type) * - `freq: Long` */ @Since("2.4.0") - def findFrequentSequentialPatterns( - dataset: Dataset[_], - sequenceCol: String): DataFrame = { - - val inputType = dataset.schema(sequenceCol).dataType + def findFrequentSequentialPatterns(dataset: Dataset[_]): DataFrame = { + val sequenceColParam = $(sequenceCol) + val inputType = dataset.schema(sequenceColParam).dataType require(inputType.isInstanceOf[ArrayType] && inputType.asInstanceOf[ArrayType].elementType.isInstanceOf[ArrayType], s"The input column must be ArrayType and the array element type must also be ArrayType, " + s"but got $inputType.") - val data = dataset.select(sequenceCol) - val sequences = data.where(col(sequenceCol).isNotNull).rdd + val data = dataset.select(sequenceColParam) + val sequences = data.where(col(sequenceColParam).isNotNull).rdd .map(r => r.getAs[Seq[Seq[Any]]](0).map(_.toArray).toArray) val mllibPrefixSpan = new mllibPrefixSpan() @@ -154,7 +154,7 @@ final class PrefixSpan(@Since("2.4.0") override val uid: String) extends Params val rows = mllibPrefixSpan.run(sequences).freqSequences.map(f => Row(f.sequence, f.freq)) val schema = StructType(Seq( - StructField("sequence", dataset.schema(sequenceCol).dataType, nullable = false), + StructField("sequence", dataset.schema(sequenceColParam).dataType, nullable = false), StructField("freq", LongType, nullable = false))) val freqSequences = dataset.sparkSession.createDataFrame(rows, schema) diff --git a/mllib/src/test/scala/org/apache/spark/ml/fpm/PrefixSpanSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/fpm/PrefixSpanSuite.scala index b7c8d2a697d11..2252151af306b 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/fpm/PrefixSpanSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/fpm/PrefixSpanSuite.scala @@ -33,7 +33,7 @@ class PrefixSpanSuite extends MLTest { .setMinSupport(1.0) .setMaxPatternLength(2) .setMaxLocalProjDBSize(32000000) - .findFrequentSequentialPatterns(smallDataset, "sequence") + .findFrequentSequentialPatterns(smallDataset) .as[(Seq[Seq[Int]], Long)].collect() val expected = Array( (Seq(Seq(1)), 1L), @@ -97,7 +97,7 @@ class PrefixSpanSuite extends MLTest { .setMinSupport(0.5) .setMaxPatternLength(5) .setMaxLocalProjDBSize(32000000) - .findFrequentSequentialPatterns(df, "sequence") + .findFrequentSequentialPatterns(df) .as[(Seq[Seq[Int]], Long)].collect() compareResults[Int](smallTestDataExpectedResult, result) @@ -109,7 +109,7 @@ class PrefixSpanSuite extends MLTest { .setMinSupport(0.5) .setMaxPatternLength(5) .setMaxLocalProjDBSize(32000000) - .findFrequentSequentialPatterns(df, "sequence") + .findFrequentSequentialPatterns(df) .as[(Seq[Seq[Int]], Long)].collect() compareResults[Int](smallTestDataExpectedResult, result) @@ -124,7 +124,7 @@ class PrefixSpanSuite extends MLTest { .setMinSupport(0.5) .setMaxPatternLength(5) .setMaxLocalProjDBSize(32000000) - .findFrequentSequentialPatterns(df, "sequence") + .findFrequentSequentialPatterns(df) .as[(Seq[Seq[String]], Long)].collect() val expected = smallTestDataExpectedResult.map { case (seq, freq) => From 6e0c59fcca82e84abf5c56ca0db1d0100548c216 Mon Sep 17 00:00:00 2001 From: WeichenXu Date: Wed, 23 May 2018 17:51:24 +0800 Subject: [PATCH 3/3] minor update --- mllib/src/main/scala/org/apache/spark/ml/fpm/PrefixSpan.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/fpm/PrefixSpan.scala b/mllib/src/main/scala/org/apache/spark/ml/fpm/PrefixSpan.scala index 8529cc252263e..41716c621ca98 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/fpm/PrefixSpan.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/fpm/PrefixSpan.scala @@ -104,8 +104,8 @@ final class PrefixSpan(@Since("2.4.0") override val uid: String) extends Params def setMaxLocalProjDBSize(value: Long): this.type = set(maxLocalProjDBSize, value) /** - * Param for the name of the sequence column in dataset, rows with nulls in this column - * are ignored. + * Param for the name of the sequence column in dataset (default "sequence"), rows with + * nulls in this column are ignored. * @group param */ @Since("2.4.0")