Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-5986][MLLib] Add save/load for k-means #4951

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
Expand Up @@ -17,15 +17,22 @@

package org.apache.spark.mllib.clustering

import org.json4s._
import org.json4s.JsonDSL._
import org.json4s.jackson.JsonMethods._

import org.apache.spark.api.java.JavaRDD
import org.apache.spark.rdd.RDD
import org.apache.spark.SparkContext._
import org.apache.spark.mllib.linalg.Vector
import org.apache.spark.mllib.util.{Loader, Saveable}
import org.apache.spark.rdd.RDD
import org.apache.spark.SparkContext
import org.apache.spark.sql.SQLContext
import org.apache.spark.sql.Row

/**
* A clustering model for K-means. Each point belongs to the cluster with the closest center.
*/
class KMeansModel (val clusterCenters: Array[Vector]) extends Serializable {
class KMeansModel (val clusterCenters: Array[Vector]) extends Saveable with Serializable {

/** Total number of clusters. */
def k: Int = clusterCenters.length
Expand Down Expand Up @@ -58,4 +65,59 @@ class KMeansModel (val clusterCenters: Array[Vector]) extends Serializable {

private def clusterCentersWithNorm: Iterable[VectorWithNorm] =
clusterCenters.map(new VectorWithNorm(_))

override def save(sc: SparkContext, path: String): Unit = {
KMeansModel.SaveLoadV1_0.save(sc, this, path)
}

override protected def formatVersion: String = "1.0"
}

object KMeansModel extends Loader[KMeansModel] {
override def load(sc: SparkContext, path: String): KMeansModel = {
KMeansModel.SaveLoadV1_0.load(sc, path)
}

private case class Cluster(id: Int, point: Vector)

private object Cluster {
def apply(r: Row): Cluster = {
Cluster(r.getInt(0), r.getAs[Vector](1))
}
}

private[clustering]
object SaveLoadV1_0 {

private val thisFormatVersion = "1.0"

private[clustering]
val thisClassName = "org.apache.spark.mllib.clustering.KMeansModel"

def save(sc: SparkContext, model: KMeansModel, path: String): Unit = {
val sqlContext = new SQLContext(sc)
import sqlContext.implicits._
val metadata = compact(render(
("class" -> thisClassName) ~ ("version" -> thisFormatVersion) ~ ("k" -> model.k)))
sc.parallelize(Seq(metadata), 1).saveAsTextFile(Loader.metadataPath(path))
val dataRDD = sc.parallelize(model.clusterCenters.zipWithIndex).map { case (point, id) =>
Cluster(id, point)
}.toDF()
dataRDD.saveAsParquetFile(Loader.dataPath(path))
}

def load(sc: SparkContext, path: String): KMeansModel = {
implicit val formats = DefaultFormats
val sqlContext = new SQLContext(sc)
val (className, formatVersion, metadata) = Loader.loadMetadata(sc, path)
assert(className == thisClassName)
assert(formatVersion == thisFormatVersion)
val k = (metadata \ "k").extract[Int]
val centriods = sqlContext.parquetFile(Loader.dataPath(path))
Loader.checkSchema[Cluster](centriods.schema)
val localCentriods = centriods.map(Cluster.apply).collect()
assert(k == localCentriods.size)
new KMeansModel(localCentriods.sortBy(_.id).map(_.point))
}
}
}
Expand Up @@ -21,9 +21,10 @@ import scala.util.Random

import org.scalatest.FunSuite

import org.apache.spark.mllib.linalg.{Vector, Vectors}
import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector, Vectors}
import org.apache.spark.mllib.util.{LocalClusterSparkContext, MLlibTestSparkContext}
import org.apache.spark.mllib.util.TestingUtils._
import org.apache.spark.util.Utils

class KMeansSuite extends FunSuite with MLlibTestSparkContext {

Expand Down Expand Up @@ -257,6 +258,47 @@ class KMeansSuite extends FunSuite with MLlibTestSparkContext {
assert(predicts(0) != predicts(3))
}
}

test("model save/load") {
val tempDir = Utils.createTempDir()
val path = tempDir.toURI.toString

Array(true, false).foreach { case selector =>
val model = KMeansSuite.createModel(10, 3, selector)
// Save model, load it back, and compare.
try {
model.save(sc, path)
val sameModel = KMeansModel.load(sc, path)
KMeansSuite.checkEqual(model, sameModel)
} finally {
Utils.deleteRecursively(tempDir)
}
}
}
}

object KMeansSuite extends FunSuite {
def createModel(dim: Int, k: Int, isSparse: Boolean): KMeansModel = {
val singlePoint = isSparse match {
case true =>
Vectors.sparse(dim, Array.empty[Int], Array.empty[Double])
case _ =>
Vectors.dense(Array.fill[Double](dim)(0.0))
}
new KMeansModel(Array.fill[Vector](k)(singlePoint))
}

def checkEqual(a: KMeansModel, b: KMeansModel): Unit = {
assert(a.k === b.k)
a.clusterCenters.zip(b.clusterCenters).foreach {
case (ca: SparseVector, cb: SparseVector) =>
assert(ca === cb)
case (ca: DenseVector, cb: DenseVector) =>
assert(ca === cb)
case _ =>
throw new AssertionError("checkEqual failed since the two clusters were not identical.\n")
}
}
}

class KMeansClusterSuite extends FunSuite with LocalClusterSparkContext {
Expand Down