From dce70553cb0e5c25d1bb0a415929eb5066af964a Mon Sep 17 00:00:00 2001 From: Xusen Yin Date: Mon, 9 Mar 2015 22:12:59 +0800 Subject: [PATCH 1/4] add save/load for k-means for SPARK-5986 --- .../spark/mllib/clustering/KMeansModel.scala | 62 ++++++++++++++++++- .../spark/mllib/clustering/KMeansSuite.scala | 51 ++++++++++++++- 2 files changed, 109 insertions(+), 4 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeansModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeansModel.scala index 3b95a9e6936e8..d01af13413ef2 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeansModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeansModel.scala @@ -17,15 +17,22 @@ package org.apache.spark.mllib.clustering +import org.json4s._ +import org.json4s.JsonDSL._ +import org.json4s.jackson.JsonMethods._ + +import org.apache.spark.mllib.linalg._ +import org.apache.spark.mllib.util.{Loader, Saveable} +import org.apache.spark.mllib.util.Loader._ +import org.apache.spark.sql.SQLContext +import org.apache.spark.SparkContext import org.apache.spark.api.java.JavaRDD import org.apache.spark.rdd.RDD -import org.apache.spark.SparkContext._ -import org.apache.spark.mllib.linalg.Vector /** * A clustering model for K-means. Each point belongs to the cluster with the closest center. */ -class KMeansModel (val clusterCenters: Array[Vector]) extends Serializable { +class KMeansModel (val clusterCenters: Array[Vector]) extends Saveable with Serializable { /** Total number of clusters. */ def k: Int = clusterCenters.length @@ -58,4 +65,53 @@ class KMeansModel (val clusterCenters: Array[Vector]) extends Serializable { private def clusterCentersWithNorm: Iterable[VectorWithNorm] = clusterCenters.map(new VectorWithNorm(_)) + + override def save(sc: SparkContext, path: String): Unit = { + KMeansModel.SaveLoadV1_0.save(sc, this, path) + } + + override protected def formatVersion: String = "1.0" +} + +object KMeansModel extends Loader[KMeansModel] { + override def load(sc: SparkContext, path: String): KMeansModel = { + KMeansModel.SaveLoadV1_0.load(sc, path) + } + + private[clustering] + object SaveLoadV1_0 { + + private val thisFormatVersion = "1.0" + + private[clustering] + val thisClassName = "org.apache.spark.mllib.clustering.KMeansModel" + + /** + * Saves a [[KMeansModel]], where user features are saved under `data/users` and + * product features are saved under `data/products`. + */ + def save(sc: SparkContext, model: KMeansModel, path: String): Unit = { + val sqlContext = new SQLContext(sc) + val wrapper = new VectorUDT() + val metadata = compact(render( + ("class" -> thisClassName) ~ ("version" -> thisFormatVersion) ~ ("k" -> model.k))) + sc.parallelize(Seq(metadata), 1).saveAsTextFile(Loader.metadataPath(path)) + val dataRDD = sc.parallelize(model.clusterCenters).map(wrapper.serialize) + sqlContext.createDataFrame(dataRDD, wrapper.sqlType).saveAsParquetFile(Loader.dataPath(path)) + } + + def load(sc: SparkContext, path: String): KMeansModel = { + implicit val formats = DefaultFormats + val sqlContext = new SQLContext(sc) + val wrapper = new VectorUDT() + val (className, formatVersion, metadata) = loadMetadata(sc, path) + assert(className == thisClassName) + assert(formatVersion == thisFormatVersion) + val k = (metadata \ "k").extract[Int] + val centriods = sqlContext.parquetFile(dataPath(path)) + val localCentriods = centriods.collect() + assert(k == localCentriods.size) + new KMeansModel(localCentriods.map(wrapper.deserialize)) + } + } } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/clustering/KMeansSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/clustering/KMeansSuite.scala index caee5917000aa..c3e26918f84bd 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/clustering/KMeansSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/clustering/KMeansSuite.scala @@ -21,7 +21,8 @@ import scala.util.Random import org.scalatest.FunSuite -import org.apache.spark.mllib.linalg.{Vector, Vectors} +import org.apache.spark.util.Utils +import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector, Vectors} import org.apache.spark.mllib.util.{LocalClusterSparkContext, MLlibTestSparkContext} import org.apache.spark.mllib.util.TestingUtils._ @@ -257,6 +258,54 @@ class KMeansSuite extends FunSuite with MLlibTestSparkContext { assert(predicts(0) != predicts(3)) } } + + test("model save/load") { + val tempDir = Utils.createTempDir() + val path = tempDir.toURI.toString + + Array(true, false).foreach { case selector => + val model = KMeansSuite.createModel(10, 3, selector) + // Save model, load it back, and compare. + try { + model.save(sc, path) + val sameModel = KMeansModel.load(sc, path) + KMeansSuite.checkEqual(model, sameModel, selector) + } finally { + Utils.deleteRecursively(tempDir) + } + } + } +} + +object KMeansSuite extends FunSuite { + def createModel(dim: Int, k: Int, isSparse: Boolean): KMeansModel = { + val singlePoint = isSparse match { + case true => + Vectors.sparse(dim, Array.empty[Int], Array.empty[Double]) + case _ => + Vectors.dense(Array.fill[Double](dim)(0.0)) + } + new KMeansModel(Array.fill[Vector](k)(singlePoint)) + } + + def checkEqual(a: KMeansModel, b: KMeansModel, isSparse: Boolean): Unit = { + assert(a.k === b.k) + isSparse match { + case true => + a.clusterCenters.zip(b.clusterCenters).foreach { case (pointA, pointB) => + assert(pointA.asInstanceOf[SparseVector].size === pointB.asInstanceOf[SparseVector].size) + assert( + pointA.asInstanceOf[SparseVector].indices === pointB.asInstanceOf[SparseVector].indices) + assert( + pointA.asInstanceOf[SparseVector].values === pointB.asInstanceOf[SparseVector].values) + } + case _ => + a.clusterCenters.zip(b.clusterCenters).foreach { case (pointA, pointB) => + assert( + pointA.asInstanceOf[DenseVector].toArray === pointB.asInstanceOf[DenseVector].toArray) + } + } + } } class KMeansClusterSuite extends FunSuite with LocalClusterSparkContext { From b144216f741776fdfe4c8e95d63650bd46c659d5 Mon Sep 17 00:00:00 2001 From: Xusen Yin Date: Mon, 9 Mar 2015 22:18:24 +0800 Subject: [PATCH 2/4] remove invalid comments --- .../scala/org/apache/spark/mllib/clustering/KMeansModel.scala | 4 ---- 1 file changed, 4 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeansModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeansModel.scala index d01af13413ef2..b52f0767ca50f 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeansModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeansModel.scala @@ -86,10 +86,6 @@ object KMeansModel extends Loader[KMeansModel] { private[clustering] val thisClassName = "org.apache.spark.mllib.clustering.KMeansModel" - /** - * Saves a [[KMeansModel]], where user features are saved under `data/users` and - * product features are saved under `data/products`. - */ def save(sc: SparkContext, model: KMeansModel, path: String): Unit = { val sqlContext = new SQLContext(sc) val wrapper = new VectorUDT() From cd390fd294f65d465abee8a066fc29d7958ae9ec Mon Sep 17 00:00:00 2001 From: Xusen Yin Date: Tue, 10 Mar 2015 06:31:36 +0800 Subject: [PATCH 3/4] add indexed point --- .../spark/mllib/clustering/KMeansModel.scala | 29 +++++++++++++------ .../spark/mllib/clustering/KMeansSuite.scala | 2 +- 2 files changed, 21 insertions(+), 10 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeansModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeansModel.scala index b52f0767ca50f..f229fa1b2ff52 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeansModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeansModel.scala @@ -21,13 +21,14 @@ import org.json4s._ import org.json4s.JsonDSL._ import org.json4s.jackson.JsonMethods._ +import org.apache.spark.api.java.JavaRDD import org.apache.spark.mllib.linalg._ import org.apache.spark.mllib.util.{Loader, Saveable} import org.apache.spark.mllib.util.Loader._ -import org.apache.spark.sql.SQLContext -import org.apache.spark.SparkContext -import org.apache.spark.api.java.JavaRDD import org.apache.spark.rdd.RDD +import org.apache.spark.SparkContext +import org.apache.spark.sql.SQLContext +import org.apache.spark.sql.Row /** * A clustering model for K-means. Each point belongs to the cluster with the closest center. @@ -78,6 +79,14 @@ object KMeansModel extends Loader[KMeansModel] { KMeansModel.SaveLoadV1_0.load(sc, path) } + case class IndexedPoint(id: Int, point: Vector) + + object IndexedPoint { + def apply(r: Row): IndexedPoint = { + IndexedPoint(r.getInt(0), r.getAs[Vector](1)) + } + } + private[clustering] object SaveLoadV1_0 { @@ -88,26 +97,28 @@ object KMeansModel extends Loader[KMeansModel] { def save(sc: SparkContext, model: KMeansModel, path: String): Unit = { val sqlContext = new SQLContext(sc) - val wrapper = new VectorUDT() + import sqlContext.implicits._ val metadata = compact(render( ("class" -> thisClassName) ~ ("version" -> thisFormatVersion) ~ ("k" -> model.k))) sc.parallelize(Seq(metadata), 1).saveAsTextFile(Loader.metadataPath(path)) - val dataRDD = sc.parallelize(model.clusterCenters).map(wrapper.serialize) - sqlContext.createDataFrame(dataRDD, wrapper.sqlType).saveAsParquetFile(Loader.dataPath(path)) + val dataRDD = sc.parallelize(model.clusterCenters.zipWithIndex).map { case (point, id) => + IndexedPoint(id, point) + }.toDF() + dataRDD.saveAsParquetFile(Loader.dataPath(path)) } def load(sc: SparkContext, path: String): KMeansModel = { implicit val formats = DefaultFormats val sqlContext = new SQLContext(sc) - val wrapper = new VectorUDT() val (className, formatVersion, metadata) = loadMetadata(sc, path) assert(className == thisClassName) assert(formatVersion == thisFormatVersion) val k = (metadata \ "k").extract[Int] val centriods = sqlContext.parquetFile(dataPath(path)) - val localCentriods = centriods.collect() + Loader.checkSchema[IndexedPoint](centriods.schema) + val localCentriods = centriods.map(IndexedPoint.apply).collect() assert(k == localCentriods.size) - new KMeansModel(localCentriods.map(wrapper.deserialize)) + new KMeansModel(localCentriods.sortBy(_.id).map(_.point)) } } } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/clustering/KMeansSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/clustering/KMeansSuite.scala index c3e26918f84bd..0c596e41a5c0a 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/clustering/KMeansSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/clustering/KMeansSuite.scala @@ -21,10 +21,10 @@ import scala.util.Random import org.scalatest.FunSuite -import org.apache.spark.util.Utils import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector, Vectors} import org.apache.spark.mllib.util.{LocalClusterSparkContext, MLlibTestSparkContext} import org.apache.spark.mllib.util.TestingUtils._ +import org.apache.spark.util.Utils class KMeansSuite extends FunSuite with MLlibTestSparkContext { From 6dd74a0f57678bdbfc6654433047e96ff1801429 Mon Sep 17 00:00:00 2001 From: Xusen Yin Date: Wed, 11 Mar 2015 09:19:08 +0800 Subject: [PATCH 4/4] rewrite some functions and classes --- .../spark/mllib/clustering/KMeansModel.scala | 21 ++++++++--------- .../spark/mllib/clustering/KMeansSuite.scala | 23 +++++++------------ 2 files changed, 18 insertions(+), 26 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeansModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeansModel.scala index f229fa1b2ff52..707da537d238f 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeansModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeansModel.scala @@ -22,9 +22,8 @@ import org.json4s.JsonDSL._ import org.json4s.jackson.JsonMethods._ import org.apache.spark.api.java.JavaRDD -import org.apache.spark.mllib.linalg._ +import org.apache.spark.mllib.linalg.Vector import org.apache.spark.mllib.util.{Loader, Saveable} -import org.apache.spark.mllib.util.Loader._ import org.apache.spark.rdd.RDD import org.apache.spark.SparkContext import org.apache.spark.sql.SQLContext @@ -79,11 +78,11 @@ object KMeansModel extends Loader[KMeansModel] { KMeansModel.SaveLoadV1_0.load(sc, path) } - case class IndexedPoint(id: Int, point: Vector) + private case class Cluster(id: Int, point: Vector) - object IndexedPoint { - def apply(r: Row): IndexedPoint = { - IndexedPoint(r.getInt(0), r.getAs[Vector](1)) + private object Cluster { + def apply(r: Row): Cluster = { + Cluster(r.getInt(0), r.getAs[Vector](1)) } } @@ -102,7 +101,7 @@ object KMeansModel extends Loader[KMeansModel] { ("class" -> thisClassName) ~ ("version" -> thisFormatVersion) ~ ("k" -> model.k))) sc.parallelize(Seq(metadata), 1).saveAsTextFile(Loader.metadataPath(path)) val dataRDD = sc.parallelize(model.clusterCenters.zipWithIndex).map { case (point, id) => - IndexedPoint(id, point) + Cluster(id, point) }.toDF() dataRDD.saveAsParquetFile(Loader.dataPath(path)) } @@ -110,13 +109,13 @@ object KMeansModel extends Loader[KMeansModel] { def load(sc: SparkContext, path: String): KMeansModel = { implicit val formats = DefaultFormats val sqlContext = new SQLContext(sc) - val (className, formatVersion, metadata) = loadMetadata(sc, path) + val (className, formatVersion, metadata) = Loader.loadMetadata(sc, path) assert(className == thisClassName) assert(formatVersion == thisFormatVersion) val k = (metadata \ "k").extract[Int] - val centriods = sqlContext.parquetFile(dataPath(path)) - Loader.checkSchema[IndexedPoint](centriods.schema) - val localCentriods = centriods.map(IndexedPoint.apply).collect() + val centriods = sqlContext.parquetFile(Loader.dataPath(path)) + Loader.checkSchema[Cluster](centriods.schema) + val localCentriods = centriods.map(Cluster.apply).collect() assert(k == localCentriods.size) new KMeansModel(localCentriods.sortBy(_.id).map(_.point)) } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/clustering/KMeansSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/clustering/KMeansSuite.scala index 0c596e41a5c0a..7bf250eb5a383 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/clustering/KMeansSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/clustering/KMeansSuite.scala @@ -269,7 +269,7 @@ class KMeansSuite extends FunSuite with MLlibTestSparkContext { try { model.save(sc, path) val sameModel = KMeansModel.load(sc, path) - KMeansSuite.checkEqual(model, sameModel, selector) + KMeansSuite.checkEqual(model, sameModel) } finally { Utils.deleteRecursively(tempDir) } @@ -288,22 +288,15 @@ object KMeansSuite extends FunSuite { new KMeansModel(Array.fill[Vector](k)(singlePoint)) } - def checkEqual(a: KMeansModel, b: KMeansModel, isSparse: Boolean): Unit = { + def checkEqual(a: KMeansModel, b: KMeansModel): Unit = { assert(a.k === b.k) - isSparse match { - case true => - a.clusterCenters.zip(b.clusterCenters).foreach { case (pointA, pointB) => - assert(pointA.asInstanceOf[SparseVector].size === pointB.asInstanceOf[SparseVector].size) - assert( - pointA.asInstanceOf[SparseVector].indices === pointB.asInstanceOf[SparseVector].indices) - assert( - pointA.asInstanceOf[SparseVector].values === pointB.asInstanceOf[SparseVector].values) - } + a.clusterCenters.zip(b.clusterCenters).foreach { + case (ca: SparseVector, cb: SparseVector) => + assert(ca === cb) + case (ca: DenseVector, cb: DenseVector) => + assert(ca === cb) case _ => - a.clusterCenters.zip(b.clusterCenters).foreach { case (pointA, pointB) => - assert( - pointA.asInstanceOf[DenseVector].toArray === pointB.asInstanceOf[DenseVector].toArray) - } + throw new AssertionError("checkEqual failed since the two clusters were not identical.\n") } } }