Skip to content

Commit

Permalink
Add PMML export for KMeans model
Browse files Browse the repository at this point in the history
  • Loading branch information
holdenk committed Mar 26, 2018
1 parent 4431407 commit 25d6f77
Show file tree
Hide file tree
Showing 3 changed files with 46 additions and 5 deletions.
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
org.apache.spark.ml.regression.InternalLinearRegressionModelWriter
org.apache.spark.ml.regression.PMMLLinearRegressionModelWriter
org.apache.spark.ml.clustering.InternalKMeansModelWriter
org.apache.spark.ml.clustering.InternalKMeansModelWriter
org.apache.spark.ml.clustering.PMMLKMeansModelWriter
17 changes: 16 additions & 1 deletion mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ private[clustering] trait KMeansParams extends Params with HasMaxIter with HasFe
@Since("1.5.0")
class KMeansModel private[ml] (
@Since("1.5.0") override val uid: String,
private val parentModel: MLlibKMeansModel)
private[clustering] val parentModel: MLlibKMeansModel)
extends Model[KMeansModel] with KMeansParams with GeneralMLWritable {

@Since("1.5.0")
Expand Down Expand Up @@ -213,6 +213,21 @@ private class InternalKMeansModelWriter extends MLWriterFormat with MLFormatRegi
}
}

/** A writer for KMeans that handles the "pmml" format */
private class PMMLKMeansModelWriter extends MLWriterFormat with MLFormatRegister {

override def format(): String = "pmml"
override def stageName(): String = "org.apache.spark.ml.clustering.KMeansModel"

override def write(path: String, sparkSession: SparkSession,
optionMap: mutable.Map[String, String], stage: PipelineStage): Unit = {
val instance = stage.asInstanceOf[KMeansModel]
val sc = sparkSession.sparkContext
instance.parentModel.toPMML(sc, path)
}
}


@Since("1.6.0")
object KMeansModel extends MLReadable[KMeansModel] {

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,17 +19,22 @@ package org.apache.spark.ml.clustering

import scala.util.Random

import org.dmg.pmml.{ClusteringModel, PMML}

import org.apache.spark.{SparkException, SparkFunSuite}
import org.apache.spark.ml.linalg.{Vector, Vectors}
import org.apache.spark.ml.param.ParamMap
import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils}
import org.apache.spark.mllib.clustering.{DistanceMeasure, KMeans => MLlibKMeans}
import org.apache.spark.ml.util._
import org.apache.spark.mllib.clustering.{DistanceMeasure, KMeans => MLlibKMeans,
KMeansModel => MLlibKMeansModel}
import org.apache.spark.mllib.linalg.{Vectors => MLlibVectors}
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.sql.{DataFrame, Dataset, SparkSession}

private[clustering] case class TestRow(features: Vector)

class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {
class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest
with PMMLReadWriteTest {

final val k = 5
@transient var dataset: Dataset[_] = _
Expand Down Expand Up @@ -202,6 +207,26 @@ class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultR
testEstimatorAndModelReadWrite(kmeans, dataset, KMeansSuite.allParamSettings,
KMeansSuite.allParamSettings, checkModelData)
}

test("pmml export") {
val clusterCenters = Array(
MLlibVectors.dense(1.0, 2.0, 6.0),
MLlibVectors.dense(1.0, 3.0, 0.0),
MLlibVectors.dense(1.0, 4.0, 6.0))
val oldKmeansModel = new MLlibKMeansModel(clusterCenters)
val kmeansModel = new KMeansModel("", oldKmeansModel)
def checkModel(pmml: PMML): Unit = {
// Check the header descripiton is what we expect
assert(pmml.getHeader.getDescription === "k-means clustering")
// check that the number of fields match the single vector size
assert(pmml.getDataDictionary.getNumberOfFields === clusterCenters(0).size)
// This verify that there is a model attached to the pmml object and the model is a clustering
// one. It also verifies that the pmml model has the same number of clusters of the spark
// model.
val pmmlClusteringModel = pmml.getModels.get(0).asInstanceOf[ClusteringModel]
assert(pmmlClusteringModel.getNumberOfClusters === clusterCenters.length)
}
}
}

object KMeansSuite {
Expand Down

0 comments on commit 25d6f77

Please sign in to comment.