Skip to content

Commit

Permalink
make assocationrules static and support generic
Browse files Browse the repository at this point in the history
  • Loading branch information
hhbyyh committed Feb 15, 2017
1 parent 5d7881c commit e141776
Show file tree
Hide file tree
Showing 4 changed files with 125 additions and 252 deletions.
113 changes: 0 additions & 113 deletions mllib/src/main/scala/org/apache/spark/ml/fpm/AssociationRules.scala

This file was deleted.

179 changes: 102 additions & 77 deletions mllib/src/main/scala/org/apache/spark/ml/fpm/FPGrowth.scala
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
package org.apache.spark.ml.fpm

import scala.collection.mutable.ArrayBuffer
import scala.reflect.ClassTag

import org.apache.hadoop.fs.Path

Expand All @@ -26,11 +27,11 @@ import org.apache.spark.ml.{Estimator, Model}
import org.apache.spark.ml.param._
import org.apache.spark.ml.param.shared.{HasFeaturesCol, HasPredictionCol}
import org.apache.spark.ml.util._
import org.apache.spark.mllib.fpm.{FPGrowth => MLlibFPGrowth, FPGrowthModel => MLlibFPGrowthModel}
import org.apache.spark.sql.{DataFrame, _}
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types.{ArrayType, StringType, StructType}
import org.apache.spark.util.SizeEstimator
import org.apache.spark.mllib.fpm.{AssociationRules => MLlibAssociationRules,
FPGrowth => MLlibFPGrowth}
import org.apache.spark.mllib.fpm.FPGrowth.FreqItemset
import org.apache.spark.sql._
import org.apache.spark.sql.types._

/**
* Common params for FPGrowth and FPGrowthModel
Expand All @@ -43,8 +44,10 @@ private[fpm] trait FPGrowthParams extends Params with HasFeaturesCol with HasPre
* @return output schema
*/
protected def validateAndTransformSchema(schema: StructType): StructType = {
SchemaUtils.checkColumnType(schema, $(featuresCol), new ArrayType(StringType, false))
SchemaUtils.appendColumn(schema, $(predictionCol), new ArrayType(StringType, false))
val inputType = schema($(featuresCol)).dataType
require(inputType.isInstanceOf[ArrayType],
s"The input column must be ArrayType, but got $inputType.")
SchemaUtils.appendColumn(schema, $(predictionCol), schema($(featuresCol)).dataType)
}

/**
Expand Down Expand Up @@ -84,6 +87,7 @@ private[fpm] trait FPGrowthParams extends Params with HasFeaturesCol with HasPre
val minConfidence: DoubleParam = new DoubleParam(this, "minConfidence",
"minimal confidence for generating Association Rule (Default: 0.8)",
ParamValidators.inRange(0.0, 1.0))
setDefault(minConfidence -> 0.8)

/** @group getParam */
@Since("2.2.0")
Expand Down Expand Up @@ -115,6 +119,12 @@ class FPGrowth @Since("2.2.0") (
@Since("2.2.0")
def setNumPartitions(value: Int): this.type = set(numPartitions, value)

/** @group setParam
* minConfidence has not effect during fitting.
*/
@Since("2.2.0")
def setMinConfidence(value: Double): this.type = set(minConfidence, value)

/** @group setParam */
@Since("2.2.0")
def setFeaturesCol(value: String): this.type = set(featuresCol, value)
Expand All @@ -124,9 +134,22 @@ class FPGrowth @Since("2.2.0") (
def setPredictionCol(value: String): this.type = set(predictionCol, value)

override def fit(dataset: Dataset[_]): FPGrowthModel = {
val data = dataset.select($(featuresCol)).rdd.map(r => r.getSeq[String](0).toArray)
genericFit(dataset)
}

private def genericFit[T: ClassTag](dataset: Dataset[_]): FPGrowthModel = {
val data = dataset.select($(featuresCol)).rdd.map(r => r.getSeq[T](0).toArray)
val parentModel = new MLlibFPGrowth().setMinSupport($(minSupport)).run(data)
copyValues(new FPGrowthModel(uid, parentModel)).setParent(this)
val rows = parentModel.freqItemsets
.map(f => (f.items, f.freq))
.map(cols => Row(cols._1, cols._2))

val dt = dataset.schema($(featuresCol)).dataType
val fields = Array(StructField("items", dt, nullable = false),
StructField("freq", LongType, nullable = false))
val schema = StructType(fields)
val frequentItems = dataset.sparkSession.createDataFrame(rows, schema).toDF("items", "freq")
copyValues(new FPGrowthModel(uid, frequentItems)).setParent(this)
}

@Since("2.2.0")
Expand All @@ -149,19 +172,18 @@ object FPGrowth extends DefaultParamsReadable[FPGrowth] {
* :: Experimental ::
* Model fitted by FPGrowth.
*
* @param parentModel a model trained by spark.mllib.fpm.FPGrowth
* @param freqItemsets frequent items in the format of DataFrame("items", "freq")
*/
@Since("2.2.0")
@Experimental
class FPGrowthModel private[ml] (
@Since("2.2.0") override val uid: String,
private val parentModel: MLlibFPGrowthModel[_])
val freqItemsets: DataFrame)
extends Model[FPGrowthModel] with FPGrowthParams with MLWritable {

/** @group setParam */
@Since("2.2.0")
def setMinConfidence(value: Double): this.type = set(minConfidence, value)
setDefault(minConfidence -> 0.8)

/** @group setParam */
@Since("2.2.0")
Expand All @@ -173,81 +195,50 @@ class FPGrowthModel private[ml] (

/**
* Get association rules fitted by AssociationRules using the minConfidence. Returns a dataframe
* with three fields, "antecedent", "consequent" and "confidence", with "antecedent" and
* "consequent" being Array[String] and the "confidence" being Double.
* with three fields, "antecedent", "consequent" and "confidence", where "antecedent" and
* "consequent" are Array[String] and "confidence" is Double.
*/
@Since("2.2.0")
@transient private lazy val associationRules: DataFrame = {
@transient lazy val getAssociationRules: DataFrame = {
val freqItems = getFreqItemsets
val associationRules = new AssociationRules()
.setMinConfidence($(minConfidence))
.setItemsCol("items")
.setFreqCol("freq")
associationRules.run(freqItems)
AssocaitionRules.getAssocationRulesFromFP(freqItems, "items", "freq", $(minConfidence))
}

/**
* Get association rules fitted by AssociationRules using the minConfidence, in the format
* of DataFrame("antecedent", "consequent", "confidence")
*/
@Since("2.2.0")
@transient lazy val getAssociationRules: DataFrame = associationRules

/**
* Get frequent items fitted by FPGrowth, in the format of DataFrame("items", "freq")
*/
@Since("2.2.0")
@transient lazy val getFreqItemsets: DataFrame = {
val sqlContext = SparkSession.builder().getOrCreate()
import sqlContext.implicits._
parentModel.freqItemsets.map(f => (f.items.map(_.toString), f.freq))
.toDF("items", "freq")
}
@transient val getFreqItemsets: DataFrame = freqItemsets

@Since("2.2.0")
override def transform(dataset: Dataset[_]): DataFrame = {
rulesTransform(dataset, associationRules)
genericTransform(dataset, getAssociationRules)
}

private def rulesTransform(dataset: Dataset[_], associationRules: DataFrame): DataFrame = {
import dataset.sparkSession.implicits._

val ruleSize = SizeEstimator.estimate(associationRules.rdd)
if (ruleSize < 1e8) {
val rules = associationRules.rdd.map(r =>
(r.getSeq[String](0), r.getSeq[String](1))
).collect()
val brRules = dataset.sparkSession.sparkContext.broadcast(rules)

// For each rule, examine the input items and summarize the consequents
val predictUDF = udf((items: Seq[String]) => {
val itemset = items.toSet
val rulesValue = brRules.value
rulesValue.flatMap {
r => if (r._1.forall(itemset.contains)) r._2 else Array.empty[String]
}.distinct
})
dataset.withColumn($(predictionCol), predictUDF(col($(featuresCol))))

} else {
val itemsRDD = dataset.select($(featuresCol)).rdd.map(r => r.getSeq[String](0))
.distinct().zipWithUniqueId().map(t => (t._2, t._1))
val rulesRDD = associationRules.rdd.map(r =>
(r.getSeq[String](0), r.getSeq[String](1))
)
val merged = itemsRDD.cartesian(rulesRDD).map {
case ((id, items), (antecedent, consequent)) =>
val consequents = if (antecedent.forall(items.contains(_))) consequent else Seq.empty
(id, consequents)
}.aggregateByKey(new ArrayBuffer[String])(
(ar, seq) => ar ++= seq, (ar, seq) => ar ++= seq)

val mapping = itemsRDD.join(merged).map {
case (id, (items, consequent)) => (items, consequent)
}.toDF($(featuresCol), $(predictionCol))

dataset.join(mapping, $(featuresCol))
}
private def genericTransform[T](dataset: Dataset[_], associationRules: DataFrame): DataFrame = {
// use unique id to perform the join and aggregateByKey
val itemsRDD = dataset.select($(featuresCol)).rdd.map(r => r.getSeq[T](0))
.distinct().zipWithUniqueId().map(_.swap).cache()
val rulesRDD = associationRules.rdd.map(r => (r.getSeq[T](0), r.getSeq[T](1)))

val itemsWithConsequents = itemsRDD.cartesian(rulesRDD).map {
case ((id, items), (antecedent, consequent)) =>
val itemSet = items.toSet
val consequents = if (antecedent.forall(itemSet.contains(_))) consequent else Seq.empty
(id, consequents)
}.aggregateByKey(new ArrayBuffer[T])(
(ar, seq) => ar ++= seq, (ar, seq) => ar ++= seq)

val mappingRDD = itemsRDD.join(itemsWithConsequents)
.map { case (id, (items, consequent)) => (items, consequent) }
.map (cols => Row(cols._1, cols._2))
val dt = dataset.schema($(featuresCol)).dataType
val fields = Array($(featuresCol), $(predictionCol))
.map(fieldName => StructField(fieldName, dt, nullable = true))
val schema = StructType(fields)
val mapping = dataset.sparkSession.createDataFrame(mappingRDD, schema)

dataset.join(mapping, $(featuresCol))
}

@Since("2.2.0")
Expand All @@ -257,7 +248,7 @@ class FPGrowthModel private[ml] (

@Since("2.2.0")
override def copy(extra: ParamMap): FPGrowthModel = {
val copied = new FPGrowthModel(uid, parentModel)
val copied = new FPGrowthModel(uid, freqItemsets)
copyValues(copied, extra).setParent(this.parent)
}

Expand All @@ -280,7 +271,7 @@ object FPGrowthModel extends MLReadable[FPGrowthModel] {
// Save metadata and Params
DefaultParamsWriter.saveMetadata(instance, path, sc)
val dataPath = new Path(path, "data").toString
instance.parentModel.save(sc, dataPath)
instance.freqItemsets.write.save(dataPath)
}
}

Expand All @@ -292,11 +283,45 @@ object FPGrowthModel extends MLReadable[FPGrowthModel] {
override def load(path: String): FPGrowthModel = {
val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
val dataPath = new Path(path, "data").toString
val mllibModel = MLlibFPGrowthModel.load(sc, dataPath)
val model = new FPGrowthModel(metadata.uid, mllibModel)
val frequentItems = sparkSession.read.load(dataPath)
val model = new FPGrowthModel(metadata.uid, frequentItems)
DefaultParamsReader.getAndSetParams(model, metadata)
model
}
}
}

private[fpm] object AssocaitionRules {

/**
* Computes the association rules with confidence above minConfidence.
* @param dataset DataFrame("items", "freq") containing frequent itemset obtained from
* algorithms like [[FPGrowth]].
* @param itemsCol column name for frequent itemsets
* @param freqCol column name for frequent itemsets count
* @param minConfidence minimum confidence for the result association rules
* @return a DataFrame("antecedent", "consequent", "confidence") containing the association
* rules.
*/
@Since("2.2.0")
def getAssocationRulesFromFP[T: ClassTag](dataset: Dataset[_],
itemsCol: String = "items",
freqCol: String = "freq",
minConfidence: Double = 0.8): DataFrame = {

val freqItemSetRdd = dataset.select(itemsCol, freqCol).rdd
.map(row => new FreqItemset(row.getSeq[T](0).toArray, row.getLong(1)))
val rows = new MLlibAssociationRules()
.setMinConfidence(minConfidence)
.run(freqItemSetRdd)
.map(r => Row(r.antecedent, r.consequent, r.confidence))

val dt = dataset.schema(itemsCol).dataType
val fields = Array("antecedent", "consequent")
.map(fieldName => StructField(fieldName, dt, nullable = false)) ++
Seq(StructField("confidence", DoubleType, nullable = false))
val schema = StructType(fields)
val mapping = dataset.sparkSession.createDataFrame(rows, schema)
mapping
}
}

0 comments on commit e141776

Please sign in to comment.