Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-11237][ML] Add pmml export for k-means in Spark ML #20907

Closed
Closed
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,2 +1,4 @@
org.apache.spark.ml.regression.InternalLinearRegressionModelWriter
org.apache.spark.ml.regression.PMMLLinearRegressionModelWriter
org.apache.spark.ml.regression.PMMLLinearRegressionModelWriter
org.apache.spark.ml.clustering.InternalKMeansModelWriter
org.apache.spark.ml.clustering.PMMLKMeansModelWriter
75 changes: 50 additions & 25 deletions mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,13 @@

package org.apache.spark.ml.clustering

import scala.collection.mutable

import org.apache.hadoop.fs.Path

import org.apache.spark.SparkException
import org.apache.spark.annotation.{Experimental, Since}
import org.apache.spark.ml.{Estimator, Model}
import org.apache.spark.ml.{Estimator, Model, PipelineStage}
import org.apache.spark.ml.linalg.{Vector, VectorUDT}
import org.apache.spark.ml.param._
import org.apache.spark.ml.param.shared._
Expand All @@ -30,7 +32,7 @@ import org.apache.spark.mllib.clustering.{DistanceMeasure, KMeans => MLlibKMeans
import org.apache.spark.mllib.linalg.{Vector => OldVector, Vectors => OldVectors}
import org.apache.spark.mllib.linalg.VectorImplicits._
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{DataFrame, Dataset, Row}
import org.apache.spark.sql.{DataFrame, Dataset, Row, SparkSession}
import org.apache.spark.sql.functions.{col, udf}
import org.apache.spark.sql.types.{IntegerType, StructType}
import org.apache.spark.storage.StorageLevel
Expand Down Expand Up @@ -103,8 +105,8 @@ private[clustering] trait KMeansParams extends Params with HasMaxIter with HasFe
@Since("1.5.0")
class KMeansModel private[ml] (
@Since("1.5.0") override val uid: String,
private val parentModel: MLlibKMeansModel)
extends Model[KMeansModel] with KMeansParams with MLWritable {
private[clustering] val parentModel: MLlibKMeansModel)
extends Model[KMeansModel] with KMeansParams with GeneralMLWritable {

@Since("1.5.0")
override def copy(extra: ParamMap): KMeansModel = {
Expand Down Expand Up @@ -152,14 +154,14 @@ class KMeansModel private[ml] (
}

/**
* Returns a [[org.apache.spark.ml.util.MLWriter]] instance for this ML instance.
* Returns a [[org.apache.spark.ml.util.GeneralMLWriter]] instance for this ML instance.
*
* For [[KMeansModel]], this does NOT currently save the training [[summary]].
* An option to save [[summary]] may be added in the future.
*
*/
@Since("1.6.0")
override def write: MLWriter = new KMeansModel.KMeansModelWriter(this)
override def write: GeneralMLWriter = new GeneralMLWriter(this)

private var trainingSummary: Option[KMeansSummary] = None

Expand All @@ -185,6 +187,47 @@ class KMeansModel private[ml] (
}
}

/** Helper class for storing model data */
private case class ClusterData(clusterIdx: Int, clusterCenter: Vector)


/** A writer for KMeans that handles the "internal" (or default) format */
private class InternalKMeansModelWriter extends MLWriterFormat with MLFormatRegister {

override def format(): String = "internal"
override def stageName(): String = "org.apache.spark.ml.clustering.KMeansModel"

override def write(path: String, sparkSession: SparkSession,
optionMap: mutable.Map[String, String], stage: PipelineStage): Unit = {
val instance = stage.asInstanceOf[KMeansModel]
val sc = sparkSession.sparkContext
// Save metadata and Params
DefaultParamsWriter.saveMetadata(instance, path, sc)
// Save model data: cluster centers
val data: Array[ClusterData] = instance.clusterCenters.zipWithIndex.map {
case (center, idx) =>
ClusterData(idx, center)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

doesn't this type change Data -> ClusterData change the schema of the output parquet file?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not 100% sure. I'll manually test we can load the old format first.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wait no this shouldn't change anything, were saving this with a DataFrame and the schema is the same.
See the schema from 1: res3: org.apache.spark.sql.types.StructType = StructType(StructField(clusterIdx,IntegerType,false), StructField(clusterCenter,org.apache.spark.ml.linalg.VectorUDT@3bfc3ba7,true)) and the new one org.apache.spark.sql.types.StructType = StructType(StructField(clusterIdx,IntegerType,false), StructField(clusterCenter,org.apache.spark.ml.linalg.VectorUDT@3bfc3ba7,true))

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍

}
val dataPath = new Path(path, "data").toString
sparkSession.createDataFrame(data).repartition(1).write.parquet(dataPath)
}
}

/** A writer for KMeans that handles the "pmml" format */
private class PMMLKMeansModelWriter extends MLWriterFormat with MLFormatRegister {

override def format(): String = "pmml"
override def stageName(): String = "org.apache.spark.ml.clustering.KMeansModel"

override def write(path: String, sparkSession: SparkSession,
optionMap: mutable.Map[String, String], stage: PipelineStage): Unit = {
val instance = stage.asInstanceOf[KMeansModel]
val sc = sparkSession.sparkContext
instance.parentModel.toPMML(sc, path)
}
}


@Since("1.6.0")
object KMeansModel extends MLReadable[KMeansModel] {

Expand All @@ -194,30 +237,12 @@ object KMeansModel extends MLReadable[KMeansModel] {
@Since("1.6.0")
override def load(path: String): KMeansModel = super.load(path)

/** Helper class for storing model data */
private case class Data(clusterIdx: Int, clusterCenter: Vector)

/**
* We store all cluster centers in a single row and use this class to store model data by
* Spark 1.6 and earlier. A model can be loaded from such older data for backward compatibility.
*/
private case class OldData(clusterCenters: Array[OldVector])

/** [[MLWriter]] instance for [[KMeansModel]] */
private[KMeansModel] class KMeansModelWriter(instance: KMeansModel) extends MLWriter {

override protected def saveImpl(path: String): Unit = {
// Save metadata and Params
DefaultParamsWriter.saveMetadata(instance, path, sc)
// Save model data: cluster centers
val data: Array[Data] = instance.clusterCenters.zipWithIndex.map { case (center, idx) =>
Data(idx, center)
}
val dataPath = new Path(path, "data").toString
sparkSession.createDataFrame(data).repartition(1).write.parquet(dataPath)
}
}

private class KMeansModelReader extends MLReader[KMeansModel] {

/** Checked against metadata when loading model */
Expand All @@ -232,7 +257,7 @@ object KMeansModel extends MLReadable[KMeansModel] {
val dataPath = new Path(path, "data").toString

val clusterCenters = if (majorVersion(metadata.sparkVersion) >= 2) {
val data: Dataset[Data] = sparkSession.read.parquet(dataPath).as[Data]
val data: Dataset[ClusterData] = sparkSession.read.parquet(dataPath).as[ClusterData]
data.collect().sortBy(_.clusterIdx).map(_.clusterCenter).map(OldVectors.fromML)
} else {
// Loads KMeansModel stored with the old format used by Spark 1.6 and earlier.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -746,7 +746,7 @@ private class InternalLinearRegressionModelWriter

/** A writer for LinearRegression that handles the "pmml" format */
private class PMMLLinearRegressionModelWriter
extends MLWriterFormat with MLFormatRegister {
extends MLWriterFormat with MLFormatRegister {

override def format(): String = "pmml"

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,17 +19,22 @@ package org.apache.spark.ml.clustering

import scala.util.Random

import org.dmg.pmml.{ClusteringModel, PMML}

import org.apache.spark.{SparkException, SparkFunSuite}
import org.apache.spark.ml.linalg.{Vector, Vectors}
import org.apache.spark.ml.param.ParamMap
import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils}
import org.apache.spark.mllib.clustering.{DistanceMeasure, KMeans => MLlibKMeans}
import org.apache.spark.ml.util._
import org.apache.spark.mllib.clustering.{DistanceMeasure, KMeans => MLlibKMeans,
KMeansModel => MLlibKMeansModel}
import org.apache.spark.mllib.linalg.{Vectors => MLlibVectors}
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.sql.{DataFrame, Dataset, SparkSession}

private[clustering] case class TestRow(features: Vector)

class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {
class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest
with PMMLReadWriteTest {

final val k = 5
@transient var dataset: Dataset[_] = _
Expand Down Expand Up @@ -202,6 +207,26 @@ class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultR
testEstimatorAndModelReadWrite(kmeans, dataset, KMeansSuite.allParamSettings,
KMeansSuite.allParamSettings, checkModelData)
}

test("pmml export") {
val clusterCenters = Array(
MLlibVectors.dense(1.0, 2.0, 6.0),
MLlibVectors.dense(1.0, 3.0, 0.0),
MLlibVectors.dense(1.0, 4.0, 6.0))
val oldKmeansModel = new MLlibKMeansModel(clusterCenters)
val kmeansModel = new KMeansModel("", oldKmeansModel)
def checkModel(pmml: PMML): Unit = {
// Check the header descripiton is what we expect
assert(pmml.getHeader.getDescription === "k-means clustering")
// check that the number of fields match the single vector size
assert(pmml.getDataDictionary.getNumberOfFields === clusterCenters(0).size)
// This verify that there is a model attached to the pmml object and the model is a clustering
// one. It also verifies that the pmml model has the same number of clusters of the spark
// model.
val pmmlClusteringModel = pmml.getModels.get(0).asInstanceOf[ClusteringModel]
assert(pmmlClusteringModel.getNumberOfClusters === clusterCenters.length)
}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Isn't this missing a call to testPMMLWrite?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh yeah :( Thanks for catching that.

}
}

object KMeansSuite {
Expand Down